/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers.normalization;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.conf.layers.ConvolutionLayer;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.layers.BaseLayer;
import org.deeplearning4j.optimize.api.ConvexOptimizer;
import org.deeplearning4j.optimize.api.IterationListener;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.BroadcastOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastAddOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastDivOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp;
import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastSubOp;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.ops.transforms.Transforms;

public class BatchNormalization
extends BaseLayer<ConvolutionLayer> {
    private INDArray std;
    private NeuralNetConfiguration conf;
    private int index = 0;
    private List<IterationListener> listeners = new ArrayList<IterationListener>();
    private Map<String, INDArray> params = new LinkedHashMap<String, INDArray>();
    private int[] shape;
    private Gradient gradient;
    private INDArray xHat;

    public BatchNormalization(NeuralNetConfiguration conf) {
        super(conf);
    }

    @Override
    public double calcL2() {
        return 0.0;
    }

    @Override
    public double calcL1() {
        return 0.0;
    }

    @Override
    public Layer.Type type() {
        return Layer.Type.CONVOLUTIONAL;
    }

    @Override
    public Gradient error(INDArray input) {
        return null;
    }

    @Override
    public INDArray derivativeActivation(INDArray input) {
        return null;
    }

    @Override
    public Gradient calcGradient(Gradient layerError, INDArray indArray) {
        return null;
    }

    @Override
    public Pair<Gradient, INDArray> backpropGradient(INDArray epsilon) {
        epsilon = epsilon.reshape(this.shape);
        int m = this.shape[0] * this.shape[2];
        INDArray gBeta = epsilon.sum(new int[]{0, 2});
        INDArray gammGradient = this.getParam("gammaGradient");
        Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastAddOp(gammGradient, gBeta, gammGradient, 1));
        INDArray newGamma = epsilon.reshape(this.xHat.shape()).mul(this.xHat).sum(new int[]{0, 2});
        Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastAddOp(gammGradient, newGamma, gammGradient, 1));
        INDArray coefficients = this.getParam("gamma").div(this.std);
        gBeta.divi((Number)m);
        this.getParam("gammaGradient").divi((Number)m);
        INDArray toMuli = epsilon.reshape(this.xHat.shape()).sub(this.xHat);
        INDArray otherMuli = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(toMuli, gammGradient, toMuli, -1));
        INDArray sub = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastSubOp(otherMuli, gBeta, otherMuli, -1));
        INDArray ret = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastMulOp(sub, coefficients, sub, -1));
        ret = ret.reshape(this.shape);
        DefaultGradient g = new DefaultGradient();
        this.gradient = g;
        return new Pair<Gradient, INDArray>(g, ret);
    }

    @Override
    public void merge(Layer layer, int batchSize) {
    }

    @Override
    public INDArray activationMean() {
        return null;
    }

    @Override
    public void update(Gradient gradient) {
    }

    @Override
    public void fit() {
    }

    @Override
    public void update(INDArray gradient, String paramType) {
    }

    @Override
    public double score() {
        return 0.0;
    }

    @Override
    public void computeGradientAndScore() {
    }

    @Override
    public void accumulateScore(double accum) {
    }

    @Override
    public INDArray params() {
        return Nd4j.create((int)0);
    }

    @Override
    public int numParams() {
        return 0;
    }

    @Override
    public void setParams(INDArray params) {
    }

    @Override
    public void fit(INDArray data) {
    }

    @Override
    public void iterate(INDArray input) {
    }

    @Override
    public Gradient gradient() {
        return this.gradient;
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<Gradient, Double>(this.gradient(), this.score());
    }

    @Override
    public int batchSize() {
        return 0;
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        this.conf = conf;
    }

    @Override
    public INDArray input() {
        return null;
    }

    @Override
    public void validateInput() {
    }

    @Override
    public ConvexOptimizer getOptimizer() {
        return null;
    }

    @Override
    public INDArray getParam(String param) {
        return this.params.get(param);
    }

    @Override
    public void initParams() {
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.params;
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        this.params = paramTable;
    }

    @Override
    public void setParam(String key, INDArray val) {
        this.params.put(key, val);
    }

    @Override
    public void clear() {
    }

    @Override
    public INDArray preOutput(INDArray x) {
        return this.preOutput(x, Layer.TrainingMode.TRAIN);
    }

    @Override
    public INDArray preOutput(INDArray x, Layer.TrainingMode training) {
        INDArray var;
        INDArray mean;
        int[] activationShape = this.getShape(x);
        org.deeplearning4j.nn.conf.layers.BatchNormalization layerConf = (org.deeplearning4j.nn.conf.layers.BatchNormalization)this.conf().getLayer();
        this.shape = activationShape;
        if (training != Layer.TrainingMode.TEST && !layerConf.isUseBatchMean()) {
            mean = x.mean(new int[]{0, 2});
            var = x.var(new int[]{0, 2});
            var.addi((Number)layerConf.getEps());
        } else {
            mean = this.getParam("avgMean");
            var = this.getParam("avgVar");
        }
        this.std = Transforms.sqrt((INDArray)var);
        INDArray xMu = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastSubOp(x, mean, x, -1));
        this.xHat = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastDivOp(xMu, this.std, xMu.dup(), -1));
        INDArray gamma = this.getParam("gamma");
        INDArray beta = this.getParam("beta");
        INDArray out = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastAddOp(this.xHat, gamma, this.xHat.dup(), -1));
        out = Nd4j.getExecutioner().execAndReturn((BroadcastOp)new BroadcastAddOp(out, beta, out, -1));
        double decay = 0.0;
        if (training != Layer.TrainingMode.TEST && !layerConf.isUseBatchMean()) {
            if (layerConf.isFinetune()) {
                layerConf.setN(layerConf.getN() + 1);
                decay = 1.0 / (double)layerConf.getN();
            } else {
                decay = layerConf.getDecay();
            }
            int m = activationShape[0] * activationShape[2];
            double adjust = (double)m / Math.max((double)m - 1.0, 1.0);
            this.getParam("avgMean").muli((Number)decay);
            this.getParam("avgMean").addi(mean.mul((Number)(1.0 - decay)));
            this.getParam("avgVar").muli((Number)decay);
            this.getParam("avgVar").addi(var.mul((Number)((1.0 - decay) * adjust)));
        }
        return out.reshape(x.shape());
    }

    @Override
    public int numParams(boolean backwards) {
        return 0;
    }

    @Override
    public INDArray activate(Layer.TrainingMode training) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray activate(INDArray input, Layer.TrainingMode training) {
        return this.preOutput(input, training);
    }

    @Override
    public INDArray preOutput(INDArray x, boolean training) {
        return this.preOutput(x, training ? Layer.TrainingMode.TRAIN : Layer.TrainingMode.TEST);
    }

    @Override
    public INDArray activate(boolean training) {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray activate(INDArray input, boolean training) {
        return this.preOutput(input, training);
    }

    @Override
    public INDArray activate() {
        throw new UnsupportedOperationException();
    }

    @Override
    public INDArray activate(INDArray input) {
        throw new UnsupportedOperationException();
    }

    @Override
    public Layer transpose() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Layer clone() {
        throw new UnsupportedOperationException();
    }

    @Override
    public Collection<IterationListener> getListeners() {
        return this.listeners;
    }

    @Override
    public void setListeners(IterationListener ... listeners) {
        this.listeners = new ArrayList<IterationListener>(Arrays.asList(listeners));
    }

    @Override
    public void setListeners(Collection<IterationListener> listeners) {
        this.listeners = new ArrayList<IterationListener>(listeners);
    }

    @Override
    public void setIndex(int index) {
        this.index = index;
    }

    @Override
    public int getIndex() {
        return this.index;
    }

    @Override
    public void setInput(INDArray input) {
    }

    @Override
    public void setInputMiniBatchSize(int size) {
    }

    @Override
    public int getInputMiniBatchSize() {
        return 0;
    }

    public int[] getShape(INDArray x) {
        if (x.rank() == 3) {
            int leadDim = x.size(0);
            int cDim = this.getParam("gamma").length();
            int rdim = (int)Math.round((double)x.length() / ((double)leadDim * (double)cDim));
            if (rdim < 1) {
                rdim = 1;
            }
            if (leadDim * cDim * rdim != x.length()) {
                throw new IllegalArgumentException("Illegal input for batch size");
            }
            return new int[]{leadDim, cDim, rdim};
        }
        if (x.rank() == 4) {
            int leadDim = x.size(1);
            int cDim = this.getParam("gamma").length();
            int rdim = (int)Math.round((double)x.length() / ((double)leadDim * (double)cDim));
            if (rdim < 1) {
                rdim = 1;
            }
            if (leadDim * cDim * rdim != x.length()) {
                throw new IllegalArgumentException("Illegal input for batch size");
            }
            return new int[]{leadDim, cDim, rdim};
        }
        throw new IllegalStateException("Unable to process input of rank " + x.rank());
    }
}

