|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectorg.encog.ml.train.BasicTraining
org.encog.neural.networks.training.lma.LevenbergMarquardtTraining
public class LevenbergMarquardtTraining
Trains a neural network using a Levenberg Marquardt algorithm (LMA). This training technique is based on the mathematical technique of the same name. The LMA interpolates between the Gauss-Newton algorithm (GNA) and the method of gradient descent (similar to what is used by backpropagation. The lambda parameter determines the degree to which GNA and Gradient Descent are used. A lower lambda results in heavier use of GNA, whereas a higher lambda results in a heavier use of gradient descent. Each iteration starts with a low lambda that builds if the improvement to the neural network is not desirable. At some point the lambda is high enough that the training method reverts totally to gradient descent. This allows the neural network to be trained effectively in cases where GNA provides the optimal training time, but has the ability to fall back to the more primitive gradient descent method LMA finds only a local minimum, not a global minimum. References: http://www.heatonresearch.com/wiki/LMA http://en.wikipedia.org/wiki/Levenberg%E2%80%93Marquardt_algorithm http://en.wikipedia.org/wiki/Finite_difference_method http://crsouza.blogspot.com/2009/11/neural-network-learning-by-levenberg_18.html http://mathworld.wolfram.com/FiniteDifference.html http://www-alg.ist.hokudai.ac.jp/~jan/alpha.pdf - http://www.inference.phy.cam.ac.uk/mackay/Bayes_FAQ.html
Field Summary | |
---|---|
static double |
LAMBDA_MAX
The max amount for the LAMBDA. |
static double |
SCALE_LAMBDA
The amount to scale the lambda by. |
Constructor Summary | |
---|---|
LevenbergMarquardtTraining(BasicNetwork network,
MLDataSet training)
Construct the LMA object. |
|
LevenbergMarquardtTraining(BasicNetwork network,
MLDataSet training,
ComputeHessian h)
Construct the LMA object. |
Method Summary | |
---|---|
boolean |
canContinue()
|
ComputeHessian |
getHessian()
|
MLMethod |
getMethod()
Get the current best machine learning method from the training. |
int |
getThreadCount()
|
void |
iteration()
Perform one iteration. |
TrainingContinuation |
pause()
Pause the training to continue later. |
void |
resume(TrainingContinuation state)
Resume training. |
void |
setThreadCount(int numThreads)
Set the number of threads to use. |
void |
updateWeights()
Update the weights in the neural network. |
Methods inherited from class org.encog.ml.train.BasicTraining |
---|
addStrategy, finishTraining, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, iteration, postIteration, preIteration, setError, setIteration, setTraining |
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Field Detail |
---|
public static final double SCALE_LAMBDA
public static final double LAMBDA_MAX
Constructor Detail |
---|
public LevenbergMarquardtTraining(BasicNetwork network, MLDataSet training)
network
- The network to train. Must have a single output neuron.training
- The training data to use. Must be indexable.public LevenbergMarquardtTraining(BasicNetwork network, MLDataSet training, ComputeHessian h)
network
- The network to train. Must have a single output neuron.training
- The training data to use. Must be indexable.Method Detail |
---|
public boolean canContinue()
canContinue
in interface MLTrain
public MLMethod getMethod()
MLTrain
getMethod
in interface MLTrain
public void iteration()
iteration
in interface MLTrain
public TrainingContinuation pause()
pause
in interface MLTrain
public void resume(TrainingContinuation state)
resume
in interface MLTrain
state
- The training continuation object to use to continue.public void updateWeights()
public ComputeHessian getHessian()
public int getThreadCount()
getThreadCount
in interface MultiThreadable
public void setThreadCount(int numThreads)
MultiThreadable
setThreadCount
in interface MultiThreadable
numThreads
- The number of threads to use, or zero to
automatically determine based on core count.
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |