spatialrust_features/
normal_gpu.rs1use spatialrust_core::{HasPositions3, PointCloud, SpatialResult};
2use spatialrust_gpu::{estimate_normals_gpu, estimate_normals_grid_gpu, WgpuRuntime};
3use spatialrust_math::Vec3;
4
5use crate::neighborhood::{KdTreeNeighborhood, NeighborhoodProvider};
6use crate::normal::{build_output_cloud, orient_normal_towards_viewpoint, NormalEstimationConfig};
7
8#[derive(Clone, Copy, Debug, PartialEq)]
15pub struct GpuNormalEstimator {
16 config: NormalEstimationConfig,
17}
18
19impl GpuNormalEstimator {
20 #[must_use]
22 pub const fn new(config: NormalEstimationConfig) -> Self {
23 Self { config }
24 }
25
26 #[must_use]
28 pub const fn config(&self) -> NormalEstimationConfig {
29 self.config
30 }
31
32 pub fn estimate(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
34 if input.is_empty() {
35 return Ok(input.clone());
36 }
37
38 let (x, y, z) = input.positions3()?;
39 let n = input.len();
40 let runtime = WgpuRuntime::shared()?;
41
42 let gpu_normals = if let Some(radius) = self.config.search_radius {
45 estimate_normals_grid_gpu(&runtime, x, y, z, radius)?
46 } else {
47 let k = self.config.k_neighbors.max(1);
48 let neighborhood = KdTreeNeighborhood::from_point_cloud(input)?;
49 let mut flat = Vec::with_capacity(n * k);
50 for index in 0..n {
51 let mut neighbors = neighborhood.query_k(index, k)?;
52 if neighbors.is_empty() {
53 neighbors.push(index);
54 }
55 for slot in 0..k {
56 flat.push(neighbors[slot % neighbors.len()] as u32);
57 }
58 }
59 estimate_normals_gpu(&runtime, x, y, z, &flat, k as u32)?
60 };
61
62 let mut nx = Vec::with_capacity(n);
63 let mut ny = Vec::with_capacity(n);
64 let mut nz = Vec::with_capacity(n);
65 let mut curvature = Vec::with_capacity(n);
66 for (index, gpu_normal) in gpu_normals.iter().enumerate() {
67 let mut normal =
68 Vec3::new(gpu_normal.normal[0], gpu_normal.normal[1], gpu_normal.normal[2]);
69 if let Some(viewpoint) = self.config.viewpoint {
70 normal = orient_normal_towards_viewpoint(
71 normal,
72 Vec3::new(x[index], y[index], z[index]),
73 viewpoint,
74 );
75 }
76 nx.push(normal.x);
77 ny.push(normal.y);
78 nz.push(normal.z);
79 curvature.push(gpu_normal.curvature);
80 }
81
82 build_output_cloud(input, nx, ny, nz, curvature)
83 }
84}
85
86#[cfg(test)]
87mod tests {
88 use super::GpuNormalEstimator;
89 use crate::normal::NormalEstimationConfig;
90 use spatialrust_core::{HasNormals3, PointCloudBuilder, StandardSchemas};
91 use spatialrust_gpu::WgpuRuntime;
92 use spatialrust_math::Vec3;
93
94 #[test]
95 fn estimates_plane_normals_on_gpu() {
96 if WgpuRuntime::shared().is_err() {
98 return;
99 }
100
101 let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
102 for i in 0..8 {
103 for j in 0..8 {
104 builder.push_point([i as f32 * 0.1, j as f32 * 0.1, 0.0]).unwrap();
105 }
106 }
107 let cloud = builder.build().unwrap();
108
109 let estimator = GpuNormalEstimator::new(NormalEstimationConfig {
110 k_neighbors: 8,
111 min_neighbors: 3,
112 viewpoint: Some(Vec3::new(0.0, 0.0, 10.0)),
113 ..Default::default()
114 });
115 let output = estimator.estimate(&cloud).unwrap();
116
117 let (nx, ny, nz) = output.normals3().unwrap();
118 for index in 0..output.len() {
119 assert!(nz[index] > 0.99, "normal not vertical: {}", nz[index]);
121 assert!(nx[index].abs() < 0.1 && ny[index].abs() < 0.1);
122 }
123 assert!(output.field("curvature").is_ok());
124 }
125
126 #[test]
127 fn estimates_plane_normals_on_gpu_grid() {
128 if WgpuRuntime::shared().is_err() {
130 return;
131 }
132
133 let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
134 for i in 0..12 {
135 for j in 0..12 {
136 builder.push_point([i as f32 * 0.1, j as f32 * 0.1, 0.0]).unwrap();
137 }
138 }
139 let cloud = builder.build().unwrap();
140
141 let estimator = GpuNormalEstimator::new(NormalEstimationConfig {
142 search_radius: Some(0.25),
143 viewpoint: Some(Vec3::new(0.0, 0.0, 10.0)),
144 ..Default::default()
145 });
146 let output = estimator.estimate(&cloud).unwrap();
147
148 let (_, _, nz) = output.normals3().unwrap();
149 for &value in nz {
150 assert!(value > 0.99, "normal not vertical: {value}");
151 }
152 }
153}