6 #include "SamplesAccumulator.h" 11 class LinearFeatureResponse;
12 class HistogramAggregator;
13 class DataPointCollection;
14 template<
class F,
class S>
class Forest;
15 struct TrainingParameters;
32 TrainNodeMsRFParams(
int _max_decision_levels,
int _num_of_candidate_features,
unsigned int _num_of_candidate_thresholds_per_feature,
int _num_ot_trees,
bool _verbose,
int _maxSamples) :
max_decision_levels(_max_decision_levels),
num_of_candidate_features(_num_of_candidate_features),
num_of_candidate_thresholds_per_feature(_num_of_candidate_thresholds_per_feature),
num_ot_trees(_num_ot_trees),
verbose(_verbose),
maxSamples(_maxSamples) {}
71 DllExport
CTrainNodeMsRF(byte nStates, word nFeatures,
size_t maxSamples);
80 DllExport
void reset(
void);
81 DllExport
void save(
const std::string &path,
const std::string &name = std::string(),
short idx = -1)
const;
82 DllExport
void load(
const std::string &path,
const std::string &name = std::string(),
short idx = -1);
84 DllExport
void addFeatureVec(
const Mat &featureVector, byte gt);
85 DllExport
void train(
bool doClean =
false);
99 std::unique_ptr<sw::Forest<sw::LinearFeatureResponse, sw::HistogramAggregator>>
m_pRF;
void init(TrainNodeMsRFParams params)
virtual ~CTrainNodeMsRF(void)
struct DirectGraphicalModels::TrainNodeMsRFParams TrainNodeMsRFParams
Microsoft Research Random Forest parameters.
size_t maxSamples
Maximum number of samples to be used in training. 0 means using all the samples.
void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
void addFeatureVec(const Mat &featureVector, byte gt)
Adds new feature vector.
void train(bool doClean=false)
Random model training.
std::unique_ptr< sw::TrainingParameters > m_pParams
CTrainNodeMsRF(byte nStates, word nFeatures, TrainNodeMsRFParams params=TRAIN_NODE_MS_RF_PARAMS_DEFAULT)
Constructor.
Microsoft Research Random Forest parameters.
void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Calculates the node potential, based on the feature vector.
void loadFile(FILE *pFile)
Loads the random model from the file.
void reset(void)
Resets class variables.
const TrainNodeMsRFParams TRAIN_NODE_MS_RF_PARAMS_DEFAULT
void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
int num_ot_trees
Number of trees in the forest (time / accuracy)
std::unique_ptr< sw::Forest< sw::LinearFeatureResponse, sw::HistogramAggregator > > m_pRF
Random Forest classifier.
unsigned int num_of_candidate_thresholds_per_feature
Number of candidate thresholds (per feature)
std::unique_ptr< CSamplesAccumulator > m_pSamplesAcc
Samples Accumulator.
bool verbose
Verbose mode.
Base abstract class for node potentials training.
void saveFile(FILE *pFile) const
Saves the random model into the file.
Microsoft Sherwood Random Forest training class.
TrainNodeMsRFParams(int _max_decision_levels, int _num_of_candidate_features, unsigned int _num_of_candidate_thresholds_per_feature, int _num_ot_trees, bool _verbose, int _maxSamples)
int num_of_candidate_features
Number of candidate features.
int max_decision_levels
Maximum number of the decision levels.