Skip to main content

spatialrust_features/
normal_gpu.rs

1use spatialrust_core::{HasPositions3, PointCloud, SpatialResult};
2use spatialrust_gpu::{estimate_normals_gpu, estimate_normals_grid_gpu, WgpuRuntime};
3use spatialrust_math::Vec3;
4
5use crate::neighborhood::{KdTreeNeighborhood, NeighborhoodProvider};
6use crate::normal::{build_output_cloud, orient_normal_towards_viewpoint, NormalEstimationConfig};
7
8/// GPU-accelerated normal estimator.
9///
10/// Neighbor search runs on the CPU (KD-tree); the per-point covariance analysis
11/// and eigen-decomposition run on the GPU via wgpu. Output matches
12/// [`crate::NormalEstimator`]: `normal_x/y/z` and `curvature` fields, optionally
13/// oriented toward a viewpoint.
14#[derive(Clone, Copy, Debug, PartialEq)]
15pub struct GpuNormalEstimator {
16    config: NormalEstimationConfig,
17}
18
19impl GpuNormalEstimator {
20    /// Creates a GPU normal estimator from config.
21    #[must_use]
22    pub const fn new(config: NormalEstimationConfig) -> Self {
23        Self { config }
24    }
25
26    /// Returns the config.
27    #[must_use]
28    pub const fn config(&self) -> NormalEstimationConfig {
29        self.config
30    }
31
32    /// Estimates normals on the GPU, returning a cloud with normal/curvature fields.
33    pub fn estimate(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
34        if input.is_empty() {
35            return Ok(input.clone());
36        }
37
38        let (x, y, z) = input.positions3()?;
39        let n = input.len();
40        let runtime = WgpuRuntime::shared()?;
41
42        // With a search radius, run the neighbor search entirely on the GPU via a
43        // uniform grid. Otherwise use CPU KD-tree k-NN feeding the GPU eigensolver.
44        let gpu_normals = if let Some(radius) = self.config.search_radius {
45            estimate_normals_grid_gpu(&runtime, x, y, z, radius)?
46        } else {
47            let k = self.config.k_neighbors.max(1);
48            let neighborhood = KdTreeNeighborhood::from_point_cloud(input)?;
49            let mut flat = Vec::with_capacity(n * k);
50            for index in 0..n {
51                let mut neighbors = neighborhood.query_k(index, k)?;
52                if neighbors.is_empty() {
53                    neighbors.push(index);
54                }
55                for slot in 0..k {
56                    flat.push(neighbors[slot % neighbors.len()] as u32);
57                }
58            }
59            estimate_normals_gpu(&runtime, x, y, z, &flat, k as u32)?
60        };
61
62        let mut nx = Vec::with_capacity(n);
63        let mut ny = Vec::with_capacity(n);
64        let mut nz = Vec::with_capacity(n);
65        let mut curvature = Vec::with_capacity(n);
66        for (index, gpu_normal) in gpu_normals.iter().enumerate() {
67            let mut normal =
68                Vec3::new(gpu_normal.normal[0], gpu_normal.normal[1], gpu_normal.normal[2]);
69            if let Some(viewpoint) = self.config.viewpoint {
70                normal = orient_normal_towards_viewpoint(
71                    normal,
72                    Vec3::new(x[index], y[index], z[index]),
73                    viewpoint,
74                );
75            }
76            nx.push(normal.x);
77            ny.push(normal.y);
78            nz.push(normal.z);
79            curvature.push(gpu_normal.curvature);
80        }
81
82        build_output_cloud(input, nx, ny, nz, curvature)
83    }
84}
85
86#[cfg(test)]
87mod tests {
88    use super::GpuNormalEstimator;
89    use crate::normal::NormalEstimationConfig;
90    use spatialrust_core::{HasNormals3, PointCloudBuilder, StandardSchemas};
91    use spatialrust_gpu::WgpuRuntime;
92    use spatialrust_math::Vec3;
93
94    #[test]
95    fn estimates_plane_normals_on_gpu() {
96        // Skip gracefully when no GPU/software adapter is available.
97        if WgpuRuntime::shared().is_err() {
98            return;
99        }
100
101        let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
102        for i in 0..8 {
103            for j in 0..8 {
104                builder.push_point([i as f32 * 0.1, j as f32 * 0.1, 0.0]).unwrap();
105            }
106        }
107        let cloud = builder.build().unwrap();
108
109        let estimator = GpuNormalEstimator::new(NormalEstimationConfig {
110            k_neighbors: 8,
111            min_neighbors: 3,
112            viewpoint: Some(Vec3::new(0.0, 0.0, 10.0)),
113            ..Default::default()
114        });
115        let output = estimator.estimate(&cloud).unwrap();
116
117        let (nx, ny, nz) = output.normals3().unwrap();
118        for index in 0..output.len() {
119            // Plane normals point up toward the viewpoint.
120            assert!(nz[index] > 0.99, "normal not vertical: {}", nz[index]);
121            assert!(nx[index].abs() < 0.1 && ny[index].abs() < 0.1);
122        }
123        assert!(output.field("curvature").is_ok());
124    }
125
126    #[test]
127    fn estimates_plane_normals_on_gpu_grid() {
128        // Radius mode runs the neighbor search entirely on the GPU.
129        if WgpuRuntime::shared().is_err() {
130            return;
131        }
132
133        let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
134        for i in 0..12 {
135            for j in 0..12 {
136                builder.push_point([i as f32 * 0.1, j as f32 * 0.1, 0.0]).unwrap();
137            }
138        }
139        let cloud = builder.build().unwrap();
140
141        let estimator = GpuNormalEstimator::new(NormalEstimationConfig {
142            search_radius: Some(0.25),
143            viewpoint: Some(Vec3::new(0.0, 0.0, 10.0)),
144            ..Default::default()
145        });
146        let output = estimator.estimate(&cloud).unwrap();
147
148        let (_, _, nz) = output.normals3().unwrap();
149        for &value in nz {
150            assert!(value > 0.99, "normal not vertical: {value}");
151        }
152    }
153}