/*
 * Decompiled with CFR 0.152.
 */
package com.lucidworks.spark.example.ml;

import com.lucidworks.spark.SparkApp;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Option;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.ml.PipelineModel;
import org.apache.spark.ml.tuning.CrossValidatorModel;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.types.StructType;

public class UseML
implements SparkApp.RDDProcessor {
    @Override
    public String getName() {
        return "use-ml";
    }

    @Override
    public Option[] getOptions() {
        return new Option[0];
    }

    @Override
    public int run(SparkConf conf, CommandLine cli) throws Exception {
        long startMs = System.currentTimeMillis();
        conf.set("spark.ui.enabled", "false");
        SparkSession sparkSession = SparkSession.builder().config(conf).getOrCreate();
        JavaSparkContext jsc = new JavaSparkContext(sparkSession.sparkContext());
        long diffMs = System.currentTimeMillis() - startMs;
        System.out.println(">> took " + diffMs + " ms to create SQLContext");
        HashMap<String, String> options = new HashMap<String, String>();
        options.put("zkhost", "localhost:9983");
        options.put("collection", "ml20news");
        options.put("query", "content_txt:[* TO *]");
        options.put("fields", "content_txt");
        Dataset solrData = sparkSession.read().format("solr").options(options).load();
        Dataset sample = solrData.sample(false, 0.1, 5150L).select("content_txt", new String[0]);
        List rows = sample.collectAsList();
        System.out.println(">> loaded " + rows.size() + " docs to classify");
        StructType schema = sample.schema();
        CrossValidatorModel cvModel = CrossValidatorModel.load((String)"ml-pipeline-model");
        PipelineModel bestModel = (PipelineModel)cvModel.bestModel();
        int r = 0;
        startMs = System.currentTimeMillis();
        for (Object o : rows) {
            Row next = (Row)o;
            Row oneRow = RowFactory.create((Object[])new Object[]{next.getString(0)});
            Dataset oneRowDF = sparkSession.createDataFrame(Collections.singletonList(oneRow), schema);
            Dataset scored = bestModel.transform(oneRowDF);
            Object o1 = scored.collectAsList().get(0);
            Row scoredRow = (Row)o1;
            String predictedLabel = scoredRow.getString(scoredRow.fieldIndex("predictedLabel"));
            ++r;
        }
        diffMs = System.currentTimeMillis() - startMs;
        System.out.println(">> took " + diffMs + " ms to score " + rows.size() + " docs");
        return 0;
    }
}

