/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.layers;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;

public abstract class BaseLayer
implements Layer {
    protected INDArray input;
    protected Map<String, INDArray> params;
    protected NeuralNetConfiguration conf;
    protected INDArray dropoutMask;
    protected ParamInitializer paramInitializer;

    public BaseLayer(NeuralNetConfiguration conf) {
        this.conf = conf;
    }

    public BaseLayer(NeuralNetConfiguration conf, INDArray input) {
        this.input = input;
        this.conf = conf;
    }

    @Override
    public void setParameters(INDArray params) {
    }

    @Override
    public void fit() {
    }

    @Override
    public void setConf(NeuralNetConfiguration conf) {
        this.conf = conf;
    }

    @Override
    public void setParam(String key, INDArray val) {
        this.params.put(key, val);
    }

    @Override
    public INDArray params() {
        ArrayList<INDArray> ret = new ArrayList<INDArray>();
        for (String s : this.params.keySet()) {
            ret.add(this.params.get(s));
        }
        return Nd4j.toFlattened(ret);
    }

    @Override
    public void setParams(INDArray params) {
        List<String> gradientList = this.conf.getGradientList();
        int length = 0;
        for (String s : gradientList) {
            length += this.getParam(s).length();
        }
        if (params.length() != length) {
            throw new IllegalArgumentException("Unable to set parameters: must be of length " + length);
        }
        int idx = 0;
        for (int i = 0; i < gradientList.size(); ++i) {
            INDArray param = this.getParam(gradientList.get(i));
            INDArray get = params.get(new NDArrayIndex[]{NDArrayIndex.interval((int)idx, (int)(idx + param.length()))});
            if (param.length() != get.length()) {
                throw new IllegalStateException("Parameter " + gradientList.get(i) + " should have been of length " + param.length() + " but was " + get.length());
            }
            param.assign(get.reshape(param.shape()));
            idx += param.length();
        }
    }

    @Override
    public void initParams() {
        this.paramInitializer.init(this.paramTable(), this.conf());
    }

    @Override
    public Map<String, INDArray> paramTable() {
        return this.params;
    }

    @Override
    public void setParamTable(Map<String, INDArray> paramTable) {
        this.params = paramTable;
    }

    @Override
    public INDArray getParam(String param) {
        return this.params.get(param);
    }

    @Override
    public INDArray preOutput(INDArray x) {
        if (x == null) {
            throw new IllegalArgumentException("No null input allowed");
        }
        this.input = x;
        INDArray b = this.getParam("b");
        INDArray W = this.getParam("W");
        INDArray ret = this.input.mmul(W);
        if (ret.columns() != b.columns()) {
            throw new IllegalStateException("This is weird");
        }
        if (this.conf.isConcatBiases()) {
            ret = Nd4j.hstack((INDArray[])new INDArray[]{ret, b});
        } else {
            ret.addiRowVector(b);
        }
        return ret;
    }

    @Override
    public int batchSize() {
        return this.input.rows();
    }

    @Override
    public INDArray activate() {
        INDArray b = this.getParam("b");
        INDArray W = this.getParam("W");
        return (INDArray)this.conf.getActivationFunction().apply((Object)this.getInput().mmul(W).addiRowVector(b));
    }

    @Override
    public INDArray activate(INDArray input) {
        this.input = input;
        return this.activate();
    }

    @Override
    public INDArray activationMean() {
        INDArray b = this.getParam("b");
        INDArray W = this.getParam("W");
        return this.getInput().mmul(W).addRowVector(b);
    }

    @Override
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override
    public void setConfiguration(NeuralNetConfiguration conf) {
        this.conf = conf;
    }

    @Override
    public INDArray getInput() {
        return this.input;
    }

    @Override
    public void setInput(INDArray input) {
        this.input = input;
    }

    protected void applyDropOutIfNecessary(INDArray input) {
        if (this.conf.getDropOut() > 0.0) {
            INDArray mask = Nd4j.rand((int)input.rows(), (int)input.columns());
            mask.gti((Number)2);
            this.dropoutMask = Nd4j.rand((int)input.rows(), (int)input.columns()).gt((Number)this.conf.getDropOut());
        } else {
            this.dropoutMask = Nd4j.ones((int)input.rows(), (int)this.conf.getnOut());
        }
        input.muli(this.dropoutMask);
    }

    @Override
    public void merge(Layer l, int batchSize) {
        this.setParams(this.params().addi(l.params().divi((Number)batchSize)));
    }

    @Override
    public Layer clone() {
        INDArray W = this.getParam("W");
        INDArray b = this.getParam("b");
        Layer layer = null;
        try {
            Constructor<?> c = this.getClass().getConstructor(NeuralNetConfiguration.class, INDArray.class, INDArray.class, INDArray.class);
            layer = (Layer)c.newInstance(this.conf, W.dup(), b.dup(), this.input != null ? this.input.dup() : null);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return layer;
    }

    @Override
    public int numParams() {
        int ret = 0;
        for (INDArray val : this.params.values()) {
            ret += val.length();
        }
        return ret;
    }

    @Override
    public void fit(INDArray input) {
        if (input != null) {
            this.input = input;
        }
        Solver solver = new Solver.Builder().model(this).configure(this.conf()).listeners(this.conf.getListeners()).build();
        solver.optimize();
    }

    @Override
    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<Gradient, Double>(this.getGradient(), this.score());
    }

    @Override
    public INDArray input() {
        return this.input;
    }

    @Override
    public void validateInput() {
        if (this.conf.getBatchSize() > 0 && this.input.rows() != this.conf.getBatchSize()) {
            throw new IllegalStateException("Illegal batch size " + this.input.rows() + " should have been " + this.conf.getBatchSize());
        }
    }

    protected Gradient createGradient(INDArray ... gradients) {
        DefaultGradient ret = new DefaultGradient();
        if (gradients.length != this.conf.getGradientList().size()) {
            throw new IllegalArgumentException("Unable to create gradients...not equal to number of parameters");
        }
        for (int i = 0; i < gradients.length; ++i) {
            INDArray paramI = this.getParam(this.conf.getGradientList().get(i));
            if (!Arrays.equals(paramI.shape(), gradients[i].shape())) {
                throw new IllegalArgumentException("Gradient at index " + i + " had wrong gradient size of " + Arrays.toString(gradients[i].shape()) + " when should have been " + Arrays.toString(paramI.shape()));
            }
            ret.gradientLookupTable().put(this.conf.getGradientList().get(i), gradients[i]);
        }
        return ret;
    }

    @Override
    public Layer transpose() {
        INDArray W = this.getParam("W");
        INDArray b = this.getParam("b");
        Layer layer = null;
        try {
            Constructor<?> c = this.getClass().getConstructor(NeuralNetConfiguration.class, INDArray.class, INDArray.class, INDArray.class);
            NeuralNetConfiguration clone = this.conf.clone();
            int nIn = clone.getnOut();
            int nOut = clone.getnIn();
            clone.setnIn(nIn);
            clone.setnOut(nOut);
            layer = (Layer)c.newInstance(this.conf, W.transpose().dup(), b.transpose().dup(), this.input != null ? this.input.transpose().dup() : null);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
        return layer;
    }
}

