/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.util;

import java.util.HashSet;
import java.util.Set;
import org.deeplearning4j.exception.DL4JInvalidConfigException;
import org.deeplearning4j.nn.conf.layers.BaseOutputLayer;
import org.deeplearning4j.nn.conf.layers.CnnLossLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.LossLayer;
import org.deeplearning4j.nn.conf.layers.RnnLossLayer;
import org.deeplearning4j.nn.conf.ocnn.OCNNOutputLayer;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationCube;
import org.nd4j.linalg.activations.impl.ActivationELU;
import org.nd4j.linalg.activations.impl.ActivationHardTanH;
import org.nd4j.linalg.activations.impl.ActivationIdentity;
import org.nd4j.linalg.activations.impl.ActivationLReLU;
import org.nd4j.linalg.activations.impl.ActivationPReLU;
import org.nd4j.linalg.activations.impl.ActivationRReLU;
import org.nd4j.linalg.activations.impl.ActivationRationalTanh;
import org.nd4j.linalg.activations.impl.ActivationReLU;
import org.nd4j.linalg.activations.impl.ActivationReLU6;
import org.nd4j.linalg.activations.impl.ActivationSELU;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
import org.nd4j.linalg.activations.impl.ActivationSoftPlus;
import org.nd4j.linalg.activations.impl.ActivationSoftSign;
import org.nd4j.linalg.activations.impl.ActivationSoftmax;
import org.nd4j.linalg.activations.impl.ActivationSwish;
import org.nd4j.linalg.activations.impl.ActivationTanH;
import org.nd4j.linalg.activations.impl.ActivationThresholdedReLU;
import org.nd4j.linalg.lossfunctions.ILossFunction;
import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
import org.nd4j.linalg.lossfunctions.impl.LossMCXENT;

public class OutputLayerUtil {
    private static final Set<Class<?>> OUTSIDE_ZERO_ONE_RANGE = new HashSet();
    private static final String COMMON_MSG = "\nThis configuration validation check can be disabled for MultiLayerConfiguration and ComputationGraphConfiguration using validateOutputLayerConfig(false), however this is not recommended.";

    private OutputLayerUtil() {
    }

    public static void validateOutputLayer(String layerName, Layer layer) {
        long nOut;
        ILossFunction loss;
        IActivation activation;
        boolean isLossLayer = false;
        if (layer instanceof BaseOutputLayer && !(layer instanceof OCNNOutputLayer)) {
            activation = ((BaseOutputLayer)layer).getActivationFn();
            loss = ((BaseOutputLayer)layer).getLossFn();
            nOut = ((BaseOutputLayer)layer).getNOut();
        } else if (layer instanceof LossLayer) {
            activation = ((LossLayer)layer).getActivationFn();
            loss = ((LossLayer)layer).getLossFn();
            nOut = ((LossLayer)layer).getNOut();
            isLossLayer = true;
        } else if (layer instanceof RnnLossLayer) {
            activation = ((RnnLossLayer)layer).getActivationFn();
            loss = ((RnnLossLayer)layer).getLossFn();
            nOut = ((RnnLossLayer)layer).getNOut();
            isLossLayer = true;
        } else if (layer instanceof CnnLossLayer) {
            activation = ((CnnLossLayer)layer).getActivationFn();
            loss = ((CnnLossLayer)layer).getLossFn();
            nOut = ((CnnLossLayer)layer).getNOut();
            isLossLayer = true;
        } else {
            return;
        }
        OutputLayerUtil.validateOutputLayerConfiguration(layerName, nOut, isLossLayer, activation, loss);
    }

    public static void validateOutputLayerConfiguration(String layerName, long nOut, boolean isLossLayer, IActivation activation, ILossFunction lossFunction) {
        if (!isLossLayer && nOut == 1L && activation instanceof ActivationSoftmax) {
            throw new DL4JInvalidConfigException("Invalid output layer configuration for layer \"" + layerName + "\": Softmax + nOut=1 networks are not supported. Softmax cannot be used with nOut=1 as the output will always be exactly 1.0 regardless of the input. " + COMMON_MSG);
        }
        if (OutputLayerUtil.lossFunctionExpectsProbability(lossFunction) && OutputLayerUtil.activationExceedsZeroOneRange(activation, isLossLayer)) {
            throw new DL4JInvalidConfigException("Invalid output layer configuration for layer \"" + layerName + "\": loss function " + lossFunction + " expects activations to be in the range 0 to 1 (probabilities) but activation function " + activation + " does not bound values to this 0 to 1 range. This indicates a likely invalid network configuration. " + COMMON_MSG);
        }
        if (activation instanceof ActivationSoftmax && lossFunction instanceof LossBinaryXENT) {
            throw new DL4JInvalidConfigException("Invalid output layer configuration for layer \"" + layerName + "\": softmax activation function in combination with LossBinaryXENT (binary cross entropy loss function). For multi-class classification, use softmax + MCXENT (multi-class cross entropy); for binary multi-label classification, use sigmoid + XENT. " + COMMON_MSG);
        }
        if (activation instanceof ActivationSigmoid && lossFunction instanceof LossMCXENT) {
            throw new DL4JInvalidConfigException("Invalid output layer configuration for layer \"" + layerName + "\": sigmoid activation function in combination with LossMCXENT (multi-class cross entropy loss function). For multi-class classification, use softmax + MCXENT (multi-class cross entropy); for binary multi-label classification, use sigmoid + XENT. " + COMMON_MSG);
        }
    }

    public static boolean lossFunctionExpectsProbability(ILossFunction lf) {
        return lf instanceof LossMCXENT || lf instanceof LossBinaryXENT;
    }

    public static boolean activationExceedsZeroOneRange(IActivation activation, boolean isLossLayer) {
        if (OUTSIDE_ZERO_ONE_RANGE.contains(activation.getClass())) {
            return !isLossLayer || !(activation instanceof ActivationIdentity);
        }
        return false;
    }

    static {
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationCube.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationELU.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationHardTanH.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationIdentity.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationLReLU.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationPReLU.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationRationalTanh.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationReLU.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationReLU6.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationRReLU.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationSELU.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationSoftPlus.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationSoftSign.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationSwish.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationTanH.class);
        OUTSIDE_ZERO_ONE_RANGE.add(ActivationThresholdedReLU.class);
    }
}

