package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/QuaternaryWUMMFEDInstruction.class */
public class QuaternaryWUMMFEDInstruction extends QuaternaryFEDInstruction {
    /* JADX INFO: Access modifiers changed from: protected */
    public QuaternaryWUMMFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2) {
        super(FEDInstruction.FEDType.Quaternary, operator, cPOperand, cPOperand2, cPOperand3, cPOperand4, str, str2);
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        FederatedRequest broadcast;
        MatrixObject matrixObject = executionContext.getMatrixObject(this.input1);
        MatrixObject matrixObject2 = executionContext.getMatrixObject(this.input2);
        MatrixObject matrixObject3 = executionContext.getMatrixObject(this.input3);
        if (!matrixObject.isFederated()) {
            throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" + matrixObject.isFederated() + ", " + matrixObject2.isFederated() + ", " + matrixObject3.isFederated() + ")");
        }
        FederationMap fedMapping = matrixObject.getFedMapping();
        FederatedRequest[] federatedRequestArr = null;
        long[] jArr = new long[3];
        jArr[0] = fedMapping.getID();
        if (matrixObject.isFederated(FederationMap.FType.ROW)) {
            if (matrixObject2.isFederated(FederationMap.FType.ROW) && fedMapping.isAligned(matrixObject2.getFedMapping(), FederationMap.AlignType.ROW)) {
                jArr[1] = matrixObject2.getFedMapping().getID();
            } else {
                federatedRequestArr = fedMapping.broadcastSliced(matrixObject2, false);
                jArr[1] = federatedRequestArr[0].getID();
            }
            broadcast = fedMapping.broadcast(matrixObject3);
            jArr[2] = broadcast.getID();
        } else {
            if (!matrixObject.isFederated(FederationMap.FType.COL)) {
                throw new DMLRuntimeException("Federated WUMM only supported for ROW or COLUMN partitioned federated data.");
            }
            broadcast = fedMapping.broadcast(matrixObject2);
            jArr[1] = broadcast.getID();
            if (matrixObject3.isFederated() && fedMapping.isAligned(matrixObject3.getFedMapping(), FederationMap.AlignType.COL, FederationMap.AlignType.COL_T)) {
                jArr[2] = matrixObject3.getFedMapping().getID();
            } else {
                federatedRequestArr = fedMapping.broadcastSliced(matrixObject3, true);
                jArr[2] = federatedRequestArr[0].getID();
            }
        }
        FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, jArr);
        FederatedRequest[] federatedRequestArr2 = (FederatedRequest[]) ArrayUtils.addAll(new FederatedRequest[]{broadcast, callInstruction}, new ArrayList().toArray(new FederatedRequest[0]));
        if (federatedRequestArr == null) {
            fedMapping.execute(getTID(), true, federatedRequestArr2);
        } else {
            fedMapping.execute(getTID(), true, federatedRequestArr, federatedRequestArr2);
        }
        executionContext.getMatrixObject(this.output).setFedMapping(fedMapping.copyWithNewID(callInstruction.getID()));
        setOutputDataCharacteristics(matrixObject, matrixObject2, matrixObject3, executionContext);
    }
}
