Skip to main content

spatialrust_gpu/kernels/
voxel_keys.rs

1use bytemuck::{Pod, Zeroable};
2use spatialrust_core::{SpatialError, SpatialResult};
3use wgpu::util::DeviceExt;
4
5use crate::runtime::WgpuRuntime;
6
7const WORKGROUP_SIZE: u32 = 256;
8
9#[repr(C)]
10#[derive(Clone, Copy, Debug, Pod, Zeroable)]
11struct VoxelKeyUniform {
12    origin: [f32; 4],
13    inv_leaf: f32,
14    point_count: u32,
15    _pad0: u32,
16    _pad1: u32,
17}
18
19#[repr(C)]
20#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
21pub(crate) struct VoxelKeyOutput {
22    pub(crate) ix: i32,
23    pub(crate) iy: i32,
24    pub(crate) iz: i32,
25    pub(crate) _pad: i32,
26}
27
28/// GPU buffers for per-point positions and computed voxel keys.
29pub struct GpuVoxelKeyBuffers {
30    x: wgpu::Buffer,
31    y: wgpu::Buffer,
32    z: wgpu::Buffer,
33    keys: wgpu::Buffer,
34    point_count: u32,
35}
36
37impl GpuVoxelKeyBuffers {
38    /// Returns the number of source points.
39    #[must_use]
40    pub fn point_count(&self) -> u32 {
41        self.point_count
42    }
43
44    /// Returns the GPU buffer of x coordinates.
45    #[must_use]
46    pub fn x_buffer(&self) -> &wgpu::Buffer {
47        &self.x
48    }
49
50    /// Returns the GPU buffer of y coordinates.
51    #[must_use]
52    pub fn y_buffer(&self) -> &wgpu::Buffer {
53        &self.y
54    }
55
56    /// Returns the GPU buffer of z coordinates.
57    #[must_use]
58    pub fn z_buffer(&self) -> &wgpu::Buffer {
59        &self.z
60    }
61
62    /// Returns the GPU buffer of computed voxel keys.
63    #[must_use]
64    pub fn keys_buffer(&self) -> &wgpu::Buffer {
65        &self.keys
66    }
67
68    /// Returns position/key GPU buffers to the runtime upload pool.
69    pub fn recycle(self, runtime: &WgpuRuntime) {
70        runtime.recycle_storage(self.x.size(), self.x);
71        runtime.recycle_storage(self.y.size(), self.y);
72        runtime.recycle_storage(self.z.size(), self.z);
73        runtime.recycle_storage(self.keys.size(), self.keys);
74    }
75}
76
77/// Computes per-point voxel grid keys on the GPU.
78pub fn compute_voxel_keys(
79    runtime: &WgpuRuntime,
80    x: &[f32],
81    y: &[f32],
82    z: &[f32],
83    origin: [f32; 3],
84    inv_leaf: f32,
85) -> SpatialResult<Vec<(i64, i64, i64)>> {
86    if x.len() != y.len() || x.len() != z.len() {
87        return Err(SpatialError::BufferLengthMismatch { expected: x.len(), found: y.len() });
88    }
89    if x.is_empty() {
90        return Ok(Vec::new());
91    }
92
93    let buffers = compute_voxel_keys_gpu_buffers(runtime, x, y, z, origin, inv_leaf)?;
94    read_voxel_keys(runtime, &buffers)
95}
96
97/// Uploads positions and computes per-point voxel keys, keeping data on the GPU.
98pub fn compute_voxel_keys_gpu_buffers(
99    runtime: &WgpuRuntime,
100    x: &[f32],
101    y: &[f32],
102    z: &[f32],
103    origin: [f32; 3],
104    inv_leaf: f32,
105) -> SpatialResult<GpuVoxelKeyBuffers> {
106    if x.len() != y.len() || x.len() != z.len() {
107        return Err(SpatialError::BufferLengthMismatch { expected: x.len(), found: y.len() });
108    }
109    if x.is_empty() {
110        return Err(SpatialError::InvalidArgument(
111            "cannot compute voxel keys for an empty point cloud".to_owned(),
112        ));
113    }
114
115    let point_count = x.len() as u32;
116
117    let x_buffer = runtime.upload_f32_storage("voxel-key-x", x)?;
118    let y_buffer = runtime.upload_f32_storage("voxel-key-y", y)?;
119    let z_buffer = runtime.upload_f32_storage("voxel-key-z", z)?;
120
121    let keys_buffer = dispatch_voxel_keys(
122        runtime,
123        &x_buffer,
124        &y_buffer,
125        &z_buffer,
126        origin,
127        inv_leaf,
128        point_count,
129    )?;
130
131    Ok(GpuVoxelKeyBuffers { x: x_buffer, y: y_buffer, z: z_buffer, keys: keys_buffer, point_count })
132}
133
134fn read_voxel_keys(
135    runtime: &WgpuRuntime,
136    buffers: &GpuVoxelKeyBuffers,
137) -> SpatialResult<Vec<(i64, i64, i64)>> {
138    let device = runtime.device();
139    let queue = runtime.queue();
140    let point_count = buffers.point_count() as usize;
141    let output_len = point_count * std::mem::size_of::<VoxelKeyOutput>();
142
143    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
144        label: Some("voxel-key-staging"),
145        size: output_len as u64,
146        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
147        mapped_at_creation: false,
148    });
149
150    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
151        label: Some("voxel-key-readback-encoder"),
152    });
153    encoder.copy_buffer_to_buffer(buffers.keys_buffer(), 0, &staging_buffer, 0, output_len as u64);
154    queue.submit(Some(encoder.finish()));
155
156    let slice = staging_buffer.slice(..);
157    let (sender, receiver) = std::sync::mpsc::channel();
158    slice.map_async(wgpu::MapMode::Read, move |result| {
159        let _ = sender.send(result);
160    });
161    device.poll(wgpu::Maintain::Wait);
162    receiver
163        .recv()
164        .map_err(|_| SpatialError::InvalidArgument("failed to receive wgpu map result".to_owned()))?
165        .map_err(|error| {
166            SpatialError::InvalidArgument(format!("failed to map wgpu buffer: {error}"))
167        })?;
168
169    let data = slice.get_mapped_range();
170    let outputs: &[VoxelKeyOutput] = bytemuck::cast_slice(&data);
171    let keys = outputs
172        .iter()
173        .map(|key| (i64::from(key.ix), i64::from(key.iy), i64::from(key.iz)))
174        .collect();
175    drop(data);
176    staging_buffer.unmap();
177
178    Ok(keys)
179}
180
181fn dispatch_voxel_keys(
182    runtime: &WgpuRuntime,
183    x_buffer: &wgpu::Buffer,
184    y_buffer: &wgpu::Buffer,
185    z_buffer: &wgpu::Buffer,
186    origin: [f32; 3],
187    inv_leaf: f32,
188    point_count: u32,
189) -> SpatialResult<wgpu::Buffer> {
190    let device = runtime.device();
191    let queue = runtime.queue();
192
193    let uniform = VoxelKeyUniform {
194        origin: [origin[0], origin[1], origin[2], 0.0],
195        inv_leaf,
196        point_count,
197        _pad0: 0,
198        _pad1: 0,
199    };
200
201    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
202        label: Some("voxel-key-uniform"),
203        contents: bytemuck::bytes_of(&uniform),
204        usage: wgpu::BufferUsages::UNIFORM,
205    });
206
207    let output_len = point_count as usize * std::mem::size_of::<VoxelKeyOutput>();
208    let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
209        label: Some("voxel-key-output"),
210        size: output_len as u64,
211        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
212        mapped_at_creation: false,
213    });
214
215    let pipelines = runtime.pipelines();
216
217    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
218        label: Some("voxel-key-bind-group"),
219        layout: &pipelines.voxel_keys.bind_group_layout,
220        entries: &[
221            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
222            wgpu::BindGroupEntry { binding: 1, resource: x_buffer.as_entire_binding() },
223            wgpu::BindGroupEntry { binding: 2, resource: y_buffer.as_entire_binding() },
224            wgpu::BindGroupEntry { binding: 3, resource: z_buffer.as_entire_binding() },
225            wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
226        ],
227    });
228
229    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
230        label: Some("voxel-key-encoder"),
231    });
232
233    {
234        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
235            label: Some("voxel-key-pass"),
236            timestamp_writes: None,
237        });
238        pass.set_pipeline(&pipelines.voxel_keys.pipeline);
239        pass.set_bind_group(0, &bind_group, &[]);
240        let workgroups = point_count.div_ceil(WORKGROUP_SIZE);
241        pass.dispatch_workgroups(workgroups, 1, 1);
242    }
243
244    queue.submit(Some(encoder.finish()));
245
246    Ok(output_buffer)
247}
248
249#[cfg(test)]
250mod tests {
251    use super::compute_voxel_keys;
252    use crate::runtime::WgpuRuntime;
253
254    #[test]
255    fn gpu_voxel_keys_match_cpu_reference() {
256        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
257        let x = [0.0_f32, 0.1, 1.0, 1.1];
258        let y = [0.0_f32, 0.0, 0.0, 0.0];
259        let z = [0.0_f32, 0.0, 0.0, 0.0];
260        let origin = [0.0_f32, 0.0, 0.0];
261        let inv_leaf = 2.0_f32;
262
263        let gpu_keys =
264            compute_voxel_keys(&runtime, &x, &y, &z, origin, inv_leaf).expect("gpu keys");
265
266        let cpu_keys: Vec<(i64, i64, i64)> = x
267            .iter()
268            .zip(y.iter())
269            .zip(z.iter())
270            .map(|((x, y), z)| {
271                (
272                    ((x - origin[0]) * inv_leaf).floor() as i64,
273                    ((y - origin[1]) * inv_leaf).floor() as i64,
274                    ((z - origin[2]) * inv_leaf).floor() as i64,
275                )
276            })
277            .collect();
278
279        assert_eq!(gpu_keys, cpu_keys);
280    }
281}