package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.international.morph.MorphoFeatures;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.MapFactory;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ThreeDimensionalMap;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.TwoDimensionalMap;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/parser/lexparser/SplittingGrammarExtractor.class */
public class SplittingGrammarExtractor {
    static final int MIN_DEBUG_ITERATION = 0;
    static final int MAX_DEBUG_ITERATION = 0;
    static final int MAX_ITERATIONS = Integer.MAX_VALUE;
    Options op;
    Index<String> stateIndex;
    Index<String> wordIndex;
    Index<String> tagIndex;
    List<String> startSymbols;
    double trainSize;
    Lexicon lex;
    transient Index<String> tempWordIndex;
    transient Index<String> tempTagIndex;
    transient Lexicon tempLex;
    Pair<UnaryGrammar, BinaryGrammar> bgug;
    static final double LEX_SMOOTH = 1.0E-4d;
    static final double STATE_SMOOTH = 0.0d;
    static final double EPSILON = 1.0E-4d;
    int iteration = 0;
    List<Tree> trees = new ArrayList();
    Counter<Tree> treeWeights = new ClassicCounter(MapFactory.identityHashMapFactory());
    Set<String> originalStates = Generics.newHashSet();
    IntCounter<String> stateSplitCounts = new IntCounter<>();
    ThreeDimensionalMap<String, String, String, double[][][]> binaryBetas = new ThreeDimensionalMap<>();
    TwoDimensionalMap<String, String, double[][]> unaryBetas = new TwoDimensionalMap<>();
    Random random = new Random(87543875943265L);

    boolean DEBUG() {
        return this.iteration >= 0 && this.iteration < 0;
    }

    public SplittingGrammarExtractor(Options options) {
        this.op = options;
        this.startSymbols = Arrays.asList(options.langpack().startSymbols());
    }

    double[] neginfDoubles(int i) {
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = Double.NEGATIVE_INFINITY;
        }
        return dArr;
    }

    public void outputTransitions(Tree tree, IdentityHashMap<Tree, double[][]> identityHashMap, IdentityHashMap<Tree, double[][][]> identityHashMap2) {
        outputTransitions(tree, 0, identityHashMap, identityHashMap2);
    }

    public void outputTransitions(Tree tree, int i, IdentityHashMap<Tree, double[][]> identityHashMap, IdentityHashMap<Tree, double[][][]> identityHashMap2) {
        for (int i2 = 0; i2 < i; i2++) {
            System.out.print(" ");
        }
        if (tree.isLeaf()) {
            System.out.println(tree.label().value());
            return;
        }
        if (tree.children().length == 1) {
            System.out.println(tree.label().value() + " -> " + tree.children()[0].label().value());
            if (!tree.isPreTerminal()) {
                double[][] dArr = identityHashMap.get(tree);
                for (int i3 = 0; i3 < dArr.length; i3++) {
                    for (int i4 = 0; i4 < dArr[0].length; i4++) {
                        for (int i5 = 0; i5 < i; i5++) {
                            System.out.print(" ");
                        }
                        System.out.println("  " + i3 + "," + i4 + ": " + dArr[i3][i4] + " | " + Math.exp(dArr[i3][i4]));
                    }
                }
            }
        } else {
            System.out.println(tree.label().value() + " -> " + tree.children()[0].label().value() + " " + tree.children()[1].label().value());
            double[][][] dArr2 = identityHashMap2.get(tree);
            for (int i6 = 0; i6 < dArr2.length; i6++) {
                for (int i7 = 0; i7 < dArr2[0].length; i7++) {
                    for (int i8 = 0; i8 < dArr2[0][0].length; i8++) {
                        for (int i9 = 0; i9 < i; i9++) {
                            System.out.print(" ");
                        }
                        System.out.println("  " + i6 + "," + i7 + "," + i8 + ": " + dArr2[i6][i7][i8] + " | " + Math.exp(dArr2[i6][i7][i8]));
                    }
                }
            }
        }
        if (tree.isPreTerminal()) {
            return;
        }
        for (Tree tree2 : tree.children()) {
            outputTransitions(tree2, i + 1, identityHashMap, identityHashMap2);
        }
    }

    public void outputBetas() {
        System.out.println("UNARY:");
        for (String str : this.unaryBetas.firstKeySet()) {
            for (String str2 : this.unaryBetas.get(str).keySet()) {
                System.out.println("  " + str + "->" + str2);
                double[][] dArr = this.unaryBetas.get(str).get(str2);
                int length = dArr.length;
                int length2 = dArr[0].length;
                for (int i = 0; i < length; i++) {
                    for (int i2 = 0; i2 < length2; i2++) {
                        System.out.println("    " + i + "->" + i2 + " " + dArr[i][i2] + " | " + Math.exp(dArr[i][i2]));
                    }
                }
            }
        }
        System.out.println("BINARY:");
        for (String str3 : this.binaryBetas.firstKeySet()) {
            for (String str4 : this.binaryBetas.get(str3).firstKeySet()) {
                for (String str5 : this.binaryBetas.get(str3).get(str4).keySet()) {
                    System.out.println("  " + str3 + "->" + str4 + "," + str5);
                    double[][][] dArr2 = this.binaryBetas.get(str3).get(str4).get(str5);
                    int length3 = dArr2.length;
                    int length4 = dArr2[0].length;
                    int length5 = dArr2[0][0].length;
                    for (int i3 = 0; i3 < length3; i3++) {
                        for (int i4 = 0; i4 < length4; i4++) {
                            for (int i5 = 0; i5 < length5; i5++) {
                                System.out.println("    " + i3 + "->" + i4 + "," + i5 + " " + dArr2[i3][i4][i5] + " | " + Math.exp(dArr2[i3][i4][i5]));
                            }
                        }
                    }
                }
            }
        }
    }

    public String state(String str, int i) {
        return (this.startSymbols.contains(str) || str.equals(".$$.")) ? str : str + "^" + i;
    }

    public int getStateSplitCount(Tree tree) {
        return this.stateSplitCounts.getIntCount(tree.label().value());
    }

    public int getStateSplitCount(String str) {
        return this.stateSplitCounts.getIntCount(str);
    }

    public void countOriginalStates() {
        this.originalStates.clear();
        Iterator<Tree> it = this.trees.iterator();
        while (it.hasNext()) {
            countOriginalStates(it.next());
        }
        Iterator<String> it2 = this.originalStates.iterator();
        while (it2.hasNext()) {
            this.stateSplitCounts.incrementCount((IntCounter<String>) it2.next(), 1);
        }
    }

    private void countOriginalStates(Tree tree) {
        if (tree.isLeaf()) {
            return;
        }
        this.originalStates.add(tree.label().value());
        for (Tree tree2 : tree.children()) {
            if (!tree2.isLeaf()) {
                countOriginalStates(tree2);
            }
        }
    }

    private void initialBetasAndLexicon() {
        this.wordIndex = new HashIndex();
        this.tagIndex = new HashIndex();
        this.lex = this.op.tlpParams.lex(this.op, this.wordIndex, this.tagIndex);
        this.lex.initializeTraining(this.trainSize);
        for (Tree tree : this.trees) {
            double count = this.treeWeights.getCount(tree);
            this.lex.incrementTreesRead(count);
            initialBetasAndLexicon(tree, 0, count);
        }
        this.lex.finishTraining();
    }

    private int initialBetasAndLexicon(Tree tree, int i, double d) {
        if (tree.isLeaf()) {
            return i;
        }
        if (tree.isPreTerminal()) {
            this.lex.train(new TaggedWord(tree.children()[0].label().value(), state(tree.label().value(), 0)), i, d);
            return i + 1;
        }
        if (tree.children().length == 2) {
            String value = tree.label().value();
            String value2 = tree.getChild(0).label().value();
            String value3 = tree.getChild(1).label().value();
            if (!this.binaryBetas.contains(value, value2, value3)) {
                double[][][] dArr = new double[1][1][1];
                dArr[0][0][0] = 0.0d;
                this.binaryBetas.put(value, value2, value3, dArr);
            }
        } else {
            if (tree.children().length != 1) {
                throw new RuntimeException("Trees should have been binarized, expected 1 or 2 children");
            }
            String value4 = tree.label().value();
            String value5 = tree.getChild(0).label().value();
            if (!this.unaryBetas.contains(value4, value5)) {
                double[][] dArr2 = new double[1][1];
                dArr2[0][0] = 0.0d;
                this.unaryBetas.put(value4, value5, dArr2);
            }
        }
        for (Tree tree2 : tree.children()) {
            i = initialBetasAndLexicon(tree2, i, d);
        }
        return i;
    }

    private void splitStateCounts() {
        IntCounter<String> intCounter = new IntCounter<>();
        intCounter.addAll(this.stateSplitCounts);
        intCounter.addAll(this.stateSplitCounts);
        for (String str : this.startSymbols) {
            if (intCounter.getCount(str) > 1.0d) {
                intCounter.setCount((IntCounter<String>) str, 1);
            }
        }
        if (intCounter.getCount(".$$.") > 1.0d) {
            intCounter.setCount((IntCounter<String>) ".$$.", 1);
        }
        this.stateSplitCounts = intCounter;
    }

    public void splitBetas() {
        TwoDimensionalMap<String, String, double[][]> twoDimensionalMap = new TwoDimensionalMap<>();
        ThreeDimensionalMap<String, String, String, double[][][]> threeDimensionalMap = new ThreeDimensionalMap<>();
        for (String str : this.unaryBetas.firstKeySet()) {
            for (String str2 : this.unaryBetas.get(str).keySet()) {
                double[][] dArr = this.unaryBetas.get(str, str2);
                int length = dArr.length;
                int length2 = dArr[0].length;
                if (!this.startSymbols.contains(str)) {
                    double[][] dArr2 = new double[length * 2][length2];
                    for (int i = 0; i < length; i++) {
                        for (int i2 = 0; i2 < length2; i2++) {
                            dArr2[i * 2][i2] = dArr[i][i2];
                            dArr2[(i * 2) + 1][i2] = dArr[i][i2];
                        }
                    }
                    length *= 2;
                    dArr = dArr2;
                }
                if (!str2.equals(".$$.")) {
                    double[][] dArr3 = new double[length][length2 * 2];
                    for (int i3 = 0; i3 < length; i3++) {
                        for (int i4 = 0; i4 < length2; i4++) {
                            double nextDouble = 0.45d + (this.random.nextDouble() * 0.1d);
                            dArr3[i3][i4 * 2] = dArr[i3][i4] + Math.log(nextDouble);
                            dArr3[i3][(i4 * 2) + 1] = dArr[i3][i4] + Math.log(1.0d - nextDouble);
                        }
                    }
                    dArr = dArr3;
                }
                twoDimensionalMap.put(str, str2, dArr);
            }
        }
        for (String str3 : this.binaryBetas.firstKeySet()) {
            for (String str4 : this.binaryBetas.get(str3).firstKeySet()) {
                for (String str5 : this.binaryBetas.get(str3).get(str4).keySet()) {
                    double[][][] dArr4 = this.binaryBetas.get(str3, str4, str5);
                    int length3 = dArr4.length;
                    int length4 = dArr4[0].length;
                    int length5 = dArr4[0][0].length;
                    if (!this.startSymbols.contains(str3)) {
                        double[][][] dArr5 = new double[length3 * 2][length4][length5];
                        for (int i5 = 0; i5 < length3; i5++) {
                            for (int i6 = 0; i6 < length4; i6++) {
                                for (int i7 = 0; i7 < length5; i7++) {
                                    dArr5[i5 * 2][i6][i7] = dArr4[i5][i6][i7];
                                    dArr5[(i5 * 2) + 1][i6][i7] = dArr4[i5][i6][i7];
                                }
                            }
                        }
                        length3 *= 2;
                        dArr4 = dArr5;
                    }
                    double[][][] dArr6 = new double[length3][length4 * 2][length5];
                    for (int i8 = 0; i8 < length3; i8++) {
                        for (int i9 = 0; i9 < length4; i9++) {
                            for (int i10 = 0; i10 < length5; i10++) {
                                double nextDouble2 = 0.45d + (this.random.nextDouble() * 0.1d);
                                dArr6[i8][i9 * 2][i10] = dArr4[i8][i9][i10] + Math.log(nextDouble2);
                                dArr6[i8][(i9 * 2) + 1][i10] = dArr4[i8][i9][i10] + Math.log(1.0d - nextDouble2);
                            }
                        }
                    }
                    int i11 = length4 * 2;
                    if (!str5.equals(".$$.")) {
                        dArr6 = new double[length3][i11][length5 * 2];
                        for (int i12 = 0; i12 < length3; i12++) {
                            for (int i13 = 0; i13 < i11; i13++) {
                                for (int i14 = 0; i14 < length5; i14++) {
                                    double nextDouble3 = 0.45d + (this.random.nextDouble() * 0.1d);
                                    dArr6[i12][i13][i14 * 2] = dArr6[i12][i13][i14] + Math.log(nextDouble3);
                                    dArr6[i12][i13][(i14 * 2) + 1] = dArr6[i12][i13][i14] + Math.log(1.0d - nextDouble3);
                                }
                            }
                        }
                    }
                    threeDimensionalMap.put(str3, str4, str5, dArr6);
                }
            }
        }
        this.unaryBetas = twoDimensionalMap;
        this.binaryBetas = threeDimensionalMap;
    }

    public boolean recalculateBetas(boolean z) {
        if (z) {
            if (DEBUG()) {
                System.out.println("Pre-split betas");
                outputBetas();
            }
            splitBetas();
            if (DEBUG()) {
                System.out.println("Post-split betas");
                outputBetas();
            }
        }
        TwoDimensionalMap<String, String, double[][]> twoDimensionalMap = new TwoDimensionalMap<>();
        ThreeDimensionalMap<String, String, String, double[][][]> threeDimensionalMap = new ThreeDimensionalMap<>();
        recalculateTemporaryBetas(z, null, twoDimensionalMap, threeDimensionalMap);
        boolean useNewBetas = useNewBetas(!z, twoDimensionalMap, threeDimensionalMap);
        if (DEBUG()) {
            outputBetas();
        }
        return useNewBetas;
    }

    public boolean useNewBetas(boolean z, TwoDimensionalMap<String, String, double[][]> twoDimensionalMap, ThreeDimensionalMap<String, String, String, double[][][]> threeDimensionalMap) {
        rescaleTemporaryBetas(twoDimensionalMap, threeDimensionalMap);
        boolean z2 = z && testConvergence(twoDimensionalMap, threeDimensionalMap);
        this.unaryBetas = twoDimensionalMap;
        this.binaryBetas = threeDimensionalMap;
        this.wordIndex = this.tempWordIndex;
        this.tagIndex = this.tempTagIndex;
        this.lex = this.tempLex;
        if (DEBUG()) {
            System.out.println("LEXICON");
            try {
                OutputStreamWriter outputStreamWriter = new OutputStreamWriter(System.out, "utf-8");
                this.lex.writeData(outputStreamWriter);
                outputStreamWriter.flush();
            } catch (IOException e) {
                throw new RuntimeIOException(e);
            }
        }
        this.tempWordIndex = null;
        this.tempTagIndex = null;
        this.tempLex = null;
        return z2;
    }

    public void recalculateTemporaryBetas(boolean z, Map<String, double[]> map, TwoDimensionalMap<String, String, double[][]> twoDimensionalMap, ThreeDimensionalMap<String, String, String, double[][][]> threeDimensionalMap) {
        this.tempWordIndex = new HashIndex();
        this.tempTagIndex = new HashIndex();
        this.tempLex = this.op.tlpParams.lex(this.op, this.tempWordIndex, this.tempTagIndex);
        this.tempLex.initializeTraining(this.trainSize);
        for (Tree tree : this.trees) {
            double count = this.treeWeights.getCount(tree);
            if (DEBUG()) {
                System.out.println("Incrementing trees read: " + count);
            }
            this.tempLex.incrementTreesRead(count);
            recalculateTemporaryBetas(tree, z, map, twoDimensionalMap, threeDimensionalMap);
        }
        this.tempLex.finishTraining();
    }

    public boolean testConvergence(TwoDimensionalMap<String, String, double[][]> twoDimensionalMap, ThreeDimensionalMap<String, String, String, double[][][]> threeDimensionalMap) {
        for (String str : this.unaryBetas.firstKeySet()) {
            for (String str2 : this.unaryBetas.get(str).keySet()) {
                double[][] dArr = this.unaryBetas.get(str, str2);
                double[][] dArr2 = twoDimensionalMap.get(str, str2);
                int length = dArr.length;
                int length2 = dArr[0].length;
                for (int i = 0; i < length; i++) {
                    for (int i2 = 0; i2 < length2; i2++) {
                        if (Math.abs(dArr2[i][i2] - dArr[i][i2]) > 1.0E-4d) {
                            return false;
                        }
                    }
                }
            }
        }
        for (String str3 : this.binaryBetas.firstKeySet()) {
            for (String str4 : this.binaryBetas.get(str3).firstKeySet()) {
                for (String str5 : this.binaryBetas.get(str3).get(str4).keySet()) {
                    double[][][] dArr3 = this.binaryBetas.get(str3, str4, str5);
                    double[][][] dArr4 = threeDimensionalMap.get(str3, str4, str5);
                    int length3 = dArr3.length;
                    int length4 = dArr3[0].length;
                    int length5 = dArr3[0][0].length;
                    for (int i3 = 0; i3 < length3; i3++) {
                        for (int i4 = 0; i4 < length4; i4++) {
                            for (int i5 = 0; i5 < length5; i5++) {
                                if (Math.abs(dArr4[i3][i4][i5] - dArr3[i3][i4][i5]) > 1.0E-4d) {
                                    return false;
                                }
                            }
                        }
                    }
                }
            }
        }
        return true;
    }

    public void recalculateTemporaryBetas(Tree tree, boolean z, Map<String, double[]> map, TwoDimensionalMap<String, String, double[][]> twoDimensionalMap, ThreeDimensionalMap<String, String, String, double[][][]> threeDimensionalMap) {
        if (DEBUG()) {
            System.out.println("Recalculating temporary betas for tree " + tree);
        }
        double[] dArr = {Math.log(this.treeWeights.getCount(tree))};
        IdentityHashMap<Tree, double[][]> identityHashMap = new IdentityHashMap<>();
        IdentityHashMap<Tree, double[][][]> identityHashMap2 = new IdentityHashMap<>();
        recountTree(tree, z, identityHashMap, identityHashMap2);
        if (DEBUG()) {
            System.out.println("  Transitions:");
            outputTransitions(tree, identityHashMap, identityHashMap2);
        }
        recalculateTemporaryBetas(tree, dArr, 0, identityHashMap, identityHashMap2, map, twoDimensionalMap, threeDimensionalMap);
    }

    public int recalculateTemporaryBetas(Tree tree, double[] dArr, int i, IdentityHashMap<Tree, double[][]> identityHashMap, IdentityHashMap<Tree, double[][][]> identityHashMap2, Map<String, double[]> map, TwoDimensionalMap<String, String, double[][]> twoDimensionalMap, ThreeDimensionalMap<String, String, String, double[][][]> threeDimensionalMap) {
        int recalculateTemporaryBetas;
        if (tree.isLeaf()) {
            return i;
        }
        if (map != null) {
            double[] dArr2 = map.get(tree.label().value());
            if (dArr2 == null) {
                dArr2 = new double[dArr.length];
                map.put(tree.label().value(), dArr2);
            }
            for (int i2 = 0; i2 < dArr.length; i2++) {
                double[] dArr3 = dArr2;
                int i3 = i2;
                dArr3[i3] = dArr3[i3] + Math.exp(dArr[i2]);
            }
        }
        if (tree.isPreTerminal()) {
            String value = tree.label().value();
            String value2 = tree.children()[0].label().value();
            double d = 0.0d;
            for (double d2 : dArr) {
                d += Math.exp(d2);
            }
            if (d <= STATE_SMOOTH) {
                return i + 1;
            }
            double length = (d * 1.0E-4d) / dArr.length;
            for (int i4 = 0; i4 < dArr.length; i4++) {
                this.tempLex.train(new TaggedWord(value2, state(value, i4)), i, (Math.exp(dArr[i4]) + length) * 0.9999000099990001d);
            }
            return i + 1;
        }
        if (tree.children().length == 1) {
            String value3 = tree.label().value();
            String value4 = tree.children()[0].label().value();
            double[][] dArr4 = identityHashMap.get(tree);
            int length2 = dArr4.length;
            int length3 = dArr4[0].length;
            double[][] dArr5 = twoDimensionalMap.get(value3, value4);
            if (dArr5 == null) {
                dArr5 = new double[length2][length3];
                for (int i5 = 0; i5 < length2; i5++) {
                    for (int i6 = 0; i6 < length3; i6++) {
                        dArr5[i5][i6] = Double.NEGATIVE_INFINITY;
                    }
                }
                twoDimensionalMap.put(value3, value4, dArr5);
            }
            double[] neginfDoubles = neginfDoubles(length3);
            for (int i7 = 0; i7 < length2; i7++) {
                for (int i8 = 0; i8 < length3; i8++) {
                    double d3 = dArr4[i7][i8];
                    dArr5[i7][i8] = SloppyMath.logAdd(dArr5[i7][i8], d3 + dArr[i7]);
                    neginfDoubles[i8] = SloppyMath.logAdd(neginfDoubles[i8], d3 + dArr[i7]);
                }
            }
            recalculateTemporaryBetas = recalculateTemporaryBetas(tree.children()[0], neginfDoubles, i, identityHashMap, identityHashMap2, map, twoDimensionalMap, threeDimensionalMap);
        } else {
            String value5 = tree.label().value();
            String value6 = tree.children()[0].label().value();
            String value7 = tree.children()[1].label().value();
            double[][][] dArr6 = identityHashMap2.get(tree);
            int length4 = dArr6.length;
            int length5 = dArr6[0].length;
            int length6 = dArr6[0][0].length;
            double[][][] dArr7 = threeDimensionalMap.get(value5, value6, value7);
            if (dArr7 == null) {
                dArr7 = new double[length4][length5][length6];
                for (int i9 = 0; i9 < length4; i9++) {
                    for (int i10 = 0; i10 < length5; i10++) {
                        for (int i11 = 0; i11 < length6; i11++) {
                            dArr7[i9][i10][i11] = Double.NEGATIVE_INFINITY;
                        }
                    }
                }
                threeDimensionalMap.put(value5, value6, value7, dArr7);
            }
            double[] neginfDoubles2 = neginfDoubles(length5);
            double[] neginfDoubles3 = neginfDoubles(length6);
            for (int i12 = 0; i12 < length4; i12++) {
                for (int i13 = 0; i13 < length5; i13++) {
                    for (int i14 = 0; i14 < length6; i14++) {
                        double d4 = dArr6[i12][i13][i14];
                        dArr7[i12][i13][i14] = SloppyMath.logAdd(dArr7[i12][i13][i14], d4 + dArr[i12]);
                        neginfDoubles2[i13] = SloppyMath.logAdd(neginfDoubles2[i13], d4 + dArr[i12]);
                        neginfDoubles3[i14] = SloppyMath.logAdd(neginfDoubles3[i14], d4 + dArr[i12]);
                    }
                }
            }
            recalculateTemporaryBetas = recalculateTemporaryBetas(tree.children()[1], neginfDoubles3, recalculateTemporaryBetas(tree.children()[0], neginfDoubles2, i, identityHashMap, identityHashMap2, map, twoDimensionalMap, threeDimensionalMap), identityHashMap, identityHashMap2, map, twoDimensionalMap, threeDimensionalMap);
        }
        return recalculateTemporaryBetas;
    }

    public void rescaleTemporaryBetas(TwoDimensionalMap<String, String, double[][]> twoDimensionalMap, ThreeDimensionalMap<String, String, String, double[][][]> threeDimensionalMap) {
        for (String str : twoDimensionalMap.firstKeySet()) {
            Iterator<String> it = twoDimensionalMap.get(str).keySet().iterator();
            while (it.hasNext()) {
                double[][] dArr = twoDimensionalMap.get(str).get(it.next());
                int length = dArr.length;
                int length2 = dArr[0].length;
                for (int i = 0; i < length; i++) {
                    double d = Double.NEGATIVE_INFINITY;
                    for (int i2 = 0; i2 < length2; i2++) {
                        d = SloppyMath.logAdd(d, dArr[i][i2]);
                    }
                    if (Double.isInfinite(d)) {
                        for (int i3 = 0; i3 < length2; i3++) {
                            dArr[i][i3] = -Math.log(length2);
                        }
                    } else {
                        for (int i4 = 0; i4 < length2; i4++) {
                            double[] dArr2 = dArr[i];
                            int i5 = i4;
                            dArr2[i5] = dArr2[i5] - d;
                        }
                    }
                }
            }
        }
        for (String str2 : threeDimensionalMap.firstKeySet()) {
            for (String str3 : threeDimensionalMap.get(str2).firstKeySet()) {
                Iterator<String> it2 = threeDimensionalMap.get(str2).get(str3).keySet().iterator();
                while (it2.hasNext()) {
                    double[][][] dArr3 = threeDimensionalMap.get(str2).get(str3).get(it2.next());
                    int length3 = dArr3.length;
                    int length4 = dArr3[0].length;
                    int length5 = dArr3[0][0].length;
                    for (int i6 = 0; i6 < length3; i6++) {
                        double d2 = Double.NEGATIVE_INFINITY;
                        for (int i7 = 0; i7 < length4; i7++) {
                            for (int i8 = 0; i8 < length5; i8++) {
                                d2 = SloppyMath.logAdd(d2, dArr3[i6][i7][i8]);
                            }
                        }
                        if (Double.isInfinite(d2)) {
                            for (int i9 = 0; i9 < length4; i9++) {
                                for (int i10 = 0; i10 < length5; i10++) {
                                    dArr3[i6][i9][i10] = -Math.log(length4 * length5);
                                }
                            }
                        } else {
                            for (int i11 = 0; i11 < length4; i11++) {
                                for (int i12 = 0; i12 < length5; i12++) {
                                    double[] dArr4 = dArr3[i6][i11];
                                    int i13 = i12;
                                    dArr4[i13] = dArr4[i13] - d2;
                                }
                            }
                        }
                    }
                }
            }
        }
    }

    public void recountTree(Tree tree, boolean z, IdentityHashMap<Tree, double[][]> identityHashMap, IdentityHashMap<Tree, double[][][]> identityHashMap2) {
        recountTree(tree, z, new IdentityHashMap<>(), new IdentityHashMap<>(), identityHashMap, identityHashMap2);
    }

    public void recountTree(Tree tree, boolean z, IdentityHashMap<Tree, double[]> identityHashMap, IdentityHashMap<Tree, double[]> identityHashMap2, IdentityHashMap<Tree, double[][]> identityHashMap3, IdentityHashMap<Tree, double[][][]> identityHashMap4) {
        recountInside(tree, z, 0, identityHashMap);
        if (DEBUG()) {
            System.out.println("ROOT PROBABILITY: " + identityHashMap.get(tree)[0]);
        }
        recountOutside(tree, identityHashMap, identityHashMap2);
        recountWeights(tree, identityHashMap, identityHashMap2, identityHashMap3, identityHashMap4);
    }

    public void recountWeights(Tree tree, IdentityHashMap<Tree, double[]> identityHashMap, IdentityHashMap<Tree, double[]> identityHashMap2, IdentityHashMap<Tree, double[][]> identityHashMap3, IdentityHashMap<Tree, double[][][]> identityHashMap4) {
        if (tree.isLeaf() || tree.isPreTerminal()) {
            return;
        }
        if (tree.children().length == 1) {
            Tree tree2 = tree.children()[0];
            double[][] dArr = this.unaryBetas.get(tree.label().value(), tree2.label().value());
            double[] dArr2 = identityHashMap.get(tree2);
            double[] dArr3 = identityHashMap2.get(tree);
            int length = dArr.length;
            int length2 = dArr[0].length;
            double[][] dArr4 = new double[length][length2];
            identityHashMap3.put(tree, dArr4);
            for (int i = 0; i < length; i++) {
                for (int i2 = 0; i2 < length2; i2++) {
                    dArr4[i][i2] = dArr3[i] + dArr2[i2] + dArr[i][i2];
                }
            }
            for (int i3 = 0; i3 < length; i3++) {
                double d = Double.NEGATIVE_INFINITY;
                for (int i4 = 0; i4 < length2; i4++) {
                    d = SloppyMath.logAdd(d, dArr4[i3][i4]);
                }
                if (Double.isInfinite(d)) {
                    double d2 = -Math.log(length2);
                    for (int i5 = 0; i5 < length2; i5++) {
                        dArr4[i3][i5] = d2;
                    }
                } else {
                    for (int i6 = 0; i6 < length2; i6++) {
                        dArr4[i3][i6] = dArr4[i3][i6] - d;
                    }
                }
            }
            recountWeights(tree2, identityHashMap, identityHashMap2, identityHashMap3, identityHashMap4);
            return;
        }
        Tree tree3 = tree.children()[0];
        Tree tree4 = tree.children()[1];
        double[][][] dArr5 = this.binaryBetas.get(tree.label().value(), tree3.label().value(), tree4.label().value());
        double[] dArr6 = identityHashMap.get(tree3);
        double[] dArr7 = identityHashMap.get(tree4);
        double[] dArr8 = identityHashMap2.get(tree);
        int length3 = dArr5.length;
        int length4 = dArr5[0].length;
        int length5 = dArr5[0][0].length;
        double[][][] dArr9 = new double[length3][length4][length5];
        identityHashMap4.put(tree, dArr9);
        for (int i7 = 0; i7 < length3; i7++) {
            for (int i8 = 0; i8 < length4; i8++) {
                for (int i9 = 0; i9 < length5; i9++) {
                    dArr9[i7][i8][i9] = dArr8[i7] + dArr6[i8] + dArr7[i9] + dArr5[i7][i8][i9];
                }
            }
        }
        for (int i10 = 0; i10 < length3; i10++) {
            double d3 = Double.NEGATIVE_INFINITY;
            for (int i11 = 0; i11 < length4; i11++) {
                for (int i12 = 0; i12 < length5; i12++) {
                    d3 = SloppyMath.logAdd(d3, dArr9[i10][i11][i12]);
                }
            }
            if (Double.isInfinite(d3)) {
                double d4 = -Math.log(length4 * length5);
                for (int i13 = 0; i13 < length4; i13++) {
                    for (int i14 = 0; i14 < length5; i14++) {
                        dArr9[i10][i13][i14] = d4;
                    }
                }
            } else {
                for (int i15 = 0; i15 < length4; i15++) {
                    for (int i16 = 0; i16 < length5; i16++) {
                        dArr9[i10][i15][i16] = dArr9[i10][i15][i16] - d3;
                    }
                }
            }
        }
        recountWeights(tree3, identityHashMap, identityHashMap2, identityHashMap3, identityHashMap4);
        recountWeights(tree4, identityHashMap, identityHashMap2, identityHashMap3, identityHashMap4);
    }

    public void recountOutside(Tree tree, IdentityHashMap<Tree, double[]> identityHashMap, IdentityHashMap<Tree, double[]> identityHashMap2) {
        identityHashMap2.put(tree, new double[]{STATE_SMOOTH});
        recurseOutside(tree, identityHashMap, identityHashMap2);
    }

    public void recurseOutside(Tree tree, IdentityHashMap<Tree, double[]> identityHashMap, IdentityHashMap<Tree, double[]> identityHashMap2) {
        if (tree.isLeaf() || tree.isPreTerminal()) {
            return;
        }
        if (tree.children().length == 1) {
            recountOutside(tree.children()[0], tree, identityHashMap, identityHashMap2);
        } else {
            recountOutside(tree.children()[0], tree.children()[1], tree, identityHashMap, identityHashMap2);
        }
    }

    public void recountOutside(Tree tree, Tree tree2, IdentityHashMap<Tree, double[]> identityHashMap, IdentityHashMap<Tree, double[]> identityHashMap2) {
        String value = tree2.label().value();
        String value2 = tree.label().value();
        double[] dArr = identityHashMap2.get(tree2);
        double[][] dArr2 = this.unaryBetas.get(value, value2);
        int length = dArr2.length;
        int length2 = dArr2[0].length;
        double[] neginfDoubles = neginfDoubles(length2);
        identityHashMap2.put(tree, neginfDoubles);
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                neginfDoubles[i2] = SloppyMath.logAdd(neginfDoubles[i2], dArr2[i][i2] + dArr[i]);
            }
        }
        recurseOutside(tree, identityHashMap, identityHashMap2);
    }

    public void recountOutside(Tree tree, Tree tree2, Tree tree3, IdentityHashMap<Tree, double[]> identityHashMap, IdentityHashMap<Tree, double[]> identityHashMap2) {
        String value = tree3.label().value();
        String value2 = tree.label().value();
        String value3 = tree2.label().value();
        double[] dArr = identityHashMap.get(tree);
        double[] dArr2 = identityHashMap.get(tree2);
        double[] dArr3 = identityHashMap2.get(tree3);
        double[][][] dArr4 = this.binaryBetas.get(value, value2, value3);
        int length = dArr4.length;
        int length2 = dArr4[0].length;
        int length3 = dArr4[0][0].length;
        double[] neginfDoubles = neginfDoubles(length2);
        identityHashMap2.put(tree, neginfDoubles);
        double[] neginfDoubles2 = neginfDoubles(length3);
        identityHashMap2.put(tree2, neginfDoubles2);
        for (int i = 0; i < length; i++) {
            for (int i2 = 0; i2 < length2; i2++) {
                for (int i3 = 0; i3 < length3; i3++) {
                    neginfDoubles[i2] = SloppyMath.logAdd(neginfDoubles[i2], dArr4[i][i2][i3] + dArr3[i] + dArr2[i3]);
                    neginfDoubles2[i3] = SloppyMath.logAdd(neginfDoubles2[i3], dArr4[i][i2][i3] + dArr3[i] + dArr[i2]);
                }
            }
        }
        recurseOutside(tree, identityHashMap, identityHashMap2);
        recurseOutside(tree2, identityHashMap, identityHashMap2);
    }

    public int recountInside(Tree tree, boolean z, int i, IdentityHashMap<Tree, double[]> identityHashMap) {
        int recountInside;
        if (tree.isLeaf()) {
            throw new RuntimeException();
        }
        if (tree.isPreTerminal()) {
            int stateSplitCount = getStateSplitCount(tree);
            String value = tree.children()[0].label().value();
            String value2 = tree.label().value();
            double[] dArr = new double[stateSplitCount];
            identityHashMap.put(tree, dArr);
            if (!z || value2.equals(".$$.")) {
                for (int i2 = 0; i2 < stateSplitCount; i2++) {
                    double score = this.lex.score(new IntTaggedWord(value, state(value2, i2), this.wordIndex, this.tagIndex), i, value, null);
                    if (DEBUG()) {
                        System.out.println("Lexicon log prob " + state(value2, i2) + "-" + value + ": " + score);
                    }
                    dArr[i2] = score;
                }
            } else {
                for (int i3 = 0; i3 < stateSplitCount / 2; i3++) {
                    double score2 = this.lex.score(new IntTaggedWord(value, state(value2, i3), this.wordIndex, this.tagIndex), i, value, null);
                    double nextDouble = 0.45d + (this.random.nextDouble() * 0.1d);
                    dArr[i3 * 2] = score2 + Math.log(nextDouble);
                    dArr[(i3 * 2) + 1] = score2 + Math.log(1.0d - nextDouble);
                    if (DEBUG()) {
                        System.out.println("Lexicon log prob " + state(value2, i3) + "-" + value + ": " + score2);
                        System.out.println("  Log Split -> " + dArr[i3 * 2] + "," + dArr[(i3 * 2) + 1]);
                    }
                }
            }
            recountInside = i + 1;
        } else if (tree.children().length == 1) {
            recountInside = recountInside(tree.children()[0], z, i, identityHashMap);
            double[] dArr2 = identityHashMap.get(tree.children()[0]);
            String value3 = tree.label().value();
            String value4 = tree.children()[0].label().value();
            double[][] dArr3 = this.unaryBetas.get(value3, value4);
            int length = dArr3.length;
            int length2 = dArr3[0].length;
            double[] neginfDoubles = neginfDoubles(length);
            identityHashMap.put(tree, neginfDoubles);
            for (int i4 = 0; i4 < length; i4++) {
                for (int i5 = 0; i5 < length2; i5++) {
                    neginfDoubles[i4] = SloppyMath.logAdd(neginfDoubles[i4], dArr2[i5] + dArr3[i4][i5]);
                }
            }
            if (DEBUG()) {
                System.out.println(value3 + " -> " + value4);
                for (int i6 = 0; i6 < length; i6++) {
                    System.out.println("  " + i6 + MorphoFeatures.KEY_VAL_DELIM + neginfDoubles[i6]);
                    for (int i7 = 0; i7 < length2; i7++) {
                        System.out.println("    " + i6 + "," + i7 + ": " + dArr3[i6][i7] + " | " + Math.exp(dArr3[i6][i7]));
                    }
                }
            }
        } else {
            recountInside = recountInside(tree.children()[1], z, recountInside(tree.children()[0], z, i, identityHashMap), identityHashMap);
            double[] dArr4 = identityHashMap.get(tree.children()[0]);
            double[] dArr5 = identityHashMap.get(tree.children()[1]);
            String value5 = tree.label().value();
            String value6 = tree.children()[0].label().value();
            String value7 = tree.children()[1].label().value();
            double[][][] dArr6 = this.binaryBetas.get(value5, value6, value7);
            int length3 = dArr6.length;
            int length4 = dArr6[0].length;
            int length5 = dArr6[0][0].length;
            double[] neginfDoubles2 = neginfDoubles(length3);
            identityHashMap.put(tree, neginfDoubles2);
            for (int i8 = 0; i8 < length3; i8++) {
                for (int i9 = 0; i9 < length4; i9++) {
                    for (int i10 = 0; i10 < length5; i10++) {
                        neginfDoubles2[i8] = SloppyMath.logAdd(neginfDoubles2[i8], dArr4[i9] + dArr5[i10] + dArr6[i8][i9][i10]);
                    }
                }
            }
            if (DEBUG()) {
                System.out.println(value5 + " -> " + value6 + "," + value7);
                for (int i11 = 0; i11 < length3; i11++) {
                    System.out.println("  " + i11 + MorphoFeatures.KEY_VAL_DELIM + neginfDoubles2[i11]);
                    for (int i12 = 0; i12 < length4; i12++) {
                        for (int i13 = 0; i13 < length5; i13++) {
                            System.out.println("    " + i11 + "," + i12 + "," + i13 + ": " + dArr6[i11][i12][i13] + " | " + Math.exp(dArr6[i11][i12][i13]));
                        }
                    }
                }
            }
        }
        return recountInside;
    }

    public void mergeStates() {
        if (this.op.trainOptions.splitRecombineRate <= STATE_SMOOTH) {
            return;
        }
        TwoDimensionalMap<String, String, double[][]> twoDimensionalMap = new TwoDimensionalMap<>();
        ThreeDimensionalMap<String, String, String, double[][][]> threeDimensionalMap = new ThreeDimensionalMap<>();
        Map<String, double[]> newHashMap = Generics.newHashMap();
        recalculateTemporaryBetas(false, newHashMap, twoDimensionalMap, threeDimensionalMap);
        Map<String, double[]> newHashMap2 = Generics.newHashMap();
        Iterator<Tree> it = this.trees.iterator();
        while (it.hasNext()) {
            countMergeEffects(it.next(), newHashMap, newHashMap2);
        }
        ArrayList arrayList = new ArrayList();
        for (String str : newHashMap2.keySet()) {
            double[] dArr = newHashMap2.get(str);
            for (int i = 0; i < dArr.length; i++) {
                arrayList.add(new Triple(str, Integer.valueOf(i * 2), Double.valueOf(dArr[i])));
            }
        }
        Collections.sort(arrayList, new Comparator<Triple<String, Integer, Double>>() { // from class: edu.stanford.nlp.parser.lexparser.SplittingGrammarExtractor.1
            @Override // java.util.Comparator
            public int compare(Triple<String, Integer, Double> triple, Triple<String, Integer, Double> triple2) {
                return Double.compare(triple2.third().doubleValue(), triple.third().doubleValue());
            }

            @Override // java.util.Comparator
            public boolean equals(Object obj) {
                return obj == this;
            }
        });
        List<Triple<String, Integer, Double>> subList = arrayList.subList(0, Math.min(arrayList.size() - 1, Math.max(0, (int) (arrayList.size() * this.op.trainOptions.splitRecombineRate))));
        System.out.println();
        System.out.println(subList);
        recalculateMergedBetas(buildMergeCorrespondence(subList));
        Iterator<Triple<String, Integer, Double>> it2 = subList.iterator();
        while (it2.hasNext()) {
            this.stateSplitCounts.decrementCount((IntCounter<String>) it2.next().first(), 1);
        }
    }

    public void recalculateMergedBetas(Map<String, int[]> map) {
        TwoDimensionalMap<String, String, double[][]> twoDimensionalMap = new TwoDimensionalMap<>();
        ThreeDimensionalMap<String, String, String, double[][][]> threeDimensionalMap = new ThreeDimensionalMap<>();
        this.tempWordIndex = new HashIndex();
        this.tempTagIndex = new HashIndex();
        this.tempLex = this.op.tlpParams.lex(this.op, this.tempWordIndex, this.tempTagIndex);
        this.tempLex.initializeTraining(this.trainSize);
        for (Tree tree : this.trees) {
            double count = this.treeWeights.getCount(tree);
            double[] dArr = {Math.log(count)};
            this.tempLex.incrementTreesRead(count);
            IdentityHashMap<Tree, double[][]> identityHashMap = new IdentityHashMap<>();
            IdentityHashMap<Tree, double[][][]> identityHashMap2 = new IdentityHashMap<>();
            recountTree(tree, false, identityHashMap, identityHashMap2);
            IdentityHashMap<Tree, double[][]> identityHashMap3 = new IdentityHashMap<>();
            IdentityHashMap<Tree, double[][][]> identityHashMap4 = new IdentityHashMap<>();
            mergeTransitions(tree, identityHashMap, identityHashMap2, identityHashMap3, identityHashMap4, dArr, map);
            recalculateTemporaryBetas(tree, dArr, 0, identityHashMap3, identityHashMap4, null, twoDimensionalMap, threeDimensionalMap);
        }
        this.tempLex.finishTraining();
        useNewBetas(false, twoDimensionalMap, threeDimensionalMap);
    }

    public void mergeTransitions(Tree tree, IdentityHashMap<Tree, double[][]> identityHashMap, IdentityHashMap<Tree, double[][][]> identityHashMap2, IdentityHashMap<Tree, double[][]> identityHashMap3, IdentityHashMap<Tree, double[][][]> identityHashMap4, double[] dArr, Map<String, int[]> map) {
        if (tree.isPreTerminal() || tree.isLeaf()) {
            return;
        }
        if (tree.children().length == 1) {
            double[][] dArr2 = identityHashMap.get(tree);
            int[] iArr = map.get(tree.label().value());
            int i = iArr[iArr.length - 1] + 1;
            int[] iArr2 = map.get(tree.children()[0].label().value());
            int i2 = iArr2[iArr2.length - 1] + 1;
            double[][] dArr3 = new double[i][i2];
            for (int i3 = 0; i3 < i; i3++) {
                for (int i4 = 0; i4 < i2; i4++) {
                    dArr3[i3][i4] = Double.NEGATIVE_INFINITY;
                }
            }
            identityHashMap3.put(tree, dArr3);
            for (int i5 = 0; i5 < dArr2.length; i5++) {
                int i6 = iArr[i5];
                for (int i7 = 0; i7 < dArr2[0].length; i7++) {
                    int i8 = iArr2[i7];
                    dArr3[i6][i8] = SloppyMath.logAdd(dArr3[i6][i8], dArr2[i5][i7] + dArr[i5]);
                }
            }
            for (int i9 = 0; i9 < i; i9++) {
                double d = Double.NEGATIVE_INFINITY;
                for (int i10 = 0; i10 < i2; i10++) {
                    d = SloppyMath.logAdd(d, dArr3[i9][i10]);
                }
                if (Double.isInfinite(d)) {
                    for (int i11 = 0; i11 < i2; i11++) {
                        dArr3[i9][i11] = -Math.log(i2);
                    }
                } else {
                    for (int i12 = 0; i12 < i2; i12++) {
                        double[] dArr4 = dArr3[i9];
                        int i13 = i12;
                        dArr4[i13] = dArr4[i13] - d;
                    }
                }
            }
            double[] neginfDoubles = neginfDoubles(dArr2[0].length);
            for (int i14 = 0; i14 < dArr2.length; i14++) {
                for (int i15 = 0; i15 < dArr2[0].length; i15++) {
                    neginfDoubles[i15] = SloppyMath.logAdd(neginfDoubles[i15], dArr2[i14][i15] + dArr[i14]);
                }
            }
            mergeTransitions(tree.children()[0], identityHashMap, identityHashMap2, identityHashMap3, identityHashMap4, neginfDoubles, map);
            return;
        }
        double[][][] dArr5 = identityHashMap2.get(tree);
        int[] iArr3 = map.get(tree.label().value());
        int i16 = iArr3[iArr3.length - 1] + 1;
        int[] iArr4 = map.get(tree.children()[0].label().value());
        int i17 = iArr4[iArr4.length - 1] + 1;
        int[] iArr5 = map.get(tree.children()[1].label().value());
        int i18 = iArr5[iArr5.length - 1] + 1;
        double[][][] dArr6 = new double[i16][i17][i18];
        for (int i19 = 0; i19 < i16; i19++) {
            for (int i20 = 0; i20 < i17; i20++) {
                for (int i21 = 0; i21 < i18; i21++) {
                    dArr6[i19][i20][i21] = Double.NEGATIVE_INFINITY;
                }
            }
        }
        identityHashMap4.put(tree, dArr6);
        for (int i22 = 0; i22 < dArr5.length; i22++) {
            int i23 = iArr3[i22];
            for (int i24 = 0; i24 < dArr5[0].length; i24++) {
                int i25 = iArr4[i24];
                for (int i26 = 0; i26 < dArr5[0][0].length; i26++) {
                    int i27 = iArr5[i26];
                    dArr6[i23][i25][i27] = SloppyMath.logAdd(dArr6[i23][i25][i27], dArr5[i22][i24][i26] + dArr[i22]);
                }
            }
        }
        for (int i28 = 0; i28 < i16; i28++) {
            double d2 = Double.NEGATIVE_INFINITY;
            for (int i29 = 0; i29 < i17; i29++) {
                for (int i30 = 0; i30 < i18; i30++) {
                    d2 = SloppyMath.logAdd(d2, dArr6[i28][i29][i30]);
                }
            }
            if (Double.isInfinite(d2)) {
                for (int i31 = 0; i31 < i17; i31++) {
                    for (int i32 = 0; i32 < i18; i32++) {
                        dArr6[i28][i31][i32] = -Math.log(i17 * i18);
                    }
                }
            } else {
                for (int i33 = 0; i33 < i17; i33++) {
                    for (int i34 = 0; i34 < i18; i34++) {
                        double[] dArr7 = dArr6[i28][i33];
                        int i35 = i34;
                        dArr7[i35] = dArr7[i35] - d2;
                    }
                }
            }
        }
        double[] neginfDoubles2 = neginfDoubles(dArr5[0].length);
        double[] neginfDoubles3 = neginfDoubles(dArr5[0][0].length);
        for (int i36 = 0; i36 < dArr5.length; i36++) {
            for (int i37 = 0; i37 < dArr5[0].length; i37++) {
                for (int i38 = 0; i38 < dArr5[0][0].length; i38++) {
                    double d3 = dArr5[i36][i37][i38];
                    neginfDoubles2[i37] = SloppyMath.logAdd(neginfDoubles2[i37], d3 + dArr[i36]);
                    neginfDoubles3[i38] = SloppyMath.logAdd(neginfDoubles3[i38], d3 + dArr[i36]);
                }
            }
        }
        mergeTransitions(tree.children()[0], identityHashMap, identityHashMap2, identityHashMap3, identityHashMap4, neginfDoubles2, map);
        mergeTransitions(tree.children()[1], identityHashMap, identityHashMap2, identityHashMap3, identityHashMap4, neginfDoubles3, map);
    }

    Map<String, int[]> buildMergeCorrespondence(List<Triple<String, Integer, Double>> list) {
        Map<String, int[]> newHashMap = Generics.newHashMap();
        for (String str : this.originalStates) {
            int stateSplitCount = getStateSplitCount(str);
            int[] iArr = new int[stateSplitCount];
            for (int i = 0; i < stateSplitCount; i++) {
                iArr[i] = i;
            }
            newHashMap.put(str, iArr);
        }
        for (Triple<String, Integer, Double> triple : list) {
            int stateSplitCount2 = getStateSplitCount(triple.first());
            int intValue = triple.second().intValue();
            int[] iArr2 = newHashMap.get(triple.first());
            for (int i2 = intValue + 1; i2 < stateSplitCount2; i2++) {
                iArr2[i2] = iArr2[i2] - 1;
            }
        }
        return newHashMap;
    }

    public void countMergeEffects(Tree tree, Map<String, double[]> map, Map<String, double[]> map2) {
        IdentityHashMap<Tree, double[]> identityHashMap = new IdentityHashMap<>();
        IdentityHashMap<Tree, double[]> identityHashMap2 = new IdentityHashMap<>();
        recountTree(tree, false, identityHashMap, identityHashMap2, new IdentityHashMap<>(), new IdentityHashMap<>());
        for (Tree tree2 : tree.children()) {
            countMergeEffects(tree2, map, map2, identityHashMap, identityHashMap2);
        }
    }

    public void countMergeEffects(Tree tree, Map<String, double[]> map, Map<String, double[]> map2, IdentityHashMap<Tree, double[]> identityHashMap, IdentityHashMap<Tree, double[]> identityHashMap2) {
        if (tree.isLeaf() || tree.label().value().equals(".$$.")) {
            return;
        }
        String value = tree.label().value();
        double d = 0.0d;
        double[] dArr = map.get(value);
        for (double d2 : dArr) {
            d += d2;
        }
        double[] dArr2 = identityHashMap.get(tree);
        double[] dArr3 = identityHashMap2.get(tree);
        double[] dArr4 = map2.get(value);
        if (dArr4 == null) {
            dArr4 = new double[dArr2.length / 2];
            map2.put(value, dArr4);
        }
        for (int i = 0; i < dArr2.length / 2; i++) {
            dArr4[i] = (dArr4[i] + (SloppyMath.logAdd(Math.log(dArr[i * 2] / d) + dArr2[i * 2], Math.log(dArr[(i * 2) + 1] / d) + dArr2[(i * 2) + 1]) + SloppyMath.logAdd(dArr3[i * 2], dArr3[(i * 2) + 1]))) - SloppyMath.logAdd(dArr2[i * 2] + dArr3[i * 2], dArr2[(i * 2) + 1] + dArr3[(i * 2) + 1]);
        }
        if (tree.isPreTerminal()) {
            return;
        }
        for (Tree tree2 : tree.children()) {
            countMergeEffects(tree2, map, map2, identityHashMap, identityHashMap2);
        }
    }

    public void buildStateIndex() {
        this.stateIndex = new HashIndex();
        for (String str : this.stateSplitCounts.keySet()) {
            for (int i = 0; i < this.stateSplitCounts.getIntCount(str); i++) {
                this.stateIndex.indexOf(state(str, i), true);
            }
        }
    }

    public void buildGrammars() {
        TwoDimensionalMap<String, String, double[][]> twoDimensionalMap = new TwoDimensionalMap<>();
        ThreeDimensionalMap<String, String, String, double[][][]> threeDimensionalMap = new ThreeDimensionalMap<>();
        Map<String, double[]> newHashMap = Generics.newHashMap();
        recalculateTemporaryBetas(false, newHashMap, twoDimensionalMap, threeDimensionalMap);
        BinaryGrammar binaryGrammar = new BinaryGrammar(this.stateIndex);
        for (String str : threeDimensionalMap.firstKeySet()) {
            int stateSplitCount = getStateSplitCount(str);
            double[] dArr = newHashMap.get(str);
            for (String str2 : threeDimensionalMap.get(str).firstKeySet()) {
                int stateSplitCount2 = getStateSplitCount(str2);
                for (String str3 : threeDimensionalMap.get(str).get(str2).keySet()) {
                    int stateSplitCount3 = getStateSplitCount(str3);
                    double[][][] dArr2 = threeDimensionalMap.get(str, str2, str3);
                    for (int i = 0; i < stateSplitCount; i++) {
                        if (dArr[i] >= 1.0E-4d) {
                            for (int i2 = 0; i2 < stateSplitCount2; i2++) {
                                for (int i3 = 0; i3 < stateSplitCount3; i3++) {
                                    binaryGrammar.addRule(new BinaryRule(this.stateIndex.indexOf(state(str, i)), this.stateIndex.indexOf(state(str2, i2)), this.stateIndex.indexOf(state(str3, i3)), dArr2[i][i2][i3] - Math.log(dArr[i])));
                                }
                            }
                        }
                    }
                }
            }
        }
        UnaryGrammar unaryGrammar = new UnaryGrammar(this.stateIndex);
        for (String str4 : twoDimensionalMap.firstKeySet()) {
            int stateSplitCount4 = getStateSplitCount(str4);
            double[] dArr3 = newHashMap.get(str4);
            for (String str5 : twoDimensionalMap.get(str4).keySet()) {
                int stateSplitCount5 = getStateSplitCount(str5);
                double[][] dArr4 = twoDimensionalMap.get(str4, str5);
                for (int i4 = 0; i4 < stateSplitCount4; i4++) {
                    if (dArr3[i4] >= 1.0E-4d) {
                        for (int i5 = 0; i5 < stateSplitCount5; i5++) {
                            unaryGrammar.addRule(new UnaryRule(this.stateIndex.indexOf(state(str4, i4)), this.stateIndex.indexOf(state(str5, i5)), dArr4[i4][i5] - Math.log(dArr3[i4])));
                        }
                    }
                }
            }
        }
        this.bgug = new Pair<>(unaryGrammar, binaryGrammar);
    }

    public void saveTrees(Collection<Tree> collection, double d, Collection<Tree> collection2, double d2) {
        this.trainSize = STATE_SMOOTH;
        this.trees.clear();
        this.treeWeights.clear();
        for (Tree tree : collection) {
            this.trees.add(tree);
            this.treeWeights.incrementCount(tree, d);
            this.trainSize += d;
        }
        int size = 0 + collection.size();
        if (collection2 != null && d2 >= STATE_SMOOTH) {
            for (Tree tree2 : collection2) {
                this.trees.add(tree2);
                this.treeWeights.incrementCount(tree2, d2);
                this.trainSize += d2;
            }
            size += collection2.size();
        }
        System.err.println("Found " + size + " trees with total weight " + this.trainSize);
    }

    public void extract(Collection<Tree> collection) {
        extract(collection, 1.0d, null, STATE_SMOOTH);
    }

    public void extract(Collection<Tree> collection, double d, Collection<Tree> collection2, double d2) {
        saveTrees(collection, d, collection2, d2);
        countOriginalStates();
        initialBetasAndLexicon();
        for (int i = 0; i < this.op.trainOptions.splitCount; i++) {
            splitStateCounts();
            recalculateBetas(true);
            this.iteration = 0;
            boolean z = false;
            while (!z && this.iteration < MAX_ITERATIONS) {
                if (DEBUG()) {
                    System.out.println();
                    System.out.println();
                    System.out.println("-------------------");
                    System.out.println("Iteration " + this.iteration);
                }
                z = recalculateBetas(false);
                this.iteration++;
            }
            System.err.println("Converged for cycle " + i + " in " + this.iteration + " iterations");
            mergeStates();
        }
        buildStateIndex();
        buildGrammars();
    }
}
