org.encog.neural.networks.training.propagation
Class Propagation

java.lang.Object
  extended by org.encog.ml.train.BasicTraining
      extended by org.encog.neural.networks.training.propagation.Propagation
All Implemented Interfaces:
MLTrain, BatchSize, Train, MultiThreadable
Direct Known Subclasses:
Backpropagation, ManhattanPropagation, QuickPropagation, ResilientPropagation, ScaledConjugateGradient

public abstract class Propagation
extends BasicTraining
implements Train, MultiThreadable, BatchSize

Implements basic functionality that is needed by each of the propagation methods. The specifics of each of the propagation methods is implemented inside of the PropagationMethod interface implementors.

Author:
jheaton

Field Summary
protected  double[] gradients
          The gradients.
protected  double lastError
          The last error.
protected  ContainsFlat network
          The network to train.
 
Constructor Summary
Propagation(ContainsFlat network, MLDataSet training)
          Construct a propagation object.
 
Method Summary
 void calculateGradients()
          Calculate the gradients.
 void finishTraining()
          Should be called after training has completed and the iteration method will not be called any further.
 void fixFlatSpot(boolean b)
          Default is true.
 int getBatchSize()
          The batch size.
 FlatNetwork getCurrentFlatNetwork()
           
 double[] getLastGradient()
           
 MLMethod getMethod()
          Get the current best machine learning method from the training.
 int getThreadCount()
           
abstract  void initOthers()
           
 void iteration()
          Perform one training iteration.
 void iteration(int count)
          Perform the specified number of training iterations.
protected  void learn()
          Apply and learn.
protected  void learnLimited()
          Apply and learn.
 void report(double[] gradients, double error, Throwable ex)
          Called by the worker threads to report the progress at each step.
 void rollIteration()
          Increase the iteration by one.
 void setBatchSize(int theBatchSize)
          Set the batch size.
 void setErrorFunction(ErrorFunction ef)
           
 void setThreadCount(int numThreads)
          Set the number of threads.
abstract  double updateWeight(double[] gradients, double[] lastGradient, int index)
          Update a weight, the means by which weights are updated vary depending on the training.
 
Methods inherited from class org.encog.ml.train.BasicTraining
addStrategy, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, postIteration, preIteration, setError, setIteration, setTraining
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 
Methods inherited from interface org.encog.ml.train.MLTrain
addStrategy, canContinue, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, pause, resume, setError, setIteration
 

Field Detail

gradients

protected double[] gradients
The gradients.


network

protected final ContainsFlat network
The network to train.


lastError

protected double lastError
The last error.

Constructor Detail

Propagation

public Propagation(ContainsFlat network,
                   MLDataSet training)
Construct a propagation object.

Parameters:
network - The network.
training - The training set.
Method Detail

finishTraining

public void finishTraining()
Should be called after training has completed and the iteration method will not be called any further.

Specified by:
finishTraining in interface MLTrain
Overrides:
finishTraining in class BasicTraining

getCurrentFlatNetwork

public FlatNetwork getCurrentFlatNetwork()
Returns:
the currentFlatNetwork

getMethod

public MLMethod getMethod()
Get the current best machine learning method from the training.

Specified by:
getMethod in interface MLTrain
Returns:
The best machine learningm method.

iteration

public void iteration()
Perform one training iteration.

Specified by:
iteration in interface MLTrain

rollIteration

public void rollIteration()
Increase the iteration by one.


iteration

public void iteration(int count)
Perform the specified number of training iterations. This can be more efficient than single training iterations. This is particularly true if you are training with a GPU.

Specified by:
iteration in interface MLTrain
Overrides:
iteration in class BasicTraining
Parameters:
count - The number of training iterations.

setThreadCount

public void setThreadCount(int numThreads)
Set the number of threads. Specify zero to tell Encog to automatically determine the best number of threads for the processor. If OpenCL is used as the target device, then this value is not used.

Specified by:
setThreadCount in interface MultiThreadable
Parameters:
numThreads - The number of threads.

getThreadCount

public int getThreadCount()
Specified by:
getThreadCount in interface MultiThreadable
Returns:
The number of threads to use, 0 to automatically determine based on core count.

fixFlatSpot

public void fixFlatSpot(boolean b)
Default is true. Call this with false to disable flat spot fix. For more info on flat spot: http://www.heatonresearch.com/wiki/Flat_Spot

Parameters:
b - True to fix flat spots, false otherwise.

setErrorFunction

public void setErrorFunction(ErrorFunction ef)

calculateGradients

public void calculateGradients()
Calculate the gradients.


report

public void report(double[] gradients,
                   double error,
                   Throwable ex)
Called by the worker threads to report the progress at each step.

Parameters:
gradients - The gradients from that worker.
error - The error for that worker.
ex - The exception.

learn

protected void learn()
Apply and learn.


learnLimited

protected void learnLimited()
Apply and learn. This is the same as learn, but it checks to see if any of the weights are below the limit threshold. In this case, these weights are zeroed out. Having two methods allows the regular learn method, which is what is usually use, to be as fast as possible.


initOthers

public abstract void initOthers()

updateWeight

public abstract double updateWeight(double[] gradients,
                                    double[] lastGradient,
                                    int index)
Update a weight, the means by which weights are updated vary depending on the training.

Parameters:
gradients - The gradients.
lastGradient - The last gradients.
index - The index.
Returns:
The update value.

getLastGradient

public double[] getLastGradient()
Returns:
the lastGradient

getBatchSize

public int getBatchSize()
The batch size. Specify 1 for pure online training. Specify 0 for pure batch training (complete training set in one batch). Otherwise specify the batch size for batch training.

Specified by:
getBatchSize in interface BatchSize
Returns:
The batch size.

setBatchSize

public void setBatchSize(int theBatchSize)
Set the batch size. Specify 1 for pure online training. Specify 0 for pure batch training (complete training set in one batch). Otherwise specify the batch size for batch training.

Specified by:
setBatchSize in interface BatchSize
Parameters:
theBatchSize - The batch size.


Copyright © 2014. All Rights Reserved.