package com.googlecode.clearnlp.run;

import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.constituent.CTLibEn;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.dependency.srl.SRLabeler;
import com.googlecode.clearnlp.engine.EngineSetter;
import com.googlecode.clearnlp.feature.xml.SRLFtrXml;
import com.googlecode.clearnlp.reader.SRLReader;
import com.googlecode.clearnlp.util.UTFile;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTXml;
import com.googlecode.clearnlp.util.pair.Pair;
import java.io.FileInputStream;
import java.util.Set;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;

/* loaded from: input_file:com/googlecode/clearnlp/run/SRLTrain.class */
public class SRLTrain extends AbstractRun {

    @Option(name = "-i", usage = "the directory containg training files (input; required)", required = true, metaVar = "<directory>")
    protected String s_trainDir;

    @Option(name = "-c", usage = "the configuration file (input; required)", required = true, metaVar = "<filename>")
    protected String s_configXml;

    @Option(name = "-f", usage = "the feature file (input; required)", required = true, metaVar = "<filename>")
    protected String s_featureXml;

    @Option(name = "-m", usage = "the model file (output; required)", required = true, metaVar = "<filename>")
    protected String s_modelFile;

    @Option(name = "-n", usage = "the bootstrapping level (default: 2)", required = false, metaVar = "<integer>")
    protected int n_boot = 2;

    @Option(name = "-sb", usage = "if set, save all intermediate bootstrapping models", required = false, metaVar = "<boolean>")
    protected boolean b_saveAllModels = false;

    public SRLTrain() {
    }

    public SRLTrain(String[] strArr) {
        initArgs(strArr);
        try {
            run(this.s_configXml, this.s_featureXml, this.s_trainDir, this.s_modelFile, this.n_boot);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    private void run(String str, String str2, String str3, String str4, int i) throws Exception {
        Element documentElement = UTXml.getDocumentElement(new FileInputStream(str));
        SRLReader sRLReader = (SRLReader) getReader(documentElement).o1;
        SRLFtrXml sRLFtrXml = new SRLFtrXml(new FileInputStream(str2));
        String[] sortedFileList = UTFile.getSortedFileList(str3);
        Pair<Set<String>, Set<String>> downUpSets = getDownUpSets(sRLReader, sRLFtrXml, sortedFileList, -1);
        SRLabeler trainedLabeler = getTrainedLabeler(documentElement, sRLReader, sRLFtrXml, sortedFileList, null, downUpSets.o1, downUpSets.o2, -1);
        if (this.b_saveAllModels) {
            EngineSetter.setSRLabeler(str4 + CTLibEn.POS_PERIOD + 0, str2, trainedLabeler);
        }
        for (int i2 = 1; i2 <= i; i2++) {
            trainedLabeler = getTrainedLabeler(documentElement, sRLReader, sRLFtrXml, sortedFileList, trainedLabeler.getModels(), downUpSets.o1, downUpSets.o2, -1);
            if (this.b_saveAllModels) {
                EngineSetter.setSRLabeler(str4 + CTLibEn.POS_PERIOD + i2, str2, trainedLabeler);
            }
        }
        if (this.b_saveAllModels) {
            return;
        }
        EngineSetter.setSRLabeler(str4, str2, trainedLabeler);
    }

    public Pair<Set<String>, Set<String>> getDownUpSets(SRLReader sRLReader, SRLFtrXml sRLFtrXml, String[] strArr, int i) {
        SRLabeler sRLabeler = new SRLabeler();
        int length = strArr.length;
        System.out.println("Collecting lexica:");
        for (int i2 = 0; i2 < length; i2++) {
            if (i != i2) {
                sRLReader.open(UTInput.createBufferedFileReader(strArr[i2]));
                while (true) {
                    DEPTree next = sRLReader.next();
                    if (next == null) {
                        break;
                    }
                    sRLabeler.label(next);
                }
                System.out.print(CTLibEn.POS_PERIOD);
                sRLReader.close();
            }
        }
        System.out.println();
        Set<String> downSet = sRLabeler.getDownSet(sRLFtrXml.getDownCutoff());
        Set<String> upSet = sRLabeler.getUpSet(sRLFtrXml.getUpCutoff());
        System.out.printf("- down-paths: size = %d, cutoff = %d\n", Integer.valueOf(downSet.size()), Integer.valueOf(sRLFtrXml.getDownCutoff()));
        System.out.printf("- up-paths  : size = %d, cutoff = %d\n", Integer.valueOf(upSet.size()), Integer.valueOf(sRLFtrXml.getUpCutoff()));
        return new Pair<>(downSet, upSet);
    }

    public SRLabeler getTrainedLabeler(Element element, SRLReader sRLReader, SRLFtrXml sRLFtrXml, String[] strArr, StringModel[] stringModelArr, Set<String> set, Set<String> set2, int i) throws Exception {
        StringTrainSpace[] stringTrainSpaceArr = new StringTrainSpace[2];
        int length = strArr.length;
        for (int i2 = 0; i2 < stringTrainSpaceArr.length; i2++) {
            stringTrainSpaceArr[i2] = new StringTrainSpace(false, sRLFtrXml.getLabelCutoff(0), sRLFtrXml.getFeatureCutoff(0));
        }
        SRLabeler sRLabeler = stringModelArr == null ? new SRLabeler(sRLFtrXml, stringTrainSpaceArr, set, set2) : new SRLabeler(sRLFtrXml, stringModelArr, stringTrainSpaceArr, set, set2);
        System.out.println("Collecting training instances:");
        for (int i3 = 0; i3 < length; i3++) {
            if (i != i3) {
                sRLReader.open(UTInput.createBufferedFileReader(strArr[i3]));
                while (true) {
                    DEPTree next = sRLReader.next();
                    if (next == null) {
                        break;
                    }
                    sRLabeler.label(next);
                }
                System.out.print(CTLibEn.POS_PERIOD);
                sRLReader.close();
            }
        }
        System.out.println();
        StringModel[] stringModelArr2 = new StringModel[stringTrainSpaceArr.length];
        for (int i4 = 0; i4 < stringModelArr2.length; i4++) {
            stringModelArr2[i4] = (StringModel) getModel(UTXml.getFirstElementByTagName(element, "train"), stringTrainSpaceArr[i4], i4);
        }
        return new SRLabeler(sRLFtrXml, stringModelArr2, set, set2);
    }

    public static void main(String[] strArr) {
        new SRLTrain(strArr);
    }
}
