package org.deeplearning4j.gradient.multilayer;

import java.util.ArrayList;
import java.util.List;
import org.deeplearning4j.nn.gradient.MultiLayerGradient;
import org.deeplearning4j.plot.NeuralNetPlotter;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/gradient/multilayer/WeightPlotListener.class */
public class WeightPlotListener implements MultiLayerGradientListener {
    private static final long serialVersionUID = -2476819215506562426L;
    private List<MultiLayerGradient> gradients = new ArrayList();
    private static Logger log = LoggerFactory.getLogger(WeightPlotListener.class);

    @Override // org.deeplearning4j.gradient.multilayer.MultiLayerGradientListener
    public void onMultiLayerGradient(MultiLayerGradient multiLayerGradient) {
        this.gradients.add(multiLayerGradient);
        if (this.gradients.size() >= 6) {
            this.gradients.remove(0);
        }
        plot();
    }

    public void plot() {
        DoubleMatrix[] doubleMatrixArr = new DoubleMatrix[this.gradients.size()];
        String[] strArr = new String[this.gradients.size()];
        log.info("Plotting " + this.gradients.size() + " matrices");
        for (int i = 0; i < this.gradients.size(); i++) {
            strArr[i] = String.valueOf(i);
            doubleMatrixArr[i] = this.gradients.get(i).getGradients().get(0).getwGradient();
        }
        new NeuralNetPlotter().plotMatrices(strArr, doubleMatrixArr);
    }
}
