package org.neuroph.nnet.learning;

import java.util.Iterator;
import org.neuroph.core.NeuralNetwork;
import org.neuroph.core.data.DataSet;
import org.neuroph.core.data.DataSetRow;
import org.neuroph.core.learning.SupervisedLearning;
import org.neuroph.util.NeuralNetworkCODEC;

/* loaded from: input_file:org/neuroph/nnet/learning/SimulatedAnnealingLearning.class */
public class SimulatedAnnealingLearning extends SupervisedLearning {
    private static final long serialVersionUID = 1;
    protected NeuralNetwork network;
    private double startTemperature;
    private double stopTemperature;
    private int cycles;
    protected double temperature;
    private double[] weights;
    private double[] bestWeights;

    public SimulatedAnnealingLearning(NeuralNetwork neuralNetwork, double d, double d2, int i) {
        this.network = neuralNetwork;
        this.temperature = d;
        this.startTemperature = d;
        this.stopTemperature = d2;
        this.cycles = i;
        this.weights = new double[NeuralNetworkCODEC.determineArraySize(neuralNetwork)];
        this.bestWeights = new double[NeuralNetworkCODEC.determineArraySize(neuralNetwork)];
        NeuralNetworkCODEC.network2array(neuralNetwork, this.weights);
        NeuralNetworkCODEC.network2array(neuralNetwork, this.bestWeights);
    }

    public SimulatedAnnealingLearning(NeuralNetwork neuralNetwork) {
        this(neuralNetwork, 10.0d, 2.0d, 1000);
    }

    public NeuralNetwork getNetwork() {
        return this.network;
    }

    public void randomize() {
        for (int i = 0; i < this.weights.length; i++) {
            this.weights[i] = this.weights[i] + (((0.5d - Math.random()) / this.startTemperature) * this.temperature);
        }
        NeuralNetworkCODEC.array2network(this.weights, this.network);
    }

    private double determineError(DataSet dataSet) {
        double d = 0.0d;
        Iterator<DataSetRow> it = dataSet.iterator();
        while (it.hasNext() && !isStopped()) {
            DataSetRow next = it.next();
            this.neuralNetwork.setInput(next.getInput());
            this.neuralNetwork.calculate();
            this.neuralNetwork.getOutput();
            next.getDesiredOutput();
            double[] dArr = null;
            double d2 = 0.0d;
            for (double d3 : dArr) {
                d2 += d3 * d3;
            }
            d += d2 / (2 * dArr.length);
        }
        return d;
    }

    @Override // org.neuroph.core.learning.SupervisedLearning, org.neuroph.core.learning.IterativeLearning
    public void doLearningEpoch(DataSet dataSet) {
        System.arraycopy(this.weights, 0, this.bestWeights, 0, this.weights.length);
        double determineError = determineError(dataSet);
        this.temperature = this.startTemperature;
        for (int i = 0; i < this.cycles; i++) {
            randomize();
            double determineError2 = determineError(dataSet);
            if (determineError2 < determineError) {
                System.arraycopy(this.weights, 0, this.bestWeights, 0, this.weights.length);
                determineError = determineError2;
            } else {
                System.arraycopy(this.bestWeights, 0, this.weights, 0, this.weights.length);
            }
            NeuralNetworkCODEC.array2network(this.bestWeights, this.network);
            this.temperature *= Math.exp(Math.log(this.stopTemperature / this.startTemperature) / (this.cycles - 1));
        }
        this.previousEpochError = getErrorFunction().getTotalError();
        if (hasReachedStopCondition()) {
            stopLearning();
        }
    }

    @Override // org.neuroph.core.learning.SupervisedLearning
    protected void updateNetworkWeights(double[] dArr) {
    }
}
