package org.apache.sysds.runtime.instructions.gpu;

import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.class */
public abstract class BuiltinBinaryGPUInstruction extends GPUInstruction {
    private int _arity;

    /* JADX INFO: Access modifiers changed from: protected */
    public BuiltinBinaryGPUInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2, int i) {
        super(operator, cPOperand, cPOperand2, cPOperand3, str, str2);
        this._arity = i;
    }

    public static BuiltinBinaryGPUInstruction parseInstruction(String str) {
        CPOperand cPOperand = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand cPOperand2 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        CPOperand cPOperand3 = new CPOperand("", Types.ValueType.UNKNOWN, Types.DataType.UNKNOWN);
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 3);
        String str2 = instructionPartsWithValueType[0];
        cPOperand.split(instructionPartsWithValueType[1]);
        cPOperand2.split(instructionPartsWithValueType[2]);
        cPOperand3.split(instructionPartsWithValueType[3]);
        if ((cPOperand.getDataType() == Types.DataType.MATRIX || cPOperand2.getDataType() == Types.DataType.MATRIX) && cPOperand3.getDataType() != Types.DataType.MATRIX) {
            throw new DMLRuntimeException("Element-wise matrix operations between variables " + cPOperand.getName() + " and " + cPOperand2.getName() + " must produce a matrix, which " + cPOperand3.getName() + " is not");
        }
        Builtin builtinFnObject = Builtin.getBuiltinFnObject(str2);
        boolean z = cPOperand.getDataType() == Types.DataType.MATRIX && cPOperand2.getDataType() == Types.DataType.MATRIX;
        boolean z2 = (cPOperand.getDataType() == Types.DataType.MATRIX && cPOperand2.getDataType() == Types.DataType.SCALAR) || (cPOperand.getDataType() == Types.DataType.SCALAR && cPOperand2.getDataType() == Types.DataType.MATRIX);
        if (cPOperand.getDataType() == Types.DataType.SCALAR && cPOperand2.getDataType() == Types.DataType.SCALAR) {
            throw new DMLRuntimeException("GPU : Unsupported GPU builtin operations on 2 scalars");
        }
        if (z && str2.equals("solve")) {
            return new MatrixMatrixBuiltinGPUInstruction(new BinaryOperator(builtinFnObject), cPOperand, cPOperand2, cPOperand3, str2, str, 2);
        }
        if (z2 && (str2.equals("min") || str2.equals("max"))) {
            return new ScalarMatrixBuiltinGPUInstruction(new BinaryOperator(builtinFnObject), cPOperand, cPOperand2, cPOperand3, str2, str, 2);
        }
        throw new DMLRuntimeException("GPU : Unsupported GPU builtin operations on a matrix and a scalar:" + str2);
    }
}
