package edu.stanford.nlp.coref.neural;

import edu.stanford.nlp.coref.CorefAlgorithm;
import edu.stanford.nlp.coref.CorefProperties;
import edu.stanford.nlp.coref.CorefUtils;
import edu.stanford.nlp.coref.data.Dictionaries;
import edu.stanford.nlp.coref.data.Document;
import edu.stanford.nlp.coref.data.Mention;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.neural.Embedding;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.logging.Redwood;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:edu/stanford/nlp/coref/neural/NeuralCorefAlgorithm.class */
public class NeuralCorefAlgorithm implements CorefAlgorithm {
    private static Redwood.RedwoodChannels log = Redwood.channels(NeuralCorefAlgorithm.class);
    private final double greedyness;
    private final int maxMentionDistance;
    private final int maxMentionDistanceWithStringMatch;
    private final CategoricalFeatureExtractor featureExtractor;
    private final EmbeddingExtractor embeddingExtractor;
    private final NeuralCorefModel model;

    public NeuralCorefAlgorithm(Properties properties, Dictionaries dictionaries) {
        this.greedyness = NeuralCorefProperties.greedyness(properties);
        this.maxMentionDistance = CorefProperties.maxMentionDistance(properties);
        this.maxMentionDistanceWithStringMatch = CorefProperties.maxMentionDistanceWithStringMatch(properties);
        this.model = (NeuralCorefModel) IOUtils.readObjectAnnouncingTimingFromURLOrClasspathOrFileSystem(log, "Loading coref model", NeuralCorefProperties.modelPath(properties));
        this.embeddingExtractor = new EmbeddingExtractor(CorefProperties.conll(properties), (Embedding) IOUtils.readObjectAnnouncingTimingFromURLOrClasspathOrFileSystem(log, "Loading coref embeddings", NeuralCorefProperties.pretrainedEmbeddingsPath(properties)), this.model.getWordEmbeddings());
        this.featureExtractor = new CategoricalFeatureExtractor(properties, dictionaries);
    }

    @Override // edu.stanford.nlp.coref.CorefAlgorithm
    public void runCoref(Document document) {
        List<Mention> sortedMentions = CorefUtils.getSortedMentions(document);
        HashMap hashMap = new HashMap();
        for (Mention mention : sortedMentions) {
            ((List) hashMap.computeIfAbsent(Integer.valueOf(mention.headIndex), num -> {
                return new ArrayList();
            })).add(mention);
        }
        SimpleMatrix documentEmbedding = this.embeddingExtractor.getDocumentEmbedding(document);
        HashMap hashMap2 = new HashMap();
        HashMap hashMap3 = new HashMap();
        ClassicCounter classicCounter = new ClassicCounter();
        for (Mention mention2 : sortedMentions) {
            SimpleMatrix mentionEmbeddings = this.embeddingExtractor.getMentionEmbeddings(mention2, documentEmbedding);
            hashMap2.put(Integer.valueOf(mention2.mentionID), this.model.getAntecedentEmbedding(mentionEmbeddings));
            hashMap3.put(Integer.valueOf(mention2.mentionID), this.model.getAnaphorEmbedding(mentionEmbeddings));
            classicCounter.incrementCount(Integer.valueOf(mention2.mentionID), this.model.getAnaphoricityScore(mentionEmbeddings, this.featureExtractor.getAnaphoricityFeatures(mention2, document, hashMap)));
        }
        for (Map.Entry<Integer, List<Integer>> entry : CorefUtils.heuristicFilter(sortedMentions, this.maxMentionDistance, this.maxMentionDistanceWithStringMatch).entrySet()) {
            double count = classicCounter.getCount(entry.getKey()) - (50.0d * (this.greedyness - 0.5d));
            int intValue = entry.getKey().intValue();
            Integer num2 = null;
            Iterator<Integer> it = entry.getValue().iterator();
            while (it.hasNext()) {
                int intValue2 = it.next().intValue();
                double pairwiseScore = this.model.getPairwiseScore((SimpleMatrix) hashMap2.get(Integer.valueOf(intValue2)), (SimpleMatrix) hashMap3.get(Integer.valueOf(intValue)), this.featureExtractor.getPairFeatures(new Pair<>(Integer.valueOf(intValue2), Integer.valueOf(intValue)), document, hashMap));
                if (pairwiseScore > count) {
                    count = pairwiseScore;
                    num2 = Integer.valueOf(intValue2);
                }
            }
            if (num2 != null) {
                CorefUtils.mergeCoreferenceClusters(new Pair(num2, Integer.valueOf(intValue)), document);
            }
        }
    }
}
