package jcublas.fft;

import java.util.Iterator;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import jcuda.jcufft.JCufft;
import jcuda.jcufft.cufftHandle;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.fft.DefaultFFTInstance;
import org.nd4j.linalg.jcublas.fft.ops.JCudaVectorFFT;

/* loaded from: input_file:jcublas/fft/JcudaFft.class */
public class JcudaFft extends DefaultFFTInstance {
    private final Map<String, cufftHandle> handles = new ConcurrentHashMap();

    public JcudaFft() {
        Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { // from class: jcublas.fft.JcudaFft.1
            @Override // java.lang.Runnable
            public void run() {
                Iterator it = JcudaFft.this.handles.values().iterator();
                while (it.hasNext()) {
                    JCufft.cufftDestroy((cufftHandle) it.next());
                }
            }
        }));
        JCufft.setExceptionsEnabled(true);
    }

    public cufftHandle getHandle() {
        cufftHandle cuffthandle = this.handles.get(Thread.currentThread().getName());
        if (cuffthandle == null) {
            cuffthandle = new cufftHandle();
            JCufft.cufftCreate(cuffthandle);
            this.handles.put(Thread.currentThread().getName(), cuffthandle);
        }
        return cuffthandle;
    }

    protected Op getFftOp(INDArray iNDArray, int i) {
        return new JCudaVectorFFT(iNDArray, i);
    }
}
