package org.deeplearning4j.nn.weights;

import org.apache.commons.math3.util.FastMath;
import org.deeplearning4j.clustering.kdtree.KDTree;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.rng.distribution.Distribution;
import org.nd4j.linalg.factory.Nd4j;

/* loaded from: input_file:org/deeplearning4j/nn/weights/WeightInitUtil.class */
public class WeightInitUtil {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: org.deeplearning4j.nn.weights.WeightInitUtil$1, reason: invalid class name */
    /* loaded from: input_file:org/deeplearning4j/nn/weights/WeightInitUtil$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$deeplearning4j$nn$weights$WeightInit = new int[WeightInit.values().length];

        static {
            try {
                $SwitchMap$org$deeplearning4j$nn$weights$WeightInit[WeightInit.DISTRIBUTION.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$weights$WeightInit[WeightInit.NORMALIZED.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$weights$WeightInit[WeightInit.RELU.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$weights$WeightInit[WeightInit.SIZE.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$weights$WeightInit[WeightInit.UNIFORM.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$weights$WeightInit[WeightInit.VI.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$weights$WeightInit[WeightInit.XAVIER.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$deeplearning4j$nn$weights$WeightInit[WeightInit.ZERO.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
        }
    }

    public static INDArray normalized(int[] iArr, int i) {
        return Nd4j.rand(iArr).subi(Double.valueOf(0.5d)).divi(Double.valueOf(i));
    }

    public static INDArray uniformBasedOnInAndOut(int[] iArr, int i, int i2) {
        return Nd4j.rand(iArr, Nd4j.getDistributions().createUniform((-4.0d) * Math.sqrt(6.0d / (i2 + i)), 4.0d * Math.sqrt(6.0d / (i2 + i))));
    }

    public static INDArray initWeights(int[] iArr, float f, float f2) {
        return Nd4j.rand(iArr, f, f2, Nd4j.getRandom());
    }

    public static INDArray initWeights(int[] iArr, WeightInit weightInit, Distribution distribution) {
        switch (AnonymousClass1.$SwitchMap$org$deeplearning4j$nn$weights$WeightInit[weightInit.ordinal()]) {
            case KDTree.GREATER /* 1 */:
                return distribution.sample(iArr);
            case 2:
                return Nd4j.rand(iArr, Nd4j.getRandom()).subi(Double.valueOf(0.5d)).divi(Integer.valueOf(iArr[0]));
            case 3:
                return Nd4j.randn(iArr).muli(Double.valueOf(FastMath.sqrt(2.0d / iArr[0])));
            case 4:
                return uniformBasedOnInAndOut(iArr, iArr[0], iArr[1]);
            case 5:
                double d = 1.0d / iArr[0];
                return Nd4j.rand(iArr, -d, d, Nd4j.getRandom());
            case 6:
                INDArray rand = Nd4j.rand(iArr, Nd4j.getRandom());
                int i = 0;
                for (int i2 : iArr) {
                    i += i2;
                }
                double sqrt = Math.sqrt(6.0d) / Math.sqrt(i + 1);
                rand.muli(2).muli(Double.valueOf(sqrt)).subi(Double.valueOf(sqrt));
                return rand;
            case 7:
                return Nd4j.randn(iArr).divi(Double.valueOf(FastMath.sqrt(iArr[0] + iArr[1])));
            case 8:
                return Nd4j.create(iArr);
            default:
                throw new IllegalStateException("Illegal weight init value");
        }
    }

    public static INDArray initWeights(int i, int i2, WeightInit weightInit, Distribution distribution) {
        return initWeights(new int[]{i, i2}, weightInit, distribution);
    }
}
