package org.apache.sysds.runtime.lineage;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.ExecutorService;
import java.util.stream.Collectors;
import jcuda.Pointer;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.gpu.context.GPUContext;
import org.apache.sysds.runtime.instructions.gpu.context.GPUObject;
import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;

/* loaded from: input_file:org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.class */
public class LineageGPUCacheEviction {
    private static GPUContext _gpuContext = null;
    public static ExecutorService gpuEvictionThread = null;
    private static HashMap<Long, TreeSet<LineageCacheEntry>> freeQueues = new HashMap<>();
    private static HashMap<Pointer, Integer> livePointers = new HashMap<>();
    private static HashMap<Pointer, LineageCacheEntry> GPUCacheEntries = new HashMap<>();

    /* JADX INFO: Access modifiers changed from: protected */
    public static void resetEviction() {
        gpuEvictionThread = null;
        freeQueues.clear();
        livePointers.clear();
        GPUCacheEntries.clear();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void incrementLiveCount(Pointer pointer) {
        if (livePointers.merge(pointer, 1, (v0, v1) -> {
            return Integer.sum(v0, v1);
        }).intValue() == 1) {
            freeQueues.get(Long.valueOf(getPointerSize(pointer))).remove(GPUCacheEntries.get(pointer));
        }
    }

    public static void decrementLiveCount(Pointer pointer) {
        if (livePointers.compute(pointer, (pointer2, num) -> {
            if (num.intValue() == 1) {
                return null;
            }
            return Integer.valueOf(num.intValue() - 1);
        }) == null) {
            long pointerSize = getPointerSize(pointer);
            if (!freeQueues.containsKey(Long.valueOf(pointerSize))) {
                freeQueues.put(Long.valueOf(pointerSize), new TreeSet<>(LineageCacheConfig.LineageGPUCacheComparator));
            }
            freeQueues.get(Long.valueOf(pointerSize)).add(GPUCacheEntries.get(pointer));
        }
    }

    public static boolean probeLiveCachedPointers(Pointer pointer) {
        return livePointers.containsKey(pointer);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void addEntry(LineageCacheEntry lineageCacheEntry) {
        if (lineageCacheEntry.isNullVal()) {
            return;
        }
        if (lineageCacheEntry.isScalarValue()) {
            throw new DMLRuntimeException("Scalars are never stored in GPU. Lineage: " + lineageCacheEntry._key);
        }
        lineageCacheEntry.initiateScoreGPU(LineageCacheEviction._removelist);
        livePointers.put(lineageCacheEntry.getGPUPointer(), 1);
        GPUCacheEntries.put(lineageCacheEntry.getGPUPointer(), lineageCacheEntry);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static void maintainOrder(LineageCacheEntry lineageCacheEntry) {
        if (lineageCacheEntry.getCacheStatus() != LineageCacheConfig.LineageCacheStatus.GPUCACHED) {
            return;
        }
        lineageCacheEntry.updateTimestamp();
    }

    protected static void removeSingleEntry(Map<LineageItem, LineageCacheEntry> map, LineageCacheEntry lineageCacheEntry) {
        map.remove(lineageCacheEntry._key);
        LineageCacheEviction._removelist.merge(lineageCacheEntry._key, 1, (v0, v1) -> {
            return Integer.sum(v0, v1);
        });
    }

    private static void removeEntry(LineageCacheEntry lineageCacheEntry) {
        Map<LineageItem, LineageCacheEntry> lineageCache = LineageCache.getLineageCache();
        if (lineageCacheEntry._origItem == null) {
            removeSingleEntry(lineageCache, lineageCacheEntry);
            return;
        }
        LineageCacheEntry lineageCacheEntry2 = lineageCache.get(lineageCacheEntry._origItem);
        while (true) {
            LineageCacheEntry lineageCacheEntry3 = lineageCacheEntry2;
            if (lineageCacheEntry3 == null) {
                return;
            }
            removeSingleEntry(lineageCache, lineageCacheEntry3);
            lineageCacheEntry2 = lineageCacheEntry3._nextEntry;
        }
    }

    public static void removeAllEntries(double d) {
        for (Long l : new ArrayList(freeQueues.keySet())) {
            int size = (int) (freeQueues.get(l).size() * d);
            int i = 1;
            LineageCacheEntry pollFirstFreeEntry = pollFirstFreeEntry(l.longValue());
            while (pollFirstFreeEntry != null) {
                _gpuContext.getMemoryManager().guardedCudaFree(pollFirstFreeEntry.getGPUPointer());
                if (DMLScript.STATISTICS) {
                    LineageCacheStatistics.incrementGpuDel();
                }
                pollFirstFreeEntry = pollFirstFreeEntry(l.longValue());
                if (i > size) {
                    break;
                } else {
                    i++;
                }
            }
        }
    }

    public static void setGPUContext(GPUContext gPUContext) {
        _gpuContext = gPUContext;
    }

    public static boolean isGPUCacheFreeQEmpty() {
        return freeQueues.isEmpty();
    }

    public static LineageCacheEntry pollFirstFreeEntry(long j) {
        TreeSet<LineageCacheEntry> treeSet = freeQueues.get(Long.valueOf(j));
        if (treeSet != null && treeSet.isEmpty()) {
            freeQueues.remove(Long.valueOf(j));
        }
        if (treeSet == null || treeSet.isEmpty()) {
            return null;
        }
        LineageCacheEntry pollFirst = treeSet.pollFirst();
        if (probeLiveCachedPointers(pollFirst.getGPUPointer())) {
            throw new DMLRuntimeException("Recycling live pointer: " + pollFirst._key);
        }
        removeEntry(pollFirst);
        GPUCacheEntries.remove(pollFirst.getGPUPointer());
        return pollFirst;
    }

    public static LineageCacheEntry pollFistFreeNotExact(long j) {
        ArrayList arrayList = new ArrayList(freeQueues.keySet());
        Collections.sort(arrayList);
        long longValue = ((Long) arrayList.get(arrayList.size() - 1)).longValue();
        if (j > longValue) {
            return pollFirstFreeEntry(longValue);
        }
        Iterator it = arrayList.iterator();
        while (it.hasNext()) {
            long longValue2 = ((Long) it.next()).longValue();
            if (longValue2 >= j) {
                return pollFirstFreeEntry(longValue2);
            }
        }
        return null;
    }

    public static int numPointersCached() {
        return freeQueues.values().stream().mapToInt((v0) -> {
            return v0.size();
        }).sum();
    }

    public static long totalMemoryCached() {
        long j = 0;
        Iterator<Map.Entry<Long, TreeSet<LineageCacheEntry>>> it = freeQueues.entrySet().iterator();
        while (it.hasNext()) {
            j += it.next().getKey().longValue() * r0.getValue().size();
        }
        return j;
    }

    protected static long getPointerSize(Pointer pointer) {
        return _gpuContext.getMemoryManager().getSizeAllocatedGPUPointer(pointer);
    }

    public static Set<Pointer> getAllCachedPointers() {
        HashSet hashSet = new HashSet();
        Iterator<Map.Entry<Long, TreeSet<LineageCacheEntry>>> it = freeQueues.entrySet().iterator();
        while (it.hasNext()) {
            hashSet.addAll((Collection) it.next().getValue().stream().map((v0) -> {
                return v0.getGPUPointer();
            }).collect(Collectors.toSet()));
        }
        return hashSet;
    }

    public static Pointer copyToHostCache(LineageCacheEntry lineageCacheEntry) {
        long nanoTime = System.nanoTime();
        MatrixBlock pointerToMatrixBlock = pointerToMatrixBlock(lineageCacheEntry);
        adjustD2HTransferSpeed(lineageCacheEntry.getSize(), (System.nanoTime() - nanoTime) / 1.0E9d);
        Pointer gPUPointer = lineageCacheEntry.getGPUPointer();
        long inMemorySize = pointerToMatrixBlock.getInMemorySize();
        synchronized (LineageCache.getLineageCache()) {
            if (!LineageCacheEviction.isBelowThreshold(inMemorySize)) {
                synchronized (LineageCache.getLineageCache()) {
                    LineageCacheEviction.makeSpace(LineageCache.getLineageCache(), inMemorySize);
                }
            }
            LineageCacheEviction.updateSize(inMemorySize, true);
            lineageCacheEntry.setValue(pointerToMatrixBlock);
            LineageCacheEviction.addEntry(lineageCacheEntry);
        }
        return gPUPointer;
    }

    private static void adjustD2HTransferSpeed(double d, double d2) {
        double d3 = (d / 1048576.0d) / d2;
        if (d3 > LineageCacheConfig.D2HMAXBANDWIDTH) {
            return;
        }
        LineageCacheConfig.D2HCOPYBANDWIDTH = (0.5d * d3) + ((1.0d - 0.5d) * LineageCacheConfig.D2HCOPYBANDWIDTH);
    }

    private static MatrixBlock pointerToMatrixBlock(LineageCacheEntry lineageCacheEntry) {
        DataCharacteristics dataCharacteristics = lineageCacheEntry.getDataCharacteristics();
        if (!lineageCacheEntry.isDensePointer()) {
            throw new DMLRuntimeException("Sparse pointers should not be cached in GPU. Lineage: " + lineageCacheEntry._key);
        }
        MatrixBlock matrixBlock = new MatrixBlock(GPUObject.toIntExact(dataCharacteristics.getRows()), GPUObject.toIntExact(dataCharacteristics.getCols()), false);
        matrixBlock.allocateDenseBlock();
        LibMatrixCUDA.cudaSupportFunctions.deviceToHost(_gpuContext, lineageCacheEntry.getGPUPointer(), matrixBlock.getDenseBlockValues(), null, true);
        matrixBlock.recomputeNonZeros();
        return matrixBlock;
    }

    public static void removeFromDeviceCache(LineageCacheEntry lineageCacheEntry, Pointer pointer, boolean z) {
        if (z) {
            LineageCache.removeEntry(lineageCacheEntry._key);
        }
        GPUCacheEntries.remove(pointer);
    }
}
