/*
 * Decompiled with CFR 0.152.
 */
package io.kroxylicious.proxy.internal;

import edu.umd.cs.findbugs.annotations.NonNull;
import io.kroxylicious.proxy.filter.Filter;
import io.kroxylicious.proxy.filter.FilterAndInvoker;
import io.kroxylicious.proxy.filter.FilterContext;
import io.kroxylicious.proxy.filter.FilterInvoker;
import io.kroxylicious.proxy.filter.FilterResult;
import io.kroxylicious.proxy.filter.RequestFilterResult;
import io.kroxylicious.proxy.filter.RequestFilterResultBuilder;
import io.kroxylicious.proxy.filter.ResponseFilterResult;
import io.kroxylicious.proxy.filter.ResponseFilterResultBuilder;
import io.kroxylicious.proxy.frame.DecodedFrame;
import io.kroxylicious.proxy.frame.DecodedRequestFrame;
import io.kroxylicious.proxy.frame.DecodedResponseFrame;
import io.kroxylicious.proxy.frame.OpaqueRequestFrame;
import io.kroxylicious.proxy.frame.OpaqueResponseFrame;
import io.kroxylicious.proxy.frame.RequestFrame;
import io.kroxylicious.proxy.internal.InternalCompletionStage;
import io.kroxylicious.proxy.internal.InternalRequestFrame;
import io.kroxylicious.proxy.internal.InternalResponseFrame;
import io.kroxylicious.proxy.internal.PromiseFactory;
import io.kroxylicious.proxy.internal.filter.RequestFilterResultBuilderImpl;
import io.kroxylicious.proxy.internal.filter.ResponseFilterResultBuilderImpl;
import io.kroxylicious.proxy.internal.util.Assertions;
import io.kroxylicious.proxy.internal.util.ByteBufOutputStream;
import io.kroxylicious.proxy.model.VirtualCluster;
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelDuplexHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelPromise;
import java.util.Objects;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Executor;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nullable;
import org.apache.kafka.common.message.ProduceRequestData;
import org.apache.kafka.common.message.RequestHeaderData;
import org.apache.kafka.common.message.ResponseHeaderData;
import org.apache.kafka.common.protocol.ApiKeys;
import org.apache.kafka.common.protocol.ApiMessage;
import org.apache.kafka.common.utils.ByteBufferOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class FilterHandler
extends ChannelDuplexHandler {
    private static final Logger LOGGER = LoggerFactory.getLogger(FilterHandler.class);
    private final Filter filter;
    private final FilterInvoker invoker;
    private final long timeoutMs;
    private final String sniHostname;
    private final VirtualCluster virtualCluster;
    private final Channel inboundChannel;
    private CompletableFuture<Void> writeFuture = CompletableFuture.completedFuture(null);
    private CompletableFuture<Void> readFuture = CompletableFuture.completedFuture(null);
    private ChannelHandlerContext ctx;
    private PromiseFactory promiseFactory;

    public FilterHandler(FilterAndInvoker filterAndInvoker, long timeoutMs, String sniHostname, VirtualCluster virtualCluster, Channel inboundChannel) {
        this.filter = Objects.requireNonNull(filterAndInvoker).filter();
        this.invoker = filterAndInvoker.invoker();
        this.timeoutMs = Assertions.requireStrictlyPositive(timeoutMs, "timeout");
        this.sniHostname = sniHostname;
        this.virtualCluster = virtualCluster;
        this.inboundChannel = inboundChannel;
    }

    String filterDescriptor() {
        return this.filter.getClass().getSimpleName() + "@" + System.identityHashCode(this.filter);
    }

    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        this.ctx = ctx;
        this.promiseFactory = new PromiseFactory((ScheduledExecutorService)ctx.executor(), this.timeoutMs, TimeUnit.MILLISECONDS, LOGGER.getName());
        super.channelActive(ctx);
    }

    public void channelRead(ChannelHandlerContext ctx, Object msg) {
        if (msg instanceof InternalResponseFrame) {
            InternalResponseFrame decodedFrame = (InternalResponseFrame)msg;
            if (decodedFrame.isRecipient(this.filter)) {
                this.completeInternalResponse(decodedFrame);
            } else {
                this.readDecodedResponse(decodedFrame);
            }
        } else if (msg instanceof DecodedResponseFrame) {
            DecodedResponseFrame decodedFrame = (DecodedResponseFrame)msg;
            this.readFuture = this.readFuture.isDone() ? this.readDecodedResponse(decodedFrame) : ((CompletableFuture)this.readFuture.thenCompose(ignored -> {
                if (ctx.channel().isOpen()) {
                    return this.readDecodedResponse(decodedFrame);
                }
                return CompletableFuture.completedFuture(null);
            })).exceptionally(throwable -> null);
        } else {
            if (!(msg instanceof OpaqueResponseFrame)) {
                throw new IllegalStateException("Unexpected message reading from upstream:  " + String.valueOf(msg));
            }
            this.readFuture = this.readFuture.whenComplete((a, b) -> {
                if (ctx.channel().isOpen()) {
                    ctx.fireChannelRead(msg);
                }
            });
        }
    }

    public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
        if (msg instanceof InternalRequestFrame) {
            InternalRequestFrame decodedFrame = (InternalRequestFrame)msg;
            this.writeDecodedRequest(decodedFrame, promise);
        } else if (msg instanceof DecodedRequestFrame) {
            DecodedRequestFrame decodedFrame = (DecodedRequestFrame)msg;
            this.writeFuture = this.writeFuture.isDone() ? this.writeDecodedRequest(decodedFrame, promise) : ((CompletableFuture)this.writeFuture.thenCompose(ignored -> {
                if (ctx.channel().isOpen()) {
                    return this.writeDecodedRequest(decodedFrame, promise);
                }
                return CompletableFuture.completedFuture(null);
            })).exceptionally(throwable -> null);
        } else {
            if (!(msg instanceof OpaqueRequestFrame) && msg != Unpooled.EMPTY_BUFFER) {
                throw new IllegalStateException("Unexpected message writing to upstream: " + String.valueOf(msg));
            }
            this.writeFuture.whenComplete((unused, throwable) -> {
                if (ctx.channel().isOpen()) {
                    ctx.write(msg, promise);
                }
            });
        }
    }

    private CompletableFuture<Void> readDecodedResponse(DecodedResponseFrame<?> decodedFrame) {
        boolean defer;
        InternalFilterContext filterContext = new InternalFilterContext(decodedFrame);
        CompletableFuture<ResponseFilterResult> future = this.dispatchDecodedResponseFrame(decodedFrame, filterContext);
        boolean bl = defer = !future.isDone();
        if (defer) {
            return ((CompletableFuture)this.configureResponseFilterChain(decodedFrame, this.handleDeferredStage(decodedFrame, future)).whenComplete(this::deferredResponseCompleted)).thenApply(responseFilterResult -> null);
        }
        return this.configureResponseFilterChain(decodedFrame, future).thenApply(responseFilterResult -> null);
    }

    private CompletableFuture<ResponseFilterResult> dispatchDecodedResponseFrame(DecodedResponseFrame<?> decodedFrame, InternalFilterContext filterContext) {
        CompletionStage<ResponseFilterResult> stage;
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Dispatching upstream {} response to filter {}: {}", new Object[]{this.channelDescriptor(), decodedFrame.apiKey(), this.filterDescriptor(), decodedFrame});
        }
        return (stage = this.invoker.onResponse(decodedFrame.apiKey(), decodedFrame.apiVersion(), (ResponseHeaderData)decodedFrame.header(), (ApiMessage)decodedFrame.body(), filterContext)) instanceof InternalCompletionStage ? ((InternalCompletionStage)stage).getUnderlyingCompletableFuture() : stage.toCompletableFuture();
    }

    private CompletableFuture<ResponseFilterResult> configureResponseFilterChain(DecodedResponseFrame<?> decodedFrame, CompletableFuture<ResponseFilterResult> future) {
        return ((CompletableFuture)((CompletableFuture)future.thenApply(FilterHandler::validateFilterResultNonNull)).thenApply(fr -> this.handleResponseFilterResult(decodedFrame, (ResponseFilterResult)fr))).exceptionally(t -> (ResponseFilterResult)this.handleFilteringException((Throwable)t, (DecodedFrame<?, ?>)decodedFrame));
    }

    private CompletableFuture<Void> writeDecodedRequest(DecodedRequestFrame<?> decodedFrame, ChannelPromise promise) {
        boolean defer;
        InternalFilterContext filterContext = new InternalFilterContext(decodedFrame);
        CompletableFuture<RequestFilterResult> future = this.dispatchDecodedRequest(decodedFrame, filterContext);
        boolean bl = defer = !future.isDone();
        if (defer) {
            return ((CompletableFuture)this.configureRequestFilterChain(decodedFrame, promise, this.handleDeferredStage(decodedFrame, future)).whenComplete(this::deferredRequestCompleted)).thenApply(requestFilterResult -> null);
        }
        return this.configureRequestFilterChain(decodedFrame, promise, future).thenApply(requestFilterResult -> null);
    }

    private CompletableFuture<RequestFilterResult> dispatchDecodedRequest(DecodedRequestFrame<?> decodedFrame, InternalFilterContext filterContext) {
        CompletionStage<RequestFilterResult> stage;
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Dispatching downstream {} request to filter{}: {}", new Object[]{this.channelDescriptor(), decodedFrame.apiKey(), this.filterDescriptor(), decodedFrame});
        }
        return (stage = this.invoker.onRequest(decodedFrame.apiKey(), decodedFrame.apiVersion(), (RequestHeaderData)decodedFrame.header(), (ApiMessage)decodedFrame.body(), filterContext)) instanceof InternalCompletionStage ? ((InternalCompletionStage)stage).getUnderlyingCompletableFuture() : stage.toCompletableFuture();
    }

    private CompletableFuture<RequestFilterResult> configureRequestFilterChain(DecodedRequestFrame<?> decodedFrame, ChannelPromise promise, CompletableFuture<RequestFilterResult> future) {
        return ((CompletableFuture)((CompletableFuture)future.thenApply(FilterHandler::validateFilterResultNonNull)).thenApply(fr -> this.handleRequestFilterResult(decodedFrame, promise, (RequestFilterResult)fr))).exceptionally(t -> (RequestFilterResult)this.handleFilteringException((Throwable)t, (DecodedFrame<?, ?>)decodedFrame));
    }

    private ResponseFilterResult handleResponseFilterResult(DecodedResponseFrame<?> decodedFrame, ResponseFilterResult responseFilterResult) {
        if (responseFilterResult.drop()) {
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("{}: Filter{} drops {} response", new Object[]{this.channelDescriptor(), this.filterDescriptor(), decodedFrame.apiKey()});
            }
            return responseFilterResult;
        }
        if (responseFilterResult.message() != null) {
            ResponseHeaderData header = responseFilterResult.header() == null ? (ResponseHeaderData)decodedFrame.header() : (ResponseHeaderData)responseFilterResult.header();
            this.forwardResponse(decodedFrame, header, responseFilterResult.message());
        }
        if (responseFilterResult.closeConnection()) {
            this.closeConnection();
        }
        return responseFilterResult;
    }

    private RequestFilterResult handleRequestFilterResult(DecodedRequestFrame<?> decodedFrame, ChannelPromise promise, RequestFilterResult requestFilterResult) {
        if (requestFilterResult.drop()) {
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("{}: Filter{} drops {} request", new Object[]{this.channelDescriptor(), this.filterDescriptor(), decodedFrame.apiKey()});
            }
            return requestFilterResult;
        }
        if (requestFilterResult.message() != null) {
            if (requestFilterResult.shortCircuitResponse()) {
                this.forwardShortCircuitResponse(decodedFrame, requestFilterResult);
            } else {
                this.forwardRequest(decodedFrame, requestFilterResult, promise);
            }
        }
        if (requestFilterResult.closeConnection()) {
            if (requestFilterResult.message() != null) {
                this.ctx.flush();
            }
            this.closeConnection();
        }
        return requestFilterResult;
    }

    private <F extends FilterResult> F handleFilteringException(Throwable t, DecodedFrame<?, ?> decodedFrame) {
        if (LOGGER.isWarnEnabled()) {
            String direction = decodedFrame.header() instanceof RequestHeaderData ? "request" : "response";
            LOGGER.atWarn().setMessage("{}: Filter{} for {} {} ended exceptionally - closing connection. Cause message {}").addArgument((Object)this.channelDescriptor()).addArgument((Object)direction).addArgument((Object)this.filterDescriptor()).addArgument((Object)decodedFrame.apiKey()).addArgument((Object)t.getMessage()).setCause(LOGGER.isDebugEnabled() ? t : null).log();
        }
        this.closeConnection();
        return null;
    }

    private <F extends FilterResult> CompletableFuture<F> handleDeferredStage(DecodedFrame<?, ?> decodedFrame, CompletableFuture<F> future) {
        this.inboundChannel.config().setAutoRead(false);
        this.promiseFactory.wrapWithTimeLimit(future, () -> "Deferred work for filter %s did not complete processing within %s ms %s %s".formatted(this.filterDescriptor(), this.timeoutMs, decodedFrame instanceof DecodedRequestFrame ? "request" : "response", decodedFrame.apiKey()));
        return future.thenApplyAsync(filterResult -> filterResult, (Executor)this.ctx.executor());
    }

    private void deferredResponseCompleted(ResponseFilterResult ignored, Throwable throwable) {
        this.inboundChannel.config().setAutoRead(true);
        this.readFuture.whenComplete((u, t) -> this.inboundChannel.flush());
    }

    private void deferredRequestCompleted(RequestFilterResult ignored, Throwable throwable) {
        this.inboundChannel.config().setAutoRead(true);
        this.ctx.flush();
        this.writeFuture.whenComplete((u, t) -> this.ctx.flush());
        this.inboundChannel.flush();
    }

    private void forwardRequest(DecodedRequestFrame<?> decodedFrame, RequestFilterResult requestFilterResult, ChannelPromise promise) {
        Object header = requestFilterResult.header() == null ? decodedFrame.header() : requestFilterResult.header();
        ApiMessage message = requestFilterResult.message();
        if (decodedFrame.body() != message) {
            throw new IllegalStateException();
        }
        if (decodedFrame.header() != header) {
            throw new IllegalStateException();
        }
        String name = message.getClass().getName();
        if (!name.endsWith("RequestData")) {
            throw new AssertionError((Object)("Attempt to use forwardRequest with a non-request: " + name));
        }
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Forwarding request: {}", (Object)this.channelDescriptor(), decodedFrame);
        }
        this.ctx.write(decodedFrame, promise);
    }

    private void forwardResponse(DecodedFrame<?, ?> decodedFrame, ResponseHeaderData header, ApiMessage message) {
        String name = message.getClass().getName();
        if (!name.endsWith("ResponseData")) {
            throw new AssertionError((Object)("Attempt to use forwardResponse with a non-response: " + name));
        }
        if (decodedFrame instanceof RequestFrame) {
            if (message.apiKey() != decodedFrame.apiKey().id) {
                throw new AssertionError((Object)("Attempt to respond with ApiMessage of type " + String.valueOf(ApiKeys.forId((int)message.apiKey())) + " but request is of type " + String.valueOf(decodedFrame.apiKey())));
            }
            DecodedResponseFrame<ApiMessage> responseFrame = new DecodedResponseFrame<ApiMessage>(decodedFrame.apiVersion(), decodedFrame.correlationId(), header, message);
            decodedFrame.transferBuffersTo(responseFrame);
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("{}: Forwarding response: {}", (Object)this.channelDescriptor(), decodedFrame);
            }
            this.ctx.fireChannelRead(responseFrame);
            this.ctx.fireChannelReadComplete();
        } else {
            if (decodedFrame.body() != message) {
                throw new AssertionError();
            }
            if (decodedFrame.header() != header) {
                throw new AssertionError();
            }
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("{}: Forwarding response: {}", (Object)this.channelDescriptor(), decodedFrame);
            }
            this.ctx.fireChannelRead(decodedFrame);
        }
    }

    private void forwardShortCircuitResponse(DecodedRequestFrame<?> decodedFrame, RequestFilterResult requestFilterResult) {
        if (decodedFrame.hasResponse()) {
            ResponseHeaderData header = requestFilterResult.header() == null ? new ResponseHeaderData() : (ResponseHeaderData)requestFilterResult.header();
            header.setCorrelationId(decodedFrame.correlationId());
            this.forwardResponse(decodedFrame, header, requestFilterResult.message());
        } else if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Filter {} attempted to short-circuit respond to a message with apiKey {} that has no response in the Kafka Protocol, dropping response", new Object[]{this.channelDescriptor(), this.filterDescriptor(), decodedFrame.apiKey()});
        }
    }

    private void closeConnection() {
        this.ctx.close().addListener(future -> {
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("{}: Channel closed", (Object)this.channelDescriptor());
            }
        });
    }

    private String channelDescriptor() {
        return this.ctx.channel().toString();
    }

    private void completeInternalResponse(InternalResponseFrame<?> decodedFrame) {
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Completing {} response for request sent by this filter{}: {}", new Object[]{this.channelDescriptor(), decodedFrame.apiKey(), this.filterDescriptor(), decodedFrame});
        }
        CompletableFuture<?> p = decodedFrame.promise();
        p.complete(decodedFrame.body());
    }

    private static <F extends FilterResult> F validateFilterResultNonNull(F f) {
        return (F)Objects.requireNonNullElseGet(f, () -> {
            throw new IllegalStateException("filter completion must not yield a null result");
        });
    }

    private class InternalFilterContext
    implements FilterContext {
        private final DecodedFrame<?, ?> decodedFrame;

        InternalFilterContext(DecodedFrame<?, ?> decodedFrame) {
            this.decodedFrame = decodedFrame;
        }

        public String channelDescriptor() {
            return FilterHandler.this.channelDescriptor();
        }

        public ByteBufferOutputStream createByteBufferOutputStream(int initialCapacity) {
            ByteBuf buffer = FilterHandler.this.ctx.alloc().ioBuffer(initialCapacity);
            this.decodedFrame.add(buffer);
            return new ByteBufOutputStream(buffer);
        }

        @Nullable
        public String sniHostname() {
            return FilterHandler.this.sniHostname;
        }

        public String getVirtualClusterName() {
            return FilterHandler.this.virtualCluster.getClusterName();
        }

        public RequestFilterResultBuilder requestFilterResultBuilder() {
            return new RequestFilterResultBuilderImpl();
        }

        public ResponseFilterResultBuilder responseFilterResultBuilder() {
            return new ResponseFilterResultBuilderImpl();
        }

        public CompletionStage<RequestFilterResult> forwardRequest(RequestHeaderData header, ApiMessage request) {
            return this.requestFilterResultBuilder().forward((ApiMessage)header, request).completed();
        }

        public CompletionStage<ResponseFilterResult> forwardResponse(ResponseHeaderData header, ApiMessage response) {
            return this.responseFilterResultBuilder().forward((ApiMessage)header, response).completed();
        }

        @NonNull
        public <M extends ApiMessage> CompletionStage<M> sendRequest(@NonNull RequestHeaderData header, @NonNull ApiMessage request) {
            Objects.requireNonNull(header);
            Objects.requireNonNull(request);
            ApiKeys apiKey = ApiKeys.forId((int)request.apiKey());
            header.setRequestApiKey(apiKey.id);
            header.setCorrelationId(-1);
            if (!apiKey.isVersionSupported(header.requestApiVersion())) {
                throw new IllegalArgumentException("apiKey %s does not support version %d. the supported version range for this api key is %d...%d (inclusive).".formatted(apiKey, header.requestApiVersion(), apiKey.oldestVersion(), apiKey.latestVersion()));
            }
            boolean hasResponse = apiKey != ApiKeys.PRODUCE || ((ProduceRequestData)request).acks() != 0;
            CompletableFuture filterPromise = FilterHandler.this.promiseFactory.newTimeLimitedPromise(() -> "Asynchronous %s request made by filter %s failed to complete within %s ms.".formatted(apiKey, FilterHandler.this.filterDescriptor(), FilterHandler.this.timeoutMs));
            InternalRequestFrame<ApiMessage> frame = new InternalRequestFrame<ApiMessage>(header.requestApiVersion(), header.correlationId(), hasResponse, FilterHandler.this.filter, filterPromise, header, request);
            if (LOGGER.isDebugEnabled()) {
                LOGGER.debug("{}: Sending request: {}", (Object)FilterHandler.this.channelDescriptor(), frame);
            }
            ChannelPromise writePromise = FilterHandler.this.ctx.channel().newPromise();
            FilterHandler.this.ctx.writeAndFlush(frame, writePromise);
            if (!hasResponse) {
                writePromise.addListener(f -> {
                    if (f.isSuccess()) {
                        filterPromise.complete(null);
                    } else {
                        filterPromise.completeExceptionally(f.cause());
                    }
                });
            }
            return filterPromise.minimalCompletionStage();
        }
    }
}

