package org.neuroph.core.learning;

import java.io.Serializable;
import java.util.Iterator;
import org.neuroph.core.Connection;
import org.neuroph.core.Layer;
import org.neuroph.core.Neuron;
import org.neuroph.core.Weight;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.learning.error.ErrorFunction;
import org.neuroph.core.learning.error.MeanSquaredError;
import org.neuroph.core.learning.stop.MaxErrorStop;

/* loaded from: input_file:org/neuroph/core/learning/SupervisedLearning.class */
public abstract class SupervisedLearning extends IterativeLearning implements Serializable {
    private static final long serialVersionUID = 3;
    protected transient double previousEpochError;
    private transient int minErrorChangeIterationsCount;
    protected double maxError = 0.01d;
    private double minErrorChange = Double.POSITIVE_INFINITY;
    private int minErrorChangeIterationsLimit = Integer.MAX_VALUE;
    private boolean batchMode = false;
    private ErrorFunction errorFunction = new MeanSquaredError();

    public SupervisedLearning() {
        this.stopConditions.add(new MaxErrorStop(this));
    }

    public void learn(DataSet dataSet, double d) {
        this.maxError = d;
        learn(dataSet);
    }

    public void learn(DataSet dataSet, double d, int i) {
        this.maxError = d;
        setMaxIterations(i);
        learn(dataSet);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // org.neuroph.core.learning.IterativeLearning, org.neuroph.core.learning.LearningRule
    public void onStart() {
        super.onStart();
        this.minErrorChangeIterationsCount = 0;
        this.previousEpochError = 0.0d;
    }

    @Override // org.neuroph.core.learning.IterativeLearning
    protected void beforeEpoch() {
        this.previousEpochError = this.errorFunction.getTotalError();
        this.errorFunction.reset();
    }

    @Override // org.neuroph.core.learning.IterativeLearning
    protected void afterEpoch() {
        if (Math.abs(this.previousEpochError - this.errorFunction.getTotalError()) <= this.minErrorChange) {
            this.minErrorChangeIterationsCount++;
        } else {
            this.minErrorChangeIterationsCount = 0;
        }
        if (this.batchMode) {
            doBatchWeightsUpdate();
        }
    }

    @Override // org.neuroph.core.learning.IterativeLearning
    public void doLearningEpoch(DataSet dataSet) {
        Iterator<DataSetRow> it = dataSet.iterator();
        while (it.hasNext() && !isStopped()) {
            learnPattern(it.next());
        }
    }

    protected void learnPattern(DataSetRow dataSetRow) {
        this.neuralNetwork.setInput(dataSetRow.getInput());
        this.neuralNetwork.calculate();
        updateNetworkWeights(this.errorFunction.calculatePatternError(this.neuralNetwork.getOutput(), dataSetRow.getDesiredOutput()));
    }

    protected void doBatchWeightsUpdate() {
        Layer[] layers = this.neuralNetwork.getLayers();
        for (int layersCount = this.neuralNetwork.getLayersCount() - 1; layersCount > 0; layersCount--) {
            for (Neuron neuron : layers[layersCount].getNeurons()) {
                for (Connection connection : neuron.getInputConnections()) {
                    Weight weight = connection.getWeight();
                    weight.value += weight.weightChange;
                    weight.weightChange = 0.0d;
                }
            }
        }
    }

    public boolean isInBatchMode() {
        return this.batchMode;
    }

    public void setBatchMode(boolean z) {
        this.batchMode = z;
    }

    public void setMaxError(double d) {
        this.maxError = d;
    }

    public double getMaxError() {
        return this.maxError;
    }

    public double getPreviousEpochError() {
        return this.previousEpochError;
    }

    public double getMinErrorChange() {
        return this.minErrorChange;
    }

    public void setMinErrorChange(double d) {
        this.minErrorChange = d;
    }

    public int getMinErrorChangeIterationsLimit() {
        return this.minErrorChangeIterationsLimit;
    }

    public void setMinErrorChangeIterationsLimit(int i) {
        this.minErrorChangeIterationsLimit = i;
    }

    public int getMinErrorChangeIterationsCount() {
        return this.minErrorChangeIterationsCount;
    }

    public ErrorFunction getErrorFunction() {
        return this.errorFunction;
    }

    public void setErrorFunction(ErrorFunction errorFunction) {
        this.errorFunction = errorFunction;
    }

    public double getTotalNetworkError() {
        return this.errorFunction.getTotalError();
    }

    protected abstract void updateNetworkWeights(double[] dArr);
}
