org.encog.neural.networks.training.propagation.quick
public class QuickPropagation extends Propagation implements LearningRate
Modifier and Type | Field and Description |
---|---|
static String |
LAST_GRADIENTS
Continuation tag for the last gradients.
|
gradients, network
Constructor and Description |
---|
QuickPropagation(ContainsFlat network,
MLDataSet training)
Construct a QPROP trainer for flat networks.
|
QuickPropagation(ContainsFlat network,
MLDataSet training,
double theLearningRate)
Construct a QPROP trainer for flat networks.
|
Modifier and Type | Method and Description |
---|---|
boolean |
canContinue() |
double[] |
getLastDelta() |
double |
getLearningRate() |
double |
getOutputEpsilon() |
double |
getShrink() |
void |
initOthers()
Perform training method specific init.
|
boolean |
isValidResume(TrainingContinuation state)
Determine if the specified continuation object is valid to resume with.
|
TrainingContinuation |
pause()
Pause the training.
|
void |
resume(TrainingContinuation state)
Resume training.
|
void |
setBatchSize(int theBatchSize)
Do not allow batch sizes other than 0, not supported.
|
void |
setLearningRate(double rate)
Set the learning rate, this is value is essentially a percent.
|
void |
setOutputEpsilon(double theOutputEpsilon) |
void |
setShrink(double s) |
double |
updateWeight(double[] gradients,
double[] lastGradient,
int index)
Update a weight.
|
calculateGradients, finishTraining, fixFlatSpot, getBatchSize, getCurrentFlatNetwork, getLastGradient, getMethod, getThreadCount, iteration, iteration, learn, learnLimited, report, rollIteration, setErrorFunction, setThreadCount
addStrategy, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, postIteration, preIteration, setError, setIteration, setTraining
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
addStrategy, getError, getImplementationType, getIteration, getStrategies, getTraining, isTrainingDone, setError, setIteration
public static final String LAST_GRADIENTS
public QuickPropagation(ContainsFlat network, MLDataSet training)
network
- The network to train.training
- The training data.public QuickPropagation(ContainsFlat network, MLDataSet training, double theLearningRate)
network
- The network to train.training
- The training data.theLearningRate
- The learning rate. 2 is a good suggestion as
a learning rate to start with. If it fails to converge,
then drop it. Just like backprop, except QPROP can
take higher learning rates.public boolean canContinue()
canContinue
in interface MLTrain
public double[] getLastDelta()
public double getLearningRate()
getLearningRate
in interface LearningRate
public boolean isValidResume(TrainingContinuation state)
state
- The continuation object to check.public TrainingContinuation pause()
public void resume(TrainingContinuation state)
public void setLearningRate(double rate)
setLearningRate
in interface LearningRate
rate
- The learning rate.public double getOutputEpsilon()
public double getShrink()
public void setShrink(double s)
s
- the shrink to setpublic void setOutputEpsilon(double theOutputEpsilon)
theOutputEpsilon
- the outputEpsilon to setpublic void initOthers()
initOthers
in class Propagation
public double updateWeight(double[] gradients, double[] lastGradient, int index)
updateWeight
in class Propagation
gradients
- The gradients.lastGradient
- The last gradients.index
- The index.public void setBatchSize(int theBatchSize)
setBatchSize
in interface BatchSize
setBatchSize
in class Propagation
theBatchSize
- The batch size.Copyright © 2014. All Rights Reserved.