1use std::cmp::Ordering;
2
3use spatialrust_core::{HasPositions3, PointCloud, SpatialResult};
4
5use crate::{NearestNeighborIndex, Neighbor, RadiusSearchIndex, SpatialIndex};
6
7const LEAF_SIZE: usize = 16;
8const SEARCH_STACK_SIZE: usize = 64;
9const AXIS_LEAF: u8 = u8::MAX;
10const INVALID_NODE: u32 = u32::MAX;
11
12#[derive(Clone, Debug)]
14pub struct KdTree {
15 x: Vec<f32>,
16 y: Vec<f32>,
17 z: Vec<f32>,
18 points_order: Vec<u32>,
19 nodes: Vec<KdNode>,
20 root: u32,
21}
22
23#[derive(Clone, Copy, Debug)]
24struct KdNode {
25 split: f32,
26 left: u32,
27 right: u32,
28 start: u32,
29 end: u32,
30 axis: u8,
31}
32
33impl KdTree {
34 #[must_use]
36 pub fn from_slices(x: &[f32], y: &[f32], z: &[f32]) -> Self {
37 assert_eq!(x.len(), y.len());
38 assert_eq!(x.len(), z.len());
39
40 let len = x.len();
41 let mut points_order: Vec<u32> = (0..len as u32).collect();
42 let mut nodes = Vec::with_capacity(if len == 0 { 0 } else { len.div_ceil(LEAF_SIZE) * 2 });
43
44 let root = if len == 0 {
45 INVALID_NODE
46 } else {
47 build_node(x, y, z, &mut points_order, 0, len, &mut nodes)
48 };
49
50 Self { x: x.to_vec(), y: y.to_vec(), z: z.to_vec(), points_order, nodes, root }
51 }
52
53 pub fn from_point_cloud(cloud: &PointCloud) -> SpatialResult<Self> {
55 let (x, y, z) = cloud.positions3()?;
56 Ok(Self::from_slices(x, y, z))
57 }
58
59 fn point(&self, point_index: u32) -> (f32, f32, f32) {
60 let idx = point_index as usize;
61 (self.x[idx], self.y[idx], self.z[idx])
62 }
63
64 fn ordered_point(&self, order_index: u32) -> (u32, f32, f32, f32) {
65 let point_index = self.points_order[order_index as usize];
66 let (x, y, z) = self.point(point_index);
67 (point_index, x, y, z)
68 }
69
70 fn nearest_k_recursive(
71 &self,
72 node: u32,
73 qx: f32,
74 qy: f32,
75 qz: f32,
76 k: usize,
77 best: &mut KnnAccumulator,
78 ) {
79 if node == INVALID_NODE {
80 return;
81 }
82
83 let node_data = self.nodes[node as usize];
84 if node_data.axis == AXIS_LEAF {
85 let start = node_data.start as usize;
86 let end = node_data.end as usize;
87 for order_index in start..end {
88 let (index, px, py, pz) = self.ordered_point(order_index as u32);
89 best.insert(
90 k,
91 Neighbor {
92 index: index as usize,
93 distance_squared: squared_distance(px, py, pz, qx, qy, qz),
94 },
95 );
96 }
97 return;
98 }
99
100 let diff = match node_data.axis {
101 0 => qx - node_data.split,
102 1 => qy - node_data.split,
103 _ => qz - node_data.split,
104 };
105
106 let (near, far) = if diff <= 0.0 {
107 (node_data.left, node_data.right)
108 } else {
109 (node_data.right, node_data.left)
110 };
111
112 self.nearest_k_recursive(near, qx, qy, qz, k, best);
113
114 let worst = best.prune_distance(k);
115 if diff * diff < worst || best.len() < k {
116 self.nearest_k_recursive(far, qx, qy, qz, k, best);
117 }
118 }
119
120 fn radius_recursive(
121 &self,
122 node: u32,
123 qx: f32,
124 qy: f32,
125 qz: f32,
126 radius_sq: f32,
127 out: &mut Vec<Neighbor>,
128 ) {
129 if node == INVALID_NODE {
130 return;
131 }
132
133 let node_data = self.nodes[node as usize];
134 if node_data.axis == AXIS_LEAF {
135 let start = node_data.start as usize;
136 let end = node_data.end as usize;
137 for order_index in start..end {
138 let (index, px, py, pz) = self.ordered_point(order_index as u32);
139 let distance_squared = squared_distance(px, py, pz, qx, qy, qz);
140 if distance_squared <= radius_sq {
141 out.push(Neighbor { index: index as usize, distance_squared });
142 }
143 }
144 return;
145 }
146
147 let diff = match node_data.axis {
148 0 => qx - node_data.split,
149 1 => qy - node_data.split,
150 _ => qz - node_data.split,
151 };
152
153 let (near, far) = if diff <= 0.0 {
154 (node_data.left, node_data.right)
155 } else {
156 (node_data.right, node_data.left)
157 };
158
159 self.radius_recursive(near, qx, qy, qz, radius_sq, out);
160 if diff * diff <= radius_sq {
161 self.radius_recursive(far, qx, qy, qz, radius_sq, out);
162 }
163 }
164
165 #[must_use]
170 pub fn radius_reaches(&self, x: f32, y: f32, z: f32, radius: f32, target: usize) -> bool {
171 if target == 0 {
172 return true;
173 }
174 if self.is_empty() || radius < 0.0 {
175 return false;
176 }
177 self.radius_count_iterative(x, y, z, radius * radius, target)
178 }
179
180 fn radius_count_iterative(
181 &self,
182 qx: f32,
183 qy: f32,
184 qz: f32,
185 radius_sq: f32,
186 target: usize,
187 ) -> bool {
188 let mut count = 0usize;
189 let mut stack = [INVALID_NODE; SEARCH_STACK_SIZE];
190 let mut stack_len = 0usize;
191 let mut node = self.root;
192
193 loop {
194 while node != INVALID_NODE {
195 let node_data = self.nodes[node as usize];
196 if node_data.axis == AXIS_LEAF {
197 let start = node_data.start as usize;
198 let end = node_data.end as usize;
199 for order_index in start..end {
200 let (_, px, py, pz) = self.ordered_point(order_index as u32);
201 if squared_distance(px, py, pz, qx, qy, qz) <= radius_sq {
202 count += 1;
203 if count >= target {
204 return true;
205 }
206 }
207 }
208 break;
209 }
210
211 let diff = match node_data.axis {
212 0 => qx - node_data.split,
213 1 => qy - node_data.split,
214 _ => qz - node_data.split,
215 };
216 let (near, far) = if diff <= 0.0 {
217 (node_data.left, node_data.right)
218 } else {
219 (node_data.right, node_data.left)
220 };
221
222 if diff * diff <= radius_sq {
223 if stack_len < stack.len() {
224 stack[stack_len] = far;
225 stack_len += 1;
226 } else if self
227 .radius_count_recursive(far, qx, qy, qz, radius_sq, target, &mut count)
228 {
229 return true;
230 }
231 }
232 node = near;
233 }
234
235 if stack_len == 0 {
236 return false;
237 }
238 stack_len -= 1;
239 node = stack[stack_len];
240 }
241 }
242
243 fn radius_count_recursive(
246 &self,
247 node: u32,
248 qx: f32,
249 qy: f32,
250 qz: f32,
251 radius_sq: f32,
252 target: usize,
253 count: &mut usize,
254 ) -> bool {
255 if node == INVALID_NODE {
256 return false;
257 }
258
259 let node_data = self.nodes[node as usize];
260 if node_data.axis == AXIS_LEAF {
261 let start = node_data.start as usize;
262 let end = node_data.end as usize;
263 for order_index in start..end {
264 let (_, px, py, pz) = self.ordered_point(order_index as u32);
265 if squared_distance(px, py, pz, qx, qy, qz) <= radius_sq {
266 *count += 1;
267 if *count >= target {
268 return true;
269 }
270 }
271 }
272 return false;
273 }
274
275 let diff = match node_data.axis {
276 0 => qx - node_data.split,
277 1 => qy - node_data.split,
278 _ => qz - node_data.split,
279 };
280 let (near, far) = if diff <= 0.0 {
281 (node_data.left, node_data.right)
282 } else {
283 (node_data.right, node_data.left)
284 };
285
286 if self.radius_count_recursive(near, qx, qy, qz, radius_sq, target, count) {
287 return true;
288 }
289 if diff * diff <= radius_sq {
290 return self.radius_count_recursive(far, qx, qy, qz, radius_sq, target, count);
291 }
292 false
293 }
294}
295
296impl SpatialIndex for KdTree {
297 fn len(&self) -> usize {
298 self.x.len()
299 }
300}
301
302impl NearestNeighborIndex for KdTree {
303 fn nearest_one(&self, x: f32, y: f32, z: f32) -> Option<Neighbor> {
304 self.nearest_k(x, y, z, 1).into_iter().next()
305 }
306
307 fn nearest_k(&self, x: f32, y: f32, z: f32, k: usize) -> Vec<Neighbor> {
308 let mut best = Vec::with_capacity(k.min(self.len()));
309 self.nearest_k_into(x, y, z, k, &mut best);
310 best
311 }
312}
313
314impl KdTree {
315 pub fn nearest_k_into(&self, x: f32, y: f32, z: f32, k: usize, out: &mut Vec<Neighbor>) {
318 self.nearest_k_unsorted_into(x, y, z, k, out);
319 out.sort_by(|a, b| {
320 a.distance_squared.partial_cmp(&b.distance_squared).unwrap_or(Ordering::Equal)
321 });
322 }
323
324 pub fn nearest_k_unsorted_into(
328 &self,
329 x: f32,
330 y: f32,
331 z: f32,
332 k: usize,
333 out: &mut Vec<Neighbor>,
334 ) {
335 out.clear();
336 if self.is_empty() || k == 0 {
337 return;
338 }
339
340 out.reserve(k.min(self.len()));
341 let mut best = KnnAccumulator::new(out);
342 self.nearest_k_recursive(self.root, x, y, z, k, &mut best);
343 }
344}
345
346impl RadiusSearchIndex for KdTree {
347 fn radius_search(&self, x: f32, y: f32, z: f32, radius: f32) -> Vec<Neighbor> {
348 if self.is_empty() || radius < 0.0 {
349 return Vec::new();
350 }
351
352 let radius_sq = radius * radius;
353 let mut out = Vec::new();
354 self.radius_recursive(self.root, x, y, z, radius_sq, &mut out);
355 out
358 }
359}
360
361#[allow(clippy::too_many_arguments)]
362fn build_node(
363 x: &[f32],
364 y: &[f32],
365 z: &[f32],
366 points_order: &mut [u32],
367 start: usize,
368 end: usize,
369 nodes: &mut Vec<KdNode>,
370) -> u32 {
371 let node_index = nodes.len() as u32;
372 nodes.push(KdNode {
373 split: 0.0,
374 left: INVALID_NODE,
375 right: INVALID_NODE,
376 start: start as u32,
377 end: end as u32,
378 axis: 0,
379 });
380
381 let count = end - start;
382 if count <= LEAF_SIZE {
383 nodes[node_index as usize].axis = AXIS_LEAF;
384 return node_index;
385 }
386
387 let axis = select_axis(x, y, z, points_order, start, end);
388 let mid = start + count / 2;
389 select_nth_by_axis(x, y, z, points_order, start, end, axis, mid);
390
391 let split_point = points_order[mid];
392 let split_value = coordinate(x, y, z, split_point, axis);
393 nodes[node_index as usize].axis = axis;
394 nodes[node_index as usize].split = split_value;
395
396 let left = build_node(x, y, z, points_order, start, mid, nodes);
397 let right = build_node(x, y, z, points_order, mid, end, nodes);
398
399 nodes[node_index as usize].left = left;
400 nodes[node_index as usize].right = right;
401 node_index
402}
403
404fn select_axis(
405 x: &[f32],
406 y: &[f32],
407 z: &[f32],
408 points_order: &[u32],
409 start: usize,
410 end: usize,
411) -> u8 {
412 let mut min = [f32::INFINITY; 3];
413 let mut max = [f32::NEG_INFINITY; 3];
414 for &point_index in &points_order[start..end] {
415 min[0] = min[0].min(x[point_index as usize]);
416 min[1] = min[1].min(y[point_index as usize]);
417 min[2] = min[2].min(z[point_index as usize]);
418 max[0] = max[0].max(x[point_index as usize]);
419 max[1] = max[1].max(y[point_index as usize]);
420 max[2] = max[2].max(z[point_index as usize]);
421 }
422
423 let mut best_axis = 0_u8;
424 let mut best_extent = max[0] - min[0];
425 for axis in 1_u8..3 {
426 let extent = max[axis as usize] - min[axis as usize];
427 if extent > best_extent {
428 best_extent = extent;
429 best_axis = axis;
430 }
431 }
432 best_axis
433}
434
435fn select_nth_by_axis(
436 x: &[f32],
437 y: &[f32],
438 z: &[f32],
439 points_order: &mut [u32],
440 start: usize,
441 end: usize,
442 axis: u8,
443 nth: usize,
444) {
445 let mut left = start;
446 let mut right = end;
447 while left < right {
448 let pivot = partition_by_axis(x, y, z, points_order, left, right, axis);
449 match nth.cmp(&pivot) {
450 Ordering::Less => right = pivot,
451 Ordering::Greater => left = pivot + 1,
452 Ordering::Equal => break,
453 }
454 }
455}
456
457fn partition_by_axis(
458 x: &[f32],
459 y: &[f32],
460 z: &[f32],
461 points_order: &mut [u32],
462 start: usize,
463 end: usize,
464 axis: u8,
465) -> usize {
466 let pivot_index = (start + end) / 2;
467 points_order.swap(start, pivot_index);
468 let pivot_point = points_order[start];
469 let pivot_value = coordinate(x, y, z, pivot_point, axis);
470
471 let mut store = start + 1;
472 for i in (start + 1)..end {
473 if coordinate(x, y, z, points_order[i], axis) < pivot_value {
474 points_order.swap(i, store);
475 store += 1;
476 }
477 }
478 points_order.swap(start, store - 1);
479 store - 1
480}
481
482fn coordinate(x: &[f32], y: &[f32], z: &[f32], point_index: u32, axis: u8) -> f32 {
483 match axis {
484 0 => x[point_index as usize],
485 1 => y[point_index as usize],
486 _ => z[point_index as usize],
487 }
488}
489
490fn squared_distance(px: f32, py: f32, pz: f32, qx: f32, qy: f32, qz: f32) -> f32 {
491 let dx = px - qx;
492 let dy = py - qy;
493 let dz = pz - qz;
494 dx * dx + dy * dy + dz * dz
495}
496
497#[derive(Debug)]
498struct KnnAccumulator<'a> {
499 neighbors: &'a mut Vec<Neighbor>,
500 worst_index: usize,
501 worst_distance_squared: f32,
502}
503
504impl<'a> KnnAccumulator<'a> {
505 fn new(neighbors: &'a mut Vec<Neighbor>) -> Self {
506 Self { neighbors, worst_index: 0, worst_distance_squared: 0.0 }
507 }
508
509 fn len(&self) -> usize {
510 self.neighbors.len()
511 }
512
513 fn prune_distance(&self, k: usize) -> f32 {
514 if self.neighbors.len() < k {
515 f32::INFINITY
516 } else {
517 self.worst_distance_squared
518 }
519 }
520
521 fn insert(&mut self, k: usize, candidate: Neighbor) {
522 if k == 0 {
523 return;
524 }
525
526 if self.neighbors.len() < k {
527 let distance_squared = candidate.distance_squared;
528 self.neighbors.push(candidate);
529 if self.neighbors.len() == 1 || distance_squared > self.worst_distance_squared {
530 self.worst_index = self.neighbors.len() - 1;
531 self.worst_distance_squared = distance_squared;
532 }
533 return;
534 }
535
536 if candidate.distance_squared >= self.worst_distance_squared {
537 return;
538 }
539
540 self.neighbors[self.worst_index] = candidate;
541 self.refresh_worst();
542 }
543
544 fn refresh_worst(&mut self) {
545 let mut worst_index = 0usize;
546 let mut worst_distance_squared = self.neighbors[0].distance_squared;
547 for (index, neighbor) in self.neighbors.iter().enumerate().skip(1) {
548 if neighbor.distance_squared > worst_distance_squared {
549 worst_index = index;
550 worst_distance_squared = neighbor.distance_squared;
551 }
552 }
553 self.worst_index = worst_index;
554 self.worst_distance_squared = worst_distance_squared;
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use super::KdTree;
561 use crate::{
562 brute::{brute_force_knn, brute_force_radius, BruteForceIndex},
563 NearestNeighborIndex, RadiusSearchIndex,
564 };
565 use spatialrust_core::{PointCloudBuilder, StandardSchemas};
566
567 use crate::SpatialIndex;
568
569 fn sample_cloud() -> (Vec<f32>, Vec<f32>, Vec<f32>) {
570 (
571 vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
572 vec![0.0, 0.0, 0.0, 1.0, 2.0, 0.0],
573 vec![0.0, 0.0, 0.0, 0.0, 0.0, 5.0],
574 )
575 }
576
577 #[test]
578 fn nearest_one_matches_brute_force() {
579 let (x, y, z) = sample_cloud();
580 let tree = KdTree::from_slices(&x, &y, &z);
581 let brute = BruteForceIndex::from_slices(&x, &y, &z);
582
583 let query = (2.1_f32, 0.0, 0.0);
584 assert_eq!(
585 tree.nearest_one(query.0, query.1, query.2),
586 brute.nearest_one(query.0, query.1, query.2)
587 );
588 }
589
590 #[test]
591 fn nearest_k_matches_brute_force() {
592 let (x, y, z) = sample_cloud();
593 let tree = KdTree::from_slices(&x, &y, &z);
594 let expected = brute_force_knn(&x, &y, &z, 1.0, 0.0, 0.0, 3);
595 let actual = tree.nearest_k(1.0, 0.0, 0.0, 3);
596 assert_eq!(actual, expected);
597 }
598
599 #[test]
600 fn radius_search_matches_brute_force() {
601 let (x, y, z) = sample_cloud();
602 let tree = KdTree::from_slices(&x, &y, &z);
603 let mut expected = brute_force_radius(&x, &y, &z, 2.0, 0.0, 0.0, 1.5);
604 let mut actual = tree.radius_search(2.0, 0.0, 0.0, 1.5);
605 expected.sort_by_key(|n| n.index);
607 actual.sort_by_key(|n| n.index);
608 assert_eq!(actual, expected);
609 }
610
611 #[test]
612 fn radius_reaches_matches_radius_search_count() {
613 let (x, y, z) = sample_cloud();
614 let tree = KdTree::from_slices(&x, &y, &z);
615 let count = tree.radius_search(2.0, 0.0, 0.0, 1.5).len();
616 assert!(tree.radius_reaches(2.0, 0.0, 0.0, 1.5, count));
618 assert!(!tree.radius_reaches(2.0, 0.0, 0.0, 1.5, count + 1));
619 assert!(tree.radius_reaches(2.0, 0.0, 0.0, 1.5, 0));
620 }
621
622 #[test]
623 fn radius_reaches_matches_brute_force_on_many_queries() {
624 let mut state = 0x8765_4321_u32;
625 let mut next = || {
626 state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
627 (state as f32 / u32::MAX as f32) * 20.0 - 10.0
628 };
629
630 let mut x = Vec::new();
631 let mut y = Vec::new();
632 let mut z = Vec::new();
633 for _ in 0..257 {
634 x.push(next());
635 y.push(next());
636 z.push(next());
637 }
638
639 let tree = KdTree::from_slices(&x, &y, &z);
640 for radius in [0.5_f32, 2.0, 5.0] {
641 for _ in 0..32 {
642 let qx = next();
643 let qy = next();
644 let qz = next();
645 let count = brute_force_radius(&x, &y, &z, qx, qy, qz, radius).len();
646 assert!(tree.radius_reaches(qx, qy, qz, radius, 0));
647 if count > 0 {
648 assert!(tree.radius_reaches(qx, qy, qz, radius, count));
649 }
650 assert!(!tree.radius_reaches(qx, qy, qz, radius, count + 1));
651 }
652 }
653 }
654
655 #[test]
656 fn nearest_k_matches_brute_force_on_many_queries() {
657 let mut state = 0x1234_5678_u32;
658 let mut next = || {
659 state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
660 (state as f32 / u32::MAX as f32) * 20.0 - 10.0
661 };
662
663 let mut x = Vec::new();
664 let mut y = Vec::new();
665 let mut z = Vec::new();
666 for _ in 0..257 {
667 x.push(next());
668 y.push(next());
669 z.push(next());
670 }
671
672 let tree = KdTree::from_slices(&x, &y, &z);
673 for k in [1_usize, 2, 5, 10, 33] {
674 for _ in 0..64 {
675 let qx = next();
676 let qy = next();
677 let qz = next();
678 let actual = tree.nearest_k(qx, qy, qz, k);
679 let expected = brute_force_knn(&x, &y, &z, qx, qy, qz, k);
680 assert_eq!(actual, expected, "k={k}, query=({qx}, {qy}, {qz})");
681 }
682 }
683 }
684
685 #[test]
686 fn builds_from_point_cloud() {
687 let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
688 builder.push_point([0.0, 0.0, 0.0]).unwrap();
689 builder.push_point([1.0, 0.0, 0.0]).unwrap();
690 let cloud = builder.build().unwrap();
691 let tree = KdTree::from_point_cloud(&cloud).unwrap();
692 assert_eq!(tree.len(), 2);
693 let nearest = tree.nearest_one(0.9, 0.0, 0.0).unwrap();
694 assert_eq!(nearest.index, 1);
695 }
696
697 #[test]
698 fn degenerate_points_return_valid_neighbor() {
699 let x = vec![1.0, 1.0, 1.0];
700 let y = vec![2.0, 2.0, 2.0];
701 let z = vec![3.0, 3.0, 3.0];
702 let tree = KdTree::from_slices(&x, &y, &z);
703 let neighbor = tree.nearest_one(1.0, 2.0, 2.0).unwrap();
704 assert_eq!(neighbor.distance_squared, 1.0);
705 }
706}