package org.deeplearning4j.optimize.solvers;

import java.util.Collection;
import java.util.Iterator;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.optimize.api.StepFunction;
import org.deeplearning4j.optimize.api.TrainingListener;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/optimize/solvers/StochasticGradientDescent.class */
public class StochasticGradientDescent extends BaseOptimizer {
    private static final Logger log = LoggerFactory.getLogger(StochasticGradientDescent.class);

    public StochasticGradientDescent(NeuralNetConfiguration neuralNetConfiguration, StepFunction stepFunction, Collection<TrainingListener> collection, Model model) {
        super(neuralNetConfiguration, stepFunction, collection, model);
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public boolean optimize(LayerWorkspaceMgr layerWorkspaceMgr) {
        Gradient gradient = (Gradient) gradientAndScore(layerWorkspaceMgr).getFirst();
        INDArray params = this.model.params();
        if (this.accumulator != null) {
            this.accumulator.storeUpdate(gradient.gradient());
            this.accumulator.applyUpdate(this.stepFunction, params, gradient.gradient());
        } else {
            this.stepFunction.step(params, gradient.gradient());
        }
        this.model.setParams(params);
        int iterationCount = BaseOptimizer.getIterationCount(this.model);
        int epochCount = BaseOptimizer.getEpochCount(this.model);
        MemoryWorkspace scopeOutOfWorkspaces = Nd4j.getMemoryManager().scopeOutOfWorkspaces();
        Throwable th = null;
        try {
            try {
                Iterator<TrainingListener> it = this.trainingListeners.iterator();
                while (it.hasNext()) {
                    it.next().iterationDone(this.model, iterationCount, epochCount);
                }
                if (scopeOutOfWorkspaces != null) {
                    if (0 != 0) {
                        try {
                            scopeOutOfWorkspaces.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        scopeOutOfWorkspaces.close();
                    }
                }
                BaseOptimizer.incrementIterationCount(this.model, 1);
                applyConstraints(this.model);
                return true;
            } finally {
            }
        } catch (Throwable th3) {
            if (scopeOutOfWorkspaces != null) {
                if (th != null) {
                    try {
                        scopeOutOfWorkspaces.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    scopeOutOfWorkspaces.close();
                }
            }
            throw th3;
        }
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public void preProcessLine() {
    }

    @Override // org.deeplearning4j.optimize.solvers.BaseOptimizer, org.deeplearning4j.optimize.api.ConvexOptimizer
    public void postStep(INDArray iNDArray) {
    }
}
