Direct Graphical Models  v.1.7.0
TrainNodeCvGMM.cpp
1 #include "TrainNodeCvGMM.h"
2 #include "SamplesAccumulator.h"
3 #include "random.h"
4 #include "macroses.h"
5 
6 namespace DirectGraphicalModels
7 {
8 // Constatnts
9 const double CTrainNodeCvGMM::MIN_COEFFICIENT_BASE = 32.0;
10 
11 // Constructor
12 CTrainNodeCvGMM::CTrainNodeCvGMM(byte nStates, word nFeatures, TrainNodeCvGMMParams params) : CBaseRandomModel(nStates), CTrainNode(nStates, nFeatures), m_minCoefficient(1)
13 {
14  init(params);
15 }
16 
17 // Constructor
18 CTrainNodeCvGMM::CTrainNodeCvGMM(byte nStates, word nFeatures, size_t maxSamples, byte numGausses) : CBaseRandomModel(nStates), CTrainNode(nStates, nFeatures), m_minCoefficient(1)
19 {
21  params.maxSamples = maxSamples;
22  params.numGausses = numGausses;
23  init(params);
24 }
25 
27 {
29 
30  for (byte s = 0; s < m_nStates; s++) {
31  Ptr<ml::EM> pEM = ml::EM::create();
32  pEM->setClustersNumber(params.numGausses);
33  pEM->setCovarianceMatrixType(params.covariance_matrix_type);
34  pEM->setTermCriteria(TermCriteria(params.term_criteria_type, params.maxCount, params.epsilon));
35  m_vpEM.push_back(pEM);
36  }
37 }
38 
39 // Destructor
41 {
42  delete m_pSamplesAcc;
43  m_vpEM.clear();
44 }
45 
47 {
49  for (Ptr<ml::EM> &em : m_vpEM) em->clear();
50 }
51 
52 void CTrainNodeCvGMM::save(const std::string &path, const std::string &name, short idx) const
53 {
54  for (byte s = 0; s < m_nStates; s++) {
55  std::string fileName = generateFileName(path, name.empty() ? "TrainNodeCvGMM_" + std::to_string(s) : name + "_" + std::to_string(s), idx);
56  m_vpEM[s]->save(fileName.c_str());
57  }
58 }
59 
60 void CTrainNodeCvGMM::load(const std::string &path, const std::string &name, short idx)
61 {
62  for (byte s = 0; s < m_nStates; s++) {
63  std::string fileName = generateFileName(path, name.empty() ? "TrainNodeCvGMM_" + std::to_string(s) : name + "_" + std::to_string(s), idx);
64  try {
65  m_vpEM[s] = Algorithm::load<ml::EM>(fileName.c_str());
66  } catch (Exception &) {
67  printf("In file: %s\n", fileName.c_str());
68  }
69  }
70 
72 }
73 
74 void CTrainNodeCvGMM::addFeatureVec(const Mat &featureVector, byte gt)
75 {
76  m_pSamplesAcc->addSample(featureVector, gt);
77 }
78 
79 void CTrainNodeCvGMM::train(bool doClean)
80 {
81 #ifdef DEBUG_PRINT_INFO
82  printf("\n");
83 #endif
84 
85  for (byte s = 0; s < m_nStates; s++) { // states
86  int nSamples = m_pSamplesAcc->getNumSamples(s);
87 #ifdef DEBUG_PRINT_INFO
88  printf("State[%d] - %d of %d samples\n", s, nSamples, m_pSamplesAcc->getNumInputSamples(s));
89 #endif
90  if (nSamples == 0) continue;
91  DGM_IF_WARNING(!m_vpEM[s]->trainEM(m_pSamplesAcc->getSamplesContainer(s)), "Error EM training!");
92  if (doClean) m_pSamplesAcc->release(s);
93  } // s
94 
96 }
97 
98 void CTrainNodeCvGMM::calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
99 {
100  Mat fv;
101  featureVector.convertTo(fv, CV_64FC1);
102 
103  // Min Coefficient approach
104  for (byte s = 0; s < m_nStates; s++) { // state
105  float * pPot = potential.ptr<float>(s);
106  byte * pMask = mask.ptr<byte>(s);
107  if (m_vpEM[s]->isTrained())
108  pPot[0] = static_cast<float>(std::exp(m_vpEM[s]->predict2(fv.t(), noArray())[0]) * m_minCoefficient);
109  else {
110  // pPot[0] = 0;
111  pMask[0] = 0;
112  }
113  } // s
114 
115 
116  // Minimax approach
117  /*double min = 1.0e+150;
118  double max = 1.0e-150;
119  double *v = new double[m_nStates];
120  for (byte s = 0; s < m_nStates; s++) { // state
121  if (m_pEM[s].isTrained()) {
122  v[s] = std::exp(m_pEM[s].predict(fv)[0]);
123  if (max < v[s]) max = v[s];
124  if (min > v[s]) min = v[s];
125  }
126  }
127  for (byte s = 0; s < m_nStates; s++) {
128  v[s] /= (max - min);
129  res.at<float>(s, 0) = static_cast<float>(v[s]);
130  }
131  delete [] v;*/
132 }
133 }
void addFeatureVec(const Mat &featureVector, byte gt)
Adds new feature vector.
void init(TrainNodeCvGMMParams params)
void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
int term_criteria_type
Termination cirteria type (according the the two previous parameters)
void reset(void)
Resets class variables.
word getNumFeatures(void) const
Returns number of features.
Definition: ITrain.h:37
std::string generateFileName(const std::string &path, const std::string &name, short idx) const
Generates name of the data file for storing random model parameters.
int getNumSamples(byte state) const
Returns the number of stored samples in container for the state (class) state.
word numGausses
The number of Gauss functions for approximation.
size_t maxSamples
Maximum number of samples to be used in training. 0 means using all the samples.
Base abstract class for random model training.
void release(byte state)
Releases memory of container for the state (class) state.
const TrainNodeCvGMMParams TRAIN_NODE_CV_GMM_PARAMS_DEFAULT
void addSample(const Mat &featureVector, byte state)
Adds new sample to the accumulator.
void reset(void)
Resets the accumulator.
int covariance_matrix_type
Type of the covariance matrix.
std::vector< Ptr< ml::EM > > m_vpEM
Expectation Maximization for GMM parameters estimation.
void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Calculates the node potential, based on the feature vector.
int maxCount
Max number of iterations.
CTrainNodeCvGMM(byte nStates, word nFeatures, TrainNodeCvGMMParams params=TRAIN_NODE_CV_GMM_PARAMS_DEFAULT)
Constructor.
Samples accumulator abstract class.
void train(bool doClean=false)
Random model training.
void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
CSamplesAccumulator * m_pSamplesAcc
Samples Accumulator.
Base abstract class for node potentials training.
Definition: TrainNode.h:47
Mat getSamplesContainer(byte state) const
Returns samples container for the state (class) state.
int getNumInputSamples(byte state) const
Returns the number of input samples in container for the state (class) state.
byte m_nStates
The number of states (classes)
OpenCV Random Forest parameters.