package org.apache.sysds.runtime.instructions.fed;

import com.sun.tools.javac.util.List;
import java.util.Objects;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.TernaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.class */
public class TernaryFEDInstruction extends ComputationFEDInstruction {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction$RetAlignedValues.class */
    public static final class RetAlignedValues {
        public boolean _twoAligned;
        public boolean _allAligned;
        public long[] _vars;
        public FederatedRequest[] _fr;

        public RetAlignedValues(boolean z, boolean z2, long[] jArr, FederatedRequest[] federatedRequestArr) {
            this._twoAligned = z;
            this._allAligned = z2;
            this._vars = jArr;
            this._fr = federatedRequestArr;
        }
    }

    private TernaryFEDInstruction(TernaryOperator ternaryOperator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2, FEDInstruction.FederatedOutput federatedOutput) {
        super(FEDInstruction.FEDType.Ternary, ternaryOperator, cPOperand, cPOperand2, cPOperand3, cPOperand4, str, str2, federatedOutput);
    }

    public static TernaryFEDInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        return new TernaryFEDInstruction(InstructionUtils.parseTernaryOperator(str2, instructionPartsWithValueType.length > 5 ? Integer.parseInt(instructionPartsWithValueType[5]) : 1), new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), str2, str, instructionPartsWithValueType.length > 7 ? FEDInstruction.FederatedOutput.valueOf(instructionPartsWithValueType[6]) : FEDInstruction.FederatedOutput.NONE);
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        MatrixObject matrixObject = this.input1.isMatrix() ? executionContext.getMatrixObject(this.input1.getName()) : null;
        MatrixObject matrixObject2 = this.input2.isMatrix() ? executionContext.getMatrixObject(this.input2.getName()) : null;
        MatrixObject matrixObject3 = (this.input3 == null || !this.input3.isMatrix()) ? null : executionContext.getMatrixObject(this.input3.getName());
        long count = List.of(matrixObject, matrixObject2, matrixObject3).stream().filter((v0) -> {
            return Objects.nonNull(v0);
        }).count();
        if (count == 3) {
            processMatrixInput(executionContext, matrixObject, matrixObject2, matrixObject3);
            return;
        }
        if (count == 1) {
            processMatrixScalarInput(executionContext, matrixObject == null ? matrixObject2 == null ? matrixObject3 : matrixObject2 : matrixObject, matrixObject == null ? matrixObject2 == null ? this.input3 : this.input2 : this.input1);
            return;
        }
        if (matrixObject != null && matrixObject2 != null) {
            if (this.input3 != null && !this.input3.isLiteral()) {
                this.instString = InstructionUtils.replaceOperand(this.instString, 4, InstructionUtils.createLiteralOperand(executionContext.getScalarInput(this.input3).getStringValue(), Types.ValueType.FP64));
            }
            process2MatrixScalarInput(executionContext, matrixObject, matrixObject2, this.input1, this.input2);
            return;
        }
        if (matrixObject2 != null && matrixObject3 != null) {
            if (!this.input1.isLiteral()) {
                this.instString = InstructionUtils.replaceOperand(this.instString, 2, InstructionUtils.createLiteralOperand(executionContext.getScalarInput(this.input1).getStringValue(), Types.ValueType.FP64));
            }
            process2MatrixScalarInput(executionContext, matrixObject2, matrixObject3, this.input2, this.input3);
        } else {
            if (matrixObject == null || matrixObject3 == null) {
                return;
            }
            if (!this.input2.isLiteral()) {
                this.instString = InstructionUtils.replaceOperand(this.instString, 3, InstructionUtils.createLiteralOperand(executionContext.getScalarInput(this.input2).getStringValue(), Types.ValueType.FP64));
            }
            process2MatrixScalarInput(executionContext, matrixObject, matrixObject3, this.input1, this.input3);
        }
    }

    private void processMatrixScalarInput(ExecutionContext executionContext, MatrixObject matrixObject, CPOperand cPOperand) {
        FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{cPOperand}, new long[]{matrixObject.getFedMapping().getID()});
        sendFederatedRequests(executionContext, matrixObject, callInstruction.getID(), callInstruction);
    }

    private void process2MatrixScalarInput(ExecutionContext executionContext, MatrixObject matrixObject, MatrixObject matrixObject2, CPOperand cPOperand, CPOperand cPOperand2) {
        long[] jArr;
        FederatedRequest[] federatedRequestArr = null;
        boolean z = true;
        CPOperand[] cPOperandArr = {cPOperand, cPOperand2};
        if (!matrixObject.isFederated()) {
            z = false;
            matrixObject = executionContext.getMatrixObject(cPOperand2);
            federatedRequestArr = matrixObject.getFedMapping().broadcastSliced(executionContext.getMatrixObject(cPOperand), false);
            jArr = new long[]{federatedRequestArr[0].getID(), matrixObject.getFedMapping().getID()};
        } else if (matrixObject2.isFederated() && matrixObject.getFedMapping().isAligned(matrixObject2.getFedMapping(), false)) {
            jArr = new long[]{matrixObject.getFedMapping().getID(), matrixObject2.getFedMapping().getID()};
        } else {
            federatedRequestArr = matrixObject.getFedMapping().broadcastSliced(matrixObject2, false);
            jArr = new long[]{matrixObject.getFedMapping().getID(), federatedRequestArr[0].getID()};
        }
        FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, cPOperandArr, jArr);
        if (federatedRequestArr == null) {
            sendFederatedRequests(executionContext, matrixObject, callInstruction.getID(), callInstruction);
        } else if (z) {
            sendFederatedRequests(executionContext, matrixObject, callInstruction.getID(), federatedRequestArr, callInstruction, matrixObject.getFedMapping().cleanup(getTID(), federatedRequestArr[0].getID()));
        } else {
            sendFederatedRequests(executionContext, matrixObject, callInstruction.getID(), federatedRequestArr, callInstruction);
        }
    }

    private void sendFederatedRequests(ExecutionContext executionContext, MatrixObject matrixObject, long j, FederatedRequest... federatedRequestArr) {
        sendFederatedRequests(executionContext, matrixObject, j, null, null, federatedRequestArr);
    }

    private void sendFederatedRequests(ExecutionContext executionContext, MatrixObject matrixObject, long j, FederatedRequest[] federatedRequestArr, FederatedRequest... federatedRequestArr2) {
        sendFederatedRequests(executionContext, matrixObject, j, federatedRequestArr, null, federatedRequestArr2);
    }

    private void sendFederatedRequests(ExecutionContext executionContext, MatrixObject matrixObject, long j, FederatedRequest[] federatedRequestArr, FederatedRequest[] federatedRequestArr2, FederatedRequest... federatedRequestArr3) {
        if (this._fedOut.isForcedLocal()) {
            processAndRetrieve(executionContext, matrixObject, j, federatedRequestArr, federatedRequestArr2, federatedRequestArr3);
        } else {
            matrixObject.getFedMapping().execute(getTID(), true, federatedRequestArr, federatedRequestArr2, federatedRequestArr3);
            setOutputFedMapping(executionContext, matrixObject, j);
        }
    }

    private void processAndRetrieve(ExecutionContext executionContext, MatrixObject matrixObject, long j, FederatedRequest[] federatedRequestArr, FederatedRequest[] federatedRequestArr2, FederatedRequest... federatedRequestArr3) {
        executionContext.setMatrixOutput(this.output.getName(), FederationUtils.bind(matrixObject.getFedMapping().execute(getTID(), true, federatedRequestArr, federatedRequestArr2, collectRequests(federatedRequestArr3, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, j))), matrixObject.isFederated(FederationMap.FType.COL)));
    }

    private static FederatedRequest[] collectRequests(FederatedRequest[] federatedRequestArr, FederatedRequest federatedRequest) {
        FederatedRequest[] federatedRequestArr2 = new FederatedRequest[federatedRequestArr.length + 1];
        for (int i = 0; i < federatedRequestArr.length; i++) {
            federatedRequestArr2[i] = federatedRequestArr[i];
        }
        federatedRequestArr2[federatedRequestArr2.length - 1] = federatedRequest;
        return federatedRequestArr2;
    }

    private void processMatrixInput(ExecutionContext executionContext, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3) {
        RetAlignedValues alignedInputs = getAlignedInputs(executionContext, matrixObject, matrixObject2, matrixObject3);
        if (alignedInputs._allAligned) {
            FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{matrixObject.getFedMapping().getID(), matrixObject2.getFedMapping().getID(), matrixObject3.getFedMapping().getID()});
            sendFederatedRequests(executionContext, matrixObject, callInstruction.getID(), callInstruction);
            return;
        }
        if (alignedInputs._twoAligned) {
            FederatedRequest callInstruction2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, alignedInputs._vars);
            sendFederatedRequests(executionContext, matrixObject, callInstruction2.getID(), alignedInputs._fr, callInstruction2, matrixObject.getFedMapping().cleanup(getTID(), alignedInputs._fr[0].getID()));
            return;
        }
        if (!matrixObject.isFederated()) {
            if (matrixObject2.isFederated()) {
                matrixObject = matrixObject2;
                matrixObject2 = executionContext.getMatrixObject(this.input1);
            } else {
                matrixObject = matrixObject3;
                matrixObject3 = executionContext.getMatrixObject(this.input1);
            }
        }
        FederatedRequest[] broadcastSliced = matrixObject.getFedMapping().broadcastSliced(matrixObject2, false);
        FederatedRequest[] broadcastSliced2 = matrixObject.getFedMapping().broadcastSliced(matrixObject3, false);
        long[] jArr = {matrixObject.getFedMapping().getID(), broadcastSliced[0].getID(), broadcastSliced2[0].getID()};
        if (!executionContext.getMatrixObject(this.input1).isFederated()) {
            jArr = executionContext.getMatrixObject(this.input2).isFederated() ? new long[]{broadcastSliced[0].getID(), matrixObject.getFedMapping().getID(), broadcastSliced2[0].getID()} : new long[]{broadcastSliced[0].getID(), broadcastSliced2[0].getID(), matrixObject.getFedMapping().getID()};
        }
        FederatedRequest callInstruction3 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, jArr);
        sendFederatedRequests(executionContext, matrixObject, callInstruction3.getID(), broadcastSliced, broadcastSliced2, callInstruction3, matrixObject.getFedMapping().cleanup(getTID(), broadcastSliced[0].getID(), broadcastSliced2[0].getID()));
    }

    private RetAlignedValues getAlignedInputs(ExecutionContext executionContext, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3) {
        long[] jArr = new long[0];
        FederatedRequest[] federatedRequestArr = new FederatedRequest[0];
        boolean z = false;
        boolean z2 = false;
        if (matrixObject.isFederated() && matrixObject2.isFederated() && matrixObject.getFedMapping().isAligned(matrixObject2.getFedMapping(), false)) {
            z = true;
            federatedRequestArr = matrixObject.getFedMapping().broadcastSliced(matrixObject3, false);
            jArr = new long[]{matrixObject.getFedMapping().getID(), matrixObject2.getFedMapping().getID(), federatedRequestArr[0].getID()};
        }
        if (matrixObject.isFederated() && matrixObject3.isFederated() && matrixObject.getFedMapping().isAligned(matrixObject3.getFedMapping(), false)) {
            z2 = z;
            z = true;
            federatedRequestArr = matrixObject.getFedMapping().broadcastSliced(matrixObject2, false);
            jArr = new long[]{matrixObject.getFedMapping().getID(), federatedRequestArr[0].getID(), matrixObject3.getFedMapping().getID()};
        }
        if (matrixObject2.isFederated() && matrixObject3.isFederated() && matrixObject2.getFedMapping().isAligned(matrixObject3.getFedMapping(), false) && !z2) {
            z = true;
            federatedRequestArr = matrixObject2.getFedMapping().broadcastSliced(executionContext.getMatrixObject(this.input1), false);
            jArr = new long[]{federatedRequestArr[0].getID(), matrixObject2.getFedMapping().getID(), matrixObject3.getFedMapping().getID()};
        }
        return new RetAlignedValues(z, z2, jArr, federatedRequestArr);
    }

    private void setOutputFedMapping(ExecutionContext executionContext, MatrixObject matrixObject, long j) {
        executionContext.getMatrixObject(this.output).setFedMapping(matrixObject.getFedMapping().copyWithNewID(j));
    }
}
