package ai.djl.timeseries.model.deepar;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.NDResource;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.recurrent.LSTM;
import ai.djl.timeseries.block.FeatureEmbedder;
import ai.djl.timeseries.block.MeanScaler;
import ai.djl.timeseries.block.NopScaler;
import ai.djl.timeseries.block.Scaler;
import ai.djl.timeseries.dataset.FieldName;
import ai.djl.timeseries.distribution.output.DistributionOutput;
import ai.djl.timeseries.distribution.output.StudentTOutput;
import ai.djl.timeseries.timefeature.Lag;
import ai.djl.timeseries.timefeature.TimeFeature;
import ai.djl.timeseries.transform.ExpectedNumInstanceSampler;
import ai.djl.timeseries.transform.PredictionSplitSampler;
import ai.djl.timeseries.transform.TimeSeriesTransform;
import ai.djl.timeseries.transform.convert.AsArray;
import ai.djl.timeseries.transform.convert.VstackFeatures;
import ai.djl.timeseries.transform.feature.AddAgeFeature;
import ai.djl.timeseries.transform.feature.AddObservedValuesIndicator;
import ai.djl.timeseries.transform.feature.AddTimeFeature;
import ai.djl.timeseries.transform.field.RemoveFields;
import ai.djl.timeseries.transform.field.SelectField;
import ai.djl.timeseries.transform.field.SetField;
import ai.djl.timeseries.transform.split.InstanceSplit;
import ai.djl.training.ParameterStore;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:ai/djl/timeseries/model/deepar/DeepARNetwork.class */
public abstract class DeepARNetwork extends AbstractBlock {
    private static final String[] TRAIN_INPUT_FIELDS = {FieldName.FEAT_STATIC_CAT.name(), FieldName.FEAT_STATIC_REAL.name(), "PAST_" + FieldName.FEAT_TIME.name(), "PAST_" + FieldName.TARGET.name(), "PAST_" + FieldName.OBSERVED_VALUES.name(), "PAST_" + FieldName.IS_PAD.name(), "FUTURE_" + FieldName.FEAT_TIME.name(), "FUTURE_" + FieldName.TARGET.name(), "FUTURE_" + FieldName.OBSERVED_VALUES.name()};
    private static final String[] PRED_INPUT_FIELDS = {FieldName.FEAT_STATIC_CAT.name(), FieldName.FEAT_STATIC_REAL.name(), "PAST_" + FieldName.FEAT_TIME.name(), "PAST_" + FieldName.TARGET.name(), "PAST_" + FieldName.OBSERVED_VALUES.name(), "FUTURE_" + FieldName.FEAT_TIME.name(), "PAST_" + FieldName.IS_PAD.name()};
    protected String freq;
    protected int historyLength;
    protected int contextLength;
    protected int predictionLength;
    protected boolean useFeatDynamicReal;
    protected boolean useFeatStaticCat;
    protected boolean useFeatStaticReal;
    protected DistributionOutput distrOutput;
    protected List<Integer> cardinality;
    protected List<Integer> embeddingDimension;
    protected List<Integer> lagsSeq;
    protected int numParallelSamples;
    protected FeatureEmbedder embedder;
    protected Block paramProj;
    protected LSTM rnn;
    protected Scaler scaler;

    /* loaded from: input_file:ai/djl/timeseries/model/deepar/DeepARNetwork$Builder.class */
    public static final class Builder {
        private String freq;
        private int contextLength;
        private int predictionLength;
        private boolean useFeatDynamicReal;
        private boolean useFeatStaticCat;
        private boolean useFeatStaticReal;
        private List<Integer> cardinality;
        private List<Integer> embeddingDimension;
        private List<Integer> lagsSeq;
        private int numParallelSamples = 100;
        private int numLayers = 2;
        private int hiddenSize = 40;
        private float dropRate = 0.1f;
        private boolean scaling = true;
        private DistributionOutput distrOutput = new StudentTOutput();

        public Builder setFreq(String str) {
            this.freq = str;
            return this;
        }

        public Builder setPredictionLength(int i) {
            this.predictionLength = i;
            return this;
        }

        public Builder setCardinality(List<Integer> list) {
            this.cardinality = list;
            return this;
        }

        public Builder optDistrOutput(DistributionOutput distributionOutput) {
            this.distrOutput = distributionOutput;
            return this;
        }

        public Builder optContextLength(int i) {
            this.contextLength = i;
            return this;
        }

        public Builder optNumParallelSamples(int i) {
            this.numParallelSamples = i;
            return this;
        }

        public Builder optNumLayers(int i) {
            this.numLayers = i;
            return this;
        }

        public Builder optHiddenSize(int i) {
            this.hiddenSize = i;
            return this;
        }

        public Builder optDropRate(float f) {
            this.dropRate = f;
            return this;
        }

        public Builder optEmbeddingDimension(List<Integer> list) {
            this.embeddingDimension = list;
            return this;
        }

        public Builder optLagsSeq(List<Integer> list) {
            this.lagsSeq = list;
            return this;
        }

        public Builder optUseFeatDynamicReal(boolean z) {
            this.useFeatDynamicReal = z;
            return this;
        }

        public Builder optUseFeatStaticCat(boolean z) {
            this.useFeatStaticCat = z;
            return this;
        }

        public Builder optUseFeatStaticReal(boolean z) {
            this.useFeatStaticReal = z;
            return this;
        }

        public DeepARTrainingNetwork buildTrainingNetwork() {
            return new DeepARTrainingNetwork(this);
        }

        public DeepARPredictionNetwork buildPredictionNetwork() {
            return new DeepARPredictionNetwork(this);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DeepARNetwork(Builder builder) {
        this.freq = builder.freq;
        this.predictionLength = builder.predictionLength;
        this.contextLength = builder.contextLength != 0 ? builder.contextLength : this.predictionLength;
        this.distrOutput = builder.distrOutput;
        this.cardinality = builder.cardinality;
        this.useFeatStaticReal = builder.useFeatStaticReal;
        this.useFeatDynamicReal = builder.useFeatDynamicReal;
        this.useFeatStaticCat = builder.useFeatStaticCat;
        this.numParallelSamples = builder.numParallelSamples;
        this.paramProj = addChildBlock("param_proj", this.distrOutput.getArgsProj());
        if (builder.embeddingDimension != null || builder.cardinality == null) {
            this.embeddingDimension = builder.embeddingDimension;
        } else {
            this.embeddingDimension = new ArrayList();
            Iterator<Integer> it = this.cardinality.iterator();
            while (it.hasNext()) {
                this.embeddingDimension.add(Integer.valueOf(Math.min(50, (it.next().intValue() + 1) / 2)));
            }
        }
        this.lagsSeq = builder.lagsSeq == null ? Lag.getLagsForFreq(builder.freq) : builder.lagsSeq;
        this.historyLength = this.contextLength + this.lagsSeq.stream().max(Comparator.naturalOrder()).get().intValue();
        this.embedder = addChildBlock("feature_embedder", FeatureEmbedder.builder().setCardinalities(this.cardinality).setEmbeddingDims(this.embeddingDimension).build());
        if (builder.scaling) {
            this.scaler = addChildBlock("scaler", MeanScaler.builder().setDim(1).optKeepDim(true).optMinimumScale(1.0E-10f).build());
        } else {
            this.scaler = addChildBlock("scaler", NopScaler.builder().setDim(1).optKeepDim(true).build());
        }
        this.rnn = addChildBlock("rnn_lstm", LSTM.builder().setNumLayers(builder.numLayers).setStateSize(builder.hiddenSize).optDropRate(builder.dropRate).optBatchFirst(true).optReturnState(true).build());
    }

    protected void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        Shape slice = shapeArr[3].slice(2);
        Shape addAll = new Shape(new long[]{1, this.contextLength}).addAll(slice);
        this.scaler.initialize(nDManager, dataType, new Shape[]{addAll, addAll});
        long j = this.scaler.getOutputShapes(new Shape[]{addAll, addAll})[1].get(1);
        this.embedder.initialize(nDManager, dataType, new Shape[]{shapeArr[0]});
        long j2 = this.embedder.getOutputShapes(new Shape[]{shapeArr[0]})[0].get(1);
        Shape add = new Shape(new long[]{1, (this.contextLength * 2) - 1}).addAll(slice).add(new long[]{this.lagsSeq.size()});
        Shape add2 = add.slice(0, add.dimension() - 1).add(new long[]{add.tail() + shapeArr[2].get(2) + j2 + shapeArr[1].get(1) + j});
        this.rnn.initialize(nDManager, dataType, new Shape[]{add2});
        this.paramProj.initialize(nDManager, dataType, new Shape[]{this.rnn.getOutputShapes(new Shape[]{add2})[0]});
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    public NDList unrollLaggedRnn(ParameterStore parameterStore, NDList nDList, boolean z) {
        NDManager newSubManager = nDList.getManager().newSubManager();
        try {
            newSubManager.tempAttachAll(new NDResource[]{nDList});
            NDArray nDArray = (NDArray) nDList.get(0);
            NDArray nDArray2 = (NDArray) nDList.get(1);
            NDArray nDArray3 = (NDArray) nDList.get(2);
            NDArray nDArray4 = (NDArray) nDList.get(3);
            NDArray nDArray5 = (NDArray) nDList.get(4);
            NDArray nDArray6 = (NDArray) nDList.get(5);
            NDArray nDArray7 = nDList.size() > 6 ? (NDArray) nDList.get(6) : null;
            NDArray nDArray8 = nDArray4.get(":,{}:", new Object[]{Integer.valueOf(-this.contextLength)});
            NDArray nDArray9 = (NDArray) this.scaler.forward(parameterStore, new NDList(new NDArray[]{nDArray8, nDArray5.get(":,{}:", new Object[]{Integer.valueOf(-this.contextLength)})}), z).get(1);
            NDArray div = nDArray4.get(":,:{}", new Object[]{Integer.valueOf(-this.contextLength)}).div(nDArray9);
            NDArray div2 = nDArray7 != null ? nDArray8.concat(nDArray7.get(":, :-1", new Object[0]), 1).div(nDArray9) : nDArray8.div(nDArray9);
            NDArray concat = NDArrays.concat(new NDList(new NDArray[]{this.embedder.forward(parameterStore, new NDList(new NDArray[]{nDArray}), z).singletonOrThrow(), nDArray2, nDArray9.log()}), 1);
            NDList forward = this.rnn.forward(parameterStore, new NDList(new NDArray[]{laggedSequenceValues(this.lagsSeq, div, div2).concat(concat.expandDims(1).repeat(1, div2.getShape().get(1)).concat(nDArray6 != null ? nDArray3.get(":, {}:", new Object[]{Integer.valueOf((-this.contextLength) + 1)}).concat(nDArray6, 1) : nDArray3.get(":, {}:", new Object[]{Integer.valueOf((-this.contextLength) + 1)}), -1), -1)}), z);
            NDArray nDArray10 = (NDArray) forward.get(0);
            NDArray nDArray11 = (NDArray) forward.get(1);
            NDArray nDArray12 = (NDArray) forward.get(2);
            NDList forward2 = this.paramProj.forward(parameterStore, new NDList(new NDArray[]{nDArray10}), z);
            nDArray9.setName("scale");
            nDArray10.setName("output");
            concat.setName("static_feat");
            nDArray11.setName("hidden_state");
            nDArray12.setName("cell_state");
            NDList ret = newSubManager.ret(forward2.addAll(new NDList(new NDArray[]{nDArray9, nDArray10, concat, nDArray11, nDArray12})));
            if (newSubManager != null) {
                newSubManager.close();
            }
            return ret;
        } catch (Throwable th) {
            if (newSubManager != null) {
                try {
                    newSubManager.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* JADX WARN: Multi-variable type inference failed */
    public NDArray laggedSequenceValues(List<Integer> list, NDArray nDArray, NDArray nDArray2) {
        if (((Integer) Collections.max(list)).intValue() > ((int) nDArray.getShape().get(1))) {
            throw new IllegalArgumentException(String.format("lags cannot go further than prior sequence length, found lag %d while prior sequence is only %d-long", Collections.max(list), Long.valueOf(nDArray.getShape().get(1))));
        }
        NDManager subManagerOf = NDManager.subManagerOf(nDArray);
        try {
            subManagerOf.tempAttachAll(new NDResource[]{nDArray, nDArray2});
            NDArray concat = nDArray.concat(nDArray2, 1);
            NDList nDList = new NDList(list.size());
            Iterator<Integer> it = list.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                long j = (-intValue) - nDArray2.getShape().get(1);
                long j2 = -intValue;
                nDList.add(j2 < 0 ? concat.get(":, {}:{}", new Object[]{Long.valueOf(j), Long.valueOf(j2)}) : concat.get(":, {}:", new Object[]{Long.valueOf(j)}));
            }
            NDArray stack = NDArrays.stack(nDList, -1);
            NDArray ret = subManagerOf.ret(stack.reshape(new long[]{stack.getShape().get(0), stack.getShape().get(1), -1}));
            if (subManagerOf != null) {
                subManagerOf.close();
            }
            return ret;
        } catch (Throwable th) {
            if (subManagerOf != null) {
                try {
                    subManagerOf.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    public int getContextLength() {
        return this.contextLength;
    }

    public int getHistoryLength() {
        return this.historyLength;
    }

    public List<TimeSeriesTransform> createTrainingTransformation(NDManager nDManager) {
        List<TimeSeriesTransform> createTransformation = createTransformation(nDManager);
        createTransformation.add(new InstanceSplit(FieldName.TARGET, FieldName.IS_PAD, FieldName.START, FieldName.FORECAST_START, new ExpectedNumInstanceSampler(0, 0, this.predictionLength, 1.0d), this.historyLength, this.predictionLength, new FieldName[]{FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES}, this.distrOutput.getValueInSupport()));
        createTransformation.add(new SelectField(TRAIN_INPUT_FIELDS));
        return createTransformation;
    }

    public List<TimeSeriesTransform> createPredictionTransformation(NDManager nDManager) {
        List<TimeSeriesTransform> createTransformation = createTransformation(nDManager);
        createTransformation.add(new InstanceSplit(FieldName.TARGET, FieldName.IS_PAD, FieldName.START, FieldName.FORECAST_START, PredictionSplitSampler.newValidationSplitSampler(), this.historyLength, this.predictionLength, new FieldName[]{FieldName.FEAT_TIME, FieldName.OBSERVED_VALUES}, this.distrOutput.getValueInSupport()));
        createTransformation.add(new SelectField(PRED_INPUT_FIELDS));
        return createTransformation;
    }

    private List<TimeSeriesTransform> createTransformation(NDManager nDManager) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        arrayList2.add(FieldName.FEAT_DYNAMIC_CAT);
        if (!this.useFeatStaticReal) {
            arrayList2.add(FieldName.FEAT_STATIC_REAL);
        }
        if (!this.useFeatDynamicReal) {
            arrayList2.add(FieldName.FEAT_DYNAMIC_REAL);
        }
        arrayList.add(new RemoveFields(arrayList2));
        if (!this.useFeatStaticCat) {
            arrayList.add(new SetField(FieldName.FEAT_STATIC_CAT, nDManager.zeros(new Shape(new long[]{1}))));
        }
        if (!this.useFeatDynamicReal) {
            arrayList.add(new SetField(FieldName.FEAT_STATIC_REAL, nDManager.zeros(new Shape(new long[]{1}))));
        }
        arrayList.add(new AsArray(FieldName.FEAT_STATIC_CAT, 1, DataType.INT32));
        arrayList.add(new AsArray(FieldName.FEAT_STATIC_REAL, 1));
        arrayList.add(new AddObservedValuesIndicator(FieldName.TARGET, FieldName.OBSERVED_VALUES));
        arrayList.add(new AddTimeFeature(FieldName.START, FieldName.TARGET, FieldName.FEAT_TIME, TimeFeature.timeFeaturesFromFreqStr(this.freq), this.predictionLength, this.freq));
        arrayList.add(new AddAgeFeature(FieldName.TARGET, FieldName.FEAT_AGE, this.predictionLength, true));
        arrayList.add(new VstackFeatures(FieldName.FEAT_TIME, !this.useFeatDynamicReal ? new FieldName[]{FieldName.FEAT_TIME, FieldName.FEAT_AGE} : new FieldName[]{FieldName.FEAT_TIME, FieldName.FEAT_AGE, FieldName.FEAT_DYNAMIC_REAL}));
        return arrayList;
    }

    public static Builder builder() {
        return new Builder();
    }
}
