package com.amazon.randomcutforest.parkservices.calibration;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.PredictiveRandomCutForest;
import com.amazon.randomcutforest.config.TransformMethod;
import com.amazon.randomcutforest.parkservices.ForecastDescriptor;
import com.amazon.randomcutforest.parkservices.RCFCaster;
import com.amazon.randomcutforest.parkservices.config.Calibration;
import com.amazon.randomcutforest.returntypes.DiVector;
import com.amazon.randomcutforest.returntypes.RangeVector;
import com.amazon.randomcutforest.returntypes.SampleSummary;
import com.amazon.randomcutforest.statistics.Deviation;
import java.util.Arrays;
import java.util.Optional;
import lombok.Generated;

/* loaded from: input_file:com/amazon/randomcutforest/parkservices/calibration/ErrorHandler.class */
public class ErrorHandler {
    int sequenceIndex;
    double percentile;
    int forecastHorizon;
    int errorHorizon;
    protected RangeVector[] pastForecasts;
    RangeVector errorDistribution;
    DiVector errorRMSE;
    float[] errorMean;
    Deviation[] intervalPrecision;
    Deviation[] rmseHighDeviations;
    Deviation[] rmseLowDeviations;
    float[] lowerLimit;
    float[] upperLimit;
    double[] lastInputs;
    PredictiveRandomCutForest estimator;
    float[] lastDataDeviations;
    RangeVector multipliers;
    RangeVector adders;

    /* loaded from: input_file:com/amazon/randomcutforest/parkservices/calibration/ErrorHandler$Builder.class */
    public static class Builder {
        protected int dimensions;
        protected int forecastHorizon;
        protected int shingleSize = 1;
        protected boolean useRCF = true;
        protected int errorHorizon = 100;
        protected double percentile = RCFCaster.DEFAULT_ERROR_PERCENTILE;
        protected Optional<float[]> upperLimit = Optional.empty();
        protected Optional<float[]> lowerLimit = Optional.empty();

        public Builder dimensions(int i) {
            this.dimensions = i;
            return this;
        }

        public Builder shingleSize(int i) {
            this.shingleSize = i;
            return this;
        }

        public Builder forecastHorizon(int i) {
            this.forecastHorizon = i;
            return this;
        }

        public Builder errorHorizon(int i) {
            this.errorHorizon = i;
            return this;
        }

        public Builder percentile(double d) {
            this.percentile = d;
            return this;
        }

        public Builder lowerLimit(float[] fArr) {
            this.lowerLimit = Optional.of(fArr);
            return this;
        }

        public Builder upperLimit(float[] fArr) {
            this.upperLimit = Optional.of(fArr);
            return this;
        }

        public Builder useRCF(boolean z) {
            this.useRCF = z;
            return this;
        }

        public ErrorHandler build() {
            return new ErrorHandler(this);
        }
    }

    public ErrorHandler(Builder builder) {
        CommonUtils.checkArgument(builder.forecastHorizon > 0, "has to be positive");
        CommonUtils.checkArgument(builder.errorHorizon >= builder.forecastHorizon, "intervalPrecision horizon should be at least as large as forecast horizon");
        CommonUtils.checkArgument(builder.errorHorizon <= 1024, "reduce error horizon");
        this.forecastHorizon = builder.forecastHorizon;
        this.errorHorizon = builder.errorHorizon;
        int i = builder.dimensions / builder.shingleSize;
        int i2 = i * this.forecastHorizon;
        this.percentile = builder.percentile;
        this.pastForecasts = new RangeVector[this.forecastHorizon];
        for (int i3 = 0; i3 < this.forecastHorizon; i3++) {
            this.pastForecasts[i3] = new RangeVector(i2);
        }
        this.sequenceIndex = 0;
        this.lastInputs = new double[2 * i];
        this.rmseHighDeviations = new Deviation[i2];
        this.rmseLowDeviations = new Deviation[i2];
        this.intervalPrecision = new Deviation[i2];
        for (int i4 = 0; i4 < i2; i4++) {
            this.rmseHighDeviations[i4] = new Deviation(1.0d / this.errorHorizon);
            this.rmseLowDeviations[i4] = new Deviation(1.0d / this.errorHorizon);
            this.intervalPrecision[i4] = new Deviation(1.0d / this.errorHorizon);
        }
        this.errorMean = new float[i2];
        this.errorRMSE = new DiVector(i2);
        this.lastDataDeviations = new float[i];
        this.errorDistribution = new RangeVector(i2);
        if (builder.upperLimit.isPresent()) {
            CommonUtils.checkArgument(builder.upperLimit.get().length == i, "incorrect length");
            this.upperLimit = Arrays.copyOf(builder.upperLimit.get(), i);
        } else {
            this.upperLimit = new float[i];
            Arrays.fill(this.upperLimit, Float.MAX_VALUE);
        }
        if (builder.lowerLimit.isPresent()) {
            CommonUtils.checkArgument(builder.lowerLimit.get().length == i, "incorrect length");
            for (int i5 = 0; i5 < i; i5++) {
                CommonUtils.checkArgument(builder.lowerLimit.get()[i5] <= this.upperLimit[i5], "incorrect limits");
            }
            this.lowerLimit = Arrays.copyOf(builder.lowerLimit.get(), i);
        } else {
            this.lowerLimit = new float[i];
            Arrays.fill(this.lowerLimit, -3.4028235E38f);
        }
        if (builder.useRCF) {
            int length = this.lastInputs.length + (2 * i) + 2;
            double[] dArr = new double[length];
            Arrays.fill(dArr, 1.0d);
            dArr[this.lastInputs.length] = this.lastInputs.length;
            dArr[this.lastInputs.length + 1] = this.lastInputs.length;
            this.estimator = new PredictiveRandomCutForest.Builder().inputDimensions(length).weights(dArr).randomSeed(13L).outputAfter(50).transformMethod(TransformMethod.NORMALIZE).startNormalization(49).build();
        }
    }

    public ErrorHandler(int i, int i2, int i3, double d, int i4, float[] fArr, float[] fArr2, double[] dArr, Deviation[] deviationArr, PredictiveRandomCutForest predictiveRandomCutForest, float[] fArr3) {
        CommonUtils.checkArgument(i2 > 0, " incorrect forecast horizon");
        CommonUtils.checkArgument(i >= i2, "incorrect error horizon");
        CommonUtils.checkArgument(i4 > 0, "incorrect parameters");
        CommonUtils.checkArgument(i3 >= 0, "cannot be negative");
        CommonUtils.checkArgument(Math.abs(d - 0.25d) < 0.24d, "has to be between (0,0.5) ");
        CommonUtils.checkArgument(deviationArr.length == (3 * i4) * i2, "incorrect length");
        CommonUtils.checkArgument(dArr.length == 2 * i4, "incorrect length");
        this.sequenceIndex = i3;
        this.errorHorizon = i;
        this.percentile = d;
        this.forecastHorizon = i2;
        this.pastForecasts = new RangeVector[i2];
        this.lastInputs = Arrays.copyOf(dArr, dArr.length);
        int i5 = i2 * i4;
        CommonUtils.checkArgument(fArr2.length >= i4, "incorrect length");
        this.lastDataDeviations = Arrays.copyOf(fArr2, fArr2.length);
        this.errorMean = new float[i5];
        this.errorRMSE = new DiVector(i5);
        this.errorDistribution = new RangeVector(i5);
        this.intervalPrecision = new Deviation[i4 * i2];
        this.rmseHighDeviations = new Deviation[i4 * i2];
        this.rmseLowDeviations = new Deviation[i4 * i2];
        for (int i6 = 0; i6 < i4 * i2; i6++) {
            this.intervalPrecision[i6] = deviationArr[i6].copy();
            this.rmseHighDeviations[i6] = deviationArr[i6 + (i4 * i2)].copy();
            this.rmseLowDeviations[i6] = deviationArr[i6 + (2 * i4 * i2)].copy();
        }
        this.lowerLimit = new float[i4];
        Arrays.fill(this.lowerLimit, -3.4028235E38f);
        this.upperLimit = new float[i4];
        Arrays.fill(this.upperLimit, Float.MAX_VALUE);
        this.estimator = predictiveRandomCutForest;
        int length = fArr.length / (3 * i5);
        CommonUtils.checkArgument((length * 3) * i5 == fArr.length, " has to be multiple of 3");
        for (int i7 = 0; i7 < length; i7++) {
            this.pastForecasts[i7] = new RangeVector(Arrays.copyOfRange(fArr, i7 * 3 * i5, ((i7 * 3) + 1) * i5), Arrays.copyOfRange(fArr, ((i7 * 3) + 1) * i5, ((i7 * 3) + 2) * i5), Arrays.copyOfRange(fArr, ((i7 * 3) + 2) * i5, ((i7 * 3) + 3) * i5));
        }
        for (int i8 = length; i8 < i2; i8++) {
            this.pastForecasts[i8] = new RangeVector(i5);
        }
        recomputeErrors(this.lastInputs, i4);
    }

    public void setUpperLimit(float[] fArr) {
        if (fArr != null) {
            CommonUtils.checkArgument(fArr.length == this.upperLimit.length, "incorrect Length");
            System.arraycopy(fArr, 0, this.upperLimit, 0, fArr.length);
        }
    }

    public void setLowerLimit(float[] fArr) {
        if (fArr != null) {
            CommonUtils.checkArgument(fArr.length == this.lowerLimit.length, "incorrect Length");
            for (int i = 0; i < fArr.length; i++) {
                CommonUtils.checkArgument(fArr[i] <= this.upperLimit[i], "lower limit is higher than upper limit");
                this.lowerLimit[i] = fArr[i];
            }
        }
    }

    public void updateActuals(double[] dArr, double[] dArr2) {
        int length = this.pastForecasts.length;
        int length2 = dArr.length;
        for (int i = 0; i < this.lastInputs.length - length2; i++) {
            this.lastInputs[i] = this.lastInputs[i + length2];
        }
        System.arraycopy(dArr, 0, this.lastInputs, this.lastInputs.length - length2, length2);
        if (this.sequenceIndex > 0) {
            int i2 = ((this.sequenceIndex + length) - 1) % length;
            float[] fArr = new float[this.lastInputs.length + (2 * length2) + 2];
            for (int i3 = 0; i3 < this.lastInputs.length; i3++) {
                fArr[i3] = (float) this.lastInputs[i3];
            }
            for (int i4 = 0; i4 < this.forecastHorizon; i4++) {
                if (this.sequenceIndex > i4) {
                    for (int i5 = 0; i5 < length2; i5++) {
                        RangeVector rangeVector = this.pastForecasts[i2];
                        int i6 = i4 * length2;
                        fArr[this.lastInputs.length] = i4;
                        fArr[this.lastInputs.length + 1] = this.forecastHorizon - i4;
                        if (dArr[i5] > rangeVector.upper[i6 + i5] || dArr[i5] < rangeVector.lower[i6 + i5]) {
                            this.intervalPrecision[i6 + i5].update(0.0d);
                        } else {
                            this.intervalPrecision[i6 + i5].update(1.0d);
                        }
                        double d = dArr[i5] - rangeVector.values[i6 + i5];
                        if (d >= 0.0d) {
                            this.rmseHighDeviations[i6 + i5].update(d);
                            this.rmseLowDeviations[i6 + i5].update(0.0d);
                            fArr[this.lastInputs.length + 2 + i5] = (float) d;
                            fArr[this.lastInputs.length + length2 + 2 + i5] = 0.0f;
                        } else {
                            this.rmseLowDeviations[i6 + i5].update(d);
                            this.rmseHighDeviations[i6 + i5].update(0.0d);
                            fArr[this.lastInputs.length + length2 + 2 + i5] = (float) d;
                            fArr[this.lastInputs.length + 2 + i5] = 0.0f;
                        }
                    }
                    if (this.estimator != null) {
                        this.estimator.update(fArr, 0L);
                    }
                }
                i2 = ((i2 + length) - 1) % length;
            }
        }
        this.lastDataDeviations = CommonUtils.toFloatArray(dArr2);
        recomputeErrors(this.lastInputs, length2);
    }

    void recomputeErrors(double[] dArr, int i) {
        double outputAfter = this.estimator != null ? this.sequenceIndex / this.estimator.getForest().getOutputAfter() : this.sequenceIndex / (10 * this.forecastHorizon);
        float[] fArr = new float[dArr.length + (i * 2) + 2];
        System.arraycopy(CommonUtils.toFloatArray(dArr), 0, fArr, 0, dArr.length);
        float[] fArr2 = new float[this.intervalPrecision.length];
        float[] fArr3 = new float[this.intervalPrecision.length];
        if (outputAfter < 1.0d) {
            for (int i2 = 0; i2 < this.intervalPrecision.length; i2++) {
                double d = this.lastDataDeviations[i2 % i];
                this.errorRMSE.low[i2] = d;
                this.errorRMSE.high[i2] = d;
                float f = this.lastDataDeviations[i2 % i];
                fArr3[i2] = f;
                fArr2[i2] = f;
            }
        } else {
            if (outputAfter < 2.0d) {
                for (int i3 = 0; i3 < this.errorRMSE.high.length; i3++) {
                    double d2 = (2.0d - outputAfter) * this.lastDataDeviations[i3 % i];
                    this.errorRMSE.high[i3] = d2 + ((outputAfter - 1.0d) * this.rmseHighDeviations[i3].getDeviation());
                    this.errorRMSE.low[i3] = d2 + ((outputAfter - 1.0d) * this.rmseLowDeviations[i3].getDeviation());
                }
            } else {
                for (int i4 = 0; i4 < this.errorRMSE.high.length; i4++) {
                    this.errorRMSE.high[i4] = this.rmseHighDeviations[i4].getDeviation();
                    this.errorRMSE.low[i4] = this.rmseLowDeviations[i4].getDeviation();
                }
            }
            if (this.estimator != null) {
                for (int i5 = 0; i5 < this.forecastHorizon; i5++) {
                    int[] iArr = new int[i];
                    fArr[dArr.length] = i5;
                    fArr[dArr.length + 1] = this.forecastHorizon - i5;
                    for (int i6 = 0; i6 < i; i6++) {
                        iArr[i6] = dArr.length + 2 + i6;
                    }
                    SampleSummary predict = this.estimator.predict(fArr, 0L, iArr, 1, 0.5d, 0.7d);
                    for (int i7 = 0; i7 < i; i7++) {
                        fArr2[(i5 * i) + i7] = (((this.forecastHorizon - i5) * Math.max(0.0f, predict.deviation[(dArr.length + 2) + i7])) / this.forecastHorizon) + ((float) ((i5 * this.rmseHighDeviations[(i5 * i) + i7].getDeviation()) / this.forecastHorizon));
                    }
                    for (int i8 = 0; i8 < i; i8++) {
                        iArr[i8] = dArr.length + i + 2 + i8;
                    }
                    SampleSummary predict2 = this.estimator.predict(fArr, 0L, iArr, 1, 0.5d, 0.7d);
                    for (int i9 = 0; i9 < i; i9++) {
                        fArr3[(i5 * i) + i9] = (((this.forecastHorizon - i5) * Math.max(0.0f, predict2.deviation[((dArr.length + i) + 2) + i9])) / this.forecastHorizon) + ((float) ((i5 * this.rmseLowDeviations[(i5 * i) + i9].getDeviation()) / this.forecastHorizon));
                    }
                }
            } else {
                for (int i10 = 0; i10 < this.errorRMSE.high.length; i10++) {
                    fArr2[i10] = (float) this.errorRMSE.high[i10];
                    fArr3[i10] = (float) this.errorRMSE.low[i10];
                }
            }
        }
        for (int i11 = 0; i11 < this.intervalPrecision.length; i11++) {
            if (this.intervalPrecision[i11].getMean() < 1.0d - this.percentile) {
                fArr2[i11] = ((float) Math.max(1.0d, 1.0d / (this.intervalPrecision[i11].getMean() + 0.1d))) * fArr2[i11];
                fArr3[i11] = ((float) Math.max(1.0d, 1.0d / (this.intervalPrecision[i11].getMean() + 0.1d))) * fArr3[i11];
            }
        }
        for (int i12 = 0; i12 < this.errorMean.length; i12++) {
            this.errorMean[i12] = (float) (this.rmseHighDeviations[i12].getMean() + this.rmseLowDeviations[i12].getMean());
            this.errorDistribution.values[i12] = this.errorMean[i12];
            this.errorDistribution.upper[i12] = this.errorMean[i12] + ((float) (1.3d * fArr2[i12]));
            this.errorDistribution.lower[i12] = this.errorMean[i12] - ((float) (1.3d * fArr3[i12]));
        }
    }

    public void augmentDescriptor(ForecastDescriptor forecastDescriptor) {
        float[] fArr = new float[forecastDescriptor.getInputLength() * this.forecastHorizon];
        for (int i = 0; i < this.errorMean.length; i++) {
            fArr[i] = (float) this.intervalPrecision[i].getMean();
        }
        forecastDescriptor.setErrorMean(this.errorMean);
        forecastDescriptor.setErrorRMSE(this.errorRMSE);
        forecastDescriptor.setObservedErrorDistribution(this.errorDistribution);
        forecastDescriptor.setIntervalPrecision(fArr);
    }

    public void updateForecasts(RangeVector rangeVector) {
        this.sequenceIndex++;
        int length = this.pastForecasts.length;
        int i = ((this.sequenceIndex + length) - 1) % length;
        int length2 = this.pastForecasts[0].values.length;
        System.arraycopy(rangeVector.values, 0, this.pastForecasts[i].values, 0, length2);
        System.arraycopy(rangeVector.upper, 0, this.pastForecasts[i].upper, 0, length2);
        System.arraycopy(rangeVector.lower, 0, this.pastForecasts[i].lower, 0, length2);
    }

    public RangeVector getErrorDistribution() {
        return new RangeVector(this.errorDistribution);
    }

    public float[] getErrorMean() {
        return Arrays.copyOf(this.errorMean, this.errorMean.length);
    }

    public DiVector getErrorRMSE() {
        return new DiVector(this.errorRMSE);
    }

    public Deviation[] getDeviationList() {
        Deviation[] deviationArr = new Deviation[3 * this.intervalPrecision.length];
        for (int i = 0; i < this.intervalPrecision.length; i++) {
            deviationArr[i] = this.intervalPrecision[i].copy();
            deviationArr[i + this.intervalPrecision.length] = this.rmseHighDeviations[i].copy();
            deviationArr[i + (2 * this.intervalPrecision.length)] = this.rmseLowDeviations[i].copy();
        }
        return deviationArr;
    }

    public float[] getIntervalPrecision() {
        float[] fArr = new float[this.intervalPrecision.length];
        for (int i = 0; i < fArr.length; i++) {
            fArr[i] = (float) this.intervalPrecision[i].getMean();
        }
        return fArr;
    }

    public void calibrate(double[] dArr, Calibration calibration, RangeVector rangeVector) {
        if (calibration != Calibration.NONE) {
            int length = this.intervalPrecision.length / this.forecastHorizon;
            CommonUtils.checkArgument(dArr.length == length, "incorrect input");
            CommonUtils.checkArgument(this.intervalPrecision.length == rangeVector.values.length, "mismatched lengths");
            for (int i = 0; i < this.intervalPrecision.length; i++) {
                if (calibration == Calibration.SIMPLE) {
                    rangeVector.values[i] = Math.min(Math.max(rangeVector.values[i] + this.errorDistribution.values[i], this.lowerLimit[i % length]), this.upperLimit[i % length]);
                } else {
                    rangeVector.values[i] = Math.min(Math.max(rangeVector.values[i], this.lowerLimit[i % length]), this.upperLimit[i % length]);
                }
                rangeVector.upper[i] = Math.min(Math.max(rangeVector.upper[i], rangeVector.values[i] + this.errorDistribution.upper[i]), this.upperLimit[i % length]);
                rangeVector.lower[i] = Math.max(Math.min(rangeVector.lower[i], rangeVector.values[i] + this.errorDistribution.lower[i]), this.lowerLimit[i % length]);
            }
        }
    }

    public int getInputLength() {
        return this.lastInputs.length / 2;
    }

    public float[] getPastForecastsFlattened() {
        int min = Math.min(this.sequenceIndex, this.pastForecasts.length);
        int length = this.intervalPrecision.length;
        float[] fArr = new float[3 * length * min];
        for (int i = 0; i < min; i++) {
            System.arraycopy(this.pastForecasts[i].values, 0, fArr, 3 * i * length, length);
            System.arraycopy(this.pastForecasts[i].upper, 0, fArr, (3 * i * length) + length, length);
            System.arraycopy(this.pastForecasts[i].lower, 0, fArr, (3 * i * length) + (2 * length), length);
        }
        return fArr;
    }

    public static Builder builder() {
        return new Builder();
    }

    @Generated
    public int getSequenceIndex() {
        return this.sequenceIndex;
    }

    @Generated
    public double getPercentile() {
        return this.percentile;
    }

    @Generated
    public int getForecastHorizon() {
        return this.forecastHorizon;
    }

    @Generated
    public int getErrorHorizon() {
        return this.errorHorizon;
    }

    @Generated
    public RangeVector[] getPastForecasts() {
        return this.pastForecasts;
    }

    @Generated
    public Deviation[] getRmseHighDeviations() {
        return this.rmseHighDeviations;
    }

    @Generated
    public Deviation[] getRmseLowDeviations() {
        return this.rmseLowDeviations;
    }

    @Generated
    public float[] getLowerLimit() {
        return this.lowerLimit;
    }

    @Generated
    public float[] getUpperLimit() {
        return this.upperLimit;
    }

    @Generated
    public double[] getLastInputs() {
        return this.lastInputs;
    }

    @Generated
    public PredictiveRandomCutForest getEstimator() {
        return this.estimator;
    }

    @Generated
    public float[] getLastDataDeviations() {
        return this.lastDataDeviations;
    }

    @Generated
    public RangeVector getMultipliers() {
        return this.multipliers;
    }

    @Generated
    public RangeVector getAdders() {
        return this.adders;
    }

    @Generated
    public void setSequenceIndex(int i) {
        this.sequenceIndex = i;
    }

    @Generated
    public void setPercentile(double d) {
        this.percentile = d;
    }

    @Generated
    public void setForecastHorizon(int i) {
        this.forecastHorizon = i;
    }

    @Generated
    public void setErrorHorizon(int i) {
        this.errorHorizon = i;
    }

    @Generated
    public void setPastForecasts(RangeVector[] rangeVectorArr) {
        this.pastForecasts = rangeVectorArr;
    }

    @Generated
    public void setErrorDistribution(RangeVector rangeVector) {
        this.errorDistribution = rangeVector;
    }

    @Generated
    public void setErrorRMSE(DiVector diVector) {
        this.errorRMSE = diVector;
    }

    @Generated
    public void setErrorMean(float[] fArr) {
        this.errorMean = fArr;
    }

    @Generated
    public void setIntervalPrecision(Deviation[] deviationArr) {
        this.intervalPrecision = deviationArr;
    }

    @Generated
    public void setRmseHighDeviations(Deviation[] deviationArr) {
        this.rmseHighDeviations = deviationArr;
    }

    @Generated
    public void setRmseLowDeviations(Deviation[] deviationArr) {
        this.rmseLowDeviations = deviationArr;
    }

    @Generated
    public void setLastInputs(double[] dArr) {
        this.lastInputs = dArr;
    }

    @Generated
    public void setEstimator(PredictiveRandomCutForest predictiveRandomCutForest) {
        this.estimator = predictiveRandomCutForest;
    }

    @Generated
    public void setLastDataDeviations(float[] fArr) {
        this.lastDataDeviations = fArr;
    }

    @Generated
    public void setMultipliers(RangeVector rangeVector) {
        this.multipliers = rangeVector;
    }

    @Generated
    public void setAdders(RangeVector rangeVector) {
        this.adders = rangeVector;
    }
}
