package org.wso2.carbon.ml.core.spark.models;

import hex.deeplearning.DeepLearningModel;
import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.HashMap;
import java.util.LinkedList;
import org.apache.spark.mllib.linalg.Vector;
import org.wso2.extension.siddhi.execution.ml.ModelHandler;
import water.Key;
import water.fvec.Frame;
import water.fvec.Vec;
import water.serial.ObjectTreeBinarySerializer;
import water.util.FileUtils;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/models/MLDeeplearningModel.class */
public class MLDeeplearningModel implements Externalizable {
    private DeepLearningModel dlModel;
    private HashMap<Double, Double> labelToH2OEnumMap;
    private String storageLocation;

    public MLDeeplearningModel() {
    }

    public MLDeeplearningModel(DeepLearningModel deepLearningModel) {
        this.dlModel = deepLearningModel;
    }

    public void setStorageLocation(String str) {
        this.storageLocation = str;
    }

    public DeepLearningModel getDeepLearningModel() {
        return this.dlModel;
    }

    public void setDeepLearningModel(DeepLearningModel deepLearningModel) {
        this.dlModel = deepLearningModel;
    }

    public double predict(Vector vector) {
        return this.dlModel.score(vector.toArray());
    }

    public double[] predict(Frame frame) {
        Frame score = this.dlModel.score(frame);
        int numRows = (int) frame.numRows();
        Vec vec = score.vec(0);
        double[] dArr = new double[numRows];
        for (int i = 0; i < numRows; i++) {
            dArr[i] = vec.at(i);
        }
        return dArr;
    }

    public String getURIStringForLocation(String str) {
        return ModelHandler.FILE_STORAGE_PREFIX + str.substring(1).replace("\\", "/");
    }

    @Override // java.io.Externalizable
    public void writeExternal(ObjectOutput objectOutput) throws IOException {
        objectOutput.writeObject(this.storageLocation);
        objectOutput.writeObject(this.labelToH2OEnumMap);
        LinkedList linkedList = new LinkedList();
        linkedList.add(this.dlModel._key);
        new ObjectTreeBinarySerializer().save(linkedList, FileUtils.getURI(this.storageLocation));
    }

    @Override // java.io.Externalizable
    public void readExternal(ObjectInput objectInput) throws IOException, ClassNotFoundException {
        this.storageLocation = (String) objectInput.readObject();
        this.labelToH2OEnumMap = (HashMap) objectInput.readObject();
        this.dlModel = ((Key) new ObjectTreeBinarySerializer().load(FileUtils.getURI(this.storageLocation)).get(0)).get();
    }

    public DeepLearningModel getDlModel() {
        return this.dlModel;
    }
}
