package com.linecorp.armeria.server.thrift;

import com.linecorp.armeria.common.Scheme;
import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.ServiceInvocationContext;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.thrift.ThriftProtocolFactories;
import com.linecorp.armeria.common.thrift.ThriftUtil;
import com.linecorp.armeria.common.util.Exceptions;
import com.linecorp.armeria.server.ServiceCodec;
import com.linecorp.armeria.server.ServiceConfig;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.ByteBufAllocator;
import io.netty.channel.Channel;
import io.netty.handler.codec.http.DefaultFullHttpResponse;
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.handler.codec.http.HttpResponseStatus;
import io.netty.handler.codec.http.HttpVersion;
import io.netty.util.concurrent.Promise;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.thrift.AsyncProcessFunction;
import org.apache.thrift.ProcessFunction;
import org.apache.thrift.TApplicationException;
import org.apache.thrift.TBase;
import org.apache.thrift.TBaseAsyncProcessor;
import org.apache.thrift.TBaseProcessor;
import org.apache.thrift.TException;
import org.apache.thrift.TFieldIdEnum;
import org.apache.thrift.protocol.TMessage;
import org.apache.thrift.protocol.TProtocol;
import org.apache.thrift.protocol.TProtocolFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:com/linecorp/armeria/server/thrift/ThriftServiceCodec.class */
final class ThriftServiceCodec implements ServiceCodec {
    private final SerializationFormat defaultSerializationFormat;
    private final Set<SerializationFormat> allowedSerializationFormats;
    private final Object service;
    private final Map<String, ThriftFunction> functions = new HashMap();
    private static final Exception HTTP_METHOD_NOT_ALLOWED_EXCEPTION = (Exception) Exceptions.clearTrace(new IllegalArgumentException("HTTP method not allowed"));
    private static final Exception THRIFT_PROTOCOL_NOT_SUPPORTED = (Exception) Exceptions.clearTrace(new IllegalArgumentException("Specified Thrift protocol not supported"));
    private static final Exception ACCEPT_THRIFT_PROTOCOL_MUST_MATCH_CONTENT_TYPE = (Exception) Exceptions.clearTrace(new IllegalArgumentException("Thrift protocol specified in Accept header must match the one specified in Content-Type header"));
    private static final Logger logger = LoggerFactory.getLogger(ThriftServiceCodec.class);
    private static final Map<SerializationFormat, ThreadLocalTProtocol> FORMAT_TO_THREAD_LOCAL_IN_PROTOCOL = createFormatToThreadLocalTProtocolMap();
    private static final Map<SerializationFormat, ThreadLocalTProtocol> FORMAT_TO_THREAD_LOCAL_OUT_PROTOCOL = createFormatToThreadLocalTProtocolMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/linecorp/armeria/server/thrift/ThriftServiceCodec$InvalidHttpRequestException.class */
    public static final class InvalidHttpRequestException extends Exception {
        private static final long serialVersionUID = -8742741687997488293L;
        private final HttpResponseStatus httpResponseStatus;

        private InvalidHttpRequestException(HttpResponseStatus httpResponseStatus, Exception exc) {
            super(exc);
            this.httpResponseStatus = httpResponseStatus;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/linecorp/armeria/server/thrift/ThriftServiceCodec$ThreadLocalTProtocol.class */
    public static final class ThreadLocalTProtocol extends ThreadLocal<TProtocol> {
        private final TProtocolFactory protoFactory;

        private ThreadLocalTProtocol(TProtocolFactory tProtocolFactory) {
            this.protoFactory = tProtocolFactory;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.lang.ThreadLocal
        public TProtocol initialValue() {
            return this.protoFactory.getProtocol(new TByteBufTransport());
        }
    }

    /* loaded from: input_file:com/linecorp/armeria/server/thrift/ThriftServiceCodec$ThriftDecodeFailureResult.class */
    private static final class ThriftDecodeFailureResult extends ServiceCodec.DefaultDecodeResult {
        private final SerializationFormat serializationFormat;
        private final int seqId;
        private final String method;
        private final TBase<TBase<?, ?>, TFieldIdEnum> params;
        private String seqIdStr;
        private List<Object> paramList;

        ThriftDecodeFailureResult(SerializationFormat serializationFormat, Object obj, Throwable th, int i, String str, TBase<TBase<?, ?>, TFieldIdEnum> tBase) {
            super(obj, th);
            this.serializationFormat = serializationFormat;
            this.seqId = i;
            this.method = str;
            this.params = tBase;
        }

        @Override // com.linecorp.armeria.server.ServiceCodec.DefaultDecodeResult, com.linecorp.armeria.server.ServiceCodec.DecodeResult
        public SerializationFormat decodedSerializationFormat() {
            return this.serializationFormat;
        }

        @Override // com.linecorp.armeria.server.ServiceCodec.DefaultDecodeResult, com.linecorp.armeria.server.ServiceCodec.DecodeResult
        public Optional<String> decodedInvocationId() {
            String str = this.seqIdStr;
            if (str == null) {
                String seqIdToString = ThriftUtil.seqIdToString(this.seqId);
                str = seqIdToString;
                this.seqIdStr = seqIdToString;
            }
            return Optional.of(str);
        }

        @Override // com.linecorp.armeria.server.ServiceCodec.DefaultDecodeResult, com.linecorp.armeria.server.ServiceCodec.DecodeResult
        public Optional<String> decodedMethod() {
            return Optional.of(this.method);
        }

        @Override // com.linecorp.armeria.server.ServiceCodec.DefaultDecodeResult, com.linecorp.armeria.server.ServiceCodec.DecodeResult
        public Optional<List<Object>> decodedParams() {
            if (this.params == null) {
                return Optional.empty();
            }
            List<Object> list = this.paramList;
            if (list == null) {
                List<Object> javaParams = ThriftUtil.toJavaParams(this.params);
                list = javaParams;
                this.paramList = javaParams;
            }
            return Optional.of(list);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ThriftServiceCodec(Object obj, SerializationFormat serializationFormat, Set<SerializationFormat> set) {
        Objects.requireNonNull(set, "allowedSerializationFormats");
        this.service = Objects.requireNonNull(obj, "service");
        this.defaultSerializationFormat = (SerializationFormat) Objects.requireNonNull(serializationFormat, "defaultSerializationFormat");
        this.allowedSerializationFormats = Collections.unmodifiableSet(set);
        HashSet hashSet = new HashSet();
        Class<?> cls = obj.getClass();
        ClassLoader classLoader = cls.getClassLoader();
        HashSet<Class> hashSet2 = new HashSet();
        getAllInterfaces(cls, hashSet2);
        for (Class cls2 : hashSet2) {
            Map<String, AsyncProcessFunction<?, ?, ?>> thriftAsyncProcessMap = getThriftAsyncProcessMap(obj, cls2, classLoader);
            if (thriftAsyncProcessMap != null) {
                thriftAsyncProcessMap.forEach((str, asyncProcessFunction) -> {
                    registerFunction(hashSet, cls, str, asyncProcessFunction);
                });
            }
            Map<String, ProcessFunction<?, ?>> thriftProcessMap = getThriftProcessMap(obj, cls2, classLoader);
            if (thriftProcessMap != null) {
                thriftProcessMap.forEach((str2, processFunction) -> {
                    registerFunction(hashSet, cls, str2, processFunction);
                });
            }
        }
        if (this.functions.isEmpty()) {
            throw new IllegalArgumentException('\'' + cls.getName() + "' is not a Thrift service implementation.");
        }
    }

    private void registerFunction(Set<String> set, Class<?> cls, String str, Object obj) {
        checkDuplicateMethodName(set, cls, str);
        set.add(str);
        try {
            this.functions.put(str, obj instanceof ProcessFunction ? new ThriftFunction((ProcessFunction<?, ?>) obj) : new ThriftFunction((AsyncProcessFunction<?, ?, ?>) obj));
        } catch (Exception e) {
            throw new IllegalArgumentException("failed to retrieve function metadata: " + cls.getName() + '.' + str + "()", e);
        }
    }

    private static void checkDuplicateMethodName(Set<String> set, Class<?> cls, String str) {
        if (set.contains(str)) {
            throw new IllegalArgumentException('\'' + cls.getName() + "' implements multiple Thrift service interfaces with a duplicate method name: " + str);
        }
    }

    private static Map<String, ProcessFunction<?, ?>> getThriftProcessMap(Object obj, Class<?> cls, ClassLoader classLoader) {
        String name = cls.getName();
        if (!name.endsWith("$Iface")) {
            return null;
        }
        try {
            Class<?> cls2 = Class.forName(name.substring(0, name.length() - 5) + "Processor", false, classLoader);
            if (TBaseProcessor.class.isAssignableFrom(cls2)) {
                return ((TBaseProcessor) cls2.getConstructor(cls).newInstance(obj)).getProcessMapView();
            }
            return null;
        } catch (Exception e) {
            logger.debug("Failed to retrieve the process map from: {}", cls, e);
            return null;
        }
    }

    private static Map<String, AsyncProcessFunction<?, ?, ?>> getThriftAsyncProcessMap(Object obj, Class<?> cls, ClassLoader classLoader) {
        String name = cls.getName();
        if (!name.endsWith("$AsyncIface")) {
            return null;
        }
        try {
            Class<?> cls2 = Class.forName(name.substring(0, name.length() - 10) + "AsyncProcessor", false, classLoader);
            if (TBaseAsyncProcessor.class.isAssignableFrom(cls2)) {
                return ((TBaseAsyncProcessor) cls2.getConstructor(cls).newInstance(obj)).getProcessMapView();
            }
            return null;
        } catch (Exception e) {
            logger.debug("Failed to retrieve the asynchronous process map from:: {}", cls, e);
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Object thriftService() {
        return this.service;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public Set<SerializationFormat> allowedSerializationFormats() {
        return this.allowedSerializationFormats;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public SerializationFormat defaultSerializationFormat() {
        return this.defaultSerializationFormat;
    }

    @Override // com.linecorp.armeria.server.ServiceCodec
    public ServiceCodec.DecodeResult decodeRequest(ServiceConfig serviceConfig, Channel channel, SessionProtocol sessionProtocol, String str, String str2, String str3, ByteBuf byteBuf, Object obj, Promise<Object> promise) throws Exception {
        TBase emptyArgsInstance;
        try {
            SerializationFormat validateRequestAndDetermineSerializationFormat = validateRequestAndDetermineSerializationFormat(obj);
            TProtocol tProtocol = FORMAT_TO_THREAD_LOCAL_IN_PROTOCOL.get(validateRequestAndDetermineSerializationFormat).get();
            tProtocol.reset();
            TByteBufTransport tByteBufTransport = (TByteBufTransport) tProtocol.getTransport();
            tByteBufTransport.reset(byteBuf);
            try {
                TMessage readMessageBegin = tProtocol.readMessageBegin();
                byte b = readMessageBegin.type;
                int i = readMessageBegin.seqid;
                String str4 = readMessageBegin.name;
                if (b != 1 && b != 4) {
                    TApplicationException tApplicationException = new TApplicationException(2, "unexpected TMessageType: " + typeString(b));
                    ThriftDecodeFailureResult thriftDecodeFailureResult = new ThriftDecodeFailureResult(validateRequestAndDetermineSerializationFormat, encodeException(channel.alloc(), validateRequestAndDetermineSerializationFormat, str4, i, tApplicationException), tApplicationException, i, str4, null);
                    tByteBufTransport.clear();
                    return thriftDecodeFailureResult;
                }
                ThriftFunction thriftFunction = this.functions.get(str4);
                if (thriftFunction == null) {
                    TApplicationException tApplicationException2 = new TApplicationException(1, "unknown method: " + str4);
                    ThriftDecodeFailureResult thriftDecodeFailureResult2 = new ThriftDecodeFailureResult(validateRequestAndDetermineSerializationFormat, encodeException(channel.alloc(), validateRequestAndDetermineSerializationFormat, str4, i, tApplicationException2), tApplicationException2, i, str4, null);
                    tByteBufTransport.clear();
                    return thriftDecodeFailureResult2;
                }
                try {
                    if (thriftFunction.isAsync()) {
                        emptyArgsInstance = (TBase) thriftFunction.asyncFunc().getEmptyArgsInstance();
                        emptyArgsInstance.read(tProtocol);
                        tProtocol.readMessageEnd();
                    } else {
                        emptyArgsInstance = thriftFunction.syncFunc().getEmptyArgsInstance();
                        emptyArgsInstance.read(tProtocol);
                        tProtocol.readMessageEnd();
                    }
                    ThriftServiceInvocationContext thriftServiceInvocationContext = new ThriftServiceInvocationContext(channel, Scheme.of(validateRequestAndDetermineSerializationFormat, sessionProtocol), str, str2, str3, serviceConfig.loggerName(), obj, thriftFunction, i, emptyArgsInstance);
                    tByteBufTransport.clear();
                    return thriftServiceInvocationContext;
                } catch (Exception e) {
                    TApplicationException tApplicationException3 = new TApplicationException(7, "argument decode failure: " + e);
                    ThriftDecodeFailureResult thriftDecodeFailureResult3 = new ThriftDecodeFailureResult(validateRequestAndDetermineSerializationFormat, encodeException(channel.alloc(), validateRequestAndDetermineSerializationFormat, str4, i, tApplicationException3), tApplicationException3, i, str4, null);
                    tByteBufTransport.clear();
                    return thriftDecodeFailureResult3;
                }
            } catch (Throwable th) {
                tByteBufTransport.clear();
                throw th;
            }
        } catch (InvalidHttpRequestException e2) {
            return new ServiceCodec.DefaultDecodeResult(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, e2.httpResponseStatus), e2.getCause());
        }
    }

    @Override // com.linecorp.armeria.server.ServiceCodec
    public boolean failureResponseFailsSession(ServiceInvocationContext serviceInvocationContext) {
        return false;
    }

    @Override // com.linecorp.armeria.server.ServiceCodec
    public ByteBuf encodeResponse(ServiceInvocationContext serviceInvocationContext, Object obj) throws Exception {
        TBase<TBase<?, ?>, TFieldIdEnum> newResult;
        ThriftServiceInvocationContext thriftServiceInvocationContext = (ThriftServiceInvocationContext) serviceInvocationContext;
        ThriftFunction thriftFunction = thriftServiceInvocationContext.func;
        if (thriftFunction.isOneway()) {
            return null;
        }
        if (thriftFunction.isResult(obj)) {
            newResult = (TBase) obj;
        } else {
            newResult = thriftFunction.newResult();
            thriftFunction.setSuccess(newResult, obj);
        }
        return encodeSuccess(thriftServiceInvocationContext, newResult);
    }

    @Override // com.linecorp.armeria.server.ServiceCodec
    public ByteBuf encodeFailureResponse(ServiceInvocationContext serviceInvocationContext, Throwable th) throws Exception {
        ThriftServiceInvocationContext thriftServiceInvocationContext = (ThriftServiceInvocationContext) serviceInvocationContext;
        ThriftFunction thriftFunction = thriftServiceInvocationContext.func;
        if (thriftFunction.isOneway()) {
            return encodeException(thriftServiceInvocationContext, th);
        }
        try {
            TBase<TBase<?, ?>, TFieldIdEnum> newResult = thriftFunction.newResult();
            return thriftFunction.setException(newResult, th) ? encodeSuccess(thriftServiceInvocationContext, newResult) : encodeException(thriftServiceInvocationContext, th);
        } catch (Throwable th2) {
            return encodeException(thriftServiceInvocationContext, th2);
        }
    }

    private static ByteBuf encodeSuccess(ThriftServiceInvocationContext thriftServiceInvocationContext, TBase<TBase<?, ?>, TFieldIdEnum> tBase) {
        TProtocol tProtocol = FORMAT_TO_THREAD_LOCAL_OUT_PROTOCOL.get(thriftServiceInvocationContext.scheme().serializationFormat()).get();
        tProtocol.reset();
        TByteBufTransport tByteBufTransport = (TByteBufTransport) tProtocol.getTransport();
        ByteBuf buffer = thriftServiceInvocationContext.alloc().buffer();
        tByteBufTransport.reset(buffer);
        try {
            try {
                tProtocol.writeMessageBegin(new TMessage(thriftServiceInvocationContext.method(), (byte) 2, thriftServiceInvocationContext.seqId));
                tBase.write(tProtocol);
                tProtocol.writeMessageEnd();
                tByteBufTransport.clear();
                return buffer;
            } catch (TException e) {
                throw new Error((Throwable) e);
            }
        } catch (Throwable th) {
            tByteBufTransport.clear();
            throw th;
        }
    }

    private static ByteBuf encodeException(ThriftServiceInvocationContext thriftServiceInvocationContext, Throwable th) {
        return th instanceof TApplicationException ? encodeException(thriftServiceInvocationContext.alloc(), thriftServiceInvocationContext.scheme().serializationFormat(), thriftServiceInvocationContext.method(), thriftServiceInvocationContext.seqId, (TApplicationException) th) : encodeException(thriftServiceInvocationContext.alloc(), thriftServiceInvocationContext.scheme().serializationFormat(), thriftServiceInvocationContext.method(), thriftServiceInvocationContext.seqId, new TApplicationException(6, th.toString()));
    }

    private static ByteBuf encodeException(ByteBufAllocator byteBufAllocator, SerializationFormat serializationFormat, String str, int i, TApplicationException tApplicationException) {
        TProtocol tProtocol = FORMAT_TO_THREAD_LOCAL_OUT_PROTOCOL.get(serializationFormat).get();
        tProtocol.reset();
        TByteBufTransport tByteBufTransport = (TByteBufTransport) tProtocol.getTransport();
        ByteBuf buffer = byteBufAllocator.buffer();
        tByteBufTransport.reset(buffer);
        try {
            try {
                tProtocol.writeMessageBegin(new TMessage(str, (byte) 3, i));
                tApplicationException.write(tProtocol);
                tProtocol.writeMessageEnd();
                tByteBufTransport.clear();
                return buffer;
            } catch (TException e) {
                throw new Error((Throwable) e);
            }
        } catch (Throwable th) {
            tByteBufTransport.clear();
            throw th;
        }
    }

    private SerializationFormat validateRequestAndDetermineSerializationFormat(Object obj) throws InvalidHttpRequestException {
        SerializationFormat serializationFormat;
        if (!(obj instanceof HttpRequest)) {
            return this.defaultSerializationFormat;
        }
        HttpRequest httpRequest = (HttpRequest) obj;
        if (httpRequest.method() != HttpMethod.POST) {
            throw new InvalidHttpRequestException(HttpResponseStatus.METHOD_NOT_ALLOWED, HTTP_METHOD_NOT_ALLOWED_EXCEPTION);
        }
        String str = httpRequest.headers().get(HttpHeaderNames.CONTENT_TYPE);
        if (str != null) {
            serializationFormat = SerializationFormat.fromMimeType(str).orElse(this.defaultSerializationFormat);
            if (!this.allowedSerializationFormats.contains(serializationFormat)) {
                throw new InvalidHttpRequestException(HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE, THRIFT_PROTOCOL_NOT_SUPPORTED);
            }
        } else {
            serializationFormat = this.defaultSerializationFormat;
        }
        String str2 = httpRequest.headers().get(HttpHeaderNames.ACCEPT);
        if (str2 == null || SerializationFormat.fromMimeType(str2).orElse(serializationFormat) == serializationFormat) {
            return serializationFormat;
        }
        throw new InvalidHttpRequestException(HttpResponseStatus.NOT_ACCEPTABLE, ACCEPT_THRIFT_PROTOCOL_MUST_MATCH_CONTENT_TYPE);
    }

    private static String typeString(byte b) {
        switch (b) {
            case 1:
                return "CALL";
            case 2:
                return "REPLY";
            case 3:
                return "EXCEPTION";
            case 4:
                return "ONEWAY";
            default:
                return "UNKNOWN(" + (b & 255) + ')';
        }
    }

    private static Map<SerializationFormat, ThreadLocalTProtocol> createFormatToThreadLocalTProtocolMap() {
        return Collections.unmodifiableMap((Map) SerializationFormat.ofThrift().stream().collect(Collectors.toMap(Function.identity(), serializationFormat -> {
            return new ThreadLocalTProtocol(ThriftProtocolFactories.get(serializationFormat));
        })));
    }

    private static void getAllInterfaces(Class<?> cls, Set<Class<?>> set) {
        while (cls != null) {
            for (Class<?> cls2 : cls.getInterfaces()) {
                if (set.add(cls2)) {
                    getAllInterfaces(cls2, set);
                }
            }
            cls = cls.getSuperclass();
        }
    }
}
