package org.nd4j.linalg.learning;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/learning/RmsPropUpdater.class */
public class RmsPropUpdater implements GradientUpdater {
    private INDArray lastGradient;
    private double rmsDecay;
    private double lr;

    public RmsPropUpdater(double d, double d2) {
        this.rmsDecay = 0.5d;
        this.lr = 0.1d;
        this.lr = d;
        this.rmsDecay = d2;
    }

    public void setRmsDecay(double d) {
        this.rmsDecay = d;
    }

    public double getRmsDecay() {
        return this.rmsDecay;
    }

    public void setLR(double d) {
        this.lr = d;
    }

    public double getLR() {
        return this.lr;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public INDArray getGradient(INDArray iNDArray, int i) {
        if (this.lastGradient == null) {
            this.lastGradient = Nd4j.zeros(iNDArray.shape());
        }
        this.lastGradient.muli(Double.valueOf(this.rmsDecay)).addi(iNDArray.mul(iNDArray).muli(Double.valueOf(1.0d - this.rmsDecay)));
        return iNDArray.mul(Double.valueOf(this.lr)).divi(Transforms.sqrt(this.lastGradient.add(Double.valueOf(Nd4j.EPS_THRESHOLD))));
    }

    public RmsPropUpdater() {
        this.rmsDecay = 0.5d;
        this.lr = 0.1d;
    }
}
