package ai.djl.timeseries.model.deepar;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.Shape;
import ai.djl.timeseries.model.deepar.DeepARNetwork;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;

/* loaded from: input_file:ai/djl/timeseries/model/deepar/DeepARTrainingNetwork.class */
public final class DeepARTrainingNetwork extends DeepARNetwork {
    /* JADX INFO: Access modifiers changed from: package-private */
    public DeepARTrainingNetwork(DeepARNetwork.Builder builder) {
        super(builder);
    }

    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDArray nDArray = (NDArray) nDList.get(0);
        NDArray nDArray2 = (NDArray) nDList.get(1);
        NDArray nDArray3 = (NDArray) nDList.get(2);
        NDArray nDArray4 = (NDArray) nDList.get(3);
        NDArray nDArray5 = (NDArray) nDList.get(4);
        NDArray nDArray6 = (NDArray) nDList.get(6);
        NDArray nDArray7 = (NDArray) nDList.get(7);
        NDArray nDArray8 = (NDArray) nDList.get(8);
        NDList unrollLaggedRnn = unrollLaggedRnn(parameterStore, new NDList(new NDArray[]{nDArray, nDArray2, nDArray3, nDArray4, nDArray5, nDArray6, nDArray7}), z);
        NDArray concat = nDArray5.get(":, {}:", new Object[]{Integer.valueOf((-this.contextLength) + 1)}).concat(nDArray8, 1);
        concat.setName("loss_weights");
        String[] argsArray = this.distrOutput.getArgsArray();
        NDList nDList2 = new NDList(argsArray.length + 2);
        for (String str : argsArray) {
            nDList2.add(unrollLaggedRnn.get(str));
        }
        nDList2.add(unrollLaggedRnn.get("scale"));
        nDList2.add(concat);
        return nDList2;
    }

    public Shape[] getOutputShapes(Shape[] shapeArr) {
        Shape slice = shapeArr[3].slice(2);
        Shape addAll = new Shape(new long[]{1, this.contextLength}).addAll(slice);
        Shape shape = this.scaler.getOutputShapes(new Shape[]{addAll, addAll})[1];
        long j = shape.get(1);
        long j2 = this.embedder.getOutputShapes(new Shape[]{shapeArr[0]})[0].get(1);
        Shape add = new Shape(new long[]{1, (this.contextLength * 2) - 1}).addAll(slice).add(new long[]{this.lagsSeq.size()});
        Shape[] outputShapes = this.paramProj.getOutputShapes(new Shape[]{this.rnn.getOutputShapes(new Shape[]{add.slice(0, add.dimension() - 1).add(new long[]{add.tail() + shapeArr[2].get(2) + j2 + shapeArr[1].get(1) + j})})[0]});
        long[] jArr = new long[shapeArr[8].dimension()];
        System.arraycopy(shapeArr[8].getShape(), 0, jArr, 0, jArr.length);
        jArr[1] = jArr[1] + (this.contextLength - 1);
        Shape shape2 = new Shape(jArr);
        Shape[] shapeArr2 = new Shape[outputShapes.length + 2];
        System.arraycopy(outputShapes, 0, shapeArr2, 0, outputShapes.length);
        shapeArr2[outputShapes.length] = shape;
        shapeArr2[outputShapes.length + 1] = shape2;
        return shapeArr2;
    }
}
