package org.nd4j.linalg.factory.ops;

import org.nd4j.common.base.Preconditions;
import org.nd4j.enums.PadMode;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.broadcast.BiasAdd;
import org.nd4j.linalg.api.ops.impl.layers.convolution.BatchNorm;
import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU;
import org.nd4j.linalg.api.ops.impl.scalar.PRelu;
import org.nd4j.linalg.api.ops.impl.scalar.RectifiedLinear;
import org.nd4j.linalg.api.ops.impl.scalar.Relu6;
import org.nd4j.linalg.api.ops.impl.transforms.Pad;
import org.nd4j.linalg.api.ops.impl.transforms.ReluLayer;
import org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU;
import org.nd4j.linalg.api.ops.impl.transforms.custom.DotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LayerNorm;
import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.custom.XwPlusB;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SigmoidDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftmaxBp;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.GELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.HardTanh;
import org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SELU;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Sigmoid;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftPlus;
import org.nd4j.linalg.api.ops.impl.transforms.strict.SoftSign;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Swish;
import org.nd4j.linalg.api.ops.impl.transforms.strict.Tanh;
import org.nd4j.linalg.api.ops.random.impl.DropOut;
import org.nd4j.linalg.api.ops.random.impl.DropOutInverted;
import org.nd4j.linalg.factory.NDValidation;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/nd4j/linalg/factory/ops/NDNN.class */
public class NDNN {
    public INDArray cReLU(INDArray iNDArray) {
        NDValidation.validateNumerical("CReLU", "x", iNDArray);
        return Nd4j.exec(new CReLU(iNDArray))[0];
    }

    public INDArray batchNorm(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5, double d, int... iArr) {
        NDValidation.validateNumerical("batchNorm", "input", iNDArray);
        NDValidation.validateNumerical("batchNorm", "mean", iNDArray2);
        NDValidation.validateNumerical("batchNorm", "variance", iNDArray3);
        NDValidation.validateNumerical("batchNorm", "gamma", iNDArray4);
        NDValidation.validateNumerical("batchNorm", "beta", iNDArray5);
        Preconditions.checkArgument(iArr.length >= 1, "axis has incorrect size/length. Expected: axis.length >= 1, got %s", iArr.length);
        return Nd4j.exec(new BatchNorm(iNDArray, iNDArray2, iNDArray3, iNDArray4, iNDArray5, d, iArr))[0];
    }

    public INDArray biasAdd(INDArray iNDArray, INDArray iNDArray2, boolean z) {
        NDValidation.validateNumerical("biasAdd", "input", iNDArray);
        NDValidation.validateNumerical("biasAdd", "bias", iNDArray2);
        return Nd4j.exec(new BiasAdd(iNDArray, iNDArray2, z))[0];
    }

    public INDArray dotProductAttention(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, boolean z) {
        NDValidation.validateNumerical("dotProductAttention", "queries", iNDArray);
        NDValidation.validateNumerical("dotProductAttention", "keys", iNDArray2);
        NDValidation.validateNumerical("dotProductAttention", "values", iNDArray3);
        NDValidation.validateNumerical("dotProductAttention", "mask", iNDArray4);
        return Nd4j.exec(new DotProductAttention(iNDArray, iNDArray2, iNDArray3, iNDArray4, z, false))[0];
    }

    public INDArray dropout(INDArray iNDArray, double d) {
        NDValidation.validateNumerical("dropout", "input", iNDArray);
        return Nd4j.exec(new DropOut(iNDArray, d));
    }

    public INDArray dropoutInverted(INDArray iNDArray, double d) {
        NDValidation.validateNumerical("dropoutInverted", "input", iNDArray);
        return Nd4j.exec(new DropOutInverted(iNDArray, d));
    }

    public INDArray elu(INDArray iNDArray) {
        NDValidation.validateNumerical("elu", "x", iNDArray);
        return Nd4j.exec(new ELU(iNDArray))[0];
    }

    public INDArray gelu(INDArray iNDArray) {
        NDValidation.validateNumerical("gelu", "x", iNDArray);
        return Nd4j.exec(new GELU(iNDArray));
    }

    public INDArray hardSigmoid(INDArray iNDArray) {
        NDValidation.validateNumerical("hardSigmoid", "x", iNDArray);
        return Nd4j.exec(new HardSigmoid(iNDArray));
    }

    public INDArray hardTanh(INDArray iNDArray) {
        NDValidation.validateNumerical("hardTanh", "x", iNDArray);
        return Nd4j.exec(new HardTanh(iNDArray));
    }

    public INDArray hardTanhDerivative(INDArray iNDArray) {
        NDValidation.validateNumerical("hardTanhDerivative", "x", iNDArray);
        return Nd4j.exec(new HardTanhDerivative(iNDArray));
    }

    public INDArray layerNorm(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, boolean z, int... iArr) {
        NDValidation.validateNumerical("layerNorm", "input", iNDArray);
        NDValidation.validateNumerical("layerNorm", "gain", iNDArray2);
        NDValidation.validateNumerical("layerNorm", "bias", iNDArray3);
        Preconditions.checkArgument(iArr.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", iArr.length);
        return Nd4j.exec(new LayerNorm(iNDArray, iNDArray2, iNDArray3, z, iArr))[0];
    }

    public INDArray layerNorm(INDArray iNDArray, INDArray iNDArray2, boolean z, int... iArr) {
        NDValidation.validateNumerical("layerNorm", "input", iNDArray);
        NDValidation.validateNumerical("layerNorm", "gain", iNDArray2);
        Preconditions.checkArgument(iArr.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", iArr.length);
        return Nd4j.exec(new LayerNorm(iNDArray, iNDArray2, (INDArray) null, z, iArr))[0];
    }

    public INDArray leakyRelu(INDArray iNDArray, double d) {
        NDValidation.validateNumerical("leakyRelu", "x", iNDArray);
        return Nd4j.exec(new LeakyReLU(iNDArray, d));
    }

    public INDArray leakyReluDerivative(INDArray iNDArray, double d) {
        NDValidation.validateNumerical("leakyReluDerivative", "x", iNDArray);
        return Nd4j.exec(new LeakyReLUDerivative(iNDArray, d));
    }

    public INDArray linear(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        NDValidation.validateNumerical("linear", "input", iNDArray);
        NDValidation.validateNumerical("linear", "weights", iNDArray2);
        NDValidation.validateNumerical("linear", "bias", iNDArray3);
        return Nd4j.exec(new XwPlusB(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public INDArray logSigmoid(INDArray iNDArray) {
        NDValidation.validateNumerical("logSigmoid", "x", iNDArray);
        return Nd4j.exec(new LogSigmoid(iNDArray));
    }

    public INDArray logSoftmax(INDArray iNDArray) {
        NDValidation.validateNumerical("logSoftmax", "x", iNDArray);
        return Nd4j.exec(new LogSoftMax(iNDArray))[0];
    }

    public INDArray logSoftmax(INDArray iNDArray, int i) {
        NDValidation.validateNumerical("logSoftmax", "x", iNDArray);
        return Nd4j.exec(new LogSoftMax(iNDArray, i))[0];
    }

    public INDArray multiHeadDotProductAttention(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3, INDArray iNDArray4, INDArray iNDArray5, INDArray iNDArray6, INDArray iNDArray7, INDArray iNDArray8, boolean z) {
        NDValidation.validateNumerical("multiHeadDotProductAttention", "queries", iNDArray);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "keys", iNDArray2);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "values", iNDArray3);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "Wq", iNDArray4);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "Wk", iNDArray5);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "Wv", iNDArray6);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "Wo", iNDArray7);
        NDValidation.validateNumerical("multiHeadDotProductAttention", "mask", iNDArray8);
        return Nd4j.exec(new MultiHeadDotProductAttention(iNDArray, iNDArray2, iNDArray3, iNDArray4, iNDArray5, iNDArray6, iNDArray7, iNDArray8, z, false))[0];
    }

    public INDArray pad(INDArray iNDArray, INDArray iNDArray2, PadMode padMode, double d) {
        NDValidation.validateNumerical("pad", "input", iNDArray);
        NDValidation.validateNumerical("pad", "padding", iNDArray2);
        return Nd4j.exec(new Pad(iNDArray, iNDArray2, padMode, d))[0];
    }

    public INDArray pad(INDArray iNDArray, INDArray iNDArray2, double d) {
        NDValidation.validateNumerical("pad", "input", iNDArray);
        NDValidation.validateNumerical("pad", "padding", iNDArray2);
        return Nd4j.exec(new Pad(iNDArray, iNDArray2, PadMode.CONSTANT, d))[0];
    }

    public INDArray preciseGelu(INDArray iNDArray) {
        NDValidation.validateNumerical("preciseGelu", "x", iNDArray);
        return Nd4j.exec(new PreciseGELU(iNDArray));
    }

    public INDArray prelu(INDArray iNDArray, INDArray iNDArray2, int... iArr) {
        NDValidation.validateNumerical("prelu", "input", iNDArray);
        NDValidation.validateNumerical("prelu", "alpha", iNDArray2);
        Preconditions.checkArgument(iArr.length >= 1, "sharedAxes has incorrect size/length. Expected: sharedAxes.length >= 1, got %s", iArr.length);
        return Nd4j.exec(new PRelu(iNDArray, iNDArray2, iArr))[0];
    }

    public INDArray relu(INDArray iNDArray, double d) {
        NDValidation.validateNumerical("relu", "x", iNDArray);
        return Nd4j.exec(new RectifiedLinear(iNDArray, d));
    }

    public INDArray relu6(INDArray iNDArray, double d) {
        NDValidation.validateNumerical("relu6", "x", iNDArray);
        return Nd4j.exec(new Relu6(iNDArray, d));
    }

    public INDArray reluLayer(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        NDValidation.validateNumerical("reluLayer", "input", iNDArray);
        NDValidation.validateNumerical("reluLayer", "weights", iNDArray2);
        NDValidation.validateNumerical("reluLayer", "bias", iNDArray3);
        return Nd4j.exec(new ReluLayer(iNDArray, iNDArray2, iNDArray3))[0];
    }

    public INDArray selu(INDArray iNDArray) {
        NDValidation.validateNumerical("selu", "x", iNDArray);
        return Nd4j.exec(new SELU(iNDArray));
    }

    public INDArray sigmoid(INDArray iNDArray) {
        NDValidation.validateNumerical("sigmoid", "x", iNDArray);
        return Nd4j.exec(new Sigmoid(iNDArray));
    }

    public INDArray sigmoidDerivative(INDArray iNDArray, INDArray iNDArray2) {
        NDValidation.validateNumerical("sigmoidDerivative", "x", iNDArray);
        NDValidation.validateNumerical("sigmoidDerivative", "wrt", iNDArray2);
        return Nd4j.exec(new SigmoidDerivative(iNDArray, iNDArray2))[0];
    }

    public INDArray softmax(INDArray iNDArray, int i) {
        NDValidation.validateNumerical("softmax", "x", iNDArray);
        return Nd4j.exec(new SoftMax(iNDArray, i))[0];
    }

    public INDArray softmax(INDArray iNDArray) {
        NDValidation.validateNumerical("softmax", "x", iNDArray);
        return Nd4j.exec(new SoftMax(iNDArray, -1))[0];
    }

    public INDArray softmaxDerivative(INDArray iNDArray, INDArray iNDArray2, int i) {
        NDValidation.validateNumerical("softmaxDerivative", "x", iNDArray);
        NDValidation.validateNumerical("softmaxDerivative", "wrt", iNDArray2);
        return Nd4j.exec(new SoftmaxBp(iNDArray, iNDArray2, Integer.valueOf(i)))[0];
    }

    public INDArray softplus(INDArray iNDArray) {
        NDValidation.validateNumerical("softplus", "x", iNDArray);
        return Nd4j.exec(new SoftPlus(iNDArray));
    }

    public INDArray softsign(INDArray iNDArray) {
        NDValidation.validateNumerical("softsign", "x", iNDArray);
        return Nd4j.exec(new SoftSign(iNDArray));
    }

    public INDArray softsignDerivative(INDArray iNDArray) {
        NDValidation.validateNumerical("softsignDerivative", "x", iNDArray);
        return Nd4j.exec(new SoftSignDerivative(iNDArray));
    }

    public INDArray swish(INDArray iNDArray) {
        NDValidation.validateNumerical("swish", "x", iNDArray);
        return Nd4j.exec(new Swish(iNDArray));
    }

    public INDArray tanh(INDArray iNDArray) {
        NDValidation.validateNumerical("tanh", "x", iNDArray);
        return Nd4j.exec(new Tanh(iNDArray));
    }
}
