package io.kroxylicious.proxy.internal;

import com.fasterxml.jackson.databind.ObjectMapper;
import io.kroxylicious.proxy.filter.FilterAndInvoker;
import io.kroxylicious.proxy.filter.NetFilter;
import io.kroxylicious.proxy.frame.DecodedRequestFrame;
import io.kroxylicious.proxy.frame.DecodedResponseFrame;
import io.kroxylicious.proxy.frame.RequestFrame;
import io.kroxylicious.proxy.internal.codec.CorrelationManager;
import io.kroxylicious.proxy.internal.codec.DecodePredicate;
import io.kroxylicious.proxy.internal.codec.FrameOversizedException;
import io.kroxylicious.proxy.internal.codec.KafkaRequestEncoder;
import io.kroxylicious.proxy.internal.codec.KafkaResponseDecoder;
import io.kroxylicious.proxy.model.VirtualCluster;
import io.kroxylicious.proxy.service.HostPort;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SniCompletionEvent;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.kafka.common.message.ApiVersionsRequestData;
import org.apache.kafka.common.message.ApiVersionsResponseData;
import org.apache.kafka.common.message.ApiVersionsResponseDataJsonConverter;
import org.apache.kafka.common.message.ResponseHeaderData;
import org.apache.kafka.common.protocol.ApiKeys;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/kroxylicious/proxy/internal/KafkaProxyFrontendHandler.class */
public class KafkaProxyFrontendHandler extends ChannelInboundHandlerAdapter implements NetFilter.NetFilterContext {
    private static final Logger LOGGER = LoggerFactory.getLogger(KafkaProxyFrontendHandler.class);
    private static final ApiVersionsResponseData API_VERSIONS_RESPONSE;
    private final boolean logNetwork;
    private final boolean logFrames;
    private final VirtualCluster virtualCluster;
    private ChannelHandlerContext outboundCtx;
    private KafkaProxyBackendHandler backendHandler;
    private boolean pendingFlushes;
    private final NetFilter filter;
    private final SaslDecodePredicate dp;
    private AuthenticationEvent authentication;
    private String clientSoftwareName;
    private String clientSoftwareVersion;
    private String sniHostname;
    private ChannelHandlerContext inboundCtx;
    private List<Object> bufferedMsgs = new ArrayList();
    private boolean pendingReadComplete = true;
    private State state = State.START;
    private boolean isInboundBlocked = true;
    private HAProxyMessage haProxyMessage;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/kroxylicious/proxy/internal/KafkaProxyFrontendHandler$State.class */
    public enum State {
        START,
        HA_PROXY,
        API_VERSIONS,
        CONNECTING,
        CONNECTED,
        OUTBOUND_ACTIVE,
        FAILED
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public KafkaProxyFrontendHandler(NetFilter netFilter, SaslDecodePredicate saslDecodePredicate, VirtualCluster virtualCluster) {
        this.filter = netFilter;
        this.dp = saslDecodePredicate;
        this.virtualCluster = virtualCluster;
        this.logNetwork = virtualCluster.isLogNetwork();
        this.logFrames = virtualCluster.isLogFrames();
    }

    private IllegalStateException illegalState(String str) {
        String name = this.state.name();
        this.state = State.FAILED;
        return new IllegalStateException((str == null ? "" : str + ", ") + "state=" + name);
    }

    State state() {
        return this.state;
    }

    public void outboundChannelActive(ChannelHandlerContext channelHandlerContext) {
        if (this.state != State.CONNECTED) {
            throw illegalState(null);
        }
        LOGGER.trace("{}: outboundChannelActive", this.inboundCtx.channel().id());
        this.outboundCtx = channelHandlerContext;
        Iterator<Object> it = this.bufferedMsgs.iterator();
        while (it.hasNext()) {
            forwardOutbound(channelHandlerContext, it.next());
        }
        this.bufferedMsgs = null;
        if (this.pendingReadComplete) {
            this.pendingReadComplete = false;
            channelReadComplete(channelHandlerContext);
        }
        this.state = State.OUTBOUND_ACTIVE;
        this.inboundCtx.channel().config().setAutoRead(true);
    }

    public void channelWritabilityChanged(ChannelHandlerContext channelHandlerContext) throws Exception {
        super.channelWritabilityChanged(channelHandlerContext);
        if (this.backendHandler != null) {
            this.backendHandler.inboundChannelWritabilityChanged(channelHandlerContext);
        }
    }

    public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) throws Exception {
        if (this.state == State.OUTBOUND_ACTIVE) {
            forwardOutbound(channelHandlerContext, obj);
        } else {
            handlePreOutboundActive(channelHandlerContext, obj);
        }
    }

    private void handlePreOutboundActive(ChannelHandlerContext channelHandlerContext, Object obj) {
        if (isInitialHaProxyMessage(obj)) {
            this.haProxyMessage = (HAProxyMessage) obj;
            this.state = State.HA_PROXY;
        } else if (isInitialDecodedApiVersionsFrame(obj)) {
            handleApiVersionsFrame(channelHandlerContext, obj);
        } else if (isInitialRequestFrame(obj)) {
            bufferMsgAndSelectServer(obj);
        } else {
            if (!isSubsequentRequestFrame(obj)) {
                throw illegalState("Unexpected channelRead() message of " + String.valueOf(obj.getClass()));
            }
            bufferMessage(obj);
        }
    }

    private void handleApiVersionsFrame(ChannelHandlerContext channelHandlerContext, Object obj) {
        this.state = State.API_VERSIONS;
        DecodedRequestFrame<ApiVersionsRequestData> decodedRequestFrame = (DecodedRequestFrame) obj;
        storeApiVersionsFeatures(decodedRequestFrame);
        if (!this.dp.isAuthenticationOffloadEnabled()) {
            bufferMsgAndSelectServer(obj);
        } else {
            writeApiVersionsResponse(channelHandlerContext, decodedRequestFrame);
            channelHandlerContext.channel().read();
        }
    }

    private boolean isSubsequentRequestFrame(Object obj) {
        return (this.state == State.CONNECTING || this.state == State.CONNECTED) && (obj instanceof RequestFrame);
    }

    private boolean isInitialRequestFrame(Object obj) {
        return (this.state == State.START || this.state == State.HA_PROXY || this.state == State.API_VERSIONS) && (obj instanceof RequestFrame);
    }

    private boolean isInitialHaProxyMessage(Object obj) {
        return this.state == State.START && (obj instanceof HAProxyMessage);
    }

    private boolean isInitialDecodedApiVersionsFrame(Object obj) {
        return (this.state == State.START || this.state == State.HA_PROXY) && (obj instanceof DecodedRequestFrame) && ((DecodedRequestFrame) obj).apiKey() == ApiKeys.API_VERSIONS;
    }

    private void bufferMsgAndSelectServer(Object obj) {
        this.state = State.CONNECTING;
        bufferMessage(obj);
        this.filter.selectServer(this);
    }

    private void bufferMessage(Object obj) {
        this.bufferedMsgs.add(obj);
    }

    @Override // io.kroxylicious.proxy.filter.NetFilter.NetFilterContext
    public void initiateConnect(HostPort hostPort, List<FilterAndInvoker> list) {
        if (this.backendHandler != null) {
            throw new IllegalStateException();
        }
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Connecting to backend broker {} using filters {}", new Object[]{this.inboundCtx.channel().id(), hostPort, list});
        }
        CorrelationManager correlationManager = new CorrelationManager();
        Channel channel = this.inboundCtx.channel();
        Bootstrap bootstrap = new Bootstrap();
        this.backendHandler = new KafkaProxyBackendHandler(this, this.inboundCtx);
        bootstrap.group(channel.eventLoop()).channel(channel.getClass()).handler(this.backendHandler).option(ChannelOption.AUTO_READ, true).option(ChannelOption.TCP_NODELAY, true);
        LOGGER.trace("Connecting to outbound {}", hostPort);
        ChannelFuture initConnection = initConnection(hostPort.host(), hostPort.port(), bootstrap);
        Channel channel2 = initConnection.channel();
        ChannelPipeline pipeline = channel2.pipeline();
        if (this.logFrames) {
            pipeline.addFirst("frameLogger", new LoggingHandler("io.kroxylicious.proxy.internal.UpstreamFrameLogger"));
        }
        addFiltersToPipeline(list, pipeline, channel);
        pipeline.addFirst("responseDecoder", new KafkaResponseDecoder(correlationManager, this.virtualCluster.socketFrameMaxSizeBytes()));
        pipeline.addFirst("requestEncoder", new KafkaRequestEncoder(correlationManager));
        if (this.logNetwork) {
            pipeline.addFirst("networkLogger", new LoggingHandler("io.kroxylicious.proxy.internal.UpstreamNetworkLogger"));
        }
        this.virtualCluster.getUpstreamSslContext().ifPresent(sslContext -> {
            pipeline.addFirst("ssl", sslContext.newHandler(channel2.alloc(), hostPort.host(), hostPort.port()));
        });
        initConnection.addListener(future -> {
            if (future.isSuccess()) {
                this.state = State.CONNECTED;
                LOGGER.trace("{}: Outbound connected", this.inboundCtx.channel().id());
                this.dp.setDelegate(DecodePredicate.forFilters(list));
            } else {
                this.state = State.FAILED;
                Throwable cause = future.cause();
                LOGGER.atWarn().setCause(LOGGER.isDebugEnabled() ? cause : null).log("Connection to target cluster on {} failed with: {}, closing inbound channel. Increase log level to DEBUG for stacktrace", hostPort, cause.getMessage());
                channel.close();
            }
        });
    }

    ChannelFuture initConnection(String str, int i, Bootstrap bootstrap) {
        return bootstrap.connect(str, i);
    }

    private void addFiltersToPipeline(List<FilterAndInvoker> list, ChannelPipeline channelPipeline, Channel channel) {
        for (FilterAndInvoker filterAndInvoker : list) {
            channelPipeline.addFirst(filterAndInvoker.toString(), new FilterHandler(filterAndInvoker, 20000L, this.sniHostname, this.virtualCluster, channel));
        }
    }

    public void forwardOutbound(ChannelHandlerContext channelHandlerContext, Object obj) {
        if (this.outboundCtx == null) {
            LOGGER.trace("READ on inbound {} ignored because outbound is not active (msg: {})", channelHandlerContext.channel(), obj);
            return;
        }
        Channel channel = this.outboundCtx.channel();
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("READ on inbound {} outbound {} (outbound.isWritable: {}, msg: {})", new Object[]{channelHandlerContext.channel(), channel, Boolean.valueOf(channel.isWritable()), obj});
            LOGGER.trace("Outbound bytesBeforeUnwritable: {}", Long.valueOf(channel.bytesBeforeUnwritable()));
            LOGGER.trace("Outbound config: {}", channel.config());
            LOGGER.trace("Outbound is active, writing and flushing {}", obj);
        }
        if (channel.isWritable()) {
            channel.write(obj, this.outboundCtx.voidPromise());
            this.pendingFlushes = true;
        } else {
            channel.writeAndFlush(obj, this.outboundCtx.voidPromise());
            this.pendingFlushes = false;
        }
        LOGGER.trace("/READ");
    }

    private void writeApiVersionsResponse(ChannelHandlerContext channelHandlerContext, DecodedRequestFrame<ApiVersionsRequestData> decodedRequestFrame) {
        short apiVersion = decodedRequestFrame.apiVersion();
        int correlationId = decodedRequestFrame.correlationId();
        ResponseHeaderData correlationId2 = new ResponseHeaderData().setCorrelationId(correlationId);
        LOGGER.debug("{}: Writing ApiVersions response", channelHandlerContext.channel());
        channelHandlerContext.writeAndFlush(new DecodedResponseFrame(apiVersion, correlationId, correlationId2, API_VERSIONS_RESPONSE));
    }

    private void storeApiVersionsFeatures(DecodedRequestFrame<ApiVersionsRequestData> decodedRequestFrame) {
        this.clientSoftwareName = decodedRequestFrame.body().clientSoftwareName();
        this.clientSoftwareVersion = decodedRequestFrame.body().clientSoftwareVersion();
    }

    public void outboundWritabilityChanged(ChannelHandlerContext channelHandlerContext) {
        if (this.outboundCtx != channelHandlerContext) {
            throw illegalState("Mismatching outboundCtx");
        }
        if (this.isInboundBlocked && channelHandlerContext.channel().isWritable()) {
            this.isInboundBlocked = false;
            this.inboundCtx.channel().config().setAutoRead(true);
        }
    }

    public void channelReadComplete(ChannelHandlerContext channelHandlerContext) {
        if (this.outboundCtx == null) {
            LOGGER.trace("READ_COMPLETE on inbound {}, ignored because outbound is not active", channelHandlerContext.channel());
            this.pendingReadComplete = true;
            return;
        }
        Channel channel = this.outboundCtx.channel();
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("READ_COMPLETE on inbound {} outbound {} (pendingFlushes: {}, isInboundBlocked: {}, output.isWritable: {})", new Object[]{channelHandlerContext.channel(), channel, Boolean.valueOf(this.pendingFlushes), Boolean.valueOf(this.isInboundBlocked), Boolean.valueOf(channel.isWritable())});
        }
        if (this.pendingFlushes) {
            this.pendingFlushes = false;
            channel.flush();
        }
        if (channel.isWritable()) {
            return;
        }
        channelHandlerContext.channel().config().setAutoRead(false);
        this.isInboundBlocked = true;
    }

    public void channelInactive(ChannelHandlerContext channelHandlerContext) {
        Channel channel;
        LOGGER.trace("INACTIVE on inbound {}", channelHandlerContext.channel());
        if (this.outboundCtx == null || (channel = this.outboundCtx.channel()) == null) {
            return;
        }
        closeOnFlush(channel);
    }

    public void exceptionCaught(ChannelHandlerContext channelHandlerContext, Throwable th) {
        LOGGER.warn("Netty caught exception from the frontend: {}", th.getMessage(), th);
        if (th instanceof DecoderException) {
            Throwable cause = ((DecoderException) th).getCause();
            if (cause instanceof FrameOversizedException) {
                FrameOversizedException frameOversizedException = (FrameOversizedException) cause;
                LOGGER.warn("Received over-sized frame, max frame size bytes {}, received frame size bytes {} (hint: are we decoding a Kafka frame, or something unexpected like an HTTP request{}?)", new Object[]{Integer.valueOf(frameOversizedException.getMaxFrameSizeBytes()), Integer.valueOf(frameOversizedException.getReceivedFrameSizeBytes()), this.virtualCluster.getDownstreamSslContext().isPresent() ? "" : " or an unexpected TLS handshake"});
            }
        }
        closeOnFlush(channelHandlerContext.channel());
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void closeOnFlush(Channel channel) {
        if (channel.isActive()) {
            channel.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE);
        }
    }

    public void userEventTriggered(ChannelHandlerContext channelHandlerContext, Object obj) throws Exception {
        if (obj instanceof SniCompletionEvent) {
            SniCompletionEvent sniCompletionEvent = (SniCompletionEvent) obj;
            if (sniCompletionEvent.isSuccess()) {
                this.sniHostname = sniCompletionEvent.hostname();
            }
        } else if (obj instanceof AuthenticationEvent) {
            this.authentication = (AuthenticationEvent) obj;
        }
        super.userEventTriggered(channelHandlerContext, obj);
    }

    @Override // io.kroxylicious.proxy.filter.NetFilter.NetFilterContext
    public String clientHost() {
        if (this.haProxyMessage != null) {
            return this.haProxyMessage.sourceAddress();
        }
        SocketAddress remoteAddress = this.inboundCtx.channel().remoteAddress();
        return remoteAddress instanceof InetSocketAddress ? ((InetSocketAddress) remoteAddress).getAddress().getHostAddress() : String.valueOf(remoteAddress);
    }

    @Override // io.kroxylicious.proxy.filter.NetFilter.NetFilterContext
    public int clientPort() {
        if (this.haProxyMessage != null) {
            return this.haProxyMessage.sourcePort();
        }
        SocketAddress remoteAddress = this.inboundCtx.channel().remoteAddress();
        if (remoteAddress instanceof InetSocketAddress) {
            return ((InetSocketAddress) remoteAddress).getPort();
        }
        return -1;
    }

    @Override // io.kroxylicious.proxy.filter.NetFilter.NetFilterContext
    public SocketAddress srcAddress() {
        return this.inboundCtx.channel().remoteAddress();
    }

    @Override // io.kroxylicious.proxy.filter.NetFilter.NetFilterContext
    public SocketAddress localAddress() {
        return this.inboundCtx.channel().localAddress();
    }

    @Override // io.kroxylicious.proxy.filter.NetFilter.NetFilterContext
    public String authorizedId() {
        if (this.authentication != null) {
            return this.authentication.authorizationId();
        }
        return null;
    }

    @Override // io.kroxylicious.proxy.filter.NetFilter.NetFilterContext
    public String clientSoftwareName() {
        return this.clientSoftwareName;
    }

    @Override // io.kroxylicious.proxy.filter.NetFilter.NetFilterContext
    public String clientSoftwareVersion() {
        return this.clientSoftwareVersion;
    }

    @Override // io.kroxylicious.proxy.filter.NetFilter.NetFilterContext
    public String sniHostname() {
        return this.sniHostname;
    }

    public void channelActive(ChannelHandlerContext channelHandlerContext) throws Exception {
        this.inboundCtx = channelHandlerContext;
        LOGGER.trace("{}: channelActive", this.inboundCtx.channel().id());
        channelHandlerContext.channel().config().setAutoRead(false);
        channelHandlerContext.channel().read();
        super.channelActive(channelHandlerContext);
    }

    public String toString() {
        return "KafkaProxyFrontendHandler{inbound = " + String.valueOf(this.inboundCtx.channel()) + ", state = " + String.valueOf(this.state) + "}";
    }

    static {
        ObjectMapper objectMapper = new ObjectMapper();
        try {
            InputStream resourceAsStream = KafkaProxyFrontendHandler.class.getResourceAsStream("/ApiVersions-3.2.json");
            try {
                API_VERSIONS_RESPONSE = ApiVersionsResponseDataJsonConverter.read(objectMapper.readTree(resourceAsStream), (short) 3);
                if (resourceAsStream != null) {
                    resourceAsStream.close();
                }
            } finally {
            }
        } catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }
}
