package com.zavtech.morpheus.reference.regress;

import com.zavtech.morpheus.array.Array;
import com.zavtech.morpheus.frame.DataFrame;
import com.zavtech.morpheus.frame.DataFrameContent;
import com.zavtech.morpheus.frame.DataFrameException;
import com.zavtech.morpheus.frame.DataFrameLeastSquares;
import com.zavtech.morpheus.range.Range;
import com.zavtech.morpheus.util.text.printer.Printer;
import java.io.ByteArrayOutputStream;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import java.util.stream.Stream;
import org.apache.commons.math3.distribution.FDistribution;
import org.apache.commons.math3.distribution.TDistribution;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.QRDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;

/* loaded from: input_file:com/zavtech/morpheus/reference/regress/XDataFrameLeastSquares.class */
abstract class XDataFrameLeastSquares<R, C> implements DataFrameLeastSquares<R, C> {
    private static final List<DataFrameLeastSquares.Field> fields = Arrays.asList(DataFrameLeastSquares.Field.values());
    private String name;
    private double alpha;
    private C regressand;
    private List<C> regressors;
    private DataFrameLeastSquares.Solver solver;
    private DataFrame<R, C> frame;
    private double tss;
    private double rss;
    private double ess;
    private double stdError;
    private double threshold;
    private double rSquared;
    private double rSquaredAdj;
    private double errorVariance;
    private boolean hasIntercept;
    private double fValue;
    private double fValueProbability;
    private long runtimeMillis;
    private DataFrame<C, DataFrameLeastSquares.Field> betas;
    private DataFrame<R, String> residuals;
    private DataFrame<String, DataFrameLeastSquares.Field> intercept;

    /* JADX INFO: Access modifiers changed from: package-private */
    public XDataFrameLeastSquares(String str, DataFrame<R, C> dataFrame, C c, List<C> list, boolean z) {
        if (list.size() == 0) {
            throw new DataFrameException("At least one regressor must be specified");
        }
        this.name = str;
        this.alpha = 0.05d;
        this.frame = dataFrame;
        this.threshold = 0.0d;
        this.solver = DataFrameLeastSquares.Solver.QR;
        this.regressand = c;
        this.hasIntercept = z;
        this.regressors = new ArrayList(list);
        this.betas = DataFrame.ofDoubles((Iterable) list, (Iterable) fields);
        this.intercept = DataFrame.ofDoubles("Intercept", fields);
    }

    abstract void compute();

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public void fit() {
        try {
            long nanoTime = System.nanoTime();
            compute();
            this.runtimeMillis = (System.nanoTime() - nanoTime) / 1000000;
        } catch (DataFrameException e) {
            throw e;
        } catch (Exception e2) {
            throw new DataFrameException("Failed white running linear regression of " + getRegressand().toString() + " on " + Arrays.toString(getRegressors().toArray()), e2);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public DataFrame<R, C> frame() {
        return this.frame;
    }

    private boolean isDirty() {
        return this.residuals == null;
    }

    private void computeIf() {
        if (isDirty()) {
            fit();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public RealVector createY() {
        int count = this.frame.rows().count();
        int ordinalOf = this.frame.cols().ordinalOf(this.regressand);
        ArrayRealVector arrayRealVector = new ArrayRealVector(count);
        for (int i = 0; i < count; i++) {
            arrayRealVector.setEntry(i, this.frame.data().getDouble(i, ordinalOf));
        }
        return arrayRealVector;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public RealMatrix createX() {
        int count = this.frame.rows().count();
        int i = hasIntercept() ? 1 : 0;
        int size = hasIntercept() ? this.regressors.size() + 1 : this.regressors.size();
        int[] array = this.regressors.stream().mapToInt(obj -> {
            return this.frame.cols().ordinalOf(obj);
        }).toArray();
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(count, size);
        for (int i2 = 0; i2 < count; i2++) {
            array2DRowRealMatrix.setEntry(i2, 0, 1.0d);
            for (int i3 = i; i3 < size; i3++) {
                array2DRowRealMatrix.setEntry(i2, i3, this.frame.data().getDouble(i2, array[i3 - i]));
            }
        }
        return array2DRowRealMatrix;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void compute(RealVector realVector, RealMatrix realMatrix) {
        int count = this.frame.rows().count();
        int size = this.regressors.size() + (hasIntercept() ? 1 : 0);
        int size2 = this.regressors.size();
        RealMatrix computeBeta = computeBeta(realVector, realMatrix);
        RealVector columnVector = computeBeta.getColumnVector(0);
        RealVector columnVector2 = computeBeta.getColumnVector(1);
        this.tss = computeTSS(realVector);
        this.ess = this.tss - this.rss;
        this.fValue = (this.ess / size2) / (this.rss / (count - size));
        this.fValueProbability = 1.0d - new FDistribution(size2, count - size).cumulativeProbability(this.fValue);
        this.rSquared = 1.0d - (this.rss / this.tss);
        this.rSquaredAdj = 1.0d - ((this.rss * (count - (hasIntercept() ? 1 : 0))) / (this.tss * (count - size)));
        computeParameterStdErrors(columnVector2);
        computeParameterSignificance(columnVector);
    }

    private RealMatrix computeBeta(RealVector realVector, RealMatrix realMatrix) {
        if (this.solver == DataFrameLeastSquares.Solver.QR) {
            return computeBetaQR(realVector, realMatrix);
        }
        int rowDimension = realMatrix.getRowDimension();
        int columnDimension = realMatrix.getColumnDimension();
        int i = hasIntercept() ? 1 : 0;
        RealMatrix transpose = realMatrix.transpose();
        RealMatrix inverse = new LUDecomposition(transpose.multiply(realMatrix)).getSolver().getInverse();
        RealVector operate = inverse.multiply(transpose).operate(realVector);
        RealVector subtract = realVector.subtract(realMatrix.operate(operate));
        this.rss = subtract.dotProduct(subtract);
        this.errorVariance = this.rss / (rowDimension - columnDimension);
        this.stdError = Math.sqrt(this.errorVariance);
        this.residuals = createResidualsFrame(subtract);
        RealMatrix scalarMultiply = inverse.scalarMultiply(this.errorVariance);
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(columnDimension, 2);
        if (hasIntercept()) {
            array2DRowRealMatrix.setEntry(0, 0, operate.getEntry(0));
            array2DRowRealMatrix.setEntry(0, 1, scalarMultiply.getEntry(0, 0));
        }
        for (int i2 = 0; i2 < getRegressors().size(); i2++) {
            int i3 = i2 + i;
            array2DRowRealMatrix.setEntry(i3, 1, scalarMultiply.getEntry(i3, i3));
            array2DRowRealMatrix.setEntry(i3, 0, operate.getEntry(i3));
        }
        return array2DRowRealMatrix;
    }

    private RealMatrix computeBetaQR(RealVector realVector, RealMatrix realMatrix) {
        int rowDimension = realMatrix.getRowDimension();
        int columnDimension = realMatrix.getColumnDimension();
        int i = hasIntercept() ? 1 : 0;
        QRDecomposition qRDecomposition = new QRDecomposition(realMatrix, this.threshold);
        RealVector solve = qRDecomposition.getSolver().solve(realVector);
        RealVector subtract = realVector.subtract(realMatrix.operate(solve));
        this.rss = subtract.dotProduct(subtract);
        this.errorVariance = this.rss / (rowDimension - columnDimension);
        this.stdError = Math.sqrt(this.errorVariance);
        this.residuals = createResidualsFrame(subtract);
        RealMatrix inverse = new LUDecomposition(qRDecomposition.getR().getSubMatrix(0, columnDimension - 1, 0, columnDimension - 1)).getSolver().getInverse();
        RealMatrix scalarMultiply = inverse.multiply(inverse.transpose()).scalarMultiply(this.errorVariance);
        Array2DRowRealMatrix array2DRowRealMatrix = new Array2DRowRealMatrix(columnDimension, 2);
        if (hasIntercept()) {
            array2DRowRealMatrix.setEntry(0, 0, solve.getEntry(0));
            array2DRowRealMatrix.setEntry(0, 1, scalarMultiply.getEntry(0, 0));
        }
        for (int i2 = 0; i2 < getRegressors().size(); i2++) {
            int i3 = i2 + i;
            array2DRowRealMatrix.setEntry(i3, 1, scalarMultiply.getEntry(i3, i3));
            array2DRowRealMatrix.setEntry(i3, 0, solve.getEntry(i3));
        }
        return array2DRowRealMatrix;
    }

    private DataFrame<R, String> createResidualsFrame(RealVector realVector) {
        DataFrame<R, String> ofDoubles = DataFrame.ofDoubles(this.frame.rows().keyArray(), "Residuals");
        ofDoubles.applyDoubles(dataFrameValue -> {
            return realVector.getEntry(dataFrameValue.rowOrdinal());
        });
        return ofDoubles;
    }

    private void computeParameterStdErrors(RealVector realVector) {
        try {
            int i = hasIntercept() ? 1 : 0;
            if (hasIntercept()) {
                this.intercept.data().setDouble(0, (int) DataFrameLeastSquares.Field.STD_ERROR, Math.sqrt(realVector.getEntry(0)));
            }
            for (int i2 = 0; i2 < this.regressors.size(); i2++) {
                this.betas.data().setDouble(i2, (int) DataFrameLeastSquares.Field.STD_ERROR, Math.sqrt(realVector.getEntry(i2 + i)));
            }
        } catch (Exception e) {
            throw new DataFrameException("Failed to calculate regression coefficient standard errors", e);
        }
    }

    private void computeParameterSignificance(RealVector realVector) {
        try {
            TDistribution tDistribution = new TDistribution(this.frame.rows().count() - (this.regressors.size() + 1));
            double entry = realVector.getEntry(0);
            double d = this.intercept.data().getDouble(0, (int) DataFrameLeastSquares.Field.STD_ERROR);
            double d2 = entry / d;
            double cumulativeProbability = tDistribution.cumulativeProbability(-Math.abs(d2)) * 2.0d;
            double inverseCumulativeProbability = d * tDistribution.inverseCumulativeProbability(1.0d - (this.alpha / 2.0d));
            this.intercept.data().setDouble(0, (int) DataFrameLeastSquares.Field.PARAMETER, entry);
            this.intercept.data().setDouble(0, (int) DataFrameLeastSquares.Field.T_STAT, d2);
            this.intercept.data().setDouble(0, (int) DataFrameLeastSquares.Field.P_VALUE, cumulativeProbability);
            this.intercept.data().setDouble(0, (int) DataFrameLeastSquares.Field.CI_LOWER, entry - inverseCumulativeProbability);
            this.intercept.data().setDouble(0, (int) DataFrameLeastSquares.Field.CI_UPPER, entry + inverseCumulativeProbability);
            int i = hasIntercept() ? 1 : 0;
            for (int i2 = 0; i2 < this.regressors.size(); i2++) {
                C c = this.regressors.get(i2);
                double entry2 = realVector.getEntry(i2 + i);
                double d3 = this.betas.data().getDouble((DataFrameContent<C, DataFrameLeastSquares.Field>) c, DataFrameLeastSquares.Field.STD_ERROR);
                double d4 = entry2 / d3;
                double cumulativeProbability2 = tDistribution.cumulativeProbability(-Math.abs(d4)) * 2.0d;
                double inverseCumulativeProbability2 = d3 * tDistribution.inverseCumulativeProbability(1.0d - (this.alpha / 2.0d));
                this.betas.data().setDouble((DataFrameContent<C, DataFrameLeastSquares.Field>) c, DataFrameLeastSquares.Field.PARAMETER, entry2);
                this.betas.data().setDouble((DataFrameContent<C, DataFrameLeastSquares.Field>) c, DataFrameLeastSquares.Field.T_STAT, d4);
                this.betas.data().setDouble((DataFrameContent<C, DataFrameLeastSquares.Field>) c, DataFrameLeastSquares.Field.P_VALUE, cumulativeProbability2);
                this.betas.data().setDouble((DataFrameContent<C, DataFrameLeastSquares.Field>) c, DataFrameLeastSquares.Field.CI_LOWER, entry2 - inverseCumulativeProbability2);
                this.betas.data().setDouble((DataFrameContent<C, DataFrameLeastSquares.Field>) c, DataFrameLeastSquares.Field.CI_UPPER, entry2 + inverseCumulativeProbability2);
            }
        } catch (Exception e) {
            throw new DataFrameException("Failed to compute regression coefficient t-stats and p-values", e);
        }
    }

    protected double computeTSS(RealVector realVector) {
        if (!hasIntercept()) {
            return realVector.dotProduct(realVector);
        }
        double[] array = realVector.toArray();
        double orElse = DoubleStream.of(array).average().orElse(Double.NaN);
        ArrayRealVector arrayRealVector = new ArrayRealVector(DoubleStream.of(array).map(d -> {
            return d - orElse;
        }).toArray());
        return arrayRealVector.dotProduct(arrayRealVector);
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public int getN() {
        return this.frame.rowCount();
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public C getRegressand() {
        return this.regressand;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public boolean hasIntercept() {
        return this.hasIntercept;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public List<C> getRegressors() {
        return Collections.unmodifiableList(this.regressors);
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public double getAlpha() {
        return this.alpha;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public double getRSquared() {
        computeIf();
        return this.rSquared;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public double getRSquaredAdj() {
        computeIf();
        return this.rSquaredAdj;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public double getStdError() {
        computeIf();
        return this.stdError;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public double getFValue() {
        computeIf();
        return this.fValue;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public double getFValueProbability() {
        computeIf();
        return this.fValueProbability;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public double getTotalSumOfSquares() {
        computeIf();
        return this.tss;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public double getExplainedSumOfSquares() {
        computeIf();
        return this.ess;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public double getResidualSumOfSquares() {
        computeIf();
        return this.rss;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public double getInterceptValue(DataFrameLeastSquares.Field field) {
        computeIf();
        return this.intercept.data().getDouble(0, (int) field);
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public double getBetaValue(C c, DataFrameLeastSquares.Field field) {
        computeIf();
        return this.betas.data().getDouble((DataFrameContent<C, DataFrameLeastSquares.Field>) c, field);
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public DataFrame<C, DataFrameLeastSquares.Field> getBetas() {
        computeIf();
        return this.betas;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public DataFrame<String, DataFrameLeastSquares.Field> getIntercept() {
        computeIf();
        return this.intercept;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public DataFrame<R, String> getResiduals() {
        computeIf();
        return this.residuals;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public DataFrame<R, String> getFittedValues() {
        try {
            computeIf();
            double interceptValue = getInterceptValue(DataFrameLeastSquares.Field.PARAMETER);
            double[] array = this.regressors.stream().mapToDouble(obj -> {
                return getBetaValue(obj, DataFrameLeastSquares.Field.PARAMETER);
            }).toArray();
            return DataFrame.ofDoubles(this.frame.rows().keyArray(), Array.of("Fitted"), dataFrameValue -> {
                Object rowKey = dataFrameValue.rowKey();
                double d = interceptValue;
                for (int i = 0; i < this.regressors.size(); i++) {
                    d += this.frame.data().getDouble((DataFrameContent<R, C>) rowKey, this.regressors.get(i)) * array[i];
                }
                return d;
            });
        } catch (Exception e) {
            throw new DataFrameException("Failed to compute regression fitted values", e);
        }
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public double getDurbinWatsonStatistic() {
        try {
            computeIf();
            double d = 0.0d;
            double d2 = 0.0d;
            int rowCount = this.residuals.rowCount();
            for (int i = 1; i < rowCount; i++) {
                double d3 = this.residuals.data().getDouble(i, 0);
                d += d3 * d3;
                d2 += Math.pow(d3 - this.residuals.data().getDouble(i - 1, 0), 2.0d);
            }
            return d2 / d;
        } catch (Exception e) {
            throw new DataFrameException("Failed to compute the Durbin-Watson Statistic", e);
        }
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public DataFrame<Integer, String> getResidualsAcf(int i) {
        try {
            computeIf();
            return DataFrame.ofDoubles(Range.of(0, i), Array.of("Residual(ACF)"), dataFrameValue -> {
                return this.residuals.colAt(0).stats().autocorr(((Integer) dataFrameValue.rowKey()).intValue()).doubleValue();
            });
        } catch (Exception e) {
            throw new DataFrameException("Failed to compute the autocorrelation function of residuals", e);
        }
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public DataFrameLeastSquares<R, C> withSolver(DataFrameLeastSquares.Solver solver) {
        if (solver != this.solver) {
            this.solver = solver;
            this.residuals = null;
        }
        return this;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public DataFrameLeastSquares<R, C> withAlpha(double d) {
        if (d != this.alpha) {
            this.alpha = d;
            this.residuals = null;
        }
        return this;
    }

    @Override // com.zavtech.morpheus.frame.DataFrameLeastSquares
    public DataFrameLeastSquares<R, C> withIntercept(boolean z) {
        if (this.hasIntercept != z) {
            this.hasIntercept = z;
            this.residuals = null;
        }
        return this;
    }

    public String toString() {
        computeIf();
        int count = this.frame.rows().count();
        int size = this.regressors.size() + 1;
        int size2 = this.regressors.size();
        String text = toText(100, getSummary());
        int length = text.trim().split("\n")[1].length();
        int i = (length - 4) / 4;
        DecimalFormat decimalFormat = new DecimalFormat("0.###E0;-0.###E0");
        String str = "%-" + i + "s%" + i + "s    %-" + i + "s%" + i + "s";
        StringBuilder sb = new StringBuilder();
        sb.append("\n").append(lineOf('=', length));
        sb.append("\n");
        sb.append(lineOf(' ', (length / 2) - ("Linear Regression Results".length() / 2)));
        sb.append("Linear Regression Results");
        sb.append(lineOf(' ', ((length - (length / 2)) - ("Linear Regression Results".length() / 2)) + "Linear Regression Results".length()));
        sb.append("\n").append(lineOf('=', length));
        sb.append("\n").append(String.format(str, "Model:", this.name, "R-Squared:", String.format("%.4f", Double.valueOf(getRSquared()))));
        sb.append("\n").append(String.format(str, "Observations:", Integer.valueOf(count), "R-Squared(adjusted):", String.format("%.4f", Double.valueOf(getRSquaredAdj()))));
        sb.append("\n").append(String.format(str, "DF Model:", Integer.valueOf(size2), "F-Statistic:", String.format("%.4f", Double.valueOf(getFValue()))));
        sb.append("\n").append(String.format(str, "DF Residuals:", Integer.valueOf(count - size), "F-Statistic(Prob):", decimalFormat.format(Double.valueOf(getFValueProbability()))));
        sb.append("\n").append(String.format(str, "Standard Error:", String.format("%.4f", Double.valueOf(getStdError())), "Runtime(millis)", Long.valueOf(this.runtimeMillis)));
        sb.append("\n").append(String.format(str, "Durbin-Watson:", String.format("%.4f", Double.valueOf(getDurbinWatsonStatistic())), "", ""));
        sb.append("\n").append(lineOf('=', length));
        sb.append(text);
        sb.append("\n").append(lineOf('=', length));
        return sb.toString();
    }

    private String lineOf(char c, int i) {
        char[] cArr = new char[i];
        Arrays.fill(cArr, c);
        return new String(cArr);
    }

    private DataFrame<Object, DataFrameLeastSquares.Field> getSummary() {
        return DataFrame.ofDoubles((List) Stream.concat(hasIntercept() ? Stream.of("Intercept") : Stream.empty(), this.regressors.stream()).collect(Collectors.toList()), fields, dataFrameValue -> {
            DataFrameLeastSquares.Field field = (DataFrameLeastSquares.Field) dataFrameValue.colKey();
            if (dataFrameValue.rowKey().equals("Intercept")) {
                return this.intercept.data().getDouble(0, (int) field);
            }
            return this.betas.data().getDouble((DataFrameContent<C, DataFrameLeastSquares.Field>) dataFrameValue.rowKey(), field);
        });
    }

    private String toText(int i, DataFrame<?, DataFrameLeastSquares.Field> dataFrame) {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(1024);
        dataFrame.out().print(i, byteArrayOutputStream, formats -> {
            formats.setPrinter(DataFrameLeastSquares.Field.PARAMETER, Printer.ofDouble("#.####;-#.####"));
            formats.setPrinter(DataFrameLeastSquares.Field.STD_ERROR, Printer.ofDouble("#.####;-#.####"));
            formats.setPrinter(DataFrameLeastSquares.Field.T_STAT, Printer.ofDouble("#.####;-#.####"));
            formats.setPrinter(DataFrameLeastSquares.Field.P_VALUE, Printer.ofDouble("0.###E0;-0.###E0"));
            formats.setPrinter(DataFrameLeastSquares.Field.CI_LOWER, Printer.ofDouble("#.####;-#.####"));
            formats.setPrinter(DataFrameLeastSquares.Field.CI_UPPER, Printer.ofDouble("#.####;-#.####"));
        });
        return new String(byteArrayOutputStream.toByteArray());
    }
}
