1 #include "TrainNodeMsRF.h" 5 #include "sherwood/Sherwood.h" 8 #include "sherwood/ParallelForestTrainer.h" 11 #include "sherwood/utilities/FeatureResponseFunctions.h" 12 #include "sherwood/utilities/StatisticsAggregators.h" 13 #include "sherwood/utilities/DataPointCollection.h" 14 #include "sherwood/utilities/TrainingContexts.h" 35 m_pParams = std::unique_ptr<sw::TrainingParameters>(
new sw::TrainingParameters());
56 std::string fileName =
generateFileName(path, name.empty() ?
"TrainNodeMsRF" : name, idx);
57 m_pRF->Serialize(fileName);
62 std::string fileName =
generateFileName(path, name.empty() ?
"TrainNodeMsRF" : name, idx);
73 #ifdef DEBUG_PRINT_INFO 77 sw::DataPointCollection * pData =
new sw::DataPointCollection();
82 #ifdef DEBUG_PRINT_INFO 83 printf(
"State[%d] - %d of %d samples\n", s, nSamples,
m_pSamplesAcc->getNumInputSamples(s));
85 for (
int smp = 0; smp < nSamples; smp++) {
87 byte fval =
m_pSamplesAcc->getSamplesContainer(s).at<byte>(smp, f);
88 pData->m_vData.push_back(fval);
90 pData->m_vLabels.push_back(s);
101 m_pRF = sw::ParallelForestTrainer<sw::LinearFeatureResponse, sw::HistogramAggregator>::TrainForest(random, *
m_pParams, classificationContext, *pData);
103 m_pRF = sw::ForestTrainer<sw::LinearFeatureResponse, sw::HistogramAggregator>::TrainForest(random, *
m_pParams, classificationContext, *pData);
111 std::unique_ptr<sw::DataPointCollection> testData = std::unique_ptr<sw::DataPointCollection>(
new sw::DataPointCollection());
114 float feature =
static_cast<float>(featureVector.ptr<byte>(f)[0]);
115 testData->m_vData.push_back(feature);
118 std::vector<std::vector<int>> leafNodeIndices;
119 m_pRF->Apply(*testData, leafNodeIndices);
123 for (
size_t t = 0; t <
m_pRF->TreeCount(); t++) {
124 int leafIndex = leafNodeIndices[t][index];
125 h.Aggregate(
m_pRF->GetTree((t)).GetNode(leafIndex).TrainingDataStatistics);
128 float mudiness =
static_cast<float> (0.5 * h.Entropy());
131 potential.at<
float>(s, 0) = (1.0f - mudiness) * h.GetProbability(s);
void init(TrainNodeMsRFParams params)
virtual ~CTrainNodeMsRF(void)
word getNumFeatures(void) const
Returns number of features.
size_t maxSamples
Maximum number of samples to be used in training. 0 means using all the samples.
std::string generateFileName(const std::string &path, const std::string &name, short idx) const
Generates name of the data file for storing random model parameters.
void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
void addFeatureVec(const Mat &featureVector, byte gt)
Adds new feature vector.
void train(bool doClean=false)
Random model training.
std::unique_ptr< sw::TrainingParameters > m_pParams
CTrainNodeMsRF(byte nStates, word nFeatures, TrainNodeMsRFParams params=TRAIN_NODE_MS_RF_PARAMS_DEFAULT)
Constructor.
Base abstract class for random model training.
Microsoft Research Random Forest parameters.
void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Calculates the node potential, based on the feature vector.
void reset(void)
Resets class variables.
const TrainNodeMsRFParams TRAIN_NODE_MS_RF_PARAMS_DEFAULT
void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
int num_ot_trees
Number of trees in the forest (time / accuracy)
Samples accumulator abstract class.
std::unique_ptr< sw::Forest< sw::LinearFeatureResponse, sw::HistogramAggregator > > m_pRF
Random Forest classifier.
unsigned int num_of_candidate_thresholds_per_feature
Number of candidate thresholds (per feature)
std::unique_ptr< CSamplesAccumulator > m_pSamplesAcc
Samples Accumulator.
bool verbose
Verbose mode.
Base abstract class for node potentials training.
byte m_nStates
The number of states (classes)
int num_of_candidate_features
Number of candidate features.
int max_decision_levels
Maximum number of the decision levels.