package ai.djl.ndarray;

import ai.djl.Device;
import ai.djl.ndarray.index.NDIndex;
import ai.djl.ndarray.internal.NDArrayEx;
import ai.djl.ndarray.internal.NDFormat;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.ndarray.types.SparseFormat;
import ai.djl.util.Float16Utils;
import ai.djl.util.cuda.CudaLibrary;
import java.nio.Buffer;
import java.nio.ByteBuffer;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.LongBuffer;
import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.function.Function;
import java.util.stream.IntStream;
import java.util.stream.LongStream;

/* loaded from: input_file:ai/djl/ndarray/NDArray.class */
public interface NDArray extends NDResource, BytesSupplier {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: ai.djl.ndarray.NDArray$1, reason: invalid class name */
    /* loaded from: input_file:ai/djl/ndarray/NDArray$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$ai$djl$ndarray$types$DataType = new int[DataType.values().length];

        static {
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.FLOAT16.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.FLOAT32.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.FLOAT64.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.INT32.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.INT64.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.BOOLEAN.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.INT8.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$ai$djl$ndarray$types$DataType[DataType.UINT8.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
        }
    }

    static NDArray decode(NDManager nDManager, byte[] bArr) {
        return nDManager.decode(bArr);
    }

    String getName();

    void setName(String str);

    String getUid();

    DataType getDataType();

    Device getDevice();

    Shape getShape();

    SparseFormat getSparseFormat();

    default boolean isSparse() {
        return getSparseFormat() != SparseFormat.DENSE;
    }

    default boolean isScalar() {
        return getShape().isScalar();
    }

    default byte[] encode() {
        return NDSerializer.encode(this);
    }

    NDArray toDevice(Device device, boolean z);

    NDArray toType(DataType dataType, boolean z);

    void setRequiresGradient(boolean z);

    NDArray getGradient();

    boolean hasGradient();

    NDArray stopGradient();

    default NDArray scaleGradient(double d) {
        return mul(Double.valueOf(d)).add(stopGradient().mul(Double.valueOf(1.0d - d)));
    }

    default long size(int i) {
        return getShape().size(i);
    }

    default long size() {
        return getShape().size();
    }

    default double[] toDoubleArray() {
        if (getDataType() != DataType.FLOAT64) {
            throw new IllegalStateException("DataType mismatch, Required double Actual " + getDataType());
        }
        DoubleBuffer asDoubleBuffer = toByteBuffer().asDoubleBuffer();
        double[] dArr = new double[asDoubleBuffer.remaining()];
        asDoubleBuffer.get(dArr);
        return dArr;
    }

    default float[] toFloatArray() {
        if (getDataType() == DataType.FLOAT16) {
            return Float16Utils.fromByteBuffer(toByteBuffer());
        }
        if (getDataType() != DataType.FLOAT32) {
            throw new IllegalStateException("DataType mismatch, Required float, Actual " + getDataType());
        }
        FloatBuffer asFloatBuffer = toByteBuffer().asFloatBuffer();
        float[] fArr = new float[asFloatBuffer.remaining()];
        asFloatBuffer.get(fArr);
        return fArr;
    }

    default int[] toIntArray() {
        if (getDataType() != DataType.INT32) {
            throw new IllegalStateException("DataType mismatch, Required int Actual " + getDataType());
        }
        IntBuffer asIntBuffer = toByteBuffer().asIntBuffer();
        int[] iArr = new int[asIntBuffer.remaining()];
        asIntBuffer.get(iArr);
        return iArr;
    }

    default long[] toLongArray() {
        if (getDataType() != DataType.INT64) {
            throw new IllegalStateException("DataType mismatch, Required long Actual " + getDataType());
        }
        LongBuffer asLongBuffer = toByteBuffer().asLongBuffer();
        long[] jArr = new long[asLongBuffer.remaining()];
        asLongBuffer.get(jArr);
        return jArr;
    }

    default byte[] toByteArray() {
        ByteBuffer byteBuffer = toByteBuffer();
        if (byteBuffer.hasArray()) {
            return byteBuffer.array();
        }
        byte[] bArr = new byte[byteBuffer.remaining()];
        byteBuffer.get(bArr);
        return bArr;
    }

    default int[] toUint8Array() {
        ByteBuffer byteBuffer = toByteBuffer();
        int[] iArr = new int[byteBuffer.remaining()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = byteBuffer.get() & 255;
        }
        return iArr;
    }

    default boolean[] toBooleanArray() {
        if (getDataType() != DataType.BOOLEAN) {
            throw new IllegalStateException("DataType mismatch, Required boolean Actual " + getDataType());
        }
        ByteBuffer byteBuffer = toByteBuffer();
        boolean[] zArr = new boolean[byteBuffer.remaining()];
        for (int i = 0; i < zArr.length; i++) {
            zArr[i] = byteBuffer.get() != 0;
        }
        return zArr;
    }

    default String[] toStringArray() {
        return toStringArray(StandardCharsets.UTF_8);
    }

    String[] toStringArray(Charset charset);

    default Number[] toArray() {
        switch (AnonymousClass1.$SwitchMap$ai$djl$ndarray$types$DataType[getDataType().ordinal()]) {
            case 1:
            case 2:
                float[] floatArray = toFloatArray();
                return (Number[]) IntStream.range(0, floatArray.length).mapToObj(i -> {
                    return Float.valueOf(floatArray[i]);
                }).toArray(i2 -> {
                    return new Number[i2];
                });
            case CudaLibrary.INITIALIZATION_ERROR /* 3 */:
                return (Number[]) Arrays.stream(toDoubleArray()).boxed().toArray(i3 -> {
                    return new Double[i3];
                });
            case 4:
                return (Number[]) Arrays.stream(toIntArray()).boxed().toArray(i4 -> {
                    return new Integer[i4];
                });
            case 5:
                return (Number[]) Arrays.stream(toLongArray()).boxed().toArray(i5 -> {
                    return new Long[i5];
                });
            case 6:
            case 7:
                ByteBuffer byteBuffer = toByteBuffer();
                Byte[] bArr = new Byte[byteBuffer.remaining()];
                for (int i6 = 0; i6 < bArr.length; i6++) {
                    bArr[i6] = Byte.valueOf(byteBuffer.get());
                }
                return bArr;
            case 8:
                return (Number[]) Arrays.stream(toUint8Array()).boxed().toArray(i7 -> {
                    return new Integer[i7];
                });
            default:
                throw new IllegalStateException("Unsupported DataType: " + getDataType());
        }
    }

    void set(Buffer buffer);

    default void set(float[] fArr) {
        set(FloatBuffer.wrap(fArr));
    }

    default void set(int[] iArr) {
        set(IntBuffer.wrap(iArr));
    }

    default void set(double[] dArr) {
        set(DoubleBuffer.wrap(dArr));
    }

    default void set(long[] jArr) {
        set(LongBuffer.wrap(jArr));
    }

    default void set(byte[] bArr) {
        set(ByteBuffer.wrap(bArr));
    }

    default void set(NDIndex nDIndex, NDArray nDArray) {
        getNDArrayInternal().getIndexer(getManager()).set(this, nDIndex, nDArray);
    }

    default void set(NDIndex nDIndex, Number number) {
        getNDArrayInternal().getIndexer(getManager()).set(this, nDIndex, number);
    }

    default void set(NDIndex nDIndex, Function<NDArray, NDArray> function) {
        set(nDIndex, function.apply(get(nDIndex)));
    }

    default void set(NDArray nDArray, Number number) {
        set(new NDIndex("{}", nDArray), number);
    }

    default void setScalar(NDIndex nDIndex, Number number) {
        getNDArrayInternal().getIndexer(getManager()).setScalar(this, nDIndex, number);
    }

    default NDArray get(NDIndex nDIndex) {
        return get(getManager(), nDIndex);
    }

    default NDArray get(NDManager nDManager, NDIndex nDIndex) {
        return getNDArrayInternal().getIndexer(nDManager).get(this, nDIndex);
    }

    default NDArray get(NDArray nDArray) {
        return get(new NDIndex("{}", nDArray));
    }

    default NDArray get(String str, Object... objArr) {
        return get(new NDIndex(str, objArr));
    }

    default NDArray get(long... jArr) {
        return get(new NDIndex(jArr));
    }

    default NDArray get(NDManager nDManager, long... jArr) {
        return get(nDManager, new NDIndex(jArr));
    }

    NDArray gather(NDArray nDArray, int i);

    default NDArray take(NDArray nDArray) {
        return take(getManager(), nDArray);
    }

    NDArray take(NDManager nDManager, NDArray nDArray);

    NDArray put(NDArray nDArray, NDArray nDArray2);

    default NDArray getScalar(long... jArr) {
        NDArray nDArray = get(new NDIndex(jArr));
        if (nDArray.size() != 1) {
            throw new IllegalArgumentException("The supplied Index does not produce a scalar");
        }
        return nDArray;
    }

    default long getLong(long... jArr) {
        return getScalar(jArr).toLongArray()[0];
    }

    default double getDouble(long... jArr) {
        return getScalar(jArr).toDoubleArray()[0];
    }

    default float getFloat(long... jArr) {
        return getScalar(jArr).toFloatArray()[0];
    }

    default int getInt(long... jArr) {
        return getScalar(jArr).toIntArray()[0];
    }

    default byte getByte(long... jArr) {
        return getScalar(jArr).toByteArray()[0];
    }

    default int getUint8(long... jArr) {
        return getByte(jArr) & 255;
    }

    default boolean getBoolean(long... jArr) {
        return getScalar(jArr).toBooleanArray()[0];
    }

    void copyTo(NDArray nDArray);

    default NDArray duplicate() {
        NDArray create = getManager().create(getShape(), getDataType(), getDevice());
        create.setName(getName());
        copyTo(create);
        return create;
    }

    default NDArray booleanMask(NDArray nDArray) {
        return booleanMask(nDArray, 0);
    }

    NDArray booleanMask(NDArray nDArray, int i);

    NDArray sequenceMask(NDArray nDArray, float f);

    NDArray sequenceMask(NDArray nDArray);

    default NDArray zerosLike() {
        return getManager().zeros(getShape(), getDataType(), getDevice());
    }

    default NDArray onesLike() {
        return getManager().ones(getShape(), getDataType(), getDevice());
    }

    default NDArray like() {
        return getManager().create(getShape());
    }

    boolean contentEquals(Number number);

    boolean contentEquals(NDArray nDArray);

    default boolean shapeEquals(NDArray nDArray) {
        return getShape().equals(nDArray.getShape());
    }

    default boolean allClose(NDArray nDArray) {
        return allClose(nDArray, 1.0E-5d, 1.0E-8d, false);
    }

    default boolean allClose(NDArray nDArray, double d, double d2, boolean z) {
        if (!shapeEquals(nDArray)) {
            return false;
        }
        Number[] array = toArray();
        Number[] array2 = nDArray.toArray();
        for (int i = 0; i < array.length; i++) {
            double doubleValue = array[i].doubleValue();
            double doubleValue2 = array2[i].doubleValue();
            if (!(z && Double.isNaN(doubleValue) && Double.isNaN(doubleValue2)) && (Double.isNaN(doubleValue) || Double.isNaN(doubleValue2) || Math.abs(doubleValue - doubleValue2) > d2 + (d * Math.abs(doubleValue2)))) {
                return false;
            }
        }
        return true;
    }

    NDArray eq(Number number);

    NDArray eq(NDArray nDArray);

    NDArray neq(Number number);

    NDArray neq(NDArray nDArray);

    NDArray gt(Number number);

    NDArray gt(NDArray nDArray);

    NDArray gte(Number number);

    NDArray gte(NDArray nDArray);

    NDArray lt(Number number);

    NDArray lt(NDArray nDArray);

    NDArray lte(Number number);

    NDArray lte(NDArray nDArray);

    NDArray add(Number number);

    NDArray add(NDArray nDArray);

    NDArray sub(Number number);

    NDArray sub(NDArray nDArray);

    NDArray mul(Number number);

    NDArray mul(NDArray nDArray);

    NDArray div(Number number);

    NDArray div(NDArray nDArray);

    NDArray mod(Number number);

    NDArray mod(NDArray nDArray);

    NDArray pow(Number number);

    NDArray pow(NDArray nDArray);

    NDArray addi(Number number);

    NDArray addi(NDArray nDArray);

    NDArray subi(Number number);

    NDArray subi(NDArray nDArray);

    NDArray muli(Number number);

    NDArray muli(NDArray nDArray);

    NDArray divi(Number number);

    NDArray divi(NDArray nDArray);

    NDArray modi(Number number);

    NDArray modi(NDArray nDArray);

    NDArray powi(Number number);

    NDArray powi(NDArray nDArray);

    NDArray sign();

    NDArray signi();

    NDArray maximum(Number number);

    NDArray maximum(NDArray nDArray);

    NDArray minimum(Number number);

    NDArray minimum(NDArray nDArray);

    NDArray neg();

    NDArray negi();

    NDArray abs();

    NDArray square();

    NDArray sqrt();

    NDArray cbrt();

    NDArray floor();

    NDArray ceil();

    NDArray round();

    NDArray trunc();

    NDArray exp();

    NDArray gammaln();

    NDArray log();

    NDArray log10();

    NDArray log2();

    NDArray sin();

    NDArray cos();

    NDArray tan();

    NDArray asin();

    NDArray acos();

    NDArray atan();

    NDArray sinh();

    NDArray cosh();

    NDArray tanh();

    NDArray asinh();

    NDArray acosh();

    NDArray atanh();

    NDArray toDegrees();

    NDArray toRadians();

    NDArray max();

    default NDArray max(int[] iArr) {
        return max(iArr, false);
    }

    NDArray max(int[] iArr, boolean z);

    NDArray min();

    default NDArray min(int[] iArr) {
        return min(iArr, false);
    }

    NDArray min(int[] iArr, boolean z);

    NDArray sum();

    default NDArray sum(int[] iArr) {
        return sum(iArr, false);
    }

    NDArray sum(int[] iArr, boolean z);

    NDArray prod();

    default NDArray prod(int[] iArr) {
        return prod(iArr, false);
    }

    NDArray prod(int[] iArr, boolean z);

    NDArray mean();

    default NDArray mean(int[] iArr) {
        return mean(iArr, false);
    }

    NDArray mean(int[] iArr, boolean z);

    default NDArray normalize() {
        return normalize(2.0d, 1L, 1.0E-12d);
    }

    default NDArray normalize(double d, long j) {
        return normalize(d, j, 1.0E-12d);
    }

    NDArray normalize(double d, long j, double d2);

    NDArray rotate90(int i, int[] iArr);

    default NDArray trace() {
        return trace(0, 0, 1);
    }

    default NDArray trace(int i) {
        return trace(i, 0, 1);
    }

    NDArray trace(int i, int i2, int i3);

    default NDList split(long j) {
        return split(j, 0);
    }

    default NDList split(long[] jArr) {
        return split(jArr, 0);
    }

    default NDList split(long j, int i) {
        long j2 = getShape().getShape()[i];
        if (j2 % j != 0) {
            throw new IllegalArgumentException("array split does not result in an equal division");
        }
        long j3 = j2 / j;
        return split(LongStream.range(0L, j).map(j4 -> {
            return j4 * j3;
        }).toArray(), i);
    }

    NDList split(long[] jArr, int i);

    NDArray flatten();

    default NDArray reshape(long... jArr) {
        return reshape(new Shape(jArr));
    }

    NDArray reshape(Shape shape);

    NDArray expandDims(int i);

    default NDArray squeeze() {
        long[] shape = getShape().getShape();
        return squeeze(IntStream.range(0, shape.length).filter(i -> {
            return shape[i] == 1;
        }).toArray());
    }

    default NDArray squeeze(int i) {
        return squeeze(new int[]{i});
    }

    NDArray squeeze(int[] iArr);

    default NDArray stack(NDArray nDArray) {
        return stack(nDArray, 0);
    }

    default NDArray stack(NDArray nDArray, int i) {
        return getNDArrayInternal().stack(new NDList(nDArray), i);
    }

    default NDArray concat(NDArray nDArray) {
        return concat(nDArray, 0);
    }

    default NDArray concat(NDArray nDArray, int i) {
        return getNDArrayInternal().concat(new NDList(nDArray), i);
    }

    NDArray logicalAnd(NDArray nDArray);

    NDArray logicalOr(NDArray nDArray);

    NDArray logicalXor(NDArray nDArray);

    NDArray logicalNot();

    default NDArray argSort() {
        return argSort(-1, true);
    }

    default NDArray argSort(int i) {
        return argSort(i, true);
    }

    NDArray argSort(int i, boolean z);

    NDArray sort();

    NDArray sort(int i);

    NDArray softmax(int i);

    NDArray logSoftmax(int i);

    NDArray cumSum();

    NDArray cumSum(int i);

    void intern(NDArray nDArray);

    NDArray isInfinite();

    NDArray inverse();

    NDArray isNaN();

    NDArray tile(long j);

    NDArray tile(int i, long j);

    NDArray tile(long[] jArr);

    NDArray tile(Shape shape);

    NDArray repeat(long j);

    NDArray repeat(int i, long j);

    NDArray repeat(long[] jArr);

    NDArray repeat(Shape shape);

    NDArray dot(NDArray nDArray);

    NDArray matMul(NDArray nDArray);

    NDArray clip(Number number, Number number2);

    default NDArray swapAxes(int i, int i2) {
        int[] array = IntStream.range(0, getShape().dimension()).toArray();
        int i3 = array[i];
        array[i] = array[i2];
        array[i2] = i3;
        return transpose(array);
    }

    NDArray flip(int... iArr);

    NDArray transpose();

    NDArray transpose(int... iArr);

    NDArray broadcast(Shape shape);

    default NDArray broadcast(long... jArr) {
        return broadcast(new Shape(jArr));
    }

    NDArray argMax();

    NDArray argMax(int i);

    NDArray argMin();

    NDArray argMin(int i);

    NDArray percentile(Number number);

    NDArray percentile(Number number, int[] iArr);

    NDArray median();

    NDArray median(int[] iArr);

    NDArray toDense();

    NDArray toSparse(SparseFormat sparseFormat);

    NDArray nonzero();

    default boolean isEmpty() {
        return getShape().size() == 0;
    }

    default NDArray all() {
        return toType(DataType.BOOLEAN, false).sum().eq(Long.valueOf(size()));
    }

    default NDArray any() {
        return toType(DataType.BOOLEAN, false).sum().gt((Number) 0);
    }

    default NDArray none() {
        return toType(DataType.BOOLEAN, false).sum().eq((Number) 0);
    }

    default NDArray countNonzero() {
        return toType(DataType.BOOLEAN, false).sum();
    }

    default NDArray countNonzero(int i) {
        return toType(DataType.BOOLEAN, false).sum(new int[]{i});
    }

    NDArray erfinv();

    NDArrayEx getNDArrayInternal();

    default String toDebugString() {
        return toDebugString(100, 10, 10, 20);
    }

    default String toDebugString(int i, int i2, int i3, int i4) {
        return NDFormat.format(this, i, i2, i3, i4);
    }

    @Override // ai.djl.ndarray.NDResource, java.lang.AutoCloseable
    void close();

    default NDArray norm() {
        return norm(false);
    }

    default NDArray norm(int[] iArr) {
        return norm(iArr, false);
    }

    NDArray norm(boolean z);

    default NDArray norm(int[] iArr, boolean z) {
        return norm(2, iArr, z);
    }

    NDArray norm(int i, int[] iArr, boolean z);

    default NDArray oneHot(int i) {
        return oneHot(i, 1.0f, 0.0f, DataType.FLOAT32);
    }

    default NDArray oneHot(int i, DataType dataType) {
        return oneHot(i, 1.0f, 0.0f, dataType);
    }

    NDArray oneHot(int i, float f, float f2, DataType dataType);

    NDArray batchDot(NDArray nDArray);
}
