package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;
import org.antlr.v4.runtime.tree.xpath.XPath;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/CumulativeOffsetFEDInstruction.class */
public class CumulativeOffsetFEDInstruction extends BinaryFEDInstruction {
    private UnaryOperator _uop;

    private CumulativeOffsetFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, double d, boolean z, String str, String str2) {
        super(FEDInstruction.FEDType.CumsumOffset, operator, cPOperand, cPOperand2, cPOperand3, str, str2);
        this._uop = null;
        if ("bcumoffk+".equals(str)) {
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+"));
            return;
        }
        if ("bcumoff*".equals(str)) {
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucum*"));
            return;
        }
        if ("bcumoff+*".equals(str)) {
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+*"));
        } else if ("bcumoffmin".equals(str)) {
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucummin"));
        } else if ("bcumoffmax".equals(str)) {
            this._uop = new UnaryOperator(Builtin.getBuiltinFnObject("ucummax"));
        }
    }

    public static CumulativeOffsetFEDInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 5);
        return new CumulativeOffsetFEDInstruction(null, new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), Double.parseDouble(instructionPartsWithValueType[4]), Boolean.parseBoolean(instructionPartsWithValueType[5]), instructionPartsWithValueType[0], str);
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        MatrixObject matrixObject = executionContext.getMatrixObject(this.input1);
        MatrixObject matrixObject2 = executionContext.getMatrixObject(this.input2);
        if (getOpcode().startsWith("bcumoff") && matrixObject.isFederated(FederationMap.FType.ROW)) {
            processCumulativeInstruction(executionContext);
            return;
        }
        FederatedRequest[] broadcastSliced = matrixObject.getFedMapping().broadcastSliced(matrixObject2, false);
        FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixObject.getFedMapping().getID(), broadcastSliced[0].getID()});
        matrixObject.getFedMapping().execute(getTID(), true, broadcastSliced, new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, callInstruction.getID(), matrixObject.getDataCharacteristics(), matrixObject.getDataType()), callInstruction);
        setOutputFedMapping(executionContext, matrixObject, callInstruction.getID());
    }

    public void processCumulativeInstruction(ExecutionContext executionContext) {
        MatrixObject matrixObject;
        MatrixObject matrixObject2 = executionContext.getMatrixObject(this.input1.getName());
        MatrixObject matrixObject3 = executionContext.getMatrixObject(this.input2.getName());
        DataCharacteristics dataCharacteristics = executionContext.getDataCharacteristics(this.output.getName());
        long nextFedDataID = FederationUtils.getNextFedDataID();
        String opcode = getOpcode();
        if (opcode.equalsIgnoreCase("bcumoff+*")) {
            FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, nextFedDataID, dataCharacteristics, matrixObject2.getDataType());
            FederatedRequest broadcast = matrixObject2.getFedMapping().broadcast(matrixObject3);
            FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, nextFedDataID, new CPOperand[]{this.input1, this.input2}, new long[]{matrixObject2.getFedMapping().getID(), broadcast.getID()}, Types.ExecType.SPARK, false);
            Future<FederatedResponse>[] execute = matrixObject2.getFedMapping().execute(getTID(), true, federatedRequest, broadcast, callInstruction, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction.getID()));
            matrixObject = setOutputFedMapping(executionContext, matrixObject2, callInstruction.getID());
            setScalingValues(executionContext, matrixObject2, matrixObject, getScalars(matrixObject2, matrixObject3, execute));
        } else {
            String replace = opcode.replace("bcumoff", "uac");
            String replace2 = opcode.replace(opcode.contains("bcumoffk") ? "bcumoffk" : "bcumoff", "");
            double d = opcode.equalsIgnoreCase("bcumoffk+") ? DataExpression.DEFAULT_DELIM_FILL_VALUE : opcode.equalsIgnoreCase("bcumoff*") ? 1.0d : opcode.equalsIgnoreCase("bcumoffmin") ? Double.MAX_VALUE : -1.7976931348623157E308d;
            MatrixBlock resultBlock = getResultBlock(modifyAndGetInstruction(replace, matrixObject2, matrixObject3), (int) matrixObject2.getNumColumns(), opcode, d, this._uop);
            matrixObject = executionContext.getMatrixObject(this.output);
            setScalingValues(replace2, executionContext, matrixObject2, matrixObject, resultBlock, d);
        }
        processCumulative(matrixObject, matrixObject3);
    }

    private Future<FederatedResponse>[] modifyAndGetInstruction(String str, MatrixObject matrixObject, MatrixObject matrixObject2) {
        String concatOperands = InstructionUtils.concatOperands(InstructionUtils.removeOperand(InstructionUtils.removeOperand(InstructionUtils.removeOperand(InstructionUtils.replaceOperand(this.instString, 1, str), 3), 4), 4), AggBinaryOp.SparkAggType.SINGLE_BLOCK.name());
        long nextFedDataID = FederationUtils.getNextFedDataID();
        FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, nextFedDataID, new MatrixCharacteristics(-1L, -1L), matrixObject.getDataType());
        FederatedRequest callInstruction = FederationUtils.callInstruction(concatOperands, this.output, nextFedDataID, new CPOperand[]{this.input1}, new long[]{matrixObject.getFedMapping().getID()}, Types.ExecType.SPARK, false);
        return matrixObject.getFedMapping().execute(getTID(), true, federatedRequest, callInstruction, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction.getID()));
    }

    private void processCumulative(MatrixObject matrixObject, MatrixObject matrixObject2) {
        String replaceOperand = InstructionUtils.replaceOperand(this.instString, 2, InstructionUtils.createOperand(this.output));
        FederatedRequest broadcast = matrixObject.getFedMapping().broadcast(matrixObject2);
        FederatedRequest callInstruction = FederationUtils.callInstruction(replaceOperand, this.output, matrixObject.getFedMapping().getID(), new CPOperand[]{this.output, this.input2}, new long[]{matrixObject.getFedMapping().getID(), broadcast.getID()}, Types.ExecType.SPARK, false);
        matrixObject.getFedMapping().execute(getTID(), true, broadcast, callInstruction);
        matrixObject.setFedMapping(matrixObject.getFedMapping().copyWithNewID(callInstruction.getID()));
        if (!getOpcode().equalsIgnoreCase("bcumoff+*")) {
            matrixObject.getDataCharacteristics().set(matrixObject.getNumRows(), matrixObject.getNumColumns(), (int) matrixObject.getBlocksize());
            return;
        }
        matrixObject.getDataCharacteristics().set(matrixObject.getNumRows(), 1L, (int) matrixObject.getBlocksize());
        for (int i = 0; i < matrixObject.getFedMapping().getFederatedRanges().length; i++) {
            matrixObject.getFedMapping().getFederatedRanges()[i].setEndDim(1, 1L);
        }
    }

    private static MatrixBlock getResultBlock(Future<FederatedResponse>[] futureArr, int i, String str, double d, UnaryOperator unaryOperator) {
        MatrixBlock matrixBlock = new MatrixBlock(futureArr.length, i, d);
        for (int i2 = 0; i2 < futureArr.length - 1; i2++) {
            try {
                matrixBlock.copy(i2 + 1, i2 + 1, 0, i - 1, (MatrixBlock) futureArr[i2].get().getData()[0], true);
            } catch (Exception e) {
                throw new DMLRuntimeException("Federated Get data failed with exception on CumulativeOffsetFEDInstruction", e);
            }
        }
        return matrixBlock.unaryOperations(unaryOperator, (MatrixValue) new MatrixBlock());
    }

    private MatrixBlock getScalars(MatrixObject matrixObject, MatrixObject matrixObject2, Future<FederatedResponse>[] futureArr) {
        MatrixBlock[] aggMatrices = getAggMatrices(matrixObject, matrixObject2);
        MatrixBlock matrixBlock = aggMatrices[0];
        MatrixBlock matrixBlock2 = aggMatrices[1];
        for (int i = 0; i < futureArr.length; i++) {
            try {
                MatrixBlock matrixBlock3 = (MatrixBlock) futureArr[i].get().getData()[0];
                matrixBlock.setValue(i, 0, matrixBlock3.getValue(matrixBlock3.getNumRows() - 1, 0));
            } catch (Exception e) {
                throw new DMLRuntimeException("Federated Get data failed with exception on CumulativeOffsetFEDInstruction", e);
            }
        }
        MatrixBlock matrixBlock4 = new MatrixBlock(futureArr.length, 1, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        matrixBlock4.copy(1, matrixBlock4.getNumRows() - 1, 0, 0, matrixBlock.unaryOperations(new UnaryOperator(Builtin.getBuiltinFnObject("ucumk+*")), (MatrixValue) new MatrixBlock()).slice(0, matrixBlock.getNumRows() - 2), true);
        return matrixBlock2.slice(0, matrixBlock2.getNumRows() - 1, 1, 1).binaryOperations(InstructionUtils.parseBinaryOperator(XPath.WILDCARD), (MatrixValue) matrixBlock4, (MatrixValue) new MatrixBlock()).binaryOperationsInPlace(InstructionUtils.parseBinaryOperator("+"), (MatrixValue) matrixBlock2.slice(0, matrixBlock2.getNumRows() - 1, 0, 0));
    }

    private MatrixBlock[] getAggMatrices(MatrixObject matrixObject, MatrixObject matrixObject2) {
        Future<FederatedResponse>[] modifyAndGetInstruction = modifyAndGetInstruction("ucum*", matrixObject, matrixObject2);
        MatrixBlock matrixBlock = new MatrixBlock(modifyAndGetInstruction.length, 2, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        MatrixBlock matrixBlock2 = new MatrixBlock(modifyAndGetInstruction.length, 2, DataExpression.DEFAULT_DELIM_FILL_VALUE);
        for (int i = 0; i < modifyAndGetInstruction.length; i++) {
            try {
                MatrixBlock matrixBlock3 = (MatrixBlock) modifyAndGetInstruction[i].get().getData()[0];
                matrixBlock.setValue(i, 1, matrixBlock3.getValue(matrixBlock3.getNumRows() - 1, 1));
                matrixBlock2.copy(i, i, 0, 1, matrixBlock3.slice(0, 0), true);
            } catch (Exception e) {
                throw new DMLRuntimeException("Federated Get data failed with exception on CumulativeOffsetFEDInstruction", e);
            }
        }
        return new MatrixBlock[]{matrixBlock, matrixBlock2};
    }

    private void setScalingValues(ExecutionContext executionContext, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixBlock matrixBlock) {
        MatrixBlock matrixBlock2 = new MatrixBlock((int) matrixObject.getNumRows(), (int) matrixObject.getNumColumns(), 1.0d);
        MatrixBlock matrixBlock3 = new MatrixBlock((int) matrixObject.getNumRows(), (int) matrixObject.getNumColumns(), DataExpression.DEFAULT_DELIM_FILL_VALUE);
        for (int i = 0; i < matrixBlock.getNumRows() - 1; i++) {
            int i2 = (int) matrixObject.getFedMapping().getFederatedRanges()[i + 1].getBeginDims()[0];
            matrixBlock2.setValue(i2, 0, DataExpression.DEFAULT_DELIM_FILL_VALUE);
            matrixBlock3.setValue(i2, 0, matrixBlock.getValue(i + 1, 0));
        }
        MatrixObject createMatrixObject = ExecutionContext.createMatrixObject(matrixBlock2);
        long nextFedDataID = FederationUtils.getNextFedDataID();
        executionContext.setVariable(String.valueOf(nextFedDataID), createMatrixObject);
        MatrixObject createMatrixObject2 = ExecutionContext.createMatrixObject(matrixBlock3);
        long nextFedDataID2 = FederationUtils.getNextFedDataID();
        executionContext.setVariable(String.valueOf(nextFedDataID2), createMatrixObject2);
        CPOperand cPOperand = new CPOperand(String.valueOf(nextFedDataID), Types.ValueType.FP64, Types.DataType.MATRIX);
        CPOperand cPOperand2 = new CPOperand(String.valueOf(nextFedDataID2), Types.ValueType.FP64, Types.DataType.MATRIX);
        String constructTernaryString = InstructionUtils.constructTernaryString(this.instString, cPOperand, this.input1, cPOperand2, this.output);
        FederatedRequest[] broadcastSliced = matrixObject.getFedMapping().broadcastSliced(createMatrixObject, false);
        FederatedRequest[] broadcastSliced2 = matrixObject.getFedMapping().broadcastSliced(createMatrixObject2, false);
        FederatedRequest callInstruction = FederationUtils.callInstruction(constructTernaryString, this.output, new CPOperand[]{this.input1, cPOperand, cPOperand2}, new long[]{matrixObject.getFedMapping().getID(), broadcastSliced[0].getID(), broadcastSliced2[0].getID()});
        matrixObject.getFedMapping().execute(getTID(), true, broadcastSliced, broadcastSliced2, callInstruction);
        matrixObject2.setFedMapping(matrixObject.getFedMapping().copyWithNewID(callInstruction.getID()));
        executionContext.removeVariable(cPOperand.getName());
        executionContext.removeVariable(cPOperand2.getName());
    }

    private void setScalingValues(String str, ExecutionContext executionContext, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixBlock matrixBlock, double d) {
        MatrixBlock matrixBlock2 = new MatrixBlock((int) matrixObject.getNumRows(), (int) matrixObject.getNumColumns(), d);
        for (int i = 1; i < matrixBlock.getNumRows(); i++) {
            int i2 = (int) matrixObject.getFedMapping().getFederatedRanges()[i].getBeginDims()[0];
            matrixBlock2.copy(i2, i2, 0, (int) (matrixObject.getNumColumns() - 1), matrixBlock.slice(i, i), true);
        }
        MatrixObject createMatrixObject = ExecutionContext.createMatrixObject(matrixBlock2);
        long nextFedDataID = FederationUtils.getNextFedDataID();
        executionContext.setVariable(String.valueOf(nextFedDataID), createMatrixObject);
        CPOperand cPOperand = new CPOperand(String.valueOf(nextFedDataID), Types.ValueType.FP64, Types.DataType.MATRIX);
        String constructBinaryInstString = InstructionUtils.constructBinaryInstString(this.instString, str, this.input1, cPOperand, this.output);
        long nextFedDataID2 = FederationUtils.getNextFedDataID();
        FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, nextFedDataID2, new MatrixCharacteristics(-1L, -1L), Types.DataType.MATRIX);
        FederatedRequest[] broadcastSliced = matrixObject.getFedMapping().broadcastSliced(createMatrixObject, false);
        FederatedRequest callInstruction = FederationUtils.callInstruction(constructBinaryInstString, this.output, nextFedDataID2, new CPOperand[]{this.input1, cPOperand}, new long[]{matrixObject.getFedMapping().getID(), broadcastSliced[0].getID()}, Types.ExecType.SPARK, false);
        matrixObject.getFedMapping().execute(getTID(), true, broadcastSliced, federatedRequest, callInstruction);
        matrixObject2.setFedMapping(matrixObject.getFedMapping().copyWithNewID(callInstruction.getID()));
        executionContext.removeVariable(cPOperand.getName());
    }

    private MatrixObject setOutputFedMapping(ExecutionContext executionContext, MatrixObject matrixObject, long j) {
        MatrixObject matrixObject2 = executionContext.getMatrixObject(this.output);
        matrixObject2.getDataCharacteristics().set(matrixObject.getDataCharacteristics());
        matrixObject2.setFedMapping(matrixObject.getFedMapping().copyWithNewID(j));
        return matrixObject2;
    }
}
