Direct Graphical Models  v.1.7.0
TrainNodeKNN.h
1 // k-Nearest Neighbors training class interface
2 // Written by Sergey G. Kosov in 2017 for Project X
3 #pragma once
4 
5 #include "TrainNode.h"
6 
7 namespace DirectGraphicalModels
8 {
9  class CKDTree;
10  class CSamplesAccumulator;
11 
13  typedef struct TrainNodeKNNParams {
14  float bias;
15  size_t maxNeighbors;
16  size_t maxSamples;
17 
19  TrainNodeKNNParams(float _bias, size_t _maxNeighbors, size_t _maxSamples) : bias(_bias), maxNeighbors(_maxNeighbors), maxSamples(_maxSamples) {}
21 
23  0.1f, // Regularization CRF parameter: bias is added to all potential values
24  100, // Max number of neighbors to be used for calculating potentials
25  0 // Maximum number of samples to be used in training. 0 means using all the samples
26  );
27 
28  // ====================== k-Nearest Neighbors Train Class =====================
38  class CTrainNodeKNN : public CTrainNode
39  {
40  public:
47  DllExport CTrainNodeKNN(byte nStates, word nFeatures, TrainNodeKNNParams params = TRAIN_NODE_KNN_PARAMS_DEFAULT);
56  DllExport CTrainNodeKNN(byte nStates, word nFeatures, size_t maxSamples);
57  DllExport ~CTrainNodeKNN(void);
58 
59  DllExport void reset(void);
60  DllExport void save(const std::string &path, const std::string &name = std::string(), short idx = -1) const;
61  DllExport void load(const std::string &path, const std::string &name = std::string(), short idx = -1);
62 
63  DllExport void addFeatureVec(const Mat &featureVector, byte gt);
64  DllExport void train(bool doClean = false);
65 
66 
67  protected:
68  DllExport void saveFile(FILE *pFile) const {}
69  DllExport void loadFile(FILE *pFile) {}
70  DllExport void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const;
71 
72 
73  protected:
76 
77 
78  private:
79  void init(TrainNodeKNNParams params); // This function is called by both constructors
80 
81 
82  private:
84  };
85 }
void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
TrainNodeKNNParams(float _bias, size_t _maxNeighbors, size_t _maxSamples)
Definition: TrainNodeKNN.h:19
void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
size_t maxNeighbors
Max number of neighbors to be used for calculating potentials.
Definition: TrainNodeKNN.h:15
Nearest Neighbor training class.
Definition: TrainNodeKNN.h:38
void addFeatureVec(const Mat &featureVector, byte gt)
Adds new feature vector.
size_t maxSamples
Maximum number of samples to be used in training. 0 means using all the samples.
Definition: TrainNodeKNN.h:16
Class implementing k-D Tree data structure.
Definition: KDTree.h:17
k-Nearest Neighbors parameters
Definition: TrainNodeKNN.h:13
void train(bool doClean=false)
Random model training.
void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Calculates the node potential, based on the feature vector.
void loadFile(FILE *pFile)
Loads the random model from the file.
Definition: TrainNodeKNN.h:69
Samples accumulator abstract class.
struct DirectGraphicalModels::TrainNodeKNNParams TrainNodeKNNParams
k-Nearest Neighbors parameters
float bias
Regularization CRF parameter: bias is added to all potential values.
Definition: TrainNodeKNN.h:14
void saveFile(FILE *pFile) const
Saves the random model into the file.
Definition: TrainNodeKNN.h:68
Base abstract class for node potentials training.
Definition: TrainNode.h:47
CTrainNodeKNN(byte nStates, word nFeatures, TrainNodeKNNParams params=TRAIN_NODE_KNN_PARAMS_DEFAULT)
Constructor.
Definition: TrainNodeKNN.cpp:9
void reset(void)
Resets class variables.
CSamplesAccumulator * m_pSamplesAcc
Samples Accumulator.
Definition: TrainNodeKNN.h:75
const TrainNodeKNNParams TRAIN_NODE_KNN_PARAMS_DEFAULT
Definition: TrainNodeKNN.h:22
void init(TrainNodeKNNParams params)