package org.nd4j.linalg.learning;

import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/learning/AdaDelta.class */
public class AdaDelta implements Serializable, GradientUpdater {
    private INDArray msg;
    private INDArray msdx;
    private double rho;

    /* loaded from: input_file:org/nd4j/linalg/learning/AdaDelta$AdaDeltaAggregator.class */
    public static class AdaDeltaAggregator implements GradientUpdaterAggregator {
        private INDArray msgSum;
        private INDArray msdxSum;
        private double rhoSum;
        private int count = 0;

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public GradientUpdater getUpdater() {
            AdaDelta adaDelta = new AdaDelta(this.rhoSum / this.count);
            adaDelta.setMsg(this.msgSum.div(Integer.valueOf(this.count)));
            adaDelta.setMsdx(this.msdxSum.div(Integer.valueOf(this.count)));
            adaDelta.setRho(this.rhoSum / this.count);
            return adaDelta;
        }

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public void aggregate(GradientUpdater gradientUpdater) {
            if (!(gradientUpdater instanceof AdaDelta)) {
                throw new UnsupportedOperationException("Cannot aggregate AdaDelta with updater: " + gradientUpdater);
            }
            AdaDelta adaDelta = (AdaDelta) gradientUpdater;
            if (this.msgSum == null) {
                this.msgSum = adaDelta.msg.dup();
                this.msdxSum = adaDelta.msdx.dup();
                this.rhoSum = adaDelta.rho;
            } else {
                this.msgSum.addi(adaDelta.msg);
                this.msdxSum.addi(adaDelta.msdx);
                this.rhoSum += adaDelta.rho;
            }
            this.count++;
        }

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public GradientUpdaterAggregator combine(GradientUpdaterAggregator gradientUpdaterAggregator) {
            if (!(gradientUpdaterAggregator instanceof AdaDeltaAggregator)) {
                throw new IllegalArgumentException("Cannot combine AdaDeltaAggregator with aggregator: " + gradientUpdaterAggregator);
            }
            AdaDeltaAggregator adaDeltaAggregator = (AdaDeltaAggregator) gradientUpdaterAggregator;
            this.msgSum.addi(adaDeltaAggregator.msgSum);
            this.msdxSum.addi(adaDeltaAggregator.msdxSum);
            this.rhoSum += adaDeltaAggregator.rhoSum;
            this.count += adaDeltaAggregator.count;
            return this;
        }
    }

    public AdaDelta(double d) {
        this.rho = 0.95d;
        this.rho = d;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void update(Object... objArr) {
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public INDArray getGradient(INDArray iNDArray, int i) {
        if (this.msg == null) {
            this.msg = Nd4j.zeros(iNDArray.shape());
        }
        if (this.msdx == null) {
            this.msdx = Nd4j.zeros(iNDArray.shape());
        }
        this.msg.muli(Double.valueOf(this.rho));
        this.msg.addi(Double.valueOf(1.0d - this.rho)).muli(iNDArray.mul(iNDArray));
        INDArray muli = iNDArray.muli(Transforms.sqrt(this.msdx.add(Double.valueOf(Nd4j.EPS_THRESHOLD)), false).divi(Transforms.sqrt(this.msg.add(Double.valueOf(Nd4j.EPS_THRESHOLD)), false)));
        this.msdx.muli(Double.valueOf(this.rho));
        this.msdx.addi(muli.mul(muli).muli(Double.valueOf(1.0d - this.rho)));
        return muli;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public GradientUpdaterAggregator getAggregator(boolean z) {
        AdaDeltaAggregator adaDeltaAggregator = new AdaDeltaAggregator();
        if (z) {
            adaDeltaAggregator.aggregate(this);
        }
        return adaDeltaAggregator;
    }

    public INDArray getMsg() {
        return this.msg;
    }

    public INDArray getMsdx() {
        return this.msdx;
    }

    public double getRho() {
        return this.rho;
    }

    public void setMsg(INDArray iNDArray) {
        this.msg = iNDArray;
    }

    public void setMsdx(INDArray iNDArray) {
        this.msdx = iNDArray;
    }

    public void setRho(double d) {
        this.rho = d;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AdaDelta)) {
            return false;
        }
        AdaDelta adaDelta = (AdaDelta) obj;
        if (!adaDelta.canEqual(this)) {
            return false;
        }
        INDArray msg = getMsg();
        INDArray msg2 = adaDelta.getMsg();
        if (msg == null) {
            if (msg2 != null) {
                return false;
            }
        } else if (!msg.equals(msg2)) {
            return false;
        }
        INDArray msdx = getMsdx();
        INDArray msdx2 = adaDelta.getMsdx();
        if (msdx == null) {
            if (msdx2 != null) {
                return false;
            }
        } else if (!msdx.equals(msdx2)) {
            return false;
        }
        return Double.compare(getRho(), adaDelta.getRho()) == 0;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof AdaDelta;
    }

    public int hashCode() {
        INDArray msg = getMsg();
        int hashCode = (1 * 59) + (msg == null ? 0 : msg.hashCode());
        INDArray msdx = getMsdx();
        int hashCode2 = (hashCode * 59) + (msdx == null ? 0 : msdx.hashCode());
        long doubleToLongBits = Double.doubleToLongBits(getRho());
        return (hashCode2 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
    }

    public String toString() {
        return "AdaDelta(msg=" + getMsg() + ", msdx=" + getMsdx() + ", rho=" + getRho() + ")";
    }

    public AdaDelta() {
        this.rho = 0.95d;
    }
}
