package org.apache.sysds.runtime.controlprogram.federated;

import io.netty.bootstrap.Bootstrap;
import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandler;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.EventLoopGroup;
import io.netty.channel.nio.NioEventLoopGroup;
import io.netty.channel.socket.SocketChannel;
import io.netty.channel.socket.nio.NioSocketChannel;
import io.netty.handler.codec.serialization.ObjectEncoder;
import io.netty.handler.ssl.SslContext;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslHandler;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.netty.handler.timeout.ReadTimeoutHandler;
import io.netty.util.concurrent.Promise;
import java.io.Serializable;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Set;
import java.util.concurrent.Future;
import javax.net.ssl.SSLException;
import org.antlr.v4.runtime.atn.PredictionContext;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.paramserv.NetworkTrafficCounter;
import org.apache.sysds.runtime.meta.MetaData;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/federated/FederatedData.class */
public class FederatedData {
    private static final Log LOG = LogFactory.getLog(FederatedData.class.getName());
    private static final Set<InetSocketAddress> _allFedSites = new HashSet();
    private static EventLoopGroup workerGroup = null;
    private static SslContextMan sslInstance = null;
    private final Types.DataType _dataType;
    private final InetSocketAddress _address;
    private final String _filepath;
    private long _varID;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/federated/FederatedData$DataRequestHandler.class */
    public static class DataRequestHandler extends ChannelInboundHandlerAdapter {
        private Promise<FederatedResponse> _prom;

        public void setPromise(Promise<FederatedResponse> promise) {
            this._prom = promise;
        }

        public void channelRead(ChannelHandlerContext channelHandlerContext, Object obj) {
            this._prom.setSuccess((FederatedResponse) obj);
            channelHandlerContext.close();
        }

        public Promise<FederatedResponse> getProm() {
            return this._prom;
        }
    }

    /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/federated/FederatedData$FederatedRequestEncoder.class */
    public static class FederatedRequestEncoder extends ObjectEncoder {
        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Multi-variable type inference failed */
        public ByteBuf allocateBuffer(ChannelHandlerContext channelHandlerContext, Serializable serializable, boolean z) throws Exception {
            int i = 256;
            if (serializable instanceof FederatedRequest[]) {
                i = 0;
                try {
                    for (FederatedRequest federatedRequest : (FederatedRequest[]) serializable) {
                        int intExact = Math.toIntExact(federatedRequest.estimateSerializationBufferSize());
                        if (PredictionContext.EMPTY_RETURN_STATE - i < intExact) {
                            throw new ArithmeticException("Overflow.");
                        }
                        i += intExact;
                    }
                } catch (ArithmeticException e) {
                    i = Integer.MAX_VALUE;
                }
            }
            return z ? channelHandlerContext.alloc().ioBuffer(i) : channelHandlerContext.alloc().heapBuffer(i);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/federated/FederatedData$SslContextMan.class */
    public static class SslContextMan {
        protected final SslContext context;

        private SslContextMan() {
            try {
                this.context = SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE).build();
            } catch (SSLException e) {
                throw new DMLRuntimeException("Static SSL setup failed for client side", e);
            }
        }
    }

    public FederatedData(Types.DataType dataType, InetSocketAddress inetSocketAddress, String str) {
        this._varID = -1L;
        this._dataType = dataType;
        this._address = inetSocketAddress;
        this._filepath = str;
        if (this._address != null) {
            _allFedSites.add(this._address);
        }
    }

    public FederatedData(Types.DataType dataType, InetSocketAddress inetSocketAddress, String str, long j) {
        this._varID = -1L;
        this._dataType = dataType;
        this._address = inetSocketAddress;
        this._filepath = str;
        this._varID = j;
    }

    public InetSocketAddress getAddress() {
        return this._address;
    }

    public void setVarID(long j) {
        this._varID = j;
    }

    public long getVarID() {
        return this._varID;
    }

    public String getFilepath() {
        return this._filepath;
    }

    public Types.DataType getDataType() {
        return this._dataType;
    }

    public boolean isInitialized() {
        return this._varID != -1;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public boolean equalAddress(FederatedData federatedData) {
        return (this._address == null || federatedData == null || federatedData._address == null || !this._address.equals(federatedData._address)) ? false : true;
    }

    public FederatedData copyWithNewID(long j) {
        FederatedData federatedData = new FederatedData(this._dataType, this._address, this._filepath);
        federatedData.setVarID(j);
        return federatedData;
    }

    public synchronized Future<FederatedResponse> initFederatedData(long j) {
        return initFederatedData(j, null);
    }

    public synchronized Future<FederatedResponse> initFederatedData(long j, MetaData metaData) {
        if (isInitialized()) {
            throw new DMLRuntimeException("Tried to init already initialized data");
        }
        if (!this._dataType.isMatrix() && !this._dataType.isFrame()) {
            throw new DMLRuntimeException("Federated datatype \"" + this._dataType.toString() + "\" is not supported.");
        }
        this._varID = j;
        FederatedRequest federatedRequest = metaData != null ? new FederatedRequest(FederatedRequest.RequestType.READ_VAR, j, metaData) : new FederatedRequest(FederatedRequest.RequestType.READ_VAR, j);
        federatedRequest.appendParam(this._filepath);
        federatedRequest.appendParam(this._dataType.name());
        return executeFederatedOperation(federatedRequest);
    }

    public synchronized Future<FederatedResponse> initFederatedDataFromLocal(long j, CacheBlock cacheBlock) {
        if (isInitialized()) {
            throw new DMLRuntimeException("Tried to init already initialized data");
        }
        if (!this._dataType.isMatrix() && !this._dataType.isFrame()) {
            throw new DMLRuntimeException("Federated datatype \"" + this._dataType.toString() + "\" is not supported.");
        }
        this._varID = j;
        FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.READ_VAR, j);
        federatedRequest.appendParam(this._filepath);
        federatedRequest.appendParam(this._dataType.name());
        federatedRequest.appendParam(cacheBlock);
        return executeFederatedOperation(federatedRequest);
    }

    public Future<FederatedResponse> executeFederatedOperation(FederatedRequest... federatedRequestArr) {
        return executeFederatedOperation(this._address, federatedRequestArr);
    }

    public static synchronized Future<FederatedResponse> executeFederatedOperation(InetSocketAddress inetSocketAddress, FederatedRequest... federatedRequestArr) {
        try {
            Bootstrap bootstrap = new Bootstrap();
            if (workerGroup == null) {
                createWorkGroup();
            }
            bootstrap.group(workerGroup);
            bootstrap.channel(NioSocketChannel.class);
            DataRequestHandler dataRequestHandler = new DataRequestHandler();
            bootstrap.handler(createChannel(inetSocketAddress, dataRequestHandler));
            ChannelFuture sync = bootstrap.connect(inetSocketAddress).sync();
            dataRequestHandler.setPromise(sync.channel().eventLoop().newPromise());
            sync.channel().writeAndFlush(federatedRequestArr);
            return dataRequestHandler.getProm();
        } catch (Exception e) {
            throw new DMLRuntimeException("Failed sending federated operation", e);
        }
    }

    private static ChannelInitializer<SocketChannel> createChannel(final InetSocketAddress inetSocketAddress, final DataRequestHandler dataRequestHandler) {
        final int federatedTimeout = ConfigurationManager.getFederatedTimeout();
        final boolean isFederatedSSL = ConfigurationManager.isFederatedSSL();
        return new ChannelInitializer<SocketChannel>() { // from class: org.apache.sysds.runtime.controlprogram.federated.FederatedData.1
            /* JADX INFO: Access modifiers changed from: protected */
            public void initChannel(SocketChannel socketChannel) throws Exception {
                ChannelPipeline pipeline = socketChannel.pipeline();
                pipeline.addLast("NetworkTrafficCounter", new NetworkTrafficCounter((v0, v1) -> {
                    FederatedStatistics.logServerTraffic(v0, v1);
                }));
                if (isFederatedSSL) {
                    pipeline.addLast(new ChannelHandler[]{FederatedData.createSSLHandler(socketChannel, inetSocketAddress)});
                }
                if (federatedTimeout > -1) {
                    pipeline.addLast(new ChannelHandler[]{new ReadTimeoutHandler(federatedTimeout)});
                }
                pipeline.addLast(new ChannelHandler[]{FederationUtils.decoder(), new FederatedRequestEncoder(), dataRequestHandler});
            }
        };
    }

    public static void clearFederatedWorkers() {
        if (_allFedSites.isEmpty()) {
            return;
        }
        try {
            FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.CLEAR);
            ArrayList arrayList = new ArrayList();
            Iterator<InetSocketAddress> it = _allFedSites.iterator();
            while (it.hasNext()) {
                arrayList.add(executeFederatedOperation(it.next(), federatedRequest));
            }
            FederationUtils.waitFor(arrayList);
        } catch (Exception e) {
            LOG.warn("Failed to execute CLEAR request on existing federated sites.", e);
        } finally {
            resetFederatedSites();
        }
    }

    private static SslHandler createSSLHandler(SocketChannel socketChannel, InetSocketAddress inetSocketAddress) {
        return SslConstructor().context.newHandler(socketChannel.alloc(), inetSocketAddress.getAddress().getHostAddress(), inetSocketAddress.getPort());
    }

    public static void resetFederatedSites() {
        _allFedSites.clear();
    }

    public static void clearWorkGroup() {
        if (workerGroup != null) {
            workerGroup.shutdownGracefully();
        }
        workerGroup = null;
    }

    public static synchronized void createWorkGroup() {
        if (workerGroup == null) {
            workerGroup = new NioEventLoopGroup(8);
        }
    }

    private static SslContextMan SslConstructor() {
        return sslInstance == null ? new SslContextMan() : sslInstance;
    }

    public String toString() {
        return getClass().getSimpleName().toString() + (" " + this._dataType) + (" " + this._address.toString()) + (":" + this._filepath);
    }
}
