package org.deeplearning4j.datasets.mnist.draw;

import java.io.FileInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import org.deeplearning4j.datasets.fetchers.MnistDataFetcher;
import org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator;
import org.deeplearning4j.nn.layers.BasePretrainNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/datasets/mnist/draw/LoadAndDraw.class */
public class LoadAndDraw {
    private LoadAndDraw() {
    }

    public static void main(String[] strArr) throws Exception {
        MnistDataSetIterator mnistDataSetIterator = new MnistDataSetIterator(60, MnistDataFetcher.NUM_EXAMPLES);
        ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(strArr[0]));
        BasePretrainNetwork basePretrainNetwork = (BasePretrainNetwork) objectInputStream.readObject();
        try {
            objectInputStream.close();
        } catch (IOException e) {
        }
        while (mnistDataSetIterator.hasNext()) {
            DataSet next = mnistDataSetIterator.next();
            INDArray activate = basePretrainNetwork.activate(next.getFeatureMatrix());
            for (int i = 0; i < next.numExamples(); i++) {
                INDArray mul = next.get(i).getFeatureMatrix().mul(255);
                INDArray row = activate.getRow(i);
                INDArray mul2 = Nd4j.getDistributions().createBinomial(1, row).sample(row.shape()).mul(255);
                DrawReconstruction drawReconstruction = new DrawReconstruction(mul);
                drawReconstruction.title = "REAL";
                drawReconstruction.draw();
                DrawReconstruction drawReconstruction2 = new DrawReconstruction(mul2, 100, 100);
                drawReconstruction2.title = "TEST";
                drawReconstruction2.draw();
                Thread.sleep(10000L);
                drawReconstruction.frame.dispose();
                drawReconstruction2.frame.dispose();
            }
        }
    }
}
