package org.apache.sysds.runtime.codegen;

import java.util.ArrayList;
import jcuda.Pointer;
import org.apache.sysds.hops.codegen.SpoofCompiler;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;

/* loaded from: input_file:org/apache/sysds/runtime/codegen/SpoofCUDAOperator.class */
public interface SpoofCUDAOperator {
    public static final int JNI_MAT_ENTRY_SIZE = 6;

    /* loaded from: input_file:org/apache/sysds/runtime/codegen/SpoofCUDAOperator$PrecisionProxy.class */
    public static abstract class PrecisionProxy {
        protected final long ctx = SpoofCompiler.native_contexts.get(SpoofCompiler.GeneratorAPI.CUDA).longValue();

        public abstract int exec(ExecutionContext executionContext, SpoofCUDAOperator spoofCUDAOperator, int i, long[] jArr, long[] jArr2, long[] jArr3, ArrayList<ScalarObject> arrayList, long j);

        /* JADX INFO: Access modifiers changed from: protected */
        public Pointer transferScalars(ExecutionContext executionContext, SpoofCUDAOperator spoofCUDAOperator, int i, ArrayList<ScalarObject> arrayList) {
            double[] prepInputScalars = SpoofOperator.prepInputScalars(arrayList);
            Pointer allocate = executionContext.getGPUContext(0).allocate(spoofCUDAOperator.getName(), arrayList.size() * i);
            LibMatrixCUDA.cudaSupportFunctions.hostToDevice(executionContext.getGPUContext(0), prepInputScalars, allocate, spoofCUDAOperator.getName());
            return allocate;
        }
    }

    String getName();

    void setScalarPtr(Pointer pointer);

    Pointer getScalarPtr();

    void releaseScalarGPUMemory(ExecutionContext executionContext);

    default long[] prepareInputPointers(ExecutionContext executionContext, ArrayList<MatrixObject> arrayList, int i) {
        long[] jArr = new long[i * 6];
        for (int i2 = 0; i2 < i; i2++) {
            int i3 = i2 * 6;
            if (arrayList.get(i2).getGPUObject(executionContext.getGPUContext(0)).isSparse()) {
                jArr[i3] = executionContext.getGPUSparsePointerAddress(arrayList.get(i2)).nnz;
                jArr[i3 + 1] = arrayList.get(i2).getNumRows();
                jArr[i3 + 2] = arrayList.get(i2).getNumColumns();
                jArr[i3 + 3] = GPUObject.getPointerAddress(executionContext.getGPUSparsePointerAddress(arrayList.get(i2)).rowPtr);
                jArr[i3 + 4] = GPUObject.getPointerAddress(executionContext.getGPUSparsePointerAddress(arrayList.get(i2)).colInd);
                jArr[i3 + 5] = GPUObject.getPointerAddress(executionContext.getGPUSparsePointerAddress(arrayList.get(i2)).val);
            } else {
                jArr[i3] = arrayList.get(i2).getNnz();
                jArr[i3 + 1] = arrayList.get(i2).getNumRows();
                jArr[i3 + 2] = arrayList.get(i2).getNumColumns();
                jArr[i3 + 5] = executionContext.getGPUDensePointerAddress(arrayList.get(i2));
            }
        }
        return jArr;
    }

    default long[] prepareSideInputPointers(ExecutionContext executionContext, ArrayList<MatrixObject> arrayList, int i, boolean z) {
        long[] jArr = new long[(arrayList.size() - i) * 6];
        for (int i2 = i; i2 < arrayList.size(); i2++) {
            int i3 = (i2 - i) * 6;
            if (arrayList.get(i2).getGPUObject(executionContext.getGPUContext(0)).isSparse()) {
                jArr[i3] = executionContext.getGPUSparsePointerAddress(arrayList.get(i2)).nnz;
                jArr[i3 + 1] = arrayList.get(i2).getNumRows();
                jArr[i3 + 2] = arrayList.get(i2).getNumColumns();
                jArr[i3 + 3] = GPUObject.getPointerAddress(executionContext.getGPUSparsePointerAddress(arrayList.get(i2)).rowPtr);
                jArr[i3 + 4] = GPUObject.getPointerAddress(executionContext.getGPUSparsePointerAddress(arrayList.get(i2)).colInd);
                jArr[i3 + 5] = GPUObject.getPointerAddress(executionContext.getGPUSparsePointerAddress(arrayList.get(i2)).val);
            } else if (z && i3 == 0) {
                long numRows = arrayList.get(i2).getNumRows();
                long numColumns = arrayList.get(i2).getNumColumns();
                Pointer densePointer = arrayList.get(i2).getGPUObject(executionContext.getGPUContext(0)).getDensePointer();
                Pointer allocate = executionContext.getGPUContext(0).allocate(getName(), numRows * numColumns * LibMatrixCUDA.sizeOfDataType);
                LibMatrixCUDA.denseTranspose(executionContext, executionContext.getGPUContext(0), getName(), densePointer, allocate, numRows, numColumns);
                jArr[i3] = arrayList.get(i2).getNnz();
                jArr[i3 + 1] = numColumns;
                jArr[i3 + 2] = numRows;
                jArr[i3 + 5] = GPUObject.getPointerAddress(allocate);
            } else {
                jArr[i3] = arrayList.get(i2).getNnz();
                jArr[i3 + 1] = arrayList.get(i2).getNumRows();
                jArr[i3 + 2] = arrayList.get(i2).getNumColumns();
                jArr[i3 + 5] = executionContext.getGPUDensePointerAddress(arrayList.get(i2));
            }
        }
        return jArr;
    }

    default long[] prepareOutputPointers(ExecutionContext executionContext, MatrixObject matrixObject, boolean z) {
        long[] jArr = {0, 0, 0, 0, 0, 0};
        if (z) {
            jArr[0] = executionContext.getGPUSparsePointerAddress(matrixObject).nnz;
            jArr[1] = matrixObject.getNumRows();
            jArr[2] = matrixObject.getNumColumns();
            jArr[3] = GPUObject.getPointerAddress(executionContext.getGPUSparsePointerAddress(matrixObject).rowPtr);
            jArr[4] = GPUObject.getPointerAddress(executionContext.getGPUSparsePointerAddress(matrixObject).colInd);
            jArr[5] = GPUObject.getPointerAddress(executionContext.getGPUSparsePointerAddress(matrixObject).val);
        } else {
            jArr[0] = matrixObject.getNnz();
            jArr[1] = matrixObject.getNumRows();
            jArr[2] = matrixObject.getNumColumns();
            jArr[5] = executionContext.getGPUDensePointerAddress(matrixObject);
        }
        return jArr;
    }

    MatrixObject execute(ExecutionContext executionContext, ArrayList<MatrixObject> arrayList, ArrayList<ScalarObject> arrayList2, String str);

    ScalarObject execute(ExecutionContext executionContext, ArrayList<MatrixObject> arrayList, ArrayList<ScalarObject> arrayList2);

    int execute_sp(long j, long[] jArr, long[] jArr2, long[] jArr3, long[] jArr4, long j2);

    int execute_dp(long j, long[] jArr, long[] jArr2, long[] jArr3, long[] jArr4, long j2);
}
