package ai.djl.nn.recurrent;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.nn.recurrent.RNN;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.Iterator;

/* loaded from: input_file:ai/djl/nn/recurrent/RecurrentBlock.class */
public abstract class RecurrentBlock extends AbstractBlock {
    private static final byte VERSION = 2;
    private static final LayoutType[] EXPECTED_LAYOUT = {LayoutType.BATCH, LayoutType.TIME, LayoutType.CHANNEL};
    protected long stateSize;
    protected float dropRate;
    protected int numStackedLayers;
    protected String mode;
    protected boolean useSequenceLength;
    protected int numDirections;
    protected int gates;
    protected boolean stateOutputs;
    protected NDArray beginState;

    /* loaded from: input_file:ai/djl/nn/recurrent/RecurrentBlock$BaseBuilder.class */
    public static abstract class BaseBuilder<T extends BaseBuilder> {
        protected float dropRate;
        protected long stateSize = -1;
        protected int numStackedLayers = -1;
        protected double lstmStateClipMin;
        protected double lstmStateClipMax;
        protected boolean clipLstmState;
        protected boolean useSequenceLength;
        protected boolean useBidirectional;
        protected boolean stateOutputs;
        protected RNN.Activation activation;

        public T optDropRate(float f) {
            this.dropRate = f;
            return self();
        }

        public T setStateSize(int i) {
            this.stateSize = i;
            return self();
        }

        public T setNumStackedLayers(int i) {
            this.numStackedLayers = i;
            return self();
        }

        public T setSequenceLength(boolean z) {
            this.useSequenceLength = z;
            return self();
        }

        public T optBidrectional(boolean z) {
            this.useBidirectional = z;
            return self();
        }

        public T optStateOutput(boolean z) {
            this.stateOutputs = z;
            return self();
        }

        protected abstract T self();
    }

    public RecurrentBlock(BaseBuilder<?> baseBuilder) {
        super((byte) 2);
        this.numDirections = 1;
        this.stateSize = baseBuilder.stateSize;
        this.dropRate = baseBuilder.dropRate;
        this.numStackedLayers = baseBuilder.numStackedLayers;
        this.useSequenceLength = baseBuilder.useSequenceLength;
        this.stateOutputs = baseBuilder.stateOutputs;
        if (baseBuilder.useBidirectional) {
            this.numDirections = VERSION;
        }
        ParameterType[] parameterTypeArr = {ParameterType.WEIGHT, ParameterType.BIAS};
        String[] strArr = baseBuilder.useBidirectional ? new String[]{"l", "r"} : new String[]{"l"};
        String[] strArr2 = {"i2h", "h2h"};
        for (ParameterType parameterType : parameterTypeArr) {
            for (int i = 0; i < this.numStackedLayers; i++) {
                for (String str : strArr) {
                    for (String str2 : strArr2) {
                        addParameter(new Parameter(String.format("%s_%s_%s_%s", str, Integer.valueOf(i), str2, parameterType.name()), this, parameterType));
                    }
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void validateInputSize(NDList nDList) {
        int i = 1;
        if (this.useSequenceLength) {
            i = VERSION;
        }
        if (nDList.size() != i) {
            throw new IllegalArgumentException("Invalid number of inputs for RNN. Size of input NDList must be " + i + " when useSequenceLength is " + this.useSequenceLength);
        }
    }

    public final void setStateOutputs(boolean z) {
        this.stateOutputs = z;
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDList opInputs = opInputs(parameterStore, nDList);
        NDList rnn = opInputs.head().getNDArrayInternal().rnn(opInputs, this.mode, this.stateSize, this.dropRate, this.numStackedLayers, this.useSequenceLength, isBidirectional(), true, pairList);
        NDList nDList2 = new NDList(rnn.head().transpose(1, 0, VERSION));
        if (this.stateOutputs) {
            nDList2.add(rnn.get(1));
        }
        resetBeginStates();
        return nDList2;
    }

    public void setBeginStates(NDList nDList) {
        this.beginState = nDList.get(0);
    }

    protected void resetBeginStates() {
        this.beginState = null;
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        Shape shape = shapeArr[0];
        return this.stateOutputs ? new Shape[]{new Shape(shape.get(1), shape.get(0), this.stateSize * this.numDirections), new Shape(this.numStackedLayers * this.numDirections, shape.get(1), this.stateSize)} : new Shape[]{new Shape(shape.get(1), shape.get(0), this.stateSize * this.numDirections)};
    }

    @Override // ai.djl.nn.AbstractBlock
    public void beforeInitialize(Shape[] shapeArr) {
        this.inputShapes = shapeArr;
        Shape shape = shapeArr[0];
        Block.validateLayout(EXPECTED_LAYOUT, shape.getLayout());
        shapeArr[0] = new Shape(shape.get(1), shape.get(0), shape.get(VERSION));
    }

    @Override // ai.djl.nn.AbstractBlock, ai.djl.nn.Block
    public Shape getParameterShape(String str, Shape[] shapeArr) {
        int parseInt = Integer.parseInt(str.split("_")[1]);
        long j = shapeArr[0].get(VERSION);
        if (parseInt > 0) {
            j = this.stateSize * this.numDirections;
        }
        if (str.contains("BIAS")) {
            return new Shape(this.gates * this.stateSize);
        }
        if (str.contains("i2h")) {
            return new Shape(this.gates * this.stateSize, j);
        }
        if (str.contains("h2h")) {
            return new Shape(this.gates * this.stateSize, this.stateSize);
        }
        throw new IllegalArgumentException("Invalid parameter name");
    }

    @Override // ai.djl.nn.AbstractBlock
    public void loadMetadata(byte b, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b == VERSION) {
            readInputShapes(dataInputStream);
        } else if (b != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) b));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public boolean isBidirectional() {
        return this.numDirections == VERSION;
    }

    protected NDList opInputs(ParameterStore parameterStore, NDList nDList) {
        validateInputSize(nDList);
        long j = nDList.head().getShape().get(0);
        NDList updateInputLayoutToTNC = updateInputLayoutToTNC(nDList);
        NDArray head = updateInputLayoutToTNC.head();
        Device device = head.getDevice();
        NDList nDList2 = new NDList(head);
        NDList nDList3 = new NDList();
        Throwable th = null;
        try {
            Iterator<Parameter> it = this.parameters.values().iterator();
            while (it.hasNext()) {
                nDList3.add(parameterStore.getValue(it.next(), device).flatten());
            }
            nDList2.add(NDArrays.concat(nDList3));
            if (nDList3 != null) {
                if (0 != 0) {
                    try {
                        nDList3.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                } else {
                    nDList3.close();
                }
            }
            Shape shape = new Shape(this.numStackedLayers * this.numDirections, j, this.stateSize);
            if (this.beginState != null) {
                nDList2.add(this.beginState);
            } else {
                nDList2.add(updateInputLayoutToTNC.head().getManager().zeros(shape));
            }
            if (this.useSequenceLength) {
                nDList2.add(updateInputLayoutToTNC.get(1));
            }
            return nDList2;
        } catch (Throwable th3) {
            if (nDList3 != null) {
                if (0 != 0) {
                    try {
                        nDList3.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    nDList3.close();
                }
            }
            throw th3;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public NDList updateInputLayoutToTNC(NDList nDList) {
        return new NDList(nDList.singletonOrThrow().transpose(1, 0, VERSION));
    }
}
