package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.common.Types;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.QuaternaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.spark.QuaternarySPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/QuaternaryWCeMMFEDInstruction.class */
public class QuaternaryWCeMMFEDInstruction extends QuaternaryFEDInstruction {
    /* JADX INFO: Access modifiers changed from: protected */
    public QuaternaryWCeMMFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, String str, String str2) {
        super(FEDInstruction.FEDType.Quaternary, operator, cPOperand, cPOperand2, cPOperand3, cPOperand4, cPOperand5, str, str2);
    }

    public static QuaternaryWCeMMFEDInstruction parseInstruction(QuaternaryCPInstruction quaternaryCPInstruction) {
        return new QuaternaryWCeMMFEDInstruction(quaternaryCPInstruction.getOperator(), quaternaryCPInstruction.input1, quaternaryCPInstruction.input2, quaternaryCPInstruction.input3, quaternaryCPInstruction.getInput4(), quaternaryCPInstruction.output, quaternaryCPInstruction.getOpcode(), quaternaryCPInstruction.getInstructionString());
    }

    public static QuaternaryWCeMMFEDInstruction parseInstruction(QuaternarySPInstruction quaternarySPInstruction) {
        String rewriteSparkInstructionToCP = rewriteSparkInstructionToCP(quaternarySPInstruction.getInstructionString());
        return new QuaternaryWCeMMFEDInstruction(quaternarySPInstruction.getOperator(), quaternarySPInstruction.input1, quaternarySPInstruction.input2, quaternarySPInstruction.input3, quaternarySPInstruction.getInput4(), quaternarySPInstruction.output, InstructionUtils.getInstructionPartsWithValueType(rewriteSparkInstructionToCP)[0], rewriteSparkInstructionToCP);
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        QuaternaryOperator quaternaryOperator = (QuaternaryOperator) this._optr;
        MatrixObject matrixObject = executionContext.getMatrixObject(this.input1);
        MatrixLineagePair matrixLineagePair = executionContext.getMatrixLineagePair(this.input2);
        MatrixLineagePair matrixLineagePair2 = executionContext.getMatrixLineagePair(this.input3);
        ScalarObject scalarObject = null;
        if (quaternaryOperator.hasFourInputs()) {
            scalarObject = this._input4.getDataType() == Types.DataType.SCALAR ? executionContext.getScalarInput(this._input4) : new DoubleObject(executionContext.getMatrixInput(this._input4.getName()).quickGetValue(0, 0));
        }
        if (!matrixObject.isFederated()) {
            throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" + matrixObject.isFederated() + ", " + matrixLineagePair.isFederated() + ", " + matrixLineagePair2.isFederated() + ")");
        }
        FederationMap fedMapping = matrixObject.getFedMapping();
        FederatedRequest[] federatedRequestArr = null;
        ArrayList arrayList = new ArrayList();
        long[] jArr = new long[scalarObject != null ? 4 : 3];
        jArr[0] = fedMapping.getID();
        if (matrixObject.isFederated(FTypes.FType.ROW)) {
            if (matrixLineagePair.isFederated(FTypes.FType.ROW) && fedMapping.isAligned(matrixLineagePair.getFedMapping(), FTypes.AlignType.ROW)) {
                jArr[1] = matrixLineagePair.getFedMapping().getID();
            } else {
                federatedRequestArr = fedMapping.broadcastSliced(matrixLineagePair, false);
                jArr[1] = federatedRequestArr[0].getID();
            }
            FederatedRequest broadcast = fedMapping.broadcast(matrixLineagePair2);
            jArr[2] = broadcast.getID();
            arrayList.add(broadcast);
        } else {
            if (!matrixObject.isFederated(FTypes.FType.COL)) {
                throw new DMLRuntimeException("Federated WCeMM only supported for ROW or COLUMN partitioned federated data.");
            }
            FederatedRequest broadcast2 = fedMapping.broadcast(matrixLineagePair);
            jArr[1] = broadcast2.getID();
            arrayList.add(broadcast2);
            if (matrixLineagePair2.isFederated() && fedMapping.isAligned(matrixLineagePair2.getFedMapping(), FTypes.AlignType.COL, FTypes.AlignType.COL_T)) {
                jArr[2] = matrixLineagePair2.getFedMapping().getID();
            } else {
                federatedRequestArr = fedMapping.broadcastSliced(matrixLineagePair2, true);
                jArr[2] = federatedRequestArr[0].getID();
            }
        }
        if (scalarObject != null) {
            FederatedRequest broadcast3 = fedMapping.broadcast(scalarObject);
            jArr[3] = broadcast3.getID();
            arrayList.add(broadcast3);
            this.instString = this.instString.replace("true", "false");
        }
        FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, scalarObject == null ? new CPOperand[]{this.input1, this.input2, this.input3} : new CPOperand[]{this.input1, this.input2, this.input3, this._input4}, jArr);
        FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction.getID());
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(fedMapping.cleanup(getTID(), callInstruction.getID()));
        FederatedRequest[] federatedRequestArr2 = (FederatedRequest[]) ArrayUtils.addAll((FederatedRequest[]) ArrayUtils.addAll((FederatedRequest[]) arrayList.toArray(new FederatedRequest[0]), new FederatedRequest[]{callInstruction, federatedRequest}), (FederatedRequest[]) arrayList2.toArray(new FederatedRequest[0]));
        executionContext.setVariable(this.output.getName(), FederationUtils.aggScalar(InstructionUtils.parseBasicAggregateUnaryOperator("uak+"), federatedRequestArr == null ? fedMapping.execute(getTID(), true, federatedRequestArr2) : fedMapping.execute(getTID(), true, federatedRequestArr, federatedRequestArr2)));
    }
}
