package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.lops.MapMult;
import org.apache.sysds.lops.PMMJ;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateBinarySPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/MMFEDInstruction.class */
public class MMFEDInstruction extends BinaryFEDInstruction {
    private MMFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        super(FEDInstruction.FEDType.MAPMM, operator, cPOperand, cPOperand2, cPOperand3, str, str2);
    }

    public static MMFEDInstruction parseInstruction(AggregateBinarySPInstruction aggregateBinarySPInstruction) {
        return new MMFEDInstruction(aggregateBinarySPInstruction.getOperator(), aggregateBinarySPInstruction.input1, aggregateBinarySPInstruction.input2, aggregateBinarySPInstruction.output, aggregateBinarySPInstruction.getOpcode(), aggregateBinarySPInstruction.getInstructionString());
    }

    public static MMFEDInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (ArrayUtils.contains(new String[]{MapMult.OPCODE, PMMJ.OPCODE, "cpmm", "rmm"}, str2)) {
            return new MMFEDInstruction(InstructionUtils.getMatMultOperator(1), new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), str2, str);
        }
        throw new DMLRuntimeException("MapmmSPInstruction.parseInstruction():: Unknown opcode " + str2);
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        MatrixLineagePair matrixLineagePair = executionContext.getMatrixLineagePair(this.input1);
        MatrixLineagePair matrixLineagePair2 = executionContext.getMatrixLineagePair(this.input2);
        long nextFedDataID = FederationUtils.getNextFedDataID();
        FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.PUT_VAR, nextFedDataID, new MatrixCharacteristics(-1L, -1L), Types.DataType.MATRIX);
        if (matrixLineagePair.isFederated(FTypes.FType.COL) && matrixLineagePair2.isFederated(FTypes.FType.ROW) && matrixLineagePair.getFedMapping().isAligned(matrixLineagePair2.getFedMapping(), FTypes.AlignType.COL_T)) {
            FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, nextFedDataID, new CPOperand[]{this.input1, this.input2}, new long[]{matrixLineagePair.getFedMapping().getID(), matrixLineagePair2.getFedMapping().getID()}, Types.ExecType.SPARK, false);
            if (!this._fedOut.isForcedFederated()) {
                aggregateLocally(matrixLineagePair.getFedMapping(), true, executionContext, federatedRequest, callInstruction);
                return;
            } else {
                matrixLineagePair.getFedMapping().execute(getTID(), federatedRequest, callInstruction);
                setPartialOutput(matrixLineagePair.getFedMapping(), matrixLineagePair.getMO(), matrixLineagePair2.getMO(), callInstruction.getID(), executionContext);
                return;
            }
        }
        if (matrixLineagePair.isFederated(FTypes.FType.ROW) || matrixLineagePair.isFederated(FTypes.FType.PART)) {
            FederatedRequest broadcast = matrixLineagePair.getFedMapping().broadcast(matrixLineagePair2);
            FederatedRequest callInstruction2 = FederationUtils.callInstruction(this.instString, this.output, nextFedDataID, new CPOperand[]{this.input1, this.input2}, new long[]{matrixLineagePair.getFedMapping().getID(), broadcast.getID()}, Types.ExecType.SPARK, false);
            boolean z = matrixLineagePair2.getNumColumns() == 1;
            boolean z2 = matrixLineagePair.isFederated(FTypes.FType.PART) || (!z && matrixLineagePair2.isFederated(FTypes.FType.PART));
            if (z2 && this._fedOut.isForcedFederated()) {
                matrixLineagePair.getFedMapping().execute(getTID(), true, federatedRequest, broadcast, callInstruction2);
                setPartialOutput(matrixLineagePair.getFedMapping(), matrixLineagePair.getMO(), matrixLineagePair2.getMO(), callInstruction2.getID(), executionContext);
                return;
            } else if ((!this._fedOut.isForcedFederated() && (z || this._fedOut.isForcedLocal())) || z2) {
                aggregateLocally(matrixLineagePair.getFedMapping(), matrixLineagePair.isFederated(FTypes.FType.PART), executionContext, federatedRequest, broadcast, callInstruction2);
                return;
            } else {
                matrixLineagePair.getFedMapping().execute(getTID(), true, federatedRequest, broadcast, callInstruction2);
                setOutputFedMapping(matrixLineagePair.getFedMapping(), matrixLineagePair.getMO(), matrixLineagePair2.getMO(), callInstruction2.getID(), executionContext);
                return;
            }
        }
        if (matrixLineagePair2.isFederated(FTypes.FType.ROW)) {
            FederatedRequest[] broadcastSliced = matrixLineagePair2.getFedMapping().broadcastSliced(matrixLineagePair, true);
            FederatedRequest callInstruction3 = FederationUtils.callInstruction(this.instString, this.output, nextFedDataID, new CPOperand[]{this.input1, this.input2}, new long[]{broadcastSliced[0].getID(), matrixLineagePair2.getFedMapping().getID()}, Types.ExecType.SPARK, false);
            if (!this._fedOut.isForcedFederated()) {
                aggregateLocally(matrixLineagePair2.getFedMapping(), true, executionContext, broadcastSliced, federatedRequest, callInstruction3);
                return;
            } else {
                matrixLineagePair2.getFedMapping().execute(getTID(), true, broadcastSliced, federatedRequest, callInstruction3);
                setPartialOutput(matrixLineagePair2.getFedMapping(), matrixLineagePair.getMO(), matrixLineagePair2.getMO(), callInstruction3.getID(), executionContext);
                return;
            }
        }
        if (!matrixLineagePair.isFederated(FTypes.FType.COL)) {
            throw new DMLRuntimeException("Federated AggregateBinary not supported with the following federated objects: " + matrixLineagePair.isFederated() + ":" + matrixLineagePair.getFedMapping() + " " + matrixLineagePair2.isFederated() + ":" + matrixLineagePair2.getFedMapping());
        }
        FederatedRequest[] broadcastSliced2 = matrixLineagePair.getFedMapping().broadcastSliced(matrixLineagePair2, true);
        FederatedRequest callInstruction4 = FederationUtils.callInstruction(this.instString, this.output, nextFedDataID, new CPOperand[]{this.input1, this.input2}, new long[]{matrixLineagePair.getFedMapping().getID(), broadcastSliced2[0].getID()}, Types.ExecType.SPARK, false);
        if (!this._fedOut.isForcedFederated()) {
            aggregateLocally(matrixLineagePair.getFedMapping(), true, executionContext, broadcastSliced2, federatedRequest, callInstruction4);
        } else {
            matrixLineagePair.getFedMapping().execute(getTID(), true, broadcastSliced2, federatedRequest, callInstruction4);
            setPartialOutput(matrixLineagePair.getFedMapping(), matrixLineagePair.getMO(), matrixLineagePair2.getMO(), callInstruction4.getID(), executionContext);
        }
    }

    private void setPartialOutput(FederationMap federationMap, MatrixObject matrixObject, MatrixObject matrixObject2, long j, ExecutionContext executionContext) {
        MatrixObject matrixObject3 = executionContext.getMatrixObject(this.output);
        matrixObject3.getDataCharacteristics().set(matrixObject.getNumRows(), matrixObject2.getNumColumns(), matrixObject.getBlocksize());
        matrixObject3.setFedMapping(federationMap.copyWithNewIDAndRange(matrixObject.getNumRows(), matrixObject2.getNumColumns(), j));
    }

    private void setOutputFedMapping(FederationMap federationMap, MatrixObject matrixObject, MatrixObject matrixObject2, long j, ExecutionContext executionContext) {
        MatrixObject matrixObject3 = executionContext.getMatrixObject(this.output);
        matrixObject3.getDataCharacteristics().set(matrixObject.getNumRows(), matrixObject2.getNumColumns(), matrixObject.getBlocksize());
        matrixObject3.setFedMapping(federationMap.copyWithNewID(j, matrixObject2.getNumColumns()));
    }

    private void aggregateLocally(FederationMap federationMap, boolean z, ExecutionContext executionContext, FederatedRequest... federatedRequestArr) {
        aggregateLocally(federationMap, z, executionContext, null, federatedRequestArr);
    }

    private void aggregateLocally(FederationMap federationMap, boolean z, ExecutionContext executionContext, FederatedRequest[] federatedRequestArr, FederatedRequest... federatedRequestArr2) {
        long id = federatedRequestArr2[federatedRequestArr2.length - 1].getID();
        FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, id);
        FederatedRequest cleanup = federationMap.cleanup(getTID(), id);
        Future<FederatedResponse>[] execute = federatedRequestArr != null ? federationMap.execute(getTID(), federatedRequestArr, (FederatedRequest[]) ArrayUtils.addAll(federatedRequestArr2, new FederatedRequest[]{federatedRequest, cleanup})) : federationMap.execute(getTID(), (FederatedRequest[]) ArrayUtils.addAll(federatedRequestArr2, new FederatedRequest[]{federatedRequest, cleanup}));
        executionContext.setMatrixOutput(this.output.getName(), z ? FederationUtils.aggAdd(execute) : FederationUtils.bind(execute, false));
    }
}
