/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.base;

import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URL;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.io.filefilter.DirectoryFileFilter;
import org.apache.commons.io.filefilter.FileFileFilter;
import org.apache.commons.io.filefilter.IOFileFilter;
import org.deeplearning4j.util.ArchiveUtils;
import org.deeplearning4j.util.ImageLoader;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.linalg.util.FeatureUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class LFWLoader {
    private File baseDir = new File(System.getProperty("user.home"));
    public static final String LFW = "lfw";
    private File lfwDir = new File(this.baseDir, "lfw");
    public static final String LFW_URL = "http://vis-www.cs.umass.edu/lfw/lfw.tgz";
    private File lfwTarFile = new File(this.lfwDir, "lfw.tgz");
    private static Logger log = LoggerFactory.getLogger(LFWLoader.class);
    private int numNames;
    private int numPixelColumns;
    private ImageLoader loader = new ImageLoader(28, 28);
    private List<String> images = new ArrayList<String>();
    private List<String> outcomes = new ArrayList<String>();

    public LFWLoader() {
        this(28, 28);
    }

    public LFWLoader(int imageWidth, int imageHeight) {
        this.loader = new ImageLoader(imageWidth, imageHeight);
    }

    public void getIfNotExists() throws Exception {
        if (!this.lfwDir.exists()) {
            this.lfwDir.mkdir();
            log.info("Grabbing LFW...");
            URL website = new URL(LFW_URL);
            ReadableByteChannel rbc = Channels.newChannel(website.openStream());
            if (!this.lfwTarFile.exists()) {
                this.lfwTarFile.createNewFile();
            }
            FileOutputStream fos = new FileOutputStream(this.lfwTarFile);
            fos.getChannel().transferFrom(rbc, 0L, Long.MAX_VALUE);
            fos.flush();
            IOUtils.closeQuietly((OutputStream)fos);
            rbc.close();
            log.info("Downloaded lfw");
            this.untarFile(this.baseDir, this.lfwTarFile);
        }
        File firstImage = null;
        try {
            firstImage = this.lfwDir.listFiles()[0].listFiles()[0];
        }
        catch (Exception e) {
            FileUtils.deleteDirectory((File)this.lfwDir);
            log.warn("Error opening first image; probably corrupt download...trying again", (Throwable)e);
            this.getIfNotExists();
        }
        this.numPixelColumns = ArrayUtil.flatten((int[][])this.loader.fromFile(firstImage)).length;
        this.numNames = this.lfwDir.getAbsoluteFile().listFiles().length;
        Collection allImages = FileUtils.listFiles((File)this.lfwDir, (IOFileFilter)FileFileFilter.FILE, (IOFileFilter)DirectoryFileFilter.DIRECTORY);
        for (File f : allImages) {
            this.images.add(f.getAbsolutePath());
        }
        for (File dir : this.lfwDir.getAbsoluteFile().listFiles()) {
            this.outcomes.add(dir.getAbsolutePath());
        }
    }

    public DataSet convertListPairs(List<DataSet> images) {
        INDArray inputs = Nd4j.create((int)images.size(), (int)this.numPixelColumns);
        INDArray outputs = Nd4j.create((int)images.size(), (int)this.numNames);
        for (int i = 0; i < images.size(); ++i) {
            inputs.putRow(i, images.get(i).getFeatureMatrix());
            outputs.putRow(i, images.get(i).getLabels());
        }
        return new DataSet(inputs, outputs);
    }

    public DataSet getDataFor(int i) {
        File image = new File(this.images.get(i));
        int outcome = this.outcomes.indexOf(image.getParentFile().getAbsolutePath());
        try {
            return new DataSet(this.loader.asRowVector(image), FeatureUtil.toOutcomeVector((int)outcome, (int)this.outcomes.size()));
        }
        catch (Exception e) {
            throw new IllegalStateException("Unable to getFromOrigin data for image " + i + " for path " + this.images.get(i));
        }
    }

    public List<DataSet> getFeatureMatrix(int num) throws Exception {
        ArrayList<DataSet> ret = new ArrayList<DataSet>(num);
        File[] files = this.lfwDir.listFiles();
        int label = 0;
        for (File file : files) {
            ret.addAll(this.getImages(label, file));
            ++label;
            if (ret.size() >= num) break;
        }
        return ret;
    }

    public DataSet getAllImagesAsMatrix() throws Exception {
        List<DataSet> images = this.getImagesAsList();
        return this.convertListPairs(images);
    }

    public DataSet getAllImagesAsMatrix(int numRows) throws Exception {
        List<DataSet> images = this.getImagesAsList().subList(0, numRows);
        return this.convertListPairs(images);
    }

    public List<DataSet> getImagesAsList() throws Exception {
        ArrayList<DataSet> list = new ArrayList<DataSet>();
        File[] dirs = this.lfwDir.listFiles();
        for (int i = 0; i < dirs.length; ++i) {
            list.addAll(this.getImages(i, dirs[i]));
        }
        return list;
    }

    public List<DataSet> getImages(int label, File file) throws Exception {
        File[] images = file.listFiles();
        ArrayList<DataSet> ret = new ArrayList<DataSet>();
        for (File f : images) {
            ret.add(this.fromImageFile(label, f));
        }
        return ret;
    }

    public DataSet fromImageFile(int label, File image) throws Exception {
        INDArray outcome = FeatureUtil.toOutcomeVector((int)label, (int)this.numNames);
        INDArray image2 = ArrayUtil.toNDArray((int[])this.loader.flattenedImageFromFile(image));
        return new DataSet(image2, outcome);
    }

    public void untarFile(File baseDir, File tarFile) throws IOException {
        log.info("Untaring File: " + tarFile.toString());
        ArchiveUtils.unzipFileTo(tarFile.getAbsolutePath(), baseDir.getAbsolutePath());
    }

    public int getNumNames() {
        return this.numNames;
    }

    public int getNumPixelColumns() {
        return this.numPixelColumns;
    }
}

