/*
 * Decompiled with CFR 0.152.
 */
package alluxio.grpc;

import alluxio.conf.AlluxioConfiguration;
import alluxio.conf.PropertyKey;
import alluxio.grpc.GrpcChannelKey;
import alluxio.grpc.GrpcConnection;
import alluxio.grpc.GrpcConnectionKey;
import alluxio.grpc.GrpcNetworkGroup;
import alluxio.network.ChannelType;
import alluxio.util.CommonUtils;
import alluxio.util.WaitForOptions;
import alluxio.util.network.NettyUtils;
import alluxio.util.network.tls.SslContextProvider;
import com.google.common.base.Preconditions;
import io.grpc.ConnectivityState;
import io.grpc.ManagedChannel;
import io.grpc.netty.NettyChannelBuilder;
import io.netty.channel.Channel;
import io.netty.channel.EventLoopGroup;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import javax.annotation.concurrent.ThreadSafe;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@ThreadSafe
public class GrpcConnectionPool {
    private static final Logger LOG = LoggerFactory.getLogger(GrpcConnectionPool.class);
    public static final GrpcConnectionPool INSTANCE = new GrpcConnectionPool();
    private ConcurrentMap<GrpcConnectionKey, CountingReference<ManagedChannel>> mChannels = new ConcurrentHashMap<GrpcConnectionKey, CountingReference<ManagedChannel>>();
    private ConcurrentMap<GrpcNetworkGroup, CountingReference<EventLoopGroup>> mEventLoops = new ConcurrentHashMap<GrpcNetworkGroup, CountingReference<EventLoopGroup>>();
    private ConcurrentMap<GrpcNetworkGroup, AtomicLong> mNetworkGroupCounters = new ConcurrentHashMap<GrpcNetworkGroup, AtomicLong>();
    private SslContextProvider mSslContextProvider;

    public GrpcConnectionPool() {
        for (GrpcNetworkGroup group : GrpcNetworkGroup.values()) {
            this.mNetworkGroupCounters.put(group, new AtomicLong());
        }
    }

    private synchronized SslContextProvider getSslContextProvider(AlluxioConfiguration conf) {
        if (this.mSslContextProvider == null) {
            this.mSslContextProvider = SslContextProvider.Factory.create(conf);
        }
        return this.mSslContextProvider;
    }

    public GrpcConnection acquireConnection(GrpcChannelKey channelKey, AlluxioConfiguration conf) {
        GrpcConnectionKey connectionKey = this.getConnectionKey(channelKey, conf);
        CountingReference connectionRef = this.mChannels.compute(connectionKey, (key, ref) -> {
            boolean shutdownExistingConnection = false;
            int existingRefCount = 0;
            if (ref != null) {
                if (this.waitForConnectionReady((ManagedChannel)((CountingReference)ref).get(), conf)) {
                    LOG.debug("Acquiring an existing connection. ConnectionKey: {}. Ref-count: {}", key, (Object)((CountingReference)ref).getRefCount());
                    return ((CountingReference)ref).reference();
                }
                shutdownExistingConnection = true;
            }
            if (shutdownExistingConnection) {
                existingRefCount = ((CountingReference)ref).getRefCount();
                LOG.debug("Shutting down an existing unhealthy connection. ConnectionKey: {}. Ref-count: {}", key, (Object)existingRefCount);
                this.shutdownManagedChannel((ManagedChannel)((CountingReference)ref).get(), conf);
            }
            LOG.debug("Creating a new managed channel. ConnectionKey: {}. Ref-count:{}", key, (Object)existingRefCount);
            ManagedChannel managedChannel = this.createManagedChannel(channelKey, conf);
            return new CountingReference(managedChannel, existingRefCount).reference();
        });
        return new GrpcConnection(connectionKey, (ManagedChannel)connectionRef.get(), conf);
    }

    public void releaseConnection(GrpcConnectionKey connectionKey, AlluxioConfiguration conf) {
        this.mChannels.compute(connectionKey, (key, ref) -> {
            Preconditions.checkNotNull((Object)ref, (Object)"Cannot release nonexistent connection");
            LOG.debug("Releasing connection for: {}. Ref-count: {}", key, (Object)((CountingReference)ref).getRefCount());
            if (((CountingReference)ref).dereference() == 0) {
                LOG.debug("Shutting down connection after: {}", (Object)connectionKey);
                this.shutdownManagedChannel((ManagedChannel)((CountingReference)ref).get(), conf);
                this.releaseNetworkEventLoop(connectionKey.getChannelKey());
                return null;
            }
            return ref;
        });
    }

    private GrpcConnectionKey getConnectionKey(GrpcChannelKey channelKey, AlluxioConfiguration conf) {
        long groupIndex = ((AtomicLong)this.mNetworkGroupCounters.get((Object)channelKey.getNetworkGroup())).incrementAndGet();
        long maxConnectionsForGroup = conf.getLong(PropertyKey.Template.USER_NETWORK_MAX_CONNECTIONS.format(channelKey.getNetworkGroup().getPropertyCode()));
        return new GrpcConnectionKey(channelKey, (int)(groupIndex %= maxConnectionsForGroup));
    }

    private ManagedChannel createManagedChannel(GrpcChannelKey channelKey, AlluxioConfiguration conf) {
        NettyChannelBuilder channelBuilder;
        SocketAddress address = channelKey.getServerAddress().getSocketAddress();
        if (address instanceof InetSocketAddress) {
            InetSocketAddress inetServerAddress = (InetSocketAddress)address;
            channelBuilder = NettyChannelBuilder.forAddress((String)inetServerAddress.getHostName(), (int)inetServerAddress.getPort());
        } else {
            channelBuilder = NettyChannelBuilder.forAddress((SocketAddress)address);
        }
        channelBuilder = this.applyGroupDefaults(channelKey, channelBuilder, conf);
        return channelBuilder.build();
    }

    private NettyChannelBuilder applyGroupDefaults(GrpcChannelKey key, NettyChannelBuilder channelBuilder, AlluxioConfiguration conf) {
        long keepAliveTimeMs = conf.getMs(PropertyKey.Template.USER_NETWORK_KEEPALIVE_TIME_MS.format(key.getNetworkGroup().getPropertyCode()));
        long keepAliveTimeoutMs = conf.getMs(PropertyKey.Template.USER_NETWORK_KEEPALIVE_TIMEOUT_MS.format(key.getNetworkGroup().getPropertyCode()));
        long inboundMessageSizeBytes = conf.getBytes(PropertyKey.Template.USER_NETWORK_MAX_INBOUND_MESSAGE_SIZE.format(key.getNetworkGroup().getPropertyCode()));
        long flowControlWindow = conf.getBytes(PropertyKey.Template.USER_NETWORK_FLOWCONTROL_WINDOW.format(key.getNetworkGroup().getPropertyCode()));
        Class<? extends Channel> channelType = NettyUtils.getChannelClass(!(key.getServerAddress().getSocketAddress() instanceof InetSocketAddress), PropertyKey.Template.USER_NETWORK_NETTY_CHANNEL.format(key.getNetworkGroup().getPropertyCode()), conf);
        EventLoopGroup eventLoopGroup = this.acquireNetworkEventLoop(key, conf);
        channelBuilder.keepAliveTime(keepAliveTimeMs, TimeUnit.MILLISECONDS);
        channelBuilder.keepAliveTimeout(keepAliveTimeoutMs, TimeUnit.MILLISECONDS);
        channelBuilder.maxInboundMessageSize((int)inboundMessageSizeBytes);
        channelBuilder.flowControlWindow((int)flowControlWindow);
        channelBuilder.channelType(channelType);
        channelBuilder.eventLoopGroup(eventLoopGroup);
        channelBuilder.usePlaintext();
        if (key.getNetworkGroup() == GrpcNetworkGroup.SECRET) {
            channelBuilder.sslContext(this.getSslContextProvider(conf).getSelfSignedClientSslContext());
            channelBuilder.useTransportSecurity();
        } else if (conf.getBoolean(PropertyKey.NETWORK_TLS_ENABLED)) {
            channelBuilder.sslContext(this.getSslContextProvider(conf).getClientSslContext());
            channelBuilder.useTransportSecurity();
        }
        return channelBuilder;
    }

    private boolean waitForConnectionReady(ManagedChannel managedChannel, AlluxioConfiguration conf) {
        long healthCheckTimeoutMs = conf.getMs(PropertyKey.NETWORK_CONNECTION_HEALTH_CHECK_TIMEOUT);
        try {
            Boolean res = CommonUtils.waitForResult("channel to be ready", () -> {
                ConnectivityState currentState = managedChannel.getState(true);
                switch (currentState) {
                    case READY: {
                        return true;
                    }
                    case TRANSIENT_FAILURE: 
                    case SHUTDOWN: {
                        return false;
                    }
                    case IDLE: 
                    case CONNECTING: {
                        return null;
                    }
                }
                return null;
            }, b -> b != null, WaitForOptions.defaults().setTimeoutMs((int)healthCheckTimeoutMs));
            return res;
        }
        catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            return false;
        }
        catch (TimeoutException e) {
            return false;
        }
    }

    private void shutdownManagedChannel(ManagedChannel managedChannel, AlluxioConfiguration conf) {
        if (!managedChannel.isShutdown()) {
            long gracefulTimeoutMs = conf.getMs(PropertyKey.NETWORK_CONNECTION_SHUTDOWN_GRACEFUL_TIMEOUT);
            managedChannel.shutdown();
            try {
                if (!managedChannel.awaitTermination(gracefulTimeoutMs, TimeUnit.MILLISECONDS)) {
                    LOG.warn("Timed out gracefully shutting down connection: {}. ", (Object)managedChannel);
                }
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
        if (!managedChannel.isTerminated()) {
            long timeoutMs = conf.getMs(PropertyKey.NETWORK_CONNECTION_SHUTDOWN_TIMEOUT);
            managedChannel.shutdownNow();
            try {
                if (!managedChannel.awaitTermination(timeoutMs, TimeUnit.MILLISECONDS)) {
                    LOG.warn("Timed out forcefully shutting down connection: {}. ", (Object)managedChannel);
                }
            }
            catch (InterruptedException e) {
                Thread.currentThread().interrupt();
            }
        }
    }

    private EventLoopGroup acquireNetworkEventLoop(GrpcChannelKey channelKey, AlluxioConfiguration conf) {
        return (EventLoopGroup)this.mEventLoops.compute(channelKey.getNetworkGroup(), (key, v) -> {
            if (v != null) {
                GrpcConnectionPool.LOG.debug("Acquiring an existing event-loop for {}. Ref-Count:{}", (Object)channelKey, (Object)((CountingReference)v).getRefCount());
                ((CountingReference)v).reference();
                return v;
            }
            ChannelType nettyChannelType = NettyUtils.getChannelType(PropertyKey.Template.USER_NETWORK_NETTY_CHANNEL.format(key.getPropertyCode()), conf);
            int nettyWorkerThreadCount = conf.getInt(PropertyKey.Template.USER_NETWORK_NETTY_WORKER_THREADS.format(key.getPropertyCode()));
            v = new CountingReference(NettyUtils.createEventLoop(nettyChannelType, nettyWorkerThreadCount, String.format("alluxio-client-netty-event-loop-%s-%%d", key.name()), true), 1);
            GrpcConnectionPool.LOG.debug("Created a new event loop. NetworkGroup: {}. NettyChannelType: {}, NettyThreadCount: {}", new Object[]{key, nettyChannelType, nettyWorkerThreadCount});
            return v;
        }).get();
    }

    private void releaseNetworkEventLoop(GrpcChannelKey channelKey) {
        this.mEventLoops.compute(channelKey.getNetworkGroup(), (key, ref) -> {
            Preconditions.checkNotNull((Object)ref, (Object)"Cannot release nonexistent event-loop");
            LOG.debug("Releasing event-loop for: {}. Ref-count: {}", (Object)channelKey, (Object)((CountingReference)ref).getRefCount());
            if (((CountingReference)ref).dereference() == 0) {
                LOG.debug("Shutting down event-loop: {}", ((CountingReference)ref).get());
                ((EventLoopGroup)((CountingReference)ref).get()).shutdownGracefully();
                return null;
            }
            return ref;
        });
    }

    private class CountingReference<T> {
        private T mObject;
        private AtomicInteger mRefCount;

        private CountingReference(T object, int initialRefCount) {
            this.mObject = object;
            this.mRefCount = new AtomicInteger(initialRefCount);
        }

        private CountingReference reference() {
            this.mRefCount.incrementAndGet();
            return this;
        }

        private int dereference() {
            return this.mRefCount.decrementAndGet();
        }

        private int getRefCount() {
            return this.mRefCount.get();
        }

        private T get() {
            return this.mObject;
        }
    }
}

