package org.apache.sysds.runtime.instructions.fed;

import org.apache.sysds.common.Types;
import org.apache.sysds.lops.WeightedCrossEntropy;
import org.apache.sysds.lops.WeightedDivMM;
import org.apache.sysds.lops.WeightedDivMMR;
import org.apache.sysds.lops.WeightedSigmoid;
import org.apache.sysds.lops.WeightedSquaredLoss;
import org.apache.sysds.lops.WeightedSquaredLossR;
import org.apache.sysds.lops.WeightedUnaryMM;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.matrix.operators.QuaternaryOperator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/QuaternaryFEDInstruction.class */
public abstract class QuaternaryFEDInstruction extends ComputationFEDInstruction {
    protected CPOperand _input4;

    /* JADX INFO: Access modifiers changed from: protected */
    public QuaternaryFEDInstruction(FEDInstruction.FEDType fEDType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2) {
        super(fEDType, operator, cPOperand, cPOperand2, cPOperand3, cPOperand4, str, str2);
        this._input4 = null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public QuaternaryFEDInstruction(FEDInstruction.FEDType fEDType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, CPOperand cPOperand5, String str, String str2) {
        super(fEDType, operator, cPOperand, cPOperand2, cPOperand3, cPOperand5, str, str2);
        this._input4 = null;
        this._input4 = cPOperand4;
    }

    public static QuaternaryFEDInstruction parseInstruction(String str) {
        if (str.startsWith(Types.ExecType.SPARK.name())) {
            str = rewriteSparkInstructionToCP(str);
        }
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        int i = (str2.equals(WeightedCrossEntropy.OPCODE_CP) || str2.equals(WeightedSquaredLoss.OPCODE_CP) || str2.equals(WeightedDivMM.OPCODE_CP)) ? 1 : 0;
        int i2 = str2.equals(WeightedUnaryMM.OPCODE_CP) ? 1 : 0;
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 6 + i + i2);
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1 + i2]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2 + i2]);
        CPOperand cPOperand3 = new CPOperand(instructionPartsWithValueType[3 + i2]);
        CPOperand cPOperand4 = new CPOperand(instructionPartsWithValueType[4 + i + i2]);
        checkDataTypes(Types.DataType.MATRIX, cPOperand, cPOperand2, cPOperand3);
        if (i == 1) {
            CPOperand cPOperand5 = new CPOperand(instructionPartsWithValueType[4]);
            if (str2.equals(WeightedCrossEntropy.OPCODE_CP)) {
                WeightedCrossEntropy.WCeMMType valueOf = WeightedCrossEntropy.WCeMMType.valueOf(instructionPartsWithValueType[6]);
                if (valueOf.hasFourInputs()) {
                    checkDataTypes(new Types.DataType[]{Types.DataType.SCALAR, Types.DataType.MATRIX}, cPOperand5);
                }
                return new QuaternaryWCeMMFEDInstruction(valueOf.hasFourInputs() ? new QuaternaryOperator(valueOf, Double.parseDouble(cPOperand5.getName())) : new QuaternaryOperator(valueOf), cPOperand, cPOperand2, cPOperand3, cPOperand5, cPOperand4, str2, str);
            }
            if (str2.equals(WeightedDivMM.OPCODE_CP)) {
                WeightedDivMM.WDivMMType valueOf2 = WeightedDivMM.WDivMMType.valueOf(instructionPartsWithValueType[6]);
                if (valueOf2.hasFourInputs()) {
                    checkDataTypes(new Types.DataType[]{Types.DataType.SCALAR, Types.DataType.MATRIX}, cPOperand5);
                }
                return new QuaternaryWDivMMFEDInstruction(new QuaternaryOperator(valueOf2), cPOperand, cPOperand2, cPOperand3, cPOperand5, cPOperand4, str2, str);
            }
            if (str2.equals(WeightedSquaredLoss.OPCODE_CP)) {
                WeightedSquaredLoss.WeightsType valueOf3 = WeightedSquaredLoss.WeightsType.valueOf(instructionPartsWithValueType[6]);
                if (valueOf3.hasFourInputs()) {
                    checkDataTypes(Types.DataType.MATRIX, cPOperand5);
                }
                return new QuaternaryWSLossFEDInstruction(new QuaternaryOperator(valueOf3), cPOperand, cPOperand2, cPOperand3, cPOperand5, cPOperand4, str2, str);
            }
        } else {
            if (str2.equals(WeightedSigmoid.OPCODE_CP)) {
                return new QuaternaryWSigmoidFEDInstruction(new QuaternaryOperator(WeightedSigmoid.WSigmoidType.valueOf(instructionPartsWithValueType[5])), cPOperand, cPOperand2, cPOperand3, cPOperand4, str2, str);
            }
            if (str2.equals(WeightedUnaryMM.OPCODE_CP)) {
                return new QuaternaryWUMMFEDInstruction(new QuaternaryOperator(WeightedUnaryMM.WUMMType.valueOf(instructionPartsWithValueType[6]), instructionPartsWithValueType[1]), cPOperand, cPOperand2, cPOperand3, cPOperand4, str2, str);
            }
        }
        throw new DMLRuntimeException("Unsupported opcode (" + str2 + ") for QuaternaryFEDInstruction.");
    }

    protected static void checkDataTypes(Types.DataType dataType, CPOperand... cPOperandArr) {
        checkDataTypes(new Types.DataType[]{dataType}, cPOperandArr);
    }

    protected static void checkDataTypes(Types.DataType[] dataTypeArr, CPOperand... cPOperandArr) {
        for (CPOperand cPOperand : cPOperandArr) {
            if (!checkDataType(dataTypeArr, cPOperand)) {
                throw new DMLRuntimeException("Federated quaternary operations only supported with matrix inputs and scalar epsilon.");
            }
        }
    }

    private static boolean checkDataType(Types.DataType[] dataTypeArr, CPOperand cPOperand) {
        for (Types.DataType dataType : dataTypeArr) {
            if (cPOperand.getDataType() == dataType) {
                return true;
            }
        }
        return false;
    }

    private static String rewriteSparkInstructionToCP(String str) {
        String replace = str.replace(Types.ExecType.SPARK.name(), Types.ExecType.CP.name());
        if (replace.contains(WeightedCrossEntropy.OPCODE)) {
            replace = replace.replace(WeightedCrossEntropy.OPCODE, WeightedCrossEntropy.OPCODE_CP);
        } else if (replace.contains(WeightedDivMM.OPCODE)) {
            replace = replace.replace(WeightedDivMM.OPCODE, WeightedDivMM.OPCODE_CP);
        } else if (replace.contains(WeightedSigmoid.OPCODE)) {
            replace = replace.replace(WeightedSigmoid.OPCODE, WeightedSigmoid.OPCODE_CP);
        } else if (replace.contains(WeightedSquaredLoss.OPCODE)) {
            replace = replace.replace(WeightedSquaredLoss.OPCODE, WeightedSquaredLoss.OPCODE_CP);
        } else if (replace.contains(WeightedUnaryMM.OPCODE)) {
            replace = replace.replace(WeightedUnaryMM.OPCODE, WeightedUnaryMM.OPCODE_CP);
        } else if (replace.contains(WeightedDivMMR.OPCODE) || replace.contains(WeightedSquaredLossR.OPCODE)) {
            replace = replace.replace(WeightedDivMMR.OPCODE, WeightedDivMM.OPCODE_CP).replace(WeightedSquaredLossR.OPCODE, WeightedSquaredLoss.OPCODE_CP).replace("°true", "").replace("°false", "");
        }
        return replace + "°1";
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void setOutputDataCharacteristics(MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, ExecutionContext executionContext) {
        executionContext.getMatrixObject(this.output).getDataCharacteristics().set(matrixObject.getNumRows() > 1 ? matrixObject.getNumRows() : matrixObject2.getNumRows(), matrixObject.getNumColumns() > 1 ? matrixObject.getNumColumns() : matrixObject2.getNumColumns() == matrixObject3.getNumRows() ? matrixObject3.getNumColumns() : matrixObject3.getNumRows(), (int) matrixObject.getBlocksize());
    }
}
