package org.apache.sysds.runtime.matrix.data;

import jcuda.Pointer;
import jcuda.runtime.JCuda;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.instructions.gpu.context.CSRPointer;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;

/* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibMatrixCuDNNInputRowFetcher.class */
public class LibMatrixCuDNNInputRowFetcher extends LibMatrixCUDA implements AutoCloseable {
    GPUContext gCtx;
    String instName;
    int numColumns;
    boolean isInputInSparseFormat;
    Object inPointer;
    Pointer outPointer;

    public LibMatrixCuDNNInputRowFetcher(GPUContext gPUContext, String str, MatrixObject matrixObject) {
        this.gCtx = gPUContext;
        this.instName = str;
        this.numColumns = LibMatrixCUDA.toInt(matrixObject.getNumColumns());
        this.isInputInSparseFormat = LibMatrixCUDA.isInSparseFormat(gPUContext, matrixObject);
        this.inPointer = this.isInputInSparseFormat ? LibMatrixCUDA.getSparsePointer(gPUContext, matrixObject, str) : LibMatrixCuDNN.getDensePointerForCuDNN(gPUContext, matrixObject, str);
        this.outPointer = gPUContext.allocate(str, this.numColumns * sizeOfDataType, false);
    }

    public Pointer getNthRow(int i) {
        if (this.isInputInSparseFormat) {
            JCuda.cudaDeviceSynchronize();
            JCuda.cudaMemset(this.outPointer, 0, this.numColumns * sizeOfDataType);
            JCuda.cudaDeviceSynchronize();
            LibMatrixCUDA.sliceSparseDense(this.gCtx, this.instName, (CSRPointer) this.inPointer, this.outPointer, i, i, 0, LibMatrixCUDA.toInt(this.numColumns - 1), this.numColumns);
        } else {
            LibMatrixCUDA.sliceDenseDense(this.gCtx, this.instName, (Pointer) this.inPointer, this.outPointer, i, i, 0, LibMatrixCUDA.toInt(this.numColumns - 1), this.numColumns);
        }
        return this.outPointer;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        try {
            this.gCtx.cudaFreeHelper(null, this.outPointer, true);
        } catch (DMLRuntimeException e) {
            throw new RuntimeException(e);
        }
    }
}
