package ai.djl.nn.norm;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;

/* loaded from: input_file:ai/djl/nn/norm/BatchNorm.class */
public class BatchNorm extends ParameterBlock {
    private static final byte VERSION = 1;
    private int axis;
    private float epsilon;
    private float momentum;
    private long inChannels;
    private boolean center;
    private boolean scale;
    private Parameter gamma;
    private Parameter beta;
    private Parameter runningMean = new Parameter("runningMean", this, ParameterType.RUNNING_MEAN, false);
    private Parameter runningVar = new Parameter("runningVar", this, ParameterType.RUNNING_VAR, false);

    /* loaded from: input_file:ai/djl/nn/norm/BatchNorm$Builder.class */
    public static final class Builder {
        private int axis = BatchNorm.VERSION;
        private float epsilon = 1.0E-5f;
        private float momentum = 0.9f;
        private boolean center = true;
        private boolean scale = true;

        public Builder optAxis(int i) {
            this.axis = i;
            return this;
        }

        public Builder optCenter(boolean z) {
            this.center = z;
            return this;
        }

        public Builder optScale(boolean z) {
            this.scale = z;
            return this;
        }

        public Builder optEpsilon(float f) {
            this.epsilon = f;
            return this;
        }

        public Builder optMomentum(float f) {
            this.momentum = f;
            return this;
        }

        public BatchNorm build() {
            return new BatchNorm(this);
        }
    }

    BatchNorm(Builder builder) {
        this.axis = builder.axis;
        this.epsilon = builder.epsilon;
        this.momentum = builder.momentum;
        this.center = builder.center;
        this.scale = builder.scale;
        this.gamma = new Parameter("gamma", this, ParameterType.GAMMA, this.scale);
        this.beta = new Parameter("beta", this, ParameterType.BETA, this.center);
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, PairList<String, Object> pairList) {
        NDList opInputs = opInputs(parameterStore, nDList);
        return opInputs.head().getNDArrayInternal().batchNorm(opInputs, this.epsilon, this.momentum, this.axis, pairList);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return new Shape[]{shapeArr[0]};
    }

    @Override // ai.djl.nn.Block
    public List<Parameter> getDirectParameters() {
        return Arrays.asList(this.gamma, this.beta, this.runningMean, this.runningVar);
    }

    @Override // ai.djl.nn.AbstractBlock
    public void beforeInitialize(Shape[] shapeArr) {
        this.inputShapes = shapeArr;
        this.inChannels = shapeArr[0].size(this.axis);
    }

    @Override // ai.djl.nn.Block
    public Shape getParameterShape(String str, Shape[] shapeArr) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -1558909720:
                if (str.equals("runningVar")) {
                    z = 3;
                    break;
                }
                break;
            case -1081825756:
                if (str.equals("runningMean")) {
                    z = 2;
                    break;
                }
                break;
            case 3020272:
                if (str.equals("beta")) {
                    z = VERSION;
                    break;
                }
                break;
            case 98120615:
                if (str.equals("gamma")) {
                    z = false;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
            case VERSION /* 1 */:
            case true:
            case true:
                return new Shape(this.inChannels);
            default:
                throw new IllegalArgumentException("Invalid parameter name");
        }
    }

    private NDList opInputs(ParameterStore parameterStore, NDList nDList) {
        if (nDList.size() != VERSION) {
            throw new IllegalArgumentException("Linear requires exactly 1 NDArray");
        }
        NDArray singletonOrThrow = nDList.singletonOrThrow();
        Device device = singletonOrThrow.getDevice();
        return new NDList(singletonOrThrow, parameterStore.getValue(this.gamma, device), parameterStore.getValue(this.beta, device), parameterStore.getValue(this.runningMean, device), parameterStore.getValue(this.runningVar, device));
    }

    @Override // ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeByte(VERSION);
        dataOutputStream.writeLong(this.inChannels);
        this.gamma.save(dataOutputStream);
        this.beta.save(dataOutputStream);
        this.runningMean.save(dataOutputStream);
        this.runningVar.save(dataOutputStream);
    }

    @Override // ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        byte readByte = dataInputStream.readByte();
        if (readByte != VERSION) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) readByte));
        }
        this.inChannels = dataInputStream.readLong();
        this.gamma.load(nDManager, dataInputStream);
        this.beta.load(nDManager, dataInputStream);
        this.runningMean.load(nDManager, dataInputStream);
        this.runningVar.load(nDManager, dataInputStream);
    }
}
