package ai.djl.basicmodelzoo.cv.classification;

import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Activation;
import ai.djl.nn.Block;
import ai.djl.nn.SequentialBlock;
import ai.djl.nn.convolutional.Conv2d;
import ai.djl.nn.core.Linear;
import ai.djl.nn.norm.BatchNorm;
import ai.djl.nn.pooling.Pool;

/* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/MobileNetV1.class */
public final class MobileNetV1 {
    static final int[] FILTERS = {32, 64, 128, 128, 256, 256, 512, 512, 1024, 1024};

    /* loaded from: input_file:ai/djl/basicmodelzoo/cv/classification/MobileNetV1$Builder.class */
    public static final class Builder {
        float batchNormMomentum = 0.9f;
        float widthMultiplier = 1.0f;
        long outSize = 10;

        Builder() {
        }

        public Builder optWidthMultiplier(float f) {
            this.widthMultiplier = f;
            return this;
        }

        public Builder optBatchNormMomentum(float f) {
            this.batchNormMomentum = f;
            return this;
        }

        public Builder setOutSize(long j) {
            this.outSize = j;
            return this;
        }

        public Block build() {
            return MobileNetV1.mobilenet(this);
        }
    }

    private MobileNetV1() {
    }

    public static Block depthSeparableConv2d(int i, int i2, int i3, Builder builder) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        sequentialBlock.add(Conv2d.builder().setKernelShape(new Shape(new long[]{3, 3})).optBias(false).optPadding(new Shape(new long[]{1, 1})).optStride(new Shape(new long[]{i3, i3})).optGroups(i).setFilters(i).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.reluBlock());
        SequentialBlock sequentialBlock2 = new SequentialBlock();
        sequentialBlock2.add(Conv2d.builder().setKernelShape(new Shape(new long[]{1, 1})).setFilters(i2).optBias(false).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.reluBlock());
        return sequentialBlock.add(sequentialBlock2);
    }

    public static Block mobilenet(Builder builder) {
        SequentialBlock sequentialBlock = new SequentialBlock();
        sequentialBlock.add(new SequentialBlock().add(Conv2d.builder().setKernelShape(new Shape(new long[]{3, 3})).optBias(false).optStride(new Shape(new long[]{2, 2})).optPadding(new Shape(new long[]{1, 1})).setFilters((int) (FILTERS[0] * builder.widthMultiplier)).build()).add(BatchNorm.builder().optEpsilon(2.0E-5f).optMomentum(builder.batchNormMomentum).build()).add(Activation.reluBlock())).add(depthSeparableConv2d((int) (FILTERS[0] * builder.widthMultiplier), (int) (FILTERS[1] * builder.widthMultiplier), 1, builder)).add(depthSeparableConv2d((int) (FILTERS[1] * builder.widthMultiplier), (int) (FILTERS[2] * builder.widthMultiplier), 2, builder)).add(depthSeparableConv2d((int) (FILTERS[2] * builder.widthMultiplier), (int) (FILTERS[3] * builder.widthMultiplier), 1, builder)).add(depthSeparableConv2d((int) (FILTERS[3] * builder.widthMultiplier), (int) (FILTERS[4] * builder.widthMultiplier), 2, builder)).add(depthSeparableConv2d((int) (FILTERS[4] * builder.widthMultiplier), (int) (FILTERS[5] * builder.widthMultiplier), 1, builder)).add(depthSeparableConv2d((int) (FILTERS[5] * builder.widthMultiplier), (int) (FILTERS[6] * builder.widthMultiplier), 2, builder)).add(depthSeparableConv2d((int) (FILTERS[6] * builder.widthMultiplier), (int) (FILTERS[7] * builder.widthMultiplier), 1, builder)).add(depthSeparableConv2d((int) (FILTERS[6] * builder.widthMultiplier), (int) (FILTERS[7] * builder.widthMultiplier), 1, builder)).add(depthSeparableConv2d((int) (FILTERS[6] * builder.widthMultiplier), (int) (FILTERS[7] * builder.widthMultiplier), 1, builder)).add(depthSeparableConv2d((int) (FILTERS[6] * builder.widthMultiplier), (int) (FILTERS[7] * builder.widthMultiplier), 1, builder)).add(depthSeparableConv2d((int) (FILTERS[6] * builder.widthMultiplier), (int) (FILTERS[7] * builder.widthMultiplier), 1, builder)).add(depthSeparableConv2d((int) (FILTERS[7] * builder.widthMultiplier), (int) (FILTERS[8] * builder.widthMultiplier), 2, builder)).add(depthSeparableConv2d((int) (FILTERS[8] * builder.widthMultiplier), (int) (FILTERS[9] * builder.widthMultiplier), 1, builder)).add(Pool.globalAvgPool2dBlock()).add(Linear.builder().setUnits(builder.outSize).build());
        return sequentialBlock;
    }

    public static Builder builder() {
        return new Builder();
    }
}
