Skip to main content

spatialrust_features/
normal.rs

1use spatialrust_core::{
2    DType, FieldSemantic, HasPositions3, PointBuffer, PointBufferSet, PointCloud, PointField,
3    PointSchema, SpatialError, SpatialResult,
4};
5use spatialrust_math::{symmetric_eigen3, Mat3, Vec3};
6use spatialrust_search::{KdTree, Neighbor, RadiusSearchIndex};
7
8use crate::estimator::FeatureEstimator;
9
10/// Configuration for covariance-based normal estimation.
11#[derive(Clone, Copy, Debug, PartialEq)]
12pub struct NormalEstimationConfig {
13    /// Number of nearest neighbors to use when `search_radius` is `None`.
14    pub k_neighbors: usize,
15    /// Optional radius search instead of fixed `k`.
16    pub search_radius: Option<f32>,
17    /// Minimum number of neighbors required to estimate a valid normal.
18    pub min_neighbors: usize,
19    /// Optional viewpoint used to orient normals consistently.
20    pub viewpoint: Option<Vec3<f32>>,
21}
22
23impl Default for NormalEstimationConfig {
24    fn default() -> Self {
25        Self { k_neighbors: 20, search_radius: None, min_neighbors: 3, viewpoint: None }
26    }
27}
28
29impl NormalEstimationConfig {
30    /// Creates a k-NN normal estimation config.
31    #[must_use]
32    pub const fn k_neighbors(k_neighbors: usize) -> Self {
33        Self { k_neighbors, search_radius: None, min_neighbors: 3, viewpoint: None }
34    }
35}
36
37/// Result metadata for normal estimation.
38#[derive(Clone, Debug, Default, PartialEq, Eq)]
39pub struct NormalEstimationResult {
40    /// Number of points with valid normals.
41    pub valid_count: usize,
42    /// Number of points with invalid normals.
43    pub invalid_count: usize,
44}
45
46/// Covariance-based normal estimator.
47#[derive(Clone, Debug, PartialEq)]
48pub struct NormalEstimator {
49    config: NormalEstimationConfig,
50}
51
52impl NormalEstimator {
53    /// Creates a normal estimator from config.
54    #[must_use]
55    pub const fn new(config: NormalEstimationConfig) -> Self {
56        Self { config }
57    }
58
59    /// Returns the estimator config.
60    #[must_use]
61    pub const fn config(&self) -> NormalEstimationConfig {
62        self.config
63    }
64
65    /// Estimates normals and curvature, returning output cloud and diagnostics.
66    pub fn estimate_with_diagnostics(
67        &self,
68        input: &PointCloud,
69    ) -> SpatialResult<(PointCloud, NormalEstimationResult)> {
70        if input.is_empty() {
71            return Ok((input.clone(), NormalEstimationResult::default()));
72        }
73        if self.config.search_radius.is_some_and(|radius| radius < 0.0) {
74            return Err(SpatialError::InvalidArgument("search_radius must be non-negative".into()));
75        }
76
77        let (x, y, z) = input.positions3()?;
78        let tree = KdTree::from_slices(x, y, z);
79
80        let mut nx = vec![f32::NAN; input.len()];
81        let mut ny = vec![f32::NAN; input.len()];
82        let mut nz = vec![f32::NAN; input.len()];
83        let mut curvature = vec![0.0_f32; input.len()];
84        let mut valid_count = 0usize;
85        let mut invalid_count = 0usize;
86
87        let worker_count = normal_worker_count(input.len());
88        if worker_count == 1 {
89            let chunk = estimate_normal_range(self.config, &tree, x, y, z, 0, input.len());
90            nx = chunk.nx;
91            ny = chunk.ny;
92            nz = chunk.nz;
93            curvature = chunk.curvature;
94            valid_count = chunk.valid_count;
95            invalid_count = chunk.invalid_count;
96        } else {
97            let chunk_size = input.len().div_ceil(worker_count);
98            let chunks = std::thread::scope(|scope| {
99                let mut handles = Vec::new();
100                let config = self.config;
101                let tree_ref = &tree;
102                for start in (0..input.len()).step_by(chunk_size) {
103                    let end = (start + chunk_size).min(input.len());
104                    handles.push(scope.spawn(move || {
105                        estimate_normal_range(config, tree_ref, x, y, z, start, end)
106                    }));
107                }
108
109                handles
110                    .into_iter()
111                    .map(|handle| handle.join().expect("normal estimation worker panicked"))
112                    .collect::<Vec<_>>()
113            });
114
115            for chunk in chunks {
116                let end = chunk.start + chunk.nx.len();
117                nx[chunk.start..end].copy_from_slice(&chunk.nx);
118                ny[chunk.start..end].copy_from_slice(&chunk.ny);
119                nz[chunk.start..end].copy_from_slice(&chunk.nz);
120                curvature[chunk.start..end].copy_from_slice(&chunk.curvature);
121                valid_count += chunk.valid_count;
122                invalid_count += chunk.invalid_count;
123            }
124        }
125
126        let output = build_output_cloud(input, nx, ny, nz, curvature)?;
127        Ok((output, NormalEstimationResult { valid_count, invalid_count }))
128    }
129}
130
131impl FeatureEstimator for NormalEstimator {
132    fn name(&self) -> &'static str {
133        "NormalEstimator"
134    }
135
136    fn estimate(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
137        self.estimate_with_diagnostics(input).map(|(cloud, _)| cloud)
138    }
139}
140
141#[derive(Debug)]
142struct NormalChunk {
143    start: usize,
144    nx: Vec<f32>,
145    ny: Vec<f32>,
146    nz: Vec<f32>,
147    curvature: Vec<f32>,
148    valid_count: usize,
149    invalid_count: usize,
150}
151
152fn normal_worker_count(len: usize) -> usize {
153    let available = std::thread::available_parallelism().map_or(1, |count| count.get());
154    let useful = (len / 16_384).max(1);
155    available.min(useful)
156}
157
158fn estimate_normal_range(
159    config: NormalEstimationConfig,
160    tree: &KdTree,
161    x: &[f32],
162    y: &[f32],
163    z: &[f32],
164    start: usize,
165    end: usize,
166) -> NormalChunk {
167    let len = end - start;
168    let mut nx = vec![f32::NAN; len];
169    let mut ny = vec![f32::NAN; len];
170    let mut nz = vec![f32::NAN; len];
171    let mut curvature = vec![0.0_f32; len];
172    let mut valid_count = 0usize;
173    let mut invalid_count = 0usize;
174    let mut neighbor_buffer = Vec::with_capacity(config.k_neighbors.saturating_add(1));
175    let mut index_buffer = Vec::with_capacity(config.k_neighbors);
176
177    for index in start..end {
178        query_neighbors_into(config, tree, x, y, z, index, &mut neighbor_buffer, &mut index_buffer);
179        let local = index - start;
180        if index_buffer.len() < config.min_neighbors {
181            invalid_count += 1;
182            continue;
183        }
184
185        let Some((normal, curv)) = estimate_normal_from_neighbors(x, y, z, index, &index_buffer)
186        else {
187            invalid_count += 1;
188            continue;
189        };
190
191        let oriented = if let Some(viewpoint) = config.viewpoint {
192            orient_normal_towards_viewpoint(normal, point_xyz(x, y, z, index), viewpoint)
193        } else {
194            normal
195        };
196
197        nx[local] = oriented.x;
198        ny[local] = oriented.y;
199        nz[local] = oriented.z;
200        curvature[local] = curv;
201        valid_count += 1;
202    }
203
204    NormalChunk { start, nx, ny, nz, curvature, valid_count, invalid_count }
205}
206
207fn query_neighbors_into(
208    config: NormalEstimationConfig,
209    tree: &KdTree,
210    x: &[f32],
211    y: &[f32],
212    z: &[f32],
213    index: usize,
214    neighbor_buffer: &mut Vec<Neighbor>,
215    index_buffer: &mut Vec<usize>,
216) {
217    index_buffer.clear();
218    if let Some(radius) = config.search_radius {
219        for neighbor in tree.radius_search(x[index], y[index], z[index], radius) {
220            if neighbor.index != index {
221                index_buffer.push(neighbor.index);
222            }
223        }
224    } else {
225        tree.nearest_k_unsorted_into(
226            x[index],
227            y[index],
228            z[index],
229            config.k_neighbors.saturating_add(1),
230            neighbor_buffer,
231        );
232        for neighbor in neighbor_buffer.iter() {
233            if neighbor.index != index {
234                index_buffer.push(neighbor.index);
235                if index_buffer.len() == config.k_neighbors {
236                    break;
237                }
238            }
239        }
240    }
241}
242
243/// Orients a normal to point towards the viewpoint when possible.
244#[must_use]
245pub fn orient_normal_towards_viewpoint(
246    mut normal: Vec3<f32>,
247    point: Vec3<f32>,
248    viewpoint: Vec3<f32>,
249) -> Vec3<f32> {
250    let view_direction =
251        Vec3::new(viewpoint.x - point.x, viewpoint.y - point.y, viewpoint.z - point.z);
252    if normal.dot(view_direction) < 0.0 {
253        normal.x = -normal.x;
254        normal.y = -normal.y;
255        normal.z = -normal.z;
256    }
257    normal.normalize()
258}
259
260fn point_xyz(x: &[f32], y: &[f32], z: &[f32], index: usize) -> Vec3<f32> {
261    Vec3::new(x[index], y[index], z[index])
262}
263
264fn estimate_normal_from_neighbors(
265    x: &[f32],
266    y: &[f32],
267    z: &[f32],
268    _center_index: usize,
269    neighbors: &[usize],
270) -> Option<(Vec3<f32>, f32)> {
271    let mut mean_x = 0.0_f32;
272    let mut mean_y = 0.0_f32;
273    let mut mean_z = 0.0_f32;
274    for &index in neighbors {
275        mean_x += x[index];
276        mean_y += y[index];
277        mean_z += z[index];
278    }
279    let count = neighbors.len() as f32;
280    mean_x /= count;
281    mean_y /= count;
282    mean_z /= count;
283
284    let mut c00 = 0.0_f32;
285    let mut c11 = 0.0_f32;
286    let mut c22 = 0.0_f32;
287    let mut c01 = 0.0_f32;
288    let mut c02 = 0.0_f32;
289    let mut c12 = 0.0_f32;
290    for &index in neighbors {
291        let dx = x[index] - mean_x;
292        let dy = y[index] - mean_y;
293        let dz = z[index] - mean_z;
294        c00 += dx * dx;
295        c11 += dy * dy;
296        c22 += dz * dz;
297        c01 += dx * dy;
298        c02 += dx * dz;
299        c12 += dy * dz;
300    }
301    let inv = 1.0 / count;
302    smallest_eigenpair_for_covariance(
303        c00 * inv,
304        c11 * inv,
305        c22 * inv,
306        c01 * inv,
307        c02 * inv,
308        c12 * inv,
309    )
310}
311
312fn smallest_eigenpair_for_covariance(
313    c00: f32,
314    c11: f32,
315    c22: f32,
316    c01: f32,
317    c02: f32,
318    c12: f32,
319) -> Option<(Vec3<f32>, f32)> {
320    let eigenvalues = symmetric_eigenvalues3(c00, c11, c22, c01, c02, c12);
321    let lambda = eigenvalues[0];
322    let normal =
323        eigenvector_for_eigenvalue(c00, c11, c22, c01, c02, c12, lambda).unwrap_or_else(|| {
324            let covariance = Mat3::<f64>::from_rows(
325                [c00 as f64, c01 as f64, c02 as f64],
326                [c01 as f64, c11 as f64, c12 as f64],
327                [c02 as f64, c12 as f64, c22 as f64],
328            );
329            let eigen = symmetric_eigen3(covariance);
330            Vec3::new(
331                eigen.eigenvectors.m[0][0] as f32,
332                eigen.eigenvectors.m[1][0] as f32,
333                eigen.eigenvectors.m[2][0] as f32,
334            )
335            .normalize()
336        });
337
338    let sum = eigenvalues[0] + eigenvalues[1] + eigenvalues[2];
339    let curvature = if sum > 0.0 { eigenvalues[0] / sum } else { 0.0 };
340    Some((normal.normalize(), curvature))
341}
342
343fn symmetric_eigenvalues3(c00: f32, c11: f32, c22: f32, c01: f32, c02: f32, c12: f32) -> [f32; 3] {
344    let p1 = c01 * c01 + c02 * c02 + c12 * c12;
345    if p1 <= f32::EPSILON {
346        let mut values = [c00, c11, c22];
347        values.sort_by(|a, b| a.partial_cmp(b).unwrap());
348        return values;
349    }
350
351    let q = (c00 + c11 + c22) / 3.0;
352    let b00 = c00 - q;
353    let b11 = c11 - q;
354    let b22 = c22 - q;
355    let p2 = b00 * b00 + b11 * b11 + b22 * b22 + 2.0 * p1;
356    let p = (p2 / 6.0).sqrt();
357    if p <= f32::EPSILON {
358        return [q, q, q];
359    }
360
361    let inv_p = 1.0 / p;
362    let n00 = b00 * inv_p;
363    let n11 = b11 * inv_p;
364    let n22 = b22 * inv_p;
365    let n01 = c01 * inv_p;
366    let n02 = c02 * inv_p;
367    let n12 = c12 * inv_p;
368    let det = n00 * (n11 * n22 - n12 * n12) - n01 * (n01 * n22 - n12 * n02)
369        + n02 * (n01 * n12 - n11 * n02);
370    let r = (det * 0.5).clamp(-1.0, 1.0);
371    let phi = r.acos() / 3.0;
372
373    let largest = q + 2.0 * p * phi.cos();
374    let smallest = q + 2.0 * p * (phi + 2.0 * std::f32::consts::PI / 3.0).cos();
375    let middle = 3.0 * q - largest - smallest;
376    let mut values = [smallest, middle, largest];
377    values.sort_by(|a, b| a.partial_cmp(b).unwrap());
378    values
379}
380
381fn eigenvector_for_eigenvalue(
382    c00: f32,
383    c11: f32,
384    c22: f32,
385    c01: f32,
386    c02: f32,
387    c12: f32,
388    lambda: f32,
389) -> Option<Vec3<f32>> {
390    let row0 = Vec3::new(c00 - lambda, c01, c02);
391    let row1 = Vec3::new(c01, c11 - lambda, c12);
392    let row2 = Vec3::new(c02, c12, c22 - lambda);
393
394    let candidates = [row0.cross(row1), row0.cross(row2), row1.cross(row2)];
395    let mut best = candidates[0];
396    let mut best_norm = best.length_squared();
397    for candidate in candidates.into_iter().skip(1) {
398        let norm = candidate.length_squared();
399        if norm > best_norm {
400            best = candidate;
401            best_norm = norm;
402        }
403    }
404
405    if best_norm <= 1e-24 {
406        None
407    } else {
408        Some(best.normalize())
409    }
410}
411
412pub(crate) fn build_output_cloud(
413    input: &PointCloud,
414    nx: Vec<f32>,
415    ny: Vec<f32>,
416    nz: Vec<f32>,
417    curvature: Vec<f32>,
418) -> SpatialResult<PointCloud> {
419    let mut schema = input.schema().clone();
420    ensure_field(&mut schema, "normal_x", FieldSemantic::NormalX, DType::F32);
421    ensure_field(&mut schema, "normal_y", FieldSemantic::NormalY, DType::F32);
422    ensure_field(&mut schema, "normal_z", FieldSemantic::NormalZ, DType::F32);
423    ensure_field(&mut schema, "curvature", FieldSemantic::Curvature, DType::F32);
424
425    let mut buffers = PointBufferSet::new();
426    for field in input.schema().fields() {
427        let source = input.field(&field.name)?;
428        buffers.insert(field.name.clone(), clone_buffer(source)?);
429    }
430    buffers.insert("normal_x".to_owned(), PointBuffer::from_f32(nx));
431    buffers.insert("normal_y".to_owned(), PointBuffer::from_f32(ny));
432    buffers.insert("normal_z".to_owned(), PointBuffer::from_f32(nz));
433    buffers.insert("curvature".to_owned(), PointBuffer::from_f32(curvature));
434
435    PointCloud::try_from_parts(schema, buffers, input.metadata().clone())
436}
437
438fn ensure_field(schema: &mut PointSchema, name: &str, semantic: FieldSemantic, dtype: DType) {
439    if schema.find_semantic(semantic).is_none() {
440        *schema = schema.clone().with_field(PointField::scalar(name, semantic, dtype));
441    }
442}
443
444fn clone_buffer(buffer: &PointBuffer) -> SpatialResult<PointBuffer> {
445    Ok(match buffer {
446        PointBuffer::F32(values) => PointBuffer::from_f32(values.clone()),
447        PointBuffer::F64(values) => PointBuffer::F64(values.clone()),
448        PointBuffer::U8(values) => PointBuffer::U8(values.clone()),
449        PointBuffer::U16(values) => PointBuffer::U16(values.clone()),
450        PointBuffer::U32(values) => PointBuffer::U32(values.clone()),
451        PointBuffer::I32(values) => PointBuffer::I32(values.clone()),
452    })
453}
454
455#[cfg(test)]
456mod tests {
457    use super::{orient_normal_towards_viewpoint, NormalEstimationConfig, NormalEstimator};
458    use crate::FeatureEstimator;
459    use spatialrust_core::{HasNormals3, PointCloudBuilder, StandardSchemas};
460    use spatialrust_math::Vec3;
461
462    fn plane_cloud() -> spatialrust_core::PointCloud {
463        let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
464        for x in 0..5 {
465            for y in 0..5 {
466                builder.push_point([x as f32, y as f32, 0.0]).unwrap();
467            }
468        }
469        builder.build().unwrap()
470    }
471
472    fn tilted_plane_cloud() -> spatialrust_core::PointCloud {
473        let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
474        for x in 0..7 {
475            for y in 0..7 {
476                let fx = x as f32 * 0.2;
477                let fy = y as f32 * 0.2;
478                let z = 0.2 * fx - 0.3 * fy + 0.1;
479                builder.push_point([fx, fy, z]).unwrap();
480            }
481        }
482        builder.build().unwrap()
483    }
484
485    #[test]
486    fn estimates_plane_normals_upwards() {
487        let input = plane_cloud();
488        let estimator = NormalEstimator::new(NormalEstimationConfig {
489            k_neighbors: 8,
490            min_neighbors: 3,
491            viewpoint: Some(Vec3::new(0.0, 0.0, 10.0)),
492            ..NormalEstimationConfig::default()
493        });
494        let (output, stats) = estimator.estimate_with_diagnostics(&input).unwrap();
495        assert_eq!(stats.valid_count, input.len());
496        assert_eq!(stats.invalid_count, 0);
497
498        let (_, _, nz) = output.normals3().unwrap();
499        for value in nz {
500            assert!((*value - 1.0).abs() < 0.1, "expected upward normal, got {value}");
501        }
502    }
503
504    #[test]
505    fn estimates_tilted_plane_normals() {
506        let input = tilted_plane_cloud();
507        let estimator = NormalEstimator::new(NormalEstimationConfig {
508            k_neighbors: 12,
509            min_neighbors: 3,
510            viewpoint: Some(Vec3::new(0.0, 0.0, 10.0)),
511            ..NormalEstimationConfig::default()
512        });
513        let output = estimator.estimate(&input).unwrap();
514        let (nx, ny, nz) = output.normals3().unwrap();
515        let expected = Vec3::new(-0.2, 0.3, 1.0).normalize();
516
517        for index in 0..input.len() {
518            let actual = Vec3::new(nx[index], ny[index], nz[index]).normalize();
519            assert!(actual.dot(expected) > 0.98, "tilted plane normal was {actual:?}");
520        }
521    }
522
523    #[test]
524    fn orient_normal_towards_viewpoint_works() {
525        let normal = Vec3::new(0.0, 0.0, -1.0);
526        let point = Vec3::new(0.0, 0.0, 0.0);
527        let viewpoint = Vec3::new(0.0, 0.0, 1.0);
528        let oriented = orient_normal_towards_viewpoint(normal, point, viewpoint);
529        assert!(oriented.z > 0.0);
530    }
531
532    #[test]
533    fn adds_curvature_field() {
534        let input = plane_cloud();
535        let estimator = NormalEstimator::new(NormalEstimationConfig::k_neighbors(10));
536        let output = estimator.estimate(&input).unwrap();
537        assert!(output.field("curvature").is_ok());
538    }
539}