package com.lucidworks.spark.example.ml;

import com.lucidworks.spark.SparkApp;
import com.lucidworks.spark.ml.feature.LuceneTextAnalyzerTransformer;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Option;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.Pipeline;
import org.apache.spark.ml.PipelineStage;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.NaiveBayes;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.ml.feature.HashingTF;
import org.apache.spark.ml.feature.IndexToString;
import org.apache.spark.ml.feature.StringIndexer;
import org.apache.spark.ml.tuning.CrossValidator;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.ml.tuning.ParamGridBuilder;
import org.apache.spark.mllib.evaluation.MulticlassMetrics;
import org.apache.spark.mllib.linalg.Matrix;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import scala.collection.JavaConversions$;

/* loaded from: input_file:com/lucidworks/spark/example/ml/MLPipeline.class */
public class MLPipeline implements SparkApp.RDDProcessor {
    @Override // com.lucidworks.spark.SparkApp.RDDProcessor
    public String getName() {
        return "ml-pipeline";
    }

    @Override // com.lucidworks.spark.SparkApp.RDDProcessor
    public Option[] getOptions() {
        return new Option[]{Option.builder().hasArg().required(false).desc("Query to identify documents in the training set").longOpt("query").build(), Option.builder().hasArg().required(false).desc("Field in Solr containing the label for each document in the training set").longOpt("labelField").build(), Option.builder().hasArg().required(false).desc("Comma-separated list of field(s) in Solr containing the text content for each document in the training set").longOpt("contentFields").build(), Option.builder().hasArg().required(false).desc("Classifier type: either NaiveBayes or LogisticRegression").longOpt("classifier").build(), Option.builder().hasArg().required(false).desc("Fraction (0 to 1) of full dataset to sample from Solr, default is 1").longOpt("sample").build()};
    }

    @Override // com.lucidworks.spark.SparkApp.RDDProcessor
    public int run(SparkConf sparkConf, CommandLine commandLine) throws Exception {
        NaiveBayes naiveBayes;
        SQLContext sQLContext = new SQLContext(new JavaSparkContext(sparkConf));
        String optionValue = commandLine.getOptionValue("zkHost", "localhost:9983");
        String optionValue2 = commandLine.getOptionValue("collection", "ml20news");
        String optionValue3 = commandLine.getOptionValue("query", "content_txt:[* TO *] AND newsgroup_s:[* TO *]");
        String optionValue4 = commandLine.getOptionValue("labelField", "newsgroup_s");
        String optionValue5 = commandLine.getOptionValue("contentFields", "content_txt,subject");
        HashMap hashMap = new HashMap();
        hashMap.put("zkhost", optionValue);
        hashMap.put("collection", optionValue2);
        hashMap.put("query", optionValue3);
        hashMap.put("fields", "id," + optionValue4 + "," + optionValue5);
        DataFrame sample = sQLContext.read().format("solr").options(hashMap).load().sample(false, Double.parseDouble(commandLine.getOptionValue("sample", "1.0")));
        PipelineStage fit = new StringIndexer().setInputCol(optionValue4).setOutputCol("label").fit(sample);
        String[] split = optionValue5.split(",");
        for (int i = 0; i < split.length; i++) {
            split[i] = split[i].trim();
        }
        List asList = Arrays.asList(json("{ 'analyzers': [{ 'name': 'ws_tok', 'tokenizer': { 'type': 'whitespace' }}],\n'fields': [{ 'regex': '.+', 'analyzer': 'ws_tok' }]}\n"), json("{ 'analyzers': [{ 'name': 'std_tok_lower', 'tokenizer': { 'type': 'standard' },\n                'filters': [{ 'type': 'lowercase' }]}],\n  'fields': [{ 'regex': '.+', 'analyzer': 'std_tok_lower' }]}\n"));
        PipelineStage outputCol = new LuceneTextAnalyzerTransformer().setInputCols(split).setOutputCol("words");
        PipelineStage outputCol2 = new HashingTF().setInputCol("words").setOutputCol("features");
        if ("NaiveBayes".equals(commandLine.getOptionValue("classifier", "LogisticRegression"))) {
            naiveBayes = new NaiveBayes();
        } else {
            NaiveBayes classifier = new OneVsRest().setClassifier(new LogisticRegression().setMaxIter(10));
            classifier.setLabelCol("label");
            naiveBayes = classifier;
        }
        System.out.println("Using estimator: " + naiveBayes);
        PipelineStage labels = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(fit.labels());
        Pipeline stages = new Pipeline().setStages(new PipelineStage[]{fit, outputCol, outputCol2, naiveBayes, labels});
        DataFrame[] randomSplit = sample.randomSplit(new double[]{0.7d, 0.3d});
        DataFrame dataFrame = randomSplit[0];
        DataFrame dataFrame2 = randomSplit[1];
        MulticlassClassificationEvaluator metricName = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("precision");
        ParamGridBuilder addGrid = new ParamGridBuilder().addGrid(outputCol2.numFeatures(), new int[]{1000, 5000}).addGrid(outputCol.analysisSchema(), JavaConversions$.MODULE$.asScalaIterable(asList)).addGrid(outputCol.prefixTokensWithInputCol());
        if (naiveBayes instanceof LogisticRegression) {
            addGrid.addGrid(((LogisticRegression) naiveBayes).regParam(), new double[]{0.1d, 0.01d});
        } else if (naiveBayes instanceof NaiveBayes) {
            addGrid.addGrid(naiveBayes.smoothing(), new double[]{1.0d, 0.5d});
        }
        CrossValidatorModel fit2 = new CrossValidator().setEstimator(stages).setEvaluator(metricName).setEstimatorParamMaps(addGrid.build()).setNumFolds(3).fit(dataFrame);
        System.out.println("Best model params: " + Arrays.toString(fit2.bestModel().params()));
        fit2.write().overwrite().save("ml-pipeline-model");
        DataFrame transform = CrossValidatorModel.load("ml-pipeline-model").transform(dataFrame2);
        transform.cache();
        System.out.println("Cross-Fold Test Error = " + (1.0d - metricName.evaluate(transform)));
        for (Row row : transform.select("id", new String[]{optionValue4, "predictedLabel"}).sample(false, 0.1d).collect()) {
            System.out.println(row.get(0) + ": actual=" + row.get(1) + ", predicted=" + row.get(2));
        }
        MulticlassMetrics multiclassMetrics = new MulticlassMetrics(transform.select("prediction", new String[]{"label"}));
        Matrix confusionMatrix = multiclassMetrics.confusionMatrix();
        System.out.println("Confusion Matrix");
        System.out.println(confusionMatrix);
        System.out.println();
        System.out.println("F-Measure: " + multiclassMetrics.fMeasure());
        System.out.println("label\tfpr\n");
        String[] labels2 = labels.getLabels();
        for (int i2 = 0; i2 < labels2.length; i2++) {
            System.out.print(labels2[i2]);
            System.out.print("\t");
            System.out.print(multiclassMetrics.falsePositiveRate(i2));
            System.out.println();
        }
        return 0;
    }

    private String json(String str) {
        return str.replaceAll("'", "\"");
    }
}
