spatialrust_segmentation/
region_growing.rs1use std::collections::VecDeque;
2
3use spatialrust_core::{
4 HasNormals3, HasPositions3, PointBuffer, PointCloud, SpatialError, SpatialResult,
5};
6use spatialrust_search::{KdTree, NearestNeighborIndex};
7
8use crate::cloud::with_labels;
9use crate::segmenter::PointCloudSegmenter;
10
11#[derive(Clone, Copy, Debug, PartialEq)]
13pub struct RegionGrowingConfig {
14 pub k_neighbors: usize,
16 pub smoothness_threshold: f32,
18 pub curvature_threshold: f32,
23 pub min_cluster_size: usize,
25 pub max_cluster_size: usize,
27}
28
29impl Default for RegionGrowingConfig {
30 fn default() -> Self {
31 Self {
32 k_neighbors: 30,
33 smoothness_threshold: 0.052_359_88,
35 curvature_threshold: 1.0,
36 min_cluster_size: 1,
37 max_cluster_size: usize::MAX,
38 }
39 }
40}
41
42impl RegionGrowingConfig {
43 #[must_use]
45 pub const fn with_smoothness(smoothness_threshold: f32, k_neighbors: usize) -> Self {
46 Self {
47 k_neighbors,
48 smoothness_threshold,
49 curvature_threshold: 1.0,
50 min_cluster_size: 1,
51 max_cluster_size: usize::MAX,
52 }
53 }
54}
55
56#[derive(Clone, Debug, PartialEq)]
58pub struct RegionGrowingResult {
59 pub cloud: PointCloud,
61 pub cluster_count: usize,
63 pub cluster_sizes: Vec<usize>,
65}
66
67#[derive(Clone, Copy, Debug, PartialEq)]
74pub struct RegionGrowingSegmenter {
75 config: RegionGrowingConfig,
76}
77
78impl RegionGrowingSegmenter {
79 #[must_use]
81 pub const fn new(config: RegionGrowingConfig) -> Self {
82 Self { config }
83 }
84
85 #[must_use]
87 pub const fn config(&self) -> RegionGrowingConfig {
88 self.config
89 }
90
91 pub fn segment(&self, input: &PointCloud) -> SpatialResult<RegionGrowingResult> {
93 if input.is_empty() {
94 return Ok(RegionGrowingResult {
95 cloud: input.clone(),
96 cluster_count: 0,
97 cluster_sizes: Vec::new(),
98 });
99 }
100 if self.config.k_neighbors == 0 {
101 return Err(SpatialError::InvalidArgument(
102 "k_neighbors must be greater than zero".to_owned(),
103 ));
104 }
105 if self.config.smoothness_threshold < 0.0 {
106 return Err(SpatialError::InvalidArgument(
107 "smoothness_threshold must be non-negative".to_owned(),
108 ));
109 }
110 if self.config.min_cluster_size == 0 {
111 return Err(SpatialError::InvalidArgument(
112 "min_cluster_size must be greater than zero".to_owned(),
113 ));
114 }
115 if self.config.max_cluster_size < self.config.min_cluster_size {
116 return Err(SpatialError::InvalidArgument(
117 "max_cluster_size must be >= min_cluster_size".to_owned(),
118 ));
119 }
120
121 let (x, y, z) = input.positions3()?;
122 let (nx, ny, nz) = input.normals3()?;
123 let curvature = match input.field("curvature") {
124 Ok(PointBuffer::F32(values)) => Some(values.as_slice()),
125 _ => None,
126 };
127
128 let len = input.len();
129 let tree = KdTree::from_slices(x, y, z);
130 let cos_threshold = self.config.smoothness_threshold.cos();
131
132 let mut order: Vec<usize> = (0..len).collect();
134 if let Some(curv) = curvature {
135 order.sort_by(|&a, &b| curv[a].total_cmp(&curv[b]));
136 }
137
138 let mut processed = vec![false; len];
139 let mut labels = vec![-1_i32; len];
140 let mut cluster_sizes = Vec::new();
141 let mut cluster_id = 0_i32;
142
143 for &start in &order {
144 if processed[start] {
145 continue;
146 }
147
148 let mut seeds = VecDeque::from([start]);
149 let mut region = vec![start];
151 processed[start] = true;
152
153 while let Some(current) = seeds.pop_front() {
154 let neighbors =
155 tree.nearest_k(x[current], y[current], z[current], self.config.k_neighbors + 1);
156 for neighbor in neighbors {
157 let candidate = neighbor.index;
158 if candidate == current || processed[candidate] {
159 continue;
160 }
161 let dot = (nx[current] * nx[candidate]
163 + ny[current] * ny[candidate]
164 + nz[current] * nz[candidate])
165 .abs();
166 if dot < cos_threshold {
167 continue;
168 }
169 processed[candidate] = true;
170 region.push(candidate);
171 let flat = match curvature {
173 Some(curv) => curv[candidate] <= self.config.curvature_threshold,
174 None => true,
175 };
176 if flat {
177 seeds.push_back(candidate);
178 }
179 }
180 }
181
182 if region.len() >= self.config.min_cluster_size
183 && region.len() <= self.config.max_cluster_size
184 {
185 for index in ®ion {
186 labels[*index] = cluster_id;
187 }
188 cluster_sizes.push(region.len());
189 cluster_id += 1;
190 }
191 }
192
193 Ok(RegionGrowingResult {
194 cloud: with_labels(input, "label", labels)?,
195 cluster_count: cluster_sizes.len(),
196 cluster_sizes,
197 })
198 }
199}
200
201impl PointCloudSegmenter for RegionGrowingSegmenter {
202 fn name(&self) -> &'static str {
203 "RegionGrowingSegmenter"
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::{RegionGrowingConfig, RegionGrowingSegmenter};
210 use spatialrust_core::{DType, FieldSemantic, PointCloudBuilder, PointField, PointSchema};
211
212 fn schema_with_normals() -> PointSchema {
213 PointSchema::new()
214 .with_field(PointField::scalar("x", FieldSemantic::PositionX, DType::F32))
215 .with_field(PointField::scalar("y", FieldSemantic::PositionY, DType::F32))
216 .with_field(PointField::scalar("z", FieldSemantic::PositionZ, DType::F32))
217 .with_field(PointField::scalar("normal_x", FieldSemantic::NormalX, DType::F32))
218 .with_field(PointField::scalar("normal_y", FieldSemantic::NormalY, DType::F32))
219 .with_field(PointField::scalar("normal_z", FieldSemantic::NormalZ, DType::F32))
220 }
221
222 fn floor_and_wall() -> spatialrust_core::PointCloud {
224 let mut builder = PointCloudBuilder::new(schema_with_normals());
225 for i in 0..5 {
226 for j in 0..5 {
227 let (xf, yf) = (i as f32, j as f32);
228 builder.push_point([xf, yf, 0.0, 0.0, 0.0, 1.0]).unwrap();
230 builder.push_point([xf, 0.0, yf + 1.0, 0.0, 1.0, 0.0]).unwrap();
232 }
233 }
234 builder.build().unwrap()
235 }
236
237 #[test]
238 fn perpendicular_faces_split_into_two_regions() {
239 let input = floor_and_wall();
240 let segmenter = RegionGrowingSegmenter::new(RegionGrowingConfig::with_smoothness(
241 10.0_f32.to_radians(),
242 8,
243 ));
244 let result = segmenter.segment(&input).unwrap();
245 assert_eq!(result.cluster_count, 2);
246 assert!(result.cloud.field("label").is_ok());
247 }
248
249 #[test]
250 fn coplanar_points_form_single_region() {
251 let mut builder = PointCloudBuilder::new(schema_with_normals());
252 for i in 0..5 {
253 for j in 0..5 {
254 builder.push_point([i as f32, j as f32, 0.0, 0.0, 0.0, 1.0]).unwrap();
255 }
256 }
257 let input = builder.build().unwrap();
258
259 let segmenter = RegionGrowingSegmenter::new(RegionGrowingConfig::with_smoothness(
260 10.0_f32.to_radians(),
261 8,
262 ));
263 let result = segmenter.segment(&input).unwrap();
264 assert_eq!(result.cluster_count, 1);
265 assert_eq!(result.cluster_sizes, vec![25]);
266 }
267
268 #[test]
269 fn requires_normals() {
270 let mut builder = PointCloudBuilder::xyz();
271 builder.push_point([0.0, 0.0, 0.0]).unwrap();
272 let input = builder.build().unwrap();
273 let segmenter = RegionGrowingSegmenter::new(RegionGrowingConfig::default());
274 assert!(segmenter.segment(&input).is_err());
275 }
276}