package org.apache.sysds.runtime.instructions.fed;

import org.apache.sysds.common.Types;
import org.apache.sysds.lops.BinaryM;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/fed/BinaryFEDInstruction.class */
public abstract class BinaryFEDInstruction extends ComputationFEDInstruction {
    /* JADX INFO: Access modifiers changed from: protected */
    public BinaryFEDInstruction(FEDInstruction.FEDType fEDType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, FEDInstruction.FederatedOutput federatedOutput) {
        super(fEDType, operator, cPOperand, cPOperand2, cPOperand3, str, str2, federatedOutput);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public BinaryFEDInstruction(FEDInstruction.FEDType fEDType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        this(fEDType, operator, cPOperand, cPOperand2, cPOperand3, str, str2, FEDInstruction.FederatedOutput.NONE);
    }

    public BinaryFEDInstruction(FEDInstruction.FEDType fEDType, Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2) {
        super(fEDType, operator, cPOperand, cPOperand2, cPOperand3, cPOperand4, str, str2);
    }

    public static BinaryFEDInstruction parseInstruction(String str) {
        if (str.startsWith(Types.ExecType.SPARK.name())) {
            str = rewriteSparkInstructionToCP(str);
        }
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 3, 4, 5, 6);
        String str2 = instructionPartsWithValueType[0];
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        CPOperand cPOperand3 = new CPOperand(instructionPartsWithValueType[3]);
        FEDInstruction.FederatedOutput valueOf = FEDInstruction.FederatedOutput.valueOf(instructionPartsWithValueType[instructionPartsWithValueType.length - 1]);
        checkOutputDataType(cPOperand, cPOperand2, cPOperand3);
        Operator parseBinaryOrBuiltinOperator = InstructionUtils.parseBinaryOrBuiltinOperator(str2, cPOperand, cPOperand2);
        if (cPOperand.getDataType() == Types.DataType.SCALAR && cPOperand2.getDataType() == Types.DataType.SCALAR) {
            throw new DMLRuntimeException("Federated binary scalar scalar operations not yet supported");
        }
        if (cPOperand.getDataType() == Types.DataType.MATRIX && cPOperand2.getDataType() == Types.DataType.MATRIX) {
            return new BinaryMatrixMatrixFEDInstruction(parseBinaryOrBuiltinOperator, cPOperand, cPOperand2, cPOperand3, str2, str, valueOf);
        }
        if (cPOperand.getDataType() == Types.DataType.TENSOR && cPOperand2.getDataType() == Types.DataType.TENSOR) {
            throw new DMLRuntimeException("Federated binary tensor tensor operations not yet supported");
        }
        if ((cPOperand.isMatrix() && cPOperand2.isScalar()) || (cPOperand2.isMatrix() && cPOperand.isScalar())) {
            return new BinaryMatrixScalarFEDInstruction(parseBinaryOrBuiltinOperator, cPOperand, cPOperand2, cPOperand3, str2, str, valueOf);
        }
        throw new DMLRuntimeException("Federated binary operations not yet supported:" + str2);
    }

    protected static String parseBinaryInstruction(String str, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 3, 4);
        String str2 = instructionPartsWithValueType[0];
        cPOperand.split(instructionPartsWithValueType[1]);
        cPOperand2.split(instructionPartsWithValueType[2]);
        cPOperand3.split(instructionPartsWithValueType[3]);
        return str2;
    }

    protected static String parseBinaryInstruction(String str, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, CPOperand cPOperand4) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 4);
        String str2 = instructionPartsWithValueType[0];
        cPOperand.split(instructionPartsWithValueType[1]);
        cPOperand2.split(instructionPartsWithValueType[2]);
        cPOperand3.split(instructionPartsWithValueType[3]);
        cPOperand4.split(instructionPartsWithValueType[4]);
        return str2;
    }

    protected static void checkOutputDataType(CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3) {
        if ((cPOperand.getDataType() == Types.DataType.MATRIX || cPOperand2.getDataType() == Types.DataType.MATRIX) && cPOperand3.getDataType() != Types.DataType.MATRIX) {
            throw new DMLRuntimeException("Element-wise matrix operations between variables " + cPOperand.getName() + " and " + cPOperand2.getName() + " must produce a matrix, which " + cPOperand3.getName() + " is not");
        }
    }

    private static String rewriteSparkInstructionToCP(String str) {
        return str.replace(Types.ExecType.SPARK.name(), Types.ExecType.CP.name()).replace("°map", "°").replace("°RIGHT", "").replace("°" + BinaryM.VectorType.ROW_VECTOR.name(), "").replace("°" + BinaryM.VectorType.COL_VECTOR.name(), "");
    }
}
