package org.deeplearning4j.nn.modelimport.keras.layers;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/layers/KerasBatchNormalization.class */
public class KerasBatchNormalization extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasBatchNormalization.class);
    public static final int LAYER_BATCHNORM_MODE_1 = 1;
    public static final int LAYER_BATCHNORM_MODE_2 = 2;
    public static final String LAYER_FIELD_GAMMA_REGULARIZER = "gamma_regularizer";
    public static final String LAYER_FIELD_BETA_REGULARIZER = "beta_regularizer";
    public static final String LAYER_FIELD_MODE = "mode";
    public static final String LAYER_FIELD_AXIS = "axis";
    public static final String LAYER_FIELD_MOMENTUM = "momentum";
    public static final String LAYER_FIELD_EPSILON = "epsilon";
    public static final int NUM_TRAINABLE_PARAMS = 4;
    public static final String PARAM_NAME_GAMMA = "gamma";
    public static final String PARAM_NAME_BETA = "beta";
    public static final String PARAM_NAME_RUNNING_MEAN = "running_mean";
    public static final String PARAM_NAME_RUNNING_STD = "running_std";

    public KerasBatchNormalization(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, true);
    }

    public KerasBatchNormalization(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(map, z);
        getGammaRegularizerFromConfig(map, z);
        getBetaRegularizerFromConfig(map, z);
        getBatchNormMode(map, z);
        getBatchNormAxis(map, z);
        this.layer = new BatchNormalization.Builder().name(this.layerName).dropOut(this.dropout).minibatch(true).lockGammaBeta(false).eps(getEpsFromConfig(map)).momentum(getMomentumFromConfig(map)).build();
    }

    public BatchNormalization getBatchNormalizationLayer() {
        return this.layer;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public InputType getOutputType(InputType... inputTypeArr) throws InvalidKerasConfigurationException {
        if (inputTypeArr.length > 1) {
            throw new InvalidKerasConfigurationException("Keras BatchNorm layer accepts only one input (received " + inputTypeArr.length + ")");
        }
        return getBatchNormalizationLayer().getOutputType(-1, inputTypeArr[0]);
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public int getNumParams() {
        return 4;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public void setWeights(Map<String, INDArray> map) throws InvalidKerasConfigurationException {
        this.weights = new HashMap();
        if (!map.containsKey(PARAM_NAME_BETA)) {
            throw new InvalidKerasConfigurationException("Parameter beta does not exist in weights");
        }
        this.weights.put(PARAM_NAME_BETA, map.get(PARAM_NAME_BETA));
        if (!map.containsKey(PARAM_NAME_GAMMA)) {
            throw new InvalidKerasConfigurationException("Parameter gamma does not exist in weights");
        }
        this.weights.put(PARAM_NAME_GAMMA, map.get(PARAM_NAME_GAMMA));
        if (!map.containsKey(PARAM_NAME_RUNNING_MEAN)) {
            throw new InvalidKerasConfigurationException("Parameter running_mean does not exist in weights");
        }
        this.weights.put("mean", map.get(PARAM_NAME_RUNNING_MEAN));
        if (!map.containsKey(PARAM_NAME_RUNNING_STD)) {
            throw new InvalidKerasConfigurationException("Parameter running_std does not exist in weights");
        }
        this.weights.put("var", map.get(PARAM_NAME_RUNNING_STD));
        if (map.size() > 4) {
            Set<String> keySet = map.keySet();
            keySet.remove(PARAM_NAME_BETA);
            keySet.remove(PARAM_NAME_GAMMA);
            keySet.remove(PARAM_NAME_RUNNING_MEAN);
            keySet.remove(PARAM_NAME_RUNNING_STD);
            String obj = keySet.toString();
            log.warn("Attemping to set weights for unknown parameters: " + obj.substring(1, obj.length() - 1));
        }
    }

    protected double getEpsFromConfig(Map<String, Object> map) throws InvalidKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map);
        if (innerLayerConfigFromConfig.containsKey(LAYER_FIELD_EPSILON)) {
            return ((Double) innerLayerConfigFromConfig.get(LAYER_FIELD_EPSILON)).doubleValue();
        }
        throw new InvalidKerasConfigurationException("Keras BatchNorm layer config missing epsilon field");
    }

    protected double getMomentumFromConfig(Map<String, Object> map) throws InvalidKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map);
        if (innerLayerConfigFromConfig.containsKey(LAYER_FIELD_MOMENTUM)) {
            return ((Double) innerLayerConfigFromConfig.get(LAYER_FIELD_MOMENTUM)).doubleValue();
        }
        throw new InvalidKerasConfigurationException("Keras BatchNorm layer config missing momentum field");
    }

    protected void getGammaRegularizerFromConfig(Map<String, Object> map, boolean z) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        if (getInnerLayerConfigFromConfig(map).get(LAYER_FIELD_GAMMA_REGULARIZER) != null) {
            if (z) {
                throw new UnsupportedKerasConfigurationException("Regularization for BatchNormalization gamma parameter not supported");
            }
            log.warn("Regularization for BatchNormalization gamma parameter not supported...ignoring.");
        }
    }

    protected void getBetaRegularizerFromConfig(Map<String, Object> map, boolean z) throws UnsupportedKerasConfigurationException, InvalidKerasConfigurationException {
        if (getInnerLayerConfigFromConfig(map).get(LAYER_FIELD_BETA_REGULARIZER) != null) {
            if (z) {
                throw new UnsupportedKerasConfigurationException("Regularization for BatchNormalization beta parameter not supported");
            }
            log.warn("Regularization for BatchNormalization beta parameter not supported...ignoring.");
        }
    }

    protected int getBatchNormMode(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        Map<String, Object> innerLayerConfigFromConfig = getInnerLayerConfigFromConfig(map);
        if (!innerLayerConfigFromConfig.containsKey("mode")) {
            throw new InvalidKerasConfigurationException("Keras BatchNorm layer config missing mode field");
        }
        int intValue = ((Integer) innerLayerConfigFromConfig.get("mode")).intValue();
        switch (intValue) {
            case 1:
                throw new UnsupportedKerasConfigurationException("Keras BatchNormalization mode 1 (sample-wise) not supported");
            case 2:
                throw new UnsupportedKerasConfigurationException("Keras BatchNormalization (per-batch statistics during testing) 2 not supported");
            default:
                return intValue;
        }
    }

    protected int getBatchNormAxis(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException {
        return ((Integer) getInnerLayerConfigFromConfig(map).get(LAYER_FIELD_AXIS)).intValue();
    }
}
