Direct Graphical Models  v.1.7.0
TrainNodeKNN.cpp
1 #include "TrainNodeKNN.h"
2 #include "KDTree.h"
3 #include "SamplesAccumulator.h"
4 #include "mathop.h"
5 
6 namespace DirectGraphicalModels
7 {
8  // Constructor
9  CTrainNodeKNN::CTrainNodeKNN(byte nStates, word nFeatures, TrainNodeKNNParams params) : CBaseRandomModel(nStates), CTrainNode(nStates, nFeatures)
10  {
11  init(params);
12  }
13 
14  // Constructor
15  CTrainNodeKNN::CTrainNodeKNN(byte nStates, word nFeatures, size_t maxSamples) : CBaseRandomModel(nStates), CTrainNode(nStates, nFeatures)
16  {
18  params.maxSamples = maxSamples;
19  init(params);
20  }
21 
23  {
25  m_pTree = new CKDTree();
26  m_params = params;
27  }
28 
29  // Destructor
31  {
32  delete m_pSamplesAcc;
33  delete m_pTree;
34  }
35 
37  {
39  m_pTree->reset();
40  }
41 
42  void CTrainNodeKNN::save(const std::string &path, const std::string &name, short idx) const
43  {
44  std::string fileName = generateFileName(path, name.empty() ? "TrainNodeKNN" : name, idx);
45  m_pTree->save(fileName);
46  }
47 
48  void CTrainNodeKNN::load(const std::string &path, const std::string &name, short idx)
49  {
50  std::string fileName = generateFileName(path, name.empty() ? "TrainNodeKNN" : name, idx);
51  m_pTree->load(fileName);
52  }
53 
54  void CTrainNodeKNN::addFeatureVec(const Mat &featureVector, byte gt)
55  {
56  m_pSamplesAcc->addSample(featureVector, gt);
57  }
58 
59  void CTrainNodeKNN::train(bool doClean)
60  {
61 #ifdef DEBUG_PRINT_INFO
62  printf("\n");
63 #endif
64 
65  // Filling the <samples> and <classes>
66  Mat samples, classes;
67  for (byte s = 0; s < m_nStates; s++) { // states
68  int nSamples = m_pSamplesAcc->getNumSamples(s);
69 #ifdef DEBUG_PRINT_INFO
70  printf("State[%d] - %d of %d samples\n", s, nSamples, m_pSamplesAcc->getNumInputSamples(s));
71 #endif
72  samples.push_back(m_pSamplesAcc->getSamplesContainer(s));
73  classes.push_back(Mat(nSamples, 1, CV_8UC1, Scalar(s)));
74  if (doClean) m_pSamplesAcc->release(s); // free memory
75  } // s
76 
77  // Training, e.g. building the tree
78  m_pTree->build(samples, classes);
79  }
80 
81  void CTrainNodeKNN::calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
82  {
83  auto nearestNeighbors = m_pTree->findNearestNeighbors(featureVector.t(), m_params.maxNeighbors);
84  //float minr = mathop::Euclidian<byte, float>(featureVector.t(), nearestNeighbors.front()->getKey());
85 
86  size_t n = nearestNeighbors.size();
87  for (auto node : nearestNeighbors) {
88  byte s = node->getValue();
89 
90  //float r = mathop::Euclidian<byte, float>(featureVector.t(), node->getKey());
91  //r = 1 + r - minr;
92  //potential.at<float>(s, 0) += 0.1f / (r * r);
93 
94  potential.at<float>(s, 0) += 1.0f;
95  }
96  if (n) potential /= static_cast<double>(n);
97  potential += m_params.bias;
98  }
99 }
void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
std::string generateFileName(const std::string &path, const std::string &name, short idx) const
Generates name of the data file for storing random model parameters.
int getNumSamples(byte state) const
Returns the number of stored samples in container for the state (class) state.
size_t maxNeighbors
Max number of neighbors to be used for calculating potentials.
Definition: TrainNodeKNN.h:15
void addFeatureVec(const Mat &featureVector, byte gt)
Adds new feature vector.
Base abstract class for random model training.
void release(byte state)
Releases memory of container for the state (class) state.
size_t maxSamples
Maximum number of samples to be used in training. 0 means using all the samples.
Definition: TrainNodeKNN.h:16
Class implementing k-D Tree data structure.
Definition: KDTree.h:17
void build(Mat &keys, Mat &values)
Builds a k-d tree on keys with corresponding values.
Definition: KDTree.cpp:90
void addSample(const Mat &featureVector, byte state)
Adds new sample to the accumulator.
void reset(void)
Resets the tree.
Definition: KDTree.h:38
void reset(void)
Resets the accumulator.
k-Nearest Neighbors parameters
Definition: TrainNodeKNN.h:13
void train(bool doClean=false)
Random model training.
std::vector< std::shared_ptr< const CKDNode > > findNearestNeighbors(const Mat &key, size_t maxNeighbors) const
Finds up to maxNeighbors nearest neighbors to the key.
Definition: KDTree.cpp:117
void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Calculates the node potential, based on the feature vector.
Samples accumulator abstract class.
float bias
Regularization CRF parameter: bias is added to all potential values.
Definition: TrainNodeKNN.h:14
void load(const std::string &fileName)
Loads a tree from the file.
Definition: KDTree.cpp:80
Base abstract class for node potentials training.
Definition: TrainNode.h:47
CTrainNodeKNN(byte nStates, word nFeatures, TrainNodeKNNParams params=TRAIN_NODE_KNN_PARAMS_DEFAULT)
Constructor.
Definition: TrainNodeKNN.cpp:9
void reset(void)
Resets class variables.
void save(const std::string &fileName) const
Saves the tree into a file.
Definition: KDTree.cpp:65
CSamplesAccumulator * m_pSamplesAcc
Samples Accumulator.
Definition: TrainNodeKNN.h:75
Mat getSamplesContainer(byte state) const
Returns samples container for the state (class) state.
const TrainNodeKNNParams TRAIN_NODE_KNN_PARAMS_DEFAULT
Definition: TrainNodeKNN.h:22
int getNumInputSamples(byte state) const
Returns the number of input samples in container for the state (class) state.
void init(TrainNodeKNNParams params)
byte m_nStates
The number of states (classes)