/*
 * Decompiled with CFR 0.152.
 */
package org.wso2.carbon.ml.siddhi.extension;

import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.wso2.carbon.ml.commons.domain.Feature;
import org.wso2.carbon.ml.commons.domain.MLModel;
import org.wso2.carbon.ml.core.exceptions.MLInputAdapterException;
import org.wso2.carbon.ml.core.exceptions.MLModelHandlerException;
import org.wso2.carbon.ml.core.factories.DatasetType;
import org.wso2.carbon.ml.core.h2o.POJOPredictor;
import org.wso2.carbon.ml.core.impl.MLIOFactory;
import org.wso2.carbon.ml.core.impl.Predictor;
import org.wso2.carbon.ml.core.interfaces.MLInputAdapter;
import org.wso2.carbon.ml.core.utils.MLCoreServiceValueHolder;

public class ModelHandler {
    public static final String FILE_STORAGE_PREFIX = "file";
    public static final String REGISTRY_STORAGE_PREFIX = "registry";
    public static final String PATH_TO_GOVERNANCE_REGISTRY = "/_system/governance";
    private MLModel mlModel;
    private long modelId;

    public ModelHandler(String modelStorageLocation) throws ClassNotFoundException, URISyntaxException, MLInputAdapterException, IOException {
        this.mlModel = ModelHandler.retrieveModel(modelStorageLocation);
    }

    private static MLModel retrieveModel(String modelStorageLocation) throws URISyntaxException, MLInputAdapterException, IOException, ClassNotFoundException {
        String[] modelStorage = modelStorageLocation.trim().split(":");
        String storageType = modelStorage[0];
        if (storageType.equals(REGISTRY_STORAGE_PREFIX)) {
            modelStorageLocation = modelStorage[1].startsWith(PATH_TO_GOVERNANCE_REGISTRY) ? modelStorage[1].substring(PATH_TO_GOVERNANCE_REGISTRY.length()) : modelStorage[1];
        } else if (storageType.equals(FILE_STORAGE_PREFIX)) {
            modelStorageLocation = modelStorage[1];
            storageType = DatasetType.FILE.getValue();
        } else {
            storageType = DatasetType.FILE.getValue();
        }
        MLIOFactory ioFactory = new MLIOFactory(MLCoreServiceValueHolder.getInstance().getMlProperties());
        MLInputAdapter inputAdapter = ioFactory.getInputAdapter(storageType + ".in");
        InputStream in = inputAdapter.read(modelStorageLocation);
        ObjectInputStream ois = new ObjectInputStream(in);
        MLModel mlModel = (MLModel)ois.readObject();
        ois.close();
        return mlModel;
    }

    public Object predict(String[] data, String outputType) throws MLModelHandlerException {
        ArrayList<String[]> list = new ArrayList<String[]>();
        list.add(data);
        Predictor predictor = new Predictor(this.modelId, this.mlModel, list);
        List predictions = predictor.predict();
        String predictionStr = predictions.get(0).toString();
        Object prediction = this.castValue(outputType, predictionStr);
        return prediction;
    }

    public Object predict(String[] data, String outputType, double percentile) throws MLModelHandlerException {
        ArrayList<String[]> list = new ArrayList<String[]>();
        list.add(data);
        Predictor predictor = new Predictor(this.modelId, this.mlModel, list, percentile, false);
        List predictions = predictor.predict();
        String predictionStr = predictions.get(0).toString();
        Object prediction = this.castValue(outputType, predictionStr);
        return prediction;
    }

    public Object predict(String[] data, String outputType, POJOPredictor pojoPredictor) throws MLModelHandlerException {
        String predictionStr = pojoPredictor.predict(data).toString();
        Object prediction = this.castValue(outputType, predictionStr);
        return prediction;
    }

    private Object castValue(String outputType, String value) {
        if (outputType.equalsIgnoreCase("double")) {
            return Double.parseDouble(value);
        }
        if (outputType.equalsIgnoreCase("float")) {
            return Float.valueOf(Float.parseFloat(value));
        }
        if (outputType.equalsIgnoreCase("integer") || outputType.equalsIgnoreCase("int")) {
            return Integer.parseInt(value);
        }
        if (outputType.equalsIgnoreCase("long")) {
            return Long.parseLong(value);
        }
        if (outputType.equalsIgnoreCase("boolean") || outputType.equalsIgnoreCase("bool")) {
            return Boolean.parseBoolean(value);
        }
        return value;
    }

    public Map<String, Integer> getFeatures() {
        List features = this.mlModel.getFeatures();
        HashMap<String, Integer> featureIndexMap = new HashMap<String, Integer>();
        for (Feature feature : features) {
            featureIndexMap.put(feature.getName(), feature.getIndex());
        }
        return featureIndexMap;
    }

    public List<Integer> getNewToOldIndicesList() {
        return this.mlModel.getNewToOldIndicesList();
    }

    public String getResponseVariable() {
        return this.mlModel.getResponseVariable();
    }

    public String getAlgorithmClass() {
        return this.mlModel.getAlgorithmClass();
    }

    public MLModel getMlModel() {
        return this.mlModel;
    }
}

