package org.apache.sysds.runtime.controlprogram.paramserv;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.PublicKey;
import org.apache.sysds.runtime.controlprogram.paramserv.homomorphicEncryption.SEALServer;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.CiphertextMatrix;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.PlaintextMatrix;
import org.apache.sysds.utils.stats.ParamServStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/HEParamServer.class */
public class HEParamServer extends LocalParamServer {
    private int _thread_counter;
    private final List<FederatedPSControlThread> _threads;
    private final List<Object> _result_buffer;
    private Object _result;
    private final SEALServer _seal_server;
    private Timing commTimer;

    public static HEParamServer create(ListObject listObject, String str, Statement.PSUpdateType pSUpdateType, Statement.PSFrequency pSFrequency, ExecutionContext executionContext, int i, String str2, int i2, MatrixObject matrixObject, MatrixObject matrixObject2, int i3) {
        NativeHEHelper.initialize();
        return new HEParamServer(listObject, str, pSUpdateType, pSFrequency, executionContext, i, str2, i2, matrixObject, matrixObject2, i3);
    }

    private HEParamServer(ListObject listObject, String str, Statement.PSUpdateType pSUpdateType, Statement.PSFrequency pSFrequency, ExecutionContext executionContext, int i, String str2, int i2, MatrixObject matrixObject, MatrixObject matrixObject2, int i3) {
        super(listObject, str, pSUpdateType, pSFrequency, executionContext, i, str2, i2, matrixObject, matrixObject2, i3, true);
        this._thread_counter = 0;
        this._seal_server = new SEALServer();
        this._threads = Collections.synchronizedList(new ArrayList(i));
        for (int i4 = 0; i4 < getNumWorkers(); i4++) {
            this._threads.add(null);
        }
        this._result_buffer = new ArrayList(i);
        resetResultBuffer();
    }

    public void registerThread(int i, FederatedPSControlThread federatedPSControlThread) {
        this._threads.set(i, federatedPSControlThread);
    }

    private synchronized void resetResultBuffer() {
        this._result_buffer.clear();
        for (int i = 0; i < getNumWorkers(); i++) {
            this._result_buffer.add(null);
        }
    }

    public byte[] generateA() {
        return this._seal_server.generateA();
    }

    public PublicKey aggregatePartialPublicKeys(PublicKey[] publicKeyArr) {
        return this._seal_server.aggregatePartialPublicKeys(publicKeyArr);
    }

    private synchronized <T, U> U collectAndDo(int i, T t, Function<List<T>, U> function) {
        this._result_buffer.set(i, t);
        this._thread_counter++;
        if (this._thread_counter == getNumWorkers()) {
            this._result = function.apply((List) this._result_buffer.stream().map(obj -> {
                return obj;
            }).collect(Collectors.toList()));
            resetResultBuffer();
            this._thread_counter = 0;
            notifyAll();
        } else {
            try {
                wait();
            } catch (InterruptedException e) {
                throw new RuntimeException("thread interrupted");
            }
        }
        return (U) this._result;
    }

    private CiphertextMatrix[] homomorphicAggregation(List<ListObject> list) {
        Timing timing = DMLScript.STATISTICS ? new Timing(true) : null;
        CiphertextMatrix[] ciphertextMatrixArr = new CiphertextMatrix[list.get(0).getLength()];
        IntStream.range(0, list.get(0).getLength()).forEach(i -> {
            CiphertextMatrix[] ciphertextMatrixArr2 = new CiphertextMatrix[list.size()];
            for (int i = 0; i < list.size(); i++) {
                ciphertextMatrixArr2[i] = (CiphertextMatrix) ((ListObject) list.get(i)).getData(i);
            }
            ciphertextMatrixArr[i] = this._seal_server.accumulateCiphertexts(ciphertextMatrixArr2);
        });
        if (timing != null) {
            ParamServStatistics.accHEAccumulation((long) timing.stop());
        }
        return ciphertextMatrixArr;
    }

    private Void homomorphicAverage(CiphertextMatrix[] ciphertextMatrixArr, List<PlaintextMatrix[]> list) {
        Timing timing = DMLScript.STATISTICS ? new Timing(true) : null;
        MatrixObject[] matrixObjectArr = new MatrixObject[list.get(0).length];
        IntStream.range(0, list.get(0).length).forEach(i -> {
            PlaintextMatrix[] plaintextMatrixArr = new PlaintextMatrix[list.size()];
            for (int i = 0; i < list.size(); i++) {
                plaintextMatrixArr[i] = ((PlaintextMatrix[]) list.get(i))[i];
            }
            matrixObjectArr[i] = this._seal_server.average(ciphertextMatrixArr[i], plaintextMatrixArr);
        });
        ListObject listObject = new ListObject(getResult());
        for (int i2 = 0; i2 < listObject.getLength(); i2++) {
            listObject.set(i2, matrixObjectArr[i2]);
        }
        if (timing != null) {
            ParamServStatistics.accHEDecryptionTime((long) timing.stop());
        }
        updateAndBroadcastModel(listObject, null);
        return null;
    }

    private void startCommTimer() {
        this.commTimer = new Timing(true);
    }

    private long stopCommTimer() {
        return (long) this.commTimer.stop();
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.LocalParamServer, org.apache.sysds.runtime.controlprogram.paramserv.ParamServer
    public void push(int i, ListObject listObject) {
        CiphertextMatrix[] ciphertextMatrixArr = (CiphertextMatrix[]) collectAndDo(i, listObject, list -> {
            CiphertextMatrix[] homomorphicAggregation = homomorphicAggregation(list);
            startCommTimer();
            return homomorphicAggregation;
        });
        collectAndDo(i, this._threads.get(i).getPartialDecryption(ciphertextMatrixArr), list2 -> {
            ParamServStatistics.accFedNetworkTime(stopCommTimer());
            return homomorphicAverage(ciphertextMatrixArr, list2);
        });
    }
}
