1use bytemuck::{Pod, Zeroable};
2use spatialrust_core::SpatialResult;
3use wgpu::util::DeviceExt;
4
5use crate::kernels::gpu_segments::GpuVoxelSegments;
6use crate::kernels::voxel_compact::compact_voxel_segments_gpu_buffers;
7use crate::kernels::voxel_keys::VoxelKeyOutput;
8use crate::kernels::voxel_segments::VoxelSegments;
9use crate::runtime::WgpuRuntime;
10
11const WORKGROUP_SIZE: u32 = 256;
12
13#[repr(C)]
14#[derive(Clone, Copy, Debug, Pod, Zeroable)]
15struct SortParams {
16 padded_count: u32,
17 pair_distance: u32,
18 block_width: u32,
19 _pad: u32,
20}
21
22#[repr(C)]
23#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
24pub(crate) struct VoxelSortEntry {
25 ix: i32,
26 iy: i32,
27 iz: i32,
28 point_index: u32,
29}
30
31#[repr(C)]
32#[derive(Clone, Copy, Debug, Pod, Zeroable)]
33struct BuildEntriesParams {
34 point_count: u32,
35 padded_count: u32,
36 _pad0: u32,
37 _pad1: u32,
38}
39
40pub fn build_voxel_segments_gpu(
42 runtime: &WgpuRuntime,
43 keys: &[(i64, i64, i64)],
44) -> SpatialResult<VoxelSegments> {
45 let gpu_segments = build_voxel_segments_gpu_from_keys(runtime, keys)?;
46 gpu_segments.to_voxel_segments(runtime)
47}
48
49pub fn build_voxel_segments_gpu_from_keys(
51 runtime: &WgpuRuntime,
52 keys: &[(i64, i64, i64)],
53) -> SpatialResult<GpuVoxelSegments> {
54 if keys.is_empty() {
55 return empty_gpu_segments(runtime);
56 }
57
58 let point_count = keys.len();
59 let padded_count = point_count.next_power_of_two();
60 let mut key_outputs = vec![VoxelKeyOutput::default(); point_count];
61 for (index, (ix, iy, iz)) in keys.iter().copied().enumerate() {
62 key_outputs[index] = VoxelKeyOutput {
63 ix: ix.clamp(i32::MIN as i64, i32::MAX as i64) as i32,
64 iy: iy.clamp(i32::MIN as i64, i32::MAX as i64) as i32,
65 iz: iz.clamp(i32::MIN as i64, i32::MAX as i64) as i32,
66 _pad: 0,
67 };
68 }
69
70 let device = runtime.device();
71 let keys_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
72 label: Some("voxel-sort-keys-input"),
73 contents: bytemuck::cast_slice(&key_outputs),
74 usage: wgpu::BufferUsages::STORAGE,
75 });
76
77 build_voxel_segments_gpu_from_keys_buffer(
78 runtime,
79 &keys_buffer,
80 point_count as u32,
81 padded_count as u32,
82 )
83}
84
85pub fn build_voxel_segments_gpu_from_keys_buffer(
87 runtime: &WgpuRuntime,
88 keys_buffer: &wgpu::Buffer,
89 point_count: u32,
90 padded_count: u32,
91) -> SpatialResult<GpuVoxelSegments> {
92 if point_count == 0 {
93 return empty_gpu_segments(runtime);
94 }
95
96 let entries_buffer =
97 build_sort_entries_from_keys_gpu(runtime, keys_buffer, point_count, padded_count)?;
98 let sorted_buffer = sort_entries_gpu(runtime, entries_buffer, padded_count)?;
99 let compact_entries =
100 filter_valid_sorted_entries(runtime, &sorted_buffer, padded_count, point_count)?;
101 compact_voxel_segments_gpu_buffers(runtime, &compact_entries, point_count)
102}
103
104fn build_sort_entries_from_keys_gpu(
105 runtime: &WgpuRuntime,
106 keys_buffer: &wgpu::Buffer,
107 point_count: u32,
108 padded_count: u32,
109) -> SpatialResult<wgpu::Buffer> {
110 let device = runtime.device();
111 let queue = runtime.queue();
112
113 let entries_buffer = device.create_buffer(&wgpu::BufferDescriptor {
114 label: Some("voxel-sort-entries-build"),
115 size: (padded_count as usize * std::mem::size_of::<VoxelSortEntry>()) as u64,
116 usage: wgpu::BufferUsages::STORAGE
117 | wgpu::BufferUsages::COPY_DST
118 | wgpu::BufferUsages::COPY_SRC,
119 mapped_at_creation: false,
120 });
121
122 let params = BuildEntriesParams { point_count, padded_count, _pad0: 0, _pad1: 0 };
123 let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
124 label: Some("voxel-sort-build-params"),
125 contents: bytemuck::bytes_of(¶ms),
126 usage: wgpu::BufferUsages::UNIFORM,
127 });
128
129 let pipelines = runtime.pipelines();
130
131 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
132 label: Some("voxel-sort-build-bind-group"),
133 layout: &pipelines.voxel_sort_build.bind_group_layout,
134 entries: &[
135 wgpu::BindGroupEntry { binding: 0, resource: params_buffer.as_entire_binding() },
136 wgpu::BindGroupEntry { binding: 1, resource: keys_buffer.as_entire_binding() },
137 wgpu::BindGroupEntry { binding: 2, resource: entries_buffer.as_entire_binding() },
138 ],
139 });
140
141 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
142 label: Some("voxel-sort-build-encoder"),
143 });
144 {
145 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
146 label: Some("voxel-sort-build-pass"),
147 timestamp_writes: None,
148 });
149 pass.set_pipeline(&pipelines.voxel_sort_build.pipeline);
150 pass.set_bind_group(0, &bind_group, &[]);
151 pass.dispatch_workgroups(padded_count.div_ceil(WORKGROUP_SIZE), 1, 1);
152 }
153 queue.submit(Some(encoder.finish()));
154
155 Ok(entries_buffer)
156}
157
158fn empty_gpu_segments(runtime: &WgpuRuntime) -> SpatialResult<GpuVoxelSegments> {
159 let device = runtime.device();
160 let make_empty = || {
161 device.create_buffer(&wgpu::BufferDescriptor {
162 label: Some("voxel-sort-empty"),
163 size: 4,
164 usage: wgpu::BufferUsages::STORAGE,
165 mapped_at_creation: false,
166 })
167 };
168 Ok(GpuVoxelSegments::new(0, 0, make_empty(), make_empty(), make_empty()))
169}
170
171fn filter_valid_sorted_entries(
172 runtime: &WgpuRuntime,
173 entries_buffer: &wgpu::Buffer,
174 padded_count: u32,
175 point_count: u32,
176) -> SpatialResult<wgpu::Buffer> {
177 let device = runtime.device();
178 let queue = runtime.queue();
179 let buffer_len = padded_count as u64;
180 let output_len = (point_count as usize * std::mem::size_of::<VoxelSortEntry>()) as u64;
181
182 let flags_buffer = device.create_buffer(&wgpu::BufferDescriptor {
183 label: Some("voxel-sort-filter-flags"),
184 size: buffer_len * std::mem::size_of::<u32>() as u64,
185 usage: wgpu::BufferUsages::STORAGE
186 | wgpu::BufferUsages::COPY_DST
187 | wgpu::BufferUsages::COPY_SRC,
188 mapped_at_creation: false,
189 });
190 let inclusive_buffer = device.create_buffer(&wgpu::BufferDescriptor {
191 label: Some("voxel-sort-filter-inclusive"),
192 size: buffer_len * std::mem::size_of::<u32>() as u64,
193 usage: wgpu::BufferUsages::STORAGE
194 | wgpu::BufferUsages::COPY_SRC
195 | wgpu::BufferUsages::COPY_DST,
196 mapped_at_creation: false,
197 });
198 let scan_scratch_buffer = device.create_buffer(&wgpu::BufferDescriptor {
199 label: Some("voxel-sort-filter-scan-scratch"),
200 size: buffer_len * std::mem::size_of::<u32>() as u64,
201 usage: wgpu::BufferUsages::STORAGE
202 | wgpu::BufferUsages::COPY_DST
203 | wgpu::BufferUsages::COPY_SRC,
204 mapped_at_creation: false,
205 });
206 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
207 label: Some("voxel-sort-filter-output"),
208 size: output_len,
209 usage: wgpu::BufferUsages::STORAGE,
210 mapped_at_creation: false,
211 });
212
213 queue.write_buffer(
214 &flags_buffer,
215 0,
216 &vec![0u8; (buffer_len * std::mem::size_of::<u32>() as u64) as usize],
217 );
218 queue.write_buffer(
219 &inclusive_buffer,
220 0,
221 &vec![0u8; (buffer_len * std::mem::size_of::<u32>() as u64) as usize],
222 );
223 queue.write_buffer(
224 &scan_scratch_buffer,
225 0,
226 &vec![0u8; (buffer_len * std::mem::size_of::<u32>() as u64) as usize],
227 );
228
229 let params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
230 label: Some("voxel-sort-filter-params"),
231 size: std::mem::size_of::<FilterParams>() as u64,
232 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
233 mapped_at_creation: false,
234 });
235
236 let pipelines = runtime.pipelines();
237 let layout = &pipelines.voxel_sort_filter.bind_group_layout;
238
239 let mark_bind_group = create_filter_bind_group(
240 device,
241 layout,
242 ¶ms_buffer,
243 entries_buffer,
244 &flags_buffer,
245 &scan_scratch_buffer,
246 &inclusive_buffer,
247 &output_buffer,
248 );
249 let init_bind_group = create_filter_bind_group(
250 device,
251 layout,
252 ¶ms_buffer,
253 entries_buffer,
254 &flags_buffer,
255 &scan_scratch_buffer,
256 &inclusive_buffer,
257 &output_buffer,
258 );
259
260 let dispatch_padded = padded_count.div_ceil(WORKGROUP_SIZE);
261
262 write_filter_params(queue, ¶ms_buffer, point_count, padded_count, 0);
263 {
264 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
265 label: Some("voxel-sort-filter-mark-encoder"),
266 });
267 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
268 label: Some("voxel-sort-filter-mark-pass"),
269 timestamp_writes: None,
270 });
271 pass.set_pipeline(&pipelines.voxel_sort_filter.mark);
272 pass.set_bind_group(0, &mark_bind_group, &[]);
273 pass.dispatch_workgroups(dispatch_padded, 1, 1);
274 drop(pass);
275 queue.submit(Some(encoder.finish()));
276 }
277
278 write_filter_params(queue, ¶ms_buffer, point_count, padded_count, 0);
279 {
280 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
281 label: Some("voxel-sort-filter-init-encoder"),
282 });
283 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
284 label: Some("voxel-sort-filter-init-pass"),
285 timestamp_writes: None,
286 });
287 pass.set_pipeline(&pipelines.voxel_sort_filter.init);
288 pass.set_bind_group(0, &init_bind_group, &[]);
289 pass.dispatch_workgroups(dispatch_padded, 1, 1);
290 drop(pass);
291 queue.submit(Some(encoder.finish()));
292 }
293
294 let mut scan_read = &inclusive_buffer;
295 let mut scan_write = &scan_scratch_buffer;
296 let mut stride = 1u32;
297 while stride < padded_count {
298 write_filter_params(queue, ¶ms_buffer, point_count, padded_count, stride);
299 let scan_bind_group = create_filter_bind_group(
300 device,
301 layout,
302 ¶ms_buffer,
303 entries_buffer,
304 &flags_buffer,
305 scan_read,
306 scan_write,
307 &output_buffer,
308 );
309 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
310 label: Some("voxel-sort-filter-scan-encoder"),
311 });
312 {
313 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
314 label: Some("voxel-sort-filter-scan-pass"),
315 timestamp_writes: None,
316 });
317 pass.set_pipeline(&pipelines.voxel_sort_filter.scan);
318 pass.set_bind_group(0, &scan_bind_group, &[]);
319 pass.dispatch_workgroups(dispatch_padded, 1, 1);
320 }
321 queue.submit(Some(encoder.finish()));
322 std::mem::swap(&mut scan_read, &mut scan_write);
323 stride *= 2;
324 }
325
326 let valid_count = read_filter_valid_count(device, queue, scan_read, padded_count)?;
327 if valid_count != point_count as usize {
328 return Err(spatialrust_core::SpatialError::InvalidArgument(format!(
329 "expected {point_count} sorted voxel entries, found {valid_count}"
330 )));
331 }
332
333 write_filter_params(queue, ¶ms_buffer, point_count, padded_count, 0);
334 let scatter_bind_group = create_filter_bind_group(
335 device,
336 layout,
337 ¶ms_buffer,
338 entries_buffer,
339 &flags_buffer,
340 scan_read,
341 scan_write,
342 &output_buffer,
343 );
344 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
345 label: Some("voxel-sort-filter-scatter-encoder"),
346 });
347 {
348 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
349 label: Some("voxel-sort-filter-scatter-pass"),
350 timestamp_writes: None,
351 });
352 pass.set_pipeline(&pipelines.voxel_sort_filter.scatter);
353 pass.set_bind_group(0, &scatter_bind_group, &[]);
354 pass.dispatch_workgroups(dispatch_padded, 1, 1);
355 }
356 queue.submit(Some(encoder.finish()));
357
358 Ok(output_buffer)
359}
360
361fn create_filter_bind_group(
362 device: &wgpu::Device,
363 layout: &wgpu::BindGroupLayout,
364 params_buffer: &wgpu::Buffer,
365 entries_buffer: &wgpu::Buffer,
366 flags_buffer: &wgpu::Buffer,
367 scan_in: &wgpu::Buffer,
368 scan_out: &wgpu::Buffer,
369 output_buffer: &wgpu::Buffer,
370) -> wgpu::BindGroup {
371 device.create_bind_group(&wgpu::BindGroupDescriptor {
372 label: Some("voxel-sort-filter-bind-group"),
373 layout,
374 entries: &[
375 wgpu::BindGroupEntry { binding: 0, resource: params_buffer.as_entire_binding() },
376 wgpu::BindGroupEntry { binding: 1, resource: entries_buffer.as_entire_binding() },
377 wgpu::BindGroupEntry { binding: 2, resource: flags_buffer.as_entire_binding() },
378 wgpu::BindGroupEntry { binding: 3, resource: scan_in.as_entire_binding() },
379 wgpu::BindGroupEntry { binding: 4, resource: scan_out.as_entire_binding() },
380 wgpu::BindGroupEntry { binding: 5, resource: output_buffer.as_entire_binding() },
381 ],
382 })
383}
384
385#[repr(C)]
386#[derive(Clone, Copy, Debug, Pod, Zeroable)]
387struct FilterParams {
388 point_count: u32,
389 padded_count: u32,
390 scan_stride: u32,
391 _pad: u32,
392}
393
394fn write_filter_params(
395 queue: &wgpu::Queue,
396 params_buffer: &wgpu::Buffer,
397 point_count: u32,
398 padded_count: u32,
399 scan_stride: u32,
400) {
401 let params = FilterParams { point_count, padded_count, scan_stride, _pad: 0 };
402 queue.write_buffer(params_buffer, 0, bytemuck::bytes_of(¶ms));
403}
404
405fn read_filter_valid_count(
406 device: &wgpu::Device,
407 queue: &wgpu::Queue,
408 inclusive_buffer: &wgpu::Buffer,
409 padded_count: u32,
410) -> SpatialResult<usize> {
411 let offset = ((padded_count as u64).saturating_sub(1)) * std::mem::size_of::<u32>() as u64;
412 let staging = device.create_buffer(&wgpu::BufferDescriptor {
413 label: Some("voxel-sort-filter-count-staging"),
414 size: std::mem::size_of::<u32>() as u64,
415 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
416 mapped_at_creation: false,
417 });
418
419 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
420 label: Some("voxel-sort-filter-count-encoder"),
421 });
422 encoder.copy_buffer_to_buffer(inclusive_buffer, offset, &staging, 0, staging.size());
423 queue.submit(Some(encoder.finish()));
424
425 let slice = staging.slice(..);
426 let (sender, receiver) = std::sync::mpsc::channel();
427 slice.map_async(wgpu::MapMode::Read, move |result| {
428 let _ = sender.send(result);
429 });
430 device.poll(wgpu::Maintain::Wait);
431 receiver
432 .recv()
433 .map_err(|_| {
434 spatialrust_core::SpatialError::InvalidArgument(
435 "failed to receive wgpu map result".to_owned(),
436 )
437 })?
438 .map_err(|error| {
439 spatialrust_core::SpatialError::InvalidArgument(format!(
440 "failed to map wgpu buffer: {error}"
441 ))
442 })?;
443
444 let data = slice.get_mapped_range();
445 let count = bytemuck::cast_slice::<u8, u32>(&data)[0] as usize;
446 drop(data);
447 staging.unmap();
448 Ok(count)
449}
450
451fn sort_entries_gpu(
452 runtime: &WgpuRuntime,
453 entries_buffer: wgpu::Buffer,
454 padded_count: u32,
455) -> SpatialResult<wgpu::Buffer> {
456 let device = runtime.device();
457 let queue = runtime.queue();
458 let pipelines = runtime.pipelines();
459
460 let params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
461 label: Some("voxel-sort-params"),
462 size: std::mem::size_of::<SortParams>() as u64,
463 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
464 mapped_at_creation: false,
465 });
466
467 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
468 label: Some("voxel-sort-bind-group"),
469 layout: &pipelines.voxel_sort.bind_group_layout,
470 entries: &[
471 wgpu::BindGroupEntry { binding: 0, resource: params_buffer.as_entire_binding() },
472 wgpu::BindGroupEntry { binding: 1, resource: entries_buffer.as_entire_binding() },
473 ],
474 });
475
476 let mut k = 2u32;
477 while k <= padded_count {
478 let mut j = k / 2;
479 while j >= 1 {
480 let params = SortParams { padded_count, pair_distance: j, block_width: k, _pad: 0 };
481 queue.write_buffer(¶ms_buffer, 0, bytemuck::bytes_of(¶ms));
482
483 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
484 label: Some("voxel-sort-encoder"),
485 });
486 {
487 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
488 label: Some("voxel-sort-pass"),
489 timestamp_writes: None,
490 });
491 pass.set_pipeline(&pipelines.voxel_sort.pipeline);
492 pass.set_bind_group(0, &bind_group, &[]);
493 pass.dispatch_workgroups(padded_count.div_ceil(WORKGROUP_SIZE), 1, 1);
494 }
495 queue.submit(Some(encoder.finish()));
496 j /= 2;
497 }
498 k *= 2;
499 }
500
501 Ok(entries_buffer)
502}
503
504#[cfg(test)]
505mod tests {
506 use super::build_voxel_segments_gpu;
507 use crate::kernels::voxel_segments::build_voxel_segments;
508 use crate::runtime::WgpuRuntime;
509
510 #[test]
511 fn gpu_segment_build_matches_cpu_reference() {
512 let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
513 let keys = vec![(0, 0, 0), (1, 0, 0), (0, 0, 0), (1, 0, 0), (2, 1, 0)];
514 let cpu = build_voxel_segments(&keys);
515 let gpu = build_voxel_segments_gpu(&runtime, &keys).expect("gpu segments");
516 assert_eq!(cpu, gpu);
517 }
518}