package org.apache.sysds.runtime.controlprogram.paramserv;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.function.VoidFunction;
import org.apache.spark.util.LongAccumulator;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.codegen.CodegenUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.rpc.PSRpcFactory;
import org.apache.sysds.runtime.controlprogram.parfor.RemoteParForUtils;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.ProgramConverter;
import scala.Tuple2;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/SparkPSWorker.class */
public class SparkPSWorker extends LocalPSWorker implements VoidFunction<Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>>> {
    private static final long serialVersionUID = -8674739573419648732L;
    private final String _program;
    private final boolean _isLocal;
    private final HashMap<String, byte[]> _clsMap;
    private final SparkConf _conf;
    private final int _port;
    private final String _aggFunc;
    private final LongAccumulator _aSetup;
    private final LongAccumulator _aWorker;
    private final LongAccumulator _aUpdate;
    private final LongAccumulator _aIndex;
    private final LongAccumulator _aGrad;
    private final LongAccumulator _aRPC;
    private final LongAccumulator _nBatches;
    private final LongAccumulator _nEpochs;

    public SparkPSWorker(String str, String str2, Statement.PSFrequency pSFrequency, int i, long j, String str3, boolean z, HashMap<String, byte[]> hashMap, SparkConf sparkConf, int i2, LongAccumulator longAccumulator, LongAccumulator longAccumulator2, LongAccumulator longAccumulator3, LongAccumulator longAccumulator4, LongAccumulator longAccumulator5, LongAccumulator longAccumulator6, LongAccumulator longAccumulator7, LongAccumulator longAccumulator8, int i3, boolean z2) {
        this._updFunc = str;
        this._aggFunc = str2;
        this._freq = pSFrequency;
        this._epochs = i;
        this._batchSize = j;
        this._program = str3;
        this._isLocal = z;
        this._clsMap = hashMap;
        this._conf = sparkConf;
        this._port = i2;
        this._aSetup = longAccumulator;
        this._aWorker = longAccumulator2;
        this._aUpdate = longAccumulator3;
        this._aIndex = longAccumulator4;
        this._aGrad = longAccumulator5;
        this._aRPC = longAccumulator6;
        this._nBatches = longAccumulator7;
        this._nEpochs = longAccumulator8;
        this._nbatches = i3;
        this._modelAvg = z2;
        this._tpool = null;
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker, org.apache.sysds.runtime.controlprogram.paramserv.PSWorker
    public String getWorkerName() {
        return String.format("Spark worker_%d", Integer.valueOf(this._workerID));
    }

    public void call(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> tuple2) throws Exception {
        Timing timing = new Timing(true);
        configureWorker(tuple2);
        accSetupTime(timing);
        call();
    }

    private void configureWorker(Tuple2<Integer, Tuple2<MatrixBlock, MatrixBlock>> tuple2) throws IOException, InterruptedException {
        this._workerID = ((Integer) tuple2._1).intValue();
        for (Map.Entry<String, byte[]> entry : this._clsMap.entrySet()) {
            CodegenUtils.getClassSync(entry.getKey(), entry.getValue());
        }
        this._ec = ProgramConverter.parseSparkPSBody(this._program, this._workerID).getEc();
        RemoteParForUtils.setupBufferPool(this._workerID, this._isLocal);
        this._ps = PSRpcFactory.createSparkPSProxy(this._conf, this._port, this._aRPC);
        setupUpdateFunction(this._updFunc, this._ec);
        this._ps.setupAggFunc(this._ec, this._aggFunc);
        setFeatures(ParamservUtils.newMatrixObject((MatrixBlock) ((Tuple2) tuple2._2)._1, false));
        setLabels(ParamservUtils.newMatrixObject((MatrixBlock) ((Tuple2) tuple2._2)._2, false));
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker, org.apache.sysds.runtime.controlprogram.paramserv.PSWorker
    protected void incWorkerNumber() {
        this._aWorker.add(1L);
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker, org.apache.sysds.runtime.controlprogram.paramserv.PSWorker
    protected void accLocalModelUpdateTime(Timing timing) {
        if (timing != null) {
            this._aUpdate.add((long) timing.stop());
        }
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker, org.apache.sysds.runtime.controlprogram.paramserv.PSWorker
    protected void accBatchIndexingTime(Timing timing) {
        if (timing != null) {
            this._aIndex.add((long) timing.stop());
        }
    }

    @Override // org.apache.sysds.runtime.controlprogram.paramserv.LocalPSWorker, org.apache.sysds.runtime.controlprogram.paramserv.PSWorker
    protected void accGradientComputeTime(Timing timing) {
        if (timing != null) {
            this._aGrad.add((long) timing.stop());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.controlprogram.paramserv.PSWorker
    public void accNumEpochs(int i) {
        this._nEpochs.add(i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.apache.sysds.runtime.controlprogram.paramserv.PSWorker
    public void accNumBatches(int i) {
        this._nBatches.add(i);
    }

    private void accSetupTime(Timing timing) {
        if (timing != null) {
            this._aSetup.add((long) timing.stop());
        }
    }
}
