Skip to main content

spatialrust_gpu/kernels/
covariances_grid.rs

1use bytemuck::{Pod, Zeroable};
2use spatialrust_core::{SpatialError, SpatialResult};
3use wgpu::util::DeviceExt;
4
5use crate::kernels::normals_grid::{build_grid, grid_bounds};
6use crate::runtime::WgpuRuntime;
7
8const WORKGROUP_SIZE: u32 = 256;
9
10#[repr(C)]
11#[derive(Clone, Copy, Debug, Pod, Zeroable)]
12struct CovUniform {
13    origin: [f32; 4],
14    dims: [u32; 4], // dimx, dimy, dimz, point_count
15    inv_cell: f32,
16    radius_sq: f32,
17    epsilon: f32,
18    _pad: f32,
19}
20
21/// Per-point plane-regularized covariance as 6 unique elements:
22/// `[c00, c11, c22, c01, c02, c12]`.
23pub type GpuCovariance = [f32; 6];
24
25const COV_GRID_WGSL: &str = r#"
26struct Params {
27    origin: vec4<f32>,
28    dims: vec4<u32>,
29    inv_cell: f32,
30    radius_sq: f32,
31    epsilon: f32,
32    pad: f32,
33};
34
35@group(0) @binding(0) var<uniform> params: Params;
36@group(0) @binding(1) var<storage, read> xs: array<f32>;
37@group(0) @binding(2) var<storage, read> ys: array<f32>;
38@group(0) @binding(3) var<storage, read> zs: array<f32>;
39@group(0) @binding(4) var<storage, read> sorted: array<u32>;
40@group(0) @binding(5) var<storage, read> cell_start: array<u32>;
41@group(0) @binding(6) var<storage, read_write> out_cov: array<vec4<f32>>;
42
43fn rotate(a: ptr<function, array<vec3<f32>, 3>>,
44          v: ptr<function, array<vec3<f32>, 3>>,
45          p: u32, q: u32) {
46    let apq = (*a)[p][q];
47    if (abs(apq) < 1e-20) {
48        return;
49    }
50    let app = (*a)[p][p];
51    let aqq = (*a)[q][q];
52    let phi = 0.5 * (aqq - app) / apq;
53    var t: f32;
54    if (phi >= 0.0) {
55        t = 1.0 / (phi + sqrt(1.0 + phi * phi));
56    } else {
57        t = -1.0 / (-phi + sqrt(1.0 + phi * phi));
58    }
59    let c = 1.0 / sqrt(1.0 + t * t);
60    let s = t * c;
61    for (var r: u32 = 0u; r < 3u; r = r + 1u) {
62        let arp = (*a)[r][p];
63        let arq = (*a)[r][q];
64        (*a)[r][p] = c * arp - s * arq;
65        (*a)[r][q] = s * arp + c * arq;
66    }
67    for (var r: u32 = 0u; r < 3u; r = r + 1u) {
68        let apr = (*a)[p][r];
69        let aqr = (*a)[q][r];
70        (*a)[p][r] = c * apr - s * aqr;
71        (*a)[q][r] = s * apr + c * aqr;
72    }
73    for (var r: u32 = 0u; r < 3u; r = r + 1u) {
74        let vrp = (*v)[r][p];
75        let vrq = (*v)[r][q];
76        (*v)[r][p] = c * vrp - s * vrq;
77        (*v)[r][q] = s * vrp + c * vrq;
78    }
79}
80
81fn cell_coord(value: f32, origin: f32, inv_cell: f32, dim: u32) -> i32 {
82    let c = i32(floor((value - origin) * inv_cell));
83    return clamp(c, 0, i32(dim) - 1);
84}
85
86@compute @workgroup_size(256)
87fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
88    let i = gid.x;
89    if (i >= params.dims.w) {
90        return;
91    }
92    let px = xs[i];
93    let py = ys[i];
94    let pz = zs[i];
95    let dimx = params.dims.x;
96    let dimy = params.dims.y;
97    let dimz = params.dims.z;
98    let cx = cell_coord(px, params.origin.x, params.inv_cell, dimx);
99    let cy = cell_coord(py, params.origin.y, params.inv_cell, dimy);
100    let cz = cell_coord(pz, params.origin.z, params.inv_cell, dimz);
101
102    var mean = vec3<f32>(0.0, 0.0, 0.0);
103    var count = 0.0;
104    for (var dz = -1; dz <= 1; dz = dz + 1) {
105        let nz = cz + dz;
106        if (nz < 0 || nz >= i32(dimz)) { continue; }
107        for (var dy = -1; dy <= 1; dy = dy + 1) {
108            let ny = cy + dy;
109            if (ny < 0 || ny >= i32(dimy)) { continue; }
110            for (var dx = -1; dx <= 1; dx = dx + 1) {
111                let nx = cx + dx;
112                if (nx < 0 || nx >= i32(dimx)) { continue; }
113                let cid = (u32(nz) * dimy + u32(ny)) * dimx + u32(nx);
114                for (var s = cell_start[cid]; s < cell_start[cid + 1u]; s = s + 1u) {
115                    let j = sorted[s];
116                    let d = vec3<f32>(xs[j] - px, ys[j] - py, zs[j] - pz);
117                    if (dot(d, d) <= params.radius_sq) {
118                        mean = mean + vec3<f32>(xs[j], ys[j], zs[j]);
119                        count = count + 1.0;
120                    }
121                }
122            }
123        }
124    }
125
126    // Too few neighbors: emit an isotropic epsilon-scaled covariance.
127    if (count < 3.0) {
128        out_cov[i] = vec4<f32>(params.epsilon, params.epsilon, params.epsilon, 0.0);
129        // remaining elements default to zero
130        return;
131    }
132    mean = mean / count;
133
134    var c00 = 0.0; var c11 = 0.0; var c22 = 0.0;
135    var c01 = 0.0; var c02 = 0.0; var c12 = 0.0;
136    for (var dz = -1; dz <= 1; dz = dz + 1) {
137        let nz = cz + dz;
138        if (nz < 0 || nz >= i32(dimz)) { continue; }
139        for (var dy = -1; dy <= 1; dy = dy + 1) {
140            let ny = cy + dy;
141            if (ny < 0 || ny >= i32(dimy)) { continue; }
142            for (var dx = -1; dx <= 1; dx = dx + 1) {
143                let nx = cx + dx;
144                if (nx < 0 || nx >= i32(dimx)) { continue; }
145                let cid = (u32(nz) * dimy + u32(ny)) * dimx + u32(nx);
146                for (var s = cell_start[cid]; s < cell_start[cid + 1u]; s = s + 1u) {
147                    let j = sorted[s];
148                    let p = vec3<f32>(xs[j], ys[j], zs[j]);
149                    let rel = p - vec3<f32>(px, py, pz);
150                    if (dot(rel, rel) <= params.radius_sq) {
151                        let dd = p - mean;
152                        c00 = c00 + dd.x * dd.x;
153                        c11 = c11 + dd.y * dd.y;
154                        c22 = c22 + dd.z * dd.z;
155                        c01 = c01 + dd.x * dd.y;
156                        c02 = c02 + dd.x * dd.z;
157                        c12 = c12 + dd.y * dd.z;
158                    }
159                }
160            }
161        }
162    }
163
164    var a = array<vec3<f32>, 3>(
165        vec3<f32>(c00, c01, c02),
166        vec3<f32>(c01, c11, c12),
167        vec3<f32>(c02, c12, c22),
168    );
169    var v = array<vec3<f32>, 3>(
170        vec3<f32>(1.0, 0.0, 0.0),
171        vec3<f32>(0.0, 1.0, 0.0),
172        vec3<f32>(0.0, 0.0, 1.0),
173    );
174    for (var sweep: u32 = 0u; sweep < 16u; sweep = sweep + 1u) {
175        rotate(&a, &v, 0u, 1u);
176        rotate(&a, &v, 0u, 2u);
177        rotate(&a, &v, 1u, 2u);
178    }
179
180    // GICP plane regularization: rebuild covariance with eigenvalues (eps, 1, 1),
181    // smallest eigenvalue (surface normal) -> eps.
182    let eig = vec3<f32>(a[0][0], a[1][1], a[2][2]);
183    var min_idx = 0u;
184    if (eig[1] < eig[min_idx]) { min_idx = 1u; }
185    if (eig[2] < eig[min_idx]) { min_idx = 2u; }
186
187    var reg = array<f32, 6>(0.0, 0.0, 0.0, 0.0, 0.0, 0.0);
188    for (var col = 0u; col < 3u; col = col + 1u) {
189        var lambda = 1.0;
190        if (col == min_idx) {
191            lambda = params.epsilon;
192        }
193        let ax = v[0][col];
194        let ay = v[1][col];
195        let az = v[2][col];
196        reg[0] = reg[0] + lambda * ax * ax;
197        reg[1] = reg[1] + lambda * ay * ay;
198        reg[2] = reg[2] + lambda * az * az;
199        reg[3] = reg[3] + lambda * ax * ay;
200        reg[4] = reg[4] + lambda * ax * az;
201        reg[5] = reg[5] + lambda * ay * az;
202    }
203
204    out_cov[i] = vec4<f32>(reg[0], reg[1], reg[2], reg[3]);
205    out_cov[params.dims.w + i] = vec4<f32>(reg[4], reg[5], 0.0, 0.0);
206}
207"#;
208
209/// Estimates per-point plane-regularized covariances on the GPU via a uniform
210/// grid radius neighbor search.
211///
212/// Returns one [`GpuCovariance`] per point — the unique elements of a covariance
213/// matrix whose eigenvalues have been set to `(epsilon, 1, 1)` (the GICP
214/// plane-to-plane model). Grid construction (counting sort) runs on the CPU; the
215/// neighbor gather, covariance, and eigen-decomposition run on the GPU.
216pub fn estimate_plane_covariances_grid_gpu(
217    runtime: &WgpuRuntime,
218    x: &[f32],
219    y: &[f32],
220    z: &[f32],
221    radius: f32,
222    epsilon: f32,
223) -> SpatialResult<Vec<GpuCovariance>> {
224    let point_count = x.len();
225    if y.len() != point_count || z.len() != point_count {
226        return Err(SpatialError::BufferLengthMismatch { expected: point_count, found: y.len() });
227    }
228    if point_count == 0 {
229        return Ok(Vec::new());
230    }
231    if radius <= 0.0 || radius.is_nan() {
232        return Err(SpatialError::InvalidArgument("grid radius must be positive".to_owned()));
233    }
234
235    let (origin, dims) = grid_bounds(x, y, z, radius)?;
236    let (sorted, cell_start) = build_grid(x, y, z, origin, dims, radius);
237
238    let device = runtime.device();
239    let queue = runtime.queue();
240    let storage = wgpu::BufferUsages::STORAGE;
241
242    let x_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
243        label: Some("cg-x"),
244        contents: bytemuck::cast_slice(x),
245        usage: storage,
246    });
247    let y_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
248        label: Some("cg-y"),
249        contents: bytemuck::cast_slice(y),
250        usage: storage,
251    });
252    let z_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
253        label: Some("cg-z"),
254        contents: bytemuck::cast_slice(z),
255        usage: storage,
256    });
257    let sorted_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
258        label: Some("cg-sorted"),
259        contents: bytemuck::cast_slice(&sorted),
260        usage: storage,
261    });
262    let cell_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
263        label: Some("cg-cell-start"),
264        contents: bytemuck::cast_slice(&cell_start),
265        usage: storage,
266    });
267    let uniform = CovUniform {
268        origin: [origin[0], origin[1], origin[2], 0.0],
269        dims: [dims[0], dims[1], dims[2], point_count as u32],
270        inv_cell: 1.0 / radius,
271        radius_sq: radius * radius,
272        epsilon,
273        _pad: 0.0,
274    };
275    let uniform_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
276        label: Some("cg-uniform"),
277        contents: bytemuck::bytes_of(&uniform),
278        usage: wgpu::BufferUsages::UNIFORM,
279    });
280
281    // Two vec4 rows per point: row0 = (c00,c11,c22,c01), row1 = (c02,c12,_,_).
282    let output_len = (2 * point_count * std::mem::size_of::<[f32; 4]>()) as u64;
283    let output_buf = device.create_buffer(&wgpu::BufferDescriptor {
284        label: Some("cg-output"),
285        size: output_len,
286        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
287        mapped_at_creation: false,
288    });
289
290    let module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
291        label: Some("cg-shader"),
292        source: wgpu::ShaderSource::Wgsl(COV_GRID_WGSL.into()),
293    });
294    let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
295        label: Some("cg-pipeline"),
296        layout: None,
297        module: &module,
298        entry_point: Some("main"),
299        compilation_options: wgpu::PipelineCompilationOptions::default(),
300        cache: None,
301    });
302    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
303        label: Some("cg-bind-group"),
304        layout: &pipeline.get_bind_group_layout(0),
305        entries: &[
306            wgpu::BindGroupEntry { binding: 0, resource: uniform_buf.as_entire_binding() },
307            wgpu::BindGroupEntry { binding: 1, resource: x_buf.as_entire_binding() },
308            wgpu::BindGroupEntry { binding: 2, resource: y_buf.as_entire_binding() },
309            wgpu::BindGroupEntry { binding: 3, resource: z_buf.as_entire_binding() },
310            wgpu::BindGroupEntry { binding: 4, resource: sorted_buf.as_entire_binding() },
311            wgpu::BindGroupEntry { binding: 5, resource: cell_buf.as_entire_binding() },
312            wgpu::BindGroupEntry { binding: 6, resource: output_buf.as_entire_binding() },
313        ],
314    });
315
316    let mut encoder =
317        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("cg") });
318    {
319        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
320            label: Some("cg-pass"),
321            timestamp_writes: None,
322        });
323        pass.set_pipeline(&pipeline);
324        pass.set_bind_group(0, &bind_group, &[]);
325        pass.dispatch_workgroups((point_count as u32).div_ceil(WORKGROUP_SIZE), 1, 1);
326    }
327    queue.submit(Some(encoder.finish()));
328
329    let staging = device.create_buffer(&wgpu::BufferDescriptor {
330        label: Some("cg-staging"),
331        size: output_len,
332        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
333        mapped_at_creation: false,
334    });
335    let mut encoder =
336        device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: Some("cg-rb") });
337    encoder.copy_buffer_to_buffer(&output_buf, 0, &staging, 0, output_len);
338    queue.submit(Some(encoder.finish()));
339
340    let slice = staging.slice(..);
341    let (sender, receiver) = std::sync::mpsc::channel();
342    slice.map_async(wgpu::MapMode::Read, move |result| {
343        let _ = sender.send(result);
344    });
345    device.poll(wgpu::Maintain::Wait);
346    receiver
347        .recv()
348        .map_err(|_| SpatialError::InvalidArgument("failed to receive wgpu map result".to_owned()))?
349        .map_err(|error| {
350            SpatialError::InvalidArgument(format!("failed to map wgpu buffer: {error}"))
351        })?;
352    let data = slice.get_mapped_range();
353    let rows: &[[f32; 4]] = bytemuck::cast_slice(&data);
354    let mut out = Vec::with_capacity(point_count);
355    for i in 0..point_count {
356        let row0 = rows[i];
357        let row1 = rows[point_count + i];
358        out.push([row0[0], row0[1], row0[2], row0[3], row1[0], row1[1]]);
359    }
360    drop(data);
361    staging.unmap();
362
363    Ok(out)
364}
365
366#[cfg(test)]
367mod tests {
368    use super::estimate_plane_covariances_grid_gpu;
369    use crate::runtime::WgpuRuntime;
370
371    #[test]
372    fn planar_patch_covariance_is_disk() {
373        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
374        let mut x: Vec<f32> = Vec::new();
375        let mut y: Vec<f32> = Vec::new();
376        let mut z: Vec<f32> = Vec::new();
377        for i in 0..12 {
378            for j in 0..12 {
379                x.push(i as f32 * 0.1);
380                y.push(j as f32 * 0.1);
381                z.push(0.0);
382            }
383        }
384        let eps = 1e-3_f32;
385        let cov = estimate_plane_covariances_grid_gpu(&runtime, &x, &y, &z, 0.25, eps)
386            .expect("gpu covariances");
387        assert_eq!(cov.len(), x.len());
388        // For a z=0 plane, the regularized covariance ~ diag(1, 1, eps):
389        // in-plane variance ~1, out-of-plane (z) ~eps.
390        for c in &cov {
391            let [c00, c11, c22, _c01, _c02, _c12] = *c;
392            assert!((c22 - eps).abs() < 1e-2, "c22 not ~eps: {c22}");
393            assert!(c00 > 0.5 && c11 > 0.5, "in-plane variance too small: {c00},{c11}");
394        }
395    }
396}