/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.params;

import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.indexing.INDArrayIndex;
import org.nd4j.linalg.indexing.NDArrayIndex;

public class BatchNormalizationParamInitializer
implements ParamInitializer {
    public static final String GAMMA = "gamma";
    public static final String BETA = "beta";

    @Override
    public int numParams(NeuralNetConfiguration conf, boolean backprop) {
        BatchNormalization layer = (BatchNormalization)conf.getLayer();
        return 2 * layer.getNOut();
    }

    @Override
    public void init(Map<String, INDArray> params, NeuralNetConfiguration conf, INDArray paramView, boolean initializeParams) {
        BatchNormalization layer = (BatchNormalization)conf.getLayer();
        int nOut = layer.getNOut();
        INDArray gammaView = paramView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)nOut)});
        INDArray betaView = paramView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)nOut, (int)(2 * nOut))});
        params.put(GAMMA, this.createGamma(conf, gammaView, initializeParams));
        conf.addVariable(GAMMA);
        params.put(BETA, this.createBeta(conf, betaView, initializeParams));
        conf.addVariable(BETA);
    }

    @Override
    public Map<String, INDArray> getGradientsFromFlattened(NeuralNetConfiguration conf, INDArray gradientView) {
        BatchNormalization layer = (BatchNormalization)conf.getLayer();
        int nOut = layer.getNOut();
        INDArray gammaView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)0, (int)nOut)});
        INDArray betaView = gradientView.get(new INDArrayIndex[]{NDArrayIndex.point((int)0), NDArrayIndex.interval((int)nOut, (int)(2 * nOut))});
        LinkedHashMap<String, INDArray> out = new LinkedHashMap<String, INDArray>();
        out.put(GAMMA, gammaView);
        out.put(BETA, betaView);
        return out;
    }

    protected INDArray createBeta(NeuralNetConfiguration conf, INDArray betaView, boolean initializeParams) {
        BatchNormalization layer = (BatchNormalization)conf.getLayer();
        if (initializeParams) {
            betaView.assign((Number)layer.getBeta());
        }
        return betaView;
    }

    protected INDArray createGamma(NeuralNetConfiguration conf, INDArray gammaView, boolean initializeParams) {
        BatchNormalization layer = (BatchNormalization)conf.getLayer();
        if (initializeParams) {
            gammaView.assign((Number)layer.getGamma());
        }
        return gammaView;
    }
}

