package org.deeplearning4j.models.rntn;

import java.text.DecimalFormat;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.berkeley.Counter;
import org.deeplearning4j.eval.ConfusionMatrix;
import org.deeplearning4j.nn.layers.feedforward.autoencoder.recursive.Tree;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/models/rntn/RNTNEval.class */
public class RNTNEval {
    private Counter<Integer> truePositives = new Counter<>();
    private Counter<Integer> falsePositives = new Counter<>();
    private Counter<Integer> falseNegatives = new Counter<>();
    private ConfusionMatrix<Integer> confusionMatrix = new ConfusionMatrix<>();
    private static final Logger log = LoggerFactory.getLogger(RNTNEval.class);

    public void eval(RNTN rntn, List<Tree> list) {
        for (Tree tree : list) {
            rntn.forwardPropagateTree(tree);
            count(tree);
        }
    }

    private void count(Tree tree) {
        if (tree.isLeaf() || tree.prediction() == null) {
            return;
        }
        Iterator it = tree.children().iterator();
        while (it.hasNext()) {
            count((Tree) it.next());
        }
        this.confusionMatrix.add(Integer.valueOf(tree.goldLabel()), Integer.valueOf(Nd4j.getBlasWrapper().iamax(tree.prediction())));
    }

    public void incrementTruePositives(int i, int i2) {
        this.truePositives.incrementCount(Integer.valueOf(i), i2);
    }

    public void incrementFalsePositives(int i, int i2) {
        this.falsePositives.incrementCount(Integer.valueOf(i), i2);
    }

    public void incrementFalseNegatives(int i, int i2) {
        this.falseNegatives.incrementCount(Integer.valueOf(i), i2);
    }

    public String stats() {
        StringBuilder append = new StringBuilder().append("\n");
        List<Integer> classes = this.confusionMatrix.getClasses();
        for (Integer num : classes) {
            for (Integer num2 : classes) {
                int count = this.confusionMatrix.getCount(num, num2);
                if (count != 0) {
                    append.append("\nActual Class " + num + " was predicted with Predicted " + num2 + " with count " + count + " times\n");
                }
                if (num == num2) {
                    incrementTruePositives(num.intValue(), count);
                } else {
                    incrementFalsePositives(num2.intValue(), count);
                }
            }
            int actualTotal = this.confusionMatrix.getActualTotal(num) - this.confusionMatrix.getPredictedTotal(num);
            if (actualTotal > 0) {
                incrementFalseNegatives(num.intValue(), actualTotal);
            }
        }
        DecimalFormat decimalFormat = new DecimalFormat("#.####");
        append.append("\n==========================Scores========================================");
        append.append("\n Precision: " + decimalFormat.format(precision()));
        append.append("\n Recall: " + decimalFormat.format(recall()));
        append.append("\n F1 Score: " + decimalFormat.format(f1()));
        append.append("\n===========================================================================");
        return append.toString();
    }

    public double precision() {
        double d = 0.0d;
        Iterator it = this.confusionMatrix.getClasses().iterator();
        while (it.hasNext()) {
            d += precision(((Integer) it.next()).intValue());
        }
        return d / this.confusionMatrix.getClasses().size();
    }

    public double precision(int i) {
        if (this.truePositives.getCount(Integer.valueOf(i)) == 0.0d) {
            return 0.0d;
        }
        return this.truePositives.getCount(Integer.valueOf(i)) / (this.truePositives.getCount(Integer.valueOf(i)) + this.falsePositives.getCount(Integer.valueOf(i)));
    }

    public double recall() {
        double d = 0.0d;
        Iterator it = this.confusionMatrix.getClasses().iterator();
        while (it.hasNext()) {
            d += recall(((Integer) it.next()).intValue());
        }
        return d / this.confusionMatrix.getClasses().size();
    }

    public double recall(int i) {
        if (this.truePositives.getCount(Integer.valueOf(i)) == 0.0d) {
            return 0.0d;
        }
        return this.truePositives.getCount(Integer.valueOf(i)) / (this.truePositives.getCount(Integer.valueOf(i)) + this.falseNegatives.getCount(Integer.valueOf(i)));
    }

    public double f1() {
        double precision = precision();
        double recall = recall();
        if (precision == 0.0d || recall == 0.0d) {
            return 0.0d;
        }
        return 2.0d * ((precision * recall) / (precision + recall));
    }

    public double f1(int i) {
        double precision = precision(i);
        double recall = recall();
        if (precision == 0.0d || recall == 0.0d) {
            return 0.0d;
        }
        return 2.0d * ((precision * recall) / (precision + recall));
    }
}
