1use bytemuck::{Pod, Zeroable};
2use spatialrust_core::{SpatialError, SpatialResult};
3use wgpu::util::DeviceExt;
4
5use crate::runtime::WgpuRuntime;
6
7const WORKGROUP_SIZE: u32 = 256;
8
9#[repr(C)]
10#[derive(Clone, Copy, Debug, Pod, Zeroable)]
11struct VoxelKeyUniform {
12 origin: [f32; 4],
13 inv_leaf: f32,
14 point_count: u32,
15 _pad0: u32,
16 _pad1: u32,
17}
18
19#[repr(C)]
20#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
21pub(crate) struct VoxelKeyOutput {
22 pub(crate) ix: i32,
23 pub(crate) iy: i32,
24 pub(crate) iz: i32,
25 pub(crate) _pad: i32,
26}
27
28pub struct GpuVoxelKeyBuffers {
30 x: wgpu::Buffer,
31 y: wgpu::Buffer,
32 z: wgpu::Buffer,
33 keys: wgpu::Buffer,
34 point_count: u32,
35}
36
37impl GpuVoxelKeyBuffers {
38 #[must_use]
40 pub fn point_count(&self) -> u32 {
41 self.point_count
42 }
43
44 #[must_use]
46 pub fn x_buffer(&self) -> &wgpu::Buffer {
47 &self.x
48 }
49
50 #[must_use]
52 pub fn y_buffer(&self) -> &wgpu::Buffer {
53 &self.y
54 }
55
56 #[must_use]
58 pub fn z_buffer(&self) -> &wgpu::Buffer {
59 &self.z
60 }
61
62 #[must_use]
64 pub fn keys_buffer(&self) -> &wgpu::Buffer {
65 &self.keys
66 }
67
68 pub fn recycle(self, runtime: &WgpuRuntime) {
70 runtime.recycle_storage(self.x.size(), self.x);
71 runtime.recycle_storage(self.y.size(), self.y);
72 runtime.recycle_storage(self.z.size(), self.z);
73 runtime.recycle_storage(self.keys.size(), self.keys);
74 }
75}
76
77pub fn compute_voxel_keys(
79 runtime: &WgpuRuntime,
80 x: &[f32],
81 y: &[f32],
82 z: &[f32],
83 origin: [f32; 3],
84 inv_leaf: f32,
85) -> SpatialResult<Vec<(i64, i64, i64)>> {
86 if x.len() != y.len() || x.len() != z.len() {
87 return Err(SpatialError::BufferLengthMismatch { expected: x.len(), found: y.len() });
88 }
89 if x.is_empty() {
90 return Ok(Vec::new());
91 }
92
93 let buffers = compute_voxel_keys_gpu_buffers(runtime, x, y, z, origin, inv_leaf)?;
94 read_voxel_keys(runtime, &buffers)
95}
96
97pub fn compute_voxel_keys_gpu_buffers(
99 runtime: &WgpuRuntime,
100 x: &[f32],
101 y: &[f32],
102 z: &[f32],
103 origin: [f32; 3],
104 inv_leaf: f32,
105) -> SpatialResult<GpuVoxelKeyBuffers> {
106 if x.len() != y.len() || x.len() != z.len() {
107 return Err(SpatialError::BufferLengthMismatch { expected: x.len(), found: y.len() });
108 }
109 if x.is_empty() {
110 return Err(SpatialError::InvalidArgument(
111 "cannot compute voxel keys for an empty point cloud".to_owned(),
112 ));
113 }
114
115 let point_count = x.len() as u32;
116
117 let x_buffer = runtime.upload_f32_storage("voxel-key-x", x)?;
118 let y_buffer = runtime.upload_f32_storage("voxel-key-y", y)?;
119 let z_buffer = runtime.upload_f32_storage("voxel-key-z", z)?;
120
121 let keys_buffer = dispatch_voxel_keys(
122 runtime,
123 &x_buffer,
124 &y_buffer,
125 &z_buffer,
126 origin,
127 inv_leaf,
128 point_count,
129 )?;
130
131 Ok(GpuVoxelKeyBuffers { x: x_buffer, y: y_buffer, z: z_buffer, keys: keys_buffer, point_count })
132}
133
134fn read_voxel_keys(
135 runtime: &WgpuRuntime,
136 buffers: &GpuVoxelKeyBuffers,
137) -> SpatialResult<Vec<(i64, i64, i64)>> {
138 let device = runtime.device();
139 let queue = runtime.queue();
140 let point_count = buffers.point_count() as usize;
141 let output_len = point_count * std::mem::size_of::<VoxelKeyOutput>();
142
143 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
144 label: Some("voxel-key-staging"),
145 size: output_len as u64,
146 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
147 mapped_at_creation: false,
148 });
149
150 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
151 label: Some("voxel-key-readback-encoder"),
152 });
153 encoder.copy_buffer_to_buffer(buffers.keys_buffer(), 0, &staging_buffer, 0, output_len as u64);
154 queue.submit(Some(encoder.finish()));
155
156 let slice = staging_buffer.slice(..);
157 let (sender, receiver) = std::sync::mpsc::channel();
158 slice.map_async(wgpu::MapMode::Read, move |result| {
159 let _ = sender.send(result);
160 });
161 device.poll(wgpu::Maintain::Wait);
162 receiver
163 .recv()
164 .map_err(|_| SpatialError::InvalidArgument("failed to receive wgpu map result".to_owned()))?
165 .map_err(|error| {
166 SpatialError::InvalidArgument(format!("failed to map wgpu buffer: {error}"))
167 })?;
168
169 let data = slice.get_mapped_range();
170 let outputs: &[VoxelKeyOutput] = bytemuck::cast_slice(&data);
171 let keys = outputs
172 .iter()
173 .map(|key| (i64::from(key.ix), i64::from(key.iy), i64::from(key.iz)))
174 .collect();
175 drop(data);
176 staging_buffer.unmap();
177
178 Ok(keys)
179}
180
181fn dispatch_voxel_keys(
182 runtime: &WgpuRuntime,
183 x_buffer: &wgpu::Buffer,
184 y_buffer: &wgpu::Buffer,
185 z_buffer: &wgpu::Buffer,
186 origin: [f32; 3],
187 inv_leaf: f32,
188 point_count: u32,
189) -> SpatialResult<wgpu::Buffer> {
190 let device = runtime.device();
191 let queue = runtime.queue();
192
193 let uniform = VoxelKeyUniform {
194 origin: [origin[0], origin[1], origin[2], 0.0],
195 inv_leaf,
196 point_count,
197 _pad0: 0,
198 _pad1: 0,
199 };
200
201 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
202 label: Some("voxel-key-uniform"),
203 contents: bytemuck::bytes_of(&uniform),
204 usage: wgpu::BufferUsages::UNIFORM,
205 });
206
207 let output_len = point_count as usize * std::mem::size_of::<VoxelKeyOutput>();
208 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
209 label: Some("voxel-key-output"),
210 size: output_len as u64,
211 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
212 mapped_at_creation: false,
213 });
214
215 let pipelines = runtime.pipelines();
216
217 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
218 label: Some("voxel-key-bind-group"),
219 layout: &pipelines.voxel_keys.bind_group_layout,
220 entries: &[
221 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
222 wgpu::BindGroupEntry { binding: 1, resource: x_buffer.as_entire_binding() },
223 wgpu::BindGroupEntry { binding: 2, resource: y_buffer.as_entire_binding() },
224 wgpu::BindGroupEntry { binding: 3, resource: z_buffer.as_entire_binding() },
225 wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
226 ],
227 });
228
229 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
230 label: Some("voxel-key-encoder"),
231 });
232
233 {
234 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
235 label: Some("voxel-key-pass"),
236 timestamp_writes: None,
237 });
238 pass.set_pipeline(&pipelines.voxel_keys.pipeline);
239 pass.set_bind_group(0, &bind_group, &[]);
240 let workgroups = point_count.div_ceil(WORKGROUP_SIZE);
241 pass.dispatch_workgroups(workgroups, 1, 1);
242 }
243
244 queue.submit(Some(encoder.finish()));
245
246 Ok(output_buffer)
247}
248
249#[cfg(test)]
250mod tests {
251 use super::compute_voxel_keys;
252 use crate::runtime::WgpuRuntime;
253
254 #[test]
255 fn gpu_voxel_keys_match_cpu_reference() {
256 let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
257 let x = [0.0_f32, 0.1, 1.0, 1.1];
258 let y = [0.0_f32, 0.0, 0.0, 0.0];
259 let z = [0.0_f32, 0.0, 0.0, 0.0];
260 let origin = [0.0_f32, 0.0, 0.0];
261 let inv_leaf = 2.0_f32;
262
263 let gpu_keys =
264 compute_voxel_keys(&runtime, &x, &y, &z, origin, inv_leaf).expect("gpu keys");
265
266 let cpu_keys: Vec<(i64, i64, i64)> = x
267 .iter()
268 .zip(y.iter())
269 .zip(z.iter())
270 .map(|((x, y), z)| {
271 (
272 ((x - origin[0]) * inv_leaf).floor() as i64,
273 ((y - origin[1]) * inv_leaf).floor() as i64,
274 ((z - origin[2]) * inv_leaf).floor() as i64,
275 )
276 })
277 .collect();
278
279 assert_eq!(gpu_keys, cpu_keys);
280 }
281}