package com.googlecode.clearnlp.run;

import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.constituent.CTLibEn;
import com.googlecode.clearnlp.dependency.DEPNode;
import com.googlecode.clearnlp.dependency.DEPParser;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.engine.EngineSetter;
import com.googlecode.clearnlp.feature.xml.DEPFtrXml;
import com.googlecode.clearnlp.reader.DEPReader;
import com.googlecode.clearnlp.util.UTFile;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTXml;
import com.googlecode.clearnlp.util.map.Prob1DMap;
import java.io.FileInputStream;
import java.util.ArrayList;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;

/* loaded from: input_file:com/googlecode/clearnlp/run/DEPTrain.class */
public class DEPTrain extends AbstractRun {

    @Option(name = "-i", usage = "input directory containing training files (required)", required = true, metaVar = "<directory>")
    protected String s_trainDir;

    @Option(name = "-c", usage = "configuration file (required)", required = true, metaVar = "<filename>")
    protected String s_configXml;

    @Option(name = "-f", usage = "feature template file (required)", required = true, metaVar = "<filename>")
    protected String s_featureXml;

    @Option(name = "-m", usage = "model file (output; required)", required = true, metaVar = "<filename>")
    protected String s_modelFile;
    protected final String LEXICON_PUNCTUATION = "punctuation";

    @Option(name = "-n", usage = "bootstrapping level (default: 2)", required = false, metaVar = "<integer>")
    protected int n_boot = 2;

    @Option(name = "-sb", usage = "if set, save all intermediate bootstrapping models", required = false, metaVar = "<boolean>")
    protected boolean b_saveAllModels = false;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/googlecode/clearnlp/run/DEPTrain$TrainTask.class */
    public class TrainTask implements Runnable {
        DEPParser d_parser;
        DEPReader d_reader;

        public TrainTask(Element element, DEPFtrXml dEPFtrXml, Set<String> set, String str, StringModel stringModel, StringTrainSpace stringTrainSpace) {
            this.d_parser = stringModel == null ? new DEPParser(dEPFtrXml, set, stringTrainSpace) : new DEPParser(dEPFtrXml, set, stringModel, stringTrainSpace);
            this.d_reader = (DEPReader) DEPTrain.this.getReader(element).o1;
            this.d_reader.open(UTInput.createBufferedFileReader(str));
        }

        @Override // java.lang.Runnable
        public void run() {
            while (true) {
                DEPTree next = this.d_reader.next();
                if (next == null) {
                    this.d_reader.close();
                    System.out.print(CTLibEn.POS_PERIOD);
                    return;
                }
                this.d_parser.parse(next);
            }
        }
    }

    public DEPTrain() {
    }

    public DEPTrain(String[] strArr) {
        initArgs(strArr);
        try {
            run(this.s_configXml, this.s_featureXml, this.s_trainDir, this.s_modelFile, this.n_boot);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void run(String str, String str2, String str3, String str4, int i) throws Exception {
        Element documentElement = UTXml.getDocumentElement(new FileInputStream(str));
        DEPFtrXml dEPFtrXml = new DEPFtrXml(new FileInputStream(str2));
        String[] sortedFileListBySize = UTFile.getSortedFileListBySize(str3, ".*", true);
        Set<String> lexica = getLexica(documentElement, dEPFtrXml, sortedFileListBySize, -1);
        DEPParser trainedParser = getTrainedParser(documentElement, dEPFtrXml, lexica, sortedFileListBySize, null, -1, 0);
        if (this.b_saveAllModels) {
            EngineSetter.setDEPParser(str4 + CTLibEn.POS_PERIOD + 0, str2, trainedParser);
        }
        for (int i2 = 1; i2 <= i; i2++) {
            trainedParser = getTrainedParser(documentElement, dEPFtrXml, lexica, sortedFileListBySize, trainedParser.getModel(), -1, i2);
            if (this.b_saveAllModels) {
                EngineSetter.setDEPParser(str4 + CTLibEn.POS_PERIOD + i2, str2, trainedParser);
            }
        }
        if (this.b_saveAllModels) {
            return;
        }
        EngineSetter.setDEPParser(str4, str2, trainedParser);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Set<String> getLexica(Element element, DEPFtrXml dEPFtrXml, String[] strArr, int i) throws Exception {
        DEPReader dEPReader = (DEPReader) getReader(element).o1;
        Prob1DMap prob1DMap = new Prob1DMap();
        int length = strArr.length;
        System.out.println("Collecting lexica:");
        for (int i2 = 0; i2 < length; i2++) {
            if (i2 != i) {
                dEPReader.open(UTInput.createBufferedFileReader(strArr[i2]));
                while (true) {
                    DEPTree next = dEPReader.next();
                    if (next == null) {
                        break;
                    }
                    collectLexica(next, prob1DMap, dEPFtrXml.getPunctuationLabel());
                }
                System.out.print(CTLibEn.POS_PERIOD);
                dEPReader.close();
            }
        }
        System.out.println();
        return prob1DMap.toSet(dEPFtrXml.getPunctuationCutoff());
    }

    private void collectLexica(DEPTree dEPTree, Prob1DMap prob1DMap, String str) {
        int size = dEPTree.size();
        for (int i = 1; i < size; i++) {
            DEPNode dEPNode = dEPTree.get(i);
            if (dEPNode.isLabel(str)) {
                prob1DMap.add(dEPNode.form);
            }
        }
    }

    public DEPParser getTrainedParser(Element element, DEPFtrXml dEPFtrXml, Set<String> set, String[] strArr, StringModel stringModel, int i, int i2) throws Exception {
        int length = strArr.length;
        int labelCutoff = dEPFtrXml.getLabelCutoff(0);
        int featureCutoff = dEPFtrXml.getFeatureCutoff(0);
        Element firstElementByTagName = UTXml.getFirstElementByTagName(element, "train");
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(getNumOfThreads(firstElementByTagName));
        ArrayList arrayList = new ArrayList();
        System.out.println("Collecting training instances:");
        for (int i3 = 0; i3 < length; i3++) {
            if (i != i3) {
                StringTrainSpace stringTrainSpace = new StringTrainSpace(false, labelCutoff, featureCutoff);
                arrayList.add(stringTrainSpace);
                newFixedThreadPool.execute(new TrainTask(element, dEPFtrXml, set, strArr[i3], stringModel, stringTrainSpace));
            }
        }
        newFixedThreadPool.shutdown();
        try {
            newFixedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println();
        StringTrainSpace stringTrainSpace2 = (StringTrainSpace) arrayList.get(0);
        int size = arrayList.size();
        if (size > 1) {
            System.out.println("Merging training instances:");
            for (int i4 = 1; i4 < size; i4++) {
                stringTrainSpace2.appendSpace((StringTrainSpace) arrayList.get(i4));
                ((StringTrainSpace) arrayList.get(i4)).clear();
                System.out.print(CTLibEn.POS_PERIOD);
            }
            System.out.println();
        }
        return new DEPParser(dEPFtrXml, set, (StringModel) getModel(firstElementByTagName, stringTrainSpace2, 0));
    }

    protected void printScores(int[] iArr) {
        System.out.printf("- LAS: %5.2f (%d/%d)\n", Double.valueOf((100.0d * iArr[1]) / iArr[0]), Integer.valueOf(iArr[1]), Integer.valueOf(iArr[0]));
        System.out.printf("- UAS: %5.2f (%d/%d)\n", Double.valueOf((100.0d * iArr[2]) / iArr[0]), Integer.valueOf(iArr[2]), Integer.valueOf(iArr[0]));
        System.out.printf("- LS : %5.2f (%d/%d)\n", Double.valueOf((100.0d * iArr[3]) / iArr[0]), Integer.valueOf(iArr[3]), Integer.valueOf(iArr[0]));
    }

    public static void main(String[] strArr) {
        new DEPTrain(strArr);
    }
}
