package com.googlecode.clearnlp.nlp;

import com.googlecode.clearnlp.classification.model.StringModel;
import com.googlecode.clearnlp.classification.train.StringTrainSpace;
import com.googlecode.clearnlp.component.AbstractStatisticalComponent;
import com.googlecode.clearnlp.component.dep.CDEPBackParser;
import com.googlecode.clearnlp.component.dep.CDEPPassParser;
import com.googlecode.clearnlp.component.pos.CPOSBackTagger;
import com.googlecode.clearnlp.component.pos.CPOSTagger;
import com.googlecode.clearnlp.component.srl.CRolesetClassifier;
import com.googlecode.clearnlp.component.srl.CSRLabeler;
import com.googlecode.clearnlp.component.srl.CSenseClassifier;
import com.googlecode.clearnlp.constituent.CTLibEn;
import com.googlecode.clearnlp.dependency.DEPTree;
import com.googlecode.clearnlp.dependency.srl.SRLEval;
import com.googlecode.clearnlp.feature.xml.JointFtrXml;
import com.googlecode.clearnlp.reader.AbstractColumnReader;
import com.googlecode.clearnlp.reader.JointReader;
import com.googlecode.clearnlp.util.UTFile;
import com.googlecode.clearnlp.util.UTInput;
import com.googlecode.clearnlp.util.UTOutput;
import com.googlecode.clearnlp.util.UTXml;
import com.googlecode.clearnlp.util.pair.ObjectDoublePair;
import java.io.FileInputStream;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Random;
import org.kohsuke.args4j.Option;
import org.w3c.dom.Element;

/* loaded from: input_file:com/googlecode/clearnlp/nlp/NLPDevelop.class */
public class NLPDevelop extends NLPTrain {

    @Option(name = "-d", usage = "the directory containing development files (required)", required = true, metaVar = "<directory>")
    protected String s_devDir;

    @Option(name = "-r", usage = "the random seed", required = false, metaVar = "<directory>")
    protected int i_rand = 0;

    @Option(name = "-g", usage = "if set, generate files", required = false, metaVar = "<boolean>")
    protected boolean b_generate = false;

    public NLPDevelop() {
    }

    public NLPDevelop(String[] strArr) {
        initArgs(strArr);
        try {
            develop(this.s_configFile, this.s_featureFiles.split(":"), this.s_trainDir, this.s_devDir, this.s_mode);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    public void develop(String str, String[] strArr, String str2, String str3, String str4) throws Exception {
        Element documentElement = UTXml.getDocumentElement(new FileInputStream(str));
        JointFtrXml[] featureTemplates = getFeatureTemplates(strArr);
        String[] sortedFileListBySize = UTFile.getSortedFileListBySize(str2, ".*", true);
        String[] sortedFileListBySize2 = UTFile.getSortedFileListBySize(str3, ".*", true);
        JointReader jointReader = getJointReader(UTXml.getFirstElementByTagName(documentElement, "reader"));
        if (str4.equals("pos")) {
            developComponent(documentElement, jointReader, featureTemplates, sortedFileListBySize, sortedFileListBySize2, new CPOSTagger(featureTemplates, getLowerSimplifiedForms(jointReader, featureTemplates[0], sortedFileListBySize, -1)), str4, -1);
            return;
        }
        if (str4.equals("dep")) {
            developComponentBoot(documentElement, jointReader, featureTemplates, sortedFileListBySize, sortedFileListBySize2, new CDEPPassParser(featureTemplates), str4, -1);
            return;
        }
        if (str4.equals(NLPLib.MODE_PRED)) {
            decode(jointReader, getTrainedComponent(documentElement, featureTemplates, sortedFileListBySize, null, null, str4, 0, -1), sortedFileListBySize2, str4, str4);
            return;
        }
        if (str4.equals("role")) {
            decode(jointReader, getTrainedComponent(documentElement, jointReader, featureTemplates, sortedFileListBySize, new CRolesetClassifier(featureTemplates), str4, -1), sortedFileListBySize2, str4, str4);
            return;
        }
        if (str4.startsWith(NLPLib.MODE_SENSE)) {
            decode(jointReader, getTrainedComponent(documentElement, jointReader, featureTemplates, sortedFileListBySize, new CSenseClassifier(featureTemplates, str4.substring(str4.lastIndexOf("_") + 1)), str4, -1), sortedFileListBySize2, str4, str4);
            return;
        }
        if (str4.equals("srl")) {
            developComponentBoot(documentElement, jointReader, featureTemplates, sortedFileListBySize, sortedFileListBySize2, new CSRLabeler(featureTemplates), str4, -1);
        } else if (str4.equals(NLPLib.MODE_POS_BACK)) {
            developComponentBoot(documentElement, jointReader, featureTemplates, sortedFileListBySize, sortedFileListBySize2, new CPOSBackTagger(featureTemplates, getLowerSimplifiedForms(jointReader, featureTemplates[0], sortedFileListBySize, -1)), str4, -1);
        } else if (str4.equals(NLPLib.MODE_DEP_BACK)) {
            developComponentBoot(documentElement, jointReader, featureTemplates, sortedFileListBySize, sortedFileListBySize2, new CDEPBackParser(featureTemplates), str4, -1);
        }
    }

    protected double developComponent(Element element, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, String[] strArr2, Object[] objArr, String str, int i) throws Exception {
        double d;
        StringTrainSpace[] stringTrainSpaces = getStringTrainSpaces(element, jointFtrXmlArr, strArr, null, objArr, str, 0, i);
        Element firstElementByTagName = UTXml.getFirstElementByTagName(element, str);
        int length = stringTrainSpaces.length;
        int i2 = 1;
        StringModel[] stringModelArr = new StringModel[length];
        double d2 = 0.0d;
        Random random = new Random(this.i_rand);
        int i3 = 0;
        do {
            d = d2;
            for (int i4 = 0; i4 < length; i4++) {
                int i5 = i2;
                i2++;
                updateModel(firstElementByTagName, stringTrainSpaces[i4], random, i5, i4);
                stringModelArr[i4] = (StringModel) stringTrainSpaces[i4].getModel();
            }
            d2 = decode(jointReader, getComponent(jointFtrXmlArr, stringModelArr, objArr, str), strArr2, str, Integer.toString(i3));
            i3++;
        } while (d < d2);
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double developComponent(Element element, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, String[] strArr2, AbstractStatisticalComponent abstractStatisticalComponent, String str, int i) throws Exception {
        return developComponent(element, jointReader, jointFtrXmlArr, strArr, strArr2, abstractStatisticalComponent != null ? getLexica(abstractStatisticalComponent, jointReader, jointFtrXmlArr, strArr, i) : null, str, i);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void developComponentBoot(Element element, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, String[] strArr2, AbstractStatisticalComponent abstractStatisticalComponent, String str, int i) throws Exception {
        double d;
        Object[] lexica = getLexica(abstractStatisticalComponent, jointReader, jointFtrXmlArr, strArr, i);
        double d2 = 0.0d;
        StringModel[] stringModelArr = null;
        int i2 = 0;
        do {
            d = d2;
            ObjectDoublePair<StringModel[]> developComponent = developComponent(element, jointReader, jointFtrXmlArr, strArr, strArr2, lexica, stringModelArr, str, i2, i);
            stringModelArr = (StringModel[]) developComponent.o;
            d2 = developComponent.d;
            i2++;
        } while (-0.01d < d2 - d);
    }

    protected void developComponentBoot2(Element element, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, String[] strArr2, AbstractStatisticalComponent abstractStatisticalComponent, String str, int i) throws Exception {
        double d;
        Object[] lexica = getLexica(abstractStatisticalComponent, jointReader, jointFtrXmlArr, strArr, i);
        double d2 = 0.0d;
        StringModel[] stringModelArr = null;
        int i2 = 0;
        do {
            d = d2;
            ObjectDoublePair<StringModel[]> developComponent2 = developComponent2(element, jointReader, jointFtrXmlArr, strArr, strArr2, lexica, stringModelArr, str, i2, i);
            stringModelArr = (StringModel[]) developComponent2.o;
            d2 = developComponent2.d;
            i2++;
        } while (d < d2);
    }

    /* JADX WARN: Multi-variable type inference failed */
    private ObjectDoublePair<StringModel[]> developComponent(Element element, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, String[] strArr2, Object[] objArr, StringModel[] stringModelArr, String str, int i, int i2) throws Exception {
        double d;
        StringTrainSpace[] stringTrainSpaces = getStringTrainSpaces(element, jointFtrXmlArr, strArr, stringModelArr, objArr, str, i, i2);
        Element firstElementByTagName = UTXml.getFirstElementByTagName(element, str);
        int length = stringTrainSpaces.length;
        int i3 = 1;
        double d2 = 0.0d;
        Random[] randomArr = new Random[length];
        StringModel[] stringModelArr2 = new StringModel[length];
        double[] dArr = new double[length];
        for (int i4 = 0; i4 < length; i4++) {
            randomArr[i4] = new Random(this.i_rand);
        }
        do {
            d = d2;
            for (int i5 = 0; i5 < length; i5++) {
                if (stringModelArr2[i5] != 0) {
                    double[] weights = stringModelArr2[i5].getWeights();
                    dArr[i5] = Arrays.copyOf(weights, weights.length);
                }
                updateModel(firstElementByTagName, stringTrainSpaces[i5], randomArr[i5], i3, i5);
                stringModelArr2[i5] = (StringModel) stringTrainSpaces[i5].getModel();
            }
            d2 = decode(jointReader, getComponent(jointFtrXmlArr, stringModelArr2, objArr, str), strArr2, str, i + CTLibEn.POS_PERIOD + i3 + CTLibEn.POS_PERIOD + this.i_rand);
            i3++;
        } while (d < d2);
        for (int i6 = 0; i6 < length; i6++) {
            stringModelArr2[i6].setWeights(dArr[i6]);
        }
        return new ObjectDoublePair<>(stringModelArr2, d);
    }

    private ObjectDoublePair<StringModel[]> developComponent2(Element element, JointReader jointReader, JointFtrXml[] jointFtrXmlArr, String[] strArr, String[] strArr2, Object[] objArr, StringModel[] stringModelArr, String str, int i, int i2) throws Exception {
        StringTrainSpace[] stringTrainSpaces = getStringTrainSpaces(element, jointFtrXmlArr, strArr, stringModelArr, objArr, str, i, i2);
        Element firstElementByTagName = UTXml.getFirstElementByTagName(element, str);
        int length = stringTrainSpaces.length;
        double d = -1.0d;
        for (int i3 = 0; i3 < length; i3++) {
            StringModel[] stringModelArr2 = stringModelArr;
            stringModelArr = new StringModel[i3 + 1];
            for (int i4 = 0; i4 < i3; i4++) {
                stringModelArr[i4] = stringModelArr2[i4];
            }
            Random random = new Random(this.i_rand);
            double[] dArr = null;
            double d2 = 0.0d;
            int i5 = 1;
            do {
                d = d2;
                if (stringModelArr[i3] != null) {
                    double[] weights = stringModelArr[i3].getWeights();
                    dArr = Arrays.copyOf(weights, weights.length);
                }
                updateModel(firstElementByTagName, stringTrainSpaces[i3], random, i5, i3);
                stringModelArr[i3] = (StringModel) stringTrainSpaces[i3].getModel();
                d2 = decode(jointReader, getComponent(jointFtrXmlArr, stringModelArr, objArr, str), strArr2, str, Integer.toString((100 * i) + i5));
                i5++;
            } while (d < d2);
            stringModelArr[i3].setWeights(dArr);
        }
        return new ObjectDoublePair<>(stringModelArr, d);
    }

    protected double decode(JointReader jointReader, AbstractStatisticalComponent abstractStatisticalComponent, String[] strArr, String str, String str2) throws Exception {
        int[] counts = getCounts(str);
        PrintStream printStream = null;
        for (String str3 : strArr) {
            if (this.b_generate) {
                printStream = UTOutput.createPrintBufferedFileStream(str3 + CTLibEn.POS_PERIOD + str2);
            }
            jointReader.open(UTInput.createBufferedFileReader(str3));
            while (true) {
                DEPTree next = jointReader.next();
                if (next == null) {
                    break;
                }
                abstractStatisticalComponent.process(next);
                abstractStatisticalComponent.countAccuracy(counts);
                if (this.b_generate) {
                    printStream.println(toString(next, str) + AbstractColumnReader.DELIM_SENTENCE);
                }
            }
            jointReader.close();
            if (this.b_generate) {
                printStream.close();
            }
        }
        return getScore(str, counts);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public int[] getCounts(String str) {
        if (str.startsWith("pos") || str.equals("role") || str.startsWith(NLPLib.MODE_SENSE)) {
            return new int[2];
        }
        if (str.equals("dep")) {
            return new int[4];
        }
        if (str.equals(NLPLib.MODE_PRED) || str.equals("srl")) {
            return new int[3];
        }
        if (str.equals(NLPLib.MODE_DEP_BACK)) {
            return new int[5];
        }
        return null;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getScore(String str, int[] iArr) {
        double d = 0.0d;
        if (str.startsWith("pos") || str.equals("role") || str.startsWith(NLPLib.MODE_SENSE)) {
            d = (100.0d * iArr[1]) / iArr[0];
            System.out.printf("- ACC: %5.2f (%d/%d)\n", Double.valueOf(d), Integer.valueOf(iArr[1]), Integer.valueOf(iArr[0]));
        } else if (str.equals("dep")) {
            printScores(new String[]{"T", SRLEval.LAS, SRLEval.UAS, CTLibEn.POS_LS}, iArr);
            d = (100.0d * iArr[1]) / iArr[0];
        } else if (str.equals(NLPLib.MODE_PRED) || str.equals("srl")) {
            double d2 = (100.0d * iArr[0]) / iArr[1];
            double d3 = (100.0d * iArr[0]) / iArr[2];
            d = SRLEval.getF1(d2, d3);
            System.out.printf("P: %5.2f ", Double.valueOf(d2));
            System.out.printf("R: %5.2f ", Double.valueOf(d3));
            System.out.printf("F1: %5.2f\n", Double.valueOf(d));
        } else if (str.equals(NLPLib.MODE_DEP_BACK)) {
            printScores(new String[]{"T", CTLibEn.POS_POS, SRLEval.LAS, SRLEval.UAS, CTLibEn.POS_LS}, iArr);
            d = (100.0d * iArr[2]) / iArr[0];
        }
        return d;
    }

    private void printScores(String[] strArr, int[] iArr) {
        int i = iArr[0];
        int length = iArr.length;
        for (int i2 = 1; i2 < length; i2++) {
            System.out.printf("%3s: %5.2f (%d/%d)\n", strArr[i2], Double.valueOf((100.0d * iArr[i2]) / i), Integer.valueOf(iArr[i2]), Integer.valueOf(i));
        }
    }

    public static void main(String[] strArr) {
        new NLPDevelop(strArr);
    }
}
