/*
 * Decompiled with CFR 0.152.
 */
package io.kroxylicious.proxy.internal;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import io.kroxylicious.proxy.filter.FilterAndInvoker;
import io.kroxylicious.proxy.filter.NetFilter;
import io.kroxylicious.proxy.frame.DecodedRequestFrame;
import io.kroxylicious.proxy.frame.DecodedResponseFrame;
import io.kroxylicious.proxy.frame.RequestFrame;
import io.kroxylicious.proxy.internal.AuthenticationEvent;
import io.kroxylicious.proxy.internal.FilterHandler;
import io.kroxylicious.proxy.internal.KafkaProxyBackendHandler;
import io.kroxylicious.proxy.internal.SaslDecodePredicate;
import io.kroxylicious.proxy.internal.codec.CorrelationManager;
import io.kroxylicious.proxy.internal.codec.DecodePredicate;
import io.kroxylicious.proxy.internal.codec.FrameOversizedException;
import io.kroxylicious.proxy.internal.codec.KafkaRequestEncoder;
import io.kroxylicious.proxy.internal.codec.KafkaResponseDecoder;
import io.kroxylicious.proxy.model.VirtualCluster;
import io.kroxylicious.proxy.service.HostPort;
import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.Unpooled;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.handler.codec.DecoderException;
import io.netty.handler.codec.haproxy.HAProxyMessage;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SniCompletionEvent;
import io.netty.util.concurrent.GenericFutureListener;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.ArrayList;
import java.util.List;
import org.apache.kafka.common.message.ApiVersionsRequestData;
import org.apache.kafka.common.message.ApiVersionsResponseData;
import org.apache.kafka.common.message.ApiVersionsResponseDataJsonConverter;
import org.apache.kafka.common.message.ResponseHeaderData;
import org.apache.kafka.common.protocol.ApiKeys;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KafkaProxyFrontendHandler
extends ChannelInboundHandlerAdapter
implements NetFilter.NetFilterContext {
    private static final Logger LOGGER = LoggerFactory.getLogger(KafkaProxyFrontendHandler.class);
    private static final ApiVersionsResponseData API_VERSIONS_RESPONSE;
    private final boolean logNetwork;
    private final boolean logFrames;
    private final VirtualCluster virtualCluster;
    private ChannelHandlerContext outboundCtx;
    private KafkaProxyBackendHandler backendHandler;
    private boolean pendingFlushes;
    private final NetFilter filter;
    private final SaslDecodePredicate dp;
    private AuthenticationEvent authentication;
    private String clientSoftwareName;
    private String clientSoftwareVersion;
    private String sniHostname;
    private ChannelHandlerContext inboundCtx;
    private List<Object> bufferedMsgs = new ArrayList<Object>();
    private boolean pendingReadComplete = true;
    private State state = State.START;
    private boolean isInboundBlocked = true;
    private HAProxyMessage haProxyMessage;

    KafkaProxyFrontendHandler(NetFilter filter, SaslDecodePredicate dp, VirtualCluster virtualCluster) {
        this.filter = filter;
        this.dp = dp;
        this.virtualCluster = virtualCluster;
        this.logNetwork = virtualCluster.isLogNetwork();
        this.logFrames = virtualCluster.isLogFrames();
    }

    private IllegalStateException illegalState(String msg) {
        String name = this.state.name();
        this.state = State.FAILED;
        return new IllegalStateException((String)(msg == null ? "" : msg + ", ") + "state=" + name);
    }

    State state() {
        return this.state;
    }

    public void outboundChannelActive(ChannelHandlerContext ctx) {
        if (this.state != State.CONNECTED) {
            throw this.illegalState(null);
        }
        LOGGER.trace("{}: outboundChannelActive", (Object)this.inboundCtx.channel().id());
        this.outboundCtx = ctx;
        for (Object bufferedMsg : this.bufferedMsgs) {
            this.forwardOutbound(ctx, bufferedMsg);
        }
        this.bufferedMsgs = null;
        if (this.pendingReadComplete) {
            this.pendingReadComplete = false;
            this.channelReadComplete(ctx);
        }
        this.state = State.OUTBOUND_ACTIVE;
        Channel inboundChannel = this.inboundCtx.channel();
        inboundChannel.config().setAutoRead(true);
    }

    public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception {
        super.channelWritabilityChanged(ctx);
        if (this.backendHandler != null) {
            this.backendHandler.inboundChannelWritabilityChanged(ctx);
        }
    }

    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if (this.state == State.OUTBOUND_ACTIVE) {
            this.forwardOutbound(ctx, msg);
        } else {
            this.handlePreOutboundActive(ctx, msg);
        }
    }

    private void handlePreOutboundActive(ChannelHandlerContext ctx, Object msg) {
        if (this.isInitialHaProxyMessage(msg)) {
            this.haProxyMessage = (HAProxyMessage)msg;
            this.state = State.HA_PROXY;
        } else if (this.isInitialDecodedApiVersionsFrame(msg)) {
            this.handleApiVersionsFrame(ctx, msg);
        } else if (this.isInitialRequestFrame(msg)) {
            this.bufferMsgAndSelectServer(msg);
        } else if (this.isSubsequentRequestFrame(msg)) {
            this.bufferMessage(msg);
        } else {
            throw this.illegalState("Unexpected channelRead() message of " + String.valueOf(msg.getClass()));
        }
    }

    private void handleApiVersionsFrame(ChannelHandlerContext ctx, Object msg) {
        this.state = State.API_VERSIONS;
        DecodedRequestFrame apiVersionsFrame = (DecodedRequestFrame)msg;
        this.storeApiVersionsFeatures(apiVersionsFrame);
        if (this.dp.isAuthenticationOffloadEnabled()) {
            this.writeApiVersionsResponse(ctx, apiVersionsFrame);
            ctx.channel().read();
        } else {
            this.bufferMsgAndSelectServer(msg);
        }
    }

    private boolean isSubsequentRequestFrame(Object msg) {
        return (this.state == State.CONNECTING || this.state == State.CONNECTED) && msg instanceof RequestFrame;
    }

    private boolean isInitialRequestFrame(Object msg) {
        return (this.state == State.START || this.state == State.HA_PROXY || this.state == State.API_VERSIONS) && msg instanceof RequestFrame;
    }

    private boolean isInitialHaProxyMessage(Object msg) {
        return this.state == State.START && msg instanceof HAProxyMessage;
    }

    private boolean isInitialDecodedApiVersionsFrame(Object msg) {
        return (this.state == State.START || this.state == State.HA_PROXY) && msg instanceof DecodedRequestFrame && ((DecodedRequestFrame)msg).apiKey() == ApiKeys.API_VERSIONS;
    }

    private void bufferMsgAndSelectServer(Object msg) {
        this.state = State.CONNECTING;
        this.bufferMessage(msg);
        this.filter.selectServer(this);
    }

    private void bufferMessage(Object msg) {
        this.bufferedMsgs.add(msg);
    }

    @Override
    public void initiateConnect(HostPort remote, List<FilterAndInvoker> filters) {
        if (this.backendHandler != null) {
            throw new IllegalStateException();
        }
        if (LOGGER.isDebugEnabled()) {
            LOGGER.debug("{}: Connecting to backend broker {} using filters {}", new Object[]{this.inboundCtx.channel().id(), remote, filters});
        }
        CorrelationManager correlationManager = new CorrelationManager();
        Channel inboundChannel = this.inboundCtx.channel();
        Bootstrap b = new Bootstrap();
        this.backendHandler = new KafkaProxyBackendHandler(this, this.inboundCtx);
        ((Bootstrap)((Bootstrap)((Bootstrap)((Bootstrap)b.group((EventLoopGroup)inboundChannel.eventLoop())).channel(inboundChannel.getClass())).handler((ChannelHandler)this.backendHandler)).option(ChannelOption.AUTO_READ, (Object)true)).option(ChannelOption.TCP_NODELAY, (Object)true);
        LOGGER.trace("Connecting to outbound {}", (Object)remote);
        ChannelFuture connectFuture = this.initConnection(remote.host(), remote.port(), b);
        Channel outboundChannel = connectFuture.channel();
        ChannelPipeline pipeline = outboundChannel.pipeline();
        if (this.logFrames) {
            pipeline.addFirst("frameLogger", (ChannelHandler)new LoggingHandler("io.kroxylicious.proxy.internal.UpstreamFrameLogger"));
        }
        this.addFiltersToPipeline(filters, pipeline, inboundChannel);
        pipeline.addFirst("responseDecoder", (ChannelHandler)new KafkaResponseDecoder(correlationManager, this.virtualCluster.socketFrameMaxSizeBytes()));
        pipeline.addFirst("requestEncoder", (ChannelHandler)new KafkaRequestEncoder(correlationManager));
        if (this.logNetwork) {
            pipeline.addFirst("networkLogger", (ChannelHandler)new LoggingHandler("io.kroxylicious.proxy.internal.UpstreamNetworkLogger"));
        }
        this.virtualCluster.getUpstreamSslContext().ifPresent(c -> pipeline.addFirst("ssl", (ChannelHandler)c.newHandler(outboundChannel.alloc(), remote.host(), remote.port())));
        connectFuture.addListener(future -> {
            if (future.isSuccess()) {
                this.state = State.CONNECTED;
                LOGGER.trace("{}: Outbound connected", (Object)this.inboundCtx.channel().id());
                this.dp.setDelegate(DecodePredicate.forFilters(filters));
            } else {
                this.state = State.FAILED;
                Throwable failureCause = future.cause();
                LOGGER.atWarn().setCause(LOGGER.isDebugEnabled() ? failureCause : null).log("Connection to target cluster on {} failed with: {}, closing inbound channel. Increase log level to DEBUG for stacktrace", (Object)remote, (Object)failureCause.getMessage());
                inboundChannel.close();
            }
        });
    }

    ChannelFuture initConnection(String remoteHost, int remotePort, Bootstrap b) {
        return b.connect(remoteHost, remotePort);
    }

    private void addFiltersToPipeline(List<FilterAndInvoker> filters, ChannelPipeline pipeline, Channel inboundChannel) {
        for (FilterAndInvoker filter : filters) {
            pipeline.addFirst(filter.toString(), (ChannelHandler)new FilterHandler(filter, 20000L, this.sniHostname, this.virtualCluster, inboundChannel));
        }
    }

    public void forwardOutbound(ChannelHandlerContext ctx, Object msg) {
        if (this.outboundCtx == null) {
            LOGGER.trace("READ on inbound {} ignored because outbound is not active (msg: {})", (Object)ctx.channel(), msg);
            return;
        }
        Channel outboundChannel = this.outboundCtx.channel();
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("READ on inbound {} outbound {} (outbound.isWritable: {}, msg: {})", new Object[]{ctx.channel(), outboundChannel, outboundChannel.isWritable(), msg});
            LOGGER.trace("Outbound bytesBeforeUnwritable: {}", (Object)outboundChannel.bytesBeforeUnwritable());
            LOGGER.trace("Outbound config: {}", (Object)outboundChannel.config());
            LOGGER.trace("Outbound is active, writing and flushing {}", msg);
        }
        if (outboundChannel.isWritable()) {
            outboundChannel.write(msg, this.outboundCtx.voidPromise());
            this.pendingFlushes = true;
        } else {
            outboundChannel.writeAndFlush(msg, this.outboundCtx.voidPromise());
            this.pendingFlushes = false;
        }
        LOGGER.trace("/READ");
    }

    private void writeApiVersionsResponse(ChannelHandlerContext ctx, DecodedRequestFrame<ApiVersionsRequestData> frame) {
        short apiVersion = frame.apiVersion();
        int correlationId = frame.correlationId();
        ResponseHeaderData header = new ResponseHeaderData().setCorrelationId(correlationId);
        LOGGER.debug("{}: Writing ApiVersions response", (Object)ctx.channel());
        ctx.writeAndFlush(new DecodedResponseFrame<ApiVersionsResponseData>(apiVersion, correlationId, header, API_VERSIONS_RESPONSE));
    }

    private void storeApiVersionsFeatures(DecodedRequestFrame<ApiVersionsRequestData> frame) {
        this.clientSoftwareName = ((ApiVersionsRequestData)frame.body()).clientSoftwareName();
        this.clientSoftwareVersion = ((ApiVersionsRequestData)frame.body()).clientSoftwareVersion();
    }

    public void outboundWritabilityChanged(ChannelHandlerContext outboundCtx) {
        if (this.outboundCtx != outboundCtx) {
            throw this.illegalState("Mismatching outboundCtx");
        }
        if (this.isInboundBlocked && outboundCtx.channel().isWritable()) {
            this.isInboundBlocked = false;
            this.inboundCtx.channel().config().setAutoRead(true);
        }
    }

    public void channelReadComplete(ChannelHandlerContext ctx) {
        if (this.outboundCtx == null) {
            LOGGER.trace("READ_COMPLETE on inbound {}, ignored because outbound is not active", (Object)ctx.channel());
            this.pendingReadComplete = true;
            return;
        }
        Channel outboundChannel = this.outboundCtx.channel();
        if (LOGGER.isTraceEnabled()) {
            LOGGER.trace("READ_COMPLETE on inbound {} outbound {} (pendingFlushes: {}, isInboundBlocked: {}, output.isWritable: {})", new Object[]{ctx.channel(), outboundChannel, this.pendingFlushes, this.isInboundBlocked, outboundChannel.isWritable()});
        }
        if (this.pendingFlushes) {
            this.pendingFlushes = false;
            outboundChannel.flush();
        }
        if (!outboundChannel.isWritable()) {
            ctx.channel().config().setAutoRead(false);
            this.isInboundBlocked = true;
        }
    }

    public void channelInactive(ChannelHandlerContext ctx) {
        LOGGER.trace("INACTIVE on inbound {}", (Object)ctx.channel());
        if (this.outboundCtx == null) {
            return;
        }
        Channel outboundChannel = this.outboundCtx.channel();
        if (outboundChannel != null) {
            KafkaProxyFrontendHandler.closeOnFlush(outboundChannel);
        }
    }

    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) {
        DecoderException de;
        Throwable throwable;
        LOGGER.warn("Netty caught exception from the frontend: {}", (Object)cause.getMessage(), (Object)cause);
        if (cause instanceof DecoderException && (throwable = (de = (DecoderException)cause).getCause()) instanceof FrameOversizedException) {
            FrameOversizedException e = (FrameOversizedException)throwable;
            String tlsHint = this.virtualCluster.getDownstreamSslContext().isPresent() ? "" : " or an unexpected TLS handshake";
            LOGGER.warn("Received over-sized frame, max frame size bytes {}, received frame size bytes {} (hint: are we decoding a Kafka frame, or something unexpected like an HTTP request{}?)", new Object[]{e.getMaxFrameSizeBytes(), e.getReceivedFrameSizeBytes(), tlsHint});
        }
        KafkaProxyFrontendHandler.closeOnFlush(ctx.channel());
    }

    static void closeOnFlush(Channel ch) {
        if (ch.isActive()) {
            ch.writeAndFlush((Object)Unpooled.EMPTY_BUFFER).addListener((GenericFutureListener)ChannelFutureListener.CLOSE);
        }
    }

    public void userEventTriggered(ChannelHandlerContext ctx, Object event) throws Exception {
        if (event instanceof SniCompletionEvent) {
            SniCompletionEvent sniCompletionEvent = (SniCompletionEvent)event;
            if (sniCompletionEvent.isSuccess()) {
                this.sniHostname = sniCompletionEvent.hostname();
            }
        } else if (event instanceof AuthenticationEvent) {
            this.authentication = (AuthenticationEvent)event;
        }
        super.userEventTriggered(ctx, event);
    }

    @Override
    public String clientHost() {
        if (this.haProxyMessage != null) {
            return this.haProxyMessage.sourceAddress();
        }
        SocketAddress socketAddress = this.inboundCtx.channel().remoteAddress();
        if (socketAddress instanceof InetSocketAddress) {
            return ((InetSocketAddress)socketAddress).getAddress().getHostAddress();
        }
        return String.valueOf(socketAddress);
    }

    @Override
    public int clientPort() {
        if (this.haProxyMessage != null) {
            return this.haProxyMessage.sourcePort();
        }
        SocketAddress socketAddress = this.inboundCtx.channel().remoteAddress();
        if (socketAddress instanceof InetSocketAddress) {
            return ((InetSocketAddress)socketAddress).getPort();
        }
        return -1;
    }

    @Override
    public SocketAddress srcAddress() {
        return this.inboundCtx.channel().remoteAddress();
    }

    @Override
    public SocketAddress localAddress() {
        return this.inboundCtx.channel().localAddress();
    }

    @Override
    public String authorizedId() {
        return this.authentication != null ? this.authentication.authorizationId() : null;
    }

    @Override
    public String clientSoftwareName() {
        return this.clientSoftwareName;
    }

    @Override
    public String clientSoftwareVersion() {
        return this.clientSoftwareVersion;
    }

    @Override
    public String sniHostname() {
        return this.sniHostname;
    }

    public void channelActive(ChannelHandlerContext ctx) throws Exception {
        this.inboundCtx = ctx;
        LOGGER.trace("{}: channelActive", (Object)this.inboundCtx.channel().id());
        ctx.channel().config().setAutoRead(false);
        ctx.channel().read();
        super.channelActive(ctx);
    }

    public String toString() {
        return "KafkaProxyFrontendHandler{inbound = " + String.valueOf(this.inboundCtx.channel()) + ", state = " + String.valueOf((Object)this.state) + "}";
    }

    static {
        ObjectMapper objectMapper = new ObjectMapper();
        try (InputStream parser = KafkaProxyFrontendHandler.class.getResourceAsStream("/ApiVersions-3.2.json");){
            API_VERSIONS_RESPONSE = ApiVersionsResponseDataJsonConverter.read((JsonNode)objectMapper.readTree(parser), (short)3);
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

    static enum State {
        START,
        HA_PROXY,
        API_VERSIONS,
        CONNECTING,
        CONNECTED,
        OUTBOUND_ACTIVE,
        FAILED;

    }
}

