package org.deeplearning4j.nn.layers.factory;

import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.Map;
import org.deeplearning4j.nn.api.LayerFactory;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.AutoEncoder;
import org.deeplearning4j.nn.conf.layers.BatchNormalization;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.conf.layers.DenseLayer;
import org.deeplearning4j.nn.conf.layers.GRU;
import org.deeplearning4j.nn.conf.layers.GravesLSTM;
import org.deeplearning4j.nn.conf.layers.ImageLSTM;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.LocalResponseNormalization;
import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RBM;
import org.deeplearning4j.nn.conf.layers.RecursiveAutoEncoder;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.conf.layers.SubsamplingLayer;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;

/* loaded from: input_file:org/deeplearning4j/nn/layers/factory/DefaultLayerFactory.class */
public class DefaultLayerFactory implements LayerFactory {
    protected Layer layerConfig;

    public DefaultLayerFactory(Class<? extends Layer> cls) {
        try {
            this.layerConfig = cls.newInstance();
        } catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    @Override // org.deeplearning4j.nn.api.LayerFactory
    public <E extends org.deeplearning4j.nn.api.Layer> E create(NeuralNetConfiguration neuralNetConfiguration, int i, int i2, Collection<IterationListener> collection) {
        return (E) create(neuralNetConfiguration, collection, i);
    }

    @Override // org.deeplearning4j.nn.api.LayerFactory
    public <E extends org.deeplearning4j.nn.api.Layer> E create(NeuralNetConfiguration neuralNetConfiguration) {
        return (E) create(neuralNetConfiguration, new ArrayList(), 0);
    }

    @Override // org.deeplearning4j.nn.api.LayerFactory
    public <E extends org.deeplearning4j.nn.api.Layer> E create(NeuralNetConfiguration neuralNetConfiguration, Collection<IterationListener> collection, int i) {
        E e = (E) getInstance(neuralNetConfiguration);
        e.setListeners(collection);
        e.setIndex(i);
        e.setParamTable(getParams(neuralNetConfiguration));
        e.setConf(neuralNetConfiguration);
        return e;
    }

    protected org.deeplearning4j.nn.api.Layer getInstance(NeuralNetConfiguration neuralNetConfiguration) {
        if (this.layerConfig instanceof DenseLayer) {
            return new org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer(neuralNetConfiguration);
        }
        if (this.layerConfig instanceof AutoEncoder) {
            return new org.deeplearning4j.nn.layers.feedforward.autoencoder.AutoEncoder(neuralNetConfiguration);
        }
        if (this.layerConfig instanceof RBM) {
            return new org.deeplearning4j.nn.layers.feedforward.rbm.RBM(neuralNetConfiguration);
        }
        if (this.layerConfig instanceof ImageLSTM) {
            return new org.deeplearning4j.nn.layers.recurrent.ImageLSTM(neuralNetConfiguration);
        }
        if (this.layerConfig instanceof GravesLSTM) {
            return new org.deeplearning4j.nn.layers.recurrent.GravesLSTM(neuralNetConfiguration);
        }
        if (this.layerConfig instanceof GRU) {
            return new org.deeplearning4j.nn.layers.recurrent.GRU(neuralNetConfiguration);
        }
        if (this.layerConfig instanceof OutputLayer) {
            return new org.deeplearning4j.nn.layers.OutputLayer(neuralNetConfiguration);
        }
        if (this.layerConfig instanceof RnnOutputLayer) {
            return new org.deeplearning4j.nn.layers.recurrent.RnnOutputLayer(neuralNetConfiguration);
        }
        if (this.layerConfig instanceof RecursiveAutoEncoder) {
            return new org.deeplearning4j.nn.layers.feedforward.autoencoder.recursive.RecursiveAutoEncoder(neuralNetConfiguration);
        }
        if (this.layerConfig instanceof ConvolutionLayer) {
            return new org.deeplearning4j.nn.layers.convolution.ConvolutionLayer(neuralNetConfiguration);
        }
        if (this.layerConfig instanceof SubsamplingLayer) {
            return new org.deeplearning4j.nn.layers.convolution.subsampling.SubsamplingLayer(neuralNetConfiguration);
        }
        if (this.layerConfig instanceof BatchNormalization) {
            return new org.deeplearning4j.nn.layers.normalization.BatchNormalization(neuralNetConfiguration);
        }
        if (this.layerConfig instanceof LocalResponseNormalization) {
            return new org.deeplearning4j.nn.layers.normalization.LocalResponseNormalization(neuralNetConfiguration);
        }
        throw new RuntimeException("unknown layer type: " + this.layerConfig);
    }

    protected Map<String, INDArray> getParams(NeuralNetConfiguration neuralNetConfiguration) {
        ParamInitializer initializer = initializer();
        Map<String, INDArray> synchronizedMap = Collections.synchronizedMap(new LinkedHashMap());
        initializer.init(synchronizedMap, neuralNetConfiguration);
        return synchronizedMap;
    }

    @Override // org.deeplearning4j.nn.api.LayerFactory
    public ParamInitializer initializer() {
        return new DefaultParamInitializer();
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (!(obj instanceof DefaultLayerFactory)) {
            return false;
        }
        DefaultLayerFactory defaultLayerFactory = (DefaultLayerFactory) obj;
        return this.layerConfig == null ? defaultLayerFactory.layerConfig == null : this.layerConfig.equals(defaultLayerFactory.layerConfig);
    }

    public int hashCode() {
        if (this.layerConfig != null) {
            return this.layerConfig.hashCode();
        }
        return 0;
    }
}
