package org.deeplearning4j.nn.conf.layers;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import org.deeplearning4j.eval.EvaluationBinary;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.api.layers.LayerConstraint;
import org.deeplearning4j.nn.conf.distribution.Distribution;
import org.deeplearning4j.nn.conf.distribution.NormalDistribution;
import org.deeplearning4j.nn.conf.dropout.IDropout;
import org.deeplearning4j.nn.conf.layers.misc.FrozenLayer;
import org.deeplearning4j.nn.conf.layers.recurrent.Bidirectional;
import org.deeplearning4j.nn.weights.WeightInit;
import org.nd4j.util.OneTimeLogger;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:org/deeplearning4j/nn/conf/layers/LayerValidation.class */
public class LayerValidation {
    private static final Logger log = LoggerFactory.getLogger(LayerValidation.class);

    public static void assertNInNOutSet(String str, String str2, int i, int i2, int i3) {
        if (i2 <= 0 || i3 <= 0) {
            if (str2 == null) {
                str2 = "(name not set)";
            }
            throw new DL4JInvalidConfigException(str + " (index=" + i + ", name=" + str2 + ") nIn=" + i2 + ", nOut=" + i3 + "; nIn and nOut must be > 0");
        }
    }

    public static void generalValidation(String str, Layer layer, IDropout iDropout, Double d, Double d2, Double d3, Double d4, Distribution distribution, List<LayerConstraint> list, List<LayerConstraint> list2, List<LayerConstraint> list3) {
        generalValidation(str, layer, iDropout, d == null ? Double.NaN : d.doubleValue(), d2 == null ? Double.NaN : d2.doubleValue(), d3 == null ? Double.NaN : d3.doubleValue(), d4 == null ? Double.NaN : d4.doubleValue(), distribution, list, list2, list3);
    }

    public static void generalValidation(String str, Layer layer, IDropout iDropout, double d, double d2, double d3, double d4, Distribution distribution, List<LayerConstraint> list, List<LayerConstraint> list2, List<LayerConstraint> list3) {
        if (layer != null) {
            if (layer instanceof BaseLayer) {
                configureBaseLayer(str, (BaseLayer) layer, iDropout, Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3), Double.valueOf(d4), distribution);
            } else if ((layer instanceof FrozenLayer) && (((FrozenLayer) layer).getLayer() instanceof BaseLayer)) {
                configureBaseLayer(str, (BaseLayer) ((FrozenLayer) layer).getLayer(), iDropout, Double.valueOf(d), Double.valueOf(d2), Double.valueOf(d3), Double.valueOf(d4), distribution);
            } else if (layer instanceof Bidirectional) {
                Bidirectional bidirectional = (Bidirectional) layer;
                generalValidation(str, bidirectional.getFwd(), iDropout, d, d2, d3, d4, distribution, list, list2, list3);
                generalValidation(str, bidirectional.getBwd(), iDropout, d, d2, d3, d4, distribution, list, list2, list3);
            }
            if (layer.getConstraints() == null || layer.constraints.isEmpty()) {
                ArrayList arrayList = new ArrayList();
                if (list != null && !layer.initializer().paramKeys(layer).isEmpty()) {
                    Iterator<LayerConstraint> it = arrayList.iterator();
                    while (it.hasNext()) {
                        LayerConstraint mo37clone = it.next().mo37clone();
                        mo37clone.setParams(new HashSet(layer.initializer().paramKeys(layer)));
                        arrayList.add(mo37clone);
                    }
                }
                if (list2 != null && !layer.initializer().weightKeys(layer).isEmpty()) {
                    Iterator<LayerConstraint> it2 = list2.iterator();
                    while (it2.hasNext()) {
                        LayerConstraint mo37clone2 = it2.next().mo37clone();
                        mo37clone2.setParams(new HashSet(layer.initializer().weightKeys(layer)));
                        arrayList.add(mo37clone2);
                    }
                }
                if (list3 != null && !layer.initializer().biasKeys(layer).isEmpty()) {
                    Iterator<LayerConstraint> it3 = list3.iterator();
                    while (it3.hasNext()) {
                        LayerConstraint mo37clone3 = it3.next().mo37clone();
                        mo37clone3.setParams(new HashSet(layer.initializer().biasKeys(layer)));
                        arrayList.add(mo37clone3);
                    }
                }
                if (arrayList.isEmpty()) {
                    layer.setConstraints(null);
                } else {
                    layer.setConstraints(arrayList);
                }
            }
        }
    }

    private static void configureBaseLayer(String str, BaseLayer baseLayer, IDropout iDropout, Double d, Double d2, Double d3, Double d4, Distribution distribution) {
        if (!Double.isNaN(d3.doubleValue()) && Double.isNaN(baseLayer.getL1())) {
            baseLayer.setL1(d3.doubleValue());
        }
        if (!Double.isNaN(d.doubleValue()) && Double.isNaN(baseLayer.getL2())) {
            baseLayer.setL2(d.doubleValue());
        }
        if (!Double.isNaN(d4.doubleValue()) && Double.isNaN(baseLayer.getL1Bias())) {
            baseLayer.setL1Bias(d4.doubleValue());
        }
        if (!Double.isNaN(d2.doubleValue()) && Double.isNaN(baseLayer.getL2Bias())) {
            baseLayer.setL2Bias(d2.doubleValue());
        }
        if (Double.isNaN(d.doubleValue()) && Double.isNaN(baseLayer.getL2())) {
            baseLayer.setL2(EvaluationBinary.DEFAULT_EDGE_VALUE);
        }
        if (Double.isNaN(d3.doubleValue()) && Double.isNaN(baseLayer.getL1())) {
            baseLayer.setL1(EvaluationBinary.DEFAULT_EDGE_VALUE);
        }
        if (Double.isNaN(d2.doubleValue()) && Double.isNaN(baseLayer.getL2Bias())) {
            baseLayer.setL2Bias(EvaluationBinary.DEFAULT_EDGE_VALUE);
        }
        if (Double.isNaN(d4.doubleValue()) && Double.isNaN(baseLayer.getL1Bias())) {
            baseLayer.setL1Bias(EvaluationBinary.DEFAULT_EDGE_VALUE);
        }
        if (baseLayer.getIDropout() == null) {
            baseLayer.setIDropout(iDropout);
        }
        if (baseLayer.getWeightInit() != WeightInit.DISTRIBUTION) {
            if (distribution == null && baseLayer.getDist() == null) {
                return;
            }
            OneTimeLogger.warn(log, "Layer \"" + str + "\" distribution is set but will not be applied unless weight init is set to WeighInit.DISTRIBUTION.", new Object[0]);
            return;
        }
        if (distribution != null && baseLayer.getDist() == null) {
            baseLayer.setDist(distribution);
        } else if (distribution == null && baseLayer.getDist() == null) {
            baseLayer.setDist(new NormalDistribution(EvaluationBinary.DEFAULT_EDGE_VALUE, 1.0d));
            OneTimeLogger.warn(log, "Layer \"" + str + "\" distribution is automatically set to normalize distribution with mean 0 and variance 1.", new Object[0]);
        }
    }
}
