/*
 * Decompiled with CFR 0.152.
 */
package org.apache.solr.handler;

import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.comp.StreamComparator;
import org.apache.solr.client.solrj.io.stream.StreamContext;
import org.apache.solr.client.solrj.io.stream.TupleStream;
import org.apache.solr.client.solrj.io.stream.expr.Explanation;
import org.apache.solr.client.solrj.io.stream.expr.Expressible;
import org.apache.solr.client.solrj.io.stream.expr.StreamExplanation;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionParameter;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
import org.apache.solr.common.SolrException;
import org.apache.solr.core.SolrCore;

public class ClassifyStream
extends TupleStream
implements Expressible {
    private TupleStream docStream;
    private TupleStream modelStream;
    private String field;
    private String analyzerField;
    private Tuple modelTuple;
    Analyzer analyzer;
    private Map<CharSequence, Integer> termToIndex;
    private List<Double> idfs;
    private List<Double> modelWeights;

    public ClassifyStream(StreamExpression expression, StreamFactory factory) throws IOException {
        List<StreamExpression> streamExpressions = factory.getExpressionOperandsRepresentingTypes(expression, Expressible.class, TupleStream.class);
        if (streamExpressions.size() != 2) {
            throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - expecting two stream but found %d", expression, streamExpressions.size()));
        }
        this.modelStream = factory.constructStream(streamExpressions.get(0));
        this.docStream = factory.constructStream(streamExpressions.get(1));
        StreamExpressionNamedParameter fieldParameter = factory.getNamedOperand(expression, "field");
        if (fieldParameter == null) {
            throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - field parameter must be specified", expression, streamExpressions.size()));
        }
        this.analyzerField = this.field = fieldParameter.getParameter().toString();
        StreamExpressionNamedParameter analyzerFieldParameter = factory.getNamedOperand(expression, "analyzerField");
        if (analyzerFieldParameter != null) {
            this.analyzerField = analyzerFieldParameter.getParameter().toString();
        }
    }

    @Override
    public void setStreamContext(StreamContext context) {
        Object solrCoreObj = context.get("solr-core");
        if (solrCoreObj == null || !(solrCoreObj instanceof SolrCore)) {
            throw new SolrException(SolrException.ErrorCode.INVALID_STATE, "StreamContext must have SolrCore in solr-core key");
        }
        this.analyzer = ((SolrCore)solrCoreObj).getLatestSchema().getFieldType(this.analyzerField).getIndexAnalyzer();
        this.docStream.setStreamContext(context);
        this.modelStream.setStreamContext(context);
    }

    @Override
    public List<TupleStream> children() {
        ArrayList<TupleStream> l = new ArrayList<TupleStream>();
        l.add(this.docStream);
        l.add(this.modelStream);
        return l;
    }

    @Override
    public void open() throws IOException {
        this.docStream.open();
        this.modelStream.open();
    }

    @Override
    public void close() throws IOException {
        this.docStream.close();
        this.modelStream.close();
    }

    @Override
    public Tuple read() throws IOException {
        if (this.modelTuple == null) {
            this.modelTuple = this.modelStream.read();
            if (this.modelTuple == null || this.modelTuple.EOF) {
                throw new IOException("Model tuple not found for classify stream!");
            }
            this.termToIndex = new HashMap<CharSequence, Integer>();
            List<String> terms = this.modelTuple.getStrings("terms_ss");
            for (int i = 0; i < terms.size(); ++i) {
                this.termToIndex.put(terms.get(i), i);
            }
            this.idfs = this.modelTuple.getDoubles("idfs_ds");
            this.modelWeights = this.modelTuple.getDoubles("weights_ds");
        }
        Tuple docTuple = this.docStream.read();
        if (docTuple.EOF) {
            return docTuple;
        }
        String text = docTuple.getString(this.field);
        double[] tfs = new double[this.termToIndex.size()];
        TokenStream tokenStream = this.analyzer.tokenStream(this.analyzerField, text);
        CharTermAttribute termAtt = (CharTermAttribute)tokenStream.getAttribute(CharTermAttribute.class);
        tokenStream.reset();
        int termCount = 0;
        while (tokenStream.incrementToken()) {
            ++termCount;
            if (!this.termToIndex.containsKey(termAtt.toString())) continue;
            int n = this.termToIndex.get(termAtt.toString());
            tfs[n] = tfs[n] + 1.0;
        }
        tokenStream.end();
        tokenStream.close();
        ArrayList<Double> tfidfs = new ArrayList<Double>(this.termToIndex.size());
        tfidfs.add(1.0);
        for (int i = 0; i < tfs.length; ++i) {
            if (tfs[i] != 0.0) {
                tfs[i] = 1.0 + Math.log(tfs[i]);
            }
            tfidfs.add(this.idfs.get(i) * tfs[i]);
        }
        double total = 0.0;
        for (int i = 0; i < tfidfs.size(); ++i) {
            total += (Double)tfidfs.get(i) * this.modelWeights.get(i);
        }
        double score = total * (double)((float)(1.0 / Math.sqrt(termCount)));
        double positiveProb = this.sigmoid(total);
        docTuple.put("probability_d", positiveProb);
        docTuple.put("score_d", score);
        return docTuple;
    }

    private double sigmoid(double in) {
        double d = 1.0 / (1.0 + Math.exp(-in));
        return d;
    }

    @Override
    public StreamComparator getStreamSort() {
        return null;
    }

    @Override
    public StreamExpressionParameter toExpression(StreamFactory factory) throws IOException {
        return this.toExpression(factory, true);
    }

    private StreamExpression toExpression(StreamFactory factory, boolean includeStreams) throws IOException {
        StreamExpression expression = new StreamExpression(factory.getFunctionName(this.getClass()));
        if (includeStreams) {
            if (this.docStream instanceof Expressible && this.modelStream instanceof Expressible) {
                expression.addParameter(((Expressible)((Object)this.modelStream)).toExpression(factory));
                expression.addParameter(((Expressible)((Object)this.docStream)).toExpression(factory));
            } else {
                throw new IOException("This ClassifyStream contains a non-expressible TupleStream - it cannot be converted to an expression");
            }
        }
        expression.addParameter(new StreamExpressionNamedParameter("field", this.field));
        expression.addParameter(new StreamExpressionNamedParameter("analyzerField", this.analyzerField));
        return expression;
    }

    @Override
    public Explanation toExplanation(StreamFactory factory) throws IOException {
        StreamExplanation explanation = new StreamExplanation(this.getStreamNodeId().toString());
        explanation.setFunctionName(factory.getFunctionName(this.getClass()));
        explanation.setImplementingClass(this.getClass().getName());
        explanation.setExpressionType("stream-decorator");
        explanation.setExpression(this.toExpression(factory, false).toString());
        explanation.addChild(this.docStream.toExplanation(factory));
        explanation.addChild(this.modelStream.toExplanation(factory));
        return explanation;
    }
}

