package mulan.classifier.neural.model;

import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import mulan.core.ArgumentNullException;

/* loaded from: input_file:mulan/classifier/neural/model/BasicNeuralNet.class */
public class BasicNeuralNet implements NeuralNet, Serializable {
    private static final long serialVersionUID = -8944873770650464701L;
    private final List<List<Neuron>> layers;
    private double[] currentNetOutput;
    private final int netInputDim;
    private final int netOutputDim;

    public BasicNeuralNet(int[] iArr, double d, Class<? extends ActivationFunction> cls, Random random) {
        if (iArr == null || iArr.length < 2) {
            throw new IllegalArgumentException("The topology for neural network is not specified or is invalid. Please provide correct topology for the network.");
        }
        if (cls == null) {
            throw new ArgumentNullException("activationFunction");
        }
        this.netInputDim = iArr[0];
        this.netOutputDim = iArr[iArr.length - 1];
        this.layers = new ArrayList(iArr.length);
        ArrayList arrayList = new ArrayList(iArr[0]);
        for (int i = 0; i < iArr[0]; i++) {
            Neuron neuron = new Neuron(new ActivationLinear(), 1, d, random);
            double[] weights = neuron.getWeights();
            weights[0] = 1.0d;
            weights[1] = 0.0d;
            arrayList.add(neuron);
        }
        this.layers.add(arrayList);
        for (int i2 = 1; i2 < iArr.length; i2++) {
            try {
                ArrayList arrayList2 = new ArrayList(iArr[i2]);
                for (int i3 = 0; i3 < iArr[i2]; i3++) {
                    arrayList2.add(new Neuron(cls.newInstance(), iArr[i2 - 1], d, random));
                }
                this.layers.add(arrayList2);
                List<Neuron> list = this.layers.get(i2 - 1);
                for (int i4 = 0; i4 < list.size(); i4++) {
                    list.get(i4).addAllNeurons(arrayList2);
                }
            } catch (IllegalAccessException e) {
                throw new IllegalArgumentException("Failed to create activation function instance.", e);
            } catch (InstantiationException e2) {
                throw new IllegalArgumentException("Failed to create activation function instance.", e2);
            }
        }
    }

    @Override // mulan.classifier.neural.model.NeuralNet
    public List<Neuron> getLayerUnits(int i) {
        return Collections.unmodifiableList(this.layers.get(i));
    }

    @Override // mulan.classifier.neural.model.NeuralNet
    public int getLayersCount() {
        return this.layers.size();
    }

    @Override // mulan.classifier.neural.model.NeuralNet
    public double[] feedForward(double[] dArr) {
        if (dArr == null || dArr.length != this.netInputDim) {
            throw new IllegalArgumentException("Specified input pattern vector is null or does not match network input dimension.");
        }
        double[] dArr2 = null;
        double[] dArr3 = dArr;
        for (int i = 0; i < this.layers.size(); i++) {
            List<Neuron> list = this.layers.get(i);
            int size = list.size();
            dArr2 = new double[size];
            for (int i2 = 0; i2 < size; i2++) {
                if (i == 0) {
                    dArr2[i2] = list.get(i2).processInput(new double[]{dArr3[i2]});
                } else {
                    dArr2[i2] = list.get(i2).processInput(dArr3);
                }
            }
            dArr3 = Arrays.copyOf(dArr2, dArr2.length);
        }
        this.currentNetOutput = Arrays.copyOf(dArr2, dArr2.length);
        return this.currentNetOutput;
    }

    @Override // mulan.classifier.neural.model.NeuralNet
    public double[] getOutput() {
        return this.currentNetOutput == null ? new double[this.netOutputDim] : this.currentNetOutput;
    }

    @Override // mulan.classifier.neural.model.NeuralNet
    public void reset() {
        this.currentNetOutput = null;
        Iterator<List<Neuron>> it = this.layers.iterator();
        while (it.hasNext()) {
            Iterator<Neuron> it2 = it.next().iterator();
            while (it2.hasNext()) {
                it2.next().reset();
            }
        }
    }

    @Override // mulan.classifier.neural.model.NeuralNet
    public int getNetInputSize() {
        return this.netInputDim;
    }

    @Override // mulan.classifier.neural.model.NeuralNet
    public int getNetOutputSize() {
        return this.netOutputDim;
    }
}
