Skip to main content

spatialrust_segmentation/
region_growing.rs

1use std::collections::VecDeque;
2
3use spatialrust_core::{
4    HasNormals3, HasPositions3, PointBuffer, PointCloud, SpatialError, SpatialResult,
5};
6use spatialrust_search::{KdTree, NearestNeighborIndex};
7
8use crate::cloud::with_labels;
9use crate::segmenter::PointCloudSegmenter;
10
11/// Configuration for normal-based region growing segmentation.
12#[derive(Clone, Copy, Debug, PartialEq)]
13pub struct RegionGrowingConfig {
14    /// Number of nearest neighbors considered per point.
15    pub k_neighbors: usize,
16    /// Maximum angle (radians) between point normals for them to join a region.
17    pub smoothness_threshold: f32,
18    /// Maximum curvature for a point to act as a growth seed.
19    ///
20    /// Only applied when the input cloud carries a `curvature` field; flatter
21    /// points (low curvature) are grown first and propagate the region.
22    pub curvature_threshold: f32,
23    /// Minimum number of points required to keep a region.
24    pub min_cluster_size: usize,
25    /// Maximum number of points allowed in a region.
26    pub max_cluster_size: usize,
27}
28
29impl Default for RegionGrowingConfig {
30    fn default() -> Self {
31        Self {
32            k_neighbors: 30,
33            // ~3 degrees, matching common region-growing defaults.
34            smoothness_threshold: 0.052_359_88,
35            curvature_threshold: 1.0,
36            min_cluster_size: 1,
37            max_cluster_size: usize::MAX,
38        }
39    }
40}
41
42impl RegionGrowingConfig {
43    /// Creates a config with the given smoothness angle (radians) and neighbor count.
44    #[must_use]
45    pub const fn with_smoothness(smoothness_threshold: f32, k_neighbors: usize) -> Self {
46        Self {
47            k_neighbors,
48            smoothness_threshold,
49            curvature_threshold: 1.0,
50            min_cluster_size: 1,
51            max_cluster_size: usize::MAX,
52        }
53    }
54}
55
56/// Result of region growing segmentation.
57#[derive(Clone, Debug, PartialEq)]
58pub struct RegionGrowingResult {
59    /// Input points annotated with region labels (`label` field, `-1` = unassigned).
60    pub cloud: PointCloud,
61    /// Number of accepted regions.
62    pub cluster_count: usize,
63    /// Size of each region in label order.
64    pub cluster_sizes: Vec<usize>,
65}
66
67/// Normal-based region growing segmenter.
68///
69/// Grows smooth regions by connecting neighboring points whose normals differ by
70/// less than [`RegionGrowingConfig::smoothness_threshold`], seeding growth from
71/// the flattest (lowest-curvature) points first. The input cloud must carry
72/// normals (e.g. from normal estimation).
73#[derive(Clone, Copy, Debug, PartialEq)]
74pub struct RegionGrowingSegmenter {
75    config: RegionGrowingConfig,
76}
77
78impl RegionGrowingSegmenter {
79    /// Creates a segmenter from config.
80    #[must_use]
81    pub const fn new(config: RegionGrowingConfig) -> Self {
82        Self { config }
83    }
84
85    /// Returns the segmenter config.
86    #[must_use]
87    pub const fn config(&self) -> RegionGrowingConfig {
88        self.config
89    }
90
91    /// Segments the input cloud into smooth regions, adding a `label` field.
92    pub fn segment(&self, input: &PointCloud) -> SpatialResult<RegionGrowingResult> {
93        if input.is_empty() {
94            return Ok(RegionGrowingResult {
95                cloud: input.clone(),
96                cluster_count: 0,
97                cluster_sizes: Vec::new(),
98            });
99        }
100        if self.config.k_neighbors == 0 {
101            return Err(SpatialError::InvalidArgument(
102                "k_neighbors must be greater than zero".to_owned(),
103            ));
104        }
105        if self.config.smoothness_threshold < 0.0 {
106            return Err(SpatialError::InvalidArgument(
107                "smoothness_threshold must be non-negative".to_owned(),
108            ));
109        }
110        if self.config.min_cluster_size == 0 {
111            return Err(SpatialError::InvalidArgument(
112                "min_cluster_size must be greater than zero".to_owned(),
113            ));
114        }
115        if self.config.max_cluster_size < self.config.min_cluster_size {
116            return Err(SpatialError::InvalidArgument(
117                "max_cluster_size must be >= min_cluster_size".to_owned(),
118            ));
119        }
120
121        let (x, y, z) = input.positions3()?;
122        let (nx, ny, nz) = input.normals3()?;
123        let curvature = match input.field("curvature") {
124            Ok(PointBuffer::F32(values)) => Some(values.as_slice()),
125            _ => None,
126        };
127
128        let len = input.len();
129        let tree = KdTree::from_slices(x, y, z);
130        let cos_threshold = self.config.smoothness_threshold.cos();
131
132        // Seed order: flattest points first when curvature is available.
133        let mut order: Vec<usize> = (0..len).collect();
134        if let Some(curv) = curvature {
135            order.sort_by(|&a, &b| curv[a].total_cmp(&curv[b]));
136        }
137
138        let mut processed = vec![false; len];
139        let mut labels = vec![-1_i32; len];
140        let mut cluster_sizes = Vec::new();
141        let mut cluster_id = 0_i32;
142
143        for &start in &order {
144            if processed[start] {
145                continue;
146            }
147
148            let mut seeds = VecDeque::from([start]);
149            // Each accepted point is pushed to `region` exactly once, at discovery.
150            let mut region = vec![start];
151            processed[start] = true;
152
153            while let Some(current) = seeds.pop_front() {
154                let neighbors =
155                    tree.nearest_k(x[current], y[current], z[current], self.config.k_neighbors + 1);
156                for neighbor in neighbors {
157                    let candidate = neighbor.index;
158                    if candidate == current || processed[candidate] {
159                        continue;
160                    }
161                    // Smoothness test: treat antiparallel normals as aligned.
162                    let dot = (nx[current] * nx[candidate]
163                        + ny[current] * ny[candidate]
164                        + nz[current] * nz[candidate])
165                        .abs();
166                    if dot < cos_threshold {
167                        continue;
168                    }
169                    processed[candidate] = true;
170                    region.push(candidate);
171                    // Flat points keep the region growing; rough points stay leaves.
172                    let flat = match curvature {
173                        Some(curv) => curv[candidate] <= self.config.curvature_threshold,
174                        None => true,
175                    };
176                    if flat {
177                        seeds.push_back(candidate);
178                    }
179                }
180            }
181
182            if region.len() >= self.config.min_cluster_size
183                && region.len() <= self.config.max_cluster_size
184            {
185                for index in &region {
186                    labels[*index] = cluster_id;
187                }
188                cluster_sizes.push(region.len());
189                cluster_id += 1;
190            }
191        }
192
193        Ok(RegionGrowingResult {
194            cloud: with_labels(input, "label", labels)?,
195            cluster_count: cluster_sizes.len(),
196            cluster_sizes,
197        })
198    }
199}
200
201impl PointCloudSegmenter for RegionGrowingSegmenter {
202    fn name(&self) -> &'static str {
203        "RegionGrowingSegmenter"
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::{RegionGrowingConfig, RegionGrowingSegmenter};
210    use spatialrust_core::{DType, FieldSemantic, PointCloudBuilder, PointField, PointSchema};
211
212    fn schema_with_normals() -> PointSchema {
213        PointSchema::new()
214            .with_field(PointField::scalar("x", FieldSemantic::PositionX, DType::F32))
215            .with_field(PointField::scalar("y", FieldSemantic::PositionY, DType::F32))
216            .with_field(PointField::scalar("z", FieldSemantic::PositionZ, DType::F32))
217            .with_field(PointField::scalar("normal_x", FieldSemantic::NormalX, DType::F32))
218            .with_field(PointField::scalar("normal_y", FieldSemantic::NormalY, DType::F32))
219            .with_field(PointField::scalar("normal_z", FieldSemantic::NormalZ, DType::F32))
220    }
221
222    /// A floor (normal +Z) and a wall (normal +Y) meeting along the y=0 edge.
223    fn floor_and_wall() -> spatialrust_core::PointCloud {
224        let mut builder = PointCloudBuilder::new(schema_with_normals());
225        for i in 0..5 {
226            for j in 0..5 {
227                let (xf, yf) = (i as f32, j as f32);
228                // floor: z = 0, normal up
229                builder.push_point([xf, yf, 0.0, 0.0, 0.0, 1.0]).unwrap();
230                // wall: y = 0, rising in z, normal +Y
231                builder.push_point([xf, 0.0, yf + 1.0, 0.0, 1.0, 0.0]).unwrap();
232            }
233        }
234        builder.build().unwrap()
235    }
236
237    #[test]
238    fn perpendicular_faces_split_into_two_regions() {
239        let input = floor_and_wall();
240        let segmenter = RegionGrowingSegmenter::new(RegionGrowingConfig::with_smoothness(
241            10.0_f32.to_radians(),
242            8,
243        ));
244        let result = segmenter.segment(&input).unwrap();
245        assert_eq!(result.cluster_count, 2);
246        assert!(result.cloud.field("label").is_ok());
247    }
248
249    #[test]
250    fn coplanar_points_form_single_region() {
251        let mut builder = PointCloudBuilder::new(schema_with_normals());
252        for i in 0..5 {
253            for j in 0..5 {
254                builder.push_point([i as f32, j as f32, 0.0, 0.0, 0.0, 1.0]).unwrap();
255            }
256        }
257        let input = builder.build().unwrap();
258
259        let segmenter = RegionGrowingSegmenter::new(RegionGrowingConfig::with_smoothness(
260            10.0_f32.to_radians(),
261            8,
262        ));
263        let result = segmenter.segment(&input).unwrap();
264        assert_eq!(result.cluster_count, 1);
265        assert_eq!(result.cluster_sizes, vec![25]);
266    }
267
268    #[test]
269    fn requires_normals() {
270        let mut builder = PointCloudBuilder::xyz();
271        builder.push_point([0.0, 0.0, 0.0]).unwrap();
272        let input = builder.build().unwrap();
273        let segmenter = RegionGrowingSegmenter::new(RegionGrowingConfig::default());
274        assert!(segmenter.segment(&input).is_err());
275    }
276}