spatialrust_segmentation/
dbscan.rs1use std::collections::VecDeque;
2
3use spatialrust_core::{HasPositions3, PointCloud, SpatialError, SpatialResult};
4use spatialrust_search::{KdTree, RadiusSearchIndex};
5
6use crate::cloud::with_labels;
7use crate::segmenter::PointCloudSegmenter;
8
9#[derive(Clone, Copy, Debug, PartialEq)]
11pub struct DbscanConfig {
12 pub eps: f32,
14 pub min_points: usize,
18}
19
20impl Default for DbscanConfig {
21 fn default() -> Self {
22 Self { eps: 0.5, min_points: 10 }
23 }
24}
25
26impl DbscanConfig {
27 #[must_use]
29 pub const fn new(eps: f32, min_points: usize) -> Self {
30 Self { eps, min_points }
31 }
32}
33
34#[derive(Clone, Debug, PartialEq)]
36pub struct DbscanResult {
37 pub cloud: PointCloud,
39 pub cluster_count: usize,
41 pub cluster_sizes: Vec<usize>,
43 pub noise_count: usize,
45}
46
47#[derive(Clone, Copy, Debug, PartialEq)]
55pub struct DbscanSegmenter {
56 config: DbscanConfig,
57}
58
59impl DbscanSegmenter {
60 #[must_use]
62 pub const fn new(config: DbscanConfig) -> Self {
63 Self { config }
64 }
65
66 #[must_use]
68 pub const fn config(&self) -> DbscanConfig {
69 self.config
70 }
71
72 pub fn segment(&self, input: &PointCloud) -> SpatialResult<DbscanResult> {
74 if input.is_empty() {
75 return Ok(DbscanResult {
76 cloud: input.clone(),
77 cluster_count: 0,
78 cluster_sizes: Vec::new(),
79 noise_count: 0,
80 });
81 }
82 if self.config.eps <= 0.0 || self.config.eps.is_nan() {
83 return Err(SpatialError::InvalidArgument("eps must be positive".to_owned()));
84 }
85 if self.config.min_points == 0 {
86 return Err(SpatialError::InvalidArgument(
87 "min_points must be greater than zero".to_owned(),
88 ));
89 }
90
91 let (x, y, z) = input.positions3()?;
92 let tree = KdTree::from_slices(x, y, z);
93 let len = input.len();
94
95 let mut labels = vec![-1_i32; len];
98 let mut visited = vec![false; len];
99 let mut cluster_sizes = Vec::new();
100 let mut cluster_id = 0_i32;
101
102 for seed in 0..len {
103 if visited[seed] {
104 continue;
105 }
106 visited[seed] = true;
107
108 let seed_neighbors = tree.radius_search(x[seed], y[seed], z[seed], self.config.eps);
109 if seed_neighbors.len() < self.config.min_points {
110 continue;
113 }
114
115 labels[seed] = cluster_id;
117 let mut size = 1_usize;
118 let mut queue: VecDeque<usize> =
119 seed_neighbors.iter().map(|n| n.index).filter(|&i| i != seed).collect();
120
121 while let Some(current) = queue.pop_front() {
122 if labels[current] == -1 {
123 labels[current] = cluster_id;
125 size += 1;
126 }
127 if visited[current] {
128 continue;
129 }
130 visited[current] = true;
131
132 let neighbors =
133 tree.radius_search(x[current], y[current], z[current], self.config.eps);
134 if neighbors.len() >= self.config.min_points {
135 for neighbor in neighbors {
137 if !visited[neighbor.index] || labels[neighbor.index] == -1 {
138 queue.push_back(neighbor.index);
139 }
140 }
141 }
142 }
143
144 cluster_sizes.push(size);
145 cluster_id += 1;
146 }
147
148 let noise_count = labels.iter().filter(|&&l| l == -1).count();
149 Ok(DbscanResult {
150 cloud: with_labels(input, "label", labels)?,
151 cluster_count: cluster_sizes.len(),
152 cluster_sizes,
153 noise_count,
154 })
155 }
156}
157
158impl PointCloudSegmenter for DbscanSegmenter {
159 fn name(&self) -> &'static str {
160 "DbscanSegmenter"
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use super::{DbscanConfig, DbscanSegmenter};
167 use spatialrust_core::{PointCloudBuilder, StandardSchemas};
168
169 fn two_blobs_with_noise() -> spatialrust_core::PointCloud {
171 let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
172 for center in [(0.0, 0.0, 0.0), (10.0, 0.0, 0.0)] {
173 for dx in 0..3 {
174 for dy in 0..3 {
175 builder
176 .push_point([center.0 + dx as f32 * 0.2, center.1 + dy as f32 * 0.2, 0.0])
177 .unwrap();
178 }
179 }
180 }
181 builder.push_point([100.0, 100.0, 100.0]).unwrap();
182 builder.build().unwrap()
183 }
184
185 #[test]
186 fn finds_two_clusters_and_isolates_noise() {
187 let input = two_blobs_with_noise();
188 let result = DbscanSegmenter::new(DbscanConfig::new(0.5, 4)).segment(&input).unwrap();
189 assert_eq!(result.cluster_count, 2);
190 assert_eq!(result.noise_count, 1);
191 assert!(result.cluster_sizes.iter().all(|&s| s == 9));
192 }
193
194 #[test]
195 fn all_noise_when_min_points_too_high() {
196 let input = two_blobs_with_noise();
197 let result = DbscanSegmenter::new(DbscanConfig::new(0.5, 50)).segment(&input).unwrap();
199 assert_eq!(result.cluster_count, 0);
200 assert_eq!(result.noise_count, input.len());
201 }
202
203 #[test]
204 fn invalid_params_error() {
205 let input = two_blobs_with_noise();
206 assert!(DbscanSegmenter::new(DbscanConfig::new(0.0, 4)).segment(&input).is_err());
207 assert!(DbscanSegmenter::new(DbscanConfig::new(0.5, 0)).segment(&input).is_err());
208 }
209}