package org.apache.sysds.runtime.instructions.fed;

import java.util.concurrent.Future;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/AggregateBinaryFEDInstruction.class */
public class AggregateBinaryFEDInstruction extends BinaryFEDInstruction {
    private static final Log LOG = LogFactory.getLog(AggregateBinaryFEDInstruction.class.getName());

    public AggregateBinaryFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        super(FEDInstruction.FEDType.AggregateBinary, operator, cPOperand, cPOperand2, cPOperand3, str, str2);
    }

    public AggregateBinaryFEDInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, FEDInstruction.FederatedOutput federatedOutput) {
        super(FEDInstruction.FEDType.AggregateBinary, operator, cPOperand, cPOperand2, cPOperand3, str, str2, federatedOutput);
    }

    public static AggregateBinaryFEDInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        if (!str2.equalsIgnoreCase("ba+*")) {
            throw new DMLRuntimeException("AggregateBinaryInstruction.parseInstruction():: Unknown opcode " + str2);
        }
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 5);
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        CPOperand cPOperand3 = new CPOperand(instructionPartsWithValueType[3]);
        int parseInt = Integer.parseInt(instructionPartsWithValueType[4]);
        return new AggregateBinaryFEDInstruction(InstructionUtils.getMatMultOperator(parseInt), cPOperand, cPOperand2, cPOperand3, str2, str, FEDInstruction.FederatedOutput.valueOf(instructionPartsWithValueType[5]));
    }

    @Override // org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        MatrixLineagePair matrixLineagePair = executionContext.getMatrixLineagePair(this.input1);
        MatrixLineagePair matrixLineagePair2 = executionContext.getMatrixLineagePair(this.input2);
        if (matrixLineagePair.isFederated(FTypes.FType.COL) && matrixLineagePair2.isFederated(FTypes.FType.ROW) && matrixLineagePair.getFedMapping().isAligned(matrixLineagePair2.getFedMapping(), FTypes.AlignType.COL_T)) {
            FederatedRequest callInstruction = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixLineagePair.getFedMapping().getID(), matrixLineagePair2.getFedMapping().getID()}, true);
            if (this._fedOut.isForcedFederated()) {
                writeInfoLog(matrixLineagePair, matrixLineagePair2);
            }
            aggregateLocally(matrixLineagePair.getFedMapping(), true, executionContext, callInstruction);
            return;
        }
        if (!matrixLineagePair.isFederated(FTypes.FType.ROW)) {
            if (matrixLineagePair2.isFederated(FTypes.FType.ROW)) {
                FederatedRequest[] broadcastSliced = matrixLineagePair2.getFedMapping().broadcastSliced(matrixLineagePair, true);
                FederatedRequest callInstruction2 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{broadcastSliced[0].getID(), matrixLineagePair2.getFedMapping().getID()}, true);
                if (this._fedOut.isForcedFederated()) {
                    writeInfoLog(matrixLineagePair, matrixLineagePair2);
                }
                aggregateLocally(matrixLineagePair2.getFedMapping(), true, executionContext, broadcastSliced, callInstruction2);
                return;
            }
            if (!matrixLineagePair.isFederated(FTypes.FType.COL)) {
                throw new DMLRuntimeException("Federated AggregateBinary not supported with the following federated objects: " + matrixLineagePair.isFederated() + ":" + matrixLineagePair.getFedMapping() + " " + matrixLineagePair2.isFederated() + ":" + matrixLineagePair2.getFedMapping());
            }
            FederatedRequest[] broadcastSliced2 = matrixLineagePair.getFedMapping().broadcastSliced(matrixLineagePair2, true);
            FederatedRequest callInstruction3 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixLineagePair.getFedMapping().getID(), broadcastSliced2[0].getID()}, true);
            if (this._fedOut.isForcedFederated()) {
                writeInfoLog(matrixLineagePair, matrixLineagePair2);
            }
            aggregateLocally(matrixLineagePair.getFedMapping(), true, executionContext, broadcastSliced2, callInstruction3);
            return;
        }
        FederatedRequest broadcast = matrixLineagePair.getFedMapping().broadcast(matrixLineagePair2);
        FederatedRequest callInstruction4 = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2}, new long[]{matrixLineagePair.getFedMapping().getID(), broadcast.getID()}, true);
        boolean z = matrixLineagePair2.getNumColumns() == 1;
        boolean z2 = matrixLineagePair.isFederated(FTypes.FType.PART) || (!z && matrixLineagePair2.isFederated(FTypes.FType.PART));
        if (z2 && this._fedOut.isForcedFederated()) {
            writeInfoLog(matrixLineagePair, matrixLineagePair2);
        }
        if ((!this._fedOut.isForcedFederated() && (z || this._fedOut.isForcedLocal())) || z2) {
            aggregateLocally(matrixLineagePair.getFedMapping(), matrixLineagePair.isFederated(FTypes.FType.PART), executionContext, broadcast, callInstruction4);
        } else {
            matrixLineagePair.getFedMapping().execute(getTID(), true, broadcast, callInstruction4);
            setOutputFedMapping(matrixLineagePair.getFedMapping(), matrixLineagePair, matrixLineagePair2, callInstruction4.getID(), executionContext);
        }
    }

    private void writeInfoLog(MatrixLineagePair matrixLineagePair, MatrixLineagePair matrixLineagePair2) {
        FTypes.FType type = matrixLineagePair.getFedMapping() == null ? null : matrixLineagePair.getFedMapping().getType();
        FTypes.FType type2 = matrixLineagePair2.getFedMapping() == null ? null : matrixLineagePair2.getFedMapping().getType();
        LOG.info("Federated output flag would result in PART federated map and has been ignored in " + this.instString);
        LOG.info("Input 1 FType is " + type + " and input 2 FType " + type2);
    }

    private void setPartialOutput(FederationMap federationMap, MatrixLineagePair matrixLineagePair, MatrixLineagePair matrixLineagePair2, long j, ExecutionContext executionContext) {
        MatrixObject matrixObject = executionContext.getMatrixObject(this.output);
        matrixObject.getDataCharacteristics().set(matrixLineagePair.getNumRows(), matrixLineagePair2.getNumColumns(), (int) matrixLineagePair.getBlocksize());
        matrixObject.setFedMapping(federationMap.copyWithNewIDAndRange(matrixLineagePair.getNumRows(), matrixLineagePair2.getNumColumns(), j));
    }

    private void setOutputFedMapping(FederationMap federationMap, MatrixLineagePair matrixLineagePair, MatrixLineagePair matrixLineagePair2, long j, ExecutionContext executionContext) {
        MatrixObject matrixObject = executionContext.getMatrixObject(this.output);
        matrixObject.getDataCharacteristics().set(matrixLineagePair.getNumRows(), matrixLineagePair2.getNumColumns(), (int) matrixLineagePair.getBlocksize());
        matrixObject.setFedMapping(federationMap.copyWithNewID(j, matrixLineagePair2.getNumColumns()));
    }

    private void aggregateLocally(FederationMap federationMap, boolean z, ExecutionContext executionContext, FederatedRequest... federatedRequestArr) {
        aggregateLocally(federationMap, z, executionContext, null, federatedRequestArr);
    }

    private void aggregateLocally(FederationMap federationMap, boolean z, ExecutionContext executionContext, FederatedRequest[] federatedRequestArr, FederatedRequest... federatedRequestArr2) {
        long id = federatedRequestArr2[federatedRequestArr2.length - 1].getID();
        FederatedRequest federatedRequest = new FederatedRequest(FederatedRequest.RequestType.GET_VAR, id);
        FederatedRequest cleanup = federationMap.cleanup(getTID(), id);
        Future<FederatedResponse>[] execute = federatedRequestArr != null ? federationMap.execute(getTID(), federatedRequestArr, (FederatedRequest[]) ArrayUtils.addAll(federatedRequestArr2, new FederatedRequest[]{federatedRequest, cleanup})) : federationMap.execute(getTID(), (FederatedRequest[]) ArrayUtils.addAll(federatedRequestArr2, new FederatedRequest[]{federatedRequest, cleanup}));
        executionContext.setMatrixOutput(this.output.getName(), z ? FederationUtils.aggAdd(execute) : FederationUtils.bind(execute, false));
    }
}
