package org.nd4j.linalg.learning;

import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/nd4j/linalg/learning/Sgd.class */
public class Sgd implements GradientUpdater {
    private double learningRate;

    /* loaded from: input_file:org/nd4j/linalg/learning/Sgd$SgdAggregator.class */
    public static class SgdAggregator implements GradientUpdaterAggregator {
        private double lrSum;
        private int count = 0;

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public GradientUpdater getUpdater() {
            return new Sgd(this.lrSum / this.count);
        }

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public void aggregate(GradientUpdater gradientUpdater) {
            if (!(gradientUpdater instanceof Sgd)) {
                throw new UnsupportedOperationException("Cannot aggregate Sgd with updater: " + gradientUpdater);
            }
            this.lrSum += ((Sgd) gradientUpdater).learningRate;
            this.count++;
        }

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public GradientUpdaterAggregator combine(GradientUpdaterAggregator gradientUpdaterAggregator) {
            if (!(gradientUpdaterAggregator instanceof SgdAggregator)) {
                throw new IllegalArgumentException("Cannot combine SgdAggregator with aggregator: " + gradientUpdaterAggregator);
            }
            SgdAggregator sgdAggregator = (SgdAggregator) gradientUpdaterAggregator;
            this.lrSum += sgdAggregator.lrSum;
            this.count += sgdAggregator.count;
            return this;
        }
    }

    public Sgd(double d) {
        this.learningRate = 0.1d;
        this.learningRate = d;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void update(Object... objArr) {
        if (objArr.length > 0) {
            this.learningRate = ((Double) objArr[0]).doubleValue();
        }
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public INDArray getGradient(INDArray iNDArray, int i) {
        return iNDArray.muli(Double.valueOf(this.learningRate));
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public GradientUpdaterAggregator getAggregator(boolean z) {
        SgdAggregator sgdAggregator = new SgdAggregator();
        if (z) {
            sgdAggregator.aggregate(this);
        }
        return sgdAggregator;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof Sgd)) {
            return false;
        }
        Sgd sgd = (Sgd) obj;
        return sgd.canEqual(this) && Double.compare(getLearningRate(), sgd.getLearningRate()) == 0;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof Sgd;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(getLearningRate());
        return (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
    }

    public String toString() {
        return "Sgd(learningRate=" + getLearningRate() + ")";
    }

    public Sgd() {
        this.learningRate = 0.1d;
    }
}
