package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;
import org.apache.sysds.lops.TernaryAggregate;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/AggregateTernaryFEDInstruction.class */
public class AggregateTernaryFEDInstruction extends ComputationFEDInstruction {
    private AggregateTernaryFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2) {
        super(FEDInstruction.FEDType.AggregateTernary, operator, cPOperand, cPOperand2, cPOperand3, cPOperand4, str, str2);
    }

    public static AggregateTernaryFEDInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (!str2.equalsIgnoreCase(TernaryAggregate.OPCODE_RC) && !str2.equalsIgnoreCase(TernaryAggregate.OPCODE_C)) {
            throw new DMLRuntimeException("AggregateTernaryInstruction.parseInstruction():: Unknown opcode " + str2);
        }
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 5);
        return new AggregateTernaryFEDInstruction(InstructionUtils.parseAggregateTernaryOperator(str2, Integer.parseInt(instructionPartsWithValueType[5])), new CPOperand(instructionPartsWithValueType[1]), new CPOperand(instructionPartsWithValueType[2]), new CPOperand(instructionPartsWithValueType[3]), new CPOperand(instructionPartsWithValueType[4]), str2, str);
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        MatrixObject matrixObject = executionContext.getMatrixObject(this.input1);
        MatrixObject matrixObject2 = executionContext.getMatrixObject(this.input2);
        MatrixObject matrixObject3 = this.input3.isLiteral() ? null : executionContext.getMatrixObject(this.input3);
        if (matrixObject3 != null && matrixObject.isFederated() && matrixObject2.isFederated() && matrixObject3.isFederated()) {
            FederationMap fedMapping = matrixObject.getFedMapping();
            FederationMap fedMapping2 = matrixObject2.getFedMapping();
            FederationMap.AlignType[] alignTypeArr = new FederationMap.AlignType[1];
            alignTypeArr[0] = matrixObject.isFederated(FederationMap.FType.ROW) ? FederationMap.AlignType.ROW : FederationMap.AlignType.COL;
            if (fedMapping.isAligned(fedMapping2, alignTypeArr)) {
                FederationMap fedMapping3 = matrixObject2.getFedMapping();
                FederationMap fedMapping4 = matrixObject3.getFedMapping();
                FederationMap.AlignType[] alignTypeArr2 = new FederationMap.AlignType[1];
                alignTypeArr2[0] = matrixObject.isFederated(FederationMap.FType.ROW) ? FederationMap.AlignType.ROW : FederationMap.AlignType.COL;
                if (fedMapping3.isAligned(fedMapping4, alignTypeArr2)) {
                    FederatedRequest callInstruction = FederationUtils.callInstruction(getInstructionString(), this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{matrixObject.getFedMapping().getID(), matrixObject2.getFedMapping().getID(), matrixObject3.getFedMapping().getID()});
                    Future<FederatedResponse>[] execute = matrixObject.getFedMapping().execute(getTID(), callInstruction, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction.getID()), matrixObject.getFedMapping().cleanup(getTID(), callInstruction.getID()));
                    if (this.output.getDataType().isScalar()) {
                        executionContext.setScalarOutput(this.output.getName(), FederationUtils.aggScalar(InstructionUtils.parseBasicAggregateUnaryOperator("uak+"), execute, matrixObject.getFedMapping()));
                        return;
                    } else {
                        executionContext.setMatrixOutput(this.output.getName(), FederationUtils.aggMatrix(InstructionUtils.parseBasicAggregateUnaryOperator(getOpcode().equals("fed_tak+*") ? "uak+" : "uack+"), execute, matrixObject.getFedMapping()));
                        return;
                    }
                }
            }
        }
        if (matrixObject.isFederated() && matrixObject2.isFederated() && matrixObject.getFedMapping().isAligned(matrixObject2.getFedMapping(), false) && matrixObject3 == null) {
            FederatedRequest broadcast = matrixObject.getFedMapping().broadcast(executionContext.getScalarInput(this.input3));
            FederatedRequest callInstruction2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{matrixObject.getFedMapping().getID(), matrixObject2.getFedMapping().getID(), broadcast.getID()});
            Future<FederatedResponse>[] execute2 = matrixObject.getFedMapping().execute(getTID(), broadcast, callInstruction2, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction2.getID()), matrixObject2.getFedMapping().cleanup(getTID(), broadcast.getID(), callInstruction2.getID()));
            if (!this.output.getDataType().isScalar()) {
                throw new DMLRuntimeException("Not Implemented Federated Ternary Variation");
            }
            double d = 0.0d;
            for (Future<FederatedResponse> future : execute2) {
                try {
                    d += ((ScalarObject) future.get().getData()[0]).getDoubleValue();
                } catch (Exception e) {
                    throw new DMLRuntimeException("Federated Get data failed with exception on TernaryFedInstruction", e);
                }
            }
            executionContext.setScalarOutput(this.output.getName(), new DoubleObject(d));
            return;
        }
        if (!matrixObject.isFederatedExcept(FederationMap.FType.BROADCAST) || !this.input3.isMatrix() || matrixObject3 == null) {
            if (matrixObject3 != null) {
                throw new DMLRuntimeException("Federated AggregateTernary not supported with the following federated objects: " + matrixObject.isFederated() + ":" + matrixObject.getFedMapping() + " " + matrixObject2.isFederated() + ":" + matrixObject2.getFedMapping() + matrixObject3.isFederated() + ":" + matrixObject3.getFedMapping());
            }
            throw new DMLRuntimeException("Federated AggregateTernary not supported with the following federated objects: " + matrixObject.isFederated() + ":" + matrixObject.getFedMapping() + " " + matrixObject2.isFederated() + ":" + matrixObject2.getFedMapping());
        }
        FederatedRequest[] broadcastSliced = matrixObject.getFedMapping().broadcastSliced(matrixObject3, false);
        FederatedRequest[] broadcastSliced2 = matrixObject.getFedMapping().broadcastSliced(matrixObject2, false);
        FederatedRequest callInstruction3 = FederationUtils.callInstruction(getInstructionString(), this.output, new CPOperand[]{this.input1, this.input2, this.input3}, new long[]{matrixObject.getFedMapping().getID(), broadcastSliced2[0].getID(), broadcastSliced[0].getID()});
        Future<FederatedResponse>[] execute3 = matrixObject.getFedMapping().execute(getTID(), broadcastSliced, broadcastSliced2[0], callInstruction3, new FederatedRequest(FederatedRequest.RequestType.GET_VAR, callInstruction3.getID()));
        if (!this.output.getDataType().isScalar()) {
            throw new DMLRuntimeException("Not Implemented Federated Ternary Variation");
        }
        double d2 = 0.0d;
        for (Future<FederatedResponse> future2 : execute3) {
            try {
                d2 += ((ScalarObject) future2.get().getData()[0]).getDoubleValue();
            } catch (Exception e2) {
                throw new DMLRuntimeException("Federated Get data failed with exception on TernaryFedInstruction", e2);
            }
        }
        executionContext.setScalarOutput(this.output.getName(), new DoubleObject(d2));
    }
}
