Skip to main content

spatialrust_filtering/
fps.rs

1//! Farthest Point Sampling (FPS).
2//!
3//! Greedily selects a subset that is spread as evenly as possible over the
4//! cloud: each new point is the one farthest from everything chosen so far. This
5//! is the standard downsampling for learned point-cloud models (PointNet++ and
6//! friends), where uniform spatial coverage matters more than a fixed grid.
7
8use spatialrust_core::{
9    HasPositions3, PointBuffer, PointBufferSet, PointCloud, SpatialError, SpatialResult,
10};
11
12use crate::filter::PointCloudFilter;
13
14/// Configuration for [`FarthestPointSampling`].
15#[derive(Clone, Copy, Debug, PartialEq)]
16pub struct FarthestPointSamplingConfig {
17    /// Number of points to keep.
18    pub sample_size: usize,
19    /// Index of the first seed point (the rest are chosen deterministically).
20    pub seed_index: usize,
21}
22
23impl Default for FarthestPointSamplingConfig {
24    fn default() -> Self {
25        Self { sample_size: 1024, seed_index: 0 }
26    }
27}
28
29impl FarthestPointSamplingConfig {
30    /// Creates a config keeping `sample_size` points, seeded from index 0.
31    #[must_use]
32    pub const fn new(sample_size: usize) -> Self {
33        Self { sample_size, seed_index: 0 }
34    }
35}
36
37/// Farthest Point Sampling downsampler.
38#[derive(Clone, Copy, Debug, PartialEq)]
39pub struct FarthestPointSampling {
40    config: FarthestPointSamplingConfig,
41}
42
43impl FarthestPointSampling {
44    /// Creates a sampler from config.
45    #[must_use]
46    pub const fn new(config: FarthestPointSamplingConfig) -> Self {
47        Self { config }
48    }
49
50    /// Returns the sampler config.
51    #[must_use]
52    pub const fn config(&self) -> FarthestPointSamplingConfig {
53        self.config
54    }
55
56    /// Returns the selected point indices in selection order.
57    pub fn select(&self, input: &PointCloud) -> SpatialResult<Vec<usize>> {
58        if self.config.sample_size == 0 {
59            return Err(SpatialError::InvalidArgument(
60                "sample_size must be greater than zero".to_owned(),
61            ));
62        }
63        let len = input.len();
64        if self.config.seed_index >= len.max(1) {
65            return Err(SpatialError::InvalidArgument("seed_index is out of range".to_owned()));
66        }
67        if len == 0 {
68            return Ok(Vec::new());
69        }
70        if self.config.sample_size >= len {
71            return Ok((0..len).collect());
72        }
73
74        let (x, y, z) = input.positions3()?;
75
76        // `min_dist[i]` = squared distance from point i to the nearest selected
77        // point. Seed it from the first chosen point, then repeatedly take the
78        // current maximum and relax the array against the new selection.
79        let mut selected = Vec::with_capacity(self.config.sample_size);
80        let mut min_dist = vec![f32::INFINITY; len];
81        let mut current = self.config.seed_index;
82        selected.push(current);
83
84        for _ in 1..self.config.sample_size {
85            let (cx, cy, cz) = (x[current], y[current], z[current]);
86            let mut best = 0_usize;
87            let mut best_dist = -1.0_f32;
88            for i in 0..len {
89                let dx = x[i] - cx;
90                let dy = y[i] - cy;
91                let dz = z[i] - cz;
92                let d = dx * dx + dy * dy + dz * dz;
93                if d < min_dist[i] {
94                    min_dist[i] = d;
95                }
96                if min_dist[i] > best_dist {
97                    best_dist = min_dist[i];
98                    best = i;
99                }
100            }
101            current = best;
102            selected.push(current);
103        }
104
105        Ok(selected)
106    }
107}
108
109impl PointCloudFilter for FarthestPointSampling {
110    fn name(&self) -> &'static str {
111        "FarthestPointSampling"
112    }
113
114    fn filter(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
115        let indices = self.select(input)?;
116        gather_indices(input, &indices)
117    }
118}
119
120/// Gathers the selected indices into a new cloud, preserving schema.
121fn gather_indices(input: &PointCloud, indices: &[usize]) -> SpatialResult<PointCloud> {
122    let mut buffers = PointBufferSet::new();
123    for field in input.schema().fields() {
124        let source = input.field(&field.name)?;
125        buffers.insert(field.name.clone(), gather_buffer(source, indices));
126    }
127    PointCloud::try_from_parts(input.schema().clone(), buffers, input.metadata().clone())
128}
129
130fn gather_buffer(source: &PointBuffer, indices: &[usize]) -> PointBuffer {
131    match source {
132        PointBuffer::F32(v) => PointBuffer::from_f32(indices.iter().map(|&i| v[i]).collect()),
133        PointBuffer::F64(v) => PointBuffer::F64(indices.iter().map(|&i| v[i]).collect()),
134        PointBuffer::U8(v) => PointBuffer::U8(indices.iter().map(|&i| v[i]).collect()),
135        PointBuffer::U16(v) => PointBuffer::U16(indices.iter().map(|&i| v[i]).collect()),
136        PointBuffer::U32(v) => PointBuffer::U32(indices.iter().map(|&i| v[i]).collect()),
137        PointBuffer::I32(v) => PointBuffer::I32(indices.iter().map(|&i| v[i]).collect()),
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144    use spatialrust_core::{PointCloudBuilder, StandardSchemas};
145
146    fn grid(n: usize) -> PointCloud {
147        let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
148        for i in 0..n {
149            for j in 0..n {
150                builder.push_point([i as f32, j as f32, 0.0]).unwrap();
151            }
152        }
153        builder.build().unwrap()
154    }
155
156    #[test]
157    fn selects_requested_count() {
158        let cloud = grid(10);
159        let out = FarthestPointSampling::new(FarthestPointSamplingConfig::new(16))
160            .filter(&cloud)
161            .unwrap();
162        assert_eq!(out.len(), 16);
163    }
164
165    #[test]
166    fn samples_are_spread_out() {
167        // The four corners of a 10x10 grid should be among the first selections
168        // because FPS maximizes spacing.
169        let cloud = grid(10);
170        let indices =
171            FarthestPointSampling::new(FarthestPointSamplingConfig::new(4)).select(&cloud).unwrap();
172        // Index 0 = (0,0). The next pick must be the opposite corner (9,9)=99.
173        assert_eq!(indices[0], 0);
174        assert_eq!(indices[1], 99);
175    }
176
177    #[test]
178    fn oversampling_returns_all_points() {
179        let cloud = grid(3);
180        let out = FarthestPointSampling::new(FarthestPointSamplingConfig::new(100))
181            .filter(&cloud)
182            .unwrap();
183        assert_eq!(out.len(), cloud.len());
184    }
185
186    #[test]
187    fn rejects_bad_params() {
188        let cloud = grid(3);
189        assert!(FarthestPointSampling::new(FarthestPointSamplingConfig::new(0))
190            .select(&cloud)
191            .is_err());
192        assert!(FarthestPointSampling::new(FarthestPointSamplingConfig {
193            sample_size: 4,
194            seed_index: 999
195        })
196        .select(&cloud)
197        .is_err());
198    }
199}