package org.nd4j.linalg.learning;

import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/learning/RmsProp.class */
public class RmsProp implements GradientUpdater {
    private INDArray lastGradient;
    private double rmsDecay;
    private double learningRate;
    private static final double epsilon = 1.0E-8d;

    /* loaded from: input_file:org/nd4j/linalg/learning/RmsProp$RmsPropAggregator.class */
    public static class RmsPropAggregator implements GradientUpdaterAggregator {
        private INDArray lastGradientSum;
        private double rmsDecaySum;
        private double lrSum;
        private int count = 0;

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public GradientUpdater getUpdater() {
            RmsProp rmsProp = new RmsProp(this.lrSum / this.count, this.rmsDecaySum / this.count);
            rmsProp.setLastGradient(this.lastGradientSum.div(Integer.valueOf(this.count)));
            return rmsProp;
        }

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public void aggregate(GradientUpdater gradientUpdater) {
            if (!(gradientUpdater instanceof RmsProp)) {
                throw new UnsupportedOperationException();
            }
            RmsProp rmsProp = (RmsProp) gradientUpdater;
            if (this.lastGradientSum == null) {
                this.lastGradientSum = rmsProp.lastGradient.dup();
                this.rmsDecaySum = rmsProp.rmsDecay;
                this.lrSum = rmsProp.learningRate;
            } else {
                this.lastGradientSum.addi(rmsProp.lastGradient);
                this.rmsDecaySum += rmsProp.rmsDecay;
                this.lrSum += rmsProp.learningRate;
            }
            this.count++;
        }

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public GradientUpdaterAggregator combine(GradientUpdaterAggregator gradientUpdaterAggregator) {
            if (!(gradientUpdaterAggregator instanceof RmsPropAggregator)) {
                throw new IllegalArgumentException("Cannot combine RmsPropAggregator with aggregator: " + gradientUpdaterAggregator);
            }
            RmsPropAggregator rmsPropAggregator = (RmsPropAggregator) gradientUpdaterAggregator;
            this.lastGradientSum.addi(rmsPropAggregator.lastGradientSum);
            this.rmsDecaySum += rmsPropAggregator.rmsDecaySum;
            this.lrSum += rmsPropAggregator.lrSum;
            this.count += rmsPropAggregator.count;
            return this;
        }
    }

    public RmsProp(double d, double d2) {
        this.rmsDecay = 0.95d;
        this.learningRate = 0.1d;
        this.learningRate = d;
        this.rmsDecay = d2;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void update(Object... objArr) {
        if (objArr.length > 0) {
            this.learningRate = ((Double) objArr[0]).doubleValue();
        }
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public INDArray getGradient(INDArray iNDArray, int i) {
        if (this.lastGradient == null) {
            this.lastGradient = Nd4j.zeros(iNDArray.shape()).add(Double.valueOf(epsilon));
        }
        this.lastGradient.muli(Double.valueOf(this.rmsDecay)).addi(iNDArray.mul(iNDArray).muli(Double.valueOf(1.0d - this.rmsDecay)));
        return iNDArray.muli(Double.valueOf(this.learningRate)).divi(Transforms.sqrt(this.lastGradient, true).addi(Double.valueOf(epsilon)));
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public GradientUpdaterAggregator getAggregator(boolean z) {
        RmsPropAggregator rmsPropAggregator = new RmsPropAggregator();
        if (z) {
            rmsPropAggregator.aggregate(this);
        }
        return rmsPropAggregator;
    }

    public INDArray getLastGradient() {
        return this.lastGradient;
    }

    public double getRmsDecay() {
        return this.rmsDecay;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLastGradient(INDArray iNDArray) {
        this.lastGradient = iNDArray;
    }

    public void setRmsDecay(double d) {
        this.rmsDecay = d;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof RmsProp)) {
            return false;
        }
        RmsProp rmsProp = (RmsProp) obj;
        if (!rmsProp.canEqual(this)) {
            return false;
        }
        INDArray lastGradient = getLastGradient();
        INDArray lastGradient2 = rmsProp.getLastGradient();
        if (lastGradient == null) {
            if (lastGradient2 != null) {
                return false;
            }
        } else if (!lastGradient.equals(lastGradient2)) {
            return false;
        }
        return Double.compare(getRmsDecay(), rmsProp.getRmsDecay()) == 0 && Double.compare(getLearningRate(), rmsProp.getLearningRate()) == 0;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof RmsProp;
    }

    public int hashCode() {
        INDArray lastGradient = getLastGradient();
        int hashCode = (1 * 59) + (lastGradient == null ? 0 : lastGradient.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(getRmsDecay());
        int i = (hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        long doubleToLongBits2 = Double.doubleToLongBits(getLearningRate());
        return (i * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
    }

    public String toString() {
        return "RmsProp(lastGradient=" + getLastGradient() + ", rmsDecay=" + getRmsDecay() + ", learningRate=" + getLearningRate() + ")";
    }

    public RmsProp() {
        this.rmsDecay = 0.95d;
        this.learningRate = 0.1d;
    }
}
