/*
 * Decompiled with CFR 0.152.
 */
package org.apache.kafka.common.network;

import java.io.File;
import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.SelectionKey;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentLinkedQueue;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLEngine;
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.config.types.Password;
import org.apache.kafka.common.metrics.Metrics;
import org.apache.kafka.common.network.ChannelBuilder;
import org.apache.kafka.common.network.Mode;
import org.apache.kafka.common.network.NetworkReceive;
import org.apache.kafka.common.network.NetworkSend;
import org.apache.kafka.common.network.Selector;
import org.apache.kafka.common.network.Send;
import org.apache.kafka.common.network.SslChannelBuilder;
import org.apache.kafka.common.network.SslTransportLayer;
import org.apache.kafka.common.security.ssl.SslFactory;
import org.apache.kafka.common.utils.MockTime;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.common.utils.Utils;
import org.apache.kafka.test.TestSslUtils;
import org.apache.kafka.test.TestUtils;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class SslTransportLayerTest {
    private static final int BUFFER_SIZE = 4096;
    private SslEchoServer server;
    private Selector selector;
    private ChannelBuilder channelBuilder;
    private CertStores serverCertStores;
    private CertStores clientCertStores;
    private Map<String, Object> sslClientConfigs;
    private Map<String, Object> sslServerConfigs;

    @Before
    public void setup() throws Exception {
        this.serverCertStores = new CertStores(true);
        this.clientCertStores = new CertStores(false);
        this.sslServerConfigs = this.serverCertStores.getTrustingConfig(this.clientCertStores);
        this.sslClientConfigs = this.clientCertStores.getTrustingConfig(this.serverCertStores);
        this.channelBuilder = new SslChannelBuilder(Mode.CLIENT);
        this.channelBuilder.configure(this.sslClientConfigs);
        this.selector = new Selector(5000L, new Metrics(), (Time)new MockTime(), "MetricGroup", new LinkedHashMap(), this.channelBuilder);
    }

    @After
    public void teardown() throws Exception {
        if (this.selector != null) {
            this.selector.close();
        }
        if (this.server != null) {
            this.server.close();
        }
    }

    @Test
    public void testValidEndpointIdentification() throws Exception {
        String node = "0";
        this.createEchoServer(this.sslServerConfigs);
        this.sslClientConfigs.put("ssl.endpoint.identification.algorithm", "HTTPS");
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.testClientConnection(node, 100, 10);
    }

    @Test
    public void testInvalidEndpointIdentification() throws Exception {
        String node = "0";
        String serverHost = InetAddress.getLocalHost().getHostAddress();
        this.server = new SslEchoServer(this.sslServerConfigs, serverHost);
        this.server.start();
        this.sslClientConfigs.put("ssl.endpoint.identification.algorithm", "HTTPS");
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress(serverHost, this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.waitForChannelClose(node);
    }

    @Test
    public void testEndpointIdentificationDisabled() throws Exception {
        String node = "0";
        String serverHost = InetAddress.getLocalHost().getHostAddress();
        this.server = new SslEchoServer(this.sslServerConfigs, serverHost);
        this.server.start();
        this.sslClientConfigs.remove("ssl.endpoint.identification.algorithm");
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress(serverHost, this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.testClientConnection(node, 100, 10);
    }

    @Test
    public void testClientAuthenticationRequiredValidProvided() throws Exception {
        String node = "0";
        this.sslServerConfigs.put("ssl.client.auth", "required");
        this.createEchoServer(this.sslServerConfigs);
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.testClientConnection(node, 100, 10);
    }

    @Test
    public void testClientAuthenticationRequiredUntrustedProvided() throws Exception {
        String node = "0";
        this.sslServerConfigs = this.serverCertStores.getUntrustingConfig();
        this.sslServerConfigs.put("ssl.client.auth", "required");
        this.createEchoServer(this.sslServerConfigs);
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.waitForChannelClose(node);
    }

    @Test
    public void testClientAuthenticationRequiredNotProvided() throws Exception {
        String node = "0";
        this.sslServerConfigs.put("ssl.client.auth", "required");
        this.createEchoServer(this.sslServerConfigs);
        this.sslClientConfigs.remove("ssl.keystore.location");
        this.sslClientConfigs.remove("ssl.keystore.password");
        this.sslClientConfigs.remove("ssl.key.password");
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.waitForChannelClose(node);
    }

    @Test
    public void testClientAuthenticationDisabledUntrustedProvided() throws Exception {
        String node = "0";
        this.sslServerConfigs = this.serverCertStores.getUntrustingConfig();
        this.sslServerConfigs.put("ssl.client.auth", "none");
        this.createEchoServer(this.sslServerConfigs);
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.testClientConnection(node, 100, 10);
    }

    @Test
    public void testClientAuthenticationDisabledNotProvided() throws Exception {
        String node = "0";
        this.sslServerConfigs.put("ssl.client.auth", "none");
        this.createEchoServer(this.sslServerConfigs);
        this.sslClientConfigs.remove("ssl.keystore.location");
        this.sslClientConfigs.remove("ssl.keystore.password");
        this.sslClientConfigs.remove("ssl.key.password");
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.testClientConnection(node, 100, 10);
    }

    @Test
    public void testClientAuthenticationRequestedValidProvided() throws Exception {
        String node = "0";
        this.sslServerConfigs.put("ssl.client.auth", "requested");
        this.createEchoServer(this.sslServerConfigs);
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.testClientConnection(node, 100, 10);
    }

    @Test
    public void testClientAuthenticationRequestedNotProvided() throws Exception {
        String node = "0";
        this.sslServerConfigs.put("ssl.client.auth", "requested");
        this.createEchoServer(this.sslServerConfigs);
        this.sslClientConfigs.remove("ssl.keystore.location");
        this.sslClientConfigs.remove("ssl.keystore.password");
        this.sslClientConfigs.remove("ssl.key.password");
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.testClientConnection(node, 100, 10);
    }

    @Test
    public void testInvalidTruststorePassword() throws Exception {
        SslChannelBuilder channelBuilder = new SslChannelBuilder(Mode.CLIENT);
        try {
            this.sslClientConfigs.put("ssl.truststore.password", "invalid");
            channelBuilder.configure(this.sslClientConfigs);
            Assert.fail((String)"SSL channel configured with invalid truststore password");
        }
        catch (KafkaException kafkaException) {
            // empty catch block
        }
    }

    @Test
    public void testInvalidKeystorePassword() throws Exception {
        SslChannelBuilder channelBuilder = new SslChannelBuilder(Mode.CLIENT);
        try {
            this.sslClientConfigs.put("ssl.keystore.password", "invalid");
            channelBuilder.configure(this.sslClientConfigs);
            Assert.fail((String)"SSL channel configured with invalid keystore password");
        }
        catch (KafkaException kafkaException) {
            // empty catch block
        }
    }

    @Test
    public void testInvalidKeyPassword() throws Exception {
        String node = "0";
        this.sslServerConfigs.put("ssl.key.password", new Password("invalid"));
        this.createEchoServer(this.sslServerConfigs);
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.waitForChannelClose(node);
    }

    @Test
    public void testUnsupportedTLSVersion() throws Exception {
        String node = "0";
        this.sslServerConfigs.put("ssl.enabled.protocols", Arrays.asList("TLSv1.2"));
        this.createEchoServer(this.sslServerConfigs);
        this.sslClientConfigs.put("ssl.enabled.protocols", Arrays.asList("TLSv1.1"));
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.waitForChannelClose(node);
    }

    @Test
    public void testUnsupportedCiphers() throws Exception {
        String node = "0";
        String[] cipherSuites = SSLContext.getDefault().getDefaultSSLParameters().getCipherSuites();
        this.sslServerConfigs.put("ssl.cipher.suites", Arrays.asList(cipherSuites[0]));
        this.createEchoServer(this.sslServerConfigs);
        this.sslClientConfigs.put("ssl.cipher.suites", Arrays.asList(cipherSuites[1]));
        this.createSelector(this.sslClientConfigs);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.waitForChannelClose(node);
    }

    @Test
    public void testNetReadBufferResize() throws Exception {
        String node = "0";
        this.createEchoServer(this.sslServerConfigs);
        this.createSelector(this.sslClientConfigs, 10, null, null);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.testClientConnection(node, 64000, 10);
    }

    @Test
    public void testNetWriteBufferResize() throws Exception {
        String node = "0";
        this.createEchoServer(this.sslServerConfigs);
        this.createSelector(this.sslClientConfigs, null, 10, null);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.testClientConnection(node, 64000, 10);
    }

    @Test
    public void testApplicationBufferResize() throws Exception {
        String node = "0";
        this.createEchoServer(this.sslServerConfigs);
        this.createSelector(this.sslClientConfigs, null, null, 10);
        InetSocketAddress addr = new InetSocketAddress("localhost", this.server.port);
        this.selector.connect(node, addr, 4096, 4096);
        this.testClientConnection(node, 64000, 10);
    }

    private void testClientConnection(String node, int minMessageSize, int messageCount) throws Exception {
        String prefix = TestUtils.randomString(minMessageSize);
        int requests = 0;
        int responses = 0;
        while (!this.selector.isChannelReady(node)) {
            this.selector.poll(1000L);
        }
        this.selector.send((Send)new NetworkSend(node, new ByteBuffer[]{ByteBuffer.wrap((prefix + "-0").getBytes())}));
        ++requests;
        while (responses < messageCount) {
            this.selector.poll(0L);
            Assert.assertEquals((String)"No disconnects should have occurred.", (long)0L, (long)this.selector.disconnected().size());
            for (NetworkReceive receive : this.selector.completedReceives()) {
                Assert.assertEquals((Object)(prefix + "-" + responses), (Object)new String(Utils.toArray((ByteBuffer)receive.payload())));
                ++responses;
            }
            for (int i = 0; i < this.selector.completedSends().size() && requests < messageCount && this.selector.isChannelReady(node); ++i, ++requests) {
                this.selector.send((Send)new NetworkSend(node, new ByteBuffer[]{ByteBuffer.wrap((prefix + "-" + requests).getBytes())}));
            }
        }
    }

    private void waitForChannelClose(String node) throws IOException {
        boolean closed = false;
        for (int i = 0; i < 30; ++i) {
            this.selector.poll(1000L);
            if (this.selector.channel(node) != null) continue;
            closed = true;
            break;
        }
        Assert.assertTrue((boolean)closed);
    }

    private void createEchoServer(Map<String, Object> sslServerConfigs) throws Exception {
        this.server = new SslEchoServer(sslServerConfigs, "localhost");
        this.server.start();
    }

    private void createSelector(Map<String, Object> sslClientConfigs) {
        this.createSelector(sslClientConfigs, null, null, null);
    }

    private void createSelector(Map<String, Object> sslClientConfigs, final Integer netReadBufSize, final Integer netWriteBufSize, final Integer appBufSize) {
        this.channelBuilder = new SslChannelBuilder(Mode.CLIENT){

            protected SslTransportLayer buildTransportLayer(SslFactory sslFactory, String id, SelectionKey key) throws IOException {
                SocketChannel socketChannel = (SocketChannel)key.channel();
                SSLEngine sslEngine = sslFactory.createSslEngine(socketChannel.socket().getInetAddress().getHostName(), socketChannel.socket().getPort());
                TestSslTransportLayer transportLayer = new TestSslTransportLayer(id, key, sslEngine, netReadBufSize, netWriteBufSize, appBufSize);
                transportLayer.startHandshake();
                return transportLayer;
            }
        };
        this.channelBuilder.configure(sslClientConfigs);
        this.selector = new Selector(5000L, new Metrics(), (Time)new MockTime(), "MetricGroup", new LinkedHashMap(), this.channelBuilder);
    }

    private class SslEchoServer
    extends Thread {
        private final int port;
        private final ServerSocketChannel serverSocketChannel;
        private final List<SocketChannel> newChannels;
        private final List<SocketChannel> socketChannels;
        private final AcceptorThread acceptorThread;
        private SslFactory sslFactory;
        private final Selector selector;
        private final ConcurrentLinkedQueue<NetworkSend> inflightSends = new ConcurrentLinkedQueue();

        public SslEchoServer(Map<String, ?> configs, String serverHost) throws Exception {
            this.sslFactory = new SslFactory(Mode.SERVER);
            this.sslFactory.configure(configs);
            this.serverSocketChannel = ServerSocketChannel.open();
            this.serverSocketChannel.configureBlocking(false);
            this.serverSocketChannel.socket().bind(new InetSocketAddress(serverHost, 0));
            this.port = this.serverSocketChannel.socket().getLocalPort();
            this.socketChannels = Collections.synchronizedList(new ArrayList());
            this.newChannels = Collections.synchronizedList(new ArrayList());
            SslChannelBuilder channelBuilder = new SslChannelBuilder(Mode.SERVER);
            channelBuilder.configure(SslTransportLayerTest.this.sslServerConfigs);
            this.selector = new Selector(5000L, new Metrics(), (Time)new MockTime(), "MetricGroup", new LinkedHashMap(), (ChannelBuilder)channelBuilder);
            this.setName("echoserver");
            this.setDaemon(true);
            this.acceptorThread = new AcceptorThread();
        }

        @Override
        public void run() {
            try {
                this.acceptorThread.start();
                while (this.serverSocketChannel.isOpen()) {
                    NetworkSend send;
                    this.selector.poll(1000L);
                    for (SocketChannel socketChannel : this.newChannels) {
                        String id = this.id(socketChannel);
                        this.selector.register(id, socketChannel);
                        this.socketChannels.add(socketChannel);
                    }
                    this.newChannels.clear();
                    while ((send = this.inflightSends.peek()) != null && !this.selector.channel(send.destination()).hasSend()) {
                        send = this.inflightSends.poll();
                        this.selector.send((Send)send);
                    }
                    List completedReceives = this.selector.completedReceives();
                    for (NetworkReceive rcv : completedReceives) {
                        NetworkSend send2 = new NetworkSend(rcv.source(), new ByteBuffer[]{rcv.payload()});
                        if (!this.selector.channel(send2.destination()).hasSend()) {
                            this.selector.send((Send)send2);
                            continue;
                        }
                        this.inflightSends.add(send2);
                    }
                }
            }
            catch (IOException iOException) {
                // empty catch block
            }
        }

        private String id(SocketChannel channel) {
            return channel.socket().getLocalAddress().getHostAddress() + ":" + channel.socket().getLocalPort() + "-" + channel.socket().getInetAddress().getHostAddress() + ":" + channel.socket().getPort();
        }

        public void closeConnections() throws IOException {
            for (SocketChannel channel : this.socketChannels) {
                channel.close();
            }
            this.socketChannels.clear();
        }

        public void close() throws IOException, InterruptedException {
            this.serverSocketChannel.close();
            this.closeConnections();
            this.acceptorThread.interrupt();
            this.acceptorThread.join();
            this.interrupt();
            this.join();
        }

        private class AcceptorThread
        extends Thread {
            public AcceptorThread() throws IOException {
                this.setName("acceptor");
            }

            @Override
            public void run() {
                try {
                    java.nio.channels.Selector acceptSelector = java.nio.channels.Selector.open();
                    SslEchoServer.this.serverSocketChannel.register(acceptSelector, 16);
                    while (SslEchoServer.this.serverSocketChannel.isOpen()) {
                        if (acceptSelector.select(1000L) <= 0) continue;
                        for (SelectionKey key : acceptSelector.selectedKeys()) {
                            if (!key.isAcceptable()) continue;
                            SocketChannel socketChannel = ((ServerSocketChannel)key.channel()).accept();
                            socketChannel.configureBlocking(false);
                            SslEchoServer.this.newChannels.add(socketChannel);
                            SslEchoServer.this.selector.wakeup();
                        }
                    }
                }
                catch (IOException iOException) {
                    // empty catch block
                }
            }
        }
    }

    private static class TestSslTransportLayer
    extends SslTransportLayer {
        private final ResizeableBufferSize netReadBufSize;
        private final ResizeableBufferSize netWriteBufSize;
        private final ResizeableBufferSize appBufSize;

        public TestSslTransportLayer(String channelId, SelectionKey key, SSLEngine sslEngine, Integer netReadBufSize, Integer netWriteBufSize, Integer appBufSize) throws IOException {
            super(channelId, key, sslEngine, false);
            this.netReadBufSize = new ResizeableBufferSize(netReadBufSize);
            this.netWriteBufSize = new ResizeableBufferSize(netWriteBufSize);
            this.appBufSize = new ResizeableBufferSize(appBufSize);
        }

        protected int netReadBufferSize() {
            ByteBuffer netReadBuffer = this.netReadBuffer();
            boolean updateBufSize = netReadBuffer != null && !this.netReadBuffer().hasRemaining();
            return this.netReadBufSize.updateAndGet(super.netReadBufferSize(), updateBufSize);
        }

        protected int netWriteBufferSize() {
            return this.netWriteBufSize.updateAndGet(super.netWriteBufferSize(), true);
        }

        protected int applicationBufferSize() {
            return this.appBufSize.updateAndGet(super.applicationBufferSize(), true);
        }

        private static class ResizeableBufferSize {
            private Integer bufSizeOverride;

            ResizeableBufferSize(Integer bufSizeOverride) {
                this.bufSizeOverride = bufSizeOverride;
            }

            int updateAndGet(int actualSize, boolean update) {
                int size = actualSize;
                if (this.bufSizeOverride != null) {
                    if (update) {
                        this.bufSizeOverride = Math.min(this.bufSizeOverride * 2, size);
                    }
                    size = this.bufSizeOverride;
                }
                return size;
            }
        }
    }

    private static class CertStores {
        Map<String, Object> sslConfig;

        CertStores(boolean server) throws Exception {
            String name = server ? "server" : "client";
            Mode mode = server ? Mode.SERVER : Mode.CLIENT;
            File truststoreFile = File.createTempFile(name + "TS", ".jks");
            this.sslConfig = TestSslUtils.createSslConfig(!server, true, mode, truststoreFile, name);
            if (server) {
                this.sslConfig.put("principal.builder.class", Class.forName("org.apache.kafka.common.security.auth.DefaultPrincipalBuilder"));
            }
        }

        private Map<String, Object> getTrustingConfig(CertStores truststoreConfig) {
            HashMap<String, Object> config = new HashMap<String, Object>(this.sslConfig);
            config.put("ssl.truststore.location", truststoreConfig.sslConfig.get("ssl.truststore.location"));
            config.put("ssl.truststore.password", truststoreConfig.sslConfig.get("ssl.truststore.password"));
            config.put("ssl.truststore.type", truststoreConfig.sslConfig.get("ssl.truststore.type"));
            return config;
        }

        private Map<String, Object> getUntrustingConfig() {
            return this.sslConfig;
        }
    }
}

