package org.apache.sysds.runtime.controlprogram.caching;

import java.io.IOException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.LazyWriteBuffer;
import org.apache.sysds.runtime.util.LocalFileUtils;

/* loaded from: input_file:org/apache/sysds/runtime/controlprogram/caching/UnifiedMemoryManager.class */
public class UnifiedMemoryManager {
    private static long _limit;
    private static long _totCachedSize;
    private static long _opMemLimit;
    private static CacheEvictionQueue _mQueue;
    private static CacheMaintenanceService _fClean;
    private static final List<String> _pinnedEntries = new ArrayList();
    private static long _pinnedPhysicalMemSize = 0;
    private static long _pinnedVirtualMemSize = 0;

    public static void pin(CacheableData<?> cacheableData) {
        if (CacheableData.isCachingActive()) {
            long estimateSize = OptimizerUtils.estimateSize(cacheableData.getDataCharacteristics());
            if (probe(cacheableData)) {
                _pinnedVirtualMemSize += estimateSize;
            } else {
                makeSpace(estimateSize);
                _pinnedPhysicalMemSize += estimateSize;
            }
            _pinnedEntries.add(cacheableData.getCacheFilePathAndName());
            reserveOutputMem();
        }
    }

    public static void reserveOutputMem() {
        if (OptimizerUtils.isUMMEnabled() && CacheableData.isCachingActive()) {
            makeSpace(_opMemLimit - (_pinnedVirtualMemSize + _pinnedPhysicalMemSize));
        }
    }

    public static void unpin(CacheableData<?> cacheableData) {
        if (CacheableData.isCachingActive() && _pinnedEntries.contains(cacheableData.getCacheFilePathAndName())) {
            long estimateSize = OptimizerUtils.estimateSize(cacheableData.getDataCharacteristics());
            if (probe(cacheableData)) {
                _pinnedVirtualMemSize -= estimateSize;
            } else {
                _pinnedPhysicalMemSize -= estimateSize;
            }
            _pinnedEntries.remove(cacheableData.getCacheFilePathAndName());
        }
    }

    public static void init() {
        _mQueue = new CacheEvictionQueue();
        _fClean = new CacheMaintenanceService();
        _limit = OptimizerUtils.getBufferPoolLimit();
        _opMemLimit = (long) OptimizerUtils.getLocalMemBudget();
        _totCachedSize = 0L;
        _pinnedPhysicalMemSize = 0L;
        _pinnedVirtualMemSize = 0L;
    }

    public static void cleanup() {
        if (_mQueue != null) {
            _mQueue.clear();
        }
        if (_fClean != null) {
            _fClean.close();
        }
        _totCachedSize = 0L;
        _pinnedPhysicalMemSize = 0L;
        _pinnedVirtualMemSize = 0L;
    }

    public static void printStatus(String str) {
        System.out.println("UMM STATUS AT " + str + " --");
        synchronized (_mQueue) {
            PrintStream printStream = System.out;
            long j = _limit;
            long j2 = _totCachedSize;
            int size = _mQueue.size();
            _pinnedEntries.size();
            long j3 = _pinnedPhysicalMemSize;
            printStream.println("\tUMM: Meta Data: UMM limit=" + j + ", size[bytes]=" + printStream + ", size[elements]=" + j2 + ", pinned[elements]=" + printStream + ", pinned[bytes]=" + size);
            int size2 = _mQueue.size();
            for (Map.Entry<String, ByteBuffer> entry : _mQueue.entrySet()) {
                String key = entry.getKey();
                ByteBuffer value = entry.getValue();
                PrintStream printStream2 = System.out;
                int i = size2;
                String simpleName = value.isShallow() ? value._cdata.getClass().getSimpleName() : "?";
                long size3 = value.getSize();
                value.isShallow();
                printStream2.println("\tUMM: Cached element (" + i + "): " + key + ", " + simpleName + ", " + size3 + ", " + printStream2);
                size2--;
            }
        }
    }

    public static void setUMMLimit(long j) {
        _limit = j;
    }

    public static long getUMMSize() {
        long j;
        synchronized (_mQueue) {
            j = _limit;
        }
        return j;
    }

    public static long getUMMFree() {
        long j;
        synchronized (_mQueue) {
            j = _limit - (_totCachedSize + _pinnedPhysicalMemSize);
        }
        return j;
    }

    public static CacheBlock<?> readBlock(String str, boolean z) throws IOException {
        ByteBuffer byteBuffer;
        CacheBlock<?> readCacheBlockFromLocal;
        synchronized (_mQueue) {
            byteBuffer = _mQueue.get(str);
            if (CacheableData.CACHING_BUFFER_POLICY == LazyWriteBuffer.RPolicy.LRU && byteBuffer != null) {
                _mQueue.remove(str);
                _mQueue.addLast(str, byteBuffer);
            }
        }
        if (byteBuffer != null) {
            readCacheBlockFromLocal = byteBuffer.deserializeBlock();
            if (DMLScript.STATISTICS) {
                CacheStatistics.incrementFSBuffHits();
            }
        } else {
            readCacheBlockFromLocal = LocalFileUtils.readCacheBlockFromLocal(str, z);
            if (DMLScript.STATISTICS) {
                CacheStatistics.incrementFSHits();
            }
        }
        return readCacheBlockFromLocal;
    }

    public static boolean probe(CacheableData<?> cacheableData) {
        return _mQueue.containsKey(cacheableData.getCacheFilePathAndName());
    }

    public static int makeSpace(long j) {
        int i = 0;
        if (getUMMFree() > j) {
            return 0;
        }
        try {
            synchronized (_mQueue) {
                while (getUMMFree() < j && !_mQueue.isEmpty()) {
                    Map.Entry<String, ByteBuffer> removeFirstUnpinned = _mQueue.removeFirstUnpinned(_pinnedEntries);
                    String key = removeFirstUnpinned.getKey();
                    ByteBuffer value = removeFirstUnpinned.getValue();
                    if (value != null) {
                        value.checkSerialized();
                        value.evictBuffer(key);
                        value.freeMemory();
                        _totCachedSize -= value.getSize();
                        i++;
                    }
                }
            }
            if (DMLScript.STATISTICS) {
                CacheStatistics.incrementFSWrites(i);
            }
            return i;
        } catch (Exception e) {
            throw new DMLRuntimeException("Eviction request of size " + (j - getUMMFree()) + " in the UMM failed.", e);
        }
    }

    public static int writeBlock(String str, CacheBlock<?> cacheBlock) throws IOException {
        int i;
        long cacheBlockSize = getCacheBlockSize(cacheBlock);
        if (cacheBlockSize > _limit || !ByteBuffer.isValidCapacity(cacheBlockSize, cacheBlock)) {
            LocalFileUtils.writeCacheBlockToLocal(str, cacheBlock);
            if (DMLScript.STATISTICS) {
                CacheStatistics.incrementFSWrites();
            }
            i = 0 + 1;
        } else {
            ByteBuffer byteBuffer = new ByteBuffer(cacheBlockSize);
            synchronized (_mQueue) {
                i = 0 + makeSpace(cacheBlockSize);
                _mQueue.addLast(str, byteBuffer);
                _totCachedSize += cacheBlockSize;
            }
            _fClean.serializeData(byteBuffer, cacheBlock);
            if (DMLScript.STATISTICS) {
                CacheStatistics.incrementBPoolWrites();
            }
        }
        return i;
    }

    public static long getCacheBlockSize(CacheBlock<?> cacheBlock) {
        return cacheBlock.isShallowSerialize() ? cacheBlock.getInMemorySize() : cacheBlock.getExactSerializedSize();
    }

    public static void deleteBlock(String str) {
        boolean z = true;
        synchronized (_mQueue) {
            ByteBuffer byteBuffer = (ByteBuffer) _mQueue.remove(str);
            if (byteBuffer != null) {
                _totCachedSize -= byteBuffer.getSize();
                z = false;
                byteBuffer.freeMemory();
            }
        }
        if (z) {
            _fClean.deleteFile(str);
        }
    }

    public void deleteAll() {
        throw new NotImplementedException();
    }

    public static void forceEviction() throws IOException {
        while (!_mQueue.isEmpty()) {
            Map.Entry<String, ByteBuffer> removeFirst = _mQueue.removeFirst();
            ByteBuffer value = removeFirst.getValue();
            if (value != null) {
                value.checkSerialized();
                value.evictBuffer(removeFirst.getKey());
                value.freeMemory();
            }
        }
    }
}
