org.encog.neural.networks.training.propagation
public abstract class Propagation extends BasicTraining implements Train, MultiThreadable, BatchSize
Modifier and Type | Field and Description |
---|---|
protected double[] |
gradients
The gradients.
|
protected ContainsFlat |
network
The network to train.
|
Constructor and Description |
---|
Propagation(ContainsFlat network,
MLDataSet training)
Construct a propagation object.
|
Modifier and Type | Method and Description |
---|---|
void |
calculateGradients()
Calculate the gradients.
|
void |
finishTraining()
Should be called after training has completed and the iteration method
will not be called any further.
|
void |
fixFlatSpot(boolean b)
Default is true.
|
int |
getBatchSize()
The batch size.
|
FlatNetwork |
getCurrentFlatNetwork() |
double[] |
getLastGradient() |
MLMethod |
getMethod()
Get the current best machine learning method from the training.
|
int |
getThreadCount() |
abstract void |
initOthers() |
void |
iteration()
Perform one training iteration.
|
void |
iteration(int count)
Perform the specified number of training iterations.
|
protected void |
learn()
Apply and learn.
|
protected void |
learnLimited()
Apply and learn.
|
void |
report(double[] gradients,
double error,
Throwable ex)
Called by the worker threads to report the progress at each step.
|
void |
rollIteration()
Increase the iteration by one.
|
void |
setBatchSize(int theBatchSize)
Set the batch size.
|
void |
setErrorFunction(ErrorFunction ef) |
void |
setThreadCount(int numThreads)
Set the number of threads.
|
abstract double |
updateWeight(double[] gradients,
double[] lastGradient,
int index)
Update a weight, the means by which weights are updated vary depending on
the training.
|
addStrategy, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, postIteration, preIteration, setError, setIteration, setTraining
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
addStrategy, canContinue, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, pause, resume, setError, setIteration
protected double[] gradients
protected final ContainsFlat network
public Propagation(ContainsFlat network, MLDataSet training)
network
- The network.training
- The training set.public void finishTraining()
finishTraining
in interface MLTrain
finishTraining
in class BasicTraining
public FlatNetwork getCurrentFlatNetwork()
public MLMethod getMethod()
public void iteration()
public void rollIteration()
public void iteration(int count)
iteration
in interface MLTrain
iteration
in class BasicTraining
count
- The number of training iterations.public void setThreadCount(int numThreads)
setThreadCount
in interface MultiThreadable
numThreads
- The number of threads.public int getThreadCount()
getThreadCount
in interface MultiThreadable
public void fixFlatSpot(boolean b)
b
- True to fix flat spots, false otherwise.public void setErrorFunction(ErrorFunction ef)
public void calculateGradients()
public void report(double[] gradients, double error, Throwable ex)
gradients
- The gradients from that worker.error
- The error for that worker.ex
- The exception.protected void learn()
protected void learnLimited()
public abstract void initOthers()
public abstract double updateWeight(double[] gradients, double[] lastGradient, int index)
gradients
- The gradients.lastGradient
- The last gradients.index
- The index.public double[] getLastGradient()
public int getBatchSize()
getBatchSize
in interface BatchSize
public void setBatchSize(int theBatchSize)
setBatchSize
in interface BatchSize
theBatchSize
- The batch size.Copyright © 2014. All Rights Reserved.