org.encog.neural.networks.training.lma
Class LevenbergMarquardtTraining

java.lang.Object
  extended by org.encog.ml.train.BasicTraining
      extended by org.encog.neural.networks.training.lma.LevenbergMarquardtTraining
All Implemented Interfaces:
MLTrain, MultiThreadable

public class LevenbergMarquardtTraining
extends BasicTraining
implements MultiThreadable

Trains a neural network using a Levenberg Marquardt algorithm (LMA). This training technique is based on the mathematical technique of the same name. The LMA interpolates between the Gauss-Newton algorithm (GNA) and the method of gradient descent (similar to what is used by backpropagation. The lambda parameter determines the degree to which GNA and Gradient Descent are used. A lower lambda results in heavier use of GNA, whereas a higher lambda results in a heavier use of gradient descent. Each iteration starts with a low lambda that builds if the improvement to the neural network is not desirable. At some point the lambda is high enough that the training method reverts totally to gradient descent. This allows the neural network to be trained effectively in cases where GNA provides the optimal training time, but has the ability to fall back to the more primitive gradient descent method LMA finds only a local minimum, not a global minimum. References: http://www.heatonresearch.com/wiki/LMA http://en.wikipedia.org/wiki/Levenberg%E2%80%93Marquardt_algorithm http://en.wikipedia.org/wiki/Finite_difference_method http://crsouza.blogspot.com/2009/11/neural-network-learning-by-levenberg_18.html http://mathworld.wolfram.com/FiniteDifference.html http://www-alg.ist.hokudai.ac.jp/~jan/alpha.pdf - http://www.inference.phy.cam.ac.uk/mackay/Bayes_FAQ.html


Field Summary
static double LAMBDA_MAX
          The max amount for the LAMBDA.
static double SCALE_LAMBDA
          The amount to scale the lambda by.
 
Constructor Summary
LevenbergMarquardtTraining(BasicNetwork network, MLDataSet training)
          Construct the LMA object.
LevenbergMarquardtTraining(BasicNetwork network, MLDataSet training, ComputeHessian h)
          Construct the LMA object.
 
Method Summary
 boolean canContinue()
           
 ComputeHessian getHessian()
           
 MLMethod getMethod()
          Get the current best machine learning method from the training.
 int getThreadCount()
           
 void iteration()
          Perform one iteration.
 TrainingContinuation pause()
          Pause the training to continue later.
 void resume(TrainingContinuation state)
          Resume training.
 void setThreadCount(int numThreads)
          Set the number of threads to use.
 void updateWeights()
          Update the weights in the neural network.
 
Methods inherited from class org.encog.ml.train.BasicTraining
addStrategy, finishTraining, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, iteration, postIteration, preIteration, setError, setIteration, setTraining
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Field Detail

SCALE_LAMBDA

public static final double SCALE_LAMBDA
The amount to scale the lambda by.

See Also:
Constant Field Values

LAMBDA_MAX

public static final double LAMBDA_MAX
The max amount for the LAMBDA.

See Also:
Constant Field Values
Constructor Detail

LevenbergMarquardtTraining

public LevenbergMarquardtTraining(BasicNetwork network,
                                  MLDataSet training)
Construct the LMA object.

Parameters:
network - The network to train. Must have a single output neuron.
training - The training data to use. Must be indexable.

LevenbergMarquardtTraining

public LevenbergMarquardtTraining(BasicNetwork network,
                                  MLDataSet training,
                                  ComputeHessian h)
Construct the LMA object.

Parameters:
network - The network to train. Must have a single output neuron.
training - The training data to use. Must be indexable.
Method Detail

canContinue

public boolean canContinue()
Specified by:
canContinue in interface MLTrain
Returns:
True if the training can be paused, and later continued.

getMethod

public MLMethod getMethod()
Description copied from interface: MLTrain
Get the current best machine learning method from the training.

Specified by:
getMethod in interface MLTrain
Returns:
The trained network.

iteration

public void iteration()
Perform one iteration.

Specified by:
iteration in interface MLTrain

pause

public TrainingContinuation pause()
Pause the training to continue later.

Specified by:
pause in interface MLTrain
Returns:
A training continuation object.

resume

public void resume(TrainingContinuation state)
Resume training.

Specified by:
resume in interface MLTrain
Parameters:
state - The training continuation object to use to continue.

updateWeights

public void updateWeights()
Update the weights in the neural network.


getHessian

public ComputeHessian getHessian()
Returns:
The Hessian calculation method used.

getThreadCount

public int getThreadCount()
Specified by:
getThreadCount in interface MultiThreadable
Returns:
The number of threads to use, 0 to automatically determine based on core count.

setThreadCount

public void setThreadCount(int numThreads)
Description copied from interface: MultiThreadable
Set the number of threads to use.

Specified by:
setThreadCount in interface MultiThreadable
Parameters:
numThreads - The number of threads to use, or zero to automatically determine based on core count.


Copyright © 2014. All Rights Reserved.