package org.nd4j.linalg.learning;

import java.io.Serializable;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.util.Shape;

/* loaded from: input_file:org/nd4j/linalg/learning/AdaGrad.class */
public class AdaGrad implements Serializable {
    protected static final long serialVersionUID = -4754127927704099888L;
    protected double masterStepSize;
    public INDArray historicalGradient;
    public int[] shape;
    protected int numIterations;
    protected boolean decayLr;

    public AdaGrad(int i, int i2, double d) {
        this.masterStepSize = 0.1d;
        this.numIterations = 0;
        this.shape = new int[]{i, i2};
        this.masterStepSize = d;
        this.decayLr = false;
    }

    public AdaGrad(int[] iArr) {
        this.masterStepSize = 0.1d;
        this.numIterations = 0;
        this.shape = iArr;
        this.masterStepSize = 0.1d;
        this.decayLr = false;
    }

    public AdaGrad(int i, int i2) {
        this(i, i2, 0.1d);
    }

    public double getGradient(double d, int i, int[] iArr) {
        boolean z = false;
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.ones(iArr);
            z = true;
        }
        double sqrt = d * (this.masterStepSize / (!z ? Math.sqrt(this.historicalGradient.getDouble(i)) : this.historicalGradient.getDouble(i)));
        this.historicalGradient.putScalar(i, this.historicalGradient.getDouble(i) + Math.pow(d, 2.0d));
        this.numIterations++;
        return sqrt;
    }

    public AdaGrad createSubset(int i) {
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.ones(this.shape);
        }
        if (Shape.isMatrix(this.shape)) {
            AdaGrad adaGrad = new AdaGrad(1, this.historicalGradient.columns());
            adaGrad.historicalGradient = this.historicalGradient.slice(i).dup();
            adaGrad.setMasterStepSize(this.masterStepSize);
            adaGrad.setDecayLr(this.decayLr);
            return adaGrad;
        }
        AdaGrad adaGrad2 = new AdaGrad(1, 1);
        adaGrad2.historicalGradient = Nd4j.scalar(this.historicalGradient.getDouble(i));
        adaGrad2.setMasterStepSize(this.masterStepSize);
        adaGrad2.setDecayLr(this.decayLr);
        return adaGrad2;
    }

    public INDArray getGradient(INDArray iNDArray, int i, int[] iArr) {
        INDArray sqrt;
        boolean z = false;
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.ones(iArr);
            z = true;
        } else if (!this.historicalGradient.isVector() && this.historicalGradient.slice(i).length() != iNDArray.length()) {
            throw new IllegalArgumentException("Illegal gradient");
        }
        if (this.historicalGradient.isVector()) {
            sqrt = Transforms.sqrt(this.historicalGradient);
        } else {
            sqrt = !z ? Transforms.sqrt(this.historicalGradient.slice(i)) : this.historicalGradient;
        }
        iNDArray.muli(sqrt.rdivi(Double.valueOf(this.masterStepSize)));
        this.historicalGradient.slice(i).addi(Transforms.pow(iNDArray, (Number) 2));
        this.numIterations++;
        return iNDArray;
    }

    public INDArray getGradient(INDArray iNDArray) {
        boolean z = false;
        if (this.historicalGradient == null) {
            this.historicalGradient = Nd4j.ones(iNDArray.rows(), iNDArray.columns());
            z = true;
        } else if (this.historicalGradient.length() != iNDArray.length()) {
            throw new IllegalArgumentException("Illegal gradient");
        }
        iNDArray.muli((!z ? Transforms.sqrt(this.historicalGradient) : this.historicalGradient).rdivi(Double.valueOf(this.masterStepSize)));
        this.historicalGradient.addi(Transforms.pow(iNDArray, (Number) 2));
        this.numIterations++;
        return iNDArray;
    }

    public double getMasterStepSize() {
        return this.masterStepSize;
    }

    public void setMasterStepSize(double d) {
        this.masterStepSize = d;
    }

    public boolean isDecayLr() {
        return this.decayLr;
    }

    public void setDecayLr(boolean z) {
        this.decayLr = z;
    }
}
