package org.wso2.carbon.ml.core.h2o;

import hex.genmodel.GenModel;
import hex.genmodel.easy.EasyPredictModelWrapper;
import hex.genmodel.easy.RowData;
import hex.genmodel.easy.exception.AbstractPredictException;
import java.io.File;
import java.net.MalformedURLException;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.wso2.carbon.ml.commons.constants.MLConstants;
import org.wso2.carbon.ml.commons.domain.Feature;
import org.wso2.carbon.ml.commons.domain.MLModel;
import org.wso2.carbon.ml.core.exceptions.MLModelHandlerException;
import org.wso2.carbon.utils.CarbonUtils;

/* loaded from: input_file:org/wso2/carbon/ml/core/h2o/POJOPredictor.class */
public class POJOPredictor {
    List<Feature> featureList;
    MLModel mlModel;
    private GenModel rawModel;
    private EasyPredictModelWrapper model;
    private int numberOfFeatures;

    public POJOPredictor(MLModel mLModel, String str) throws MLModelHandlerException {
        String replace = extractModelName(str).replace('.', '_').replace('-', '_');
        String carbonHome = CarbonUtils.getCarbonHome();
        try {
            this.rawModel = (GenModel) new URLClassLoader(new URL[]{new File(carbonHome + MLConstants.H2O_POJO_Path).toURI().toURL(), new File(carbonHome + MLConstants.H2O_POJO_Path + "h2o-genmodel.jar").toURI().toURL()}, getClass().getClassLoader()).loadClass(replace).newInstance();
            this.model = new EasyPredictModelWrapper(this.rawModel);
            this.numberOfFeatures = mLModel.getFeatures().size();
            this.featureList = mLModel.getFeatures();
            this.mlModel = mLModel;
        } catch (ClassNotFoundException | IllegalAccessException | InstantiationException | MalformedURLException e) {
            throw new MLModelHandlerException("Error occurred while initializing POHOPredictor.", e);
        }
    }

    public Object predict(String[] strArr) throws MLModelHandlerException {
        RowData rowData = new RowData();
        for (int i = 0; i < this.numberOfFeatures; i++) {
            rowData.put(this.featureList.get(i).getName(), strArr[i]);
        }
        if (this.model.getModelCategory().name().equalsIgnoreCase("multinomial")) {
            try {
                return decodePredictedValue(Double.parseDouble(this.model.predictMultinomial(rowData).label), this.mlModel);
            } catch (AbstractPredictException e) {
                throw new MLModelHandlerException("Error occurred while predicting.", e);
            }
        }
        if (!this.model.getModelCategory().name().equalsIgnoreCase("binomial")) {
            throw new MLModelHandlerException("Unsupported deep learning model:" + this.model.getModelCategory());
        }
        try {
            return decodePredictedValue(Double.parseDouble(this.model.predictBinomial(rowData).label), this.mlModel);
        } catch (AbstractPredictException e2) {
            throw new MLModelHandlerException("Error occurred while predicting.", e2);
        }
    }

    private Object decodePredictedValue(double d, MLModel mLModel) {
        if (mLModel.getResponseIndex() == -1) {
            return Double.valueOf(d);
        }
        List<Map<String, Integer>> encodings = mLModel.getEncodings();
        Map<String, Integer> map = encodings.get(encodings.size() - 1);
        return (map == null || map.isEmpty()) ? Double.valueOf(d) : decode(map, (int) Math.round(d));
    }

    private String decode(Map<String, Integer> map, int i) {
        String findClass = findClass(map, i);
        if (findClass != null) {
            return findClass;
        }
        int closest = closest(i, map.values());
        findClass(map, closest);
        return String.valueOf(closest);
    }

    private String findClass(Map<String, Integer> map, int i) {
        for (Map.Entry<String, Integer> entry : map.entrySet()) {
            if (i == entry.getValue().intValue()) {
                return entry.getKey();
            }
        }
        return null;
    }

    private int closest(int i, Collection<Integer> collection) {
        int i2 = Integer.MAX_VALUE;
        int i3 = i;
        Iterator<Integer> it = collection.iterator();
        while (it.hasNext()) {
            int intValue = it.next().intValue();
            int abs = Math.abs(intValue - i);
            if (abs < i2) {
                i2 = abs;
                i3 = intValue;
            }
        }
        return i3;
    }

    private String extractModelName(String str) {
        return str.substring(str.lastIndexOf(File.separator) + 1);
    }
}
