2 #include "GraphPairwise.h" 18 std::fill(edge->msg, edge->msg + nStates, 1.0f);
19 std::fill(edge->msg_temp, edge->msg_temp + nStates, 1.0f);
30 for (
size_t e_f : node->from) {
32 if (edge_from->
node1 > edge_from->
node2)
continue;
34 for (byte s = 0; s < nStates; s++) node->Pot.at<
float>(s, 0) *= edge_from->
Pot.at<
float>(src->
sol, s);
37 for (
size_t e_t : node->to) {
40 for (byte s = 0; s < nStates; s++) node->Pot.at<
float>(s, 0) *= edge_to->
msg[s];
44 minMaxLoc(node->Pot, NULL, NULL, NULL, &extremumLoc);
45 node->sol =
static_cast<byte
> (extremumLoc.y);
54 float * data =
new float[nStates];
55 float * temp =
new float[nStates];
58 for (
unsigned int i = 0; i < nIt; i++) {
59 #ifdef DEBUG_PRINT_INFO 60 if (i == 0) printf(
"\n");
61 if (i % 5 == 0) printf(
"--- It: %d ---\n", i);
65 memcpy(data, node->Pot.data, nStates *
sizeof(
float));
68 for (
size_t e_t : node->to) {
71 for (byte s = 0; s < nStates; s++) data[s] *= edge_to->
msg[s];
76 for (
size_t e_f : node->from) {
78 if (edge_from->
node1 > edge_from->
node2)
continue;
79 for (byte s = 0; s < nStates; s++) data[s] *= edge_from->
msg[s];
83 for (byte s = 0; s < nStates; s++) data[s] = static_cast<float>(
fastPow(data[s], 1.0f / MAX(nForward, nBackward)));
86 for (
size_t e_t : node->to) {
94 memcpy(data, node->Pot.data, nStates *
sizeof(
float));
97 for (
size_t e_t : node->to) {
100 for (byte s = 0; s < nStates; s++) data[s] *= edge_to->
msg[s];
105 for (
size_t e_f : node->from) {
107 if (edge_from->
node1 > edge_from->
node2)
continue;
108 for (byte s = 0; s < nStates; s++) data[s] *= edge_from->
msg[s];
114 for (byte s = 1; s < nStates; s++)
if (max < data[s]) max = data[s];
115 for (byte s = 0; s < nStates; s++) data[s] /= max;
116 for (byte s = 0; s < nStates; s++) data[s] = static_cast<float>(
fastPow(data[s], 1.0f / MAX(nForward, nBackward)));
119 for (
size_t e_f : node->from) {
135 for (byte s = 0; s < nStates; s++) temp[s] = data[s] / MAX(FLT_EPSILON, edge.
msg[s]);
137 for (byte y = 0; y < nStates; y++) {
138 float *pPot = edge.
Pot.ptr<
float>(y);
139 float max = temp[0] * pPot[0];
140 for (byte x = 1; x < nStates; x++) {
141 float val = temp[x] * pPot[x];
142 if (max < val) max = val;
148 float max = edge.
msg[0];
149 for (byte s = 1; s < nStates; s++)
if (max < edge.
msg[s]) max = edge.
msg[s];
150 for (byte s = 0; s < nStates; s++) edge.
msg[s] /= max;
byte getNumStates(void) const
Returns number of states (classes)
size_t node2
Second (destination) node in edge.
float * msg
Message (used in message-passing algorithms): Mat(size: nStates x 1; type: CV_32FC1) ...
size_t node1
First (source) node in edge.
CGraph & getGraph(void) const
Returns the reference to the graph.
void calculateMessage(Edge &edge, float *temp, float *data)
std::unique_ptr< Node > ptr_node_t
double fastPow(double a, double b)
CGraphPairwise & getGraphPairwise(void) const
Returns the graph.
void deleteMessages(void)
Deletes memory for Edge::msg and Edge::msg_temp containers for all edges in the graph.
Mat Pot
The edge potentials: Mat(size: nStates x nStates; type: CV_32FC1)
virtual void infer(unsigned int nIt=1)
Inference.
std::unique_ptr< Edge > ptr_edge_t
void createMessages(void)
Allocates memory for Edge::msg and Edge::msg_temp containers for all edges in the graph...
virtual void calculateMessages(unsigned int nIt)
Calculates messages, associated with the edges of corresponding graphical model.