package org.deeplearning4j.evaluation;

import java.awt.Color;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.io.FileUtils;
import org.deeplearning4j.eval.ROC;
import org.deeplearning4j.eval.ROCMultiClass;
import org.deeplearning4j.ui.api.Component;
import org.deeplearning4j.ui.api.LengthUnit;
import org.deeplearning4j.ui.components.chart.ChartLine;
import org.deeplearning4j.ui.components.chart.style.StyleChart;
import org.deeplearning4j.ui.components.component.ComponentDiv;
import org.deeplearning4j.ui.components.component.style.StyleDiv;
import org.deeplearning4j.ui.components.table.ComponentTable;
import org.deeplearning4j.ui.components.table.style.StyleTable;
import org.deeplearning4j.ui.components.text.ComponentText;
import org.deeplearning4j.ui.components.text.style.StyleText;
import org.deeplearning4j.ui.standalone.StaticPageUtil;

/* loaded from: input_file:org/deeplearning4j/evaluation/EvaluationTools.class */
public class EvaluationTools {
    private static final String ROC_TITLE = "ROC: TPR/Recall (y) vs. FPR (x)";
    private static final String PR_TITLE = "Precision (y) vs. Recall (x)";
    private static final String PR_THRESHOLD_TITLE = "Precision and Recall (y) vs. Classifier Threshold (x)";
    private static final double CHART_WIDTH_PX = 600.0d;
    private static final double CHART_HEIGHT_PX = 400.0d;
    private static final StyleChart CHART_STYLE = new StyleChart.Builder().width(CHART_WIDTH_PX, LengthUnit.Px).height(CHART_HEIGHT_PX, LengthUnit.Px).margin(LengthUnit.Px, 60, 60, 40, 10).strokeWidth(2.0d).seriesColors(new Color[]{Color.BLUE, Color.LIGHT_GRAY}).build();
    private static final StyleChart CHART_STYLE_PRECISION_RECALL = new StyleChart.Builder().width(CHART_WIDTH_PX, LengthUnit.Px).height(CHART_HEIGHT_PX, LengthUnit.Px).margin(LengthUnit.Px, 60, 60, 40, 10).strokeWidth(2.0d).seriesColors(new Color[]{Color.BLUE, Color.GREEN}).build();
    private static final StyleTable TABLE_STYLE = new StyleTable.Builder().backgroundColor(Color.WHITE).headerColor(Color.LIGHT_GRAY).borderWidth(1).columnWidths(LengthUnit.Percent, new double[]{50.0d, 50.0d}).width(CHART_HEIGHT_PX, LengthUnit.Px).height(200.0d, LengthUnit.Px).build();
    private static final StyleDiv OUTER_DIV_STYLE = new StyleDiv.Builder().width(1200.0d, LengthUnit.Px).height(CHART_HEIGHT_PX, LengthUnit.Px).build();
    private static final StyleDiv INNER_DIV_STYLE = new StyleDiv.Builder().width(CHART_WIDTH_PX, LengthUnit.Px).floatValue(StyleDiv.FloatValue.left).build();
    private static final StyleDiv PAD_DIV_STYLE = new StyleDiv.Builder().width(CHART_WIDTH_PX, LengthUnit.Px).height(100.0d, LengthUnit.Px).floatValue(StyleDiv.FloatValue.left).build();
    private static final ComponentDiv PAD_DIV = new ComponentDiv(PAD_DIV_STYLE, new Component[0]);
    private static final StyleText HEADER_TEXT_STYLE = new StyleText.Builder().color(Color.BLACK).fontSize(16.0d).underline(true).build();
    private static final StyleDiv HEADER_DIV_STYLE = new StyleDiv.Builder().width(1050.0d, LengthUnit.Px).height(30.0d, LengthUnit.Px).backgroundColor(Color.LIGHT_GRAY).margin(LengthUnit.Px, 5, 5, 200, 10).floatValue(StyleDiv.FloatValue.left).build();
    private static final StyleDiv HEADER_DIV_PAD_STYLE = new StyleDiv.Builder().width(1200.0d, LengthUnit.Px).height(150.0d, LengthUnit.Px).backgroundColor(Color.WHITE).build();
    private static final StyleDiv HEADER_DIV_TEXT_PAD_STYLE = new StyleDiv.Builder().width(120.0d, LengthUnit.Px).height(30.0d, LengthUnit.Px).backgroundColor(Color.LIGHT_GRAY).floatValue(StyleDiv.FloatValue.left).build();
    private static final ComponentTable INFO_TABLE = new ComponentTable.Builder(new StyleTable.Builder().backgroundColor(Color.WHITE).borderWidth(0).build()).content((String[][]) new String[]{new String[]{"Precision", "(true positives) / (true positives + false positives)"}, new String[]{"True Positive Rate (Recall)", "(true positives) / (data positives)"}, new String[]{"False Positive Rate", "(false positives) / (data negatives)"}}).build();

    private EvaluationTools() {
    }

    public static void exportRocChartsToHtmlFile(ROC roc, File file) throws IOException {
        FileUtils.writeStringToFile(file, rocChartToHtml(roc));
    }

    public static void exportRocChartsToHtmlFile(ROCMultiClass rOCMultiClass, File file) throws Exception {
        FileUtils.writeStringToFile(file, rocChartToHtml(rOCMultiClass));
    }

    public static String rocChartToHtml(ROC roc) {
        return StaticPageUtil.renderHTML(new Component[]{getRocFromPoints(ROC_TITLE, roc.getResultsAsArray(), roc.getCountActualPositive(), roc.getCountActualNegative(), roc.calculateAUC()), getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, roc.getPrecisionRecallCurve())});
    }

    public static String rocChartToHtml(ROCMultiClass rOCMultiClass) {
        return rocChartToHtml(rOCMultiClass, null);
    }

    public static String rocChartToHtml(ROCMultiClass rOCMultiClass, List<String> list) {
        long[] countActualPositive = rOCMultiClass.getCountActualPositive();
        long[] countActualNegative = rOCMultiClass.getCountActualNegative();
        ArrayList arrayList = new ArrayList(countActualPositive.length);
        for (int i = 0; i < countActualPositive.length; i++) {
            double[][] resultsAsArray = rOCMultiClass.getResultsAsArray(i);
            String str = "Class " + i;
            if (list != null && list.size() > i) {
                str = str + " (" + list.get(i) + ")";
            }
            arrayList.add(new ComponentDiv(HEADER_DIV_PAD_STYLE, new Component[0]));
            ComponentDiv componentDiv = new ComponentDiv(HEADER_DIV_TEXT_PAD_STYLE, new Component[0]);
            ComponentDiv componentDiv2 = new ComponentDiv(HEADER_DIV_STYLE, new Component[]{new ComponentText(str + " vs. All", HEADER_TEXT_STYLE)});
            Component rocFromPoints = getRocFromPoints(ROC_TITLE, resultsAsArray, countActualPositive[i], countActualNegative[i], rOCMultiClass.calculateAUC(i));
            Component pRCharts = getPRCharts(PR_TITLE, PR_THRESHOLD_TITLE, rOCMultiClass.getPrecisionRecallCurve(i));
            arrayList.add(componentDiv);
            arrayList.add(componentDiv2);
            arrayList.add(rocFromPoints);
            arrayList.add(pRCharts);
        }
        return StaticPageUtil.renderHTML(arrayList);
    }

    /* JADX WARN: Type inference failed for: r1v17, types: [java.lang.String[], java.lang.String[][]] */
    private static Component getRocFromPoints(String str, double[][] dArr, long j, long j2, double d) {
        double[] dArr2 = {0.0d, 1.0d};
        Component build = new ChartLine.Builder(str, CHART_STYLE).setXMin(Double.valueOf(0.0d)).setXMax(Double.valueOf(1.0d)).setYMin(Double.valueOf(0.0d)).setYMax(Double.valueOf(1.0d)).addSeries("ROC", dArr[0], dArr[1]).addSeries("", dArr2, dArr2).build();
        return new ComponentDiv(OUTER_DIV_STYLE, new Component[]{new ComponentDiv(INNER_DIV_STYLE, new Component[]{PAD_DIV, new ComponentTable.Builder(TABLE_STYLE).header(new String[]{"Field", "Value"}).content((String[][]) new String[]{new String[]{"AUC", String.format("%.5f", Double.valueOf(d))}, new String[]{"Total Data Positive Count", String.valueOf(j)}, new String[]{"Total Data Negative Count", String.valueOf(j2)}}).build(), PAD_DIV, INFO_TABLE}), new ComponentDiv(INNER_DIV_STYLE, new Component[]{build})});
    }

    private static Component getPRCharts(String str, String str2, List<ROC.PrecisionRecallPoint> list) {
        return new ComponentDiv(OUTER_DIV_STYLE, new Component[]{new ComponentDiv(INNER_DIV_STYLE, new Component[]{getPrecisionRecallCurve(str, list)}), new ComponentDiv(INNER_DIV_STYLE, new Component[]{getPrecisionRecallVsThreshold(str2, list)})});
    }

    private static Component getPrecisionRecallCurve(String str, List<ROC.PrecisionRecallPoint> list) {
        double[] dArr = new double[list.size()];
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            ROC.PrecisionRecallPoint precisionRecallPoint = list.get(i);
            dArr[i] = precisionRecallPoint.getRecall();
            dArr2[i] = precisionRecallPoint.getPrecision();
        }
        return new ChartLine.Builder(str, CHART_STYLE).setXMin(Double.valueOf(0.0d)).setXMax(Double.valueOf(1.0d)).setYMin(Double.valueOf(0.0d)).setYMax(Double.valueOf(1.0d)).addSeries("P vs R", dArr, dArr2).build();
    }

    private static Component getPrecisionRecallVsThreshold(String str, List<ROC.PrecisionRecallPoint> list) {
        double[] dArr = new double[list.size()];
        double[] dArr2 = new double[dArr.length];
        double[] dArr3 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            ROC.PrecisionRecallPoint precisionRecallPoint = list.get(i);
            dArr3[i] = precisionRecallPoint.getClassiferThreshold();
            dArr[i] = precisionRecallPoint.getRecall();
            dArr2[i] = precisionRecallPoint.getPrecision();
        }
        return new ChartLine.Builder(str, CHART_STYLE_PRECISION_RECALL).setXMin(Double.valueOf(0.0d)).setXMax(Double.valueOf(1.0d)).setYMin(Double.valueOf(0.0d)).setYMax(Double.valueOf(1.0d)).addSeries("Precision", dArr3, dArr2).addSeries("Recall", dArr3, dArr).showLegend(true).build();
    }
}
