package org.apache.sysds.runtime.instructions.fed;

import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.lops.SortKeys;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
import org.apache.sysds.runtime.instructions.cp.QuantileSortCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
import org.apache.sysds.runtime.instructions.cp.ReshapeCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
import org.apache.sysds.runtime.instructions.spark.CentralMomentSPInstruction;
import org.apache.sysds.runtime.instructions.spark.IndexingSPInstruction;
import org.apache.sysds.runtime.instructions.spark.QuantileSortSPInstruction;
import org.apache.sysds.runtime.instructions.spark.ReblockSPInstruction;
import org.apache.sysds.runtime.instructions.spark.ReorgSPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnaryMatrixSPInstruction;
import org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.class */
public abstract class UnaryFEDInstruction extends ComputationFEDInstruction {
    /* JADX INFO: Access modifiers changed from: protected */
    public UnaryFEDInstruction(FEDInstruction.FEDType fEDType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, String str, String str2) {
        this(fEDType, operator, cPOperand, (CPOperand) null, (CPOperand) null, cPOperand2, str, str2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public UnaryFEDInstruction(FEDInstruction.FEDType fEDType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, String str, String str2, FEDInstruction.FederatedOutput federatedOutput) {
        this(fEDType, operator, cPOperand, null, null, cPOperand2, str, str2, federatedOutput);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public UnaryFEDInstruction(FEDInstruction.FEDType fEDType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        this(fEDType, operator, cPOperand, cPOperand2, (CPOperand) null, cPOperand3, str, str2);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public UnaryFEDInstruction(FEDInstruction.FEDType fEDType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, FEDInstruction.FederatedOutput federatedOutput) {
        this(fEDType, operator, cPOperand, cPOperand2, null, cPOperand3, str, str2, federatedOutput);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public UnaryFEDInstruction(FEDInstruction.FEDType fEDType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2) {
        this(fEDType, operator, cPOperand, cPOperand2, cPOperand3, cPOperand4, str, str2, FEDInstruction.FederatedOutput.NONE);
    }

    protected UnaryFEDInstruction(FEDInstruction.FEDType fEDType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2, FEDInstruction.FederatedOutput federatedOutput) {
        super(fEDType, operator, cPOperand, cPOperand2, cPOperand3, cPOperand4, str, str2, federatedOutput);
    }

    public static UnaryFEDInstruction parseInstruction(UnaryCPInstruction unaryCPInstruction, ExecutionContext executionContext) {
        if (unaryCPInstruction instanceof IndexingCPInstruction) {
            IndexingCPInstruction indexingCPInstruction = (IndexingCPInstruction) unaryCPInstruction;
            if ((indexingCPInstruction.input1.isMatrix() || indexingCPInstruction.input1.isFrame()) && executionContext.getCacheableData(indexingCPInstruction.input1).isFederatedExcept(FTypes.FType.BROADCAST)) {
                return IndexingFEDInstruction.parseInstruction(indexingCPInstruction);
            }
            return null;
        }
        if ((unaryCPInstruction instanceof ReorgCPInstruction) && (unaryCPInstruction.getOpcode().equals("r'") || unaryCPInstruction.getOpcode().equals("rdiag") || unaryCPInstruction.getOpcode().equals("rev"))) {
            ReorgCPInstruction reorgCPInstruction = (ReorgCPInstruction) unaryCPInstruction;
            CacheableData<?> cacheableData = executionContext.getCacheableData(reorgCPInstruction.input1);
            if (((cacheableData instanceof MatrixObject) || (cacheableData instanceof FrameObject)) && cacheableData.isFederatedExcept(FTypes.FType.BROADCAST)) {
                return ReorgFEDInstruction.parseInstruction(reorgCPInstruction);
            }
            return null;
        }
        if (unaryCPInstruction.input1 == null || !unaryCPInstruction.input1.isMatrix() || !executionContext.containsVariable(unaryCPInstruction.input1)) {
            return null;
        }
        MatrixObject matrixObject = executionContext.getMatrixObject(unaryCPInstruction.input1);
        if (!matrixObject.isFederatedExcept(FTypes.FType.BROADCAST)) {
            return null;
        }
        if (unaryCPInstruction instanceof CentralMomentCPInstruction) {
            return CentralMomentFEDInstruction.parseInstruction((CentralMomentCPInstruction) unaryCPInstruction);
        }
        if (unaryCPInstruction instanceof QuantileSortCPInstruction) {
            if (matrixObject.isFederated(FTypes.FType.ROW) || (matrixObject.getFedMapping().getFederatedRanges().length == 1 && matrixObject.isFederated(FTypes.FType.COL))) {
                return QuantileSortFEDInstruction.parseInstruction((QuantileSortCPInstruction) unaryCPInstruction);
            }
            return null;
        }
        if (unaryCPInstruction instanceof ReshapeCPInstruction) {
            return ReshapeFEDInstruction.parseInstruction((ReshapeCPInstruction) unaryCPInstruction);
        }
        if ((unaryCPInstruction instanceof AggregateUnaryCPInstruction) && ((AggregateUnaryCPInstruction) unaryCPInstruction).getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT) {
            return AggregateUnaryFEDInstruction.parseInstruction((AggregateUnaryCPInstruction) unaryCPInstruction);
        }
        if (!(unaryCPInstruction instanceof UnaryMatrixCPInstruction) || !UnaryMatrixFEDInstruction.isValidOpcode(unaryCPInstruction.getOpcode())) {
            return null;
        }
        if (unaryCPInstruction.getOpcode().equalsIgnoreCase("ucumk+*") && matrixObject.isFederated(FTypes.FType.COL)) {
            return null;
        }
        return UnaryMatrixFEDInstruction.parseInstruction((UnaryMatrixCPInstruction) unaryCPInstruction);
    }

    public static UnaryFEDInstruction parseInstruction(UnarySPInstruction unarySPInstruction, ExecutionContext executionContext) {
        if (unarySPInstruction instanceof IndexingSPInstruction) {
            IndexingSPInstruction indexingSPInstruction = (IndexingSPInstruction) unarySPInstruction;
            if ((indexingSPInstruction.input1.isMatrix() || indexingSPInstruction.input1.isFrame()) && executionContext.getCacheableData(indexingSPInstruction.input1).isFederatedExcept(FTypes.FType.BROADCAST)) {
                return IndexingFEDInstruction.parseInstruction(indexingSPInstruction);
            }
            return null;
        }
        if (unarySPInstruction instanceof CentralMomentSPInstruction) {
            CentralMomentSPInstruction centralMomentSPInstruction = (CentralMomentSPInstruction) unarySPInstruction;
            Data variable = executionContext.getVariable(centralMomentSPInstruction.input1);
            if ((variable instanceof MatrixObject) && ((MatrixObject) variable).isFederated() && ((MatrixObject) variable).isFederatedExcept(FTypes.FType.BROADCAST)) {
                return CentralMomentFEDInstruction.parseInstruction(centralMomentSPInstruction);
            }
            return null;
        }
        if (unarySPInstruction instanceof QuantileSortSPInstruction) {
            QuantileSortSPInstruction quantileSortSPInstruction = (QuantileSortSPInstruction) unarySPInstruction;
            Data variable2 = executionContext.getVariable(quantileSortSPInstruction.input1);
            if ((variable2 instanceof MatrixObject) && ((MatrixObject) variable2).isFederated() && ((MatrixObject) variable2).isFederatedExcept(FTypes.FType.BROADCAST)) {
                return QuantileSortFEDInstruction.parseInstruction(quantileSortSPInstruction);
            }
            return null;
        }
        if (unarySPInstruction instanceof AggregateUnarySPInstruction) {
            AggregateUnarySPInstruction aggregateUnarySPInstruction = (AggregateUnarySPInstruction) unarySPInstruction;
            Data variable3 = executionContext.getVariable(aggregateUnarySPInstruction.input1);
            if (!(variable3 instanceof MatrixObject) || !((MatrixObject) variable3).isFederated() || !((MatrixObject) variable3).isFederatedExcept(FTypes.FType.BROADCAST)) {
                return null;
            }
            if (!ArrayUtils.contains(new String[]{"uarimin", "uarimax"}, aggregateUnarySPInstruction.getOpcode()) || ((MatrixObject) variable3).getFedMapping().getType() == FTypes.FType.ROW) {
                return AggregateUnaryFEDInstruction.parseInstruction(aggregateUnarySPInstruction);
            }
            return null;
        }
        if ((unarySPInstruction instanceof ReorgSPInstruction) && (unarySPInstruction.getOpcode().equals("r'") || unarySPInstruction.getOpcode().equals("rdiag") || unarySPInstruction.getOpcode().equals("rev"))) {
            ReorgSPInstruction reorgSPInstruction = (ReorgSPInstruction) unarySPInstruction;
            CacheableData<?> cacheableData = executionContext.getCacheableData(reorgSPInstruction.input1);
            if (((cacheableData instanceof MatrixObject) || (cacheableData instanceof FrameObject)) && cacheableData.isFederated() && cacheableData.isFederatedExcept(FTypes.FType.BROADCAST)) {
                return ReorgFEDInstruction.parseInstruction(reorgSPInstruction);
            }
            return null;
        }
        if ((unarySPInstruction instanceof ReblockSPInstruction) && unarySPInstruction.input1 != null && (unarySPInstruction.input1.isFrame() || unarySPInstruction.input1.isMatrix())) {
            if (executionContext.getCacheableData(((ReblockSPInstruction) unarySPInstruction).input1).isFederatedExcept(FTypes.FType.BROADCAST)) {
                return ReblockFEDInstruction.parseInstruction((ReblockSPInstruction) unarySPInstruction);
            }
            return null;
        }
        if (unarySPInstruction.input1 == null || !unarySPInstruction.input1.isMatrix() || !executionContext.containsVariable(unarySPInstruction.input1)) {
            return null;
        }
        MatrixObject matrixObject = executionContext.getMatrixObject(unarySPInstruction.input1);
        if (!matrixObject.isFederatedExcept(FTypes.FType.BROADCAST)) {
            return null;
        }
        if (unarySPInstruction.getOpcode().equalsIgnoreCase("cm")) {
            return CentralMomentFEDInstruction.parseInstruction((CentralMomentSPInstruction) unarySPInstruction);
        }
        if (unarySPInstruction.getOpcode().equalsIgnoreCase(SortKeys.OPCODE)) {
            if (matrixObject.getFedMapping().getFederatedRanges().length == 1) {
                return QuantileSortFEDInstruction.parseInstruction(unarySPInstruction.getInstructionString(), false);
            }
            return null;
        }
        if (unarySPInstruction.getOpcode().equalsIgnoreCase("rshape")) {
            return ReshapeFEDInstruction.parseInstruction(unarySPInstruction.getInstructionString());
        }
        if ((unarySPInstruction instanceof UnaryMatrixSPInstruction) && UnaryMatrixFEDInstruction.isValidOpcode(unarySPInstruction.getOpcode())) {
            return UnaryMatrixFEDInstruction.parseInstruction((UnaryMatrixSPInstruction) unarySPInstruction);
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static String parseUnaryInstruction(String str, CPOperand cPOperand, CPOperand cPOperand2) {
        if (InstructionUtils.checkNumFields(str, 2, 3, 4) == 2) {
            return parse(str, cPOperand, null, null, cPOperand2);
        }
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        cPOperand.split(instructionPartsWithValueType[1]);
        cPOperand2.split(instructionPartsWithValueType[2]);
        return str2;
    }

    protected static String parseUnaryInstruction(String str, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3) {
        InstructionUtils.checkNumFields(str, 3);
        return parse(str, cPOperand, cPOperand2, null, cPOperand3);
    }

    protected static String parseUnaryInstruction(String str, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4) {
        InstructionUtils.checkNumFields(str, 4);
        return parse(str, cPOperand, cPOperand2, cPOperand3, cPOperand4);
    }

    private static String parse(String str, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        cPOperand4.split(instructionPartsWithValueType[instructionPartsWithValueType.length - 1]);
        switch (instructionPartsWithValueType.length) {
            case 3:
                cPOperand.split(instructionPartsWithValueType[1]);
                break;
            case 4:
                cPOperand.split(instructionPartsWithValueType[1]);
                cPOperand2.split(instructionPartsWithValueType[2]);
                break;
            case 5:
                cPOperand.split(instructionPartsWithValueType[1]);
                cPOperand2.split(instructionPartsWithValueType[2]);
                cPOperand3.split(instructionPartsWithValueType[3]);
                break;
            default:
                throw new DMLRuntimeException("Unexpected number of operands in the instruction: " + str);
        }
        return str2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static FEDInstruction.FederatedOutput parseFedOutFlag(String str, int i) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        return instructionPartsWithValueType.length > i ? FEDInstruction.FederatedOutput.valueOf(instructionPartsWithValueType[i]) : FEDInstruction.FederatedOutput.NONE;
    }
}
