|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |
java.lang.Objectorg.encog.ml.train.BasicTraining
org.encog.neural.networks.training.propagation.Propagation
public abstract class Propagation
Implements basic functionality that is needed by each of the propagation methods. The specifics of each of the propagation methods is implemented inside of the PropagationMethod interface implementors.
Field Summary | |
---|---|
protected double[] |
gradients
The gradients. |
protected double |
lastError
The last error. |
protected ContainsFlat |
network
The network to train. |
Constructor Summary | |
---|---|
Propagation(ContainsFlat network,
MLDataSet training)
Construct a propagation object. |
Method Summary | |
---|---|
void |
calculateGradients()
Calculate the gradients. |
void |
finishTraining()
Should be called after training has completed and the iteration method will not be called any further. |
void |
fixFlatSpot(boolean b)
Default is true. |
int |
getBatchSize()
The batch size. |
FlatNetwork |
getCurrentFlatNetwork()
|
double[] |
getLastGradient()
|
MLMethod |
getMethod()
Get the current best machine learning method from the training. |
int |
getThreadCount()
|
abstract void |
initOthers()
|
void |
iteration()
Perform one training iteration. |
void |
iteration(int count)
Perform the specified number of training iterations. |
protected void |
learn()
Apply and learn. |
protected void |
learnLimited()
Apply and learn. |
void |
report(double[] gradients,
double error,
Throwable ex)
Called by the worker threads to report the progress at each step. |
void |
rollIteration()
Increase the iteration by one. |
void |
setBatchSize(int theBatchSize)
Set the batch size. |
void |
setErrorFunction(ErrorFunction ef)
|
void |
setThreadCount(int numThreads)
Set the number of threads. |
abstract double |
updateWeight(double[] gradients,
double[] lastGradient,
int index)
Update a weight, the means by which weights are updated vary depending on the training. |
Methods inherited from class org.encog.ml.train.BasicTraining |
---|
addStrategy, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, postIteration, preIteration, setError, setIteration, setTraining |
Methods inherited from class java.lang.Object |
---|
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait |
Methods inherited from interface org.encog.ml.train.MLTrain |
---|
addStrategy, canContinue, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, pause, resume, setError, setIteration |
Field Detail |
---|
protected double[] gradients
protected final ContainsFlat network
protected double lastError
Constructor Detail |
---|
public Propagation(ContainsFlat network, MLDataSet training)
network
- The network.training
- The training set.Method Detail |
---|
public void finishTraining()
finishTraining
in interface MLTrain
finishTraining
in class BasicTraining
public FlatNetwork getCurrentFlatNetwork()
public MLMethod getMethod()
getMethod
in interface MLTrain
public void iteration()
iteration
in interface MLTrain
public void rollIteration()
public void iteration(int count)
iteration
in interface MLTrain
iteration
in class BasicTraining
count
- The number of training iterations.public void setThreadCount(int numThreads)
setThreadCount
in interface MultiThreadable
numThreads
- The number of threads.public int getThreadCount()
getThreadCount
in interface MultiThreadable
public void fixFlatSpot(boolean b)
b
- True to fix flat spots, false otherwise.public void setErrorFunction(ErrorFunction ef)
public void calculateGradients()
public void report(double[] gradients, double error, Throwable ex)
gradients
- The gradients from that worker.error
- The error for that worker.ex
- The exception.protected void learn()
protected void learnLimited()
public abstract void initOthers()
public abstract double updateWeight(double[] gradients, double[] lastGradient, int index)
gradients
- The gradients.lastGradient
- The last gradients.index
- The index.
public double[] getLastGradient()
public int getBatchSize()
getBatchSize
in interface BatchSize
public void setBatchSize(int theBatchSize)
setBatchSize
in interface BatchSize
theBatchSize
- The batch size.
|
||||||||||
PREV CLASS NEXT CLASS | FRAMES NO FRAMES | |||||||||
SUMMARY: NESTED | FIELD | CONSTR | METHOD | DETAIL: FIELD | CONSTR | METHOD |