package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.function.Function;

/* loaded from: input_file:ai/djl/nn/SequentialBlock.class */
public class SequentialBlock extends AbstractBlock {
    private static final byte VERSION = 2;

    public SequentialBlock() {
        super((byte) 2);
    }

    public SequentialBlock addAll(Block... blockArr) {
        addAll(Arrays.asList(blockArr));
        return this;
    }

    public SequentialBlock addAll(Collection<Block> collection) {
        collection.forEach(this::add);
        return this;
    }

    public SequentialBlock add(Block block) {
        if (block != null) {
            addChildBlock(block.getClass().getSimpleName(), block);
        }
        return this;
    }

    public SequentialBlock add(Function<NDList, NDList> function) {
        add(new LambdaBlock(function));
        return this;
    }

    public SequentialBlock addSingleton(Function<NDArray, NDArray> function) {
        add(LambdaBlock.singleton(function));
        return this;
    }

    public void removeLastBlock() {
        this.children.remove(this.children.size() - 1);
    }

    public void replaceLastBlock(Block block) {
        removeLastBlock();
        if (block != null) {
            add(block);
        }
    }

    @Override // ai.djl.nn.AbstractBlock
    protected NDList forwardInternal(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDList nDList2 = nDList;
        Iterator<Block> it = this.children.values().iterator();
        while (it.hasNext()) {
            nDList2 = it.next().forward(parameterStore, nDList2, z);
        }
        return nDList2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // ai.djl.nn.AbstractBlock
    public NDList forwardInternal(ParameterStore parameterStore, NDList nDList, NDList nDList2, PairList<String, Object> pairList) {
        NDList nDList3 = nDList;
        Iterator<Block> it = this.children.values().iterator();
        while (it.hasNext()) {
            nDList3 = it.next().forward(parameterStore, nDList3, nDList2, pairList);
        }
        return nDList3;
    }

    @Override // ai.djl.nn.AbstractBlock
    public void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        Shape[] shapeArr2 = shapeArr;
        for (Block block : getChildren().values()) {
            block.initialize(nDManager, dataType, shapeArr2);
            shapeArr2 = block.getOutputShapes(shapeArr2);
        }
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(Shape[] shapeArr) {
        if (this.children.isEmpty()) {
            throw new IllegalArgumentException("The sequential block is empty");
        }
        Shape[] shapeArr2 = shapeArr;
        Iterator<Block> it = this.children.values().iterator();
        while (it.hasNext()) {
            shapeArr2 = it.next().getOutputShapes(shapeArr2);
        }
        return shapeArr2;
    }

    @Override // ai.djl.nn.AbstractBlock
    public void loadMetadata(byte b, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        if (b == VERSION) {
            readInputShapes(dataInputStream);
        } else if (b != 1) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) b));
        }
    }

    @Override // ai.djl.nn.AbstractBlock
    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Sequential(\n");
        Iterator<Block> it = this.children.values().iterator();
        while (it.hasNext()) {
            sb.append(it.next().toString().replaceAll("(?m)^", "\t")).append('\n');
        }
        sb.append(')');
        return sb.toString();
    }
}
