package com.googlecode.clearnlp.run;

import com.carrotsearch.hppc.cursors.ObjectCursor;
import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.constituent.CTLibEn;
import com.googlecode.clearnlp.engine.EngineProcess;
import com.googlecode.clearnlp.engine.EngineSetter;
import com.googlecode.clearnlp.feature.xml.POSFtrXml;
import com.googlecode.clearnlp.pos.POSLib;
import com.googlecode.clearnlp.pos.POSNode;
import com.googlecode.clearnlp.pos.POSTagger;
import com.googlecode.clearnlp.reader.POSReader;
import com.googlecode.clearnlp.util.UTFile;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTXml;
import com.googlecode.clearnlp.util.list.SortedDoubleArrayList;
import com.googlecode.clearnlp.util.map.Prob1DMap;
import com.googlecode.clearnlp.util.pair.Pair;
import java.io.FileInputStream;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Set;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;

/* loaded from: input_file:com/googlecode/clearnlp/run/POSTrain.class */
public class POSTrain extends AbstractRun {
    public static final int MODEL_SIZE = 2;

    @Option(name = "-i", usage = "directory containg training files (required)", required = true, metaVar = "<directory>")
    protected String s_trainDir;

    @Option(name = "-c", usage = "configuration file (required)", required = true, metaVar = "<filename>")
    protected String s_configXml;

    @Option(name = "-f", usage = "feature template file (required)", required = true, metaVar = "<filename>")
    protected String s_featureXml;

    @Option(name = "-m", usage = "model file (output; required)", required = true, metaVar = "<filename>")
    protected String s_modelFile;
    protected final int FLAG_DOMAIN = 0;
    protected final int FLAG_GENERAL = 1;
    protected final int FLAG_DYNAMIC = 2;

    @Option(name = "-t", usage = "similarity threshold (default: -1)", required = false, metaVar = "<double>")
    protected double d_threshold = -1.0d;

    @Option(name = "-s", usage = "model type - 0|1|2 (default: 1)\n0: train only a domain-specific model\n1: train only a generalized model\n2: train both models using dynamic model selection", required = false, metaVar = "<integer>")
    protected int i_flag = 1;

    public POSTrain() {
    }

    public POSTrain(String[] strArr) {
        initArgs(strArr);
        try {
            run(this.s_configXml, this.s_featureXml, this.s_trainDir, this.s_modelFile, this.d_threshold, this.i_flag);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void run(String str, String str2, String str3, String str4, double d, int i) throws Exception {
        Element documentElement = UTXml.getDocumentElement(new FileInputStream(str));
        POSReader pOSReader = (POSReader) getReader(documentElement).o1;
        POSFtrXml pOSFtrXml = new POSFtrXml(new FileInputStream(str2));
        String[] sortedFileList = UTFile.getSortedFileList(str3);
        if (i == 2) {
            if (d < 0.0d) {
                d = crossValidate(sortedFileList, pOSReader, pOSFtrXml, documentElement);
            }
            EngineSetter.setPOSTaggers(str4, str2, getTrainedTaggers(documentElement, pOSReader, pOSFtrXml, sortedFileList, -1), d, 2);
        } else {
            POSTagger[] pOSTaggerArr = {getTrainedTagger(documentElement, pOSReader, pOSFtrXml, sortedFileList, -1, i)};
            pOSTaggerArr[0].clearFormSet();
            EngineSetter.setPOSTaggers(str4, str2, pOSTaggerArr, d, 1);
        }
    }

    public POSTagger getTrainedTagger(Element element, POSReader pOSReader, POSFtrXml pOSFtrXml, String[] strArr, int i, int i2) throws Exception {
        Set<String> lemmaSet = getLemmaSet(pOSReader, pOSFtrXml, i2, strArr, i);
        Pair<Set<String>, Map<String, String>> lexica = getLexica(pOSReader, pOSFtrXml, i2, lemmaSet, strArr, i);
        return new POSTagger(pOSFtrXml, lemmaSet, lexica.o1, lexica.o2, (StringModel) getModel(UTXml.getFirstElementByTagName(element, "train"), getTrainSpace(pOSReader, pOSFtrXml, i2, lemmaSet, lexica.o1, lexica.o2, strArr, i), i2));
    }

    public POSTagger[] getTrainedTaggers(Element element, POSReader pOSReader, POSFtrXml pOSFtrXml, String[] strArr, int i) throws Exception {
        POSTagger[] pOSTaggerArr = new POSTagger[2];
        for (int i2 = 0; i2 < 2; i2++) {
            System.out.printf("===== Training model %d =====\n", Integer.valueOf(i2));
            pOSTaggerArr[i2] = getTrainedTagger(element, pOSReader, pOSFtrXml, strArr, i, i2);
        }
        return pOSTaggerArr;
    }

    private Set<String> getLemmaSet(POSReader pOSReader, POSFtrXml pOSFtrXml, int i, String[] strArr, int i2) throws Exception {
        int documentFrequency = pOSFtrXml.getDocumentFrequency(i);
        Prob1DMap prob1DMap = new Prob1DMap();
        int length = strArr.length;
        System.out.println("Collecting n-gram set:");
        System.out.println("- document frequency cutoff: " + documentFrequency);
        for (int i3 = 0; i3 < length; i3++) {
            if (i2 != i3) {
                pOSReader.open(UTInput.createBufferedFileReader(strArr[i3]));
                HashSet hashSet = new HashSet();
                while (true) {
                    POSNode[] next = pOSReader.next();
                    if (next == null) {
                        break;
                    }
                    EngineProcess.normalizeForms(next);
                    for (POSNode pOSNode : next) {
                        hashSet.add(pOSNode.lemma);
                    }
                }
                pOSReader.close();
                Iterator it = hashSet.iterator();
                while (it.hasNext()) {
                    prob1DMap.add((String) it.next());
                }
            }
        }
        HashSet hashSet2 = new HashSet();
        Iterator it2 = prob1DMap.keys().iterator();
        while (it2.hasNext()) {
            String str = (String) ((ObjectCursor) it2.next()).value;
            if (prob1DMap.get(str) > documentFrequency) {
                hashSet2.add(str);
            }
        }
        System.out.printf("- lemma reduction: %d -> %d\n", Integer.valueOf(prob1DMap.size()), Integer.valueOf(hashSet2.size()));
        return hashSet2;
    }

    private Pair<Set<String>, Map<String, String>> getLexica(POSReader pOSReader, POSFtrXml pOSFtrXml, int i, Set<String> set, String[] strArr, int i2) {
        POSTagger pOSTagger = new POSTagger(set);
        int length = strArr.length;
        int featureCutoff = pOSFtrXml.getFeatureCutoff(i);
        double ambiguityThreshold = pOSFtrXml.getAmbiguityThreshold(i);
        System.out.println("Collecting lexica:");
        System.out.println("- lexica cutoff: " + featureCutoff);
        System.out.println("- ambiguity class threshold: " + ambiguityThreshold);
        for (int i3 = 0; i3 < length; i3++) {
            if (i2 != i3) {
                pOSReader.open(UTInput.createBufferedFileReader(strArr[i3]));
                while (true) {
                    POSNode[] next = pOSReader.next();
                    if (next == null) {
                        break;
                    }
                    pOSTagger.tag(next);
                }
                pOSReader.close();
            }
        }
        Set<String> formSet = pOSTagger.getFormSet(featureCutoff);
        Map<String, String> ambiguityMap = pOSTagger.getAmbiguityMap(ambiguityThreshold);
        System.out.println("- # of word-forms: " + formSet.size());
        System.out.println("- # of word-forms with ambiguity classes: " + ambiguityMap.size());
        return new Pair<>(formSet, ambiguityMap);
    }

    private StringTrainSpace getTrainSpace(POSReader pOSReader, POSFtrXml pOSFtrXml, int i, Set<String> set, Set<String> set2, Map<String, String> map, String[] strArr, int i2) {
        StringTrainSpace stringTrainSpace = new StringTrainSpace(false, pOSFtrXml.getLabelCutoff(i), pOSFtrXml.getFeatureCutoff(i));
        POSTagger pOSTagger = new POSTagger(pOSFtrXml, set, set2, map, stringTrainSpace);
        int length = strArr.length;
        System.out.println("Collecting training instances:");
        for (int i3 = 0; i3 < length; i3++) {
            if (i2 != i3) {
                pOSReader.open(UTInput.createBufferedFileReader(strArr[i3]));
                while (true) {
                    POSNode[] next = pOSReader.next();
                    if (next == null) {
                        break;
                    }
                    pOSTagger.tag(next);
                }
                pOSReader.close();
                System.out.print(CTLibEn.POS_PERIOD);
            }
        }
        System.out.println();
        return stringTrainSpace;
    }

    public double crossValidate(String[] strArr, POSReader pOSReader, POSFtrXml pOSFtrXml, Element element) throws Exception {
        SortedDoubleArrayList sortedDoubleArrayList = new SortedDoubleArrayList();
        int length = strArr.length;
        for (int i = 0; i < length; i++) {
            System.out.printf("<== Cross validation %d ==>\n", Integer.valueOf(i));
            crossValidatePredict(strArr[i], pOSReader, getTrainedTaggers(element, pOSReader, pOSFtrXml, strArr, i), sortedDoubleArrayList);
        }
        double ceil = Math.ceil(sortedDoubleArrayList.get((int) Math.round(sortedDoubleArrayList.size() * 0.05d)) * 1000.0d) / 1000.0d;
        System.out.println("Out-of-domain validation:");
        System.out.println("- threshold: " + ceil);
        return ceil;
    }

    private void crossValidatePredict(String str, POSReader pOSReader, POSTagger[] pOSTaggerArr, SortedDoubleArrayList sortedDoubleArrayList) {
        int[] iArr = new int[2];
        int[] iArr2 = new int[2];
        int i = 0;
        System.out.println("Predicting: " + str);
        pOSReader.open(UTInput.createBufferedFileReader(str));
        while (true) {
            POSNode[] next = pOSReader.next();
            if (next == null) {
                break;
            }
            String[] labels = POSLib.getLabels(next);
            i += labels.length;
            for (int i2 = 0; i2 < 2; i2++) {
                pOSTaggerArr[i2].tag(next);
                iArr[i2] = countCorrect(next, labels);
                int i3 = i2;
                iArr2[i3] = iArr2[i3] + iArr[i2];
            }
            if (iArr[0] > iArr[1]) {
                double cosineSimilarity = pOSTaggerArr[0].getCosineSimilarity(next);
                if (cosineSimilarity > 0.0d) {
                    sortedDoubleArrayList.add(cosineSimilarity);
                }
            }
        }
        pOSReader.close();
        for (int i4 = 0; i4 < 2; i4++) {
            System.out.printf("- accuracy %d: %7.5f (%d/%d)\n", Integer.valueOf(i4), Double.valueOf((100.0d * iArr2[i4]) / i), Integer.valueOf(iArr2[i4]), Integer.valueOf(i));
        }
    }

    private int countCorrect(POSNode[] pOSNodeArr, String[] strArr) {
        int i = 0;
        int length = pOSNodeArr.length;
        for (int i2 = 0; i2 < length; i2++) {
            if (strArr[i2].equals(pOSNodeArr[i2].pos)) {
                i++;
            }
        }
        return i;
    }

    public static void main(String[] strArr) {
        new POSTrain(strArr);
    }
}
