package ai.djl.timeseries.model.deepar;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.timeseries.distribution.Distribution;
import ai.djl.timeseries.model.deepar.DeepARNetwork;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.util.Iterator;

/* loaded from: input_file:ai/djl/timeseries/model/deepar/DeepARPredictionNetwork.class */
public class DeepARPredictionNetwork extends DeepARNetwork {
    /* JADX INFO: Access modifiers changed from: package-private */
    public DeepARPredictionNetwork(DeepARNetwork.Builder builder) {
        super(builder);
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDList unrollLaggedRnn = unrollLaggedRnn(parameterStore, new NDList(new NDArray[]{(NDArray) nDList.get(0), (NDArray) nDList.get(1), (NDArray) nDList.get(2), (NDArray) nDList.get(3), (NDArray) nDList.get(4), ((NDArray) nDList.get(5)).get(":, :1", new Object[0])}), z);
        NDList nDList2 = new NDList(new NDArray[]{unrollLaggedRnn.get("hidden_state"), unrollLaggedRnn.get("cell_state")});
        NDList nDList3 = new NDList(this.distrOutput.getArgsArray().length);
        for (String str : this.distrOutput.getArgsArray()) {
            NDArray repeat = unrollLaggedRnn.get(str).repeat(0, this.numParallelSamples);
            repeat.setName(str);
            nDList3.add(repeat);
        }
        NDArray repeat2 = unrollLaggedRnn.get("scale").repeat(0, this.numParallelSamples);
        NDArray expandDims = unrollLaggedRnn.get("static_feat").repeat(0, this.numParallelSamples).expandDims(1);
        NDArray div = ((NDArray) nDList.get(3)).repeat(0, this.numParallelSamples).div(repeat2);
        NDArray repeat3 = ((NDArray) nDList.get(5)).repeat(0, this.numParallelSamples);
        NDList nDList4 = new NDList(nDList2.size());
        Iterator it = nDList2.iterator();
        while (it.hasNext()) {
            nDList4.add(((NDArray) it.next()).repeat(1, this.numParallelSamples));
        }
        NDArray sample = outputDistribution(nDList3, repeat2, 1).sample();
        NDList nDList5 = new NDList(this.predictionLength);
        nDList5.add(sample);
        for (int i = 1; i < this.predictionLength; i++) {
            NDArray div2 = sample.div(repeat2);
            NDList forward = this.rnn.forward(parameterStore, new NDList(new NDArray[]{laggedSequenceValues(this.lagsSeq, div, div2).concat(expandDims.concat(repeat3.get(":, {}:{}", new Object[]{Integer.valueOf(i), Integer.valueOf(i + 1)}), -1), -1)}).addAll(nDList4), z);
            NDArray nDArray = (NDArray) forward.get(0);
            nDList4 = forward.subNDList(1);
            div = div.concat(div2, 1);
            sample = outputDistribution(this.paramProj.forward(parameterStore, new NDList(new NDArray[]{nDArray}), z), repeat2, 0).sample();
            nDList5.add(sample);
        }
        return new NDList(new NDArray[]{NDArrays.concat(nDList5, 1).reshape(new long[]{-1, this.numParallelSamples, this.predictionLength})});
    }

    public Shape[] getOutputShapes(Shape[] shapeArr) {
        return new Shape[]{new Shape(new long[]{shapeArr[0].head(), this.numParallelSamples, this.predictionLength})};
    }

    /* JADX WARN: Type inference failed for: r0v5, types: [ai.djl.timeseries.distribution.Distribution$DistributionBuilder] */
    private Distribution outputDistribution(NDList nDList, NDArray nDArray, int i) {
        NDList nDList2 = nDList;
        if (i > 0) {
            nDList2 = new NDList(nDList.size());
            Iterator it = nDList.iterator();
            while (it.hasNext()) {
                NDArray nDArray2 = (NDArray) it.next();
                NDArray nDArray3 = nDArray2.get(":, {}:", new Object[]{Integer.valueOf(-i)});
                nDArray3.setName(nDArray2.getName());
                nDList2.add(nDArray3);
            }
        }
        return this.distrOutput.distributionBuilder().setDistrArgs(nDList2).optScale(nDArray).build();
    }
}
