package org.nd4j.linalg.factory.ops;

import org.nd4j.autodiff.loss.LossReduce;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.loss.AbsoluteDifferenceLoss;
import org.nd4j.linalg.api.ops.impl.loss.CosineDistanceLoss;
import org.nd4j.linalg.api.ops.impl.loss.CtcLoss;
import org.nd4j.linalg.api.ops.impl.loss.HingeLoss;
import org.nd4j.linalg.api.ops.impl.loss.HuberLoss;
import org.nd4j.linalg.api.ops.impl.loss.L2Loss;
import org.nd4j.linalg.api.ops.impl.loss.LogLoss;
import org.nd4j.linalg.api.ops.impl.loss.LogPoissonLoss;
import org.nd4j.linalg.api.ops.impl.loss.MeanPairwiseSquaredErrorLoss;
import org.nd4j.linalg.api.ops.impl.loss.MeanSquaredErrorLoss;
import org.nd4j.linalg.api.ops.impl.loss.SigmoidCrossEntropyLoss;
import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyLoss;
import org.nd4j.linalg.api.ops.impl.loss.SparseSoftmaxCrossEntropyLossWithLogits;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/factory/ops/NDLoss.class */
public class NDLoss {
    public INDArray absoluteDifference(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, LossReduce lossReduce) {
        NDValidation.validateNumerical("absoluteDifference", "label", iNDArray);
        NDValidation.validateNumerical("absoluteDifference", "predictions", iNDArray2);
        NDValidation.validateNumerical("absoluteDifference", "weights", iNDArray3);
        return Nd4j.exec(new AbsoluteDifferenceLoss(iNDArray, iNDArray2, iNDArray3, lossReduce))[0];
    }

    public INDArray absoluteDifference(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        NDValidation.validateNumerical("absoluteDifference", "label", iNDArray);
        NDValidation.validateNumerical("absoluteDifference", "predictions", iNDArray2);
        NDValidation.validateNumerical("absoluteDifference", "weights", iNDArray3);
        return Nd4j.exec(new AbsoluteDifferenceLoss(iNDArray, iNDArray2, iNDArray3, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
    }

    public INDArray cosineDistance(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, LossReduce lossReduce, int i) {
        NDValidation.validateNumerical("cosineDistance", "label", iNDArray);
        NDValidation.validateNumerical("cosineDistance", "predictions", iNDArray2);
        NDValidation.validateNumerical("cosineDistance", "weights", iNDArray3);
        return Nd4j.exec(new CosineDistanceLoss(iNDArray, iNDArray2, iNDArray3, lossReduce, i))[0];
    }

    public INDArray cosineDistance(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, int i) {
        NDValidation.validateNumerical("cosineDistance", "label", iNDArray);
        NDValidation.validateNumerical("cosineDistance", "predictions", iNDArray2);
        NDValidation.validateNumerical("cosineDistance", "weights", iNDArray3);
        return Nd4j.exec(new CosineDistanceLoss(iNDArray, iNDArray2, iNDArray3, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, i))[0];
    }

    public INDArray ctcLoss(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4) {
        NDValidation.validateNumerical("ctcLoss", "targetLabels", iNDArray);
        NDValidation.validateNumerical("ctcLoss", "logitInput", iNDArray2);
        NDValidation.validateNumerical("ctcLoss", "targetLabelLengths", iNDArray3);
        NDValidation.validateNumerical("ctcLoss", "logitInputLengths", iNDArray4);
        return Nd4j.exec(new CtcLoss(iNDArray, iNDArray2, iNDArray3, iNDArray4))[0];
    }

    public INDArray hingeLoss(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, LossReduce lossReduce) {
        NDValidation.validateNumerical("hingeLoss", "label", iNDArray);
        NDValidation.validateNumerical("hingeLoss", "predictions", iNDArray2);
        NDValidation.validateNumerical("hingeLoss", "weights", iNDArray3);
        return Nd4j.exec(new HingeLoss(iNDArray, iNDArray2, iNDArray3, lossReduce))[0];
    }

    public INDArray hingeLoss(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        NDValidation.validateNumerical("hingeLoss", "label", iNDArray);
        NDValidation.validateNumerical("hingeLoss", "predictions", iNDArray2);
        NDValidation.validateNumerical("hingeLoss", "weights", iNDArray3);
        return Nd4j.exec(new HingeLoss(iNDArray, iNDArray2, iNDArray3, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
    }

    public INDArray huberLoss(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, LossReduce lossReduce, double d) {
        NDValidation.validateNumerical("huberLoss", "label", iNDArray);
        NDValidation.validateNumerical("huberLoss", "predictions", iNDArray2);
        NDValidation.validateNumerical("huberLoss", "weights", iNDArray3);
        return Nd4j.exec(new HuberLoss(iNDArray, iNDArray2, iNDArray3, lossReduce, d))[0];
    }

    public INDArray huberLoss(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, double d) {
        NDValidation.validateNumerical("huberLoss", "label", iNDArray);
        NDValidation.validateNumerical("huberLoss", "predictions", iNDArray2);
        NDValidation.validateNumerical("huberLoss", "weights", iNDArray3);
        return Nd4j.exec(new HuberLoss(iNDArray, iNDArray2, iNDArray3, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, d))[0];
    }

    public INDArray l2Loss(INDArray iNDArray) {
        NDValidation.validateNumerical("l2Loss", "var", iNDArray);
        return Nd4j.exec(new L2Loss(iNDArray))[0];
    }

    public INDArray logLoss(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, LossReduce lossReduce, double d) {
        NDValidation.validateNumerical("logLoss", "label", iNDArray);
        NDValidation.validateNumerical("logLoss", "predictions", iNDArray2);
        NDValidation.validateNumerical("logLoss", "weights", iNDArray3);
        return Nd4j.exec(new LogLoss(iNDArray, iNDArray2, iNDArray3, lossReduce, d))[0];
    }

    public INDArray logLoss(INDArray iNDArray, INDArray iNDArray2) {
        NDValidation.validateNumerical("logLoss", "label", iNDArray);
        NDValidation.validateNumerical("logLoss", "predictions", iNDArray2);
        return Nd4j.exec(new LogLoss(iNDArray, iNDArray2, null, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0d))[0];
    }

    public INDArray logPoisson(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, LossReduce lossReduce, boolean z) {
        NDValidation.validateNumerical("logPoisson", "label", iNDArray);
        NDValidation.validateNumerical("logPoisson", "predictions", iNDArray2);
        NDValidation.validateNumerical("logPoisson", "weights", iNDArray3);
        return Nd4j.exec(new LogPoissonLoss(iNDArray, iNDArray2, iNDArray3, lossReduce, z))[0];
    }

    public INDArray logPoisson(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, boolean z) {
        NDValidation.validateNumerical("logPoisson", "label", iNDArray);
        NDValidation.validateNumerical("logPoisson", "predictions", iNDArray2);
        NDValidation.validateNumerical("logPoisson", "weights", iNDArray3);
        return Nd4j.exec(new LogPoissonLoss(iNDArray, iNDArray2, iNDArray3, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, z))[0];
    }

    public INDArray meanPairwiseSquaredError(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, LossReduce lossReduce) {
        NDValidation.validateNumerical("meanPairwiseSquaredError", "label", iNDArray);
        NDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", iNDArray2);
        NDValidation.validateNumerical("meanPairwiseSquaredError", "weights", iNDArray3);
        return Nd4j.exec(new MeanPairwiseSquaredErrorLoss(iNDArray, iNDArray2, iNDArray3, lossReduce))[0];
    }

    public INDArray meanPairwiseSquaredError(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        NDValidation.validateNumerical("meanPairwiseSquaredError", "label", iNDArray);
        NDValidation.validateNumerical("meanPairwiseSquaredError", "predictions", iNDArray2);
        NDValidation.validateNumerical("meanPairwiseSquaredError", "weights", iNDArray3);
        return Nd4j.exec(new MeanPairwiseSquaredErrorLoss(iNDArray, iNDArray2, iNDArray3, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
    }

    public INDArray meanSquaredError(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, LossReduce lossReduce) {
        NDValidation.validateNumerical("meanSquaredError", "label", iNDArray);
        NDValidation.validateNumerical("meanSquaredError", "predictions", iNDArray2);
        NDValidation.validateNumerical("meanSquaredError", "weights", iNDArray3);
        return Nd4j.exec(new MeanSquaredErrorLoss(iNDArray, iNDArray2, iNDArray3, lossReduce))[0];
    }

    public INDArray meanSquaredError(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        NDValidation.validateNumerical("meanSquaredError", "label", iNDArray);
        NDValidation.validateNumerical("meanSquaredError", "predictions", iNDArray2);
        NDValidation.validateNumerical("meanSquaredError", "weights", iNDArray3);
        return Nd4j.exec(new MeanSquaredErrorLoss(iNDArray, iNDArray2, iNDArray3, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT))[0];
    }

    public INDArray sigmoidCrossEntropy(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, LossReduce lossReduce, double d) {
        NDValidation.validateNumerical("sigmoidCrossEntropy", "label", iNDArray);
        NDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", iNDArray2);
        NDValidation.validateNumerical("sigmoidCrossEntropy", "weights", iNDArray3);
        return Nd4j.exec(new SigmoidCrossEntropyLoss(iNDArray, iNDArray2, iNDArray3, lossReduce, d))[0];
    }

    public INDArray sigmoidCrossEntropy(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        NDValidation.validateNumerical("sigmoidCrossEntropy", "label", iNDArray);
        NDValidation.validateNumerical("sigmoidCrossEntropy", "predictionLogits", iNDArray2);
        NDValidation.validateNumerical("sigmoidCrossEntropy", "weights", iNDArray3);
        return Nd4j.exec(new SigmoidCrossEntropyLoss(iNDArray, iNDArray2, iNDArray3, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0d))[0];
    }

    public INDArray softmaxCrossEntropy(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, LossReduce lossReduce, double d) {
        NDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", iNDArray);
        NDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", iNDArray2);
        NDValidation.validateNumerical("softmaxCrossEntropy", "weights", iNDArray3);
        return Nd4j.exec(new SoftmaxCrossEntropyLoss(iNDArray, iNDArray2, iNDArray3, lossReduce, d))[0];
    }

    public INDArray softmaxCrossEntropy(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        NDValidation.validateNumerical("softmaxCrossEntropy", "oneHotLabels", iNDArray);
        NDValidation.validateNumerical("softmaxCrossEntropy", "logitPredictions", iNDArray2);
        NDValidation.validateNumerical("softmaxCrossEntropy", "weights", iNDArray3);
        return Nd4j.exec(new SoftmaxCrossEntropyLoss(iNDArray, iNDArray2, iNDArray3, LossReduce.MEAN_BY_NONZERO_WEIGHT_COUNT, 0.0d))[0];
    }

    public INDArray sparseSoftmaxCrossEntropy(INDArray iNDArray, INDArray iNDArray2) {
        NDValidation.validateNumerical("sparseSoftmaxCrossEntropy", "logits", iNDArray);
        NDValidation.validateInteger("sparseSoftmaxCrossEntropy", "labels", iNDArray2);
        return Nd4j.exec(new SparseSoftmaxCrossEntropyLossWithLogits(iNDArray, iNDArray2))[0];
    }
}
