package com.github.signaflo.timeseries.model.arima;

import org.ejml.data.DenseMatrix64F;
import org.ejml.data.RowD1Matrix64F;
import org.ejml.ops.CommonOps;

/* loaded from: input_file:com/github/signaflo/timeseries/model/arima/ArimaKalmanFilter.class */
class ArimaKalmanFilter {
    private final double[] y;
    private final int r;
    private final int d;
    private final int rd;
    private final DenseMatrix64F transitionMatrix;
    private final RowD1Matrix64F stateDisturbance;
    private final RowD1Matrix64F predictedState;
    private final RowD1Matrix64F filteredState;
    private final DenseMatrix64F predictedStateCovariance;
    private final RowD1Matrix64F filteredStateCovariance;
    private final double[] predictionErrorVariance;
    private final double[] predictionError;
    private final DenseMatrix64F Z;
    private final DenseMatrix64F Zt;
    private final DenseMatrix64F ZP;
    private DenseMatrix64F PZtf;
    private final DenseMatrix64F PZtfZ;
    private final KalmanOutput kalmanOutput;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/github/signaflo/timeseries/model/arima/ArimaKalmanFilter$KalmanOutput.class */
    public static class KalmanOutput {
        private final int n;
        private final double ssq;
        private final double sumlog;
        private final double sigma2;
        private final double logLikelihood;
        private final double[] residuals;

        KalmanOutput(int i, double d, double d2, double[] dArr) {
            this.n = i;
            this.ssq = d;
            this.sumlog = d2;
            this.sigma2 = d / i;
            this.logLikelihood = (((-i) / 2.0d) * (Math.log(6.283185307179586d * this.sigma2) + 1.0d)) - (0.5d * d2);
            this.residuals = (double[]) dArr.clone();
        }

        double ssq() {
            return this.ssq;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public int n() {
            return this.n;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public double sumLog() {
            return this.sumlog;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public double sigma2() {
            return this.sigma2;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public double logLikelihood() {
            return this.logLikelihood;
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public double[] residuals() {
            return (double[]) this.residuals.clone();
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public ArimaKalmanFilter(ArimaStateSpace arimaStateSpace) {
        this.y = arimaStateSpace.observations();
        this.r = arimaStateSpace.r();
        this.d = arimaStateSpace.d();
        this.rd = this.r + this.d;
        this.transitionMatrix = new DenseMatrix64F(arimaStateSpace.transitionMatrix());
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.rd, 1, true, arimaStateSpace.movingAverageVector());
        this.stateDisturbance = new DenseMatrix64F(this.rd, this.rd);
        CommonOps.multOuter(denseMatrix64F, this.stateDisturbance);
        this.predictedState = new DenseMatrix64F(this.rd, 1, true, new double[this.rd]);
        this.filteredState = new DenseMatrix64F(this.rd, 1, true, new double[this.rd]);
        this.predictedStateCovariance = initializePredictedCovariance(arimaStateSpace);
        this.filteredStateCovariance = new DenseMatrix64F(this.rd, this.rd);
        this.predictionErrorVariance = new double[this.y.length];
        this.predictionError = new double[this.y.length];
        this.Z = new DenseMatrix64F(1, this.rd, true, arimaStateSpace.stateEffectsVector());
        this.Zt = new DenseMatrix64F(this.rd, 1, true, new double[this.rd]);
        this.ZP = new DenseMatrix64F(1, this.rd, true, new double[this.rd]);
        this.PZtf = new DenseMatrix64F(this.rd, 1, true, new double[this.rd]);
        this.PZtfZ = new DenseMatrix64F(this.rd, this.rd, true, new double[this.rd * this.rd]);
        this.kalmanOutput = filter();
    }

    private KalmanOutput filter() {
        int i = 0;
        this.predictionError[0] = this.y[0];
        CommonOps.mult(this.Z, this.predictedStateCovariance, this.ZP);
        CommonOps.transpose(this.Z, this.Zt);
        this.predictionErrorVariance[0] = CommonOps.dot(this.ZP, this.Zt);
        double d = this.predictionErrorVariance[0];
        double d2 = 0.0d;
        double d3 = 0.0d;
        if (d < 10000.0d) {
            i = 0 + 1;
            d2 = (this.predictionError[0] * this.predictionError[0]) / d;
            d3 = Math.log(d);
        }
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.rd, 1, true, new double[this.rd]);
        CommonOps.transpose(this.ZP, denseMatrix64F);
        CommonOps.divide(denseMatrix64F, d);
        this.PZtf = denseMatrix64F.copy();
        CommonOps.scale(this.predictionError[0], denseMatrix64F);
        CommonOps.add(this.predictedState, denseMatrix64F, this.filteredState);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.rd, this.rd);
        CommonOps.mult(this.PZtf, this.Z, this.PZtfZ);
        CommonOps.mult(this.PZtfZ, this.predictedStateCovariance, denseMatrix64F2);
        CommonOps.subtract(this.predictedStateCovariance, denseMatrix64F2, this.filteredStateCovariance);
        DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(this.rd, this.rd);
        DenseMatrix64F denseMatrix64F4 = new DenseMatrix64F(this.rd, this.rd);
        DenseMatrix64F copy = this.transitionMatrix.copy();
        CommonOps.transpose(copy);
        double[] dArr = this.predictionError;
        dArr[0] = dArr[0] / Math.sqrt(d);
        for (int i2 = 1; i2 < this.y.length; i2++) {
            CommonOps.mult(this.transitionMatrix, this.filteredState, this.predictedState);
            CommonOps.mult(this.transitionMatrix, this.filteredStateCovariance, denseMatrix64F3);
            CommonOps.mult(denseMatrix64F3, copy, denseMatrix64F4);
            CommonOps.add(denseMatrix64F4, this.stateDisturbance, this.predictedStateCovariance);
            this.predictionError[i2] = this.y[i2] - CommonOps.dot(this.Z, this.predictedState);
            CommonOps.mult(this.Z, this.predictedStateCovariance, this.ZP);
            this.predictionErrorVariance[i2] = CommonOps.dot(this.ZP, this.Zt);
            double d4 = this.predictionErrorVariance[i2];
            if (d4 < 10000.0d) {
                i++;
                d2 += (this.predictionError[i2] * this.predictionError[i2]) / d4;
                d3 += Math.log(d4);
            }
            CommonOps.transpose(this.ZP, denseMatrix64F);
            CommonOps.divide(denseMatrix64F, d4);
            this.PZtf = denseMatrix64F.copy();
            CommonOps.scale(this.predictionError[i2], denseMatrix64F);
            CommonOps.add(this.predictedState, denseMatrix64F, this.filteredState);
            CommonOps.mult(this.PZtf, this.Z, this.PZtfZ);
            CommonOps.mult(this.PZtfZ, this.predictedStateCovariance, denseMatrix64F2);
            CommonOps.subtract(this.predictedStateCovariance, denseMatrix64F2, this.filteredStateCovariance);
            double[] dArr2 = this.predictionError;
            int i3 = i2;
            dArr2[i3] = dArr2[i3] / Math.sqrt(d4);
        }
        return new KalmanOutput(i, d2, d3, this.predictionError);
    }

    private DenseMatrix64F initializePredictedCovariance(ArimaStateSpace arimaStateSpace) {
        DenseMatrix64F denseMatrix64F = new DenseMatrix64F(this.rd, this.rd);
        DenseMatrix64F denseMatrix64F2 = new DenseMatrix64F(this.r, this.r, true, unpack(getInitialStateCovariance(arimaStateSpace.arParams(), arimaStateSpace.maParams())));
        double[] dArr = new double[this.d * this.d];
        for (int i = 0; i < this.d; i++) {
            dArr[(i * this.d) + i] = 1000000.0d;
        }
        DenseMatrix64F denseMatrix64F3 = new DenseMatrix64F(this.d, this.d, true, dArr);
        CommonOps.insert(denseMatrix64F2, denseMatrix64F, 0, 0);
        CommonOps.insert(denseMatrix64F3, denseMatrix64F, this.r, this.r);
        return denseMatrix64F;
    }

    private static double[] getInitialStateCovariance(double[] dArr, double[] dArr2) {
        int length = dArr.length;
        int length2 = dArr2.length;
        if (length == 0 && length2 == 0) {
            return new double[]{1.0d};
        }
        int max = Math.max(length, length2 + 1);
        int i = (max * (max + 1)) / 2;
        double[] dArr3 = new double[i];
        double[] dArr4 = new double[i];
        double[] dArr5 = new double[i];
        int i2 = (i * (i - 1)) / 2;
        int validate = validate(length, length2, max, i, i2);
        if (validate != 0) {
            throw new RuntimeException("Validation Error # " + validate);
        }
        for (int i3 = 1; i3 < max; i3++) {
            dArr4[i3] = 0.0d;
            if (i3 <= length2) {
                dArr4[i3] = dArr2[i3 - 1];
            }
        }
        dArr4[0] = 1.0d;
        int i4 = max;
        for (int i5 = 1; i5 < max; i5++) {
            double d = dArr4[i5];
            for (int i6 = i5; i6 < max; i6++) {
                int i7 = i4;
                i4++;
                dArr4[i7] = dArr4[i6] * d;
            }
        }
        if (length == 0) {
            int i8 = i;
            int i9 = i;
            for (int i10 = 0; i10 < max; i10++) {
                for (int i11 = 0; i11 <= i10; i11++) {
                    i9--;
                    dArr3[i9] = dArr4[i9];
                    if (i11 != 0) {
                        i8--;
                        dArr3[i9] = dArr3[i9] + dArr3[i8];
                    }
                }
            }
            return dArr3;
        }
        double[] dArr6 = new double[i2];
        double[] dArr7 = new double[i];
        double[] dArr8 = new double[i];
        int i12 = 0;
        int i13 = -1;
        int i14 = i - max;
        int i15 = i14 + 1;
        int i16 = i14;
        int i17 = i14 - 1;
        int i18 = 0;
        while (i18 < max) {
            double d2 = i18 < length ? dArr[i18] : 0.0d;
            int i19 = i16;
            i16++;
            dArr8[i19] = 0.0d;
            int i20 = i15 + i18;
            int i21 = i18;
            while (i21 < max) {
                int i22 = i12;
                i12++;
                double d3 = dArr4[i22];
                double d4 = i21 < length ? dArr[i21] : 0.0d;
                if (i18 != max - 1) {
                    dArr8[i16] = -d4;
                    if (i21 != max - 1) {
                        int i23 = i20;
                        dArr8[i23] = dArr8[i23] - d2;
                        i13++;
                        dArr8[i13] = -1.0d;
                    }
                }
                dArr8[i14] = (-d4) * d2;
                i17++;
                if (i17 >= i) {
                    i17 = 0;
                }
                int i24 = i17;
                dArr8[i24] = dArr8[i24] + 1.0d;
                inclu2(i, dArr8, dArr5, d3, dArr3, dArr6, dArr7);
                dArr8[i17] = 0.0d;
                if (i21 != max - 1) {
                    int i25 = i20;
                    i20++;
                    dArr8[i25] = 0.0d;
                    dArr8[i13] = 0.0d;
                }
                i21++;
            }
            i18++;
        }
        regres(i, i2, dArr6, dArr7, dArr3);
        int i26 = i14;
        for (int i27 = 0; i27 < max; i27++) {
            int i28 = i26;
            i26++;
            dArr8[i27] = dArr3[i28];
        }
        int i29 = i - 1;
        int i30 = i14 - 1;
        for (int i31 = 0; i31 < i14; i31++) {
            int i32 = i29;
            i29--;
            int i33 = i30;
            i30--;
            dArr3[i32] = dArr3[i33];
        }
        System.arraycopy(dArr8, 0, dArr3, 0, max);
        return dArr3;
    }

    private static int validate(int i, int i2, int i3, int i4, int i5) {
        if (i == 0 && i2 == 0) {
            return 4;
        }
        if (i4 != (i3 * (i3 + 1)) / 2) {
            return 6;
        }
        return i5 != (i4 * (i4 - 1)) / 2 ? 7 : 0;
    }

    private static void inclu2(int i, double[] dArr, double[] dArr2, double d, double[] dArr3, double[] dArr4, double[] dArr5) {
        System.arraycopy(dArr, 0, dArr2, 0, i);
        int i2 = 0;
        double d2 = d;
        double d3 = 1.0d;
        for (int i3 = 0; i3 < i; i3++) {
            if (dArr2[i3] != 0.0d) {
                double d4 = dArr2[i3];
                double d5 = dArr3[i3];
                double d6 = d5 + (d3 * d4 * d4);
                dArr3[i3] = d6;
                double d7 = d5 / d6;
                double d8 = (d3 * d4) / d6;
                d3 = d7 * d3;
                if (i3 != i - 1) {
                    for (int i4 = i3 + 1; i4 < i; i4++) {
                        double d9 = dArr2[i4];
                        double d10 = dArr4[i2];
                        dArr2[i4] = d9 - (d4 * d10);
                        int i5 = i2;
                        i2++;
                        dArr4[i5] = (d7 * d10) + (d8 * d9);
                    }
                }
                double d11 = d2;
                d2 = d11 - (d4 * dArr5[i3]);
                dArr5[i3] = (d7 * dArr5[i3]) + (d8 * d11);
                if (d5 == 0.0d) {
                    return;
                }
            } else {
                i2 = ((i2 + i) - i3) - 1;
            }
        }
    }

    private static void regres(int i, int i2, double[] dArr, double[] dArr2, double[] dArr3) {
        int i3 = i2 - 1;
        int i4 = i - 1;
        for (int i5 = 0; i5 < i; i5++) {
            double d = dArr2[i4];
            if (i4 != i - 1) {
                int i6 = i5;
                int i7 = i - 1;
                for (int i8 = 0; i8 < i6; i8++) {
                    d -= dArr[i3] * dArr3[i7];
                    i3--;
                    i7--;
                }
            }
            dArr3[i4] = d;
            i4--;
        }
    }

    private static double[] unpack(double[] dArr) {
        int sqrt = ((-1) + ((int) Math.sqrt(1 + (8 * dArr.length)))) / 2;
        double[] dArr2 = new double[sqrt * sqrt];
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        while (i3 < sqrt) {
            for (int i4 = 0; i4 < sqrt - i; i4++) {
                int i5 = i2;
                i2++;
                dArr2[i4 + i + (i3 * sqrt)] = dArr[i5];
            }
            i3++;
            i++;
        }
        for (int i6 = 0; i6 < sqrt - 1; i6++) {
            for (int i7 = i6 + 1; i7 < sqrt; i7++) {
                dArr2[i6 + (sqrt * i7)] = dArr2[i7 + (i6 * sqrt)];
            }
        }
        return dArr2;
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public KalmanOutput output() {
        return this.kalmanOutput;
    }

    double[] predictionError() {
        return (double[]) this.predictionError.clone();
    }

    double ssq() {
        return this.kalmanOutput.ssq();
    }

    int n() {
        return this.kalmanOutput.n();
    }

    double sumLog() {
        return this.kalmanOutput.sumLog();
    }

    double logLikelihood() {
        return this.kalmanOutput.logLikelihood();
    }
}
