package org.nd4j.linalg.lossfunctions.impl;

import org.nd4j.common.base.Preconditions;
import org.nd4j.common.primitives.Pair;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.DynamicCustomOp;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.same.TimesOneMinus;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.LossUtil;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.serde.jackson.shaded.NDArrayTextDeSerializer;
import org.nd4j.serde.jackson.shaded.NDArrayTextSerializer;
import org.nd4j.shade.jackson.annotation.JsonInclude;
import org.nd4j.shade.jackson.annotation.JsonProperty;
import org.nd4j.shade.jackson.databind.annotation.JsonDeserialize;
import org.nd4j.shade.jackson.databind.annotation.JsonSerialize;

@JsonInclude(JsonInclude.Include.NON_NULL)
/* loaded from: input_file:org/nd4j/linalg/lossfunctions/impl/LossBinaryXENT.class */
public class LossBinaryXENT implements ILossFunction {
    public static final double DEFAULT_CLIPPING_EPSILON = 1.0E-5d;

    @JsonDeserialize(using = NDArrayTextDeSerializer.class)
    @JsonSerialize(using = NDArrayTextSerializer.class)
    private final INDArray weights;
    private double clipEps;

    public LossBinaryXENT() {
        this((INDArray) null);
    }

    public LossBinaryXENT(INDArray iNDArray) {
        this(1.0E-5d, iNDArray);
    }

    public LossBinaryXENT(double d) {
        this(d, null);
    }

    public LossBinaryXENT(@JsonProperty("clipEps") double d, @JsonProperty("weights") INDArray iNDArray) {
        if (iNDArray != null && !iNDArray.isRowVector()) {
            throw new IllegalArgumentException("Weights array must be a row vector");
        }
        if (d < 0.0d || d > 0.5d) {
            throw new IllegalArgumentException("Invalid clipping epsilon value: epsilon should be >= 0 (but near zero).Got: " + d);
        }
        this.clipEps = d;
        this.weights = iNDArray;
    }

    private INDArray scoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        INDArray muli;
        if (!iNDArray.equalShapes(iNDArray2)) {
            Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", new Object[]{iNDArray.shape(), iNDArray2.shape()});
        }
        INDArray castTo = iNDArray.castTo(iNDArray2.dataType());
        if (iActivation instanceof ActivationSoftmax) {
            INDArray iNDArray4 = Nd4j.exec(new SoftMax(iNDArray2, iNDArray2.ulike(), -1))[0];
            Transforms.log(iNDArray4, false);
            muli = iNDArray4.muli(castTo);
        } else {
            INDArray activation = iActivation.getActivation(iNDArray2.dup(), true);
            if (this.clipEps > 0.0d) {
                Nd4j.getExecutioner().execAndReturn(DynamicCustomOp.builder("clipbyvalue").addInputs(activation).callInplace(true).addFloatingPointArguments(Double.valueOf(this.clipEps), Double.valueOf(1.0d - this.clipEps)).build());
            }
            muli = Transforms.log(activation, true).muli(castTo);
            INDArray rsubi = activation.rsubi((Number) 1);
            Transforms.log(rsubi, false);
            rsubi.muli(castTo.rsub((Number) 1));
            muli.addi(rsubi);
        }
        if (this.weights != null) {
            if (this.weights.length() != iNDArray2.size(1)) {
                throw new IllegalStateException("Weights vector (length " + this.weights.length() + ") does not match output.size(1)=" + iNDArray2.size(1));
            }
            muli.muliRowVector(this.weights.castTo(muli.dataType()));
        }
        if (iNDArray3 != null) {
            LossUtil.applyMask(muli, iNDArray3);
        }
        return muli;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public double computeScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        double d = -scoreArray(iNDArray, iNDArray2, iActivation, iNDArray3).sumNumber().doubleValue();
        if (z) {
            d /= r0.size(0);
        }
        return d;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeScoreArray(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        return scoreArray(iNDArray, iNDArray2, iActivation, iNDArray3).sum(true, 1).muli((Number) (-1));
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public INDArray computeGradient(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3) {
        if (!iNDArray.equalShapes(iNDArray2)) {
            Preconditions.throwEx("Labels and preOutput must have equal shapes: got shapes %s vs %s", new Object[]{iNDArray.shape(), iNDArray2.shape()});
        }
        INDArray castTo = iNDArray.castTo(iNDArray2.dataType());
        INDArray activation = iActivation.getActivation(iNDArray2.dup(), true);
        if (this.clipEps > 0.0d) {
            Nd4j.getExecutioner().execAndReturn(DynamicCustomOp.builder("clipbyvalue").addInputs(activation).callInplace(true).addFloatingPointArguments(Double.valueOf(this.clipEps), Double.valueOf(1.0d - this.clipEps)).build());
        }
        INDArray divi = activation.sub(castTo).divi(Nd4j.getExecutioner().exec(new TimesOneMinus(activation)));
        if (iNDArray3 != null && LossUtil.isPerOutputMasking(divi, iNDArray3)) {
            LossUtil.applyMask(divi, iNDArray3);
        }
        INDArray iNDArray4 = (INDArray) iActivation.backprop(iNDArray2, divi).getFirst();
        if (this.weights != null) {
            if (this.weights.length() != activation.size(1)) {
                throw new IllegalStateException("Weights vector (length " + this.weights.length() + ") does not match output.size(1)=" + activation.size(1));
            }
            iNDArray4.muliRowVector(this.weights.castTo(iNDArray4.dataType()));
        }
        if (iNDArray3 != null) {
            LossUtil.applyMask(iNDArray4, iNDArray3);
        }
        return iNDArray4;
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public Pair<Double, INDArray> computeGradientAndScore(INDArray iNDArray, INDArray iNDArray2, IActivation iActivation, INDArray iNDArray3, boolean z) {
        return new Pair<>(Double.valueOf(computeScore(iNDArray, iNDArray2, iActivation, iNDArray3, z)), computeGradient(iNDArray, iNDArray2, iActivation, iNDArray3));
    }

    @Override // org.nd4j.linalg.lossfunctions.ILossFunction
    public String name() {
        return toString();
    }

    public String toString() {
        return this.weights == null ? "LossBinaryXENT()" : "LossBinaryXENT(weights=" + this.weights + ")";
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof LossBinaryXENT)) {
            return false;
        }
        LossBinaryXENT lossBinaryXENT = (LossBinaryXENT) obj;
        if (!lossBinaryXENT.canEqual(this) || Double.compare(getClipEps(), lossBinaryXENT.getClipEps()) != 0) {
            return false;
        }
        INDArray weights = getWeights();
        INDArray weights2 = lossBinaryXENT.getWeights();
        return weights == null ? weights2 == null : weights.equals(weights2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof LossBinaryXENT;
    }

    public int hashCode() {
        long doubleToLongBits = Double.doubleToLongBits(getClipEps());
        int i = (1 * 59) + ((int) ((doubleToLongBits >>> 32) ^ doubleToLongBits));
        INDArray weights = getWeights();
        return (i * 59) + (weights == null ? 43 : weights.hashCode());
    }

    public INDArray getWeights() {
        return this.weights;
    }

    public double getClipEps() {
        return this.clipEps;
    }

    public void setClipEps(double d) {
        this.clipEps = d;
    }
}
