package org.deeplearning4j.plot;

import org.deeplearning4j.datasets.iterator.DataSetIterator;
import org.deeplearning4j.datasets.mnist.draw.DrawReconstruction;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.dataset.DataSet;

/* loaded from: input_file:org/deeplearning4j/plot/MultiLayerNetworkReconstructionRender.class */
public class MultiLayerNetworkReconstructionRender {
    private DataSetIterator iter;
    private MultiLayerNetwork network;
    private int reconLayer;

    public MultiLayerNetworkReconstructionRender(DataSetIterator dataSetIterator, MultiLayerNetwork multiLayerNetwork, int i) {
        this.reconLayer = -1;
        this.iter = dataSetIterator;
        this.network = multiLayerNetwork;
        this.reconLayer = i;
    }

    public MultiLayerNetworkReconstructionRender(DataSetIterator dataSetIterator, MultiLayerNetwork multiLayerNetwork) {
        this(dataSetIterator, multiLayerNetwork, -1);
    }

    public void draw() throws InterruptedException {
        while (this.iter.hasNext()) {
            DataSet next = this.iter.next();
            INDArray output = this.reconLayer < 0 ? this.network.output(next.getFeatureMatrix()) : this.network.reconstruct(next.getFeatureMatrix(), this.reconLayer);
            for (int i = 0; i < next.numExamples(); i++) {
                INDArray mul = next.get(i).getFeatureMatrix().mul(255);
                INDArray mul2 = output.getRow(i).mul(65025);
                DrawReconstruction drawReconstruction = new DrawReconstruction(mul);
                drawReconstruction.title = "REAL";
                drawReconstruction.draw();
                DrawReconstruction drawReconstruction2 = new DrawReconstruction(mul2);
                drawReconstruction2.title = "TEST";
                drawReconstruction2.draw();
                Thread.sleep(10000L);
                drawReconstruction.frame.dispose();
                drawReconstruction2.frame.dispose();
            }
        }
    }
}
