package org.deeplearning4j.nn.learning;

import java.io.Serializable;
import org.jblas.DoubleMatrix;
import org.jblas.MatrixFunctions;

/* loaded from: input_file:org/deeplearning4j/nn/learning/AdaGrad.class */
public class AdaGrad implements Serializable {
    private static final long serialVersionUID = -4754127927704099888L;
    private double masterStepSize;
    public DoubleMatrix historicalGradient;
    public DoubleMatrix adjustedGradient;
    public double fudgeFactor;
    public DoubleMatrix gradient;
    public int rows;
    public int cols;
    private int numIterations;
    private double lrDecay;
    private boolean decayLr;
    private double minLearningRate;

    public AdaGrad(int i, int i2, double d) {
        this.masterStepSize = 0.001d;
        this.fudgeFactor = 1.0E-6d;
        this.numIterations = 0;
        this.lrDecay = 0.95d;
        this.minLearningRate = 1.0E-4d;
        this.rows = i;
        this.cols = i2;
        this.adjustedGradient = new DoubleMatrix(i, i2);
        this.historicalGradient = new DoubleMatrix(i, i2);
        this.masterStepSize = d;
        this.decayLr = false;
    }

    public AdaGrad(int i, int i2) {
        this(i, i2, 0.01d);
    }

    public DoubleMatrix getLearningRates(DoubleMatrix doubleMatrix) {
        this.gradient = doubleMatrix.dup();
        double d = this.masterStepSize;
        if (this.decayLr && this.numIterations > 0) {
            this.masterStepSize *= this.lrDecay;
            if (this.masterStepSize < this.minLearningRate) {
                this.masterStepSize = this.minLearningRate;
            }
        }
        this.numIterations++;
        this.adjustedGradient = MatrixFunctions.sqrt(MatrixFunctions.pow(this.gradient, 2.0d)).mul(d);
        this.adjustedGradient.addi(1.0E-6d);
        return this.adjustedGradient;
    }

    public double getMasterStepSize() {
        return this.masterStepSize;
    }

    public void setMasterStepSize(double d) {
        this.masterStepSize = d;
    }

    public synchronized boolean isDecayLr() {
        return this.decayLr;
    }

    public synchronized void setDecayLr(boolean z) {
        this.decayLr = z;
    }
}
