package edu.stanford.nlp.parser.shiftreduce;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.common.ParserGrammar;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.common.ParserUtils;
import edu.stanford.nlp.parser.lexparser.BinaryHeadFinder;
import edu.stanford.nlp.parser.lexparser.EvaluateTreebank;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.TreeBinarizer;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.parser.metrics.Eval;
import edu.stanford.nlp.parser.metrics.ParserQueryEval;
import edu.stanford.nlp.parser.shiftreduce.TreeRecorder;
import edu.stanford.nlp.tagger.common.Tagger;
import edu.stanford.nlp.trees.BasicCategoryTreeTransformer;
import edu.stanford.nlp.trees.CompositeTreeTransformer;
import edu.stanford.nlp.trees.LabeledScoredTreeNode;
import edu.stanford.nlp.trees.MemoryTreebank;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeCoreAnnotations;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.TreebankLanguagePack;
import edu.stanford.nlp.trees.Trees;
import edu.stanford.nlp.util.ArrayUtils;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ReflectionLoading;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import edu.stanford.nlp.util.logging.Redwood;
import java.io.FileFilter;
import java.io.IOException;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/parser/shiftreduce/ShiftReduceParser.class */
public class ShiftReduceParser extends ParserGrammar implements Serializable {
    final ShiftReduceOptions op;
    PerceptronModel model;
    private static final long serialVersionUID = 1;
    private static final Redwood.RedwoodChannels log = Redwood.channels(ShiftReduceParser.class);
    private static final String[] BEAM_FLAGS = {"-beamSize", "4"};
    private static final String[] BASIC_TRAINING_OPTIONS = {"-forceTags", "-debugOutputFrequency", "1", "-quietEvaluation"};
    private static final String[] FORCE_TAGS = {"-forceTags"};

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/parser/shiftreduce/ShiftReduceParser$RetagProcessor.class */
    public static class RetagProcessor implements ThreadsafeProcessor<Tree, Tree> {
        Tagger tagger;

        public RetagProcessor(Tagger tagger) {
            this.tagger = tagger;
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        public Tree process(Tree tree) {
            ShiftReduceParser.redoTags(tree, this.tagger);
            return tree;
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        /* renamed from: newInstance */
        public ThreadsafeProcessor<Tree, Tree> newInstance2() {
            return this;
        }
    }

    public ShiftReduceParser(ShiftReduceOptions shiftReduceOptions) {
        this(shiftReduceOptions, null);
    }

    public ShiftReduceParser(ShiftReduceOptions shiftReduceOptions, PerceptronModel perceptronModel) {
        this.op = shiftReduceOptions;
        this.model = perceptronModel;
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public Options getOp() {
        return this.op;
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public TreebankLangParserParams getTLPParams() {
        return this.op.tlpParams;
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public TreebankLanguagePack treebankLanguagePack() {
        return getTLPParams().treebankLanguagePack();
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public String[] defaultCoreNLPFlags() {
        return this.op.trainOptions().beamSize > 1 ? (String[]) ArrayUtils.concatenate(getTLPParams().defaultCoreNLPFlags(), BEAM_FLAGS) : getTLPParams().defaultCoreNLPFlags();
    }

    public Set<String> knownStates() {
        return Collections.unmodifiableSet(this.model.knownStates);
    }

    public Set<String> tagSet() {
        return this.model.tagSet();
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public boolean requiresTags() {
        return true;
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public ParserQuery parserQuery() {
        return new ShiftReduceParserQuery(this);
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public Tree parse(String str) {
        if (getOp().testOptions.preTag) {
            return super.parse(str);
        }
        throw new UnsupportedOperationException("Can only parse raw text if a tagger is specified, as the ShiftReduceParser cannot produce its own tags");
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public Tree parse(List<? extends HasWord> list) {
        ShiftReduceParserQuery shiftReduceParserQuery = new ShiftReduceParserQuery(this);
        return shiftReduceParserQuery.parse(list) ? shiftReduceParserQuery.getBestParse() : ParserUtils.xTree(list);
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public Tree parseTree(List<? extends HasWord> list) {
        ShiftReduceParserQuery shiftReduceParserQuery = new ShiftReduceParserQuery(this);
        if (shiftReduceParserQuery.parse(list)) {
            return shiftReduceParserQuery.getBestParse();
        }
        return null;
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public List<Eval> getExtraEvals() {
        return Collections.emptyList();
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public List<ParserQueryEval> getParserQueryEvals() {
        if (this.op.testOptions().recordBinarized == null && this.op.testOptions().recordDebinarized == null && !this.op.testOptions().recordTransitionTypes) {
            return Collections.emptyList();
        }
        ArrayList newArrayList = Generics.newArrayList();
        if (this.op.testOptions().recordBinarized != null) {
            newArrayList.add(new TreeRecorder(TreeRecorder.Mode.BINARIZED, this.op.testOptions().recordBinarized));
        }
        if (this.op.testOptions().recordDebinarized != null) {
            newArrayList.add(new TreeRecorder(TreeRecorder.Mode.DEBINARIZED, this.op.testOptions().recordDebinarized));
        }
        if (this.op.testOptions().recordTransitionTypes) {
            newArrayList.add(new TransitionTypeEval());
        }
        return newArrayList;
    }

    public static State initialStateFromGoldTagTree(Tree tree) {
        return initialStateFromTaggedSentence(tree.taggedYield());
    }

    public static State initialStateFromTaggedSentence(List<? extends HasWord> list) {
        CoreLabel coreLabel;
        String tag;
        ArrayList newArrayList = Generics.newArrayList();
        for (int i = 0; i < list.size(); i++) {
            HasWord hasWord = list.get(i);
            if (hasWord instanceof CoreLabel) {
                coreLabel = (CoreLabel) hasWord;
                tag = coreLabel.tag();
            } else {
                coreLabel = new CoreLabel();
                coreLabel.setValue(hasWord.word());
                coreLabel.setWord(hasWord.word());
                if (!(hasWord instanceof HasTag)) {
                    throw new IllegalArgumentException("Expected tagged words");
                }
                tag = ((HasTag) hasWord).tag();
                coreLabel.setTag(tag);
            }
            if (tag == null) {
                throw new IllegalArgumentException("Input word not tagged");
            }
            CoreLabel coreLabel2 = new CoreLabel();
            coreLabel2.setValue(tag);
            coreLabel.setIndex(i + 1);
            coreLabel2.setIndex(i + 1);
            LabeledScoredTreeNode labeledScoredTreeNode = new LabeledScoredTreeNode(coreLabel);
            LabeledScoredTreeNode labeledScoredTreeNode2 = new LabeledScoredTreeNode(coreLabel2);
            labeledScoredTreeNode2.addChild(labeledScoredTreeNode);
            coreLabel.set(TreeCoreAnnotations.HeadWordLabelAnnotation.class, coreLabel);
            coreLabel.set(TreeCoreAnnotations.HeadTagLabelAnnotation.class, coreLabel2);
            coreLabel2.set(TreeCoreAnnotations.HeadWordLabelAnnotation.class, coreLabel);
            coreLabel2.set(TreeCoreAnnotations.HeadTagLabelAnnotation.class, coreLabel2);
            newArrayList.add(labeledScoredTreeNode2);
        }
        return new State(newArrayList);
    }

    public static ShiftReduceOptions buildTrainingOptions(String str, String[] strArr) {
        ShiftReduceOptions shiftReduceOptions = new ShiftReduceOptions();
        shiftReduceOptions.setOptions(BASIC_TRAINING_OPTIONS);
        if (str != null) {
            shiftReduceOptions.tlpParams = (TreebankLangParserParams) ReflectionLoading.loadByReflection(str, new Object[0]);
        }
        shiftReduceOptions.setOptions(strArr);
        if (shiftReduceOptions.trainOptions.randomSeed == 0) {
            shiftReduceOptions.trainOptions.randomSeed = System.nanoTime();
            log.info("Random seed not set by options, using " + shiftReduceOptions.trainOptions.randomSeed);
        }
        return shiftReduceOptions;
    }

    public Treebank readTreebank(String str, FileFilter fileFilter) {
        log.info("Loading trees from " + str);
        MemoryTreebank memoryTreebank = this.op.tlpParams.memoryTreebank();
        memoryTreebank.loadPath(str, fileFilter);
        log.info("Read in " + memoryTreebank.size() + " trees from " + str);
        return memoryTreebank;
    }

    public List<Tree> readBinarizedTreebank(String str, FileFilter fileFilter) {
        List<Tree> binarizeTreebank = binarizeTreebank(filterTreebank(readTreebank(str, fileFilter)), this.op);
        log.info("Converted trees to binarized format");
        return binarizeTreebank;
    }

    public static boolean checkLeafBranching(Tree tree) {
        if (tree == null) {
            return false;
        }
        if (tree.isLeaf() || tree.isPreTerminal()) {
            return true;
        }
        for (Tree tree2 : tree.children()) {
            if (!checkLeafBranching(tree2) || tree2.isLeaf()) {
                return false;
            }
        }
        return true;
    }

    public static boolean checkRootTransition(Tree tree) {
        return tree.numChildren() == 1;
    }

    public List<Tree> filterTreebank(Treebank treebank) {
        ArrayList arrayList = new ArrayList();
        Iterator<Tree> it = treebank.iterator();
        while (it.hasNext()) {
            Tree next = it.next();
            if (isLegalTree(next)) {
                arrayList.add(next);
            } else {
                log.error("Found an illegal tree, skipping: " + next);
            }
        }
        return arrayList;
    }

    public static boolean isLegalTree(Tree tree) {
        return checkLeafBranching(tree) && checkRootTransition(tree);
    }

    public static List<Tree> binarizeTreebank(Iterable<Tree> iterable, Options options) {
        TreeBinarizer simpleTreeBinarizer = TreeBinarizer.simpleTreeBinarizer(options.tlpParams.headFinder(), options.tlpParams.treebankLanguagePack());
        BasicCategoryTreeTransformer basicCategoryTreeTransformer = new BasicCategoryTreeTransformer(options.langpack());
        CompositeTreeTransformer compositeTreeTransformer = new CompositeTreeTransformer();
        compositeTreeTransformer.addTransformer(simpleTreeBinarizer);
        compositeTreeTransformer.addTransformer(basicCategoryTreeTransformer);
        ArrayList<Tree> arrayList = new ArrayList();
        Iterator<Tree> it = iterable.iterator();
        while (it.hasNext()) {
            arrayList.add(compositeTreeTransformer.transformTree(it.next()));
        }
        BinaryHeadFinder binaryHeadFinder = new BinaryHeadFinder(options.tlpParams.headFinder());
        ArrayList arrayList2 = new ArrayList();
        for (Tree tree : arrayList) {
            if (tree.isBinarized()) {
                Trees.convertToCoreLabels(tree);
                tree.percolateHeadAnnotations(binaryHeadFinder);
                tree.indexLeaves(1, true);
                arrayList2.add(tree);
            } else {
                log.warn("Found a tree which was not properly binarized.  So-called binarized tree is as follows:\n" + tree.pennString());
            }
        }
        return arrayList2;
    }

    public static Set<String> findKnownStates(List<Tree> list) {
        Set newHashSet = Generics.newHashSet();
        Iterator<Tree> it = list.iterator();
        while (it.hasNext()) {
            findKnownStates(it.next(), newHashSet);
        }
        return Collections.unmodifiableSet(newHashSet);
    }

    public static void findKnownStates(Tree tree, Set<String> set) {
        if (tree.isLeaf() || tree.isPreTerminal()) {
            return;
        }
        if (!ShiftReduceUtils.isTemporary(tree)) {
            set.add(tree.value());
        }
        for (Tree tree2 : tree.children()) {
            findKnownStates(tree2, set);
        }
    }

    public static void redoTags(Tree tree, Tagger tagger) {
        List<TaggedWord> apply = tagger.apply((List<? extends HasWord>) tree.yieldWords());
        List<Label> preTerminalYield = tree.preTerminalYield();
        if (preTerminalYield.size() != apply.size()) {
            throw new AssertionError("Tags are not the same size");
        }
        for (int i = 0; i < preTerminalYield.size(); i++) {
            preTerminalYield.get(i).setValue(apply.get(i).tag());
        }
    }

    public static void redoTags(List<Tree> list, Tagger tagger, int i) {
        if (i == 1) {
            Iterator<Tree> it = list.iterator();
            while (it.hasNext()) {
                redoTags(it.next(), tagger);
            }
        } else {
            MulticoreWrapper multicoreWrapper = new MulticoreWrapper(i, new RetagProcessor(tagger));
            Iterator<Tree> it2 = list.iterator();
            while (it2.hasNext()) {
                multicoreWrapper.put(it2.next());
            }
            multicoreWrapper.join();
        }
    }

    private static Set<String> findRootStates(List<Tree> list) {
        Set newHashSet = Generics.newHashSet();
        Iterator<Tree> it = list.iterator();
        while (it.hasNext()) {
            newHashSet.add(it.next().value());
        }
        return Collections.unmodifiableSet(newHashSet);
    }

    private static Set<String> findRootOnlyStates(List<Tree> list, Set<String> set) {
        Set newHashSet = Generics.newHashSet(set);
        Iterator<Tree> it = list.iterator();
        while (it.hasNext()) {
            for (Tree tree : it.next().children()) {
                findRootOnlyStatesHelper(tree, set, newHashSet);
            }
        }
        return Collections.unmodifiableSet(newHashSet);
    }

    private static void findRootOnlyStatesHelper(Tree tree, Set<String> set, Set<String> set2) {
        set2.remove(tree.value());
        for (Tree tree2 : tree.children()) {
            findRootOnlyStatesHelper(tree2, set, set2);
        }
    }

    private void verifyTransitions(List<TrainingExample> list) {
        for (TrainingExample trainingExample : list) {
            State initialStateFromGoldTagTree = initialStateFromGoldTagTree(trainingExample.binarizedTree);
            List<Transition> list2 = trainingExample.transitions;
            int i = 0;
            while (true) {
                if (i >= list2.size()) {
                    break;
                }
                if (!list2.get(i).isLegal(initialStateFromGoldTagTree, null)) {
                    System.err.println("Transition list for a gold tree is illegal!");
                    System.err.println("  " + trainingExample.binarizedTree);
                    System.err.println("  " + list2);
                    System.err.println("  First illegal transition: " + i + ": " + list2.get(i));
                    System.err.println("  State at this time: " + initialStateFromGoldTagTree);
                    break;
                }
                initialStateFromGoldTagTree = list2.get(i).apply(initialStateFromGoldTagTree);
                i++;
            }
        }
    }

    private void train(List<Pair<String, FileFilter>> list, Pair<String, FileFilter> pair, String str) {
        log.info("Training method: " + this.op.trainOptions().trainingMethod);
        log.debug("Headfinder used to binarize trees: " + getTLPParams().headFinder().getClass());
        ArrayList newArrayList = Generics.newArrayList();
        for (Pair<String, FileFilter> pair2 : list) {
            newArrayList.addAll(readBinarizedTreebank(pair2.first(), pair2.second()));
        }
        int i = this.op.trainOptions.trainingThreads;
        int availableProcessors = i <= 0 ? Runtime.getRuntime().availableProcessors() : i;
        Tagger tagger = null;
        if (this.op.testOptions.preTag) {
            Timing timing = new Timing();
            tagger = Tagger.loadModel(this.op.testOptions.taggerSerializedFile);
            redoTags(newArrayList, tagger, availableProcessors);
            timing.done("Retagging");
        }
        Set<String> findKnownStates = findKnownStates(newArrayList);
        Set<String> findRootStates = findRootStates(newArrayList);
        Set<String> findRootOnlyStates = findRootOnlyStates(newArrayList, findRootStates);
        log.info("Known states: " + findKnownStates);
        log.info("States which occur at the root: " + findRootStates);
        log.info("States which only occur at the root: " + findRootOnlyStates);
        Timing timing2 = new Timing();
        List<TrainingExample> createTransitionSequences = CreateTransitionSequence.createTransitionSequences(newArrayList, this.op.compoundUnaries, findRootStates, findRootOnlyStates);
        HashIndex hashIndex = new HashIndex();
        Iterator<TrainingExample> it = createTransitionSequences.iterator();
        while (it.hasNext()) {
            hashIndex.addAll(it.next().transitions);
        }
        verifyTransitions(createTransitionSequences);
        timing2.done("Converting trees into transition lists");
        log.info("Number of transitions: " + hashIndex.size());
        this.model = PerceptronModel.trainModel(this.op, hashIndex, findKnownStates, findRootStates, findRootOnlyStates, this.model, str, tagger, new Random(this.op.trainOptions.randomSeed), createTransitionSequences, pair != null ? readTreebank(pair.first(), pair.second()) : null, availableProcessors);
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public void setOptionFlags(String... strArr) {
        this.op.setOptions(strArr);
    }

    public static ShiftReduceParser loadModel(String str, String... strArr) {
        ShiftReduceParser shiftReduceParser = (ShiftReduceParser) IOUtils.readObjectAnnouncingTimingFromURLOrClasspathOrFileSystem(log, "Loading parser from serialized file", str);
        if (strArr.length > 0) {
            shiftReduceParser.setOptionFlags(strArr);
        }
        return shiftReduceParser;
    }

    public void saveModel(String str) {
        try {
            IOUtils.writeObjectToFile(this, str);
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    }

    public static void main(String[] strArr) {
        ArrayList newArrayList = Generics.newArrayList();
        ArrayList arrayList = null;
        Pair<String, FileFilter> pair = null;
        Pair<String, FileFilter> pair2 = null;
        String str = null;
        String str2 = null;
        String str3 = null;
        int i = 0;
        while (i < strArr.length) {
            if (strArr[i].equalsIgnoreCase("-trainTreebank")) {
                if (arrayList == null) {
                    arrayList = Generics.newArrayList();
                }
                arrayList.add(ArgUtils.getTreebankDescription(strArr, i, "-trainTreebank"));
                i = i + ArgUtils.numSubArgs(strArr, i) + 1;
            } else if (strArr[i].equalsIgnoreCase("-testTreebank")) {
                pair = ArgUtils.getTreebankDescription(strArr, i, "-testTreebank");
                i = i + ArgUtils.numSubArgs(strArr, i) + 1;
            } else if (strArr[i].equalsIgnoreCase("-devTreebank")) {
                pair2 = ArgUtils.getTreebankDescription(strArr, i, "-devTreebank");
                i = i + ArgUtils.numSubArgs(strArr, i) + 1;
            } else if (strArr[i].equalsIgnoreCase("-serializedPath") || strArr[i].equalsIgnoreCase("-model")) {
                str = strArr[i + 1];
                i += 2;
            } else if (strArr[i].equalsIgnoreCase("-tlpp")) {
                str2 = strArr[i + 1];
                i += 2;
            } else if (strArr[i].equalsIgnoreCase("-continueTraining")) {
                str3 = strArr[i + 1];
                i += 2;
            } else {
                newArrayList.add(strArr[i]);
                i++;
            }
        }
        String[] strArr2 = (String[]) newArrayList.toArray(new String[newArrayList.size()]);
        if (arrayList == null && str == null) {
            throw new IllegalArgumentException("Must specify a treebank to train from with -trainTreebank or a parser to load with -serializedPath");
        }
        ShiftReduceParser shiftReduceParser = null;
        if (arrayList != null) {
            log.info("Training ShiftReduceParser");
            log.info("Initial arguments:");
            log.info("   " + StringUtils.join(strArr));
            shiftReduceParser = str3 != null ? loadModel(str3, (String[]) ArrayUtils.concatenate(BASIC_TRAINING_OPTIONS, strArr2)) : new ShiftReduceParser(buildTrainingOptions(str2, strArr2));
            Timing timing = new Timing();
            shiftReduceParser.train(arrayList, pair2, str);
            timing.done("Overall training process");
            shiftReduceParser.saveModel(str);
        }
        if (str != null && shiftReduceParser == null) {
            shiftReduceParser = loadModel(str, (String[]) ArrayUtils.concatenate(FORCE_TAGS, strArr2));
        }
        if (pair != null) {
            log.info("Loading test trees from " + pair.first());
            MemoryTreebank memoryTreebank = shiftReduceParser.op.tlpParams.memoryTreebank();
            memoryTreebank.loadPath(pair.first(), pair.second());
            log.info("Loaded " + memoryTreebank.size() + " trees");
            new EvaluateTreebank(shiftReduceParser.op, null, shiftReduceParser).testOnTreebank(memoryTreebank);
        }
    }
}
