/*
 * Decompiled with CFR 0.152.
 */
package io.kroxylicious.proxy.internal;

import io.kroxylicious.proxy.bootstrap.FilterChainFactory;
import io.kroxylicious.proxy.config.PluginFactoryRegistry;
import io.kroxylicious.proxy.filter.Filter;
import io.kroxylicious.proxy.filter.FilterAndInvoker;
import io.kroxylicious.proxy.filter.NetFilter;
import io.kroxylicious.proxy.internal.ApiVersionsServiceImpl;
import io.kroxylicious.proxy.internal.KafkaAuthnHandler;
import io.kroxylicious.proxy.internal.KafkaProxyFrontendHandler;
import io.kroxylicious.proxy.internal.ResponseOrderer;
import io.kroxylicious.proxy.internal.SaslDecodePredicate;
import io.kroxylicious.proxy.internal.codec.KafkaRequestDecoder;
import io.kroxylicious.proxy.internal.codec.KafkaResponseEncoder;
import io.kroxylicious.proxy.internal.filter.ApiVersionsIntersectFilter;
import io.kroxylicious.proxy.internal.filter.BrokerAddressFilter;
import io.kroxylicious.proxy.internal.filter.EagerMetadataLearner;
import io.kroxylicious.proxy.internal.filter.NettyFilterContext;
import io.kroxylicious.proxy.internal.net.Endpoint;
import io.kroxylicious.proxy.internal.net.EndpointReconciler;
import io.kroxylicious.proxy.internal.net.VirtualClusterBinding;
import io.kroxylicious.proxy.internal.net.VirtualClusterBindingResolver;
import io.kroxylicious.proxy.model.VirtualCluster;
import io.kroxylicious.proxy.service.HostPort;
import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.socket.SocketChannel;
import io.netty.handler.codec.haproxy.HAProxyMessageDecoder;
import io.netty.handler.logging.LogLevel;
import io.netty.handler.logging.LoggingHandler;
import io.netty.handler.ssl.SniHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.util.concurrent.Future;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.CompletionStage;
import org.apache.kafka.common.security.auth.AuthenticateCallbackHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KafkaProxyInitializer
extends ChannelInitializer<SocketChannel> {
    private static final Logger LOGGER = LoggerFactory.getLogger(KafkaProxyInitializer.class);
    private final boolean haproxyProtocol;
    private final Map<KafkaAuthnHandler.SaslMechanism, AuthenticateCallbackHandler> authnHandlers;
    private final boolean tls;
    private final VirtualClusterBindingResolver virtualClusterBindingResolver;
    private final EndpointReconciler endpointReconciler;
    private final PluginFactoryRegistry pfr;
    private final FilterChainFactory filterChainFactory;

    public KafkaProxyInitializer(FilterChainFactory filterChainFactory, PluginFactoryRegistry pfr, boolean tls, VirtualClusterBindingResolver virtualClusterBindingResolver, EndpointReconciler endpointReconciler, boolean haproxyProtocol, Map<KafkaAuthnHandler.SaslMechanism, AuthenticateCallbackHandler> authnMechanismHandlers) {
        this.pfr = pfr;
        this.endpointReconciler = endpointReconciler;
        this.haproxyProtocol = haproxyProtocol;
        this.authnHandlers = authnMechanismHandlers != null ? authnMechanismHandlers : Map.of();
        this.tls = tls;
        this.virtualClusterBindingResolver = virtualClusterBindingResolver;
        this.filterChainFactory = filterChainFactory;
    }

    public void initChannel(SocketChannel ch) {
        Optional<String> bindingAddress;
        LOGGER.trace("Connection from {} to my address {}", (Object)ch.remoteAddress(), (Object)ch.localAddress());
        ChannelPipeline pipeline = ch.pipeline();
        int targetPort = ch.localAddress().getPort();
        Optional<String> optional = bindingAddress = ch.parent().localAddress().getAddress().isAnyLocalAddress() ? Optional.empty() : Optional.of(ch.localAddress().getAddress().getHostAddress());
        if (this.tls) {
            this.initTlsChannel(ch, pipeline, bindingAddress, targetPort);
        } else {
            this.initPlainChannel(ch, pipeline, bindingAddress, targetPort);
        }
    }

    private void initPlainChannel(final SocketChannel ch, final ChannelPipeline pipeline, final Optional<String> bindingAddress, final int targetPort) {
        pipeline.addLast(new ChannelHandler[]{new ChannelInboundHandlerAdapter(){

            public void channelActive(ChannelHandlerContext ctx) {
                KafkaProxyInitializer.this.virtualClusterBindingResolver.resolve(Endpoint.createEndpoint(bindingAddress, targetPort, KafkaProxyInitializer.this.tls), null).handle((binding, t) -> {
                    if (t != null) {
                        ctx.fireExceptionCaught(t);
                        return null;
                    }
                    try {
                        KafkaProxyInitializer.this.addHandlers(ch, (VirtualClusterBinding)binding);
                        ctx.fireChannelActive();
                    }
                    catch (Throwable t1) {
                        ctx.fireExceptionCaught(t1);
                    }
                    finally {
                        pipeline.remove((ChannelHandler)this);
                    }
                    return null;
                });
            }
        }});
    }

    private void initTlsChannel(SocketChannel ch, ChannelPipeline pipeline, Optional<String> bindingAddress, int targetPort) {
        LOGGER.debug("Adding SSL/SNI handler");
        pipeline.addLast(new ChannelHandler[]{new SniHandler((sniHostname, promise) -> {
            try {
                CompletionStage<VirtualClusterBinding> stage = this.virtualClusterBindingResolver.resolve(Endpoint.createEndpoint(bindingAddress, targetPort, this.tls), (String)sniHostname);
                stage.handle((binding, t) -> {
                    try {
                        if (t != null) {
                            promise.setFailure(t);
                            return null;
                        }
                        VirtualCluster virtualCluster = binding.virtualCluster();
                        Optional<SslContext> sslContext = virtualCluster.getDownstreamSslContext();
                        if (sslContext.isEmpty()) {
                            promise.setFailure((Throwable)new IllegalStateException("Virtual cluster %s does not provide SSL context".formatted(virtualCluster)));
                        } else {
                            this.addHandlers(ch, (VirtualClusterBinding)binding);
                            promise.setSuccess((Object)sslContext.get());
                        }
                    }
                    catch (Throwable t1) {
                        promise.setFailure(t1);
                    }
                    return null;
                });
                return promise;
            }
            catch (Throwable cause) {
                return promise.setFailure(cause);
            }
        }){

            protected void onLookupComplete(ChannelHandlerContext ctx, Future<SslContext> future) throws Exception {
                super.onLookupComplete(ctx, future);
                ctx.fireChannelActive();
            }
        }});
    }

    void addHandlers(SocketChannel ch, VirtualClusterBinding binding) {
        VirtualCluster virtualCluster = binding.virtualCluster();
        ChannelPipeline pipeline = ch.pipeline();
        if (virtualCluster.isLogNetwork()) {
            pipeline.addLast("networkLogger", (ChannelHandler)new LoggingHandler("io.kroxylicious.proxy.internal.DownstreamNetworkLogger", LogLevel.INFO));
        }
        if (this.haproxyProtocol) {
            LOGGER.debug("Adding haproxy handler");
            pipeline.addLast("HAProxyMessageDecoder", (ChannelHandler)new HAProxyMessageDecoder());
        }
        SaslDecodePredicate dp = new SaslDecodePredicate(!this.authnHandlers.isEmpty());
        KafkaRequestDecoder decoder = new KafkaRequestDecoder(dp, virtualCluster.socketFrameMaxSizeBytes());
        pipeline.addLast("requestDecoder", (ChannelHandler)decoder);
        pipeline.addLast("responseEncoder", (ChannelHandler)new KafkaResponseEncoder());
        pipeline.addLast("responseOrderer", (ChannelHandler)new ResponseOrderer());
        if (virtualCluster.isLogFrames()) {
            pipeline.addLast("frameLogger", (ChannelHandler)new LoggingHandler("io.kroxylicious.proxy.internal.DownstreamFrameLogger", LogLevel.INFO));
        }
        if (!this.authnHandlers.isEmpty()) {
            LOGGER.debug("Adding authn handler for handlers {}", this.authnHandlers);
            pipeline.addLast(new ChannelHandler[]{new KafkaAuthnHandler((Channel)ch, this.authnHandlers)});
        }
        ApiVersionsServiceImpl apiVersionService = new ApiVersionsServiceImpl();
        InitalizerNetFilter netFilter = new InitalizerNetFilter(dp, apiVersionService, ch, binding, this.pfr, this.filterChainFactory, this.endpointReconciler);
        KafkaProxyFrontendHandler frontendHandler = new KafkaProxyFrontendHandler(netFilter, dp, virtualCluster);
        pipeline.addLast("netHandler", (ChannelHandler)frontendHandler);
        LOGGER.debug("{}: Initial pipeline: {}", (Object)ch, (Object)pipeline);
    }

    static class InitalizerNetFilter
    implements NetFilter {
        private final SaslDecodePredicate decodePredicate;
        private final ApiVersionsServiceImpl apiVersionService;
        private final SocketChannel ch;
        private final VirtualCluster virtualCluster;
        private final VirtualClusterBinding binding;
        private final PluginFactoryRegistry pfr;
        private final FilterChainFactory filterChainFactory;
        private final EndpointReconciler endpointReconciler;

        InitalizerNetFilter(SaslDecodePredicate decodePredicate, ApiVersionsServiceImpl apiVersionService, SocketChannel ch, VirtualClusterBinding binding, PluginFactoryRegistry pfr, FilterChainFactory filterChainFactory, EndpointReconciler endpointReconciler) {
            this.decodePredicate = decodePredicate;
            this.apiVersionService = apiVersionService;
            this.ch = ch;
            this.virtualCluster = binding.virtualCluster();
            this.binding = binding;
            this.pfr = pfr;
            this.filterChainFactory = filterChainFactory;
            this.endpointReconciler = endpointReconciler;
        }

        @Override
        public void selectServer(NetFilter.NetFilterContext context) {
            List apiVersionFilters = this.decodePredicate.isAuthenticationOffloadEnabled() ? List.of() : FilterAndInvoker.build((Filter)new ApiVersionsIntersectFilter(this.apiVersionService));
            NettyFilterContext filterContext = new NettyFilterContext(this.ch.eventLoop(), this.pfr);
            List<FilterAndInvoker> customProtocolFilters = this.filterChainFactory.createFilters(filterContext);
            List<FilterAndInvoker> brokerAddressFilters = FilterAndInvoker.build((Filter)new BrokerAddressFilter(this.virtualCluster, this.endpointReconciler));
            ArrayList<FilterAndInvoker> filters = new ArrayList<FilterAndInvoker>(apiVersionFilters);
            filters.addAll(customProtocolFilters);
            if (this.binding.restrictUpstreamToMetadataDiscovery()) {
                filters.addAll(FilterAndInvoker.build((Filter)new EagerMetadataLearner()));
            }
            filters.addAll(brokerAddressFilters);
            HostPort target = this.binding.upstreamTarget();
            if (target == null) {
                throw new IllegalStateException("A target address for binding %s is not known.".formatted(this.binding));
            }
            context.initiateConnect(target, filters);
        }
    }
}

