package ai.djl.training.dataset;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.training.dataset.RandomAccessDataset;
import ai.djl.util.Progress;
import java.io.IOException;
import java.util.stream.Stream;

/* loaded from: input_file:ai/djl/training/dataset/ArrayDataset.class */
public class ArrayDataset extends RandomAccessDataset {
    protected NDArray[] data;
    protected NDArray[] labels;

    /* loaded from: input_file:ai/djl/training/dataset/ArrayDataset$Builder.class */
    public static final class Builder extends RandomAccessDataset.BaseBuilder<Builder> {
        private NDArray[] data;
        private NDArray[] labels;

        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.training.dataset.RandomAccessDataset.BaseBuilder
        public Builder self() {
            return this;
        }

        public Builder setData(NDArray... nDArrayArr) {
            this.data = nDArrayArr;
            return self();
        }

        public Builder optLabels(NDArray... nDArrayArr) {
            this.labels = nDArrayArr;
            return self();
        }

        public ArrayDataset build() {
            if (this.data == null || this.data.length == 0) {
                throw new IllegalArgumentException("Please pass in at least one data");
            }
            return new ArrayDataset(this);
        }
    }

    public ArrayDataset(RandomAccessDataset.BaseBuilder<?> baseBuilder) {
        super(baseBuilder);
        if (baseBuilder instanceof Builder) {
            Builder builder = (Builder) baseBuilder;
            this.data = builder.data;
            this.labels = builder.labels;
            long size = this.data[0].size(0);
            if (Stream.of((Object[]) this.data).anyMatch(nDArray -> {
                return nDArray.size(0) != size;
            })) {
                throw new IllegalArgumentException("All the NDArray must have the same length!");
            }
            if (this.labels != null && Stream.of((Object[]) this.labels).anyMatch(nDArray2 -> {
                return nDArray2.size(0) != size;
            })) {
                throw new IllegalArgumentException("All the NDArray must have the same length!");
            }
        }
    }

    @Override // ai.djl.training.dataset.RandomAccessDataset
    protected long availableSize() {
        return this.data[0].size(0);
    }

    @Override // ai.djl.training.dataset.RandomAccessDataset
    public Record get(NDManager nDManager, long j) {
        NDList nDList = new NDList();
        NDList nDList2 = new NDList();
        for (NDArray nDArray : this.data) {
            nDList.add(nDArray.get(j));
        }
        if (this.labels != null) {
            for (NDArray nDArray2 : this.labels) {
                nDList2.add(nDArray2.get(j));
            }
        }
        nDList.attach(nDManager);
        nDList2.attach(nDManager);
        return new Record(nDList, nDList2);
    }

    @Override // ai.djl.training.dataset.Dataset
    public void prepare(Progress progress) throws IOException {
    }
}
