/*
 * Decompiled with CFR 0.152.
 */
package kaist.cilab.parser.berkeleyadaptation;

import edu.berkeley.nlp.PCFGLA.Binarization;
import edu.berkeley.nlp.PCFGLA.Corpus;
import edu.berkeley.nlp.syntax.Tree;
import java.io.FileOutputStream;
import java.io.FileWriter;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PrintWriter;
import java.util.LinkedList;
import java.util.List;
import kaist.cilab.parser.berkeleyadaptation.Tester;
import kaist.cilab.parser.berkeleyadaptation.Trainer;

public class CrossValidator {
    public double[] crossValidate(int fold, String treebankPath, String grammarName) throws Exception {
        Corpus corpus = new Corpus(treebankPath, Corpus.TreeBankType.KOREAN, 1.0, false, -1, false, false);
        List<Tree<String>> trees = corpus.getTrainTrees();
        double[] ret = new double[3];
        ret[2] = 0.0;
        ret[1] = 0.0;
        ret[0] = 0.0;
        int i = 0;
        if (i < fold) {
            System.out.println("FOLD " + (i + 1) + "/" + fold);
            System.err.println("Total number of Trees: " + trees.size());
            List<Tree<String>> trainingData = new LinkedList<Tree<String>>();
            LinkedList<Tree<String>> testingData = new LinkedList<Tree<String>>();
            this.getFoldData(trees, i, fold, trainingData, testingData);
            trainingData = Corpus.binarizeAndFilterTrees(trainingData, 1, 0, 10000, Binarization.RIGHT, false, false);
            System.err.println("Total number of Training Trees: " + trainingData.size());
            System.err.println("Total number of Testing Trees: " + testingData.size());
            Trainer trainer = new Trainer();
            trainer.train(trainingData, grammarName);
            Tester tester = new Tester();
            PrintWriter goldPrint = new PrintWriter(new OutputStreamWriter((OutputStream)new FileOutputStream("TESTING_" + grammarName + "_GOLD_" + (i + 1)), "UTF-8"));
            PrintWriter ansPrint = new PrintWriter(new OutputStreamWriter((OutputStream)new FileOutputStream("RESULT_" + grammarName + "_GOLD_" + (i + 1)), "UTF-8"));
            double[] performance = tester.doTest(testingData, grammarName, goldPrint, ansPrint);
            goldPrint.close();
            ansPrint.close();
            System.out.println("P: " + performance[0] + ", R: " + performance[1] + ", F: " + performance[2]);
            ret[0] = ret[0] + performance[0];
            ret[1] = ret[1] + performance[1];
            fold = 1;
        }
        ret[0] = ret[0] / (double)fold;
        ret[1] = ret[1] / (double)fold;
        ret[2] = 2.0 * ret[0] * ret[1] / (ret[0] + ret[1]);
        return ret;
    }

    private void getFoldData(List<Tree<String>> trees, int foldIdx, int foldNum, List<Tree<String>> trainingData, List<Tree<String>> testingData) {
        int cnt = 0;
        for (Tree<String> tr : trees) {
            if (cnt % foldNum == foldIdx) {
                testingData.add(tr);
            } else {
                trainingData.add(tr);
            }
            ++cnt;
        }
    }

    public static void main(String[] args) {
        String treebankPath = args[0];
        String grammarName = args[2];
        CrossValidator cv = new CrossValidator();
        try {
            double[] totalPerf = cv.crossValidate(10, treebankPath, grammarName);
            PrintWriter pw = new PrintWriter(new FileWriter(args[1]));
            pw.println("FOLD: 10");
            pw.println("OVERALL PERFORMANCE");
            pw.println(" PRECISION: " + totalPerf[0]);
            pw.println(" RECALL   : " + totalPerf[1]);
            pw.println(" F1 SCORE : " + totalPerf[2]);
            pw.close();
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }
}

