package edu.stanford.nlp.parser.lexparser;

import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.io.NumberRangeFileFilter;
import edu.stanford.nlp.io.NumberRangesFileFilter;
import edu.stanford.nlp.ling.CategoryWordTag;
import edu.stanford.nlp.ling.CategoryWordTagFactory;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.ling.StringLabelFactory;
import edu.stanford.nlp.ling.Word;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.trees.DiskTreebank;
import edu.stanford.nlp.trees.HeadFinder;
import edu.stanford.nlp.trees.LabeledScoredTreeFactory;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeFactory;
import edu.stanford.nlp.trees.TreeTransformer;
import edu.stanford.nlp.trees.TreebankLanguagePack;
import java.io.FileFilter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:edu/stanford/nlp/parser/lexparser/TreeAnnotatorAndBinarizer.class */
public class TreeAnnotatorAndBinarizer implements TreeTransformer {
    private TreeFactory tf;
    private TreebankLanguagePack tlp;
    private TreeTransformer annotator;
    private TreeTransformer binarizer;
    private TreeTransformer postSplitter;
    private boolean forceCNF;
    private ClassicCounter<Tree> annotatedRuleCounts;
    private ClassicCounter<String> annotatedStateCounts;
    private TrainOptions trainOptions;

    /* loaded from: input_file:edu/stanford/nlp/parser/lexparser/TreeAnnotatorAndBinarizer$TreeNullAnnotator.class */
    static class TreeNullAnnotator implements TreeTransformer {
        private TreeFactory tf = new LabeledScoredTreeFactory(new CategoryWordTagFactory());
        private HeadFinder hf;

        @Override // edu.stanford.nlp.trees.TreeTransformer
        public Tree transformTree(Tree tree) {
            return transformTreeHelper(tree.treeSkeletonCopy(this.tf));
        }

        private Tree transformTreeHelper(Tree tree) {
            String word;
            String tag;
            if (tree != null) {
                String value = tree.label().value();
                if (tree.isLeaf()) {
                    tree.setLabel(new Word(value));
                } else {
                    for (Tree tree2 : tree.children()) {
                        transformTreeHelper(tree2);
                    }
                    Tree determineHead = this.hf.determineHead(tree);
                    if (determineHead == null) {
                        System.err.println("ERROR: null head for tree\n" + tree.toString());
                        word = null;
                        tag = null;
                    } else if (determineHead.isLeaf()) {
                        tag = value;
                        word = determineHead.label().value();
                    } else {
                        CategoryWordTag categoryWordTag = (CategoryWordTag) determineHead.label();
                        word = categoryWordTag.word();
                        tag = categoryWordTag.tag();
                    }
                    tree.setLabel(new CategoryWordTag(value, word, tag));
                }
            }
            return tree;
        }

        public TreeNullAnnotator(HeadFinder headFinder) {
            this.hf = headFinder;
        }
    }

    public TreeAnnotatorAndBinarizer(TreebankLangParserParams treebankLangParserParams, boolean z, boolean z2, boolean z3, Options options) {
        this(treebankLangParserParams.headFinder(), treebankLangParserParams.headFinder(), treebankLangParserParams, z, z2, z3, options);
    }

    public TreeAnnotatorAndBinarizer(HeadFinder headFinder, HeadFinder headFinder2, TreebankLangParserParams treebankLangParserParams, boolean z, boolean z2, boolean z3, Options options) {
        this.trainOptions = options.trainOptions;
        if (z3) {
            this.annotator = new TreeAnnotator(headFinder, treebankLangParserParams, options);
        } else {
            this.annotator = new TreeNullAnnotator(headFinder);
        }
        this.binarizer = new TreeBinarizer(headFinder2, treebankLangParserParams.treebankLanguagePack(), z2, this.trainOptions.markovFactor, this.trainOptions.markovOrder, this.trainOptions.compactGrammar() > 0, this.trainOptions.compactGrammar() > 1, this.trainOptions.HSEL_CUT, this.trainOptions.markFinalStates);
        if (this.trainOptions.selectivePostSplit) {
            this.postSplitter = new PostSplitter(treebankLangParserParams, options);
        }
        this.tf = new LabeledScoredTreeFactory(new CategoryWordTagFactory());
        this.tlp = treebankLangParserParams.treebankLanguagePack();
        this.forceCNF = z;
        if (this.trainOptions.printAnnotatedRuleCounts) {
            this.annotatedRuleCounts = new ClassicCounter<>();
        }
        if (this.trainOptions.printAnnotatedStateCounts) {
            this.annotatedStateCounts = new ClassicCounter<>();
        }
    }

    public void dumpStats() {
        if (this.trainOptions.selectivePostSplit) {
            ((PostSplitter) this.postSplitter).dumpStats();
        }
    }

    public void setDoSelectiveSplit(boolean z) {
        ((TreeBinarizer) this.binarizer).setDoSelectiveSplit(z);
    }

    public void addRoot(Tree tree) {
        if (tree.isLeaf()) {
            System.err.println("Warning: tree is leaf: " + tree);
            tree = this.tf.newTreeNode(this.tlp.startSymbol(), Collections.singletonList(tree));
        }
        tree.setLabel(new CategoryWordTag(this.tlp.startSymbol(), ".$.", ".$$."));
        ArrayList arrayList = new ArrayList();
        arrayList.add(this.tf.newLeaf(new Word(".$.")));
        Tree newTreeNode = this.tf.newTreeNode(new CategoryWordTag(".$$.", ".$.", ".$$."), arrayList);
        List<Tree> childrenAsList = tree.getChildrenAsList();
        childrenAsList.add(newTreeNode);
        tree.setChildren(childrenAsList);
    }

    @Override // edu.stanford.nlp.trees.TreeTransformer
    public Tree transformTree(Tree tree) {
        if (this.trainOptions.printTreeTransformations > 0) {
            TrainOptions trainOptions = this.trainOptions;
            TrainOptions.printTrainTree(null, "ORIGINAL TREE:", tree);
        }
        Tree transformTree = this.annotator.transformTree(tree);
        if (this.trainOptions.selectivePostSplit) {
            transformTree = this.postSplitter.transformTree(transformTree);
        }
        if (this.trainOptions.printTreeTransformations > 0) {
            TrainOptions trainOptions2 = this.trainOptions;
            TrainOptions.printTrainTree(this.trainOptions.printAnnotatedPW, "ANNOTATED TREE:", transformTree);
        }
        if (this.trainOptions.printAnnotatedRuleCounts) {
            Iterator<Tree> it = transformTree.deepCopy(new LabeledScoredTreeFactory(), new StringLabelFactory()).localTrees().iterator();
            while (it.hasNext()) {
                this.annotatedRuleCounts.incrementCount(it.next());
            }
        }
        if (this.trainOptions.printAnnotatedStateCounts) {
            Iterator<Tree> it2 = transformTree.iterator();
            while (it2.hasNext()) {
                Tree next = it2.next();
                if (!next.isLeaf()) {
                    this.annotatedStateCounts.incrementCount(next.label().value());
                }
            }
        }
        addRoot(transformTree);
        Tree transformTree2 = this.binarizer.transformTree(transformTree);
        if (this.trainOptions.printTreeTransformations > 0) {
            TrainOptions trainOptions3 = this.trainOptions;
            TrainOptions.printTrainTree(this.trainOptions.printBinarizedPW, "BINARIZED TREE:", transformTree2);
            this.trainOptions.printTreeTransformations--;
        }
        if (this.forceCNF) {
            transformTree2 = new TreeTransformer() { // from class: edu.stanford.nlp.parser.lexparser.CNFTransformers$ToCNFTransformer
                @Override // edu.stanford.nlp.trees.TreeTransformer
                public Tree transformTree(Tree tree2) {
                    String str;
                    if (tree2.isLeaf()) {
                        return tree2.treeFactory().newLeaf(tree2.label());
                    }
                    Tree[] children = tree2.children();
                    if (children.length > 1 || tree2.isPreTerminal() || tree2.label().value().startsWith("ROOT")) {
                        Label label = tree2.label();
                        Tree[] treeArr = new Tree[children.length];
                        for (int i = 0; i < children.length; i++) {
                            treeArr[i] = transformTree(children[i]);
                        }
                        return tree2.treeFactory().newTreeNode(label, Arrays.asList(treeArr));
                    }
                    Tree tree3 = tree2;
                    ArrayList<String> arrayList = new ArrayList();
                    while (tree3.children().length == 1 && !tree3.isPrePreTerminal()) {
                        String value = tree3.label().value();
                        if (!value.startsWith("@")) {
                            arrayList.add(value);
                        }
                        tree3 = tree3.children()[0];
                    }
                    String value2 = tree3.label().value();
                    if (!value2.startsWith("@")) {
                        arrayList.add(value2);
                    }
                    if (arrayList.size() > 1) {
                        StringBuilder sb = new StringBuilder();
                        for (String str2 : arrayList) {
                            sb.append("&");
                            sb.append(str2);
                        }
                        str = sb.toString();
                    } else {
                        if (arrayList.size() != 1) {
                            return transformTree(tree2.children()[0]);
                        }
                        str = (String) arrayList.iterator().next();
                    }
                    Tree[] children2 = tree3.children();
                    Label newLabel = tree2.label().labelFactory().newLabel(str);
                    Tree[] treeArr2 = new Tree[children2.length];
                    for (int i2 = 0; i2 < children2.length; i2++) {
                        treeArr2[i2] = transformTree(children2[i2]);
                    }
                    return tree2.treeFactory().newTreeNode(newLabel, Arrays.asList(treeArr2));
                }
            }.transformTree(transformTree2);
        }
        return transformTree2;
    }

    public void printRuleCounts() {
        System.err.println();
        for (Tree tree : this.annotatedRuleCounts.keySet()) {
            System.err.print(this.annotatedRuleCounts.getCount(tree) + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + tree.label().value() + " -->");
            for (Tree tree2 : tree.getChildrenAsList()) {
                System.err.print(" ");
                System.err.print(tree2.label().value());
            }
            System.err.println();
        }
    }

    public void printStateCounts() {
        System.err.println();
        System.err.println("Annotated state counts");
        ArrayList<String> arrayList = new ArrayList(this.annotatedStateCounts.keySet());
        Collections.sort(arrayList);
        for (String str : arrayList) {
            System.err.println(str + LinearClassifier.TEXT_SERIALIZATION_DELIMITER + this.annotatedStateCounts.getCount(str));
        }
    }

    private static int numSubArgs(String[] strArr, int i) {
        int i2 = i;
        while (i2 + 1 < strArr.length && strArr[i2 + 1].charAt(0) != '-') {
            i2++;
        }
        return i2 - i;
    }

    public static void main(String[] strArr) {
        Options options = new Options();
        String str = null;
        FileFilter fileFilter = null;
        int i = 0;
        while (i < strArr.length && strArr[i].startsWith("-")) {
            if (strArr[i].equalsIgnoreCase("-train")) {
                int numSubArgs = numSubArgs(strArr, i);
                int i2 = i + 1;
                if (numSubArgs < 1) {
                    throw new RuntimeException("Error: -train option must have treebankPath as first argument.");
                }
                str = strArr[i2];
                i = i2 + 1;
                if (numSubArgs == 2) {
                    i++;
                    fileFilter = new NumberRangesFileFilter(strArr[i], true);
                } else if (numSubArgs >= 3) {
                    fileFilter = new NumberRangeFileFilter(Integer.parseInt(strArr[i]), Integer.parseInt(strArr[i + 1]), true);
                    i += 2;
                }
            } else {
                i = options.setOption(strArr, i);
            }
        }
        if (i < strArr.length) {
            System.err.println("usage: java TreeAnnotatorAndBinarizer options*");
            System.err.println("  Options are like for lexicalized parser including -train treebankPath fileRange]");
            System.exit(0);
        }
        System.err.println("Annotating from treebank dir: " + str);
        DiskTreebank diskTreebank = options.tlpParams.diskTreebank();
        if (fileFilter == null) {
            diskTreebank.loadPath(str);
        } else {
            diskTreebank.loadPath(str, fileFilter);
        }
        List<Tree> first = LexicalizedParser.getAnnotatedBinaryTreebankFromTreebank(diskTreebank, null, options).first();
        Iterator<Tree> it = diskTreebank.iterator();
        for (Tree tree : first) {
            System.out.println("Original tree:");
            it.next().pennPrint();
            System.out.println("Binarized tree:");
            tree.pennPrint();
            System.out.println();
        }
    }
}
