Direct Graphical Models  v.1.7.0
TrainNodeMsRF.h
1 // Random Forest (based on Microsof Sherwood library) training class interface
2 // Written by Sergey G. Kosov in 2013 for Project X
3 #pragma once
4 
5 #include "TrainNode.h"
6 #include "SamplesAccumulator.h"
7 
8 //#ifdef USE_SHERWOOD
9 
10 namespace MicrosoftResearch { namespace Cambridge { namespace Sherwood {
11  class LinearFeatureResponse;
12  class HistogramAggregator;
13  class DataPointCollection;
14  template<class F, class S> class Forest;
15  struct TrainingParameters;
16 }}}
17 
19 
20 namespace DirectGraphicalModels
21 {
23  typedef struct TrainNodeMsRFParams {
28  bool verbose;
29  size_t maxSamples;
30 
32  TrainNodeMsRFParams(int _max_decision_levels, int _num_of_candidate_features, unsigned int _num_of_candidate_thresholds_per_feature, int _num_ot_trees, bool _verbose, int _maxSamples) : max_decision_levels(_max_decision_levels), num_of_candidate_features(_num_of_candidate_features), num_of_candidate_thresholds_per_feature(_num_of_candidate_thresholds_per_feature), num_ot_trees(_num_ot_trees), verbose(_verbose), maxSamples(_maxSamples) {}
34 
36  10, // Maximum number of the decision levels
37  10, // Number of candidate features
38  10, // Number of candidate thresholds (per feature)
39  10, // Number of trees in the forest (time / accuracy)
40  false, // Verbose mode
41  0 // Maximum number of samples to be used in training. 0 means using all the samples
42  );
43 
44  // =========================== Microsoft RF Train Class ===========================
52  class CTrainNodeMsRF : public CTrainNode
53  {
54  public:
61  DllExport CTrainNodeMsRF(byte nStates, word nFeatures, TrainNodeMsRFParams params = TRAIN_NODE_MS_RF_PARAMS_DEFAULT);
71  DllExport CTrainNodeMsRF(byte nStates, word nFeatures, size_t maxSamples);
72  DllExport virtual ~CTrainNodeMsRF(void);
73 
80  DllExport void reset(void);
81  DllExport void save(const std::string &path, const std::string &name = std::string(), short idx = -1) const;
82  DllExport void load(const std::string &path, const std::string &name = std::string(), short idx = -1);
83 
84  DllExport void addFeatureVec(const Mat &featureVector, byte gt);
85  DllExport void train(bool doClean = false);
86 
87 
88  protected:
89  DllExport void saveFile(FILE *pFile) const { }
90  DllExport void loadFile(FILE *pFile) { }
91  DllExport void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const;
92 
93 
94  private:
95  void init(TrainNodeMsRFParams params); // This function is called by both constructors
96 
97 
98  private:
99  std::unique_ptr<sw::Forest<sw::LinearFeatureResponse, sw::HistogramAggregator>> m_pRF;
100  std::unique_ptr<CSamplesAccumulator> m_pSamplesAcc;
101  std::unique_ptr<sw::TrainingParameters> m_pParams;
102  };
103 }
104 //#endif
void init(TrainNodeMsRFParams params)
struct DirectGraphicalModels::TrainNodeMsRFParams TrainNodeMsRFParams
Microsoft Research Random Forest parameters.
size_t maxSamples
Maximum number of samples to be used in training. 0 means using all the samples.
Definition: TrainNodeMsRF.h:29
void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
void addFeatureVec(const Mat &featureVector, byte gt)
Adds new feature vector.
void train(bool doClean=false)
Random model training.
std::unique_ptr< sw::TrainingParameters > m_pParams
CTrainNodeMsRF(byte nStates, word nFeatures, TrainNodeMsRFParams params=TRAIN_NODE_MS_RF_PARAMS_DEFAULT)
Constructor.
Microsoft Research Random Forest parameters.
Definition: TrainNodeMsRF.h:23
void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Calculates the node potential, based on the feature vector.
void loadFile(FILE *pFile)
Loads the random model from the file.
Definition: TrainNodeMsRF.h:90
void reset(void)
Resets class variables.
const TrainNodeMsRFParams TRAIN_NODE_MS_RF_PARAMS_DEFAULT
Definition: TrainNodeMsRF.h:35
void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
int num_ot_trees
Number of trees in the forest (time / accuracy)
Definition: TrainNodeMsRF.h:27
std::unique_ptr< sw::Forest< sw::LinearFeatureResponse, sw::HistogramAggregator > > m_pRF
Random Forest classifier.
Definition: TrainNodeMsRF.h:99
unsigned int num_of_candidate_thresholds_per_feature
Number of candidate thresholds (per feature)
Definition: TrainNodeMsRF.h:26
std::unique_ptr< CSamplesAccumulator > m_pSamplesAcc
Samples Accumulator.
Base abstract class for node potentials training.
Definition: TrainNode.h:47
void saveFile(FILE *pFile) const
Saves the random model into the file.
Definition: TrainNodeMsRF.h:89
Microsoft Sherwood Random Forest training class.
Definition: TrainNodeMsRF.h:52
TrainNodeMsRFParams(int _max_decision_levels, int _num_of_candidate_features, unsigned int _num_of_candidate_thresholds_per_feature, int _num_ot_trees, bool _verbose, int _maxSamples)
Definition: TrainNodeMsRF.h:32
int num_of_candidate_features
Number of candidate features.
Definition: TrainNodeMsRF.h:25
int max_decision_levels
Maximum number of the decision levels.
Definition: TrainNodeMsRF.h:24