/*
 * Decompiled with CFR 0.152.
 */
package kaist.cilab.parser.berkeleyadaptation;

import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.Corpus;
import edu.berkeley.nlp.PCFGLA.CorpusStatistics;
import edu.berkeley.nlp.PCFGLA.FeaturizedLexicon;
import edu.berkeley.nlp.PCFGLA.Grammar;
import edu.berkeley.nlp.PCFGLA.GrammarMerger;
import edu.berkeley.nlp.PCFGLA.GrammarTrainer;
import edu.berkeley.nlp.PCFGLA.Lexicon;
import edu.berkeley.nlp.PCFGLA.OptionParser;
import edu.berkeley.nlp.PCFGLA.ParserData;
import edu.berkeley.nlp.PCFGLA.SimpleFeaturizer;
import edu.berkeley.nlp.PCFGLA.SimpleLexicon;
import edu.berkeley.nlp.PCFGLA.SophisticatedLexicon;
import edu.berkeley.nlp.PCFGLA.StateSetTreeList;
import edu.berkeley.nlp.PCFGLA.smoothing.NoSmoothing;
import edu.berkeley.nlp.PCFGLA.smoothing.SmoothAcrossParentBits;
import edu.berkeley.nlp.syntax.StateSet;
import edu.berkeley.nlp.syntax.Tree;
import edu.berkeley.nlp.util.Numberer;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;

public class Trainer {
    public void train(List<Tree<String>> trainTrees, String outFileName) {
        String argument = "-out " + outFileName;
        OptionParser optParser = new OptionParser(GrammarTrainer.Options.class);
        GrammarTrainer.Options opts = (GrammarTrainer.Options)optParser.parse(argument.split(" "), true);
        List<Tree<String>> validationTrees = this.getValidationData(trainTrees);
        opts.numSplits = 2;
        int maxSentenceLength = opts.maxSentenceLength;
        GrammarTrainer.HORIZONTAL_MARKOVIZATION = opts.horizontalMarkovization;
        GrammarTrainer.VERTICAL_MARKOVIZATION = opts.verticalMarkovization;
        Binarization binarization = opts.binarization;
        double randomness = opts.randomization;
        GrammarTrainer.RANDOM = new Random(opts.randSeed);
        boolean baseline = opts.baseline;
        boolean noSplit = opts.noSplit;
        int numSplitTimes = opts.numSplits;
        if (baseline) {
            numSplitTimes = 0;
        }
        String splitGrammarFile = opts.inFile;
        int allowedDroppingIters = opts.di;
        int maxIterations = opts.splitMaxIterations;
        int minIterations = opts.splitMinIterations;
        double[] smoothParams = new double[]{opts.smoothingParameter1, opts.smoothingParameter2};
        boolean allowMoreSubstatesThanCounts = false;
        boolean findClosedUnaryPaths = opts.findClosedUnaryPaths;
        Numberer tagNumberer = Numberer.getGlobalNumberer("tags");
        short nSubstates = opts.nSubStates;
        short[] numSubStatesArray = GrammarTrainer.initializeSubStateArray(trainTrees, validationTrees, tagNumberer, nSubstates);
        double filter = opts.filter;
        int nTrees = trainTrees.size();
        Lexicon lexicon = null;
        Lexicon maxLexicon = null;
        Lexicon previousLexicon = null;
        Grammar grammar = null;
        Grammar maxGrammar = null;
        Grammar previousGrammar = null;
        double maxLikelihood = Double.NEGATIVE_INFINITY;
        int iter = 0;
        int droppingIter = 0;
        int startSplit = 0;
        double mergingPercentage = opts.mergingPercentage;
        boolean separateMergingThreshold = opts.separateMergingThreshold;
        StateSetTreeList trainStateSetTrees = new StateSetTreeList(trainTrees, numSubStatesArray, false, tagNumberer);
        StateSetTreeList validationStateSetTrees = new StateSetTreeList(validationTrees, numSubStatesArray, false, tagNumberer);
        trainTrees = null;
        validationTrees = null;
        System.gc();
        SimpleFeaturizer feat = new SimpleFeaturizer(opts.rare, opts.reallyRare);
        if (splitGrammarFile == null) {
            Lexicon tmp_lexicon;
            grammar = new Grammar(numSubStatesArray, findClosedUnaryPaths, new NoSmoothing(), null, filter);
            Lexicon lexicon2 = tmp_lexicon = opts.simpleLexicon ? new SimpleLexicon(numSubStatesArray, -1, smoothParams, new NoSmoothing(), filter, trainStateSetTrees) : new SophisticatedLexicon(numSubStatesArray, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, smoothParams, new NoSmoothing(), filter);
            if (opts.featurizedLexicon) {
                tmp_lexicon = new FeaturizedLexicon(numSubStatesArray, feat, trainStateSetTrees);
            }
            int n = 0;
            boolean secondHalf = false;
            for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
                secondHalf = (double)n++ > (double)nTrees / 2.0;
                try {
                    tmp_lexicon.trainTree(stateSetTree, randomness, null, secondHalf, false, opts.rare);
                }
                catch (Exception e) {
                    System.out.println(stateSetTree.toString());
                }
            }
            Lexicon lexicon3 = lexicon = opts.simpleLexicon ? new SimpleLexicon(numSubStatesArray, -1, smoothParams, new NoSmoothing(), filter, trainStateSetTrees) : new SophisticatedLexicon(numSubStatesArray, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, smoothParams, new NoSmoothing(), filter);
            if (opts.featurizedLexicon) {
                lexicon = new FeaturizedLexicon(numSubStatesArray, feat, trainStateSetTrees);
            }
            for (Tree<StateSet> stateSetTree : trainStateSetTrees) {
                secondHalf = (double)n++ > (double)nTrees / 2.0;
                lexicon.trainTree(stateSetTree, randomness, tmp_lexicon, secondHalf, false, opts.rare);
                grammar.tallyUninitializedStateSetTree(stateSetTree);
            }
            lexicon.tieRareWordStats(opts.rare);
            lexicon.optimize();
            grammar.optimize(randomness);
            previousGrammar = maxGrammar = grammar;
            previousLexicon = maxLexicon = lexicon;
        }
        int splitIndex = startSplit;
        while (splitIndex < numSplitTimes * 3) {
            block21: {
                String opString;
                block22: {
                    block23: {
                        block20: {
                            opString = "";
                            if (splitIndex % 3 != 2) break block20;
                            if (opts.smooth.equals("NoSmoothing")) break block21;
                            System.out.println("Setting smoother for grammar and lexicon.");
                            SmoothAcrossParentBits grSmoother = new SmoothAcrossParentBits(0.01, maxGrammar.splitTrees);
                            SmoothAcrossParentBits lexSmoother = new SmoothAcrossParentBits(0.1, maxGrammar.splitTrees);
                            maxGrammar.setSmoother(grSmoother);
                            maxLexicon.setSmoother(lexSmoother);
                            minIterations = maxIterations = opts.smoothMaxIterations;
                            opString = "smoothing";
                            break block22;
                        }
                        if (splitIndex % 3 != 0) break block23;
                        if (opts.noSplit) break block21;
                        System.out.println("Before splitting, we have a total of " + maxGrammar.totalSubStates() + " substates.");
                        CorpusStatistics corpusStatistics = new CorpusStatistics(tagNumberer, trainStateSetTrees);
                        int[] counts = corpusStatistics.getSymbolCounts();
                        maxGrammar = maxGrammar.splitAllStates(randomness, counts, allowMoreSubstatesThanCounts, 0);
                        maxLexicon = maxLexicon.splitAllStates(counts, allowMoreSubstatesThanCounts, 0);
                        NoSmoothing grSmoother = new NoSmoothing();
                        NoSmoothing lexSmoother = new NoSmoothing();
                        maxGrammar.setSmoother(grSmoother);
                        maxLexicon.setSmoother(lexSmoother);
                        System.out.println("After splitting, we have a total of " + maxGrammar.totalSubStates() + " substates.");
                        System.out.println("Rule probabilities are NOT normalized in the split, therefore the training LL is not guaranteed to improve between iteration 0 and 1!");
                        opString = "splitting";
                        maxIterations = opts.splitMaxIterations;
                        minIterations = opts.splitMinIterations;
                        break block22;
                    }
                    if (mergingPercentage == 0.0) break block21;
                    double[][] mergeWeights = GrammarMerger.computeMergeWeights(maxGrammar, maxLexicon, trainStateSetTrees);
                    double[][][] deltas = GrammarMerger.computeDeltas(maxGrammar, maxLexicon, mergeWeights, trainStateSetTrees);
                    boolean[][][] mergeThesePairs = GrammarMerger.determineMergePairs(deltas, separateMergingThreshold, mergingPercentage, maxGrammar);
                    grammar = GrammarMerger.doTheMerges(maxGrammar, maxLexicon, mergeThesePairs, mergeWeights);
                    short[] newNumSubStatesArray = grammar.numSubStates;
                    trainStateSetTrees = new StateSetTreeList(trainStateSetTrees, newNumSubStatesArray, false);
                    validationStateSetTrees = new StateSetTreeList(validationStateSetTrees, newNumSubStatesArray, false);
                    lexicon = opts.featurizedLexicon ? new FeaturizedLexicon(newNumSubStatesArray, feat, trainStateSetTrees) : (opts.simpleLexicon ? new SimpleLexicon(newNumSubStatesArray, -1, smoothParams, maxLexicon.getSmoother(), filter, trainStateSetTrees) : new SophisticatedLexicon(newNumSubStatesArray, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, maxLexicon.getSmoothingParams(), maxLexicon.getSmoother(), maxLexicon.getPruningThreshold()));
                    boolean updateOnlyLexicon = true;
                    double trainingLikelihood = GrammarTrainer.doOneEStep(grammar, maxLexicon, null, lexicon, trainStateSetTrees, updateOnlyLexicon, opts.rare);
                    lexicon.optimize();
                    GrammarMerger.printMergingStatistics(maxGrammar, grammar);
                    opString = "merging";
                    maxGrammar = grammar;
                    maxLexicon = lexicon;
                    maxIterations = opts.mergeMaxIterations;
                    minIterations = opts.mergeMinIterations;
                }
                previousGrammar = grammar = maxGrammar;
                previousLexicon = lexicon = maxLexicon;
                droppingIter = 0;
                numSubStatesArray = grammar.numSubStates;
                trainStateSetTrees = new StateSetTreeList(trainStateSetTrees, numSubStatesArray, false);
                validationStateSetTrees = new StateSetTreeList(validationStateSetTrees, numSubStatesArray, false);
                maxLikelihood = GrammarTrainer.calculateLogLikelihood(maxGrammar, maxLexicon, validationStateSetTrees);
                System.out.println("After " + opString + " in the " + (splitIndex / 3 + 1) + "th round, we get a validation likelihood of " + maxLikelihood);
                iter = 0;
                do {
                    System.out.println("Beginning iteration " + (++iter - 1) + ":");
                    System.out.print("Calculating validation likelihood...");
                    double validationLikelihood = GrammarTrainer.calculateLogLikelihood(previousGrammar, previousLexicon, validationStateSetTrees);
                    System.out.println("done: " + validationLikelihood);
                    System.out.print("Calculating training likelihood...");
                    grammar = new Grammar(grammar.numSubStates, grammar.findClosedPaths, grammar.smoother, grammar, grammar.threshold);
                    lexicon = opts.featurizedLexicon ? lexicon.copyLexicon() : (opts.simpleLexicon ? new SimpleLexicon(grammar.numSubStates, -1, smoothParams, lexicon.getSmoother(), filter, trainStateSetTrees) : new SophisticatedLexicon(grammar.numSubStates, SophisticatedLexicon.DEFAULT_SMOOTHING_CUTOFF, lexicon.getSmoothingParams(), lexicon.getSmoother(), lexicon.getPruningThreshold()));
                    boolean updateOnlyLexicon = false;
                    double trainingLikelihood = GrammarTrainer.doOneEStep(previousGrammar, previousLexicon, grammar, lexicon, trainStateSetTrees, updateOnlyLexicon, opts.rare);
                    System.out.println("done: " + trainingLikelihood);
                    lexicon.optimize();
                    grammar.optimize(0.0);
                    if (iter < minIterations || validationLikelihood >= maxLikelihood) {
                        maxLikelihood = validationLikelihood;
                        maxGrammar = previousGrammar;
                        maxLexicon = previousLexicon;
                        droppingIter = 0;
                    } else {
                        ++droppingIter;
                    }
                    previousGrammar = grammar;
                    previousLexicon = lexicon;
                } while (droppingIter < allowedDroppingIters && !baseline && iter < maxIterations);
                ParserData pData2 = new ParserData(maxLexicon, maxGrammar, null, Numberer.getNumberers(), numSubStatesArray, GrammarTrainer.VERTICAL_MARKOVIZATION, GrammarTrainer.HORIZONTAL_MARKOVIZATION, binarization);
                String outTmpName = String.valueOf(outFileName) + "_" + (splitIndex / 3 + 1) + "_" + opString + ".gr";
                System.out.println("Saving grammar to " + outTmpName + ".");
                if (pData2.Save(outTmpName)) {
                    System.out.println("Saving successful.");
                } else {
                    System.out.println("Saving failed!");
                }
                pData2 = null;
            }
            ++splitIndex;
        }
        System.out.print("Calculating last validation likelihood...");
        double validationLikelihood = GrammarTrainer.calculateLogLikelihood(grammar, lexicon, validationStateSetTrees);
        System.out.println("done.\n  Iteration " + iter + " (final) gives validation likelihood " + validationLikelihood);
        if (validationLikelihood > maxLikelihood) {
            maxLikelihood = validationLikelihood;
            maxGrammar = previousGrammar;
            maxLexicon = previousLexicon;
        }
        ParserData pData3 = new ParserData(maxLexicon, maxGrammar, null, Numberer.getNumberers(), numSubStatesArray, GrammarTrainer.VERTICAL_MARKOVIZATION, GrammarTrainer.HORIZONTAL_MARKOVIZATION, binarization);
        System.out.println("Saving grammar to " + outFileName + ".");
        System.out.println("It gives a validation data log likelihood of: " + maxLikelihood);
        if (pData3.Save(outFileName)) {
            System.out.println("Saving successful.");
        } else {
            System.out.println("Saving failed!");
        }
    }

    private List<Tree<String>> getValidationData(List<Tree<String>> target) {
        ArrayList<Tree<String>> ret = new ArrayList<Tree<String>>();
        int cnt = 0;
        for (Tree<String> t : target) {
            if (cnt % 20 == 0) {
                ret.add(t);
            }
            ++cnt;
        }
        return ret;
    }

    public static void main(String[] args) {
        String path = "C:/Lab/External_Resources/sejong_to_treebank/convertedF/BaselineF/BaselineF";
        Corpus corpus = new Corpus(path, Corpus.TreeBankType.KOREAN, 1.0, false, -1, false, false);
        List<Tree<String>> trees = Corpus.binarizeAndFilterTrees(corpus.getTrainTrees(), 1, 0, 10000, Binarization.RIGHT, false, false);
        Trainer trainer = new Trainer();
        trainer.train(trees, "KorGrammar_BerkF_ORIG");
    }
}

