package com.github.signaflo.timeseries.model;

import com.github.signaflo.math.stats.distributions.Distribution;
import com.github.signaflo.math.stats.distributions.Normal;
import com.github.signaflo.timeseries.TimePeriod;
import com.github.signaflo.timeseries.TimeSeries;
import com.github.signaflo.timeseries.forecast.Forecast;
import lombok.NonNull;

/* loaded from: input_file:com/github/signaflo/timeseries/model/RandomWalk.class */
public final class RandomWalk implements Model {
    private final TimeSeries timeSeries;
    private final TimeSeries fittedSeries;
    private final TimeSeries residuals;

    public RandomWalk(@NonNull TimeSeries timeSeries) {
        if (timeSeries == null) {
            throw new NullPointerException("observed");
        }
        if (timeSeries.size() < 1) {
            throw new IllegalArgumentException("A random walk model requires at least one observation.");
        }
        this.timeSeries = timeSeries;
        this.fittedSeries = fitSeries();
        this.residuals = calculateResiduals();
    }

    public static TimeSeries simulate(@NonNull Distribution distribution, int i) {
        if (distribution == null) {
            throw new NullPointerException("dist");
        }
        if (i < 1) {
            throw new IllegalArgumentException("the number of observations to simulate must be a positive integer.");
        }
        double[] dArr = new double[i];
        dArr[0] = distribution.rand();
        for (int i2 = 1; i2 < i; i2++) {
            dArr[i2] = dArr[i2 - 1] + distribution.rand();
        }
        return TimeSeries.from(dArr);
    }

    public static TimeSeries simulate(double d, double d2, int i) {
        return simulate((Distribution) new Normal(d, d2), i);
    }

    public static TimeSeries simulate(double d, int i) {
        return simulate((Distribution) new Normal(0.0d, d), i);
    }

    public static TimeSeries simulate(int i) {
        return simulate((Distribution) new Normal(0.0d, 1.0d), i);
    }

    @Override // com.github.signaflo.timeseries.model.Model
    public Forecast forecast(int i, double d) {
        int size = this.timeSeries.size();
        TimePeriod timePeriod = this.timeSeries.timePeriod();
        this.timeSeries.observationTimes().get(size - 1).plus(timePeriod.periodLength() * timePeriod.timeUnit().unitLength(), timePeriod.timeUnit().temporalUnit());
        double[] dArr = new double[i];
        for (int i2 = 0; i2 < i; i2++) {
            dArr[i2] = this.timeSeries.at(size - 1);
        }
        return new RandomWalkForecaster(this.timeSeries, predictionErrors()).forecast(i, d);
    }

    @Override // com.github.signaflo.timeseries.model.Model
    public TimeSeries observations() {
        return this.timeSeries;
    }

    @Override // com.github.signaflo.timeseries.model.Model
    public TimeSeries fittedSeries() {
        return this.fittedSeries;
    }

    @Override // com.github.signaflo.timeseries.model.Model
    public TimeSeries predictionErrors() {
        return this.residuals;
    }

    private TimeSeries fitSeries() {
        double[] dArr = new double[this.timeSeries.size()];
        dArr[0] = this.timeSeries.at(0);
        for (int i = 1; i < this.timeSeries.size(); i++) {
            dArr[i] = this.timeSeries.at(i - 1);
        }
        return TimeSeries.from(this.timeSeries.timePeriod(), this.timeSeries.observationTimes().get(0), dArr);
    }

    private TimeSeries calculateResiduals() {
        double[] dArr = new double[this.timeSeries.size()];
        for (int i = 1; i < this.timeSeries.size(); i++) {
            dArr[i] = this.timeSeries.at(i) - this.fittedSeries.at(i);
        }
        return TimeSeries.from(this.timeSeries.timePeriod(), this.timeSeries.observationTimes().get(0), dArr);
    }

    public String toString() {
        return "Random walk time series model";
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        RandomWalk randomWalk = (RandomWalk) obj;
        if (this.timeSeries != null) {
            if (!this.timeSeries.equals(randomWalk.timeSeries)) {
                return false;
            }
        } else if (randomWalk.timeSeries != null) {
            return false;
        }
        if (this.fittedSeries.equals(randomWalk.fittedSeries)) {
            return this.residuals.equals(randomWalk.residuals);
        }
        return false;
    }

    public int hashCode() {
        return (31 * ((31 * (this.timeSeries != null ? this.timeSeries.hashCode() : 0)) + this.fittedSeries.hashCode())) + this.residuals.hashCode();
    }
}
