package org.apache.sysds.utils.stats;

import java.util.concurrent.atomic.LongAdder;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;

/* loaded from: input_file:org/apache/sysds/utils/stats/ParamServStatistics.class */
public class ParamServStatistics {
    private static final Timing executionTimer = new Timing(false);
    private static final LongAdder executionTime = new LongAdder();
    private static final LongAdder numWorkers = new LongAdder();
    private static final LongAdder setupTime = new LongAdder();
    private static final LongAdder gradientComputeTime = new LongAdder();
    private static final LongAdder aggregationTime = new LongAdder();
    private static final LongAdder localModelUpdateTime = new LongAdder();
    private static final LongAdder modelBroadcastTime = new LongAdder();
    private static final LongAdder batchIndexTime = new LongAdder();
    private static final LongAdder rpcRequestTime = new LongAdder();
    private static final LongAdder validationTime = new LongAdder();
    private static final LongAdder fedDataPartitioningTime = new LongAdder();
    private static final LongAdder fedWorkerComputingTime = new LongAdder();
    private static final LongAdder fedGradientWeightingTime = new LongAdder();
    private static final LongAdder fedCommunicationTime = new LongAdder();
    private static final LongAdder fedNetworkTime = new LongAdder();
    private static final LongAdder heEncryption = new LongAdder();
    private static final LongAdder heAccumulation = new LongAdder();
    private static final LongAdder hePartialDecryption = new LongAdder();
    private static final LongAdder heDecryption = new LongAdder();
    private static final LongAdder fedAggregation = new LongAdder();

    public static void incWorkerNumber() {
        numWorkers.increment();
    }

    public static void incWorkerNumber(long j) {
        numWorkers.add(j);
    }

    public static Timing getExecutionTimer() {
        return executionTimer;
    }

    public static double getExecutionTime() {
        return executionTime.doubleValue();
    }

    public static void accExecutionTime(long j) {
        executionTime.add(j);
    }

    public static void accSetupTime(long j) {
        setupTime.add(j);
    }

    public static void accGradientComputeTime(long j) {
        gradientComputeTime.add(j);
    }

    public static void accAggregationTime(long j) {
        aggregationTime.add(j);
    }

    public static void accLocalModelUpdateTime(long j) {
        localModelUpdateTime.add(j);
    }

    public static void accModelBroadcastTime(long j) {
        modelBroadcastTime.add(j);
    }

    public static void accBatchIndexingTime(long j) {
        batchIndexTime.add(j);
    }

    public static void accRpcRequestTime(long j) {
        rpcRequestTime.add(j);
    }

    public static double getValidationTime() {
        return validationTime.doubleValue();
    }

    public static void accValidationTime(long j) {
        validationTime.add(j);
    }

    public static long getFedDataPartitioningTime() {
        return fedDataPartitioningTime.longValue();
    }

    public static void accFedDataPartitioningTime(long j) {
        fedDataPartitioningTime.add(j);
    }

    public static void accFedWorkerComputing(long j) {
        fedWorkerComputingTime.add(j);
    }

    public static void accFedNetworkTime(long j) {
        fedNetworkTime.add(j);
    }

    public static void accFedAggregation(long j) {
        fedAggregation.add(j);
    }

    public static void accFedGradientWeightingTime(long j) {
        fedGradientWeightingTime.add(j);
    }

    public static void accFedCommunicationTime(long j) {
        fedCommunicationTime.add(j);
    }

    public static void accHEEncryptionTime(long j) {
        heEncryption.add(j);
    }

    public static void accHEAccumulation(long j) {
        heAccumulation.add(j);
    }

    public static void accHEPartialDecryptionTime(long j) {
        hePartialDecryption.add(j);
    }

    public static void accHEDecryptionTime(long j) {
        heDecryption.add(j);
    }

    public static void reset() {
        executionTime.reset();
        numWorkers.reset();
        setupTime.reset();
        gradientComputeTime.reset();
        aggregationTime.reset();
        localModelUpdateTime.reset();
        modelBroadcastTime.reset();
        batchIndexTime.reset();
        rpcRequestTime.reset();
        validationTime.reset();
        fedDataPartitioningTime.reset();
        fedWorkerComputingTime.reset();
        fedGradientWeightingTime.reset();
        fedCommunicationTime.reset();
        fedNetworkTime.reset();
        heEncryption.reset();
        heAccumulation.reset();
        hePartialDecryption.reset();
        heDecryption.reset();
        fedAggregation.reset();
    }

    public static String displayStatistics() {
        if (numWorkers.longValue() <= 0) {
            return "";
        }
        StringBuilder sb = new StringBuilder();
        sb.append(String.format("Paramserv total execution time:\t%.3f secs.\n", Double.valueOf(executionTime.doubleValue() / 1000.0d)));
        sb.append(String.format("Paramserv total num workers:\t%d.\n", Long.valueOf(numWorkers.longValue())));
        sb.append(String.format("Paramserv setup time:\t\t%.3f secs.\n", Double.valueOf(setupTime.doubleValue() / 1000.0d)));
        if (fedDataPartitioningTime.longValue() > 0) {
            sb.append(displayFedPSStatistics());
            sb.append(String.format("PS fed global model agg time:\t%.3f secs.\n", Double.valueOf(aggregationTime.doubleValue() / 1000.0d)));
        } else {
            sb.append(String.format("Paramserv grad compute time:\t%.3f secs.\n", Double.valueOf(gradientComputeTime.doubleValue() / 1000.0d)));
            sb.append(String.format("Paramserv model update time:\t%.3f/%.3f secs.\n", Double.valueOf(localModelUpdateTime.doubleValue() / 1000.0d), Double.valueOf(aggregationTime.doubleValue() / 1000.0d)));
            sb.append(String.format("Paramserv model broadcast time:\t%.3f secs.\n", Double.valueOf(modelBroadcastTime.doubleValue() / 1000.0d)));
            sb.append(String.format("Paramserv batch slice time:\t%.3f secs.\n", Double.valueOf(batchIndexTime.doubleValue() / 1000.0d)));
            sb.append(String.format("Paramserv RPC request time:\t%.3f secs.\n", Double.valueOf(rpcRequestTime.doubleValue() / 1000.0d)));
        }
        sb.append(String.format("Paramserv valdiation time:\t%.3f secs.\n", Double.valueOf(validationTime.doubleValue() / 1000.0d)));
        return sb.toString();
    }

    private static String displayFedPSStatistics() {
        return String.format("PS fed data partitioning time:\t%.3f secs.\n", Double.valueOf(fedDataPartitioningTime.doubleValue() / 1000.0d)) + String.format("PS fed comm time (cum):\t\t%.3f secs.\n", Double.valueOf(fedCommunicationTime.doubleValue() / 1000.0d)) + String.format("PS fed worker comp time (cum):\t%.3f secs.\n", Double.valueOf(fedWorkerComputingTime.doubleValue() / 1000.0d)) + String.format("PS fed grad. weigh. time (cum):\t%.3f secs.\n", Double.valueOf(fedGradientWeightingTime.doubleValue() / 1000.0d));
    }

    public static String displayFloStatistics() {
        return String.format("PS fed network time (cum):\t%.3f secs.\n", Double.valueOf(fedNetworkTime.doubleValue() / 1000.0d)) + String.format("PS fed agg time:\t%.3f secs.\n", Double.valueOf(fedAggregation.doubleValue() / 1000.0d)) + String.format("Paramserv grad compute time:\t%.3f secs.\n", Double.valueOf(gradientComputeTime.doubleValue() / 1000.0d)) + String.format("HE PS encryption time:\t%.3f secs.\n", Double.valueOf(heEncryption.doubleValue() / 1000.0d)) + String.format("HE PS accumulation time:\t%.3f secs.\n", Double.valueOf(heAccumulation.doubleValue() / 1000.0d)) + String.format("HE PS partial decryption time:\t%.3f secs.\n", Double.valueOf(hePartialDecryption.doubleValue() / 1000.0d)) + String.format("HE PS decryption time:\t%.3f secs.\n", Double.valueOf(heDecryption.doubleValue() / 1000.0d));
    }
}
