package com.datumbox.framework.core.machinelearning.featureselection.continuous;

import com.datumbox.framework.common.Configuration;
import com.datumbox.framework.common.concurrency.ForkJoinStream;
import com.datumbox.framework.common.concurrency.StreamMethods;
import com.datumbox.framework.common.dataobjects.AssociativeArray;
import com.datumbox.framework.common.dataobjects.Dataframe;
import com.datumbox.framework.common.dataobjects.MatrixDataframe;
import com.datumbox.framework.common.dataobjects.Record;
import com.datumbox.framework.common.persistentstorage.interfaces.BigMap;
import com.datumbox.framework.common.persistentstorage.interfaces.DatabaseConnector;
import com.datumbox.framework.common.utilities.PHPMethods;
import com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer;
import com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractContinuousFeatureSelector;
import com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.math3.linear.BlockRealMatrix;
import org.apache.commons.math3.linear.DiagonalMatrix;
import org.apache.commons.math3.linear.EigenDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.stat.StatUtils;

/* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/continuous/PCA.class */
public class PCA extends AbstractContinuousFeatureSelector<ModelParameters, TrainingParameters> implements Parallelizable {
    private boolean parallelized;
    protected final ForkJoinStream streamExecutor;

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/continuous/PCA$ModelParameters.class */
    public static class ModelParameters extends AbstractTrainer.AbstractModelParameters {
        private static final long serialVersionUID = 1;

        @BigMap(mapType = DatabaseConnector.MapType.HASHMAP, storageHint = DatabaseConnector.StorageHint.IN_MEMORY, concurrent = false)
        private Map<Object, Integer> featureIds;
        private int rows;
        private int cols;
        private double[] mean;
        private double[] eigenValues;
        private double[][] components;

        protected ModelParameters(DatabaseConnector databaseConnector) {
            super(databaseConnector);
        }

        public Map<Object, Integer> getFeatureIds() {
            return this.featureIds;
        }

        protected void setFeatureIds(Map<Object, Integer> map) {
            this.featureIds = map;
        }

        public int getRows() {
            return this.rows;
        }

        protected void setRows(int i) {
            this.rows = i;
        }

        public int getCols() {
            return this.cols;
        }

        protected void setCols(int i) {
            this.cols = i;
        }

        public double[] getMean() {
            return PHPMethods.array_clone(this.mean);
        }

        protected void setMean(double[] dArr) {
            this.mean = PHPMethods.array_clone(dArr);
        }

        public double[] getEigenValues() {
            return PHPMethods.array_clone(this.eigenValues);
        }

        protected void setEigenValues(double[] dArr) {
            this.eigenValues = PHPMethods.array_clone(dArr);
        }

        public double[][] getComponents() {
            return PHPMethods.array_clone(this.components);
        }

        protected void setComponents(double[][] dArr) {
            this.components = PHPMethods.array_clone(dArr);
        }
    }

    /* loaded from: input_file:com/datumbox/framework/core/machinelearning/featureselection/continuous/PCA$TrainingParameters.class */
    public static class TrainingParameters extends AbstractTrainer.AbstractTrainingParameters {
        private static final long serialVersionUID = 1;
        private boolean whitened = false;
        private Integer maxDimensions = null;
        private Double variancePercentageThreshold = null;

        public boolean isWhitened() {
            return this.whitened;
        }

        public void setWhitened(boolean z) {
            this.whitened = z;
        }

        public Integer getMaxDimensions() {
            return this.maxDimensions;
        }

        public void setMaxDimensions(Integer num) {
            this.maxDimensions = num;
        }

        public Double getVariancePercentageThreshold() {
            return this.variancePercentageThreshold;
        }

        public void setVariancePercentageThreshold(Double d) {
            this.variancePercentageThreshold = d;
        }
    }

    public PCA(String str, Configuration configuration) {
        super(str, configuration, ModelParameters.class, TrainingParameters.class);
        this.parallelized = true;
        this.streamExecutor = new ForkJoinStream(kb().getConf().getConcurrencyConfig());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public boolean isParallelized() {
        return this.parallelized;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.interfaces.Parallelizable
    public void setParallelized(boolean z) {
        this.parallelized = z;
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.AbstractTrainer
    protected void _fit(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        int intValue = modelParameters.getN().intValue();
        int intValue2 = modelParameters.getD().intValue();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        RealMatrix x = MatrixDataframe.newInstance(dataframe, false, (Map) null, featureIds).getX();
        double[] dArr = new double[intValue2];
        for (Integer num : featureIds.values()) {
            dArr[num.intValue()] = 0.0d;
            for (double d : x.getColumn(num.intValue())) {
                int intValue3 = num.intValue();
                dArr[intValue3] = dArr[intValue3] + d;
            }
            int intValue4 = num.intValue();
            dArr[intValue4] = dArr[intValue4] / intValue;
            for (int i = 0; i < intValue; i++) {
                x.addToEntry(i, num.intValue(), -dArr[num.intValue()]);
            }
        }
        modelParameters.setMean(dArr);
        EigenDecomposition eigenDecomposition = new EigenDecomposition(x.transpose().multiply(x).scalarMultiply(1.0d / (intValue - 1.0d)));
        double[] realEigenvalues = eigenDecomposition.getRealEigenvalues();
        RealMatrix v = eigenDecomposition.getV();
        if (((TrainingParameters) kb().getTrainingParameters()).isWhitened()) {
            double[] dArr2 = new double[realEigenvalues.length];
            for (int i2 = 0; i2 < realEigenvalues.length; i2++) {
                dArr2[i2] = Math.sqrt(realEigenvalues[i2]);
            }
            v = v.multiply(new DiagonalMatrix(dArr2));
        }
        Integer maxDimensions = ((TrainingParameters) kb().getTrainingParameters()).getMaxDimensions();
        Double variancePercentageThreshold = ((TrainingParameters) kb().getTrainingParameters()).getVariancePercentageThreshold();
        if (variancePercentageThreshold != null && variancePercentageThreshold.doubleValue() <= 1.0d) {
            double d2 = 0.0d;
            double sum = StatUtils.sum(realEigenvalues);
            int i3 = 0;
            for (double d3 : realEigenvalues) {
                d2 += d3 / sum;
                i3++;
                if (d2 >= variancePercentageThreshold.doubleValue()) {
                    break;
                }
            }
            if (maxDimensions == null || maxDimensions.intValue() > i3) {
                maxDimensions = Integer.valueOf(i3);
            }
        }
        if (maxDimensions != null && maxDimensions.intValue() < intValue2) {
            double[] dArr3 = new double[maxDimensions.intValue()];
            System.arraycopy(realEigenvalues, 0, dArr3, 0, maxDimensions.intValue());
            realEigenvalues = dArr3;
            v = v.getSubMatrix(0, v.getRowDimension() - 1, 0, maxDimensions.intValue() - 1);
        }
        modelParameters.setRows(v.getRowDimension());
        modelParameters.setCols(v.getColumnDimension());
        modelParameters.setEigenValues(realEigenvalues);
        modelParameters.setComponents(v.getData());
    }

    @Override // com.datumbox.framework.core.machinelearning.common.abstracts.featureselectors.AbstractFeatureSelector
    protected void filterFeatures(Dataframe dataframe) {
        ModelParameters modelParameters = (ModelParameters) kb().getModelParameters();
        Map<Object, Integer> featureIds = modelParameters.getFeatureIds();
        HashMap hashMap = new HashMap();
        MatrixDataframe parseDataset = MatrixDataframe.parseDataset(dataframe, hashMap, featureIds);
        RealMatrix multiply = parseDataset.getX().multiply(new BlockRealMatrix(modelParameters.getComponents()));
        this.streamExecutor.forEach(StreamMethods.stream(dataframe.entries(), isParallelized()), entry -> {
            Integer num = (Integer) entry.getKey();
            Record record = (Record) entry.getValue();
            int intValue = ((Integer) hashMap.get(num)).intValue();
            AssociativeArray associativeArray = new AssociativeArray();
            int i = 0;
            for (double d : multiply.getRow(intValue)) {
                int i2 = i;
                i++;
                associativeArray.put(Integer.valueOf(i2), Double.valueOf(d));
            }
            dataframe._unsafe_set(num, new Record(associativeArray, record.getY(), record.getYPredicted(), record.getYPredictedProbabilities()));
        });
        dataframe.recalculateMeta();
    }
}
