package org.apache.sysds.runtime.instructions.gpu;

import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/gpu/ArithmeticBinaryGPUInstruction.class */
public abstract class ArithmeticBinaryGPUInstruction extends GPUInstruction {
    /* JADX INFO: Access modifiers changed from: protected */
    public ArithmeticBinaryGPUInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        super(operator, cPOperand, cPOperand2, cPOperand3, str, str2);
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.ArithmeticBinary;
    }

    public static ArithmeticBinaryGPUInstruction parseInstruction(String str) {
        Operator parseBinaryOperator;
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 3);
        String str2 = instructionPartsWithValueType[0];
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        CPOperand cPOperand3 = new CPOperand(instructionPartsWithValueType[3]);
        Types.DataType dataType = cPOperand.getDataType();
        Types.DataType dataType2 = cPOperand2.getDataType();
        Types.DataType dataType3 = cPOperand3.getDataType();
        if (dataType != dataType2) {
            parseBinaryOperator = InstructionUtils.parseScalarBinaryOperator(str2, dataType == Types.DataType.SCALAR);
        } else {
            parseBinaryOperator = InstructionUtils.parseBinaryOperator(str2);
        }
        Operator operator = parseBinaryOperator;
        if (dataType == Types.DataType.MATRIX && dataType2 == Types.DataType.MATRIX && dataType3 == Types.DataType.MATRIX) {
            return new MatrixMatrixArithmeticGPUInstruction(operator, cPOperand, cPOperand2, cPOperand3, str2, str);
        }
        if (dataType3 == Types.DataType.MATRIX && ((dataType == Types.DataType.SCALAR && dataType2 == Types.DataType.MATRIX) || (dataType == Types.DataType.MATRIX && dataType2 == Types.DataType.SCALAR))) {
            return new ScalarMatrixArithmeticGPUInstruction(operator, cPOperand, cPOperand2, cPOperand3, str2, str);
        }
        throw new DMLRuntimeException("Unsupported GPU ArithmeticInstruction.");
    }
}
