package ai.djl.nn.recurrent;

import ai.djl.ndarray.NDList;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.recurrent.RNN;

/* loaded from: input_file:ai/djl/nn/recurrent/RecurrentCell.class */
public abstract class RecurrentCell extends ParameterBlock {
    protected long stateSize;
    protected float dropRate;
    protected int numStackedLayers;
    protected String mode;
    protected boolean useSequenceLength;
    protected boolean useBidirectional;
    protected boolean stateOutputs;

    /* loaded from: input_file:ai/djl/nn/recurrent/RecurrentCell$BaseBuilder.class */
    public static abstract class BaseBuilder<T extends BaseBuilder> {
        protected float dropRate;
        protected long stateSize = -1;
        protected int numStackedLayers = -1;
        protected double lstmStateClipMin;
        protected double lstmStateClipMax;
        protected boolean clipLstmState;
        protected boolean useSequenceLength;
        protected boolean useBidirectional;
        protected boolean stateOutputs;
        protected RNN.Activation activation;

        public T optDropRate(float f) {
            this.dropRate = f;
            return self();
        }

        public T optLstmStateClipMin(float f, float f2) {
            this.lstmStateClipMin = f;
            this.lstmStateClipMax = f2;
            this.clipLstmState = true;
            return self();
        }

        public T setStateSize(int i) {
            this.stateSize = i;
            return self();
        }

        public T setNumStackedLayers(int i) {
            this.numStackedLayers = i;
            return self();
        }

        public T setActivation(RNN.Activation activation) {
            this.activation = activation;
            return self();
        }

        public T setSequenceLength(boolean z) {
            this.useSequenceLength = z;
            return self();
        }

        public T optBidrectional(boolean z) {
            this.useBidirectional = z;
            return self();
        }

        public T optStateOutput(boolean z) {
            this.stateOutputs = z;
            return self();
        }

        protected abstract T self();
    }

    public RecurrentCell(BaseBuilder<?> baseBuilder) {
        this.stateSize = baseBuilder.stateSize;
        this.dropRate = baseBuilder.dropRate;
        this.numStackedLayers = baseBuilder.numStackedLayers;
        this.useSequenceLength = baseBuilder.useSequenceLength;
        this.useBidirectional = baseBuilder.useBidirectional;
        this.stateOutputs = baseBuilder.stateOutputs;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void validateInputSize(NDList nDList) {
        int i = 1;
        if (this.useSequenceLength) {
            i = 2;
        }
        if (nDList.size() != i) {
            throw new IllegalArgumentException("Invalid number of inputs for RNN. Size of input NDList must be " + i + " when useSequenceLength is " + this.useSequenceLength);
        }
    }
}
