1use spatialrust_core::{
8 DType, FieldSemantic, HasPositions3, PointBuffer, PointBufferSet, PointCloud, PointField,
9 SpatialError, SpatialResult,
10};
11use spatialrust_search::KdTree;
12
13use crate::filter::PointCloudFilter;
14
15#[derive(Clone, Copy, Debug, PartialEq)]
17pub struct StatisticalOutlierConfig {
18 pub k_neighbors: usize,
20 pub std_mul: f32,
23}
24
25impl Default for StatisticalOutlierConfig {
26 fn default() -> Self {
27 Self { k_neighbors: 16, std_mul: 1.0 }
28 }
29}
30
31impl StatisticalOutlierConfig {
32 #[must_use]
34 pub const fn new(k_neighbors: usize, std_mul: f32) -> Self {
35 Self { k_neighbors, std_mul }
36 }
37}
38
39#[derive(Clone, Copy, Debug, PartialEq)]
45pub struct StatisticalOutlierRemoval {
46 config: StatisticalOutlierConfig,
47}
48
49impl StatisticalOutlierRemoval {
50 #[must_use]
52 pub const fn new(config: StatisticalOutlierConfig) -> Self {
53 Self { config }
54 }
55
56 #[must_use]
58 pub const fn config(&self) -> StatisticalOutlierConfig {
59 self.config
60 }
61
62 pub fn keep_mask(&self, input: &PointCloud) -> SpatialResult<Vec<bool>> {
64 if self.config.k_neighbors == 0 {
65 return Err(SpatialError::InvalidArgument(
66 "k_neighbors must be greater than zero".to_owned(),
67 ));
68 }
69 let len = input.len();
70 if len == 0 {
71 return Ok(Vec::new());
72 }
73
74 let (x, y, z) = input.positions3()?;
75 let tree = KdTree::from_slices(x, y, z);
76
77 let mut mean_dist = vec![0.0_f32; len];
79 fill_mean_neighbor_distances(self.config.k_neighbors, &tree, x, y, z, &mut mean_dist);
80
81 let n = len as f64;
82 let mean: f64 = mean_dist.iter().map(|&d| d as f64).sum::<f64>() / n;
83 let variance: f64 = mean_dist.iter().map(|&d| (d as f64 - mean).powi(2)).sum::<f64>() / n;
84 let std = variance.sqrt();
85 let threshold = mean + self.config.std_mul as f64 * std;
86
87 Ok(mean_dist.iter().map(|&d| d as f64 <= threshold).collect())
88 }
89}
90
91impl PointCloudFilter for StatisticalOutlierRemoval {
92 fn name(&self) -> &'static str {
93 "StatisticalOutlierRemoval"
94 }
95
96 fn filter(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
97 let mask = self.keep_mask(input)?;
98 gather_mask(input, &mask)
99 }
100}
101
102fn fill_mean_neighbor_distances(
103 k_neighbors: usize,
104 tree: &KdTree,
105 x: &[f32],
106 y: &[f32],
107 z: &[f32],
108 mean_dist: &mut [f32],
109) {
110 let worker_count = outlier_worker_count(mean_dist.len());
111 if worker_count == 1 {
112 fill_mean_neighbor_distances_chunk(k_neighbors, tree, x, y, z, 0, mean_dist);
113 return;
114 }
115
116 let chunk_size = mean_dist.len().div_ceil(worker_count);
117 std::thread::scope(|scope| {
118 for (chunk_index, chunk) in mean_dist.chunks_mut(chunk_size).enumerate() {
119 let start = chunk_index * chunk_size;
120 scope.spawn(move || {
121 fill_mean_neighbor_distances_chunk(k_neighbors, tree, x, y, z, start, chunk);
122 });
123 }
124 });
125}
126
127fn outlier_worker_count(len: usize) -> usize {
128 let available = std::thread::available_parallelism().map_or(1, |count| count.get());
129 let useful = (len / 16_384).max(1);
130 available.min(useful)
131}
132
133fn fill_mean_neighbor_distances_chunk(
134 k_neighbors: usize,
135 tree: &KdTree,
136 x: &[f32],
137 y: &[f32],
138 z: &[f32],
139 start: usize,
140 mean_dist: &mut [f32],
141) {
142 let mut neighbors = Vec::with_capacity(k_neighbors.saturating_add(1));
143 for (offset, mean) in mean_dist.iter_mut().enumerate() {
144 let i = start + offset;
145 tree.nearest_k_unsorted_into(
146 x[i],
147 y[i],
148 z[i],
149 k_neighbors.saturating_add(1),
150 &mut neighbors,
151 );
152 let mut sum = 0.0_f32;
153 let mut count = 0_u32;
154 for neighbor in &neighbors {
155 if neighbor.index == i {
156 continue;
157 }
158 sum += neighbor.distance_squared.sqrt();
159 count += 1;
160 }
161 *mean = if count == 0 { 0.0 } else { sum / count as f32 };
162 }
163}
164
165#[derive(Clone, Copy, Debug, PartialEq)]
167pub struct RadiusOutlierConfig {
168 pub radius: f32,
170 pub min_neighbors: usize,
172}
173
174impl Default for RadiusOutlierConfig {
175 fn default() -> Self {
176 Self { radius: 0.5, min_neighbors: 4 }
177 }
178}
179
180impl RadiusOutlierConfig {
181 #[must_use]
183 pub const fn new(radius: f32, min_neighbors: usize) -> Self {
184 Self { radius, min_neighbors }
185 }
186}
187
188#[derive(Clone, Copy, Debug, PartialEq)]
194pub struct RadiusOutlierRemoval {
195 config: RadiusOutlierConfig,
196}
197
198impl RadiusOutlierRemoval {
199 #[must_use]
201 pub const fn new(config: RadiusOutlierConfig) -> Self {
202 Self { config }
203 }
204
205 #[must_use]
207 pub const fn config(&self) -> RadiusOutlierConfig {
208 self.config
209 }
210
211 pub fn keep_mask(&self, input: &PointCloud) -> SpatialResult<Vec<bool>> {
213 if self.config.radius <= 0.0 || self.config.radius.is_nan() {
214 return Err(SpatialError::InvalidArgument("radius must be positive".to_owned()));
215 }
216 let len = input.len();
217 if len == 0 {
218 return Ok(Vec::new());
219 }
220
221 let (x, y, z) = input.positions3()?;
222 let tree = KdTree::from_slices(x, y, z);
223
224 let target = self.config.min_neighbors + 1;
228 let mut keep = vec![false; len];
229 fill_radius_reaches_mask(&tree, x, y, z, self.config.radius, target, &mut keep);
230 Ok(keep)
231 }
232}
233
234impl PointCloudFilter for RadiusOutlierRemoval {
235 fn name(&self) -> &'static str {
236 "RadiusOutlierRemoval"
237 }
238
239 fn filter(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
240 let mask = self.keep_mask(input)?;
241 gather_mask(input, &mask)
242 }
243}
244
245fn fill_radius_reaches_mask(
246 tree: &KdTree,
247 x: &[f32],
248 y: &[f32],
249 z: &[f32],
250 radius: f32,
251 target: usize,
252 keep: &mut [bool],
253) {
254 let worker_count = outlier_worker_count(keep.len());
255 if worker_count == 1 {
256 fill_radius_reaches_mask_chunk(tree, x, y, z, radius, target, 0, keep);
257 return;
258 }
259
260 let chunk_size = keep.len().div_ceil(worker_count);
261 std::thread::scope(|scope| {
262 for (chunk_index, chunk) in keep.chunks_mut(chunk_size).enumerate() {
263 let start = chunk_index * chunk_size;
264 scope.spawn(move || {
265 fill_radius_reaches_mask_chunk(tree, x, y, z, radius, target, start, chunk);
266 });
267 }
268 });
269}
270
271fn fill_radius_reaches_mask_chunk(
272 tree: &KdTree,
273 x: &[f32],
274 y: &[f32],
275 z: &[f32],
276 radius: f32,
277 target: usize,
278 start: usize,
279 keep: &mut [bool],
280) {
281 for (offset, keep_point) in keep.iter_mut().enumerate() {
282 let i = start + offset;
283 *keep_point = tree.radius_reaches(x[i], y[i], z[i], radius, target);
284 }
285}
286
287fn gather_mask(input: &PointCloud, mask: &[bool]) -> SpatialResult<PointCloud> {
289 if let Some(output) = gather_xyz_mask(input, mask)? {
290 return Ok(output);
291 }
292
293 let indices: Vec<usize> =
294 mask.iter().enumerate().filter_map(|(i, &keep)| keep.then_some(i)).collect();
295
296 let mut buffers = PointBufferSet::new();
297 for field in input.schema().fields() {
298 let source = input.field(&field.name)?;
299 buffers.insert(field.name.clone(), gather_buffer(source, &indices));
300 }
301 PointCloud::try_from_parts(input.schema().clone(), buffers, input.metadata().clone())
302}
303
304fn gather_xyz_mask(input: &PointCloud, mask: &[bool]) -> SpatialResult<Option<PointCloud>> {
305 let schema = input.schema();
306 if schema.len() != 3 {
307 return Ok(None);
308 }
309
310 let Some(x_field) = xyz_f32_field(input, FieldSemantic::PositionX) else {
311 return Ok(None);
312 };
313 let Some(y_field) = xyz_f32_field(input, FieldSemantic::PositionY) else {
314 return Ok(None);
315 };
316 let Some(z_field) = xyz_f32_field(input, FieldSemantic::PositionZ) else {
317 return Ok(None);
318 };
319
320 let (x, y, z) = input.positions3()?;
321 let output_len = mask.iter().filter(|&&keep| keep).count();
322 let mut out_x = Vec::with_capacity(output_len);
323 let mut out_y = Vec::with_capacity(output_len);
324 let mut out_z = Vec::with_capacity(output_len);
325
326 for (index, &keep) in mask.iter().enumerate() {
327 if keep {
328 out_x.push(x[index]);
329 out_y.push(y[index]);
330 out_z.push(z[index]);
331 }
332 }
333
334 let mut buffers = PointBufferSet::new();
335 buffers.insert(x_field.name.clone(), PointBuffer::from_f32(out_x));
336 buffers.insert(y_field.name.clone(), PointBuffer::from_f32(out_y));
337 buffers.insert(z_field.name.clone(), PointBuffer::from_f32(out_z));
338 PointCloud::try_from_parts(schema.clone(), buffers, input.metadata().clone()).map(Some)
339}
340
341fn xyz_f32_field(input: &PointCloud, semantic: FieldSemantic) -> Option<&PointField> {
342 let field = input.schema().find_semantic(semantic)?;
343 (field.dtype == DType::F32 && field.components == 1).then_some(field)
344}
345
346fn gather_buffer(source: &PointBuffer, indices: &[usize]) -> PointBuffer {
347 match source {
348 PointBuffer::F32(v) => PointBuffer::from_f32(indices.iter().map(|&i| v[i]).collect()),
349 PointBuffer::F64(v) => PointBuffer::F64(indices.iter().map(|&i| v[i]).collect()),
350 PointBuffer::U8(v) => PointBuffer::U8(indices.iter().map(|&i| v[i]).collect()),
351 PointBuffer::U16(v) => PointBuffer::U16(indices.iter().map(|&i| v[i]).collect()),
352 PointBuffer::U32(v) => PointBuffer::U32(indices.iter().map(|&i| v[i]).collect()),
353 PointBuffer::I32(v) => PointBuffer::I32(indices.iter().map(|&i| v[i]).collect()),
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use spatialrust_core::{DType, FieldSemantic, PointField, PointSchema};
361
362 fn cloud_from_xyz(points: &[[f32; 3]]) -> PointCloud {
363 let schema = PointSchema::new()
364 .with_field(PointField::scalar("x", FieldSemantic::PositionX, DType::F32))
365 .with_field(PointField::scalar("y", FieldSemantic::PositionY, DType::F32))
366 .with_field(PointField::scalar("z", FieldSemantic::PositionZ, DType::F32));
367 let mut buffers = PointBufferSet::new();
368 buffers
369 .insert("x".to_owned(), PointBuffer::from_f32(points.iter().map(|p| p[0]).collect()));
370 buffers
371 .insert("y".to_owned(), PointBuffer::from_f32(points.iter().map(|p| p[1]).collect()));
372 buffers
373 .insert("z".to_owned(), PointBuffer::from_f32(points.iter().map(|p| p[2]).collect()));
374 PointCloud::try_from_parts(schema, buffers, Default::default()).unwrap()
375 }
376
377 fn grid_with_outlier() -> (PointCloud, usize) {
379 let mut points = Vec::new();
380 for ix in 0..6 {
381 for iy in 0..6 {
382 points.push([ix as f32, iy as f32, 0.0]);
383 }
384 }
385 let outlier_index = points.len();
386 points.push([100.0, 100.0, 100.0]);
387 (cloud_from_xyz(&points), outlier_index)
388 }
389
390 #[test]
391 fn sor_removes_isolated_speckle() {
392 let (cloud, outlier) = grid_with_outlier();
393 let filter = StatisticalOutlierRemoval::new(StatisticalOutlierConfig::new(8, 1.0));
394 let mask = filter.keep_mask(&cloud).unwrap();
395 assert!(!mask[outlier], "the far speckle must be dropped");
396 assert!(mask[..outlier].iter().all(|&k| k));
398 let out = filter.filter(&cloud).unwrap();
399 assert_eq!(out.len(), cloud.len() - 1);
400 }
401
402 #[test]
403 fn ror_removes_isolated_speckle() {
404 let (cloud, outlier) = grid_with_outlier();
405 let filter = RadiusOutlierRemoval::new(RadiusOutlierConfig::new(1.5, 2));
406 let mask = filter.keep_mask(&cloud).unwrap();
407 assert!(!mask[outlier], "the far speckle has no neighbors in radius");
408 assert!(mask[..outlier].iter().all(|&k| k));
409 }
410
411 #[test]
412 fn empty_cloud_is_passthrough() {
413 let cloud = cloud_from_xyz(&[]);
414 let sor = StatisticalOutlierRemoval::new(StatisticalOutlierConfig::default());
415 assert_eq!(sor.filter(&cloud).unwrap().len(), 0);
416 let ror = RadiusOutlierRemoval::new(RadiusOutlierConfig::default());
417 assert_eq!(ror.filter(&cloud).unwrap().len(), 0);
418 }
419
420 #[test]
421 fn invalid_params_error() {
422 let cloud = cloud_from_xyz(&[[0.0, 0.0, 0.0]]);
423 assert!(StatisticalOutlierRemoval::new(StatisticalOutlierConfig::new(0, 1.0))
424 .keep_mask(&cloud)
425 .is_err());
426 assert!(RadiusOutlierRemoval::new(RadiusOutlierConfig::new(0.0, 1))
427 .keep_mask(&cloud)
428 .is_err());
429 }
430}