package edu.stanford.nlp.parser.shiftreduce;

import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.io.RuntimeIOException;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.Label;
import edu.stanford.nlp.ling.TaggedWord;
import edu.stanford.nlp.parser.common.ArgUtils;
import edu.stanford.nlp.parser.common.ParserConstraint;
import edu.stanford.nlp.parser.common.ParserGrammar;
import edu.stanford.nlp.parser.common.ParserQuery;
import edu.stanford.nlp.parser.common.ParserUtils;
import edu.stanford.nlp.parser.lexparser.BinaryHeadFinder;
import edu.stanford.nlp.parser.lexparser.EvaluateTreebank;
import edu.stanford.nlp.parser.lexparser.Options;
import edu.stanford.nlp.parser.lexparser.TreeBinarizer;
import edu.stanford.nlp.parser.lexparser.TreebankLangParserParams;
import edu.stanford.nlp.parser.metrics.Eval;
import edu.stanford.nlp.parser.metrics.ParserQueryEval;
import edu.stanford.nlp.parser.shiftreduce.BinaryTransition;
import edu.stanford.nlp.parser.shiftreduce.ShiftReduceTrainOptions;
import edu.stanford.nlp.parser.shiftreduce.TreeRecorder;
import edu.stanford.nlp.stats.IntCounter;
import edu.stanford.nlp.tagger.common.Tagger;
import edu.stanford.nlp.tagger.maxent.TaggerConfig;
import edu.stanford.nlp.trees.BasicCategoryTreeTransformer;
import edu.stanford.nlp.trees.CompositeTreeTransformer;
import edu.stanford.nlp.trees.LabeledScoredTreeNode;
import edu.stanford.nlp.trees.MemoryTreebank;
import edu.stanford.nlp.trees.Tree;
import edu.stanford.nlp.trees.TreeCoreAnnotations;
import edu.stanford.nlp.trees.Treebank;
import edu.stanford.nlp.trees.TreebankLanguagePack;
import edu.stanford.nlp.trees.Trees;
import edu.stanford.nlp.util.ArrayUtils;
import edu.stanford.nlp.util.CollectionUtils;
import edu.stanford.nlp.util.Function;
import edu.stanford.nlp.util.Generics;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.ReflectionLoading;
import edu.stanford.nlp.util.ScoredComparator;
import edu.stanford.nlp.util.ScoredObject;
import edu.stanford.nlp.util.StringUtils;
import edu.stanford.nlp.util.Timing;
import edu.stanford.nlp.util.Triple;
import edu.stanford.nlp.util.concurrent.MulticoreWrapper;
import edu.stanford.nlp.util.concurrent.ThreadsafeProcessor;
import java.io.FileFilter;
import java.io.IOException;
import java.io.Serializable;
import java.text.DecimalFormat;
import java.text.NumberFormat;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Random;
import java.util.Set;

/* loaded from: input_file:edu/stanford/nlp/parser/shiftreduce/ShiftReduceParser.class */
public class ShiftReduceParser extends ParserGrammar implements Serializable {
    Index<Transition> transitionIndex = new HashIndex();
    Map<String, Weight> featureWeights = Generics.newHashMap();
    ShiftReduceOptions op;
    FeatureFactory featureFactory;
    Set<String> knownStates;
    Set<String> rootStates;
    Set<String> rootOnlyStates;
    private static final String[] BEAM_FLAGS = {"-beamSize", "4"};
    private static final NumberFormat NF = new DecimalFormat("0.00");
    private static final NumberFormat FILENAME = new DecimalFormat("0000");
    static final String[] FORCE_TAGS = {"-forceTags"};
    private static final long serialVersionUID = 1;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/parser/shiftreduce/ShiftReduceParser$RetagProcessor.class */
    public static class RetagProcessor implements ThreadsafeProcessor<Tree, Tree> {
        Tagger tagger;

        public RetagProcessor(Tagger tagger) {
            this.tagger = tagger;
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        public Tree process(Tree tree) {
            ShiftReduceParser.redoTags(tree, this.tagger);
            return tree;
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        /* renamed from: newInstance, reason: merged with bridge method [inline-methods] */
        public ThreadsafeProcessor<Tree, Tree> newInstance2() {
            return this;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/parser/shiftreduce/ShiftReduceParser$TrainTreeProcessor.class */
    public class TrainTreeProcessor implements ThreadsafeProcessor<Integer, Pair<Integer, Integer>> {
        List<Tree> binarizedTrees;
        List<List<Transition>> transitionLists;
        List<Update> updates;
        Oracle oracle;

        public TrainTreeProcessor(List<Tree> list, List<List<Transition>> list2, List<Update> list3, Oracle oracle) {
            this.binarizedTrees = list;
            this.transitionLists = list2;
            this.updates = list3;
            this.oracle = oracle;
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        public Pair<Integer, Integer> process(Integer num) {
            return ShiftReduceParser.this.trainTree(num.intValue(), this.binarizedTrees, this.transitionLists, this.updates, this.oracle);
        }

        @Override // edu.stanford.nlp.util.concurrent.ThreadsafeProcessor
        /* renamed from: newInstance */
        public ThreadsafeProcessor<Integer, Pair<Integer, Integer>> newInstance2() {
            return this;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/stanford/nlp/parser/shiftreduce/ShiftReduceParser$Update.class */
    public static class Update {
        final List<String> features;
        final int goldTransition;
        final int predictedTransition;
        final float delta;

        Update(List<String> list, int i, int i2, float f) {
            this.features = list;
            this.goldTransition = i;
            this.predictedTransition = i2;
            this.delta = f;
        }
    }

    public ShiftReduceParser(ShiftReduceOptions shiftReduceOptions) {
        this.op = shiftReduceOptions;
        String[] split = shiftReduceOptions.featureFactoryClass.split(";");
        if (split.length == 1) {
            this.featureFactory = (FeatureFactory) ReflectionLoading.loadByReflection(split[0], new Object[0]);
            return;
        }
        FeatureFactory[] featureFactoryArr = new FeatureFactory[split.length];
        for (int i = 0; i < split.length; i++) {
            int indexOf = split[i].indexOf("(");
            if (indexOf >= 0) {
                featureFactoryArr[i] = (FeatureFactory) ReflectionLoading.loadByReflection(split[i].substring(0, indexOf), split[i].substring(indexOf + 1, split[i].length() - 1));
            } else {
                featureFactoryArr[i] = (FeatureFactory) ReflectionLoading.loadByReflection(split[i], new Object[0]);
            }
        }
        this.featureFactory = new CombinationFeatureFactory(featureFactoryArr);
    }

    private ShiftReduceParser(ShiftReduceOptions shiftReduceOptions, FeatureFactory featureFactory) {
        this.op = shiftReduceOptions;
        this.featureFactory = featureFactory;
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public Options getOp() {
        return this.op;
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public TreebankLangParserParams getTLPParams() {
        return this.op.tlpParams;
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public TreebankLanguagePack treebankLanguagePack() {
        return getTLPParams().treebankLanguagePack();
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public String[] defaultCoreNLPFlags() {
        return this.op.trainOptions().beamSize > 1 ? (String[]) ArrayUtils.concatenate(getTLPParams().defaultCoreNLPFlags(), BEAM_FLAGS) : getTLPParams().defaultCoreNLPFlags();
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public boolean requiresTags() {
        return true;
    }

    private ShiftReduceParser deepCopy() {
        ShiftReduceParser shiftReduceParser = new ShiftReduceParser(this.op, this.featureFactory);
        shiftReduceParser.copyWeights(this);
        return shiftReduceParser;
    }

    public void copyWeights(ShiftReduceParser shiftReduceParser) {
        this.transitionIndex.clear();
        Iterator<Transition> it = shiftReduceParser.transitionIndex.iterator();
        while (it.hasNext()) {
            this.transitionIndex.add(it.next());
        }
        this.knownStates = Collections.unmodifiableSet(Generics.newHashSet(shiftReduceParser.knownStates));
        this.rootStates = Collections.unmodifiableSet(Generics.newHashSet(shiftReduceParser.rootStates));
        this.rootOnlyStates = Collections.unmodifiableSet(Generics.newHashSet(shiftReduceParser.rootOnlyStates));
        this.featureWeights.clear();
        for (String str : shiftReduceParser.featureWeights.keySet()) {
            this.featureWeights.put(str, new Weight(shiftReduceParser.featureWeights.get(str)));
        }
    }

    public static ShiftReduceParser averageScoredModels(Collection<ScoredObject<ShiftReduceParser>> collection) {
        if (collection.size() == 0) {
            throw new IllegalArgumentException("Cannot average empty models");
        }
        System.err.print("Averaging " + collection.size() + " models with scores");
        Iterator<ScoredObject<ShiftReduceParser>> it = collection.iterator();
        while (it.hasNext()) {
            System.err.print(" " + NF.format(it.next().score()));
        }
        System.err.println();
        return averageModels(CollectionUtils.transformAsList(collection, new Function<ScoredObject<ShiftReduceParser>, ShiftReduceParser>() { // from class: edu.stanford.nlp.parser.shiftreduce.ShiftReduceParser.1
            @Override // edu.stanford.nlp.util.Function
            public ShiftReduceParser apply(ScoredObject<ShiftReduceParser> scoredObject) {
                return scoredObject.object();
            }
        }));
    }

    public static ShiftReduceParser averageModels(Collection<ShiftReduceParser> collection) {
        ShiftReduceParser next = collection.iterator().next();
        ShiftReduceParser shiftReduceParser = new ShiftReduceParser(next.op, next.featureFactory);
        Iterator<Transition> it = next.transitionIndex.iterator();
        while (it.hasNext()) {
            shiftReduceParser.transitionIndex.add(it.next());
        }
        shiftReduceParser.knownStates = Collections.unmodifiableSet(Generics.newHashSet(next.knownStates));
        shiftReduceParser.rootStates = Collections.unmodifiableSet(Generics.newHashSet(next.rootStates));
        shiftReduceParser.rootOnlyStates = Collections.unmodifiableSet(Generics.newHashSet(next.rootOnlyStates));
        Iterator<ShiftReduceParser> it2 = collection.iterator();
        while (it2.hasNext()) {
            if (!it2.next().transitionIndex.equals(shiftReduceParser.transitionIndex)) {
                throw new IllegalArgumentException("Can only average models with the same transition index");
            }
        }
        Set<String> newHashSet = Generics.newHashSet();
        Iterator<ShiftReduceParser> it3 = collection.iterator();
        while (it3.hasNext()) {
            Iterator<String> it4 = it3.next().featureWeights.keySet().iterator();
            while (it4.hasNext()) {
                newHashSet.add(it4.next());
            }
        }
        Iterator it5 = newHashSet.iterator();
        while (it5.hasNext()) {
            shiftReduceParser.featureWeights.put((String) it5.next(), new Weight());
        }
        int size = collection.size();
        for (String str : newHashSet) {
            for (ShiftReduceParser shiftReduceParser2 : collection) {
                if (shiftReduceParser2.featureWeights.containsKey(str)) {
                    shiftReduceParser.featureWeights.get(str).addScaled(shiftReduceParser2.featureWeights.get(str), 1.0f / size);
                }
            }
        }
        return shiftReduceParser;
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public ParserQuery parserQuery() {
        return new ShiftReduceParserQuery(this);
    }

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // edu.stanford.nlp.parser.common.ParserGrammar, edu.stanford.nlp.util.Function
    public Tree apply(List<? extends HasWord> list) {
        ShiftReduceParserQuery shiftReduceParserQuery = new ShiftReduceParserQuery(this);
        return shiftReduceParserQuery.parse(list) ? shiftReduceParserQuery.getBestParse() : ParserUtils.xTree(list);
    }

    public void condenseFeatures() {
        Iterator<String> it = this.featureWeights.keySet().iterator();
        while (it.hasNext()) {
            Weight weight = this.featureWeights.get(it.next());
            weight.condense();
            if (weight.size() == 0) {
                it.remove();
            }
        }
    }

    public void filterFeatures(Set<String> set) {
        Iterator<String> it = this.featureWeights.keySet().iterator();
        while (it.hasNext()) {
            if (!set.contains(it.next())) {
                it.remove();
            }
        }
    }

    public void outputStats() {
        System.err.println("Number of known features: " + this.featureWeights.size());
        int i = 0;
        Iterator<String> it = this.featureWeights.keySet().iterator();
        while (it.hasNext()) {
            i += this.featureWeights.get(it.next()).size();
        }
        System.err.println("Number of non-zero weights: " + i);
        int i2 = 0;
        Iterator<String> it2 = this.featureWeights.keySet().iterator();
        while (it2.hasNext()) {
            i2 += it2.next().length();
        }
        System.err.println("Total word length: " + i2);
        System.err.println("Number of transitions: " + this.transitionIndex.size());
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public List<Eval> getExtraEvals() {
        return Collections.emptyList();
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public List<ParserQueryEval> getParserQueryEvals() {
        if (this.op.testOptions().recordBinarized == null && this.op.testOptions().recordDebinarized == null) {
            return Collections.emptyList();
        }
        ArrayList newArrayList = Generics.newArrayList();
        if (this.op.testOptions().recordBinarized != null) {
            newArrayList.add(new TreeRecorder(TreeRecorder.Mode.BINARIZED, this.op.testOptions().recordBinarized));
        }
        if (this.op.testOptions().recordDebinarized != null) {
            newArrayList.add(new TreeRecorder(TreeRecorder.Mode.DEBINARIZED, this.op.testOptions().recordDebinarized));
        }
        return newArrayList;
    }

    public Transition findEmergencyTransition(State state, List<ParserConstraint> list) {
        if (state.stack.size() == 0) {
            return null;
        }
        if (list != null) {
            Tree peek = state.stack.peek();
            for (ParserConstraint parserConstraint : list) {
                if (ShiftReduceUtils.leftIndex(peek) == parserConstraint.start && ShiftReduceUtils.rightIndex(peek) == parserConstraint.end - 1 && !ShiftReduceUtils.constraintMatchesTreeTop(peek, parserConstraint)) {
                    for (String str : this.knownStates) {
                        if (parserConstraint.state.matcher(str).matches()) {
                            return this.op.compoundUnaries ? new CompoundUnaryTransition(Collections.singletonList(str), false) : new UnaryTransition(str, false);
                        }
                    }
                }
            }
        }
        if (ShiftReduceUtils.isTemporary(state.stack.peek()) && (state.stack.size() == 1 || ShiftReduceUtils.isTemporary(state.stack.pop().peek()))) {
            return this.op.compoundUnaries ? new CompoundUnaryTransition(Collections.singletonList(state.stack.peek().value().substring(1)), false) : new UnaryTransition(state.stack.peek().value().substring(1), false);
        }
        if (state.stack.size() == 1 && state.tokenPosition >= state.sentence.size() && !this.rootStates.contains(state.stack.peek().value())) {
            String next = this.rootStates.iterator().next();
            return this.op.compoundUnaries ? new CompoundUnaryTransition(Collections.singletonList(next), false) : new UnaryTransition(next, false);
        }
        if (state.stack.size() == 1) {
            return null;
        }
        if (ShiftReduceUtils.isTemporary(state.stack.peek())) {
            return new BinaryTransition(state.stack.peek().value().substring(1), BinaryTransition.Side.RIGHT);
        }
        if (ShiftReduceUtils.isTemporary(state.stack.pop().peek())) {
            return new BinaryTransition(state.stack.pop().peek().value().substring(1), BinaryTransition.Side.LEFT);
        }
        return null;
    }

    public ScoredObject<Integer> findHighestScoringTransition(State state, List<String> list, boolean z) {
        Collection<ScoredObject<Integer>> findHighestScoringTransitions = findHighestScoringTransitions(state, list, z, 1, null);
        if (findHighestScoringTransitions.size() == 0) {
            return null;
        }
        return findHighestScoringTransitions.iterator().next();
    }

    public Collection<ScoredObject<Integer>> findHighestScoringTransitions(State state, List<String> list, boolean z, int i, List<ParserConstraint> list2) {
        float[] fArr = new float[this.transitionIndex.size()];
        Iterator<String> it = list.iterator();
        while (it.hasNext()) {
            Weight weight = this.featureWeights.get(it.next());
            if (weight != null) {
                weight.score(fArr);
            }
        }
        PriorityQueue priorityQueue = new PriorityQueue(i + 1, ScoredComparator.ASCENDING_COMPARATOR);
        for (int i2 = 0; i2 < fArr.length; i2++) {
            if (!z || this.transitionIndex.get(i2).isLegal(state, list2)) {
                priorityQueue.add(new ScoredObject(Integer.valueOf(i2), fArr[i2]));
                if (priorityQueue.size() > i) {
                    priorityQueue.poll();
                }
            }
        }
        return priorityQueue;
    }

    public static State initialStateFromGoldTagTree(Tree tree) {
        return initialStateFromTaggedSentence(tree.taggedYield());
    }

    public static State initialStateFromTaggedSentence(List<? extends HasWord> list) {
        CoreLabel coreLabel;
        String tag;
        ArrayList newArrayList = Generics.newArrayList();
        for (int i = 0; i < list.size(); i++) {
            HasWord hasWord = list.get(i);
            if (hasWord instanceof CoreLabel) {
                coreLabel = (CoreLabel) hasWord;
                tag = coreLabel.tag();
            } else {
                coreLabel = new CoreLabel();
                coreLabel.setValue(hasWord.word());
                coreLabel.setWord(hasWord.word());
                if (!(hasWord instanceof HasTag)) {
                    throw new IllegalArgumentException("Expected tagged words");
                }
                tag = ((HasTag) hasWord).tag();
                coreLabel.setTag(tag);
            }
            if (tag == null) {
                throw new IllegalArgumentException("Input word not tagged");
            }
            CoreLabel coreLabel2 = new CoreLabel();
            coreLabel2.setValue(tag);
            coreLabel.setIndex(i + 1);
            coreLabel2.setIndex(i + 1);
            LabeledScoredTreeNode labeledScoredTreeNode = new LabeledScoredTreeNode(coreLabel);
            LabeledScoredTreeNode labeledScoredTreeNode2 = new LabeledScoredTreeNode(coreLabel2);
            labeledScoredTreeNode2.addChild(labeledScoredTreeNode);
            coreLabel.set(TreeCoreAnnotations.HeadWordAnnotation.class, labeledScoredTreeNode);
            coreLabel.set(TreeCoreAnnotations.HeadTagAnnotation.class, labeledScoredTreeNode2);
            coreLabel2.set(TreeCoreAnnotations.HeadWordAnnotation.class, labeledScoredTreeNode);
            coreLabel2.set(TreeCoreAnnotations.HeadTagAnnotation.class, labeledScoredTreeNode2);
            newArrayList.add(labeledScoredTreeNode2);
        }
        return new State(newArrayList);
    }

    public static ShiftReduceOptions buildTrainingOptions(String str, String[] strArr) {
        ShiftReduceOptions shiftReduceOptions = new ShiftReduceOptions();
        shiftReduceOptions.setOptions("-forceTags", "-debugOutputFrequency", TaggerConfig.NTHREADS, "-quietEvaluation");
        if (str != null) {
            shiftReduceOptions.tlpParams = (TreebankLangParserParams) ReflectionLoading.loadByReflection(str, new Object[0]);
        }
        shiftReduceOptions.setOptions(strArr);
        if (shiftReduceOptions.trainOptions.randomSeed == 0) {
            shiftReduceOptions.trainOptions.randomSeed = new Random().nextLong();
            System.err.println("Random seed not set by options, using " + shiftReduceOptions.trainOptions.randomSeed);
        }
        return shiftReduceOptions;
    }

    public Treebank readTreebank(String str, FileFilter fileFilter) {
        System.err.println("Loading trees from " + str);
        MemoryTreebank memoryTreebank = this.op.tlpParams.memoryTreebank();
        memoryTreebank.loadPath(str, fileFilter);
        System.err.println("Read in " + memoryTreebank.size() + " trees from " + str);
        return memoryTreebank;
    }

    public List<Tree> readBinarizedTreebank(String str, FileFilter fileFilter) {
        List<Tree> binarizeTreebank = binarizeTreebank(readTreebank(str, fileFilter), this.op);
        System.err.println("Converted trees to binarized format");
        return binarizeTreebank;
    }

    public static List<Tree> binarizeTreebank(Treebank treebank, Options options) {
        TreeBinarizer treeBinarizer = new TreeBinarizer(options.tlpParams.headFinder(), options.tlpParams.treebankLanguagePack(), false, false, 0, false, false, 0.0d, false, true, true);
        BasicCategoryTreeTransformer basicCategoryTreeTransformer = new BasicCategoryTreeTransformer(options.langpack());
        CompositeTreeTransformer compositeTreeTransformer = new CompositeTreeTransformer();
        compositeTreeTransformer.addTransformer(treeBinarizer);
        compositeTreeTransformer.addTransformer(basicCategoryTreeTransformer);
        Treebank transform = treebank.transform(compositeTreeTransformer);
        BinaryHeadFinder binaryHeadFinder = new BinaryHeadFinder(options.tlpParams.headFinder());
        ArrayList newArrayList = Generics.newArrayList();
        Iterator<Tree> it = transform.iterator();
        while (it.hasNext()) {
            Tree next = it.next();
            Trees.convertToCoreLabels(next);
            next.percolateHeadAnnotations(binaryHeadFinder);
            next.indexLeaves(1, true);
            newArrayList.add(next);
        }
        return newArrayList;
    }

    public static Set<String> findKnownStates(List<Tree> list) {
        Set newHashSet = Generics.newHashSet();
        Iterator<Tree> it = list.iterator();
        while (it.hasNext()) {
            findKnownStates(it.next(), newHashSet);
        }
        return Collections.unmodifiableSet(newHashSet);
    }

    public static void findKnownStates(Tree tree, Set<String> set) {
        if (tree.isLeaf() || tree.isPreTerminal()) {
            return;
        }
        if (!ShiftReduceUtils.isTemporary(tree)) {
            set.add(tree.value());
        }
        for (Tree tree2 : tree.children()) {
            findKnownStates(tree2, set);
        }
    }

    public static void redoTags(Tree tree, Tagger tagger) {
        List<TaggedWord> apply = tagger.apply((List<? extends HasWord>) tree.yieldWords());
        List<Label> preTerminalYield = tree.preTerminalYield();
        if (preTerminalYield.size() != apply.size()) {
            throw new AssertionError("Tags are not the same size");
        }
        for (int i = 0; i < preTerminalYield.size(); i++) {
            preTerminalYield.get(i).setValue(apply.get(i).tag());
        }
    }

    public static void redoTags(List<Tree> list, Tagger tagger, int i) {
        if (i == 1) {
            Iterator<Tree> it = list.iterator();
            while (it.hasNext()) {
                redoTags(it.next(), tagger);
            }
        } else {
            MulticoreWrapper multicoreWrapper = new MulticoreWrapper(i, new RetagProcessor(tagger));
            Iterator<Tree> it2 = list.iterator();
            while (it2.hasNext()) {
                multicoreWrapper.put(it2.next());
            }
            multicoreWrapper.join();
        }
    }

    private static boolean findStateOnAgenda(Collection<State> collection, State state) {
        Iterator<State> it = collection.iterator();
        while (it.hasNext()) {
            if (it.next().areTransitionsEqual(state)) {
                return true;
            }
        }
        return false;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Pair<Integer, Integer> trainTree(int i, List<Tree> list, List<List<Transition>> list2, List<Update> list3, Oracle oracle) {
        int i2 = 0;
        int i3 = 0;
        Tree tree = list.get(i);
        ReorderingOracle reorderingOracle = (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) ? new ReorderingOracle(this.op) : null;
        if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE) {
            State initialStateFromGoldTagTree = initialStateFromGoldTagTree(tree);
            while (!initialStateFromGoldTagTree.isFinished()) {
                List<String> featurize = this.featureFactory.featurize(initialStateFromGoldTagTree);
                ScoredObject<Integer> findHighestScoringTransition = findHighestScoringTransition(initialStateFromGoldTagTree, featurize, true);
                if (findHighestScoringTransition == null) {
                    throw new AssertionError("Did not find a legal transition");
                }
                int intValue = findHighestScoringTransition.object().intValue();
                Transition transition = this.transitionIndex.get(intValue);
                OracleTransition goldTransition = oracle.goldTransition(i, initialStateFromGoldTagTree);
                if (goldTransition.isCorrect(transition)) {
                    i2++;
                    if (goldTransition.transition != null && !goldTransition.transition.equals(transition)) {
                        int indexOf = this.transitionIndex.indexOf(goldTransition.transition);
                        if (indexOf >= 0) {
                            list3.add(new Update(featurize, indexOf, -1, 1.0f));
                        }
                    }
                } else {
                    i3++;
                    list3.add(new Update(featurize, goldTransition.transition != null ? this.transitionIndex.indexOf(goldTransition.transition) : -1, intValue, 1.0f));
                }
                initialStateFromGoldTagTree = transition.apply(initialStateFromGoldTagTree);
            }
        } else if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.BEAM || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
            if (this.op.trainOptions().beamSize <= 0) {
                throw new IllegalArgumentException("Illegal beam size " + this.op.trainOptions().beamSize);
            }
            LinkedList newLinkedList = Generics.newLinkedList(list2.get(i));
            PriorityQueue priorityQueue = new PriorityQueue(this.op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
            State initialStateFromGoldTagTree2 = initialStateFromGoldTagTree(tree);
            priorityQueue.add(initialStateFromGoldTagTree2);
            while (newLinkedList.size() > 0) {
                Transition transition2 = newLinkedList.get(0);
                Transition transition3 = null;
                double d = 0.0d;
                PriorityQueue priorityQueue2 = new PriorityQueue(this.op.trainOptions().beamSize + 1, ScoredComparator.ASCENDING_COMPARATOR);
                State state = null;
                State state2 = null;
                Iterator it = priorityQueue.iterator();
                while (it.hasNext()) {
                    State state3 = (State) it.next();
                    boolean z = this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && initialStateFromGoldTagTree2.areTransitionsEqual(state3);
                    for (ScoredObject<Integer> scoredObject : findHighestScoringTransitions(state3, this.featureFactory.featurize(state3), true, this.op.trainOptions().beamSize, null)) {
                        State apply = this.transitionIndex.get(scoredObject.object().intValue()).apply(state3, scoredObject.score());
                        priorityQueue2.add(apply);
                        if (priorityQueue2.size() > this.op.trainOptions().beamSize) {
                            priorityQueue2.poll();
                        }
                        if (state == null || state.score() < apply.score()) {
                            state = apply;
                            state2 = state3;
                        }
                        if (z && (transition3 == null || scoredObject.score() > d)) {
                            transition3 = this.transitionIndex.get(scoredObject.object().intValue());
                            d = scoredObject.score();
                        }
                    }
                }
                if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM && transition3 == null) {
                    break;
                }
                State apply2 = transition2.apply(initialStateFromGoldTagTree2, 0.0d);
                if (!apply2.areTransitionsEqual(state)) {
                    i3++;
                    List<String> featurize2 = this.featureFactory.featurize(initialStateFromGoldTagTree2);
                    list3.add(new Update(this.featureFactory.featurize(state2), -1, this.transitionIndex.indexOf(state.transitions.peek()), 1.0f));
                    list3.add(new Update(featurize2, this.transitionIndex.indexOf(transition2), -1, 1.0f));
                    if (this.op.trainOptions().trainingMethod != ShiftReduceTrainOptions.TrainingMethod.BEAM) {
                        if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_BEAM) {
                            if (!findStateOnAgenda(priorityQueue2, apply2)) {
                                if (!reorderingOracle.reorder(initialStateFromGoldTagTree2, transition3, newLinkedList)) {
                                    break;
                                }
                                apply2 = transition3.apply(initialStateFromGoldTagTree2);
                                if (!findStateOnAgenda(priorityQueue2, apply2)) {
                                    break;
                                }
                            } else {
                                newLinkedList.remove(0);
                            }
                        } else {
                            continue;
                        }
                    } else {
                        if (!findStateOnAgenda(priorityQueue2, apply2)) {
                            break;
                        }
                        newLinkedList.remove(0);
                    }
                } else {
                    i2++;
                    newLinkedList.remove(0);
                }
                initialStateFromGoldTagTree2 = apply2;
                priorityQueue = priorityQueue2;
            }
        } else if (this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.REORDER_ORACLE || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.EARLY_TERMINATION || this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.GOLD) {
            State initialStateFromGoldTagTree3 = initialStateFromGoldTagTree(tree);
            LinkedList newLinkedList2 = Generics.newLinkedList(list2.get(i));
            boolean z2 = true;
            while (newLinkedList2.size() > 0 && z2) {
                Transition transition4 = newLinkedList2.get(0);
                int indexOf2 = this.transitionIndex.indexOf(transition4);
                List<String> featurize3 = this.featureFactory.featurize(initialStateFromGoldTagTree3);
                int intValue2 = findHighestScoringTransition(initialStateFromGoldTagTree3, featurize3, false).object().intValue();
                Transition transition5 = this.transitionIndex.get(intValue2);
                if (indexOf2 == intValue2) {
                    newLinkedList2.remove(0);
                    initialStateFromGoldTagTree3 = transition4.apply(initialStateFromGoldTagTree3);
                    i2++;
                } else {
                    i3++;
                    list3.add(new Update(featurize3, indexOf2, intValue2, 1.0f));
                    switch (this.op.trainOptions().trainingMethod) {
                        case EARLY_TERMINATION:
                            z2 = false;
                            break;
                        case GOLD:
                            newLinkedList2.remove(0);
                            initialStateFromGoldTagTree3 = transition4.apply(initialStateFromGoldTagTree3);
                            break;
                        case REORDER_ORACLE:
                            z2 = reorderingOracle.reorder(initialStateFromGoldTagTree3, transition5, newLinkedList2);
                            if (z2) {
                                initialStateFromGoldTagTree3 = transition5.apply(initialStateFromGoldTagTree3);
                                break;
                            } else {
                                break;
                            }
                        default:
                            throw new IllegalArgumentException("Unexpected method " + this.op.trainOptions().trainingMethod);
                    }
                }
            }
        }
        return Pair.makePair(Integer.valueOf(i2), Integer.valueOf(i3));
    }

    private Triple<List<Update>, Integer, Integer> trainBatch(List<Integer> list, List<Tree> list2, List<List<Transition>> list3, List<Update> list4, Oracle oracle, MulticoreWrapper<Integer, Pair<Integer, Integer>> multicoreWrapper) {
        int i = 0;
        int i2 = 0;
        if (this.op.trainOptions.trainingThreads == 1) {
            Iterator<Integer> it = list.iterator();
            while (it.hasNext()) {
                Pair<Integer, Integer> trainTree = trainTree(it.next().intValue(), list2, list3, list4, oracle);
                i += trainTree.first.intValue();
                i2 += trainTree.second.intValue();
            }
        } else {
            Iterator<Integer> it2 = list.iterator();
            while (it2.hasNext()) {
                multicoreWrapper.put(it2.next());
            }
            multicoreWrapper.join(false);
            while (multicoreWrapper.peek()) {
                Pair<Integer, Integer> poll = multicoreWrapper.poll();
                i += poll.first.intValue();
                i2 += poll.second.intValue();
            }
        }
        return new Triple<>(list4, Integer.valueOf(i), Integer.valueOf(i2));
    }

    private static Set<String> findRootStates(List<Tree> list) {
        Set newHashSet = Generics.newHashSet();
        Iterator<Tree> it = list.iterator();
        while (it.hasNext()) {
            newHashSet.add(it.next().value());
        }
        return Collections.unmodifiableSet(newHashSet);
    }

    private static Set<String> findRootOnlyStates(List<Tree> list, Set<String> set) {
        Set newHashSet = Generics.newHashSet(set);
        Iterator<Tree> it = list.iterator();
        while (it.hasNext()) {
            for (Tree tree : it.next().children()) {
                findRootOnlyStatesHelper(tree, set, newHashSet);
            }
        }
        return Collections.unmodifiableSet(newHashSet);
    }

    private static void findRootOnlyStatesHelper(Tree tree, Set<String> set, Set<String> set2) {
        set2.remove(tree.value());
        for (Tree tree2 : tree.children()) {
            findRootOnlyStatesHelper(tree2, set, set2);
        }
    }

    private void train(List<Pair<String, FileFilter>> list, Pair<String, FileFilter> pair, String str, Set<String> set) {
        System.err.println("Training method: " + this.op.trainOptions().trainingMethod);
        ArrayList newArrayList = Generics.newArrayList();
        for (Pair<String, FileFilter> pair2 : list) {
            newArrayList.addAll(readBinarizedTreebank(pair2.first(), pair2.second()));
        }
        int i = this.op.trainOptions.trainingThreads;
        int availableProcessors = i <= 0 ? Runtime.getRuntime().availableProcessors() : i;
        Tagger tagger = null;
        if (this.op.testOptions.preTag) {
            Timing timing = new Timing();
            tagger = Tagger.loadModel(this.op.testOptions.taggerSerializedFile);
            redoTags(newArrayList, tagger, availableProcessors);
            timing.done("Retagging");
        }
        this.knownStates = findKnownStates(newArrayList);
        this.rootStates = findRootStates(newArrayList);
        this.rootOnlyStates = findRootOnlyStates(newArrayList, this.rootStates);
        System.err.println("Known states: " + this.knownStates);
        System.err.println("States which occur at the root: " + this.rootStates);
        System.err.println("States which only occur at the root: " + this.rootStates);
        Timing timing2 = new Timing();
        List<List<Transition>> createTransitionSequences = CreateTransitionSequence.createTransitionSequences(newArrayList, this.op.compoundUnaries, this.rootStates, this.rootOnlyStates);
        Iterator<List<Transition>> it = createTransitionSequences.iterator();
        while (it.hasNext()) {
            this.transitionIndex.addAll(it.next());
        }
        timing2.done("Converting trees into transition lists");
        System.err.println("Number of transitions: " + this.transitionIndex.size());
        Random random = new Random(this.op.trainOptions.randomSeed);
        Treebank readTreebank = pair != null ? readTreebank(pair.first(), pair.second()) : null;
        double d = 0.0d;
        int i2 = 0;
        PriorityQueue priorityQueue = this.op.trainOptions().averagedModels > 0 ? new PriorityQueue(this.op.trainOptions().averagedModels + 1, ScoredComparator.ASCENDING_COMPARATOR) : null;
        ArrayList newArrayList2 = Generics.newArrayList();
        for (int i3 = 0; i3 < newArrayList.size(); i3++) {
            newArrayList2.add(Integer.valueOf(i3));
        }
        Oracle oracle = this.op.trainOptions().trainingMethod == ShiftReduceTrainOptions.TrainingMethod.ORACLE ? new Oracle(newArrayList, this.op.compoundUnaries, this.rootStates) : null;
        List<Update> newArrayList3 = Generics.newArrayList();
        MulticoreWrapper<Integer, Pair<Integer, Integer>> multicoreWrapper = null;
        if (availableProcessors != 1) {
            newArrayList3 = Collections.synchronizedList(newArrayList3);
            multicoreWrapper = new MulticoreWrapper<>(this.op.trainOptions.trainingThreads, new TrainTreeProcessor(newArrayList, createTransitionSequences, newArrayList3, oracle));
        }
        IntCounter intCounter = this.op.trainOptions().featureFrequencyCutoff > 1 ? new IntCounter() : null;
        int i4 = 1;
        while (true) {
            if (i4 > this.op.trainOptions.trainingIterations) {
                break;
            }
            Timing timing3 = new Timing();
            int i5 = 0;
            int i6 = 0;
            Collections.shuffle(newArrayList2, random);
            int i7 = 0;
            while (true) {
                int i8 = i7;
                if (i8 >= newArrayList2.size()) {
                    break;
                }
                Triple<List<Update>, Integer, Integer> trainBatch = trainBatch(newArrayList2.subList(i8, Math.min(i8 + this.op.trainOptions.batchSize, newArrayList2.size())), newArrayList, createTransitionSequences, newArrayList3, oracle, multicoreWrapper);
                i5 += trainBatch.second.intValue();
                i6 += trainBatch.third.intValue();
                for (Update update : trainBatch.first) {
                    for (String str2 : update.features) {
                        if (set == null || set.contains(str2)) {
                            Weight weight = this.featureWeights.get(str2);
                            if (weight == null) {
                                weight = new Weight();
                                this.featureWeights.put(str2, weight);
                            }
                            weight.updateWeight(update.goldTransition, update.delta);
                            weight.updateWeight(update.predictedTransition, -update.delta);
                            if (intCounter != null) {
                                intCounter.incrementCount((IntCounter) str2, (update.goldTransition < 0 || update.predictedTransition < 0) ? 1 : 2);
                            }
                        }
                    }
                }
                newArrayList3.clear();
                i7 = i8 + this.op.trainOptions.batchSize;
            }
            timing3.done("Iteration " + i4);
            System.err.println("While training, got " + i5 + " transitions correct and " + i6 + " transitions wrong");
            outputStats();
            double d2 = 0.0d;
            if (readTreebank != null) {
                EvaluateTreebank evaluateTreebank = new EvaluateTreebank(this.op, null, this, tagger);
                evaluateTreebank.testOnTreebank(readTreebank);
                d2 = evaluateTreebank.getLBScore();
                System.err.println("Label F1 after " + i4 + " iterations: " + d2);
                if (d2 <= d) {
                    System.err.println("Failed to improve for " + (i4 - i2) + " iteration(s) on previous best score of " + d);
                    if (this.op.trainOptions.stalledIterationLimit > 0 && i4 - i2 >= this.op.trainOptions.stalledIterationLimit) {
                        System.err.println("Failed to improve for too long, stopping training");
                        break;
                    }
                } else {
                    System.err.println("New best dev score (previous best " + d + ")");
                    d = d2;
                    i2 = i4;
                }
                System.err.println();
                if (priorityQueue != null) {
                    priorityQueue.add(new ScoredObject(deepCopy(), d2));
                    if (priorityQueue.size() > this.op.trainOptions().averagedModels) {
                        priorityQueue.poll();
                    }
                }
            }
            if (this.op.trainOptions().saveIntermediateModels && str != null && this.op.trainOptions.debugOutputFrequency > 0) {
                saveModel(str.substring(0, str.length() - 7) + "-" + FILENAME.format(i4) + "-" + NF.format(d2) + ".ser.gz");
            }
            i4++;
        }
        if (multicoreWrapper != null) {
            multicoreWrapper.join();
        }
        if (priorityQueue != null) {
            if (!this.op.trainOptions().cvAveragedModels || readTreebank == null) {
                copyWeights(averageScoredModels(priorityQueue));
            } else {
                ArrayList newArrayList4 = Generics.newArrayList();
                while (priorityQueue.size() > 0) {
                    newArrayList4.add(priorityQueue.poll());
                }
                Collections.reverse(newArrayList4);
                double d3 = 0.0d;
                int i9 = 0;
                for (int i10 = 1; i10 <= newArrayList4.size(); i10++) {
                    System.err.println("Testing with " + i10 + " models averaged together");
                    ShiftReduceParser averageScoredModels = averageScoredModels(newArrayList4.subList(0, i10));
                    EvaluateTreebank evaluateTreebank2 = new EvaluateTreebank(averageScoredModels.op, null, averageScoredModels, tagger);
                    evaluateTreebank2.testOnTreebank(readTreebank);
                    double lBScore = evaluateTreebank2.getLBScore();
                    System.err.println("Label F1 for " + i10 + " models: " + lBScore);
                    if (lBScore > d3) {
                        d3 = lBScore;
                        i9 = i10;
                    }
                }
                copyWeights(averageScoredModels(newArrayList4.subList(0, i9)));
            }
        }
        if (intCounter != null) {
            filterFeatures(intCounter.keysAbove(this.op.trainOptions().featureFrequencyCutoff));
        }
        condenseFeatures();
    }

    @Override // edu.stanford.nlp.parser.common.ParserGrammar
    public void setOptionFlags(String... strArr) {
        this.op.setOptions(strArr);
    }

    public static ShiftReduceParser loadModel(String str, String... strArr) {
        try {
            Timing timing = new Timing();
            System.err.print("Loading parser from serialized file " + str + " ...");
            ShiftReduceParser shiftReduceParser = (ShiftReduceParser) IOUtils.readObjectFromURLOrClasspathOrFileSystem(str);
            timing.done();
            if (strArr.length > 0) {
                shiftReduceParser.setOptionFlags(strArr);
            }
            return shiftReduceParser;
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        } catch (ClassNotFoundException e2) {
            throw new RuntimeIOException(e2);
        }
    }

    public void saveModel(String str) {
        try {
            IOUtils.writeObjectToFile(this, str);
        } catch (IOException e) {
            throw new RuntimeIOException(e);
        }
    }

    public static void main(String[] strArr) {
        ArrayList newArrayList = Generics.newArrayList();
        ArrayList arrayList = null;
        Pair<String, FileFilter> pair = null;
        Pair<String, FileFilter> pair2 = null;
        String str = null;
        String str2 = null;
        String str3 = null;
        int i = 0;
        while (i < strArr.length) {
            if (strArr[i].equalsIgnoreCase("-trainTreebank")) {
                if (arrayList == null) {
                    arrayList = Generics.newArrayList();
                }
                arrayList.add(ArgUtils.getTreebankDescription(strArr, i, "-trainTreebank"));
                i = i + ArgUtils.numSubArgs(strArr, i) + 1;
            } else if (strArr[i].equalsIgnoreCase("-testTreebank")) {
                pair = ArgUtils.getTreebankDescription(strArr, i, "-testTreebank");
                i = i + ArgUtils.numSubArgs(strArr, i) + 1;
            } else if (strArr[i].equalsIgnoreCase("-devTreebank")) {
                pair2 = ArgUtils.getTreebankDescription(strArr, i, "-devTreebank");
                i = i + ArgUtils.numSubArgs(strArr, i) + 1;
            } else if (strArr[i].equalsIgnoreCase("-serializedPath") || strArr[i].equalsIgnoreCase("-model")) {
                str = strArr[i + 1];
                i += 2;
            } else if (strArr[i].equalsIgnoreCase("-tlpp")) {
                str2 = strArr[i + 1];
                i += 2;
            } else if (strArr[i].equalsIgnoreCase("-continueTraining")) {
                str3 = strArr[i + 1];
                i += 2;
            } else {
                newArrayList.add(strArr[i]);
                i++;
            }
        }
        String[] strArr2 = (String[]) newArrayList.toArray(new String[newArrayList.size()]);
        if (arrayList == null && str == null) {
            throw new IllegalArgumentException("Must specify a treebank to train from with -trainTreebank or a parser to load with -serializedPath");
        }
        ShiftReduceParser shiftReduceParser = null;
        if (arrayList != null) {
            System.err.println("Training ShiftReduceParser");
            System.err.println("Initial arguments:");
            System.err.println("   " + StringUtils.join(strArr));
            shiftReduceParser = str3 != null ? loadModel(str3, (String[]) ArrayUtils.concatenate(FORCE_TAGS, strArr2)) : new ShiftReduceParser(buildTrainingOptions(str2, strArr2));
            ShiftReduceOptions shiftReduceOptions = shiftReduceParser.op;
            if (!shiftReduceOptions.trainOptions().retrainAfterCutoff || shiftReduceOptions.trainOptions().featureFrequencyCutoff <= 0) {
                shiftReduceParser.train(arrayList, pair2, str, null);
            } else {
                String str4 = str.substring(0, str.length() - 7) + "-temp.ser.gz";
                shiftReduceParser.train(arrayList, pair2, str4, null);
                shiftReduceParser.saveModel(str4);
                Set<String> keySet = shiftReduceParser.featureWeights.keySet();
                shiftReduceParser = new ShiftReduceParser(shiftReduceOptions);
                shiftReduceParser.train(arrayList, pair2, str, keySet);
            }
            shiftReduceParser.saveModel(str);
        }
        if (str != null && shiftReduceParser == null) {
            shiftReduceParser = loadModel(str, (String[]) ArrayUtils.concatenate(FORCE_TAGS, strArr2));
        }
        if (pair != null) {
            System.err.println("Loading test trees from " + pair.first());
            MemoryTreebank memoryTreebank = shiftReduceParser.op.tlpParams.memoryTreebank();
            memoryTreebank.loadPath(pair.first(), pair.second());
            System.err.println("Loaded " + memoryTreebank.size() + " trees");
            new EvaluateTreebank(shiftReduceParser.op, null, shiftReduceParser).testOnTreebank(memoryTreebank);
        }
    }
}
