package org.deeplearning4j.nn.modelimport.keras.layers;

import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.modelimport.keras.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.KerasLayer;
import org.deeplearning4j.nn.modelimport.keras.UnsupportedKerasConfigurationException;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/modelimport/keras/layers/KerasDense.class */
public class KerasDense extends KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasDense.class);
    public static final int NUM_TRAINABLE_PARAMS = 2;
    public static final String KERAS_PARAM_NAME_W = "W";
    public static final String KERAS_PARAM_NAME_B = "b";

    public KerasDense(Map<String, Object> map) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(map, true);
    }

    public KerasDense(Map<String, Object> map, boolean z) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        super(map, z);
        this.layer = new DenseLayer.Builder().name(this.layerName).nOut(getNOutFromConfig(map)).dropOut(this.dropout).activation(getActivationFromConfig(map)).weightInit(getWeightInitFromConfig(map, z)).biasInit(0.0d).l1(this.weightL1Regularization).l2(this.weightL2Regularization).build();
    }

    public DenseLayer getDenseLayer() {
        return this.layer;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public InputType getOutputType(InputType... inputTypeArr) throws InvalidKerasConfigurationException {
        InputPreProcessor inputPreprocessor = getInputPreprocessor(inputTypeArr[0]);
        return inputPreprocessor != null ? getDenseLayer().getOutputType(-1, inputPreprocessor.getOutputType(inputTypeArr[0])) : getDenseLayer().getOutputType(-1, inputTypeArr[0]);
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public int getNumParams() {
        return 2;
    }

    @Override // org.deeplearning4j.nn.modelimport.keras.KerasLayer
    public void setWeights(Map<String, INDArray> map) throws InvalidKerasConfigurationException {
        this.weights = new HashMap();
        if (!map.containsKey("W")) {
            throw new InvalidKerasConfigurationException("Parameter W does not exist in weights");
        }
        this.weights.put("W", map.get("W"));
        if (!map.containsKey("b")) {
            throw new InvalidKerasConfigurationException("Parameter b does not exist in weights");
        }
        this.weights.put("b", map.get("b"));
        if (map.size() > 2) {
            Set<String> keySet = map.keySet();
            keySet.remove("W");
            keySet.remove("b");
            String obj = keySet.toString();
            log.warn("Attemping to set weights for unknown parameters: " + obj.substring(1, obj.length() - 1));
        }
    }
}
