Skip to main content

spatialrust_gpu/kernels/
normals.rs

1use bytemuck::{Pod, Zeroable};
2use spatialrust_core::{SpatialError, SpatialResult};
3use wgpu::util::DeviceExt;
4
5use crate::runtime::WgpuRuntime;
6
7const WORKGROUP_SIZE: u32 = 256;
8
9#[repr(C)]
10#[derive(Clone, Copy, Debug, Pod, Zeroable)]
11struct NormalsUniform {
12    point_count: u32,
13    k: u32,
14    _pad0: u32,
15    _pad1: u32,
16}
17
18/// Per-point normal estimation output: `(nx, ny, nz, curvature)`.
19#[derive(Clone, Copy, Debug, Default)]
20pub struct GpuNormal {
21    /// Unit normal x/y/z.
22    pub normal: [f32; 3],
23    /// Surface variation curvature in `[0, 1/3]`.
24    pub curvature: f32,
25}
26
27const NORMALS_WGSL: &str = r#"
28struct Params { point_count: u32, k: u32, pad0: u32, pad1: u32, };
29
30@group(0) @binding(0) var<uniform> params: Params;
31@group(0) @binding(1) var<storage, read> xs: array<f32>;
32@group(0) @binding(2) var<storage, read> ys: array<f32>;
33@group(0) @binding(3) var<storage, read> zs: array<f32>;
34@group(0) @binding(4) var<storage, read> neighbors: array<u32>;
35@group(0) @binding(5) var<storage, read_write> out_normals: array<vec4<f32>>;
36
37fn rotate(a: ptr<function, array<vec3<f32>, 3>>,
38          v: ptr<function, array<vec3<f32>, 3>>,
39          p: u32, q: u32) {
40    let apq = (*a)[p][q];
41    if (abs(apq) < 1e-20) {
42        return;
43    }
44    let app = (*a)[p][p];
45    let aqq = (*a)[q][q];
46    let phi = 0.5 * (aqq - app) / apq;
47    var t: f32;
48    if (phi >= 0.0) {
49        t = 1.0 / (phi + sqrt(1.0 + phi * phi));
50    } else {
51        t = -1.0 / (-phi + sqrt(1.0 + phi * phi));
52    }
53    let c = 1.0 / sqrt(1.0 + t * t);
54    let s = t * c;
55
56    // Apply the Jacobi rotation to columns/rows p and q of the symmetric matrix.
57    for (var r: u32 = 0u; r < 3u; r = r + 1u) {
58        let arp = (*a)[r][p];
59        let arq = (*a)[r][q];
60        (*a)[r][p] = c * arp - s * arq;
61        (*a)[r][q] = s * arp + c * arq;
62    }
63    for (var r: u32 = 0u; r < 3u; r = r + 1u) {
64        let apr = (*a)[p][r];
65        let aqr = (*a)[q][r];
66        (*a)[p][r] = c * apr - s * aqr;
67        (*a)[q][r] = s * apr + c * aqr;
68    }
69    for (var r: u32 = 0u; r < 3u; r = r + 1u) {
70        let vrp = (*v)[r][p];
71        let vrq = (*v)[r][q];
72        (*v)[r][p] = c * vrp - s * vrq;
73        (*v)[r][q] = s * vrp + c * vrq;
74    }
75}
76
77@compute @workgroup_size(256)
78fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
79    let i = gid.x;
80    if (i >= params.point_count) {
81        return;
82    }
83    let k = params.k;
84    let base = i * k;
85
86    var mean = vec3<f32>(0.0, 0.0, 0.0);
87    var count = 0.0;
88    for (var j: u32 = 0u; j < k; j = j + 1u) {
89        let idx = neighbors[base + j];
90        mean = mean + vec3<f32>(xs[idx], ys[idx], zs[idx]);
91        count = count + 1.0;
92    }
93    if (count < 3.0) {
94        out_normals[i] = vec4<f32>(0.0, 0.0, 1.0, 0.0);
95        return;
96    }
97    mean = mean / count;
98
99    var c00 = 0.0; var c11 = 0.0; var c22 = 0.0;
100    var c01 = 0.0; var c02 = 0.0; var c12 = 0.0;
101    for (var j: u32 = 0u; j < k; j = j + 1u) {
102        let idx = neighbors[base + j];
103        let d = vec3<f32>(xs[idx], ys[idx], zs[idx]) - mean;
104        c00 = c00 + d.x * d.x;
105        c11 = c11 + d.y * d.y;
106        c22 = c22 + d.z * d.z;
107        c01 = c01 + d.x * d.y;
108        c02 = c02 + d.x * d.z;
109        c12 = c12 + d.y * d.z;
110    }
111
112    var a = array<vec3<f32>, 3>(
113        vec3<f32>(c00, c01, c02),
114        vec3<f32>(c01, c11, c12),
115        vec3<f32>(c02, c12, c22),
116    );
117    var v = array<vec3<f32>, 3>(
118        vec3<f32>(1.0, 0.0, 0.0),
119        vec3<f32>(0.0, 1.0, 0.0),
120        vec3<f32>(0.0, 0.0, 1.0),
121    );
122
123    for (var sweep: u32 = 0u; sweep < 16u; sweep = sweep + 1u) {
124        rotate(&a, &v, 0u, 1u);
125        rotate(&a, &v, 0u, 2u);
126        rotate(&a, &v, 1u, 2u);
127    }
128
129    let eig = vec3<f32>(a[0][0], a[1][1], a[2][2]);
130    var min_idx = 0u;
131    if (eig[1] < eig[min_idx]) { min_idx = 1u; }
132    if (eig[2] < eig[min_idx]) { min_idx = 2u; }
133
134    let normal = vec3<f32>(v[0][min_idx], v[1][min_idx], v[2][min_idx]);
135    let length = max(sqrt(dot(normal, normal)), 1e-20);
136    let unit = normal / length;
137
138    let trace = eig[0] + eig[1] + eig[2];
139    var curvature = 0.0;
140    if (trace > 0.0) {
141        curvature = eig[min_idx] / trace;
142    }
143
144    out_normals[i] = vec4<f32>(unit.x, unit.y, unit.z, curvature);
145}
146"#;
147
148/// Estimates per-point normals and curvature on the GPU.
149///
150/// `neighbors` is a flattened `point_count * k` array of indices into the point
151/// arrays, where row `i` lists the neighbors of point `i` (pad short rows by
152/// repeating the point's own index). Normal orientation is arbitrary (sign is
153/// not disambiguated); callers can flip toward a viewpoint on the CPU.
154pub fn estimate_normals_gpu(
155    runtime: &WgpuRuntime,
156    x: &[f32],
157    y: &[f32],
158    z: &[f32],
159    neighbors: &[u32],
160    k: u32,
161) -> SpatialResult<Vec<GpuNormal>> {
162    let point_count = x.len();
163    if y.len() != point_count || z.len() != point_count {
164        return Err(SpatialError::BufferLengthMismatch { expected: point_count, found: y.len() });
165    }
166    if point_count == 0 {
167        return Ok(Vec::new());
168    }
169    if k == 0 || neighbors.len() != point_count * k as usize {
170        return Err(SpatialError::InvalidArgument(format!(
171            "neighbors must have point_count*k = {} entries, got {}",
172            point_count * k as usize,
173            neighbors.len()
174        )));
175    }
176
177    let device = runtime.device();
178    let queue = runtime.queue();
179
180    let x_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
181        label: Some("normals-x"),
182        contents: bytemuck::cast_slice(x),
183        usage: wgpu::BufferUsages::STORAGE,
184    });
185    let y_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
186        label: Some("normals-y"),
187        contents: bytemuck::cast_slice(y),
188        usage: wgpu::BufferUsages::STORAGE,
189    });
190    let z_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
191        label: Some("normals-z"),
192        contents: bytemuck::cast_slice(z),
193        usage: wgpu::BufferUsages::STORAGE,
194    });
195    let neighbor_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
196        label: Some("normals-neighbors"),
197        contents: bytemuck::cast_slice(neighbors),
198        usage: wgpu::BufferUsages::STORAGE,
199    });
200    let uniform = NormalsUniform { point_count: point_count as u32, k, _pad0: 0, _pad1: 0 };
201    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
202        label: Some("normals-uniform"),
203        contents: bytemuck::bytes_of(&uniform),
204        usage: wgpu::BufferUsages::UNIFORM,
205    });
206
207    let output_len = (point_count * std::mem::size_of::<[f32; 4]>()) as u64;
208    let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
209        label: Some("normals-output"),
210        size: output_len,
211        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
212        mapped_at_creation: false,
213    });
214
215    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
216        label: Some("normals-shader"),
217        source: wgpu::ShaderSource::Wgsl(NORMALS_WGSL.into()),
218    });
219    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
220        label: Some("normals-pipeline"),
221        layout: None,
222        module: &module,
223        entry_point: Some("main"),
224        compilation_options: wgpu::PipelineCompilationOptions::default(),
225        cache: None,
226    });
227
228    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
229        label: Some("normals-bind-group"),
230        layout: &pipeline.get_bind_group_layout(0),
231        entries: &[
232            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
233            wgpu::BindGroupEntry { binding: 1, resource: x_buffer.as_entire_binding() },
234            wgpu::BindGroupEntry { binding: 2, resource: y_buffer.as_entire_binding() },
235            wgpu::BindGroupEntry { binding: 3, resource: z_buffer.as_entire_binding() },
236            wgpu::BindGroupEntry { binding: 4, resource: neighbor_buffer.as_entire_binding() },
237            wgpu::BindGroupEntry { binding: 5, resource: output_buffer.as_entire_binding() },
238        ],
239    });
240
241    let mut encoder =
242        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("normals") });
243    {
244        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
245            label: Some("normals-pass"),
246            timestamp_writes: None,
247        });
248        pass.set_pipeline(&pipeline);
249        pass.set_bind_group(0, &bind_group, &[]);
250        pass.dispatch_workgroups((point_count as u32).div_ceil(WORKGROUP_SIZE), 1, 1);
251    }
252    queue.submit(Some(encoder.finish()));
253
254    let staging = device.create_buffer(&wgpu::BufferDescriptor {
255        label: Some("normals-staging"),
256        size: output_len,
257        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
258        mapped_at_creation: false,
259    });
260    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
261        label: Some("normals-readback"),
262    });
263    encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging, 0, output_len);
264    queue.submit(Some(encoder.finish()));
265
266    let slice = staging.slice(..);
267    let (sender, receiver) = std::sync::mpsc::channel();
268    slice.map_async(wgpu::MapMode::Read, move |result| {
269        let _ = sender.send(result);
270    });
271    device.poll(wgpu::Maintain::Wait);
272    receiver
273        .recv()
274        .map_err(|_| SpatialError::InvalidArgument("failed to receive wgpu map result".to_owned()))?
275        .map_err(|error| {
276            SpatialError::InvalidArgument(format!("failed to map wgpu buffer: {error}"))
277        })?;
278
279    let data = slice.get_mapped_range();
280    let raw: &[[f32; 4]] = bytemuck::cast_slice(&data);
281    let normals =
282        raw.iter().map(|v| GpuNormal { normal: [v[0], v[1], v[2]], curvature: v[3] }).collect();
283    drop(data);
284    staging.unmap();
285
286    Ok(normals)
287}
288
289#[cfg(test)]
290mod tests {
291    use super::estimate_normals_gpu;
292    use crate::runtime::WgpuRuntime;
293
294    #[test]
295    fn planar_patch_has_vertical_normal() {
296        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
297        // 5x5 grid on the z=0 plane.
298        let mut x: Vec<f32> = Vec::new();
299        let mut y: Vec<f32> = Vec::new();
300        let mut z: Vec<f32> = Vec::new();
301        for i in 0..5 {
302            for j in 0..5 {
303                x.push(i as f32 * 0.1);
304                y.push(j as f32 * 0.1);
305                z.push(0.0);
306            }
307        }
308        let n = x.len();
309        let k = 8u32;
310
311        // Brute-force k nearest neighbors per point (CPU).
312        let mut neighbors = Vec::with_capacity(n * k as usize);
313        for i in 0..n {
314            let mut order: Vec<usize> = (0..n).collect();
315            order.sort_by(|&a, &b| {
316                let da = (x[a] - x[i]).powi(2) + (y[a] - y[i]).powi(2) + (z[a] - z[i]).powi(2);
317                let db = (x[b] - x[i]).powi(2) + (y[b] - y[i]).powi(2) + (z[b] - z[i]).powi(2);
318                da.total_cmp(&db)
319            });
320            for &idx in order.iter().take(k as usize) {
321                neighbors.push(idx as u32);
322            }
323        }
324
325        let normals =
326            estimate_normals_gpu(&runtime, &x, &y, &z, &neighbors, k).expect("gpu normals");
327        assert_eq!(normals.len(), n);
328        for normal in &normals {
329            // Normal should point along Z (up or down) and curvature ~0 on a plane.
330            assert!(normal.normal[2].abs() > 0.99, "normal not vertical: {:?}", normal.normal);
331            assert!(normal.curvature < 1e-3, "curvature too high: {}", normal.curvature);
332        }
333    }
334}