package org.deeplearning4j.nn.layers;

import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.berkeley.Pair;
import org.deeplearning4j.nn.api.Layer;
import org.deeplearning4j.nn.api.ParamInitializer;
import org.deeplearning4j.nn.conf.NeuralNetConfiguration;
import org.deeplearning4j.nn.gradient.DefaultGradient;
import org.deeplearning4j.nn.gradient.Gradient;
import org.deeplearning4j.nn.params.DefaultParamInitializer;
import org.deeplearning4j.optimize.Solver;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.NDArrayIndex;

/* loaded from: input_file:org/deeplearning4j/nn/layers/BaseLayer.class */
public abstract class BaseLayer implements Layer {
    protected INDArray input;
    protected Map<String, INDArray> params;
    protected NeuralNetConfiguration conf;
    protected INDArray dropoutMask;
    protected ParamInitializer paramInitializer;

    public BaseLayer(NeuralNetConfiguration neuralNetConfiguration) {
        this.conf = neuralNetConfiguration;
    }

    public BaseLayer(NeuralNetConfiguration neuralNetConfiguration, INDArray iNDArray) {
        this.input = iNDArray;
        this.conf = neuralNetConfiguration;
    }

    public void fit() {
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void setConf(NeuralNetConfiguration neuralNetConfiguration) {
        this.conf = neuralNetConfiguration;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setParam(String str, INDArray iNDArray) {
        this.params.put(str, iNDArray);
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray params() {
        ArrayList arrayList = new ArrayList();
        Iterator<String> it = this.conf.getGradientList().iterator();
        while (it.hasNext()) {
            arrayList.add(this.params.get(it.next()));
        }
        return Nd4j.toFlattened(arrayList);
    }

    public void setParams(INDArray iNDArray) {
        List<String> gradientList = this.conf.getGradientList();
        int i = 0;
        Iterator<String> it = gradientList.iterator();
        while (it.hasNext()) {
            i += getParam(it.next()).length();
        }
        if (iNDArray.length() != i) {
            throw new IllegalArgumentException("Unable to set parameters: must be of length " + i);
        }
        int i2 = 0;
        for (int i3 = 0; i3 < gradientList.size(); i3++) {
            INDArray param = getParam(gradientList.get(i3));
            INDArray iNDArray2 = iNDArray.get(new NDArrayIndex[]{NDArrayIndex.interval(i2, i2 + param.length())});
            if (param.length() != iNDArray2.length()) {
                throw new IllegalStateException("Parameter " + gradientList.get(i3) + " should have been of length " + param.length() + " but was " + iNDArray2.length());
            }
            param.assign(iNDArray2.reshape(param.shape()));
            i2 += param.length();
        }
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void initParams() {
        this.paramInitializer.init(paramTable(), conf());
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Map<String, INDArray> paramTable() {
        return this.params;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setParamTable(Map<String, INDArray> map) {
        this.params = map;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray getParam(String str) {
        return this.params.get(str);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray preOutput(INDArray iNDArray) {
        if (iNDArray == null) {
            throw new IllegalArgumentException("No null input allowed");
        }
        this.input = iNDArray;
        INDArray param = getParam("b");
        INDArray mmul = this.input.mmul(getParam(DefaultParamInitializer.WEIGHT_KEY));
        if (mmul.columns() != param.columns()) {
            throw new IllegalStateException("This is weird");
        }
        if (this.conf.isConcatBiases()) {
            mmul = Nd4j.hstack(new INDArray[]{mmul, param});
        } else {
            mmul.addiRowVector(param);
        }
        return mmul;
    }

    public int batchSize() {
        return this.input.rows();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate() {
        return (INDArray) this.conf.getActivationFunction().apply(getInput().mmul(getParam(DefaultParamInitializer.WEIGHT_KEY)).addiRowVector(getParam("b")));
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activate(INDArray iNDArray) {
        this.input = iNDArray;
        return activate();
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray activationMean() {
        return getInput().mmul(getParam(DefaultParamInitializer.WEIGHT_KEY)).addRowVector(getParam("b"));
    }

    @Override // org.deeplearning4j.nn.api.Layer, org.deeplearning4j.nn.api.Model
    public NeuralNetConfiguration conf() {
        return this.conf;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setConfiguration(NeuralNetConfiguration neuralNetConfiguration) {
        this.conf = neuralNetConfiguration;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public INDArray getInput() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void setInput(INDArray iNDArray) {
        this.input = iNDArray;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void applyDropOutIfNecessary(INDArray iNDArray) {
        if (this.conf.getDropOut() > 0.0d) {
            Nd4j.rand(iNDArray.rows(), iNDArray.columns()).gti(2);
            this.dropoutMask = Nd4j.rand(iNDArray.rows(), iNDArray.columns()).gt(Double.valueOf(this.conf.getDropOut()));
        } else {
            this.dropoutMask = Nd4j.ones(iNDArray.rows(), this.conf.getnOut());
        }
        iNDArray.muli(this.dropoutMask);
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public void merge(Layer layer, int i) {
        setParams(params().addi(layer.params().divi(Integer.valueOf(i))));
    }

    @Override // org.deeplearning4j.nn.api.Layer
    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public Layer m20clone() {
        INDArray param = getParam(DefaultParamInitializer.WEIGHT_KEY);
        INDArray param2 = getParam("b");
        Layer layer = null;
        try {
            Constructor<?> constructor = getClass().getConstructor(NeuralNetConfiguration.class, INDArray.class, INDArray.class, INDArray.class);
            Object[] objArr = new Object[4];
            objArr[0] = this.conf;
            objArr[1] = param.dup();
            objArr[2] = param2.dup();
            objArr[3] = this.input != null ? this.input.dup() : null;
            layer = (Layer) constructor.newInstance(objArr);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return layer;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public int numParams() {
        int i = 0;
        Iterator<INDArray> it = this.params.values().iterator();
        while (it.hasNext()) {
            i += it.next().length();
        }
        return i;
    }

    public void fit(INDArray iNDArray) {
        if (iNDArray != null) {
            this.input = iNDArray;
        }
        new Solver.Builder().model(this).configure(conf()).listeners(this.conf.getListeners()).build().optimize();
    }

    public Pair<Gradient, Double> gradientAndScore() {
        return new Pair<>(getGradient(), Double.valueOf(score()));
    }

    @Override // org.deeplearning4j.nn.api.Model
    public INDArray input() {
        return this.input;
    }

    @Override // org.deeplearning4j.nn.api.Model
    public void validateInput() {
        if (this.conf.getBatchSize() > 0 && this.input.rows() != this.conf.getBatchSize()) {
            throw new IllegalStateException("Illegal batch size " + this.input.rows() + " should have been " + this.conf.getBatchSize());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public Gradient createGradient(INDArray... iNDArrayArr) {
        DefaultGradient defaultGradient = new DefaultGradient();
        if (iNDArrayArr.length != this.conf.getGradientList().size()) {
            throw new IllegalArgumentException("Unable to create gradients...not equal to number of parameters");
        }
        for (int i = 0; i < iNDArrayArr.length; i++) {
            INDArray param = getParam(this.conf.getGradientList().get(i));
            if (!Arrays.equals(param.shape(), iNDArrayArr[i].shape())) {
                throw new IllegalArgumentException("Gradient at index " + i + " had wrong gradient size of " + Arrays.toString(iNDArrayArr[i].shape()) + " when should have been " + Arrays.toString(param.shape()));
            }
            defaultGradient.gradientLookupTable().put(this.conf.getGradientList().get(i), iNDArrayArr[i]);
        }
        return defaultGradient;
    }

    @Override // org.deeplearning4j.nn.api.Layer
    public Layer transpose() {
        INDArray param = getParam(DefaultParamInitializer.WEIGHT_KEY);
        INDArray param2 = getParam("b");
        Layer layer = null;
        try {
            Constructor<?> constructor = getClass().getConstructor(NeuralNetConfiguration.class, INDArray.class, INDArray.class, INDArray.class);
            NeuralNetConfiguration m13clone = this.conf.m13clone();
            int i = m13clone.getnOut();
            int i2 = m13clone.getnIn();
            m13clone.setnIn(i);
            m13clone.setnOut(i2);
            Object[] objArr = new Object[4];
            objArr[0] = this.conf;
            objArr[1] = param.transpose().dup();
            objArr[2] = param2.transpose().dup();
            objArr[3] = this.input != null ? this.input.transpose().dup() : null;
            layer = (Layer) constructor.newInstance(objArr);
        } catch (Exception e) {
            e.printStackTrace();
        }
        return layer;
    }
}
