package org.wso2.carbon.ml.core.spark.algorithms;

import hex.deeplearning.DeepLearning;
import hex.deeplearning.DeepLearningModel;
import hex.deeplearning.DeepLearningParameters;
import hex.splitframe.ShuffleSplitFrame;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.wso2.carbon.ml.commons.domain.Feature;
import org.wso2.carbon.ml.commons.domain.MLModel;
import org.wso2.carbon.ml.core.exceptions.MLModelBuilderException;
import org.wso2.carbon.ml.core.utils.DeeplearningModelUtils;
import scala.Tuple2;
import water.DKV;
import water.Key;
import water.Scope;
import water.fvec.Frame;
import water.fvec.Vec;
import water.util.FrameUtils;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/algorithms/StackedAutoencodersClassifier.class */
public class StackedAutoencodersClassifier implements Serializable {
    private static final long serialVersionUID = -3518369175759608115L;
    private static final Log log = LogFactory.getLog(StackedAutoencodersClassifier.class);
    private transient DeepLearning deeplearning;
    private transient DeepLearningModel dlModel;

    public DeepLearningModel train(JavaRDD<LabeledPoint> javaRDD, int i, int[] iArr, String str, int i2, int i3, String str2, String str3, MLModel mLModel, long j) {
        try {
            try {
                Scope.enter();
                if (javaRDD != null) {
                    int size = mLModel.getFeatures().size();
                    List<Feature> features = mLModel.getFeatures();
                    String[] strArr = new String[size + 1];
                    for (int i4 = 0; i4 < size; i4++) {
                        strArr[i4] = features.get(i4).getName();
                    }
                    strArr[size] = mLModel.getResponseVariable();
                    Frame javaRDDToFrame = DeeplearningModelUtils.javaRDDToFrame(strArr, javaRDD);
                    String responseVariable = mLModel.getResponseVariable();
                    int find = javaRDDToFrame.find(responseVariable);
                    Scope.track(javaRDDToFrame.replace(find, javaRDDToFrame.vecs()[find].toEnum())._key);
                    double[] dArr = {1.0d, 1.0d - 1.0d};
                    Frame[] shuffleSplitFrame = ShuffleSplitFrame.shuffleSplitFrame(javaRDDToFrame, FrameUtils.generateNumKeys(javaRDDToFrame._key, dArr.length), dArr, 123456789L);
                    Frame frame = shuffleSplitFrame[0];
                    Frame frame2 = shuffleSplitFrame[1];
                    if (log.isDebugEnabled()) {
                        log.debug("Creating Deeplearning parameters");
                    }
                    DeepLearningParameters deepLearningParameters = new DeepLearningParameters();
                    deepLearningParameters._model_id = Key.make(str3.replace('.', '_').replace('-', '_') + "_dl");
                    deepLearningParameters._train = frame._key;
                    deepLearningParameters._valid = frame2._key;
                    deepLearningParameters._response_column = responseVariable;
                    deepLearningParameters._activation = getActivationType(str);
                    deepLearningParameters._hidden = iArr;
                    deepLearningParameters._train_samples_per_iteration = i;
                    deepLearningParameters._input_dropout_ratio = 0.2d;
                    deepLearningParameters._l1 = 1.0E-5d;
                    deepLearningParameters._max_w2 = 10.0f;
                    deepLearningParameters._epochs = i2;
                    deepLearningParameters._seed = i3;
                    deepLearningParameters._adaptive_rate = true;
                    deepLearningParameters._replicate_training_data = true;
                    deepLearningParameters._overwrite_with_best_model = true;
                    deepLearningParameters._diagnostics = false;
                    deepLearningParameters._classification_stop = -1.0d;
                    deepLearningParameters._score_interval = 60.0d;
                    deepLearningParameters._score_training_samples = i / 10;
                    DKV.put(frame);
                    DKV.put(frame2);
                    this.deeplearning = new DeepLearning(deepLearningParameters);
                    if (log.isDebugEnabled()) {
                        log.debug("Start training deeplearning model ....");
                    }
                    try {
                        this.dlModel = this.deeplearning.trainModel().get();
                        if (log.isDebugEnabled()) {
                            log.debug("Successfully finished Training deeplearning model.");
                        }
                    } catch (RuntimeException e) {
                        log.error("Error in training Stacked Autoencoder classifier model", e);
                    }
                } else {
                    log.error("Train file not found!");
                }
                Scope.exit(new Key[0]);
            } catch (RuntimeException e2) {
                log.error("Failed to train the deeplearning model [id] " + j + ". " + e2.getMessage());
                Scope.exit(new Key[0]);
            }
            return this.dlModel;
        } catch (Throwable th) {
            Scope.exit(new Key[0]);
            throw th;
        }
    }

    private DeepLearningParameters.Activation getActivationType(String str) {
        String[] strArr = {"Rectifier", "RectifierWithDropout", "Tanh", "TanhWithDropout", "Maxout", "MaxoutWithDropout"};
        return str.equalsIgnoreCase(strArr[0]) ? DeepLearningParameters.Activation.Rectifier : str.equalsIgnoreCase(strArr[1]) ? DeepLearningParameters.Activation.RectifierWithDropout : str.equalsIgnoreCase(strArr[2]) ? DeepLearningParameters.Activation.Tanh : str.equalsIgnoreCase(strArr[3]) ? DeepLearningParameters.Activation.TanhWithDropout : str.equalsIgnoreCase(strArr[4]) ? DeepLearningParameters.Activation.Maxout : str.equalsIgnoreCase(strArr[5]) ? DeepLearningParameters.Activation.MaxoutWithDropout : DeepLearningParameters.Activation.RectifierWithDropout;
    }

    public JavaPairRDD<Double, Double> test(JavaSparkContext javaSparkContext, DeepLearningModel deepLearningModel, JavaRDD<LabeledPoint> javaRDD, MLModel mLModel) throws MLModelBuilderException {
        Scope.enter();
        if (deepLearningModel == null) {
            throw new MLModelBuilderException("DeeplearningModel is Null");
        }
        int size = mLModel.getFeatures().size();
        List<Feature> features = mLModel.getFeatures();
        String[] strArr = new String[size + 1];
        for (int i = 0; i < size; i++) {
            strArr[i] = features.get(i).getName();
        }
        strArr[size] = mLModel.getResponseVariable();
        Frame javaRDDToFrame = DeeplearningModelUtils.javaRDDToFrame(strArr, javaRDD);
        Frame subframe = javaRDDToFrame.subframe(0, javaRDDToFrame.numCols() - 1);
        int numRows = (int) subframe.numRows();
        Vec vec = deepLearningModel.score(subframe).vec(0);
        double[] dArr = new double[numRows];
        for (int i2 = 0; i2 < numRows; i2++) {
            dArr[i2] = vec.at(i2);
        }
        Vec vec2 = javaRDDToFrame.vec(javaRDDToFrame.numCols() - 1);
        double[] dArr2 = new double[numRows];
        for (int i3 = 0; i3 < numRows; i3++) {
            dArr2[i3] = vec2.at(i3);
        }
        Scope.exit(new Key[0]);
        ArrayList arrayList = new ArrayList();
        for (int i4 = 0; i4 < dArr2.length; i4++) {
            arrayList.add(new Tuple2(Double.valueOf(dArr[i4]), Double.valueOf(dArr2[i4])));
        }
        return javaSparkContext.parallelizePairs(arrayList);
    }
}
