package org.deeplearning4j.plot;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.TreeSet;
import java.util.UUID;
import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang3.StringUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.gradient.Gradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.core.io.ClassPathResource;

/* loaded from: input_file:org/deeplearning4j/plot/NeuralNetPlotter.class */
public class NeuralNetPlotter implements Serializable {
    private static ClassPathResource script = new ClassPathResource("scripts" + File.separator + "plot.py");
    private static final Logger log = LoggerFactory.getLogger(NeuralNetPlotter.class);
    private static String ID_FOR_SESSION = UUID.randomUUID().toString();
    private static String localPath = System.getProperty("java.io.tmpdir") + File.separator;
    private static String dataFilePath = localPath + "data" + File.separator;
    private static String graphPath = localPath + "graphs" + File.separator;
    private static String graphFilePath = graphPath + ID_FOR_SESSION + File.separator;
    private static String localPlotPath = loadIntoTmp();
    private static String layerGraphFilePath = graphFilePath;

    public String getLayerGraphFilePath() {
        return layerGraphFilePath;
    }

    public void setLayerGraphFilePath(String str) {
        layerGraphFilePath = str;
    }

    public static void printDataFilePath() {
        log.info("Data stored at " + dataFilePath);
    }

    public static void printGraphFilePath() {
        log.warn("Graphs stored at " + graphFilePath + ". Warning: You must manually delete the folder when you are done.");
    }

    private static String loadIntoTmp() {
        setupDirectory(dataFilePath);
        setupDirectory(graphFilePath);
        printDataFilePath();
        printGraphFilePath();
        File file = new File(graphPath, "plot.py");
        file.deleteOnExit();
        if (!file.exists()) {
            try {
                FileUtils.writeLines(file, IOUtils.readLines(script.getInputStream()));
            } catch (IOException e) {
                throw new IllegalStateException("Unable to load python file");
            }
        }
        return file.getAbsolutePath();
    }

    protected static void setupDirectory(String str) {
        File file = new File(str);
        if (file.isDirectory()) {
            return;
        }
        file.mkdir();
    }

    public void updateGraphDirectory(Layer layer) {
        String[] split = layer.getClass().toString().split("\\.");
        String str = graphFilePath + File.separator + (Integer.toString(layer.getIndex()) + split[split.length - 1]) + File.separator;
        if (new File(str).exists()) {
            return;
        }
        setupDirectory(str);
        setLayerGraphFilePath(str);
    }

    protected String writeMatrix(INDArray iNDArray) {
        try {
            String str = dataFilePath + UUID.randomUUID().toString();
            File file = new File(str);
            BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file, true));
            file.deleteOnExit();
            for (int i = 0; i < iNDArray.rows(); i++) {
                INDArray row = iNDArray.getRow(i);
                StringBuilder sb = new StringBuilder();
                for (int i2 = 0; i2 < row.length(); i2++) {
                    sb.append(String.format("%.10f", Double.valueOf(row.getDouble(i2))));
                    if (i2 < row.length() - 1) {
                        sb.append(",");
                    }
                }
                sb.append("\n");
                bufferedOutputStream.write(sb.toString().getBytes());
                bufferedOutputStream.flush();
            }
            bufferedOutputStream.close();
            return str;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public String writeArray(ArrayList arrayList) {
        try {
            String str = dataFilePath + UUID.randomUUID().toString();
            File file = new File(str);
            BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new FileOutputStream(file, true));
            file.deleteOnExit();
            StringBuilder sb = new StringBuilder();
            Iterator it = arrayList.iterator();
            while (it.hasNext()) {
                sb.append(String.format("%.10f", (Double) it.next()));
                sb.append(",");
            }
            String sb2 = sb.toString();
            bufferedOutputStream.write(sb2.substring(0, sb2.length() - 1).getBytes());
            bufferedOutputStream.flush();
            bufferedOutputStream.close();
            return str;
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void renderGraph(String str, String str2, String str3) {
        try {
            log.info("Rendering " + str + " graphs for data analysis...");
            Process exec = Runtime.getRuntime().exec("python " + localPlotPath + " " + str + " " + str2 + " " + str3);
            log.info("Std out " + IOUtils.readLines(exec.getInputStream()).toString());
            log.error("Std error " + IOUtils.readLines(exec.getErrorStream()).toString());
        } catch (IOException e) {
            log.warn("Image closed");
            throw new RuntimeException(e);
        }
    }

    public void renderGraph(String str, String str2, String str3, int i, int i2) {
        try {
            log.info("Rendering " + str + " graphs for data analysis...");
            Process exec = Runtime.getRuntime().exec("python " + localPlotPath + " " + str + " " + str2 + " " + str3 + " " + i + " " + i2);
            log.info("Std out " + IOUtils.readLines(exec.getInputStream()).toString());
            log.error("Std error " + IOUtils.readLines(exec.getErrorStream()).toString());
        } catch (IOException e) {
            log.warn("Image closed");
            throw new RuntimeException(e);
        }
    }

    public void graphPlotType(String str, List<String> list, INDArray[] iNDArrayArr, String str2) {
        String[] strArr = new String[iNDArrayArr.length * 2];
        if (list.size() != iNDArrayArr.length) {
            throw new IllegalArgumentException("Titles and matrix lengths must be equal");
        }
        for (int i = 0; i < strArr.length - 1; i += 2) {
            strArr[i] = writeMatrix(iNDArrayArr[i / 2].ravel());
            strArr[i + 1] = list.get(i / 2);
        }
        renderGraph(str, StringUtils.join(strArr, ","), str2);
    }

    public void plotWeightHistograms(Layer layer, Gradient gradient) {
        TreeSet treeSet = new TreeSet(gradient.gradientForVariable().keySet());
        ArrayList arrayList = new ArrayList(treeSet);
        Iterator it = treeSet.iterator();
        while (it.hasNext()) {
            arrayList.add(((String) it.next()) + "-gradient");
        }
        INDArray[] iNDArrayArr = new INDArray[layer.conf().variables().size() * 2];
        int i = 0;
        for (int i2 = 0; i2 < layer.conf().variables().size(); i2++) {
            int i3 = i;
            i++;
            iNDArrayArr[i3] = layer.getParam(layer.conf().variables().get(i2));
        }
        for (int i4 = 0; i4 < layer.conf().variables().size(); i4++) {
            int i5 = i;
            i++;
            iNDArrayArr[i5] = gradient.getGradientFor(layer.conf().variables().get(i4));
        }
        graphPlotType("histogram", arrayList, iNDArrayArr, layerGraphFilePath + "weightHistograms.png");
    }

    public void plotWeightHistograms(Layer layer) {
        plotWeightHistograms(layer, layer.gradient());
    }

    public void plotActivations(Layer layer) {
        if (layer.input() == null) {
            throw new IllegalStateException("Unable to plot; missing input");
        }
        renderGraph("activations", writeMatrix(layer.activationMean()), layerGraphFilePath + "activationPlot.png");
    }

    public void renderFilter(Layer layer, int i) {
        INDArray dup = layer.getParam("W").dup();
        FilterRenderer filterRenderer = new FilterRenderer();
        try {
            if (dup.shape().length > 2) {
                filterRenderer.renderFilters(dup.transpose(), layerGraphFilePath + "renderFilter.png", dup.columns(), dup.rows(), dup.slices());
            } else {
                filterRenderer.renderFilters(dup, layerGraphFilePath + "renderFilter.png", (int) Math.sqrt(dup.rows()), (int) Math.sqrt(dup.columns()), i);
            }
        } catch (Exception e) {
            log.error("Unable to plot filter, continuing...", e);
            e.printStackTrace();
        }
    }

    public void plotNetworkGradient(Layer layer, Gradient gradient) {
        plotWeightHistograms(layer, gradient);
        plotActivations(layer);
    }

    public void plotNetworkGradient(Layer layer, INDArray iNDArray) {
        graphPlotType("histogram", Arrays.asList("W", "w-gradient"), new INDArray[]{layer.getParam("W"), iNDArray}, layerGraphFilePath + "weightHistograms.png");
        plotActivations(layer);
    }
}
