package org.apache.sysds.runtime.controlprogram.federated;

import java.io.Serializable;
import java.net.InetSocketAddress;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.Future;
import javax.net.ssl.SSLException;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.util.ProgramConverter;
import org.apache.sysds.utils.Statistics;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics.class */
public class FederatedStatistics {
    private static Set<Pair<String, Integer>> _fedWorkerAddresses = new HashSet();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics$FedStatsCollectFunction.class */
    public static class FedStatsCollectFunction extends FederatedUDF {
        private static final long serialVersionUID = 1;

        public FedStatsCollectFunction() {
            super(new long[0]);
        }

        @Override // org.apache.sysds.runtime.controlprogram.federated.FederatedUDF
        public FederatedResponse execute(ExecutionContext executionContext, Data... dataArr) {
            FedStatsCollection fedStatsCollection = new FedStatsCollection();
            fedStatsCollection.collectStats();
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS, fedStatsCollection);
        }

        @Override // org.apache.sysds.runtime.lineage.LineageTraceable
        public Pair<String, LineageItem> getLineageItem(ExecutionContext executionContext) {
            return null;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics$FedStatsCollection.class */
    public static class FedStatsCollection implements Serializable {
        private static final long serialVersionUID = 1;
        private CacheStatsCollection cacheStats = new CacheStatsCollection();
        private double jitCompileTime = DataExpression.DEFAULT_DELIM_FILL_VALUE;
        private GCStatsCollection gcStats = new GCStatsCollection();
        private HashMap<String, Pair<Long, Double>> heavyHitters = new HashMap<>();

        /* JADX INFO: Access modifiers changed from: protected */
        /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics$FedStatsCollection$CacheStatsCollection.class */
        public static class CacheStatsCollection implements Serializable {
            private static final long serialVersionUID = 1;
            private long memHits = 0;
            private long linHits = 0;
            private long fsBuffHits = 0;
            private long fsHits = 0;
            private long hdfsHits = 0;
            private long linWrites = 0;
            private long fsBuffWrites = 0;
            private long fsWrites = 0;
            private long hdfsWrites = 0;
            private double acqRTime = DataExpression.DEFAULT_DELIM_FILL_VALUE;
            private double acqMTime = DataExpression.DEFAULT_DELIM_FILL_VALUE;
            private double rlsTime = DataExpression.DEFAULT_DELIM_FILL_VALUE;
            private double expTime = DataExpression.DEFAULT_DELIM_FILL_VALUE;

            protected CacheStatsCollection() {
            }

            /* JADX INFO: Access modifiers changed from: private */
            public void collectStats() {
                this.memHits = CacheStatistics.getMemHits();
                this.linHits = CacheStatistics.getLinHits();
                this.fsBuffHits = CacheStatistics.getFSBuffHits();
                this.fsHits = CacheStatistics.getFSHits();
                this.hdfsHits = CacheStatistics.getHDFSHits();
                this.linWrites = CacheStatistics.getLinWrites();
                this.fsBuffWrites = CacheStatistics.getFSBuffWrites();
                this.fsWrites = CacheStatistics.getFSWrites();
                this.hdfsWrites = CacheStatistics.getHDFSWrites();
                this.acqRTime = CacheStatistics.getAcquireRTime() / 1.0E9d;
                this.acqMTime = CacheStatistics.getAcquireMTime() / 1.0E9d;
                this.rlsTime = CacheStatistics.getReleaseTime() / 1.0E9d;
                this.expTime = CacheStatistics.getExportTime() / 1.0E9d;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public void aggregate(CacheStatsCollection cacheStatsCollection) {
                this.memHits += cacheStatsCollection.memHits;
                this.linHits += cacheStatsCollection.linHits;
                this.fsBuffHits += cacheStatsCollection.fsBuffHits;
                this.fsHits += cacheStatsCollection.fsHits;
                this.hdfsHits += cacheStatsCollection.hdfsHits;
                this.linWrites += cacheStatsCollection.linWrites;
                this.fsBuffWrites += cacheStatsCollection.fsBuffWrites;
                this.fsWrites += cacheStatsCollection.fsWrites;
                this.hdfsWrites += cacheStatsCollection.hdfsWrites;
                this.acqRTime += cacheStatsCollection.acqRTime;
                this.acqMTime += cacheStatsCollection.acqMTime;
                this.rlsTime += cacheStatsCollection.rlsTime;
                this.expTime += cacheStatsCollection.expTime;
            }
        }

        /* JADX INFO: Access modifiers changed from: protected */
        /* loaded from: input_file:org/apache/sysds/runtime/controlprogram/federated/FederatedStatistics$FedStatsCollection$GCStatsCollection.class */
        public static class GCStatsCollection implements Serializable {
            private static final long serialVersionUID = 1;
            private long gcCount = 0;
            private double gcTime = DataExpression.DEFAULT_DELIM_FILL_VALUE;

            protected GCStatsCollection() {
            }

            /* JADX INFO: Access modifiers changed from: private */
            public void collectStats() {
                this.gcCount = Statistics.getJVMgcCount();
                this.gcTime = Statistics.getJVMgcTime() / 1000.0d;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public void aggregate(GCStatsCollection gCStatsCollection) {
                this.gcCount += gCStatsCollection.gcCount;
                this.gcTime += gCStatsCollection.gcTime;
            }
        }

        protected FedStatsCollection() {
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void collectStats() {
            this.cacheStats.collectStats();
            this.jitCompileTime = Statistics.getJITCompileTime() / 1000.0d;
            this.gcStats.collectStats();
            this.heavyHitters = Statistics.getHeavyHittersHashMap();
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void aggregate(FedStatsCollection fedStatsCollection) {
            this.cacheStats.aggregate(fedStatsCollection.cacheStats);
            this.jitCompileTime += fedStatsCollection.jitCompileTime;
            this.gcStats.aggregate(fedStatsCollection.gcStats);
            fedStatsCollection.heavyHitters.forEach((str, pair) -> {
                this.heavyHitters.merge(str, pair, (pair, pair2) -> {
                    return new ImmutablePair(Long.valueOf(((Long) pair.getLeft()).longValue() + ((Long) pair2.getLeft()).longValue()), Double.valueOf(((Double) pair.getRight()).doubleValue() + ((Double) pair2.getRight()).doubleValue()));
                });
            });
        }
    }

    public static void registerFedWorker(String str, int i) {
        _fedWorkerAddresses.add(new ImmutablePair(str, new Integer(i)));
    }

    public static String displayFedWorkers() {
        StringBuilder sb = new StringBuilder();
        sb.append("Federated Worker Addresses:\n");
        for (Pair<String, Integer> pair : _fedWorkerAddresses) {
            sb.append(String.format("  %s:%d", pair.getLeft(), Integer.valueOf(((Integer) pair.getRight()).intValue())));
            sb.append(ProgramConverter.NEWLINE);
        }
        return sb.toString();
    }

    public static String displayFedStatistics(int i) {
        StringBuilder sb = new StringBuilder();
        FedStatsCollection collectFedStats = collectFedStats();
        sb.append("SystemDS Federated Statistics:\n");
        sb.append(displayCacheStats(collectFedStats.cacheStats));
        sb.append(String.format("Total JIT compile time:\t\t%.3f sec.\n", Double.valueOf(collectFedStats.jitCompileTime)));
        sb.append(displayGCStats(collectFedStats.gcStats));
        sb.append(displayHeavyHitters(collectFedStats.heavyHitters, i));
        return sb.toString();
    }

    public static String displayCacheStats(FedStatsCollection.CacheStatsCollection cacheStatsCollection) {
        return String.format("Cache hits (Mem/Li/WB/FS/HDFS):\t%d/%d/%d/%d/%d.\n", Long.valueOf(cacheStatsCollection.memHits), Long.valueOf(cacheStatsCollection.linHits), Long.valueOf(cacheStatsCollection.fsBuffHits), Long.valueOf(cacheStatsCollection.fsHits), Long.valueOf(cacheStatsCollection.hdfsHits)) + String.format("Cache writes (Li/WB/FS/HDFS):\t%d/%d/%d/%d.\n", Long.valueOf(cacheStatsCollection.linWrites), Long.valueOf(cacheStatsCollection.fsBuffWrites), Long.valueOf(cacheStatsCollection.fsWrites), Long.valueOf(cacheStatsCollection.hdfsWrites)) + String.format("Cache times (ACQr/m, RLS, EXP):\t%.3f/%.3f/%.3f/%.3f sec.\n", Double.valueOf(cacheStatsCollection.acqRTime), Double.valueOf(cacheStatsCollection.acqMTime), Double.valueOf(cacheStatsCollection.rlsTime), Double.valueOf(cacheStatsCollection.expTime));
    }

    public static String displayGCStats(FedStatsCollection.GCStatsCollection gCStatsCollection) {
        return String.format("Total JVM GC count:\t\t%d.\n", Long.valueOf(gCStatsCollection.gcCount)) + String.format("Total JVM GC time:\t\t%.3f sec.\n", Double.valueOf(gCStatsCollection.gcTime));
    }

    public static String displayHeavyHitters(HashMap<String, Pair<Long, Double>> hashMap) {
        return displayHeavyHitters(hashMap, 10);
    }

    public static String displayHeavyHitters(HashMap<String, Pair<Long, Double>> hashMap, int i) {
        StringBuilder sb = new StringBuilder();
        Map.Entry[] entryArr = (Map.Entry[]) hashMap.entrySet().toArray(new Map.Entry[0]);
        Arrays.sort(entryArr, new Comparator<Map.Entry<String, Pair<Long, Double>>>() { // from class: org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics.1
            @Override // java.util.Comparator
            public int compare(Map.Entry<String, Pair<Long, Double>> entry, Map.Entry<String, Pair<Long, Double>> entry2) {
                return ((Double) entry.getValue().getRight()).compareTo((Double) entry2.getValue().getRight());
            }
        });
        sb.append("Heavy hitter instructions:\n");
        int min = Math.min(i, entryArr.length);
        int length = String.valueOf(min).length();
        int length2 = "Instruction".length();
        int length3 = "Time(s)".length();
        int length4 = "Count".length();
        DecimalFormat decimalFormat = new DecimalFormat("#,##0.000");
        for (int i2 = 0; i2 < min; i2++) {
            Map.Entry entry = entryArr[(entryArr.length - 1) - i2];
            length2 = Math.max(length2, ((String) entry.getKey()).length());
            length3 = Math.max(length3, decimalFormat.format(((Pair) entry.getValue()).getRight()).length());
            length4 = Math.max(length4, String.valueOf(((Pair) entry.getValue()).getLeft()).length());
        }
        int min2 = Math.min(length2, DMLScript.STATISTICS_MAX_WRAP_LEN);
        sb.append(String.format(" %" + length + "s  %-" + min2 + "s  %" + length3 + "s  %" + length4 + "s", "#", "Instruction", "Time(s)", "Count"));
        sb.append(ProgramConverter.NEWLINE);
        for (int i3 = 0; i3 < min; i3++) {
            String[] wrap = Statistics.wrap((String) entryArr[(entryArr.length - 1) - i3].getKey(), min2);
            String format = decimalFormat.format(((Pair) entryArr[(entryArr.length - 1) - i3].getValue()).getRight());
            long longValue = ((Long) ((Pair) entryArr[(entryArr.length - 1) - i3].getValue()).getLeft()).longValue();
            int length5 = wrap.length;
            int i4 = 0;
            while (i4 < length5) {
                String str = i4 < wrap.length ? wrap[i4] : "";
                if (i4 == 0) {
                    sb.append(String.format(" %" + length + "d  %-" + min2 + "s  %" + length3 + "s  %" + length4 + "d", Integer.valueOf(i3 + 1), str, format, Long.valueOf(longValue)));
                } else {
                    sb.append(String.format(" %" + length + "s  %-" + min2 + "s  %" + length3 + "s  %" + length4 + "s", "", str, "", ""));
                }
                sb.append(ProgramConverter.NEWLINE);
                i4++;
            }
        }
        return sb.toString();
    }

    private static FedStatsCollection collectFedStats() {
        Future<FederatedResponse>[] federatedResponses = getFederatedResponses();
        FedStatsCollection fedStatsCollection = new FedStatsCollection();
        for (Future<FederatedResponse> future : federatedResponses) {
            try {
                Object[] data = future.get().getData();
                if (data[0] instanceof FedStatsCollection) {
                    fedStatsCollection.aggregate((FedStatsCollection) data[0]);
                }
            } catch (Exception e) {
                throw new DMLRuntimeException("Exception of type " + e.getClass().toString() + " thrown while getting the federated stats of the federated response: ", e);
            }
        }
        return fedStatsCollection;
    }

    private static Future<FederatedResponse>[] getFederatedResponses() {
        ArrayList arrayList = new ArrayList();
        for (Pair<String, Integer> pair : _fedWorkerAddresses) {
            InetSocketAddress inetSocketAddress = new InetSocketAddress((String) pair.getLeft(), ((Integer) pair.getRight()).intValue());
            try {
                arrayList.add(FederatedData.executeFederatedOperation(inetSocketAddress, new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1L, new FedStatsCollectFunction())));
            } catch (SSLException e) {
                System.out.println("SSLException while getting the federated stats from " + inetSocketAddress.toString() + ": " + e.getMessage());
            } catch (DMLRuntimeException e2) {
            } catch (Exception e3) {
                System.out.println("Exeption of type " + e3.getClass().getName() + " thrown while getting stats from federated worker: " + e3.getMessage());
            }
        }
        return (Future[]) arrayList.toArray(new Future[0]);
    }
}
