Skip to main content

spatialrust_filtering/
mls.rs

1//! Moving Least Squares (MLS) surface smoothing.
2//!
3//! For each point a local reference plane is fit to its neighborhood, a bivariate
4//! polynomial height field is fit over that plane (weighted least squares), and
5//! the point is projected onto the polynomial surface. This removes scanner noise
6//! while preserving curvature far better than a plain average, and yields cleaner
7//! normals for downstream estimation.
8
9use spatialrust_core::{
10    FieldSemantic, HasPositions3, PointBuffer, PointBufferSet, PointCloud, SpatialError,
11    SpatialResult,
12};
13use spatialrust_math::{solve_linear_system, symmetric_eigen3, LeastSquaresResult, Mat3, Vec3};
14use spatialrust_search::{KdTree, RadiusSearchIndex};
15
16use crate::filter::PointCloudFilter;
17
18/// Configuration for [`MlsSmoothing`].
19#[derive(Clone, Copy, Debug, PartialEq)]
20pub struct MlsConfig {
21    /// Neighborhood radius used to fit the local surface.
22    pub search_radius: f32,
23    /// Polynomial order of the fitted height field (1 = plane, 2 = quadratic).
24    pub polynomial_order: u8,
25    /// Minimum neighbors required to smooth a point (else it is left in place).
26    pub min_neighbors: usize,
27}
28
29impl Default for MlsConfig {
30    fn default() -> Self {
31        Self { search_radius: 0.1, polynomial_order: 2, min_neighbors: 6 }
32    }
33}
34
35impl MlsConfig {
36    /// Creates a config with the given search radius (quadratic fit).
37    #[must_use]
38    pub const fn with_radius(search_radius: f32) -> Self {
39        Self { search_radius, polynomial_order: 2, min_neighbors: 6 }
40    }
41}
42
43/// Moving Least Squares smoothing filter.
44#[derive(Clone, Copy, Debug, PartialEq)]
45pub struct MlsSmoothing {
46    config: MlsConfig,
47}
48
49impl MlsSmoothing {
50    /// Creates a smoother from config.
51    #[must_use]
52    pub const fn new(config: MlsConfig) -> Self {
53        Self { config }
54    }
55
56    /// Returns the smoother config.
57    #[must_use]
58    pub const fn config(&self) -> MlsConfig {
59        self.config
60    }
61
62    /// Returns the smoothed XYZ positions, one per input point.
63    pub fn smoothed_positions(&self, input: &PointCloud) -> SpatialResult<Vec<Vec3<f32>>> {
64        if self.config.search_radius <= 0.0 || self.config.search_radius.is_nan() {
65            return Err(SpatialError::InvalidArgument("search_radius must be positive".to_owned()));
66        }
67        if self.config.polynomial_order > 2 {
68            return Err(SpatialError::InvalidArgument(
69                "polynomial_order must be 1 or 2".to_owned(),
70            ));
71        }
72
73        let (x, y, z) = input.positions3()?;
74        let len = input.len();
75        let tree = KdTree::from_slices(x, y, z);
76        // Gaussian weight scale: ~1/2 of the radius averages out noise while
77        // still down-weighting the far neighborhood to preserve curvature.
78        let h_sq = (self.config.search_radius / 2.0).powi(2).max(1e-12);
79
80        let mut out = Vec::with_capacity(len);
81        for i in 0..len {
82            let p = Vec3::new(x[i], y[i], z[i]);
83            let neighbors = tree.radius_search(p.x, p.y, p.z, self.config.search_radius);
84            if neighbors.len() < self.config.min_neighbors {
85                out.push(p);
86                continue;
87            }
88            // Leave-one-out: exclude the query point itself so the surface is
89            // determined by its neighbors, which smooths the point's own noise.
90            let pts: Vec<Vec3<f32>> = neighbors
91                .iter()
92                .filter(|n| n.index != i)
93                .map(|n| Vec3::new(x[n.index], y[n.index], z[n.index]))
94                .collect();
95            out.push(project_point(p, &pts, self.config.polynomial_order, h_sq).unwrap_or(p));
96        }
97        Ok(out)
98    }
99}
100
101impl PointCloudFilter for MlsSmoothing {
102    fn name(&self) -> &'static str {
103        "MlsSmoothing"
104    }
105
106    fn filter(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
107        let smoothed = self.smoothed_positions(input)?;
108        build_output(input, &smoothed)
109    }
110}
111
112/// Fits a local frame + height polynomial and returns the projection of `p`.
113fn project_point(p: Vec3<f32>, neighbors: &[Vec3<f32>], order: u8, h_sq: f32) -> Option<Vec3<f32>> {
114    // Weighted centroid and covariance about the query point.
115    let mut sum_w = 0.0_f64;
116    let mut centroid = [0.0_f64; 3];
117    let mut weights = Vec::with_capacity(neighbors.len());
118    for q in neighbors {
119        let d_sq = (*q - p).length_squared();
120        let w = f64::from((-d_sq / h_sq).exp());
121        weights.push(w);
122        sum_w += w;
123        centroid[0] += w * f64::from(q.x);
124        centroid[1] += w * f64::from(q.y);
125        centroid[2] += w * f64::from(q.z);
126    }
127    if sum_w < 1e-12 {
128        return None;
129    }
130    let c = Vec3::new(
131        (centroid[0] / sum_w) as f32,
132        (centroid[1] / sum_w) as f32,
133        (centroid[2] / sum_w) as f32,
134    );
135
136    let mut cov = [[0.0_f64; 3]; 3];
137    for (q, &w) in neighbors.iter().zip(&weights) {
138        let d = *q - c;
139        let d = [f64::from(d.x), f64::from(d.y), f64::from(d.z)];
140        for r in 0..3 {
141            for col in 0..3 {
142                cov[r][col] += w * d[r] * d[col];
143            }
144        }
145    }
146    let covariance = Mat3::<f64>::from_rows(cov[0], cov[1], cov[2]);
147    let eigen = symmetric_eigen3(covariance);
148    // Smallest eigenvector (column 0) is the plane normal.
149    let normal = Vec3::new(
150        eigen.eigenvectors.m[0][0] as f32,
151        eigen.eigenvectors.m[1][0] as f32,
152        eigen.eigenvectors.m[2][0] as f32,
153    )
154    .normalize();
155
156    // In-plane orthonormal basis (u, w).
157    let helper =
158        if normal.x.abs() < 0.9 { Vec3::new(1.0, 0.0, 0.0) } else { Vec3::new(0.0, 1.0, 0.0) };
159    let u = normal.cross(helper).normalize();
160    let v = normal.cross(u);
161
162    // Query point's in-plane coordinates relative to the centroid.
163    let rel = p - c;
164    let qu = rel.dot(u);
165    let qv = rel.dot(v);
166
167    // Fit height(s, t) = sum coeff_k * basis_k(s, t) by weighted least squares.
168    let basis = |s: f32, t: f32| match order {
169        0 => vec![1.0_f64],
170        1 => vec![1.0, f64::from(s), f64::from(t)],
171        _ => vec![
172            1.0,
173            f64::from(s),
174            f64::from(t),
175            f64::from(s * s),
176            f64::from(s * t),
177            f64::from(t * t),
178        ],
179    };
180    let terms = basis(0.0, 0.0).len();
181    if neighbors.len() < terms {
182        // Not enough support for this order: fall back to the plane.
183        return Some(c + scale(u, qu) + scale(v, qv));
184    }
185
186    let mut ata = vec![vec![0.0_f64; terms]; terms];
187    let mut atb = vec![0.0_f64; terms];
188    for (q, &w) in neighbors.iter().zip(&weights) {
189        let d = *q - c;
190        let s = d.dot(u);
191        let t = d.dot(v);
192        let height = f64::from(d.dot(normal));
193        let row = basis(s, t);
194        for a in 0..terms {
195            atb[a] += w * row[a] * height;
196            for b in 0..terms {
197                ata[a][b] += w * row[a] * row[b];
198            }
199        }
200    }
201
202    let coeffs = match solve_linear_system(ata, atb) {
203        LeastSquaresResult::Solved(c) => c,
204        LeastSquaresResult::Singular => return Some(c + scale(u, qu) + scale(v, qv)),
205    };
206    let query_basis = basis(qu, qv);
207    let height: f64 = coeffs.iter().zip(&query_basis).map(|(c, b)| c * b).sum();
208
209    Some(c + scale(u, qu) + scale(v, qv) + scale(normal, height as f32))
210}
211
212fn scale(v: Vec3<f32>, s: f32) -> Vec3<f32> {
213    Vec3::new(v.x * s, v.y * s, v.z * s)
214}
215
216/// Rebuilds the cloud with smoothed positions, preserving every other field.
217fn build_output(input: &PointCloud, positions: &[Vec3<f32>]) -> SpatialResult<PointCloud> {
218    let schema = input.schema().clone();
219    let name_for = |sem| schema.find_semantic(sem).map(|f| f.name.clone());
220    let (xn, yn, zn) = (
221        name_for(FieldSemantic::PositionX),
222        name_for(FieldSemantic::PositionY),
223        name_for(FieldSemantic::PositionZ),
224    );
225
226    let mut buffers = PointBufferSet::new();
227    for field in schema.fields() {
228        let name = &field.name;
229        let buffer = if Some(name) == xn.as_ref() {
230            PointBuffer::from_f32(positions.iter().map(|p| p.x).collect())
231        } else if Some(name) == yn.as_ref() {
232            PointBuffer::from_f32(positions.iter().map(|p| p.y).collect())
233        } else if Some(name) == zn.as_ref() {
234            PointBuffer::from_f32(positions.iter().map(|p| p.z).collect())
235        } else {
236            clone_buffer(input.field(name)?)
237        };
238        buffers.insert(name.clone(), buffer);
239    }
240    PointCloud::try_from_parts(schema, buffers, input.metadata().clone())
241}
242
243fn clone_buffer(buffer: &PointBuffer) -> PointBuffer {
244    match buffer {
245        PointBuffer::F32(v) => PointBuffer::from_f32(v.clone()),
246        PointBuffer::F64(v) => PointBuffer::F64(v.clone()),
247        PointBuffer::U8(v) => PointBuffer::U8(v.clone()),
248        PointBuffer::U16(v) => PointBuffer::U16(v.clone()),
249        PointBuffer::U32(v) => PointBuffer::U32(v.clone()),
250        PointBuffer::I32(v) => PointBuffer::I32(v.clone()),
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use spatialrust_core::{PointCloudBuilder, StandardSchemas};
258
259    /// A noisy plane: MLS should pull points back toward z = 0.
260    #[test]
261    fn flattens_noisy_plane() {
262        // Deterministic pseudo-noise so the test is stable.
263        let mut seed = 12345_u64;
264        let mut noise = || {
265            seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
266            // Top 32 bits mapped to a centered [-0.02, 0.02] perturbation.
267            let unit = (seed >> 32) as u32 as f32 / u32::MAX as f32;
268            (unit * 2.0 - 1.0) * 0.02
269        };
270
271        let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
272        let mut z_before = Vec::new();
273        for i in 0..25 {
274            for j in 0..25 {
275                let z = noise();
276                z_before.push(z);
277                builder.push_point([i as f32 * 0.05, j as f32 * 0.05, z]).unwrap();
278            }
279        }
280        let cloud = builder.build().unwrap();
281
282        // Order-1 (plane) fit is the right tool for a planar surface.
283        let smoother = MlsSmoothing::new(MlsConfig {
284            search_radius: 0.2,
285            polynomial_order: 1,
286            min_neighbors: 6,
287        });
288        let out = smoother.filter(&cloud).unwrap();
289        assert_eq!(out.len(), cloud.len());
290
291        let (_, _, z) = out.positions3().unwrap();
292        let interior = |k: usize| {
293            let (i, j) = (k / 25, k % 25);
294            (4..21).contains(&i) && (4..21).contains(&j)
295        };
296        // Compare the RMS deviation from the true plane (z = 0) before/after.
297        let rms = |get: &dyn Fn(usize) -> f32| {
298            let vals: Vec<f32> = (0..cloud.len()).filter(|&k| interior(k)).map(get).collect();
299            (vals.iter().map(|v| v * v).sum::<f32>() / vals.len() as f32).sqrt()
300        };
301        let before = rms(&|k| z_before[k]);
302        let after = rms(&|k| z[k]);
303        assert!(after < before * 0.6, "MLS did not flatten: rms {after} vs {before}");
304    }
305
306    #[test]
307    fn rejects_bad_params() {
308        let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
309        builder.push_point([0.0, 0.0, 0.0]).unwrap();
310        let cloud = builder.build().unwrap();
311        assert!(MlsSmoothing::new(MlsConfig::with_radius(0.0)).smoothed_positions(&cloud).is_err());
312        assert!(MlsSmoothing::new(MlsConfig {
313            search_radius: 0.1,
314            polynomial_order: 3,
315            min_neighbors: 6
316        })
317        .smoothed_positions(&cloud)
318        .is_err());
319    }
320}