Skip to main content

spatialrust_segmentation/
cluster.rs

1use std::collections::VecDeque;
2
3use spatialrust_core::{HasPositions3, PointCloud, SpatialError, SpatialResult};
4use spatialrust_search::{KdTree, RadiusSearchIndex};
5
6use crate::cloud::with_labels;
7use crate::segmenter::PointCloudSegmenter;
8
9/// Configuration for Euclidean clustering.
10#[derive(Clone, Copy, Debug, PartialEq)]
11pub struct EuclideanClusterConfig {
12    /// Maximum distance between points in the same cluster.
13    pub cluster_tolerance: f32,
14    /// Minimum number of points required to form a cluster.
15    pub min_cluster_size: usize,
16    /// Maximum number of points allowed in a cluster.
17    pub max_cluster_size: usize,
18}
19
20impl Default for EuclideanClusterConfig {
21    fn default() -> Self {
22        Self { cluster_tolerance: 0.02, min_cluster_size: 1, max_cluster_size: usize::MAX }
23    }
24}
25
26impl EuclideanClusterConfig {
27    /// Creates a config with the given tolerance and minimum cluster size.
28    #[must_use]
29    pub const fn with_tolerance(cluster_tolerance: f32, min_cluster_size: usize) -> Self {
30        Self { cluster_tolerance, min_cluster_size, max_cluster_size: usize::MAX }
31    }
32}
33
34/// Result of Euclidean clustering.
35#[derive(Clone, Debug, PartialEq)]
36pub struct EuclideanClusterResult {
37    /// Input points annotated with cluster labels.
38    pub cloud: PointCloud,
39    /// Number of valid clusters found.
40    pub cluster_count: usize,
41    /// Size of each cluster in label order.
42    pub cluster_sizes: Vec<usize>,
43}
44
45/// Euclidean region-growing cluster extractor.
46#[derive(Clone, Copy, Debug, PartialEq)]
47pub struct EuclideanClusterExtractor {
48    config: EuclideanClusterConfig,
49}
50
51impl EuclideanClusterExtractor {
52    /// Creates an extractor from config.
53    #[must_use]
54    pub const fn new(config: EuclideanClusterConfig) -> Self {
55        Self { config }
56    }
57
58    /// Returns the extractor config.
59    #[must_use]
60    pub const fn config(&self) -> EuclideanClusterConfig {
61        self.config
62    }
63
64    /// Clusters the input cloud and adds a `label` field.
65    pub fn extract(&self, input: &PointCloud) -> SpatialResult<EuclideanClusterResult> {
66        if input.is_empty() {
67            return Ok(EuclideanClusterResult {
68                cloud: input.clone(),
69                cluster_count: 0,
70                cluster_sizes: Vec::new(),
71            });
72        }
73        if self.config.cluster_tolerance < 0.0 {
74            return Err(SpatialError::InvalidArgument(
75                "cluster_tolerance must be non-negative".to_owned(),
76            ));
77        }
78        if self.config.min_cluster_size == 0 {
79            return Err(SpatialError::InvalidArgument(
80                "min_cluster_size must be greater than zero".to_owned(),
81            ));
82        }
83        if self.config.max_cluster_size < self.config.min_cluster_size {
84            return Err(SpatialError::InvalidArgument(
85                "max_cluster_size must be >= min_cluster_size".to_owned(),
86            ));
87        }
88
89        let (x, y, z) = input.positions3()?;
90        let tree = KdTree::from_slices(x, y, z);
91        let len = input.len();
92        let mut processed = vec![false; len];
93        let mut labels = vec![-1_i32; len];
94        let mut cluster_sizes = Vec::new();
95        let mut cluster_id = 0_i32;
96
97        for seed in 0..len {
98            if processed[seed] {
99                continue;
100            }
101
102            let mut queue = VecDeque::from([seed]);
103            let mut cluster_indices = Vec::new();
104            processed[seed] = true;
105
106            while let Some(index) = queue.pop_front() {
107                cluster_indices.push(index);
108                let neighbors =
109                    tree.radius_search(x[index], y[index], z[index], self.config.cluster_tolerance);
110                for neighbor in neighbors {
111                    let candidate = neighbor.index;
112                    if processed[candidate] {
113                        continue;
114                    }
115                    processed[candidate] = true;
116                    queue.push_back(candidate);
117                }
118            }
119
120            if cluster_indices.len() >= self.config.min_cluster_size
121                && cluster_indices.len() <= self.config.max_cluster_size
122            {
123                let cluster_size = cluster_indices.len();
124                for index in cluster_indices {
125                    labels[index] = cluster_id;
126                }
127                cluster_sizes.push(cluster_size);
128                cluster_id += 1;
129            }
130        }
131
132        Ok(EuclideanClusterResult {
133            cloud: with_labels(input, "label", labels)?,
134            cluster_count: cluster_sizes.len(),
135            cluster_sizes,
136        })
137    }
138}
139
140impl PointCloudSegmenter for EuclideanClusterExtractor {
141    fn name(&self) -> &'static str {
142        "EuclideanClusterExtractor"
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::{EuclideanClusterConfig, EuclideanClusterExtractor};
149    use spatialrust_core::{PointCloudBuilder, StandardSchemas};
150
151    fn three_clusters() -> spatialrust_core::PointCloud {
152        let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
153        for center in [(0.0, 0.0, 0.0), (10.0, 0.0, 0.0), (0.0, 10.0, 0.0)] {
154            for dx in 0..3 {
155                for dy in 0..3 {
156                    builder
157                        .push_point([center.0 + dx as f32, center.1 + dy as f32, center.2])
158                        .unwrap();
159                }
160            }
161        }
162        builder.build().unwrap()
163    }
164
165    #[test]
166    fn finds_three_separated_clusters() {
167        let input = three_clusters();
168        let extractor = EuclideanClusterExtractor::new(EuclideanClusterConfig {
169            cluster_tolerance: 1.5,
170            min_cluster_size: 3,
171            max_cluster_size: usize::MAX,
172        });
173        let result = extractor.extract(&input).unwrap();
174        assert_eq!(result.cluster_count, 3);
175        assert!(result.cluster_sizes.iter().all(|&size| size == 9));
176        assert!(result.cloud.field("label").is_ok());
177    }
178
179    #[test]
180    fn rejects_clusters_smaller_than_minimum() {
181        let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
182        builder.push_point([0.0, 0.0, 0.0]).unwrap();
183        builder.push_point([0.1, 0.0, 0.0]).unwrap();
184        let input = builder.build().unwrap();
185
186        let extractor = EuclideanClusterExtractor::new(EuclideanClusterConfig {
187            cluster_tolerance: 0.5,
188            min_cluster_size: 3,
189            max_cluster_size: usize::MAX,
190        });
191        let result = extractor.extract(&input).unwrap();
192        assert_eq!(result.cluster_count, 0);
193    }
194}