package io.kroxylicious.proxy.internal;

import edu.umd.cs.findbugs.annotations.NonNull;
import io.kroxylicious.proxy.filter.Filter;
import io.kroxylicious.proxy.filter.FilterAndInvoker;
import io.kroxylicious.proxy.filter.FilterContext;
import io.kroxylicious.proxy.filter.FilterInvoker;
import io.kroxylicious.proxy.filter.FilterResult;
import io.kroxylicious.proxy.filter.RequestFilterResult;
import io.kroxylicious.proxy.filter.RequestFilterResultBuilder;
import io.kroxylicious.proxy.filter.ResponseFilterResult;
import io.kroxylicious.proxy.filter.ResponseFilterResultBuilder;
import io.kroxylicious.proxy.frame.DecodedFrame;
import io.kroxylicious.proxy.frame.DecodedRequestFrame;
import io.kroxylicious.proxy.frame.DecodedResponseFrame;
import io.kroxylicious.proxy.frame.OpaqueRequestFrame;
import io.kroxylicious.proxy.frame.OpaqueResponseFrame;
import io.kroxylicious.proxy.frame.RequestFrame;
import io.kroxylicious.proxy.internal.filter.RequestFilterResultBuilderImpl;
import io.kroxylicious.proxy.internal.filter.ResponseFilterResultBuilderImpl;
import io.kroxylicious.proxy.internal.util.Assertions;
import io.kroxylicious.proxy.internal.util.ByteBufOutputStream;
import io.kroxylicious.proxy.model.VirtualCluster;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import javax.annotation.Nullable;
import org.apache.kafka.common.message.ProduceRequestData;
import org.apache.kafka.common.message.RequestHeaderData;
import org.apache.kafka.common.message.ResponseHeaderData;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.utils.ByteBufferOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/kroxylicious/proxy/internal/FilterHandler.class */
public class FilterHandler extends ChannelDuplexHandler {
    private static final Logger LOGGER = LoggerFactory.getLogger(FilterHandler.class);
    private final Filter filter;
    private final FilterInvoker invoker;
    private final long timeoutMs;
    private final String sniHostname;
    private final VirtualCluster virtualCluster;
    private final Channel inboundChannel;
    private CompletableFuture<Void> writeFuture = CompletableFuture.completedFuture(null);
    private CompletableFuture<Void> readFuture = CompletableFuture.completedFuture(null);
    private ChannelHandlerContext ctx;
    private PromiseFactory promiseFactory;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/kroxylicious/proxy/internal/FilterHandler$InternalFilterContext.class */
    public class InternalFilterContext implements FilterContext {
        private final DecodedFrame<?, ?> decodedFrame;

        InternalFilterContext(DecodedFrame<?, ?> decodedFrame) {
            this.decodedFrame = decodedFrame;
        }

        public String channelDescriptor() {
            return FilterHandler.this.channelDescriptor();
        }

        public ByteBufferOutputStream createByteBufferOutputStream(int i) {
            ByteBuf ioBuffer = FilterHandler.this.ctx.alloc().ioBuffer(i);
            this.decodedFrame.add(ioBuffer);
            return new ByteBufOutputStream(ioBuffer);
        }

        @Nullable
        public String sniHostname() {
            return FilterHandler.this.sniHostname;
        }

        public String getVirtualClusterName() {
            return FilterHandler.this.virtualCluster.getClusterName();
        }

        public RequestFilterResultBuilder requestFilterResultBuilder() {
            return new RequestFilterResultBuilderImpl();
        }

        public ResponseFilterResultBuilder responseFilterResultBuilder() {
            return new ResponseFilterResultBuilderImpl();
        }

        public CompletionStage<RequestFilterResult> forwardRequest(RequestHeaderData requestHeaderData, ApiMessage apiMessage) {
            return requestFilterResultBuilder().forward(requestHeaderData, apiMessage).completed();
        }

        public CompletionStage<ResponseFilterResult> forwardResponse(ResponseHeaderData responseHeaderData, ApiMessage apiMessage) {
            return responseFilterResultBuilder().forward(responseHeaderData, apiMessage).completed();
        }

        @NonNull
        public <M extends ApiMessage> CompletionStage<M> sendRequest(@NonNull RequestHeaderData requestHeaderData, @NonNull ApiMessage apiMessage) {
            Objects.requireNonNull(requestHeaderData);
            Objects.requireNonNull(apiMessage);
            ApiKeys forId = ApiKeys.forId(apiMessage.apiKey());
            requestHeaderData.setRequestApiKey(forId.id);
            requestHeaderData.setCorrelationId(-1);
            if (!forId.isVersionSupported(requestHeaderData.requestApiVersion())) {
                throw new IllegalArgumentException("apiKey %s does not support version %d. the supported version range for this api key is %d...%d (inclusive).".formatted(forId, Short.valueOf(requestHeaderData.requestApiVersion()), Short.valueOf(forId.oldestVersion()), Short.valueOf(forId.latestVersion())));
            }
            boolean z = (forId == ApiKeys.PRODUCE && ((ProduceRequestData) apiMessage).acks() == 0) ? false : true;
            CompletableFuture newTimeLimitedPromise = FilterHandler.this.promiseFactory.newTimeLimitedPromise(() -> {
                return "Asynchronous %s request made by filter %s failed to complete within %s ms.".formatted(forId, FilterHandler.this.filterDescriptor(), Long.valueOf(FilterHandler.this.timeoutMs));
            });
            InternalRequestFrame internalRequestFrame = new InternalRequestFrame(requestHeaderData.requestApiVersion(), requestHeaderData.correlationId(), z, FilterHandler.this.filter, newTimeLimitedPromise, requestHeaderData, apiMessage);
            if (FilterHandler.LOGGER.isDebugEnabled()) {
                FilterHandler.LOGGER.debug("{}: Sending request: {}", FilterHandler.this.channelDescriptor(), internalRequestFrame);
            }
            ChannelPromise newPromise = FilterHandler.this.ctx.channel().newPromise();
            FilterHandler.this.ctx.writeAndFlush(internalRequestFrame, newPromise);
            if (!z) {
                newPromise.addListener(future -> {
                    if (future.isSuccess()) {
                        newTimeLimitedPromise.complete(null);
                    } else {
                        newTimeLimitedPromise.completeExceptionally(future.cause());
                    }
                });
            }
            return newTimeLimitedPromise.minimalCompletionStage();
        }
    }

    public FilterHandler(FilterAndInvoker filterAndInvoker, long j, String str, VirtualCluster virtualCluster, Channel channel) {
        this.filter = ((FilterAndInvoker) Objects.requireNonNull(filterAndInvoker)).filter();
        this.invoker = filterAndInvoker.invoker();
        this.timeoutMs = Assertions.requireStrictlyPositive(j, "timeout");
        this.sniHostname = str;
        this.virtualCluster = virtualCluster;
        this.inboundChannel = channel;
    }

    String filterDescriptor() {
        return this.filter.getClass().getSimpleName() + "@" + System.identityHashCode(this.filter);
    }

    public void channelActive(ChannelHandlerContext channelHandlerContext) throws Exception {
        this.ctx = channelHandlerContext;
        this.promiseFactory = new PromiseFactory(channelHandlerContext.executor(), this.timeoutMs, TimeUnit.MILLISECONDS, LOGGER.getName());
        super.channelActive(channelHandlerContext);
    }

    public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) {
        if (obj instanceof InternalResponseFrame) {
            InternalResponseFrame<?> internalResponseFrame = (InternalResponseFrame) obj;
            if (internalResponseFrame.isRecipient(this.filter)) {
                completeInternalResponse(internalResponseFrame);
                return;
            } else {
                readDecodedResponse(internalResponseFrame);
                return;
            }
        }
        if (!(obj instanceof DecodedResponseFrame)) {
            if (!(obj instanceof OpaqueResponseFrame)) {
                throw new IllegalStateException("Unexpected message reading from upstream:  " + String.valueOf(obj));
            }
            this.readFuture = this.readFuture.whenComplete((r5, th) -> {
                if (channelHandlerContext.channel().isOpen()) {
                    channelHandlerContext.fireChannelRead(obj);
                }
            });
        } else {
            DecodedResponseFrame<?> decodedResponseFrame = (DecodedResponseFrame) obj;
            if (this.readFuture.isDone()) {
                this.readFuture = readDecodedResponse(decodedResponseFrame);
            } else {
                this.readFuture = this.readFuture.thenCompose(r6 -> {
                    return channelHandlerContext.channel().isOpen() ? readDecodedResponse(decodedResponseFrame) : CompletableFuture.completedFuture(null);
                }).exceptionally((Function<Throwable, ? extends U>) th2 -> {
                    return null;
                });
            }
        }
    }

    public void write(ChannelHandlerContext channelHandlerContext, Object obj, ChannelPromise channelPromise) throws Exception {
        if (obj instanceof InternalRequestFrame) {
            writeDecodedRequest((InternalRequestFrame) obj, channelPromise);
            return;
        }
        if (!(obj instanceof DecodedRequestFrame)) {
            if (!(obj instanceof OpaqueRequestFrame) && obj != Unpooled.EMPTY_BUFFER) {
                throw new IllegalStateException("Unexpected message writing to upstream: " + String.valueOf(obj));
            }
            this.writeFuture.whenComplete((r7, th) -> {
                if (channelHandlerContext.channel().isOpen()) {
                    channelHandlerContext.write(obj, channelPromise);
                }
            });
            return;
        }
        DecodedRequestFrame<?> decodedRequestFrame = (DecodedRequestFrame) obj;
        if (this.writeFuture.isDone()) {
            this.writeFuture = writeDecodedRequest(decodedRequestFrame, channelPromise);
        } else {
            this.writeFuture = this.writeFuture.thenCompose(r8 -> {
                return channelHandlerContext.channel().isOpen() ? writeDecodedRequest(decodedRequestFrame, channelPromise) : CompletableFuture.completedFuture(null);
            }).exceptionally((Function<Throwable, ? extends U>) th2 -> {
                return null;
            });
        }
    }

    private CompletableFuture<Void> readDecodedResponse(DecodedResponseFrame<?> decodedResponseFrame) {
        CompletableFuture<ResponseFilterResult> dispatchDecodedResponseFrame = dispatchDecodedResponseFrame(decodedResponseFrame, new InternalFilterContext(decodedResponseFrame));
        return !dispatchDecodedResponseFrame.isDone() ? configureResponseFilterChain(decodedResponseFrame, handleDeferredStage(decodedResponseFrame, dispatchDecodedResponseFrame)).whenComplete(this::deferredResponseCompleted).thenApply(responseFilterResult -> {
            return null;
        }) : configureResponseFilterChain(decodedResponseFrame, dispatchDecodedResponseFrame).thenApply(responseFilterResult2 -> {
            return null;
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r4v1, types: [org.apache.kafka.common.protocol.ApiMessage] */
    private CompletableFuture<ResponseFilterResult> dispatchDecodedResponseFrame(DecodedResponseFrame<?> decodedResponseFrame, InternalFilterContext internalFilterContext) {
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Dispatching upstream {} response to filter {}: {}", new Object[]{channelDescriptor(), decodedResponseFrame.apiKey(), filterDescriptor(), decodedResponseFrame});
        }
        CompletionStage<ResponseFilterResult> onResponse = this.invoker.onResponse(decodedResponseFrame.apiKey(), decodedResponseFrame.apiVersion(), decodedResponseFrame.header(), decodedResponseFrame.body(), internalFilterContext);
        return onResponse instanceof InternalCompletionStage ? ((InternalCompletionStage) onResponse).getUnderlyingCompletableFuture() : onResponse.toCompletableFuture();
    }

    private CompletableFuture<ResponseFilterResult> configureResponseFilterChain(DecodedResponseFrame<?> decodedResponseFrame, CompletableFuture<ResponseFilterResult> completableFuture) {
        return completableFuture.thenApply((v0) -> {
            return validateFilterResultNonNull(v0);
        }).thenApply((Function<? super U, ? extends U>) responseFilterResult -> {
            return handleResponseFilterResult(decodedResponseFrame, responseFilterResult);
        }).exceptionally(th -> {
            return handleFilteringException(th, decodedResponseFrame);
        });
    }

    private CompletableFuture<Void> writeDecodedRequest(DecodedRequestFrame<?> decodedRequestFrame, ChannelPromise channelPromise) {
        CompletableFuture<RequestFilterResult> dispatchDecodedRequest = dispatchDecodedRequest(decodedRequestFrame, new InternalFilterContext(decodedRequestFrame));
        return !dispatchDecodedRequest.isDone() ? configureRequestFilterChain(decodedRequestFrame, channelPromise, handleDeferredStage(decodedRequestFrame, dispatchDecodedRequest)).whenComplete(this::deferredRequestCompleted).thenApply(requestFilterResult -> {
            return null;
        }) : configureRequestFilterChain(decodedRequestFrame, channelPromise, dispatchDecodedRequest).thenApply(requestFilterResult2 -> {
            return null;
        });
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r4v1, types: [org.apache.kafka.common.protocol.ApiMessage] */
    private CompletableFuture<RequestFilterResult> dispatchDecodedRequest(DecodedRequestFrame<?> decodedRequestFrame, InternalFilterContext internalFilterContext) {
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Dispatching downstream {} request to filter{}: {}", new Object[]{channelDescriptor(), decodedRequestFrame.apiKey(), filterDescriptor(), decodedRequestFrame});
        }
        CompletionStage<RequestFilterResult> onRequest = this.invoker.onRequest(decodedRequestFrame.apiKey(), decodedRequestFrame.apiVersion(), decodedRequestFrame.header(), decodedRequestFrame.body(), internalFilterContext);
        return onRequest instanceof InternalCompletionStage ? ((InternalCompletionStage) onRequest).getUnderlyingCompletableFuture() : onRequest.toCompletableFuture();
    }

    private CompletableFuture<RequestFilterResult> configureRequestFilterChain(DecodedRequestFrame<?> decodedRequestFrame, ChannelPromise channelPromise, CompletableFuture<RequestFilterResult> completableFuture) {
        return completableFuture.thenApply((v0) -> {
            return validateFilterResultNonNull(v0);
        }).thenApply((Function<? super U, ? extends U>) requestFilterResult -> {
            return handleRequestFilterResult(decodedRequestFrame, channelPromise, requestFilterResult);
        }).exceptionally(th -> {
            return handleFilteringException(th, decodedRequestFrame);
        });
    }

    private ResponseFilterResult handleResponseFilterResult(DecodedResponseFrame<?> decodedResponseFrame, ResponseFilterResult responseFilterResult) {
        if (responseFilterResult.drop()) {
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("{}: Filter{} drops {} response", new Object[]{channelDescriptor(), filterDescriptor(), decodedResponseFrame.apiKey()});
            }
            return responseFilterResult;
        }
        if (responseFilterResult.message() != null) {
            forwardResponse(decodedResponseFrame, responseFilterResult.header() == null ? decodedResponseFrame.header() : (ResponseHeaderData) responseFilterResult.header(), responseFilterResult.message());
        }
        if (responseFilterResult.closeConnection()) {
            closeConnection();
        }
        return responseFilterResult;
    }

    private RequestFilterResult handleRequestFilterResult(DecodedRequestFrame<?> decodedRequestFrame, ChannelPromise channelPromise, RequestFilterResult requestFilterResult) {
        if (requestFilterResult.drop()) {
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("{}: Filter{} drops {} request", new Object[]{channelDescriptor(), filterDescriptor(), decodedRequestFrame.apiKey()});
            }
            return requestFilterResult;
        }
        if (requestFilterResult.message() != null) {
            if (requestFilterResult.shortCircuitResponse()) {
                forwardShortCircuitResponse(decodedRequestFrame, requestFilterResult);
            } else {
                forwardRequest(decodedRequestFrame, requestFilterResult, channelPromise);
            }
        }
        if (requestFilterResult.closeConnection()) {
            if (requestFilterResult.message() != null) {
                this.ctx.flush();
            }
            closeConnection();
        }
        return requestFilterResult;
    }

    private <F extends FilterResult> F handleFilteringException(Throwable th, DecodedFrame<?, ?> decodedFrame) {
        if (LOGGER.isWarnEnabled()) {
            LOGGER.atWarn().setMessage("{}: Filter{} for {} {} ended exceptionally - closing connection. Cause message {}").addArgument(channelDescriptor()).addArgument(decodedFrame.header() instanceof RequestHeaderData ? "request" : "response").addArgument(filterDescriptor()).addArgument(decodedFrame.apiKey()).addArgument(th.getMessage()).setCause(LOGGER.isDebugEnabled() ? th : null).log();
        }
        closeConnection();
        return null;
    }

    private <F extends FilterResult> CompletableFuture<F> handleDeferredStage(DecodedFrame<?, ?> decodedFrame, CompletableFuture<F> completableFuture) {
        this.inboundChannel.config().setAutoRead(false);
        this.promiseFactory.wrapWithTimeLimit(completableFuture, () -> {
            Object[] objArr = new Object[4];
            objArr[0] = filterDescriptor();
            objArr[1] = Long.valueOf(this.timeoutMs);
            objArr[2] = decodedFrame instanceof DecodedRequestFrame ? "request" : "response";
            objArr[3] = decodedFrame.apiKey();
            return "Deferred work for filter %s did not complete processing within %s ms %s %s".formatted(objArr);
        });
        return (CompletableFuture<F>) completableFuture.thenApplyAsync(filterResult -> {
            return filterResult;
        }, (Executor) this.ctx.executor());
    }

    private void deferredResponseCompleted(ResponseFilterResult responseFilterResult, Throwable th) {
        this.inboundChannel.config().setAutoRead(true);
        this.readFuture.whenComplete((r3, th2) -> {
            this.inboundChannel.flush();
        });
    }

    private void deferredRequestCompleted(RequestFilterResult requestFilterResult, Throwable th) {
        this.inboundChannel.config().setAutoRead(true);
        this.ctx.flush();
        this.writeFuture.whenComplete((r3, th2) -> {
            this.ctx.flush();
        });
        this.inboundChannel.flush();
    }

    private void forwardRequest(DecodedRequestFrame<?> decodedRequestFrame, RequestFilterResult requestFilterResult, ChannelPromise channelPromise) {
        RequestHeaderData header = requestFilterResult.header() == null ? decodedRequestFrame.header() : requestFilterResult.header();
        ApiMessage message = requestFilterResult.message();
        if (decodedRequestFrame.body() != message) {
            throw new IllegalStateException();
        }
        if (decodedRequestFrame.header() != header) {
            throw new IllegalStateException();
        }
        String name = message.getClass().getName();
        if (!name.endsWith("RequestData")) {
            throw new AssertionError("Attempt to use forwardRequest with a non-request: " + name);
        }
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Forwarding request: {}", channelDescriptor(), decodedRequestFrame);
        }
        this.ctx.write(decodedRequestFrame, channelPromise);
    }

    private void forwardResponse(DecodedFrame<?, ?> decodedFrame, ResponseHeaderData responseHeaderData, ApiMessage apiMessage) {
        String name = apiMessage.getClass().getName();
        if (!name.endsWith("ResponseData")) {
            throw new AssertionError("Attempt to use forwardResponse with a non-response: " + name);
        }
        if (!(decodedFrame instanceof RequestFrame)) {
            if (decodedFrame.body() != apiMessage) {
                throw new AssertionError();
            }
            if (decodedFrame.header() != responseHeaderData) {
                throw new AssertionError();
            }
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("{}: Forwarding response: {}", channelDescriptor(), decodedFrame);
            }
            this.ctx.fireChannelRead(decodedFrame);
            return;
        }
        if (apiMessage.apiKey() != decodedFrame.apiKey().id) {
            throw new AssertionError("Attempt to respond with ApiMessage of type " + String.valueOf(ApiKeys.forId(apiMessage.apiKey())) + " but request is of type " + String.valueOf(decodedFrame.apiKey()));
        }
        DecodedResponseFrame decodedResponseFrame = new DecodedResponseFrame(decodedFrame.apiVersion(), decodedFrame.correlationId(), responseHeaderData, apiMessage);
        decodedFrame.transferBuffersTo(decodedResponseFrame);
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Forwarding response: {}", channelDescriptor(), decodedFrame);
        }
        this.ctx.fireChannelRead(decodedResponseFrame);
        this.ctx.fireChannelReadComplete();
    }

    private void forwardShortCircuitResponse(DecodedRequestFrame<?> decodedRequestFrame, RequestFilterResult requestFilterResult) {
        if (decodedRequestFrame.hasResponse()) {
            ResponseHeaderData responseHeaderData = requestFilterResult.header() == null ? new ResponseHeaderData() : (ResponseHeaderData) requestFilterResult.header();
            responseHeaderData.setCorrelationId(decodedRequestFrame.correlationId());
            forwardResponse(decodedRequestFrame, responseHeaderData, requestFilterResult.message());
        } else if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Filter {} attempted to short-circuit respond to a message with apiKey {} that has no response in the Kafka Protocol, dropping response", new Object[]{channelDescriptor(), filterDescriptor(), decodedRequestFrame.apiKey()});
        }
    }

    private void closeConnection() {
        this.ctx.close().addListener(future -> {
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("{}: Channel closed", channelDescriptor());
            }
        });
    }

    private String channelDescriptor() {
        return this.ctx.channel().toString();
    }

    private void completeInternalResponse(InternalResponseFrame<?> internalResponseFrame) {
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Completing {} response for request sent by this filter{}: {}", new Object[]{channelDescriptor(), internalResponseFrame.apiKey(), filterDescriptor(), internalResponseFrame});
        }
        internalResponseFrame.promise().complete(internalResponseFrame.body());
    }

    private static <F extends FilterResult> F validateFilterResultNonNull(F f) {
        return (F) Objects.requireNonNullElseGet(f, () -> {
            throw new IllegalStateException("filter completion must not yield a null result");
        });
    }
}
