package jcublas.util;

import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import jcuda.Pointer;
import jcuda.cuComplex;
import jcuda.cuDoubleComplex;
import jcuda.driver.CUdeviceptr;
import jcuda.driver.JCudaDriver;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.complex.IComplexDouble;
import org.nd4j.linalg.api.complex.IComplexFloat;
import org.nd4j.linalg.api.ops.ScalarOp;

/* loaded from: input_file:jcublas/util/PointerUtil.class */
public class PointerUtil {
    public static double[] toDoubles(Object[] objArr) {
        double[] dArr = new double[objArr.length];
        for (int i = 0; i < objArr.length; i++) {
            dArr[i] = Double.valueOf(objArr[i].toString()).doubleValue();
        }
        return dArr;
    }

    public static Pointer getPointer(IComplexDouble iComplexDouble) {
        return getPointer(cuDoubleComplex.cuCmplx(iComplexDouble.realComponent().doubleValue(), iComplexDouble.imaginaryComponent().doubleValue()));
    }

    public static Pointer getPointer(IComplexFloat iComplexFloat) {
        return getPointer(cuComplex.cuCmplx(iComplexFloat.realComponent().floatValue(), iComplexFloat.imaginaryComponent().floatValue()));
    }

    public static Pointer getPointer(cuDoubleComplex cudoublecomplex) {
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(16);
        allocateDirect.order(ByteOrder.nativeOrder());
        DoubleBuffer asDoubleBuffer = allocateDirect.asDoubleBuffer();
        asDoubleBuffer.put(0, cudoublecomplex.x);
        asDoubleBuffer.put(1, cudoublecomplex.y);
        return Pointer.to(asDoubleBuffer);
    }

    public static Pointer getPointer(cuComplex cucomplex) {
        ByteBuffer allocateDirect = ByteBuffer.allocateDirect(8);
        allocateDirect.order(ByteOrder.nativeOrder());
        FloatBuffer asFloatBuffer = allocateDirect.asFloatBuffer();
        asFloatBuffer.put(0, cucomplex.x);
        asFloatBuffer.put(1, cucomplex.y);
        return Pointer.to(asFloatBuffer);
    }

    public static float[] toFloats(Object[] objArr) {
        float[] fArr = new float[objArr.length];
        for (int i = 0; i < objArr.length; i++) {
            fArr[i] = Float.valueOf(objArr[i].toString()).floatValue();
        }
        return fArr;
    }

    public static int getNumBlocks(int i, int i2, int i3) {
        int numThreads = getNumThreads(i, i3);
        return Math.min(i2, (i + ((numThreads * 2) - 1)) / (numThreads * 2));
    }

    public static int getNumThreads(int i, int i2) {
        return i < i2 * 2 ? nextPow2((i + 1) / 2) : i2;
    }

    public static int nextPow2(int i) {
        int i2 = i - 1;
        int i3 = i2 | (i2 >> 1);
        int i4 = i3 | (i3 >> 2);
        int i5 = i4 | (i4 >> 4);
        int i6 = i5 | (i5 >> 8);
        return (i6 | (i6 >> 16)) + 1;
    }

    public static CUdeviceptr constructAndAlloc(int i, DataBuffer.Type type) {
        CUdeviceptr cUdeviceptr = new CUdeviceptr();
        JCudaDriver.cuMemAlloc(cUdeviceptr, i * (type == DataBuffer.Type.FLOAT ? 4 : 8));
        return cUdeviceptr;
    }

    public static int sizeFor(DataBuffer.Type type) {
        return type == DataBuffer.Type.DOUBLE ? 8 : 4;
    }

    public static Object getPointer(ScalarOp scalarOp) {
        if (scalarOp.scalar() != null) {
            if (scalarOp.x().data().dataType() == DataBuffer.Type.FLOAT) {
                return new float[]{scalarOp.scalar().floatValue()};
            }
            if (scalarOp.x().data().dataType() == DataBuffer.Type.DOUBLE) {
                return new double[]{scalarOp.scalar().doubleValue()};
            }
        }
        throw new IllegalStateException("Unable to get pointer for scalar operation " + scalarOp);
    }

    public static Pointer getPointer(double d) {
        return Pointer.to(new double[]{d});
    }

    public static Pointer getPointer(float f) {
        return Pointer.to(new float[]{f});
    }
}
