package ai.djl.pytorch.engine;

import ai.djl.ndarray.NDArray;
import ai.djl.pytorch.jni.JniUtils;
import ai.djl.training.GradientCollector;
import java.util.concurrent.atomic.AtomicBoolean;

/* loaded from: input_file:ai/djl/pytorch/engine/PtGradientCollector.class */
public final class PtGradientCollector implements GradientCollector {
    private boolean gradModel = JniUtils.isGradMode();
    private static AtomicBoolean isCollecting = new AtomicBoolean();

    public PtGradientCollector() {
        JniUtils.setGradMode(true);
        if (isCollecting.getAndSet(true)) {
            throw new IllegalStateException("A PtGradientCollector is already collecting. Only one can be collecting at a time");
        }
    }

    public void backward(NDArray nDArray) {
        backward(nDArray, nDArray.getManager().ones(nDArray.getShape(), nDArray.getDataType()).toDevice(nDArray.getDevice(), false), false, false);
    }

    private void backward(NDArray nDArray, NDArray nDArray2, boolean z, boolean z2) {
        JniUtils.backward((PtNDArray) nDArray, (PtNDArray) nDArray2, z, z2);
    }

    public void zeroGradients() {
        for (NDArray nDArray : PtNDManager.getSystemManager().getManagedArrays()) {
            if (nDArray.hasGradient()) {
                nDArray.getGradient().subi(nDArray.getGradient());
            }
        }
    }

    public void close() {
        if (!this.gradModel) {
            JniUtils.setGradMode(false);
        }
        isCollecting.set(false);
    }
}
