Skip to main content

spatialrust_search/
graph.rs

1//! Neighborhood graph construction (k-NN and radius graphs).
2//!
3//! Turns a point cloud into the edge list that graph neural networks consume
4//! (PyG-style `edge_index`): each point becomes a node, with a directed edge to
5//! every neighbor. Built on the KD-tree, so it scales to large clouds.
6
7use spatialrust_core::{HasPositions3, PointCloud, SpatialError, SpatialResult};
8
9use crate::kdtree::KdTree;
10use crate::{NearestNeighborIndex, RadiusSearchIndex};
11
12/// A directed neighborhood graph over a point cloud.
13#[derive(Clone, Debug, PartialEq, Eq)]
14pub struct NeighborGraph {
15    /// Number of nodes (input points).
16    pub num_nodes: usize,
17    /// Directed edges as `[source, target]` index pairs (no self-loops).
18    pub edges: Vec<[u32; 2]>,
19}
20
21impl NeighborGraph {
22    /// Number of directed edges.
23    #[must_use]
24    pub fn num_edges(&self) -> usize {
25        self.edges.len()
26    }
27
28    /// Whether the graph has no edges.
29    #[must_use]
30    pub fn is_empty(&self) -> bool {
31        self.edges.is_empty()
32    }
33}
34
35/// Builds a directed k-nearest-neighbor graph: an edge from every point to each
36/// of its `k` nearest neighbors (excluding itself).
37pub fn knn_graph(cloud: &PointCloud, k: usize) -> SpatialResult<NeighborGraph> {
38    if k == 0 {
39        return Err(SpatialError::InvalidArgument("k must be greater than zero".to_owned()));
40    }
41    let (x, y, z) = cloud.positions3()?;
42    let len = cloud.len();
43    if len == 0 {
44        return Ok(NeighborGraph { num_nodes: 0, edges: Vec::new() });
45    }
46
47    let tree = KdTree::from_slices(x, y, z);
48    let mut edges = Vec::with_capacity(len * k);
49    for i in 0..len {
50        // k + 1 because the point finds itself first.
51        for neighbor in tree.nearest_k(x[i], y[i], z[i], k + 1) {
52            if neighbor.index != i {
53                edges.push([i as u32, neighbor.index as u32]);
54            }
55        }
56    }
57    Ok(NeighborGraph { num_nodes: len, edges })
58}
59
60/// Builds a directed radius graph: an edge from every point to each other point
61/// within `radius`.
62pub fn radius_graph(cloud: &PointCloud, radius: f32) -> SpatialResult<NeighborGraph> {
63    if radius <= 0.0 || radius.is_nan() {
64        return Err(SpatialError::InvalidArgument("radius must be positive".to_owned()));
65    }
66    let (x, y, z) = cloud.positions3()?;
67    let len = cloud.len();
68    if len == 0 {
69        return Ok(NeighborGraph { num_nodes: 0, edges: Vec::new() });
70    }
71
72    let tree = KdTree::from_slices(x, y, z);
73    let mut edges = Vec::new();
74    for i in 0..len {
75        for neighbor in tree.radius_search(x[i], y[i], z[i], radius) {
76            if neighbor.index != i {
77                edges.push([i as u32, neighbor.index as u32]);
78            }
79        }
80    }
81    Ok(NeighborGraph { num_nodes: len, edges })
82}
83
84#[cfg(test)]
85mod tests {
86    use super::{knn_graph, radius_graph};
87    use spatialrust_core::{PointCloudBuilder, StandardSchemas};
88
89    fn line(n: usize) -> spatialrust_core::PointCloud {
90        let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
91        for i in 0..n {
92            builder.push_point([i as f32, 0.0, 0.0]).unwrap();
93        }
94        builder.build().unwrap()
95    }
96
97    #[test]
98    fn knn_graph_has_k_edges_per_node() {
99        let cloud = line(10);
100        let graph = knn_graph(&cloud, 2).unwrap();
101        assert_eq!(graph.num_nodes, 10);
102        // Every node has exactly k outgoing edges.
103        assert_eq!(graph.num_edges(), 10 * 2);
104        // No self-loops.
105        assert!(graph.edges.iter().all(|[s, t]| s != t));
106    }
107
108    #[test]
109    fn knn_nearest_neighbor_is_adjacent_on_a_line() {
110        let cloud = line(5);
111        let graph = knn_graph(&cloud, 1).unwrap();
112        // Point 0's single nearest neighbor is point 1.
113        let from_0: Vec<u32> =
114            graph.edges.iter().filter(|[s, _]| *s == 0).map(|[_, t]| *t).collect();
115        assert_eq!(from_0, vec![1]);
116    }
117
118    #[test]
119    fn radius_graph_links_points_within_radius() {
120        let cloud = line(5);
121        // Radius 1.5 reaches the immediate neighbors on each side (spacing 1.0).
122        let graph = radius_graph(&cloud, 1.5).unwrap();
123        // Interior nodes have 2 neighbors, the two ends have 1 each: 3*2 + 2*1 = 8.
124        assert_eq!(graph.num_edges(), 8);
125    }
126
127    #[test]
128    fn rejects_bad_params() {
129        let cloud = line(3);
130        assert!(knn_graph(&cloud, 0).is_err());
131        assert!(radius_graph(&cloud, 0.0).is_err());
132    }
133}