org.encog.neural.networks.training.propagation.resilient
Class ResilientPropagation

java.lang.Object
  extended by org.encog.ml.train.BasicTraining
      extended by org.encog.neural.networks.training.propagation.Propagation
          extended by org.encog.neural.networks.training.propagation.resilient.ResilientPropagation
All Implemented Interfaces:
MLTrain, BatchSize, Train, MultiThreadable

public class ResilientPropagation
extends Propagation

One problem with the backpropagation algorithm is that the magnitude of the partial derivative is usually too large or too small. Further, the learning rate is a single value for the entire neural network. The resilient propagation learning algorithm uses a special update value(similar to the learning rate) for every neuron connection. Further these update values are automatically determined, unlike the learning rate of the backpropagation algorithm. For most training situations, we suggest that the resilient propagation algorithm (this class) be used for training. There are a total of three parameters that must be provided to the resilient training algorithm. Defaults are provided for each, and in nearly all cases, these defaults are acceptable. This makes the resilient propagation algorithm one of the easiest and most efficient training algorithms available. It is also important to note that RPROP does not work well with online training. You should always use a batch size bigger than one. Typically the larger the better. By default a batch size of zero is used, zero means to include the entire training set in the batch. The optional parameters are: zeroTolerance - How close to zero can a number be to be considered zero. The default is 0.00000000000000001. initialUpdate - What are the initial update values for each matrix value. The default is 0.1. maxStep - What is the largest amount that the update values can step. The default is 50. Usually you will not need to use these, and you should use the constructor that does not require them.

Author:
jheaton

Field Summary
static String LAST_GRADIENTS
          Continuation tag for the last gradients.
static String UPDATE_VALUES
          Continuation tag for the last values.
 
Fields inherited from class org.encog.neural.networks.training.propagation.Propagation
gradients, lastError, network
 
Constructor Summary
ResilientPropagation(ContainsFlat network, MLDataSet training)
          Construct an RPROP trainer, allows an OpenCL device to be specified.
ResilientPropagation(ContainsFlat network, MLDataSet training, double initialUpdate, double maxStep)
          Construct a resilient training object, allow the training parameters to be specified.
 
Method Summary
 boolean canContinue()
           
 RPROPType getRPROPType()
           
 double[] getUpdateValues()
           
 void initOthers()
          Perform training method specific init.
 boolean isValidResume(TrainingContinuation state)
          Determine if the specified continuation object is valid to resume with.
 TrainingContinuation pause()
          Pause the training.
 void resume(TrainingContinuation state)
          Resume training.
 void setRPROPType(RPROPType t)
          Set the type of RPROP to use.
 double updateiWeightMinus(double[] gradients, double[] lastGradient, int index)
           
 double updateiWeightPlus(double[] gradients, double[] lastGradient, int index)
           
 double updateWeight(double[] gradients, double[] lastGradient, int index)
          Calculate the amount to change the weight by.
 double updateWeightMinus(double[] gradients, double[] lastGradient, int index)
           
 double updateWeightPlus(double[] gradients, double[] lastGradient, int index)
           
 
Methods inherited from class org.encog.neural.networks.training.propagation.Propagation
calculateGradients, finishTraining, fixFlatSpot, getBatchSize, getCurrentFlatNetwork, getLastGradient, getMethod, getThreadCount, iteration, iteration, learn, learnLimited, report, rollIteration, setBatchSize, setErrorFunction, setThreadCount
 
Methods inherited from class org.encog.ml.train.BasicTraining
addStrategy, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, postIteration, preIteration, setError, setIteration, setTraining
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface org.encog.ml.train.MLTrain
addStrategy, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, setError, setIteration
 

Field Detail

LAST_GRADIENTS

public static final String LAST_GRADIENTS
Continuation tag for the last gradients.

See Also:
Constant Field Values

UPDATE_VALUES

public static final String UPDATE_VALUES
Continuation tag for the last values.

See Also:
Constant Field Values
Constructor Detail

ResilientPropagation

public ResilientPropagation(ContainsFlat network,
                            MLDataSet training)
Construct an RPROP trainer, allows an OpenCL device to be specified. Use the defaults for all training parameters. Usually this is the constructor to use as the resilient training algorithm is designed for the default parameters to be acceptable for nearly all problems.

Parameters:
network - The network to train.
training - The training data to use.

ResilientPropagation

public ResilientPropagation(ContainsFlat network,
                            MLDataSet training,
                            double initialUpdate,
                            double maxStep)
Construct a resilient training object, allow the training parameters to be specified. Usually the default parameters are acceptable for the resilient training algorithm. Therefore you should usually use the other constructor, that makes use of the default values.

Parameters:
network - The network to train.
training - The training set to use.
initialUpdate - The initial update values, this is the amount that the deltas are all initially set to.
maxStep - The maximum that a delta can reach.
Method Detail

canContinue

public boolean canContinue()
Returns:
True, as RPROP can continue.

isValidResume

public boolean isValidResume(TrainingContinuation state)
Determine if the specified continuation object is valid to resume with.

Parameters:
state - The continuation object to check.
Returns:
True if the specified continuation object is valid for this training method and network.

pause

public TrainingContinuation pause()
Pause the training.

Returns:
A training continuation object to continue with.

resume

public void resume(TrainingContinuation state)
Resume training.

Parameters:
state - The training state to return to.

setRPROPType

public void setRPROPType(RPROPType t)
Set the type of RPROP to use. The default is RPROPp (RPROP+), or classic RPROP.

Parameters:
t - The type.

getRPROPType

public RPROPType getRPROPType()
Returns:
The type of RPROP used.

initOthers

public void initOthers()
Perform training method specific init.

Specified by:
initOthers in class Propagation

updateWeight

public double updateWeight(double[] gradients,
                           double[] lastGradient,
                           int index)
Calculate the amount to change the weight by.

Specified by:
updateWeight in class Propagation
Parameters:
gradients - The gradients.
lastGradient - The last gradients.
index - The index to update.
Returns:
The amount to change the weight by.

updateWeightPlus

public double updateWeightPlus(double[] gradients,
                               double[] lastGradient,
                               int index)

updateWeightMinus

public double updateWeightMinus(double[] gradients,
                                double[] lastGradient,
                                int index)

updateiWeightPlus

public double updateiWeightPlus(double[] gradients,
                                double[] lastGradient,
                                int index)

updateiWeightMinus

public double updateiWeightMinus(double[] gradients,
                                 double[] lastGradient,
                                 int index)

getUpdateValues

public double[] getUpdateValues()
Returns:
The RPROP update values.


Copyright © 2014. All Rights Reserved.