Direct Graphical Models  v.1.7.0
InferTRW.cpp
1 #include "InferTRW.h"
2 #include "GraphPairwise.h"
3 #include "macroses.h"
4 
5 namespace DirectGraphicalModels
6 {
7 void CInferTRW::infer(unsigned int nIt)
8 {
9  const byte nStates = getGraph().getNumStates(); // number of states (classes)
10 
11  // ====================================== Initialization ======================================
13 #ifdef ENABLE_PPL
14  concurrency::parallel_for_each(getGraphPairwise().m_vEdges.begin(), getGraphPairwise().m_vEdges.end(), [nStates](ptr_edge_t &edge) {
15 #else
16  std::for_each(getGraphPairwise().m_vEdges.begin(), getGraphPairwise().m_vEdges.end(), [nStates](ptr_edge_t &edge) {
17 #endif
18  std::fill(edge->msg, edge->msg + nStates, 1.0f);
19  std::fill(edge->msg_temp, edge->msg_temp + nStates, 1.0f);
20  });
21 
22  // =================================== Calculating messages ==================================
23 
24  calculateMessages(nIt);
25 
26  // =================================== Calculating beliefs ===================================
27 
28  for (ptr_node_t &node : getGraphPairwise().m_vNodes) {
29  // backward edges
30  for (size_t e_f : node->from) {
31  Edge * edge_from = getGraphPairwise().m_vEdges[e_f].get();
32  if (edge_from->node1 > edge_from->node2) continue;
33  Node *src = getGraphPairwise().m_vNodes[edge_from->node1].get();
34  for (byte s = 0; s < nStates; s++) node->Pot.at<float>(s, 0) *= edge_from->Pot.at<float>(src->sol, s);
35  }
36  // forward edges
37  for (size_t e_t : node->to) {
38  Edge *edge_to = getGraphPairwise().m_vEdges[e_t].get();
39  if (edge_to->node1 > edge_to->node2) continue;
40  for (byte s = 0; s < nStates; s++) node->Pot.at<float>(s, 0) *= edge_to->msg[s];
41  }
42 
43  Point extremumLoc;
44  minMaxLoc(node->Pot, NULL, NULL, NULL, &extremumLoc);
45  node->sol = static_cast<byte> (extremumLoc.y);
46  }
47 
49 }
50 
51 void CInferTRW::calculateMessages(unsigned int nIt)
52 {
53  const byte nStates = getGraph().getNumStates(); // number of states
54  float * data = new float[nStates];
55  float * temp = new float[nStates];
56 
57  // main loop
58  for (unsigned int i = 0; i < nIt; i++) { // iterations
59 #ifdef DEBUG_PRINT_INFO
60  if (i == 0) printf("\n");
61  if (i % 5 == 0) printf("--- It: %d ---\n", i);
62 #endif
63  // Forward pass
64  std::for_each(getGraphPairwise().m_vNodes.begin(), getGraphPairwise().m_vNodes.end(), [&](ptr_node_t &node) {
65  memcpy(data, node->Pot.data, nStates * sizeof(float)); // data = node.pot
66 
67  int nForward = 0;
68  for (size_t e_t : node->to) {
69  Edge *edge_to = getGraphPairwise().m_vEdges[e_t].get();
70  if (edge_to->node1 > edge_to->node2) continue;
71  for (byte s = 0; s < nStates; s++) data[s] *= edge_to->msg[s]; // data = node.pot * edge_to.msg
72  nForward++;
73  } // e_t
74 
75  int nBackward = 0;
76  for (size_t e_f : node->from) {
77  Edge *edge_from = getGraphPairwise().m_vEdges[e_f].get();
78  if (edge_from->node1 > edge_from->node2) continue;
79  for (byte s = 0; s < nStates; s++) data[s] *= edge_from->msg[s]; // data = node.pot * edge_to.msg * edge_from.msg
80  nBackward++;
81  } // e_f
82 
83  for (byte s = 0; s < nStates; s++) data[s] = static_cast<float>(fastPow(data[s], 1.0f / MAX(nForward, nBackward)));
84 
85  // pass messages from i to nodes with higher m_ordering
86  for (size_t e_t : node->to) {
87  Edge *edge_to = getGraphPairwise().m_vEdges[e_t].get();
88  if (edge_to->node1 < edge_to->node2) calculateMessage(*edge_to, temp, data);
89  } // e_t
90  });
91 
92  // Backward pass
93  std::for_each(getGraphPairwise().m_vNodes.rbegin(), getGraphPairwise().m_vNodes.rend(), [&](ptr_node_t &node) {
94  memcpy(data, node->Pot.data, nStates * sizeof(float)); // data = node.pot
95 
96  int nForward = 0;
97  for (size_t e_t : node->to) {
98  Edge *edge_to = getGraphPairwise().m_vEdges[e_t].get();
99  if (edge_to->node1 > edge_to->node2) continue;
100  for (byte s = 0; s < nStates; s++) data[s] *= edge_to->msg[s];
101  nForward++;
102  } // e_t
103 
104  int nBackward = 0;
105  for (size_t e_f : node->from) {
106  Edge *edge_from = getGraphPairwise().m_vEdges[e_f].get();
107  if (edge_from->node1 > edge_from->node2) continue;
108  for (byte s = 0; s < nStates; s++) data[s] *= edge_from->msg[s];
109  nBackward++;
110  } // e_f
111 
112  // normalize data
113  float max = data[0];
114  for (byte s = 1; s < nStates; s++) if (max < data[s]) max = data[s];
115  for (byte s = 0; s < nStates; s++) data[s] /= max;
116  for (byte s = 0; s < nStates; s++) data[s] = static_cast<float>(fastPow(data[s], 1.0f / MAX(nForward, nBackward)));
117 
118  // pass messages from i to nodes with smaller m_ordering
119  for (size_t e_f : node->from) {
120  Edge *edge_from = getGraphPairwise().m_vEdges[e_f].get();
121  if (edge_from->node1 < edge_from->node2) calculateMessage(*edge_from, temp, data);
122  } // e_f
123  }); // All Nodes
124  } // iterations
125 
126  delete[] data;
127  delete[] temp;
128 }
129 
130 // Updates edge->msg = F(data, edge.Pot)
131 void CInferTRW::calculateMessage(Edge &edge, float *temp, float *data)
132 {
133  const byte nStates = getGraph().getNumStates();
134 
135  for (byte s = 0; s < nStates; s++) temp[s] = data[s] / MAX(FLT_EPSILON, edge.msg[s]); // tmp = gamma * data / edge.msg
136 
137  for (byte y = 0; y < nStates; y++) {
138  float *pPot = edge.Pot.ptr<float>(y);
139  float max = temp[0] * pPot[0]; // vMin = tmp + edge.Pot(0, kdest)
140  for (byte x = 1; x < nStates; x++) {
141  float val = temp[x] * pPot[x];
142  if (max < val) max = val;
143  }
144  edge.msg[y] = max;
145  }
146 
147  // Normalization
148  float max = edge.msg[0];
149  for (byte s = 1; s < nStates; s++) if (max < edge.msg[s]) max = edge.msg[s];
150  for (byte s = 0; s < nStates; s++) edge.msg[s] /= max;
151 }
152 }
byte getNumStates(void) const
Returns number of states (classes)
Definition: Graph.h:99
size_t node2
Second (destination) node in edge.
Definition: GraphPairwise.h:34
float * msg
Message (used in message-passing algorithms): Mat(size: nStates x 1; type: CV_32FC1) ...
Definition: GraphPairwise.h:36
size_t node1
First (source) node in edge.
Definition: GraphPairwise.h:33
CGraph & getGraph(void) const
Returns the reference to the graph.
Definition: Infer.h:82
void calculateMessage(Edge &edge, float *temp, float *data)
Definition: InferTRW.cpp:131
std::unique_ptr< Node > ptr_node_t
Definition: GraphPairwise.h:24
double fastPow(double a, double b)
Definition: macroses.h:45
CGraphPairwise & getGraphPairwise(void) const
Returns the graph.
void deleteMessages(void)
Deletes memory for Edge::msg and Edge::msg_temp containers for all edges in the graph.
Mat Pot
The edge potentials: Mat(size: nStates x nStates; type: CV_32FC1)
Definition: GraphPairwise.h:35
virtual void infer(unsigned int nIt=1)
Inference.
Definition: InferTRW.cpp:7
std::unique_ptr< Edge > ptr_edge_t
Definition: GraphPairwise.h:55
void createMessages(void)
Allocates memory for Edge::msg and Edge::msg_temp containers for all edges in the graph...
virtual void calculateMessages(unsigned int nIt)
Calculates messages, associated with the edges of corresponding graphical model.
Definition: InferTRW.cpp:51