package org.apache.sysds.runtime.matrix.data;

import jcuda.Pointer;
import jcuda.jcublas.cublasHandle;
import jcuda.jcusparse.cusparseHandle;
import jcuda.runtime.JCuda;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.gpu.context.CSRPointer;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;

/* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibMatrixCuMatMult.class */
public class LibMatrixCuMatMult extends LibMatrixCUDA {
    private static final Log LOG = LogFactory.getLog(LibMatrixCuMatMult.class.getName());

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/apache/sysds/runtime/matrix/data/LibMatrixCuMatMult$CuMatMultParameters.class */
    public static class CuMatMultParameters {
        public int m;
        public int n;
        public int k;
        public int lda;
        public int ldb;
        public int ldc;
        public long leftNumRows;
        public long leftNumCols;
        public long rightNumRows;
        public long rightNumCols;
        private boolean isLeftTransposed;
        private boolean isRightTransposed;

        public CuMatMultParameters(long j, long j2, long j3, long j4, boolean z, boolean z2) {
            this.leftNumRows = j;
            this.leftNumCols = j2;
            this.rightNumRows = j3;
            this.rightNumCols = j4;
            this.isLeftTransposed = z;
            this.isRightTransposed = z2;
            setDimensions();
        }

        public void rowToColumnMajor() {
            boolean z = this.isLeftTransposed;
            boolean z2 = this.isRightTransposed;
            this.isLeftTransposed = z2;
            this.isRightTransposed = LibMatrixCuMatMult.swap(z, z2);
            long j = this.leftNumRows;
            long j2 = this.rightNumCols;
            this.leftNumRows = j2;
            this.rightNumCols = LibMatrixCuMatMult.swap(j, j2);
            long j3 = this.leftNumCols;
            long j4 = this.rightNumRows;
            this.leftNumCols = j4;
            this.rightNumRows = LibMatrixCuMatMult.swap(j3, j4);
            setDimensions();
        }

        private void validate() {
            int i = LibMatrixCUDA.toInt(this.isRightTransposed ? this.rightNumCols : this.rightNumRows);
            if (this.k != i) {
                int i2 = this.k;
                long j = this.leftNumRows;
                long j2 = this.leftNumCols;
                long j3 = this.rightNumRows;
                long j4 = this.rightNumCols;
                boolean z = this.isLeftTransposed;
                boolean z2 = this.isRightTransposed;
                DMLRuntimeException dMLRuntimeException = new DMLRuntimeException("Dimension mismatch: " + i2 + " != " + i + " [" + j + "," + dMLRuntimeException + "," + j2 + "," + dMLRuntimeException + "], " + j3 + " " + dMLRuntimeException);
                throw dMLRuntimeException;
            }
        }

        private void setDimensions() {
            this.m = LibMatrixCUDA.toInt(this.isLeftTransposed ? this.leftNumCols : this.leftNumRows);
            this.n = LibMatrixCUDA.toInt(this.isRightTransposed ? this.rightNumRows : this.rightNumCols);
            this.k = LibMatrixCUDA.toInt(this.isLeftTransposed ? this.leftNumRows : this.leftNumCols);
            this.lda = this.isLeftTransposed ? this.k : this.m;
            this.ldb = this.isRightTransposed ? this.n : this.k;
            this.ldc = this.m;
            if (this.m == -1 || this.n == -1 || this.k == -1) {
                throw new DMLRuntimeException("Incorrect dimensions");
            }
        }
    }

    public static MatrixObject matmult(ExecutionContext executionContext, GPUContext gPUContext, String str, MatrixObject matrixObject, MatrixObject matrixObject2, String str2, boolean z, boolean z2) {
        boolean isInSparseFormat = isInSparseFormat(gPUContext, matrixObject);
        boolean isInSparseFormat2 = isInSparseFormat(gPUContext, matrixObject2);
        MatrixObject matrixObject3 = executionContext.getMatrixObject(str2);
        long numColumns = z ? matrixObject.getNumColumns() : matrixObject.getNumRows();
        long numRows = z2 ? matrixObject2.getNumRows() : matrixObject2.getNumColumns();
        CuMatMultParameters cuMatMultParameters = new CuMatMultParameters(matrixObject.getNumRows(), matrixObject.getNumColumns(), matrixObject2.getNumRows(), matrixObject2.getNumColumns(), z, z2);
        if (isInSparseFormat && isInSparseFormat2) {
            cuMatMultParameters.validate();
            int cusparseOp = cusparseOp(z);
            int cusparseOp2 = cusparseOp(z2);
            executionContext.allocateGPUMatrixObject(str2, numColumns, numRows);
            CSRPointer jcudaSparseMatrixPtr = matrixObject.getGPUObject(gPUContext).getJcudaSparseMatrixPtr();
            CSRPointer jcudaSparseMatrixPtr2 = matrixObject2.getGPUObject(gPUContext).getJcudaSparseMatrixPtr();
            CSRPointer allocateForMatrixMultiply = CSRPointer.allocateForMatrixMultiply(gPUContext, getCusparseHandle(gPUContext), jcudaSparseMatrixPtr, cusparseOp, jcudaSparseMatrixPtr2, cusparseOp2, cuMatMultParameters.m, cuMatMultParameters.n, cuMatMultParameters.k);
            cudaSupportFunctions.cusparsecsrgemm(getCusparseHandle(gPUContext), cusparseOp, cusparseOp2, cuMatMultParameters.m, cuMatMultParameters.n, cuMatMultParameters.k, jcudaSparseMatrixPtr.descr, (int) jcudaSparseMatrixPtr.nnz, jcudaSparseMatrixPtr.val, jcudaSparseMatrixPtr.rowPtr, jcudaSparseMatrixPtr.colInd, jcudaSparseMatrixPtr2.descr, (int) jcudaSparseMatrixPtr2.nnz, jcudaSparseMatrixPtr2.val, jcudaSparseMatrixPtr2.rowPtr, jcudaSparseMatrixPtr2.colInd, allocateForMatrixMultiply.descr, allocateForMatrixMultiply.val, allocateForMatrixMultiply.rowPtr, allocateForMatrixMultiply.colInd);
            matrixObject3.getGPUObject(gPUContext).setSparseMatrixCudaPointer(allocateForMatrixMultiply);
        } else if (!isInSparseFormat && isInSparseFormat2) {
            getDenseMatrixOutputForGPUInstruction(executionContext, str, str2, numColumns, numRows);
            denseSparseMatMult(getCusparseHandle(gPUContext), str, getDensePointer(gPUContext, matrixObject3, str), getDensePointer(gPUContext, matrixObject, str), matrixObject2.getGPUObject(gPUContext).getJcudaSparseMatrixPtr(), cuMatMultParameters);
        } else if (!isInSparseFormat || isInSparseFormat2) {
            getDenseMatrixOutputForGPUInstruction(executionContext, str, str2, numColumns, numRows);
            denseDenseMatMult(getCublasHandle(gPUContext), str, getDensePointer(gPUContext, matrixObject3, str), getDensePointer(gPUContext, matrixObject, str), getDensePointer(gPUContext, matrixObject2, str), cuMatMultParameters);
        } else {
            getDenseMatrixOutputForGPUInstruction(executionContext, str, str2, numColumns, numRows);
            sparseDenseMatMult(gPUContext, str, getDensePointer(gPUContext, matrixObject3, str), matrixObject.getGPUObject(gPUContext).getJcudaSparseMatrixPtr(), getDensePointer(gPUContext, matrixObject2, str), matrixObject.getNumRows(), matrixObject.getNumColumns(), matrixObject2.getNumRows(), matrixObject2.getNumColumns(), numColumns, numRows, z, z2);
        }
        return matrixObject3;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public static void sparseDenseMatMult(GPUContext gPUContext, String str, Pointer pointer, CSRPointer cSRPointer, Pointer pointer2, long j, long j2, long j3, long j4, long j5, long j6, boolean z, boolean z2) {
        Pointer allocate = (j5 == 1 || j6 == 1) ? pointer : gPUContext.allocate(str, j5 * j6 * sizeOfDataType, false);
        denseSparseMatMult(getCusparseHandle(gPUContext), str, allocate, pointer2, cSRPointer, new CuMatMultParameters(j3, j4, j, j2, !z2, !z));
        if (j5 == 1 || j6 == 1) {
            return;
        }
        cudaSupportFunctions.cublasgeam(gPUContext.getCublasHandle(), 1, 1, toInt(j6), toInt(j5), one(), allocate, toInt(j5), zero(), new Pointer(), toInt(j5), pointer, toInt(j6));
        if (!DMLScript.EAGER_CUDA_FREE) {
            JCuda.cudaDeviceSynchronize();
        }
        gPUContext.cudaFreeHelper(str, allocate, DMLScript.EAGER_CUDA_FREE);
    }

    private static void denseSparseMatMult(cusparseHandle cusparsehandle, String str, Pointer pointer, Pointer pointer2, CSRPointer cSRPointer, CuMatMultParameters cuMatMultParameters) {
        if ((cuMatMultParameters.leftNumRows == 1 && !cuMatMultParameters.isLeftTransposed) || (cuMatMultParameters.leftNumCols == 1 && cuMatMultParameters.isLeftTransposed)) {
            LOG.debug(" GPU Sparse-Dense Matrix Vector ");
            cudaSupportFunctions.cusparsecsrmv(cusparsehandle, reverseCusparseOp(cusparseOp(cuMatMultParameters.isLeftTransposed)), toInt(cuMatMultParameters.rightNumRows), toInt(cuMatMultParameters.rightNumCols), toInt(cSRPointer.nnz), one(), cSRPointer.descr, cSRPointer.val, cSRPointer.rowPtr, cSRPointer.colInd, pointer2, zero(), pointer);
            return;
        }
        int i = toInt(cuMatMultParameters.rightNumRows);
        int i2 = toInt(cuMatMultParameters.rightNumCols);
        cuMatMultParameters.rowToColumnMajor();
        cuMatMultParameters.validate();
        int reverseCusparseOp = reverseCusparseOp(cusparseOp(cuMatMultParameters.isLeftTransposed));
        int cusparseOp = cusparseOp(cuMatMultParameters.isRightTransposed);
        LOG.debug(" GPU Sparse-Dense Matrix Multiply (rhs transpose) ");
        cudaSupportFunctions.cusparsecsrmm2(cusparsehandle, reverseCusparseOp, cusparseOp, i, cuMatMultParameters.n, i2, toInt(cSRPointer.nnz), one(), cSRPointer.descr, cSRPointer.val, cSRPointer.rowPtr, cSRPointer.colInd, pointer2, cuMatMultParameters.ldb, zero(), pointer, cuMatMultParameters.ldc);
    }

    private static void denseDenseMatMult(cublasHandle cublashandle, String str, Pointer pointer, Pointer pointer2, Pointer pointer3, CuMatMultParameters cuMatMultParameters) {
        cuMatMultParameters.rowToColumnMajor();
        cuMatMultParameters.validate();
        int cublasOp = cublasOp(cuMatMultParameters.isLeftTransposed);
        int cublasOp2 = cublasOp(cuMatMultParameters.isRightTransposed);
        Pointer swap = swap(pointer2, pointer3);
        if (cuMatMultParameters.m == 1 && cuMatMultParameters.n == 1) {
            LOG.debug(" GPU Dense-dense Vector Product");
            double[] dArr = {DataExpression.DEFAULT_DELIM_FILL_VALUE};
            cudaSupportFunctions.cublasdot(cublashandle, cuMatMultParameters.k, pointer3, 1, swap, 1, Pointer.to(dArr));
            JCuda.cudaMemcpy(pointer, Pointer.to(dArr), 1 * sizeOfDataType, 1);
            return;
        }
        if (cuMatMultParameters.m == 1) {
            LOG.debug(" GPU Dense Vector-Matrix Multiply");
            int reverseCublasOp = reverseCublasOp(cublasOp2);
            cudaSupportFunctions.cublasgemv(cublashandle, reverseCublasOp, reverseCublasOp == 1 ? cuMatMultParameters.k : cuMatMultParameters.n, reverseCublasOp == 1 ? cuMatMultParameters.n : cuMatMultParameters.k, one(), swap, cuMatMultParameters.ldb, pointer3, 1, zero(), pointer, 1);
        } else if (cuMatMultParameters.n == 1) {
            LOG.debug(" GPU Dense Matrix-Vector Multiply");
            cudaSupportFunctions.cublasgemv(cublashandle, cublasOp, cublasOp == 0 ? cuMatMultParameters.m : cuMatMultParameters.k, cublasOp == 0 ? cuMatMultParameters.k : cuMatMultParameters.m, one(), pointer3, cuMatMultParameters.lda, swap, 1, zero(), pointer, 1);
        } else {
            LOG.debug(" GPU Dense-Dense Matrix Multiply ");
            cudaSupportFunctions.cublasgemm(cublashandle, cublasOp, cublasOp2, cuMatMultParameters.m, cuMatMultParameters.n, cuMatMultParameters.k, one(), pointer3, cuMatMultParameters.lda, swap, cuMatMultParameters.ldb, zero(), pointer, cuMatMultParameters.ldc);
        }
    }

    private static long swap(long j, long j2) {
        return j;
    }

    private static boolean swap(boolean z, boolean z2) {
        return z;
    }

    private static Pointer swap(Pointer pointer, Pointer pointer2) {
        return pointer;
    }

    private static int cusparseOp(boolean z) {
        return z ? 1 : 0;
    }

    private static int cublasOp(boolean z) {
        return z ? 1 : 0;
    }

    private static int reverseCublasOp(int i) {
        return i == 1 ? 0 : 1;
    }

    private static int reverseCusparseOp(int i) {
        return i == 1 ? 0 : 1;
    }
}
