package org.deeplearning4j.ui.weights;

import com.fasterxml.jackson.jaxrs.json.JacksonJsonProvider;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.ui.UiServer;
import org.deeplearning4j.ui.UiUtils;
import org.deeplearning4j.ui.providers.ObjectMapperProvider;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/ui/weights/HistogramIterationListener.class */
public class HistogramIterationListener implements IterationListener {
    private static final Logger log = LoggerFactory.getLogger(HistogramIterationListener.class);
    private Client client;
    private WebTarget target;
    private int iterations;
    private ArrayList<Double> scoreHistory;
    private List<Map<String, List<Double>>> meanMagHistoryParams;
    private List<Map<String, List<Double>>> meanMagHistoryUpdates;
    private boolean openBrowser;
    private boolean firstIteration;
    private String path;
    private String subPath;

    public HistogramIterationListener(int i) {
        this(i, true, "weights");
    }

    public HistogramIterationListener(int i, boolean z, String str) {
        this.client = ClientBuilder.newClient().register(JacksonJsonProvider.class).register(new ObjectMapperProvider());
        this.iterations = 1;
        this.scoreHistory = new ArrayList<>();
        this.meanMagHistoryParams = new ArrayList();
        this.meanMagHistoryUpdates = new ArrayList();
        this.firstIteration = true;
        try {
            int port = UiServer.getInstance().getPort();
            this.iterations = i;
            this.target = this.client.target("http://localhost:" + port).path(str).path("update");
            this.openBrowser = z;
            this.path = "http://localhost:" + port + "/" + str;
            this.subPath = str;
            System.out.println("UI Histogram: " + this.path);
        } catch (Exception e) {
            log.error("Error initializing UI server", e);
            throw new RuntimeException(e);
        }
    }

    public boolean invoked() {
        return false;
    }

    public void invoke() {
    }

    public void iterationDone(Model model, int i) {
        if (i % this.iterations == 0) {
            Map gradientForVariable = model.gradient().gradientForVariable();
            if (this.meanMagHistoryParams.size() == 0) {
                int i2 = -1;
                Iterator it = gradientForVariable.keySet().iterator();
                while (it.hasNext()) {
                    i2 = Math.max(i2, indexFromString((String) it.next()));
                }
                if (i2 == -1) {
                    i2 = 0;
                }
                for (int i3 = 0; i3 <= i2; i3++) {
                    this.meanMagHistoryParams.add(new LinkedHashMap());
                    this.meanMagHistoryUpdates.add(new LinkedHashMap());
                }
            }
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            for (Map.Entry entry : gradientForVariable.entrySet()) {
                String str = (String) entry.getKey();
                String str2 = "param_" + str;
                linkedHashMap.put(str2, ((INDArray) entry.getValue()).dup());
                Map<String, List<Double>> map = this.meanMagHistoryUpdates.get(indexFromString(str));
                List<Double> list = map.get(str2);
                if (list == null) {
                    list = new ArrayList();
                    map.put(str2, list);
                }
                list.add(Double.valueOf(((INDArray) entry.getValue()).norm1Number().doubleValue() / ((INDArray) entry.getValue()).length()));
            }
            Map paramTable = model.paramTable();
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            for (Map.Entry entry2 : paramTable.entrySet()) {
                String str3 = (String) entry2.getKey();
                String str4 = "param_" + str3;
                linkedHashMap2.put(str4, ((INDArray) entry2.getValue()).dup());
                Map<String, List<Double>> map2 = this.meanMagHistoryParams.get(indexFromString(str3));
                List<Double> list2 = map2.get(str4);
                if (list2 == null) {
                    list2 = new ArrayList();
                    map2.put(str4, list2);
                }
                list2.add(Double.valueOf(((INDArray) entry2.getValue()).norm1Number().doubleValue() / ((INDArray) entry2.getValue()).length()));
            }
            double score = model.score();
            this.scoreHistory.add(Double.valueOf(score));
            ModelAndGradient modelAndGradient = new ModelAndGradient();
            modelAndGradient.setGradients(linkedHashMap);
            modelAndGradient.setParameters(linkedHashMap2);
            modelAndGradient.setScore(score);
            modelAndGradient.setScores(this.scoreHistory);
            modelAndGradient.setPath(this.subPath);
            modelAndGradient.setUpdateMagnitudes(this.meanMagHistoryUpdates);
            modelAndGradient.setParamMagnitudes(this.meanMagHistoryParams);
            modelAndGradient.setLastUpdateTime(System.currentTimeMillis());
            log.debug("{}", this.target.request(new String[]{"application/json"}).accept(new String[]{"application/json"}).post(Entity.entity(modelAndGradient, "application/json")));
            if (this.openBrowser && this.firstIteration) {
                UiUtils.tryOpenBrowser(this.path, log);
                this.firstIteration = false;
            }
        }
    }

    private static int indexFromString(String str) {
        int indexOf = str.indexOf("_");
        if (indexOf == -1) {
            return 0;
        }
        try {
            return Integer.parseInt(str.substring(0, indexOf));
        } catch (NumberFormatException e) {
            return -1;
        }
    }
}
