/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn;

import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.OutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.RealDistribution;
import org.apache.commons.math3.random.MersenneTwister;
import org.apache.commons.math3.random.RandomGenerator;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.datasets.DataSet;
import org.deeplearning4j.gradient.NeuralNetworkGradientListener;
import org.deeplearning4j.gradient.multilayer.MultiLayerGradientListener;
import org.deeplearning4j.nn.HiddenLayer;
import org.deeplearning4j.nn.LogisticRegression;
import org.deeplearning4j.nn.NeuralNetwork;
import org.deeplearning4j.nn.Persistable;
import org.deeplearning4j.nn.activation.ActivationFunction;
import org.deeplearning4j.nn.activation.Sigmoid;
import org.deeplearning4j.nn.gradient.LogisticRegressionGradient;
import org.deeplearning4j.nn.gradient.MultiLayerGradient;
import org.deeplearning4j.nn.gradient.NeuralNetworkGradient;
import org.deeplearning4j.optimize.MultiLayerNetworkOptimizer;
import org.deeplearning4j.rng.SynchronizedRandomGenerator;
import org.deeplearning4j.transformation.MatrixTransform;
import org.deeplearning4j.util.MatrixUtil;
import org.deeplearning4j.util.SerializationUtils;
import org.jblas.DoubleMatrix;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class BaseMultiLayerNetwork
implements Serializable,
Persistable {
    private static Logger log = LoggerFactory.getLogger(BaseMultiLayerNetwork.class);
    private static final long serialVersionUID = -5029161847383716484L;
    private int nIns;
    private int[] hiddenLayerSizes;
    private int nOuts;
    private int nLayers;
    private HiddenLayer[] sigmoidLayers;
    private LogisticRegression logLayer;
    private RandomGenerator rng;
    private RealDistribution dist;
    private double momentum = 0.1;
    private DoubleMatrix input;
    private DoubleMatrix labels;
    private MultiLayerNetworkOptimizer optimizer;
    private ActivationFunction activation = new Sigmoid();
    private boolean toDecode;
    private double l2 = 0.01;
    private boolean shouldInit = true;
    private double fanIn = -1.0;
    private int renderWeightsEveryNEpochs = -1;
    private boolean useRegularization = false;
    private Map<Integer, MatrixTransform> weightTransforms = new HashMap<Integer, MatrixTransform>();
    private Map<Integer, MatrixTransform> hiddenBiasTransforms = new HashMap<Integer, MatrixTransform>();
    private Map<Integer, MatrixTransform> visibleBiasTransforms = new HashMap<Integer, MatrixTransform>();
    private boolean shouldBackProp = true;
    private boolean forceNumEpochs = false;
    private double sparsity = 0.0;
    private DoubleMatrix columnSums;
    private DoubleMatrix columnMeans;
    private DoubleMatrix columnStds;
    private boolean initCalled = false;
    private boolean useHiddenActivationsForwardProp = true;
    private boolean useAdaGrad = false;
    public double learningRateUpdate = 0.95;
    private NeuralNetwork[] layers;
    public double errorTolerance = 1.0E-4;
    protected Map<Integer, List<NeuralNetworkGradientListener>> gradientListeners = new HashMap<Integer, List<NeuralNetworkGradientListener>>();
    protected List<MultiLayerGradientListener> multiLayerGradientListeners = new ArrayList<MultiLayerGradientListener>();
    protected double dropOut = 0.0;
    protected boolean normalizeByInputRows = false;
    protected NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm;
    protected NeuralNetwork.LossFunction lossFunction;

    protected BaseMultiLayerNetwork() {
    }

    protected BaseMultiLayerNetwork(int nIns, int[] hiddenLayerSizes, int nOuts, int nLayers, RandomGenerator rng) {
        this(nIns, hiddenLayerSizes, nOuts, nLayers, rng, null, null);
    }

    protected BaseMultiLayerNetwork(int nIn, int[] hiddenLayerSizes, int nOuts, int nLayers, RandomGenerator rng, DoubleMatrix input, DoubleMatrix labels) {
        this.nIns = nIn;
        this.hiddenLayerSizes = hiddenLayerSizes;
        this.input = input.dup();
        this.labels = labels.dup();
        if (hiddenLayerSizes.length != nLayers) {
            throw new IllegalArgumentException("The number of hidden layer sizes must be equivalent to the nLayers argument which is a value of " + nLayers);
        }
        this.nOuts = nOuts;
        this.nLayers = nLayers;
        this.sigmoidLayers = new HiddenLayer[nLayers];
        this.layers = this.createNetworkLayers(nLayers);
        this.rng = rng == null ? new SynchronizedRandomGenerator((RandomGenerator)new MersenneTwister(123)) : rng;
        if (input != null) {
            this.initializeLayers(input);
        }
    }

    public double fanIn() {
        if (this.fanIn < 0.0) {
            return 1.0 / (double)this.nIns;
        }
        return this.fanIn;
    }

    private void dimensionCheck() {
        for (int i = 0; i < this.nLayers; ++i) {
            HiddenLayer h = this.sigmoidLayers[i];
            NeuralNetwork network = this.layers[i];
            h.getW().assertSameSize(network.getW());
            h.getB().assertSameSize(network.gethBias());
            if (i >= this.nLayers - 1) continue;
            HiddenLayer h1 = this.sigmoidLayers[i + 1];
            NeuralNetwork network1 = this.layers[i + 1];
            if (h1.getnIn() != h.getnOut()) {
                throw new IllegalStateException("Invalid structure: hidden layer in for " + (i + 1) + " not equal to number of ins " + i);
            }
            if (network.getnHidden() == network1.getnVisible()) continue;
            throw new IllegalStateException("Invalid structure: network hidden for " + (i + 1) + " not equal to number of visible " + i);
        }
        if (this.sigmoidLayers[this.sigmoidLayers.length - 1].getnOut() != this.logLayer.getnIn()) {
            throw new IllegalStateException("Number of outputs for final hidden layer not equal to the number of logistic input units for output layer");
        }
    }

    public void synchonrizeRng() {
        SynchronizedRandomGenerator rgen = new SynchronizedRandomGenerator(this.rng);
        for (int i = 0; i < this.nLayers; ++i) {
            this.layers[i].setRng(rgen);
            this.sigmoidLayers[i].setRng(rgen);
        }
    }

    public void resetAdaGrad(double lr) {
        for (int i = 0; i < this.nLayers; ++i) {
            this.layers[i].resetAdaGrad(lr);
        }
        this.logLayer.resetAdaGrad(lr);
    }

    public double getReconstructionCrossEntropy() {
        double sum = 0.0;
        for (int i = 0; i < this.nLayers; ++i) {
            sum += this.layers[i].getReConstructionCrossEntropy();
        }
        return sum /= (double)this.nLayers;
    }

    public void asDecoder(BaseMultiLayerNetwork network) {
        this.createNetworkLayers(network.nLayers + 1);
        this.layers = new NeuralNetwork[network.nLayers];
        this.sigmoidLayers = new HiddenLayer[network.nLayers];
        this.hiddenLayerSizes = new int[network.nLayers];
        this.nIns = network.nOuts;
        this.nOuts = network.nIns;
        this.nLayers = network.nLayers;
        this.dist = network.dist;
        int count = 0;
        for (int i = network.nLayers - 1; i >= 0; --i) {
            this.layers[count] = network.layers[i].clone();
            this.layers[count].setRng(network.layers[i].getRng());
            this.hiddenLayerSizes[count] = network.hiddenLayerSizes[i];
            ++count;
        }
        this.rng = network.rng;
        this.shouldInit = false;
    }

    public void initializeLayers(DoubleMatrix input) {
        if (input == null) {
            throw new IllegalArgumentException("Unable to initialize layers with empty input");
        }
        if (input.columns != this.nIns) {
            throw new IllegalArgumentException(String.format("Unable to train on number of inputs; columns should be equal to number of inputs. Number of inputs was %d while number of columns was %d", this.nIns, input.columns));
        }
        if (this.layers == null) {
            this.layers = new NeuralNetwork[this.nLayers];
        }
        for (int i = 0; i < this.hiddenLayerSizes.length; ++i) {
            if (this.hiddenLayerSizes[i] >= 1) continue;
            throw new IllegalArgumentException("All hidden layer sizes must be >= 1");
        }
        this.input = input.dup();
        if (!this.initCalled) {
            this.init();
        } else {
            this.feedForward(input);
        }
    }

    public void init() {
        DoubleMatrix layerInput = this.input;
        if (this.nLayers < 1) {
            throw new IllegalStateException("Unable to create network layers; number specified is less than 1");
        }
        if (this.dist == null) {
            this.dist = new NormalDistribution(this.rng, 0.0, 0.01, 1.0E-9);
        }
        this.layers = new NeuralNetwork[this.nLayers];
        for (int i = 0; i < this.nLayers; ++i) {
            int inputSize = i == 0 ? this.nIns : this.hiddenLayerSizes[i - 1];
            if (i == 0) {
                this.sigmoidLayers[i] = this.createHiddenLayer(i, inputSize, this.hiddenLayerSizes[i], this.activation, this.rng, layerInput, this.dist);
            } else {
                if (this.input != null) {
                    layerInput = this.useHiddenActivationsForwardProp ? this.sigmoidLayers[i - 1].sampleHiddenGivenVisible() : this.getLayers()[i - 1].sampleHiddenGivenVisible(layerInput).getSecond();
                }
                this.sigmoidLayers[i] = this.createHiddenLayer(i, inputSize, this.hiddenLayerSizes[i], this.activation, this.rng, layerInput, this.dist);
            }
            this.layers[i] = this.createLayer(layerInput, inputSize, this.hiddenLayerSizes[i], this.sigmoidLayers[i].getW(), this.sigmoidLayers[i].getB(), null, this.rng, i);
        }
        this.logLayer = new LogisticRegression.Builder().useAdaGrad(this.useAdaGrad).normalizeByInputRows(this.normalizeByInputRows).useRegularization(this.useRegularization).numberOfInputs(this.hiddenLayerSizes[this.nLayers - 1]).numberOfOutputs(this.nOuts).withL2(this.l2).build();
        this.dimensionCheck();
        this.applyTransforms();
        this.initCalled = true;
    }

    protected void initializeNetwork(NeuralNetwork network) {
        network.setFanIn(this.fanIn);
        network.setRenderEpochs(this.renderWeightsEveryNEpochs);
    }

    public void finetune(double lr, int epochs) {
        this.finetune(this.labels, lr, epochs);
    }

    public void initialize(DataSet data) {
        this.setInput((DoubleMatrix)data.getFirst());
        this.feedForward((DoubleMatrix)data.getFirst());
        this.labels = (DoubleMatrix)data.getSecond();
        this.logLayer.setLabels(this.labels);
    }

    public MultiLayerGradient getGradient(Object[] params) {
        ArrayList<NeuralNetworkGradient> gradient = new ArrayList<NeuralNetworkGradient>();
        for (NeuralNetwork network : this.layers) {
            gradient.add(network.getGradient(params));
        }
        double lr = 0.01;
        if (params.length >= 2) {
            lr = (Double)params[1];
        }
        this.feedForward(this.input);
        LogisticRegressionGradient g2 = this.logLayer.getGradient(lr);
        MultiLayerGradient ret = new MultiLayerGradient(gradient, g2);
        if (this.multiLayerGradientListeners != null && !this.multiLayerGradientListeners.isEmpty()) {
            for (MultiLayerGradientListener listener : this.multiLayerGradientListeners) {
                listener.onMultiLayerGradient(ret);
            }
        }
        return ret;
    }

    public List<DoubleMatrix> feedForward() {
        DoubleMatrix currInput = this.input;
        ArrayList<DoubleMatrix> activations = new ArrayList<DoubleMatrix>();
        activations.add(currInput);
        for (int i = 0; i < this.nLayers; ++i) {
            this.getLayers()[i].setInput(currInput);
            this.getSigmoidLayers()[i].setInput(this.input);
            currInput = this.useHiddenActivationsForwardProp ? this.getSigmoidLayers()[i].activate(currInput) : this.getLayers()[i].sampleHiddenGivenVisible(currInput).getSecond();
            activations.add(currInput);
        }
        this.logLayer.setInput(currInput);
        activations.add(this.getLogLayer().predict(currInput));
        return activations;
    }

    public List<DoubleMatrix> feedForward(DoubleMatrix input) {
        if (input == null) {
            throw new IllegalStateException("Unable to perform feed forward; no input found");
        }
        this.input = input;
        return this.feedForward();
    }

    private void computeDeltas(List<Pair<DoubleMatrix, DoubleMatrix>> deltaRet) {
        int i;
        DoubleMatrix[] gradients = new DoubleMatrix[this.nLayers + 2];
        DoubleMatrix[] deltas = new DoubleMatrix[this.nLayers + 2];
        ActivationFunction derivative = this.getSigmoidLayers()[0].getActivationFunction();
        DoubleMatrix delta = null;
        List<DoubleMatrix> activations = this.feedForward(this.getInput());
        ArrayList<DoubleMatrix> weights = new ArrayList<DoubleMatrix>();
        for (int j = 0; j < this.getLayers().length; ++j) {
            weights.add(this.getLayers()[j].getW());
        }
        weights.add(this.getLogLayer().getW());
        DoubleMatrix labels = this.labels;
        for (i = this.nLayers + 1; i >= 0; --i) {
            if (i >= this.nLayers + 1) {
                DoubleMatrix z = activations.get(i);
                deltas[i] = delta = labels.sub(z).neg().muli(derivative.applyDerivative(z));
                continue;
            }
            delta = deltas[i + 1];
            DoubleMatrix w = ((DoubleMatrix)weights.get(i)).transpose();
            DoubleMatrix z = activations.get(i);
            DoubleMatrix zDerivative = derivative.applyDerivative(z);
            DoubleMatrix error = delta.mmul(w);
            error.muli(zDerivative);
            deltas[i] = error.dup();
            DoubleMatrix lastLayerDelta = deltas[i + 1].transpose();
            DoubleMatrix newGradient = lastLayerDelta.mmul(z);
            if (this.normalizeByInputRows) {
                newGradient.divi((double)this.getInput().rows);
            }
            gradients[i] = newGradient;
        }
        for (i = 0; i < gradients.length; ++i) {
            deltaRet.add(new Pair<DoubleMatrix, DoubleMatrix>(gradients[i], deltas[i]));
        }
    }

    public BaseMultiLayerNetwork clone() {
        Object ret = new Builder().withClazz(this.getClass()).buildEmpty();
        ((BaseMultiLayerNetwork)ret).update(this);
        return ret;
    }

    public void backProp(double lr, int epochs) {
        Double lastEntropy = this.negativeLogLikelihood();
        BaseMultiLayerNetwork revert = this.clone();
        if (this.forceNumEpochs) {
            for (int i = 0; i < epochs; ++i) {
                this.backPropStep(revert, lr, i);
                lastEntropy = this.negativeLogLikelihood();
            }
        } else {
            boolean train = true;
            int count = 0;
            int numOver = 0;
            int tolerance = 3;
            double changeTolerance = 1.0E-5;
            while (train) {
                this.backPropStep(revert, lr, ++count);
                this.getLogLayer().trainTillConvergence(lr, epochs);
                Double entropy = this.negativeLogLikelihood();
                if (lastEntropy == null || entropy < lastEntropy) {
                    double diff = Math.abs(entropy - lastEntropy);
                    if (diff < changeTolerance) {
                        log.info("Not enough of a change on back prop...breaking");
                        break;
                    }
                    lastEntropy = entropy;
                    log.info("New negative log likelihood " + lastEntropy);
                    this.getLogLayer().trainTillConvergence(lr, epochs);
                    continue;
                }
                if (entropy >= lastEntropy) {
                    this.update(revert);
                    log.info("Last change no good...reverting");
                    if (++numOver < tolerance) continue;
                    train = false;
                    continue;
                }
                if (entropy != lastEntropy) continue;
                train = false;
            }
        }
    }

    protected void backPropStep(BaseMultiLayerNetwork revert, double lr, int epoch) {
        ArrayList<Pair<DoubleMatrix, DoubleMatrix>> deltas = new ArrayList<Pair<DoubleMatrix, DoubleMatrix>>();
        this.computeDeltas(deltas);
        for (int l = 0; l < this.nLayers; ++l) {
            DoubleMatrix add = (DoubleMatrix)((Pair)deltas.get(l)).getFirst();
            if (this.isUseAdaGrad()) {
                add.muli(this.getLayers()[l].getAdaGrad().getLearningRates(add));
            } else {
                add.muli(lr);
            }
            if (this.normalizeByInputRows) {
                add.divi((double)this.input.rows);
            }
            if (this.useRegularization) {
                add.muli(this.getLayers()[l].getW().mul(this.l2));
            }
            this.getLayers()[l].getW().addi(add);
            this.getSigmoidLayers()[l].setW(this.layers[l].getW());
            DoubleMatrix deltaColumnSums = ((DoubleMatrix)((Pair)deltas.get(l + 1)).getSecond()).columnSums();
            if (this.normalizeByInputRows) {
                deltaColumnSums.divi((double)this.input.rows);
            }
            if (this.sparsity != 0.0) {
                deltaColumnSums = MatrixUtil.scalarMinus(this.sparsity, deltaColumnSums);
            }
            this.getLayers()[l].gethBias().addi(deltaColumnSums.mul(lr));
            this.getSigmoidLayers()[l].setB(this.getLayers()[l].gethBias());
        }
        this.getLogLayer().getW().addi((DoubleMatrix)((Pair)deltas.get(this.nLayers)).getFirst());
    }

    public void finetune(DoubleMatrix labels, double lr, int epochs) {
        if (labels != null) {
            this.labels = labels;
        }
        this.optimizer = new MultiLayerNetworkOptimizer(this, lr);
        this.optimizer.optimize(this.labels, lr, epochs);
    }

    public DoubleMatrix predict(DoubleMatrix x) {
        DoubleMatrix col;
        int i;
        List<DoubleMatrix> activations = this.feedForward(x);
        if (this.columnSums != null) {
            for (i = 0; i < x.columns; ++i) {
                col = x.getColumn(i);
                col = col.div(this.columnSums.get(0, i));
                x.putColumn(i, col);
            }
        }
        if (this.columnMeans != null) {
            for (i = 0; i < x.columns; ++i) {
                col = x.getColumn(i);
                col = col.sub(this.columnMeans.get(0, i));
                x.putColumn(i, col);
            }
        }
        if (this.columnStds != null) {
            for (i = 0; i < x.columns; ++i) {
                col = x.getColumn(i);
                col = col.div(this.columnStds.get(0, i));
                x.putColumn(i, col);
            }
        }
        if (this.input == null) {
            this.initializeLayers(x);
        }
        DoubleMatrix predicted = activations.get(activations.size() - 1);
        return predicted;
    }

    public DoubleMatrix reconstruct(DoubleMatrix x, int layerNum) {
        DoubleMatrix col;
        int i;
        if (layerNum > this.nLayers || layerNum < 0) {
            throw new IllegalArgumentException("Layer number " + layerNum + " does not exist");
        }
        if (this.columnSums != null) {
            for (i = 0; i < x.columns; ++i) {
                col = x.getColumn(i);
                col = col.div(this.columnSums.get(0, i));
                x.putColumn(i, col);
            }
        }
        if (this.columnMeans != null) {
            for (i = 0; i < x.columns; ++i) {
                col = x.getColumn(i);
                col = col.sub(this.columnMeans.get(0, i));
                x.putColumn(i, col);
            }
        }
        if (this.columnStds != null) {
            for (i = 0; i < x.columns; ++i) {
                col = x.getColumn(i);
                col = col.div(this.columnStds.get(0, i));
                x.putColumn(i, col);
            }
        }
        DoubleMatrix input = x;
        for (int i2 = 0; i2 < layerNum; ++i2) {
            HiddenLayer layer = this.sigmoidLayers[i2];
            input = layer.activate(input);
        }
        return input;
    }

    public DoubleMatrix reconstruct(DoubleMatrix x) {
        return this.reconstruct(x, this.sigmoidLayers.length);
    }

    @Override
    public void write(OutputStream os) {
        SerializationUtils.writeObject(this, os);
    }

    @Override
    public void load(InputStream is) {
        BaseMultiLayerNetwork loaded = (BaseMultiLayerNetwork)SerializationUtils.readObject(is);
        this.update(loaded);
    }

    public static BaseMultiLayerNetwork loadFromFile(InputStream is) {
        try {
            ObjectInputStream ois = new ObjectInputStream(is);
            log.info("Loading network model...");
            BaseMultiLayerNetwork loaded = (BaseMultiLayerNetwork)ois.readObject();
            return loaded;
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
    }

    protected synchronized void update(BaseMultiLayerNetwork network) {
        int i;
        if (network.layers != null && network.layers.length > 0) {
            this.layers = new NeuralNetwork[this.nLayers];
            for (i = 0; i < this.layers.length; ++i) {
                this.getLayers()[i] = network.getLayers()[i].clone();
            }
        }
        this.normalizeByInputRows = network.normalizeByInputRows;
        this.useAdaGrad = network.useAdaGrad;
        this.hiddenLayerSizes = network.hiddenLayerSizes;
        if (network.logLayer != null) {
            this.logLayer = network.logLayer.clone();
        }
        this.nIns = network.nIns;
        this.nLayers = network.nLayers;
        this.nOuts = network.nOuts;
        this.rng = network.rng;
        this.dist = network.dist;
        this.activation = network.activation;
        this.useRegularization = network.useRegularization;
        this.columnMeans = network.columnMeans;
        this.columnStds = network.columnStds;
        this.columnSums = network.columnSums;
        this.errorTolerance = network.errorTolerance;
        this.renderWeightsEveryNEpochs = network.renderWeightsEveryNEpochs;
        this.forceNumEpochs = network.forceNumEpochs;
        this.input = network.input;
        this.l2 = network.l2;
        this.fanIn = network.fanIn;
        this.labels = network.labels;
        this.momentum = network.momentum;
        this.learningRateUpdate = network.learningRateUpdate;
        this.shouldBackProp = network.shouldBackProp;
        this.weightTransforms = network.weightTransforms;
        this.sparsity = network.sparsity;
        this.toDecode = network.toDecode;
        this.visibleBiasTransforms = network.visibleBiasTransforms;
        this.hiddenBiasTransforms = network.hiddenBiasTransforms;
        this.dropOut = network.dropOut;
        this.optimizationAlgorithm = network.optimizationAlgorithm;
        this.lossFunction = network.lossFunction;
        if (network.sigmoidLayers != null && network.sigmoidLayers.length > 0) {
            this.sigmoidLayers = new HiddenLayer[network.sigmoidLayers.length];
            for (i = 0; i < this.sigmoidLayers.length; ++i) {
                this.getSigmoidLayers()[i] = network.getSigmoidLayers()[i].clone();
            }
        }
    }

    public double negativeLogLikelihood() {
        return this.logLayer.negativeLogLikelihood();
    }

    public abstract void trainNetwork(DoubleMatrix var1, DoubleMatrix var2, Object[] var3);

    public abstract void pretrain(DoubleMatrix var1, Object[] var2);

    protected void applyTransforms() {
        if (this.layers == null || this.layers.length < 1) {
            throw new IllegalStateException("Layers not initialized");
        }
        for (int i = 0; i < this.layers.length; ++i) {
            if (this.weightTransforms.containsKey(i)) {
                this.layers[i].setW((DoubleMatrix)this.weightTransforms.get(i).apply(this.layers[i].getW()));
            }
            if (this.hiddenBiasTransforms.containsKey(i)) {
                this.layers[i].sethBias((DoubleMatrix)this.getHiddenBiasTransforms().get(i).apply(this.layers[i].gethBias()));
            }
            if (!this.visibleBiasTransforms.containsKey(i)) continue;
            this.layers[i].setvBias((DoubleMatrix)this.getVisibleBiasTransforms().get(i).apply(this.layers[i].getvBias()));
        }
    }

    public abstract NeuralNetwork createLayer(DoubleMatrix var1, int var2, int var3, DoubleMatrix var4, DoubleMatrix var5, DoubleMatrix var6, RandomGenerator var7, int var8);

    public abstract NeuralNetwork[] createNetworkLayers(int var1);

    public HiddenLayer createHiddenLayer(int index, int nIn, int nOut, ActivationFunction activation, RandomGenerator rng, DoubleMatrix layerInput, RealDistribution dist) {
        return new HiddenLayer.Builder().nIn(nIn).nOut(nOut).withActivation(activation).withRng(rng).withInput(layerInput).dist(dist).build();
    }

    public void merge(BaseMultiLayerNetwork network, int batchSize) {
        if (network.nLayers != this.nLayers) {
            throw new IllegalArgumentException("Unable to merge networks that are not of equal length");
        }
        for (int i = 0; i < this.nLayers; ++i) {
            NeuralNetwork n = this.layers[i];
            NeuralNetwork otherNetwork = network.layers[i];
            n.merge(otherNetwork, batchSize);
            this.getSigmoidLayers()[i].setB(n.gethBias());
            this.getSigmoidLayers()[i].setW(n.getW());
        }
        this.getLogLayer().merge(network.logLayer, batchSize);
    }

    public void encode(BaseMultiLayerNetwork network) {
        this.createNetworkLayers(network.nLayers);
        this.layers = new NeuralNetwork[network.nLayers];
        this.hiddenLayerSizes = new int[this.nLayers];
        int count = 0;
        for (int i = this.nLayers - 1; i > 0; --i) {
            NeuralNetwork n = network.layers[i].clone();
            HiddenLayer l = network.sigmoidLayers[i].clone();
            this.layers[count] = n;
            this.sigmoidLayers[count] = l;
            this.hiddenLayerSizes[count] = network.hiddenLayerSizes[i];
            ++count;
        }
        this.logLayer = new LogisticRegression(this.hiddenLayerSizes[this.nLayers - 1], network.input.columns);
    }

    public DoubleMatrix getLabels() {
        return this.labels;
    }

    public LogisticRegression getLogLayer() {
        return this.logLayer;
    }

    public void setInput(DoubleMatrix input) {
        this.input = input;
        if (input != null && this.layers == null) {
            this.initializeLayers(input);
        }
    }

    public boolean isShouldBackProp() {
        return this.shouldBackProp;
    }

    public NeuralNetwork.OptimizationAlgorithm getOptimizationAlgorithm() {
        return this.optimizationAlgorithm;
    }

    public void setOptimizationAlgorithm(NeuralNetwork.OptimizationAlgorithm optimizationAlgorithm) {
        this.optimizationAlgorithm = optimizationAlgorithm;
    }

    public NeuralNetwork.LossFunction getLossFunction() {
        return this.lossFunction;
    }

    public void setLossFunction(NeuralNetwork.LossFunction lossFunction) {
        this.lossFunction = lossFunction;
    }

    public DoubleMatrix getInput() {
        return this.input;
    }

    public synchronized HiddenLayer[] getSigmoidLayers() {
        return this.sigmoidLayers;
    }

    public synchronized NeuralNetwork[] getLayers() {
        return this.layers;
    }

    public boolean isForceNumEpochs() {
        return this.forceNumEpochs;
    }

    public DoubleMatrix getColumnSums() {
        return this.columnSums;
    }

    public void setColumnSums(DoubleMatrix columnSums) {
        this.columnSums = columnSums;
    }

    public int[] getHiddenLayerSizes() {
        return this.hiddenLayerSizes;
    }

    public void setHiddenLayerSizes(int[] hiddenLayerSizes) {
        this.hiddenLayerSizes = hiddenLayerSizes;
    }

    public RandomGenerator getRng() {
        return this.rng;
    }

    public void setRng(RandomGenerator rng) {
        this.rng = rng;
    }

    public RealDistribution getDist() {
        return this.dist;
    }

    public void setDist(RealDistribution dist) {
        this.dist = dist;
    }

    public MultiLayerNetworkOptimizer getOptimizer() {
        return this.optimizer;
    }

    public void setOptimizer(MultiLayerNetworkOptimizer optimizer) {
        this.optimizer = optimizer;
    }

    public ActivationFunction getActivation() {
        return this.activation;
    }

    public void setActivation(ActivationFunction activation) {
        this.activation = activation;
    }

    public boolean isToDecode() {
        return this.toDecode;
    }

    public void setToDecode(boolean toDecode) {
        this.toDecode = toDecode;
    }

    public boolean isShouldInit() {
        return this.shouldInit;
    }

    public void setShouldInit(boolean shouldInit) {
        this.shouldInit = shouldInit;
    }

    public double getFanIn() {
        return this.fanIn;
    }

    public void setFanIn(double fanIn) {
        this.fanIn = fanIn;
    }

    public int getRenderWeightsEveryNEpochs() {
        return this.renderWeightsEveryNEpochs;
    }

    public void setRenderWeightsEveryNEpochs(int renderWeightsEveryNEpochs) {
        this.renderWeightsEveryNEpochs = renderWeightsEveryNEpochs;
    }

    public Map<Integer, MatrixTransform> getWeightTransforms() {
        return this.weightTransforms;
    }

    public void setWeightTransforms(Map<Integer, MatrixTransform> weightTransforms) {
        this.weightTransforms = weightTransforms;
    }

    public double getSparsity() {
        return this.sparsity;
    }

    public void setSparsity(double sparsity) {
        this.sparsity = sparsity;
    }

    public double getLearningRateUpdate() {
        return this.learningRateUpdate;
    }

    public void setLearningRateUpdate(double learningRateUpdate) {
        this.learningRateUpdate = learningRateUpdate;
    }

    public double getErrorTolerance() {
        return this.errorTolerance;
    }

    public void setErrorTolerance(double errorTolerance) {
        this.errorTolerance = errorTolerance;
    }

    public void setLabels(DoubleMatrix labels) {
        this.labels = labels;
    }

    public void setForceNumEpochs(boolean forceNumEpochs) {
        this.forceNumEpochs = forceNumEpochs;
    }

    public DoubleMatrix getColumnMeans() {
        return this.columnMeans;
    }

    public void setColumnMeans(DoubleMatrix columnMeans) {
        this.columnMeans = columnMeans;
    }

    public DoubleMatrix getColumnStds() {
        return this.columnStds;
    }

    public void setColumnStds(DoubleMatrix columnStds) {
        this.columnStds = columnStds;
    }

    public boolean isUseAdaGrad() {
        return this.useAdaGrad;
    }

    public void setUseAdaGrad(boolean useAdaGrad) {
        this.useAdaGrad = useAdaGrad;
    }

    public boolean isNormalizeByInputRows() {
        return this.normalizeByInputRows;
    }

    public void setNormalizeByInputRows(boolean normalizeByInputRows) {
        this.normalizeByInputRows = normalizeByInputRows;
    }

    public boolean isUseHiddenActivationsForwardProp() {
        return this.useHiddenActivationsForwardProp;
    }

    public void setUseHiddenActivationsForwardProp(boolean useHiddenActivationsForwardProp) {
        this.useHiddenActivationsForwardProp = useHiddenActivationsForwardProp;
    }

    public double getDropOut() {
        return this.dropOut;
    }

    public void setDropOut(double dropOut) {
        this.dropOut = dropOut;
    }

    public Map<Integer, MatrixTransform> getHiddenBiasTransforms() {
        return this.hiddenBiasTransforms;
    }

    public Map<Integer, MatrixTransform> getVisibleBiasTransforms() {
        return this.visibleBiasTransforms;
    }

    public int getnIns() {
        return this.nIns;
    }

    public void setnIns(int nIns) {
        this.nIns = nIns;
    }

    public int getnOuts() {
        return this.nOuts;
    }

    public void setnOuts(int nOuts) {
        this.nOuts = nOuts;
    }

    public int getnLayers() {
        return this.nLayers;
    }

    public void setnLayers(int nLayers) {
        this.nLayers = nLayers;
    }

    public double getMomentum() {
        return this.momentum;
    }

    public void setMomentum(double momentum) {
        this.momentum = momentum;
    }

    public double getL2() {
        return this.l2;
    }

    public void setL2(double l2) {
        this.l2 = l2;
    }

    public boolean isUseRegularization() {
        return this.useRegularization;
    }

    public void setUseRegularization(boolean useRegularization) {
        this.useRegularization = useRegularization;
    }

    public void setSigmoidLayers(HiddenLayer[] sigmoidLayers) {
        this.sigmoidLayers = sigmoidLayers;
    }

    public void setLogLayer(LogisticRegression logLayer) {
        this.logLayer = logLayer;
    }

    public void setShouldBackProp(boolean shouldBackProp) {
        this.shouldBackProp = shouldBackProp;
    }

    public void setLayers(NeuralNetwork[] layers) {
        this.layers = layers;
    }

    public static class Builder<E extends BaseMultiLayerNetwork> {
        protected Class<? extends BaseMultiLayerNetwork> clazz;
        private E ret;
        private int nIns;
        private int[] hiddenLayerSizes;
        private int nOuts;
        private int nLayers;
        private RandomGenerator rng = new MersenneTwister(1234);
        private DoubleMatrix input;
        private DoubleMatrix labels;
        private ActivationFunction activation;
        private boolean decode = false;
        private double fanIn = -1.0;
        private int renderWeithsEveryNEpochs = -1;
        private double l2 = 0.01;
        private boolean useRegularization = false;
        private double momentum;
        private RealDistribution dist;
        protected Map<Integer, MatrixTransform> weightTransforms = new HashMap<Integer, MatrixTransform>();
        protected boolean backProp = true;
        protected boolean shouldForceEpochs = false;
        private double sparsity = 0.0;
        private Map<Integer, MatrixTransform> hiddenBiasTransforms = new HashMap<Integer, MatrixTransform>();
        private Map<Integer, MatrixTransform> visibleBiasTransforms = new HashMap<Integer, MatrixTransform>();
        private boolean useAdaGrad = false;
        private Map<Integer, List<NeuralNetworkGradientListener>> gradientListeners = new HashMap<Integer, List<NeuralNetworkGradientListener>>();
        private List<MultiLayerGradientListener> multiLayerGradientListeners = new ArrayList<MultiLayerGradientListener>();
        private boolean normalizeByInputRows = false;
        private boolean useHiddenActivationsForwardProp = true;
        private double dropOut = 0.0;
        private NeuralNetwork.LossFunction lossFunction = NeuralNetwork.LossFunction.RECONSTRUCTION_CROSSENTROPY;
        private NeuralNetwork.OptimizationAlgorithm optimizationAlgo = NeuralNetwork.OptimizationAlgorithm.CONJUGATE_GRADIENT;

        public Builder<E> withOptimizationAlgorithm(NeuralNetwork.OptimizationAlgorithm optimizationAlgo) {
            this.optimizationAlgo = optimizationAlgo;
            return this;
        }

        public Builder<E> withLossFunction(NeuralNetwork.LossFunction lossFunction) {
            this.lossFunction = lossFunction;
            return this;
        }

        public Builder<E> withDropOut(double dropOut) {
            this.dropOut = dropOut;
            return this;
        }

        public Builder<E> useHiddenActivationsForwardProp(boolean useHiddenActivationsForwardProp) {
            this.useHiddenActivationsForwardProp = useHiddenActivationsForwardProp;
            return this;
        }

        public Builder<E> normalizeByInputRows(boolean normalizeByInputRows) {
            this.normalizeByInputRows = normalizeByInputRows;
            return this;
        }

        public Builder<E> withMultiLayerGradientListeners(List<MultiLayerGradientListener> multiLayerGradientListeners) {
            this.multiLayerGradientListeners.addAll(multiLayerGradientListeners);
            return this;
        }

        public Builder<E> withGradientListeners(Map<Integer, List<NeuralNetworkGradientListener>> gradientListeners) {
            this.gradientListeners.putAll(gradientListeners);
            return this;
        }

        public Builder<E> useAdaGrad(boolean useAdaGrad) {
            this.useAdaGrad = useAdaGrad;
            return this;
        }

        public Builder<E> withSparsity(double sparsity) {
            this.sparsity = sparsity;
            return this;
        }

        public Builder<E> withVisibleBiasTransforms(Map<Integer, MatrixTransform> visibleBiasTransforms) {
            this.visibleBiasTransforms = visibleBiasTransforms;
            return this;
        }

        public Builder<E> withHiddenBiasTransforms(Map<Integer, MatrixTransform> hiddenBiasTransforms) {
            this.hiddenBiasTransforms = hiddenBiasTransforms;
            return this;
        }

        public Builder<E> forceEpochs() {
            this.shouldForceEpochs = true;
            return this;
        }

        public Builder<E> disableBackProp() {
            this.backProp = false;
            return this;
        }

        public Builder<E> transformWeightsAt(int layer, MatrixTransform transform) {
            this.weightTransforms.put(layer, transform);
            return this;
        }

        public Builder<E> transformWeightsAt(Map<Integer, MatrixTransform> transforms) {
            this.weightTransforms.putAll(transforms);
            return this;
        }

        public Builder<E> withDist(RealDistribution dist) {
            this.dist = dist;
            return this;
        }

        public Builder<E> withMomentum(double momentum) {
            this.momentum = momentum;
            return this;
        }

        public Builder<E> useRegularization(boolean useRegularization) {
            this.useRegularization = useRegularization;
            return this;
        }

        public Builder<E> withL2(double l2) {
            this.l2 = l2;
            return this;
        }

        public Builder<E> renderWeights(int everyN) {
            this.renderWeithsEveryNEpochs = everyN;
            return this;
        }

        public Builder<E> withFanIn(Double fanIn) {
            this.fanIn = fanIn;
            return this;
        }

        public Builder<E> withActivation(ActivationFunction activation) {
            this.activation = activation;
            return this;
        }

        public Builder<E> numberOfInputs(int nIns) {
            this.nIns = nIns;
            return this;
        }

        public Builder<E> decodeNetwork(boolean decode) {
            this.decode = decode;
            return this;
        }

        public Builder<E> hiddenLayerSizes(int[] hiddenLayerSizes) {
            this.hiddenLayerSizes = hiddenLayerSizes;
            this.nLayers = hiddenLayerSizes.length;
            return this;
        }

        public Builder<E> numberOfOutPuts(int nOuts) {
            this.nOuts = nOuts;
            return this;
        }

        public Builder<E> withRng(RandomGenerator gen) {
            this.rng = gen;
            return this;
        }

        public Builder<E> withInput(DoubleMatrix input) {
            this.input = input;
            return this;
        }

        public Builder<E> withLabels(DoubleMatrix labels) {
            this.labels = labels;
            return this;
        }

        public Builder<E> withClazz(Class<? extends BaseMultiLayerNetwork> clazz) {
            this.clazz = clazz;
            return this;
        }

        public E buildEmpty() {
            try {
                return (E)this.clazz.newInstance();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }

        public E build() {
            try {
                this.ret = this.clazz.newInstance();
                ((BaseMultiLayerNetwork)this.ret).setNormalizeByInputRows(this.normalizeByInputRows);
                ((BaseMultiLayerNetwork)this.ret).setInput(this.input);
                ((BaseMultiLayerNetwork)this.ret).setnOuts(this.nOuts);
                ((BaseMultiLayerNetwork)this.ret).setnIns(this.nIns);
                ((BaseMultiLayerNetwork)this.ret).setLabels(this.labels);
                ((BaseMultiLayerNetwork)this.ret).setHiddenLayerSizes(this.hiddenLayerSizes);
                ((BaseMultiLayerNetwork)this.ret).setnLayers(this.nLayers);
                ((BaseMultiLayerNetwork)this.ret).setRng(this.rng);
                ((BaseMultiLayerNetwork)this.ret).setShouldBackProp(this.backProp);
                ((BaseMultiLayerNetwork)this.ret).setSigmoidLayers(new HiddenLayer[((BaseMultiLayerNetwork)this.ret).getnLayers()]);
                ((BaseMultiLayerNetwork)this.ret).setUseHiddenActivationsForwardProp(this.useHiddenActivationsForwardProp);
                ((BaseMultiLayerNetwork)this.ret).setToDecode(this.decode);
                ((BaseMultiLayerNetwork)this.ret).setInput(this.input);
                ((BaseMultiLayerNetwork)this.ret).setMomentum(this.momentum);
                ((BaseMultiLayerNetwork)this.ret).setLabels(this.labels);
                ((BaseMultiLayerNetwork)this.ret).setFanIn(this.fanIn);
                ((BaseMultiLayerNetwork)this.ret).setSparsity(this.sparsity);
                ((BaseMultiLayerNetwork)this.ret).setRenderWeightsEveryNEpochs(this.renderWeithsEveryNEpochs);
                ((BaseMultiLayerNetwork)this.ret).setL2(this.l2);
                ((BaseMultiLayerNetwork)this.ret).setForceNumEpochs(this.shouldForceEpochs);
                ((BaseMultiLayerNetwork)this.ret).setUseRegularization(this.useRegularization);
                ((BaseMultiLayerNetwork)this.ret).setUseAdaGrad(this.useAdaGrad);
                ((BaseMultiLayerNetwork)this.ret).setDropOut(this.dropOut);
                ((BaseMultiLayerNetwork)this.ret).setOptimizationAlgorithm(this.optimizationAlgo);
                ((BaseMultiLayerNetwork)this.ret).setLossFunction(this.lossFunction);
                if (this.activation != null) {
                    ((BaseMultiLayerNetwork)this.ret).setActivation(this.activation);
                }
                if (this.dist != null) {
                    ((BaseMultiLayerNetwork)this.ret).setDist(this.dist);
                }
                ((BaseMultiLayerNetwork)this.ret).getWeightTransforms().putAll(this.weightTransforms);
                ((BaseMultiLayerNetwork)this.ret).getVisibleBiasTransforms().putAll(this.visibleBiasTransforms);
                ((BaseMultiLayerNetwork)this.ret).getHiddenBiasTransforms().putAll(this.hiddenBiasTransforms);
                ((BaseMultiLayerNetwork)this.ret).gradientListeners.putAll(this.gradientListeners);
                ((BaseMultiLayerNetwork)this.ret).multiLayerGradientListeners.addAll(this.multiLayerGradientListeners);
                if (this.hiddenLayerSizes == null) {
                    throw new IllegalStateException("Unable to build network, no hidden layer sizes defined");
                }
                ((BaseMultiLayerNetwork)this.ret).init();
                return this.ret;
            }
            catch (IllegalAccessException | InstantiationException e) {
                throw new RuntimeException(e);
            }
        }
    }
}

