Direct Graphical Models  v.1.7.0
MessagePassing.cpp
1 #include "MessagePassing.h"
2 #include "GraphPairwise.h"
3 #include "macroses.h"
4 
5 namespace DirectGraphicalModels
6 {
7 void CMessagePassing::infer(unsigned int nIt)
8 {
9  const byte nStates = getGraph().getNumStates();
10 
11  // ====================================== Initialization ======================================
12  createMessages();
13 #ifdef ENABLE_PPL
14  concurrency::parallel_for_each(getGraphPairwise().m_vEdges.begin(), getGraphPairwise().m_vEdges.end(), [nStates](ptr_edge_t &edge) {
15 #else
16  std::for_each(getGraphPairwise().m_vEdges.begin(), getGraphPairwise().m_vEdges.end(), [nStates](ptr_edge_t &edge) {
17 #endif
18  std::fill(edge->msg, edge->msg + nStates, 1.0f / nStates); // msg[] = 1 / nStates;
19  std::fill(edge->msg_temp, edge->msg_temp + nStates, 1.0f / nStates); // msg_temp[] = 1 / nStates;
20  });
21 
22  // =================================== Calculating messages ==================================
23  calculateMessages(nIt);
24 
25  // =================================== Calculating beliefs ===================================
26 #ifdef ENABLE_PPL
27  concurrency::parallel_for_each(getGraphPairwise().m_vNodes.begin(), getGraphPairwise().m_vNodes.end(), [&, nStates](ptr_node_t &node) {
28 #else
29  std::for_each(getGraphPairwise().m_vNodes.begin(), getGraphPairwise().m_vNodes.end(), [&,nStates](ptr_node_t &node) {
30 #endif
31  size_t nFromEdges = node->from.size();
32  // Don't understand the normalization step, replaced with another version.
33  //for (size_t e_f = 0; e_f < nFromEdges; e_f++) {
34  // Edge *edge_from = m_pGraph->m_vEdges[node->from[e_f]].get(); // current incoming edge
35  // float SUM_pot = 0;
36 
37  // float epsilon = FLT_EPSILON;
38  // for (byte s = 0; s < nStates; s++) { // states
39  // SUM_pot += node->Pot.at<float>(s, 0);
40  // // node.Pot.at<float>(s,0) *= edge_from->msg[s];
41  // node->Pot.at<float>(s, 0) = (epsilon + node->Pot.at<float>(s, 0)) * (epsilon + edge_from->msg[s]); // Soft multiplication
42  // } //s
43  //
44  // // Normalization
45  // float SUM_new_pot = 0;
46  // for (byte s = 0; s < nStates; s++) // states
47  // SUM_new_pot += node->Pot.at<float>(s, 0);
48  // for (byte s = 0; s < nStates; s++) { // states
49  // node->Pot.at<float>(s, 0) *= SUM_pot / SUM_new_pot;
50  // //node->Pot.at<float>(s, 0) /= SUM_new_pot;
51  // DGM_ASSERT_MSG(!std::isnan(node->Pot.at<float>(s, 0)), "The lower precision boundary for the potential of the node %zu is reached.\n \
52  // SUM_pot = %f\nSUM_new_pot = %f\n", node->id, SUM_pot, SUM_new_pot);
53 // }
54 //} // e_f
55 
56 
57  for (size_t e_f = 0; e_f < nFromEdges; e_f++) {
58  Edge *edge_from = getGraphPairwise().m_vEdges[node->from[e_f]].get(); // current incoming edge
59 
60  float epsilon = FLT_EPSILON;
61  for (byte s = 0; s < nStates; s++) { // states
62  // node.Pot.at<float>(s,0) *= edge_from->msg[s];
63  node->Pot.at<float>(s, 0) = (epsilon + node->Pot.at<float>(s, 0)) * (epsilon + edge_from->msg[s]); // Soft multiplication
64  } //s
65  } // e_f
66  // Normalization
67  float SUM_pot = 0;
68  for (byte s = 0; s < nStates; s++) // states
69  SUM_pot += node->Pot.at<float>(s, 0);
70  for (byte s = 0; s < nStates; s++) { // states
71  node->Pot.at<float>(s, 0) /= SUM_pot;
72  //node->Pot.at<float>(s, 0) /= SUM_new_pot;
73  DGM_ASSERT_MSG(!std::isnan(node->Pot.at<float>(s, 0)), "The lower precision boundary for the potential of the node %zu is reached.\n \
74  SUM_pot = %f\n", node->id, SUM_pot);
75  }
76  });
77 
79 }
80 
81 // dst: usually edge_to->msg or edge_to->msg_temp
82 void CMessagePassing::calculateMessage(Edge *edge_to, float *temp, float *&dst, bool maxSum)
83 {
84  byte s; // state indexes
85  Node * node = getGraphPairwise().m_vNodes[edge_to->node1].get(); // source node
86  size_t nFromEdges = node->from.size(); // number of incoming eges
87  const byte nStates = getGraph().getNumStates(); // number of states
88 
89  // Compute temp = product of all incoming msgs except e_t
90  for (s = 0; s < nStates; s++) temp[s] = node->Pot.at<float>(s, 0); // temp = node.Pot
91 
92  for (size_t e_f = 0; e_f < nFromEdges; e_f++) { // incoming edges
93  Edge *edge_from = getGraphPairwise().m_vEdges[node->from[e_f]].get(); // current incoming edge
94  if (edge_from->node1 != edge_to->node2)
95  for (s = 0; s < nStates; s++)
96  temp[s] *= edge_from->msg[s]; // temp = temp * msg
97  else
98  edge_from->suspend = true;
99  } // e_f
100 
101  // Compute new message: new_msg = (edge_to.Pot^2)^t x temp
102  float Z = MatMul(edge_to->Pot, temp, dst, maxSum);
103 
104  // Normalization and setting new values
105  if (Z > FLT_EPSILON)
106  for (s = 0; s < nStates; s++)
107  dst[s] /= Z;
108  else
109  for (s = 0; s < nStates; s++)
110  dst[s] = 1.0f / nStates;
111 }
112 
114 {
115  const byte nStates = getGraph().getNumStates();
116 #ifdef ENABLE_PPL
117  concurrency::parallel_for_each(getGraphPairwise().m_vEdges.begin(), getGraphPairwise().m_vEdges.end(), [nStates](ptr_edge_t &edge) {
118 #else
119  std::for_each(getGraphPairwise().m_vEdges.begin(), getGraphPairwise().m_vEdges.end(), [nStates](ptr_edge_t &edge) {
120 #endif
121  if (!edge->msg) edge->msg = new float[nStates];
122  DGM_ASSERT_MSG(edge->msg, "Out of Memory");
123 
124  if (!edge->msg_temp) edge->msg_temp = new float[nStates];
125  DGM_ASSERT_MSG(edge->msg_temp, "Out of Memory");
126  });
127 }
128 
130 {
131 #ifdef ENABLE_PPL
132  concurrency::parallel_for_each(getGraphPairwise().m_vEdges.begin(), getGraphPairwise().m_vEdges.end(), [](ptr_edge_t &edge) {
133 #else
134  std::for_each(getGraphPairwise().m_vEdges.begin(), getGraphPairwise().m_vEdges.end(), [](ptr_edge_t &edge) {
135 #endif
136  if (edge->msg) {
137  delete[] edge->msg;
138  edge->msg = NULL;
139  }
140  if (edge->msg_temp) {
141  delete[] edge->msg_temp;
142  edge->msg_temp = NULL;
143  }
144  });
145 }
146 
148 {
149 #ifdef ENABLE_PPL
150  concurrency::parallel_for_each(getGraphPairwise().m_vEdges.begin(), getGraphPairwise().m_vEdges.end(), [](ptr_edge_t &edge) { edge->msg_swap(); });
151 #else
152  std::for_each(getGraphPairwise().m_vEdges.begin(), getGraphPairwise().m_vEdges.end(), [](ptr_edge_t &edge) { edge->msg_swap(); });
153 #endif
154 }
155 
156 // dst = (M * M)^T x v
157 float CMessagePassing::MatMul(const Mat &M, const float *v, float *&dst, bool maxSum)
158 {
159  float res = 0;
160  if (!dst) dst = new float[M.cols];
161  for (int x = 0; x < M.cols; x++) {
162  float sum = 0;
163  for (int y = 0; y < M.rows; y++) {
164  float m = M.at<float>(y, x);
165  float prod = v[y] * m * m;
166  if (maxSum) { if (prod > sum) sum = prod; }
167  else sum += prod;
168  } // y
169  dst[x] = sum;
170  res += sum;
171  } // x
172  return res;
173 }
174 }
Mat Pot
Node potentials: Mat(size: nStates x 1; type: CV_32FC1)
Definition: GraphPairwise.h:16
byte getNumStates(void) const
Returns number of states (classes)
Definition: Graph.h:99
void swapMessages(void)
Swaps Edge::msg and Edge::msg_temp for all edges in the graph.
bool suspend
Flag, indicating weather the message calculation must be postponed (used in message-passing algorithm...
Definition: GraphPairwise.h:39
size_t node2
Second (destination) node in edge.
Definition: GraphPairwise.h:34
float * msg
Message (used in message-passing algorithms): Mat(size: nStates x 1; type: CV_32FC1) ...
Definition: GraphPairwise.h:36
virtual void calculateMessages(unsigned int nIt)=0
Calculates messages, associated with the edges of corresponding graphical model.
size_t node1
First (source) node in edge.
Definition: GraphPairwise.h:33
CGraph & getGraph(void) const
Returns the reference to the graph.
Definition: Infer.h:82
static float MatMul(const Mat &M, const float *v, float *&dst, bool maxSum=false)
Specific matrix multiplication.
std::unique_ptr< Node > ptr_node_t
Definition: GraphPairwise.h:24
CGraphPairwise & getGraphPairwise(void) const
Returns the graph.
void deleteMessages(void)
Deletes memory for Edge::msg and Edge::msg_temp containers for all edges in the graph.
Mat Pot
The edge potentials: Mat(size: nStates x nStates; type: CV_32FC1)
Definition: GraphPairwise.h:35
std::unique_ptr< Edge > ptr_edge_t
Definition: GraphPairwise.h:55
void createMessages(void)
Allocates memory for Edge::msg and Edge::msg_temp containers for all edges in the graph...
vec_size_t from
Array of edge ids, coming from the Parent vertices.
Definition: GraphPairwise.h:19
virtual void infer(unsigned int nIt=1)
Inference.
void calculateMessage(Edge *edge, float *temp, float *&dst, bool maxSum=false)
Calculates one message for the specified edge edge.