package mulan.classifier.neural;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import mulan.classifier.neural.model.NeuralNet;
import mulan.classifier.neural.model.Neuron;

/* loaded from: input_file:mulan/classifier/neural/BPMLLAlgorithm.class */
public class BPMLLAlgorithm {
    private final NeuralNet neuralNet;
    private final double weightsDecayCost;

    public BPMLLAlgorithm(NeuralNet neuralNet, double d) {
        if (neuralNet == null) {
            throw new IllegalArgumentException("The passed neural network model is null.");
        }
        if (d <= 0.0d || d > 1.0d) {
            throw new IllegalArgumentException("The weights decay regularization cost term must be greater than 0 and no more than 1. The passed value is : " + d);
        }
        this.neuralNet = neuralNet;
        this.weightsDecayCost = d;
    }

    public NeuralNet getNetwork() {
        return this.neuralNet;
    }

    public double getWeightsDecayCost() {
        return this.weightsDecayCost;
    }

    public double learn(double[] dArr, double[] dArr2, double d) {
        if (dArr == null || dArr.length != this.neuralNet.getNetInputSize()) {
            throw new IllegalArgumentException("Specified input pattern vector is null or does not match the input dimension of underlying neural network model.");
        }
        if (dArr2 == null || dArr2.length != this.neuralNet.getNetOutputSize()) {
            throw new IllegalArgumentException("Specified expected labels vector is null or does not match the output dimension of underlying neural network model.");
        }
        double[] computeErrorsForNeurons = computeErrorsForNeurons(this.neuralNet.feedForward(dArr), dArr2);
        if (computeErrorsForNeurons == null) {
            return Double.NaN;
        }
        double d2 = 0.0d;
        int layersCount = this.neuralNet.getLayersCount();
        for (int i = layersCount - 1; i > 0; i--) {
            List<Neuron> layerUnits = this.neuralNet.getLayerUnits(i);
            if (i == layersCount - 1) {
                computeOutputLayerErrorTerms(layerUnits, computeErrorsForNeurons);
            } else {
                computeHiddenLayerErrorTerms(layerUnits, this.neuralNet.getLayerUnits(i + 1));
            }
            List<Neuron> layerUnits2 = this.neuralNet.getLayerUnits(i - 1);
            double[] dArr3 = new double[layerUnits2.size()];
            int size = layerUnits2.size();
            for (int i2 = 0; i2 < size; i2++) {
                dArr3[i2] = layerUnits2.get(i2).getOutput();
            }
            Iterator<Neuron> it = layerUnits.iterator();
            while (it.hasNext()) {
                for (double d3 : it.next().getWeights()) {
                    d2 += d3 * d3;
                }
            }
            updateWeights(layerUnits, dArr3, d);
        }
        double d4 = 0.0d;
        for (double d5 : computeErrorsForNeurons) {
            d4 += Math.abs(d5);
        }
        return d4 + (this.weightsDecayCost * 0.5d * d2);
    }

    public double getNetworkError(double[] dArr, double[] dArr2) {
        double[] computeErrorsForNeurons = computeErrorsForNeurons(this.neuralNet.feedForward(dArr), dArr2);
        if (computeErrorsForNeurons == null) {
            return Double.NaN;
        }
        double d = 0.0d;
        int layersCount = this.neuralNet.getLayersCount();
        for (int i = 1; i < layersCount; i++) {
            Iterator<Neuron> it = this.neuralNet.getLayerUnits(i).iterator();
            while (it.hasNext()) {
                for (double d2 : it.next().getWeights()) {
                    d += d2 * d2;
                }
            }
        }
        double d3 = 0.0d;
        for (double d4 : computeErrorsForNeurons) {
            d3 += Math.abs(d4);
        }
        return d3 + (this.weightsDecayCost * 0.5d * d);
    }

    private void updateWeights(List<Neuron> list, double[] dArr, double d) {
        int size = list.size();
        for (int i = 0; i < size; i++) {
            Neuron neuron = list.get(i);
            double[] weights = neuron.getWeights();
            double error = neuron.getError();
            int length = dArr.length;
            for (int i2 = 0; i2 < length; i2++) {
                int i3 = i2;
                weights[i3] = weights[i3] + (((d * error) * dArr[i2]) - (this.weightsDecayCost * weights[i2]));
            }
            weights[length] = weights[length] + (((d * error) * neuron.getBiasInput()) - (this.weightsDecayCost * weights[length]));
        }
    }

    private void computeOutputLayerErrorTerms(List<Neuron> list, double[] dArr) {
        int size = list.size();
        for (int i = 0; i < size; i++) {
            Neuron neuron = list.get(i);
            neuron.setError(dArr[i] * neuron.getActivationFunction().derivative(neuron.getNeuronInput()));
        }
    }

    private void computeHiddenLayerErrorTerms(List<Neuron> list, List<Neuron> list2) {
        int size = list.size();
        int size2 = list2.size();
        for (int i = 0; i < size; i++) {
            Neuron neuron = list.get(i);
            double d = 0.0d;
            for (int i2 = 0; i2 < size2; i2++) {
                Neuron neuron2 = list2.get(i2);
                d += neuron2.getError() * neuron2.getWeights()[i];
            }
            neuron.setError(d * neuron.getActivationFunction().derivative(neuron.getNeuronInput()));
        }
    }

    private double[] computeErrorsForNeurons(double[] dArr, double[] dArr2) {
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        int length = dArr2.length;
        for (int i = 0; i < length; i++) {
            if (dArr2[i] == 1.0d) {
                arrayList.add(Integer.valueOf(i));
            } else {
                arrayList2.add(Integer.valueOf(i));
            }
        }
        double[] dArr3 = null;
        if (arrayList.size() != 0 && arrayList2.size() != 0) {
            dArr3 = new double[length];
            for (int i2 = 0; i2 < length; i2++) {
                double d = 0.0d;
                if (arrayList.contains(Integer.valueOf(i2))) {
                    Iterator it = arrayList2.iterator();
                    while (it.hasNext()) {
                        d += Math.exp(-(dArr[i2] - dArr[((Integer) it.next()).intValue()]));
                    }
                } else {
                    Iterator it2 = arrayList.iterator();
                    while (it2.hasNext()) {
                        d -= Math.exp(-(dArr[((Integer) it2.next()).intValue()] - dArr[i2]));
                    }
                }
                dArr3[i2] = d * (1.0d / (arrayList.size() * arrayList2.size()));
            }
        }
        return dArr3;
    }
}
