package com.googlecode.clearnlp.nlp;

import com.carrotsearch.hppc.ObjectIntOpenHashMap;
import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.component.AbstractStatisticalComponent;
import com.googlecode.clearnlp.component.dep.CDEPBackParser;
import com.googlecode.clearnlp.component.dep.CDEPPassParser;
import com.googlecode.clearnlp.component.pos.CPOSBackTagger;
import com.googlecode.clearnlp.component.pos.CPOSTagger;
import com.googlecode.clearnlp.component.srl.CPredIdentifier;
import com.googlecode.clearnlp.component.srl.CRolesetClassifier;
import com.googlecode.clearnlp.component.srl.CSRLabeler;
import com.googlecode.clearnlp.component.srl.CSenseClassifier;
import com.googlecode.clearnlp.constituent.CTLibEn;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.engine.EngineProcess;
import com.googlecode.clearnlp.feature.xml.JointFtrXml;
import com.googlecode.clearnlp.reader.JointReader;
import com.googlecode.clearnlp.util.UTFile;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTXml;
import com.googlecode.clearnlp.util.map.Prob1DMap;
import java.io.BufferedOutputStream;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.zip.ZipOutputStream;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;

/* loaded from: input_file:com/googlecode/clearnlp/nlp/NLPTrain.class */
public class NLPTrain extends AbstractNLP {

    @Option(name = "-c", usage = "configuration file (required)", required = true, metaVar = "<filename>")
    protected String s_configFile;

    @Option(name = "-f", usage = "feature template files delimited by ':' (required)", required = true, metaVar = "<filename>")
    protected String s_featureFiles;

    @Option(name = "-i", usage = "input directory containing training files (required)", required = true, metaVar = "<directory>")
    protected String s_trainDir;

    @Option(name = "-m", usage = "model file (output; required)", required = true, metaVar = "<filename>")
    protected String s_modelFile;

    @Option(name = "-z", usage = "mode (pos|morph|dep|pred|role|srl)", required = true, metaVar = "<string>")
    protected String s_mode;
    protected final String DELIM_FILES = ":";

    @Option(name = "-n", usage = "bootstrapping level (default: 2)", required = false, metaVar = "<integer>")
    protected int n_boot = 0;

    @Option(name = "-margin", usage = "margin between the 1st and 2nd predictions (default: 0.5)", required = false, metaVar = "<double>")
    protected double d_margin = 0.5d;

    @Option(name = "-beams", usage = "the size of beam (default: 0)", required = false, metaVar = "<double>")
    protected int n_beams = 0;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/googlecode/clearnlp/nlp/NLPTrain$TrainTask.class */
    public class TrainTask implements Runnable {
        AbstractStatisticalComponent j_component;
        JointReader j_reader;

        public TrainTask(Element element, String str, AbstractStatisticalComponent abstractStatisticalComponent) {
            this.j_reader = NLPTrain.this.getJointReader(UTXml.getFirstElementByTagName(element, "reader"));
            this.j_reader.open(UTInput.createBufferedFileReader(str));
            this.j_component = abstractStatisticalComponent;
        }

        @Override // java.lang.Runnable
        public void run() {
            while (true) {
                DEPTree next = this.j_reader.next();
                if (next == null) {
                    this.j_reader.close();
                    System.out.print(CTLibEn.POS_PERIOD);
                    return;
                }
                this.j_component.process(next);
            }
        }
    }

    public NLPTrain() {
    }

    public NLPTrain(String[] strArr) {
        initArgs(strArr);
        try {
            train(this.s_configFile, this.s_featureFiles.split(":"), this.s_trainDir, this.s_modelFile, this.s_mode);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void train(String str, String[] strArr, String str2, String str3, String str4) throws Exception {
        Element documentElement = UTXml.getDocumentElement(new FileInputStream(str));
        getComponent(documentElement, getJointReader(UTXml.getFirstElementByTagName(documentElement, "reader")), getFeatureTemplates(strArr), UTFile.getSortedFileListBySize(str2, ".*", true), -1, str4).saveModels(new ZipOutputStream(new BufferedOutputStream(new FileOutputStream(str3))));
    }

    protected AbstractStatisticalComponent getComponent(Element element, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, int i, String str) {
        if (str.equals("pos")) {
            return getTrainedComponent(element, jointReader, jointFtrXmlArr, strArr, new CPOSTagger(jointFtrXmlArr, getLowerSimplifiedForms(jointReader, jointFtrXmlArr[0], strArr, i)), str, i);
        }
        if (str.equals("dep")) {
            return getTrainedComponent(element, jointReader, jointFtrXmlArr, strArr, new CDEPPassParser(jointFtrXmlArr), str, i);
        }
        if (str.equals(NLPLib.MODE_PRED)) {
            return getTrainedComponent(element, jointFtrXmlArr, strArr, null, null, str, 0, i);
        }
        if (str.equals("role")) {
            return getTrainedComponent(element, jointReader, jointFtrXmlArr, strArr, new CRolesetClassifier(jointFtrXmlArr), str, i);
        }
        if (str.startsWith(NLPLib.MODE_SENSE)) {
            return getTrainedComponent(element, jointReader, jointFtrXmlArr, strArr, new CSenseClassifier(jointFtrXmlArr, str.substring(str.lastIndexOf("_") + 1)), str, i);
        }
        if (str.equals("srl")) {
            return getTrainedComponent(element, jointReader, jointFtrXmlArr, strArr, new CSRLabeler(jointFtrXmlArr), str, i);
        }
        if (str.equals(NLPLib.MODE_POS_BACK)) {
            return getTrainedComponent(element, jointReader, jointFtrXmlArr, strArr, new CPOSTagger(jointFtrXmlArr, getLowerSimplifiedForms(jointReader, jointFtrXmlArr[0], strArr, i)), str, i);
        }
        if (str.equals(NLPLib.MODE_DEP_BACK)) {
            return getTrainedComponent(element, jointReader, jointFtrXmlArr, strArr, new CDEPBackParser(jointFtrXmlArr), str, i);
        }
        throw new IllegalArgumentException("The requested mode '" + str + "' is not supported.");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractStatisticalComponent getComponent(JointFtrXml[] jointFtrXmlArr, StringModel[] stringModelArr, Object[] objArr, String str) {
        if (str.equals("pos")) {
            return new CPOSTagger(jointFtrXmlArr, stringModelArr, objArr);
        }
        if (str.equals("dep")) {
            return new CDEPPassParser(jointFtrXmlArr, stringModelArr, objArr);
        }
        if (str.equals(NLPLib.MODE_PRED)) {
            return new CPredIdentifier(jointFtrXmlArr, stringModelArr, objArr);
        }
        if (str.equals("role")) {
            return new CRolesetClassifier(jointFtrXmlArr, stringModelArr, objArr);
        }
        if (str.startsWith(NLPLib.MODE_SENSE)) {
            return new CSenseClassifier(jointFtrXmlArr, stringModelArr, objArr, str.substring(str.lastIndexOf("_") + 1));
        }
        if (str.equals("srl")) {
            return new CSRLabeler(jointFtrXmlArr, stringModelArr, objArr);
        }
        if (str.equals(NLPLib.MODE_POS_BACK)) {
            return new CPOSBackTagger(jointFtrXmlArr, stringModelArr, objArr, this.d_margin);
        }
        if (str.equals(NLPLib.MODE_DEP_BACK)) {
            return new CDEPBackParser(jointFtrXmlArr, stringModelArr, objArr, this.d_margin, this.n_beams);
        }
        throw new IllegalArgumentException("The requested mode '" + str + "' is not supported.");
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractStatisticalComponent getTrainedComponent(Element element, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, AbstractStatisticalComponent abstractStatisticalComponent, String str, int i) {
        Object[] lexica = getLexica(abstractStatisticalComponent, jointReader, jointFtrXmlArr, strArr, i);
        AbstractStatisticalComponent abstractStatisticalComponent2 = null;
        StringModel[] stringModelArr = null;
        for (int i2 = 0; i2 <= this.n_boot; i2++) {
            abstractStatisticalComponent2 = getTrainedComponent(element, jointFtrXmlArr, strArr, stringModelArr, lexica, str, i2, i);
            stringModelArr = abstractStatisticalComponent2.getModels();
        }
        return abstractStatisticalComponent2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public JointFtrXml[] getFeatureTemplates(String[] strArr) throws Exception {
        int length = strArr.length;
        JointFtrXml[] jointFtrXmlArr = new JointFtrXml[length];
        for (int i = 0; i < length; i++) {
            jointFtrXmlArr[i] = new JointFtrXml(new FileInputStream(strArr[i]));
        }
        return jointFtrXmlArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Object[] getLexica(AbstractStatisticalComponent abstractStatisticalComponent, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, int i) {
        int length = strArr.length;
        System.out.println("Collecting lexica:");
        for (int i2 = 0; i2 < length; i2++) {
            if (i != i2) {
                jointReader.open(UTInput.createBufferedFileReader(strArr[i2]));
                while (true) {
                    DEPTree next = jointReader.next();
                    if (next == null) {
                        break;
                    }
                    abstractStatisticalComponent.process(next);
                }
                jointReader.close();
                System.out.print(CTLibEn.POS_PERIOD);
            }
        }
        System.out.println();
        return abstractStatisticalComponent.getLexica();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Set<String> getLowerSimplifiedForms(JointReader jointReader, JointFtrXml jointFtrXml, String[] strArr, int i) {
        HashSet hashSet = new HashSet();
        int length = strArr.length;
        Prob1DMap prob1DMap = new Prob1DMap();
        System.out.println("Collecting word-forms:");
        for (int i2 = 0; i2 < length; i2++) {
            if (i != i2) {
                jointReader.open(UTInput.createBufferedFileReader(strArr[i2]));
                hashSet.clear();
                while (true) {
                    DEPTree next = jointReader.next();
                    if (next == null) {
                        break;
                    }
                    EngineProcess.normalizeForms(next);
                    int size = next.size();
                    for (int i3 = 1; i3 < size; i3++) {
                        hashSet.add(next.get(i3).lowerSimplifiedForm);
                    }
                }
                jointReader.close();
                prob1DMap.addAll(hashSet);
                System.out.print(CTLibEn.POS_PERIOD);
            }
        }
        System.out.println();
        return prob1DMap.toSet(jointFtrXml.getDocumentFrequencyCutoff());
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractStatisticalComponent getTrainedComponent(Element element, JointFtrXml[] jointFtrXmlArr, String[] strArr, StringModel[] stringModelArr, Object[] objArr, String str, int i, int i2) {
        StringTrainSpace[] stringTrainSpaces = getStringTrainSpaces(element, jointFtrXmlArr, strArr, stringModelArr, objArr, str, i, i2);
        Element firstElementByTagName = UTXml.getFirstElementByTagName(element, str);
        int length = stringTrainSpaces.length;
        StringModel[] stringModelArr2 = new StringModel[length];
        for (int i3 = 0; i3 < length; i3++) {
            if (str.equals("role") || str.startsWith(NLPLib.MODE_SENSE)) {
                stringModelArr2[i3] = (StringModel) getModel(firstElementByTagName, stringTrainSpaces[i3], 0, i);
            } else {
                stringModelArr2[i3] = (StringModel) getModel(firstElementByTagName, stringTrainSpaces[i3], i3, i);
            }
            stringTrainSpaces[i3].clear();
        }
        return getComponent(jointFtrXmlArr, stringModelArr2, objArr, str);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public StringTrainSpace[] getStringTrainSpaces(Element element, JointFtrXml[] jointFtrXmlArr, String[] strArr, StringModel[] stringModelArr, Object[] objArr, String str, int i, int i2) {
        Element firstElementByTagName = UTXml.getFirstElementByTagName(element, str);
        int length = strArr.length;
        int numOfThreads = getNumOfThreads(firstElementByTagName);
        ArrayList arrayList = new ArrayList();
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(numOfThreads);
        System.out.println("Collecting training instances:");
        for (int i3 = 0; i3 < length; i3++) {
            if (i2 != i3) {
                StringTrainSpace[] stringTrainSpaces = getStringTrainSpaces(jointFtrXmlArr, objArr, str, i);
                arrayList.add(stringTrainSpaces);
                newFixedThreadPool.execute(new TrainTask(element, strArr[i3], getComponent(jointFtrXmlArr, stringTrainSpaces, stringModelArr, objArr, str)));
            }
        }
        newFixedThreadPool.shutdown();
        try {
            newFixedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
        System.out.println();
        int length2 = ((StringTrainSpace[]) arrayList.get(0)).length;
        StringTrainSpace[] stringTrainSpaceArr = new StringTrainSpace[length2];
        for (int i4 = 0; i4 < length2; i4++) {
            stringTrainSpaceArr[i4] = ((StringTrainSpace[]) arrayList.get(0))[i4];
            int size = arrayList.size();
            if (size > 1) {
                System.out.println("Merging training instances:");
                for (int i5 = 1; i5 < size; i5++) {
                    StringTrainSpace stringTrainSpace = stringTrainSpaceArr[i4];
                    StringTrainSpace stringTrainSpace2 = ((StringTrainSpace[]) arrayList.get(i5))[i4];
                    stringTrainSpace.appendSpace(stringTrainSpace2);
                    stringTrainSpace2.clear();
                    System.out.print(CTLibEn.POS_PERIOD);
                }
                System.out.println();
            }
        }
        return stringTrainSpaceArr;
    }

    protected AbstractStatisticalComponent getComponent(JointFtrXml[] jointFtrXmlArr, StringTrainSpace[] stringTrainSpaceArr, StringModel[] stringModelArr, Object[] objArr, String str) {
        if (str.equals("pos")) {
            return new CPOSTagger(jointFtrXmlArr, stringTrainSpaceArr, objArr);
        }
        if (str.equals("dep")) {
            return stringModelArr == null ? new CDEPPassParser(jointFtrXmlArr, stringTrainSpaceArr, objArr) : new CDEPPassParser(jointFtrXmlArr, stringTrainSpaceArr, stringModelArr, objArr);
        }
        if (str.equals(NLPLib.MODE_PRED)) {
            return new CPredIdentifier(jointFtrXmlArr, stringTrainSpaceArr, objArr);
        }
        if (str.equals("role")) {
            return new CRolesetClassifier(jointFtrXmlArr, stringTrainSpaceArr, objArr);
        }
        if (str.startsWith(NLPLib.MODE_SENSE)) {
            return new CSenseClassifier(jointFtrXmlArr, stringTrainSpaceArr, objArr, str.substring(str.lastIndexOf("_") + 1));
        }
        if (str.equals("srl")) {
            return stringModelArr == null ? new CSRLabeler(jointFtrXmlArr, stringTrainSpaceArr, objArr) : new CSRLabeler(jointFtrXmlArr, stringTrainSpaceArr, stringModelArr, objArr);
        }
        if (str.equals(NLPLib.MODE_POS_BACK)) {
            return stringModelArr == null ? new CPOSBackTagger(jointFtrXmlArr, stringTrainSpaceArr, objArr, this.d_margin) : new CPOSBackTagger(jointFtrXmlArr, stringTrainSpaceArr, stringModelArr, objArr, this.d_margin);
        }
        if (str.equals(NLPLib.MODE_DEP_BACK)) {
            return stringModelArr == null ? new CDEPBackParser(jointFtrXmlArr, stringTrainSpaceArr, objArr, this.d_margin, this.n_beams) : new CDEPBackParser(jointFtrXmlArr, stringTrainSpaceArr, stringModelArr, objArr, this.d_margin, this.n_beams);
        }
        throw new IllegalArgumentException("The requested mode '" + str + "' is not supported.");
    }

    protected StringTrainSpace[] getStringTrainSpaces(JointFtrXml[] jointFtrXmlArr, Object[] objArr, String str, int i) {
        return (str.equals("role") || str.startsWith(NLPLib.MODE_SENSE)) ? getStringTrainSpaces(jointFtrXmlArr[0], ((ObjectIntOpenHashMap) objArr[1]).size()) : str.equals("srl") ? getStringTrainSpaces(jointFtrXmlArr[0], 2) : (i <= 0 || !str.equals(NLPLib.MODE_DEP_BACK)) ? getStringTrainSpaces(jointFtrXmlArr) : getStringTrainSpaces(jointFtrXmlArr, 1);
    }

    private StringTrainSpace[] getStringTrainSpaces(JointFtrXml[] jointFtrXmlArr) {
        return getStringTrainSpaces(jointFtrXmlArr, 0);
    }

    private StringTrainSpace[] getStringTrainSpaces(JointFtrXml[] jointFtrXmlArr, int i) {
        int length = jointFtrXmlArr.length;
        StringTrainSpace[] stringTrainSpaceArr = new StringTrainSpace[length];
        for (int i2 = 0; i2 < length; i2++) {
            stringTrainSpaceArr[i2] = new StringTrainSpace(false, jointFtrXmlArr[i2].getLabelCutoff(i), jointFtrXmlArr[i2].getFeatureCutoff(i));
        }
        return stringTrainSpaceArr;
    }

    private StringTrainSpace[] getStringTrainSpaces(JointFtrXml jointFtrXml, int i) {
        StringTrainSpace[] stringTrainSpaceArr = new StringTrainSpace[i];
        for (int i2 = 0; i2 < i; i2++) {
            stringTrainSpaceArr[i2] = new StringTrainSpace(false, jointFtrXml.getLabelCutoff(0), jointFtrXml.getFeatureCutoff(0));
        }
        return stringTrainSpaceArr;
    }

    public static void main(String[] strArr) {
        new NLPTrain(strArr);
    }
}
