package ai.djl.training.initializer;

import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;

/* loaded from: input_file:ai/djl/training/initializer/XavierInitializer.class */
public class XavierInitializer implements Initializer {
    private RandomType randomType;
    private FactorType factorType;
    private double magnitude;

    /* loaded from: input_file:ai/djl/training/initializer/XavierInitializer$FactorType.class */
    public enum FactorType {
        AVG,
        IN,
        OUT
    }

    /* loaded from: input_file:ai/djl/training/initializer/XavierInitializer$RandomType.class */
    public enum RandomType {
        UNIFORM,
        GAUSSIAN
    }

    public XavierInitializer(RandomType randomType, FactorType factorType, double d) {
        this.randomType = randomType;
        this.factorType = factorType;
        this.magnitude = d;
    }

    public XavierInitializer() {
        this(RandomType.UNIFORM, FactorType.AVG, 3.0d);
    }

    @Override // ai.djl.training.initializer.Initializer
    public NDArray initialize(NDManager nDManager, Shape shape, DataType dataType) {
        double d;
        long dimension = shape.dimension();
        if (dimension < 2) {
            throw new IllegalArgumentException("XavierInitializer cannot be applied to Shape with dimension: " + dimension + ", it requires shape to be at least 2D.");
        }
        double size = dimension == 2 ? 1.0d : shape.slice(2).size();
        double d2 = shape.get(1) * size;
        double head = shape.head() * size;
        switch (this.factorType) {
            case AVG:
                d = (d2 + head) / 2.0d;
                break;
            case IN:
                d = d2;
                break;
            case OUT:
                d = head;
                break;
            default:
                throw new IllegalArgumentException("Invalid factor type, valid types are: avg, in, out");
        }
        if (d == 0.0d) {
            throw new IllegalStateException("Xavier initializer factor is 0, please check your input shape.");
        }
        double sqrt = StrictMath.sqrt(this.magnitude / d);
        switch (this.randomType) {
            case UNIFORM:
                return nDManager.randomUniform(Double.valueOf(-sqrt), Double.valueOf(sqrt), shape, dataType, nDManager.getDevice());
            case GAUSSIAN:
                return nDManager.randomNormal(0, Double.valueOf(sqrt), shape, dataType, nDManager.getDevice());
            default:
                throw new IllegalArgumentException("Invalid randomType");
        }
    }
}
