package edu.stanford.nlp.ie.ner;

import edu.stanford.nlp.classify.Dataset;
import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LinearClassifier;
import edu.stanford.nlp.classify.LinearClassifierFactory;
import edu.stanford.nlp.classify.LogPrior;
import edu.stanford.nlp.classify.NBLinearClassifierFactory;
import edu.stanford.nlp.classify.ProbabilisticClassifier;
import edu.stanford.nlp.classify.SVMLightClassifierFactory;
import edu.stanford.nlp.ie.AbstractSequenceClassifier;
import edu.stanford.nlp.international.morph.MorphoFeatures;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.BasicDatum;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.Datum;
import edu.stanford.nlp.ling.Document;
import edu.stanford.nlp.ling.HasTag;
import edu.stanford.nlp.ling.HasWord;
import edu.stanford.nlp.ling.WordTag;
import edu.stanford.nlp.math.ArrayMath;
import edu.stanford.nlp.math.SloppyMath;
import edu.stanford.nlp.objectbank.ObjectBank;
import edu.stanford.nlp.process.DocumentProcessor;
import edu.stanford.nlp.process.ListProcessor;
import edu.stanford.nlp.process.PTBLexer;
import edu.stanford.nlp.sequences.BeamBestSequenceFinder;
import edu.stanford.nlp.sequences.Clique;
import edu.stanford.nlp.sequences.DocumentReaderAndWriter;
import edu.stanford.nlp.sequences.ExactBestSequenceFinder;
import edu.stanford.nlp.sequences.FeatureFactory;
import edu.stanford.nlp.sequences.PlainTextDocumentReaderAndWriter;
import edu.stanford.nlp.sequences.SeqClassifierFlags;
import edu.stanford.nlp.sequences.SequenceModel;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.util.CoreMap;
import edu.stanford.nlp.util.HashIndex;
import edu.stanford.nlp.util.Index;
import edu.stanford.nlp.util.PaddedList;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import java.io.BufferedInputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Properties;
import java.util.Set;
import java.util.regex.Pattern;

/* loaded from: input_file:edu/stanford/nlp/ie/ner/CMMClassifier.class */
public class CMMClassifier<IN extends CoreLabel> extends AbstractSequenceClassifier<IN> implements DocumentProcessor, ListProcessor<Object, WordTag> {
    private ProbabilisticClassifier<String, String> classifier;
    Set<List<String>> answerArrays;
    public static final String DEFAULT_CLASSIFIER = "/classifiers/ner-eng-ie.cmm-3-all2006.ser.gz";
    static int lastPos = -1;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/stanford/nlp/ie/ner/CMMClassifier$Scorer.class */
    public static class Scorer implements SequenceModel {
        private CMMClassifier classifier;
        private int[] backgroundTags;
        private Index<String> tagIndex;
        private List<? extends CoreLabel> lineInfos;
        private int pre;
        private int post;
        private Set<List<String>> legalTags;
        private static final boolean VERBOSE = false;
        private int[] tagArray = null;
        private double[] scoreCache = null;
        private int[] lastWindow = null;
        private int percent = -1;
        private int num = 0;
        private long secs = System.currentTimeMillis();
        private long hit = 0;
        private long tot = 0;

        void buildTagArray() {
            int size = this.tagIndex.size();
            this.tagArray = new int[size];
            for (int i = 0; i < size; i++) {
                this.tagArray[i] = i;
            }
        }

        @Override // edu.stanford.nlp.sequences.SequenceModel
        public int length() {
            return (this.lineInfos.size() - this.pre) - this.post;
        }

        @Override // edu.stanford.nlp.sequences.SequenceModel
        public int leftWindow() {
            return this.pre;
        }

        @Override // edu.stanford.nlp.sequences.SequenceModel
        public int rightWindow() {
            return this.post;
        }

        @Override // edu.stanford.nlp.sequences.SequenceModel
        public int[] getPossibleValues(int i) {
            if (this.tagArray == null) {
                buildTagArray();
            }
            return i < this.pre ? this.backgroundTags : this.tagArray;
        }

        @Override // edu.stanford.nlp.sequences.SequenceModel
        public double scoreOf(int[] iArr) {
            throw new UnsupportedOperationException();
        }

        @Override // edu.stanford.nlp.sequences.SequenceModel
        public double scoreOf(int[] iArr, int i) {
            if (this.lastWindow == null) {
                this.lastWindow = new int[leftWindow() + rightWindow() + 1];
                Arrays.fill(this.lastWindow, -1);
            }
            boolean z = i == CMMClassifier.lastPos;
            for (int leftWindow = i - leftWindow(); leftWindow <= i + rightWindow(); leftWindow++) {
                if (leftWindow != i && leftWindow >= 0) {
                    z &= iArr[leftWindow] == this.lastWindow[(leftWindow - i) + leftWindow()];
                }
            }
            if (!z) {
                this.scoreCache = scoresOf(iArr, i);
                for (int leftWindow2 = i - leftWindow(); leftWindow2 <= i + rightWindow(); leftWindow2++) {
                    if (leftWindow2 >= 0) {
                        this.lastWindow[(leftWindow2 - i) + leftWindow()] = iArr[leftWindow2];
                    }
                }
                CMMClassifier.lastPos = i;
            }
            return this.scoreCache[iArr[i]];
        }

        @Override // edu.stanford.nlp.sequences.SequenceModel
        public double[] scoresOf(int[] iArr, int i) {
            String[] strArr = new String[1 + leftWindow() + rightWindow()];
            String[] strArr2 = new String[leftWindow()];
            for (int i2 = 0; i2 < 1 + leftWindow() + rightWindow(); i2++) {
                int leftWindow = (i - leftWindow()) + i2;
                if (leftWindow >= 0) {
                    strArr[i2] = this.tagIndex.get(iArr[leftWindow]);
                    this.lineInfos.get(leftWindow).set(CoreAnnotations.AnswerAnnotation.class, strArr[i2]);
                    if (i2 < leftWindow()) {
                        strArr2[i2] = strArr[i2];
                    }
                }
            }
            double[] dArr = new double[this.tagIndex.size()];
            if (!this.legalTags.contains(Arrays.asList(strArr2)) && this.classifier.flags.useObservedSequencesOnly) {
                Arrays.fill(dArr, -1000.0d);
                return dArr;
            }
            this.num++;
            this.hit++;
            Counter<String> scoresOf = this.classifier.scoresOf(this.lineInfos, i);
            for (String str : scoresOf.keySet()) {
                int indexOf = this.tagIndex.indexOf(str);
                if (indexOf > -1) {
                    int[] possibleValues = getPossibleValues(i);
                    for (int i3 = 0; i3 < possibleValues.length; i3++) {
                        if (possibleValues[i3] == indexOf) {
                            dArr[i3] = scoresOf.getCount(str);
                        }
                    }
                }
            }
            if (this.classifier.normalize()) {
                ArrayMath.logNormalize(dArr);
            }
            return dArr;
        }

        static double[] recenter(double[] dArr) {
            double[] dArr2 = new double[dArr.length];
            double logSum = ArrayMath.logSum(dArr);
            for (int i = 0; i < dArr.length; i++) {
                dArr2[i] = dArr[i] - logSum;
            }
            return dArr2;
        }

        Scorer(List<? extends CoreLabel> list, Index<String> index, CMMClassifier cMMClassifier, int i, int i2, Set<List<String>> set) {
            this.classifier = null;
            this.backgroundTags = null;
            this.tagIndex = null;
            this.lineInfos = null;
            this.pre = 0;
            this.post = 0;
            this.legalTags = null;
            this.pre = i;
            this.post = i2;
            this.lineInfos = list;
            this.tagIndex = index;
            this.classifier = cMMClassifier;
            this.legalTags = set;
            this.backgroundTags = new int[]{index.indexOf(cMMClassifier.flags.backgroundSymbol)};
        }
    }

    protected CMMClassifier() {
        super(new SeqClassifierFlags());
    }

    public CMMClassifier(Properties properties) {
        super(properties);
    }

    public Set<String> getTags() {
        HashSet hashSet = new HashSet(this.classIndex.objectsList());
        hashSet.remove(this.flags.backgroundSymbol);
        return hashSet;
    }

    @Override // edu.stanford.nlp.ie.AbstractSequenceClassifier
    public List<IN> classify(List<IN> list) {
        if (this.flags.useSequences) {
            classifySeq(list);
        } else {
            classifyNoSeq(list);
        }
        return list;
    }

    private void classifyNoSeq(List<IN> list) {
        if (this.flags.useReverse) {
            Collections.reverse(list);
        }
        if (this.flags.lowerNewgeneThreshold) {
            System.err.println("Using NEWGENE threshold: " + this.flags.newgeneThreshold);
            int size = list.size();
            for (int i = 0; i < size; i++) {
                IN in = list.get(i);
                Counter<String> scoresOf = this.classifier.scoresOf(makeDatum(list, i, this.featureFactory));
                String str = this.flags.backgroundSymbol;
                if ("NEWGENE".equals(in.get(CoreAnnotations.GazAnnotation.class))) {
                    for (String str2 : scoresOf.keySet()) {
                        if ("G".equals(str2)) {
                            System.err.println(in.word() + ':' + scoresOf.getCount(str2));
                            if (scoresOf.getCount(str2) > this.flags.newgeneThreshold) {
                                str = str2;
                            }
                        }
                    }
                }
                in.set(CoreAnnotations.AnswerAnnotation.class, str);
            }
        } else {
            int size2 = list.size();
            for (int i2 = 0; i2 < size2; i2++) {
                list.get(i2).set(CoreAnnotations.AnswerAnnotation.class, classOf(list, i2));
            }
            if (this.flags.justify && (this.classifier instanceof LinearClassifier)) {
                LinearClassifier linearClassifier = (LinearClassifier) this.classifier;
                int size3 = list.size();
                for (int i3 = 0; i3 < size3; i3++) {
                    IN in2 = list.get(i3);
                    System.err.print("@@ Position " + i3 + ": ");
                    System.err.println(in2.word() + " chose " + ((String) in2.get(CoreAnnotations.AnswerAnnotation.class)));
                    linearClassifier.justificationOf(makeDatum(list, i3, this.featureFactory));
                }
            }
        }
        if (this.flags.useReverse) {
            Collections.reverse(list);
        }
    }

    protected String classOf(List<IN> list, int i) {
        return this.classifier.classOf(makeDatum(list, i, this.featureFactory));
    }

    public double loglikelihood(List<IN> list) {
        double d = 0.0d;
        for (int i = 0; i < list.size(); i++) {
            Datum<String, String> makeDatum = makeDatum(list, i, this.featureFactory);
            Counter<String> logProbabilityOf = this.classifier.logProbabilityOf(makeDatum);
            double d2 = Double.NEGATIVE_INFINITY;
            Iterator<String> it = logProbabilityOf.keySet().iterator();
            while (it.hasNext()) {
                d2 = SloppyMath.logAdd(d2, logProbabilityOf.getCount(it.next()));
            }
            d -= logProbabilityOf.getCount(makeDatum.label()) - d2;
        }
        if (this.classifier instanceof LinearClassifier) {
            double d3 = this.flags.sigma * this.flags.sigma;
            LinearClassifier linearClassifier = (LinearClassifier) this.classifier;
            for (String str : linearClassifier.features()) {
                Iterator<String> it2 = this.classIndex.iterator();
                while (it2.hasNext()) {
                    double weight = linearClassifier.weight((LinearClassifier) str, it2.next());
                    d += ((weight * weight) / 2.0d) / d3;
                }
            }
        }
        return d;
    }

    @Override // edu.stanford.nlp.ie.AbstractSequenceClassifier
    public SequenceModel getSequenceModel(List<IN> list) {
        if (this.flags.useReverse) {
            Collections.reverse(list);
        }
        return new Scorer(list, this.classIndex, this, !this.flags.useTaggySequences ? this.flags.usePrevSequences ? 1 : 0 : this.flags.maxLeft, this.flags.useNextSequences ? 1 : 0, this.answerArrays);
    }

    private void classifySeq(List<IN> list) {
        if (list.isEmpty()) {
            return;
        }
        SequenceModel sequenceModel = getSequenceModel(list);
        int[] bestSequence = this.flags.useViterbi ? new ExactBestSequenceFinder().bestSequence(sequenceModel) : new BeamBestSequenceFinder(this.flags.beamSize, true, true).bestSequence(sequenceModel, list.size());
        if (this.flags.lowerNewgeneThreshold) {
            System.err.println("Using NEWGENE threshold: " + this.flags.newgeneThreshold);
            int[] iArr = new int[bestSequence.length];
            System.arraycopy(bestSequence, 0, iArr, 0, bestSequence.length);
            int indexOf = this.classIndex.indexOf("G");
            int indexOf2 = this.classIndex.indexOf(this.flags.backgroundSymbol);
            int i = 0;
            int size = list.size();
            while (i < size) {
                if ("NEWGENE".equals(list.get(i).get(CoreAnnotations.GazAnnotation.class))) {
                    int i2 = i;
                    int i3 = i;
                    while (i3 < list.size() && "NEWGENE".equals(list.get(i3).get(CoreAnnotations.GazAnnotation.class))) {
                        i3++;
                    }
                    int i4 = i3;
                    int max = Math.max(0, i2 - 4);
                    int min = Math.min(bestSequence.length, i4 + 4);
                    for (int i5 = max; i5 < min; i5++) {
                        iArr[i5] = indexOf2;
                    }
                    double d = 0.0d;
                    for (int i6 = i2; i6 < i4; i6++) {
                        d += Scorer.recenter(sequenceModel.scoresOf(iArr, i6))[indexOf2];
                    }
                    ClassicCounter classicCounter = new ClassicCounter();
                    for (int i7 = i2; i7 < i4; i7++) {
                        for (int i8 = i2; i8 < i4; i8++) {
                            iArr[i8] = indexOf2;
                        }
                        for (int i9 = i7; i9 < i4; i9++) {
                            iArr[i9] = indexOf;
                            double d2 = 0.0d;
                            for (int i10 = i2; i10 < i4; i10++) {
                                d2 += Scorer.recenter(sequenceModel.scoresOf(iArr, i10))[bestSequence[i10]];
                            }
                            classicCounter.incrementCount(new Pair(Integer.valueOf(i7), Integer.valueOf(i9)), d2 - d);
                        }
                    }
                    for (int i11 = i2; i11 < i4; i11++) {
                        for (int i12 = i11; i12 < i4; i12++) {
                            double count = classicCounter.getCount(new Pair(Integer.valueOf(i11), Integer.valueOf(i12)));
                            Pair pair = new Pair(Integer.valueOf(i11 - 1), Integer.valueOf(i12));
                            Pair pair2 = new Pair(Integer.valueOf(i11), Integer.valueOf(i12 + 1));
                            Pair pair3 = new Pair(Integer.valueOf(i11 + 1), Integer.valueOf(i12));
                            Pair pair4 = new Pair(Integer.valueOf(i11), Integer.valueOf(i12 - 1));
                            if (count >= this.flags.newgeneThreshold && ((!classicCounter.containsKey(pair) || count > classicCounter.getCount(pair)) && ((!classicCounter.containsKey(pair2) || count > classicCounter.getCount(pair2)) && ((!classicCounter.containsKey(pair3) || count > classicCounter.getCount(pair3)) && (!classicCounter.containsKey(pair4) || count > classicCounter.getCount(pair4)))))) {
                                StringBuilder sb = new StringBuilder();
                                IN in = list.get(i11);
                                String str = (String) in.get(CoreAnnotations.IDAnnotation.class);
                                String str2 = (String) in.get(CoreAnnotations.PositionAnnotation.class);
                                String str3 = (String) list.get(i12).get(CoreAnnotations.PositionAnnotation.class);
                                for (int i13 = i11; i13 <= i12; i13++) {
                                    sb.append(list.get(i13).word());
                                    sb.append(' ');
                                }
                                System.out.println(str + '|' + str2 + ' ' + str3 + '|' + sb.toString().trim());
                            }
                        }
                    }
                    for (int i14 = max; i14 < min; i14++) {
                        iArr[i14] = bestSequence[i14];
                    }
                    i = i4;
                }
                i++;
            }
        }
        int size2 = list.size();
        for (int i15 = 0; i15 < size2; i15++) {
            list.get(i15).set(CoreAnnotations.AnswerAnnotation.class, this.classIndex.get(bestSequence[i15]));
        }
        if (this.flags.justify && (this.classifier instanceof LinearClassifier)) {
            LinearClassifier linearClassifier = (LinearClassifier) this.classifier;
            if (this.flags.dump) {
                linearClassifier.dump();
            }
            int size3 = list.size();
            for (int i16 = 0; i16 < size3; i16++) {
                IN in2 = list.get(i16);
                System.err.print("@@ Position is: " + i16 + ": ");
                System.err.println(in2.word() + ' ' + ((String) in2.get(CoreAnnotations.AnswerAnnotation.class)));
                linearClassifier.justificationOf(makeDatum(list, i16, this.featureFactory));
            }
        }
        if (this.flags.useReverse) {
            Collections.reverse(list);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void adapt(String str, Dataset<String, String> dataset, DocumentReaderAndWriter documentReaderAndWriter) {
        this.flags.ocrTrain = false;
        adapt(makeObjectBankFromFile(str, documentReaderAndWriter), dataset);
    }

    public void adapt(ObjectBank<List<IN>> objectBank, Dataset<String, String> dataset) {
        adapt(getDataset(objectBank, dataset));
    }

    public void retrain(ObjectBank<List<IN>> objectBank, Index<String> index, Index<String> index2) {
        int size = index.size();
        int size2 = index2.size();
        Dataset<String, String> dataset = getDataset(objectBank, index, index2);
        int ordinal = LogPrior.LogPriorType.QUADRATIC.ordinal();
        LinearClassifier linearClassifier = (LinearClassifier) this.classifier;
        LinearClassifierFactory linearClassifierFactory = new LinearClassifierFactory(this.flags.tolerance, this.flags.useSum, ordinal, this.flags.sigma, this.flags.epsilon, this.flags.QNsize);
        double[][] weights = linearClassifier.weights();
        Index<String> index3 = dataset.featureIndex;
        Index<String> index4 = dataset.labelIndex;
        int size3 = index3.size();
        int size4 = index4.size();
        double[] dArr = new double[size3 * size4];
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < size2; i2++) {
                dArr[(index3.indexOf(index.get(i)) * size4) + index4.indexOf(index2.get(i2))] = weights[i][i2];
            }
        }
        linearClassifier.setWeights(linearClassifierFactory.trainWeights(dataset, dArr));
    }

    public void retrain(ObjectBank<List<IN>> objectBank) {
        if (this.classifier == null) {
            System.err.println("Cannot retrain before you train!");
            System.exit(-1);
        }
        Index<String> featureIndex = ((LinearClassifier) this.classifier).featureIndex();
        Index<String> labelIndex = ((LinearClassifier) this.classifier).labelIndex();
        System.err.println("Starting retrain:\t# of original features" + featureIndex.size() + ", # of original labels" + labelIndex.size());
        retrain(objectBank, featureIndex, labelIndex);
    }

    @Override // edu.stanford.nlp.ie.AbstractSequenceClassifier
    public void train(Collection<List<IN>> collection, DocumentReaderAndWriter documentReaderAndWriter) {
        Dataset<String, String> dataset = getDataset(collection);
        train(dataset);
        for (int i = 0; i < this.flags.numTimesPruneFeatures; i++) {
            Index<String> featuresAboveThreshhold = getFeaturesAboveThreshhold(dataset, this.flags.featureDiffThresh);
            System.err.println("Removing features with weight below " + this.flags.featureDiffThresh + " and retraining...");
            dataset = getDataset(dataset, featuresAboveThreshhold);
            int i2 = this.flags.QNsize;
            this.flags.QNsize = this.flags.QNsize2;
            train(dataset);
            this.flags.QNsize = i2;
        }
        if (this.flags.doAdaptation && this.flags.adaptFile != null) {
            adapt(this.flags.adaptFile, dataset, documentReaderAndWriter);
        }
        System.err.print("Built this classifier: ");
        if (this.classifier instanceof LinearClassifier) {
            System.err.println(((LinearClassifier) this.classifier).toString(this.flags.printClassifier, this.flags.printClassifierParam));
        } else {
            System.err.println(this.classifier.toString());
        }
    }

    public Index<String> getFeaturesAboveThreshhold(Dataset<String, String> dataset, double d) {
        if (!(this.classifier instanceof LinearClassifier)) {
            throw new RuntimeException("Attempting to remove features based on weight from a non-linear classifier");
        }
        Index<String> index = dataset.featureIndex;
        Index<String> index2 = dataset.labelIndex;
        HashIndex hashIndex = new HashIndex();
        LinearClassifier linearClassifier = (LinearClassifier) this.classifier;
        for (String str : index) {
            Iterator<String> it = index2.iterator();
            double d2 = Double.POSITIVE_INFINITY;
            double d3 = Double.NEGATIVE_INFINITY;
            while (true) {
                if (it.hasNext()) {
                    double weight = linearClassifier.weight((LinearClassifier) str, it.next());
                    if (weight < d2) {
                        d2 = weight;
                    }
                    if (weight > d3) {
                        d3 = weight;
                    }
                    if (d3 - d2 > d) {
                        hashIndex.add(str);
                        break;
                    }
                }
            }
        }
        return hashIndex;
    }

    public Dataset<String, String> getDataset(Collection<List<IN>> collection) {
        return getDataset(collection, null, null);
    }

    public Dataset<String, String> getDataset(Collection<List<IN>> collection, Index<String> index, Index<String> index2) {
        Dataset<String, String> dataset;
        makeAnswerArraysAndTagIndex(collection);
        int i = 0;
        Iterator<List<IN>> it = collection.iterator();
        while (it.hasNext()) {
            i += it.next().size();
        }
        System.err.println("Making Dataset...");
        if (index == null || index2 == null) {
            dataset = new Dataset<>(i);
        } else {
            System.err.println("Using feature/class Index from existing Dataset...");
            System.err.println("(This is used when getting Dataset from adaptation set. We want to make the index consistent.)");
            dataset = new Dataset<>(i, index, index2);
        }
        for (List<IN> list : collection) {
            if (this.flags.useReverse) {
                Collections.reverse(list);
            }
            int size = list.size();
            for (int i2 = 0; i2 < size; i2++) {
                dataset.add(makeDatum(list, i2, this.featureFactory));
            }
            if (this.flags.useReverse) {
                Collections.reverse(list);
            }
        }
        System.err.println("done.");
        if (this.flags.featThreshFile != null) {
            System.err.println("applying thresholds...");
            dataset.applyFeatureCountThreshold(getThresholds(this.flags.featThreshFile));
        } else if (this.flags.featureThreshold > 1) {
            System.err.println("Removing Features with counts < " + this.flags.featureThreshold);
            dataset.applyFeatureCountThreshold(this.flags.featureThreshold);
        }
        dataset.summaryStatistics();
        return dataset;
    }

    public Dataset<String, String> getBiasedDataset(ObjectBank<List<IN>> objectBank, Index<String> index, Index<String> index2) {
        makeAnswerArraysAndTagIndex(objectBank);
        HashIndex hashIndex = new HashIndex(index.objectsList());
        int i = 0;
        Iterator<List<IN>> it = objectBank.iterator();
        while (it.hasNext()) {
            i += it.next().size();
        }
        System.err.println("Making Dataset...");
        Dataset<String, String> dataset = new Dataset<>(i, index, index2);
        Iterator<List<IN>> it2 = objectBank.iterator();
        while (it2.hasNext()) {
            List<IN> next = it2.next();
            if (this.flags.useReverse) {
                Collections.reverse(next);
            }
            int size = next.size();
            for (int i2 = 0; i2 < size; i2++) {
                Datum<String, String> makeDatum = makeDatum(next, i2, this.featureFactory);
                ArrayList arrayList = new ArrayList();
                for (String str : makeDatum.asFeatures()) {
                    if (!hashIndex.contains(str)) {
                        arrayList.add(str);
                    }
                }
                dataset.add(makeDatum);
            }
            if (this.flags.useReverse) {
                Collections.reverse(next);
            }
        }
        System.err.println("done.");
        if (this.flags.featThreshFile != null) {
            System.err.println("applying thresholds...");
            dataset.applyFeatureCountThreshold(getThresholds(this.flags.featThreshFile));
        } else if (this.flags.featureThreshold > 1) {
            System.err.println("Removing Features with counts < " + this.flags.featureThreshold);
            dataset.applyFeatureCountThreshold(this.flags.featureThreshold);
        }
        dataset.summaryStatistics();
        return dataset;
    }

    public Dataset<String, String> getDataset(ObjectBank<List<IN>> objectBank, Dataset<String, String> dataset) {
        return dataset == null ? getDataset(objectBank) : getDataset(objectBank, dataset.featureIndex, dataset.labelIndex);
    }

    /* JADX WARN: Type inference failed for: r0v11, types: [int[], int[][]] */
    public Dataset<String, String> getDataset(Dataset<String, String> dataset, Index<String> index) {
        int[][] dataArray = dataset.getDataArray();
        int[] labelsArray = dataset.getLabelsArray();
        Index<String> index2 = dataset.featureIndex;
        int[] iArr = new int[index2.size()];
        ?? r0 = new int[dataArray.length];
        System.err.print("Building reduced dataset...");
        int size = index2.size();
        int i = 0;
        for (int i2 = 0; i2 < size; i2++) {
            iArr[i2] = index.indexOf(index2.get(i2));
            if (iArr[i2] > i) {
                i = iArr[i2];
            }
        }
        for (int i3 = 0; i3 < dataArray.length; i3++) {
            int[] iArr2 = dataArray[i3];
            int i4 = 0;
            for (int i5 : iArr2) {
                if (iArr[i5] > 0) {
                    i4++;
                }
            }
            int[] iArr3 = new int[i4];
            int i6 = 0;
            for (int i7 : iArr2) {
                int i8 = iArr[i7];
                if (i8 > 0) {
                    int i9 = i6;
                    i6++;
                    iArr3[i9] = i8;
                }
            }
            r0[i3] = iArr3;
        }
        Dataset<String, String> dataset2 = new Dataset<>(dataset.labelIndex, labelsArray, index, r0, r0.length);
        System.err.println("done.");
        if (this.flags.featThreshFile != null) {
            System.err.println("applying thresholds...");
            dataset2.applyFeatureCountThreshold(getThresholds(this.flags.featThreshFile));
        } else if (this.flags.featureThreshold > 1) {
            System.err.println("Removing Features with counts < " + this.flags.featureThreshold);
            dataset2.applyFeatureCountThreshold(this.flags.featureThreshold);
        }
        dataset2.summaryStatistics();
        return dataset2;
    }

    private void adapt(Dataset<String, String> dataset) {
        if (this.flags.classifierType.equalsIgnoreCase("SVM")) {
            throw new UnsupportedOperationException();
        }
        adaptMaxEnt(dataset);
    }

    private void adaptMaxEnt(Dataset<String, String> dataset) {
        if (!(this.classifier instanceof LinearClassifier)) {
            throw new UnsupportedOperationException();
        }
        int ordinal = LogPrior.LogPriorType.QUADRATIC.ordinal();
        if (this.flags.useHuber) {
            throw new UnsupportedOperationException();
        }
        if (this.flags.useQuartic) {
            throw new UnsupportedOperationException();
        }
        ((LinearClassifier) this.classifier).adaptWeights(dataset, new LinearClassifierFactory(this.flags.tolerance, this.flags.useSum, ordinal, this.flags.adaptSigma, this.flags.epsilon, this.flags.QNsize));
    }

    private void train(Dataset<String, String> dataset) {
        if (this.flags.classifierType.equalsIgnoreCase("SVM")) {
            trainSVM(dataset);
        } else {
            trainMaxEnt(dataset);
        }
    }

    private void trainSVM(Dataset<String, String> dataset) {
        this.classifier = new SVMLightClassifierFactory().trainClassifier((GeneralDataset) dataset);
    }

    private void trainMaxEnt(Dataset<String, String> dataset) {
        LinearClassifier trainClassifier;
        int ordinal = LogPrior.LogPriorType.QUADRATIC.ordinal();
        if (this.flags.useHuber) {
            ordinal = LogPrior.LogPriorType.HUBER.ordinal();
        } else if (this.flags.useQuartic) {
            ordinal = LogPrior.LogPriorType.QUARTIC.ordinal();
        }
        if (this.flags.useNB) {
            trainClassifier = new NBLinearClassifierFactory(this.flags.sigma).trainClassifier((GeneralDataset) dataset);
        } else {
            LinearClassifierFactory linearClassifierFactory = new LinearClassifierFactory(this.flags.tolerance, this.flags.useSum, ordinal, this.flags.sigma, this.flags.epsilon, this.flags.QNsize);
            if (this.flags.useQN) {
                linearClassifierFactory.useQuasiNewton(this.flags.useRobustQN);
            } else if (this.flags.useStochasticQN) {
                linearClassifierFactory.useStochasticQN(this.flags.initialGain, this.flags.stochasticBatchSize);
            } else if (this.flags.useSMD) {
                linearClassifierFactory.useStochasticMetaDescent(this.flags.initialGain, this.flags.stochasticBatchSize, this.flags.stochasticMethod, this.flags.SGDPasses);
            } else if (this.flags.useSGD) {
                linearClassifierFactory.useStochasticGradientDescent(this.flags.gainSGD, this.flags.stochasticBatchSize);
            } else if (this.flags.useSGDtoQN) {
                linearClassifierFactory.useStochasticGradientDescentToQuasiNewton(this.flags);
            } else if (this.flags.useHybrid) {
                linearClassifierFactory.useHybridMinimizer(this.flags.initialGain, this.flags.stochasticBatchSize, this.flags.stochasticMethod, this.flags.hybridCutoffIteration);
            } else {
                linearClassifierFactory.useConjugateGradientAscent();
            }
            trainClassifier = linearClassifierFactory.trainClassifier((GeneralDataset) dataset);
        }
        this.classifier = trainClassifier;
    }

    private void trainSemiSup(Dataset<String, String> dataset, Dataset<String, String> dataset2, double[][] dArr) {
        int ordinal = LogPrior.LogPriorType.QUADRATIC.ordinal();
        if (this.flags.useHuber) {
            ordinal = LogPrior.LogPriorType.HUBER.ordinal();
        } else if (this.flags.useQuartic) {
            ordinal = LogPrior.LogPriorType.QUARTIC.ordinal();
        }
        LinearClassifierFactory linearClassifierFactory = new LinearClassifierFactory(this.flags.tolerance, this.flags.useSum, ordinal, this.flags.sigma, this.flags.epsilon, this.flags.QNsize);
        if (this.flags.useQN) {
            linearClassifierFactory.useQuasiNewton();
        } else {
            linearClassifierFactory.useConjugateGradientAscent();
        }
        this.classifier = (LinearClassifier) linearClassifierFactory.trainClassifierSemiSup(dataset, dataset2, dArr, null);
    }

    @Override // edu.stanford.nlp.ie.AbstractSequenceClassifier
    public void serializeClassifier(String str) {
        System.err.print("Serializing classifier to " + str + PTBLexer.ptb3EllipsisStr);
        try {
            ObjectOutputStream writeStreamFromString = IOUtils.writeStreamFromString(str);
            writeStreamFromString.writeObject(this.classifier);
            writeStreamFromString.writeObject(this.flags);
            writeStreamFromString.writeObject(this.featureFactory);
            writeStreamFromString.writeObject(this.classIndex);
            writeStreamFromString.writeObject(this.answerArrays);
            writeStreamFromString.writeObject(this.knownLCWords);
            writeStreamFromString.close();
            System.err.println("Done.");
        } catch (Exception e) {
            System.err.println("Error serializing to " + str);
            e.printStackTrace();
        }
    }

    public void loadDefaultClassifier() {
        loadJarClassifier(DEFAULT_CLASSIFIER, null);
    }

    public static CMMClassifier getDefaultClassifier() {
        CMMClassifier cMMClassifier = new CMMClassifier();
        cMMClassifier.loadDefaultClassifier();
        return cMMClassifier;
    }

    @Override // edu.stanford.nlp.ie.AbstractSequenceClassifier
    public void loadClassifier(ObjectInputStream objectInputStream, Properties properties) throws ClassCastException, IOException, ClassNotFoundException {
        this.classifier = (LinearClassifier) objectInputStream.readObject();
        this.flags = (SeqClassifierFlags) objectInputStream.readObject();
        this.featureFactory = (FeatureFactory) objectInputStream.readObject();
        if (properties != null) {
            this.flags.setProperties(properties);
        }
        reinit();
        this.classIndex = (Index) objectInputStream.readObject();
        this.answerArrays = (Set) objectInputStream.readObject();
        this.knownLCWords = (Set) objectInputStream.readObject();
    }

    public static CMMClassifier getClassifierNoExceptions(File file) {
        CMMClassifier cMMClassifier = new CMMClassifier();
        cMMClassifier.loadClassifierNoExceptions(file);
        return cMMClassifier;
    }

    public static CMMClassifier getClassifier(File file) throws IOException, ClassCastException, ClassNotFoundException {
        CMMClassifier cMMClassifier = new CMMClassifier();
        cMMClassifier.loadClassifier(file);
        return cMMClassifier;
    }

    public static CMMClassifier getClassifierNoExceptions(String str) {
        CMMClassifier cMMClassifier = new CMMClassifier();
        cMMClassifier.loadClassifierNoExceptions(str);
        return cMMClassifier;
    }

    public static CMMClassifier getClassifier(String str) throws IOException, ClassCastException, ClassNotFoundException {
        CMMClassifier cMMClassifier = new CMMClassifier();
        cMMClassifier.loadClassifier(str);
        return cMMClassifier;
    }

    public static CMMClassifier getClassifierNoExceptions(InputStream inputStream) {
        CMMClassifier cMMClassifier = new CMMClassifier();
        cMMClassifier.loadClassifierNoExceptions(new BufferedInputStream(inputStream), (Properties) null);
        return cMMClassifier;
    }

    public static CMMClassifier getClassifier(InputStream inputStream) throws IOException, ClassCastException, ClassNotFoundException {
        CMMClassifier cMMClassifier = new CMMClassifier();
        cMMClassifier.loadClassifier(new BufferedInputStream(inputStream));
        return cMMClassifier;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private void makeAnswerArraysAndTagIndex(Collection<List<IN>> collection) {
        if (this.answerArrays == null) {
            this.answerArrays = new HashSet();
        }
        if (this.classIndex == null) {
            this.classIndex = new HashIndex();
        }
        for (List<IN> list : collection) {
            if (this.flags.useReverse) {
                Collections.reverse(list);
            }
            int size = list.size();
            for (int i = 0; i < size; i++) {
                for (int i2 = 1; i2 <= this.flags.maxLeft && i + i2 <= size; i2++) {
                    String[] strArr = new String[i2];
                    for (int i3 = i; i3 < i + i2; i3++) {
                        strArr[i3 - i] = (String) list.get(i3).get(CoreAnnotations.AnswerAnnotation.class);
                    }
                    this.answerArrays.add(Arrays.asList(strArr));
                }
            }
            for (int i4 = 0; i4 < size; i4++) {
                this.classIndex.add(list.get(i4).get(CoreAnnotations.AnswerAnnotation.class));
            }
            if (this.flags.useReverse) {
                Collections.reverse(list);
            }
        }
    }

    public <T extends CoreLabel> Datum<String, String> makeDatum(List<IN> list, int i, FeatureFactory featureFactory) {
        PaddedList paddedList = new PaddedList(list, this.pad);
        ArrayList arrayList = new ArrayList();
        for (Clique clique : featureFactory.getCliques()) {
            arrayList.addAll(addOtherClasses(featureFactory.getCliqueFeatures(paddedList, i, clique), paddedList, i, clique));
        }
        printFeatures((CoreMap) paddedList.get(i), arrayList);
        return new BasicDatum(arrayList, list.get(i).get(CoreAnnotations.AnswerAnnotation.class));
    }

    private static Collection<String> addOtherClasses(Collection<String> collection, List<? extends CoreLabel> list, int i, Clique clique) {
        String str = null;
        String str2 = (String) list.get(i - 1).get(CoreAnnotations.AnswerAnnotation.class);
        String str3 = (String) list.get(i - 2).get(CoreAnnotations.AnswerAnnotation.class);
        String str4 = (String) list.get(i - 3).get(CoreAnnotations.AnswerAnnotation.class);
        String str5 = (String) list.get(i - 4).get(CoreAnnotations.AnswerAnnotation.class);
        String str6 = (String) list.get(i - 5).get(CoreAnnotations.AnswerAnnotation.class);
        String str7 = (String) list.get(i + 1).get(CoreAnnotations.AnswerAnnotation.class);
        if (clique == FeatureFactory.cliqueCpC) {
            str = '|' + str2;
        } else if (clique == FeatureFactory.cliqueCp2C) {
            str = '|' + str3;
        } else if (clique == FeatureFactory.cliqueCp3C) {
            str = '|' + str4;
        } else if (clique == FeatureFactory.cliqueCp4C) {
            str = '|' + str5;
        } else if (clique == FeatureFactory.cliqueCp5C) {
            str = '|' + str6;
        } else if (clique == FeatureFactory.cliqueCpCp2C) {
            str = '|' + str2 + '-' + str3;
        } else if (clique == FeatureFactory.cliqueCpCp2Cp3C) {
            str = '|' + str2 + '-' + str3 + '-' + str4;
        } else if (clique == FeatureFactory.cliqueCpCp2Cp3Cp4C) {
            str = '|' + str2 + '-' + str3 + '-' + str4 + '-' + str5;
        } else if (clique == FeatureFactory.cliqueCpCp2Cp3Cp4Cp5C) {
            str = '|' + str2 + '-' + str3 + '-' + str4 + '-' + str5 + '-' + str6;
        } else if (clique == FeatureFactory.cliqueCnC) {
            str = '|' + str7;
        } else if (clique == FeatureFactory.cliqueCpCnC) {
            str = '|' + str2 + '-' + str7;
        }
        if (str == null) {
            return collection;
        }
        HashSet hashSet = new HashSet();
        Iterator<String> it = collection.iterator();
        while (it.hasNext()) {
            hashSet.add(it.next() + str);
        }
        return hashSet;
    }

    private static List<Pair<Pattern, Integer>> getThresholds(String str) {
        try {
            BufferedReader bufferedReader = new BufferedReader(new FileReader(str));
            ArrayList arrayList = new ArrayList();
            while (true) {
                String readLine = bufferedReader.readLine();
                if (readLine == null) {
                    bufferedReader.close();
                    return arrayList;
                }
                int lastIndexOf = readLine.lastIndexOf(32);
                arrayList.add(new Pair(Pattern.compile(readLine.substring(0, lastIndexOf)), Integer.valueOf(readLine.substring(lastIndexOf + 1))));
            }
        } catch (Exception e) {
            throw new RuntimeException("Error reading threshold file", e);
        }
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void trainSemiSup() {
        DocumentReaderAndWriter makeReaderAndWriter = makeReaderAndWriter();
        String str = this.flags.trainFile;
        String str2 = this.flags.biasedTrainFile;
        ObjectBank makeObjectBankFromFile = makeObjectBankFromFile(str, makeReaderAndWriter);
        ObjectBank makeObjectBankFromFile2 = makeObjectBankFromFile(str2, makeReaderAndWriter);
        HashIndex hashIndex = new HashIndex();
        HashIndex hashIndex2 = new HashIndex();
        Dataset<String, String> dataset = getDataset(makeObjectBankFromFile, hashIndex, hashIndex2);
        Dataset<String, String> biasedDataset = getBiasedDataset(makeObjectBankFromFile2, hashIndex, hashIndex2);
        double[][] dArr = new double[hashIndex2.size()][hashIndex2.size()];
        for (int i = 0; i < dArr.length; i++) {
            Arrays.fill(dArr[i], 0.0d);
            dArr[i][i] = 1.0d;
        }
        for (String str3 : this.flags.confusionMatrix.split(MorphoFeatures.KEY_VAL_DELIM)) {
            String[] split = str3.split("\\|");
            dArr[hashIndex2.indexOf(split[1])][hashIndex2.indexOf(split[0])] = Double.parseDouble(split[2]);
        }
        for (double[] dArr2 : dArr) {
            ArrayMath.normalize(dArr2);
        }
        for (int i2 = 0; i2 < dArr.length; i2++) {
            for (int i3 = 0; i3 < i2; i3++) {
                double d = dArr[i2][i3];
                dArr[i2][i3] = dArr[i3][i2];
                dArr[i3][i2] = d;
            }
        }
        for (int i4 = 0; i4 < dArr.length; i4++) {
            for (int i5 = 0; i5 < dArr.length; i5++) {
                System.err.println("P(" + ((String) hashIndex2.get(i5)) + '|' + ((String) hashIndex2.get(i4)) + ") = " + dArr[i5][i4]);
            }
        }
        trainSemiSup(dataset, biasedDataset, dArr);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public boolean normalize() {
        return this.flags.normalize;
    }

    public Counter<String> scoresOf(List<IN> list, int i) {
        return this.classifier.logProbabilityOf(makeDatum(list, i, this.featureFactory));
    }

    @Override // edu.stanford.nlp.ie.AbstractSequenceClassifier
    public void printProbsDocument(List<IN> list) {
    }

    public static void main(String[] strArr) throws Exception {
        StringUtils.printErrInvocationString("CMMClassifier", strArr);
        Properties argsToProperties = StringUtils.argsToProperties(strArr);
        CMMClassifier cMMClassifier = new CMMClassifier(argsToProperties);
        String str = cMMClassifier.flags.testFile;
        String str2 = cMMClassifier.flags.textFile;
        String str3 = cMMClassifier.flags.loadClassifier;
        String str4 = cMMClassifier.flags.serializeTo;
        if (str3 != null) {
            cMMClassifier.loadClassifierNoExceptions(str3, argsToProperties);
        } else if (cMMClassifier.flags.loadJarClassifier != null) {
            cMMClassifier.loadJarClassifier(cMMClassifier.flags.loadJarClassifier, argsToProperties);
        } else if (cMMClassifier.flags.trainFile == null) {
            cMMClassifier.loadDefaultClassifier();
        } else if (cMMClassifier.flags.biasedTrainFile != null) {
            cMMClassifier.trainSemiSup();
        } else {
            cMMClassifier.train();
        }
        if (str4 != null) {
            cMMClassifier.serializeClassifier(str4);
        }
        if (str != null) {
            cMMClassifier.classifyAndWriteAnswers(str, cMMClassifier.makeReaderAndWriter());
        } else if (cMMClassifier.flags.testFiles != null) {
            cMMClassifier.classifyAndWriteAnswers(cMMClassifier.flags.baseTestDir, cMMClassifier.flags.testFiles, cMMClassifier.makeReaderAndWriter());
        }
        if (str2 != null) {
            cMMClassifier.classifyAndWriteAnswers(str2, new PlainTextDocumentReaderAndWriter());
        }
    }

    @Override // edu.stanford.nlp.process.DocumentProcessor
    public Document<?, ?, WordTag> processDocument(Document document) {
        Document<?, ?, WordTag> blankDocument = document.blankDocument();
        blankDocument.addAll(process(document));
        return blankDocument;
    }

    @Override // edu.stanford.nlp.process.ListProcessor
    public List<WordTag> process(List<? extends Object> list) {
        ArrayList arrayList = new ArrayList();
        for (Object obj : list) {
            CoreLabel coreLabel = new CoreLabel();
            if (obj instanceof HasWord) {
                coreLabel.setWord(((HasWord) obj).word());
                if (obj instanceof HasTag) {
                    coreLabel.setTag(((HasTag) obj).tag());
                }
            } else {
                coreLabel.setWord(obj.toString());
            }
            arrayList.add(coreLabel);
        }
        List<IN> classify = classify(arrayList);
        ArrayList arrayList2 = new ArrayList();
        for (IN in : classify) {
            arrayList2.add(new WordTag(in.word(), (String) in.get(CoreAnnotations.AnswerAnnotation.class)));
        }
        return arrayList2;
    }

    public double weight(String str, String str2) {
        return ((LinearClassifier) this.classifier).weight((LinearClassifier) str, str2);
    }

    public double[][] weights() {
        return ((LinearClassifier) this.classifier).weights();
    }

    @Override // edu.stanford.nlp.ie.AbstractSequenceClassifier
    public List<IN> classifyWithGlobalInformation(List<IN> list, CoreMap coreMap, CoreMap coreMap2) {
        return classify(list);
    }
}
