package ai.djl.nn.recurrent;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDArrays;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterType;
import ai.djl.nn.recurrent.RecurrentCell;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:ai/djl/nn/recurrent/GRU.class */
public class GRU extends RecurrentCell {
    private static final LayoutType[] EXPECTED_LAYOUT = {LayoutType.TIME, LayoutType.BATCH, LayoutType.CHANNEL};
    private static final byte VERSION = 1;
    private List<Parameter> parameters;
    private Parameter state;

    /* loaded from: input_file:ai/djl/nn/recurrent/GRU$Builder.class */
    public static final class Builder extends RecurrentCell.BaseBuilder<Builder> {
        /* JADX INFO: Access modifiers changed from: protected */
        /* JADX WARN: Can't rename method to resolve collision */
        @Override // ai.djl.nn.recurrent.RecurrentCell.BaseBuilder
        public Builder self() {
            return this;
        }

        public GRU build() {
            if (this.stateSize == -1 || this.numStackedLayers == -1) {
                throw new IllegalArgumentException("Must set stateSize and numStackedLayers");
            }
            return new GRU(this);
        }
    }

    GRU(Builder builder) {
        super(builder);
        this.parameters = Arrays.asList(new Parameter("i2rWeight", this, ParameterType.WEIGHT), new Parameter("i2rBias", this, ParameterType.BIAS), new Parameter("h2rWeight", this, ParameterType.WEIGHT), new Parameter("h2rBias", this, ParameterType.BIAS), new Parameter("i2zWeight", this, ParameterType.WEIGHT), new Parameter("i2zBias", this, ParameterType.BIAS), new Parameter("h2zWeight", this, ParameterType.WEIGHT), new Parameter("h2zBias", this, ParameterType.BIAS), new Parameter("i2nWeight", this, ParameterType.WEIGHT), new Parameter("i2nBias", this, ParameterType.BIAS), new Parameter("h2nWeight", this, ParameterType.WEIGHT), new Parameter("h2nBias", this, ParameterType.BIAS));
        this.state = new Parameter("state", this, ParameterType.OTHER);
        this.mode = "gru";
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, PairList<String, Object> pairList) {
        NDList opInputs = opInputs(parameterStore, nDList);
        return opInputs.head().getNDArrayInternal().rnn(opInputs, this.mode, this.stateSize, this.dropRate, this.numStackedLayers, this.useSequenceLength, this.useBidirectional, this.stateOutputs, pairList);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        Shape shape = shapeArr[0];
        return new Shape[]{new Shape(shape.get(0), shape.get(VERSION), this.stateSize)};
    }

    @Override // ai.djl.nn.Block
    public List<Parameter> getDirectParameters() {
        ArrayList arrayList = new ArrayList(this.parameters);
        arrayList.add(this.state);
        return arrayList;
    }

    @Override // ai.djl.nn.AbstractBlock
    public void beforeInitialize(Shape[] shapeArr) {
        this.inputShapes = shapeArr;
        Block.validateLayout(EXPECTED_LAYOUT, shapeArr[0].getLayout());
    }

    @Override // ai.djl.nn.Block
    public Shape getParameterShape(String str, Shape[] shapeArr) {
        Shape shape = shapeArr[0];
        long j = shape.get(2);
        long j2 = shape.get(VERSION);
        boolean z = -1;
        switch (str.hashCode()) {
            case -1367052928:
                if (str.equals("h2rWeight")) {
                    z = 3;
                    break;
                }
                break;
            case -653782595:
                if (str.equals("h2nBias")) {
                    z = 10;
                    break;
                }
                break;
            case -650088511:
                if (str.equals("h2rBias")) {
                    z = 6;
                    break;
                }
                break;
            case -642700343:
                if (str.equals("h2zBias")) {
                    z = 8;
                    break;
                }
                break;
            case -622100356:
                if (str.equals("h2nWeight")) {
                    z = 5;
                    break;
                }
                break;
            case -369445239:
                if (str.equals("i2zWeight")) {
                    z = VERSION;
                    break;
                }
                break;
            case 109757585:
                if (str.equals("state")) {
                    z = 12;
                    break;
                }
                break;
            case 233721086:
                if (str.equals("i2nBias")) {
                    z = 11;
                    break;
                }
                break;
            case 237415170:
                if (str.equals("i2rBias")) {
                    z = 7;
                    break;
                }
                break;
            case 244803338:
                if (str.equals("i2zBias")) {
                    z = 9;
                    break;
                }
                break;
            case 1120459905:
                if (str.equals("i2rWeight")) {
                    z = false;
                    break;
                }
                break;
            case 1438009224:
                if (str.equals("h2zWeight")) {
                    z = 4;
                    break;
                }
                break;
            case 1865412477:
                if (str.equals("i2nWeight")) {
                    z = 2;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case VERSION /* 1 */:
            case true:
                return new Shape(this.stateSize, j);
            case true:
            case true:
            case true:
                return new Shape(this.stateSize, this.stateSize);
            case true:
            case true:
            case true:
            case true:
            case true:
            case true:
                return new Shape(this.stateSize);
            case true:
                return new Shape(this.numStackedLayers, j2, this.stateSize);
            default:
                throw new IllegalArgumentException("Invalid parameter name: " + str);
        }
    }

    @Override // ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeByte(VERSION);
        Iterator<Parameter> it = this.parameters.iterator();
        while (it.hasNext()) {
            it.next().save(dataOutputStream);
        }
        this.state.save(dataOutputStream);
    }

    @Override // ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        byte readByte = dataInputStream.readByte();
        if (readByte != VERSION) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) readByte));
        }
        Iterator<Parameter> it = this.parameters.iterator();
        while (it.hasNext()) {
            it.next().load(nDManager, dataInputStream);
        }
        this.state.load(nDManager, dataInputStream);
    }

    private NDList opInputs(ParameterStore parameterStore, NDList nDList) {
        validateInputSize(nDList);
        NDArray head = nDList.head();
        Device device = head.getDevice();
        NDList nDList2 = new NDList(head);
        NDList nDList3 = new NDList();
        Throwable th = null;
        try {
            try {
                for (Parameter parameter : this.parameters) {
                    NDArray flatten = parameterStore.getValue(parameter, device).flatten();
                    flatten.setName(parameter.getName());
                    nDList3.add(flatten);
                }
                nDList2.add(NDArrays.concat(nDList3));
                if (nDList3 != null) {
                    if (0 != 0) {
                        try {
                            nDList3.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        nDList3.close();
                    }
                }
                nDList2.add(parameterStore.getValue(this.state, device));
                if (this.useSequenceLength) {
                    nDList2.add(nDList.get(VERSION));
                }
                return nDList2;
            } finally {
            }
        } catch (Throwable th3) {
            if (nDList3 != null) {
                if (th != null) {
                    try {
                        nDList3.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    nDList3.close();
                }
            }
            throw th3;
        }
    }
}
