Direct Graphical Models  v.1.7.0
EdgeModelPotts.cpp
1 #include "EdgeModelPotts.h"
2 
3 namespace DirectGraphicalModels {
4  // Constructor
5  CEdgeModelPotts::CEdgeModelPotts(const Mat& features, float weight, const std::function<void(const Mat& src, Mat& dst)>& semiMetricFunction, bool perPixelNormalization)
6  : IEdgeModel()
7  , m_pLattice(std::make_unique<CPermutohedral>())
8  , m_weight(weight)
9  , m_norm(features.rows, 1, CV_32FC1, Scalar(1))
10  , m_function(semiMetricFunction)
11  {
12  m_pLattice->init(features);
13 
14  // Compute the normalization factor
15  m_pLattice->compute(m_norm, m_norm);
16 
17  if (perPixelNormalization)
18  for (int n = 0; n < m_norm.rows; n++)
19  m_norm.at<float>(n, 0) = 1.0f / (m_norm.at<float>(n, 0) + FLT_EPSILON);
20  else {
21  float mean_norm = static_cast<float>(sum(m_norm)[0]);
22  mean_norm = m_norm.rows / mean_norm;
23  m_norm.setTo(mean_norm);
24  }
25  }
26 
27  // dst = e^(w * norm * f(Lattice.compute(src)))
28  void CEdgeModelPotts::apply(const Mat &src, Mat &dst) const
29  {
30  m_pLattice->compute(src, dst); // dst = Lattice.compute(src)
31 
32 #ifdef ENABLE_PPL
33  concurrency::parallel_for(0, dst.rows, [&](int n) {
34 #else
35  for (int n = 0; n < dst.rows; n++) { // nodes
36 #endif
37  if (m_function) m_function(dst.row(n), lvalue_cast(dst.row(n))); // With the SemiMetric function
38 
39  // dst.row(n) *= m_weight * m_norm.at<float>(n, 0);
40  // Using expressive notation for sake of efficiency
41  float* pDst = dst.ptr<float>(n);
42  float k = m_weight * m_norm.at<float>(n, 0);
43  for (int s = 0; s < dst.cols; s++) pDst[s] *= k;
44  }
45 #ifdef ENABLE_PPL
46  );
47 #endif
48  exp(dst, dst);
49  }
50 
51 }
std::unique_ptr< CPermutohedral > m_pLattice
Pointer to the permutohedral lattice.
Interface class for edge models used in dense graphical models.
Definition: IEdgeModel.h:14
STL namespace.
void apply(const Mat &src, Mat &dst) const override
Applies an edge model to the node potentials of a dense graph.
Mat m_norm
Array with normalization factors.
CEdgeModelPotts(const Mat &features, float weight=1.0f, const std::function< void(const Mat &src, Mat &dst)> &semiMetricFunction={}, bool perPixelNormalization=true)
Constructor.