package org.nd4j.linalg.learning;

import java.util.HashMap;
import java.util.Map;
import lombok.NonNull;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.shape.Shape;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;
import org.nd4j.linalg.learning.config.AdaMax;

/* loaded from: input_file:org/nd4j/linalg/learning/AdaMaxUpdater.class */
public class AdaMaxUpdater implements GradientUpdater<AdaMax> {
    public static final String M_STATE = "M";
    public static final String U_STATE = "V";
    private final AdaMax config;
    private INDArray m;
    private INDArray u;
    private char gradientReshapeOrder;

    public AdaMaxUpdater(AdaMax adaMax) {
        this.config = adaMax;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void setState(@NonNull Map<String, INDArray> map, boolean z) {
        if (map == null) {
            throw new NullPointerException("stateMap is marked non-null but is null");
        }
        if (!map.containsKey("M") || !map.containsKey("V") || map.size() != 2) {
            throw new IllegalStateException("State map should contain only keys [M,V] but has keys " + map.keySet());
        }
        this.m = map.get("M");
        this.u = map.get("V");
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public Map<String, INDArray> getState() {
        HashMap hashMap = new HashMap();
        hashMap.put("M", this.m);
        hashMap.put("V", this.u);
        return hashMap;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void setStateViewArray(INDArray iNDArray, long[] jArr, char c, boolean z) {
        if (!iNDArray.isRowVector()) {
            throw new IllegalArgumentException("Invalid input: expect row vector input");
        }
        if (z) {
            iNDArray.assign((Number) 0);
        }
        long length = iNDArray.length();
        this.m = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(0L, length / 2));
        this.u = iNDArray.get(NDArrayIndex.point(0L), NDArrayIndex.interval(length / 2, length));
        this.m = Shape.newShapeNoCopy(this.m, jArr, c == 'f');
        this.u = Shape.newShapeNoCopy(this.u, jArr, c == 'f');
        if (this.m == null || this.u == null) {
            throw new IllegalStateException("Could not correctly reshape gradient view arrays");
        }
        this.gradientReshapeOrder = c;
    }

    @Override // org.nd4j.linalg.learning.GradientUpdater
    public void applyUpdater(INDArray iNDArray, int i, int i2) {
        if (this.m == null || this.u == null) {
            throw new IllegalStateException("Updater has not been initialized with view state");
        }
        Nd4j.exec(new org.nd4j.linalg.api.ops.impl.updaters.AdaMaxUpdater(iNDArray, this.u, this.m, this.config.getLearningRate(i, i2), this.config.getBeta1(), this.config.getBeta2(), this.config.getEpsilon(), i));
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // org.nd4j.linalg.learning.GradientUpdater
    public AdaMax getConfig() {
        return this.config;
    }

    public INDArray getM() {
        return this.m;
    }

    public INDArray getU() {
        return this.u;
    }

    public char getGradientReshapeOrder() {
        return this.gradientReshapeOrder;
    }

    public void setM(INDArray iNDArray) {
        this.m = iNDArray;
    }

    public void setU(INDArray iNDArray) {
        this.u = iNDArray;
    }

    public void setGradientReshapeOrder(char c) {
        this.gradientReshapeOrder = c;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof AdaMaxUpdater)) {
            return false;
        }
        AdaMaxUpdater adaMaxUpdater = (AdaMaxUpdater) obj;
        if (!adaMaxUpdater.canEqual(this) || getGradientReshapeOrder() != adaMaxUpdater.getGradientReshapeOrder()) {
            return false;
        }
        AdaMax config = getConfig();
        AdaMax config2 = adaMaxUpdater.getConfig();
        if (config == null) {
            if (config2 != null) {
                return false;
            }
        } else if (!config.equals(config2)) {
            return false;
        }
        INDArray m = getM();
        INDArray m2 = adaMaxUpdater.getM();
        if (m == null) {
            if (m2 != null) {
                return false;
            }
        } else if (!m.equals(m2)) {
            return false;
        }
        INDArray u = getU();
        INDArray u2 = adaMaxUpdater.getU();
        return u == null ? u2 == null : u.equals(u2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof AdaMaxUpdater;
    }

    public int hashCode() {
        int gradientReshapeOrder = (1 * 59) + getGradientReshapeOrder();
        AdaMax config = getConfig();
        int hashCode = (gradientReshapeOrder * 59) + (config == null ? 43 : config.hashCode());
        INDArray m = getM();
        int hashCode2 = (hashCode * 59) + (m == null ? 43 : m.hashCode());
        INDArray u = getU();
        return (hashCode2 * 59) + (u == null ? 43 : u.hashCode());
    }

    public String toString() {
        return "AdaMaxUpdater(config=" + getConfig() + ", m=" + getM() + ", u=" + getU() + ", gradientReshapeOrder=" + getGradientReshapeOrder() + ")";
    }
}
