package org.apache.sysds.runtime.instructions.gpu;

import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.functionobjects.IndexFunction;
import org.apache.sysds.runtime.functionobjects.ReduceCol;
import org.apache.sysds.runtime.functionobjects.ReduceRow;
import org.apache.sysds.runtime.instructions.InstructionUtils;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.gpu.GPUInstruction;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.utils.GPUStatistics;

/* loaded from: input_file:org/apache/sysds/runtime/instructions/gpu/AggregateUnaryGPUInstruction.class */
public class AggregateUnaryGPUInstruction extends GPUInstruction {
    private CPOperand _input1;
    private CPOperand _output;

    private AggregateUnaryGPUInstruction(Operator operator, CPOperand cPOperand, CPOperand cPOperand2, String str, String str2) {
        super(operator, str, str2);
        this._input1 = null;
        this._output = null;
        this._gputype = GPUInstruction.GPUINSTRUCTION_TYPE.AggregateUnary;
        this._input1 = cPOperand;
        this._output = cPOperand2;
    }

    public static AggregateUnaryGPUInstruction parseInstruction(String str) {
        String[] instructionPartsWithValueType = InstructionUtils.getInstructionPartsWithValueType(str);
        String str2 = instructionPartsWithValueType[0];
        CPOperand cPOperand = new CPOperand(instructionPartsWithValueType[1]);
        CPOperand cPOperand2 = new CPOperand(instructionPartsWithValueType[2]);
        if (str2.equalsIgnoreCase("nrow") || str2.equalsIgnoreCase("ncol") || str2.equalsIgnoreCase("length")) {
            throw new DMLRuntimeException("nrow, ncol & length should not be compiled as GPU instructions!");
        }
        return new AggregateUnaryGPUInstruction(InstructionUtils.parseBasicAggregateUnaryOperator(str2), cPOperand, cPOperand2, str2, str);
    }

    @Override // org.apache.sysds.runtime.instructions.gpu.GPUInstruction, org.apache.sysds.runtime.instructions.Instruction
    public void processInstruction(ExecutionContext executionContext) {
        GPUStatistics.incrementNoOfExecutedGPUInst();
        String opcode = getOpcode();
        if (opcode.equalsIgnoreCase("nrow") || opcode.equalsIgnoreCase("ncol") || opcode.equalsIgnoreCase("length")) {
            throw new DMLRuntimeException("nrow, ncol & length should not be compiled as GPU instructions!");
        }
        MatrixObject matrixInputForGPUInstruction = getMatrixInputForGPUInstruction(executionContext, this._input1.getName());
        int numRows = (int) matrixInputForGPUInstruction.getNumRows();
        int numColumns = (int) matrixInputForGPUInstruction.getNumColumns();
        IndexFunction indexFunction = ((AggregateUnaryOperator) this._optr).indexFn;
        if (indexFunction instanceof ReduceRow) {
            executionContext.setMetaData(this._output.getName(), 1L, numColumns);
        } else if (indexFunction instanceof ReduceCol) {
            executionContext.setMetaData(this._output.getName(), numRows, 1L);
        }
        LibMatrixCUDA.unaryAggregate(executionContext, executionContext.getGPUContext(0), getExtendedOpcode(), matrixInputForGPUInstruction, this._output.getName(), (AggregateUnaryOperator) this._optr);
        executionContext.releaseMatrixInputForGPUInstruction(this._input1.getName());
        if ((indexFunction instanceof ReduceRow) || (indexFunction instanceof ReduceCol)) {
            executionContext.releaseMatrixOutputForGPUInstruction(this._output.getName());
        }
    }
}
