Direct Graphical Models  v.1.7.0
MessagePassing.h
1 // Base abstract class for message passing algorithms used for exact and approximate inference
2 // Written by Sergey G. Kosov in 2015 for Project X
3 #pragma once
4 
5 #include "Infer.h"
6 #include "GraphPairwise.h"
7 
8 namespace DirectGraphicalModels
9 {
10  struct Edge;
11 
12  // ==================== Message Passing Base Abstract Class ==================
18  class CMessagePassing : public CInfer
19  {
20  public:
25  DllExport CMessagePassing(CGraphPairwise &graph) : CInfer(graph) {}
26  DllExport virtual ~CMessagePassing(void) {}
27 
28  DllExport virtual void infer(unsigned int nIt = 1);
29 
30 
31  protected:
36  CGraphPairwise & getGraphPairwise(void) const { return dynamic_cast<CGraphPairwise &>(getGraph()); }
42  virtual void calculateMessages(unsigned int nIt) = 0;
52  void calculateMessage(Edge *edge, float *temp, float *&dst, bool maxSum = false);
56  void createMessages(void);
60  void deleteMessages(void);
64  void swapMessages(void);
75  static float MatMul(const Mat &M, const float *v, float *&dst, bool maxSum = false);
76  };
77 }
void swapMessages(void)
Swaps Edge::msg and Edge::msg_temp for all edges in the graph.
Abstract base class for message passing inference algorithmes.
virtual void calculateMessages(unsigned int nIt)=0
Calculates messages, associated with the edges of corresponding graphical model.
CGraph & getGraph(void) const
Returns the reference to the graph.
Definition: Infer.h:82
static float MatMul(const Mat &M, const float *v, float *&dst, bool maxSum=false)
Specific matrix multiplication.
Base abstract class for random model inference.
Definition: Infer.h:19
CGraphPairwise & getGraphPairwise(void) const
Returns the graph.
void deleteMessages(void)
Deletes memory for Edge::msg and Edge::msg_temp containers for all edges in the graph.
CMessagePassing(CGraphPairwise &graph)
Constructor.
void createMessages(void)
Allocates memory for Edge::msg and Edge::msg_temp containers for all edges in the graph...
virtual void infer(unsigned int nIt=1)
Inference.
void calculateMessage(Edge *edge, float *temp, float *&dst, bool maxSum=false)
Calculates one message for the specified edge edge.