Direct Graphical Models  v.1.7.0
TrainNodeCvRF.cpp
1 #include "TrainNodeCvRF.h"
2 #include "SamplesAccumulator.h"
3 #include "macroses.h"
4 
5 namespace DirectGraphicalModels
6 {
7 // Constructor
8 CTrainNodeCvRF::CTrainNodeCvRF(byte nStates, word nFeatures, TrainNodeCvRFParams params) : CBaseRandomModel(nStates), CTrainNode(nStates, nFeatures)
9 {
10  init(params);
11 }
12 
13 // Constructor
14 CTrainNodeCvRF::CTrainNodeCvRF(byte nStates, word nFeatures, size_t maxSamples) : CBaseRandomModel(nStates), CTrainNode(nStates, nFeatures)
15 {
17  params.maxSamples = maxSamples;
18  init(params);
19 }
20 
22 {
24 
25  m_pRF = ml::RTrees::create();
26  m_pRF->setMaxDepth(params.max_depth);
27  m_pRF->setMinSampleCount(params.min_sample_count);
28  m_pRF->setRegressionAccuracy(params.regression_accuracy);
29  m_pRF->setUseSurrogates(params.use_surrogates);
30  m_pRF->setMaxCategories(params.max_categories);
31  m_pRF->setCalculateVarImportance(params.calc_var_importance);
32  m_pRF->setActiveVarCount(params.nactive_vars);
33  m_pRF->setTermCriteria(TermCriteria(params.term_criteria_type, params.maxCount, params.epsilon));
34 }
35 
36 // Destructor
38 {
39  delete m_pSamplesAcc;
40 }
41 
43 {
45  m_pRF->clear();
46 }
47 
48 void CTrainNodeCvRF::save(const std::string &path, const std::string &name, short idx) const
49 {
50  std::string fileName = generateFileName(path, name.empty() ? "TrainNodeCvRF" : name, idx);
51  m_pRF->save(fileName.c_str());
52 }
53 
54 void CTrainNodeCvRF::load(const std::string &path, const std::string &name, short idx)
55 {
56  std::string fileName = generateFileName(path, name.empty() ? "TrainNodeCvRF" : name, idx);
57  m_pRF = Algorithm::load<ml::RTrees>(fileName.c_str());
58 }
59 
60 void CTrainNodeCvRF::addFeatureVec(const Mat &featureVector, byte gt)
61 {
62  m_pSamplesAcc->addSample(featureVector, gt);
63 }
64 
65 void CTrainNodeCvRF::train(bool doClean)
66 {
67 #ifdef DEBUG_PRINT_INFO
68  printf("\n");
69 #endif
70 
71  // Filling the <samples> and <classes>
72  Mat samples, classes;
73  for (byte s = 0; s < m_nStates; s++) { // states
74  int nSamples = m_pSamplesAcc->getNumSamples(s);
75 #ifdef DEBUG_PRINT_INFO
76  printf("State[%d] - %d of %d samples\n", s, nSamples, m_pSamplesAcc->getNumInputSamples(s));
77 #endif
78  samples.push_back(m_pSamplesAcc->getSamplesContainer(s));
79  classes.push_back(Mat(nSamples, 1, CV_32FC1, Scalar(s)));
80  if (doClean) m_pSamplesAcc->release(s); // free memory
81  } // s
82  samples.convertTo(samples, CV_32FC1);
83 
84  // Filling <var_type>
85  Mat var_type(getNumFeatures() + 1, 1, CV_8UC1, Scalar(ml::VAR_NUMERICAL)); // all inputs are numerical
86  var_type.at<byte>(getNumFeatures(), 0) = ml::VAR_CATEGORICAL;
87 
88  // Training
89  try {
90  m_pRF->train(ml::TrainData::create(samples, ml::ROW_SAMPLE, classes, noArray(), noArray(), noArray(), var_type));
91  } catch (std::exception &e) {
92  printf("EXCEPTION: %s\n", e.what());
93  printf("Try to reduce the maximal depth of the forest or switch to x64.\n");
94  getchar();
95  exit(-1);
96  }
97 }
98 
100 {
101  return m_pRF->getVarImportance();
102 }
103 
104 void CTrainNodeCvRF::calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
105 {
106  Mat fv;
107  featureVector.convertTo(fv, CV_32FC1);
108  float res = m_pRF->predict(fv.t());
109  byte s = static_cast<byte>(res);
110  potential.at<float>(s, 0) = 1.0f;
111  potential += 0.1f;
112 
113  //Mat votes;
114  //m_pRF->getVotes(fv.t(), votes, ml::RTrees::Flags::PREDICT_MAX_VOTE);
115  //int sum = 0;
116  //for (int x = 0; x < votes.cols; x++) {
117  // byte s = static_cast<byte>(votes.at<int>(0, x));
118  // int nVotes = votes.at<int>(1, x);
119  // potential.at<float>(s, 0) = static_cast<float>(nVotes);
120  // sum += nVotes;
121  //} // s
122  //if (sum) potential /= sum;
123 }
124 
125 }
void init(TrainNodeCvRFParams params)
void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
CTrainNodeCvRF(byte nStates, word nFeatures, TrainNodeCvRFParams params=TRAIN_NODE_CV_RF_PARAMS_DEFAULT)
Constructor.
size_t maxSamples
Maximum number of samples to be used in training. 0 means using all the samples.
Definition: TrainNodeCvRF.h:24
Ptr< ml::RTrees > m_pRF
Random Forest.
Definition: TrainNodeCvRF.h:94
OpenCV Random Forest parameters.
Definition: TrainNodeCvRF.h:13
word getNumFeatures(void) const
Returns number of features.
Definition: ITrain.h:37
std::string generateFileName(const std::string &path, const std::string &name, short idx) const
Generates name of the data file for storing random model parameters.
int getNumSamples(byte state) const
Returns the number of stored samples in container for the state (class) state.
int max_categories
Max number of categories (use sub-optimal algorithm for larger numbers)
Definition: TrainNodeCvRF.h:18
bool use_surrogates
Compute surrogate split, no missing data.
Definition: TrainNodeCvRF.h:17
void addFeatureVec(const Mat &featureVector, byte gt)
Adds new feature vector.
Mat getFeatureImportance(void) const
Returns the feature importance vector.
void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
CSamplesAccumulator * m_pSamplesAcc
Samples Accumulator.
Definition: TrainNodeCvRF.h:95
int min_sample_count
Min sample count (1% of all data)
Definition: TrainNodeCvRF.h:15
void reset(void)
Resets class variables.
Base abstract class for random model training.
void release(byte state)
Releases memory of container for the state (class) state.
void addSample(const Mat &featureVector, byte state)
Adds new sample to the accumulator.
void reset(void)
Resets the accumulator.
void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Calculates the node potential, based on the feature vector.
Samples accumulator abstract class.
const TrainNodeCvRFParams TRAIN_NODE_CV_RF_PARAMS_DEFAULT
Definition: TrainNodeCvRF.h:30
int term_criteria_type
Termination cirteria type (according the the two previous parameters)
Definition: TrainNodeCvRF.h:23
Base abstract class for node potentials training.
Definition: TrainNode.h:47
bool calc_var_importance
Calculate variable importance (must be true in order to use CTrainNodeCvRF::getFeatureImportance func...
Definition: TrainNodeCvRF.h:19
void train(bool doClean=false)
Random model training.
Mat getSamplesContainer(byte state) const
Returns samples container for the state (class) state.
int getNumInputSamples(byte state) const
Returns the number of input samples in container for the state (class) state.
int maxCount
Max number of trees in the forest (time / accuracy)
Definition: TrainNodeCvRF.h:21
float regression_accuracy
Regression accuracy (0 means N/A here)
Definition: TrainNodeCvRF.h:16
byte m_nStates
The number of states (classes)
int nactive_vars
Number of variables randomly selected at node and used to find the best split(s). (0 means the ) ...
Definition: TrainNodeCvRF.h:20