package edu.stanford.nlp.ie.machinereading;

import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.ie.machinereading.structure.AnnotationUtils;
import edu.stanford.nlp.ie.machinereading.structure.RelationMention;
import edu.stanford.nlp.ie.machinereading.structure.RelationMentionFactory;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import java.io.PrintWriter;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/stanford/nlp/ie/machinereading/RelationExtractorResultsPrinter.class */
public class RelationExtractorResultsPrinter extends ResultsPrinter {
    protected boolean createUnrelatedRelations;
    protected final RelationMentionFactory relationMentionFactory;
    private static final int MAX_LABEL_LENGTH = 31;
    static final /* synthetic */ boolean $assertionsDisabled;

    public RelationExtractorResultsPrinter(RelationMentionFactory relationMentionFactory) {
        this(relationMentionFactory, true);
    }

    public RelationExtractorResultsPrinter() {
        this(new RelationMentionFactory(), true);
    }

    public RelationExtractorResultsPrinter(boolean z) {
        this(new RelationMentionFactory(), z);
    }

    public RelationExtractorResultsPrinter(RelationMentionFactory relationMentionFactory, boolean z) {
        this.createUnrelatedRelations = z;
        this.relationMentionFactory = relationMentionFactory;
    }

    @Override // edu.stanford.nlp.ie.machinereading.ResultsPrinter
    public void printResults(PrintWriter printWriter, List<CoreMap> list, List<CoreMap> list2) {
        ResultsPrinter.align(list, list2);
        if (!$assertionsDisabled && this.relationMentionFactory == null) {
            throw new AssertionError("ERROR: RelationExtractorResultsPrinter.relationMentionFactory cannot be null in printResults!");
        }
        ClassicCounter classicCounter = new ClassicCounter();
        ClassicCounter<String> classicCounter2 = new ClassicCounter<>();
        for (int i = 0; i < list.size(); i++) {
            for (RelationMention relationMention : AnnotationUtils.getAllRelations(this.relationMentionFactory, list.get(i), this.createUnrelatedRelations)) {
                List<RelationMention> relations = AnnotationUtils.getRelations(this.relationMentionFactory, list2.get(i), relationMention.getArg(0), relationMention.getArg(1));
                classicCounter2.incrementCount(relationMention.getType());
                Iterator<RelationMention> it = relations.iterator();
                while (it.hasNext()) {
                    classicCounter.incrementCount(new Pair<>(it.next().getType(), relationMention.getType()));
                }
            }
        }
        printResultsInternal(printWriter, classicCounter, classicCounter2);
    }

    private void printResultsInternal(PrintWriter printWriter, Counter<Pair<String, String>> counter, ClassicCounter<String> classicCounter) {
        ClassicCounter classicCounter2 = new ClassicCounter();
        ClassicCounter classicCounter3 = new ClassicCounter();
        boolean z = false;
        if (classicCounter == null) {
            classicCounter = new ClassicCounter<>();
            z = true;
        }
        for (Pair<String, String> pair : counter.keySet()) {
            String str = pair.first;
            String str2 = pair.second;
            if (str.equals(str2)) {
                classicCounter2.incrementCount(str2, counter.getCount(pair));
            }
            classicCounter3.incrementCount(str, counter.getCount(pair));
            if (z) {
                classicCounter.incrementCount(str2, counter.getCount(pair));
            }
        }
        DecimalFormat decimalFormat = new DecimalFormat();
        decimalFormat.setMaximumFractionDigits(1);
        decimalFormat.setMinimumFractionDigits(1);
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        printWriter.println("Label\tCorrect\tPredict\tActual\tPrecn\tRecall\tF");
        ArrayList<String> arrayList = new ArrayList(classicCounter.keySet());
        Collections.sort(arrayList);
        for (String str3 : arrayList) {
            double count = classicCounter2.getCount(str3);
            double count2 = classicCounter3.getCount(str3);
            double count3 = classicCounter.getCount(str3);
            double d4 = count2 > 0.0d ? count / count2 : 0.0d;
            double d5 = count / count3;
            printWriter.println(StringUtils.padOrTrim(str3, 31) + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + count + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + count2 + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + count3 + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + decimalFormat.format(d4 * 100.0d) + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + decimalFormat.format(100.0d * d5) + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + decimalFormat.format(100.0d * (d4 + d5 > 0.0d ? ((2.0d * d4) * d5) / (d4 + d5) : 0.0d)));
            if (!RelationMention.isUnrelatedLabel(str3)) {
                d += count3;
                d2 += count;
                d3 += count2;
            }
        }
        double d6 = d3 > 0.0d ? d2 / d3 : 0.0d;
        double d7 = d2 / d;
        printWriter.println("Total\t" + d2 + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + d3 + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + d + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + decimalFormat.format(100.0d * d6) + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + decimalFormat.format(100.0d * d7) + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + decimalFormat.format(100.0d * ((d3 <= 0.0d || d2 <= 0.0d) ? 0.0d : ((2.0d * d6) * d7) / (d6 + d7))));
    }

    @Override // edu.stanford.nlp.ie.machinereading.ResultsPrinter
    public void printResultsUsingLabels(PrintWriter printWriter, List<String> list, List<String> list2) {
        ClassicCounter classicCounter = new ClassicCounter();
        if (!$assertionsDisabled && list.size() != list2.size()) {
            throw new AssertionError();
        }
        for (int i = 0; i < list.size(); i++) {
            classicCounter.incrementCount(new Pair<>(list2.get(i), list.get(i)));
        }
        printResultsInternal(printWriter, classicCounter, null);
    }

    static {
        $assertionsDisabled = !RelationExtractorResultsPrinter.class.desiredAssertionStatus();
    }
}
