/*
 * Decompiled with CFR 0.152.
 */
package org.kie.kogito.explainability.utils;

import java.util.Arrays;
import java.util.stream.IntStream;
import org.apache.commons.math3.distribution.TDistribution;
import org.apache.commons.math3.linear.MatrixUtils;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
import org.apache.commons.math3.util.Pair;
import org.kie.kogito.explainability.utils.MatrixUtilsExtensions;
import org.kie.kogito.explainability.utils.WeightedLinearRegressionResults;

public class WeightedLinearRegression {
    private WeightedLinearRegression() {
        throw new IllegalStateException("Utility class");
    }

    public static WeightedLinearRegressionResults fit(RealMatrix features, RealVector observations, RealVector sampleWeights, boolean intercept) throws IllegalArgumentException, ArithmeticException {
        int nfeatures = intercept ? features.getColumnDimension() + 1 : features.getColumnDimension();
        int nsamples = observations.getDimension();
        int dof = nsamples - nfeatures;
        if (features.getRowDimension() != nsamples) {
            throw new IllegalArgumentException(String.format("Num sample mismatch: Number of rows in the features (%d)", features.getRowDimension()) + String.format(" must match number of observations (%d)", nsamples));
        }
        double weightSum = Arrays.stream(sampleWeights.toArray()).sum();
        if (weightSum == 0.0) {
            throw new ArithmeticException("Weights cannot sum to zero!");
        }
        if (intercept) {
            features = WeightedLinearRegression.adjustFeatureMatrix(features);
        }
        Pair<RealMatrix, RealVector> jointXTWXandXTWY = MatrixUtilsExtensions.jointATBATandATBC(features, sampleWeights, observations);
        RealMatrix xtWXInv = MatrixUtilsExtensions.safeInvert((RealMatrix)jointXTWXandXTWY.getFirst());
        RealVector xtWY = (RealVector)jointXTWXandXTWY.getSecond();
        RealVector coefficients = xtWXInv.operate(xtWY);
        ModelSquareSums mss = WeightedLinearRegression.getRSSandTSS(features, observations, sampleWeights, weightSum, coefficients, true);
        double mse = mss.residualSquareSum / weightSum;
        RealVector stdErrors = WeightedLinearRegression.getVarianceMatrix(dof, nfeatures, xtWXInv, mss);
        RealVector pvalues = WeightedLinearRegression.getPValues(dof, nfeatures, stdErrors, coefficients);
        return new WeightedLinearRegressionResults(coefficients, intercept, nsamples - nfeatures, mse, stdErrors, pvalues);
    }

    public static double getMSE(RealMatrix features, RealVector observations, RealVector sampleWeights, RealVector coefficients) {
        double weightSum = Arrays.stream(sampleWeights.toArray()).sum();
        if (weightSum == 0.0) {
            throw new ArithmeticException("Weights cannot sum to zero!");
        }
        ModelSquareSums mss = WeightedLinearRegression.getRSSandTSS(features, observations, sampleWeights, weightSum, coefficients, false);
        return mss.residualSquareSum / weightSum;
    }

    private static RealMatrix adjustFeatureMatrix(RealMatrix features) {
        int nsamples = features.getRowDimension();
        int nfeatures = features.getColumnDimension() + 1;
        RealMatrix adjustedFeatures = MatrixUtils.createRealMatrix((int)nsamples, (int)nfeatures);
        adjustedFeatures.setSubMatrix(features.getData(), 0, 0);
        adjustedFeatures.setColumnVector(nfeatures - 1, MatrixUtils.createRealVector((double[])new double[nsamples]).mapAdd(1.0));
        return adjustedFeatures;
    }

    private static ModelSquareSums getRSSandTSS(RealMatrix features, RealVector observations, RealVector sampleWeights, double weightSum, RealVector coefficients, boolean needObservationVariance) {
        double yBar = sampleWeights.dotProduct(observations) / weightSum;
        RealVector residual = observations.subtract(features.operate(coefficients));
        RealVector variance = observations.mapSubtract(yBar);
        double residualSquareSum = sampleWeights.dotProduct(residual.ebeMultiply(residual));
        double totalSquareSum = sampleWeights.dotProduct(variance.ebeMultiply(variance));
        if (needObservationVariance && totalSquareSum == 0.0) {
            throw new ArithmeticException("Total variance of observations is zero. Use more samples to correct this error");
        }
        return new ModelSquareSums(residualSquareSum, totalSquareSum);
    }

    private static RealVector getVarianceMatrix(int dof, int nfeatures, RealMatrix invertedLSMatrix, ModelSquareSums mss) {
        double residualMeanSquare = mss.residualSquareSum / (double)dof;
        return MatrixUtils.createRealVector((double[])IntStream.range(0, nfeatures).mapToDouble(i -> Math.sqrt(invertedLSMatrix.getEntry(i, i) * residualMeanSquare)).toArray());
    }

    private static RealVector getPValues(int dof, int nfeatures, RealVector coefficientError, RealVector coefficients) {
        if (dof <= 0) {
            return MatrixUtils.createRealVector((double[])new double[nfeatures]).mapAdd(Double.POSITIVE_INFINITY);
        }
        RealVector tvalues = coefficients.ebeDivide(coefficientError);
        TDistribution tdist = new TDistribution((double)dof);
        return tvalues.map(x -> 2.0 * (1.0 - tdist.cumulativeProbability(x)));
    }

    private static class ModelSquareSums {
        public final double residualSquareSum;
        public final double totalSquareSum;

        ModelSquareSums(double residualSquareSum, double totalSquareSum) {
            this.residualSquareSum = residualSquareSum;
            this.totalSquareSum = totalSquareSum;
        }
    }
}

