package com.linecorp.armeria.server.thrift;

import com.google.common.base.Throwables;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.google.common.net.MediaType;
import com.linecorp.armeria.common.RequestContext;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.http.AggregatedHttpMessage;
import com.linecorp.armeria.common.http.HttpData;
import com.linecorp.armeria.common.http.HttpHeaderNames;
import com.linecorp.armeria.common.http.HttpHeaders;
import com.linecorp.armeria.common.http.HttpRequest;
import com.linecorp.armeria.common.http.HttpResponseWriter;
import com.linecorp.armeria.common.http.HttpStatus;
import com.linecorp.armeria.common.logging.RequestLog;
import com.linecorp.armeria.common.logging.ResponseLog;
import com.linecorp.armeria.common.thrift.ApacheThriftCall;
import com.linecorp.armeria.common.thrift.ApacheThriftReply;
import com.linecorp.armeria.common.thrift.ThriftCall;
import com.linecorp.armeria.common.thrift.ThriftProtocolFactories;
import com.linecorp.armeria.common.thrift.ThriftReply;
import com.linecorp.armeria.common.util.CompletionActions;
import com.linecorp.armeria.common.util.Functions;
import com.linecorp.armeria.internal.thrift.ThriftFunction;
import com.linecorp.armeria.server.Service;
import com.linecorp.armeria.server.ServiceRequestContext;
import com.linecorp.armeria.server.http.AbstractHttpService;
import java.util.Arrays;
import java.util.EnumSet;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.TBase;
import org.apache.thrift.TException;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.apache.thrift.transport.TMemoryBuffer;
import org.apache.thrift.transport.TMemoryInputTransport;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/linecorp/armeria/server/thrift/THttpService.class */
public class THttpService extends AbstractHttpService {
    private static final String THRIFT_PROTOCOL_NOT_SUPPORTED = "Specified Thrift protocol not supported";
    private static final String ACCEPT_THRIFT_PROTOCOL_MUST_MATCH_CONTENT_TYPE = "Thrift protocol specified in Accept header must match the one specified in Content-Type header";
    private final Service<ThriftCall, ThriftReply> delegate;
    private final SerializationFormat defaultSerializationFormat;
    private final Set<SerializationFormat> allowedSerializationFormats;
    private final ThriftCallService thriftService;
    private static final Logger logger = LoggerFactory.getLogger(THttpService.class);
    private static final Map<SerializationFormat, ThreadLocalTProtocol> FORMAT_TO_THREAD_LOCAL_INPUT_PROTOCOL = createFormatToThreadLocalTProtocolMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/linecorp/armeria/server/thrift/THttpService$ThreadLocalTProtocol.class */
    public static final class ThreadLocalTProtocol extends ThreadLocal<TProtocol> {
        private final TProtocolFactory protoFactory;

        private ThreadLocalTProtocol(TProtocolFactory tProtocolFactory) {
            this.protoFactory = tProtocolFactory;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public TProtocol initialValue() {
            return this.protoFactory.getProtocol(new TMemoryInputTransport());
        }
    }

    public static THttpService of(Object obj) {
        return of(obj, SerializationFormat.THRIFT_BINARY);
    }

    public static THttpService of(Map<String, ?> map) {
        return of(map, SerializationFormat.THRIFT_BINARY);
    }

    public static THttpService of(Object obj, SerializationFormat serializationFormat) {
        return new THttpService(ThriftCallService.of(obj), serializationFormat, SerializationFormat.ofThrift());
    }

    public static THttpService of(Map<String, ?> map, SerializationFormat serializationFormat) {
        return new THttpService(ThriftCallService.of(map), serializationFormat, SerializationFormat.ofThrift());
    }

    public static THttpService ofFormats(Object obj, SerializationFormat serializationFormat, SerializationFormat... serializationFormatArr) {
        Objects.requireNonNull(serializationFormatArr, "otherAllowedSerializationFormats");
        return ofFormats(obj, serializationFormat, Arrays.asList(serializationFormatArr));
    }

    public static THttpService ofFormats(Map<String, ?> map, SerializationFormat serializationFormat, SerializationFormat... serializationFormatArr) {
        Objects.requireNonNull(serializationFormatArr, "otherAllowedSerializationFormats");
        return ofFormats(map, serializationFormat, (Iterable<SerializationFormat>) Arrays.asList(serializationFormatArr));
    }

    public static THttpService ofFormats(Object obj, SerializationFormat serializationFormat, Iterable<SerializationFormat> iterable) {
        Objects.requireNonNull(iterable, "otherAllowedSerializationFormats");
        EnumSet of = EnumSet.of(serializationFormat);
        of.getClass();
        iterable.forEach((v1) -> {
            r1.add(v1);
        });
        return new THttpService(ThriftCallService.of(obj), serializationFormat, of);
    }

    public static THttpService ofFormats(Map<String, ?> map, SerializationFormat serializationFormat, Iterable<SerializationFormat> iterable) {
        Objects.requireNonNull(iterable, "otherAllowedSerializationFormats");
        EnumSet of = EnumSet.of(serializationFormat);
        of.getClass();
        iterable.forEach((v1) -> {
            r1.add(v1);
        });
        return new THttpService(ThriftCallService.of(map), serializationFormat, of);
    }

    public static Function<Service<ThriftCall, ThriftReply>, THttpService> newDecorator() {
        return newDecorator(SerializationFormat.THRIFT_BINARY);
    }

    public static Function<Service<ThriftCall, ThriftReply>, THttpService> newDecorator(SerializationFormat serializationFormat) {
        return service -> {
            return new THttpService(service, serializationFormat, SerializationFormat.ofThrift());
        };
    }

    public static Function<Service<ThriftCall, ThriftReply>, THttpService> newDecorator(SerializationFormat serializationFormat, SerializationFormat... serializationFormatArr) {
        Objects.requireNonNull(serializationFormatArr, "otherAllowedSerializationFormats");
        return newDecorator(serializationFormat, Arrays.asList(serializationFormatArr));
    }

    public static Function<Service<ThriftCall, ThriftReply>, THttpService> newDecorator(SerializationFormat serializationFormat, Iterable<SerializationFormat> iterable) {
        Objects.requireNonNull(iterable, "otherAllowedSerializationFormats");
        EnumSet of = EnumSet.of(serializationFormat);
        of.getClass();
        iterable.forEach((v1) -> {
            r1.add(v1);
        });
        return service -> {
            return new THttpService(service, serializationFormat, of);
        };
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public THttpService(Service<ThriftCall, ThriftReply> service, SerializationFormat serializationFormat, Set<SerializationFormat> set) {
        Objects.requireNonNull(service, "delegate");
        Objects.requireNonNull(serializationFormat, "defaultSerializationFormat");
        Objects.requireNonNull(set, "allowedSerializationFormats");
        this.delegate = service;
        this.thriftService = findThriftService(service);
        this.defaultSerializationFormat = serializationFormat;
        this.allowedSerializationFormats = Sets.immutableEnumSet(set);
    }

    private static ThriftCallService findThriftService(Service<?, ?> service) {
        return (ThriftCallService) service.as(ThriftCallService.class).orElseThrow(() -> {
            return new IllegalStateException("service being decorated is not a ThriftService: " + service);
        });
    }

    public Map<String, ThriftServiceEntry> entries() {
        return this.thriftService.entries();
    }

    public Set<SerializationFormat> allowedSerializationFormats() {
        return this.allowedSerializationFormats;
    }

    public SerializationFormat defaultSerializationFormat() {
        return this.defaultSerializationFormat;
    }

    @Override // com.linecorp.armeria.server.http.AbstractHttpService
    protected void doPost(ServiceRequestContext serviceRequestContext, HttpRequest httpRequest, HttpResponseWriter httpResponseWriter) {
        SerializationFormat validateRequestAndDetermineSerializationFormat = validateRequestAndDetermineSerializationFormat(httpRequest, httpResponseWriter);
        if (validateRequestAndDetermineSerializationFormat == null) {
            return;
        }
        serviceRequestContext.requestLogBuilder().serializationFormat(validateRequestAndDetermineSerializationFormat);
        httpRequest.aggregate().handle(Functions.voidFunction((aggregatedHttpMessage, th) -> {
            if (th != null) {
                httpResponseWriter.respond(HttpStatus.INTERNAL_SERVER_ERROR, MediaType.PLAIN_TEXT_UTF_8, Throwables.getStackTraceAsString(th));
            } else {
                decodeAndInvoke(serviceRequestContext, aggregatedHttpMessage, validateRequestAndDetermineSerializationFormat, httpResponseWriter);
            }
        })).exceptionally((Function<Throwable, ? extends U>) CompletionActions::log);
    }

    private SerializationFormat validateRequestAndDetermineSerializationFormat(HttpRequest httpRequest, HttpResponseWriter httpResponseWriter) {
        SerializationFormat serializationFormat;
        HttpHeaders headers = httpRequest.headers();
        CharSequence charSequence = (CharSequence) headers.get(HttpHeaderNames.CONTENT_TYPE);
        if (charSequence != null) {
            serializationFormat = SerializationFormat.fromMediaType(charSequence.toString()).orElse(this.defaultSerializationFormat);
            if (!this.allowedSerializationFormats.contains(serializationFormat)) {
                httpResponseWriter.respond(HttpStatus.UNSUPPORTED_MEDIA_TYPE, MediaType.PLAIN_TEXT_UTF_8, THRIFT_PROTOCOL_NOT_SUPPORTED);
                return null;
            }
        } else {
            serializationFormat = this.defaultSerializationFormat;
        }
        CharSequence charSequence2 = (CharSequence) headers.get(HttpHeaderNames.ACCEPT);
        if (charSequence2 == null || SerializationFormat.fromMediaType(charSequence2.toString()).orElse(serializationFormat) == serializationFormat) {
            return serializationFormat;
        }
        httpResponseWriter.respond(HttpStatus.NOT_ACCEPTABLE, MediaType.PLAIN_TEXT_UTF_8, ACCEPT_THRIFT_PROTOCOL_MUST_MATCH_CONTENT_TYPE);
        return null;
    }

    private void decodeAndInvoke(ServiceRequestContext serviceRequestContext, AggregatedHttpMessage aggregatedHttpMessage, SerializationFormat serializationFormat, HttpResponseWriter httpResponseWriter) {
        String substring;
        String substring2;
        TBase<TBase<?, ?>, TFieldIdEnum> emptyArgsInstance;
        TProtocol tProtocol = FORMAT_TO_THREAD_LOCAL_INPUT_PROTOCOL.get(serializationFormat).get();
        tProtocol.reset();
        TMemoryInputTransport transport = tProtocol.getTransport();
        HttpData content = aggregatedHttpMessage.content();
        transport.reset(content.array(), content.offset(), content.length());
        try {
            try {
                TMessage readMessageBegin = tProtocol.readMessageBegin();
                int i = readMessageBegin.seqid;
                byte b = readMessageBegin.type;
                int indexOf = readMessageBegin.name.indexOf(58);
                if (indexOf < 0) {
                    substring = "";
                    substring2 = readMessageBegin.name;
                } else {
                    substring = readMessageBegin.name.substring(0, indexOf);
                    substring2 = readMessageBegin.name.substring(indexOf + 1);
                }
                if (b != 1 && b != 4) {
                    respond(serviceRequestContext, serializationFormat, i, substring2, (Throwable) new TApplicationException(2, "unexpected TMessageType: " + typeString(b)), httpResponseWriter);
                    transport.clear();
                    return;
                }
                ThriftServiceEntry thriftServiceEntry = entries().get(substring);
                ThriftFunction function = thriftServiceEntry != null ? thriftServiceEntry.metadata.function(substring2) : null;
                if (function == null) {
                    respond(serviceRequestContext, serializationFormat, i, substring2, (Throwable) new TApplicationException(1, "unknown method: " + readMessageBegin.name), httpResponseWriter);
                    transport.clear();
                    return;
                }
                try {
                    if (function.isAsync()) {
                        emptyArgsInstance = (TBase) function.asyncFunc().getEmptyArgsInstance();
                        emptyArgsInstance.read(tProtocol);
                        tProtocol.readMessageEnd();
                    } else {
                        emptyArgsInstance = function.syncFunc().getEmptyArgsInstance();
                        emptyArgsInstance.read(tProtocol);
                        tProtocol.readMessageEnd();
                    }
                    serviceRequestContext.requestLogBuilder().attr(RequestLog.RAW_RPC_REQUEST).set(new ApacheThriftCall(readMessageBegin, emptyArgsInstance));
                    transport.clear();
                    invoke(serviceRequestContext, serializationFormat, i, readMessageBegin.name, function, emptyArgsInstance, httpResponseWriter);
                } catch (Exception e) {
                    logger.debug("{} Failed to decode Thrift arguments:", serviceRequestContext, e);
                    respond(serviceRequestContext, serializationFormat, i, substring2, (Throwable) new TApplicationException(7, "failed to decode arguments: " + e), httpResponseWriter);
                    transport.clear();
                }
            } catch (Exception e2) {
                logger.debug("{} Failed to decode Thrift header:", serviceRequestContext, e2);
                httpResponseWriter.respond(HttpStatus.BAD_REQUEST, MediaType.PLAIN_TEXT_UTF_8, "Failed to decode Thrift header: " + Throwables.getStackTraceAsString(e2));
                transport.clear();
            }
        } catch (Throwable th) {
            transport.clear();
            throw th;
        }
    }

    private static String typeString(byte b) {
        switch (b) {
            case 1:
                return "CALL";
            case 2:
                return "REPLY";
            case 3:
                return "EXCEPTION";
            case 4:
                return "ONEWAY";
            default:
                return "UNKNOWN(" + (b & 255) + ')';
        }
    }

    private void invoke(ServiceRequestContext serviceRequestContext, SerializationFormat serializationFormat, int i, String str, ThriftFunction thriftFunction, TBase<TBase<?, ?>, TFieldIdEnum> tBase, HttpResponseWriter httpResponseWriter) {
        ThriftCall thriftCall = new ThriftCall(i, thriftFunction.serviceType(), str, (TBase<?, ?>) tBase);
        serviceRequestContext.requestLogBuilder().attr(RequestLog.RPC_REQUEST).set(thriftCall);
        try {
            RequestContext.PushHandle push = RequestContext.push(serviceRequestContext);
            Throwable th = null;
            try {
                try {
                    ThriftReply serve = this.delegate.serve(serviceRequestContext, thriftCall);
                    serviceRequestContext.responseLogBuilder().attr(ResponseLog.RPC_RESPONSE).set(serve);
                    if (push != null) {
                        if (0 != 0) {
                            try {
                                push.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        } else {
                            push.close();
                        }
                    }
                    serve.handle(Functions.voidFunction((obj, th3) -> {
                        if (th3 != null) {
                            handleException(serviceRequestContext, serializationFormat, i, thriftFunction, th3, httpResponseWriter);
                            return;
                        }
                        if (thriftFunction.isOneWay()) {
                            respond(serializationFormat, HttpData.EMPTY_DATA, httpResponseWriter);
                            return;
                        }
                        try {
                            TBase<TBase<?, ?>, TFieldIdEnum> newResult = thriftFunction.newResult();
                            thriftFunction.setSuccess(newResult, obj);
                            respond(serviceRequestContext, serializationFormat, i, thriftFunction.name(), newResult, httpResponseWriter);
                        } catch (Throwable th3) {
                            TBase<TBase<?, ?>, TFieldIdEnum> newResult2 = thriftFunction.newResult();
                            if (thriftFunction.setException(newResult2, th3)) {
                                respond(serviceRequestContext, serializationFormat, i, thriftFunction.name(), newResult2, httpResponseWriter);
                            } else {
                                respond(serviceRequestContext, serializationFormat, i, thriftFunction.name(), th3, httpResponseWriter);
                            }
                        }
                    })).exceptionally(CompletionActions::log);
                } finally {
                }
            } finally {
            }
        } catch (Throwable th4) {
            handleException(serviceRequestContext, serializationFormat, i, thriftFunction, th4, httpResponseWriter);
        }
    }

    private static void handleException(ServiceRequestContext serviceRequestContext, SerializationFormat serializationFormat, int i, ThriftFunction thriftFunction, Throwable th, HttpResponseWriter httpResponseWriter) {
        TBase<TBase<?, ?>, TFieldIdEnum> newResult = thriftFunction.newResult();
        if (thriftFunction.setException(newResult, th)) {
            respond(serviceRequestContext, serializationFormat, i, thriftFunction.name(), newResult, httpResponseWriter);
        } else {
            respond(serviceRequestContext, serializationFormat, i, thriftFunction.name(), th, httpResponseWriter);
        }
    }

    private static void respond(ServiceRequestContext serviceRequestContext, SerializationFormat serializationFormat, int i, String str, TBase<TBase<?, ?>, TFieldIdEnum> tBase, HttpResponseWriter httpResponseWriter) {
        respond(serializationFormat, encodeSuccess(serviceRequestContext, serializationFormat, str, i, tBase), httpResponseWriter);
    }

    private static void respond(ServiceRequestContext serviceRequestContext, SerializationFormat serializationFormat, int i, String str, Throwable th, HttpResponseWriter httpResponseWriter) {
        respond(serializationFormat, encodeException(serviceRequestContext, serializationFormat, i, str, th), httpResponseWriter);
    }

    private static void respond(SerializationFormat serializationFormat, HttpData httpData, HttpResponseWriter httpResponseWriter) {
        httpResponseWriter.respond(HttpStatus.OK, serializationFormat.mediaType(), httpData);
    }

    private static HttpData encodeSuccess(ServiceRequestContext serviceRequestContext, SerializationFormat serializationFormat, String str, int i, TBase<TBase<?, ?>, TFieldIdEnum> tBase) {
        TMemoryBuffer tMemoryBuffer = new TMemoryBuffer(128);
        TProtocol protocol = ThriftProtocolFactories.get(serializationFormat).getProtocol(tMemoryBuffer);
        try {
            TMessage tMessage = new TMessage(str, (byte) 2, i);
            protocol.writeMessageBegin(tMessage);
            tBase.write(protocol);
            protocol.writeMessageEnd();
            serviceRequestContext.responseLogBuilder().attr(ResponseLog.RAW_RPC_RESPONSE).set(new ApacheThriftReply(tMessage, (TBase<?, ?>) tBase));
            return HttpData.of(tMemoryBuffer.getArray(), 0, tMemoryBuffer.length());
        } catch (TException e) {
            throw new Error((Throwable) e);
        }
    }

    private static HttpData encodeException(ServiceRequestContext serviceRequestContext, SerializationFormat serializationFormat, int i, String str, Throwable th) {
        TApplicationException tApplicationException = th instanceof TApplicationException ? (TApplicationException) th : new TApplicationException(6, "internal server error:" + System.lineSeparator() + "---- BEGIN server-side trace ----" + System.lineSeparator() + Throwables.getStackTraceAsString(th) + "---- END server-side trace ----");
        TMemoryBuffer tMemoryBuffer = new TMemoryBuffer(128);
        TProtocol protocol = ThriftProtocolFactories.get(serializationFormat).getProtocol(tMemoryBuffer);
        try {
            TMessage tMessage = new TMessage(str, (byte) 3, i);
            protocol.writeMessageBegin(tMessage);
            tApplicationException.write(protocol);
            protocol.writeMessageEnd();
            serviceRequestContext.responseLogBuilder().attr(ResponseLog.RAW_RPC_RESPONSE).set(new ApacheThriftReply(tMessage, tApplicationException));
            return HttpData.of(tMemoryBuffer.getArray(), 0, tMemoryBuffer.length());
        } catch (TException e) {
            throw new Error((Throwable) e);
        }
    }

    private static Map<SerializationFormat, ThreadLocalTProtocol> createFormatToThreadLocalTProtocolMap() {
        return Maps.immutableEnumMap((Map) SerializationFormat.ofThrift().stream().collect(Collectors.toMap(Function.identity(), serializationFormat -> {
            return new ThreadLocalTProtocol(ThriftProtocolFactories.get(serializationFormat));
        })));
    }
}
