package org.nd4j.linalg.learning;

import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.learning.config.AdaGrad;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/learning/AdaGradUpdater.class */
public class AdaGradUpdater implements GradientUpdater<AdaGrad> {
    public INDArray historicalGradient;
    public int[] shape;
    protected double learningRate = 0.1d;
    protected int numIterations = 0;
    private double epsilon = 1.0E-6d;
    private char gradientReshapeOrder;
    private AdaGrad config;

    public AdaGradUpdater(AdaGrad adaGrad) {
        this.config = adaGrad;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void setStateViewArray(INDArray iNDArray, int[] iArr, char c, boolean z) {
        if (!iNDArray.isRowVector()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (z) {
            iNDArray.assign(Double.valueOf(this.epsilon));
        }
        this.historicalGradient = iNDArray;
        this.historicalGradient = Shape.newShapeNoCopy(this.historicalGradient, iArr, c == 'f');
        if (this.historicalGradient == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view array");
        }
        this.gradientReshapeOrder = c;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void applyUpdater(INDArray iNDArray, int i, int i2) {
        if (this.historicalGradient == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        double learningRate = this.config.getLearningRate(i, i2);
        double epsilon = this.config.getEpsilon();
        this.historicalGradient.addi(iNDArray.mul(iNDArray));
        iNDArray.muli(Transforms.sqrt(this.historicalGradient.dup(this.gradientReshapeOrder), false).addi(Double.valueOf(epsilon)).rdivi(Double.valueOf(learningRate)));
    }

    public INDArray getHistoricalGradient() {
        return this.historicalGradient;
    }

    public int[] getShape() {
        return this.shape;
    }

    public double getLearningRate() {
        return this.learningRate;
    }

    public int getNumIterations() {
        return this.numIterations;
    }

    public double getEpsilon() {
        return this.epsilon;
    }

    public char getGradientReshapeOrder() {
        return this.gradientReshapeOrder;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.learning.GradientUpdater
    public AdaGrad getConfig() {
        return this.config;
    }

    public void setHistoricalGradient(INDArray iNDArray) {
        this.historicalGradient = iNDArray;
    }

    public void setShape(int[] iArr) {
        this.shape = iArr;
    }

    public void setLearningRate(double d) {
        this.learningRate = d;
    }

    public void setNumIterations(int i) {
        this.numIterations = i;
    }

    public void setEpsilon(double d) {
        this.epsilon = d;
    }

    public void setGradientReshapeOrder(char c) {
        this.gradientReshapeOrder = c;
    }

    public void setConfig(AdaGrad adaGrad) {
        this.config = adaGrad;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AdaGradUpdater)) {
            return false;
        }
        AdaGradUpdater adaGradUpdater = (AdaGradUpdater) obj;
        if (!adaGradUpdater.canEqual(this)) {
            return false;
        }
        INDArray historicalGradient = getHistoricalGradient();
        INDArray historicalGradient2 = adaGradUpdater.getHistoricalGradient();
        if (historicalGradient == null) {
            if (historicalGradient2 != null) {
                return false;
            }
        } else if (!historicalGradient.equals(historicalGradient2)) {
            return false;
        }
        if (!Arrays.equals(getShape(), adaGradUpdater.getShape()) || Double.compare(getLearningRate(), adaGradUpdater.getLearningRate()) != 0 || getNumIterations() != adaGradUpdater.getNumIterations() || Double.compare(getEpsilon(), adaGradUpdater.getEpsilon()) != 0 || getGradientReshapeOrder() != adaGradUpdater.getGradientReshapeOrder()) {
            return false;
        }
        AdaGrad config = getConfig();
        AdaGrad config2 = adaGradUpdater.getConfig();
        return config == null ? config2 == null : config.equals(config2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof AdaGradUpdater;
    }

    public int hashCode() {
        INDArray historicalGradient = getHistoricalGradient();
        int hashCode = (((1 * 59) + (historicalGradient == null ? 43 : historicalGradient.hashCode())) * 59) + Arrays.hashCode(getShape());
        long doubleToLongBits = Double.doubleToLongBits(getLearningRate());
        int numIterations = (((hashCode * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits))) * 59) + getNumIterations();
        long doubleToLongBits2 = Double.doubleToLongBits(getEpsilon());
        int gradientReshapeOrder = (((numIterations * 59) + ((int) ((doubleToLongBits2 >>> 32) ^ doubleToLongBits2))) * 59) + getGradientReshapeOrder();
        AdaGrad config = getConfig();
        return (gradientReshapeOrder * 59) + (config == null ? 43 : config.hashCode());
    }

    public String toString() {
        return "AdaGradUpdater(historicalGradient=" + getHistoricalGradient() + ", shape=" + Arrays.toString(getShape()) + ", learningRate=" + getLearningRate() + ", numIterations=" + getNumIterations() + ", epsilon=" + getEpsilon() + ", gradientReshapeOrder=" + getGradientReshapeOrder() + ", config=" + getConfig() + ")";
    }
}
