package org.cleartk.ml.tksvmlight.model;

import com.google.common.annotations.Beta;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.lang.NotImplementedException;
import org.cleartk.ml.tksvmlight.TreeFeatureVector;
import org.cleartk.util.treebank.TopTreebankNode;
import org.cleartk.util.treebank.TreebankFormatParser;
import org.cleartk.util.treebank.TreebankNode;

@Beta
/* loaded from: input_file:org/cleartk/ml/tksvmlight/model/TreeKernel.class */
public class TreeKernel {
    public static final int MAX_CHILDREN = 5;
    public static final double LAMBDA_DEFAULT = 0.4d;
    private double lambda;
    private double lambdaSquared;
    public static final double MU_DEFAULT = 0.4d;
    private boolean normalize;
    private ForestSumMethod sumMethod;
    private KernelType kernelType;
    HashMap<String, TopTreebankNode> trees;
    HashMap<String, SimpleDepTree> depTrees;
    private double[] lambdaPowers = new double[10];
    private double mu = 0.4d;
    private boolean useCache = true;
    HashMap<SimpleDepTree, HashMap<SimpleDepTree, Double>> cache = new HashMap<>();
    private ConcurrentHashMap<String, Double> normalizers = new ConcurrentHashMap<>();

    /* loaded from: input_file:org/cleartk/ml/tksvmlight/model/TreeKernel$ForestSumMethod.class */
    public enum ForestSumMethod {
        SEQUENTIAL,
        ALL_PAIRS
    }

    /* loaded from: input_file:org/cleartk/ml/tksvmlight/model/TreeKernel$KernelType.class */
    public enum KernelType {
        SUBSET,
        SUBTREE,
        SUBSET_BOW,
        PARTIAL
    }

    public TreeKernel(double d, ForestSumMethod forestSumMethod, KernelType kernelType, boolean z) {
        this.lambda = 0.4d;
        this.lambdaSquared = this.lambda * this.lambda;
        this.normalize = false;
        this.sumMethod = ForestSumMethod.SEQUENTIAL;
        this.trees = null;
        this.depTrees = null;
        this.lambda = d;
        this.lambdaSquared = d * d;
        this.sumMethod = forestSumMethod;
        this.kernelType = kernelType;
        this.normalize = z;
        this.trees = new HashMap<>();
        this.depTrees = new HashMap<>();
        initExponents();
    }

    public void initExponents() {
        for (int i = 0; i < this.lambdaPowers.length; i++) {
            this.lambdaPowers[i] = Math.pow(this.lambda, i);
        }
    }

    public double evaluate(TreeFeatureVector treeFeatureVector, TreeFeatureVector treeFeatureVector2) {
        double d;
        double ptk;
        double d2 = 0.0d;
        if (this.sumMethod != ForestSumMethod.SEQUENTIAL) {
            throw new NotImplementedException("The only summation method implemented is Sequential!");
        }
        ArrayList arrayList = new ArrayList(treeFeatureVector.getTrees().values());
        ArrayList arrayList2 = new ArrayList(treeFeatureVector2.getTrees().values());
        for (int i = 0; i < arrayList.size(); i++) {
            String str = (String) arrayList.get(i);
            String str2 = (String) arrayList2.get(i);
            if (this.kernelType == KernelType.SUBSET) {
                d = d2;
                ptk = sst(str, str2);
            } else {
                if (this.kernelType != KernelType.PARTIAL) {
                    throw new NotImplementedException("The only kernel types implemented are SST and PTK!");
                }
                d = d2;
                ptk = ptk(str, str2);
            }
            d2 = d + ptk;
        }
        return d2;
    }

    private double sst(String str, String str2) {
        TopTreebankNode topTreebankNode;
        TopTreebankNode topTreebankNode2;
        if (this.trees.containsKey(str)) {
            topTreebankNode = this.trees.get(str);
        } else {
            topTreebankNode = TreebankFormatParser.parse(str);
            this.trees.put(str, topTreebankNode);
        }
        if (this.trees.containsKey(str2)) {
            topTreebankNode2 = this.trees.get(str2);
        } else {
            topTreebankNode2 = TreebankFormatParser.parse(str2);
            this.trees.put(str2, topTreebankNode2);
        }
        double d = 0.0d;
        double d2 = 0.0d;
        if (this.normalize) {
            if (!this.normalizers.containsKey(str)) {
                this.normalizers.put(str, Double.valueOf(sim(topTreebankNode, topTreebankNode)));
            }
            if (!this.normalizers.containsKey(str2)) {
                this.normalizers.put(str2, Double.valueOf(sim(topTreebankNode2, topTreebankNode2)));
            }
            d = this.normalizers.get(str).doubleValue();
            d2 = this.normalizers.get(str2).doubleValue();
        }
        return this.normalize ? sim(topTreebankNode, topTreebankNode2) / Math.sqrt(d * d2) : sim(topTreebankNode, topTreebankNode2);
    }

    private double sim(TreebankNode treebankNode, TreebankNode treebankNode2) {
        double d = 0.0d;
        List<TreebankNode> nodeList = getNodeList(treebankNode);
        List<TreebankNode> nodeList2 = getNodeList(treebankNode2);
        for (TreebankNode treebankNode3 : nodeList) {
            Iterator<TreebankNode> it = nodeList2.iterator();
            while (it.hasNext()) {
                d += numCommonSubtrees(treebankNode3, it.next());
            }
        }
        return d;
    }

    private double numCommonSubtrees(TreebankNode treebankNode, TreebankNode treebankNode2) {
        double d;
        double d2 = 1.0d;
        List children = treebankNode.getChildren();
        List children2 = treebankNode2.getChildren();
        int size = children.size();
        if (size != children2.size()) {
            d = 0.0d;
        } else if (!treebankNode.getType().equals(treebankNode2.getType())) {
            d = 0.0d;
        } else if (treebankNode.isLeaf() && treebankNode2.isLeaf()) {
            d = treebankNode.getValue().equals(treebankNode2.getValue()) ? this.lambda : 0.0d;
        } else {
            boolean z = true;
            int i = 0;
            while (true) {
                if (i >= size) {
                    break;
                }
                if (!((TreebankNode) children.get(i)).getType().equals(((TreebankNode) children2.get(i)).getType())) {
                    z = false;
                    break;
                }
                i++;
            }
            if (z) {
                for (int i2 = 0; i2 < size; i2++) {
                    d2 *= 1.0d + numCommonSubtrees((TreebankNode) children.get(i2), (TreebankNode) children2.get(i2));
                }
                d = this.lambda * d2;
            } else {
                d = 0.0d;
            }
        }
        return d;
    }

    private double ptk(String str, String str2) {
        SimpleDepTree simpleDepTree;
        SimpleDepTree simpleDepTree2;
        if (this.depTrees.containsKey(str)) {
            simpleDepTree = this.depTrees.get(str);
        } else {
            simpleDepTree = SimpleDepTree.fromString(str);
            this.depTrees.put(str, simpleDepTree);
        }
        if (this.depTrees.containsKey(str2)) {
            simpleDepTree2 = this.depTrees.get(str2);
        } else {
            simpleDepTree2 = SimpleDepTree.fromString(str2);
            this.depTrees.put(str2, simpleDepTree2);
        }
        if (!this.normalize) {
            return ptkSim(simpleDepTree, simpleDepTree2);
        }
        if (!this.normalizers.containsKey(str)) {
            this.normalizers.put(str, Double.valueOf(ptkSim(simpleDepTree, simpleDepTree)));
        }
        if (!this.normalizers.containsKey(str2)) {
            this.normalizers.put(str2, Double.valueOf(ptkSim(simpleDepTree2, simpleDepTree2)));
        }
        return ptkSim(simpleDepTree, simpleDepTree2) / Math.sqrt(this.normalizers.get(str).doubleValue() * this.normalizers.get(str2).doubleValue());
    }

    private double ptkSim(SimpleDepTree simpleDepTree, SimpleDepTree simpleDepTree2) {
        double d = 0.0d;
        List<SimpleDepTree> depNodeList = getDepNodeList(simpleDepTree);
        List<SimpleDepTree> depNodeList2 = getDepNodeList(simpleDepTree2);
        for (int i = 0; i < depNodeList.size(); i++) {
            SimpleDepTree simpleDepTree3 = depNodeList.get(i);
            for (int i2 = 0; i2 < depNodeList2.size(); i2++) {
                SimpleDepTree simpleDepTree4 = depNodeList2.get(i2);
                double d2 = 0.0d;
                if (simpleDepTree3.cat.equals(simpleDepTree4.cat)) {
                    if (!simpleDepTree3.isLeaf()) {
                        d2 = ptkDelta(simpleDepTree3, simpleDepTree4);
                    } else if (simpleDepTree3.cat.equals(simpleDepTree4.cat)) {
                        d2 = 0.0d + (this.mu * this.lambdaSquared);
                    }
                }
                d += d2;
            }
        }
        return d;
    }

    private double ptkDelta(SimpleDepTree simpleDepTree, SimpleDepTree simpleDepTree2) {
        if (!simpleDepTree.cat.equals(simpleDepTree2.cat)) {
            return 0.0d;
        }
        if (this.useCache && this.cache.containsKey(simpleDepTree) && this.cache.get(simpleDepTree).containsKey(simpleDepTree2)) {
            return this.cache.get(simpleDepTree).get(simpleDepTree2).doubleValue();
        }
        int size = simpleDepTree.children.size();
        int size2 = simpleDepTree2.children.size();
        double d = 1.0d;
        for (int i = 1; i <= Math.min(Math.min(size, size2), 5); i++) {
            d += ptkDeltaP(simpleDepTree.children, simpleDepTree2.children, i);
        }
        double d2 = this.mu * this.lambdaSquared * d;
        if (this.useCache) {
            if (!this.cache.containsKey(simpleDepTree)) {
                this.cache.put(simpleDepTree, new HashMap<>());
            }
            this.cache.get(simpleDepTree).put(simpleDepTree2, Double.valueOf(d2));
        }
        return d2;
    }

    private double ptkDeltaP(List<SimpleDepTree> list, List<SimpleDepTree> list2, int i) {
        return ptkDeltaP(list, list2, list.size() - 1, list2.size() - 1, i);
    }

    private double ptkDeltaP(List<SimpleDepTree> list, List<SimpleDepTree> list2, int i, int i2, int i3) {
        double ptkDelta = ptkDelta(list.get(i), list2.get(i2));
        for (int i4 = 0; i4 < i; i4++) {
            for (int i5 = 0; i5 < i2; i5++) {
                int i6 = ((i - i4) + i2) - i5;
                ptkDelta += (i6 < this.lambdaPowers.length - 1 ? this.lambdaPowers[i6] : Math.pow(this.lambda, i6)) * ptkDeltaP(list, list2, i4, i5, i3 - 1);
            }
        }
        return ptkDelta;
    }

    private static final List<TreebankNode> getNodeList(TreebankNode treebankNode) {
        ArrayList arrayList = new ArrayList();
        arrayList.add(treebankNode);
        for (int i = 0; i < arrayList.size(); i++) {
            arrayList.addAll(((TreebankNode) arrayList.get(i)).getChildren());
        }
        return arrayList;
    }

    private static final List<SimpleDepTree> getDepNodeList(SimpleDepTree simpleDepTree) {
        ArrayList newArrayList = Lists.newArrayList();
        newArrayList.add(simpleDepTree);
        for (int i = 0; i < newArrayList.size(); i++) {
            newArrayList.addAll(((SimpleDepTree) newArrayList.get(i)).children);
        }
        return newArrayList;
    }

    public void setUseCache(boolean z) {
        this.useCache = z;
    }

    public boolean getUseCache() {
        return this.useCache;
    }
}
