1use spatialrust_core::{
2 DType, FieldSemantic, HasPositions3, PointBuffer, PointBufferSet, PointCloud, PointField,
3 PointSchema, SpatialError, SpatialResult,
4};
5use spatialrust_math::{symmetric_eigen3, Mat3, Vec3};
6use spatialrust_search::{KdTree, Neighbor, RadiusSearchIndex};
7
8use crate::estimator::FeatureEstimator;
9
10#[derive(Clone, Copy, Debug, PartialEq)]
12pub struct NormalEstimationConfig {
13 pub k_neighbors: usize,
15 pub search_radius: Option<f32>,
17 pub min_neighbors: usize,
19 pub viewpoint: Option<Vec3<f32>>,
21}
22
23impl Default for NormalEstimationConfig {
24 fn default() -> Self {
25 Self { k_neighbors: 20, search_radius: None, min_neighbors: 3, viewpoint: None }
26 }
27}
28
29impl NormalEstimationConfig {
30 #[must_use]
32 pub const fn k_neighbors(k_neighbors: usize) -> Self {
33 Self { k_neighbors, search_radius: None, min_neighbors: 3, viewpoint: None }
34 }
35}
36
37#[derive(Clone, Debug, Default, PartialEq, Eq)]
39pub struct NormalEstimationResult {
40 pub valid_count: usize,
42 pub invalid_count: usize,
44}
45
46#[derive(Clone, Debug, PartialEq)]
48pub struct NormalEstimator {
49 config: NormalEstimationConfig,
50}
51
52impl NormalEstimator {
53 #[must_use]
55 pub const fn new(config: NormalEstimationConfig) -> Self {
56 Self { config }
57 }
58
59 #[must_use]
61 pub const fn config(&self) -> NormalEstimationConfig {
62 self.config
63 }
64
65 pub fn estimate_with_diagnostics(
67 &self,
68 input: &PointCloud,
69 ) -> SpatialResult<(PointCloud, NormalEstimationResult)> {
70 if input.is_empty() {
71 return Ok((input.clone(), NormalEstimationResult::default()));
72 }
73 if self.config.search_radius.is_some_and(|radius| radius < 0.0) {
74 return Err(SpatialError::InvalidArgument("search_radius must be non-negative".into()));
75 }
76
77 let (x, y, z) = input.positions3()?;
78 let tree = KdTree::from_slices(x, y, z);
79
80 let mut nx = vec![f32::NAN; input.len()];
81 let mut ny = vec![f32::NAN; input.len()];
82 let mut nz = vec![f32::NAN; input.len()];
83 let mut curvature = vec![0.0_f32; input.len()];
84 let mut valid_count = 0usize;
85 let mut invalid_count = 0usize;
86
87 let worker_count = normal_worker_count(input.len());
88 if worker_count == 1 {
89 let chunk = estimate_normal_range(self.config, &tree, x, y, z, 0, input.len());
90 nx = chunk.nx;
91 ny = chunk.ny;
92 nz = chunk.nz;
93 curvature = chunk.curvature;
94 valid_count = chunk.valid_count;
95 invalid_count = chunk.invalid_count;
96 } else {
97 let chunk_size = input.len().div_ceil(worker_count);
98 let chunks = std::thread::scope(|scope| {
99 let mut handles = Vec::new();
100 let config = self.config;
101 let tree_ref = &tree;
102 for start in (0..input.len()).step_by(chunk_size) {
103 let end = (start + chunk_size).min(input.len());
104 handles.push(scope.spawn(move || {
105 estimate_normal_range(config, tree_ref, x, y, z, start, end)
106 }));
107 }
108
109 handles
110 .into_iter()
111 .map(|handle| handle.join().expect("normal estimation worker panicked"))
112 .collect::<Vec<_>>()
113 });
114
115 for chunk in chunks {
116 let end = chunk.start + chunk.nx.len();
117 nx[chunk.start..end].copy_from_slice(&chunk.nx);
118 ny[chunk.start..end].copy_from_slice(&chunk.ny);
119 nz[chunk.start..end].copy_from_slice(&chunk.nz);
120 curvature[chunk.start..end].copy_from_slice(&chunk.curvature);
121 valid_count += chunk.valid_count;
122 invalid_count += chunk.invalid_count;
123 }
124 }
125
126 let output = build_output_cloud(input, nx, ny, nz, curvature)?;
127 Ok((output, NormalEstimationResult { valid_count, invalid_count }))
128 }
129}
130
131impl FeatureEstimator for NormalEstimator {
132 fn name(&self) -> &'static str {
133 "NormalEstimator"
134 }
135
136 fn estimate(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
137 self.estimate_with_diagnostics(input).map(|(cloud, _)| cloud)
138 }
139}
140
141#[derive(Debug)]
142struct NormalChunk {
143 start: usize,
144 nx: Vec<f32>,
145 ny: Vec<f32>,
146 nz: Vec<f32>,
147 curvature: Vec<f32>,
148 valid_count: usize,
149 invalid_count: usize,
150}
151
152fn normal_worker_count(len: usize) -> usize {
153 let available = std::thread::available_parallelism().map_or(1, |count| count.get());
154 let useful = (len / 16_384).max(1);
155 available.min(useful)
156}
157
158fn estimate_normal_range(
159 config: NormalEstimationConfig,
160 tree: &KdTree,
161 x: &[f32],
162 y: &[f32],
163 z: &[f32],
164 start: usize,
165 end: usize,
166) -> NormalChunk {
167 let len = end - start;
168 let mut nx = vec![f32::NAN; len];
169 let mut ny = vec![f32::NAN; len];
170 let mut nz = vec![f32::NAN; len];
171 let mut curvature = vec![0.0_f32; len];
172 let mut valid_count = 0usize;
173 let mut invalid_count = 0usize;
174 let mut neighbor_buffer = Vec::with_capacity(config.k_neighbors.saturating_add(1));
175 let mut index_buffer = Vec::with_capacity(config.k_neighbors);
176
177 for index in start..end {
178 query_neighbors_into(config, tree, x, y, z, index, &mut neighbor_buffer, &mut index_buffer);
179 let local = index - start;
180 if index_buffer.len() < config.min_neighbors {
181 invalid_count += 1;
182 continue;
183 }
184
185 let Some((normal, curv)) = estimate_normal_from_neighbors(x, y, z, index, &index_buffer)
186 else {
187 invalid_count += 1;
188 continue;
189 };
190
191 let oriented = if let Some(viewpoint) = config.viewpoint {
192 orient_normal_towards_viewpoint(normal, point_xyz(x, y, z, index), viewpoint)
193 } else {
194 normal
195 };
196
197 nx[local] = oriented.x;
198 ny[local] = oriented.y;
199 nz[local] = oriented.z;
200 curvature[local] = curv;
201 valid_count += 1;
202 }
203
204 NormalChunk { start, nx, ny, nz, curvature, valid_count, invalid_count }
205}
206
207fn query_neighbors_into(
208 config: NormalEstimationConfig,
209 tree: &KdTree,
210 x: &[f32],
211 y: &[f32],
212 z: &[f32],
213 index: usize,
214 neighbor_buffer: &mut Vec<Neighbor>,
215 index_buffer: &mut Vec<usize>,
216) {
217 index_buffer.clear();
218 if let Some(radius) = config.search_radius {
219 for neighbor in tree.radius_search(x[index], y[index], z[index], radius) {
220 if neighbor.index != index {
221 index_buffer.push(neighbor.index);
222 }
223 }
224 } else {
225 tree.nearest_k_unsorted_into(
226 x[index],
227 y[index],
228 z[index],
229 config.k_neighbors.saturating_add(1),
230 neighbor_buffer,
231 );
232 for neighbor in neighbor_buffer.iter() {
233 if neighbor.index != index {
234 index_buffer.push(neighbor.index);
235 if index_buffer.len() == config.k_neighbors {
236 break;
237 }
238 }
239 }
240 }
241}
242
243#[must_use]
245pub fn orient_normal_towards_viewpoint(
246 mut normal: Vec3<f32>,
247 point: Vec3<f32>,
248 viewpoint: Vec3<f32>,
249) -> Vec3<f32> {
250 let view_direction =
251 Vec3::new(viewpoint.x - point.x, viewpoint.y - point.y, viewpoint.z - point.z);
252 if normal.dot(view_direction) < 0.0 {
253 normal.x = -normal.x;
254 normal.y = -normal.y;
255 normal.z = -normal.z;
256 }
257 normal.normalize()
258}
259
260fn point_xyz(x: &[f32], y: &[f32], z: &[f32], index: usize) -> Vec3<f32> {
261 Vec3::new(x[index], y[index], z[index])
262}
263
264fn estimate_normal_from_neighbors(
265 x: &[f32],
266 y: &[f32],
267 z: &[f32],
268 _center_index: usize,
269 neighbors: &[usize],
270) -> Option<(Vec3<f32>, f32)> {
271 let mut mean_x = 0.0_f32;
272 let mut mean_y = 0.0_f32;
273 let mut mean_z = 0.0_f32;
274 for &index in neighbors {
275 mean_x += x[index];
276 mean_y += y[index];
277 mean_z += z[index];
278 }
279 let count = neighbors.len() as f32;
280 mean_x /= count;
281 mean_y /= count;
282 mean_z /= count;
283
284 let mut c00 = 0.0_f32;
285 let mut c11 = 0.0_f32;
286 let mut c22 = 0.0_f32;
287 let mut c01 = 0.0_f32;
288 let mut c02 = 0.0_f32;
289 let mut c12 = 0.0_f32;
290 for &index in neighbors {
291 let dx = x[index] - mean_x;
292 let dy = y[index] - mean_y;
293 let dz = z[index] - mean_z;
294 c00 += dx * dx;
295 c11 += dy * dy;
296 c22 += dz * dz;
297 c01 += dx * dy;
298 c02 += dx * dz;
299 c12 += dy * dz;
300 }
301 let inv = 1.0 / count;
302 smallest_eigenpair_for_covariance(
303 c00 * inv,
304 c11 * inv,
305 c22 * inv,
306 c01 * inv,
307 c02 * inv,
308 c12 * inv,
309 )
310}
311
312fn smallest_eigenpair_for_covariance(
313 c00: f32,
314 c11: f32,
315 c22: f32,
316 c01: f32,
317 c02: f32,
318 c12: f32,
319) -> Option<(Vec3<f32>, f32)> {
320 let eigenvalues = symmetric_eigenvalues3(c00, c11, c22, c01, c02, c12);
321 let lambda = eigenvalues[0];
322 let normal =
323 eigenvector_for_eigenvalue(c00, c11, c22, c01, c02, c12, lambda).unwrap_or_else(|| {
324 let covariance = Mat3::<f64>::from_rows(
325 [c00 as f64, c01 as f64, c02 as f64],
326 [c01 as f64, c11 as f64, c12 as f64],
327 [c02 as f64, c12 as f64, c22 as f64],
328 );
329 let eigen = symmetric_eigen3(covariance);
330 Vec3::new(
331 eigen.eigenvectors.m[0][0] as f32,
332 eigen.eigenvectors.m[1][0] as f32,
333 eigen.eigenvectors.m[2][0] as f32,
334 )
335 .normalize()
336 });
337
338 let sum = eigenvalues[0] + eigenvalues[1] + eigenvalues[2];
339 let curvature = if sum > 0.0 { eigenvalues[0] / sum } else { 0.0 };
340 Some((normal.normalize(), curvature))
341}
342
343fn symmetric_eigenvalues3(c00: f32, c11: f32, c22: f32, c01: f32, c02: f32, c12: f32) -> [f32; 3] {
344 let p1 = c01 * c01 + c02 * c02 + c12 * c12;
345 if p1 <= f32::EPSILON {
346 let mut values = [c00, c11, c22];
347 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
348 return values;
349 }
350
351 let q = (c00 + c11 + c22) / 3.0;
352 let b00 = c00 - q;
353 let b11 = c11 - q;
354 let b22 = c22 - q;
355 let p2 = b00 * b00 + b11 * b11 + b22 * b22 + 2.0 * p1;
356 let p = (p2 / 6.0).sqrt();
357 if p <= f32::EPSILON {
358 return [q, q, q];
359 }
360
361 let inv_p = 1.0 / p;
362 let n00 = b00 * inv_p;
363 let n11 = b11 * inv_p;
364 let n22 = b22 * inv_p;
365 let n01 = c01 * inv_p;
366 let n02 = c02 * inv_p;
367 let n12 = c12 * inv_p;
368 let det = n00 * (n11 * n22 - n12 * n12) - n01 * (n01 * n22 - n12 * n02)
369 + n02 * (n01 * n12 - n11 * n02);
370 let r = (det * 0.5).clamp(-1.0, 1.0);
371 let phi = r.acos() / 3.0;
372
373 let largest = q + 2.0 * p * phi.cos();
374 let smallest = q + 2.0 * p * (phi + 2.0 * std::f32::consts::PI / 3.0).cos();
375 let middle = 3.0 * q - largest - smallest;
376 let mut values = [smallest, middle, largest];
377 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
378 values
379}
380
381fn eigenvector_for_eigenvalue(
382 c00: f32,
383 c11: f32,
384 c22: f32,
385 c01: f32,
386 c02: f32,
387 c12: f32,
388 lambda: f32,
389) -> Option<Vec3<f32>> {
390 let row0 = Vec3::new(c00 - lambda, c01, c02);
391 let row1 = Vec3::new(c01, c11 - lambda, c12);
392 let row2 = Vec3::new(c02, c12, c22 - lambda);
393
394 let candidates = [row0.cross(row1), row0.cross(row2), row1.cross(row2)];
395 let mut best = candidates[0];
396 let mut best_norm = best.length_squared();
397 for candidate in candidates.into_iter().skip(1) {
398 let norm = candidate.length_squared();
399 if norm > best_norm {
400 best = candidate;
401 best_norm = norm;
402 }
403 }
404
405 if best_norm <= 1e-24 {
406 None
407 } else {
408 Some(best.normalize())
409 }
410}
411
412pub(crate) fn build_output_cloud(
413 input: &PointCloud,
414 nx: Vec<f32>,
415 ny: Vec<f32>,
416 nz: Vec<f32>,
417 curvature: Vec<f32>,
418) -> SpatialResult<PointCloud> {
419 let mut schema = input.schema().clone();
420 ensure_field(&mut schema, "normal_x", FieldSemantic::NormalX, DType::F32);
421 ensure_field(&mut schema, "normal_y", FieldSemantic::NormalY, DType::F32);
422 ensure_field(&mut schema, "normal_z", FieldSemantic::NormalZ, DType::F32);
423 ensure_field(&mut schema, "curvature", FieldSemantic::Curvature, DType::F32);
424
425 let mut buffers = PointBufferSet::new();
426 for field in input.schema().fields() {
427 let source = input.field(&field.name)?;
428 buffers.insert(field.name.clone(), clone_buffer(source)?);
429 }
430 buffers.insert("normal_x".to_owned(), PointBuffer::from_f32(nx));
431 buffers.insert("normal_y".to_owned(), PointBuffer::from_f32(ny));
432 buffers.insert("normal_z".to_owned(), PointBuffer::from_f32(nz));
433 buffers.insert("curvature".to_owned(), PointBuffer::from_f32(curvature));
434
435 PointCloud::try_from_parts(schema, buffers, input.metadata().clone())
436}
437
438fn ensure_field(schema: &mut PointSchema, name: &str, semantic: FieldSemantic, dtype: DType) {
439 if schema.find_semantic(semantic).is_none() {
440 *schema = schema.clone().with_field(PointField::scalar(name, semantic, dtype));
441 }
442}
443
444fn clone_buffer(buffer: &PointBuffer) -> SpatialResult<PointBuffer> {
445 Ok(match buffer {
446 PointBuffer::F32(values) => PointBuffer::from_f32(values.clone()),
447 PointBuffer::F64(values) => PointBuffer::F64(values.clone()),
448 PointBuffer::U8(values) => PointBuffer::U8(values.clone()),
449 PointBuffer::U16(values) => PointBuffer::U16(values.clone()),
450 PointBuffer::U32(values) => PointBuffer::U32(values.clone()),
451 PointBuffer::I32(values) => PointBuffer::I32(values.clone()),
452 })
453}
454
455#[cfg(test)]
456mod tests {
457 use super::{orient_normal_towards_viewpoint, NormalEstimationConfig, NormalEstimator};
458 use crate::FeatureEstimator;
459 use spatialrust_core::{HasNormals3, PointCloudBuilder, StandardSchemas};
460 use spatialrust_math::Vec3;
461
462 fn plane_cloud() -> spatialrust_core::PointCloud {
463 let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
464 for x in 0..5 {
465 for y in 0..5 {
466 builder.push_point([x as f32, y as f32, 0.0]).unwrap();
467 }
468 }
469 builder.build().unwrap()
470 }
471
472 fn tilted_plane_cloud() -> spatialrust_core::PointCloud {
473 let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
474 for x in 0..7 {
475 for y in 0..7 {
476 let fx = x as f32 * 0.2;
477 let fy = y as f32 * 0.2;
478 let z = 0.2 * fx - 0.3 * fy + 0.1;
479 builder.push_point([fx, fy, z]).unwrap();
480 }
481 }
482 builder.build().unwrap()
483 }
484
485 #[test]
486 fn estimates_plane_normals_upwards() {
487 let input = plane_cloud();
488 let estimator = NormalEstimator::new(NormalEstimationConfig {
489 k_neighbors: 8,
490 min_neighbors: 3,
491 viewpoint: Some(Vec3::new(0.0, 0.0, 10.0)),
492 ..NormalEstimationConfig::default()
493 });
494 let (output, stats) = estimator.estimate_with_diagnostics(&input).unwrap();
495 assert_eq!(stats.valid_count, input.len());
496 assert_eq!(stats.invalid_count, 0);
497
498 let (_, _, nz) = output.normals3().unwrap();
499 for value in nz {
500 assert!((*value - 1.0).abs() < 0.1, "expected upward normal, got {value}");
501 }
502 }
503
504 #[test]
505 fn estimates_tilted_plane_normals() {
506 let input = tilted_plane_cloud();
507 let estimator = NormalEstimator::new(NormalEstimationConfig {
508 k_neighbors: 12,
509 min_neighbors: 3,
510 viewpoint: Some(Vec3::new(0.0, 0.0, 10.0)),
511 ..NormalEstimationConfig::default()
512 });
513 let output = estimator.estimate(&input).unwrap();
514 let (nx, ny, nz) = output.normals3().unwrap();
515 let expected = Vec3::new(-0.2, 0.3, 1.0).normalize();
516
517 for index in 0..input.len() {
518 let actual = Vec3::new(nx[index], ny[index], nz[index]).normalize();
519 assert!(actual.dot(expected) > 0.98, "tilted plane normal was {actual:?}");
520 }
521 }
522
523 #[test]
524 fn orient_normal_towards_viewpoint_works() {
525 let normal = Vec3::new(0.0, 0.0, -1.0);
526 let point = Vec3::new(0.0, 0.0, 0.0);
527 let viewpoint = Vec3::new(0.0, 0.0, 1.0);
528 let oriented = orient_normal_towards_viewpoint(normal, point, viewpoint);
529 assert!(oriented.z > 0.0);
530 }
531
532 #[test]
533 fn adds_curvature_field() {
534 let input = plane_cloud();
535 let estimator = NormalEstimator::new(NormalEstimationConfig::k_neighbors(10));
536 let output = estimator.estimate(&input).unwrap();
537 assert!(output.field("curvature").is_ok());
538 }
539}