package org.nd4j.linalg.lossfunctions;

import java.util.Arrays;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

/* loaded from: input_file:org/nd4j/linalg/lossfunctions/LossFunctions.class */
public class LossFunctions {

    /* loaded from: input_file:org/nd4j/linalg/lossfunctions/LossFunctions$LossFunction.class */
    public enum LossFunction {
        MSE,
        EXPLL,
        XENT,
        MCXENT,
        RMSE_XENT,
        SQUARED_LOSS,
        RECONSTRUCTION_CROSSENTROPY,
        NEGATIVELOGLIKELIHOOD,
        CUSTOM
    }

    public static double score(INDArray iNDArray, LossFunction lossFunction, INDArray iNDArray2, double d, double d2, boolean z) {
        return LossCalculation.builder().l1(d2).lossFunction(lossFunction).l2(d).labels(iNDArray).z(iNDArray2).useRegularization(z).build().score();
    }

    public static double score(INDArray iNDArray, LossFunction lossFunction, INDArray iNDArray2, double d, boolean z) {
        double d2 = 0.0d;
        double d3 = 0.5d * d;
        if (!Arrays.equals(iNDArray.shape(), iNDArray2.shape())) {
            throw new IllegalArgumentException("Output and labels must be same length");
        }
        boolean z2 = Nd4j.ENFORCE_NUMERICAL_STABILITY;
        Nd4j.ENFORCE_NUMERICAL_STABILITY = true;
        switch (lossFunction) {
            case CUSTOM:
                throw new IllegalStateException("Unable to score custom operation. Please define an alternative mechanism");
            case RECONSTRUCTION_CROSSENTROPY:
                d2 = iNDArray.mul(Transforms.log(iNDArray2)).add(iNDArray.rsub((Number) 1)).muli(Transforms.log(iNDArray2).rsubi((Number) 1)).sum(1).meanNumber().doubleValue();
                break;
            case MCXENT:
                d2 = -iNDArray.mul(Transforms.log(Transforms.log(iNDArray2))).sumNumber().doubleValue();
                break;
            case XENT:
                d2 = iNDArray.mul(Transforms.log(iNDArray2)).add(iNDArray.rsub((Number) 1)).muli(Transforms.log(iNDArray2).rsubi((Number) 1)).sum(1).sumNumber().doubleValue();
                break;
            case RMSE_XENT:
                d2 = Transforms.sqrt(Transforms.pow(iNDArray.sub(iNDArray2), Double.valueOf(2.0d))).sum(1).sumNumber().doubleValue();
                break;
            case MSE:
                d2 = 0.5d * Transforms.pow(iNDArray.sub(iNDArray2), 2).sum(1).sumNumber().doubleValue();
                break;
            case EXPLL:
                d2 = iNDArray2.sub(iNDArray.mul(Transforms.log(iNDArray2))).sum(1).sumNumber().doubleValue();
                break;
            case SQUARED_LOSS:
                d2 = Transforms.pow(iNDArray.sub(iNDArray2), 2).sum(1).sumNumber().doubleValue();
                break;
            case NEGATIVELOGLIKELIHOOD:
                d2 = -iNDArray.mul(Transforms.log(Transforms.log(iNDArray2))).sumNumber().doubleValue();
                break;
        }
        if (z) {
            d2 += d3;
        }
        double size = d2 / iNDArray.size(0);
        Nd4j.ENFORCE_NUMERICAL_STABILITY = z2;
        return size;
    }
}
