Direct Graphical Models  v.1.7.0
TrainLinkNested.h
1 // Nested training model for pairwise link potentials
2 // Written by Sergey G. Kosov in 2016 for Project X
3 #pragma once
4 
5 #include "TrainLink.h"
6 
7 namespace DirectGraphicalModels
8 {
9  class CPriorNode;
10  class CTrainNode;
11 
12  // ============================= Nested Link Train Class =============================
21  template<class Trainer> class CTrainLinkNested : public CTrainLink
22  {
23  public:
31  CTrainLinkNested(byte nStatesBase, byte nStatesOccl, word nFeatures)
32  : CBaseRandomModel(nStatesBase * nStatesOccl)
33  , CTrainLink(nStatesBase, nStatesOccl, nFeatures)
34  {
35  word nStates = static_cast<word>(nStatesBase) * static_cast<word>(nStatesOccl);
36  DGM_ASSERT(nStates < 256);
37  m_pPrior = new CPriorNode(static_cast<byte>(nStates));
38  m_pTrainer = new Trainer(static_cast<byte>(nStates), nFeatures);
39  }
48  template<class TrainerParams> CTrainLinkNested(byte nStatesBase, byte nStatesOccl, word nFeatures, TrainerParams params)
49  : CTrainLink(nStatesBase, nStatesOccl, nFeatures)
50  , CBaseRandomModel(nStatesBase * nStatesOccl)
51  {
52  word nStates = static_cast<word>(nStatesBase) * static_cast<word>(nStatesOccl);
53  DGM_ASSERT(nStates < 256);
54  m_pPrior = new CPriorNode(static_cast<byte>(nStates));
55  m_pTrainer = new Trainer(static_cast<byte>(nStates), nFeatures, params);
56  }
57 
58  virtual ~CTrainLinkNested(void)
59  {
60  delete m_pPrior;
61  delete m_pTrainer;
62  }
63 
64  virtual void reset(void) { m_pPrior->reset(); m_pTrainer->reset(); }
65  virtual void save(const std::string &path, const std::string &name = std::string(), short idx = -1) const { m_pTrainer->save(path, name.empty() ? "CTrainLinkNested" : name, idx); }
66  virtual void load(const std::string &path, const std::string &name = std::string(), short idx = -1) { m_pTrainer->load(path, name.empty() ? "CTrainLinkNested" : name, idx); }
67 
68  virtual void addFeatureVec(const Mat &featureVector, byte gtb, byte gto)
69  {
70  byte gt = gtb + m_nStatesBase * gto;
72  m_pTrainer->addFeatureVec(featureVector, gt);
73  }
74 
75  virtual void train(bool doClean = false)
76  {
77  // Fill holes in trainig
78  Mat fv(getNumFeatures(), 1, CV_8UC1, Scalar(0));
79  Mat priors = m_pPrior->getPrior();
80  for (byte i = 0; i < priors.rows; i++)
81  if (priors.at<float>(i, 0) == 0)
82  m_pTrainer->addFeatureVec(fv, i);
83 
84 
85  m_pTrainer->train(doClean);
86  }
87 
88 
89  protected:
90  DllExport virtual void saveFile(FILE *pFile) const {}
91  DllExport virtual void loadFile(FILE *pFile) {}
100  virtual Mat calculateLinkPotentials(const Mat &featureVector) const
101  {
102  Mat pot = m_pTrainer->getNodePotentials(featureVector);
103  //pot = m_pPrior->getPrior(100);
104 
105  DGM_ASSERT_MSG(pot.rows == m_nStatesBase * m_nStatesOccl, "The length of the node potentinal vector = %d, but must be %d", pot.rows, m_nStatesBase * m_nStatesOccl);
106 
107  Mat res(m_nStatesBase + m_nStatesOccl, m_nStatesBase + m_nStatesOccl, CV_32FC1, Scalar(0));
108 
109  for (byte gto = 0; gto < m_nStatesOccl; gto++) {
110  float * pRes = res.ptr<float>(m_nStatesBase + gto);
111  for (byte gtb = 0; gtb < m_nStatesBase; gtb++) {
112  byte gt = gtb + m_nStatesBase * gto;
113  pRes[gtb] = pot.at<float>(gt, 0);
114  }
115  }
116 
117  return res;
118  }
119 
120  // CPriorNode * getPrior(void) const { return m_pPrior; }
121  // CTrainNode * getTrainer(void) const { return m_pTrainer; }
122 
123  private:
126  };
127 }
CTrainLinkNested(byte nStatesBase, byte nStatesOccl, word nFeatures, TrainerParams params)
Constructor.
virtual Mat calculateLinkPotentials(const Mat &featureVector) const
Returns the data-dependent link (inter-layer edge) potentials.
byte m_nStatesBase
Number of states (classes) at the base layer of ML-CRF.
Definition: TrainLink.h:87
word getNumFeatures(void) const
Returns number of features.
Definition: ITrain.h:37
virtual void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
void addNodeGroundTruth(const Mat &gt)
Adds ground truth values to the co-occurance histogram vector.
Definition: PriorNode.cpp:7
virtual void reset(void)
Resets class variables.
virtual void loadFile(FILE *pFile)
Loads the random model from the file.
Base abstract class for random model training.
virtual void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
CTrainLinkNested(byte nStatesBase, byte nStatesOccl, word nFeatures)
Constructor.
CTrainNode * m_pTrainer
Node trainer
Mat getNodePotentials(const Mat &featureVectors, const Mat &weights=Mat(), float Z=0.0f) const
Returns a block of node potentials, based on the block of feature vector.
Definition: TrainNode.cpp:54
virtual void load(const std::string &path, const std::string &name=std::string(), short idx=-1)
Loads the training data.
void reset(void)
Resets class variables.
Definition: Prior.cpp:19
Base abstract class for link (inter-layer edge) potentials training.
Definition: TrainLink.h:17
Node prior probability estimation class
Definition: PriorNode.h:14
Mat getPrior(float weight=1.0f) const
Returns the prior probabilies.
Definition: Prior.cpp:24
virtual void save(const std::string &path, const std::string &name=std::string(), short idx=-1) const
Saves the training data.
Base abstract class for node potentials training.
Definition: TrainNode.h:47
virtual void reset(void)=0
Resets class variables.
CPriorNode * m_pPrior
Node prior poobability
virtual void saveFile(FILE *pFile) const
Saves the random model into the file.
virtual void addFeatureVec(const Mat &featureVector, byte gt)=0
Adds new feature vector.
Nested link (inter-layer edge) training class.
virtual void train(bool doClean=false)
Random model training.
byte m_nStatesOccl
Number of states (classes) at the occlusion layerts of ML-CRF.
Definition: TrainLink.h:88
virtual void addFeatureVec(const Mat &featureVector, byte gtb, byte gto)
Adds a feature vector.
virtual void train(bool doClean=false)
Random model training.
Definition: TrainNode.h:92