package com.github.signaflo.data.regression;

import com.github.signaflo.data.DoublePair;
import com.github.signaflo.math.linear.doubles.Matrix;
import com.github.signaflo.math.linear.doubles.QuadraticForm;
import com.github.signaflo.math.linear.doubles.Vector;
import com.github.signaflo.math.stats.distributions.StudentsT;
import java.util.ArrayList;
import java.util.List;

/* loaded from: input_file:com/github/signaflo/data/regression/MultipleLinearRegressionPredictor.class */
public class MultipleLinearRegressionPredictor implements LinearRegressionPredictor {
    private final LinearRegression model;
    private final Matrix XtXInverse;
    private final int degreesOfFreedom;

    private MultipleLinearRegressionPredictor(MultipleLinearRegression multipleLinearRegression) {
        this.model = multipleLinearRegression;
        this.XtXInverse = Matrix.create(multipleLinearRegression.XtXInverse());
        this.degreesOfFreedom = multipleLinearRegression.response().length - multipleLinearRegression.designMatrix().length;
    }

    public static MultipleLinearRegressionPredictor from(MultipleLinearRegression multipleLinearRegression) {
        return new MultipleLinearRegressionPredictor(multipleLinearRegression);
    }

    private DoublePair getInterval(double d, double d2, double d3) {
        return new DoublePair(d - (d2 * d3), d + (d2 * d3));
    }

    @Override // com.github.signaflo.data.regression.LinearRegressionPredictor
    public LinearRegressionPrediction predict(Vector vector, double d) {
        double estimate = estimate(predictorWithIntercept(vector));
        return new MultipleLinearRegressionPrediction(estimate, standardErrorFit(predictorWithIntercept(vector)), confidenceInterval(d, predictorWithIntercept(vector), estimate), predictionInterval(d, predictorWithIntercept(vector), estimate));
    }

    private LinearRegressionPrediction predictWithIntercept(Vector vector, double d) {
        double estimate = estimate(vector);
        return new MultipleLinearRegressionPrediction(estimate, standardErrorFit(vector), confidenceInterval(d, vector, estimate), predictionInterval(d, vector, estimate));
    }

    private double estimate(Vector vector) {
        return vector.dotProduct(Vector.from(this.model.beta()));
    }

    private Vector predictorWithIntercept(Vector vector) {
        return this.model.hasIntercept() ? vector.push(1.0d) : vector;
    }

    private double standardErrorFit(Vector vector) {
        return Math.sqrt(this.model.sigma2() * QuadraticForm.multiply(vector, this.XtXInverse));
    }

    private DoublePair confidenceInterval(double d, Vector vector, double d2) {
        return getInterval(d2, new StudentsT(this.degreesOfFreedom).quantile(1.0d - (d / 2.0d)), standardErrorFit(vector));
    }

    private DoublePair predictionInterval(double d, Vector vector, double d2) {
        double quantile = new StudentsT(this.degreesOfFreedom).quantile(1.0d - (d / 2.0d));
        double standardErrorFit = standardErrorFit(vector);
        return getInterval(d2, quantile, Math.sqrt(this.model.sigma2() + (standardErrorFit * standardErrorFit)));
    }

    @Override // com.github.signaflo.data.regression.LinearRegressionPredictor
    public List<LinearRegressionPrediction> predict(Matrix matrix, double d) {
        ArrayList arrayList = new ArrayList(matrix.nrow());
        for (int i = 0; i < matrix.nrow(); i++) {
            arrayList.add(predict(matrix.getRow(i), d));
        }
        return arrayList;
    }

    public List<LinearRegressionPrediction> predictDesignMatrix(Matrix matrix, double d) {
        ArrayList arrayList = new ArrayList(matrix.nrow());
        for (int i = 0; i < matrix.nrow(); i++) {
            arrayList.add(predictWithIntercept(matrix.getRow(i), d));
        }
        return arrayList;
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return this.model.equals(((MultipleLinearRegressionPredictor) obj).model);
    }

    public int hashCode() {
        return this.model.hashCode();
    }

    public String toString() {
        return "MultipleLinearRegressionPredictor(model=" + this.model + ", XtXInverse=" + this.XtXInverse + ", degreesOfFreedom=" + this.degreesOfFreedom + ")";
    }
}
