package org.apache.sysds.runtime.instructions.fed;

import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.class */
public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction {
    /* JADX INFO: Access modifiers changed from: protected */
    public BinaryMatrixMatrixFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, FEDInstruction.FederatedOutput federatedOutput) {
        super(FEDInstruction.FEDType.Binary, operator, cPOperand, cPOperand2, cPOperand3, str, str2, federatedOutput);
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        FederatedRequest callInstruction;
        MatrixObject mo;
        MatrixLineagePair matrixLineagePair = executionContext.getMatrixLineagePair(this.input1);
        MatrixLineagePair matrixLineagePair2 = executionContext.getMatrixLineagePair(this.input2);
        if (!matrixLineagePair.isFederated() && matrixLineagePair2.isFederated() && matrixLineagePair.getDataCharacteristics().equalDims(matrixLineagePair2.getDataCharacteristics()) && ((BinaryOperator) this._optr).isCommutative()) {
            matrixLineagePair = executionContext.getMatrixLineagePair(this.input2);
            matrixLineagePair2 = executionContext.getMatrixLineagePair(this.input1);
        }
        if (matrixLineagePair2.isFederatedExcept(FTypes.FType.BROADCAST)) {
            if (matrixLineagePair.isFederated()) {
                FederationMap fedMapping = matrixLineagePair.getFedMapping();
                FederationMap fedMapping2 = matrixLineagePair2.getFedMapping();
                FTypes.AlignType[] alignTypeArr = new FTypes.AlignType[1];
                alignTypeArr[0] = matrixLineagePair.isFederated(FTypes.FType.ROW) ? FTypes.AlignType.ROW : FTypes.AlignType.COL;
                if (fedMapping.isAligned(fedMapping2, alignTypeArr)) {
                    callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixLineagePair.getFedMapping().getID(), matrixLineagePair2.getFedMapping().getID()}, true);
                    matrixLineagePair2.getFedMapping().execute(getTID(), true, callInstruction);
                    mo = matrixLineagePair2.getMO();
                }
            }
            FederatedRequest[] broadcastSliced = matrixLineagePair2.getFedMapping().broadcastSliced(matrixLineagePair, false);
            callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{broadcastSliced[0].getID(), matrixLineagePair2.getFedMapping().getID()}, true);
            matrixLineagePair2.getFedMapping().execute(getTID(), true, broadcastSliced, callInstruction);
            mo = matrixLineagePair2.getMO();
        } else if (!matrixLineagePair2.isFederated(FTypes.FType.BROADCAST) || matrixLineagePair.isFederated()) {
            if (matrixLineagePair.isFederated(FTypes.FType.FULL)) {
                if (matrixLineagePair.getFedMapping().getSize() != 1) {
                    throw new DMLRuntimeException("Matrix-matrix binary operations with a full partitioned federated input with multiple partitions are not supported yet.");
                }
                FederatedRequest broadcast = matrixLineagePair.getFedMapping().broadcast(matrixLineagePair2);
                callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixLineagePair.getFedMapping().getID(), broadcast.getID()}, true);
                matrixLineagePair.getFedMapping().execute(getTID(), true, broadcast, callInstruction);
            } else if ((matrixLineagePair.isFederated(FTypes.FType.ROW) && matrixLineagePair2.getNumRows() == 1) || (matrixLineagePair.isFederated(FTypes.FType.COL) && matrixLineagePair2.getNumColumns() == 1)) {
                FederatedRequest broadcast2 = matrixLineagePair.getFedMapping().broadcast(matrixLineagePair2);
                callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixLineagePair.getFedMapping().getID(), broadcast2.getID()}, true);
                matrixLineagePair.getFedMapping().execute(getTID(), true, broadcast2, callInstruction);
            } else if ((matrixLineagePair.isFederated(FTypes.FType.ROW) ^ matrixLineagePair.isFederated(FTypes.FType.COL)) || (matrixLineagePair.isFederated(FTypes.FType.FULL) && matrixLineagePair.getFedMapping().getSize() == 1)) {
                FederatedRequest[] broadcastSliced2 = matrixLineagePair.getFedMapping().broadcastSliced(matrixLineagePair2, false);
                callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixLineagePair.getFedMapping().getID(), broadcastSliced2[0].getID()}, true);
                matrixLineagePair.getFedMapping().execute(getTID(), true, broadcastSliced2, callInstruction);
            } else {
                if (!matrixLineagePair.isFederated(FTypes.FType.PART) || matrixLineagePair2.isFederated()) {
                    throw new DMLRuntimeException("Matrix-matrix binary operations are only supported with a row partitioned or column partitioned federated input yet.");
                }
                FederatedRequest broadcast3 = matrixLineagePair.getFedMapping().broadcast(matrixLineagePair2);
                callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixLineagePair.getFedMapping().getID(), broadcast3.getID()}, true);
                matrixLineagePair.getFedMapping().execute(getTID(), true, broadcast3, callInstruction);
            }
            mo = matrixLineagePair.getMO();
        } else {
            FederatedRequest broadcast4 = matrixLineagePair2.getFedMapping().broadcast(matrixLineagePair);
            callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixLineagePair2.getFedMapping().getID(), broadcast4.getID()}, true);
            matrixLineagePair2.getFedMapping().execute(getTID(), true, broadcast4, callInstruction);
            mo = matrixLineagePair2.getMO();
        }
        if (matrixLineagePair.isFederated(FTypes.FType.PART) && !matrixLineagePair2.isFederated()) {
            setOutputFedMappingPart(matrixLineagePair.getMO(), matrixLineagePair2.getMO(), callInstruction.getID(), executionContext);
        } else {
            if (!mo.isFederated()) {
                throw new DMLRuntimeException("Input is not federated, so the output FedMapping cannot be set!");
            }
            setOutputFedMapping(mo, Math.max(matrixLineagePair.getNumRows(), matrixLineagePair2.getNumRows()), Math.max(matrixLineagePair.getNumColumns(), matrixLineagePair2.getNumColumns()), callInstruction.getID(), executionContext);
        }
    }

    private void setOutputFedMappingPart(MatrixObject matrixObject, MatrixObject matrixObject2, long j, ExecutionContext executionContext) {
        MatrixObject matrixObject3 = executionContext.getMatrixObject(this.output);
        matrixObject3.getDataCharacteristics().set(matrixObject.getNumRows(), matrixObject2.getNumColumns(), (int) matrixObject.getBlocksize());
        matrixObject3.setFedMapping(matrixObject.getFedMapping().copyWithNewIDAndRange(matrixObject.getNumRows(), matrixObject2.getNumColumns(), j));
    }

    private void setOutputFedMapping(MatrixObject matrixObject, long j, long j2, long j3, ExecutionContext executionContext) {
        MatrixObject matrixObject2 = executionContext.getMatrixObject(this.output);
        FederationMap copyWithNewID = matrixObject.getFedMapping().copyWithNewID(j3);
        if (matrixObject.getNumRows() != j || matrixObject.getNumColumns() != j2) {
            int i = matrixObject.isFederated(FTypes.FType.COL) ? 0 : 1;
            copyWithNewID.modifyFedRanges(i == 0 ? j : j2, i);
        }
        matrixObject2.getDataCharacteristics().set(matrixObject.getDataCharacteristics()).setRows(j).setCols(j2);
        matrixObject2.setFedMapping(copyWithNewID);
    }
}
