/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.pooling;

import java.util.Arrays;
import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.MaskState;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.PoolingType;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.AbstractLayer;
import org.deeplearning4j.nn.workspace.ArrayType;
import org.deeplearning4j.nn.workspace.LayerWorkspaceMgr;
import org.deeplearning4j.util.MaskedReductionUtil;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastCopyOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;
import org.nd4j.linalg.primitives.Pair;

public class GlobalPoolingLayer
extends AbstractLayer<org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer> {
    private static final int[] DEFAULT_TIMESERIES_POOL_DIMS = new int[]{2};
    private static final int[] DEFAULT_CNN_POOL_DIMS = new int[]{2, 3};
    private static final int[] DEFAULT_CNN3D_POOL_DIMS = new int[]{2, 3, 4};
    private final int[] poolingDimensions;
    private final PoolingType poolingType;
    private final int pNorm;

    public GlobalPoolingLayer(NeuralNetConfiguration conf, DataType dataType) {
        super(conf, dataType);
        org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer layerConf = (org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer)conf.getLayer();
        this.poolingDimensions = layerConf.getPoolingDimensions();
        this.poolingType = layerConf.getPoolingType();
        this.pNorm = layerConf.getPnorm();
    }

    @Override
    public boolean isPretrainLayer() {
        return false;
    }

    @Override
    public void clearNoiseWeightParams() {
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.SUBSAMPLING;
    }

    @Override
    public INDArray activate(boolean training, LayerWorkspaceMgr workspaceMgr) {
        INDArray reduced2d;
        int[] poolDim;
        this.assertInputSet(false);
        if (this.input.rank() == 3) {
            poolDim = this.poolingDimensions == null ? DEFAULT_TIMESERIES_POOL_DIMS : this.poolingDimensions;
        } else if (this.input.rank() == 4) {
            poolDim = this.poolingDimensions == null ? DEFAULT_CNN_POOL_DIMS : this.poolingDimensions;
        } else if (this.input.rank() == 5) {
            poolDim = this.poolingDimensions == null ? DEFAULT_CNN3D_POOL_DIMS : this.poolingDimensions;
        } else {
            throw new UnsupportedOperationException("Received rank " + this.input.rank() + " input (shape = " + Arrays.toString(this.input.shape()) + "). Only rank 3 (time series), rank 4 (images/CNN data) and rank 5 (volumetric / CNN3D data)  are currently supported for global pooling " + this.layerId());
        }
        if (this.maskArray == null) {
            reduced2d = this.activateHelperFullArray(this.input, poolDim);
        } else if (this.input.rank() == 3) {
            reduced2d = MaskedReductionUtil.maskedPoolingTimeSeries(this.poolingType, this.input, this.maskArray, this.pNorm, this.dataType);
        } else if (this.input.rank() == 4) {
            if (this.maskArray.rank() != 4) {
                throw new UnsupportedOperationException("Only 4d mask arrays are currently supported for masked global reductions on CNN data. Got 4d activations array (shape " + Arrays.toString(this.input.shape()) + ") and " + this.maskArray.rank() + "d mask array (shape " + Arrays.toString(this.maskArray.shape()) + ")  - when used in conjunction with input data of shape [batch,channels,h,w]=" + Arrays.toString(this.input.shape()) + " 4d masks should have shape [batchSize,1,h,1] or [batchSize,1,w,1] or [batchSize,1,h,w]" + this.layerId());
            }
            reduced2d = MaskedReductionUtil.maskedPoolingConvolution(this.poolingType, this.input, this.maskArray, this.pNorm, this.dataType);
        } else {
            throw new UnsupportedOperationException("Invalid input: is rank " + this.input.rank() + " " + this.layerId());
        }
        if (((org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer)this.layerConf()).isCollapseDimensions()) {
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reduced2d);
        }
        long[] inputShape = this.input.shape();
        if (this.input.rank() == 3) {
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reduced2d.reshape(reduced2d.ordering(), new long[]{inputShape[0], inputShape[1], 1L}));
        }
        if (this.input.rank() == 4) {
            return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reduced2d.reshape(reduced2d.ordering(), new long[]{inputShape[0], inputShape[1], 1L, 1L}));
        }
        return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, reduced2d.reshape(reduced2d.ordering(), new long[]{inputShape[0], inputShape[1], 1L, 1L, 1L}));
    }

    public Layer clone() {
        return new GlobalPoolingLayer(this.conf, this.dataType);
    }

    private INDArray activateHelperFullArray(INDArray inputArray, int[] poolDim) {
        switch (this.poolingType) {
            case MAX: {
                return inputArray.max(poolDim);
            }
            case AVG: {
                return inputArray.mean(poolDim);
            }
            case SUM: {
                return inputArray.sum(poolDim);
            }
            case PNORM: {
                int pnorm = ((org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer)this.layerConf()).getPnorm();
                INDArray abs = Transforms.abs((INDArray)inputArray, (boolean)true);
                Transforms.pow((INDArray)abs, (Number)pnorm, (boolean)false);
                INDArray pNorm = abs.sum(poolDim);
                return Transforms.pow((INDArray)pNorm, (Number)(1.0 / (double)pnorm), (boolean)false);
            }
        }
        throw new RuntimeException("Unknown or not supported pooling type: " + (Object)((Object)this.poolingType) + " " + this.layerId());
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon, LayerWorkspaceMgr workspaceMgr) {
        INDArray epsilonNd;
        this.assertInputSet(true);
        if (!((org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer)this.layerConf()).isCollapseDimensions() && epsilon.rank() != 2) {
            long[] origShape = epsilon.shape();
            epsilon = epsilon.reshape(epsilon.ordering(), new long[]{origShape[0], origShape[1]});
        }
        INDArray input = this.input.castTo(this.dataType);
        DefaultGradient retGradient = new DefaultGradient();
        int[] poolDim = null;
        if (input.rank() == 3) {
            poolDim = this.poolingDimensions == null ? DEFAULT_TIMESERIES_POOL_DIMS : this.poolingDimensions;
        } else if (input.rank() == 4) {
            poolDim = this.poolingDimensions == null ? DEFAULT_CNN_POOL_DIMS : this.poolingDimensions;
        } else if (input.rank() == 5) {
            poolDim = this.poolingDimensions == null ? DEFAULT_CNN3D_POOL_DIMS : this.poolingDimensions;
        }
        if (this.maskArray == null) {
            epsilonNd = this.epsilonHelperFullArray(input, epsilon, poolDim);
        } else if (input.rank() == 3) {
            epsilonNd = MaskedReductionUtil.maskedPoolingEpsilonTimeSeries(this.poolingType, input, this.maskArray, epsilon, this.pNorm);
        } else if (input.rank() == 4) {
            epsilonNd = MaskedReductionUtil.maskedPoolingEpsilonCnn(this.poolingType, input, this.maskArray, epsilon, this.pNorm, this.dataType);
        } else {
            throw new UnsupportedOperationException(this.layerId());
        }
        epsilonNd = workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, epsilonNd);
        return new Pair((Object)retGradient, (Object)epsilonNd);
    }

    private INDArray epsilonHelperFullArray(INDArray inputArray, INDArray epsilon, int[] poolDim) {
        int[] broadcastDims = new int[inputArray.rank() - poolDim.length];
        int count = 0;
        for (int i = 0; i < inputArray.rank(); ++i) {
            if (ArrayUtils.contains((int[])poolDim, (int)i)) continue;
            broadcastDims[count++] = i;
        }
        switch (this.poolingType) {
            case MAX: {
                INDArray isMax = Nd4j.getExecutioner().exec((Op)new IsMax(inputArray.dup(), poolDim));
                return Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastMulOp(isMax, epsilon, isMax, broadcastDims));
            }
            case AVG: {
                int n = 1;
                for (int d : poolDim) {
                    n = (int)((long)n * inputArray.size(d));
                }
                INDArray ret = inputArray.ulike();
                Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastCopyOp(ret, epsilon, ret, broadcastDims));
                ret.divi((Number)n);
                return ret;
            }
            case SUM: {
                INDArray retSum = inputArray.ulike();
                Nd4j.getExecutioner().exec((BroadcastOp)new BroadcastCopyOp(retSum, epsilon, retSum, broadcastDims));
                return retSum;
            }
            case PNORM: {
                INDArray numerator;
                int pnorm = ((org.deeplearning4j.nn.conf.layers.GlobalPoolingLayer)this.layerConf()).getPnorm();
                INDArray abs = Transforms.abs((INDArray)inputArray, (boolean)true);
                Transforms.pow((INDArray)abs, (Number)pnorm, (boolean)false);
                INDArray pNorm = Transforms.pow((INDArray)abs.sum(poolDim), (Number)(1.0 / (double)pnorm));
                if (pnorm == 2) {
                    numerator = inputArray.dup();
                } else {
                    INDArray absp2 = Transforms.pow((INDArray)Transforms.abs((INDArray)inputArray, (boolean)true), (Number)(pnorm - 2), (boolean)false);
                    numerator = inputArray.mul(absp2);
                }
                INDArray denom = Transforms.pow((INDArray)pNorm, (Number)(pnorm - 1), (boolean)false);
                denom.rdivi(epsilon);
                Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(numerator, denom, numerator, broadcastDims));
                return numerator;
            }
        }
        throw new RuntimeException("Unknown or not supported pooling type: " + (Object)((Object)this.poolingType) + " " + this.layerId());
    }

    @Override
    public Pair<INDArray, MaskState> feedForwardMaskArray(INDArray maskArray, MaskState currentMaskState, int minibatchSize) {
        this.maskArray = maskArray;
        this.maskState = null;
        return null;
    }
}

