Direct Graphical Models  v.1.7.0
InferChain.cpp
1 #include "InferChain.h"
2 #include "GraphPairwise.h"
3 
4 namespace DirectGraphicalModels
5 {
6  void CInferChain::calculateMessages(unsigned int)
7  {
8  float *temp = new float[getGraph().getNumNodes()];
9 
10  // Forward pass
11  std::for_each(getGraphPairwise().m_vNodes.begin(), getGraphPairwise().m_vNodes.end() - 1, [&](ptr_node_t &node) {
12  size_t nToEdges = node->to.size();
13  for (size_t e_t = 0; e_t < nToEdges; e_t++) { // outgoing edges
14  Edge *edge_to = getGraphPairwise().m_vEdges[node->to[e_t]].get(); // current outgoing edge
15  if (edge_to->node2 == node->id + 1)
16  calculateMessage(edge_to, temp, edge_to->msg);
17  } // e_t;
18  });
19 
20  // Backward pass
21  std::for_each(getGraphPairwise().m_vNodes.rbegin(), getGraphPairwise().m_vNodes.rend() - 1, [&](ptr_node_t &node) {
22  size_t nToEdges = node->to.size();
23  for (size_t e_t = 0; e_t < nToEdges; e_t++) { // outgoing edges
24  Edge *edge_to = getGraphPairwise().m_vEdges[node->to[e_t]].get(); // current outgoing edge
25  if (edge_to->node2 == node->id - 1)
26  calculateMessage(edge_to, temp, edge_to->msg);
27  } // e_t;
28  });
29 
30  delete[] temp;
31  }
32 }
size_t node2
Second (destination) node in edge.
Definition: GraphPairwise.h:34
float * msg
Message (used in message-passing algorithms): Mat(size: nStates x 1; type: CV_32FC1) ...
Definition: GraphPairwise.h:36
virtual size_t getNumNodes(void) const =0
Returns the number of nodes in the graph.
CGraph & getGraph(void) const
Returns the reference to the graph.
Definition: Infer.h:82
std::unique_ptr< Node > ptr_node_t
Definition: GraphPairwise.h:24
CGraphPairwise & getGraphPairwise(void) const
Returns the graph.
virtual void calculateMessages(unsigned int nIt)
Calculates messages for exact inference in a chain graph.
Definition: InferChain.cpp:6
void calculateMessage(Edge *edge, float *temp, float *&dst, bool maxSum=false)
Calculates one message for the specified edge edge.