/*
 * Decompiled with CFR 0.152.
 */
package com.github.signaflo.data.regression;

import com.github.signaflo.data.DoublePair;
import com.github.signaflo.data.regression.LinearRegression;
import com.github.signaflo.data.regression.LinearRegressionPrediction;
import com.github.signaflo.data.regression.LinearRegressionPredictor;
import com.github.signaflo.data.regression.MultipleLinearRegression;
import com.github.signaflo.data.regression.MultipleLinearRegressionPrediction;
import com.github.signaflo.math.linear.doubles.Matrix;
import com.github.signaflo.math.linear.doubles.QuadraticForm;
import com.github.signaflo.math.linear.doubles.Vector;
import com.github.signaflo.math.stats.distributions.StudentsT;
import java.util.ArrayList;
import java.util.List;

public class MultipleLinearRegressionPredictor
implements LinearRegressionPredictor {
    private final LinearRegression model;
    private final Matrix XtXInverse;
    private final int degreesOfFreedom;

    private MultipleLinearRegressionPredictor(MultipleLinearRegression model) {
        this.model = model;
        this.XtXInverse = Matrix.create((double[][])model.XtXInverse());
        this.degreesOfFreedom = model.response().length - model.designMatrix().length;
    }

    public static MultipleLinearRegressionPredictor from(MultipleLinearRegression model) {
        return new MultipleLinearRegressionPredictor(model);
    }

    private DoublePair getInterval(double sampleEstimate, double tValue, double standardError) {
        double lowerValue = sampleEstimate - tValue * standardError;
        double upperValue = sampleEstimate + tValue * standardError;
        return new DoublePair(lowerValue, upperValue);
    }

    @Override
    public LinearRegressionPrediction predict(Vector observation, double alpha) {
        double estimate = this.estimate(this.predictorWithIntercept(observation));
        double seFit = this.standardErrorFit(this.predictorWithIntercept(observation));
        DoublePair confidenceInterval = this.confidenceInterval(alpha, this.predictorWithIntercept(observation), estimate);
        DoublePair predictionInterval = this.predictionInterval(alpha, this.predictorWithIntercept(observation), estimate);
        return new MultipleLinearRegressionPrediction(estimate, seFit, confidenceInterval, predictionInterval);
    }

    private LinearRegressionPrediction predictWithIntercept(Vector vector, double alpha) {
        double estimate = this.estimate(vector);
        double seFit = this.standardErrorFit(vector);
        DoublePair confidenceInterval = this.confidenceInterval(alpha, vector, estimate);
        DoublePair predictionInterval = this.predictionInterval(alpha, vector, estimate);
        return new MultipleLinearRegressionPrediction(estimate, seFit, confidenceInterval, predictionInterval);
    }

    private double estimate(Vector data) {
        return data.dotProduct(Vector.from((double[])this.model.beta()));
    }

    private Vector predictorWithIntercept(Vector newData) {
        if (this.model.hasIntercept()) {
            return newData.push(1.0);
        }
        return newData;
    }

    private double standardErrorFit(Vector predictor) {
        double product = QuadraticForm.multiply((Vector)predictor, (Matrix)this.XtXInverse);
        return Math.sqrt(this.model.sigma2() * product);
    }

    private DoublePair confidenceInterval(double alpha, Vector predictor, double estimate) {
        StudentsT T = new StudentsT(this.degreesOfFreedom);
        double tValue = T.quantile(1.0 - alpha / 2.0);
        double seFit = this.standardErrorFit(predictor);
        return this.getInterval(estimate, tValue, seFit);
    }

    private DoublePair predictionInterval(double alpha, Vector predictor, double estimate) {
        StudentsT T = new StudentsT(this.degreesOfFreedom);
        double tValue = T.quantile(1.0 - alpha / 2.0);
        double seFit = this.standardErrorFit(predictor);
        double standardError = Math.sqrt(this.model.sigma2() + seFit * seFit);
        return this.getInterval(estimate, tValue, standardError);
    }

    @Override
    public List<LinearRegressionPrediction> predict(Matrix observations, double alpha) {
        ArrayList<LinearRegressionPrediction> predictions = new ArrayList<LinearRegressionPrediction>(observations.nrow());
        for (int i = 0; i < observations.nrow(); ++i) {
            predictions.add(this.predict(observations.getRow(i), alpha));
        }
        return predictions;
    }

    public List<LinearRegressionPrediction> predictDesignMatrix(Matrix designMatrix, double alpha) {
        ArrayList<LinearRegressionPrediction> predictions = new ArrayList<LinearRegressionPrediction>(designMatrix.nrow());
        for (int i = 0; i < designMatrix.nrow(); ++i) {
            predictions.add(this.predictWithIntercept(designMatrix.getRow(i), alpha));
        }
        return predictions;
    }

    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || this.getClass() != o.getClass()) {
            return false;
        }
        MultipleLinearRegressionPredictor that = (MultipleLinearRegressionPredictor)o;
        return this.model.equals(that.model);
    }

    public int hashCode() {
        return this.model.hashCode();
    }

    public String toString() {
        return "MultipleLinearRegressionPredictor(model=" + this.model + ", XtXInverse=" + this.XtXInverse + ", degreesOfFreedom=" + this.degreesOfFreedom + ")";
    }
}

