Direct Graphical Models  v.1.7.0
TrainNodeCvRF.h
1 // Random Forest (based on OpenCV) training class interface
2 // Written by Sergey G. Kosov in 2012 for Project X
3 #pragma once
4 
5 #include "TrainNode.h"
6 
7 namespace DirectGraphicalModels
8 {
9  class RTrees;
10  class CSamplesAccumulator;
11 
13  typedef struct TrainNodeCvRFParams {
14  int max_depth;
21  int maxCount;
22  double epsilon;
24  size_t maxSamples;
25 
27  TrainNodeCvRFParams(int _max_depth, int _min_sample_count, float _regression_accuracy, bool _use_surrogates, int _max_categories, bool _calc_var_importance, int _nactive_vars, int _maxCount, double _epsilon, int _term_criteria_type, size_t _maxSamples) : max_depth(_max_depth), min_sample_count(_min_sample_count), regression_accuracy(_regression_accuracy), use_surrogates(_use_surrogates), max_categories(_max_categories), calc_var_importance(_calc_var_importance), nactive_vars(_nactive_vars), maxCount(_maxCount), epsilon(_epsilon), term_criteria_type(_term_criteria_type), maxSamples(_maxSamples) {}
29 
31  25, // Max depth
32  5, // Min sample count (1% of all data)
33  0, // Regression accuracy (0 means N/A here)
34  false, // Compute surrogate split, no missing data
35  15, // Max number of categories (use sub-optimal algorithm for larger numbers)
36  false, // Calculate variable importance
37  4, // Number of variables randomly selected at node and used to find the best split(s). 0 means sqrt(nFeatures)
38  100, // Max number of trees in the forest (time / accuracy)
39  0.01, // Forest accuracy
40  TermCriteria::MAX_ITER | TermCriteria::EPS, // Termination cirteria (according the the two previous parameters)
41  0 // Maximum number of samples to be used in training. 0 means using all the samples
42  );
43 
44  // =========================== OpenCV RF Train Class ===========================
50  class CTrainNodeCvRF : public CTrainNode
51  {
52  public:
59  DllExport CTrainNodeCvRF(byte nStates, word nFeatures, TrainNodeCvRFParams params = TRAIN_NODE_CV_RF_PARAMS_DEFAULT);
68  DllExport CTrainNodeCvRF(byte nStates, word nFeatures, size_t maxSamples);
69  DllExport ~CTrainNodeCvRF(void);
70 
71  DllExport void reset(void);
72  DllExport void save(const std::string &path, const std::string &name = std::string(), short idx = -1) const;
73  DllExport void load(const std::string &path, const std::string &name = std::string(), short idx = -1);
74 
75  DllExport void addFeatureVec(const Mat &featureVector, byte gt);
76  DllExport void train(bool doClean = false);
77 
84  DllExport Mat getFeatureImportance(void) const;
85 
86 
87  protected:
88  DllExport void saveFile(FILE *pFile) const { }
89  DllExport void loadFile(FILE *pFile) { }
90  DllExport void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const;
91 
92 
93  protected:
94  Ptr<ml::RTrees> m_pRF;
96 
97 
98  private:
99  void init(TrainNodeCvRFParams params); // This function is called by both constructors
100 
101 
102  private:
104  };
105 }
106 
void init(TrainNodeCvRFParams params)
void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
CTrainNodeCvRF(byte nStates, word nFeatures, TrainNodeCvRFParams params=TRAIN_NODE_CV_RF_PARAMS_DEFAULT)
Constructor.
size_t maxSamples
Maximum number of samples to be used in training. 0 means using all the samples.
Definition: TrainNodeCvRF.h:24
Ptr< ml::RTrees > m_pRF
Random Forest.
Definition: TrainNodeCvRF.h:94
OpenCV Random Forest parameters.
Definition: TrainNodeCvRF.h:13
int max_categories
Max number of categories (use sub-optimal algorithm for larger numbers)
Definition: TrainNodeCvRF.h:18
bool use_surrogates
Compute surrogate split, no missing data.
Definition: TrainNodeCvRF.h:17
void addFeatureVec(const Mat &featureVector, byte gt)
Adds new feature vector.
Mat getFeatureImportance(void) const
Returns the feature importance vector.
void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
CSamplesAccumulator * m_pSamplesAcc
Samples Accumulator.
Definition: TrainNodeCvRF.h:95
int min_sample_count
Min sample count (1% of all data)
Definition: TrainNodeCvRF.h:15
void reset(void)
Resets class variables.
void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Calculates the node potential, based on the feature vector.
OpenCV Random Forest training class.
Definition: TrainNodeCvRF.h:50
void loadFile(FILE *pFile)
Loads the random model from the file.
Definition: TrainNodeCvRF.h:89
Samples accumulator abstract class.
const TrainNodeCvRFParams TRAIN_NODE_CV_RF_PARAMS_DEFAULT
Definition: TrainNodeCvRF.h:30
int term_criteria_type
Termination cirteria type (according the the two previous parameters)
Definition: TrainNodeCvRF.h:23
void saveFile(FILE *pFile) const
Saves the random model into the file.
Definition: TrainNodeCvRF.h:88
Base abstract class for node potentials training.
Definition: TrainNode.h:47
bool calc_var_importance
Calculate variable importance (must be true in order to use CTrainNodeCvRF::getFeatureImportance func...
Definition: TrainNodeCvRF.h:19
struct DirectGraphicalModels::TrainNodeCvRFParams TrainNodeCvRFParams
OpenCV Random Forest parameters.
void train(bool doClean=false)
Random model training.
TrainNodeCvRFParams(int _max_depth, int _min_sample_count, float _regression_accuracy, bool _use_surrogates, int _max_categories, bool _calc_var_importance, int _nactive_vars, int _maxCount, double _epsilon, int _term_criteria_type, size_t _maxSamples)
Definition: TrainNodeCvRF.h:27
int maxCount
Max number of trees in the forest (time / accuracy)
Definition: TrainNodeCvRF.h:21
float regression_accuracy
Regression accuracy (0 means N/A here)
Definition: TrainNodeCvRF.h:16
int nactive_vars
Number of variables randomly selected at node and used to find the best split(s). (0 means the ) ...
Definition: TrainNodeCvRF.h:20