Skip to main content

spatialrust_gpu/kernels/
normals_grid.rs

1use bytemuck::{Pod, Zeroable};
2use spatialrust_core::{SpatialError, SpatialResult};
3use wgpu::util::DeviceExt;
4
5use crate::kernels::normals::GpuNormal;
6use crate::runtime::WgpuRuntime;
7
8const WORKGROUP_SIZE: u32 = 256;
9/// Upper bound on dense grid cells to avoid pathological memory use; callers
10/// should fall back to the CPU/KD-tree path when exceeded.
11const MAX_CELLS: u64 = 64_000_000;
12
13#[repr(C)]
14#[derive(Clone, Copy, Debug, Pod, Zeroable)]
15struct GridUniform {
16    origin: [f32; 4],
17    dims: [u32; 4], // dimx, dimy, dimz, point_count
18    inv_cell: f32,
19    radius_sq: f32,
20    _pad0: f32,
21    _pad1: f32,
22}
23
24const NORMALS_GRID_WGSL: &str = r#"
25struct Params {
26    origin: vec4<f32>,
27    dims: vec4<u32>,
28    inv_cell: f32,
29    radius_sq: f32,
30    pad0: f32,
31    pad1: f32,
32};
33
34@group(0) @binding(0) var<uniform> params: Params;
35@group(0) @binding(1) var<storage, read> xs: array<f32>;
36@group(0) @binding(2) var<storage, read> ys: array<f32>;
37@group(0) @binding(3) var<storage, read> zs: array<f32>;
38@group(0) @binding(4) var<storage, read> sorted: array<u32>;
39@group(0) @binding(5) var<storage, read> cell_start: array<u32>;
40@group(0) @binding(6) var<storage, read_write> out_normals: array<vec4<f32>>;
41
42fn rotate(a: ptr<function, array<vec3<f32>, 3>>,
43          v: ptr<function, array<vec3<f32>, 3>>,
44          p: u32, q: u32) {
45    let apq = (*a)[p][q];
46    if (abs(apq) < 1e-20) {
47        return;
48    }
49    let app = (*a)[p][p];
50    let aqq = (*a)[q][q];
51    let phi = 0.5 * (aqq - app) / apq;
52    var t: f32;
53    if (phi >= 0.0) {
54        t = 1.0 / (phi + sqrt(1.0 + phi * phi));
55    } else {
56        t = -1.0 / (-phi + sqrt(1.0 + phi * phi));
57    }
58    let c = 1.0 / sqrt(1.0 + t * t);
59    let s = t * c;
60    for (var r: u32 = 0u; r < 3u; r = r + 1u) {
61        let arp = (*a)[r][p];
62        let arq = (*a)[r][q];
63        (*a)[r][p] = c * arp - s * arq;
64        (*a)[r][q] = s * arp + c * arq;
65    }
66    for (var r: u32 = 0u; r < 3u; r = r + 1u) {
67        let apr = (*a)[p][r];
68        let aqr = (*a)[q][r];
69        (*a)[p][r] = c * apr - s * aqr;
70        (*a)[q][r] = s * apr + c * aqr;
71    }
72    for (var r: u32 = 0u; r < 3u; r = r + 1u) {
73        let vrp = (*v)[r][p];
74        let vrq = (*v)[r][q];
75        (*v)[r][p] = c * vrp - s * vrq;
76        (*v)[r][q] = s * vrp + c * vrq;
77    }
78}
79
80fn cell_coord(value: f32, origin: f32, inv_cell: f32, dim: u32) -> i32 {
81    let c = i32(floor((value - origin) * inv_cell));
82    return clamp(c, 0, i32(dim) - 1);
83}
84
85@compute @workgroup_size(256)
86fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
87    let i = gid.x;
88    if (i >= params.dims.w) {
89        return;
90    }
91    let px = xs[i];
92    let py = ys[i];
93    let pz = zs[i];
94    let dimx = params.dims.x;
95    let dimy = params.dims.y;
96    let dimz = params.dims.z;
97
98    let cx = cell_coord(px, params.origin.x, params.inv_cell, dimx);
99    let cy = cell_coord(py, params.origin.y, params.inv_cell, dimy);
100    let cz = cell_coord(pz, params.origin.z, params.inv_cell, dimz);
101
102    // First pass: mean over radius neighbors across the 27 adjacent cells.
103    var mean = vec3<f32>(0.0, 0.0, 0.0);
104    var count = 0.0;
105    for (var dz = -1; dz <= 1; dz = dz + 1) {
106        let nz = cz + dz;
107        if (nz < 0 || nz >= i32(dimz)) { continue; }
108        for (var dy = -1; dy <= 1; dy = dy + 1) {
109            let ny = cy + dy;
110            if (ny < 0 || ny >= i32(dimy)) { continue; }
111            for (var dx = -1; dx <= 1; dx = dx + 1) {
112                let nx = cx + dx;
113                if (nx < 0 || nx >= i32(dimx)) { continue; }
114                let cid = (u32(nz) * dimy + u32(ny)) * dimx + u32(nx);
115                let begin = cell_start[cid];
116                let end = cell_start[cid + 1u];
117                for (var s = begin; s < end; s = s + 1u) {
118                    let j = sorted[s];
119                    let d = vec3<f32>(xs[j] - px, ys[j] - py, zs[j] - pz);
120                    if (dot(d, d) <= params.radius_sq) {
121                        mean = mean + vec3<f32>(xs[j], ys[j], zs[j]);
122                        count = count + 1.0;
123                    }
124                }
125            }
126        }
127    }
128
129    if (count < 3.0) {
130        out_normals[i] = vec4<f32>(0.0, 0.0, 1.0, 0.0);
131        return;
132    }
133    mean = mean / count;
134
135    var c00 = 0.0; var c11 = 0.0; var c22 = 0.0;
136    var c01 = 0.0; var c02 = 0.0; var c12 = 0.0;
137    for (var dz = -1; dz <= 1; dz = dz + 1) {
138        let nz = cz + dz;
139        if (nz < 0 || nz >= i32(dimz)) { continue; }
140        for (var dy = -1; dy <= 1; dy = dy + 1) {
141            let ny = cy + dy;
142            if (ny < 0 || ny >= i32(dimy)) { continue; }
143            for (var dx = -1; dx <= 1; dx = dx + 1) {
144                let nx = cx + dx;
145                if (nx < 0 || nx >= i32(dimx)) { continue; }
146                let cid = (u32(nz) * dimy + u32(ny)) * dimx + u32(nx);
147                let begin = cell_start[cid];
148                let end = cell_start[cid + 1u];
149                for (var s = begin; s < end; s = s + 1u) {
150                    let j = sorted[s];
151                    let p = vec3<f32>(xs[j], ys[j], zs[j]);
152                    let rel = p - vec3<f32>(px, py, pz);
153                    if (dot(rel, rel) <= params.radius_sq) {
154                        let dd = p - mean;
155                        c00 = c00 + dd.x * dd.x;
156                        c11 = c11 + dd.y * dd.y;
157                        c22 = c22 + dd.z * dd.z;
158                        c01 = c01 + dd.x * dd.y;
159                        c02 = c02 + dd.x * dd.z;
160                        c12 = c12 + dd.y * dd.z;
161                    }
162                }
163            }
164        }
165    }
166
167    var a = array<vec3<f32>, 3>(
168        vec3<f32>(c00, c01, c02),
169        vec3<f32>(c01, c11, c12),
170        vec3<f32>(c02, c12, c22),
171    );
172    var v = array<vec3<f32>, 3>(
173        vec3<f32>(1.0, 0.0, 0.0),
174        vec3<f32>(0.0, 1.0, 0.0),
175        vec3<f32>(0.0, 0.0, 1.0),
176    );
177    for (var sweep: u32 = 0u; sweep < 16u; sweep = sweep + 1u) {
178        rotate(&a, &v, 0u, 1u);
179        rotate(&a, &v, 0u, 2u);
180        rotate(&a, &v, 1u, 2u);
181    }
182
183    let eig = vec3<f32>(a[0][0], a[1][1], a[2][2]);
184    var min_idx = 0u;
185    if (eig[1] < eig[min_idx]) { min_idx = 1u; }
186    if (eig[2] < eig[min_idx]) { min_idx = 2u; }
187    let normal = vec3<f32>(v[0][min_idx], v[1][min_idx], v[2][min_idx]);
188    let len = max(sqrt(dot(normal, normal)), 1e-20);
189    let unit = normal / len;
190    let trace = eig[0] + eig[1] + eig[2];
191    var curvature = 0.0;
192    if (trace > 0.0) {
193        curvature = eig[min_idx] / trace;
194    }
195    out_normals[i] = vec4<f32>(unit.x, unit.y, unit.z, curvature);
196}
197"#;
198
199/// Estimates per-point normals and curvature with a fully GPU radius neighbor
200/// search over a uniform grid.
201///
202/// The grid (cell size = `radius`) is built on the CPU with a counting sort
203/// (O(n)); the per-point neighbor gather, covariance, and eigen-decomposition
204/// all run on the GPU. Returns `SpatialError::InvalidArgument` when the bounding
205/// grid would exceed an internal cell cap (caller should fall back to the CPU
206/// KD-tree path).
207pub fn estimate_normals_grid_gpu(
208    runtime: &WgpuRuntime,
209    x: &[f32],
210    y: &[f32],
211    z: &[f32],
212    radius: f32,
213) -> SpatialResult<Vec<GpuNormal>> {
214    let point_count = x.len();
215    if y.len() != point_count || z.len() != point_count {
216        return Err(SpatialError::BufferLengthMismatch { expected: point_count, found: y.len() });
217    }
218    if point_count == 0 {
219        return Ok(Vec::new());
220    }
221    if radius <= 0.0 || radius.is_nan() {
222        return Err(SpatialError::InvalidArgument("grid radius must be positive".to_owned()));
223    }
224
225    let (origin, dims) = grid_bounds(x, y, z, radius)?;
226    let (sorted, cell_start) = build_grid(x, y, z, origin, dims, radius);
227
228    let device = runtime.device();
229    let queue = runtime.queue();
230    let inv_cell = 1.0 / radius;
231
232    let storage = wgpu::BufferUsages::STORAGE;
233    let x_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
234        label: Some("ng-x"),
235        contents: bytemuck::cast_slice(x),
236        usage: storage,
237    });
238    let y_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
239        label: Some("ng-y"),
240        contents: bytemuck::cast_slice(y),
241        usage: storage,
242    });
243    let z_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
244        label: Some("ng-z"),
245        contents: bytemuck::cast_slice(z),
246        usage: storage,
247    });
248    let sorted_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
249        label: Some("ng-sorted"),
250        contents: bytemuck::cast_slice(&sorted),
251        usage: storage,
252    });
253    let cell_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
254        label: Some("ng-cell-start"),
255        contents: bytemuck::cast_slice(&cell_start),
256        usage: storage,
257    });
258    let uniform = GridUniform {
259        origin: [origin[0], origin[1], origin[2], 0.0],
260        dims: [dims[0], dims[1], dims[2], point_count as u32],
261        inv_cell,
262        radius_sq: radius * radius,
263        _pad0: 0.0,
264        _pad1: 0.0,
265    };
266    let uniform_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
267        label: Some("ng-uniform"),
268        contents: bytemuck::bytes_of(&uniform),
269        usage: wgpu::BufferUsages::UNIFORM,
270    });
271
272    let output_len = (point_count * std::mem::size_of::<[f32; 4]>()) as u64;
273    let output_buf = device.create_buffer(&wgpu::BufferDescriptor {
274        label: Some("ng-output"),
275        size: output_len,
276        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
277        mapped_at_creation: false,
278    });
279
280    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
281        label: Some("ng-shader"),
282        source: wgpu::ShaderSource::Wgsl(NORMALS_GRID_WGSL.into()),
283    });
284    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
285        label: Some("ng-pipeline"),
286        layout: None,
287        module: &module,
288        entry_point: Some("main"),
289        compilation_options: wgpu::PipelineCompilationOptions::default(),
290        cache: None,
291    });
292    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
293        label: Some("ng-bind-group"),
294        layout: &pipeline.get_bind_group_layout(0),
295        entries: &[
296            wgpu::BindGroupEntry { binding: 0, resource: uniform_buf.as_entire_binding() },
297            wgpu::BindGroupEntry { binding: 1, resource: x_buf.as_entire_binding() },
298            wgpu::BindGroupEntry { binding: 2, resource: y_buf.as_entire_binding() },
299            wgpu::BindGroupEntry { binding: 3, resource: z_buf.as_entire_binding() },
300            wgpu::BindGroupEntry { binding: 4, resource: sorted_buf.as_entire_binding() },
301            wgpu::BindGroupEntry { binding: 5, resource: cell_buf.as_entire_binding() },
302            wgpu::BindGroupEntry { binding: 6, resource: output_buf.as_entire_binding() },
303        ],
304    });
305
306    let mut encoder =
307        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("ng") });
308    {
309        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
310            label: Some("ng-pass"),
311            timestamp_writes: None,
312        });
313        pass.set_pipeline(&pipeline);
314        pass.set_bind_group(0, &bind_group, &[]);
315        pass.dispatch_workgroups((point_count as u32).div_ceil(WORKGROUP_SIZE), 1, 1);
316    }
317    queue.submit(Some(encoder.finish()));
318
319    let staging = device.create_buffer(&wgpu::BufferDescriptor {
320        label: Some("ng-staging"),
321        size: output_len,
322        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
323        mapped_at_creation: false,
324    });
325    let mut encoder =
326        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("ng-rb") });
327    encoder.copy_buffer_to_buffer(&output_buf, 0, &staging, 0, output_len);
328    queue.submit(Some(encoder.finish()));
329
330    let slice = staging.slice(..);
331    let (sender, receiver) = std::sync::mpsc::channel();
332    slice.map_async(wgpu::MapMode::Read, move |result| {
333        let _ = sender.send(result);
334    });
335    device.poll(wgpu::Maintain::Wait);
336    receiver
337        .recv()
338        .map_err(|_| SpatialError::InvalidArgument("failed to receive wgpu map result".to_owned()))?
339        .map_err(|error| {
340            SpatialError::InvalidArgument(format!("failed to map wgpu buffer: {error}"))
341        })?;
342    let data = slice.get_mapped_range();
343    let raw: &[[f32; 4]] = bytemuck::cast_slice(&data);
344    let normals =
345        raw.iter().map(|v| GpuNormal { normal: [v[0], v[1], v[2]], curvature: v[3] }).collect();
346    drop(data);
347    staging.unmap();
348
349    Ok(normals)
350}
351
352pub(crate) fn grid_bounds(
353    x: &[f32],
354    y: &[f32],
355    z: &[f32],
356    radius: f32,
357) -> SpatialResult<([f32; 3], [u32; 3])> {
358    let mut min = [f32::INFINITY; 3];
359    let mut max = [f32::NEG_INFINITY; 3];
360    for index in 0..x.len() {
361        for (axis, value) in [x[index], y[index], z[index]].into_iter().enumerate() {
362            min[axis] = min[axis].min(value);
363            max[axis] = max[axis].max(value);
364        }
365    }
366    let inv_cell = 1.0 / radius;
367    let mut dims = [0u32; 3];
368    for axis in 0..3 {
369        let span = ((max[axis] - min[axis]) * inv_cell).floor() as i64 + 1;
370        dims[axis] = span.max(1) as u32;
371    }
372    let cells = dims[0] as u64 * dims[1] as u64 * dims[2] as u64;
373    if cells > MAX_CELLS {
374        return Err(SpatialError::InvalidArgument(format!(
375            "grid would need {cells} cells (cap {MAX_CELLS}); use a larger radius or the CPU path"
376        )));
377    }
378    Ok((min, dims))
379}
380
381/// Counting-sort points into grid cells, returning sorted indices and CSR offsets.
382pub(crate) fn build_grid(
383    x: &[f32],
384    y: &[f32],
385    z: &[f32],
386    origin: [f32; 3],
387    dims: [u32; 3],
388    radius: f32,
389) -> (Vec<u32>, Vec<u32>) {
390    let inv_cell = 1.0 / radius;
391    let n = x.len();
392    let num_cells = dims[0] as usize * dims[1] as usize * dims[2] as usize;
393
394    let cell_of = |index: usize| -> usize {
395        let cx = (((x[index] - origin[0]) * inv_cell).floor() as i64).clamp(0, dims[0] as i64 - 1)
396            as usize;
397        let cy = (((y[index] - origin[1]) * inv_cell).floor() as i64).clamp(0, dims[1] as i64 - 1)
398            as usize;
399        let cz = (((z[index] - origin[2]) * inv_cell).floor() as i64).clamp(0, dims[2] as i64 - 1)
400            as usize;
401        (cz * dims[1] as usize + cy) * dims[0] as usize + cx
402    };
403
404    let mut counts = vec![0u32; num_cells + 1];
405    for index in 0..n {
406        counts[cell_of(index)] += 1;
407    }
408    // Prefix sum -> cell_start (CSR offsets).
409    let mut acc = 0u32;
410    for slot in counts.iter_mut() {
411        let c = *slot;
412        *slot = acc;
413        acc += c;
414    }
415    let cell_start = counts; // now offsets, length num_cells+1, last == n
416
417    let mut cursor = cell_start.clone();
418    let mut sorted = vec![0u32; n];
419    for index in 0..n {
420        let cell = cell_of(index);
421        let slot = cursor[cell];
422        sorted[slot as usize] = index as u32;
423        cursor[cell] = slot + 1;
424    }
425    (sorted, cell_start)
426}
427
428#[cfg(test)]
429mod tests {
430    use super::estimate_normals_grid_gpu;
431    use crate::runtime::WgpuRuntime;
432
433    #[test]
434    fn planar_patch_has_vertical_normal() {
435        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
436        let mut x: Vec<f32> = Vec::new();
437        let mut y: Vec<f32> = Vec::new();
438        let mut z: Vec<f32> = Vec::new();
439        for i in 0..12 {
440            for j in 0..12 {
441                x.push(i as f32 * 0.1);
442                y.push(j as f32 * 0.1);
443                z.push(0.0);
444            }
445        }
446        let normals = estimate_normals_grid_gpu(&runtime, &x, &y, &z, 0.25).expect("grid normals");
447        assert_eq!(normals.len(), x.len());
448        for normal in &normals {
449            assert!(normal.normal[2].abs() > 0.99, "normal not vertical: {:?}", normal.normal);
450            assert!(normal.curvature < 1e-3, "curvature too high: {}", normal.curvature);
451        }
452    }
453}