Direct Graphical Models  v.1.7.0
TrainNodeMsRF.cpp
1 #include "TrainNodeMsRF.h"
2 
3 #ifdef USE_SHERWOOD
4 
5 #include "sherwood/Sherwood.h"
6 
7 #ifdef ENABLE_PPL
8 #include "sherwood/ParallelForestTrainer.h" // for parallle computing
9 #endif
10 
11 #include "sherwood/utilities/FeatureResponseFunctions.h"
12 #include "sherwood/utilities/StatisticsAggregators.h"
13 #include "sherwood/utilities/DataPointCollection.h"
14 #include "sherwood/utilities/TrainingContexts.h"
15 
16 namespace DirectGraphicalModels
17 {
18 // Constructor
19  CTrainNodeMsRF::CTrainNodeMsRF(byte nStates, word nFeatures, TrainNodeMsRFParams params) : CBaseRandomModel(nStates), CTrainNode(nStates, nFeatures)
20 {
21  init(params);
22 }
23 
24 // Constructor
25  CTrainNodeMsRF::CTrainNodeMsRF(byte nStates, word nFeatures, size_t maxSamples) : CBaseRandomModel(nStates), CTrainNode(nStates, nFeatures)
26 {
28  params.maxSamples = maxSamples;
29  init(params);
30 }
31 
33 {
34  m_pSamplesAcc = std::unique_ptr<CSamplesAccumulator>(new CSamplesAccumulator(m_nStates, params.maxSamples));
35  m_pParams = std::unique_ptr<sw::TrainingParameters>(new sw::TrainingParameters());
36  // Some default parameters
37  m_pParams->MaxDecisionLevels = params.max_decision_levels - 1;
38  m_pParams->NumberOfCandidateFeatures = params.num_of_candidate_features;
39  m_pParams->NumberOfCandidateThresholdsPerFeature = params.num_of_candidate_thresholds_per_feature;
40  m_pParams->NumberOfTrees = params.num_ot_trees;
41  m_pParams->Verbose = params.verbose;
42 }
43 
44 // Destructor
46 {}
47 
49 {
50  m_pSamplesAcc->reset();
51  m_pRF.reset();
52 }
53 
54 void CTrainNodeMsRF::save(const std::string &path, const std::string &name, short idx) const
55 {
56  std::string fileName = generateFileName(path, name.empty() ? "TrainNodeMsRF" : name, idx);
57  m_pRF->Serialize(fileName);
58 }
59 
60 void CTrainNodeMsRF::load(const std::string &path, const std::string &name, short idx)
61 {
62  std::string fileName = generateFileName(path, name.empty() ? "TrainNodeMsRF" : name, idx);
64 }
65 
66 void CTrainNodeMsRF::addFeatureVec(const Mat &featureVector, byte gt)
67 {
68  m_pSamplesAcc->addSample(featureVector, gt);
69 }
70 
71 void CTrainNodeMsRF::train(bool doClean)
72 {
73 #ifdef DEBUG_PRINT_INFO
74  printf("\n");
75 #endif
76  // Filling <pData>
77  sw::DataPointCollection * pData = new sw::DataPointCollection();
78  pData->m_dimension = getNumFeatures();
79 
80  for (byte s = 0; s < m_nStates; s++) { // states
81  int nSamples = m_pSamplesAcc->getNumSamples(s);
82 #ifdef DEBUG_PRINT_INFO
83  printf("State[%d] - %d of %d samples\n", s, nSamples, m_pSamplesAcc->getNumInputSamples(s));
84 #endif
85  for (int smp = 0; smp < nSamples; smp++) {
86  for (word f = 0; f < getNumFeatures(); f++) { // features
87  byte fval = m_pSamplesAcc->getSamplesContainer(s).at<byte>(smp, f);
88  pData->m_vData.push_back(fval);
89  } // f
90  pData->m_vLabels.push_back(s);
91  } // smp
92  if (doClean) m_pSamplesAcc->release(s); // releases memory
93  } // s
94 
95 
96  // Training
97  sw::Random random;
98  sw::ClassificationTrainingContext classificationContext(m_nStates, getNumFeatures());
99 #ifdef ENABLE_PPL
100  // Use this function with cautions - it is not verifiied!
101  m_pRF = sw::ParallelForestTrainer<sw::LinearFeatureResponse, sw::HistogramAggregator>::TrainForest(random, *m_pParams, classificationContext, *pData);
102 #else
103  m_pRF = sw::ForestTrainer<sw::LinearFeatureResponse, sw::HistogramAggregator>::TrainForest(random, *m_pParams, classificationContext, *pData);
104 #endif
105 
106  delete pData;
107 }
108 
109 void CTrainNodeMsRF::calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
110 {
111  std::unique_ptr<sw::DataPointCollection> testData = std::unique_ptr<sw::DataPointCollection>(new sw::DataPointCollection());
112  testData->m_dimension = getNumFeatures();
113  for (word f = 0; f < getNumFeatures(); f++) {
114  float feature = static_cast<float>(featureVector.ptr<byte>(f)[0]);
115  testData->m_vData.push_back(feature);
116  }
117 
118  std::vector<std::vector<int>> leafNodeIndices;
119  m_pRF->Apply(*testData, leafNodeIndices);
120 
121  sw::HistogramAggregator h(m_nStates);
122  int index = 0;
123  for (size_t t = 0; t < m_pRF->TreeCount(); t++) {
124  int leafIndex = leafNodeIndices[t][index];
125  h.Aggregate(m_pRF->GetTree((t)).GetNode(leafIndex).TrainingDataStatistics);
126  } // t
127 
128  float mudiness = static_cast<float> (0.5 * h.Entropy());
129 
130  for (byte s = 0; s < m_nStates; s++)
131  potential.at<float>(s, 0) = (1.0f - mudiness) * h.GetProbability(s);
132 }
133 }
134 #endif
void init(TrainNodeMsRFParams params)
word getNumFeatures(void) const
Returns number of features.
Definition: ITrain.h:37
size_t maxSamples
Maximum number of samples to be used in training. 0 means using all the samples.
Definition: TrainNodeMsRF.h:29
std::string generateFileName(const std::string &path, const std::string &name, short idx) const
Generates name of the data file for storing random model parameters.
void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
void addFeatureVec(const Mat &featureVector, byte gt)
Adds new feature vector.
void train(bool doClean=false)
Random model training.
std::unique_ptr< sw::TrainingParameters > m_pParams
CTrainNodeMsRF(byte nStates, word nFeatures, TrainNodeMsRFParams params=TRAIN_NODE_MS_RF_PARAMS_DEFAULT)
Constructor.
Base abstract class for random model training.
Microsoft Research Random Forest parameters.
Definition: TrainNodeMsRF.h:23
void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Calculates the node potential, based on the feature vector.
void reset(void)
Resets class variables.
const TrainNodeMsRFParams TRAIN_NODE_MS_RF_PARAMS_DEFAULT
Definition: TrainNodeMsRF.h:35
void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
int num_ot_trees
Number of trees in the forest (time / accuracy)
Definition: TrainNodeMsRF.h:27
Samples accumulator abstract class.
std::unique_ptr< sw::Forest< sw::LinearFeatureResponse, sw::HistogramAggregator > > m_pRF
Random Forest classifier.
Definition: TrainNodeMsRF.h:99
unsigned int num_of_candidate_thresholds_per_feature
Number of candidate thresholds (per feature)
Definition: TrainNodeMsRF.h:26
std::unique_ptr< CSamplesAccumulator > m_pSamplesAcc
Samples Accumulator.
Base abstract class for node potentials training.
Definition: TrainNode.h:47
byte m_nStates
The number of states (classes)
int num_of_candidate_features
Number of candidate features.
Definition: TrainNodeMsRF.h:25
int max_decision_levels
Maximum number of the decision levels.
Definition: TrainNodeMsRF.h:24