package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import java.util.concurrent.Future;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.lops.WeightedDivMM;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/QuaternaryWDivMMFEDInstruction.class */
public class QuaternaryWDivMMFEDInstruction extends QuaternaryFEDInstruction {
    private QuaternaryOperator _qop;

    /* JADX INFO: Access modifiers changed from: protected */
    public QuaternaryWDivMMFEDInstruction(QuaternaryOperator quaternaryOperator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, String str, String str2) {
        super(FEDInstruction.FEDType.Quaternary, quaternaryOperator, cPOperand, cPOperand2, cPOperand3, cPOperand4, cPOperand5, str, str2);
        this._qop = quaternaryOperator;
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        WeightedDivMM.WDivMMType wDivMMType = this._qop.wtype3;
        MatrixObject matrixObject = executionContext.getMatrixObject(this.input1);
        MatrixLineagePair matrixLineagePair = executionContext.getMatrixLineagePair(this.input2);
        MatrixLineagePair matrixLineagePair2 = executionContext.getMatrixLineagePair(this.input3);
        ScalarObject scalarObject = null;
        MatrixLineagePair matrixLineagePair3 = null;
        if (this._qop.hasFourInputs()) {
            if (wDivMMType == WeightedDivMM.WDivMMType.MULT_MINUS_4_LEFT || wDivMMType == WeightedDivMM.WDivMMType.MULT_MINUS_4_RIGHT) {
                matrixLineagePair3 = executionContext.getMatrixLineagePair(this._input4);
            } else {
                scalarObject = this._input4.getDataType() == Types.DataType.SCALAR ? executionContext.getScalarInput(this._input4) : new DoubleObject(executionContext.getMatrixInput(this._input4).quickGetValue(0, 0));
            }
        }
        if (!matrixObject.isFederated()) {
            throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" + matrixObject.isFederated() + ", " + matrixLineagePair.isFederated() + ", " + matrixLineagePair2.isFederated() + ")");
        }
        FederationMap fedMapping = matrixObject.getFedMapping();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        long[] jArr = new long[this._qop.hasFourInputs() ? 4 : 3];
        jArr[0] = fedMapping.getID();
        if (matrixObject.isFederated(FTypes.FType.ROW)) {
            if (matrixLineagePair.isFederated(FTypes.FType.ROW) && fedMapping.isAligned(matrixLineagePair.getFedMapping(), FTypes.AlignType.ROW)) {
                jArr[1] = matrixLineagePair.getFedMapping().getID();
            } else {
                FederatedRequest[] broadcastSliced = fedMapping.broadcastSliced(matrixLineagePair, false);
                jArr[1] = broadcastSliced[0].getID();
                arrayList.add(broadcastSliced);
            }
            FederatedRequest broadcast = fedMapping.broadcast(matrixLineagePair2);
            jArr[2] = broadcast.getID();
            arrayList2.add(broadcast);
        } else {
            if (!matrixObject.isFederated(FTypes.FType.COL)) {
                throw new DMLRuntimeException("Federated WDivMM only supported for ROW or COLUMN partitioned federated data.");
            }
            FederatedRequest broadcast2 = fedMapping.broadcast(matrixLineagePair);
            jArr[1] = broadcast2.getID();
            arrayList2.add(broadcast2);
            if (matrixLineagePair2.isFederated() && fedMapping.isAligned(matrixLineagePair2.getFedMapping(), FTypes.AlignType.COL, FTypes.AlignType.COL_T)) {
                jArr[2] = matrixLineagePair2.getFedMapping().getID();
            } else {
                FederatedRequest[] broadcastSliced2 = fedMapping.broadcastSliced(matrixLineagePair2, true);
                jArr[2] = broadcastSliced2[0].getID();
                arrayList.add(broadcastSliced2);
            }
        }
        if (matrixLineagePair3 != null) {
            if (matrixLineagePair3.isFederated() && fedMapping.isAligned(matrixLineagePair3.getFedMapping(), FTypes.AlignType.FULL)) {
                jArr[3] = matrixLineagePair3.getFedMapping().getID();
            } else {
                FederatedRequest[] broadcastSliced3 = fedMapping.broadcastSliced(matrixLineagePair3, false);
                jArr[3] = broadcastSliced3[0].getID();
                arrayList.add(broadcastSliced3);
            }
        }
        if (scalarObject != null) {
            FederatedRequest broadcast3 = fedMapping.broadcast(scalarObject);
            jArr[3] = broadcast3.getID();
            arrayList2.add(broadcast3);
            this.instString = this.instString.replace("true", "false");
        }
        FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, this._qop.hasFourInputs() ? new CPOperand[]{this.input1, this.input2, this.input3, this._input4} : new CPOperand[]{this.input1, this.input2, this.input3}, jArr);
        FederatedRequest federatedRequest = null;
        FederatedRequest federatedRequest2 = null;
        if ((wDivMMType.isLeft() && matrixObject.isFederated(FTypes.FType.ROW)) || (wDivMMType.isRight() && matrixObject.isFederated(FTypes.FType.COL))) {
            federatedRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction.getID());
            federatedRequest2 = fedMapping.cleanup(getTID(), callInstruction.getID());
        }
        FederatedRequest[] federatedRequestArr = federatedRequest == null ? (FederatedRequest[]) ArrayUtils.addAll((FederatedRequest[]) arrayList2.toArray(new FederatedRequest[0]), new FederatedRequest[]{callInstruction}) : (FederatedRequest[]) ArrayUtils.addAll((FederatedRequest[]) arrayList2.toArray(new FederatedRequest[0]), new FederatedRequest[]{callInstruction, federatedRequest, federatedRequest2});
        Future<FederatedResponse>[] execute = arrayList.isEmpty() ? fedMapping.execute(getTID(), true, federatedRequestArr) : fedMapping.executeMultipleSlices(getTID(), true, (FederatedRequest[][]) arrayList.toArray(new FederatedRequest[0]), federatedRequestArr);
        if ((wDivMMType.isLeft() && matrixObject.isFederated(FTypes.FType.ROW)) || (wDivMMType.isRight() && matrixObject.isFederated(FTypes.FType.COL))) {
            executionContext.setMatrixOutput(this.output.getName(), FederationUtils.aggMatrix(InstructionUtils.parseBasicAggregateUnaryOperator("uak+"), execute, fedMapping));
        } else {
            if (!wDivMMType.isLeft() && !wDivMMType.isRight() && !wDivMMType.isBasic()) {
                throw new DMLRuntimeException("Federated WDivMM only supported for BASIC, LEFT or RIGHT variants.");
            }
            setFederatedOutput(matrixObject, matrixLineagePair.getMO(), matrixLineagePair2.getMO(), executionContext, callInstruction.getID());
        }
    }

    private void setFederatedOutput(MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, ExecutionContext executionContext, long j) {
        WeightedDivMM.WDivMMType wDivMMType = this._qop.wtype3;
        MatrixObject matrixObject4 = executionContext.getMatrixObject(this.output);
        FederationMap copyWithNewID = matrixObject.getFedMapping().copyWithNewID(j);
        long j2 = -1;
        long j3 = -1;
        if (wDivMMType.isBasic()) {
            j2 = matrixObject.getNumRows();
            j3 = matrixObject.getNumColumns();
        } else if (wDivMMType.isLeft()) {
            j2 = matrixObject.getNumColumns();
            j3 = matrixObject2.getNumColumns();
            copyWithNewID.transpose().modifyFedRanges(j3, 1);
        } else if (wDivMMType.isRight()) {
            j2 = matrixObject.getNumRows();
            j3 = matrixObject3.getNumColumns();
            copyWithNewID.modifyFedRanges(j3, 1);
        }
        matrixObject4.setFedMapping(copyWithNewID);
        matrixObject4.getDataCharacteristics().set(j2, j3, (int) matrixObject.getBlocksize());
    }
}
