package org.deeplearning4j.nn.params;

import java.util.Map;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.weights.WeightInitUtil;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/params/RecursiveParamInitializer.class */
public class RecursiveParamInitializer extends DefaultParamInitializer {
    public static final String W = "w";
    public static final String U = "u";
    public static final String BIAS = "b";
    public static final String C = "c";

    @Override // org.deeplearning4j.nn.params.DefaultParamInitializer, org.deeplearning4j.nn.api.ParamInitializer
    public void init(Map<String, INDArray> map, NeuralNetConfiguration neuralNetConfiguration) {
        int i = neuralNetConfiguration.getnIn();
        int i2 = i * 2;
        map.put(W, WeightInitUtil.initWeights(new int[]{i2, i}, neuralNetConfiguration.getWeightInit(), neuralNetConfiguration.getActivationFunction(), neuralNetConfiguration.getDist()));
        map.put(U, WeightInitUtil.initWeights(new int[]{i, i2}, neuralNetConfiguration.getWeightInit(), neuralNetConfiguration.getActivationFunction(), neuralNetConfiguration.getDist()));
        map.put("b", WeightInitUtil.initWeights(new int[]{i2}, neuralNetConfiguration.getWeightInit(), neuralNetConfiguration.getActivationFunction(), neuralNetConfiguration.getDist()));
        map.put(C, WeightInitUtil.initWeights(new int[]{i}, neuralNetConfiguration.getWeightInit(), neuralNetConfiguration.getActivationFunction(), neuralNetConfiguration.getDist()));
        neuralNetConfiguration.getGradientList().add(W);
        neuralNetConfiguration.getGradientList().add(U);
        neuralNetConfiguration.getGradientList().add("b");
        neuralNetConfiguration.getGradientList().add(C);
    }
}
