spatialrust_search/
graph.rs1use spatialrust_core::{HasPositions3, PointCloud, SpatialError, SpatialResult};
8
9use crate::kdtree::KdTree;
10use crate::{NearestNeighborIndex, RadiusSearchIndex};
11
12#[derive(Clone, Debug, PartialEq, Eq)]
14pub struct NeighborGraph {
15 pub num_nodes: usize,
17 pub edges: Vec<[u32; 2]>,
19}
20
21impl NeighborGraph {
22 #[must_use]
24 pub fn num_edges(&self) -> usize {
25 self.edges.len()
26 }
27
28 #[must_use]
30 pub fn is_empty(&self) -> bool {
31 self.edges.is_empty()
32 }
33}
34
35pub fn knn_graph(cloud: &PointCloud, k: usize) -> SpatialResult<NeighborGraph> {
38 if k == 0 {
39 return Err(SpatialError::InvalidArgument("k must be greater than zero".to_owned()));
40 }
41 let (x, y, z) = cloud.positions3()?;
42 let len = cloud.len();
43 if len == 0 {
44 return Ok(NeighborGraph { num_nodes: 0, edges: Vec::new() });
45 }
46
47 let tree = KdTree::from_slices(x, y, z);
48 let mut edges = Vec::with_capacity(len * k);
49 for i in 0..len {
50 for neighbor in tree.nearest_k(x[i], y[i], z[i], k + 1) {
52 if neighbor.index != i {
53 edges.push([i as u32, neighbor.index as u32]);
54 }
55 }
56 }
57 Ok(NeighborGraph { num_nodes: len, edges })
58}
59
60pub fn radius_graph(cloud: &PointCloud, radius: f32) -> SpatialResult<NeighborGraph> {
63 if radius <= 0.0 || radius.is_nan() {
64 return Err(SpatialError::InvalidArgument("radius must be positive".to_owned()));
65 }
66 let (x, y, z) = cloud.positions3()?;
67 let len = cloud.len();
68 if len == 0 {
69 return Ok(NeighborGraph { num_nodes: 0, edges: Vec::new() });
70 }
71
72 let tree = KdTree::from_slices(x, y, z);
73 let mut edges = Vec::new();
74 for i in 0..len {
75 for neighbor in tree.radius_search(x[i], y[i], z[i], radius) {
76 if neighbor.index != i {
77 edges.push([i as u32, neighbor.index as u32]);
78 }
79 }
80 }
81 Ok(NeighborGraph { num_nodes: len, edges })
82}
83
84#[cfg(test)]
85mod tests {
86 use super::{knn_graph, radius_graph};
87 use spatialrust_core::{PointCloudBuilder, StandardSchemas};
88
89 fn line(n: usize) -> spatialrust_core::PointCloud {
90 let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
91 for i in 0..n {
92 builder.push_point([i as f32, 0.0, 0.0]).unwrap();
93 }
94 builder.build().unwrap()
95 }
96
97 #[test]
98 fn knn_graph_has_k_edges_per_node() {
99 let cloud = line(10);
100 let graph = knn_graph(&cloud, 2).unwrap();
101 assert_eq!(graph.num_nodes, 10);
102 assert_eq!(graph.num_edges(), 10 * 2);
104 assert!(graph.edges.iter().all(|[s, t]| s != t));
106 }
107
108 #[test]
109 fn knn_nearest_neighbor_is_adjacent_on_a_line() {
110 let cloud = line(5);
111 let graph = knn_graph(&cloud, 1).unwrap();
112 let from_0: Vec<u32> =
114 graph.edges.iter().filter(|[s, _]| *s == 0).map(|[_, t]| *t).collect();
115 assert_eq!(from_0, vec![1]);
116 }
117
118 #[test]
119 fn radius_graph_links_points_within_radius() {
120 let cloud = line(5);
121 let graph = radius_graph(&cloud, 1.5).unwrap();
123 assert_eq!(graph.num_edges(), 8);
125 }
126
127 #[test]
128 fn rejects_bad_params() {
129 let cloud = line(3);
130 assert!(knn_graph(&cloud, 0).is_err());
131 assert!(radius_graph(&cloud, 0.0).is_err());
132 }
133}