/*
 * Decompiled with CFR 0.152.
 */
package org.apache.lucene.classification;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.lucene.classification.ClassificationResult;
import org.apache.lucene.classification.Classifier;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.MultiTerms;
import org.apache.lucene.index.Term;
import org.apache.lucene.index.Terms;
import org.apache.lucene.index.TermsEnum;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.WildcardQuery;
import org.apache.lucene.util.BytesRef;

public class SimpleNaiveBayesClassifier
implements Classifier<BytesRef> {
    protected final IndexReader indexReader;
    protected final String[] textFieldNames;
    protected final String classFieldName;
    protected final Analyzer analyzer;
    protected final IndexSearcher indexSearcher;
    protected final Query query;

    public SimpleNaiveBayesClassifier(IndexReader indexReader, Analyzer analyzer, Query query, String classFieldName, String ... textFieldNames) {
        this.indexReader = indexReader;
        this.indexSearcher = new IndexSearcher(this.indexReader);
        this.textFieldNames = textFieldNames;
        this.classFieldName = classFieldName;
        this.analyzer = analyzer;
        this.query = query;
    }

    @Override
    public ClassificationResult<BytesRef> assignClass(String inputDocument) throws IOException {
        List<ClassificationResult<BytesRef>> assignedClasses = this.assignClassNormalizedList(inputDocument);
        ClassificationResult<BytesRef> assignedClass = null;
        double maxscore = -1.7976931348623157E308;
        for (ClassificationResult<BytesRef> c : assignedClasses) {
            if (!(c.getScore() > maxscore)) continue;
            assignedClass = c;
            maxscore = c.getScore();
        }
        return assignedClass;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text) throws IOException {
        List<ClassificationResult<BytesRef>> assignedClasses = this.assignClassNormalizedList(text);
        Collections.sort(assignedClasses);
        return assignedClasses;
    }

    @Override
    public List<ClassificationResult<BytesRef>> getClasses(String text, int max) throws IOException {
        List<ClassificationResult<BytesRef>> assignedClasses = this.assignClassNormalizedList(text);
        Collections.sort(assignedClasses);
        return assignedClasses.subList(0, max);
    }

    protected List<ClassificationResult<BytesRef>> assignClassNormalizedList(String inputDocument) throws IOException {
        ArrayList<ClassificationResult<BytesRef>> assignedClasses = new ArrayList<ClassificationResult<BytesRef>>();
        Terms classes = MultiTerms.getTerms(this.indexReader, this.classFieldName);
        if (classes != null) {
            BytesRef next;
            TermsEnum classesEnum = classes.iterator();
            String[] tokenizedText = this.tokenize(inputDocument);
            int docsWithClassSize = this.countDocsWithClass();
            while ((next = classesEnum.next()) != null) {
                if (next.length <= 0) continue;
                Term term = new Term(this.classFieldName, next);
                double clVal = this.calculateLogPrior(term, docsWithClassSize) + this.calculateLogLikelihood(tokenizedText, term, docsWithClassSize);
                assignedClasses.add(new ClassificationResult<BytesRef>(term.bytes(), clVal));
            }
        }
        return this.normClassificationResults(assignedClasses);
    }

    protected int countDocsWithClass() throws IOException {
        int docCount;
        Terms terms = MultiTerms.getTerms(this.indexReader, this.classFieldName);
        if (terms == null || terms.getDocCount() == -1) {
            BooleanQuery.Builder q = new BooleanQuery.Builder();
            q.add(new BooleanClause(new WildcardQuery(new Term(this.classFieldName, String.valueOf('*'))), BooleanClause.Occur.MUST));
            if (this.query != null) {
                q.add(this.query, BooleanClause.Occur.MUST);
            }
            docCount = this.indexSearcher.count(q.build());
        } else {
            docCount = terms.getDocCount();
        }
        return docCount;
    }

    protected String[] tokenize(String text) throws IOException {
        LinkedList<String> result = new LinkedList<String>();
        for (String textFieldName : this.textFieldNames) {
            try (TokenStream tokenStream = this.analyzer.tokenStream(textFieldName, text);){
                CharTermAttribute charTermAttribute = tokenStream.addAttribute(CharTermAttribute.class);
                tokenStream.reset();
                while (tokenStream.incrementToken()) {
                    result.add(charTermAttribute.toString());
                }
                tokenStream.end();
            }
        }
        return result.toArray(new String[0]);
    }

    private double calculateLogLikelihood(String[] tokenizedText, Term term, int docsWithClass) throws IOException {
        double result = 0.0;
        for (String word : tokenizedText) {
            int hits = this.getWordFreqForClass(word, term);
            double num = hits + 1;
            double den = this.getTextTermFreqForClass(term) + (double)docsWithClass;
            double wordProbability = num / den;
            result += Math.log(wordProbability);
        }
        return result;
    }

    private double getTextTermFreqForClass(Term term) throws IOException {
        double avgNumberOfUniqueTerms = 0.0;
        for (String textFieldName : this.textFieldNames) {
            Terms terms = MultiTerms.getTerms(this.indexReader, textFieldName);
            long numPostings = terms.getSumDocFreq();
            avgNumberOfUniqueTerms += (double)numPostings / (double)terms.getDocCount();
        }
        int docsWithC = this.indexReader.docFreq(term);
        return avgNumberOfUniqueTerms * (double)docsWithC;
    }

    private int getWordFreqForClass(String word, Term term) throws IOException {
        BooleanQuery.Builder booleanQuery = new BooleanQuery.Builder();
        BooleanQuery.Builder subQuery = new BooleanQuery.Builder();
        for (String textFieldName : this.textFieldNames) {
            subQuery.add(new BooleanClause(new TermQuery(new Term(textFieldName, word)), BooleanClause.Occur.SHOULD));
        }
        booleanQuery.add(new BooleanClause(subQuery.build(), BooleanClause.Occur.MUST));
        booleanQuery.add(new BooleanClause(new TermQuery(term), BooleanClause.Occur.MUST));
        if (this.query != null) {
            booleanQuery.add(this.query, BooleanClause.Occur.MUST);
        }
        return this.indexSearcher.count(booleanQuery.build());
    }

    private double calculateLogPrior(Term term, int docsWithClassSize) throws IOException {
        return Math.log(this.docCount(term)) - Math.log(docsWithClassSize);
    }

    private int docCount(Term term) throws IOException {
        return this.indexReader.docFreq(term);
    }

    protected ArrayList<ClassificationResult<BytesRef>> normClassificationResults(List<ClassificationResult<BytesRef>> assignedClasses) {
        ArrayList<ClassificationResult<BytesRef>> returnList = new ArrayList<ClassificationResult<BytesRef>>();
        if (!assignedClasses.isEmpty()) {
            Collections.sort(assignedClasses);
            double smax = assignedClasses.get(0).getScore();
            double sumLog = 0.0;
            for (ClassificationResult<BytesRef> cr : assignedClasses) {
                sumLog += Math.exp(cr.getScore() - smax);
            }
            double loga = smax;
            loga += Math.log(sumLog);
            for (ClassificationResult<BytesRef> cr : assignedClasses) {
                double scoreDiff = cr.getScore() - loga;
                returnList.add(new ClassificationResult<BytesRef>(cr.getAssignedClass(), Math.exp(scoreDiff)));
            }
        }
        return returnList;
    }
}

