package com.lucidworks.spark.example.ml;

import com.lucidworks.spark.SparkApp;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Option;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.StructType;

/* loaded from: input_file:com/lucidworks/spark/example/ml/UseML.class */
public class UseML implements SparkApp.RDDProcessor {
    @Override // com.lucidworks.spark.SparkApp.RDDProcessor
    public String getName() {
        return "use-ml";
    }

    @Override // com.lucidworks.spark.SparkApp.RDDProcessor
    public Option[] getOptions() {
        return new Option[0];
    }

    @Override // com.lucidworks.spark.SparkApp.RDDProcessor
    public int run(SparkConf sparkConf, CommandLine commandLine) throws Exception {
        long currentTimeMillis = System.currentTimeMillis();
        sparkConf.set("spark.ui.enabled", "false");
        SQLContext sQLContext = new SQLContext(new JavaSparkContext(sparkConf));
        System.out.println(">> took " + (System.currentTimeMillis() - currentTimeMillis) + " ms to create SQLContext");
        HashMap hashMap = new HashMap();
        hashMap.put("zkhost", "localhost:9983");
        hashMap.put("collection", "ml20news");
        hashMap.put("query", "content_txt:[* TO *]");
        hashMap.put("fields", "content_txt");
        DataFrame select = sQLContext.read().format("solr").options(hashMap).load().sample(false, 0.1d, 5150L).select("content_txt", new String[0]);
        List collectAsList = select.collectAsList();
        System.out.println(">> loaded " + collectAsList.size() + " docs to classify");
        StructType schema = select.schema();
        PipelineModel bestModel = CrossValidatorModel.load("ml-pipeline-model").bestModel();
        int i = 0;
        long currentTimeMillis2 = System.currentTimeMillis();
        Iterator it = collectAsList.iterator();
        while (it.hasNext()) {
            Row row = bestModel.transform(sQLContext.createDataFrame(Collections.singletonList(RowFactory.create(new Object[]{((Row) it.next()).getString(0)})), schema)).collect()[0];
            row.getString(row.fieldIndex("predictedLabel"));
            i++;
        }
        System.out.println(">> took " + (System.currentTimeMillis() - currentTimeMillis2) + " ms to score " + collectAsList.size() + " docs");
        return 0;
    }
}
