package org.apache.sysds.runtime.instructions.fed;

import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.class */
public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction {
    /* JADX INFO: Access modifiers changed from: protected */
    public BinaryMatrixMatrixFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        super(FEDInstruction.FEDType.Binary, operator, cPOperand, cPOperand2, cPOperand3, str, str2);
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        FederatedRequest callInstruction;
        MatrixObject matrixObject = executionContext.getMatrixObject(this.input1);
        MatrixObject matrixObject2 = executionContext.getMatrixObject(this.input2);
        if (matrixObject2.isFederated()) {
            if (!matrixObject.isFederated() || !matrixObject.getFedMapping().isAligned(matrixObject2.getFedMapping(), false)) {
                throw new DMLRuntimeException("Matrix-matrix binary operations  with a federated right input are not supported yet.");
            }
            callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixObject.getFedMapping().getID(), matrixObject2.getFedMapping().getID()});
            matrixObject.getFedMapping().execute(getTID(), true, callInstruction);
        } else if (matrixObject2.getNumRows() <= 1 || matrixObject2.getNumColumns() != 1) {
            FederatedRequest broadcast = matrixObject.getFedMapping().broadcast(matrixObject2);
            callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixObject.getFedMapping().getID(), broadcast.getID()});
            matrixObject.getFedMapping().execute(getTID(), true, broadcast, callInstruction, matrixObject.getFedMapping().cleanup(getTID(), broadcast.getID()));
        } else {
            FederatedRequest[] broadcastSliced = matrixObject.getFedMapping().broadcastSliced(matrixObject2, false);
            callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixObject.getFedMapping().getID(), broadcastSliced[0].getID()});
            matrixObject.getFedMapping().execute(getTID(), true, broadcastSliced, callInstruction, matrixObject.getFedMapping().cleanup(getTID(), broadcastSliced[0].getID()));
        }
        MatrixObject matrixObject3 = executionContext.getMatrixObject(this.output);
        matrixObject3.getDataCharacteristics().set(matrixObject.getDataCharacteristics());
        matrixObject3.setFedMapping(matrixObject.getFedMapping().copyWithNewID(callInstruction.getID()));
    }
}
