Direct Graphical Models  v.1.7.0
TrainNodeCvANN.cpp
1 #include "TrainNodeCvANN.h"
2 #include "SamplesAccumulator.h"
3 
4 namespace DirectGraphicalModels
5 {
6  // Constructor
7  CTrainNodeCvANN::CTrainNodeCvANN(byte nStates, word nFeatures, TrainNodeCvANNParams params) : CBaseRandomModel(nStates), CTrainNode(nStates, nFeatures)
8  {
9  init(params);
10  }
11 
12  // Constructor
13  CTrainNodeCvANN::CTrainNodeCvANN(byte nStates, word nFeatures, size_t maxSamples) : CBaseRandomModel(nStates), CTrainNode(nStates, nFeatures)
14  {
16  params.maxSamples = maxSamples;
17  init(params);
18  }
19 
21  {
23 
24  if (params.numLayers < 2) params.numLayers = 2;
25  std::vector<int> vLayers(params.numLayers);
26  vLayers[0] = getNumFeatures();
27  for (int i = 1; i < params.numLayers - 1; i++)
28  vLayers[i] = m_nStates * 1 << (params.numLayers - i);
29  vLayers[params.numLayers - 1] = m_nStates;
30 
31  m_pANN = ml::ANN_MLP::create();
32  m_pANN->setLayerSizes(vLayers);
33  m_pANN->setActivationFunction(ml::ANN_MLP::SIGMOID_SYM, 0.0, 0.0);
34  m_pANN->setTermCriteria(TermCriteria(params.term_criteria_type, params.maxCount, params.epsilon));
35  m_pANN->setTrainMethod(ml::ANN_MLP::BACKPROP, params.weightScale, params.momentumScale);
36  }
37 
38  // Destructor
40  {
41  delete m_pSamplesAcc;
42  }
43 
45  {
47  m_pANN->clear();
48  }
49 
50  void CTrainNodeCvANN::save(const std::string &path, const std::string &name, short idx) const
51  {
52  std::string fileName = generateFileName(path, name.empty() ? "TrainNodeCvANN" : name, idx);
53  m_pANN->save(fileName.c_str());
54  }
55 
56  void CTrainNodeCvANN::load(const std::string &path, const std::string &name, short idx)
57  {
58  std::string fileName = generateFileName(path, name.empty() ? "TrainNodeCvANN" : name, idx);
59  m_pANN = Algorithm::load<ml::ANN_MLP>(fileName.c_str());
60  }
61 
62  void CTrainNodeCvANN::addFeatureVec(const Mat &featureVector, byte gt)
63  {
64  m_pSamplesAcc->addSample(featureVector, gt);
65  }
66 
67  void CTrainNodeCvANN::train(bool doClean)
68  {
69 #ifdef DEBUG_PRINT_INFO
70  printf("\n");
71 #endif
72  // Filling the <samples> and <classes>
73  Mat samples, classes;
74  for (byte s = 0; s < m_nStates; s++) { // states
75  int nSamples = m_pSamplesAcc->getNumSamples(s);
76 #ifdef DEBUG_PRINT_INFO
77  printf("State[%d] - %d of %d samples\n", s, nSamples, m_pSamplesAcc->getNumInputSamples(s));
78 #endif
79  samples.push_back(m_pSamplesAcc->getSamplesContainer(s));
80  Mat classes_s(nSamples, m_nStates, CV_32FC1, Scalar(0.0f));
81  classes_s.col(s).setTo(1.0f);
82  classes.push_back(classes_s);
83  if (doClean) m_pSamplesAcc->release(s); // free memory
84  } // s
85  samples.convertTo(samples, CV_32FC1);
86 
87  // Training
88  try {
89  m_pANN->train(samples, ml::ROW_SAMPLE, classes);
90  }
91  catch (std::exception &e) {
92  printf("EXCEPTION: %s\n", e.what());
93  getchar();
94  exit(-1);
95  }
96  }
97 
98  void CTrainNodeCvANN::calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
99  {
100  Mat fv;
101  featureVector.convertTo(fv, CV_32FC1);
102  //float res = m_pANN->predict(fv.t());
103  //byte s = static_cast<byte>(res);
104  //potential.at<float>(s, 0) = 1.0f;
105  //potential += 0.1f;
106 
107  m_pANN->predict(fv.t(), potential);
108  for (float &pot : static_cast<Mat_<float>>(potential))
109  if (pot < 0) pot = 0;
110  potential = potential.t();
111  }
112 }
Ptr< ml::ANN_MLP > m_pANN
Artificial Neural Network.
int term_criteria_type
Termination cirteria type (according the the two previous parameters)
void init(TrainNodeCvANNParams params)
double weightScale
Strength of the weight gradient term. The recommended value is about 0.1. Default value is 0...
void train(bool doClean=false)
Random model training.
void calculateNodePotentials(const Mat &featureVector, Mat &potential, Mat &mask) const
Calculates the node potential, based on the feature vector.
word getNumFeatures(void) const
Returns number of features.
Definition: ITrain.h:37
double epsilon
The desired accuracy or change in parameters at which the iterative algorithm stops.
std::string generateFileName(const std::string &path, const std::string &name, short idx) const
Generates name of the data file for storing random model parameters.
int getNumSamples(byte state) const
Returns the number of stored samples in container for the state (class) state.
const TrainNodeCvANNParams TRAIN_NODE_CV_ANN_PARAMS_DEFAULT
Base abstract class for random model training.
CSamplesAccumulator * m_pSamplesAcc
Samples Accumulator.
void release(byte state)
Releases memory of container for the state (class) state.
OpenCV Artificial neural network parameters.
void addSample(const Mat &featureVector, byte state)
Adds new sample to the accumulator.
void reset(void)
Resets the accumulator.
void addFeatureVec(const Mat &featureVector, byte gt)
Adds new feature vector.
Samples accumulator abstract class.
int maxCount
The maximum number of iterations (time / accuracy)
void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
Base abstract class for node potentials training.
Definition: TrainNode.h:47
void reset(void)
Resets class variables.
word numLayers
Number of layers of neurons.
Mat getSamplesContainer(byte state) const
Returns samples container for the state (class) state.
int getNumInputSamples(byte state) const
Returns the number of input samples in container for the state (class) state.
CTrainNodeCvANN(byte nStates, word nFeatures, TrainNodeCvANNParams params=TRAIN_NODE_CV_ANN_PARAMS_DEFAULT)
Constructor.
double momentumScale
Strength of the momentum term (the difference between weights on the 2 previous iterations). This parameter provides some inertia to smooth the random fluctuations of the weights. It can vary from 0 (the feature is disabled) to 1 and beyond. The value 0.1 or so is good enough. Default value is 0.1.
void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
byte m_nStates
The number of states (classes)
size_t maxSamples
Maximum number of samples to be used in training. 0 means using all the samples.