spatialrust_segmentation/
cluster.rs1use std::collections::VecDeque;
2
3use spatialrust_core::{HasPositions3, PointCloud, SpatialError, SpatialResult};
4use spatialrust_search::{KdTree, RadiusSearchIndex};
5
6use crate::cloud::with_labels;
7use crate::segmenter::PointCloudSegmenter;
8
9#[derive(Clone, Copy, Debug, PartialEq)]
11pub struct EuclideanClusterConfig {
12 pub cluster_tolerance: f32,
14 pub min_cluster_size: usize,
16 pub max_cluster_size: usize,
18}
19
20impl Default for EuclideanClusterConfig {
21 fn default() -> Self {
22 Self { cluster_tolerance: 0.02, min_cluster_size: 1, max_cluster_size: usize::MAX }
23 }
24}
25
26impl EuclideanClusterConfig {
27 #[must_use]
29 pub const fn with_tolerance(cluster_tolerance: f32, min_cluster_size: usize) -> Self {
30 Self { cluster_tolerance, min_cluster_size, max_cluster_size: usize::MAX }
31 }
32}
33
34#[derive(Clone, Debug, PartialEq)]
36pub struct EuclideanClusterResult {
37 pub cloud: PointCloud,
39 pub cluster_count: usize,
41 pub cluster_sizes: Vec<usize>,
43}
44
45#[derive(Clone, Copy, Debug, PartialEq)]
47pub struct EuclideanClusterExtractor {
48 config: EuclideanClusterConfig,
49}
50
51impl EuclideanClusterExtractor {
52 #[must_use]
54 pub const fn new(config: EuclideanClusterConfig) -> Self {
55 Self { config }
56 }
57
58 #[must_use]
60 pub const fn config(&self) -> EuclideanClusterConfig {
61 self.config
62 }
63
64 pub fn extract(&self, input: &PointCloud) -> SpatialResult<EuclideanClusterResult> {
66 if input.is_empty() {
67 return Ok(EuclideanClusterResult {
68 cloud: input.clone(),
69 cluster_count: 0,
70 cluster_sizes: Vec::new(),
71 });
72 }
73 if self.config.cluster_tolerance < 0.0 {
74 return Err(SpatialError::InvalidArgument(
75 "cluster_tolerance must be non-negative".to_owned(),
76 ));
77 }
78 if self.config.min_cluster_size == 0 {
79 return Err(SpatialError::InvalidArgument(
80 "min_cluster_size must be greater than zero".to_owned(),
81 ));
82 }
83 if self.config.max_cluster_size < self.config.min_cluster_size {
84 return Err(SpatialError::InvalidArgument(
85 "max_cluster_size must be >= min_cluster_size".to_owned(),
86 ));
87 }
88
89 let (x, y, z) = input.positions3()?;
90 let tree = KdTree::from_slices(x, y, z);
91 let len = input.len();
92 let mut processed = vec![false; len];
93 let mut labels = vec![-1_i32; len];
94 let mut cluster_sizes = Vec::new();
95 let mut cluster_id = 0_i32;
96
97 for seed in 0..len {
98 if processed[seed] {
99 continue;
100 }
101
102 let mut queue = VecDeque::from([seed]);
103 let mut cluster_indices = Vec::new();
104 processed[seed] = true;
105
106 while let Some(index) = queue.pop_front() {
107 cluster_indices.push(index);
108 let neighbors =
109 tree.radius_search(x[index], y[index], z[index], self.config.cluster_tolerance);
110 for neighbor in neighbors {
111 let candidate = neighbor.index;
112 if processed[candidate] {
113 continue;
114 }
115 processed[candidate] = true;
116 queue.push_back(candidate);
117 }
118 }
119
120 if cluster_indices.len() >= self.config.min_cluster_size
121 && cluster_indices.len() <= self.config.max_cluster_size
122 {
123 let cluster_size = cluster_indices.len();
124 for index in cluster_indices {
125 labels[index] = cluster_id;
126 }
127 cluster_sizes.push(cluster_size);
128 cluster_id += 1;
129 }
130 }
131
132 Ok(EuclideanClusterResult {
133 cloud: with_labels(input, "label", labels)?,
134 cluster_count: cluster_sizes.len(),
135 cluster_sizes,
136 })
137 }
138}
139
140impl PointCloudSegmenter for EuclideanClusterExtractor {
141 fn name(&self) -> &'static str {
142 "EuclideanClusterExtractor"
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::{EuclideanClusterConfig, EuclideanClusterExtractor};
149 use spatialrust_core::{PointCloudBuilder, StandardSchemas};
150
151 fn three_clusters() -> spatialrust_core::PointCloud {
152 let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
153 for center in [(0.0, 0.0, 0.0), (10.0, 0.0, 0.0), (0.0, 10.0, 0.0)] {
154 for dx in 0..3 {
155 for dy in 0..3 {
156 builder
157 .push_point([center.0 + dx as f32, center.1 + dy as f32, center.2])
158 .unwrap();
159 }
160 }
161 }
162 builder.build().unwrap()
163 }
164
165 #[test]
166 fn finds_three_separated_clusters() {
167 let input = three_clusters();
168 let extractor = EuclideanClusterExtractor::new(EuclideanClusterConfig {
169 cluster_tolerance: 1.5,
170 min_cluster_size: 3,
171 max_cluster_size: usize::MAX,
172 });
173 let result = extractor.extract(&input).unwrap();
174 assert_eq!(result.cluster_count, 3);
175 assert!(result.cluster_sizes.iter().all(|&size| size == 9));
176 assert!(result.cloud.field("label").is_ok());
177 }
178
179 #[test]
180 fn rejects_clusters_smaller_than_minimum() {
181 let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
182 builder.push_point([0.0, 0.0, 0.0]).unwrap();
183 builder.push_point([0.1, 0.0, 0.0]).unwrap();
184 let input = builder.build().unwrap();
185
186 let extractor = EuclideanClusterExtractor::new(EuclideanClusterConfig {
187 cluster_tolerance: 0.5,
188 min_cluster_size: 3,
189 max_cluster_size: usize::MAX,
190 });
191 let result = extractor.extract(&input).unwrap();
192 assert_eq!(result.cluster_count, 0);
193 }
194}