package org.apache.ctakes.temporal.ae.feature;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.ctakes.constituency.parser.treekernel.TreeExtractor;
import org.apache.ctakes.constituency.parser.util.AnnotationTreeUtils;
import org.apache.ctakes.core.resource.FileLocator;
import org.apache.ctakes.relationextractor.ae.features.RelationFeaturesExtractor;
import org.apache.ctakes.typesystem.type.syntax.TreebankNode;
import org.apache.ctakes.typesystem.type.textsem.IdentifiedAnnotation;
import org.apache.ctakes.utils.distsem.WordEmbeddings;
import org.apache.ctakes.utils.distsem.WordVector;
import org.apache.ctakes.utils.distsem.WordVectorReader;
import org.apache.uima.analysis_engine.AnalysisEngineProcessException;
import org.apache.uima.jcas.JCas;
import org.cleartk.ml.Feature;
import org.cleartk.ml.feature.extractor.CleartkExtractorException;

/* loaded from: input_file:org/apache/ctakes/temporal/ae/feature/RelationSyntacticETEmbeddingFeatureExtractor.class */
public class RelationSyntacticETEmbeddingFeatureExtractor implements RelationFeaturesExtractor<IdentifiedAnnotation, IdentifiedAnnotation> {
    private int numberOfDimensions;
    private WordEmbeddings paths;

    public RelationSyntacticETEmbeddingFeatureExtractor(String str) throws CleartkExtractorException {
        this.paths = null;
        try {
            this.paths = WordVectorReader.getEmbeddings(FileLocator.getAsStream(str));
            this.numberOfDimensions = this.paths.getDimensionality();
        } catch (IOException e) {
            e.printStackTrace();
            throw new CleartkExtractorException(e);
        }
    }

    public List<Feature> extract(JCas jCas, IdentifiedAnnotation identifiedAnnotation, IdentifiedAnnotation identifiedAnnotation2) throws AnalysisEngineProcessException {
        WordVector vector;
        String removeTail;
        ArrayList arrayList = new ArrayList();
        if (AnnotationTreeUtils.getTreeCopy(jCas, AnnotationTreeUtils.getAnnotationTree(jCas, identifiedAnnotation)) == null) {
            return arrayList;
        }
        TreebankNode annotationNode = AnnotationTreeUtils.annotationNode(jCas, identifiedAnnotation);
        TreebankNode annotationNode2 = AnnotationTreeUtils.annotationNode(jCas, identifiedAnnotation2);
        ArrayList<String> arrayList2 = new ArrayList();
        if (annotationNode.getBegin() <= annotationNode2.getBegin() && annotationNode.getEnd() >= annotationNode2.getEnd()) {
            arrayList2.add(getPathBetweenNodes(annotationNode2, annotationNode, ""));
        } else if (annotationNode2.getBegin() > annotationNode.getBegin() || annotationNode2.getEnd() < annotationNode.getEnd()) {
            TreebankNode lca = TreeExtractor.getLCA(annotationNode, annotationNode2);
            arrayList2.add(getPathBetweenNodes(annotationNode, lca, ""));
            arrayList2.add(getPathBetweenNodes(annotationNode2, lca, ""));
        } else {
            arrayList2.add(getPathBetweenNodes(annotationNode, annotationNode2, ""));
        }
        List<Double> arrayList3 = new ArrayList(Collections.nCopies(this.numberOfDimensions, Double.valueOf(0.0d)));
        for (String str : arrayList2) {
            if (this.paths.containsKey(str)) {
                vector = this.paths.getVector(str);
            } else {
                while (!this.paths.containsKey(str) && (removeTail = removeTail(str)) != null) {
                    str = removeTail;
                }
                vector = this.paths.containsKey(str) ? this.paths.getVector(str) : this.paths.getVector("<unk>");
            }
            arrayList3 = addVectors(arrayList3, vector);
        }
        for (int i = 0; i < this.numberOfDimensions; i++) {
            arrayList.add(new Feature(String.format("syntactic_average_dim_%d", Integer.valueOf(i)), Double.valueOf(arrayList3.get(i).doubleValue() / arrayList2.size())));
        }
        return arrayList;
    }

    private static String removeTail(String str) {
        int lastIndexOf = str.lastIndexOf("-");
        if (lastIndexOf > 0) {
            return str.substring(0, lastIndexOf);
        }
        return null;
    }

    private String getPathBetweenNodes(TreebankNode treebankNode, TreebankNode treebankNode2, String str) {
        TreebankNode parent = treebankNode.getParent();
        String nodeType = "".equals(str) ? treebankNode.getNodeType() : treebankNode.getNodeType() + "-" + str;
        return parent == null ? nodeType : parent == treebankNode2 ? parent.getNodeType() + "-" + nodeType : getPathBetweenNodes(parent, treebankNode2, nodeType);
    }

    public double computeCosineSimilarity(WordVector wordVector, WordVector wordVector2) {
        double d = 0.0d;
        double d2 = 0.01d;
        double d3 = 0.01d;
        for (int i = 0; i < this.numberOfDimensions; i++) {
            d += wordVector.getValue(i) * wordVector2.getValue(i);
            d2 += Math.pow(wordVector.getValue(i), 2.0d);
            d3 += Math.pow(wordVector2.getValue(i), 2.0d);
        }
        return d / (Math.sqrt(d2) * Math.sqrt(d3));
    }

    public double computeCosineSimilarity(List<Double> list, List<Double> list2) {
        double d = 0.0d;
        double d2 = 0.01d;
        double d3 = 0.01d;
        for (int i = 0; i < this.numberOfDimensions; i++) {
            d += list.get(i).doubleValue() * list2.get(i).doubleValue();
            d2 += Math.pow(list.get(i).doubleValue(), 2.0d);
            d3 += Math.pow(list2.get(i).doubleValue(), 2.0d);
        }
        return d / (Math.sqrt(d2) * Math.sqrt(d3));
    }

    public List<Double> addVectors(List<Double> list, WordVector wordVector) {
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < this.numberOfDimensions; i++) {
            arrayList.add(Double.valueOf(list.get(i).doubleValue() + wordVector.getValue(i)));
        }
        return arrayList;
    }
}
