Skip to main content

spatialrust_search/
kdtree.rs

1use std::cmp::Ordering;
2
3use spatialrust_core::{HasPositions3, PointCloud, SpatialResult};
4
5use crate::{NearestNeighborIndex, Neighbor, RadiusSearchIndex, SpatialIndex};
6
7const LEAF_SIZE: usize = 16;
8const SEARCH_STACK_SIZE: usize = 64;
9const AXIS_LEAF: u8 = u8::MAX;
10const INVALID_NODE: u32 = u32::MAX;
11
12/// Cache-friendly KD-tree for 3D point clouds.
13#[derive(Clone, Debug)]
14pub struct KdTree {
15    x: Vec<f32>,
16    y: Vec<f32>,
17    z: Vec<f32>,
18    points_order: Vec<u32>,
19    nodes: Vec<KdNode>,
20    root: u32,
21}
22
23#[derive(Clone, Copy, Debug)]
24struct KdNode {
25    split: f32,
26    left: u32,
27    right: u32,
28    start: u32,
29    end: u32,
30    axis: u8,
31}
32
33impl KdTree {
34    /// Builds a KD-tree from coordinate slices.
35    #[must_use]
36    pub fn from_slices(x: &[f32], y: &[f32], z: &[f32]) -> Self {
37        assert_eq!(x.len(), y.len());
38        assert_eq!(x.len(), z.len());
39
40        let len = x.len();
41        let mut points_order: Vec<u32> = (0..len as u32).collect();
42        let mut nodes = Vec::with_capacity(if len == 0 { 0 } else { len.div_ceil(LEAF_SIZE) * 2 });
43
44        let root = if len == 0 {
45            INVALID_NODE
46        } else {
47            build_node(x, y, z, &mut points_order, 0, len, &mut nodes)
48        };
49
50        Self { x: x.to_vec(), y: y.to_vec(), z: z.to_vec(), points_order, nodes, root }
51    }
52
53    /// Builds a KD-tree from any point cloud with XYZ positions.
54    pub fn from_point_cloud(cloud: &PointCloud) -> SpatialResult<Self> {
55        let (x, y, z) = cloud.positions3()?;
56        Ok(Self::from_slices(x, y, z))
57    }
58
59    fn point(&self, point_index: u32) -> (f32, f32, f32) {
60        let idx = point_index as usize;
61        (self.x[idx], self.y[idx], self.z[idx])
62    }
63
64    fn ordered_point(&self, order_index: u32) -> (u32, f32, f32, f32) {
65        let point_index = self.points_order[order_index as usize];
66        let (x, y, z) = self.point(point_index);
67        (point_index, x, y, z)
68    }
69
70    fn nearest_k_recursive(
71        &self,
72        node: u32,
73        qx: f32,
74        qy: f32,
75        qz: f32,
76        k: usize,
77        best: &mut KnnAccumulator,
78    ) {
79        if node == INVALID_NODE {
80            return;
81        }
82
83        let node_data = self.nodes[node as usize];
84        if node_data.axis == AXIS_LEAF {
85            let start = node_data.start as usize;
86            let end = node_data.end as usize;
87            for order_index in start..end {
88                let (index, px, py, pz) = self.ordered_point(order_index as u32);
89                best.insert(
90                    k,
91                    Neighbor {
92                        index: index as usize,
93                        distance_squared: squared_distance(px, py, pz, qx, qy, qz),
94                    },
95                );
96            }
97            return;
98        }
99
100        let diff = match node_data.axis {
101            0 => qx - node_data.split,
102            1 => qy - node_data.split,
103            _ => qz - node_data.split,
104        };
105
106        let (near, far) = if diff <= 0.0 {
107            (node_data.left, node_data.right)
108        } else {
109            (node_data.right, node_data.left)
110        };
111
112        self.nearest_k_recursive(near, qx, qy, qz, k, best);
113
114        let worst = best.prune_distance(k);
115        if diff * diff < worst || best.len() < k {
116            self.nearest_k_recursive(far, qx, qy, qz, k, best);
117        }
118    }
119
120    fn radius_recursive(
121        &self,
122        node: u32,
123        qx: f32,
124        qy: f32,
125        qz: f32,
126        radius_sq: f32,
127        out: &mut Vec<Neighbor>,
128    ) {
129        if node == INVALID_NODE {
130            return;
131        }
132
133        let node_data = self.nodes[node as usize];
134        if node_data.axis == AXIS_LEAF {
135            let start = node_data.start as usize;
136            let end = node_data.end as usize;
137            for order_index in start..end {
138                let (index, px, py, pz) = self.ordered_point(order_index as u32);
139                let distance_squared = squared_distance(px, py, pz, qx, qy, qz);
140                if distance_squared <= radius_sq {
141                    out.push(Neighbor { index: index as usize, distance_squared });
142                }
143            }
144            return;
145        }
146
147        let diff = match node_data.axis {
148            0 => qx - node_data.split,
149            1 => qy - node_data.split,
150            _ => qz - node_data.split,
151        };
152
153        let (near, far) = if diff <= 0.0 {
154            (node_data.left, node_data.right)
155        } else {
156            (node_data.right, node_data.left)
157        };
158
159        self.radius_recursive(near, qx, qy, qz, radius_sq, out);
160        if diff * diff <= radius_sq {
161            self.radius_recursive(far, qx, qy, qz, radius_sq, out);
162        }
163    }
164
165    /// Returns whether at least `target` points lie within `radius` of the
166    /// query, stopping as soon as the threshold is reached. Unlike
167    /// [`radius_search`](RadiusSearchIndex::radius_search) this allocates nothing
168    /// and early-exits, which is much faster for density tests (outlier removal).
169    #[must_use]
170    pub fn radius_reaches(&self, x: f32, y: f32, z: f32, radius: f32, target: usize) -> bool {
171        if target == 0 {
172            return true;
173        }
174        if self.is_empty() || radius < 0.0 {
175            return false;
176        }
177        self.radius_count_iterative(x, y, z, radius * radius, target)
178    }
179
180    fn radius_count_iterative(
181        &self,
182        qx: f32,
183        qy: f32,
184        qz: f32,
185        radius_sq: f32,
186        target: usize,
187    ) -> bool {
188        let mut count = 0usize;
189        let mut stack = [INVALID_NODE; SEARCH_STACK_SIZE];
190        let mut stack_len = 0usize;
191        let mut node = self.root;
192
193        loop {
194            while node != INVALID_NODE {
195                let node_data = self.nodes[node as usize];
196                if node_data.axis == AXIS_LEAF {
197                    let start = node_data.start as usize;
198                    let end = node_data.end as usize;
199                    for order_index in start..end {
200                        let (_, px, py, pz) = self.ordered_point(order_index as u32);
201                        if squared_distance(px, py, pz, qx, qy, qz) <= radius_sq {
202                            count += 1;
203                            if count >= target {
204                                return true;
205                            }
206                        }
207                    }
208                    break;
209                }
210
211                let diff = match node_data.axis {
212                    0 => qx - node_data.split,
213                    1 => qy - node_data.split,
214                    _ => qz - node_data.split,
215                };
216                let (near, far) = if diff <= 0.0 {
217                    (node_data.left, node_data.right)
218                } else {
219                    (node_data.right, node_data.left)
220                };
221
222                if diff * diff <= radius_sq {
223                    if stack_len < stack.len() {
224                        stack[stack_len] = far;
225                        stack_len += 1;
226                    } else if self
227                        .radius_count_recursive(far, qx, qy, qz, radius_sq, target, &mut count)
228                    {
229                        return true;
230                    }
231                }
232                node = near;
233            }
234
235            if stack_len == 0 {
236                return false;
237            }
238            stack_len -= 1;
239            node = stack[stack_len];
240        }
241    }
242
243    /// Accumulates points within `radius_sq` into `count`; returns `true` as soon
244    /// as `count` reaches `target` so the search can short-circuit.
245    fn radius_count_recursive(
246        &self,
247        node: u32,
248        qx: f32,
249        qy: f32,
250        qz: f32,
251        radius_sq: f32,
252        target: usize,
253        count: &mut usize,
254    ) -> bool {
255        if node == INVALID_NODE {
256            return false;
257        }
258
259        let node_data = self.nodes[node as usize];
260        if node_data.axis == AXIS_LEAF {
261            let start = node_data.start as usize;
262            let end = node_data.end as usize;
263            for order_index in start..end {
264                let (_, px, py, pz) = self.ordered_point(order_index as u32);
265                if squared_distance(px, py, pz, qx, qy, qz) <= radius_sq {
266                    *count += 1;
267                    if *count >= target {
268                        return true;
269                    }
270                }
271            }
272            return false;
273        }
274
275        let diff = match node_data.axis {
276            0 => qx - node_data.split,
277            1 => qy - node_data.split,
278            _ => qz - node_data.split,
279        };
280        let (near, far) = if diff <= 0.0 {
281            (node_data.left, node_data.right)
282        } else {
283            (node_data.right, node_data.left)
284        };
285
286        if self.radius_count_recursive(near, qx, qy, qz, radius_sq, target, count) {
287            return true;
288        }
289        if diff * diff <= radius_sq {
290            return self.radius_count_recursive(far, qx, qy, qz, radius_sq, target, count);
291        }
292        false
293    }
294}
295
296impl SpatialIndex for KdTree {
297    fn len(&self) -> usize {
298        self.x.len()
299    }
300}
301
302impl NearestNeighborIndex for KdTree {
303    fn nearest_one(&self, x: f32, y: f32, z: f32) -> Option<Neighbor> {
304        self.nearest_k(x, y, z, 1).into_iter().next()
305    }
306
307    fn nearest_k(&self, x: f32, y: f32, z: f32, k: usize) -> Vec<Neighbor> {
308        let mut best = Vec::with_capacity(k.min(self.len()));
309        self.nearest_k_into(x, y, z, k, &mut best);
310        best
311    }
312}
313
314impl KdTree {
315    /// Finds up to `k` nearest neighbors sorted by ascending distance, reusing
316    /// the caller-provided output buffer.
317    pub fn nearest_k_into(&self, x: f32, y: f32, z: f32, k: usize, out: &mut Vec<Neighbor>) {
318        self.nearest_k_unsorted_into(x, y, z, k, out);
319        out.sort_by(|a, b| {
320            a.distance_squared.partial_cmp(&b.distance_squared).unwrap_or(Ordering::Equal)
321        });
322    }
323
324    /// Finds up to `k` nearest neighbors without sorting the result, reusing the
325    /// caller-provided output buffer. This is faster for callers that only need
326    /// the neighbor set, such as covariance and mean-distance calculations.
327    pub fn nearest_k_unsorted_into(
328        &self,
329        x: f32,
330        y: f32,
331        z: f32,
332        k: usize,
333        out: &mut Vec<Neighbor>,
334    ) {
335        out.clear();
336        if self.is_empty() || k == 0 {
337            return;
338        }
339
340        out.reserve(k.min(self.len()));
341        let mut best = KnnAccumulator::new(out);
342        self.nearest_k_recursive(self.root, x, y, z, k, &mut best);
343    }
344}
345
346impl RadiusSearchIndex for KdTree {
347    fn radius_search(&self, x: f32, y: f32, z: f32, radius: f32) -> Vec<Neighbor> {
348        if self.is_empty() || radius < 0.0 {
349            return Vec::new();
350        }
351
352        let radius_sq = radius * radius;
353        let mut out = Vec::new();
354        self.radius_recursive(self.root, x, y, z, radius_sq, &mut out);
355        // Intentionally unsorted: callers count or iterate neighbors, and
356        // sorting every query dominates radius search on dense clouds.
357        out
358    }
359}
360
361#[allow(clippy::too_many_arguments)]
362fn build_node(
363    x: &[f32],
364    y: &[f32],
365    z: &[f32],
366    points_order: &mut [u32],
367    start: usize,
368    end: usize,
369    nodes: &mut Vec<KdNode>,
370) -> u32 {
371    let node_index = nodes.len() as u32;
372    nodes.push(KdNode {
373        split: 0.0,
374        left: INVALID_NODE,
375        right: INVALID_NODE,
376        start: start as u32,
377        end: end as u32,
378        axis: 0,
379    });
380
381    let count = end - start;
382    if count <= LEAF_SIZE {
383        nodes[node_index as usize].axis = AXIS_LEAF;
384        return node_index;
385    }
386
387    let axis = select_axis(x, y, z, points_order, start, end);
388    let mid = start + count / 2;
389    select_nth_by_axis(x, y, z, points_order, start, end, axis, mid);
390
391    let split_point = points_order[mid];
392    let split_value = coordinate(x, y, z, split_point, axis);
393    nodes[node_index as usize].axis = axis;
394    nodes[node_index as usize].split = split_value;
395
396    let left = build_node(x, y, z, points_order, start, mid, nodes);
397    let right = build_node(x, y, z, points_order, mid, end, nodes);
398
399    nodes[node_index as usize].left = left;
400    nodes[node_index as usize].right = right;
401    node_index
402}
403
404fn select_axis(
405    x: &[f32],
406    y: &[f32],
407    z: &[f32],
408    points_order: &[u32],
409    start: usize,
410    end: usize,
411) -> u8 {
412    let mut min = [f32::INFINITY; 3];
413    let mut max = [f32::NEG_INFINITY; 3];
414    for &point_index in &points_order[start..end] {
415        min[0] = min[0].min(x[point_index as usize]);
416        min[1] = min[1].min(y[point_index as usize]);
417        min[2] = min[2].min(z[point_index as usize]);
418        max[0] = max[0].max(x[point_index as usize]);
419        max[1] = max[1].max(y[point_index as usize]);
420        max[2] = max[2].max(z[point_index as usize]);
421    }
422
423    let mut best_axis = 0_u8;
424    let mut best_extent = max[0] - min[0];
425    for axis in 1_u8..3 {
426        let extent = max[axis as usize] - min[axis as usize];
427        if extent > best_extent {
428            best_extent = extent;
429            best_axis = axis;
430        }
431    }
432    best_axis
433}
434
435fn select_nth_by_axis(
436    x: &[f32],
437    y: &[f32],
438    z: &[f32],
439    points_order: &mut [u32],
440    start: usize,
441    end: usize,
442    axis: u8,
443    nth: usize,
444) {
445    let mut left = start;
446    let mut right = end;
447    while left < right {
448        let pivot = partition_by_axis(x, y, z, points_order, left, right, axis);
449        match nth.cmp(&pivot) {
450            Ordering::Less => right = pivot,
451            Ordering::Greater => left = pivot + 1,
452            Ordering::Equal => break,
453        }
454    }
455}
456
457fn partition_by_axis(
458    x: &[f32],
459    y: &[f32],
460    z: &[f32],
461    points_order: &mut [u32],
462    start: usize,
463    end: usize,
464    axis: u8,
465) -> usize {
466    let pivot_index = (start + end) / 2;
467    points_order.swap(start, pivot_index);
468    let pivot_point = points_order[start];
469    let pivot_value = coordinate(x, y, z, pivot_point, axis);
470
471    let mut store = start + 1;
472    for i in (start + 1)..end {
473        if coordinate(x, y, z, points_order[i], axis) < pivot_value {
474            points_order.swap(i, store);
475            store += 1;
476        }
477    }
478    points_order.swap(start, store - 1);
479    store - 1
480}
481
482fn coordinate(x: &[f32], y: &[f32], z: &[f32], point_index: u32, axis: u8) -> f32 {
483    match axis {
484        0 => x[point_index as usize],
485        1 => y[point_index as usize],
486        _ => z[point_index as usize],
487    }
488}
489
490fn squared_distance(px: f32, py: f32, pz: f32, qx: f32, qy: f32, qz: f32) -> f32 {
491    let dx = px - qx;
492    let dy = py - qy;
493    let dz = pz - qz;
494    dx * dx + dy * dy + dz * dz
495}
496
497#[derive(Debug)]
498struct KnnAccumulator<'a> {
499    neighbors: &'a mut Vec<Neighbor>,
500    worst_index: usize,
501    worst_distance_squared: f32,
502}
503
504impl<'a> KnnAccumulator<'a> {
505    fn new(neighbors: &'a mut Vec<Neighbor>) -> Self {
506        Self { neighbors, worst_index: 0, worst_distance_squared: 0.0 }
507    }
508
509    fn len(&self) -> usize {
510        self.neighbors.len()
511    }
512
513    fn prune_distance(&self, k: usize) -> f32 {
514        if self.neighbors.len() < k {
515            f32::INFINITY
516        } else {
517            self.worst_distance_squared
518        }
519    }
520
521    fn insert(&mut self, k: usize, candidate: Neighbor) {
522        if k == 0 {
523            return;
524        }
525
526        if self.neighbors.len() < k {
527            let distance_squared = candidate.distance_squared;
528            self.neighbors.push(candidate);
529            if self.neighbors.len() == 1 || distance_squared > self.worst_distance_squared {
530                self.worst_index = self.neighbors.len() - 1;
531                self.worst_distance_squared = distance_squared;
532            }
533            return;
534        }
535
536        if candidate.distance_squared >= self.worst_distance_squared {
537            return;
538        }
539
540        self.neighbors[self.worst_index] = candidate;
541        self.refresh_worst();
542    }
543
544    fn refresh_worst(&mut self) {
545        let mut worst_index = 0usize;
546        let mut worst_distance_squared = self.neighbors[0].distance_squared;
547        for (index, neighbor) in self.neighbors.iter().enumerate().skip(1) {
548            if neighbor.distance_squared > worst_distance_squared {
549                worst_index = index;
550                worst_distance_squared = neighbor.distance_squared;
551            }
552        }
553        self.worst_index = worst_index;
554        self.worst_distance_squared = worst_distance_squared;
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use super::KdTree;
561    use crate::{
562        brute::{brute_force_knn, brute_force_radius, BruteForceIndex},
563        NearestNeighborIndex, RadiusSearchIndex,
564    };
565    use spatialrust_core::{PointCloudBuilder, StandardSchemas};
566
567    use crate::SpatialIndex;
568
569    fn sample_cloud() -> (Vec<f32>, Vec<f32>, Vec<f32>) {
570        (
571            vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
572            vec![0.0, 0.0, 0.0, 1.0, 2.0, 0.0],
573            vec![0.0, 0.0, 0.0, 0.0, 0.0, 5.0],
574        )
575    }
576
577    #[test]
578    fn nearest_one_matches_brute_force() {
579        let (x, y, z) = sample_cloud();
580        let tree = KdTree::from_slices(&x, &y, &z);
581        let brute = BruteForceIndex::from_slices(&x, &y, &z);
582
583        let query = (2.1_f32, 0.0, 0.0);
584        assert_eq!(
585            tree.nearest_one(query.0, query.1, query.2),
586            brute.nearest_one(query.0, query.1, query.2)
587        );
588    }
589
590    #[test]
591    fn nearest_k_matches_brute_force() {
592        let (x, y, z) = sample_cloud();
593        let tree = KdTree::from_slices(&x, &y, &z);
594        let expected = brute_force_knn(&x, &y, &z, 1.0, 0.0, 0.0, 3);
595        let actual = tree.nearest_k(1.0, 0.0, 0.0, 3);
596        assert_eq!(actual, expected);
597    }
598
599    #[test]
600    fn radius_search_matches_brute_force() {
601        let (x, y, z) = sample_cloud();
602        let tree = KdTree::from_slices(&x, &y, &z);
603        let mut expected = brute_force_radius(&x, &y, &z, 2.0, 0.0, 0.0, 1.5);
604        let mut actual = tree.radius_search(2.0, 0.0, 0.0, 1.5);
605        // `radius_search` is unsorted, so compare as sets ordered by index.
606        expected.sort_by_key(|n| n.index);
607        actual.sort_by_key(|n| n.index);
608        assert_eq!(actual, expected);
609    }
610
611    #[test]
612    fn radius_reaches_matches_radius_search_count() {
613        let (x, y, z) = sample_cloud();
614        let tree = KdTree::from_slices(&x, &y, &z);
615        let count = tree.radius_search(2.0, 0.0, 0.0, 1.5).len();
616        // True for any target up to the real count, false beyond it.
617        assert!(tree.radius_reaches(2.0, 0.0, 0.0, 1.5, count));
618        assert!(!tree.radius_reaches(2.0, 0.0, 0.0, 1.5, count + 1));
619        assert!(tree.radius_reaches(2.0, 0.0, 0.0, 1.5, 0));
620    }
621
622    #[test]
623    fn radius_reaches_matches_brute_force_on_many_queries() {
624        let mut state = 0x8765_4321_u32;
625        let mut next = || {
626            state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
627            (state as f32 / u32::MAX as f32) * 20.0 - 10.0
628        };
629
630        let mut x = Vec::new();
631        let mut y = Vec::new();
632        let mut z = Vec::new();
633        for _ in 0..257 {
634            x.push(next());
635            y.push(next());
636            z.push(next());
637        }
638
639        let tree = KdTree::from_slices(&x, &y, &z);
640        for radius in [0.5_f32, 2.0, 5.0] {
641            for _ in 0..32 {
642                let qx = next();
643                let qy = next();
644                let qz = next();
645                let count = brute_force_radius(&x, &y, &z, qx, qy, qz, radius).len();
646                assert!(tree.radius_reaches(qx, qy, qz, radius, 0));
647                if count > 0 {
648                    assert!(tree.radius_reaches(qx, qy, qz, radius, count));
649                }
650                assert!(!tree.radius_reaches(qx, qy, qz, radius, count + 1));
651            }
652        }
653    }
654
655    #[test]
656    fn nearest_k_matches_brute_force_on_many_queries() {
657        let mut state = 0x1234_5678_u32;
658        let mut next = || {
659            state = state.wrapping_mul(1_664_525).wrapping_add(1_013_904_223);
660            (state as f32 / u32::MAX as f32) * 20.0 - 10.0
661        };
662
663        let mut x = Vec::new();
664        let mut y = Vec::new();
665        let mut z = Vec::new();
666        for _ in 0..257 {
667            x.push(next());
668            y.push(next());
669            z.push(next());
670        }
671
672        let tree = KdTree::from_slices(&x, &y, &z);
673        for k in [1_usize, 2, 5, 10, 33] {
674            for _ in 0..64 {
675                let qx = next();
676                let qy = next();
677                let qz = next();
678                let actual = tree.nearest_k(qx, qy, qz, k);
679                let expected = brute_force_knn(&x, &y, &z, qx, qy, qz, k);
680                assert_eq!(actual, expected, "k={k}, query=({qx}, {qy}, {qz})");
681            }
682        }
683    }
684
685    #[test]
686    fn builds_from_point_cloud() {
687        let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
688        builder.push_point([0.0, 0.0, 0.0]).unwrap();
689        builder.push_point([1.0, 0.0, 0.0]).unwrap();
690        let cloud = builder.build().unwrap();
691        let tree = KdTree::from_point_cloud(&cloud).unwrap();
692        assert_eq!(tree.len(), 2);
693        let nearest = tree.nearest_one(0.9, 0.0, 0.0).unwrap();
694        assert_eq!(nearest.index, 1);
695    }
696
697    #[test]
698    fn degenerate_points_return_valid_neighbor() {
699        let x = vec![1.0, 1.0, 1.0];
700        let y = vec![2.0, 2.0, 2.0];
701        let z = vec![3.0, 3.0, 3.0];
702        let tree = KdTree::from_slices(&x, &y, &z);
703        let neighbor = tree.nearest_one(1.0, 2.0, 2.0).unwrap();
704        assert_eq!(neighbor.distance_squared, 1.0);
705    }
706}