Direct Graphical Models  v.1.7.0
DecodeExact.cpp
1 #include "DecodeExact.h"
2 #include "macroses.h"
3 #include <numeric>
4 
5 namespace DirectGraphicalModels
6 {
7  vec_byte_t CDecodeExact::decode(Mat &lossMatrix) const
8  {
9  DGM_IF_WARNING(!lossMatrix.empty(), "The Loss Matrix is not supported by the algorithm.");
10 
11  size_t nNodes = getGraph().getNumNodes();
12  vec_byte_t state(nNodes);
13 
14  // Calculating the potentials for every possible configuration
15  vec_float_t P = calculatePotentials();
16 
17 #ifdef DEBUG_PRINT_INFO
18  // Printing out
19  printf("nConfigurations = %zd\n", P.size());
20 
21  // Calculating the partition function
22  float Z = std::accumulate(P.cbegin(), P.cend(), 0.0f);
23 
24  setState(state, 0);
25  for (float &p: P) {
26  for (size_t n = 0; n < nNodes; n++) printf("%d ", state[n]);
27  printf(":-> %2.1f\t| %2.1f %% \n", p, p * 100 / Z);
28  incState(state);
29  }
30 #endif
31 
32  // Finding the most probable configuration
33  qword c = std::max_element(P.cbegin(), P.cend()) - P.begin();
34  setState(state, c);
35  return state;
36  }
37 
38  // Sets the <state> according to the configuration number <c>
39  void CDecodeExact::setState(vec_byte_t &state, qword c) const
40  {
41  size_t nNodes = getGraph().getNumNodes();
42  for (size_t n = 0; n < nNodes; n++) {
43  state[n] = c % getGraph().getNumStates();
44  c = (c - state[n]) / getGraph().getNumStates();
45  }
46  }
47 
48  // Increases the <state> by one
49  void CDecodeExact::incState(vec_byte_t &state) const
50  {
51  size_t nNodes = getGraph().getNumNodes();
52  for (size_t n = 0; n < nNodes; n++)
53  if (++state[n] >= getGraph().getNumStates()) state[n] = 0;
54  else break;
55  }
56 
57  // Calculates potentials for all possible configurations
58  vec_float_t CDecodeExact::calculatePotentials(void) const
59  {
60  size_t nNodes = getGraph().getNumNodes();
61  size_t nConfigurations = static_cast<size_t> (powl(getGraph().getNumStates(), static_cast<long double>(nNodes)));
62  vec_byte_t state(nNodes);
63 
64  vec_float_t res;
65  DGM_ASSERT_MSG(nConfigurations < res.max_size(), "The number of configurations %d^%zu exceeds the maximal possible size of container.", getGraph().getNumStates(), nNodes);
66  res.resize(nConfigurations, 1.0f);
67 
68  setState(state, 0);
69  Mat nPot, ePot;
70  vec_size_t vChildNodes;
71  for (float &p: res) {
72  for (size_t n = 0; n < nNodes; n++) {
73  getGraph().getNode(n, nPot);
74  p *= nPot.at<float>(state[n], 0);
75  vec_size_t vChilds;
76  getGraph().getChildNodes(n, vChilds);
77  for (size_t c: vChilds) {
78  getGraphPairwise().getEdge(n, c, ePot);
79  p *= ePot.at<float>(state[n], state[c]);
80  }
81  }
82 
83  // Old implementation with direct access to CGraphPairwise private member variables
84  // for (ptr_node_t &node : getGraphPairwise().m_vNodes) p *= node->Pot.at<float>(state[node->id], 0);
85  // for (ptr_edge_t &edge : getGraphPairwise().m_vEdges) p *= edge->Pot.at<float>(state[edge->node1], state[edge->node2]);
86  incState(state);
87  }
88 
89  return res;
90  }
91 }
virtual vec_byte_t decode(Mat &lossMatrix=EmptyMat) const
Exact decoding.
Definition: DecodeExact.cpp:7
byte getNumStates(void) const
Returns number of states (classes)
Definition: Graph.h:99
CGraph & getGraph(void) const
Returns the reference to the graph.
Definition: Decode.h:66
virtual void getChildNodes(size_t node, vec_size_t &vNodes) const =0
Returns the set of IDs of the child nodes of the argument node.
vec_float_t calculatePotentials(void) const
Calculates potentials for all possible configurations.
Definition: DecodeExact.cpp:58
virtual size_t getNumNodes(void) const =0
Returns the number of nodes in the graph.
void incState(vec_byte_t &state) const
Increases the state by one, i.e. switches the state array to the consequent configuration.
Definition: DecodeExact.cpp:49
virtual void getNode(size_t node, Mat &pot) const =0
Returns the node potential.
void setState(vec_byte_t &state, qword configuration) const
Sets the state according to the configuration index configuration.
Definition: DecodeExact.cpp:39
virtual void getEdge(size_t srcNode, size_t dstNode, Mat &pot) const =0
Returns the edge potential.
IGraphPairwise & getGraphPairwise(void) const
Returns the graph.
Definition: DecodeExact.h:40