Direct Graphical Models  v.1.7.0
KDTree.cpp
1 #include "KDTree.h"
2 #include "random.h"
3 #include "parallel.h"
4 #include "mathop.h"
5 
6 namespace DirectGraphicalModels
7 {
8  namespace {
9  template<typename T>
10  pair_mat_t getBoundingBox(Mat &data)
11  {
12  Mat min, max;
13 
14  data.row(0).copyTo(min);
15  data.row(0).copyTo(max);
16  T * pMin = min.ptr<T>(0);
17  T * pMax = max.ptr<T>(0);
18  for (int y = 1; y < data.rows; y++) { // samples
19  T * pData = data.ptr<T>(y);
20  for (int x = 0; x < data.cols; x++) { // dimensions
21  if (pMin[x] > pData[x]) pMin[x] = pData[x];
22  if (pMax[x] < pData[x]) pMax[x] = pData[x];
23  } // x: dimenstions
24  } // y: samples
25 
26  return std::make_pair(min, max);
27  }
28 
29  template<typename T>
30  int getSplitDimension(pair_mat_t &boundingBox)
31  {
32  int res = 0;
33  Mat diff = boundingBox.second - boundingBox.first; // diff = max - min
34  T max = diff.at<T>(0, 0);
35  Mat maxMasc(1, diff.cols, CV_8UC1); // binary maxv-alue masc
36  for (int x = 0; x < diff.cols; x++) { // dimensions
37  if (max == diff.at<T>(0, x)) {
38  maxMasc.at<byte>(0, x) = 1;
39  }
40  else if (max > diff.at<T>(0, x)) {
41  maxMasc.at<byte>(0, x) = 0;
42  }
43  else if (max < diff.at<T>(0, x)) {
44  maxMasc.at<byte>(0, x) = 1;
45  maxMasc(Rect(0, 0, x, 1)).setTo(0);
46  max = diff.at<T>(0, x);
47  res = x;
48  }
49  } // x: dimensions
50 
51  int nMaxs = countNonZero(maxMasc);
52  if (nMaxs == 1) return res;
53 
54  // Randomly choose one of the maximums
55  int x = 0;
56  int n = random::u<int>(1, nMaxs);
57  for (x = 0; x < diff.cols; x++) { // dimensions
58  if (maxMasc.at<byte>(0, x) == 1) n--;
59  if (n == 0) break;
60  } // x: dimensions
61  return x;
62  }
63  }
64 
65  void CKDTree::save(const std::string &fileName) const
66  {
67  if (!m_root) {
68  DGM_WARNING("The k-D tree is not built");
69  return;
70  }
71  FILE *pFile = fopen(fileName.c_str(), "wb");
72 
73  // header
74  int k = m_root->getBoundingBox().first.cols;
75  fwrite(&k, sizeof(int), 1, pFile); // dimensionality
76  m_root->save(pFile);
77  fclose(pFile);
78  }
79 
80  void CKDTree::load(const std::string &fileName)
81  {
82  FILE *pFile = fopen(fileName.c_str(), "rb");
83  // header
84  int k;
85  fread(&k, sizeof(int), 1, pFile); // dimensionality
86  m_root = loadTree(pFile, k);
87  fclose(pFile);
88  }
89 
90  void CKDTree::build(Mat &keys, Mat &values)
91  {
92  if (keys.empty()) {
93  DGM_WARNING("The data is empty");
94  return;
95  }
96  DGM_ASSERT_MSG(keys.type() == CV_8UC1, "Incorrect type of the keys");
97  DGM_ASSERT_MSG(values.type() == CV_8UC1, "Incorrect type of the values");
98  DGM_ASSERT_MSG(keys.rows == values.rows, "The amount of keys (%d) does not crrespond to the amount of values (%d)", keys.rows, values.rows);
99 
100  pair_mat_t boundingBox = getBoundingBox<byte>(keys);
101  hconcat(keys, values, keys); // keys = [keys; data]
102 
103  // Delete dublicated entries
104  parallel::sortRows<byte>(keys);
105  Mat data;
106  int y = keys.rows - 1;
107  for(; y > 0; y--) {
108  if (!mathop::isEqual<byte>(keys.row(y), keys.row(y - 1))) data.push_back(keys.row(y));
109  keys.pop_back();
110  }
111  data.push_back(keys.row(0));
112  keys.pop_back();
113 
114  m_root = buildTree(data, boundingBox);
115  }
116 
117  std::vector<std::shared_ptr<const CKDNode>> CKDTree::findNearestNeighbors(const Mat &key, size_t maxNeighbors) const
118  {
119  const float searchRadius_extension = 2.0f;
120 
121  std::vector<std::shared_ptr<const CKDNode>> nearestNeighbors;
122  nearestNeighbors.reserve(maxNeighbors);
123 
124  if (!m_root) {
125  DGM_WARNING("The k-D tree is not built");
126  return nearestNeighbors;
127  }
128 
129  std::shared_ptr<const CKDNode> nearestNode = findNearestNode(key);
130  nearestNeighbors.push_back(nearestNode);
131  float searchRadius = mathop::Euclidian<byte, float>(key, nearestNode->getKey());
132  if (maxNeighbors > 1) searchRadius *= searchRadius_extension;
133  searchRadius += 0.5f;
134 // searchRadius = 255 * sqrtf(static_cast<float>(key.cols)); // infinity
135 
136  pair_mat_t searchBox;
137  searchBox.first = key - searchRadius;
138  searchBox.second = key + searchRadius;
139 
140  m_root->findNearestNeighbors(key, maxNeighbors, searchBox, searchRadius, nearestNeighbors);
141 
142  return nearestNeighbors;
143  }
144 
145  // ----------------------------------------- Private -----------------------------------------
146  std::shared_ptr<CKDNode> CKDTree::loadTree(FILE * pFile, int k)
147  {
148  byte _isLeaf;
149  fread(&_isLeaf, sizeof(byte), 1, pFile);
150  if (_isLeaf) { // --- Leaf node ---
151  Mat key(1, k, CV_8UC1);
152  byte value;
153 
154  fread(key.data, sizeof(byte), k, pFile);
155  fread(&value, sizeof(byte), 1, pFile);
156 
157  std::shared_ptr<CKDNode> res(new CKDNode(key, value));
158  return res;
159  } else { // --- Branch node ---
160  pair_mat_t boundingBox = std::make_pair(Mat(1, k, CV_8UC1), Mat(1, k, CV_8UC1));
161  byte splitVal;
162  int splitDim;
163 
164  fread(boundingBox.first.data, sizeof(byte), k, pFile);
165  fread(boundingBox.second.data, sizeof(byte), k, pFile);
166  fread(&splitVal, sizeof(byte), 1, pFile);
167  fread(&splitDim, sizeof(int), 1, pFile);
168  std::shared_ptr<CKDNode> left = loadTree(pFile, k);
169  std::shared_ptr<CKDNode> right = loadTree(pFile, k);
170 
171  std::shared_ptr<CKDNode> res(new CKDNode(boundingBox, splitVal, splitDim, left, right));
172  return res;
173  }
174  }
175 
176  // data_i = [key,val]: k + 1 entries
177  // left = [0; splitVal)
178  // right = [splitVal; end]
179  std::shared_ptr<CKDNode> CKDTree::buildTree(Mat &data, pair_mat_t &boundingBox)
180  {
181  if (data.rows == 1) {
182  std::shared_ptr<CKDNode> res(new CKDNode(lvalue_cast(data(Rect(0, 0, data.cols - 1, 1))), data.at<byte>(0, data.cols - 1)));
183  return res;
184  }
185  else if (data.rows == 2) {
186  //pair_mat_t boundingBox = getBoundingBox<byte>(data);
187  int splitDim = getSplitDimension<byte>(boundingBox);
188  byte splitVal = (data.at<byte>(0, splitDim) + data.at<byte>(1, splitDim)) / 2;
189  std::shared_ptr<CKDNode> left, right;
190  if (data.at<byte>(0, splitDim) < data.at<byte>(1, splitDim)) {
191  left = std::shared_ptr<CKDNode>(new CKDNode(lvalue_cast(data.row(0)(Rect(0, 0, data.cols - 1, 1))), data.at<byte>(0, data.cols - 1)));
192  right = std::shared_ptr<CKDNode>(new CKDNode(lvalue_cast(data.row(1)(Rect(0, 0, data.cols - 1, 1))), data.at<byte>(1, data.cols - 1)));
193  }
194  else {
195  left = std::shared_ptr<CKDNode>(new CKDNode(lvalue_cast(data.row(1)(Rect(0, 0, data.cols - 1, 1))), data.at<byte>(1, data.cols - 1)));
196  right = std::shared_ptr<CKDNode>(new CKDNode(lvalue_cast(data.row(0)(Rect(0, 0, data.cols - 1, 1))), data.at<byte>(0, data.cols - 1)));
197  }
198  std::shared_ptr<CKDNode> res(new CKDNode(boundingBox, splitVal, splitDim, left, right));
199  return res;
200  }
201  else {
202  //pair_mat_t boundingBox = getBoundingBox<byte>(data);
203  int splitDim = getSplitDimension<byte>(boundingBox);
204  //if (splitDim == 1) printf("Achtung\n");
205  parallel::sortRows<byte>(data, splitDim);
206  int splitIdx = data.rows / 2;
207  byte splitVal = data.at<byte>(splitIdx, splitDim);
208 
209  pair_mat_t boundingBoxLeft, boundingBoxRight;
210  boundingBox.first.copyTo(boundingBoxLeft.first);
211  boundingBox.second.copyTo(boundingBoxLeft.second);
212  boundingBox.first.copyTo(boundingBoxRight.first);
213  boundingBox.second.copyTo(boundingBoxRight.second);
214 
215  boundingBoxLeft.second.at<byte>(0, splitDim) = splitVal > 0 ? splitVal - 1 : 0;
216  boundingBoxRight.first.at<byte>(0, splitDim) = splitVal;
217 
218  std::shared_ptr<CKDNode> left = buildTree(lvalue_cast(data(Rect(0, 0, data.cols, splitIdx))), boundingBoxLeft);
219  std::shared_ptr<CKDNode> right = buildTree(lvalue_cast(data(Rect(0, data.rows / 2, data.cols, data.rows - data.rows / 2))), boundingBoxRight);
220  std::shared_ptr<CKDNode> res(new CKDNode(boundingBox, splitVal, splitDim, left, right));
221  return res;
222  }
223  }
224 
225  std::shared_ptr<const CKDNode> CKDTree::findNearestNode(const Mat &key) const
226  {
227  std::shared_ptr<CKDNode> node(m_root);
228 
229  while (!node->isLeaf()) {
230  std::shared_ptr<CKDNode> n = std::static_pointer_cast<CKDNode>(node);
231  if (key.at<byte>(0, n->getSplitDim()) < n->getSplitVal()) node = n->Left();
232  else node = n->Right();
233  }
234 
235  return std::static_pointer_cast<CKDNode>(node);
236  }
237 
238 }
k-D Node class for the k-D Tree data structure
Definition: KDNode.h:17
std::shared_ptr< const CKDNode > findNearestNode(const Mat &key) const
Definition: KDTree.cpp:225
std::shared_ptr< CKDNode > buildTree(Mat &data, pair_mat_t &boundingBox)
Definition: KDTree.cpp:179
std::shared_ptr< CKDNode > loadTree(FILE *pFile, int k)
Definition: KDTree.cpp:146
std::shared_ptr< CKDNode > m_root
Definition: KDTree.h:83
void build(Mat &keys, Mat &values)
Builds a k-d tree on keys with corresponding values.
Definition: KDTree.cpp:90
std::vector< std::shared_ptr< const CKDNode > > findNearestNeighbors(const Mat &key, size_t maxNeighbors) const
Finds up to maxNeighbors nearest neighbors to the key.
Definition: KDTree.cpp:117
void load(const std::string &fileName)
Loads a tree from the file.
Definition: KDTree.cpp:80
void save(const std::string &fileName) const
Saves the tree into a file.
Definition: KDTree.cpp:65
std::shared_ptr< CKDNode > Left(void) const
Returns the pointer to the left child.
Definition: KDNode.h:97