package org.nd4j.linalg.learning;

import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/learning/AdaDelta.class */
public class AdaDelta implements Serializable, GradientUpdater {
    private INDArray msg;
    private INDArray msdx;
    private double rho;

    public AdaDelta(double d) {
        this.rho = 0.95d;
        this.rho = d;
    }

    public AdaDelta() {
        this.rho = 0.95d;
        this.rho = 0.95d;
    }

    public INDArray getMsg() {
        return this.msg;
    }

    public void setMsg(INDArray iNDArray) {
        this.msg = iNDArray;
    }

    public INDArray getMsdx() {
        return this.msdx;
    }

    public void setMsdx(INDArray iNDArray) {
        this.msdx = iNDArray;
    }

    public double getRho() {
        return this.rho;
    }

    public void setRho(double d) {
        this.rho = d;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public INDArray getGradient(INDArray iNDArray, int i) {
        if (this.msg == null) {
            this.msg = Nd4j.zeros(iNDArray.shape());
        }
        if (this.msdx == null) {
            this.msdx = Nd4j.zeros(iNDArray.shape());
        }
        this.msg.muli(Double.valueOf(this.rho));
        this.msg.addi(Double.valueOf(1.0d - this.rho)).muli(iNDArray.mul(iNDArray));
        INDArray muli = Transforms.sqrt(this.msdx.add(Double.valueOf(Nd4j.EPS_THRESHOLD))).divi(Transforms.sqrt(this.msg.add(Double.valueOf(Nd4j.EPS_THRESHOLD)))).muli(iNDArray);
        this.msdx.muli(Double.valueOf(this.rho));
        this.msdx.addi(muli.mul(muli).muli(Double.valueOf(1.0d - this.rho)));
        return muli;
    }
}
