package org.nd4j.linalg.learning;

import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;

/* loaded from: input_file:org/nd4j/linalg/learning/Nesterovs.class */
public class Nesterovs implements Serializable, GradientUpdater {
    private double momentum;
    private volatile INDArray v;
    private double learningRate;

    /* loaded from: input_file:org/nd4j/linalg/learning/Nesterovs$NesterovsAggregator.class */
    public static class NesterovsAggregator implements GradientUpdaterAggregator {
        private INDArray vSum;
        private double lrSum;
        private double momentumSum;
        private int count = 0;

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public GradientUpdater getUpdater() {
            Nesterovs nesterovs = new Nesterovs(this.momentumSum / this.count, this.lrSum / this.count);
            nesterovs.setV(this.vSum.div(Integer.valueOf(this.count)));
            return nesterovs;
        }

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public void aggregate(GradientUpdater gradientUpdater) {
            if (!(gradientUpdater instanceof Nesterovs)) {
                throw new UnsupportedOperationException("Cannot aggregate Nesterovs with updater: " + gradientUpdater);
            }
            Nesterovs nesterovs = (Nesterovs) gradientUpdater;
            if (this.vSum == null) {
                this.vSum = nesterovs.v.dup();
                this.lrSum = nesterovs.learningRate;
                this.momentumSum = nesterovs.momentum;
            } else {
                this.vSum.addi(nesterovs.v);
                this.lrSum += nesterovs.learningRate;
                this.momentumSum += nesterovs.momentum;
            }
            this.count++;
        }

        @Override // org.nd4j.linalg.learning.GradientUpdaterAggregator
        public GradientUpdaterAggregator combine(GradientUpdaterAggregator gradientUpdaterAggregator) {
            if (!(gradientUpdaterAggregator instanceof NesterovsAggregator)) {
                throw new IllegalArgumentException("Cannot combine NesterovsAggregator with aggregator: " + gradientUpdaterAggregator);
            }
            NesterovsAggregator nesterovsAggregator = (NesterovsAggregator) gradientUpdaterAggregator;
            this.vSum.addi(nesterovsAggregator.vSum);
            this.lrSum += nesterovsAggregator.lrSum;
            this.momentumSum += nesterovsAggregator.momentumSum;
            this.count += nesterovsAggregator.count;
            return this;
        }
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public int stateSizeForInputSize(int i) {
        return i;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void setStateViewArray(INDArray iNDArray, int[] iArr, char c, boolean z) {
        if (!iNDArray.isRowVector()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (z) {
            iNDArray.assign((Number) 0);
        }
        this.v = iNDArray;
        this.v = Shape.newShapeNoCopy(this.v, iArr, c == 'f');
        if (this.v == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view array");
        }
    }

    public Nesterovs(double d, double d2) {
        this.momentum = 0.5d;
        this.learningRate = 0.1d;
        this.momentum = d;
        this.learningRate = d2;
    }

    public Nesterovs(double d) {
        this.momentum = 0.5d;
        this.learningRate = 0.1d;
        this.momentum = d;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void update(Object... objArr) {
        if (objArr.length > 0) {
            this.learningRate = ((Double) objArr[0]).doubleValue();
            this.momentum = ((Double) objArr[1]).doubleValue();
        }
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public INDArray getGradient(INDArray iNDArray, int i) {
        if (this.v == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        INDArray iNDArray2 = this.v;
        this.v = iNDArray2.mul(Double.valueOf(this.momentum)).subi(iNDArray.mul(Double.valueOf(this.learningRate)));
        iNDArray.assign(iNDArray2.muli(Double.valueOf(this.momentum)).addi(this.v.mul(Double.valueOf((-this.momentum) - 1.0d))));
        return iNDArray;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public GradientUpdaterAggregator getAggregator(boolean z) {
        NesterovsAggregator nesterovsAggregator = new NesterovsAggregator();
        if (z) {
            nesterovsAggregator.aggregate(this);
        }
        return nesterovsAggregator;
    }

    public double getMomentum() {
        return this.momentum;
    }

    public INDArray getV() {
        return this.v;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public void setMomentum(double d) {
        this.momentum = d;
    }

    public void setV(INDArray iNDArray) {
        this.v = iNDArray;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof Nesterovs)) {
            return false;
        }
        Nesterovs nesterovs = (Nesterovs) obj;
        if (!nesterovs.canEqual(this) || Double.compare(getMomentum(), nesterovs.getMomentum()) != 0) {
            return false;
        }
        INDArray v = getV();
        INDArray v2 = nesterovs.getV();
        if (v == null) {
            if (v2 != null) {
                return false;
            }
        } else if (!v.equals(v2)) {
            return false;
        }
        return Double.compare(getLearningRate(), nesterovs.getLearningRate()) == 0;
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof Nesterovs;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(getMomentum());
        int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        INDArray v = getV();
        int hashCode = (i * 59) + (v == null ? 0 : v.hashCode());
        long doubleToLongBits2 = Double.doubleToLongBits(getLearningRate());
        return (hashCode * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2));
    }

    public String toString() {
        return "Nesterovs(momentum=" + getMomentum() + ", v=" + getV() + ", learningRate=" + getLearningRate() + ")";
    }

    public Nesterovs() {
        this.momentum = 0.5d;
        this.learningRate = 0.1d;
    }
}
