package org.nd4j.linalg.memory.abstracts;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.lang.ref.ReferenceQueue;
import java.lang.ref.WeakReference;
import java.util.ArrayList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.LinkedTransferQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import lombok.NonNull;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.Pointer;
import org.nd4j.linalg.api.buffer.DataBuffer;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.memory.conf.WorkspaceConfiguration;
import org.nd4j.linalg.api.memory.enums.AllocationPolicy;
import org.nd4j.linalg.api.memory.enums.LearningPolicy;
import org.nd4j.linalg.api.memory.enums.LocationPolicy;
import org.nd4j.linalg.api.memory.enums.MemoryKind;
import org.nd4j.linalg.api.memory.enums.ResetPolicy;
import org.nd4j.linalg.api.memory.enums.SpillPolicy;
import org.nd4j.linalg.api.memory.pointers.PagedPointer;
import org.nd4j.linalg.api.memory.pointers.PointersPair;
import org.nd4j.linalg.exception.ND4JIllegalStateException;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.memory.MemoryManager;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/nd4j/linalg/memory/abstracts/Nd4jWorkspace.class */
public abstract class Nd4jWorkspace implements MemoryWorkspace {
    private static final Logger log = LoggerFactory.getLogger(Nd4jWorkspace.class);
    protected int deviceId;
    protected Long threadId;
    protected MemoryWorkspace.Type workspaceType;
    protected static final long SAFETY_OFFSET = 1024;
    protected String id;
    protected AtomicLong currentSize;
    protected AtomicLong hostOffset;
    protected AtomicLong deviceOffset;
    protected PointersPair workspace;
    protected MemoryManager memoryManager;
    protected AtomicBoolean isLearning;
    protected AtomicBoolean isUsed;
    protected AtomicLong disabledCounter;
    protected AtomicLong cyclesCount;
    protected AtomicLong stepsCount;
    protected int stepsNumber;
    protected AtomicLong lastCycleAllocations;
    protected AtomicLong cycleAllocations;
    protected AtomicLong spilledAllocationsSize;
    protected AtomicLong pinnedAllocationsSize;
    protected AtomicLong maxCycle;
    protected AtomicBoolean resetPlanned;
    protected AtomicBoolean isOpen;
    protected AtomicBoolean isInit;
    protected AtomicBoolean isOver;
    protected AtomicBoolean isBorrowed;
    protected AtomicInteger tagScope;
    protected AtomicBoolean isDebug;
    protected AtomicInteger externalCount;
    protected AtomicInteger pinnedCount;
    protected AtomicBoolean trimmedMode;
    protected AtomicLong trimmedStep;
    protected final WorkspaceConfiguration workspaceConfiguration;
    protected List<PointersPair> externalAllocations;
    protected Queue<PointersPair> pinnedAllocations;
    protected MemoryWorkspace previousWorkspace;
    protected MemoryWorkspace borrowingWorkspace;
    protected AtomicLong initialBlockSize;
    protected String guid;
    protected File tempFile;
    protected AtomicLong generationId;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.nd4j.linalg.memory.abstracts.Nd4jWorkspace$1, reason: invalid class name */
    /* loaded from: input_file:org/nd4j/linalg/memory/abstracts/Nd4jWorkspace$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$nd4j$linalg$api$memory$enums$SpillPolicy = new int[SpillPolicy.values().length];

        static {
            try {
                $SwitchMap$org$nd4j$linalg$api$memory$enums$SpillPolicy[SpillPolicy.REALLOCATE.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$memory$enums$SpillPolicy[SpillPolicy.EXTERNAL.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$nd4j$linalg$api$memory$enums$SpillPolicy[SpillPolicy.FAIL.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
        }
    }

    /* loaded from: input_file:org/nd4j/linalg/memory/abstracts/Nd4jWorkspace$GarbageWorkspaceReference.class */
    public static class GarbageWorkspaceReference extends WeakReference<MemoryWorkspace> {
        private PointersPair pointersPair;
        private String id;
        private Long threadId;
        private Queue<PointersPair> pinnedPointers;
        private List<PointersPair> externalPointers;
        private String key;

        public GarbageWorkspaceReference(MemoryWorkspace memoryWorkspace, ReferenceQueue<? super MemoryWorkspace> referenceQueue) {
            super(memoryWorkspace, referenceQueue);
            this.pointersPair = ((Nd4jWorkspace) memoryWorkspace).workspace;
            this.id = memoryWorkspace.getId();
            this.threadId = memoryWorkspace.getThreadId();
            this.pinnedPointers = ((Nd4jWorkspace) memoryWorkspace).pinnedAllocations;
            this.externalPointers = ((Nd4jWorkspace) memoryWorkspace).externalAllocations;
            this.key = this.id + "_" + this.threadId;
        }

        public PointersPair getPointersPair() {
            return this.pointersPair;
        }

        public String getId() {
            return this.id;
        }

        public Long getThreadId() {
            return this.threadId;
        }

        public Queue<PointersPair> getPinnedPointers() {
            return this.pinnedPointers;
        }

        public List<PointersPair> getExternalPointers() {
            return this.externalPointers;
        }

        public String getKey() {
            return this.key;
        }

        public void setPointersPair(PointersPair pointersPair) {
            this.pointersPair = pointersPair;
        }

        public void setId(String str) {
            this.id = str;
        }

        public void setThreadId(Long l) {
            this.threadId = l;
        }

        public void setPinnedPointers(Queue<PointersPair> queue) {
            this.pinnedPointers = queue;
        }

        public void setExternalPointers(List<PointersPair> list) {
            this.externalPointers = list;
        }

        public void setKey(String str) {
            this.key = str;
        }

        public boolean equals(Object obj) {
            if (obj == this) {
                return true;
            }
            if (!(obj instanceof GarbageWorkspaceReference)) {
                return false;
            }
            GarbageWorkspaceReference garbageWorkspaceReference = (GarbageWorkspaceReference) obj;
            if (!garbageWorkspaceReference.canEqual(this)) {
                return false;
            }
            PointersPair pointersPair = getPointersPair();
            PointersPair pointersPair2 = garbageWorkspaceReference.getPointersPair();
            if (pointersPair == null) {
                if (pointersPair2 != null) {
                    return false;
                }
            } else if (!pointersPair.equals(pointersPair2)) {
                return false;
            }
            String id = getId();
            String id2 = garbageWorkspaceReference.getId();
            if (id == null) {
                if (id2 != null) {
                    return false;
                }
            } else if (!id.equals(id2)) {
                return false;
            }
            Long threadId = getThreadId();
            Long threadId2 = garbageWorkspaceReference.getThreadId();
            if (threadId == null) {
                if (threadId2 != null) {
                    return false;
                }
            } else if (!threadId.equals(threadId2)) {
                return false;
            }
            Queue<PointersPair> pinnedPointers = getPinnedPointers();
            Queue<PointersPair> pinnedPointers2 = garbageWorkspaceReference.getPinnedPointers();
            if (pinnedPointers == null) {
                if (pinnedPointers2 != null) {
                    return false;
                }
            } else if (!pinnedPointers.equals(pinnedPointers2)) {
                return false;
            }
            List<PointersPair> externalPointers = getExternalPointers();
            List<PointersPair> externalPointers2 = garbageWorkspaceReference.getExternalPointers();
            if (externalPointers == null) {
                if (externalPointers2 != null) {
                    return false;
                }
            } else if (!externalPointers.equals(externalPointers2)) {
                return false;
            }
            String key = getKey();
            String key2 = garbageWorkspaceReference.getKey();
            return key == null ? key2 == null : key.equals(key2);
        }

        protected boolean canEqual(Object obj) {
            return obj instanceof GarbageWorkspaceReference;
        }

        public int hashCode() {
            PointersPair pointersPair = getPointersPair();
            int hashCode = (1 * 59) + (pointersPair == null ? 43 : pointersPair.hashCode());
            String id = getId();
            int hashCode2 = (hashCode * 59) + (id == null ? 43 : id.hashCode());
            Long threadId = getThreadId();
            int hashCode3 = (hashCode2 * 59) + (threadId == null ? 43 : threadId.hashCode());
            Queue<PointersPair> pinnedPointers = getPinnedPointers();
            int hashCode4 = (hashCode3 * 59) + (pinnedPointers == null ? 43 : pinnedPointers.hashCode());
            List<PointersPair> externalPointers = getExternalPointers();
            int hashCode5 = (hashCode4 * 59) + (externalPointers == null ? 43 : externalPointers.hashCode());
            String key = getKey();
            return (hashCode5 * 59) + (key == null ? 43 : key.hashCode());
        }

        public String toString() {
            return "Nd4jWorkspace.GarbageWorkspaceReference(pointersPair=" + getPointersPair() + ", id=" + getId() + ", threadId=" + getThreadId() + ", pinnedPointers=" + getPinnedPointers() + ", externalPointers=" + getExternalPointers() + ", key=" + getKey() + ")";
        }
    }

    public Nd4jWorkspace(@NonNull WorkspaceConfiguration workspaceConfiguration) {
        this(workspaceConfiguration, "DefaultWorkspace");
        if (workspaceConfiguration == null) {
            throw new NullPointerException("configuration");
        }
    }

    public Nd4jWorkspace(@NonNull WorkspaceConfiguration workspaceConfiguration, @NonNull String str) {
        this.workspaceType = MemoryWorkspace.Type.SCOPED;
        this.currentSize = new AtomicLong(0L);
        this.hostOffset = new AtomicLong(0L);
        this.deviceOffset = new AtomicLong(0L);
        this.workspace = new PointersPair();
        this.isLearning = new AtomicBoolean(true);
        this.isUsed = new AtomicBoolean(true);
        this.disabledCounter = new AtomicLong(0L);
        this.cyclesCount = new AtomicLong(0L);
        this.stepsCount = new AtomicLong(0L);
        this.stepsNumber = 1;
        this.lastCycleAllocations = new AtomicLong(0L);
        this.cycleAllocations = new AtomicLong(0L);
        this.spilledAllocationsSize = new AtomicLong(0L);
        this.pinnedAllocationsSize = new AtomicLong(0L);
        this.maxCycle = new AtomicLong(0L);
        this.resetPlanned = new AtomicBoolean(false);
        this.isOpen = new AtomicBoolean(false);
        this.isInit = new AtomicBoolean(false);
        this.isOver = new AtomicBoolean(false);
        this.isBorrowed = new AtomicBoolean(false);
        this.tagScope = new AtomicInteger(0);
        this.isDebug = new AtomicBoolean(false);
        this.externalCount = new AtomicInteger(0);
        this.pinnedCount = new AtomicInteger(0);
        this.trimmedMode = new AtomicBoolean(false);
        this.trimmedStep = new AtomicLong(0L);
        this.externalAllocations = new ArrayList();
        this.pinnedAllocations = new LinkedTransferQueue();
        this.initialBlockSize = new AtomicLong(0L);
        this.generationId = new AtomicLong(0L);
        if (workspaceConfiguration == null) {
            throw new NullPointerException("configuration");
        }
        if (str == null) {
            throw new NullPointerException("workspaceId");
        }
        this.workspaceConfiguration = workspaceConfiguration;
        this.id = str;
        this.threadId = Long.valueOf(Thread.currentThread().getId());
        this.guid = Nd4j.getWorkspaceManager().getUUID();
        this.memoryManager = Nd4j.getMemoryManager();
        this.deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread().intValue();
        this.currentSize.set(this.workspaceConfiguration.getInitialSize());
        if (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED) {
            this.workspaceType = MemoryWorkspace.Type.CIRCULAR;
        } else {
            this.workspaceType = MemoryWorkspace.Type.SCOPED;
        }
        if (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && this.workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE) {
            if (this.workspaceConfiguration.getOverallocationLimit() < 1.0d) {
                throw new ND4JIllegalStateException("For cyclic workspace overallocation should be positive integral value.");
            }
            this.stepsNumber = (int) (this.workspaceConfiguration.getOverallocationLimit() + 1.0d);
            log.debug("Steps: {}", Integer.valueOf(this.stepsNumber));
        }
        if (workspaceConfiguration.getPolicyLocation() == LocationPolicy.MMAP) {
            if (workspaceConfiguration.getTempFilePath() != null) {
                this.tempFile = new File(workspaceConfiguration.getTempFilePath());
                if (this.tempFile.length() != 0 && this.tempFile.length() >= workspaceConfiguration.getInitialSize()) {
                    workspaceConfiguration.setInitialSize(this.tempFile.length());
                } else {
                    if (workspaceConfiguration.getInitialSize() <= 0) {
                        throw new ND4JIllegalStateException("Memory-mapped file should have positive length.");
                    }
                    try {
                        fillFile(this.tempFile, workspaceConfiguration.getInitialSize());
                    } catch (Exception e) {
                        throw new RuntimeException(e);
                    }
                }
            } else {
                if (workspaceConfiguration.getInitialSize() <= 0) {
                    throw new ND4JIllegalStateException("MMAP target file path should be non-null or workspace initialSize should be >0 for temp file");
                }
                try {
                    this.tempFile = File.createTempFile("workspace", "tempMMAP");
                    this.tempFile.deleteOnExit();
                    fillFile(this.tempFile, workspaceConfiguration.getInitialSize());
                } catch (Exception e2) {
                    throw new RuntimeException(e2);
                }
            }
        }
        init();
    }

    public MemoryWorkspace.Type getWorkspaceType() {
        return this.workspaceType;
    }

    /* JADX WARN: Failed to calculate best type for var: r11v1 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Failed to calculate best type for var: r12v0 ??
    java.lang.NullPointerException
     */
    /* JADX WARN: Multi-variable type inference failed. Error: java.lang.NullPointerException: Cannot invoke "jadx.core.dex.instructions.args.RegisterArg.getSVar()" because the return value of "jadx.core.dex.nodes.InsnNode.getResult()" is null
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.collectRelatedVars(AbstractTypeConstraint.java:31)
    	at jadx.core.dex.visitors.typeinference.AbstractTypeConstraint.<init>(AbstractTypeConstraint.java:19)
    	at jadx.core.dex.visitors.typeinference.TypeSearch$1.<init>(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeMoveConstraint(TypeSearch.java:376)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.makeConstraint(TypeSearch.java:361)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.collectConstraints(TypeSearch.java:341)
    	at java.base/java.util.ArrayList.forEach(ArrayList.java:1596)
    	at jadx.core.dex.visitors.typeinference.TypeSearch.run(TypeSearch.java:60)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.runMultiVariableSearch(FixTypesVisitor.java:116)
    	at jadx.core.dex.visitors.typeinference.FixTypesVisitor.visit(FixTypesVisitor.java:91)
     */
    /* JADX WARN: Not initialized variable reg: 11, insn: 0x0082: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r11 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) A[TRY_LEAVE], block:B:47:0x0082 */
    /* JADX WARN: Not initialized variable reg: 12, insn: 0x0087: MOVE (r0 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]) = (r12 I:??[int, float, boolean, short, byte, char, OBJECT, ARRAY]), block:B:49:0x0087 */
    /* JADX WARN: Type inference failed for: r11v1, types: [java.io.BufferedOutputStream] */
    /* JADX WARN: Type inference failed for: r12v0, types: [java.lang.Throwable] */
    public static void fillFile(File file, long j) throws Exception {
        ?? r11;
        ?? r12;
        byte[] bArr = new byte[16384];
        for (int i = 0; i < bArr.length; i++) {
            bArr[i] = 0;
        }
        FileOutputStream fileOutputStream = new FileOutputStream(file);
        Throwable th = null;
        try {
            try {
                BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(fileOutputStream);
                Throwable th2 = null;
                for (long j2 = 0; j2 < j; j2 += bArr.length) {
                    fileOutputStream.write(bArr);
                }
                if (bufferedOutputStream != null) {
                    if (0 != 0) {
                        try {
                            bufferedOutputStream.close();
                        } catch (Throwable th3) {
                            th2.addSuppressed(th3);
                        }
                    } else {
                        bufferedOutputStream.close();
                    }
                }
                if (fileOutputStream != null) {
                    if (0 == 0) {
                        fileOutputStream.close();
                        return;
                    }
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                }
            } catch (Throwable th5) {
                if (r11 != 0) {
                    if (r12 != 0) {
                        try {
                            r11.close();
                        } catch (Throwable th6) {
                            r12.addSuppressed(th6);
                        }
                    } else {
                        r11.close();
                    }
                }
                throw th5;
            }
        } catch (Throwable th7) {
            if (fileOutputStream != null) {
                if (0 != 0) {
                    try {
                        fileOutputStream.close();
                    } catch (Throwable th8) {
                        th.addSuppressed(th8);
                    }
                } else {
                    fileOutputStream.close();
                }
            }
            throw th7;
        }
    }

    public long getGenerationId() {
        return this.generationId.get();
    }

    public long getStepNumber() {
        return this.stepsCount.get();
    }

    public long getSpilledSize() {
        return this.spilledAllocationsSize.get();
    }

    public long getPinnedSize() {
        return this.pinnedAllocationsSize.get();
    }

    public long getInitialBlockSize() {
        return this.initialBlockSize.get();
    }

    public MemoryWorkspace getParentWorkspace() {
        return this.previousWorkspace;
    }

    public long getDeviceOffset() {
        return this.deviceOffset.get();
    }

    public long getHostOffset() {
        return this.hostOffset.get();
    }

    public long getCurrentSize() {
        return this.currentSize.get();
    }

    protected void init() {
        if (this.currentSize.get() > 0) {
            if (!this.isOver.get() && this.workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE && this.workspaceConfiguration.getOverallocationLimit() > 0.0d) {
                this.currentSize.addAndGet((long) (this.currentSize.get() * this.workspaceConfiguration.getOverallocationLimit()));
                this.isOver.set(true);
            }
            if (this.workspaceConfiguration.getMaxSize() <= 0 || this.currentSize.get() <= this.workspaceConfiguration.getMaxSize()) {
                return;
            }
            this.currentSize.set(this.workspaceConfiguration.getMaxSize());
        }
    }

    public PagedPointer alloc(long j, DataBuffer.Type type, boolean z) {
        return alloc(j, MemoryKind.HOST, type, z);
    }

    public void enableDebug(boolean z) {
        this.isDebug.set(z);
    }

    public PagedPointer alloc(long j, MemoryKind memoryKind, DataBuffer.Type type, boolean z) {
        long j2 = j % 8;
        if (j2 != 0) {
            j += j2;
        }
        long sizeOfDataType = j / Nd4j.sizeOfDataType(type);
        if (!this.isUsed.get()) {
            if (this.disabledCounter.incrementAndGet() % 10 == 0) {
                log.warn("Workspace was turned off, and wasn't enabled after {} allocations", Long.valueOf(this.disabledCounter.get()));
            }
            PagedPointer pagedPointer = new PagedPointer(this.memoryManager.allocate(j, MemoryKind.HOST, z), sizeOfDataType);
            this.externalAllocations.add(new PointersPair(pagedPointer, (PagedPointer) null));
            return pagedPointer;
        }
        boolean z2 = (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && j + this.cycleAllocations.get() > this.initialBlockSize.get() && this.initialBlockSize.get() > 0) || this.trimmedMode.get();
        if (z2 && this.workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE && !this.trimmedMode.get()) {
            this.trimmedMode.set(true);
            this.trimmedStep.set(this.stepsCount.get());
        }
        if (this.hostOffset.get() + j <= this.currentSize.get() && !z2) {
            this.cycleAllocations.addAndGet(j);
            long andAdd = this.hostOffset.getAndAdd(j);
            this.deviceOffset.set(this.hostOffset.get());
            PagedPointer withOffset = this.workspace.getHostPointer().withOffset(andAdd, sizeOfDataType);
            if (this.isDebug.get()) {
                log.info("Workspace [{}]: Allocating array of {} bytes, capacity of {} elements, prevOffset: {}; currentOffset: {}; address: {}", new Object[]{this.id, Long.valueOf(j), Long.valueOf(sizeOfDataType), Long.valueOf(andAdd), Long.valueOf(this.hostOffset.get()), Long.valueOf(withOffset.address())});
            }
            if (z) {
                Pointer.memset(withOffset, 0, j);
            }
            return withOffset;
        }
        if (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && this.currentSize.get() > 0 && !z2) {
            reset();
            this.resetPlanned.set(true);
            return alloc(j, memoryKind, type, z);
        }
        if (z2) {
            this.pinnedAllocationsSize.addAndGet(j);
        } else {
            this.spilledAllocationsSize.addAndGet(j);
        }
        if (this.isDebug.get()) {
            log.info("Workspace [{}]: step: {}, spilled  {} bytes, capacity of {} elements", new Object[]{this.id, Long.valueOf(this.stepsCount.get()), Long.valueOf(j), Long.valueOf(sizeOfDataType)});
        }
        switch (AnonymousClass1.$SwitchMap$org$nd4j$linalg$api$memory$enums$SpillPolicy[this.workspaceConfiguration.getPolicySpill().ordinal()]) {
            case 1:
            case 2:
                this.cycleAllocations.addAndGet(j);
                if (z2) {
                    this.pinnedCount.incrementAndGet();
                    PagedPointer pagedPointer2 = new PagedPointer(this.memoryManager.allocate(j, MemoryKind.HOST, z), sizeOfDataType);
                    this.pinnedAllocations.add(new PointersPair(Long.valueOf(this.stepsCount.get()), Long.valueOf(j), pagedPointer2, (PagedPointer) null));
                    return pagedPointer2;
                }
                this.externalCount.incrementAndGet();
                PagedPointer pagedPointer3 = new PagedPointer(this.memoryManager.allocate(j, MemoryKind.HOST, z), sizeOfDataType);
                this.externalAllocations.add(new PointersPair(pagedPointer3, (PagedPointer) null));
                return pagedPointer3;
            case 3:
            default:
                throw new ND4JIllegalStateException("Can't allocate memory: Workspace is full");
        }
    }

    public void free(Pointer pointer) {
    }

    public void initializeWorkspace() {
        if ((this.currentSize.get() < this.maxCycle.get() || this.currentSize.get() < this.cycleAllocations.get()) && this.workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE && ((this.workspaceConfiguration.getMaxSize() == 0 || this.maxCycle.get() < this.workspaceConfiguration.getMaxSize()) && this.workspaceConfiguration.getPolicyReset() != ResetPolicy.ENDOFBUFFER_REACHED)) {
            destroyWorkspace(true);
            this.isInit.set(false);
        }
        if (this.trimmedMode.get() && this.trimmedStep.get() + 2 < this.stepsCount.get()) {
            destroyWorkspace(false);
            this.isInit.set(false);
            this.isOver.set(false);
        }
        if (this.isInit.get() || this.workspaceConfiguration.getPolicyLearning() == LearningPolicy.NONE) {
            return;
        }
        if (this.workspaceConfiguration.getMaxSize() > 0) {
            this.currentSize.set(Math.min(this.maxCycle.get(), this.workspaceConfiguration.getMaxSize()));
        } else {
            this.currentSize.set(this.maxCycle.get());
        }
        if (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED) {
            this.currentSize.set((long) (this.currentSize.get() * 1.3d));
            this.currentSize.addAndGet(8 - (this.currentSize.get() % 8));
            this.maxCycle.set(this.currentSize.get());
        }
        this.initialBlockSize.set(this.currentSize.get());
        if (!this.isOver.get() && this.workspaceConfiguration.getPolicyAllocation() == AllocationPolicy.OVERALLOCATE && this.workspaceConfiguration.getOverallocationLimit() > 0.0d && this.currentSize.get() > 0) {
            this.currentSize.set(this.currentSize.get() + ((long) (this.currentSize.get() * this.workspaceConfiguration.getOverallocationLimit())));
            this.isOver.set(true);
        }
        if (this.workspaceConfiguration.getMinSize() > 0 && this.currentSize.get() < this.workspaceConfiguration.getMinSize()) {
            this.currentSize.set(this.workspaceConfiguration.getMinSize());
        }
        if (this.externalCount.get() > 0 && (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.BLOCK_LEFT || this.resetPlanned.get())) {
            clearExternalAllocations();
            this.resetPlanned.set(false);
        }
        init();
    }

    public int getNumberOfExternalAllocations() {
        return this.externalCount.get();
    }

    public int getNumberOfPinnedAllocations() {
        return this.pinnedCount.get();
    }

    public void destroyWorkspace() {
        destroyWorkspace(true);
    }

    public void destroyWorkspace(boolean z) {
        if (this.workspace.getHostPointer() != null && this.workspace.getHostPointer().getOriginalPointer() != null && (this.workspace.getHostPointer().getOriginalPointer() instanceof BytePointer)) {
            this.workspace.getHostPointer().getOriginalPointer().deallocate();
        }
        this.workspace.setHostPointer((PagedPointer) null);
        this.currentSize.set(0L);
        reset();
        if (z) {
            clearExternalAllocations();
        }
    }

    public MemoryWorkspace notifyScopeBorrowed() {
        if (this.isBorrowed.get()) {
            throw new ND4JIllegalStateException("Workspace [" + this.id + "]: Can't borrow from borrowed workspace");
        }
        this.borrowingWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        this.isBorrowed.set(true);
        Nd4j.getMemoryManager().setCurrentWorkspace(this);
        return this;
    }

    public long getCyclesCount() {
        return this.cyclesCount.get();
    }

    public void close() {
        if (this.isBorrowed.get()) {
            this.isBorrowed.set(false);
            Nd4j.getMemoryManager().setCurrentWorkspace(this.borrowingWorkspace);
            return;
        }
        if (this.tagScope.get() > 0) {
            if (this.tagScope.decrementAndGet() == 0) {
                Nd4j.getMemoryManager().setCurrentWorkspace(this);
                return;
            }
            return;
        }
        Nd4j.getExecutioner().commit();
        Nd4j.getMemoryManager().setCurrentWorkspace(this.previousWorkspace);
        this.isOpen.set(false);
        this.cyclesCount.incrementAndGet();
        if ((this.cyclesCount.get() > 1) & ((this.cyclesCount.get() - 1) % ((long) this.stepsNumber) == 0)) {
            this.stepsCount.incrementAndGet();
        }
        if (!this.isUsed.get()) {
            log.warn("Workspace was turned off, and wasn't ever turned on back again");
            this.isUsed.set(true);
        }
        if (this.cycleAllocations.get() > this.maxCycle.get()) {
            if (this.isDebug.get()) {
                log.info("Workspace [{}] device_{}, current cycle: {}; max cycle: {}", new Object[]{this.id, Nd4j.getAffinityManager().getDeviceForCurrentThread(), Long.valueOf(this.cycleAllocations.get()), Long.valueOf(this.maxCycle.get())});
            }
            this.maxCycle.set(this.cycleAllocations.get());
        }
        if (this.workspaceConfiguration.getPolicyLearning() != LearningPolicy.NONE && this.maxCycle.get() > 0) {
            if (this.externalCount.get() > 0 && (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.BLOCK_LEFT || this.resetPlanned.get())) {
                clearExternalAllocations();
                this.resetPlanned.set(false);
            }
            if ((this.workspaceConfiguration.getPolicyLearning() == LearningPolicy.OVER_TIME && this.workspaceConfiguration.getCyclesBeforeInitialization() == this.cyclesCount.intValue()) || (this.workspaceConfiguration.getPolicyLearning() == LearningPolicy.FIRST_LOOP && this.currentSize.get() == 0)) {
                initializeWorkspace();
            } else if (this.currentSize.get() > 0 && this.cycleAllocations.get() > 0 && this.workspaceConfiguration.getPolicySpill() == SpillPolicy.REALLOCATE && this.workspaceConfiguration.getPolicyReset() != ResetPolicy.ENDOFBUFFER_REACHED) {
                initializeWorkspace();
            }
        }
        if (this.pinnedCount.get() > 0) {
            clearPinnedAllocations(false);
        }
        if (this.trimmedMode.get() && this.trimmedStep.get() + 2 < this.stepsCount.get()) {
            this.initialBlockSize.set(this.maxCycle.get());
            initializeWorkspace();
            this.trimmedMode.set(false);
            this.trimmedStep.set(0L);
            reset();
        }
        this.lastCycleAllocations.set(this.cycleAllocations.get());
        this.disabledCounter.set(0L);
        if (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.BLOCK_LEFT) {
            reset();
        } else if (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.ENDOFBUFFER_REACHED && this.currentSize.get() > 0) {
            long j = this.initialBlockSize.get() - this.cycleAllocations.get();
            if (j > 0 && !this.trimmedMode.get() && this.deviceOffset.get() > 0) {
                if (this.isDebug.get()) {
                    log.info("Worskpace [{}]: Align to [{}]; diff: [{}]; block size: [{}]; currentOffset: [{}]; workspaceSize: [{}]; trimmedMode: {}", new Object[]{this.id, Long.valueOf(this.initialBlockSize.get()), Long.valueOf(j), Long.valueOf(this.cycleAllocations.get()), Long.valueOf(this.deviceOffset.get()), Long.valueOf(this.currentSize.get()), Boolean.valueOf(this.trimmedMode.get())});
                }
                this.deviceOffset.getAndAdd(j);
                this.hostOffset.getAndAdd(j);
            }
        }
        this.cycleAllocations.set(0L);
    }

    protected abstract void clearPinnedAllocations(boolean z);

    protected abstract void clearExternalAllocations();

    public MemoryWorkspace notifyScopeEntered() {
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        if (currentWorkspace == this && this.isOpen.get()) {
            this.tagScope.incrementAndGet();
            return this;
        }
        this.previousWorkspace = currentWorkspace;
        Nd4j.getMemoryManager().setCurrentWorkspace(this);
        this.isOpen.set(true);
        if (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.BLOCK_LEFT) {
            reset();
        }
        if (this.externalCount.get() > 0 && (this.workspaceConfiguration.getPolicyReset() == ResetPolicy.BLOCK_LEFT || this.resetPlanned.get())) {
            clearExternalAllocations();
            this.resetPlanned.set(false);
        }
        this.cycleAllocations.set(0L);
        this.disabledCounter.set(0L);
        this.generationId.incrementAndGet();
        return this;
    }

    public void reset() {
        this.hostOffset.set(0L);
        this.deviceOffset.set(0L);
    }

    protected abstract void resetWorkspace();

    public MemoryWorkspace notifyScopeLeft() {
        close();
        return this;
    }

    public void toggleWorkspaceUse(boolean z) {
        this.isUsed.set(z);
    }

    public long getLastCycleAllocations() {
        return this.lastCycleAllocations.get();
    }

    public long getThisCycleAllocations() {
        return this.cycleAllocations.get();
    }

    public long getMaxCycleAllocations() {
        return this.maxCycle.get();
    }

    public boolean isScopeActive() {
        return this.isOpen.get();
    }

    public MemoryWorkspace tagOutOfScopeUse() {
        this.tagScope.incrementAndGet();
        return this;
    }

    public String toString() {
        return "Nd4jWorkspace{id='" + this.id + "', currentSize=" + this.currentSize.get() + '}';
    }

    public int getDeviceId() {
        return this.deviceId;
    }

    public Long getThreadId() {
        return this.threadId;
    }

    public String getId() {
        return this.id;
    }

    public WorkspaceConfiguration getWorkspaceConfiguration() {
        return this.workspaceConfiguration;
    }
}
