package org.apache.sysds.runtime.instructions.fed;

import org.apache.sysds.lops.Append;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.cp.AggregateBinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.MMChainCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MMTSJCPInstruction;
import org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGAlignedSPInstruction;
import org.apache.sysds.runtime.instructions.spark.AppendGSPInstruction;
import org.apache.sysds.runtime.instructions.spark.MapmmSPInstruction;
import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.class */
public class FEDInstructionUtils {
    public static Instruction checkAndReplaceCP(Instruction instruction, ExecutionContext executionContext) {
        FEDInstruction fEDInstruction = null;
        if (instruction instanceof AggregateBinaryCPInstruction) {
            AggregateBinaryCPInstruction aggregateBinaryCPInstruction = (AggregateBinaryCPInstruction) instruction;
            if (aggregateBinaryCPInstruction.input1.isMatrix() && aggregateBinaryCPInstruction.input2.isMatrix()) {
                MatrixObject matrixObject = executionContext.getMatrixObject(aggregateBinaryCPInstruction.input1);
                MatrixObject matrixObject2 = executionContext.getMatrixObject(aggregateBinaryCPInstruction.input2);
                if (matrixObject.isFederated(FederationMap.FType.ROW) || matrixObject2.isFederated(FederationMap.FType.ROW)) {
                    fEDInstruction = AggregateBinaryFEDInstruction.parseInstruction(instruction.getInstructionString());
                }
            }
        } else if (instruction instanceof MMChainCPInstruction) {
            MMChainCPInstruction mMChainCPInstruction = (MMChainCPInstruction) instruction;
            if (executionContext.getMatrixObject(mMChainCPInstruction.input1).isFederated()) {
                fEDInstruction = MMChainFEDInstruction.parseInstruction(mMChainCPInstruction.getInstructionString());
            }
        } else if (instruction instanceof MMTSJCPInstruction) {
            MMTSJCPInstruction mMTSJCPInstruction = (MMTSJCPInstruction) instruction;
            if (executionContext.getMatrixObject(mMTSJCPInstruction.input1).isFederated()) {
                fEDInstruction = TsmmFEDInstruction.parseInstruction(mMTSJCPInstruction.getInstructionString());
            }
        } else if (instruction instanceof AggregateUnaryCPInstruction) {
            AggregateUnaryCPInstruction aggregateUnaryCPInstruction = (AggregateUnaryCPInstruction) instruction;
            if (aggregateUnaryCPInstruction.input1.isMatrix() && executionContext.containsVariable(aggregateUnaryCPInstruction.input1) && executionContext.getMatrixObject(aggregateUnaryCPInstruction.input1).isFederated() && aggregateUnaryCPInstruction.getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT) {
                fEDInstruction = AggregateUnaryFEDInstruction.parseInstruction(instruction.getInstructionString());
            }
        } else if (instruction instanceof BinaryCPInstruction) {
            BinaryCPInstruction binaryCPInstruction = (BinaryCPInstruction) instruction;
            if ((binaryCPInstruction.input1.isMatrix() && executionContext.getMatrixObject(binaryCPInstruction.input1).isFederated()) || (binaryCPInstruction.input2.isMatrix() && executionContext.getMatrixObject(binaryCPInstruction.input2).isFederated())) {
                fEDInstruction = binaryCPInstruction.getOpcode().equals(Append.OPCODE) ? AppendFEDInstruction.parseInstruction(instruction.getInstructionString()) : BinaryFEDInstruction.parseInstruction(instruction.getInstructionString());
            }
        } else if (instruction instanceof ParameterizedBuiltinCPInstruction) {
            ParameterizedBuiltinCPInstruction parameterizedBuiltinCPInstruction = (ParameterizedBuiltinCPInstruction) instruction;
            if (parameterizedBuiltinCPInstruction.getOpcode().equals("replace") && parameterizedBuiltinCPInstruction.getTarget(executionContext).isFederated()) {
                fEDInstruction = ParameterizedBuiltinFEDInstruction.parseInstruction(parameterizedBuiltinCPInstruction.getInstructionString());
            } else if ((parameterizedBuiltinCPInstruction.getOpcode().equals("transformdecode") || parameterizedBuiltinCPInstruction.getOpcode().equals("transformapply")) && parameterizedBuiltinCPInstruction.getTarget(executionContext).isFederated()) {
                return ParameterizedBuiltinFEDInstruction.parseInstruction(parameterizedBuiltinCPInstruction.getInstructionString());
            }
        } else if (instruction instanceof MultiReturnParameterizedBuiltinCPInstruction) {
            MultiReturnParameterizedBuiltinCPInstruction multiReturnParameterizedBuiltinCPInstruction = (MultiReturnParameterizedBuiltinCPInstruction) instruction;
            if (multiReturnParameterizedBuiltinCPInstruction.getOpcode().equals("transformencode") && multiReturnParameterizedBuiltinCPInstruction.input1.isFrame() && executionContext.getCacheableData(multiReturnParameterizedBuiltinCPInstruction.input1).isFederated()) {
                fEDInstruction = MultiReturnParameterizedBuiltinFEDInstruction.parseInstruction(multiReturnParameterizedBuiltinCPInstruction.getInstructionString());
            }
        } else if ((instruction instanceof ReorgCPInstruction) && instruction.getOpcode().equals("r'")) {
            ReorgCPInstruction reorgCPInstruction = (ReorgCPInstruction) instruction;
            if (executionContext.getCacheableData(reorgCPInstruction.input1).isFederated()) {
                fEDInstruction = ReorgFEDInstruction.parseInstruction(reorgCPInstruction.getInstructionString());
            }
        }
        if (fEDInstruction == null) {
            return instruction;
        }
        fEDInstruction.setTID(executionContext.getTID());
        return fEDInstruction;
    }

    public static Instruction checkAndReplaceSP(Instruction instruction, ExecutionContext executionContext) {
        FEDInstruction fEDInstruction = null;
        if (instruction instanceof MapmmSPInstruction) {
            MapmmSPInstruction mapmmSPInstruction = (MapmmSPInstruction) instruction;
            Data variable = executionContext.getVariable(mapmmSPInstruction.input1);
            if ((variable instanceof MatrixObject) && ((MatrixObject) variable).isFederated()) {
                fEDInstruction = new AggregateBinaryFEDInstruction(mapmmSPInstruction.getOperator(), mapmmSPInstruction.input1, mapmmSPInstruction.input2, mapmmSPInstruction.output, "ba+*", "FED...");
            }
        } else if (instruction instanceof AggregateUnarySPInstruction) {
            Data variable2 = executionContext.getVariable(((AggregateUnarySPInstruction) instruction).input1);
            if ((variable2 instanceof MatrixObject) && ((MatrixObject) variable2).isFederated()) {
                fEDInstruction = AggregateUnaryFEDInstruction.parseInstruction(instruction.getInstructionString());
            }
        } else if (instruction instanceof WriteSPInstruction) {
            WriteSPInstruction writeSPInstruction = (WriteSPInstruction) instruction;
            Data variable3 = executionContext.getVariable(writeSPInstruction.input1);
            if ((variable3 instanceof MatrixObject) && ((MatrixObject) variable3).isFederated()) {
                return VariableCPInstruction.parseInstruction(writeSPInstruction.getInstructionString());
            }
        } else if (instruction instanceof AppendGAlignedSPInstruction) {
            AppendGAlignedSPInstruction appendGAlignedSPInstruction = (AppendGAlignedSPInstruction) instruction;
            Data variable4 = executionContext.getVariable(appendGAlignedSPInstruction.input1);
            if ((variable4 instanceof MatrixObject) && ((MatrixObject) variable4).isFederated()) {
                fEDInstruction = AppendFEDInstruction.parseInstruction(appendGAlignedSPInstruction.getInstructionString());
            }
        } else if (instruction instanceof AppendGSPInstruction) {
            AppendGSPInstruction appendGSPInstruction = (AppendGSPInstruction) instruction;
            Data variable5 = executionContext.getVariable(appendGSPInstruction.input1);
            if ((variable5 instanceof MatrixObject) && ((MatrixObject) variable5).isFederated()) {
                fEDInstruction = AppendFEDInstruction.parseInstruction(appendGSPInstruction.getInstructionString());
            }
        }
        if (fEDInstruction == null) {
            return instruction;
        }
        fEDInstruction.setTID(executionContext.getTID());
        return fEDInstruction;
    }
}
