Tapkee
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros
quadtree.hpp
Go to the documentation of this file.
1 
32 #include <math.h>
33 #include <float.h>
34 #include <stdlib.h>
35 #include <stdio.h>
36 #include <algorithm>
37 
38 #ifndef QUADTREE_H
39 #define QUADTREE_H
40 
41 namespace tsne
42 {
43 
44 using tapkee::ScalarType;
45 
46 class Cell {
47 
48 public:
49 
54 
55  bool containsPoint(ScalarType point[])
56  {
57  if(x - hw > point[0]) return false;
58  if(x + hw < point[0]) return false;
59  if(y - hh > point[1]) return false;
60  if(y + hh < point[1]) return false;
61  return true;
62  }
63 
64 };
65 
66 
67 class QuadTree
68 {
69 
70  // Fixed constants
71  static const int QT_NO_DIMS = 2;
72  static const int QT_NODE_CAPACITY = 1;
73 
74  // A buffer we use when doing force computations
76 
77  // Properties of this node in the tree
79  bool is_leaf;
80  int size;
81  int cum_size;
82 
83  // Axis-aligned bounding box stored as a center with half-dimensions to represent the boundaries of this quad tree
85 
86  // Indices in this quad tree node, corresponding center-of-mass, and list of all children
90 
91  // Children
96 
97 public:
98 
99  // Default constructor for quadtree -- build tree, too!
100  QuadTree(ScalarType* inp_data, int N) :
101  parent(NULL), is_leaf(false), size(0), cum_size(0), boundary(), data(NULL),
102  northWest(NULL), northEast(NULL), southWest(NULL), southEast(NULL)
103  {
104  // Compute mean, width, and height of current map (boundaries of quadtree)
105  ScalarType* mean_Y = new ScalarType[QT_NO_DIMS]; for(int d = 0; d < QT_NO_DIMS; d++) mean_Y[d] = .0;
106  ScalarType* min_Y = new ScalarType[QT_NO_DIMS]; for(int d = 0; d < QT_NO_DIMS; d++) min_Y[d] = DBL_MAX;
107  ScalarType* max_Y = new ScalarType[QT_NO_DIMS]; for(int d = 0; d < QT_NO_DIMS; d++) max_Y[d] = -DBL_MAX;
108  for(int n = 0; n < N; n++) {
109  for(int d = 0; d < QT_NO_DIMS; d++) {
110  mean_Y[d] += inp_data[n * QT_NO_DIMS + d];
111  if(inp_data[n * QT_NO_DIMS + d] < min_Y[d]) min_Y[d] = inp_data[n * QT_NO_DIMS + d];
112  if(inp_data[n * QT_NO_DIMS + d] > max_Y[d]) max_Y[d] = inp_data[n * QT_NO_DIMS + d];
113  }
114  }
115  for(int d = 0; d < QT_NO_DIMS; d++) mean_Y[d] /= (ScalarType) N;
116 
117  // Construct quadtree
118  init(NULL, inp_data, mean_Y[0], mean_Y[1], std::max(max_Y[0] - mean_Y[0], mean_Y[0] - min_Y[0]) + 1e-5,
119  std::max(max_Y[1] - mean_Y[1], mean_Y[1] - min_Y[1]) + 1e-5);
120  fill(N);
121  delete[] mean_Y; delete[] max_Y; delete[] min_Y;
122  }
123 
124  // Constructor for quadtree with particular size and parent -- build the tree, too!
125  QuadTree(ScalarType* inp_data, ScalarType inp_x, ScalarType inp_y, ScalarType inp_hw, ScalarType inp_hh) :
126  parent(NULL), is_leaf(false), size(0), cum_size(0), boundary(), data(NULL),
127  northWest(NULL), northEast(NULL), southWest(NULL), southEast(NULL)
128  {
129  init(NULL, inp_data, inp_x, inp_y, inp_hw, inp_hh);
130  }
131 
132  // Constructor for quadtree with particular size and parent -- build the tree, too!
133  QuadTree(ScalarType* inp_data, int N, ScalarType inp_x, ScalarType inp_y, ScalarType inp_hw, ScalarType inp_hh) :
134  parent(NULL), is_leaf(false), size(0), cum_size(0), boundary(), data(NULL),
135  northWest(NULL), northEast(NULL), southWest(NULL), southEast(NULL)
136  {
137  init(NULL, inp_data, inp_x, inp_y, inp_hw, inp_hh);
138  fill(N);
139  }
140 
141  // Constructor for quadtree with particular size (do not fill the tree)
142  QuadTree(QuadTree* inp_parent, ScalarType* inp_data, int N, ScalarType inp_x, ScalarType inp_y, ScalarType inp_hw, ScalarType inp_hh) :
143  parent(NULL), is_leaf(false), size(0), cum_size(0), boundary(), data(NULL),
144  northWest(NULL), northEast(NULL), southWest(NULL), southEast(NULL)
145  {
146  init(inp_parent, inp_data, inp_x, inp_y, inp_hw, inp_hh);
147  fill(N);
148  }
149 
150  // Constructor for quadtree with particular size and parent (do not fill the tree)
151  QuadTree(QuadTree* inp_parent, ScalarType* inp_data, ScalarType inp_x, ScalarType inp_y, ScalarType inp_hw, ScalarType inp_hh) :
152  parent(NULL), is_leaf(false), size(0), cum_size(0), boundary(), data(NULL),
153  northWest(NULL), northEast(NULL), southWest(NULL), southEast(NULL)
154  {
155  init(inp_parent, inp_data, inp_x, inp_y, inp_hw, inp_hh);
156  }
157 
158  // Destructor for quadtree
160  {
161  delete northWest;
162  delete northEast;
163  delete southWest;
164  delete southEast;
165  }
166 
167  void setData(ScalarType* inp_data)
168  {
169  data = inp_data;
170  }
171 
173  {
174  return parent;
175  }
176 
177  //void construct(Cell boundary);
178 
179  // Insert a point into the QuadTree
180  bool insert(int new_index)
181  {
182  // Ignore objects which do not belong in this quad tree
183  ScalarType* point = data + new_index * QT_NO_DIMS;
184  if(!boundary.containsPoint(point))
185  return false;
186 
187  // Online update of cumulative size and center-of-mass
188  cum_size++;
189  ScalarType mult1 = (ScalarType) (cum_size - 1) / (ScalarType) cum_size;
190  ScalarType mult2 = 1.0 / (ScalarType) cum_size;
191  for(int d = 0; d < QT_NO_DIMS; d++) center_of_mass[d] *= mult1;
192  for(int d = 0; d < QT_NO_DIMS; d++) center_of_mass[d] += mult2 * point[d];
193 
194  // If there is space in this quad tree and it is a leaf, add the object here
195  if(is_leaf && size < QT_NODE_CAPACITY) {
196  index[size] = new_index;
197  size++;
198  return true;
199  }
200 
201  // Don't add duplicates for now (this is not very nice)
202  bool any_duplicate = false;
203  for(int n = 0; n < size; n++) {
204  bool duplicate = true;
205  for(int d = 0; d < QT_NO_DIMS; d++) {
206  if(point[d] != data[index[n] * QT_NO_DIMS + d]) { duplicate = false; break; }
207  }
208  any_duplicate = any_duplicate | duplicate;
209  }
210  if(any_duplicate) return true;
211 
212  // Otherwise, we need to subdivide the current cell
213  if(is_leaf) subdivide();
214 
215  // Find out where the point can be inserted
216  if(northWest->insert(new_index)) return true;
217  if(northEast->insert(new_index)) return true;
218  if(southWest->insert(new_index)) return true;
219  if(southEast->insert(new_index)) return true;
220 
221  // Otherwise, the point cannot be inserted (this should never happen)
222  return false;
223  }
224 
225  // Create four children which fully divide this cell into four quads of equal area
226  void subdivide()
227  {
228  // Create four children
229  northWest = new QuadTree(this, data, boundary.x - .5 * boundary.hw, boundary.y - .5 * boundary.hh, .5 * boundary.hw, .5 * boundary.hh);
230  northEast = new QuadTree(this, data, boundary.x + .5 * boundary.hw, boundary.y - .5 * boundary.hh, .5 * boundary.hw, .5 * boundary.hh);
231  southWest = new QuadTree(this, data, boundary.x - .5 * boundary.hw, boundary.y + .5 * boundary.hh, .5 * boundary.hw, .5 * boundary.hh);
232  southEast = new QuadTree(this, data, boundary.x + .5 * boundary.hw, boundary.y + .5 * boundary.hh, .5 * boundary.hw, .5 * boundary.hh);
233 
234  // Move existing points to correct children
235  for(int i = 0; i < size; i++) {
236  bool success = false;
237  if(!success) success = northWest->insert(index[i]);
238  if(!success) success = northEast->insert(index[i]);
239  if(!success) success = southWest->insert(index[i]);
240  if(!success) success = southEast->insert(index[i]);
241  index[i] = -1;
242  }
243 
244  // Empty parent node
245  size = 0;
246  is_leaf = false;
247  }
248 
249  // Checks whether the specified tree is correct
250  bool isCorrect()
251  {
252  for(int n = 0; n < size; n++) {
253  ScalarType* point = data + index[n] * QT_NO_DIMS;
254  if(!boundary.containsPoint(point)) return false;
255  }
256  if(!is_leaf) return northWest->isCorrect() &&
257  northEast->isCorrect() &&
258  southWest->isCorrect() &&
259  southEast->isCorrect();
260  else return true;
261  }
262 
263  // Rebuilds a possibly incorrect tree (LAURENS: This function is not tested yet!)
264  void rebuildTree()
265  {
266  for(int n = 0; n < size; n++) {
267  // Check whether point is erroneous
268  ScalarType* point = data + index[n] * QT_NO_DIMS;
269  if(!boundary.containsPoint(point)) {
270 
271  // Remove erroneous point
272  int rem_index = index[n];
273  for(int m = n + 1; m < size; m++) index[m - 1] = index[m];
274  index[size - 1] = -1;
275  size--;
276 
277  // Update center-of-mass and counter in all parents
278  bool done = false;
279  QuadTree* node = this;
280  while(!done) {
281  for(int d = 0; d < QT_NO_DIMS; d++) {
282  node->center_of_mass[d] = ((ScalarType) node->cum_size * node->center_of_mass[d] - point[d]) / (ScalarType) (node->cum_size - 1);
283  }
284  node->cum_size--;
285  if(node->getParent() == NULL) done = true;
286  else node = node->getParent();
287  }
288 
289  // Reinsert point in the root tree
290  node->insert(rem_index);
291  }
292  }
293 
294  // Rebuild lower parts of the tree
299  }
300 
301  // Build a list of all indices in quadtree
302  void getAllIndices(int* indices)
303  {
304  getAllIndices(indices, 0);
305  }
306 
307  int getDepth()
308  {
309  if(is_leaf) return 1;
310  return 1 + std::max(std::max(northWest->getDepth(),
311  northEast->getDepth()),
312  std::max(southWest->getDepth(),
313  southEast->getDepth()));
314  }
315 
316  // Compute non-edge forces using Barnes-Hut algorithm
317  void computeNonEdgeForces(int point_index, ScalarType theta, ScalarType neg_f[], ScalarType* sum_Q)
318  {
319 
320  // Make sure that we spend no time on empty nodes or self-interactions
321  if(cum_size == 0 || (is_leaf && size == 1 && index[0] == point_index)) return;
322 
323  // Compute distance between point and center-of-mass
324  ScalarType D = .0;
325  int ind = point_index * QT_NO_DIMS;
326  for(int d = 0; d < QT_NO_DIMS; d++) buff[d] = data[ind + d];
327  for(int d = 0; d < QT_NO_DIMS; d++) buff[d] -= center_of_mass[d];
328  for(int d = 0; d < QT_NO_DIMS; d++) D += buff[d] * buff[d];
329 
330  // Check whether we can use this node as a "summary"
331  if(is_leaf || std::max(boundary.hh, boundary.hw)/sqrt(D) < theta) {
332 
333  // Compute and add t-SNE force between point and current node
334  ScalarType Q = 1.0 / (1.0 + D);
335  *sum_Q += cum_size * Q;
336  ScalarType mult = cum_size * Q * Q;
337  for(int d = 0; d < QT_NO_DIMS; d++) neg_f[d] += mult * buff[d];
338  }
339  else {
340 
341  // Recursively apply Barnes-Hut to children
342  northWest->computeNonEdgeForces(point_index, theta, neg_f, sum_Q);
343  northEast->computeNonEdgeForces(point_index, theta, neg_f, sum_Q);
344  southWest->computeNonEdgeForces(point_index, theta, neg_f, sum_Q);
345  southEast->computeNonEdgeForces(point_index, theta, neg_f, sum_Q);
346  }
347  }
348 
349  // Computes edge forces
350  void computeEdgeForces(int* row_P, int* col_P, ScalarType* val_P, int N, ScalarType* pos_f)
351  {
352  // Loop over all edges in the graph
353  int ind1, ind2;
354  ScalarType D;
355  for(int n = 0; n < N; n++) {
356  ind1 = n * QT_NO_DIMS;
357  for(int i = row_P[n]; i < row_P[n + 1]; i++) {
358 
359  // Compute pairwise distance and Q-value
360  D = .0;
361  ind2 = col_P[i] * QT_NO_DIMS;
362  for(int d = 0; d < QT_NO_DIMS; d++) buff[d] = data[ind1 + d];
363  for(int d = 0; d < QT_NO_DIMS; d++) buff[d] -= data[ind2 + d];
364  for(int d = 0; d < QT_NO_DIMS; d++) D += buff[d] * buff[d];
365  D = val_P[i] / (1.0 + D);
366 
367  // Sum positive force
368  for(int d = 0; d < QT_NO_DIMS; d++) pos_f[ind1 + d] += D * buff[d];
369  }
370  }
371  }
372 
373  // Print out tree
374  void print()
375  {
376  if(cum_size == 0) {
377  printf("Empty node\n");
378  return;
379  }
380 
381  if(is_leaf) {
382  printf("Leaf node; data = [");
383  for(int i = 0; i < size; i++) {
384  ScalarType* point = data + index[i] * QT_NO_DIMS;
385  for(int d = 0; d < QT_NO_DIMS; d++) printf("%f, ", point[d]);
386  printf(" (index = %d)", index[i]);
387  if(i < size - 1) printf("\n");
388  else printf("]\n");
389  }
390  }
391  else {
392  printf("Intersection node with center-of-mass = [");
393  for(int d = 0; d < QT_NO_DIMS; d++) printf("%f, ", center_of_mass[d]);
394  printf("]; children are:\n");
395  northEast->print();
396  northWest->print();
397  southEast->print();
398  southWest->print();
399  }
400  }
401 
402 private:
403 
404  QuadTree(const QuadTree&);
405  QuadTree& operator=(const QuadTree&);
406 
407  void init(QuadTree* inp_parent, ScalarType* inp_data, ScalarType inp_x, ScalarType inp_y, ScalarType inp_hw, ScalarType inp_hh)
408  {
409  parent = inp_parent;
410  data = inp_data;
411  is_leaf = true;
412  size = 0;
413  cum_size = 0;
414  boundary.x = inp_x;
415  boundary.y = inp_y;
416  boundary.hw = inp_hw;
417  boundary.hh = inp_hh;
418  northWest = NULL;
419  northEast = NULL;
420  southWest = NULL;
421  southEast = NULL;
422  for(int i = 0; i < QT_NO_DIMS; i++) center_of_mass[i] = .0;
423  }
424 
425  // Build quadtree on dataset
426  void fill(int N)
427  {
428  for(int i = 0; i < N; i++) insert(i);
429  }
430 
431  // Build a list of all indices in quadtree
432  int getAllIndices(int* indices, int loc)
433  {
434 
435  // Gather indices in current quadrant
436  for(int i = 0; i < size; i++) indices[loc + i] = index[i];
437  loc += size;
438 
439  // Gather indices in children
440  if(!is_leaf) {
441  loc = northWest->getAllIndices(indices, loc);
442  loc = northEast->getAllIndices(indices, loc);
443  loc = southWest->getAllIndices(indices, loc);
444  loc = southEast->getAllIndices(indices, loc);
445  }
446  return loc;
447  }
448 
449  //bool isChild(int test_index, int start, int end);
450 };
451 
452 }
453 
454 #endif
QuadTree * southWest
Definition: quadtree.hpp:94
QuadTree * parent
Definition: quadtree.hpp:78
ScalarType * data
Definition: quadtree.hpp:87
QuadTree(ScalarType *inp_data, int N)
Definition: quadtree.hpp:100
QuadTree * northWest
Definition: quadtree.hpp:92
void rebuildTree()
Definition: quadtree.hpp:264
void getAllIndices(int *indices)
Definition: quadtree.hpp:302
ScalarType x
Definition: quadtree.hpp:50
QuadTree(ScalarType *inp_data, int N, ScalarType inp_x, ScalarType inp_y, ScalarType inp_hw, ScalarType inp_hh)
Definition: quadtree.hpp:133
QuadTree * northEast
Definition: quadtree.hpp:93
double ScalarType
default scalar value (can be overrided with TAPKEE_CUSTOM_INTERNAL_NUMTYPE define) ...
Definition: types.hpp:15
void computeNonEdgeForces(int point_index, ScalarType theta, ScalarType neg_f[], ScalarType *sum_Q)
Definition: quadtree.hpp:317
ScalarType y
Definition: quadtree.hpp:51
ScalarType buff[QT_NO_DIMS]
Definition: quadtree.hpp:75
ScalarType hh
Definition: quadtree.hpp:53
bool isCorrect()
Definition: quadtree.hpp:250
QuadTree(QuadTree *inp_parent, ScalarType *inp_data, ScalarType inp_x, ScalarType inp_y, ScalarType inp_hw, ScalarType inp_hh)
Definition: quadtree.hpp:151
QuadTree(QuadTree *inp_parent, ScalarType *inp_data, int N, ScalarType inp_x, ScalarType inp_y, ScalarType inp_hw, ScalarType inp_hh)
Definition: quadtree.hpp:142
static const int QT_NODE_CAPACITY
Definition: quadtree.hpp:72
void init(QuadTree *inp_parent, ScalarType *inp_data, ScalarType inp_x, ScalarType inp_y, ScalarType inp_hw, ScalarType inp_hh)
Definition: quadtree.hpp:407
QuadTree * getParent()
Definition: quadtree.hpp:172
QuadTree * southEast
Definition: quadtree.hpp:95
QuadTree(ScalarType *inp_data, ScalarType inp_x, ScalarType inp_y, ScalarType inp_hw, ScalarType inp_hh)
Definition: quadtree.hpp:125
int index[QT_NODE_CAPACITY]
Definition: quadtree.hpp:89
static const int QT_NO_DIMS
Definition: quadtree.hpp:71
void setData(ScalarType *inp_data)
Definition: quadtree.hpp:167
bool insert(int new_index)
Definition: quadtree.hpp:180
int getAllIndices(int *indices, int loc)
Definition: quadtree.hpp:432
void computeEdgeForces(int *row_P, int *col_P, ScalarType *val_P, int N, ScalarType *pos_f)
Definition: quadtree.hpp:350
ScalarType center_of_mass[QT_NO_DIMS]
Definition: quadtree.hpp:88
void fill(int N)
Definition: quadtree.hpp:426
QuadTree & operator=(const QuadTree &)
ScalarType hw
Definition: quadtree.hpp:52
void subdivide()
Definition: quadtree.hpp:226
bool containsPoint(ScalarType point[])
Definition: quadtree.hpp:55