package edu.stanford.nlp.coref.statistical;

import edu.stanford.nlp.coref.statistical.Clusterer;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.util.Pair;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

/* loaded from: input_file:edu/stanford/nlp/coref/statistical/EvalUtils.class */
public class EvalUtils {

    /* loaded from: input_file:edu/stanford/nlp/coref/statistical/EvalUtils$AbstractEvaluator.class */
    public static abstract class AbstractEvaluator implements Evaluator {
        public double pNum;
        public double pDen;
        public double rNum;
        public double rDen;

        @Override // edu.stanford.nlp.coref.statistical.EvalUtils.Evaluator
        public void update(List<List<Integer>> list, List<Clusterer.Cluster> list2, Map<Integer, List<Integer>> map, Map<Integer, Clusterer.Cluster> map2) {
            List<List<Integer>> list3 = (List) list2.stream().map(cluster -> {
                return cluster.mentions;
            }).collect(Collectors.toList());
            Map<Integer, List<Integer>> map3 = (Map) map2.entrySet().stream().collect(Collectors.toMap((v0) -> {
                return v0.getKey();
            }, entry -> {
                return ((Clusterer.Cluster) entry.getValue()).mentions;
            }));
            Pair<Double, Double> score = getScore(list3, map);
            Pair<Double, Double> score2 = getScore(list, map3);
            this.pNum += score.first.doubleValue();
            this.pDen += score.second.doubleValue();
            this.rNum += score2.first.doubleValue();
            this.rDen += score2.second.doubleValue();
        }

        @Override // edu.stanford.nlp.coref.statistical.EvalUtils.Evaluator
        public double getF1() {
            return EvalUtils.f1(this.pNum, this.pDen, this.rNum, this.rDen);
        }

        public double getRecall() {
            if (this.pNum == 0.0d) {
                return 0.0d;
            }
            return this.pNum / this.pDen;
        }

        public double getPrecision() {
            if (this.rNum == 0.0d) {
                return 0.0d;
            }
            return this.rNum / this.rDen;
        }

        public abstract Pair<Double, Double> getScore(List<List<Integer>> list, Map<Integer, List<Integer>> map);
    }

    /* loaded from: input_file:edu/stanford/nlp/coref/statistical/EvalUtils$B3Evaluator.class */
    public static class B3Evaluator extends AbstractEvaluator {
        @Override // edu.stanford.nlp.coref.statistical.EvalUtils.AbstractEvaluator
        public Pair<Double, Double> getScore(List<List<Integer>> list, Map<Integer, List<Integer>> map) {
            double d = 0.0d;
            int i = 0;
            for (List<Integer> list2 : list) {
                if (list2.size() != 1) {
                    ClassicCounter classicCounter = new ClassicCounter();
                    double d2 = 0.0d;
                    Iterator<Integer> it = list2.iterator();
                    while (it.hasNext()) {
                        List<Integer> list3 = map.get(Integer.valueOf(it.next().intValue()));
                        if (list3 != null) {
                            classicCounter.incrementCount(list3);
                        }
                    }
                    Iterator it2 = classicCounter.entrySet().iterator();
                    while (it2.hasNext()) {
                        Map.Entry entry = (Map.Entry) it2.next();
                        if (((List) entry.getKey()).size() != 1) {
                            d2 += ((Double) entry.getValue()).doubleValue() * ((Double) entry.getValue()).doubleValue();
                        }
                    }
                    d += d2 / list2.size();
                    i += list2.size();
                }
            }
            return new Pair<>(Double.valueOf(d), Double.valueOf(i));
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/coref/statistical/EvalUtils$CombinedEvaluator.class */
    public static class CombinedEvaluator implements Evaluator {
        private final B3Evaluator b3Evaluator = new B3Evaluator();
        private final MUCEvaluator mucEvaluator = new MUCEvaluator();
        private final double mucWeight;

        public CombinedEvaluator(double d) {
            this.mucWeight = d;
        }

        @Override // edu.stanford.nlp.coref.statistical.EvalUtils.Evaluator
        public void update(List<List<Integer>> list, List<Clusterer.Cluster> list2, Map<Integer, List<Integer>> map, Map<Integer, Clusterer.Cluster> map2) {
            if (this.mucWeight != 1.0d) {
                this.b3Evaluator.update(list, list2, map, map2);
            }
            if (this.mucWeight != 0.0d) {
                this.mucEvaluator.update(list, list2, map, map2);
            }
        }

        @Override // edu.stanford.nlp.coref.statistical.EvalUtils.Evaluator
        public double getF1() {
            return (this.mucWeight == 0.0d ? 0.0d : this.mucWeight * this.mucEvaluator.getF1()) + (this.mucWeight == 1.0d ? 0.0d : (1.0d - this.mucWeight) * this.b3Evaluator.getF1());
        }
    }

    /* loaded from: input_file:edu/stanford/nlp/coref/statistical/EvalUtils$Evaluator.class */
    public interface Evaluator {
        void update(List<List<Integer>> list, List<Clusterer.Cluster> list2, Map<Integer, List<Integer>> map, Map<Integer, Clusterer.Cluster> map2);

        double getF1();
    }

    /* loaded from: input_file:edu/stanford/nlp/coref/statistical/EvalUtils$MUCEvaluator.class */
    public static class MUCEvaluator extends AbstractEvaluator {
        @Override // edu.stanford.nlp.coref.statistical.EvalUtils.AbstractEvaluator
        public Pair<Double, Double> getScore(List<List<Integer>> list, Map<Integer, List<Integer>> map) {
            int i = 0;
            int i2 = 0;
            for (List<Integer> list2 : list) {
                i2 += list2.size() - 1;
                int size = i + list2.size();
                HashSet hashSet = new HashSet();
                Iterator<Integer> it = list2.iterator();
                while (it.hasNext()) {
                    List<Integer> list3 = map.get(Integer.valueOf(it.next().intValue()));
                    if (list3 == null) {
                        size--;
                    } else {
                        hashSet.add(list3);
                    }
                }
                i = size - hashSet.size();
            }
            return new Pair<>(Double.valueOf(i), Double.valueOf(i2));
        }
    }

    public static double getCombinedF1(double d, List<List<Integer>> list, List<Clusterer.Cluster> list2, Map<Integer, List<Integer>> map, Map<Integer, Clusterer.Cluster> map2) {
        CombinedEvaluator combinedEvaluator = new CombinedEvaluator(d);
        combinedEvaluator.update(list, list2, map, map2);
        return combinedEvaluator.getF1();
    }

    public static double f1(double d, double d2, double d3, double d4) {
        double d5 = d == 0.0d ? 0.0d : d / d2;
        double d6 = d3 == 0.0d ? 0.0d : d3 / d4;
        if (d5 == 0.0d) {
            return 0.0d;
        }
        return ((2.0d * d5) * d6) / (d5 + d6);
    }
}
