/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.spark.iterator;

import java.util.UUID;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import lombok.NonNull;
import org.apache.spark.TaskContext;
import org.apache.spark.TaskContextHelper;
import org.deeplearning4j.datasets.iterator.AsyncMultiDataSetIterator;
import org.deeplearning4j.datasets.iterator.callbacks.DataSetCallback;
import org.deeplearning4j.datasets.iterator.callbacks.DefaultCallback;
import org.nd4j.linalg.dataset.api.MultiDataSet;
import org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class SparkAMDSI
extends AsyncMultiDataSetIterator {
    private static final Logger log = LoggerFactory.getLogger(SparkAMDSI.class);
    protected TaskContext context;

    protected SparkAMDSI() {
    }

    public SparkAMDSI(MultiDataSetIterator baseIterator) {
        this(baseIterator, 8);
    }

    public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue) {
        this(iterator, queueSize, queue, true);
    }

    public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<MultiDataSet>(queueSize));
    }

    public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<MultiDataSet>(queueSize), useWorkspace);
    }

    public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace, Integer deviceId) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<MultiDataSet>(queueSize), useWorkspace, (DataSetCallback)new DefaultCallback(), deviceId);
    }

    public SparkAMDSI(MultiDataSetIterator baseIterator, int queueSize, boolean useWorkspace, DataSetCallback callback) {
        this(baseIterator, queueSize, new LinkedBlockingQueue<MultiDataSet>(queueSize), useWorkspace, callback);
    }

    public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue, boolean useWorkspace) {
        this(iterator, queueSize, queue, useWorkspace, null);
    }

    public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue, boolean useWorkspace, DataSetCallback callback) {
        this(iterator, queueSize, queue, useWorkspace, callback, Nd4j.getAffinityManager().getDeviceForCurrentThread());
    }

    public SparkAMDSI(MultiDataSetIterator iterator, int queueSize, BlockingQueue<MultiDataSet> queue, boolean useWorkspace, DataSetCallback callback, Integer deviceId) {
        this();
        if (queueSize < 2) {
            queueSize = 2;
        }
        this.callback = callback;
        this.buffer = queue;
        this.backedIterator = iterator;
        this.useWorkspaces = useWorkspace;
        this.prefetchSize = queueSize;
        this.workspaceId = "SAMDSI_ITER-" + UUID.randomUUID().toString();
        this.deviceId = deviceId;
        if (iterator.resetSupported()) {
            this.backedIterator.reset();
        }
        this.thread = new SparkPrefetchThread(this.buffer, iterator, this.terminator);
        this.context = TaskContext.get();
        Nd4j.getAffinityManager().attachThreadToDevice((Thread)this.thread, deviceId);
        this.thread.setDaemon(true);
        this.thread.start();
    }

    protected void externalCall() {
        TaskContextHelper.setTaskContext(this.context);
    }

    protected class SparkPrefetchThread
    extends AsyncMultiDataSetIterator.AsyncPrefetchThread {
        protected SparkPrefetchThread(@NonNull BlockingQueue<MultiDataSet> queue, @NonNull MultiDataSetIterator iterator, MultiDataSet terminator) {
            super((AsyncMultiDataSetIterator)SparkAMDSI.this, queue, iterator, terminator);
            if (queue == null) {
                throw new NullPointerException("queue");
            }
            if (iterator == null) {
                throw new NullPointerException("iterator");
            }
            if (terminator == null) {
                throw new NullPointerException("terminator");
            }
        }
    }
}

