package jcublas.context;

import com.google.common.collect.HashBasedTable;
import com.google.common.collect.Table;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import jcuda.CudaException;
import jcuda.driver.CUcontext;
import jcuda.driver.CUdevice;
import jcuda.driver.CUresult;
import jcuda.driver.CUstream;
import jcuda.driver.JCudaDriver;
import jcuda.jcublas.JCublas2;
import jcuda.jcublas.cublasHandle;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaStream_t;
import org.nd4j.linalg.api.buffer.allocation.MemoryStrategy;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.jcublas.device.conf.DeviceConfiguration;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

/* loaded from: input_file:jcublas/context/ContextHolder.class */
public class ContextHolder {
    private List<Integer> bannedDevices;
    private static ContextHolder INSTANCE;
    public static final String DEVICES_TO_BAN = "org.nd4j.linalg.jcuda.jcublas.ban_devices";
    public static final String SYNC_THREADS = "org.nd4j.linalg.jcuda.jcublas.syncthreads";
    private static boolean syncThreads = true;
    private static Logger log = LoggerFactory.getLogger(ContextHolder.class);
    private Map<Integer, CUdevice> devices = new ConcurrentHashMap();
    private Map<Integer, GpuInformation> info = new ConcurrentHashMap();
    private Map<Integer, CUcontext> deviceIDContexts = new ConcurrentHashMap();
    private Map<String, Integer> threadNameToDeviceNumber = new ConcurrentHashMap();
    private Table<CUcontext, String, CUstream> contextStreams = HashBasedTable.create();
    private Table<CUcontext, String, cudaStream_t> cudaStreams = HashBasedTable.create();
    private Map<String, cublasHandle> handleMap = new ConcurrentHashMap();
    private int numDevices = 0;
    private Map<Integer, DeviceConfiguration> confs = new ConcurrentHashMap();
    private boolean confCalled = false;
    private AtomicBoolean shutdown = new AtomicBoolean(false);

    private ContextHolder() {
        try {
            getNumDevices();
        } catch (Exception e) {
            log.warn("Unable to initialize cuda", e);
        }
    }

    public static synchronized ContextHolder getInstance() {
        if (INSTANCE == null) {
            Properties properties = new Properties();
            try {
                properties.load(new ClassPathResource("/cudafunctions.properties", ContextHolder.class.getClassLoader()).getInputStream());
                INSTANCE = new ContextHolder();
                INSTANCE.configure();
                for (String str : properties.stringPropertyNames()) {
                    System.getProperties().put(str, properties.getProperty(str));
                }
                Runtime.getRuntime().addShutdownHook(new Thread(new Runnable() { // from class: jcublas.context.ContextHolder.1
                    @Override // java.lang.Runnable
                    public void run() {
                        ContextHolder.INSTANCE.destroy();
                    }
                }));
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }
        return INSTANCE;
    }

    public int deviceNum() {
        return this.numDevices;
    }

    public DeviceConfiguration getConf() {
        return getConf(getDeviceForThread());
    }

    public MemoryStrategy getMemoryStrategy() {
        return getConf().getMemoryStrategy();
    }

    public void configure() {
        if (this.confCalled) {
            return;
        }
        syncThreads = Boolean.parseBoolean(System.getProperty("org.nd4j.linalg.jcuda.jcublas.syncthreads", "true"));
        if (this.numDevices == 0) {
            getNumDevices();
        }
        for (int i = 0; i < this.numDevices; i++) {
            ClassPathResource classPathResource = new ClassPathResource("devices/" + i, ContextHolder.class.getClassLoader());
            if (classPathResource.exists()) {
                Properties properties = new Properties();
                try {
                    properties.load(classPathResource.getInputStream());
                    this.confs.put(Integer.valueOf(i), new DeviceConfiguration(i, properties));
                } catch (IOException e) {
                    throw new RuntimeException(e);
                }
            } else {
                this.confs.put(Integer.valueOf(i), new DeviceConfiguration(i));
            }
        }
        this.confCalled = true;
    }

    public void setNumDevices(int i) {
        this.numDevices = i;
    }

    public DeviceConfiguration getConf(int i) {
        return this.confs.get(Integer.valueOf(i));
    }

    private void getNumDevices() {
        JCudaDriver.setExceptionsEnabled(true);
        JCudaDriver.cuInit(0);
        int[] iArr = new int[1];
        JCudaDriver.cuDeviceGetCount(iArr);
        this.numDevices = iArr[0];
        log.debug("Found " + this.numDevices + " gpus");
        if (this.numDevices < 1) {
            this.numDevices = 1;
        }
        this.bannedDevices = new ArrayList();
        String[] split = System.getProperty("org.nd4j.linalg.jcuda.jcublas.ban_devices", "-1").split(",");
        if (split.length >= 1) {
            for (String str : split) {
                if (Integer.valueOf(Integer.parseInt(str)).intValue() >= 0) {
                    this.bannedDevices.add(Integer.valueOf(Integer.parseInt(str)));
                }
            }
        }
    }

    public static void syncStream() {
        JCudaDriver.cuCtxSetCurrent(getInstance().getContext());
        JCublas2.cublasSetStream(getInstance().getHandle(), getInstance().getCudaStream());
        JCuda.cudaStreamSynchronize(getInstance().getCudaStream());
        JCudaDriver.cuStreamSynchronize(getInstance().getStream());
    }

    public int getDeviceForThread() {
        Integer num;
        if (this.numDevices <= 1 || this.threadNameToDeviceNumber.get(Thread.currentThread().getName()) != null) {
            return 0;
        }
        if (Nd4j.getRandom() == null) {
            throw new IllegalStateException("Unable to load random class");
        }
        Integer valueOf = Integer.valueOf(Nd4j.getRandom().nextInt(this.numDevices));
        while (true) {
            num = valueOf;
            if (this.bannedDevices == null || !this.bannedDevices.contains(num)) {
                break;
            }
            valueOf = Integer.valueOf(Nd4j.getRandom().nextInt(this.numDevices));
        }
        this.threadNameToDeviceNumber.put(Thread.currentThread().getName(), num);
        return num.intValue();
    }

    public cublasHandle getHandle() {
        cublasHandle cublashandle = this.handleMap.get(Thread.currentThread().getName());
        if (cublashandle != null) {
            return cublashandle;
        }
        cublasHandle cublashandle2 = new cublasHandle();
        JCublas2.cublasCreate(cublashandle2);
        this.handleMap.put(Thread.currentThread().getName(), cublashandle2);
        return cublashandle2;
    }

    public CUcontext getContext() {
        return getContext(getDeviceForThread());
    }

    public synchronized cudaStream_t getCudaStream() {
        Thread currentThread = Thread.currentThread();
        CUcontext context = getContext(getDeviceForThread());
        cudaStream_t cudastream_t = (cudaStream_t) this.cudaStreams.get(context, currentThread.getName());
        if (cudastream_t == null) {
            cudastream_t = new cudaStream_t();
            checkResult(JCudaDriver.cuCtxSetCurrent(context));
            JCuda.cudaStreamCreate(cudastream_t);
            checkResult(JCuda.cudaStreamCreate(cudastream_t));
            this.cudaStreams.put(context, currentThread.getName(), cudastream_t);
        }
        return cudastream_t;
    }

    public CUstream getStream() {
        Thread currentThread = Thread.currentThread();
        CUcontext context = getContext(getDeviceForThread());
        CUstream cUstream = (CUstream) this.contextStreams.get(context, currentThread.getName());
        if (cUstream == null) {
            cUstream = new CUstream();
            checkResult(JCudaDriver.cuCtxSetCurrent(context));
            checkResult(JCudaDriver.cuStreamCreate(cUstream, 1));
            this.contextStreams.put(context, currentThread.getName(), cUstream);
        }
        return cUstream;
    }

    private void checkResult(int i) {
        if (i != 0) {
            throw new CudaException("Failed to create a stream: " + CUresult.stringFor(i));
        }
    }

    public synchronized CUcontext getContext(int i) {
        CUcontext cUcontext = this.deviceIDContexts.get(Integer.valueOf(i));
        if (cUcontext == null) {
            cUcontext = new CUcontext();
            for (int i2 = 0; i2 < this.numDevices; i2++) {
                initialize(cUcontext, i2);
                CUdevice createDevice = createDevice(cUcontext, i2);
                this.devices.put(Integer.valueOf(i2), createDevice);
                this.info.put(Integer.valueOf(i2), new GpuInformation(createDevice));
                this.deviceIDContexts.put(Integer.valueOf(i2), cUcontext);
            }
        }
        return cUcontext;
    }

    private void initialize(CUcontext cUcontext, int i) {
        JCudaDriver.cuInit(0);
        JCudaDriver.cuCtxGetCurrent(cUcontext);
        if (cUcontext.equals(new CUcontext())) {
            createContext(cUcontext, i);
        }
    }

    private void createContext(CUcontext cUcontext, int i) {
        CUdevice cUdevice = new CUdevice();
        int cuDeviceGet = JCudaDriver.cuDeviceGet(cUdevice, i);
        if (cuDeviceGet != 0) {
            throw new CudaException("Failed to obtain a device: " + CUresult.stringFor(cuDeviceGet));
        }
        int cuCtxCreate = JCudaDriver.cuCtxCreate(cUcontext, 0, cUdevice);
        if (cuCtxCreate != 0) {
            throw new CudaException("Failed to create a context: " + CUresult.stringFor(cuCtxCreate));
        }
    }

    public static CUdevice createDevice(CUcontext cUcontext, int i) {
        CUdevice cUdevice = new CUdevice();
        int cuDeviceGet = JCudaDriver.cuDeviceGet(cUdevice, i);
        if (cuDeviceGet != 0) {
            throw new CudaException("Failed to obtain a device: " + CUresult.stringFor(cuDeviceGet));
        }
        int cuCtxCreate = JCudaDriver.cuCtxCreate(cUcontext, 0, cUdevice);
        if (cuCtxCreate != 0) {
            throw new CudaException("Failed to create a context: " + CUresult.stringFor(cuCtxCreate));
        }
        return cUdevice;
    }

    public GpuInformation getInfoFor(int i) {
        getContext(i);
        return this.info.get(Integer.valueOf(i));
    }

    public Map<Integer, CUdevice> getDevices() {
        return this.devices;
    }

    public Map<Integer, CUcontext> getDeviceIDContexts() {
        return this.deviceIDContexts;
    }

    public synchronized void destroy() {
        if (this.shutdown.get()) {
            return;
        }
        this.shutdown.set(true);
    }
}
