package ai.djl.ndarray;

import ai.djl.Device;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.util.JsonUtils;
import ai.djl.util.Pair;
import com.google.gson.JsonObject;
import com.google.gson.annotations.SerializedName;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PushbackInputStream;
import java.nio.BufferUnderflowException;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.zip.ZipEntry;
import java.util.zip.ZipInputStream;
import java.util.zip.ZipOutputStream;

/* loaded from: input_file:ai/djl/ndarray/NDList.class */
public class NDList extends ArrayList<NDArray> implements NDResource, BytesSupplier {
    private static final long serialVersionUID = 1;

    /* loaded from: input_file:ai/djl/ndarray/NDList$Encoding.class */
    public enum Encoding {
        ND_LIST,
        NPZ,
        SAFETENSORS
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:ai/djl/ndarray/NDList$SafeTensor.class */
    public static final class SafeTensor {
        String dtype;
        long[] shape;

        @SerializedName("data_offsets")
        int[] offsets;

        private SafeTensor() {
        }

        int size() {
            return this.offsets[1] - this.offsets[0];
        }
    }

    public NDList() {
    }

    public NDList(int i) {
        super(i);
    }

    public NDList(NDArray... nDArrayArr) {
        super(Arrays.asList(nDArrayArr));
    }

    public NDList(Collection<NDArray> collection) {
        super(collection);
    }

    public static NDList decode(NDManager nDManager, byte[] bArr) {
        if (bArr.length < 9) {
            throw new IllegalArgumentException("Invalid input length: " + bArr.length);
        }
        try {
            if (bArr[0] == 80 && bArr[1] == 75) {
                return decodeNumpy(nDManager, new ByteArrayInputStream(bArr));
            }
            if (bArr[0] == 57 && bArr[1] == 78 && bArr[2] == 85 && bArr[3] == 77) {
                return new NDList(NDSerializer.decode(nDManager, new ByteArrayInputStream(bArr)));
            }
            if (bArr[8] == 123) {
                return decodeSafetensors(nDManager, new ByteArrayInputStream(bArr));
            }
            ByteBuffer wrap = ByteBuffer.wrap(bArr);
            int i = wrap.getInt();
            if (i < 0) {
                throw new IllegalArgumentException("Invalid NDList size: " + i);
            }
            NDList nDList = new NDList();
            for (int i2 = 0; i2 < i; i2++) {
                nDList.add(i2, NDSerializer.decode(nDManager, wrap));
            }
            return nDList;
        } catch (IOException | BufferUnderflowException e) {
            throw new IllegalArgumentException("Invalid NDArray input", e);
        }
    }

    public static NDList decode(NDManager nDManager, InputStream inputStream) {
        try {
            byte[] bArr = new byte[9];
            new DataInputStream(inputStream).readFully(bArr);
            PushbackInputStream pushbackInputStream = new PushbackInputStream(inputStream, 9);
            pushbackInputStream.unread(bArr);
            if (bArr[0] == 80 && bArr[1] == 75) {
                return decodeNumpy(nDManager, pushbackInputStream);
            }
            if (bArr[0] == 57 && bArr[1] == 78 && bArr[2] == 85 && bArr[3] == 77) {
                return new NDList(NDSerializer.decode(nDManager, pushbackInputStream));
            }
            if (bArr[8] == 123) {
                return decodeSafetensors(nDManager, pushbackInputStream);
            }
            DataInputStream dataInputStream = new DataInputStream(pushbackInputStream);
            int readInt = dataInputStream.readInt();
            if (readInt < 0) {
                throw new IllegalArgumentException("Invalid NDList size: " + readInt);
            }
            NDList nDList = new NDList();
            for (int i = 0; i < readInt; i++) {
                nDList.add(i, nDManager.decode(dataInputStream));
            }
            return nDList;
        } catch (IOException e) {
            throw new IllegalArgumentException("Malformed data", e);
        }
    }

    private static NDList decodeSafetensors(NDManager nDManager, InputStream inputStream) throws IOException {
        DataInputStream dataInputStream = inputStream instanceof DataInputStream ? (DataInputStream) inputStream : new DataInputStream(inputStream);
        byte[] bArr = new byte[8];
        dataInputStream.readFully(bArr);
        byte[] bArr2 = new byte[Math.toIntExact(ByteBuffer.wrap(bArr).order(ByteOrder.LITTLE_ENDIAN).getLong())];
        dataInputStream.readFully(bArr2);
        String str = new String(bArr2, StandardCharsets.UTF_8);
        JsonObject jsonObject = (JsonObject) JsonUtils.GSON.fromJson(str, JsonObject.class);
        ArrayList<Pair> arrayList = new ArrayList();
        int i = 0;
        for (String str2 : jsonObject.keySet()) {
            if (!"__metadata__".equals(str2)) {
                SafeTensor safeTensor = (SafeTensor) JsonUtils.GSON.fromJson(jsonObject.get(str2), SafeTensor.class);
                if (safeTensor.offsets.length != 2) {
                    throw new IOException("Malformed safetensors metadata: " + str);
                }
                i = Math.max(i, safeTensor.offsets[1]);
                arrayList.add(new Pair(str2, safeTensor));
            }
        }
        byte[] bArr3 = new byte[i];
        dataInputStream.readFully(bArr3);
        NDList nDList = new NDList(arrayList.size());
        for (Pair pair : arrayList) {
            if (!"__metadata__".equals(pair.getKey())) {
                SafeTensor safeTensor2 = (SafeTensor) pair.getValue();
                Shape shape = new Shape(safeTensor2.shape);
                ByteBuffer wrap = ByteBuffer.wrap(bArr3, safeTensor2.offsets[0], safeTensor2.size());
                wrap.order(ByteOrder.LITTLE_ENDIAN);
                NDArray create = nDManager.create(wrap, shape, DataType.fromSafetensors(safeTensor2.dtype));
                create.setName((String) pair.getKey());
                nDList.add(create);
            }
        }
        return nDList;
    }

    private static NDList decodeNumpy(NDManager nDManager, InputStream inputStream) throws IOException {
        NDList nDList = new NDList();
        ZipInputStream zipInputStream = new ZipInputStream(inputStream);
        while (true) {
            ZipEntry nextEntry = zipInputStream.getNextEntry();
            if (nextEntry == null) {
                return nDList;
            }
            String name = nextEntry.getName();
            NDArray decodeNumpy = NDSerializer.decodeNumpy(nDManager, zipInputStream);
            if (!name.startsWith("arr_") && name.endsWith(".npy")) {
                decodeNumpy.setName(name.substring(0, name.length() - 4));
            }
            nDList.add(decodeNumpy);
        }
    }

    public NDArray get(String str) {
        Iterator<NDArray> it = iterator();
        while (it.hasNext()) {
            NDArray next = it.next();
            if (str.equals(next.getName())) {
                return next;
            }
        }
        return null;
    }

    public NDArray remove(String str) {
        int i = 0;
        Iterator<NDArray> it = iterator();
        while (it.hasNext()) {
            NDArray next = it.next();
            if (str.equals(next.getName())) {
                remove(i);
                return next;
            }
            i++;
        }
        return null;
    }

    public boolean contains(String str) {
        Iterator<NDArray> it = iterator();
        while (it.hasNext()) {
            if (str.equals(it.next().getName())) {
                return true;
            }
        }
        return false;
    }

    public NDArray head() {
        return get(0);
    }

    public NDArray singletonOrThrow() {
        if (size() != 1) {
            throw new IndexOutOfBoundsException("Incorrect number of elements in NDList.singletonOrThrow: Expected 1 and was " + size());
        }
        return get(0);
    }

    public NDList addAll(NDList nDList) {
        Iterator<NDArray> it = nDList.iterator();
        while (it.hasNext()) {
            add(it.next());
        }
        return this;
    }

    public NDList subNDList(int i) {
        return subNDList(i, size());
    }

    public NDList subNDList(int i, int i2) {
        return new NDList(subList(i, i2));
    }

    public NDList toDevice(Device device, boolean z) {
        if (!z && stream().allMatch(nDArray -> {
            return nDArray.getDevice() == device;
        })) {
            return this;
        }
        NDList nDList = new NDList(size());
        forEach(nDArray2 -> {
            nDList.add(nDArray2.toDevice(device, z));
        });
        return nDList;
    }

    @Override // ai.djl.ndarray.NDResource
    public NDManager getManager() {
        return head().getManager();
    }

    @Override // ai.djl.ndarray.NDResource
    public List<NDArray> getResourceNDArrays() {
        return this;
    }

    @Override // ai.djl.ndarray.NDResource
    public void attach(NDManager nDManager) {
        forEach(nDArray -> {
            nDArray.attach(nDManager);
        });
    }

    @Override // ai.djl.ndarray.NDResource
    public void tempAttach(NDManager nDManager) {
        forEach(nDArray -> {
            nDArray.tempAttach(nDManager);
        });
    }

    @Override // ai.djl.ndarray.NDResource
    public void detach() {
        forEach((v0) -> {
            v0.detach();
        });
    }

    public byte[] encode() {
        return encode(Encoding.ND_LIST);
    }

    public byte[] encode(Encoding encoding) {
        try {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            try {
                encode(byteArrayOutputStream, encoding);
                byte[] byteArray = byteArrayOutputStream.toByteArray();
                byteArrayOutputStream.close();
                return byteArray;
            } finally {
            }
        } catch (IOException e) {
            throw new AssertionError("NDList is not writable", e);
        }
    }

    public void encode(OutputStream outputStream) throws IOException {
        encode(outputStream, Encoding.ND_LIST);
    }

    public void encode(OutputStream outputStream, Encoding encoding) throws IOException {
        if (encoding == Encoding.NPZ) {
            ZipOutputStream zipOutputStream = new ZipOutputStream(outputStream);
            int i = 0;
            Iterator<NDArray> it = iterator();
            while (it.hasNext()) {
                NDArray next = it.next();
                String name = next.getName();
                if (name == null) {
                    zipOutputStream.putNextEntry(new ZipEntry("arr_" + i + ".npy"));
                    i++;
                } else {
                    zipOutputStream.putNextEntry(new ZipEntry(name + ".npy"));
                }
                NDSerializer.encodeAsNumpy(next, zipOutputStream);
            }
            zipOutputStream.finish();
            zipOutputStream.flush();
            return;
        }
        if (encoding != Encoding.SAFETENSORS) {
            DataOutputStream dataOutputStream = new DataOutputStream(outputStream);
            dataOutputStream.writeInt(size());
            Iterator<NDArray> it2 = iterator();
            while (it2.hasNext()) {
                NDSerializer.encode(it2.next(), dataOutputStream);
            }
            dataOutputStream.flush();
            return;
        }
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap(size());
        int i2 = 0;
        int i3 = 0;
        Iterator<NDArray> it3 = iterator();
        while (it3.hasNext()) {
            NDArray next2 = it3.next();
            String name2 = next2.getName();
            if (name2 == null) {
                name2 = "arr_" + i2;
                i2++;
            }
            SafeTensor safeTensor = new SafeTensor();
            safeTensor.dtype = next2.getDataType().asSafetensors();
            safeTensor.shape = next2.getShape().getShape();
            int intExact = i3 + Math.toIntExact(next2.getDataType().getNumOfBytes() * next2.size());
            safeTensor.offsets = new int[]{i3, intExact};
            concurrentHashMap.put(name2, safeTensor);
            i3 = intExact;
        }
        byte[] bytes = JsonUtils.GSON.toJson(concurrentHashMap).getBytes(StandardCharsets.UTF_8);
        ByteBuffer allocate = ByteBuffer.allocate(8);
        allocate.order(ByteOrder.LITTLE_ENDIAN);
        allocate.putLong(0, bytes.length);
        outputStream.write(allocate.array());
        outputStream.write(bytes);
        Iterator<NDArray> it4 = iterator();
        while (it4.hasNext()) {
            outputStream.write(it4.next().toByteArray());
        }
    }

    @Override // ai.djl.ndarray.BytesSupplier
    public byte[] getAsBytes() {
        return encode();
    }

    @Override // ai.djl.ndarray.BytesSupplier
    public ByteBuffer toByteBuffer() {
        return ByteBuffer.wrap(encode());
    }

    public Shape[] getShapes() {
        return (Shape[]) stream().map((v0) -> {
            return v0.getShape();
        }).toArray(i -> {
            return new Shape[i];
        });
    }

    @Override // ai.djl.ndarray.NDResource, java.lang.AutoCloseable
    public void close() {
        forEach((v0) -> {
            v0.close();
        });
        clear();
    }

    @Override // java.util.AbstractCollection
    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("NDList size: ").append(size()).append('\n');
        int i = 0;
        Iterator<NDArray> it = iterator();
        while (it.hasNext()) {
            NDArray next = it.next();
            String name = next.getName();
            int i2 = i;
            i++;
            sb.append(i2).append(' ');
            if (name != null) {
                sb.append(name);
            }
            sb.append(": ").append(next.getShape()).append(' ').append(next.getDataType()).append('\n');
        }
        return sb.toString();
    }
}
