1use spatialrust_core::{
10 FieldSemantic, HasPositions3, PointBuffer, PointBufferSet, PointCloud, SpatialError,
11 SpatialResult,
12};
13use spatialrust_math::{solve_linear_system, symmetric_eigen3, LeastSquaresResult, Mat3, Vec3};
14use spatialrust_search::{KdTree, RadiusSearchIndex};
15
16use crate::filter::PointCloudFilter;
17
18#[derive(Clone, Copy, Debug, PartialEq)]
20pub struct MlsConfig {
21 pub search_radius: f32,
23 pub polynomial_order: u8,
25 pub min_neighbors: usize,
27}
28
29impl Default for MlsConfig {
30 fn default() -> Self {
31 Self { search_radius: 0.1, polynomial_order: 2, min_neighbors: 6 }
32 }
33}
34
35impl MlsConfig {
36 #[must_use]
38 pub const fn with_radius(search_radius: f32) -> Self {
39 Self { search_radius, polynomial_order: 2, min_neighbors: 6 }
40 }
41}
42
43#[derive(Clone, Copy, Debug, PartialEq)]
45pub struct MlsSmoothing {
46 config: MlsConfig,
47}
48
49impl MlsSmoothing {
50 #[must_use]
52 pub const fn new(config: MlsConfig) -> Self {
53 Self { config }
54 }
55
56 #[must_use]
58 pub const fn config(&self) -> MlsConfig {
59 self.config
60 }
61
62 pub fn smoothed_positions(&self, input: &PointCloud) -> SpatialResult<Vec<Vec3<f32>>> {
64 if self.config.search_radius <= 0.0 || self.config.search_radius.is_nan() {
65 return Err(SpatialError::InvalidArgument("search_radius must be positive".to_owned()));
66 }
67 if self.config.polynomial_order > 2 {
68 return Err(SpatialError::InvalidArgument(
69 "polynomial_order must be 1 or 2".to_owned(),
70 ));
71 }
72
73 let (x, y, z) = input.positions3()?;
74 let len = input.len();
75 let tree = KdTree::from_slices(x, y, z);
76 let h_sq = (self.config.search_radius / 2.0).powi(2).max(1e-12);
79
80 let mut out = Vec::with_capacity(len);
81 for i in 0..len {
82 let p = Vec3::new(x[i], y[i], z[i]);
83 let neighbors = tree.radius_search(p.x, p.y, p.z, self.config.search_radius);
84 if neighbors.len() < self.config.min_neighbors {
85 out.push(p);
86 continue;
87 }
88 let pts: Vec<Vec3<f32>> = neighbors
91 .iter()
92 .filter(|n| n.index != i)
93 .map(|n| Vec3::new(x[n.index], y[n.index], z[n.index]))
94 .collect();
95 out.push(project_point(p, &pts, self.config.polynomial_order, h_sq).unwrap_or(p));
96 }
97 Ok(out)
98 }
99}
100
101impl PointCloudFilter for MlsSmoothing {
102 fn name(&self) -> &'static str {
103 "MlsSmoothing"
104 }
105
106 fn filter(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
107 let smoothed = self.smoothed_positions(input)?;
108 build_output(input, &smoothed)
109 }
110}
111
112fn project_point(p: Vec3<f32>, neighbors: &[Vec3<f32>], order: u8, h_sq: f32) -> Option<Vec3<f32>> {
114 let mut sum_w = 0.0_f64;
116 let mut centroid = [0.0_f64; 3];
117 let mut weights = Vec::with_capacity(neighbors.len());
118 for q in neighbors {
119 let d_sq = (*q - p).length_squared();
120 let w = f64::from((-d_sq / h_sq).exp());
121 weights.push(w);
122 sum_w += w;
123 centroid[0] += w * f64::from(q.x);
124 centroid[1] += w * f64::from(q.y);
125 centroid[2] += w * f64::from(q.z);
126 }
127 if sum_w < 1e-12 {
128 return None;
129 }
130 let c = Vec3::new(
131 (centroid[0] / sum_w) as f32,
132 (centroid[1] / sum_w) as f32,
133 (centroid[2] / sum_w) as f32,
134 );
135
136 let mut cov = [[0.0_f64; 3]; 3];
137 for (q, &w) in neighbors.iter().zip(&weights) {
138 let d = *q - c;
139 let d = [f64::from(d.x), f64::from(d.y), f64::from(d.z)];
140 for r in 0..3 {
141 for col in 0..3 {
142 cov[r][col] += w * d[r] * d[col];
143 }
144 }
145 }
146 let covariance = Mat3::<f64>::from_rows(cov[0], cov[1], cov[2]);
147 let eigen = symmetric_eigen3(covariance);
148 let normal = Vec3::new(
150 eigen.eigenvectors.m[0][0] as f32,
151 eigen.eigenvectors.m[1][0] as f32,
152 eigen.eigenvectors.m[2][0] as f32,
153 )
154 .normalize();
155
156 let helper =
158 if normal.x.abs() < 0.9 { Vec3::new(1.0, 0.0, 0.0) } else { Vec3::new(0.0, 1.0, 0.0) };
159 let u = normal.cross(helper).normalize();
160 let v = normal.cross(u);
161
162 let rel = p - c;
164 let qu = rel.dot(u);
165 let qv = rel.dot(v);
166
167 let basis = |s: f32, t: f32| match order {
169 0 => vec![1.0_f64],
170 1 => vec![1.0, f64::from(s), f64::from(t)],
171 _ => vec![
172 1.0,
173 f64::from(s),
174 f64::from(t),
175 f64::from(s * s),
176 f64::from(s * t),
177 f64::from(t * t),
178 ],
179 };
180 let terms = basis(0.0, 0.0).len();
181 if neighbors.len() < terms {
182 return Some(c + scale(u, qu) + scale(v, qv));
184 }
185
186 let mut ata = vec![vec![0.0_f64; terms]; terms];
187 let mut atb = vec![0.0_f64; terms];
188 for (q, &w) in neighbors.iter().zip(&weights) {
189 let d = *q - c;
190 let s = d.dot(u);
191 let t = d.dot(v);
192 let height = f64::from(d.dot(normal));
193 let row = basis(s, t);
194 for a in 0..terms {
195 atb[a] += w * row[a] * height;
196 for b in 0..terms {
197 ata[a][b] += w * row[a] * row[b];
198 }
199 }
200 }
201
202 let coeffs = match solve_linear_system(ata, atb) {
203 LeastSquaresResult::Solved(c) => c,
204 LeastSquaresResult::Singular => return Some(c + scale(u, qu) + scale(v, qv)),
205 };
206 let query_basis = basis(qu, qv);
207 let height: f64 = coeffs.iter().zip(&query_basis).map(|(c, b)| c * b).sum();
208
209 Some(c + scale(u, qu) + scale(v, qv) + scale(normal, height as f32))
210}
211
212fn scale(v: Vec3<f32>, s: f32) -> Vec3<f32> {
213 Vec3::new(v.x * s, v.y * s, v.z * s)
214}
215
216fn build_output(input: &PointCloud, positions: &[Vec3<f32>]) -> SpatialResult<PointCloud> {
218 let schema = input.schema().clone();
219 let name_for = |sem| schema.find_semantic(sem).map(|f| f.name.clone());
220 let (xn, yn, zn) = (
221 name_for(FieldSemantic::PositionX),
222 name_for(FieldSemantic::PositionY),
223 name_for(FieldSemantic::PositionZ),
224 );
225
226 let mut buffers = PointBufferSet::new();
227 for field in schema.fields() {
228 let name = &field.name;
229 let buffer = if Some(name) == xn.as_ref() {
230 PointBuffer::from_f32(positions.iter().map(|p| p.x).collect())
231 } else if Some(name) == yn.as_ref() {
232 PointBuffer::from_f32(positions.iter().map(|p| p.y).collect())
233 } else if Some(name) == zn.as_ref() {
234 PointBuffer::from_f32(positions.iter().map(|p| p.z).collect())
235 } else {
236 clone_buffer(input.field(name)?)
237 };
238 buffers.insert(name.clone(), buffer);
239 }
240 PointCloud::try_from_parts(schema, buffers, input.metadata().clone())
241}
242
243fn clone_buffer(buffer: &PointBuffer) -> PointBuffer {
244 match buffer {
245 PointBuffer::F32(v) => PointBuffer::from_f32(v.clone()),
246 PointBuffer::F64(v) => PointBuffer::F64(v.clone()),
247 PointBuffer::U8(v) => PointBuffer::U8(v.clone()),
248 PointBuffer::U16(v) => PointBuffer::U16(v.clone()),
249 PointBuffer::U32(v) => PointBuffer::U32(v.clone()),
250 PointBuffer::I32(v) => PointBuffer::I32(v.clone()),
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use spatialrust_core::{PointCloudBuilder, StandardSchemas};
258
259 #[test]
261 fn flattens_noisy_plane() {
262 let mut seed = 12345_u64;
264 let mut noise = || {
265 seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
266 let unit = (seed >> 32) as u32 as f32 / u32::MAX as f32;
268 (unit * 2.0 - 1.0) * 0.02
269 };
270
271 let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
272 let mut z_before = Vec::new();
273 for i in 0..25 {
274 for j in 0..25 {
275 let z = noise();
276 z_before.push(z);
277 builder.push_point([i as f32 * 0.05, j as f32 * 0.05, z]).unwrap();
278 }
279 }
280 let cloud = builder.build().unwrap();
281
282 let smoother = MlsSmoothing::new(MlsConfig {
284 search_radius: 0.2,
285 polynomial_order: 1,
286 min_neighbors: 6,
287 });
288 let out = smoother.filter(&cloud).unwrap();
289 assert_eq!(out.len(), cloud.len());
290
291 let (_, _, z) = out.positions3().unwrap();
292 let interior = |k: usize| {
293 let (i, j) = (k / 25, k % 25);
294 (4..21).contains(&i) && (4..21).contains(&j)
295 };
296 let rms = |get: &dyn Fn(usize) -> f32| {
298 let vals: Vec<f32> = (0..cloud.len()).filter(|&k| interior(k)).map(get).collect();
299 (vals.iter().map(|v| v * v).sum::<f32>() / vals.len() as f32).sqrt()
300 };
301 let before = rms(&|k| z_before[k]);
302 let after = rms(&|k| z[k]);
303 assert!(after < before * 0.6, "MLS did not flatten: rms {after} vs {before}");
304 }
305
306 #[test]
307 fn rejects_bad_params() {
308 let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
309 builder.push_point([0.0, 0.0, 0.0]).unwrap();
310 let cloud = builder.build().unwrap();
311 assert!(MlsSmoothing::new(MlsConfig::with_radius(0.0)).smoothed_positions(&cloud).is_err());
312 assert!(MlsSmoothing::new(MlsConfig {
313 search_radius: 0.1,
314 polynomial_order: 3,
315 min_neighbors: 6
316 })
317 .smoothed_positions(&cloud)
318 .is_err());
319 }
320}