package org.wso2.carbon.ml.core.utils;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.wso2.carbon.ml.core.spark.summary.DeeplearningModelSummary;
import org.wso2.carbon.ml.core.spark.summary.PredictedVsActual;
import org.wso2.carbon.ml.core.spark.summary.TestResultDataPoint;
import scala.Tuple2;
import water.fvec.Frame;
import water.fvec.Vec;

/* loaded from: input_file:org/wso2/carbon/ml/core/utils/DeeplearningModelUtils.class */
public class DeeplearningModelUtils {
    public static DeeplearningModelSummary getDeeplearningModelSummary(JavaSparkContext javaSparkContext, JavaRDD<LabeledPoint> javaRDD, JavaPairRDD<Double, Double> javaPairRDD) {
        DeeplearningModelSummary deeplearningModelSummary = new DeeplearningModelSummary();
        ArrayList arrayList = new ArrayList();
        Iterator it = javaPairRDD.collect().iterator();
        while (it.hasNext()) {
            Tuple2 tuple2 = (Tuple2) it.next();
            PredictedVsActual predictedVsActual = new PredictedVsActual();
            predictedVsActual.setPredicted(((Double) tuple2.mo2345_1()).doubleValue());
            predictedVsActual.setActual(((Double) tuple2.mo2344_2()).doubleValue());
            arrayList.add(predictedVsActual);
        }
        ArrayList arrayList2 = new ArrayList();
        for (LabeledPoint labeledPoint : javaRDD.collect()) {
            if (labeledPoint != null && labeledPoint.features() != null) {
                arrayList2.add(labeledPoint.features().toArray());
            }
        }
        ArrayList arrayList3 = new ArrayList();
        for (int i = 0; i < arrayList2.size(); i++) {
            TestResultDataPoint testResultDataPoint = new TestResultDataPoint();
            testResultDataPoint.setPredictedVsActual(arrayList.get(i));
            testResultDataPoint.setFeatureValues((double[]) arrayList2.get(i));
            arrayList3.add(testResultDataPoint);
        }
        JavaRDD parallelize = javaSparkContext.parallelize(arrayList3);
        deeplearningModelSummary.setTestResultDataPointsSample(parallelize.count() > ((long) MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()) ? parallelize.takeSample(true, MLCoreServiceValueHolder.getInstance().getSummaryStatSettings().getSampleSize()) : parallelize.collect());
        deeplearningModelSummary.setPredictedVsActuals(arrayList);
        deeplearningModelSummary.setError((1.0d * javaPairRDD.filter(new Function<Tuple2<Double, Double>, Boolean>() { // from class: org.wso2.carbon.ml.core.utils.DeeplearningModelUtils.1
            private static final long serialVersionUID = -3063364114286182333L;

            @Override // org.apache.spark.api.java.function.Function
            public Boolean call(Tuple2<Double, Double> tuple22) {
                return Boolean.valueOf(!tuple22.mo2345_1().equals(tuple22.mo2344_2()));
            }
        }).count()) / javaPairRDD.count());
        return deeplearningModelSummary;
    }

    public static Frame javaRDDToFrame(String[] strArr, JavaRDD<LabeledPoint> javaRDD) {
        List<LabeledPoint> collect = javaRDD.collect();
        Vec[] vecArr = new Vec[collect.get(0).features().size() + 1];
        for (int i = 0; i < collect.get(0).features().size() + 1; i++) {
            if (i < collect.get(0).features().size()) {
                Vec makeZero = Vec.makeZero(collect.size());
                for (int i2 = 0; i2 < collect.size(); i2++) {
                    makeZero.set(i2, collect.get(i2).features().toArray()[i]);
                }
                vecArr[i] = makeZero;
            } else {
                Vec makeZero2 = Vec.makeZero(collect.size());
                for (int i3 = 0; i3 < collect.size(); i3++) {
                    makeZero2.set(i3, (int) collect.get(i3).label());
                }
                vecArr[i] = makeZero2;
            }
        }
        return new Frame(strArr, vecArr);
    }

    public static Frame doubleArrayListToFrame(String[] strArr, List<double[]> list) {
        Vec[] vecArr = new Vec[list.get(0).length];
        for (int i = 0; i < list.get(0).length; i++) {
            Vec makeZero = Vec.makeZero(list.size());
            for (int i2 = 0; i2 < list.size(); i2++) {
                makeZero.set(i2, list.get(i2)[i]);
            }
            vecArr[i] = makeZero;
        }
        return new Frame(strArr, vecArr);
    }
}
