package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.class */
public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
    public AggregateBinaryFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        super(FEDInstruction.FEDType.AggregateBinary, operator, cPOperand, cPOperand2, cPOperand3, str, str2);
    }

    public AggregateBinaryFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, FEDInstruction.FederatedOutput federatedOutput) {
        super(FEDInstruction.FEDType.AggregateBinary, operator, cPOperand, cPOperand2, cPOperand3, str, str2, federatedOutput);
    }

    public static AggregateBinaryFEDInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (!str2.equalsIgnoreCase("ba+*")) {
            throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + str2);
        }
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 5);
        return new AggregateBinaryFEDInstruction(InstructionUtils.getMatMultOperator(Integer.parseInt(instructionPartsWithValueType[4])), new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), str2, str, FEDInstruction.FederatedOutput.valueOf(instructionPartsWithValueType[5]));
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        MatrixObject matrixObject = executionContext.getMatrixObject(this.input1);
        MatrixObject matrixObject2 = executionContext.getMatrixObject(this.input2);
        if (matrixObject.isFederated(FederationMap.FType.COL) && matrixObject2.isFederated(FederationMap.FType.ROW) && matrixObject.getFedMapping().isAligned(matrixObject2.getFedMapping(), true)) {
            FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixObject.getFedMapping().getID(), matrixObject2.getFedMapping().getID()}, true);
            if (this._fedOut.isForcedFederated()) {
                matrixObject.getFedMapping().execute(getTID(), callInstruction);
                setPartialOutput(matrixObject.getFedMapping(), matrixObject, matrixObject2, callInstruction.getID(), executionContext);
                return;
            } else {
                FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction.getID());
                executionContext.setMatrixOutput(this.output.getName(), FederationUtils.aggAdd(matrixObject.getFedMapping().execute(getTID(), callInstruction, federatedRequest, matrixObject2.getFedMapping().cleanup(getTID(), callInstruction.getID(), federatedRequest.getID()))));
                return;
            }
        }
        if (!matrixObject.isFederated(FederationMap.FType.ROW) && !matrixObject.isFederated(FederationMap.FType.PART)) {
            if (matrixObject2.isFederated(FederationMap.FType.ROW)) {
                FederatedRequest[] broadcastSliced = matrixObject2.getFedMapping().broadcastSliced(matrixObject, true);
                FederatedRequest callInstruction2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{broadcastSliced[0].getID(), matrixObject2.getFedMapping().getID()}, true);
                if (!this._fedOut.isForcedFederated()) {
                    executionContext.setMatrixOutput(this.output.getName(), FederationUtils.aggAdd(matrixObject2.getFedMapping().execute(getTID(), broadcastSliced, callInstruction2, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction2.getID()), matrixObject2.getFedMapping().cleanup(getTID(), broadcastSliced[0].getID(), callInstruction2.getID()))));
                    return;
                } else {
                    matrixObject2.getFedMapping().execute(getTID(), true, broadcastSliced, callInstruction2, matrixObject2.getFedMapping().cleanup(getTID(), broadcastSliced[0].getID()));
                    setPartialOutput(matrixObject2.getFedMapping(), matrixObject, matrixObject2, callInstruction2.getID(), executionContext);
                    return;
                }
            }
            if (!matrixObject.isFederated(FederationMap.FType.COL)) {
                throw new DMLRuntimeException("Federated AggregateBinary not supported with the following federated objects: " + matrixObject.isFederated() + ":" + matrixObject.getFedMapping() + " " + matrixObject2.isFederated() + ":" + matrixObject2.getFedMapping());
            }
            FederatedRequest[] broadcastSliced2 = matrixObject.getFedMapping().broadcastSliced(matrixObject2, true);
            FederatedRequest callInstruction3 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixObject.getFedMapping().getID(), broadcastSliced2[0].getID()}, true);
            if (!this._fedOut.isForcedFederated()) {
                executionContext.setMatrixOutput(this.output.getName(), FederationUtils.aggAdd(matrixObject.getFedMapping().execute(getTID(), broadcastSliced2, callInstruction3, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction3.getID()), matrixObject.getFedMapping().cleanup(getTID(), broadcastSliced2[0].getID(), callInstruction3.getID()))));
                return;
            } else {
                matrixObject.getFedMapping().execute(getTID(), true, broadcastSliced2, callInstruction3, matrixObject.getFedMapping().cleanup(getTID(), broadcastSliced2[0].getID()));
                setPartialOutput(matrixObject.getFedMapping(), matrixObject, matrixObject2, callInstruction3.getID(), executionContext);
                return;
            }
        }
        FederatedRequest broadcast = matrixObject.getFedMapping().broadcast(matrixObject2);
        FederatedRequest callInstruction4 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixObject.getFedMapping().getID(), broadcast.getID()}, true);
        if (matrixObject2.getNumColumns() == 1 && matrixObject2.getNumRows() != matrixObject.getNumColumns()) {
            if (!this._fedOut.isForcedFederated()) {
                Future<FederatedResponse>[] execute = matrixObject.getFedMapping().execute(getTID(), broadcast, callInstruction4, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction4.getID()), matrixObject.getFedMapping().cleanup(getTID(), broadcast.getID(), callInstruction4.getID()));
                executionContext.setMatrixOutput(this.output.getName(), matrixObject.isFederated(FederationMap.FType.PART) ? FederationUtils.aggAdd(execute) : FederationUtils.bind(execute, false));
                return;
            }
            matrixObject.getFedMapping().execute(getTID(), broadcast, callInstruction4, matrixObject.getFedMapping().cleanup(getTID(), broadcast.getID()));
            if (matrixObject.isFederated(FederationMap.FType.PART)) {
                setPartialOutput(matrixObject.getFedMapping(), matrixObject, matrixObject2, callInstruction4.getID(), executionContext);
                return;
            } else {
                setOutputFedMapping(matrixObject.getFedMapping(), matrixObject, matrixObject2, callInstruction4.getID(), executionContext);
                return;
            }
        }
        if (this._fedOut.isForcedLocal()) {
            Future<FederatedResponse>[] execute2 = matrixObject.getFedMapping().execute(getTID(), broadcast, callInstruction4, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction4.getID()), matrixObject.getFedMapping().cleanup(getTID(), broadcast.getID(), callInstruction4.getID()));
            executionContext.setMatrixOutput(this.output.getName(), matrixObject.isFederated(FederationMap.FType.PART) ? FederationUtils.aggAdd(execute2) : FederationUtils.bind(execute2, false));
            return;
        }
        matrixObject.getFedMapping().execute(getTID(), true, broadcast, callInstruction4, matrixObject.getFedMapping().cleanup(getTID(), broadcast.getID()));
        if (matrixObject.isFederated(FederationMap.FType.PART) || matrixObject2.isFederated(FederationMap.FType.PART)) {
            setPartialOutput(matrixObject.getFedMapping(), matrixObject, matrixObject2, callInstruction4.getID(), executionContext);
        } else {
            setOutputFedMapping(matrixObject.getFedMapping(), matrixObject, matrixObject2, callInstruction4.getID(), executionContext);
        }
    }

    private void setPartialOutput(FederationMap federationMap, MatrixObject matrixObject, MatrixObject matrixObject2, long j, ExecutionContext executionContext) {
        MatrixObject matrixObject3 = executionContext.getMatrixObject(this.output);
        matrixObject3.getDataCharacteristics().set(matrixObject.getNumRows(), matrixObject2.getNumColumns(), (int) matrixObject.getBlocksize());
        matrixObject3.setFedMapping(federationMap.copyWithNewIDAndRange(matrixObject.getNumRows(), matrixObject2.getNumColumns(), j));
    }

    private void setOutputFedMapping(FederationMap federationMap, MatrixObject matrixObject, MatrixObject matrixObject2, long j, ExecutionContext executionContext) {
        MatrixObject matrixObject3 = executionContext.getMatrixObject(this.output);
        matrixObject3.getDataCharacteristics().set(matrixObject.getNumRows(), matrixObject2.getNumColumns(), (int) matrixObject.getBlocksize());
        matrixObject3.setFedMapping(federationMap.copyWithNewID(j, matrixObject2.getNumColumns()));
    }
}
