spatialrust_filtering/
fps.rs1use spatialrust_core::{
9 HasPositions3, PointBuffer, PointBufferSet, PointCloud, SpatialError, SpatialResult,
10};
11
12use crate::filter::PointCloudFilter;
13
14#[derive(Clone, Copy, Debug, PartialEq)]
16pub struct FarthestPointSamplingConfig {
17 pub sample_size: usize,
19 pub seed_index: usize,
21}
22
23impl Default for FarthestPointSamplingConfig {
24 fn default() -> Self {
25 Self { sample_size: 1024, seed_index: 0 }
26 }
27}
28
29impl FarthestPointSamplingConfig {
30 #[must_use]
32 pub const fn new(sample_size: usize) -> Self {
33 Self { sample_size, seed_index: 0 }
34 }
35}
36
37#[derive(Clone, Copy, Debug, PartialEq)]
39pub struct FarthestPointSampling {
40 config: FarthestPointSamplingConfig,
41}
42
43impl FarthestPointSampling {
44 #[must_use]
46 pub const fn new(config: FarthestPointSamplingConfig) -> Self {
47 Self { config }
48 }
49
50 #[must_use]
52 pub const fn config(&self) -> FarthestPointSamplingConfig {
53 self.config
54 }
55
56 pub fn select(&self, input: &PointCloud) -> SpatialResult<Vec<usize>> {
58 if self.config.sample_size == 0 {
59 return Err(SpatialError::InvalidArgument(
60 "sample_size must be greater than zero".to_owned(),
61 ));
62 }
63 let len = input.len();
64 if self.config.seed_index >= len.max(1) {
65 return Err(SpatialError::InvalidArgument("seed_index is out of range".to_owned()));
66 }
67 if len == 0 {
68 return Ok(Vec::new());
69 }
70 if self.config.sample_size >= len {
71 return Ok((0..len).collect());
72 }
73
74 let (x, y, z) = input.positions3()?;
75
76 let mut selected = Vec::with_capacity(self.config.sample_size);
80 let mut min_dist = vec![f32::INFINITY; len];
81 let mut current = self.config.seed_index;
82 selected.push(current);
83
84 for _ in 1..self.config.sample_size {
85 let (cx, cy, cz) = (x[current], y[current], z[current]);
86 let mut best = 0_usize;
87 let mut best_dist = -1.0_f32;
88 for i in 0..len {
89 let dx = x[i] - cx;
90 let dy = y[i] - cy;
91 let dz = z[i] - cz;
92 let d = dx * dx + dy * dy + dz * dz;
93 if d < min_dist[i] {
94 min_dist[i] = d;
95 }
96 if min_dist[i] > best_dist {
97 best_dist = min_dist[i];
98 best = i;
99 }
100 }
101 current = best;
102 selected.push(current);
103 }
104
105 Ok(selected)
106 }
107}
108
109impl PointCloudFilter for FarthestPointSampling {
110 fn name(&self) -> &'static str {
111 "FarthestPointSampling"
112 }
113
114 fn filter(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
115 let indices = self.select(input)?;
116 gather_indices(input, &indices)
117 }
118}
119
120fn gather_indices(input: &PointCloud, indices: &[usize]) -> SpatialResult<PointCloud> {
122 let mut buffers = PointBufferSet::new();
123 for field in input.schema().fields() {
124 let source = input.field(&field.name)?;
125 buffers.insert(field.name.clone(), gather_buffer(source, indices));
126 }
127 PointCloud::try_from_parts(input.schema().clone(), buffers, input.metadata().clone())
128}
129
130fn gather_buffer(source: &PointBuffer, indices: &[usize]) -> PointBuffer {
131 match source {
132 PointBuffer::F32(v) => PointBuffer::from_f32(indices.iter().map(|&i| v[i]).collect()),
133 PointBuffer::F64(v) => PointBuffer::F64(indices.iter().map(|&i| v[i]).collect()),
134 PointBuffer::U8(v) => PointBuffer::U8(indices.iter().map(|&i| v[i]).collect()),
135 PointBuffer::U16(v) => PointBuffer::U16(indices.iter().map(|&i| v[i]).collect()),
136 PointBuffer::U32(v) => PointBuffer::U32(indices.iter().map(|&i| v[i]).collect()),
137 PointBuffer::I32(v) => PointBuffer::I32(indices.iter().map(|&i| v[i]).collect()),
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144 use spatialrust_core::{PointCloudBuilder, StandardSchemas};
145
146 fn grid(n: usize) -> PointCloud {
147 let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
148 for i in 0..n {
149 for j in 0..n {
150 builder.push_point([i as f32, j as f32, 0.0]).unwrap();
151 }
152 }
153 builder.build().unwrap()
154 }
155
156 #[test]
157 fn selects_requested_count() {
158 let cloud = grid(10);
159 let out = FarthestPointSampling::new(FarthestPointSamplingConfig::new(16))
160 .filter(&cloud)
161 .unwrap();
162 assert_eq!(out.len(), 16);
163 }
164
165 #[test]
166 fn samples_are_spread_out() {
167 let cloud = grid(10);
170 let indices =
171 FarthestPointSampling::new(FarthestPointSamplingConfig::new(4)).select(&cloud).unwrap();
172 assert_eq!(indices[0], 0);
174 assert_eq!(indices[1], 99);
175 }
176
177 #[test]
178 fn oversampling_returns_all_points() {
179 let cloud = grid(3);
180 let out = FarthestPointSampling::new(FarthestPointSamplingConfig::new(100))
181 .filter(&cloud)
182 .unwrap();
183 assert_eq!(out.len(), cloud.len());
184 }
185
186 #[test]
187 fn rejects_bad_params() {
188 let cloud = grid(3);
189 assert!(FarthestPointSampling::new(FarthestPointSamplingConfig::new(0))
190 .select(&cloud)
191 .is_err());
192 assert!(FarthestPointSampling::new(FarthestPointSamplingConfig {
193 sample_size: 4,
194 seed_index: 999
195 })
196 .select(&cloud)
197 .is_err());
198 }
199}