package org.apache.sysds.runtime.instructions.gpu;

import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.utils.GPUStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/gpu/MatrixMatrixAxpyGPUInstruction.class */
public class MatrixMatrixAxpyGPUInstruction extends ArithmeticBinaryGPUInstruction {
    CPOperand constant;
    int multiplier;

    private MatrixMatrixAxpyGPUInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, int i, CPOperand cPOperand3, CPOperand cPOperand4, String str, String str2) {
        super(operator, cPOperand, cPOperand3, cPOperand4, str, str2);
        this.constant = null;
        this.multiplier = 1;
        this.constant = cPOperand2;
        this.multiplier = i;
    }

    public static MatrixMatrixAxpyGPUInstruction parseInstruction(String str) {
        Operator parseTernaryOperator;
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        InstructionUtils.checkNumFields(instructionPartsWithValueType, 4);
        String str2 = instructionPartsWithValueType[0];
        int i = 1;
        if (str2.equals("-*")) {
            i = -1;
        }
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        if (cPOperand2.getDataType() != Types.DataType.SCALAR) {
            throw new DMLRuntimeException("Expected second operand to be a scalar");
        }
        CPOperand cPOperand3 = new CPOperand(instructionPartsWithValueType[3]);
        CPOperand cPOperand4 = new CPOperand(instructionPartsWithValueType[4]);
        Types.DataType dataType = cPOperand.getDataType();
        Types.DataType dataType2 = cPOperand3.getDataType();
        Types.DataType dataType3 = cPOperand4.getDataType();
        if (dataType != dataType2) {
            parseTernaryOperator = InstructionUtils.parseScalarBinaryOperator(str2, dataType == Types.DataType.SCALAR);
        } else {
            parseTernaryOperator = InstructionUtils.parseTernaryOperator(str2);
        }
        Operator operator = parseTernaryOperator;
        if (dataType == Types.DataType.MATRIX && dataType2 == Types.DataType.MATRIX && dataType3 == Types.DataType.MATRIX) {
            return new MatrixMatrixAxpyGPUInstruction(operator, cPOperand, cPOperand2, i, cPOperand3, cPOperand4, str2, str);
        }
        if (dataType3 == Types.DataType.MATRIX && ((dataType == Types.DataType.SCALAR && dataType2 == Types.DataType.MATRIX) || (dataType == Types.DataType.MATRIX && dataType2 == Types.DataType.SCALAR))) {
            throw new DMLRuntimeException("Unsupported GPU PlusMult/MinusMult ArithmeticInstruction.");
        }
        throw new DMLRuntimeException("Unsupported GPU ArithmeticInstruction.");
    }

    @Override // org.apache.sysds.runtime.instructions.gpu.GPUInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        MatrixObject matrixInputForGPUInstruction2 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
        ScalarObject scalarInput = executionContext.getScalarInput(this.constant);
        long numRows = matrixInputForGPUInstruction.getNumRows();
        long numColumns = matrixInputForGPUInstruction.getNumColumns();
        long numRows2 = matrixInputForGPUInstruction2.getNumRows();
        long numColumns2 = matrixInputForGPUInstruction2.getNumColumns();
        if (!isValidMMOperation(numRows, numRows2, numColumns, numColumns2) && !isValidMVOperation(numRows, numRows2, numColumns, numColumns2)) {
            DMLRuntimeException dMLRuntimeException = new DMLRuntimeException("Incorrect dimensions of inputs in GPU axpy operation. input1:" + numRows + " X " + dMLRuntimeException + " and input2:" + numColumns + " X " + dMLRuntimeException);
            throw dMLRuntimeException;
        }
        executionContext.setMetaData(this._output.getName(), (int) numRows, (int) numColumns);
        LibMatrixCUDA.axpy(executionContext, executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, matrixInputForGPUInstruction2, this._output.getName(), this.multiplier * scalarInput.getDoubleValue());
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input2.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }

    private static boolean isValidMMOperation(long j, long j2, long j3, long j4) {
        return j == j2 && j3 == j4;
    }

    private static boolean isValidMVOperation(long j, long j2, long j3, long j4) {
        return (j == j2 && j4 == 1) || (j2 == 1 && j3 == j4);
    }
}
