package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;
import org.apache.sysds.common.Types;
import org.apache.sysds.parser.Statement;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.class */
public class AggregateUnaryFEDInstruction extends UnaryFEDInstruction {
    private AggregateUnaryFEDInstruction(AggregateUnaryOperator aggregateUnaryOperator, CPOperand cPOperand, CPOperand cPOperand2, String str, String str2, FEDInstruction.FederatedOutput federatedOutput) {
        super(FEDInstruction.FEDType.AggregateUnary, aggregateUnaryOperator, cPOperand, cPOperand2, str, str2, federatedOutput);
    }

    protected AggregateUnaryFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, FEDInstruction.FederatedOutput federatedOutput) {
        super(FEDInstruction.FEDType.AggregateUnary, operator, cPOperand, cPOperand2, cPOperand3, str, str2, federatedOutput);
    }

    protected AggregateUnaryFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        super(FEDInstruction.FEDType.AggregateUnary, operator, cPOperand, cPOperand2, cPOperand3, str, str2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AggregateUnaryFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2) {
        super(FEDInstruction.FEDType.AggregateUnary, operator, cPOperand, cPOperand2, cPOperand3, cPOperand4, str, str2);
    }

    public static AggregateUnaryFEDInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        AggregateUnaryOperator parseAggregateUnaryRowIndexOperator = (str2.equalsIgnoreCase("uarimax") || str2.equalsIgnoreCase("uarimin")) ? InstructionUtils.parseAggregateUnaryRowIndexOperator(str2, Integer.parseInt(instructionPartsWithValueType[4]), 1) : InstructionUtils.parseBasicAggregateUnaryOperator(str2);
        if (InstructionUtils.getExecType(str) == Types.ExecType.SPARK) {
            str = InstructionUtils.replaceOperand(str, 4, "-1");
        }
        return new AggregateUnaryFEDInstruction(parseAggregateUnaryRowIndexOperator, cPOperand, cPOperand2, str2, str, (instructionPartsWithValueType.length != 5 || instructionPartsWithValueType[4].equals("uarimin") || instructionPartsWithValueType[4].equals("uarimax")) ? FEDInstruction.FederatedOutput.valueOf(instructionPartsWithValueType[5]) : FEDInstruction.FederatedOutput.valueOf(instructionPartsWithValueType[4]));
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        if (getOpcode().contains("var")) {
            processVar(executionContext);
        } else {
            processDefault(executionContext);
        }
    }

    private void processDefault(ExecutionContext executionContext) {
        AggregateUnaryOperator aggregateUnaryOperator = (AggregateUnaryOperator) this._optr;
        MatrixObject matrixObject = executionContext.getMatrixObject(this.input1);
        FederationMap fedMapping = matrixObject.getFedMapping();
        if ((this.instOpcode.equalsIgnoreCase("uarimax") || this.instOpcode.equalsIgnoreCase("uarimin")) && matrixObject.isFederated(FederationMap.FType.COL)) {
            this.instString = InstructionUtils.replaceOperand(this.instString, 5, "2");
        }
        if (this._fedOut.isForcedFederated()) {
            processFederatedOutput(fedMapping, matrixObject, executionContext);
        } else {
            processGetOutput(fedMapping, aggregateUnaryOperator, executionContext, matrixObject);
        }
    }

    private void processFederatedOutput(FederationMap federationMap, MatrixObject matrixObject, ExecutionContext executionContext) {
        if (this.output.isScalar()) {
            throw new DMLRuntimeException("Output of FED instruction, " + this.output.toString() + ", is a scalar and the output is set to be federated. Scalars cannot be federated. ");
        }
        FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1}, new long[]{matrixObject.getFedMapping().getID()}, true);
        federationMap.execute(getTID(), callInstruction);
        executionContext.getMatrixObject(this.output).setFedMapping(matrixObject.getFedMapping().copyWithNewID(callInstruction.getID()));
    }

    private void processGetOutput(FederationMap federationMap, AggregateUnaryOperator aggregateUnaryOperator, ExecutionContext executionContext, MatrixObject matrixObject) {
        FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1}, new long[]{matrixObject.getFedMapping().getID()}, true);
        Future<FederatedResponse>[] execute = federationMap.execute(getTID(), callInstruction, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction.getID()), federationMap.cleanup(getTID(), callInstruction.getID()));
        if (this.output.isScalar()) {
            executionContext.setVariable(this.output.getName(), FederationUtils.aggScalar(aggregateUnaryOperator, execute, federationMap));
        } else {
            executionContext.setMatrixOutput(this.output.getName(), FederationUtils.aggMatrix(aggregateUnaryOperator, execute, federationMap));
        }
    }

    private void processVar(ExecutionContext executionContext) {
        if (this._fedOut.isForcedFederated()) {
            throw new DMLRuntimeException("Output of " + toString() + " should not be federated since the instruction requires consolidation of partial results to be computed.");
        }
        AggregateUnaryOperator aggregateUnaryOperator = (AggregateUnaryOperator) this._optr;
        MatrixObject matrixObject = executionContext.getMatrixObject(this.input1);
        FederationMap fedMapping = matrixObject.getFedMapping();
        Future<FederatedResponse>[] futureArr = null;
        if (getOpcode().contains("var")) {
            FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString.replace(getOpcode(), getOpcode().replace("var", Statement.GAGG_FN_MEAN)), this.output, new CPOperand[]{this.input1}, new long[]{matrixObject.getFedMapping().getID()});
            futureArr = fedMapping.execute(getTID(), callInstruction, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction.getID()), fedMapping.cleanup(getTID(), callInstruction.getID()));
        }
        FederatedRequest callInstruction2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1}, new long[]{matrixObject.getFedMapping().getID()});
        Future<FederatedResponse>[] execute = fedMapping.execute(getTID(), callInstruction2, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction2.getID()), fedMapping.cleanup(getTID(), callInstruction2.getID()));
        if (this.output.isScalar()) {
            executionContext.setVariable(this.output.getName(), FederationUtils.aggScalar(aggregateUnaryOperator, execute, futureArr, fedMapping));
        } else {
            executionContext.setMatrixOutput(this.output.getName(), FederationUtils.aggMatrix(aggregateUnaryOperator, execute, futureArr, fedMapping));
        }
    }
}
