package org.apache.sysds.runtime.instructions.gpu;

import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.utils.GPUStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/gpu/MatrixMatrixArithmeticGPUInstruction.class */
public class MatrixMatrixArithmeticGPUInstruction extends ArithmeticBinaryGPUInstruction {
    /* JADX INFO: Access modifiers changed from: protected */
    public MatrixMatrixArithmeticGPUInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, CPOperand cPOperand3, String str, String str2) {
        super(operator, cPOperand, cPOperand2, cPOperand3, str, str2);
    }

    @Override // org.apache.sysds.runtime.instructions.gpu.GPUInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        MatrixObject matrixInputForGPUInstruction2 = getMatrixInputForGPUInstruction(executionContext, this._input2.getName());
        long numRows = matrixInputForGPUInstruction.getNumRows();
        long numColumns = matrixInputForGPUInstruction.getNumColumns();
        long numRows2 = matrixInputForGPUInstruction2.getNumRows();
        long numColumns2 = matrixInputForGPUInstruction2.getNumColumns();
        long j = numRows;
        long j2 = numColumns;
        if (numRows != numRows2 || numColumns != numColumns2) {
            j = numRows > numRows2 ? numRows : numRows2;
            j2 = numColumns > numColumns2 ? numColumns : numColumns2;
        }
        executionContext.setMetaData(this._output.getName(), (int) j, (int) j2);
        LibMatrixCUDA.matrixMatrixArithmetic(executionContext, executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, matrixInputForGPUInstruction2, this._output.getName(), false, false, (BinaryOperator) this._optr);
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        executionContext.releaseMatrixInputForGPUInstruction(this._input2.getName());
        executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
    }
}
