package org.apache.ctakes.relationextractor.ae.features;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import org.apache.ctakes.relationextractor.data.analysis.Utils;
import org.apache.ctakes.typesystem.type.syntax.WordToken;
import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.fit.util.JCasUtil;
import org.apache.uima.jcas.JCas;
import org.cleartk.ml.Feature;

/* loaded from: input_file:org/apache/ctakes/relationextractor/ae/features/EmbeddingFeatureExtractor.class */
public class EmbeddingFeatureExtractor implements RelationFeaturesExtractor<IdentifiedAnnotation, IdentifiedAnnotation> {
    private int numberOfDimensions;
    private Map<String, List<Double>> wordVectors;

    public EmbeddingFeatureExtractor(Map<String, List<Double>> map) {
        this.wordVectors = map;
        this.numberOfDimensions = this.wordVectors.get("oov").size();
    }

    @Override // org.apache.ctakes.relationextractor.ae.features.RelationFeaturesExtractor
    public List<Feature> extract(JCas jCas, IdentifiedAnnotation identifiedAnnotation, IdentifiedAnnotation identifiedAnnotation2) throws AnalysisEngineProcessException {
        ArrayList arrayList = new ArrayList();
        String lowerCase = Utils.getLastWord(jCas, identifiedAnnotation).toLowerCase();
        String lowerCase2 = Utils.getLastWord(jCas, identifiedAnnotation2).toLowerCase();
        List<Double> list = this.wordVectors.containsKey(lowerCase) ? this.wordVectors.get(lowerCase) : this.wordVectors.get("oov");
        List<Double> list2 = this.wordVectors.containsKey(lowerCase2) ? this.wordVectors.get(lowerCase2) : this.wordVectors.get("oov");
        for (int i = 0; i < this.numberOfDimensions; i++) {
            arrayList.add(new Feature(String.format("arg1_dim_%d", Integer.valueOf(i)), list.get(i)));
        }
        for (int i2 = 0; i2 < this.numberOfDimensions; i2++) {
            arrayList.add(new Feature(String.format("arg2_dim_%d", Integer.valueOf(i2)), list2.get(i2)));
        }
        arrayList.add(new Feature("arg_cos_sim", Double.valueOf(computeCosineSimilarity(list, list2))));
        List<WordToken> selectBetween = JCasUtil.selectBetween(jCas, WordToken.class, identifiedAnnotation, identifiedAnnotation2);
        if (selectBetween.size() < 1) {
            return arrayList;
        }
        List<Double> arrayList2 = new ArrayList(Collections.nCopies(this.numberOfDimensions, Double.valueOf(0.0d)));
        for (WordToken wordToken : selectBetween) {
            arrayList2 = addVectors(arrayList2, this.wordVectors.containsKey(wordToken.getCoveredText().toLowerCase()) ? this.wordVectors.get(wordToken.getCoveredText().toLowerCase()) : this.wordVectors.get("oov"));
        }
        for (int i3 = 0; i3 < this.numberOfDimensions; i3++) {
            arrayList.add(new Feature(String.format("average_dim_%d", Integer.valueOf(i3)), Double.valueOf(arrayList2.get(i3).doubleValue() / selectBetween.size())));
        }
        return arrayList;
    }

    public double computeCosineSimilarity(List<Double> list, List<Double> list2) {
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        for (int i = 0; i < this.numberOfDimensions; i++) {
            d += list.get(i).doubleValue() * list2.get(i).doubleValue();
            d2 += Math.pow(list.get(i).doubleValue(), 2.0d);
            d3 += Math.pow(list2.get(i).doubleValue(), 2.0d);
        }
        return d / (Math.sqrt(d2) * Math.sqrt(d3));
    }

    public List<Double> addVectors(List<Double> list, List<Double> list2) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numberOfDimensions; i++) {
            arrayList.add(Double.valueOf(list.get(i).doubleValue() + list2.get(i).doubleValue()));
        }
        return arrayList;
    }
}
