package com.lucidworks.spark.fusion;

import com.fasterxml.jackson.databind.ObjectMapper;
import java.io.BufferedInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.net.URLEncoder;
import java.nio.charset.StandardCharsets;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Date;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import org.apache.commons.compress.archivers.ArchiveException;
import org.apache.commons.compress.archivers.ArchiveOutputStream;
import org.apache.commons.compress.archivers.ArchiveStreamFactory;
import org.apache.commons.compress.archivers.zip.ZipArchiveEntry;
import org.apache.commons.compress.utils.IOUtils;
import org.apache.commons.io.FileUtils;
import org.apache.http.HttpEntity;
import org.apache.http.client.entity.EntityBuilder;
import org.apache.http.client.methods.HttpPut;
import org.apache.http.client.utils.URIBuilder;
import org.apache.http.entity.ContentType;
import org.apache.http.message.BasicNameValuePair;
import org.apache.http.util.EntityUtils;
import org.apache.log4j.Logger;
import org.apache.spark.SparkContext;
import org.apache.spark.ml.util.MLWritable;
import org.apache.spark.mllib.util.Saveable;

/* loaded from: input_file:com/lucidworks/spark/fusion/FusionMLModelSupport.class */
public class FusionMLModelSupport {
    public static Logger log = Logger.getLogger(FusionMLModelSupport.class);

    public static void saveModelInFusion(String str, String str2, String str3, String str4, SparkContext sparkContext, String str5, Object obj, Map<String, String> map) throws Exception {
        HashMap hashMap = new HashMap();
        hashMap.putAll(map);
        HttpPut buildPutRequestToFusion = buildPutRequestToFusion(str5, str, hashMap, buildModelArchive(sparkContext, str5, obj, hashMap), "/api/apollo");
        HttpEntity httpEntity = null;
        try {
            httpEntity = new FusionPipelineClient(buildPutRequestToFusion.getRequestLine().getUri(), str2, str3, str4).sendRequestToFusion(buildPutRequestToFusion);
            if (httpEntity != null) {
                try {
                    EntityUtils.consume(httpEntity);
                } catch (Exception e) {
                    log.warn("Failed to consume entity due to: " + e);
                }
            }
        } catch (Throwable th) {
            if (httpEntity != null) {
                try {
                    EntityUtils.consume(httpEntity);
                } catch (Exception e2) {
                    log.warn("Failed to consume entity due to: " + e2);
                }
            }
            throw th;
        }
    }

    public static void saveModelInLocalFusion(SparkContext sparkContext, String str, Object obj, Map<String, String> map) throws Exception {
        HashMap hashMap = new HashMap();
        hashMap.putAll(map);
        HttpPut buildPutRequestToFusion = buildPutRequestToFusion(str, "localhost:8765", hashMap, buildModelArchive(sparkContext, str, obj, hashMap), "/api/v1");
        HttpEntity httpEntity = null;
        try {
            httpEntity = new FusionPipelineClient(buildPutRequestToFusion.getRequestLine().getUri()).sendRequestToFusion(buildPutRequestToFusion);
            if (httpEntity != null) {
                try {
                    EntityUtils.consume(httpEntity);
                } catch (Exception e) {
                    log.warn("Failed to consume entity due to: " + e);
                }
            }
        } catch (Throwable th) {
            if (httpEntity != null) {
                try {
                    EntityUtils.consume(httpEntity);
                } catch (Exception e2) {
                    log.warn("Failed to consume entity due to: " + e2);
                }
            }
            throw th;
        }
    }

    public static HttpPut buildPutRequestToFusion(String str, String str2, HashMap<String, String> hashMap, File file, String str3) throws Exception {
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<String, String> entry : hashMap.entrySet()) {
            arrayList.add(new BasicNameValuePair(entry.getKey(), URLEncoder.encode(entry.getValue(), "UTF-8")));
        }
        String[] split = str2.split(":");
        String str4 = str2;
        int i = 8764;
        if (split.length == 2) {
            str4 = split[0];
            i = Integer.parseInt(split[1]);
        }
        URIBuilder uRIBuilder = new URIBuilder();
        uRIBuilder.setScheme("http").setHost(str4).setPort(i).setPath(str3 + "/blobs/" + str).setParameters(arrayList);
        HttpPut httpPut = new HttpPut(uRIBuilder.build());
        httpPut.setHeader("Content-Type", "application/zip");
        EntityBuilder create = EntityBuilder.create();
        create.setContentType(ContentType.create("application/zip"));
        create.setFile(file);
        httpPut.setEntity(create.build());
        return httpPut;
    }

    public static File buildModelArchive(SparkContext sparkContext, String str, Object obj, HashMap<String, String> hashMap) throws Exception {
        String str2 = hashMap.get("modelType");
        File modelDir = getModelDir(str);
        if (obj instanceof Saveable) {
            ((Saveable) obj).save(sparkContext, modelDir.getAbsolutePath());
            if (str2 == null) {
                str2 = "spark-mllib";
                hashMap.put("modelType", str2);
            }
        } else {
            if (!(obj instanceof MLWritable)) {
                throw new IllegalArgumentException("Provided ML model of type " + obj.getClass().getName() + " does not implement " + Saveable.class.getName() + " or " + MLWritable.class.getName() + "!");
            }
            ((MLWritable) obj).write().overwrite().save(modelDir.getAbsolutePath());
            if (str2 == null) {
                str2 = "spark-ml";
                hashMap.put("modelType", str2);
            }
        }
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        linkedHashMap.put("id", str);
        linkedHashMap.put("modelType", str2);
        linkedHashMap.put("modelClassName", obj.getClass().getName());
        hashMap.remove("modelClassName");
        String str3 = hashMap.get("featureFields");
        if (str3 != null) {
            linkedHashMap.put("featureFields", Arrays.asList(str3.split(",")));
            hashMap.remove("featureFields");
        }
        ObjectMapper objectMapper = new ObjectMapper();
        if ("spark-mllib".equals(str2)) {
            ArrayList arrayList = new ArrayList();
            Map map = (Map) objectMapper.readValue(hashMap.get("analyzerJson"), Map.class);
            HashMap hashMap2 = new HashMap();
            hashMap2.put("lucene-analyzer", map);
            arrayList.add(hashMap2);
            hashMap.remove("analyzerJson");
            HashMap hashMap3 = new HashMap();
            hashMap3.put("numFeatures", hashMap.get("numFeatures"));
            HashMap hashMap4 = new HashMap();
            hashMap4.put("hashingTF", hashMap3);
            arrayList.add(hashMap4);
            hashMap.remove("numFeatures");
            if (hashMap.containsKey("normalizer")) {
                HashMap hashMap5 = new HashMap();
                if (hashMap.containsKey("p-norm")) {
                    hashMap5.put("p-norm", hashMap.get("p-norm"));
                }
                HashMap hashMap6 = new HashMap();
                hashMap6.put("normalizer", hashMap5);
                arrayList.add(hashMap6);
                hashMap.remove("p-norm");
                hashMap.remove("normalizer");
            }
            if (hashMap.containsKey("standardscaler")) {
                HashMap hashMap7 = new HashMap();
                if (hashMap.containsKey("withMean")) {
                    hashMap7.put("withMean", hashMap.get("withMean"));
                }
                if (hashMap.containsKey("withStd")) {
                    hashMap7.put("withStd", hashMap.get("withStd"));
                }
                hashMap7.put("mean", hashMap.get("mean"));
                hashMap7.put("std", hashMap.get("std"));
                HashMap hashMap8 = new HashMap();
                hashMap8.put("standardScaler", hashMap7);
                arrayList.add(hashMap8);
                hashMap.remove("withMean");
                hashMap.remove("withStd");
                hashMap.remove("mean");
                hashMap.remove("std");
                hashMap.remove("standardscaler");
            }
            if (hashMap.containsKey("chisqselector")) {
                HashMap hashMap9 = new HashMap();
                hashMap9.put("numtopfeatures", hashMap.get("numtopfeatures"));
                hashMap9.put("selectedfeatures", hashMap.get("selectedfeatures"));
                HashMap hashMap10 = new HashMap();
                hashMap10.put("chisqselector", hashMap9);
                arrayList.add(hashMap10);
                hashMap.remove("numtopfeatures");
                hashMap.remove("selectedfeatures");
                hashMap.remove("chisqselector");
            }
            linkedHashMap.put("vectorizer", arrayList);
        }
        File file = new File(modelDir, str2 + ".json");
        OutputStreamWriter outputStreamWriter = null;
        try {
            outputStreamWriter = new OutputStreamWriter(new FileOutputStream(file), StandardCharsets.UTF_8);
            objectMapper.writeValue(outputStreamWriter, linkedHashMap);
            if (outputStreamWriter != null) {
                try {
                    outputStreamWriter.flush();
                } catch (IOException e) {
                }
                try {
                    outputStreamWriter.close();
                } catch (IOException e2) {
                }
            }
            hashMap.put("modelSpec", file.getName());
            File file2 = new File(str + ".zip");
            if (file2.isFile()) {
                file2.delete();
            }
            addFilesToZip(modelDir, file2);
            return file2;
        } catch (Throwable th) {
            if (outputStreamWriter != null) {
                try {
                    outputStreamWriter.flush();
                } catch (IOException e3) {
                }
                try {
                    outputStreamWriter.close();
                } catch (IOException e4) {
                }
            }
            throw th;
        }
    }

    protected static File getModelDir(String str) {
        File file = new File(str);
        if (file.isDirectory()) {
            file.renameTo(new File(str + "-bak-" + new SimpleDateFormat("yyMMddHHmmss").format(new Date())));
        }
        if (!file.isDirectory()) {
            file.mkdirs();
        }
        return file;
    }

    protected static void addFilesToZip(File file, File file2) throws IOException, ArchiveException {
        FileOutputStream fileOutputStream = new FileOutputStream(file2);
        ArchiveOutputStream createArchiveOutputStream = new ArchiveStreamFactory().createArchiveOutputStream("zip", fileOutputStream);
        for (File file3 : FileUtils.listFiles(file, (String[]) null, true)) {
            createArchiveOutputStream.putArchiveEntry(new ZipArchiveEntry(getEntryName(file, file3)));
            BufferedInputStream bufferedInputStream = new BufferedInputStream(new FileInputStream(file3));
            IOUtils.copy(bufferedInputStream, createArchiveOutputStream);
            bufferedInputStream.close();
            createArchiveOutputStream.closeArchiveEntry();
        }
        createArchiveOutputStream.finish();
        fileOutputStream.close();
    }

    protected static String getEntryName(File file, File file2) throws IOException {
        return file2.getCanonicalPath().substring(file.getAbsolutePath().length() + 1);
    }
}
