Direct Graphical Models  v.1.7.0
InferTree.cpp
1 #include "InferTree.h"
2 #include "GraphPairwise.h"
3 
4 namespace DirectGraphicalModels
5 {
6 void CInferTree::calculateMessages(unsigned int)
7 {
8  byte nStates = getGraph().getNumStates();
9  size_t nNodes = getGraph().getNumNodes();
10 
11  // ====================================== Initialization ======================================
12  for (ptr_edge_t &edge: getGraphPairwise().m_vEdges) {
13  delete[] edge->msg;
14  edge->msg = NULL;
15  edge->suspend = false;
16  }
17 
18  // =================================== Computing messages ===================================
19  size_t * nFromEdges = new size_t[nNodes]; // Count number of neighbors
20  std::deque<size_t> nodeQueue;
21  for (ptr_node_t &node: getGraphPairwise().m_vNodes) {
22  nFromEdges[node->id] = node->from.size(); // number of incoming edges
23  if (nFromEdges[node->id] <= 1) nodeQueue.push_back(node->id); // Add all leafs to the queue
24  }
25 
26 
27  float *temp = new float[nStates];
28  while (!nodeQueue.empty()) {
29  //for (size_t q = 0; q < nodeQueue.size(); q++) printf("%d, ", nodeQueue[q]); printf("\n");
30 
31  size_t n = nodeQueue.front(); // n - node with one neighbour
32  nodeQueue.pop_front();
33 
34  Node *node = getGraphPairwise().m_vNodes[n].get(); // Node with one neighbour
35  size_t nToEdges = node->to.size();
36 
37  bool allSuspend = true;
38  for (size_t e_t = 0; e_t < nToEdges; e_t++) {
39  Edge *edge_to = getGraphPairwise().m_vEdges[node->to[e_t]].get();
40  if (!edge_to->suspend) {
41  allSuspend = false;
42  break;
43  }
44  }
45 
46  if (allSuspend) { // Now prepare messages for suspending edges
47  for (size_t e_t = 0; e_t < nToEdges; e_t++) {
48  Edge *edge_to = getGraphPairwise().m_vEdges[node->to[e_t]].get();
49  if (edge_to->msg) continue;
50 
51  calculateMessage(edge_to, temp, edge_to->msg);
52 
53  size_t n2 = edge_to->node2;
54  nFromEdges[n2]--;
55  if (nFromEdges[n2] <= 1) nodeQueue.push_back(n2);
56  }
57  } else { // Prepare messages for all non-suspending edges
58  for (size_t e_t = 0; e_t < nToEdges; e_t++) {
59  Edge * edge_to = getGraphPairwise().m_vEdges[node->to[e_t]].get();
60  if (edge_to->suspend) continue;
61  if (edge_to->msg) continue;
62 
63  calculateMessage(edge_to, temp, edge_to->msg);
64 
65  size_t n2 = edge_to->node2;
66  nFromEdges[n2]--;
67  if (nFromEdges[n2] <= 1) nodeQueue.push_back(n2);
68  }
69  }
70  } // while
71 
72  delete[] temp;
73  delete[] nFromEdges;
74 }
75 }
vec_size_t to
Array of edge ids, pointing to the Child vertices.
Definition: GraphPairwise.h:18
byte getNumStates(void) const
Returns number of states (classes)
Definition: Graph.h:99
bool suspend
Flag, indicating weather the message calculation must be postponed (used in message-passing algorithm...
Definition: GraphPairwise.h:39
size_t node2
Second (destination) node in edge.
Definition: GraphPairwise.h:34
float * msg
Message (used in message-passing algorithms): Mat(size: nStates x 1; type: CV_32FC1) ...
Definition: GraphPairwise.h:36
virtual void calculateMessages(unsigned int nIt)
Calculates messages for exact inference in a tree graph.
Definition: InferTree.cpp:6
virtual size_t getNumNodes(void) const =0
Returns the number of nodes in the graph.
CGraph & getGraph(void) const
Returns the reference to the graph.
Definition: Infer.h:82
std::unique_ptr< Node > ptr_node_t
Definition: GraphPairwise.h:24
CGraphPairwise & getGraphPairwise(void) const
Returns the graph.
std::unique_ptr< Edge > ptr_edge_t
Definition: GraphPairwise.h:55
void calculateMessage(Edge *edge, float *temp, float *&dst, bool maxSum=false)
Calculates one message for the specified edge edge.