Direct Graphical Models  v.1.7.0
TrainNodeCvGMM.h
1 // Gaussian Mixture Model (based on OpenCV) training class interface
2 // Written by Sergey G. Kosov in 2012 for Project X
3 #pragma once
4 
5 #include "TrainNode.h"
6 
7 namespace DirectGraphicalModels
8 {
9  class CSamplesAccumulator;
10 
12  typedef struct TrainNodeCvGMMParams {
13  word numGausses;
15  int maxCount;
16  double epsilon;
18  size_t maxSamples;
19 
21  TrainNodeCvGMMParams(word _numGausses, int _covariance_matrix_type, int _maxCount, double _epsilon, int _term_criteria_type, int _maxSamples) : numGausses(_numGausses), covariance_matrix_type(_covariance_matrix_type), maxCount(_maxCount), epsilon(_epsilon), term_criteria_type(_term_criteria_type), maxSamples(_maxSamples) {}
23 
25  16, // Number of Gaussians
26  ml::EM::COV_MAT_DIAGONAL, // Covariance matrix type
27  100, // Max number of iterations
28  0.01, // GMM accuracy
29  TermCriteria::MAX_ITER | TermCriteria::EPS, // Termination cirteria (according the the two previous parameters)
30  0 // Maximum number of samples to be used in training. 0 means using all the samples
31  );
32 
33  // =========================== OpenCV GMM Train Class ===========================
39  class CTrainNodeCvGMM : public CTrainNode
40  {
41  public:
48  DllExport CTrainNodeCvGMM(byte nStates, word nFeatures, TrainNodeCvGMMParams params = TRAIN_NODE_CV_GMM_PARAMS_DEFAULT);
58  DllExport CTrainNodeCvGMM(byte nStates, word nFeatures, size_t maxSamples, byte nGausses = TRAIN_NODE_CV_GMM_PARAMS_DEFAULT.numGausses);
59  DllExport virtual ~CTrainNodeCvGMM(void);
60 
61  DllExport void reset(void);
62  DllExport void save(const std::string &path, const std::string &name = std::string(), short idx = -1) const;
63  DllExport void load(const std::string &path, const std::string &name = std::string(), short idx = -1);
64 
65  DllExport void addFeatureVec(const Mat &featureVector, byte gt);
66 
67  DllExport void train(bool doClean = false);
68 
69 
70  protected:
71  DllExport void saveFile(FILE *pFile) const { }
72  DllExport void loadFile(FILE *pFile) { }
73  DllExport void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const;
74 
75 
76  private:
77  void init(TrainNodeCvGMMParams params); // This function is called by both constructors
78 
79 
80  private:
81  static const double MIN_COEFFICIENT_BASE;
82 
83 
84  protected:
85  std::vector<Ptr<ml::EM>> m_vpEM;
87 
88  private:
89  long double m_minCoefficient; // = 1; // auxilary coefficient for scaling gaussian coefficients
90  };
91 }
92 
void addFeatureVec(const Mat &featureVector, byte gt)
Adds new feature vector.
void init(TrainNodeCvGMMParams params)
void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
int term_criteria_type
Termination cirteria type (according the the two previous parameters)
void reset(void)
Resets class variables.
void saveFile(FILE *pFile) const
Saves the random model into the file.
word numGausses
The number of Gauss functions for approximation.
size_t maxSamples
Maximum number of samples to be used in training. 0 means using all the samples.
const TrainNodeCvGMMParams TRAIN_NODE_CV_GMM_PARAMS_DEFAULT
int covariance_matrix_type
Type of the covariance matrix.
void loadFile(FILE *pFile)
Loads the random model from the file.
std::vector< Ptr< ml::EM > > m_vpEM
Expectation Maximization for GMM parameters estimation.
void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Calculates the node potential, based on the feature vector.
int maxCount
Max number of iterations.
CTrainNodeCvGMM(byte nStates, word nFeatures, TrainNodeCvGMMParams params=TRAIN_NODE_CV_GMM_PARAMS_DEFAULT)
Constructor.
Samples accumulator abstract class.
void train(bool doClean=false)
Random model training.
void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
CSamplesAccumulator * m_pSamplesAcc
Samples Accumulator.
Base abstract class for node potentials training.
Definition: TrainNode.h:47
OpenCV Gaussian Mixture Model training class.
TrainNodeCvGMMParams(word _numGausses, int _covariance_matrix_type, int _maxCount, double _epsilon, int _term_criteria_type, int _maxSamples)
struct DirectGraphicalModels::TrainNodeCvGMMParams TrainNodeCvGMMParams
OpenCV Random Forest parameters.
OpenCV Random Forest parameters.