package com.zavtech.morpheus.reference;

import com.zavtech.morpheus.array.Array;
import com.zavtech.morpheus.frame.DataFrame;
import com.zavtech.morpheus.frame.DataFrameColumns;
import com.zavtech.morpheus.frame.DataFrameException;
import com.zavtech.morpheus.frame.DataFramePCA;
import com.zavtech.morpheus.frame.DataFrameRow;
import com.zavtech.morpheus.jama.EigenDecomposition;
import com.zavtech.morpheus.jama.Matrix;
import com.zavtech.morpheus.range.Range;
import com.zavtech.morpheus.stats.StatType;
import com.zavtech.morpheus.util.Asserts;
import com.zavtech.morpheus.util.IntComparator;
import com.zavtech.morpheus.util.SortAlgorithm;
import com.zavtech.morpheus.util.Swapper;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.DoubleStream;
import java.util.stream.IntStream;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.SingularValueDecomposition;

/* JADX INFO: Access modifiers changed from: package-private */
/* loaded from: input_file:com/zavtech/morpheus/reference/XDataFramePCA.class */
public class XDataFramePCA<R, C> implements DataFramePCA<R, C> {
    private XDataFrame<R, C> frame;

    /* loaded from: input_file:com/zavtech/morpheus/reference/XDataFramePCA$ModelBase.class */
    private abstract class ModelBase<X, Y> implements DataFramePCA.Model<X, Y> {
        private DataFrame<X, Y> data;
        private DataFrame<Y, StatType> means;
        private DataFrame<Integer, DataFramePCA.Field> eigenValues;
        private DataFrame<Integer, Integer> eigenVectors;

        ModelBase(DataFrame<X, Y> dataFrame, boolean z) {
            this.data = dataFrame;
            if (z) {
                this.means = dataFrame.cols().stats().mean();
                this.data = dataFrame.cols().demean(false);
            }
        }

        abstract boolean calculate();

        DataFrame<X, Y> data() {
            return this.data;
        }

        public boolean isDemean() {
            return this.means != null;
        }

        boolean isDirty() {
            return this.eigenValues == null;
        }

        protected void update(double[] dArr, Matrix matrix) {
            Ordering ordering = new Ordering(dArr);
            SortAlgorithm.getDefault(false).sort(0, dArr.length, ordering, ordering);
            int[] indices = ordering.getIndices();
            this.eigenValues = DataFrame.ofDoubles((Iterable) Range.of(0, dArr.length), (Iterable) Array.of(DataFramePCA.Field.EIGENVALUE));
            this.eigenValues.applyDoubles(dataFrameValue -> {
                return dArr[indices[dataFrameValue.rowOrdinal()]];
            });
            addVariancePercentages();
            this.eigenVectors = DataFrame.of(Range.of(0, dArr.length), Integer.class, dataFrameColumns -> {
                for (int i = 0; i < indices.length; i++) {
                    int i2 = indices[i];
                    dataFrameColumns.add((DataFrameColumns) Integer.valueOf(i), Array.of(Double.class, dArr.length).applyDoubles(arrayValue -> {
                        return matrix.get(arrayValue.index(), i2);
                    }));
                }
            });
        }

        protected void update(double[] dArr, RealMatrix realMatrix) {
            Ordering ordering = new Ordering(dArr);
            SortAlgorithm.getDefault(false).sort(0, dArr.length, ordering, ordering);
            int[] indices = ordering.getIndices();
            this.eigenValues = DataFrame.ofDoubles((Iterable) Range.of(0, dArr.length), (Iterable) Array.of(DataFramePCA.Field.EIGENVALUE));
            this.eigenValues.applyDoubles(dataFrameValue -> {
                return dArr[indices[dataFrameValue.rowOrdinal()]];
            });
            addVariancePercentages();
            this.eigenVectors = DataFrame.of(Range.of(0, dArr.length), Integer.class, dataFrameColumns -> {
                for (int i = 0; i < indices.length; i++) {
                    int i2 = indices[i];
                    dataFrameColumns.add((DataFrameColumns) Integer.valueOf(i), Array.of(Double.class, dArr.length).applyDoubles(arrayValue -> {
                        return realMatrix.getEntry(arrayValue.index(), i2);
                    }));
                }
            });
        }

        private void addVariancePercentages() {
            double doubleValue = this.eigenValues.col(DataFramePCA.Field.EIGENVALUE).stats().sum().doubleValue();
            this.eigenValues.cols().add(DataFramePCA.Field.VAR_PERCENT, Double.class, dataFrameValue -> {
                return Double.valueOf(dataFrameValue.row().getDouble((DataFrameRow<R, C>) DataFramePCA.Field.EIGENVALUE) / doubleValue);
            });
            this.eigenValues.cols().add(DataFramePCA.Field.VAR_PERCENT_CUM, Double.class, dataFrameValue2 -> {
                switch (dataFrameValue2.rowOrdinal()) {
                    case 0:
                        return Double.valueOf(dataFrameValue2.row().getDouble((DataFrameRow<R, C>) DataFramePCA.Field.VAR_PERCENT));
                    default:
                        return Double.valueOf(dataFrameValue2.row().getDouble((DataFrameRow<R, C>) DataFramePCA.Field.VAR_PERCENT) + dataFrameValue2.col().getDouble(dataFrameValue2.rowOrdinal() - 1));
                }
            });
        }

        @Override // com.zavtech.morpheus.frame.DataFramePCA.Model
        public DataFrame<Integer, Integer> getEigenVectors() {
            calculate();
            return this.eigenVectors;
        }

        @Override // com.zavtech.morpheus.frame.DataFramePCA.Model
        public DataFrame<Integer, DataFramePCA.Field> getEigenValues() {
            calculate();
            return this.eigenValues;
        }

        @Override // com.zavtech.morpheus.frame.DataFramePCA.Model
        public DataFrame<X, Integer> getScores() {
            calculate();
            return getScores(this.eigenValues.rowCount());
        }

        @Override // com.zavtech.morpheus.frame.DataFramePCA.Model
        public DataFrame<X, Y> getProjection(int i) {
            calculate();
            Matrix times = getScores(i).export().asMatrix().times(this.eigenVectors.cols().select(dataFrameColumn -> {
                return dataFrameColumn.ordinal() < i;
            }).transpose().export().asMatrix());
            return !isDemean() ? this.data.mapToDoubles(dataFrameValue -> {
                return times.get(dataFrameValue.rowOrdinal(), dataFrameValue.colOrdinal());
            }) : this.data.mapToDoubles(dataFrameValue2 -> {
                return times.get(dataFrameValue2.rowOrdinal(), dataFrameValue2.colOrdinal()) + this.means.data().getDouble(dataFrameValue2.colOrdinal(), 0);
            });
        }

        @Override // com.zavtech.morpheus.frame.DataFramePCA.Model
        public DataFrame<X, Integer> getScores(int i) {
            calculate();
            Matrix times = this.data.export().asMatrix().times(this.eigenVectors.cols().select(dataFrameColumn -> {
                return dataFrameColumn.ordinal() < i;
            }).export().asMatrix());
            return DataFrame.ofDoubles(data().rows().keyArray(), Range.of(0, i), dataFrameValue -> {
                return times.get(dataFrameValue.rowOrdinal(), dataFrameValue.colOrdinal());
            });
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/zavtech/morpheus/reference/XDataFramePCA$ModelWithEVD.class */
    public class ModelWithEVD<X, Y> extends XDataFramePCA<R, C>.ModelBase<X, Y> {
        private boolean cov;

        private ModelWithEVD(XDataFrame<X, Y> xDataFrame, boolean z, boolean z2) {
            super(xDataFrame, z);
            this.cov = z2;
        }

        @Override // com.zavtech.morpheus.reference.XDataFramePCA.ModelBase
        boolean calculate() {
            if (!isDirty()) {
                return false;
            }
            if (data().colCount() > data().rowCount()) {
                throw new DataFrameException("PCA Analysis expects frame as nxp matrix where n>=p, transpose and try again");
            }
            try {
                long currentTimeMillis = System.currentTimeMillis();
                EigenDecomposition eigenDecomposition = new EigenDecomposition((this.cov ? data().cols().stats().covariance() : data().cols().stats().correlation()).export().asMatrix());
                update(eigenDecomposition.getRealEigenvalues(), eigenDecomposition.getV());
                System.out.println("Eigen Decomposition completed in " + (System.currentTimeMillis() - currentTimeMillis) + " millis");
                return true;
            } catch (Exception e) {
                throw new DataFrameException("Failed to perform SVD on input data for PCA", e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/zavtech/morpheus/reference/XDataFramePCA$ModelWithSVD.class */
    public class ModelWithSVD<X, Y> extends XDataFramePCA<R, C>.ModelBase<X, Y> {
        private ModelWithSVD(XDataFrame<X, Y> xDataFrame, boolean z) {
            super(xDataFrame, z);
        }

        @Override // com.zavtech.morpheus.reference.XDataFramePCA.ModelBase
        boolean calculate() {
            if (!isDirty()) {
                return false;
            }
            if (data().colCount() > data().rowCount()) {
                throw new DataFrameException("PCA Analysis expects frame as nxp matrix where n>=p, transpose and try again");
            }
            try {
                double rowCount = data().rowCount();
                long currentTimeMillis = System.currentTimeMillis();
                SingularValueDecomposition singularValueDecomposition = new SingularValueDecomposition(data().export().asApacheMatrix());
                update(DoubleStream.of(singularValueDecomposition.getSingularValues()).map(d -> {
                    return Math.pow(d, 2.0d) / (rowCount - 1.0d);
                }).toArray(), singularValueDecomposition.getV());
                System.out.println("SVD Decomposition completed in " + (System.currentTimeMillis() - currentTimeMillis) + " millis");
                return true;
            } catch (Exception e) {
                throw new DataFrameException("Failed to perform SVD on input data for PCA", e);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:com/zavtech/morpheus/reference/XDataFramePCA$Ordering.class */
    public class Ordering implements Swapper, IntComparator {
        private double[] values;
        private int[] indices;

        Ordering(double[] dArr) {
            this.values = dArr;
            this.indices = IntStream.range(0, dArr.length).toArray();
        }

        int[] getIndices() {
            return this.indices;
        }

        @Override // com.zavtech.morpheus.util.IntComparator
        public int compare(int i, int i2) {
            return (-1) * Double.compare(this.values[this.indices[i]], this.values[this.indices[i2]]);
        }

        @Override // com.zavtech.morpheus.util.Swapper
        public void swap(int i, int i2) {
            int i3 = this.indices[i];
            this.indices[i] = this.indices[i2];
            this.indices[i2] = i3;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public XDataFramePCA(XDataFrame<R, C> xDataFrame) {
        this.frame = xDataFrame;
    }

    @Override // com.zavtech.morpheus.frame.DataFramePCA
    public <T> Optional<T> apply(boolean z, Function<DataFramePCA.Model<R, C>, Optional<T>> function) {
        return apply(z, DataFramePCA.Solver.SVD, function);
    }

    @Override // com.zavtech.morpheus.frame.DataFramePCA
    public <T> Optional<T> apply(boolean z, DataFramePCA.Solver solver, Function<DataFramePCA.Model<R, C>, Optional<T>> function) {
        Asserts.notNull(solver, "The PCA Solver cannot be null");
        Asserts.notNull(function, "The PCA lambda handler cannot be null");
        switch (solver) {
            case SVD:
                return function.apply(new ModelWithSVD(this.frame, z));
            case EVD_COV:
                return function.apply(new ModelWithEVD(this.frame, z, true));
            case EVD_COR:
                return function.apply(new ModelWithEVD(this.frame, z, false));
            default:
                throw new DataFrameException("Unsupported PCA solver specified: " + solver);
        }
    }
}
