package org.apache.sysds.runtime.controlprogram.paramserv;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.stream.Collectors;
import org.apache.sysds.common.Types;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.FunctionProgramBlock;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.parfor.stat.Timing;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/paramserv/PSWorker.class */
public abstract class PSWorker implements Serializable {
    private static final long serialVersionUID = -3510485051178200118L;
    protected int _workerID;
    protected int _epochs;
    protected long _batchSize;
    protected ExecutionContext _ec;
    protected ParamServer _ps;
    protected DataIdentifier _output;
    protected FunctionCallCPInstruction _inst;
    protected MatrixObject _features;
    protected MatrixObject _labels;
    protected String _updFunc;
    protected Statement.PSFrequency _freq;

    /* JADX INFO: Access modifiers changed from: protected */
    public PSWorker() {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public PSWorker(int i, String str, Statement.PSFrequency pSFrequency, int i2, long j, ExecutionContext executionContext, ParamServer paramServer) {
        this._workerID = i;
        this._updFunc = str;
        this._freq = pSFrequency;
        this._epochs = i2;
        this._batchSize = j;
        this._ec = executionContext;
        this._ps = paramServer;
        setupUpdateFunction(str, executionContext);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setupUpdateFunction(String str, ExecutionContext executionContext) {
        String[] splitFunctionKey = DMLProgram.splitFunctionKey(str);
        String str2 = splitFunctionKey[0];
        String str3 = splitFunctionKey[1];
        FunctionProgramBlock functionProgramBlock = executionContext.getProgram().getFunctionProgramBlock(str2, str3, false);
        ArrayList<DataIdentifier> inputParams = functionProgramBlock.getInputParams();
        ArrayList<DataIdentifier> outputParams = functionProgramBlock.getOutputParams();
        this._inst = new FunctionCallCPInstruction(str2, str3, false, (CPOperand[]) inputParams.stream().map(dataIdentifier -> {
            return new CPOperand(dataIdentifier.getName(), dataIdentifier.getValueType(), dataIdentifier.getDataType());
        }).toArray(i -> {
            return new CPOperand[i];
        }), functionProgramBlock.getInputParamNames(), (ArrayList) outputParams.stream().map((v0) -> {
            return v0.getName();
        }).collect(Collectors.toCollection(ArrayList::new)), "update function");
        checkInput(false, inputParams, Types.DataType.MATRIX, Statement.PS_FEATURES);
        checkInput(false, inputParams, Types.DataType.MATRIX, Statement.PS_LABELS);
        checkInput(false, inputParams, Types.DataType.LIST, Statement.PS_MODEL);
        checkInput(true, inputParams, Types.DataType.LIST, Statement.PS_HYPER_PARAMS);
        if (outputParams.size() != 1) {
            throw new DMLRuntimeException(String.format("The output of the '%s' function should provide one list containing the gradients.", str));
        }
        if (outputParams.get(0).getDataType() != Types.DataType.LIST) {
            throw new DMLRuntimeException(String.format("The output of the '%s' function should be type of list.", str));
        }
        this._output = outputParams.get(0);
    }

    private void checkInput(boolean z, ArrayList<DataIdentifier> arrayList, Types.DataType dataType, String str) {
        if ((!z || !arrayList.stream().noneMatch(dataIdentifier -> {
            return str.equals(dataIdentifier.getName());
        })) && arrayList.stream().filter(dataIdentifier2 -> {
            return dataIdentifier2.getDataType() == dataType && str.equals(dataIdentifier2.getName());
        }).count() != 1) {
            throw new DMLRuntimeException(String.format("The '%s' function should provide an input of '%s' type named '%s'.", this._updFunc, dataType, str));
        }
    }

    public void setFeatures(MatrixObject matrixObject) {
        this._features = matrixObject;
    }

    public void setLabels(MatrixObject matrixObject) {
        this._labels = matrixObject;
    }

    public MatrixObject getFeatures() {
        return this._features;
    }

    public MatrixObject getLabels() {
        return this._labels;
    }

    public abstract String getWorkerName();

    protected abstract void incWorkerNumber();

    protected abstract void accLocalModelUpdateTime(Timing timing);

    protected abstract void accBatchIndexingTime(Timing timing);

    protected abstract void accGradientComputeTime(Timing timing);

    /* JADX INFO: Access modifiers changed from: protected */
    public void accNumEpochs(int i) {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void accNumBatches(int i) {
    }
}
