package org.wso2.carbon.ml.core.spark.transformations;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.spark.api.java.function.Function;
import org.wso2.carbon.ml.commons.domain.Feature;
import org.wso2.carbon.ml.core.exceptions.MLModelBuilderException;
import org.wso2.carbon.ml.core.internal.MLModelConfigurationContext;
import org.wso2.carbon.ml.core.spark.algorithms.SparkModelUtils;

/* loaded from: input_file:org/wso2/carbon/ml/core/spark/transformations/Normalization.class */
public class Normalization implements Function<double[], double[]> {
    private static final long serialVersionUID = 4558936873487486962L;
    private final List<Double> max;
    private final List<Double> min;

    /* loaded from: input_file:org/wso2/carbon/ml/core/spark/transformations/Normalization$Builder.class */
    public static class Builder {
        private List<Double> max = new ArrayList();
        private List<Double> min = new ArrayList();

        public Builder init(MLModelConfigurationContext mLModelConfigurationContext) {
            setMinMax(mLModelConfigurationContext.getFacts().getIncludedFeatures(), mLModelConfigurationContext.getSummaryStatsOfFeatures());
            return this;
        }

        public Builder minMax(List<Feature> list, Map<String, String> map) {
            setMinMax(list, map);
            return this;
        }

        public Builder minMax(List<Double> list, List<Double> list2) {
            this.max = list;
            this.min = list2;
            return this;
        }

        private void setMinMax(List<Feature> list, Map<String, String> map) {
            Iterator<Feature> it = list.iterator();
            while (it.hasNext()) {
                String str = map.get(it.next().getName());
                this.max.add(Double.valueOf(SparkModelUtils.getMax(str)));
                this.min.add(Double.valueOf(SparkModelUtils.getMin(str)));
            }
        }

        public Normalization build() {
            return new Normalization(this);
        }
    }

    public Normalization(Builder builder) {
        this.max = builder.max;
        this.min = builder.min;
    }

    @Override // org.apache.spark.api.java.function.Function
    public double[] call(double[] dArr) throws MLModelBuilderException {
        try {
            double[] dArr2 = new double[dArr.length];
            for (int i = 0; i < dArr.length; i++) {
                if (Double.compare(dArr[i], this.max.get(i).doubleValue()) > 0) {
                    dArr2[i] = 1.0d;
                } else if (Double.compare(dArr[i], this.min.get(i).doubleValue()) < 0) {
                    dArr2[i] = 0.0d;
                } else if (Double.compare(this.min.get(i).doubleValue(), this.max.get(i).doubleValue()) == 0) {
                    dArr2[i] = 0.5d;
                } else {
                    dArr2[i] = (dArr[i] - this.min.get(i).doubleValue()) / (this.max.get(i).doubleValue() - this.min.get(i).doubleValue());
                }
            }
            return dArr2;
        } catch (Exception e) {
            throw new MLModelBuilderException("An error occurred while normalizing values: " + e.getMessage(), e);
        }
    }
}
