package org.apache.sysds.runtime.matrix.data;

import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.jcudnn.JCudnn;
import jcuda.jcudnn.cudnnActivationDescriptor;
import jcuda.jcudnn.cudnnHandle;
import jcuda.jcudnn.cudnnStatus;
import jcuda.jcudnn.cudnnTensorDescriptor;
import jcuda.runtime.JCuda;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.gpu.context.CSRPointer;
import org.apache.sysds.runtime.instructions.gpu.context.ExecutionConfig;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.instructions.gpu.context.JCudaKernels;
import org.apache.sysds.runtime.matrix.data.LibMatrixDNN;

/* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibMatrixCuDNN.class */
public class LibMatrixCuDNN extends LibMatrixCUDA {
    private static final boolean RECOMPUTE_DENSE_NNZ = false;
    protected static int CONVOLUTION_PREFERENCE = 0;
    private static final Log LOG = LogFactory.getLog(LibMatrixCuDNN.class.getName());

    /* JADX INFO: Access modifiers changed from: protected */
    public static cudnnHandle getCudnnHandle(GPUContext gPUContext) {
        return gPUContext.getCudnnHandle();
    }

    public static void conv2dBiasAdd(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, MatrixObject matrixObject4, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, double d) {
        conv2d(gPUContext, str, matrixObject, matrixObject3, matrixObject4, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, d);
        biasAdd(gPUContext, str, matrixObject4, matrixObject2, matrixObject4);
    }

    private static Pointer denseIm2col(GPUContext gPUContext, String str, MatrixObject matrixObject, boolean z, long j, long j2, long j3, long j4, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8) {
        Pointer allocate;
        if (z) {
            CSRPointer sparsePointer = getSparsePointer(gPUContext, matrixObject, str);
            if (sparsePointer.nnz < 0) {
                throw new DMLRuntimeException("Unknown number of nonzeroes in denseIm2col");
            }
            if (sparsePointer.nnz <= 0) {
                return null;
            }
            allocate = gPUContext.allocate(str, j2 * i * i2 * j * i7 * i8 * sizeOfDataType, false);
            getCudaKernels(gPUContext).launchKernel("sparse_dense_im2col", ExecutionConfig.getConfigForSimpleVectorOperations(toInt(sparsePointer.nnz)), sparsePointer.val, sparsePointer.rowPtr, sparsePointer.colInd, allocate, Long.valueOf(sparsePointer.nnz), Long.valueOf(j), Long.valueOf(j2 * j3 * j4), Long.valueOf(j3 * j4), Long.valueOf(j4), Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i7), Integer.valueOf(i8), Integer.valueOf(i7 * i8), Integer.valueOf(i * i2), Long.valueOf(j * i7 * i8), Integer.valueOf(i5), Integer.valueOf(i6), Integer.valueOf(i3), Integer.valueOf(i4));
        } else {
            allocate = gPUContext.allocate(str, j2 * i * i2 * j * i7 * i8 * sizeOfDataType, false);
            getCudaKernels(gPUContext).launchKernel("dense_dense_im2col", ExecutionConfig.getConfigForSimpleVectorOperations(toInt(j * j2 * j3 * j4)), getDensePointerForCuDNN(gPUContext, matrixObject, str), allocate, Long.valueOf(j * j2 * j3 * j4), Long.valueOf(j2 * j3 * j4), Long.valueOf(j3 * j4), Long.valueOf(j4), Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i7), Integer.valueOf(i8), Integer.valueOf(i7 * i8), Integer.valueOf(i * i2), Long.valueOf(j * i7 * i8), Integer.valueOf(i5), Integer.valueOf(i6), Integer.valueOf(i3), Integer.valueOf(i4));
        }
        return allocate;
    }

    public static void conv2d(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, double d) {
        long j = i2 * i3 * i4;
        long j2 = i5 * i12 * i13;
        long j3 = i2 * i6 * i7;
        long j4 = i * j;
        long j5 = i * j2;
        long j6 = i5 * j3;
        long j7 = i * i12 * i13;
        boolean isInSparseFormat = isInSparseFormat(gPUContext, matrixObject2);
        if (getNnz(gPUContext, str, matrixObject2, false) == 0) {
            return;
        }
        boolean isInSparseFormat2 = isInSparseFormat(gPUContext, matrixObject);
        if (getNnz(gPUContext, str, matrixObject, false) == 0) {
            return;
        }
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject3, str);
        if (j4 >= maxNumElementsOfCuDNNTensor || j5 >= maxNumElementsOfCuDNNTensor || j6 >= maxNumElementsOfCuDNNTensor) {
            throwCuDNNDimensionError(i, j, i5, j3, i, j2);
            return;
        }
        if (isInSparseFormat && OptimizerUtils.estimateSizeExactSparsity(j3, j7, 1.0d) + OptimizerUtils.estimateSizeExactSparsity(i5, j7, 1.0d) < Math.min(LibMatrixCuDNNConvolutionAlgorithm.MAX_WORKSPACE_LIMIT_BYTES, d)) {
            Pointer denseIm2col = denseIm2col(gPUContext, str, matrixObject, isInSparseFormat2, i, i2, i3, i4, i6, i7, i8, i9, i10, i11, i12, i13);
            CSRPointer jcudaSparseMatrixPtr = matrixObject2.getGPUObject(gPUContext).getJcudaSparseMatrixPtr();
            Pointer allocate = gPUContext.allocate(str, j5 * sizeOfDataType, false);
            LibMatrixCuMatMult.sparseDenseMatMult(gPUContext, str, allocate, jcudaSparseMatrixPtr, denseIm2col, i5, j3, j3, j7, i5, j7, false, false);
            gPUContext.cudaFreeHelper(str, denseIm2col, DMLScript.EAGER_CUDA_FREE);
            getCudaKernels(gPUContext).launchKernel("reorg_knpq", ExecutionConfig.getConfigForSimpleVectorOperations(toInt(j5)), allocate, densePointerForCuDNN, Long.valueOf(j5), Long.valueOf(j7), Long.valueOf(j2), Integer.valueOf(i12 * i13));
            gPUContext.cudaFreeHelper(str, allocate, DMLScript.EAGER_CUDA_FREE);
            return;
        }
        double estimateSizeExactSparsity = (isInSparseFormat ? OptimizerUtils.estimateSizeExactSparsity(i5, j3, 1.0d) : DataExpression.DEFAULT_DELIM_FILL_VALUE) + (isInSparseFormat2 ? OptimizerUtils.estimateSizeExactSparsity(i, j, 1.0d) : DataExpression.DEFAULT_DELIM_FILL_VALUE);
        Pointer densePointerForCuDNN2 = getDensePointerForCuDNN(gPUContext, matrixObject2, str);
        long j8 = (long) (d - estimateSizeExactSparsity);
        int i14 = estimateSizeExactSparsity <= d ? i : 1;
        LibMatrixCuDNNConvolutionAlgorithm cudnnGetConvolutionForwardAlgorithm = LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionForwardAlgorithm(gPUContext, str, i14, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, j8);
        try {
            if (i14 == i) {
                cudnnConv2d(gPUContext, str, getDensePointerForCuDNN(gPUContext, matrixObject, str), densePointerForCuDNN2, densePointerForCuDNN, cudnnGetConvolutionForwardAlgorithm);
            } else {
                LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject);
                for (int i15 = 0; i15 < i; i15++) {
                    try {
                        cudnnConv2d(gPUContext, str, libMatrixCuDNNInputRowFetcher.getNthRow(i15), densePointerForCuDNN2, densePointerForCuDNN.withByteOffset(i15 * j2 * sizeOfDataType), cudnnGetConvolutionForwardAlgorithm);
                    } finally {
                    }
                }
                libMatrixCuDNNInputRowFetcher.close();
            }
            if (cudnnGetConvolutionForwardAlgorithm != null) {
                cudnnGetConvolutionForwardAlgorithm.close();
            }
        } catch (Throwable th) {
            if (cudnnGetConvolutionForwardAlgorithm != null) {
                try {
                    cudnnGetConvolutionForwardAlgorithm.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public static void softmax(ExecutionContext executionContext, GPUContext gPUContext, String str, MatrixObject matrixObject, String str2) {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : softmax, GPUContext=" + gPUContext);
        }
        cudnnTensorDescriptor allocateTensorDescriptor = allocateTensorDescriptor(toInt(matrixObject.getNumRows()), toInt(matrixObject.getNumColumns()), 1, 1);
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject, str);
        MatrixObject matrixObject2 = executionContext.getMatrixObject(str2);
        executionContext.allocateGPUMatrixObject(str2, matrixObject.getNumRows(), matrixObject.getNumColumns());
        matrixObject2.getGPUObject(gPUContext).allocateAndFillDense(DataExpression.DEFAULT_DELIM_FILL_VALUE);
        JCudnn.cudnnSoftmaxForward(gPUContext.getCudnnHandle(), 1, 1, one(), allocateTensorDescriptor, densePointerForCuDNN, zero(), allocateTensorDescriptor, getDensePointerForCuDNN(gPUContext, matrixObject2, str));
        JCudnn.cudnnDestroyTensorDescriptor(allocateTensorDescriptor);
    }

    private static cudnnTensorDescriptor allocateTensorDescriptor(int i, int i2, int i3, int i4) {
        cudnnTensorDescriptor cudnntensordescriptor = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor, 0, LibMatrixCUDA.CUDNN_DATA_TYPE, i, i2, i3, i4);
        return cudnntensordescriptor;
    }

    private static void throwCuDNNDimensionError(long j, long j2, long j3, long j4) {
        DMLRuntimeException dMLRuntimeException = new DMLRuntimeException("The dimensions of input/output matrices is too large to execute a CuDNN kernel. Max CuDNN matrix size:" + maxNumElementsOfCuDNNTensor + ". Given input matrix dimensions: [" + dMLRuntimeException + "," + j + "]. Output dimension:  [" + dMLRuntimeException + "," + j2 + "].");
        throw dMLRuntimeException;
    }

    private static void throwCuDNNDimensionError(long j, long j2, long j3, long j4, long j5, long j6) {
        DMLRuntimeException dMLRuntimeException = new DMLRuntimeException("The dimensions of input/output matrices is too large to execute a CuDNN kernel. Max CuDNN matrix size:" + maxNumElementsOfCuDNNTensor + ". Given input matrix dimensions: [" + dMLRuntimeException + "," + j + "], [" + dMLRuntimeException + "," + j2 + "]. Output dimension: [" + dMLRuntimeException + "," + j3 + "]");
        throw dMLRuntimeException;
    }

    private static void cudnnConv2d(GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, Pointer pointer3, LibMatrixCuDNNConvolutionAlgorithm libMatrixCuDNNConvolutionAlgorithm) {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : conv2d, GPUContext=" + gPUContext);
        }
        try {
            int cudnnConvolutionForward = JCudnn.cudnnConvolutionForward(getCudnnHandle(gPUContext), one(), libMatrixCuDNNConvolutionAlgorithm.nchwTensorDesc, pointer, libMatrixCuDNNConvolutionAlgorithm.filterDesc, pointer2, libMatrixCuDNNConvolutionAlgorithm.convDesc, libMatrixCuDNNConvolutionAlgorithm.algo, libMatrixCuDNNConvolutionAlgorithm.workSpace, libMatrixCuDNNConvolutionAlgorithm.sizeInBytes, zero(), libMatrixCuDNNConvolutionAlgorithm.nkpqTensorDesc, pointer3);
            if (cudnnConvolutionForward != 0) {
                throw new DMLRuntimeException("Could not executed cudnnConvolutionForward: " + cudnnStatus.stringFor(cudnnConvolutionForward));
            }
        } catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gPUContext.toString() + " from Thread " + Thread.currentThread().toString(), e);
        }
    }

    public static void conv2dBackwardFilter(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, double d) {
        long j = i2 * i3 * i4;
        long j2 = i5 * i12 * i13;
        long j3 = i2 * i6 * i7;
        long j4 = i * j;
        long j5 = i * j2;
        long j6 = i5 * j3;
        boolean isInSparseFormat = isInSparseFormat(gPUContext, matrixObject2);
        if (getNnz(gPUContext, str, matrixObject2, false) == 0) {
            return;
        }
        boolean isInSparseFormat2 = isInSparseFormat(gPUContext, matrixObject);
        if (getNnz(gPUContext, str, matrixObject, false) == 0) {
            return;
        }
        if (j4 >= maxNumElementsOfCuDNNTensor || j5 >= maxNumElementsOfCuDNNTensor || j6 >= maxNumElementsOfCuDNNTensor) {
            throwCuDNNDimensionError(i, j, i, j2, i5, j3);
            return;
        }
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject3, str);
        double estimateSizeExactSparsity = (isInSparseFormat2 ? OptimizerUtils.estimateSizeExactSparsity(i, j, 1.0d) : DataExpression.DEFAULT_DELIM_FILL_VALUE) + (isInSparseFormat ? OptimizerUtils.estimateSizeExactSparsity(i, j2, 1.0d) : DataExpression.DEFAULT_DELIM_FILL_VALUE);
        long j7 = (long) (d - estimateSizeExactSparsity);
        int i14 = estimateSizeExactSparsity <= d ? i : 1;
        LibMatrixCuDNNConvolutionAlgorithm cudnnGetConvolutionBackwardFilterAlgorithm = LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionBackwardFilterAlgorithm(gPUContext, str, i14, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, j7);
        try {
            if (i14 == i) {
                cudnnConv2dBackwardFilter(gPUContext, str, getDensePointerForCuDNN(gPUContext, matrixObject, str), getDensePointerForCuDNN(gPUContext, matrixObject2, str), densePointerForCuDNN, cudnnGetConvolutionBackwardFilterAlgorithm);
            } else {
                LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject);
                try {
                    LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher2 = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject2);
                    try {
                        Pointer allocate = gPUContext.allocate(str, j6 * sizeOfDataType, false);
                        for (int i15 = 0; i15 < i; i15++) {
                            JCuda.cudaMemset(allocate, 0, j6 * sizeOfDataType);
                            cudnnConv2dBackwardFilter(gPUContext, str, libMatrixCuDNNInputRowFetcher.getNthRow(i15), libMatrixCuDNNInputRowFetcher2.getNthRow(i15), allocate, cudnnGetConvolutionBackwardFilterAlgorithm);
                            getCudaKernels(gPUContext).launchKernel("inplace_add", ExecutionConfig.getConfigForSimpleMatrixOperations(i5, toInt(j3)), allocate, densePointerForCuDNN, Integer.valueOf(i5), Integer.valueOf(toInt(j3)));
                        }
                        gPUContext.cudaFreeHelper(str, allocate, true);
                        libMatrixCuDNNInputRowFetcher2.close();
                        libMatrixCuDNNInputRowFetcher.close();
                    } catch (Throwable th) {
                        try {
                            libMatrixCuDNNInputRowFetcher2.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                        throw th;
                    }
                } finally {
                }
            }
            if (cudnnGetConvolutionBackwardFilterAlgorithm != null) {
                cudnnGetConvolutionBackwardFilterAlgorithm.close();
            }
        } catch (Throwable th3) {
            if (cudnnGetConvolutionBackwardFilterAlgorithm != null) {
                try {
                    cudnnGetConvolutionBackwardFilterAlgorithm.close();
                } catch (Throwable th4) {
                    th3.addSuppressed(th4);
                }
            }
            throw th3;
        }
    }

    private static void cudnnConv2dBackwardFilter(GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, Pointer pointer3, LibMatrixCuDNNConvolutionAlgorithm libMatrixCuDNNConvolutionAlgorithm) {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : conv2dBackwardFilter, GPUContext=" + gPUContext);
        }
        try {
            int cudnnConvolutionBackwardFilter = JCudnn.cudnnConvolutionBackwardFilter(getCudnnHandle(gPUContext), one(), libMatrixCuDNNConvolutionAlgorithm.nchwTensorDesc, pointer, libMatrixCuDNNConvolutionAlgorithm.nkpqTensorDesc, pointer2, libMatrixCuDNNConvolutionAlgorithm.convDesc, libMatrixCuDNNConvolutionAlgorithm.algo, libMatrixCuDNNConvolutionAlgorithm.workSpace, libMatrixCuDNNConvolutionAlgorithm.sizeInBytes, zero(), libMatrixCuDNNConvolutionAlgorithm.filterDesc, pointer3);
            if (cudnnConvolutionBackwardFilter != 0) {
                throw new DMLRuntimeException("Could not executed cudnnConvolutionBackwardFilter: " + cudnnStatus.stringFor(cudnnConvolutionBackwardFilter));
            }
        } catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gPUContext.toString() + " from Thread " + Thread.currentThread().toString(), e);
        }
    }

    public static void conv2dBackwardData(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, double d) {
        long j = i2 * i3 * i4;
        long j2 = i5 * i12 * i13;
        long j3 = i2 * i6 * i7;
        long j4 = i * j;
        long j5 = i * j2;
        long j6 = i5 * j3;
        boolean isInSparseFormat = isInSparseFormat(gPUContext, matrixObject);
        if (getNnz(gPUContext, str, matrixObject, false) == 0) {
            return;
        }
        boolean isInSparseFormat2 = isInSparseFormat(gPUContext, matrixObject2);
        if (getNnz(gPUContext, str, matrixObject2, false) == 0) {
            return;
        }
        if (j4 >= maxNumElementsOfCuDNNTensor || j5 >= maxNumElementsOfCuDNNTensor || j6 >= maxNumElementsOfCuDNNTensor) {
            throwCuDNNDimensionError(i, j, i, j2, i5, j3);
            return;
        }
        double estimateSizeExactSparsity = (isInSparseFormat ? OptimizerUtils.estimateSizeExactSparsity(i5, j3, 1.0d) : DataExpression.DEFAULT_DELIM_FILL_VALUE) + (isInSparseFormat2 ? OptimizerUtils.estimateSizeExactSparsity(i, j2, 1.0d) : DataExpression.DEFAULT_DELIM_FILL_VALUE);
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject, str);
        Pointer densePointerForCuDNN2 = getDensePointerForCuDNN(gPUContext, matrixObject3, str);
        long j7 = (long) (d - estimateSizeExactSparsity);
        int i14 = estimateSizeExactSparsity <= d ? i : 1;
        LibMatrixCuDNNConvolutionAlgorithm cudnnGetConvolutionBackwardDataAlgorithm = LibMatrixCuDNNConvolutionAlgorithm.cudnnGetConvolutionBackwardDataAlgorithm(gPUContext, str, i14, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, j7);
        try {
            if (i14 == i) {
                cudnnConv2dBackwardData(gPUContext, str, densePointerForCuDNN, getDensePointerForCuDNN(gPUContext, matrixObject2, str), densePointerForCuDNN2, cudnnGetConvolutionBackwardDataAlgorithm);
            } else {
                LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject2);
                for (int i15 = 0; i15 < i; i15++) {
                    try {
                        cudnnConv2dBackwardData(gPUContext, str, libMatrixCuDNNInputRowFetcher.getNthRow(i15), densePointerForCuDNN, densePointerForCuDNN2.withByteOffset(i15 * j * sizeOfDataType), cudnnGetConvolutionBackwardDataAlgorithm);
                    } finally {
                    }
                }
                libMatrixCuDNNInputRowFetcher.close();
            }
            if (cudnnGetConvolutionBackwardDataAlgorithm != null) {
                cudnnGetConvolutionBackwardDataAlgorithm.close();
            }
        } catch (Throwable th) {
            if (cudnnGetConvolutionBackwardDataAlgorithm != null) {
                try {
                    cudnnGetConvolutionBackwardDataAlgorithm.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private static void cudnnConv2dBackwardData(GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, Pointer pointer3, LibMatrixCuDNNConvolutionAlgorithm libMatrixCuDNNConvolutionAlgorithm) {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : conv2dBackwardData, GPUContext=" + gPUContext);
        }
        try {
            int cudnnConvolutionBackwardData = JCudnn.cudnnConvolutionBackwardData(getCudnnHandle(gPUContext), one(), libMatrixCuDNNConvolutionAlgorithm.filterDesc, pointer, libMatrixCuDNNConvolutionAlgorithm.nkpqTensorDesc, pointer2, libMatrixCuDNNConvolutionAlgorithm.convDesc, libMatrixCuDNNConvolutionAlgorithm.algo, libMatrixCuDNNConvolutionAlgorithm.workSpace, libMatrixCuDNNConvolutionAlgorithm.sizeInBytes, zero(), libMatrixCuDNNConvolutionAlgorithm.nchwTensorDesc, pointer3);
            if (cudnnConvolutionBackwardData != 0) {
                throw new DMLRuntimeException("Could not executed cudnnConvolutionBackwardData: " + cudnnStatus.stringFor(cudnnConvolutionBackwardData));
            }
        } catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gPUContext.toString() + " from Thread " + Thread.currentThread().toString(), e);
        }
    }

    public static void pooling(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, LibMatrixDNN.PoolingType poolingType, double d) {
        long j = i2 * i3 * i4;
        long j2 = i2 * i12 * i13;
        long j3 = i * j;
        long j4 = i * j2;
        if (j3 >= maxNumElementsOfCuDNNTensor || j4 >= maxNumElementsOfCuDNNTensor) {
            throwCuDNNDimensionError(i, j, i, j2);
            return;
        }
        long estimateSizeExactSparsity = isInSparseFormat(gPUContext, matrixObject) ? OptimizerUtils.estimateSizeExactSparsity(i, j, 1.0d) : 0L;
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject2, str);
        if (estimateSizeExactSparsity <= d) {
            cudnnPoolingHelper(gPUContext, str, getDensePointerForCuDNN(gPUContext, matrixObject, str), densePointerForCuDNN, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, poolingType);
            return;
        }
        LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject);
        for (int i14 = 0; i14 < i; i14++) {
            try {
                cudnnPoolingHelper(gPUContext, str, libMatrixCuDNNInputRowFetcher.getNthRow(i14), densePointerForCuDNN.withByteOffset(i14 * j2 * sizeOfDataType), 1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, poolingType);
            } catch (Throwable th) {
                try {
                    libMatrixCuDNNInputRowFetcher.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
                throw th;
            }
        }
        libMatrixCuDNNInputRowFetcher.close();
    }

    private static void cudnnPoolingHelper(GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, LibMatrixDNN.PoolingType poolingType) {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : perform pooling, GPUContext=" + gPUContext);
        }
        try {
            LibMatrixCuDNNPoolingDescriptors cudnnPoolingDescriptors = LibMatrixCuDNNPoolingDescriptors.cudnnPoolingDescriptors(gPUContext, str, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, poolingType);
            try {
                int cudnnPoolingForward = JCudnn.cudnnPoolingForward(getCudnnHandle(gPUContext), cudnnPoolingDescriptors.poolingDesc, one(), cudnnPoolingDescriptors.xDesc, pointer, zero(), cudnnPoolingDescriptors.yDesc, pointer2);
                if (cudnnPoolingForward != 0) {
                    throw new DMLRuntimeException("Could not executed cudnnPoolingForward: " + cudnnStatus.stringFor(cudnnPoolingForward));
                }
                if (cudnnPoolingDescriptors != null) {
                    cudnnPoolingDescriptors.close();
                }
            } finally {
            }
        } catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gPUContext.toString() + " from Thread " + Thread.currentThread().toString(), e);
        }
    }

    public static void poolingBackward(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, MatrixObject matrixObject4, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, LibMatrixDNN.PoolingType poolingType, double d) {
        long j = i2 * i3 * i4;
        long j2 = i2 * i12 * i13;
        long j3 = i * j;
        long j4 = i * j2;
        boolean z = matrixObject3 != null;
        if (j3 >= maxNumElementsOfCuDNNTensor || j4 >= maxNumElementsOfCuDNNTensor) {
            throwCuDNNDimensionError(i, j, i, j2);
            return;
        }
        long estimateSizeExactSparsity = isInSparseFormat(gPUContext, matrixObject) ? OptimizerUtils.estimateSizeExactSparsity(i, j, 1.0d) : 0L;
        long estimateSizeExactSparsity2 = isInSparseFormat(gPUContext, matrixObject2) ? OptimizerUtils.estimateSizeExactSparsity(i, j2, 1.0d) : 0L;
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject4, str);
        if (estimateSizeExactSparsity + estimateSizeExactSparsity2 <= d) {
            cudnnPoolingBackwardHelper(gPUContext, str, getDensePointerForCuDNN(gPUContext, matrixObject, str), getDensePointerForCuDNN(gPUContext, matrixObject2, str), z ? getDensePointerForCuDNN(gPUContext, matrixObject3, str) : null, densePointerForCuDNN, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, poolingType);
            return;
        }
        LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject);
        LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher2 = new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject2);
        LibMatrixCuDNNInputRowFetcher libMatrixCuDNNInputRowFetcher3 = z ? new LibMatrixCuDNNInputRowFetcher(gPUContext, str, matrixObject3) : null;
        for (int i14 = 0; i14 < i; i14++) {
            cudnnPoolingBackwardHelper(gPUContext, str, libMatrixCuDNNInputRowFetcher.getNthRow(i14), libMatrixCuDNNInputRowFetcher2.getNthRow(i14), z ? libMatrixCuDNNInputRowFetcher3.getNthRow(i14) : null, densePointerForCuDNN.withByteOffset(i14 * j * sizeOfDataType), 1, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, poolingType);
        }
        libMatrixCuDNNInputRowFetcher.close();
        libMatrixCuDNNInputRowFetcher2.close();
        if (z) {
            libMatrixCuDNNInputRowFetcher3.close();
        }
    }

    private static void cudnnPoolingBackwardHelper(GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, Pointer pointer3, Pointer pointer4, int i, int i2, int i3, int i4, int i5, int i6, int i7, int i8, int i9, int i10, int i11, int i12, int i13, LibMatrixDNN.PoolingType poolingType) {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : maxpoolingBackward, GPUContext=" + gPUContext);
        }
        boolean z = pointer3 != null;
        try {
            try {
                LibMatrixCuDNNPoolingDescriptors cudnnPoolingBackwardDescriptors = LibMatrixCuDNNPoolingDescriptors.cudnnPoolingBackwardDescriptors(gPUContext, str, i, i2, i3, i4, i5, i6, i7, i8, i9, i10, i11, i12, i13, poolingType);
                if (!z) {
                    try {
                        pointer3 = gPUContext.allocate(str, i * i2 * i12 * i13 * sizeOfDataType, false);
                        int cudnnPoolingForward = JCudnn.cudnnPoolingForward(getCudnnHandle(gPUContext), cudnnPoolingBackwardDescriptors.poolingDesc, one(), cudnnPoolingBackwardDescriptors.xDesc, pointer, zero(), cudnnPoolingBackwardDescriptors.yDesc, pointer3);
                        if (cudnnPoolingForward != 0) {
                            throw new DMLRuntimeException("Could not executed cudnnPoolingForward before cudnnPoolingBackward: " + cudnnStatus.stringFor(cudnnPoolingForward));
                        }
                    } catch (Throwable th) {
                        if (cudnnPoolingBackwardDescriptors != null) {
                            try {
                                cudnnPoolingBackwardDescriptors.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                }
                int cudnnPoolingBackward = JCudnn.cudnnPoolingBackward(getCudnnHandle(gPUContext), cudnnPoolingBackwardDescriptors.poolingDesc, one(), cudnnPoolingBackwardDescriptors.yDesc, pointer3, cudnnPoolingBackwardDescriptors.dyDesc, pointer2, cudnnPoolingBackwardDescriptors.xDesc, pointer, zero(), cudnnPoolingBackwardDescriptors.dxDesc, pointer4);
                if (cudnnPoolingBackward != 0) {
                    throw new DMLRuntimeException("Could not executed cudnnPoolingBackward: " + cudnnStatus.stringFor(cudnnPoolingBackward));
                }
                if (cudnnPoolingBackwardDescriptors != null) {
                    cudnnPoolingBackwardDescriptors.close();
                }
                if (z) {
                    return;
                }
                gPUContext.cudaFreeHelper(str, pointer3, DMLScript.EAGER_CUDA_FREE);
            } catch (CudaException e) {
                throw new DMLRuntimeException("Error in conv2d in GPUContext " + gPUContext.toString() + " from Thread " + Thread.currentThread().toString(), e);
            }
        } catch (Throwable th3) {
            if (!z) {
                gPUContext.cudaFreeHelper(str, pointer3, DMLScript.EAGER_CUDA_FREE);
            }
            throw th3;
        }
    }

    private static void cudnnReLU(GPUContext gPUContext, String str, MatrixObject matrixObject, Pointer pointer, cudnnTensorDescriptor cudnntensordescriptor) {
        try {
            if (LOG.isTraceEnabled()) {
                LOG.trace("GPU : performCuDNNReLU, GPUContext=" + gPUContext);
            }
            Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject, str);
            cudnnActivationDescriptor cudnnactivationdescriptor = new cudnnActivationDescriptor();
            JCudnn.cudnnCreateActivationDescriptor(cudnnactivationdescriptor);
            JCudnn.cudnnSetActivationDescriptor(cudnnactivationdescriptor, 1, 1, -1.0d);
            JCudnn.cudnnActivationForward(getCudnnHandle(gPUContext), cudnnactivationdescriptor, one(), cudnntensordescriptor, densePointerForCuDNN, zero(), cudnntensordescriptor, pointer);
        } catch (CudaException e) {
            throw new DMLRuntimeException("Error in conv2d in GPUContext " + gPUContext.toString() + " from Thread " + Thread.currentThread().toString(), e);
        }
    }

    public static void relu(ExecutionContext executionContext, GPUContext gPUContext, String str, MatrixObject matrixObject, String str2) {
        if (executionContext.getGPUContext(0) != gPUContext) {
            throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function");
        }
        long numRows = matrixObject.getNumRows();
        long numColumns = matrixObject.getNumColumns();
        Pointer denseOutputPointer = getDenseOutputPointer(executionContext, gPUContext, str, str2, matrixObject.getNumRows(), matrixObject.getNumColumns());
        if (numRows * numColumns >= maxNumElementsOfCuDNNTensor) {
            if (LOG.isTraceEnabled()) {
                LOG.trace("GPU : relu custom kernel, GPUContext=" + gPUContext);
            }
            getCudaKernels(gPUContext).launchKernel("relu", ExecutionConfig.getConfigForSimpleMatrixOperations(toInt(numRows), toInt(numColumns)), getDensePointerForCuDNN(gPUContext, matrixObject, str), denseOutputPointer, Integer.valueOf(toInt(numRows)), Integer.valueOf(toInt(numColumns)));
            return;
        }
        cudnnTensorDescriptor cudnntensordescriptor = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor, 0, CUDNN_DATA_TYPE, toInt(numRows), 1, 1, toInt(numColumns));
        cudnnReLU(gPUContext, str, matrixObject, denseOutputPointer, cudnntensordescriptor);
        JCudnn.cudnnDestroyTensorDescriptor(cudnntensordescriptor);
    }

    static Pointer getDenseInputPointer(ExecutionContext executionContext, GPUContext gPUContext, String str, String str2, long j, long j2) throws DMLRuntimeException {
        return getDensePointerForCuDNN(gPUContext, executionContext.getMatrixInputForGPUInstruction(str2, str), str, toInt(j), toInt(j2));
    }

    static Pointer getDenseOutputPointer(ExecutionContext executionContext, GPUContext gPUContext, String str, String str2, long j, long j2) throws DMLRuntimeException {
        MatrixObject matrixObject = executionContext.getMatrixObject(str2);
        getDenseMatrixOutputForGPUInstruction(executionContext, str, str2, j, j2);
        return getDensePointerForCuDNN(gPUContext, matrixObject, str, toInt(j), toInt(j2));
    }

    public static void lstm(ExecutionContext executionContext, GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, Pointer pointer3, Pointer pointer4, boolean z, String str2, String str3, int i, int i2, int i3, int i4) throws DMLRuntimeException {
        singleLayerUnidirectionalRNNForward(executionContext, gPUContext, str, pointer, pointer3, pointer4, pointer2, str2, str3, "lstm", z, i, i2, i3, i4);
    }

    private static void singleLayerUnidirectionalRNNForward(ExecutionContext executionContext, GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, Pointer pointer3, Pointer pointer4, String str2, String str3, String str4, boolean z, int i, int i2, int i3, int i4) throws DMLRuntimeException {
        boolean equalsIgnoreCase = str4.equalsIgnoreCase("lstm");
        Pointer allocate = gPUContext.allocate(str, i * i4 * i2 * sizeOfDataType, false);
        Pointer denseOutputPointer = !z ? getDenseOutputPointer(executionContext, gPUContext, str, str2, i, i2) : gPUContext.allocate(str, i * i2 * sizeOfDataType, false);
        Pointer denseOutputPointer2 = equalsIgnoreCase ? getDenseOutputPointer(executionContext, gPUContext, str, str3, i, i2) : new Pointer();
        LibMatrixCuDNNRnnAlgorithm libMatrixCuDNNRnnAlgorithm = new LibMatrixCuDNNRnnAlgorithm(executionContext, gPUContext, str, str4, i, i4, i2, i3, true, pointer4);
        try {
            JCudnn.cudnnRNNForwardTraining(gPUContext.getCudnnHandle(), libMatrixCuDNNRnnAlgorithm.rnnDesc, i4, libMatrixCuDNNRnnAlgorithm.xDesc, pointer, libMatrixCuDNNRnnAlgorithm.hxDesc, pointer2, libMatrixCuDNNRnnAlgorithm.cxDesc, pointer3, libMatrixCuDNNRnnAlgorithm.wDesc, pointer4, libMatrixCuDNNRnnAlgorithm.yDesc, allocate, libMatrixCuDNNRnnAlgorithm.hyDesc, denseOutputPointer, libMatrixCuDNNRnnAlgorithm.cyDesc, denseOutputPointer2, libMatrixCuDNNRnnAlgorithm.workSpace, libMatrixCuDNNRnnAlgorithm.sizeInBytes, libMatrixCuDNNRnnAlgorithm.reserveSpace, libMatrixCuDNNRnnAlgorithm.reserveSpaceSizeInBytes);
            libMatrixCuDNNRnnAlgorithm.close();
            if (z) {
                gPUContext.cudaFreeHelper(str, denseOutputPointer, DMLScript.EAGER_CUDA_FREE);
                LibMatrixCUDA.getCudaKernels(gPUContext).launchKernel("prepare_lstm_output", ExecutionConfig.getConfigForSimpleVectorOperations(i * i4 * i2), getDenseOutputPointer(executionContext, gPUContext, str, str2, i, i4 * i2), allocate, Integer.valueOf(i), Integer.valueOf(i4), Integer.valueOf(i2), Integer.valueOf(i * i4 * i2));
            }
            gPUContext.cudaFreeHelper(str, allocate, DMLScript.EAGER_CUDA_FREE);
        } catch (Throwable th) {
            try {
                libMatrixCuDNNRnnAlgorithm.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public static void lstmBackward(ExecutionContext executionContext, GPUContext gPUContext, String str, Pointer pointer, Pointer pointer2, Pointer pointer3, Pointer pointer4, String str2, String str3, String str4, String str5, String str6, String str7, String str8, boolean z, int i, int i2, int i3, int i4) throws DMLRuntimeException {
        Pointer allocate = gPUContext.allocate(str, i * i4 * i2 * sizeOfDataType, false);
        int i5 = z ? i * i4 * i2 : i * i2;
        JCudaKernels cudaKernels = LibMatrixCUDA.getCudaKernels(gPUContext);
        ExecutionConfig configForSimpleVectorOperations = ExecutionConfig.getConfigForSimpleVectorOperations(i5);
        Object[] objArr = new Object[7];
        objArr[0] = getDenseInputPointer(executionContext, gPUContext, str, str2, i, z ? i4 * i2 : i2);
        objArr[1] = allocate;
        objArr[2] = Integer.valueOf(i);
        objArr[3] = Integer.valueOf(i4);
        objArr[4] = Integer.valueOf(i2);
        objArr[5] = Integer.valueOf(i5);
        objArr[6] = Integer.valueOf(z ? 1 : 0);
        cudaKernels.launchKernel("prepare_lstm_backward_gradients", configForSimpleVectorOperations, objArr);
        executionContext.releaseMatrixInputForGPUInstruction(str2);
        Pointer allocate2 = gPUContext.allocate(str, i * i4 * i2 * sizeOfDataType, false);
        LibMatrixCuDNNRnnAlgorithm libMatrixCuDNNRnnAlgorithm = new LibMatrixCuDNNRnnAlgorithm(executionContext, gPUContext, str, "lstm", i, i4, i2, i3, true, pointer4);
        try {
            JCudnn.cudnnRNNForwardTraining(gPUContext.getCudnnHandle(), libMatrixCuDNNRnnAlgorithm.rnnDesc, i4, libMatrixCuDNNRnnAlgorithm.xDesc, pointer, libMatrixCuDNNRnnAlgorithm.hxDesc, pointer2, libMatrixCuDNNRnnAlgorithm.cxDesc, pointer3, libMatrixCuDNNRnnAlgorithm.wDesc, pointer4, libMatrixCuDNNRnnAlgorithm.yDesc, allocate2, libMatrixCuDNNRnnAlgorithm.hyDesc, new Pointer(), libMatrixCuDNNRnnAlgorithm.cyDesc, new Pointer(), libMatrixCuDNNRnnAlgorithm.workSpace, libMatrixCuDNNRnnAlgorithm.sizeInBytes, libMatrixCuDNNRnnAlgorithm.reserveSpace, libMatrixCuDNNRnnAlgorithm.reserveSpaceSizeInBytes);
            Pointer allocate3 = gPUContext.allocate(str, i * i4 * i3 * LibMatrixCUDA.sizeOfDataType, false);
            JCudnn.cudnnRNNBackwardData(gPUContext.getCudnnHandle(), libMatrixCuDNNRnnAlgorithm.rnnDesc, i4, libMatrixCuDNNRnnAlgorithm.yDesc, allocate2, libMatrixCuDNNRnnAlgorithm.dyDesc, allocate, libMatrixCuDNNRnnAlgorithm.dhyDesc, new Pointer(), libMatrixCuDNNRnnAlgorithm.dcyDesc, getDenseInputPointer(executionContext, gPUContext, str, str3, i, i2), libMatrixCuDNNRnnAlgorithm.wDesc, pointer4, libMatrixCuDNNRnnAlgorithm.hxDesc, pointer2, libMatrixCuDNNRnnAlgorithm.cxDesc, pointer3, libMatrixCuDNNRnnAlgorithm.dxDesc, allocate3, libMatrixCuDNNRnnAlgorithm.dhxDesc, getDenseOutputPointer(executionContext, gPUContext, str, str7, i, i2), libMatrixCuDNNRnnAlgorithm.dcxDesc, getDenseOutputPointer(executionContext, gPUContext, str, str8, i, i2), libMatrixCuDNNRnnAlgorithm.workSpace, libMatrixCuDNNRnnAlgorithm.sizeInBytes, libMatrixCuDNNRnnAlgorithm.reserveSpace, libMatrixCuDNNRnnAlgorithm.reserveSpaceSizeInBytes);
            gPUContext.cudaFreeHelper(str, allocate, DMLScript.EAGER_CUDA_FREE);
            executionContext.releaseMatrixInputForGPUInstruction(str3);
            executionContext.releaseMatrixOutputForGPUInstruction(str7);
            executionContext.releaseMatrixOutputForGPUInstruction(str8);
            LibMatrixCUDA.getCudaKernels(gPUContext).launchKernel("prepare_lstm_dinput", ExecutionConfig.getConfigForSimpleVectorOperations(i * i4 * i3), getDenseOutputPointer(executionContext, gPUContext, str, str4, i, i4 * i3), allocate3, Integer.valueOf(i), Integer.valueOf(i3), Integer.valueOf(i4 * i3), Integer.valueOf(i * i4 * i3));
            executionContext.releaseMatrixOutputForGPUInstruction(str4);
            gPUContext.cudaFreeHelper(str, allocate3, DMLScript.EAGER_CUDA_FREE);
            Pointer allocate4 = gPUContext.allocate(str, (i3 + i2 + 2) * 4 * i2 * LibMatrixCUDA.sizeOfDataType, false);
            JCudnn.cudnnRNNBackwardWeights(gPUContext.getCudnnHandle(), libMatrixCuDNNRnnAlgorithm.rnnDesc, i4, libMatrixCuDNNRnnAlgorithm.xDesc, pointer, libMatrixCuDNNRnnAlgorithm.hxDesc, pointer2, libMatrixCuDNNRnnAlgorithm.yDesc, allocate2, libMatrixCuDNNRnnAlgorithm.workSpace, libMatrixCuDNNRnnAlgorithm.sizeInBytes, libMatrixCuDNNRnnAlgorithm.dwDesc, allocate4, libMatrixCuDNNRnnAlgorithm.reserveSpace, libMatrixCuDNNRnnAlgorithm.reserveSpaceSizeInBytes);
            LibMatrixCUDA.getCudaKernels(gPUContext).launchKernel("prepare_lstm_dweight", ExecutionConfig.getConfigForSimpleVectorOperations((i3 + i2 + 2) * 4 * i2), getDenseOutputPointer(executionContext, gPUContext, str, str5, i3 + i2, 4 * i2), getDenseOutputPointer(executionContext, gPUContext, str, str6, 1L, 4 * i2), allocate4, Integer.valueOf(i3), Integer.valueOf(i2));
            gPUContext.cudaFreeHelper(str, allocate4, DMLScript.EAGER_CUDA_FREE);
            executionContext.releaseMatrixOutputForGPUInstruction(str5);
            executionContext.releaseMatrixOutputForGPUInstruction(str6);
            gPUContext.cudaFreeHelper(str, allocate2, DMLScript.EAGER_CUDA_FREE);
            libMatrixCuDNNRnnAlgorithm.close();
        } catch (Throwable th) {
            try {
                libMatrixCuDNNRnnAlgorithm.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    public static void batchNormalizationForwardTraining(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, MatrixObject matrixObject4, MatrixObject matrixObject5, MatrixObject matrixObject6, MatrixObject matrixObject7, MatrixObject matrixObject8, double d, double d2, MatrixObject matrixObject9, MatrixObject matrixObject10) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : batchNormalizationForwardTraining, GPUContext=" + gPUContext);
        }
        int i = toInt(matrixObject.getNumRows());
        int i2 = toInt(matrixObject2.getNumRows());
        long numColumns = matrixObject.getNumColumns();
        validateBatchNormalizationDimensions(matrixObject2, matrixObject3, matrixObject4, matrixObject5, i2);
        cudnnTensorDescriptor allocateNCHWDescriptors = allocateNCHWDescriptors(gPUContext, i, i2, numColumns, new MatrixObject[]{matrixObject}, new MatrixObject[]{matrixObject6});
        cudnnTensorDescriptor allocateTensorDescriptor = allocateTensorDescriptor(1, i2, 1, 1);
        Pointer densePointerForCuDNN = getDensePointerForCuDNN(gPUContext, matrixObject, str);
        Pointer densePointerForCuDNN2 = getDensePointerForCuDNN(gPUContext, matrixObject6, str);
        Pointer densePointerForCuDNN3 = getDensePointerForCuDNN(gPUContext, matrixObject3, str);
        Pointer densePointerForCuDNN4 = getDensePointerForCuDNN(gPUContext, matrixObject2, str);
        Pointer densePointerForCuDNN5 = getDensePointerForCuDNN(gPUContext, matrixObject4, str);
        Pointer densePointerForCuDNN6 = getDensePointerForCuDNN(gPUContext, matrixObject5, str);
        Pointer densePointerForCuDNN7 = getDensePointerForCuDNN(gPUContext, matrixObject7, str);
        Pointer densePointerForCuDNN8 = getDensePointerForCuDNN(gPUContext, matrixObject8, str);
        JCuda.cudaMemcpy(densePointerForCuDNN7, densePointerForCuDNN5, i2 * sizeOfDataType, 3);
        JCuda.cudaMemcpy(densePointerForCuDNN8, densePointerForCuDNN6, i2 * sizeOfDataType, 3);
        checkStatus(JCudnn.cudnnBatchNormalizationForwardTraining(getCudnnHandle(gPUContext), 1, one(), zero(), allocateNCHWDescriptors, densePointerForCuDNN, allocateNCHWDescriptors, densePointerForCuDNN2, allocateTensorDescriptor, densePointerForCuDNN4, densePointerForCuDNN3, d2, densePointerForCuDNN7, densePointerForCuDNN8, d, getDensePointerForCuDNN(gPUContext, matrixObject9, str), getDensePointerForCuDNN(gPUContext, matrixObject10, str)));
    }

    public static void batchNormalizationForwardInference(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, MatrixObject matrixObject4, MatrixObject matrixObject5, MatrixObject matrixObject6, double d) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : batchNormalizationForwardInference, GPUContext=" + gPUContext);
        }
        int i = toInt(matrixObject.getNumRows());
        int i2 = toInt(matrixObject2.getNumRows());
        long numColumns = matrixObject.getNumColumns();
        validateBatchNormalizationDimensions(matrixObject2, matrixObject3, matrixObject4, matrixObject5, i2);
        cudnnTensorDescriptor allocateNCHWDescriptors = allocateNCHWDescriptors(gPUContext, i, i2, numColumns, new MatrixObject[]{matrixObject}, new MatrixObject[]{matrixObject6});
        cudnnTensorDescriptor allocateTensorDescriptor = allocateTensorDescriptor(1, i2, 1, 1);
        checkStatus(JCudnn.cudnnBatchNormalizationForwardInference(getCudnnHandle(gPUContext), 1, one(), zero(), allocateNCHWDescriptors, getDensePointerForCuDNN(gPUContext, matrixObject, str), allocateNCHWDescriptors, getDensePointerForCuDNN(gPUContext, matrixObject6, str), allocateTensorDescriptor, getDensePointerForCuDNN(gPUContext, matrixObject2, str), getDensePointerForCuDNN(gPUContext, matrixObject3, str), getDensePointerForCuDNN(gPUContext, matrixObject4, str), getDensePointerForCuDNN(gPUContext, matrixObject5, str), d));
    }

    public static void batchNormalizationBackward(GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, MatrixObject matrixObject4, MatrixObject matrixObject5, MatrixObject matrixObject6, double d, MatrixObject matrixObject7, MatrixObject matrixObject8) throws DMLRuntimeException {
        if (LOG.isTraceEnabled()) {
            LOG.trace("GPU : batchNormalizationBackward, GPUContext=" + gPUContext);
        }
        int i = toInt(matrixObject.getNumRows());
        int i2 = toInt(matrixObject3.getNumRows());
        cudnnTensorDescriptor allocateNCHWDescriptors = allocateNCHWDescriptors(gPUContext, i, i2, matrixObject.getNumColumns(), new MatrixObject[]{matrixObject, matrixObject2}, new MatrixObject[]{matrixObject4});
        cudnnTensorDescriptor allocateTensorDescriptor = allocateTensorDescriptor(1, i2, 1, 1);
        checkStatus(JCudnn.cudnnBatchNormalizationBackward(getCudnnHandle(gPUContext), 1, one(), zero(), one(), zero(), allocateNCHWDescriptors, getDensePointerForCuDNN(gPUContext, matrixObject, str), allocateNCHWDescriptors, getDensePointerForCuDNN(gPUContext, matrixObject2, str), allocateNCHWDescriptors, getDensePointerForCuDNN(gPUContext, matrixObject4, str), allocateTensorDescriptor, getDensePointerForCuDNN(gPUContext, matrixObject3, str), getDensePointerForCuDNN(gPUContext, matrixObject5, str), getDensePointerForCuDNN(gPUContext, matrixObject6, str), d, getDensePointerForCuDNN(gPUContext, matrixObject7, str), getDensePointerForCuDNN(gPUContext, matrixObject8, str)));
    }

    private static void validateBatchNormalizationDimensions(MatrixObject matrixObject, MatrixObject matrixObject2, MatrixObject matrixObject3, MatrixObject matrixObject4, int i) throws DMLRuntimeException {
        if (matrixObject.getNumRows() != i || matrixObject.getNumColumns() != 1) {
            long numRows = matrixObject.getNumRows();
            matrixObject.getNumColumns();
            DMLRuntimeException dMLRuntimeException = new DMLRuntimeException("Incorrect dimensions for scale. Expected a column vector of size " + i + ", but found [" + numRows + ", " + dMLRuntimeException + "]");
            throw dMLRuntimeException;
        }
        if (matrixObject2.getNumRows() != i || matrixObject2.getNumColumns() != 1) {
            long numRows2 = matrixObject2.getNumRows();
            matrixObject2.getNumColumns();
            DMLRuntimeException dMLRuntimeException2 = new DMLRuntimeException("Incorrect dimensions for bias. Expected a column vector of size " + i + ", but found [" + numRows2 + ", " + dMLRuntimeException2 + "]");
            throw dMLRuntimeException2;
        }
        if (matrixObject3.getNumRows() != i || matrixObject3.getNumColumns() != 1) {
            long numRows3 = matrixObject3.getNumRows();
            matrixObject3.getNumColumns();
            DMLRuntimeException dMLRuntimeException3 = new DMLRuntimeException("Incorrect dimensions for running mean. Expected a column vector of size " + i + ", but found [" + numRows3 + ", " + dMLRuntimeException3 + "]");
            throw dMLRuntimeException3;
        }
        if (matrixObject4.getNumRows() == i && matrixObject4.getNumColumns() == 1) {
            return;
        }
        long numRows4 = matrixObject4.getNumRows();
        matrixObject4.getNumColumns();
        DMLRuntimeException dMLRuntimeException4 = new DMLRuntimeException("Incorrect dimensions for running variance. Expected a column vector of size " + i + ", but found [" + numRows4 + ", " + dMLRuntimeException4 + "]");
        throw dMLRuntimeException4;
    }

    private static cudnnTensorDescriptor allocateNCHWDescriptors(GPUContext gPUContext, int i, int i2, long j, MatrixObject[] matrixObjectArr, MatrixObject[] matrixObjectArr2) throws DMLRuntimeException {
        if (j > OptimizerUtils.MAX_NUMCELLS_CP_DENSE * i2) {
            throw new DMLRuntimeException("image size (height*width) should be less than 2147483647");
        }
        int i3 = (int) (j / i2);
        int i4 = i3;
        int i5 = 1;
        double sqrt = Math.sqrt(i3);
        if (sqrt == ((int) sqrt)) {
            i4 = (int) sqrt;
            i5 = i4;
        }
        cudnnTensorDescriptor cudnntensordescriptor = new cudnnTensorDescriptor();
        JCudnn.cudnnCreateTensorDescriptor(cudnntensordescriptor);
        JCudnn.cudnnSetTensor4dDescriptor(cudnntensordescriptor, 0, CUDNN_DATA_TYPE, i, i2, i4, i5);
        return cudnntensordescriptor;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static Pointer getDensePointerForCuDNN(GPUContext gPUContext, MatrixObject matrixObject, String str) {
        long numRows = matrixObject.getNumRows() * matrixObject.getNumColumns();
        if (numRows <= maxNumElementsOfCuDNNTensor) {
            return getDensePointer(gPUContext, matrixObject, str);
        }
        long numRows2 = matrixObject.getNumRows();
        matrixObject.getNumColumns();
        DMLRuntimeException dMLRuntimeException = new DMLRuntimeException("CuDNN restriction: the size of input tensor cannot have greater than 2 giga-elements, but has " + numRows + " (i.e. [" + dMLRuntimeException + " X " + numRows2 + "]). Hint: try reducing the mini-batch size.");
        throw dMLRuntimeException;
    }

    public static Pointer getDensePointerForCuDNN(GPUContext gPUContext, MatrixObject matrixObject, String str, int i, int i2) throws DMLRuntimeException {
        long numRows = matrixObject.getNumRows() * matrixObject.getNumColumns();
        if (matrixObject.getNumRows() != i || matrixObject.getNumColumns() != i2) {
            long numRows2 = matrixObject.getNumRows();
            matrixObject.getNumColumns();
            DMLRuntimeException dMLRuntimeException = new DMLRuntimeException("Expected input of size:[" + i + ", " + i2 + "], but found [" + numRows2 + ", " + dMLRuntimeException + "].");
            throw dMLRuntimeException;
        }
        if (numRows > maxNumElementsOfCuDNNTensor) {
            long numRows3 = matrixObject.getNumRows();
            matrixObject.getNumColumns();
            DMLRuntimeException dMLRuntimeException2 = new DMLRuntimeException("CuDNN restriction: the size of input tensor cannot have greater than 2 giga-elements, but has " + numRows + " (i.e. [" + dMLRuntimeException2 + " X " + numRows3 + "]). Hint: try reducing the mini-batch size.");
            throw dMLRuntimeException2;
        }
        Pointer densePointer = getDensePointer(gPUContext, matrixObject, str);
        if (gPUContext.getMemoryManager().getSizeAllocatedGPUPointer(densePointer) == numRows * sizeOfDataType) {
            return densePointer;
        }
        DMLRuntimeException dMLRuntimeException3 = new DMLRuntimeException("Incorrect pointer: expected size:" + (numRows * sizeOfDataType) + ", but found " + dMLRuntimeException3);
        throw dMLRuntimeException3;
    }

    protected static void checkStatus(int i) {
        if (i != 0) {
            throw new DMLRuntimeException("Error status returned by CuDNN:" + cudnnStatus.stringFor(i));
        }
    }
}
