Direct Graphical Models  v.1.7.0
TrainLink.h
1 // Base abstract class for random model links (inter-layer edges) training
2 // Written by Sergey G. Kosov in 2016 for Project X
3 #pragma once
4 
5 #include "ITrain.h"
6 
7 namespace DirectGraphicalModels
8 {
9  // ============================= Link Train Class =============================
17  class CTrainLink : public ITrain
18  {
19  public:
26  DllExport CTrainLink(byte nStatesBase, byte nStatesOccl, word nFeatures)
27  : CBaseRandomModel(nStatesBase * nStatesOccl)
28  , ITrain(nStatesBase * nStatesOccl, nFeatures)
29  , m_nStatesBase(nStatesBase)
30  , m_nStatesOccl(nStatesOccl)
31  {}
32  DllExport virtual ~CTrainLink(void) = default;
33 
34 
42  DllExport void addFeatureVec(const Mat &featureVectors, const Mat &gtb, const Mat &gto);
50  DllExport void addFeatureVec(const vec_mat_t &featureVectors, const Mat &gtb, const Mat &gto);
59  DllExport virtual void addFeatureVec(const Mat &featureVector, byte gtb, byte gto) = 0;
60  DllExport virtual void train(bool doClean = false) {}
69  DllExport Mat getLinkPotentials(const Mat &featureVector, float weight = 1.0f) const;
70 
71 
72  protected:
83  DllExport virtual Mat calculateLinkPotentials(const Mat &featureVector) const = 0;
84 
85 
86  protected:
89  };
90 }
void addFeatureVec(const Mat &featureVectors, const Mat &gtb, const Mat &gto)
Adds a block of new feature vectors.
Definition: TrainLink.cpp:6
byte m_nStatesBase
Number of states (classes) at the base layer of ML-CRF.
Definition: TrainLink.h:87
Base abstract class for random model training.
Base abstract class for link (inter-layer edge) potentials training.
Definition: TrainLink.h:17
CTrainLink(byte nStatesBase, byte nStatesOccl, word nFeatures)
Constructor.
Definition: TrainLink.h:26
Mat getLinkPotentials(const Mat &featureVector, float weight=1.0f) const
Returns the link potential, based on the feature vector.
Definition: TrainLink.cpp:18
virtual Mat calculateLinkPotentials(const Mat &featureVector) const =0
Calculates the link potential, based on the feature vector.
virtual void train(bool doClean=false)
Random model training.
Definition: TrainLink.h:60
byte m_nStatesOccl
Number of states (classes) at the occlusion layerts of ML-CRF.
Definition: TrainLink.h:88
Interface class for random model training.
Definition: ITrain.h:15
virtual ~CTrainLink(void)=default