package ai.djl.training.dataset;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;

/* loaded from: input_file:ai/djl/training/dataset/Batch.class */
public class Batch implements AutoCloseable {
    private NDManager manager;
    private NDList data;
    private NDList labels;
    private Batchifier dataBatchifier;
    private Batchifier labelBatchifier;
    private int size;
    private long progress;
    private long progressTotal;

    public Batch(NDManager nDManager, NDList nDList, NDList nDList2, int i, Batchifier batchifier, Batchifier batchifier2) {
        this.manager = nDManager;
        nDList.attach(nDManager);
        nDList2.attach(nDManager);
        this.data = nDList;
        this.labels = nDList2;
        this.size = i;
        this.dataBatchifier = batchifier;
        this.labelBatchifier = batchifier2;
    }

    public Batch(NDManager nDManager, NDList nDList, NDList nDList2, int i, Batchifier batchifier, Batchifier batchifier2, long j, long j2) {
        this.manager = nDManager;
        nDList.attach(nDManager);
        nDList2.attach(nDManager);
        this.data = nDList;
        this.labels = nDList2;
        this.size = i;
        this.dataBatchifier = batchifier;
        this.labelBatchifier = batchifier2;
        this.progress = j;
        this.progressTotal = j2;
    }

    public NDManager getManager() {
        return this.manager;
    }

    public NDList getData() {
        return this.data;
    }

    public NDList getLabels() {
        return this.labels;
    }

    public int getSize() {
        return this.size;
    }

    public long getProgress() {
        return this.progress;
    }

    public long getProgressTotal() {
        return this.progressTotal;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.manager.close();
        this.manager = null;
    }

    public Batch[] split(Device[] deviceArr, boolean z) {
        int length = deviceArr.length;
        if (length == 1) {
            return this.data.head().getDevice().equals(deviceArr[0]) ? new Batch[]{new Batch(this.manager, this.data, this.labels, this.size, this.dataBatchifier, this.labelBatchifier, this.progress, this.progressTotal)} : new Batch[]{new Batch(this.manager, this.data.asInDevice(deviceArr[0], true), this.labels.asInDevice(deviceArr[0], true), this.size, this.dataBatchifier, this.labelBatchifier, this.progress, this.progressTotal)};
        }
        NDList[] split = split(this.data, this.dataBatchifier, length, z);
        NDList[] split2 = split(this.labels, this.labelBatchifier, length, z);
        Batch[] batchArr = new Batch[split.length];
        int i = this.size / length;
        int i2 = 0;
        while (i2 < split.length) {
            batchArr[i2] = new Batch(this.manager, split[i2].asInDevice(deviceArr[i2], true), split2[i2].asInDevice(deviceArr[i2], true), i2 == split.length - 1 ? this.size - (i2 * i) : i, this.dataBatchifier, this.labelBatchifier, this.progress, this.progressTotal);
            i2++;
        }
        return batchArr;
    }

    private NDList[] split(NDList nDList, Batchifier batchifier, int i, boolean z) {
        if (batchifier == null) {
            throw new IllegalStateException("Split can only be called on a batch containing a batchifier");
        }
        return batchifier.split(nDList, i, z);
    }
}
