Skip to main content

spatialrust_filtering/
outlier.rs

1//! Neighborhood-based outlier removal filters.
2//!
3//! Both filters use a KD-tree over the input positions and drop points whose
4//! local neighborhood looks sparse, which removes scanner speckle and stray
5//! returns before downstream estimation (normals, registration, segmentation).
6
7use spatialrust_core::{
8    DType, FieldSemantic, HasPositions3, PointBuffer, PointBufferSet, PointCloud, PointField,
9    SpatialError, SpatialResult,
10};
11use spatialrust_search::KdTree;
12
13use crate::filter::PointCloudFilter;
14
15/// Configuration for [`StatisticalOutlierRemoval`].
16#[derive(Clone, Copy, Debug, PartialEq)]
17pub struct StatisticalOutlierConfig {
18    /// Number of nearest neighbors averaged per point.
19    pub k_neighbors: usize,
20    /// Standard-deviation multiplier; points whose mean neighbor distance
21    /// exceeds `global_mean + std_mul * global_std` are removed.
22    pub std_mul: f32,
23}
24
25impl Default for StatisticalOutlierConfig {
26    fn default() -> Self {
27        Self { k_neighbors: 16, std_mul: 1.0 }
28    }
29}
30
31impl StatisticalOutlierConfig {
32    /// Creates a config from the neighbor count and std multiplier.
33    #[must_use]
34    pub const fn new(k_neighbors: usize, std_mul: f32) -> Self {
35        Self { k_neighbors, std_mul }
36    }
37}
38
39/// Statistical Outlier Removal (SOR).
40///
41/// For each point the mean distance to its `k` nearest neighbors is computed.
42/// Assuming those means are roughly Gaussian, points whose mean distance is
43/// more than `std_mul` standard deviations above the global mean are dropped.
44#[derive(Clone, Copy, Debug, PartialEq)]
45pub struct StatisticalOutlierRemoval {
46    config: StatisticalOutlierConfig,
47}
48
49impl StatisticalOutlierRemoval {
50    /// Creates a filter from config.
51    #[must_use]
52    pub const fn new(config: StatisticalOutlierConfig) -> Self {
53        Self { config }
54    }
55
56    /// Returns the filter config.
57    #[must_use]
58    pub const fn config(&self) -> StatisticalOutlierConfig {
59        self.config
60    }
61
62    /// Computes the keep mask without materializing the filtered cloud.
63    pub fn keep_mask(&self, input: &PointCloud) -> SpatialResult<Vec<bool>> {
64        if self.config.k_neighbors == 0 {
65            return Err(SpatialError::InvalidArgument(
66                "k_neighbors must be greater than zero".to_owned(),
67            ));
68        }
69        let len = input.len();
70        if len == 0 {
71            return Ok(Vec::new());
72        }
73
74        let (x, y, z) = input.positions3()?;
75        let tree = KdTree::from_slices(x, y, z);
76
77        // Mean distance to the k nearest neighbors (excluding the point itself).
78        let mut mean_dist = vec![0.0_f32; len];
79        fill_mean_neighbor_distances(self.config.k_neighbors, &tree, x, y, z, &mut mean_dist);
80
81        let n = len as f64;
82        let mean: f64 = mean_dist.iter().map(|&d| d as f64).sum::<f64>() / n;
83        let variance: f64 = mean_dist.iter().map(|&d| (d as f64 - mean).powi(2)).sum::<f64>() / n;
84        let std = variance.sqrt();
85        let threshold = mean + self.config.std_mul as f64 * std;
86
87        Ok(mean_dist.iter().map(|&d| d as f64 <= threshold).collect())
88    }
89}
90
91impl PointCloudFilter for StatisticalOutlierRemoval {
92    fn name(&self) -> &'static str {
93        "StatisticalOutlierRemoval"
94    }
95
96    fn filter(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
97        let mask = self.keep_mask(input)?;
98        gather_mask(input, &mask)
99    }
100}
101
102fn fill_mean_neighbor_distances(
103    k_neighbors: usize,
104    tree: &KdTree,
105    x: &[f32],
106    y: &[f32],
107    z: &[f32],
108    mean_dist: &mut [f32],
109) {
110    let worker_count = outlier_worker_count(mean_dist.len());
111    if worker_count == 1 {
112        fill_mean_neighbor_distances_chunk(k_neighbors, tree, x, y, z, 0, mean_dist);
113        return;
114    }
115
116    let chunk_size = mean_dist.len().div_ceil(worker_count);
117    std::thread::scope(|scope| {
118        for (chunk_index, chunk) in mean_dist.chunks_mut(chunk_size).enumerate() {
119            let start = chunk_index * chunk_size;
120            scope.spawn(move || {
121                fill_mean_neighbor_distances_chunk(k_neighbors, tree, x, y, z, start, chunk);
122            });
123        }
124    });
125}
126
127fn outlier_worker_count(len: usize) -> usize {
128    let available = std::thread::available_parallelism().map_or(1, |count| count.get());
129    let useful = (len / 16_384).max(1);
130    available.min(useful)
131}
132
133fn fill_mean_neighbor_distances_chunk(
134    k_neighbors: usize,
135    tree: &KdTree,
136    x: &[f32],
137    y: &[f32],
138    z: &[f32],
139    start: usize,
140    mean_dist: &mut [f32],
141) {
142    let mut neighbors = Vec::with_capacity(k_neighbors.saturating_add(1));
143    for (offset, mean) in mean_dist.iter_mut().enumerate() {
144        let i = start + offset;
145        tree.nearest_k_unsorted_into(
146            x[i],
147            y[i],
148            z[i],
149            k_neighbors.saturating_add(1),
150            &mut neighbors,
151        );
152        let mut sum = 0.0_f32;
153        let mut count = 0_u32;
154        for neighbor in &neighbors {
155            if neighbor.index == i {
156                continue;
157            }
158            sum += neighbor.distance_squared.sqrt();
159            count += 1;
160        }
161        *mean = if count == 0 { 0.0 } else { sum / count as f32 };
162    }
163}
164
165/// Configuration for [`RadiusOutlierRemoval`].
166#[derive(Clone, Copy, Debug, PartialEq)]
167pub struct RadiusOutlierConfig {
168    /// Search radius (not squared) defining a point's neighborhood.
169    pub radius: f32,
170    /// Minimum neighbors (excluding the point itself) required to keep a point.
171    pub min_neighbors: usize,
172}
173
174impl Default for RadiusOutlierConfig {
175    fn default() -> Self {
176        Self { radius: 0.5, min_neighbors: 4 }
177    }
178}
179
180impl RadiusOutlierConfig {
181    /// Creates a config from the radius and minimum neighbor count.
182    #[must_use]
183    pub const fn new(radius: f32, min_neighbors: usize) -> Self {
184        Self { radius, min_neighbors }
185    }
186}
187
188/// Radius Outlier Removal (ROR).
189///
190/// Drops every point that has fewer than `min_neighbors` other points within
191/// `radius`. Unlike SOR this uses an absolute density threshold, so it is robust
192/// when outliers are clustered rather than isolated.
193#[derive(Clone, Copy, Debug, PartialEq)]
194pub struct RadiusOutlierRemoval {
195    config: RadiusOutlierConfig,
196}
197
198impl RadiusOutlierRemoval {
199    /// Creates a filter from config.
200    #[must_use]
201    pub const fn new(config: RadiusOutlierConfig) -> Self {
202        Self { config }
203    }
204
205    /// Returns the filter config.
206    #[must_use]
207    pub const fn config(&self) -> RadiusOutlierConfig {
208        self.config
209    }
210
211    /// Computes the keep mask without materializing the filtered cloud.
212    pub fn keep_mask(&self, input: &PointCloud) -> SpatialResult<Vec<bool>> {
213        if self.config.radius <= 0.0 || self.config.radius.is_nan() {
214            return Err(SpatialError::InvalidArgument("radius must be positive".to_owned()));
215        }
216        let len = input.len();
217        if len == 0 {
218            return Ok(Vec::new());
219        }
220
221        let (x, y, z) = input.positions3()?;
222        let tree = KdTree::from_slices(x, y, z);
223
224        // The query point itself is in the tree, so requiring `min_neighbors`
225        // *other* points within radius means reaching `min_neighbors + 1` total.
226        // `radius_reaches` early-exits at that threshold without allocating.
227        let target = self.config.min_neighbors + 1;
228        let mut keep = vec![false; len];
229        fill_radius_reaches_mask(&tree, x, y, z, self.config.radius, target, &mut keep);
230        Ok(keep)
231    }
232}
233
234impl PointCloudFilter for RadiusOutlierRemoval {
235    fn name(&self) -> &'static str {
236        "RadiusOutlierRemoval"
237    }
238
239    fn filter(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
240        let mask = self.keep_mask(input)?;
241        gather_mask(input, &mask)
242    }
243}
244
245fn fill_radius_reaches_mask(
246    tree: &KdTree,
247    x: &[f32],
248    y: &[f32],
249    z: &[f32],
250    radius: f32,
251    target: usize,
252    keep: &mut [bool],
253) {
254    let worker_count = outlier_worker_count(keep.len());
255    if worker_count == 1 {
256        fill_radius_reaches_mask_chunk(tree, x, y, z, radius, target, 0, keep);
257        return;
258    }
259
260    let chunk_size = keep.len().div_ceil(worker_count);
261    std::thread::scope(|scope| {
262        for (chunk_index, chunk) in keep.chunks_mut(chunk_size).enumerate() {
263            let start = chunk_index * chunk_size;
264            scope.spawn(move || {
265                fill_radius_reaches_mask_chunk(tree, x, y, z, radius, target, start, chunk);
266            });
267        }
268    });
269}
270
271fn fill_radius_reaches_mask_chunk(
272    tree: &KdTree,
273    x: &[f32],
274    y: &[f32],
275    z: &[f32],
276    radius: f32,
277    target: usize,
278    start: usize,
279    keep: &mut [bool],
280) {
281    for (offset, keep_point) in keep.iter_mut().enumerate() {
282        let i = start + offset;
283        *keep_point = tree.radius_reaches(x[i], y[i], z[i], radius, target);
284    }
285}
286
287/// Builds a new cloud from the points where `mask` is true, preserving schema.
288fn gather_mask(input: &PointCloud, mask: &[bool]) -> SpatialResult<PointCloud> {
289    if let Some(output) = gather_xyz_mask(input, mask)? {
290        return Ok(output);
291    }
292
293    let indices: Vec<usize> =
294        mask.iter().enumerate().filter_map(|(i, &keep)| keep.then_some(i)).collect();
295
296    let mut buffers = PointBufferSet::new();
297    for field in input.schema().fields() {
298        let source = input.field(&field.name)?;
299        buffers.insert(field.name.clone(), gather_buffer(source, &indices));
300    }
301    PointCloud::try_from_parts(input.schema().clone(), buffers, input.metadata().clone())
302}
303
304fn gather_xyz_mask(input: &PointCloud, mask: &[bool]) -> SpatialResult<Option<PointCloud>> {
305    let schema = input.schema();
306    if schema.len() != 3 {
307        return Ok(None);
308    }
309
310    let Some(x_field) = xyz_f32_field(input, FieldSemantic::PositionX) else {
311        return Ok(None);
312    };
313    let Some(y_field) = xyz_f32_field(input, FieldSemantic::PositionY) else {
314        return Ok(None);
315    };
316    let Some(z_field) = xyz_f32_field(input, FieldSemantic::PositionZ) else {
317        return Ok(None);
318    };
319
320    let (x, y, z) = input.positions3()?;
321    let output_len = mask.iter().filter(|&&keep| keep).count();
322    let mut out_x = Vec::with_capacity(output_len);
323    let mut out_y = Vec::with_capacity(output_len);
324    let mut out_z = Vec::with_capacity(output_len);
325
326    for (index, &keep) in mask.iter().enumerate() {
327        if keep {
328            out_x.push(x[index]);
329            out_y.push(y[index]);
330            out_z.push(z[index]);
331        }
332    }
333
334    let mut buffers = PointBufferSet::new();
335    buffers.insert(x_field.name.clone(), PointBuffer::from_f32(out_x));
336    buffers.insert(y_field.name.clone(), PointBuffer::from_f32(out_y));
337    buffers.insert(z_field.name.clone(), PointBuffer::from_f32(out_z));
338    PointCloud::try_from_parts(schema.clone(), buffers, input.metadata().clone()).map(Some)
339}
340
341fn xyz_f32_field(input: &PointCloud, semantic: FieldSemantic) -> Option<&PointField> {
342    let field = input.schema().find_semantic(semantic)?;
343    (field.dtype == DType::F32 && field.components == 1).then_some(field)
344}
345
346fn gather_buffer(source: &PointBuffer, indices: &[usize]) -> PointBuffer {
347    match source {
348        PointBuffer::F32(v) => PointBuffer::from_f32(indices.iter().map(|&i| v[i]).collect()),
349        PointBuffer::F64(v) => PointBuffer::F64(indices.iter().map(|&i| v[i]).collect()),
350        PointBuffer::U8(v) => PointBuffer::U8(indices.iter().map(|&i| v[i]).collect()),
351        PointBuffer::U16(v) => PointBuffer::U16(indices.iter().map(|&i| v[i]).collect()),
352        PointBuffer::U32(v) => PointBuffer::U32(indices.iter().map(|&i| v[i]).collect()),
353        PointBuffer::I32(v) => PointBuffer::I32(indices.iter().map(|&i| v[i]).collect()),
354    }
355}
356
357#[cfg(test)]
358mod tests {
359    use super::*;
360    use spatialrust_core::{DType, FieldSemantic, PointField, PointSchema};
361
362    fn cloud_from_xyz(points: &[[f32; 3]]) -> PointCloud {
363        let schema = PointSchema::new()
364            .with_field(PointField::scalar("x", FieldSemantic::PositionX, DType::F32))
365            .with_field(PointField::scalar("y", FieldSemantic::PositionY, DType::F32))
366            .with_field(PointField::scalar("z", FieldSemantic::PositionZ, DType::F32));
367        let mut buffers = PointBufferSet::new();
368        buffers
369            .insert("x".to_owned(), PointBuffer::from_f32(points.iter().map(|p| p[0]).collect()));
370        buffers
371            .insert("y".to_owned(), PointBuffer::from_f32(points.iter().map(|p| p[1]).collect()));
372        buffers
373            .insert("z".to_owned(), PointBuffer::from_f32(points.iter().map(|p| p[2]).collect()));
374        PointCloud::try_from_parts(schema, buffers, Default::default()).unwrap()
375    }
376
377    /// A dense unit-spaced grid plus one far-away speckle point.
378    fn grid_with_outlier() -> (PointCloud, usize) {
379        let mut points = Vec::new();
380        for ix in 0..6 {
381            for iy in 0..6 {
382                points.push([ix as f32, iy as f32, 0.0]);
383            }
384        }
385        let outlier_index = points.len();
386        points.push([100.0, 100.0, 100.0]);
387        (cloud_from_xyz(&points), outlier_index)
388    }
389
390    #[test]
391    fn sor_removes_isolated_speckle() {
392        let (cloud, outlier) = grid_with_outlier();
393        let filter = StatisticalOutlierRemoval::new(StatisticalOutlierConfig::new(8, 1.0));
394        let mask = filter.keep_mask(&cloud).unwrap();
395        assert!(!mask[outlier], "the far speckle must be dropped");
396        // Every dense grid point should survive.
397        assert!(mask[..outlier].iter().all(|&k| k));
398        let out = filter.filter(&cloud).unwrap();
399        assert_eq!(out.len(), cloud.len() - 1);
400    }
401
402    #[test]
403    fn ror_removes_isolated_speckle() {
404        let (cloud, outlier) = grid_with_outlier();
405        let filter = RadiusOutlierRemoval::new(RadiusOutlierConfig::new(1.5, 2));
406        let mask = filter.keep_mask(&cloud).unwrap();
407        assert!(!mask[outlier], "the far speckle has no neighbors in radius");
408        assert!(mask[..outlier].iter().all(|&k| k));
409    }
410
411    #[test]
412    fn empty_cloud_is_passthrough() {
413        let cloud = cloud_from_xyz(&[]);
414        let sor = StatisticalOutlierRemoval::new(StatisticalOutlierConfig::default());
415        assert_eq!(sor.filter(&cloud).unwrap().len(), 0);
416        let ror = RadiusOutlierRemoval::new(RadiusOutlierConfig::default());
417        assert_eq!(ror.filter(&cloud).unwrap().len(), 0);
418    }
419
420    #[test]
421    fn invalid_params_error() {
422        let cloud = cloud_from_xyz(&[[0.0, 0.0, 0.0]]);
423        assert!(StatisticalOutlierRemoval::new(StatisticalOutlierConfig::new(0, 1.0))
424            .keep_mask(&cloud)
425            .is_err());
426        assert!(RadiusOutlierRemoval::new(RadiusOutlierConfig::new(0.0, 1))
427            .keep_mask(&cloud)
428            .is_err());
429    }
430}