org.encog.ml.hmm.train.bw
Class BaseBaumWelch

java.lang.Object
  extended by org.encog.ml.hmm.train.bw.BaseBaumWelch
All Implemented Interfaces:
MLTrain
Direct Known Subclasses:
TrainBaumWelch, TrainBaumWelchScaled

public abstract class BaseBaumWelch
extends Object
implements MLTrain

This class provides the base implementation for Baum-Welch learning for HMM's. There are currently two implementations provided. TrainBaumWelch - Regular Baum Welch Learning. TrainBaumWelchScaled - Regular Baum Welch Learning, which can handle underflows in long sequences. L. E. Baum, T. Petrie, G. Soules, and N. Weiss, "A maximization technique occurring in the statistical analysis of probabilistic functions of Markov chains" , Ann. Math. Statist., vol. 41, no. 1, pp. 164-171, 1970. Hidden Markov Models and the Baum-Welch Algorithm, IEEE Information Theory Society Newsletter, Dec. 2003.


Constructor Summary
BaseBaumWelch(HiddenMarkovModel hmm, MLSequenceSet training)
           
 
Method Summary
 void addStrategy(Strategy strategy)
          Training strategies can be added to improve the training results.
 boolean canContinue()
           
protected  double[][] estimateGamma(double[][][] xi, ForwardBackwardCalculator fbc)
           
abstract  double[][][] estimateXi(MLDataSet sequence, ForwardBackwardCalculator fbc, HiddenMarkovModel hmm)
           
 void finishTraining()
          Should be called once training is complete and no more iterations are needed.
abstract  ForwardBackwardCalculator generateForwardBackwardCalculator(MLDataSet sequence, HiddenMarkovModel hmm)
           
 double getError()
           
 TrainingImplementationType getImplementationType()
           
 int getIteration()
           
 MLMethod getMethod()
          Get the current best machine learning method from the training.
 List<Strategy> getStrategies()
           
 MLDataSet getTraining()
           
 boolean isTrainingDone()
           
 void iteration()
          Perform one iteration of training.
 void iteration(int count)
          Perform a number of training iterations.
 TrainingContinuation pause()
          Pause the training to continue later.
 void resume(TrainingContinuation state)
          Resume training.
 void setError(double error)
           
 void setIteration(int iteration)
          Set the current training iteration.
 
Methods inherited from class java.lang.Object
clone, equals, finalize, getClass, hashCode, notify, notifyAll, toString, wait, wait, wait
 

Constructor Detail

BaseBaumWelch

public BaseBaumWelch(HiddenMarkovModel hmm,
                     MLSequenceSet training)
Method Detail

addStrategy

public void addStrategy(Strategy strategy)
Description copied from interface: MLTrain
Training strategies can be added to improve the training results. There are a number to choose from, and several can be used at once.

Specified by:
addStrategy in interface MLTrain
Parameters:
strategy - The strategy to add.

canContinue

public boolean canContinue()
Specified by:
canContinue in interface MLTrain
Returns:
True if the training can be paused, and later continued.

estimateGamma

protected double[][] estimateGamma(double[][][] xi,
                                   ForwardBackwardCalculator fbc)

estimateXi

public abstract double[][][] estimateXi(MLDataSet sequence,
                                        ForwardBackwardCalculator fbc,
                                        HiddenMarkovModel hmm)

finishTraining

public void finishTraining()
Description copied from interface: MLTrain
Should be called once training is complete and no more iterations are needed. Calling iteration again will simply begin the training again, and require finishTraining to be called once the new training session is complete. It is particularly important to call finishTraining for multithreaded training techniques.

Specified by:
finishTraining in interface MLTrain

generateForwardBackwardCalculator

public abstract ForwardBackwardCalculator generateForwardBackwardCalculator(MLDataSet sequence,
                                                                            HiddenMarkovModel hmm)

getError

public double getError()
Specified by:
getError in interface MLTrain
Returns:
Returns the training error. This value is calculated as the training data is evaluated by the iteration function. This has two important ramifications. First, the value returned by getError() is meaningless prior to a call to iteration. Secondly, the error is calculated BEFORE training is applied by the call to iteration. The timing of the error calculation is done for performance reasons.

getImplementationType

public TrainingImplementationType getImplementationType()
Specified by:
getImplementationType in interface MLTrain
Returns:
The training implementation type.

getIteration

public int getIteration()
Specified by:
getIteration in interface MLTrain
Returns:
The current training iteration.

getMethod

public MLMethod getMethod()
Description copied from interface: MLTrain
Get the current best machine learning method from the training.

Specified by:
getMethod in interface MLTrain
Returns:
The best machine learningm method.

getStrategies

public List<Strategy> getStrategies()
Specified by:
getStrategies in interface MLTrain
Returns:
The strategies to use.

getTraining

public MLDataSet getTraining()
Specified by:
getTraining in interface MLTrain
Returns:
The training data to use.

isTrainingDone

public boolean isTrainingDone()
Specified by:
isTrainingDone in interface MLTrain
Returns:
True if training can progress no further.

iteration

public void iteration()
Description copied from interface: MLTrain
Perform one iteration of training.

Specified by:
iteration in interface MLTrain

iteration

public void iteration(int count)
Description copied from interface: MLTrain
Perform a number of training iterations.

Specified by:
iteration in interface MLTrain
Parameters:
count - The number of iterations to perform.

pause

public TrainingContinuation pause()
Description copied from interface: MLTrain
Pause the training to continue later.

Specified by:
pause in interface MLTrain
Returns:
A training continuation object.

resume

public void resume(TrainingContinuation state)
Description copied from interface: MLTrain
Resume training.

Specified by:
resume in interface MLTrain
Parameters:
state - The training continuation object to use to continue.

setError

public void setError(double error)
Specified by:
setError in interface MLTrain
Parameters:
error - Set the current error rate. This is usually used by training strategies.

setIteration

public void setIteration(int iteration)
Description copied from interface: MLTrain
Set the current training iteration.

Specified by:
setIteration in interface MLTrain
Parameters:
iteration - Iteration.


Copyright © 2014. All Rights Reserved.