package ai.djl.training.dataset;

import ai.djl.Device;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.translate.Batchifier;

/* loaded from: input_file:ai/djl/training/dataset/Batch.class */
public class Batch implements AutoCloseable {
    private NDManager manager;
    private NDList data;
    private NDList labels;
    private Batchifier batchifier;

    public Batch(NDManager nDManager, NDList nDList, NDList nDList2) {
        this.manager = nDManager;
        nDList.attach(nDManager);
        nDList2.attach(nDManager);
        this.data = nDList;
        this.labels = nDList2;
    }

    public Batch(NDManager nDManager, NDList nDList, NDList nDList2, Batchifier batchifier) {
        this.manager = nDManager;
        nDList.attach(nDManager);
        nDList2.attach(nDManager);
        this.data = nDList;
        this.labels = nDList2;
        this.batchifier = batchifier;
    }

    public NDManager getManager() {
        return this.manager;
    }

    public NDList getData() {
        return this.data;
    }

    public NDList getLabels() {
        return this.labels;
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.manager.close();
        this.manager = null;
    }

    public Batch[] split(Device[] deviceArr, boolean z) {
        int length = deviceArr.length;
        if (length == 1) {
            return this.data.head().getDevice().equals(deviceArr[0]) ? new Batch[]{new Batch(this.manager, this.data, this.labels, this.batchifier)} : new Batch[]{new Batch(this.manager, this.data.asInDevice(deviceArr[0], true), this.labels.asInDevice(deviceArr[0], true), this.batchifier)};
        }
        NDList[] split = split(this.data, length, z);
        NDList[] split2 = split(this.labels, length, z);
        Batch[] batchArr = new Batch[split.length];
        for (int i = 0; i < split.length; i++) {
            batchArr[i] = new Batch(this.manager, split[i].asInDevice(deviceArr[i], true), split2[i].asInDevice(deviceArr[i], true), this.batchifier);
        }
        return batchArr;
    }

    private NDList[] split(NDList nDList, int i, boolean z) {
        if (this.batchifier == null) {
            throw new IllegalStateException("Split can only be called on a batch containing a batchifier");
        }
        return this.batchifier.split(nDList, i, z);
    }
}
