package ai.djl.training.dataset;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;
import ai.djl.translate.Pipeline;
import java.io.IOException;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/training/dataset/DataIterable.class */
public class DataIterable implements Iterable<Batch>, Iterator<Batch> {
    private static final Logger logger = LoggerFactory.getLogger(DataIterable.class);
    private RandomAccessDataset dataset;
    private NDManager manager;
    private Batchifier dataBatchifier;
    private Batchifier labelBatchifier;
    private Pipeline pipeline;
    private Pipeline targetPipeline;
    private ExecutorService executor;
    private Device device;
    private Iterator<List<Long>> sample;
    private Queue<Future<Batch>> queue;
    private AtomicInteger progressCounter;
    private boolean autoClose;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:ai/djl/training/dataset/DataIterable$PreFetchCallable.class */
    public class PreFetchCallable implements Callable<Batch> {
        private List<Long> indices;
        private int progress;

        public PreFetchCallable(List<Long> list) {
            this.indices = list;
            this.progress = DataIterable.this.progressCounter.getAndAdd(list.size());
        }

        /* JADX WARN: Can't rename method to resolve collision */
        @Override // java.util.concurrent.Callable
        public Batch call() throws IOException {
            return DataIterable.this.fetch(this.indices, this.progress);
        }
    }

    public DataIterable(RandomAccessDataset randomAccessDataset, NDManager nDManager, Sampler sampler, Batchifier batchifier, Batchifier batchifier2, Pipeline pipeline, Pipeline pipeline2, ExecutorService executorService, int i, Device device) {
        this.dataset = randomAccessDataset;
        this.manager = nDManager.newSubManager();
        this.manager.setName("dataIter");
        this.dataBatchifier = batchifier;
        this.labelBatchifier = batchifier2;
        this.pipeline = pipeline;
        this.targetPipeline = pipeline2;
        this.executor = executorService;
        this.device = device;
        this.progressCounter = new AtomicInteger(0);
        this.autoClose = Boolean.parseBoolean(System.getProperty("ai.djl.dataiterator.autoclose", "true"));
        this.sample = sampler.sample(randomAccessDataset);
        if (executorService != null) {
            this.queue = new LinkedList();
            for (int i2 = 0; i2 < i; i2++) {
                preFetch();
            }
        }
    }

    @Override // java.lang.Iterable
    public Iterator<Batch> iterator() {
        return this;
    }

    @Override // java.util.Iterator
    public boolean hasNext() {
        if (this.executor != null) {
            if (!this.queue.isEmpty()) {
                return true;
            }
            if (!this.autoClose) {
                return false;
            }
            this.manager.close();
            return false;
        }
        if (this.sample.hasNext()) {
            return true;
        }
        if (!this.autoClose) {
            return false;
        }
        this.manager.close();
        return false;
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.Iterator
    public Batch next() {
        if (this.executor == null) {
            List<Long> next = this.sample.next();
            try {
                return fetch(next, this.progressCounter.addAndGet(next.size()));
            } catch (IOException e) {
                logger.error(e.getMessage());
                throw new IllegalStateException("Data loading failed", e);
            }
        }
        preFetch();
        try {
            return this.queue.poll().get();
        } catch (InterruptedException | ExecutionException e2) {
            logger.error(e2.getMessage());
            throw new IllegalStateException("Data loading failed", e2);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Batch fetch(List<Long> list, int i) throws IOException {
        NDManager newSubManager = this.manager.newSubManager();
        newSubManager.setName("dataIter fetch");
        int size = list.size();
        NDList[] nDListArr = new NDList[size];
        NDList[] nDListArr2 = new NDList[size];
        for (int i2 = 0; i2 < size; i2++) {
            Record record = this.dataset.get(newSubManager, list.get(i2).longValue());
            nDListArr[i2] = record.getData();
            if (this.pipeline != null) {
                nDListArr[i2] = this.pipeline.transform(nDListArr[i2]);
            }
            nDListArr2[i2] = record.getLabels();
        }
        NDList batchify = this.dataBatchifier.batchify(nDListArr);
        NDList batchify2 = this.labelBatchifier.batchify(nDListArr2);
        Arrays.stream(nDListArr).forEach((v0) -> {
            v0.close();
        });
        Arrays.stream(nDListArr2).forEach((v0) -> {
            v0.close();
        });
        if (this.targetPipeline != null) {
            batchify2 = this.targetPipeline.transform(batchify2);
        }
        if (this.device != null) {
            batchify = batchify.toDevice(this.device, false);
            batchify2 = batchify2.toDevice(this.device, false);
        }
        return new Batch(newSubManager, batchify, batchify2, size, this.dataBatchifier, this.labelBatchifier, i, this.dataset.size());
    }

    private void preFetch() {
        if (this.sample.hasNext()) {
            this.queue.offer(this.executor.submit(new PreFetchCallable(this.sample.next())));
        }
    }
}
