/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.models.rntn;

import java.text.DecimalFormat;
import java.util.List;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.eval.ConfusionMatrix;
import org.deeplearning4j.models.rntn.RNTN;
import org.deeplearning4j.nn.layers.feedforward.autoencoder.recursive.Tree;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class RNTNEval {
    private Counter<Integer> truePositives = new Counter();
    private Counter<Integer> falsePositives = new Counter();
    private Counter<Integer> falseNegatives = new Counter();
    private ConfusionMatrix<Integer> confusionMatrix = new ConfusionMatrix();
    private static final Logger log = LoggerFactory.getLogger(RNTNEval.class);

    public void eval(RNTN rntn, List<Tree> trees) {
        for (Tree t : trees) {
            rntn.forwardPropagateTree(t);
            this.count(t);
        }
    }

    private void count(Tree tree) {
        if (tree.isLeaf()) {
            return;
        }
        if (tree.prediction() == null) {
            return;
        }
        for (Tree t : tree.children()) {
            this.count(t);
        }
        int treeGoldLabel = tree.goldLabel();
        int predictionLabel = Nd4j.getBlasWrapper().iamax(tree.prediction());
        this.confusionMatrix.add((Comparable)Integer.valueOf(treeGoldLabel), (Comparable)Integer.valueOf(predictionLabel));
    }

    public void incrementTruePositives(int clazz, int i) {
        this.truePositives.incrementCount((Object)clazz, (double)i);
    }

    public void incrementFalsePositives(int clazz, int i) {
        this.falsePositives.incrementCount((Object)clazz, (double)i);
    }

    public void incrementFalseNegatives(int clazz, int i) {
        this.falseNegatives.incrementCount((Object)clazz, (double)i);
    }

    public String stats() {
        StringBuilder builder = new StringBuilder().append("\n");
        List classes = this.confusionMatrix.getClasses();
        for (Integer clazz : classes) {
            for (Integer clazz2 : classes) {
                int count = this.confusionMatrix.getCount((Comparable)clazz, (Comparable)clazz2);
                if (count != 0) {
                    builder.append("\nActual Class " + clazz + " was predicted with Predicted " + clazz2 + " with count " + count + " times\n");
                }
                if (clazz == clazz2) {
                    this.incrementTruePositives(clazz, count);
                    continue;
                }
                this.incrementFalsePositives(clazz2, count);
            }
            int falseNegatives = this.confusionMatrix.getActualTotal((Comparable)clazz) - this.confusionMatrix.getPredictedTotal((Comparable)clazz);
            if (falseNegatives <= 0) continue;
            this.incrementFalseNegatives(clazz, falseNegatives);
        }
        DecimalFormat df = new DecimalFormat("#.####");
        builder.append("\n==========================Scores========================================");
        builder.append("\n Precision: " + df.format(this.precision()));
        builder.append("\n Recall: " + df.format(this.recall()));
        builder.append("\n F1 Score: " + df.format(this.f1()));
        builder.append("\n===========================================================================");
        return builder.toString();
    }

    public double precision() {
        double prec = 0.0;
        for (Integer i : this.confusionMatrix.getClasses()) {
            prec += this.precision(i);
        }
        return prec / (double)this.confusionMatrix.getClasses().size();
    }

    public double precision(int i) {
        if (this.truePositives.getCount((Object)i) == 0.0) {
            return 0.0;
        }
        return this.truePositives.getCount((Object)i) / (this.truePositives.getCount((Object)i) + this.falsePositives.getCount((Object)i));
    }

    public double recall() {
        double r = 0.0;
        for (Integer i : this.confusionMatrix.getClasses()) {
            r += this.recall(i);
        }
        return r / (double)this.confusionMatrix.getClasses().size();
    }

    public double recall(int i) {
        if (this.truePositives.getCount((Object)i) == 0.0) {
            return 0.0;
        }
        return this.truePositives.getCount((Object)i) / (this.truePositives.getCount((Object)i) + this.falseNegatives.getCount((Object)i));
    }

    public double f1() {
        double precision = this.precision();
        double recall = this.recall();
        if (precision == 0.0 || recall == 0.0) {
            return 0.0;
        }
        return 2.0 * (precision * recall / (precision + recall));
    }

    public double f1(int i) {
        double precision = this.precision(i);
        double recall = this.recall();
        if (precision == 0.0 || recall == 0.0) {
            return 0.0;
        }
        return 2.0 * (precision * recall / (precision + recall));
    }
}

