/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.datasets.fetchers;

import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.Random;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.base.MnistFetcher;
import org.deeplearning4j.datasets.fetchers.BaseDataFetcher;
import org.deeplearning4j.datasets.mnist.MnistManager;
import org.deeplearning4j.util.MathUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

public class MnistDataFetcher
extends BaseDataFetcher {
    public static final int NUM_EXAMPLES = 60000;
    public static final int NUM_EXAMPLES_TEST = 10000;
    protected static final String TEMP_ROOT = System.getProperty("user.home");
    protected static final String MNIST_ROOT = TEMP_ROOT + File.separator + "MNIST" + File.separator;
    protected transient MnistManager man;
    protected boolean binarize = true;
    protected boolean train;
    protected int[] order;
    protected Random rng;
    protected boolean shuffle;

    public MnistDataFetcher(boolean binarize) throws IOException {
        this(binarize, true, true, System.currentTimeMillis());
    }

    public MnistDataFetcher(boolean binarize, boolean train, boolean shuffle, long rngSeed) throws IOException {
        String labels;
        String images;
        if (!this.mnistExists()) {
            new MnistFetcher().downloadAndUntar();
        }
        if (train) {
            images = MNIST_ROOT + "images-idx3-ubyte";
            labels = MNIST_ROOT + "labels-idx1-ubyte";
            this.totalExamples = 60000;
        } else {
            images = MNIST_ROOT + "t10k-images-idx3-ubyte";
            labels = MNIST_ROOT + "t10k-labels-idx1-ubyte";
            this.totalExamples = 10000;
        }
        try {
            this.man = new MnistManager(images, labels, train);
        }
        catch (Exception e) {
            FileUtils.deleteDirectory((File)new File(MNIST_ROOT));
            new MnistFetcher().downloadAndUntar();
            this.man = new MnistManager(images, labels, train);
        }
        this.numOutcomes = 10;
        this.binarize = binarize;
        this.cursor = 0;
        this.inputColumns = this.man.getImages().getEntryLength();
        this.train = train;
        this.shuffle = shuffle;
        this.order = train ? new int[60000] : new int[10000];
        for (int i = 0; i < this.order.length; ++i) {
            this.order[i] = i;
        }
        this.rng = new Random(rngSeed);
        this.reset();
    }

    private boolean mnistExists() {
        File f = new File(MNIST_ROOT, "images-idx3-ubyte");
        if (!f.exists()) {
            return false;
        }
        f = new File(MNIST_ROOT, "labels-idx1-ubyte");
        if (!f.exists()) {
            return false;
        }
        f = new File(MNIST_ROOT, "t10k-images-idx3-ubyte");
        if (!f.exists()) {
            return false;
        }
        f = new File(MNIST_ROOT, "t10k-labels-idx1-ubyte");
        return f.exists();
    }

    public MnistDataFetcher() throws IOException {
        this(true);
    }

    public void fetch(int numExamples) {
        if (!this.hasMore()) {
            throw new IllegalStateException("Unable to getFromOrigin more; there are no more images");
        }
        float[][] featureData = new float[numExamples][0];
        float[][] labelData = new float[numExamples][0];
        int actualExamples = 0;
        int i = 0;
        while (i < numExamples && this.hasMore()) {
            byte[] img = this.man.readImageUnsafe(this.order[this.cursor]);
            int label = this.man.readLabel(this.order[this.cursor]);
            float[] featureVec = new float[img.length];
            featureData[actualExamples] = featureVec;
            labelData[actualExamples] = new float[10];
            labelData[actualExamples][label] = 1.0f;
            for (int j = 0; j < img.length; ++j) {
                float v = img[j] & 0xFF;
                if (this.binarize) {
                    if (v > 30.0f) {
                        featureVec[j] = 1.0f;
                        continue;
                    }
                    featureVec[j] = 0.0f;
                    continue;
                }
                featureVec[j] = v / 255.0f;
            }
            ++actualExamples;
            ++i;
            ++this.cursor;
        }
        if (actualExamples < numExamples) {
            featureData = (float[][])Arrays.copyOfRange(featureData, 0, actualExamples);
            labelData = (float[][])Arrays.copyOfRange(labelData, 0, actualExamples);
        }
        INDArray features = Nd4j.create((float[][])featureData);
        INDArray labels = Nd4j.create((float[][])labelData);
        this.curr = new DataSet(features, labels);
    }

    public void reset() {
        this.cursor = 0;
        this.curr = null;
        if (this.shuffle) {
            MathUtils.shuffleArray((int[])this.order, (Random)this.rng);
        }
    }

    public DataSet next() {
        DataSet next = super.next();
        return next;
    }
}

