package org.deeplearning4j.ui.weights;

import com.fasterxml.jackson.jaxrs.json.JacksonJsonProvider;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import javax.ws.rs.client.Client;
import javax.ws.rs.client.ClientBuilder;
import javax.ws.rs.client.Entity;
import javax.ws.rs.client.WebTarget;
import lombok.NonNull;
import org.deeplearning4j.nn.api.Model;
import org.deeplearning4j.optimize.api.IterationListener;
import org.deeplearning4j.ui.UiConnectionInfo;
import org.deeplearning4j.ui.UiServer;
import org.deeplearning4j.ui.UiUtils;
import org.deeplearning4j.ui.flow.FlowIterationListener;
import org.deeplearning4j.ui.providers.ObjectMapperProvider;
import org.deeplearning4j.ui.weights.HistogramBin;
import org.deeplearning4j.ui.weights.beans.CompactModelAndGradient;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/ui/weights/HistogramIterationListener.class */
public class HistogramIterationListener implements IterationListener {
    private static final Logger log = LoggerFactory.getLogger(HistogramIterationListener.class);
    private Client client;
    private WebTarget target;
    private int iterations;
    private int curIteration;
    private ArrayList<Double> scoreHistory;
    private List<Map<String, List<Double>>> meanMagHistoryParams;
    private List<Map<String, List<Double>>> meanMagHistoryUpdates;
    private Map<String, Integer> layerNameIndexes;
    private List<String> layerNames;
    private int layerNameIndexesCount;
    private boolean openBrowser;
    private boolean firstIteration;
    private String path;
    private String subPath;
    private UiConnectionInfo connectionInfo;

    public HistogramIterationListener(@NonNull UiConnectionInfo uiConnectionInfo, int i) {
        this.client = ClientBuilder.newClient().register(JacksonJsonProvider.class).register(new ObjectMapperProvider());
        this.iterations = 1;
        this.curIteration = 0;
        this.scoreHistory = new ArrayList<>();
        this.meanMagHistoryParams = new ArrayList();
        this.meanMagHistoryUpdates = new ArrayList();
        this.layerNameIndexes = new HashMap();
        this.layerNames = new ArrayList();
        this.layerNameIndexesCount = 0;
        this.firstIteration = true;
        this.subPath = "weights";
        if (uiConnectionInfo == null) {
            throw new NullPointerException("connection");
        }
        this.target = this.client.target(uiConnectionInfo.getFirstPart()).path(uiConnectionInfo.getSecondPart(this.subPath)).path("update").queryParam("sid", new Object[]{uiConnectionInfo.getSessionId()});
        this.connectionInfo = uiConnectionInfo;
        this.iterations = i;
        System.out.println("UI Histogram URL: " + uiConnectionInfo.getFullAddress());
    }

    public HistogramIterationListener(int i) {
        this(i, true);
    }

    public HistogramIterationListener(int i, boolean z) {
        this.client = ClientBuilder.newClient().register(JacksonJsonProvider.class).register(new ObjectMapperProvider());
        this.iterations = 1;
        this.curIteration = 0;
        this.scoreHistory = new ArrayList<>();
        this.meanMagHistoryParams = new ArrayList();
        this.meanMagHistoryUpdates = new ArrayList();
        this.layerNameIndexes = new HashMap();
        this.layerNames = new ArrayList();
        this.layerNameIndexesCount = 0;
        this.firstIteration = true;
        this.subPath = "weights";
        try {
            int port = UiServer.getInstance().getPort();
            this.iterations = i;
            if (this.iterations < 1) {
                this.iterations = 1;
            }
            UiConnectionInfo build = new UiConnectionInfo.Builder().enableHttps(false).setAddress(FlowIterationListener.LOCALHOST).setPort(port).build();
            this.connectionInfo = build;
            this.target = this.client.target(build.getFirstPart()).path(this.subPath).path("update").queryParam("sid", new Object[]{build.getSessionId()});
            this.openBrowser = z;
            this.path = "http://localhost:" + port + "/" + this.subPath;
            System.out.println("UI Histogram URL: " + this.path + "?sid=" + build.getSessionId());
        } catch (Exception e) {
            log.error("Error initializing UI server", e);
            throw new RuntimeException(e);
        }
    }

    public boolean invoked() {
        return false;
    }

    public void invoke() {
    }

    public void iterationDone(Model model, int i) {
        if (this.curIteration % this.iterations == 0) {
            LinkedHashMap linkedHashMap = new LinkedHashMap();
            try {
                Map gradientForVariable = model.gradient().gradientForVariable();
                if (this.meanMagHistoryParams.isEmpty()) {
                    int i2 = -1;
                    Iterator it = gradientForVariable.keySet().iterator();
                    while (it.hasNext()) {
                        i2 = Math.max(i2, indexFromString((String) it.next()));
                    }
                    if (i2 == -1) {
                        i2 = 0;
                    }
                    for (int i3 = 0; i3 <= i2; i3++) {
                        this.meanMagHistoryParams.add(new LinkedHashMap());
                        this.meanMagHistoryUpdates.add(new LinkedHashMap());
                    }
                }
                for (Map.Entry entry : gradientForVariable.entrySet()) {
                    String str = (String) entry.getKey();
                    String str2 = Character.isDigit(str.charAt(0)) ? "param_" + str : str;
                    linkedHashMap.put(str2, new HistogramBin.Builder(((INDArray) entry.getValue()).dup()).setBinCount(20).setRounding(6).build().getData());
                    int indexFromString = indexFromString(str);
                    if (indexFromString >= this.meanMagHistoryUpdates.size()) {
                        this.meanMagHistoryUpdates.add(new LinkedHashMap());
                    }
                    Map<String, List<Double>> map = this.meanMagHistoryUpdates.get(indexFromString);
                    List<Double> list = map.get(str2);
                    if (list == null) {
                        list = new ArrayList();
                        map.put(str2, list);
                    }
                    list.add(Double.valueOf(((INDArray) entry.getValue()).norm1Number().doubleValue() / ((INDArray) entry.getValue()).length()));
                }
            } catch (Exception e) {
                log.warn("Skipping gradients update");
            }
            Map paramTable = model.paramTable();
            LinkedHashMap linkedHashMap2 = new LinkedHashMap();
            for (Map.Entry entry2 : paramTable.entrySet()) {
                String str3 = (String) entry2.getKey();
                String str4 = Character.isDigit(str3.charAt(0)) ? "param_" + str3 : str3;
                linkedHashMap2.put(str4, new HistogramBin.Builder(((INDArray) entry2.getValue()).dup()).setBinCount(20).setRounding(6).build().getData());
                int indexFromString2 = indexFromString(str3);
                if (indexFromString2 >= this.meanMagHistoryParams.size()) {
                    this.meanMagHistoryParams.add(new LinkedHashMap());
                }
                Map<String, List<Double>> map2 = this.meanMagHistoryParams.get(indexFromString2);
                List<Double> list2 = map2.get(str4);
                if (list2 == null) {
                    list2 = new ArrayList();
                    map2.put(str4, list2);
                }
                list2.add(Double.valueOf(((INDArray) entry2.getValue()).norm1Number().doubleValue() / ((INDArray) entry2.getValue()).length()));
            }
            double score = model.score();
            this.scoreHistory.add(Double.valueOf(score));
            CompactModelAndGradient compactModelAndGradient = new CompactModelAndGradient();
            compactModelAndGradient.setGradients(linkedHashMap);
            compactModelAndGradient.setParameters(linkedHashMap2);
            compactModelAndGradient.setScore(score);
            compactModelAndGradient.setScores(this.scoreHistory);
            compactModelAndGradient.setPath(this.subPath);
            compactModelAndGradient.setUpdateMagnitudes(this.meanMagHistoryUpdates);
            compactModelAndGradient.setParamMagnitudes(this.meanMagHistoryParams);
            compactModelAndGradient.setLayerNames(this.layerNames);
            compactModelAndGradient.setLastUpdateTime(System.currentTimeMillis());
            log.debug("{}", this.target.request(new String[]{"application/json"}).accept(new String[]{"application/json"}).post(Entity.entity(compactModelAndGradient, "application/json")));
            if (this.openBrowser && this.firstIteration) {
                StringBuilder sb = new StringBuilder(this.connectionInfo.getFullAddress());
                sb.append(this.subPath).append("?sid=").append(this.connectionInfo.getSessionId());
                UiUtils.tryOpenBrowser(sb.toString(), log);
                this.firstIteration = false;
            }
        }
        this.curIteration++;
    }

    private int indexFromString(String str) {
        int indexOf = str.indexOf(95);
        if (indexOf == -1) {
            if (!this.layerNameIndexes.containsKey(str)) {
                this.layerNames.add(str);
                Map<String, Integer> map = this.layerNameIndexes;
                int i = this.layerNameIndexesCount;
                this.layerNameIndexesCount = i + 1;
                map.put(str, Integer.valueOf(i));
            }
            return this.layerNameIndexes.get(str).intValue();
        }
        String substring = str.substring(0, indexOf);
        if (!this.layerNameIndexes.containsKey(substring)) {
            this.layerNames.add(substring);
            Map<String, Integer> map2 = this.layerNameIndexes;
            int i2 = this.layerNameIndexesCount;
            this.layerNameIndexesCount = i2 + 1;
            map2.put(substring, Integer.valueOf(i2));
        }
        return this.layerNameIndexes.get(substring).intValue();
    }
}
