package org.nd4j.linalg.lossfunctions.impl;

import org.nd4j.base.Preconditions;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.shape.OneHot;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;

@JsonInclude(JsonInclude.Include.NON_NULL)
/* loaded from: input_file:org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.class */
public class LossSparseMCXENT extends LossMCXENT {
    private static final double DEFAULT_SOFTMAX_CLIPPING_EPSILON = 1.0E-10d;

    public LossSparseMCXENT() {
        this(null);
    }

    public LossSparseMCXENT(INDArray iNDArray) {
        this(DEFAULT_SOFTMAX_CLIPPING_EPSILON, iNDArray);
    }

    public LossSparseMCXENT(@JsonProperty("softmaxClipEps") double d, @JsonProperty("weights") INDArray iNDArray) {
        super(d, iNDArray);
    }

    protected INDArray sparseScoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        return super.scoreArray(toOneHot(iNDArray, iNDArray2), iNDArray2, iActivation, iNDArray3);
    }

    @Override // org.nd4j.linalg.lossfunctions.impl.LossMCXENT, org.nd4j.linalg.lossfunctions.ILossFunction
    public double computeScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        return super.computeScore(toOneHot(iNDArray, iNDArray2), iNDArray2, iActivation, iNDArray3, z);
    }

    @Override // org.nd4j.linalg.lossfunctions.impl.LossMCXENT, org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeScoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        return sparseScoreArray(iNDArray, iNDArray2, iActivation, iNDArray3).sum(true, 1).muli((Number) (-1));
    }

    @Override // org.nd4j.linalg.lossfunctions.impl.LossMCXENT, org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeGradient(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        return super.computeGradient(toOneHot(iNDArray, iNDArray2), iNDArray2, iActivation, iNDArray3);
    }

    @Override // org.nd4j.linalg.lossfunctions.impl.LossMCXENT, org.nd4j.linalg.lossfunctions.ILossFunction
    public Pair<Double, INDArray> computeGradientAndScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        INDArray oneHot = toOneHot(iNDArray, iNDArray2);
        return new Pair<>(Double.valueOf(super.computeScore(oneHot, iNDArray2, iActivation, iNDArray3, z)), super.computeGradient(oneHot, iNDArray2, iActivation, iNDArray3));
    }

    private INDArray toOneHot(INDArray iNDArray, INDArray iNDArray2) {
        Preconditions.checkState(iNDArray.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers with first dimension equal to minibatch size, and last dimension having size 1. Got labels array with shape %ndShape", iNDArray);
        INDArray ulike = iNDArray2.ulike();
        Nd4j.exec(new OneHot(iNDArray.reshape(iNDArray.length()), ulike, (int) iNDArray2.size(-1)));
        return ulike;
    }

    @Override // org.nd4j.linalg.lossfunctions.impl.LossMCXENT
    public String toString() {
        return this.weights == null ? "LossSparseMCXENT()" : "LossSparseMCXENT(weights=" + this.weights + ")";
    }

    @Override // org.nd4j.linalg.lossfunctions.impl.LossMCXENT
    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        return (obj instanceof LossSparseMCXENT) && ((LossSparseMCXENT) obj).canEqual(this) && super.equals(obj);
    }

    @Override // org.nd4j.linalg.lossfunctions.impl.LossMCXENT
    protected boolean canEqual(Object obj) {
        return obj instanceof LossSparseMCXENT;
    }

    @Override // org.nd4j.linalg.lossfunctions.impl.LossMCXENT
    public int hashCode() {
        return super.hashCode();
    }
}
