package edu.stanford.nlp.patterns.surface;

import edu.stanford.nlp.classify.GeneralDataset;
import edu.stanford.nlp.classify.LogisticClassifier;
import edu.stanford.nlp.classify.LogisticClassifierFactory;
import edu.stanford.nlp.classify.RVFDataset;
import edu.stanford.nlp.io.IOUtils;
import edu.stanford.nlp.ling.CoreAnnotations;
import edu.stanford.nlp.ling.CoreLabel;
import edu.stanford.nlp.ling.RVFDatum;
import edu.stanford.nlp.patterns.CandidatePhrase;
import edu.stanford.nlp.patterns.DataInstance;
import edu.stanford.nlp.patterns.PatternsAnnotations;
import edu.stanford.nlp.sequences.SeqClassifierFlags;
import edu.stanford.nlp.stats.ClassicCounter;
import edu.stanford.nlp.stats.Counter;
import edu.stanford.nlp.stats.Counters;
import edu.stanford.nlp.util.CollectionValuedMap;
import edu.stanford.nlp.util.Execution;
import edu.stanford.nlp.util.Pair;
import edu.stanford.nlp.util.StringUtils;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Random;

/* loaded from: input_file:edu/stanford/nlp/patterns/surface/LearnImportantFeatures.class */
public class LearnImportantFeatures {

    @Execution.Option(name = "answerClass")
    public Class answerClass = CoreAnnotations.AnswerAnnotation.class;

    @Execution.Option(name = "answerLabel")
    public String answerLabel = "WORD";

    @Execution.Option(name = "wordClassClusterFile")
    String wordClassClusterFile = null;

    @Execution.Option(name = "thresholdWeight")
    Double thresholdWeight = null;
    Map<String, Integer> clusterIds = new HashMap();
    CollectionValuedMap<Integer, String> clusters = new CollectionValuedMap<>();

    @Execution.Option(name = "negativeWordsFiles")
    String negativeWordsFiles = null;
    HashSet<String> negativeWords = new HashSet<>();
    static final /* synthetic */ boolean $assertionsDisabled;

    public void setUp() {
        if (!$assertionsDisabled && this.wordClassClusterFile == null) {
            throw new AssertionError();
        }
        if (this.wordClassClusterFile != null) {
            Iterator<String> it = IOUtils.readLines(this.wordClassClusterFile).iterator();
            while (it.hasNext()) {
                String[] split = it.next().split("\\s+");
                int parseInt = Integer.parseInt(split[1]);
                this.clusterIds.put(split[0], Integer.valueOf(parseInt));
                this.clusters.add(Integer.valueOf(parseInt), split[0]);
            }
        }
        if (this.negativeWordsFiles != null) {
            for (String str : this.negativeWordsFiles.split("[,;]")) {
                this.negativeWords.addAll(IOUtils.linesFromFile(str));
            }
            System.out.println("number of negative words from lists " + this.negativeWords.size());
        }
    }

    public static boolean getRandomBoolean(Random random, double d) {
        return ((double) random.nextFloat()) < d;
    }

    private int sample(Map<String, DataInstance> map, Random random, Random random2, double d, double d2, int i, List<Pair<String, Integer>> list, RVFDataset<String, String> rVFDataset) {
        boolean z;
        for (Map.Entry<String, DataInstance> entry : map.entrySet()) {
            CoreLabel[] coreLabelArr = (CoreLabel[]) entry.getValue().getTokens().toArray(new CoreLabel[0]);
            for (int i2 = 0; i2 < coreLabelArr.length; i2++) {
                CoreLabel coreLabel = coreLabelArr[i2];
                if (coreLabel.get(this.answerClass).equals(this.answerLabel)) {
                    z = true;
                } else if ((!coreLabel.get(this.answerClass).equals(SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL) || this.negativeWords.contains(coreLabel.word().toLowerCase())) && getRandomBoolean(random, d)) {
                    z = true;
                } else if (getRandomBoolean(random, d2)) {
                    i++;
                    z = true;
                } else {
                    z = false;
                }
                if (z) {
                    list.add(new Pair<>(entry.getKey(), Integer.valueOf(i2)));
                    rVFDataset.add(getDatum(coreLabelArr, i2), entry.getKey(), Integer.toString(i2));
                }
            }
        }
        return i;
    }

    public Counter<String> getTopFeatures(Iterator<Pair<Map<String, DataInstance>, File>> it, double d, double d2, String str) throws IOException, ClassNotFoundException {
        ClassicCounter classicCounter = new ClassicCounter();
        RVFDataset<String, String> rVFDataset = new RVFDataset<>();
        Random random = new Random(10L);
        Random random2 = new Random(10L);
        int i = 0;
        ArrayList arrayList = new ArrayList();
        while (it.hasNext()) {
            i = sample(it.next().first(), random, random2, d2, d, i, arrayList, rVFDataset);
        }
        System.out.println("num random chosen: " + i);
        System.out.println("Number of datums per label: " + rVFDataset.numDatumsPerLabel());
        LogisticClassifier trainClassifier = new LogisticClassifierFactory().trainClassifier((GeneralDataset) rVFDataset);
        Counter weightsAsGenericCounter = trainClassifier.weightsAsGenericCounter();
        if (!((String) trainClassifier.getLabelForInternalPositiveClass()).equals(this.answerLabel)) {
            weightsAsGenericCounter = Counters.scale(weightsAsGenericCounter, -1.0d);
        }
        if (this.thresholdWeight != null) {
            HashSet hashSet = new HashSet();
            for (Map.Entry entry : weightsAsGenericCounter.entrySet()) {
                if (Math.abs(((Double) entry.getValue()).doubleValue()) <= this.thresholdWeight.doubleValue()) {
                    hashSet.add(entry.getKey());
                }
            }
            Counters.removeKeys(weightsAsGenericCounter, hashSet);
            System.out.println("Removing " + hashSet);
        }
        IOUtils.writeStringToFile(Counters.toSortedString(weightsAsGenericCounter, weightsAsGenericCounter.size(), "%1$s:%2$f", "\n"), str, "utf8");
        return classicCounter;
    }

    private RVFDatum<String, String> getDatum(CoreLabel[] coreLabelArr, int i) {
        ClassicCounter classicCounter = new ClassicCounter();
        CoreLabel coreLabel = coreLabelArr[i];
        String str = coreLabel.get(this.answerClass).toString().equals(this.answerLabel) ? this.answerLabel : SeqClassifierFlags.DEFAULT_BACKGROUND_SYMBOL;
        CollectionValuedMap collectionValuedMap = (CollectionValuedMap) coreLabel.get(PatternsAnnotations.MatchedPhrases.class);
        if (collectionValuedMap == null) {
            collectionValuedMap = new CollectionValuedMap();
            collectionValuedMap.add(str, CandidatePhrase.createOrGet(coreLabel.word()));
        }
        Iterator it = collectionValuedMap.allValues().iterator();
        while (it.hasNext()) {
            Integer num = this.clusterIds.get(((CandidatePhrase) it.next()).getPhrase());
            if (num == null) {
                num = -1;
            }
            classicCounter.setCount("Cluster-" + num, 1.0d);
        }
        for (int max = Math.max(0, i - 0); max < i; max++) {
            CoreLabel coreLabel2 = coreLabelArr[max];
            classicCounter.incrementCount("PREV-WORD-" + coreLabel2.word());
            classicCounter.incrementCount("PREV-LEMMA-" + coreLabel2.lemma());
            classicCounter.incrementCount("PREV-TAG-" + coreLabel2.tag());
        }
        for (int i2 = i + 1; i2 < coreLabelArr.length && i2 <= i + 0; i2++) {
            CoreLabel coreLabel3 = coreLabelArr[i2];
            classicCounter.incrementCount("NEXT-WORD-" + coreLabel3.word());
            classicCounter.incrementCount("NEXT-LEMMA-" + coreLabel3.lemma());
            classicCounter.incrementCount("NEXT-TAG-" + coreLabel3.tag());
        }
        return new RVFDatum<>(classicCounter, str);
    }

    public static void main(String[] strArr) {
        try {
            LearnImportantFeatures learnImportantFeatures = new LearnImportantFeatures();
            Properties argsToPropertiesWithResolve = StringUtils.argsToPropertiesWithResolve(strArr);
            Execution.fillOptions(learnImportantFeatures, argsToPropertiesWithResolve);
            learnImportantFeatures.setUp();
            String property = argsToPropertiesWithResolve.getProperty("sentsFile");
            System.out.println("Read the sents file: " + property);
            Double.parseDouble(argsToPropertiesWithResolve.getProperty("perSelectRand"));
            Double.parseDouble(argsToPropertiesWithResolve.getProperty("perSelectNeg"));
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    static {
        $assertionsDisabled = !LearnImportantFeatures.class.desiredAssertionStatus();
    }
}
