package ai.djl.ndarray.internal;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.Utils;
import java.lang.management.ManagementFactory;
import java.util.Arrays;
import java.util.Locale;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/* loaded from: input_file:ai/djl/ndarray/internal/NDFormat.class */
public abstract class NDFormat {
    private static final int PRECISION = 8;
    private static final String LF = System.getProperty("line.separator");
    private static final Pattern PATTERN = Pattern.compile("\\s*\\d\\.(\\d*?)0*e[+-](\\d+)");
    private static final boolean DEBUG = ManagementFactory.getRuntimeMXBean().getInputArguments().stream().anyMatch(str -> {
        return str.startsWith("-agentlib:jdwp");
    });

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/ndarray/internal/NDFormat$BooleanFormat.class */
    public static final class BooleanFormat extends NDFormat {
        private BooleanFormat() {
        }

        @Override // ai.djl.ndarray.internal.NDFormat
        public CharSequence format(Number number) {
            return number.byteValue() != 0 ? " true" : "false";
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/ndarray/internal/NDFormat$FloatFormat.class */
    public static final class FloatFormat extends NDFormat {
        private boolean exponential;
        private int precision;
        private int totalLength;

        private FloatFormat() {
        }

        @Override // ai.djl.ndarray.internal.NDFormat
        public void init(NDArray nDArray) {
            int i = 0;
            int i2 = 0;
            int i3 = 0;
            int i4 = 2;
            boolean z = false;
            double d = 0.0d;
            double d2 = Double.MAX_VALUE;
            for (Number number : nDArray.toArray()) {
                double doubleValue = number.doubleValue();
                if (doubleValue < 0.0d) {
                    z = true;
                }
                if (Double.isFinite(doubleValue)) {
                    double abs = Math.abs(doubleValue);
                    String format = String.format(Locale.ENGLISH, "%16e", Double.valueOf(abs));
                    Matcher matcher = NDFormat.PATTERN.matcher(format);
                    if (!matcher.matches()) {
                        throw new AssertionError("Invalid decimal value: " + format);
                    }
                    int length = matcher.group(1).length();
                    if (i3 < length) {
                        i3 = length;
                    }
                    int length2 = matcher.group(2).length();
                    if (length2 > i4) {
                        i4 = length2;
                    }
                    if (abs >= 1.0d) {
                        int log10 = ((int) Math.log10(abs)) + 1;
                        if (doubleValue < 0.0d) {
                            log10++;
                        }
                        if (log10 > i) {
                            i = log10;
                        }
                        int i5 = (length + 1) - log10;
                        if (i2 < i5) {
                            i2 = i5;
                        }
                    } else {
                        int i6 = doubleValue < 0.0d ? 2 : 1;
                        if (i6 > i) {
                            i = i6;
                        }
                        int parseInt = length + Integer.parseInt(matcher.group(2));
                        if (i2 < parseInt) {
                            i2 = parseInt;
                        }
                    }
                    if (abs > d) {
                        d = abs;
                    }
                    if (abs < d2 && abs > 0.0d) {
                        d2 = abs;
                    }
                } else {
                    int i7 = doubleValue < 0.0d ? 4 : 3;
                    if (this.totalLength < i7) {
                        this.totalLength = i7;
                    }
                }
            }
            double d3 = d / d2;
            if (d <= 1.0E8d && d2 >= 1.0E-4d && d3 <= 1000.0d) {
                this.precision = Math.min(4, i2);
                int i8 = i + this.precision + 1;
                if (this.totalLength < i8) {
                    this.totalLength = i8;
                    return;
                }
                return;
            }
            this.exponential = true;
            this.precision = Math.min(NDFormat.PRECISION, i3);
            this.totalLength = this.precision + 4;
            if (z) {
                this.totalLength++;
            }
        }

        @Override // ai.djl.ndarray.internal.NDFormat
        public CharSequence format(Number number) {
            double doubleValue = number.doubleValue();
            if (Double.isNaN(doubleValue)) {
                return String.format(Locale.ENGLISH, "%" + this.totalLength + "s", "nan");
            }
            if (Double.isInfinite(doubleValue)) {
                return doubleValue > 0.0d ? String.format(Locale.ENGLISH, "%" + this.totalLength + "s", "inf") : String.format(Locale.ENGLISH, "%" + this.totalLength + "s", "-inf");
            }
            if (this.exponential) {
                this.precision = Math.max(NDFormat.PRECISION, this.precision);
                return String.format(Locale.ENGLISH, "% ." + this.precision + "e", Double.valueOf(number.doubleValue()));
            }
            if (this.precision == 0) {
                return String.format(Locale.ENGLISH, "%" + (this.totalLength - 1) + '.' + this.precision + "f.", Double.valueOf(number.doubleValue()));
            }
            char[] charArray = String.format(Locale.ENGLISH, "%" + this.totalLength + '.' + this.precision + 'f', Double.valueOf(number.doubleValue())).toCharArray();
            for (int length = charArray.length - 1; length >= 0 && charArray[length] == '0'; length--) {
                charArray[length] = ' ';
            }
            return new String(charArray);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/ndarray/internal/NDFormat$HexFormat.class */
    public static final class HexFormat extends NDFormat {
        private HexFormat() {
        }

        @Override // ai.djl.ndarray.internal.NDFormat
        public CharSequence format(Number number) {
            return String.format(Locale.ENGLISH, "0x%02X", Byte.valueOf(number.byteValue()));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/ndarray/internal/NDFormat$IntFormat.class */
    public static final class IntFormat extends NDFormat {
        private boolean exponential;
        private int precision;
        private int totalLength;

        private IntFormat() {
        }

        @Override // ai.djl.ndarray.internal.NDFormat
        public void init(NDArray nDArray) {
            Number[] array = nDArray.toArray();
            if (array.length == 1) {
                this.totalLength = 1;
                return;
            }
            long j = 0;
            long j2 = 0;
            for (Number number : array) {
                long longValue = number.longValue();
                long abs = Math.abs(longValue);
                if (longValue < 0 && abs > j2) {
                    j2 = abs;
                }
                if (abs > j) {
                    j = abs;
                }
            }
            if (j < 1.0E8d) {
                this.totalLength = Math.max(j != 0 ? ((int) Math.log10(j)) + 1 : 1, j2 != 0 ? ((int) Math.log10(j2)) + 2 : 2);
            } else {
                this.exponential = true;
                this.precision = Math.min(NDFormat.PRECISION, ((int) Math.log10(j)) + 1);
            }
        }

        @Override // ai.djl.ndarray.internal.NDFormat
        public CharSequence format(Number number) {
            return this.exponential ? String.format(Locale.ENGLISH, "% ." + this.precision + "e", Float.valueOf(number.floatValue())) : String.format(Locale.ENGLISH, "%" + this.totalLength + "d", Long.valueOf(number.longValue()));
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/ndarray/internal/NDFormat$StringFormat.class */
    public static final class StringFormat extends NDFormat {
        private StringFormat() {
        }

        @Override // ai.djl.ndarray.internal.NDFormat
        public CharSequence format(Number number) {
            return null;
        }

        @Override // ai.djl.ndarray.internal.NDFormat
        protected String dump(StringBuilder sb, NDArray nDArray, int i, int i2, int i3, int i4) {
            return Arrays.toString(nDArray.toStringArray());
        }
    }

    public static String format(NDArray nDArray, int i, int i2, int i3, int i4) {
        StringBuilder sb = new StringBuilder(1000);
        String name = nDArray.getName();
        if (name != null) {
            sb.append(name).append(": ");
        } else {
            sb.append("ND: ");
        }
        sb.append(nDArray.getShape()).append(' ').append(nDArray.getDevice()).append(' ').append(nDArray.getDataType());
        if (nDArray.hasGradient()) {
            sb.append(" hasGradient");
        }
        if (DEBUG) {
            return sb.toString();
        }
        DataType dataType = nDArray.getDataType();
        return (dataType == DataType.UINT8 ? new HexFormat() : dataType == DataType.BOOLEAN ? new BooleanFormat() : dataType == DataType.STRING ? new StringFormat() : dataType.isInteger() ? new IntFormat() : new FloatFormat()).dump(sb, nDArray, i, i2, i3, i4);
    }

    protected abstract CharSequence format(Number number);

    protected void init(NDArray nDArray) {
    }

    protected String dump(StringBuilder sb, NDArray nDArray, int i, int i2, int i3, int i4) {
        init(nDArray);
        sb.append(LF);
        long size = nDArray.size();
        long dimension = nDArray.getShape().dimension();
        if (size == 0) {
            sb.append("[]").append(LF);
        } else if (dimension == 0) {
            sb.append(format(nDArray.toArray()[0])).append(LF);
        } else if (size > i) {
            sb.append("[ Exceed max print size ]");
        } else if (dimension > i2) {
            sb.append("[ Exceed max print dimension ]");
        } else {
            dump(sb, nDArray, 0, true, i3, i4);
        }
        return sb.toString();
    }

    private void dump(StringBuilder sb, NDArray nDArray, int i, boolean z, int i2, int i3) {
        if (!z) {
            Utils.pad(sb, ' ', i);
        }
        sb.append('[');
        Shape shape = nDArray.getShape();
        if (shape.dimension() == 1) {
            append(sb, nDArray.toArray(), i3);
        } else {
            long head = shape.head();
            long min = Math.min(head, i2);
            int i4 = 0;
            while (i4 < min) {
                NDArray nDArray2 = nDArray.get(i4);
                try {
                    dump(sb, nDArray2, i + 1, i4 == 0, i2, i3);
                    if (nDArray2 != null) {
                        nDArray2.close();
                    }
                    i4++;
                } catch (Throwable th) {
                    if (nDArray2 != null) {
                        try {
                            nDArray2.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            }
            long j = head - min;
            if (j > 0) {
                Utils.pad(sb, ' ', i + 1);
                sb.append("... ").append(j).append(" more");
            }
            Utils.pad(sb, ' ', i);
        }
        if (i == 0) {
            sb.append(']').append(LF);
        } else {
            sb.append("],").append(LF);
        }
    }

    private void append(StringBuilder sb, Number[] numberArr, int i) {
        if (numberArr.length == 0) {
            return;
        }
        long min = Math.min(numberArr.length, i);
        sb.append(format(numberArr[0]));
        for (int i2 = 1; i2 < min; i2++) {
            sb.append(", ");
            sb.append(format(numberArr[i2]));
        }
        long length = numberArr.length - min;
        if (length > 0) {
            sb.append(", ... ").append(length).append(" more");
        }
    }
}
