package ai.djl.nn;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.training.ParameterStore;
import ai.djl.training.initializer.Initializer;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.function.Predicate;

/* loaded from: input_file:ai/djl/nn/Block.class */
public interface Block {
    default NDList forward(ParameterStore parameterStore, NDList nDList, boolean z) {
        return forward(parameterStore, nDList, z, (PairList<String, Object>) null);
    }

    NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList);

    default NDList forward(ParameterStore parameterStore, NDList nDList, NDList nDList2, PairList<String, Object> pairList) {
        return forward(parameterStore, nDList, true, pairList);
    }

    void setInitializer(Initializer initializer, Parameter.Type type);

    void setInitializer(Initializer initializer, String str);

    void setInitializer(Initializer initializer, Predicate<Parameter> predicate);

    void initialize(NDManager nDManager, DataType dataType, Shape... shapeArr);

    boolean isInitialized();

    void cast(DataType dataType);

    void clear();

    PairList<String, Shape> describeInput();

    BlockList getChildren();

    ParameterList getDirectParameters();

    ParameterList getParameters();

    Shape[] getOutputShapes(Shape[] shapeArr);

    void saveParameters(DataOutputStream dataOutputStream) throws IOException;

    void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException;

    static void validateLayout(LayoutType[] layoutTypeArr, LayoutType[] layoutTypeArr2) {
        if (layoutTypeArr2.length != layoutTypeArr.length) {
            throw new UnsupportedOperationException("Expected layout: " + LayoutType.toString(layoutTypeArr) + ", but got: " + LayoutType.toString(layoutTypeArr2));
        }
        for (int i = 0; i < layoutTypeArr2.length; i++) {
            if (layoutTypeArr2[i] != LayoutType.UNKNOWN && layoutTypeArr2[i] != layoutTypeArr[i]) {
                throw new UnsupportedOperationException("Expected layout: " + LayoutType.toString(layoutTypeArr) + ", but got: " + LayoutType.toString(layoutTypeArr2));
            }
        }
    }
}
