Skip to main content

spatialrust_gpu/kernels/
voxel_sort.rs

1use bytemuck::{Pod, Zeroable};
2use spatialrust_core::SpatialResult;
3use wgpu::util::DeviceExt;
4
5use crate::kernels::gpu_segments::GpuVoxelSegments;
6use crate::kernels::voxel_compact::compact_voxel_segments_gpu_buffers;
7use crate::kernels::voxel_keys::VoxelKeyOutput;
8use crate::kernels::voxel_segments::VoxelSegments;
9use crate::runtime::WgpuRuntime;
10
11const WORKGROUP_SIZE: u32 = 256;
12
13#[repr(C)]
14#[derive(Clone, Copy, Debug, Pod, Zeroable)]
15struct SortParams {
16    padded_count: u32,
17    pair_distance: u32,
18    block_width: u32,
19    _pad: u32,
20}
21
22#[repr(C)]
23#[derive(Clone, Copy, Debug, Default, Pod, Zeroable)]
24pub(crate) struct VoxelSortEntry {
25    ix: i32,
26    iy: i32,
27    iz: i32,
28    point_index: u32,
29}
30
31#[repr(C)]
32#[derive(Clone, Copy, Debug, Pod, Zeroable)]
33struct BuildEntriesParams {
34    point_count: u32,
35    padded_count: u32,
36    _pad0: u32,
37    _pad1: u32,
38}
39
40/// Sorts per-point voxel keys on the GPU and compacts them into segments.
41pub fn build_voxel_segments_gpu(
42    runtime: &WgpuRuntime,
43    keys: &[(i64, i64, i64)],
44) -> SpatialResult<VoxelSegments> {
45    let gpu_segments = build_voxel_segments_gpu_from_keys(runtime, keys)?;
46    gpu_segments.to_voxel_segments(runtime)
47}
48
49/// Builds GPU-resident voxel segments from CPU-side keys.
50pub fn build_voxel_segments_gpu_from_keys(
51    runtime: &WgpuRuntime,
52    keys: &[(i64, i64, i64)],
53) -> SpatialResult<GpuVoxelSegments> {
54    if keys.is_empty() {
55        return empty_gpu_segments(runtime);
56    }
57
58    let point_count = keys.len();
59    let padded_count = point_count.next_power_of_two();
60    let mut key_outputs = vec![VoxelKeyOutput::default(); point_count];
61    for (index, (ix, iy, iz)) in keys.iter().copied().enumerate() {
62        key_outputs[index] = VoxelKeyOutput {
63            ix: ix.clamp(i32::MIN as i64, i32::MAX as i64) as i32,
64            iy: iy.clamp(i32::MIN as i64, i32::MAX as i64) as i32,
65            iz: iz.clamp(i32::MIN as i64, i32::MAX as i64) as i32,
66            _pad: 0,
67        };
68    }
69
70    let device = runtime.device();
71    let keys_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
72        label: Some("voxel-sort-keys-input"),
73        contents: bytemuck::cast_slice(&key_outputs),
74        usage: wgpu::BufferUsages::STORAGE,
75    });
76
77    build_voxel_segments_gpu_from_keys_buffer(
78        runtime,
79        &keys_buffer,
80        point_count as u32,
81        padded_count as u32,
82    )
83}
84
85/// Builds GPU-resident voxel segments from a GPU keys buffer.
86pub fn build_voxel_segments_gpu_from_keys_buffer(
87    runtime: &WgpuRuntime,
88    keys_buffer: &wgpu::Buffer,
89    point_count: u32,
90    padded_count: u32,
91) -> SpatialResult<GpuVoxelSegments> {
92    if point_count == 0 {
93        return empty_gpu_segments(runtime);
94    }
95
96    let entries_buffer =
97        build_sort_entries_from_keys_gpu(runtime, keys_buffer, point_count, padded_count)?;
98    let sorted_buffer = sort_entries_gpu(runtime, entries_buffer, padded_count)?;
99    let compact_entries =
100        filter_valid_sorted_entries(runtime, &sorted_buffer, padded_count, point_count)?;
101    compact_voxel_segments_gpu_buffers(runtime, &compact_entries, point_count)
102}
103
104fn build_sort_entries_from_keys_gpu(
105    runtime: &WgpuRuntime,
106    keys_buffer: &wgpu::Buffer,
107    point_count: u32,
108    padded_count: u32,
109) -> SpatialResult<wgpu::Buffer> {
110    let device = runtime.device();
111    let queue = runtime.queue();
112
113    let entries_buffer = device.create_buffer(&wgpu::BufferDescriptor {
114        label: Some("voxel-sort-entries-build"),
115        size: (padded_count as usize * std::mem::size_of::<VoxelSortEntry>()) as u64,
116        usage: wgpu::BufferUsages::STORAGE
117            | wgpu::BufferUsages::COPY_DST
118            | wgpu::BufferUsages::COPY_SRC,
119        mapped_at_creation: false,
120    });
121
122    let params = BuildEntriesParams { point_count, padded_count, _pad0: 0, _pad1: 0 };
123    let params_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
124        label: Some("voxel-sort-build-params"),
125        contents: bytemuck::bytes_of(&params),
126        usage: wgpu::BufferUsages::UNIFORM,
127    });
128
129    let pipelines = runtime.pipelines();
130
131    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
132        label: Some("voxel-sort-build-bind-group"),
133        layout: &pipelines.voxel_sort_build.bind_group_layout,
134        entries: &[
135            wgpu::BindGroupEntry { binding: 0, resource: params_buffer.as_entire_binding() },
136            wgpu::BindGroupEntry { binding: 1, resource: keys_buffer.as_entire_binding() },
137            wgpu::BindGroupEntry { binding: 2, resource: entries_buffer.as_entire_binding() },
138        ],
139    });
140
141    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
142        label: Some("voxel-sort-build-encoder"),
143    });
144    {
145        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
146            label: Some("voxel-sort-build-pass"),
147            timestamp_writes: None,
148        });
149        pass.set_pipeline(&pipelines.voxel_sort_build.pipeline);
150        pass.set_bind_group(0, &bind_group, &[]);
151        pass.dispatch_workgroups(padded_count.div_ceil(WORKGROUP_SIZE), 1, 1);
152    }
153    queue.submit(Some(encoder.finish()));
154
155    Ok(entries_buffer)
156}
157
158fn empty_gpu_segments(runtime: &WgpuRuntime) -> SpatialResult<GpuVoxelSegments> {
159    let device = runtime.device();
160    let make_empty = || {
161        device.create_buffer(&wgpu::BufferDescriptor {
162            label: Some("voxel-sort-empty"),
163            size: 4,
164            usage: wgpu::BufferUsages::STORAGE,
165            mapped_at_creation: false,
166        })
167    };
168    Ok(GpuVoxelSegments::new(0, 0, make_empty(), make_empty(), make_empty()))
169}
170
171fn filter_valid_sorted_entries(
172    runtime: &WgpuRuntime,
173    entries_buffer: &wgpu::Buffer,
174    padded_count: u32,
175    point_count: u32,
176) -> SpatialResult<wgpu::Buffer> {
177    let device = runtime.device();
178    let queue = runtime.queue();
179    let buffer_len = padded_count as u64;
180    let output_len = (point_count as usize * std::mem::size_of::<VoxelSortEntry>()) as u64;
181
182    let flags_buffer = device.create_buffer(&wgpu::BufferDescriptor {
183        label: Some("voxel-sort-filter-flags"),
184        size: buffer_len * std::mem::size_of::<u32>() as u64,
185        usage: wgpu::BufferUsages::STORAGE
186            | wgpu::BufferUsages::COPY_DST
187            | wgpu::BufferUsages::COPY_SRC,
188        mapped_at_creation: false,
189    });
190    let inclusive_buffer = device.create_buffer(&wgpu::BufferDescriptor {
191        label: Some("voxel-sort-filter-inclusive"),
192        size: buffer_len * std::mem::size_of::<u32>() as u64,
193        usage: wgpu::BufferUsages::STORAGE
194            | wgpu::BufferUsages::COPY_SRC
195            | wgpu::BufferUsages::COPY_DST,
196        mapped_at_creation: false,
197    });
198    let scan_scratch_buffer = device.create_buffer(&wgpu::BufferDescriptor {
199        label: Some("voxel-sort-filter-scan-scratch"),
200        size: buffer_len * std::mem::size_of::<u32>() as u64,
201        usage: wgpu::BufferUsages::STORAGE
202            | wgpu::BufferUsages::COPY_DST
203            | wgpu::BufferUsages::COPY_SRC,
204        mapped_at_creation: false,
205    });
206    let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
207        label: Some("voxel-sort-filter-output"),
208        size: output_len,
209        usage: wgpu::BufferUsages::STORAGE,
210        mapped_at_creation: false,
211    });
212
213    queue.write_buffer(
214        &flags_buffer,
215        0,
216        &vec![0u8; (buffer_len * std::mem::size_of::<u32>() as u64) as usize],
217    );
218    queue.write_buffer(
219        &inclusive_buffer,
220        0,
221        &vec![0u8; (buffer_len * std::mem::size_of::<u32>() as u64) as usize],
222    );
223    queue.write_buffer(
224        &scan_scratch_buffer,
225        0,
226        &vec![0u8; (buffer_len * std::mem::size_of::<u32>() as u64) as usize],
227    );
228
229    let params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
230        label: Some("voxel-sort-filter-params"),
231        size: std::mem::size_of::<FilterParams>() as u64,
232        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
233        mapped_at_creation: false,
234    });
235
236    let pipelines = runtime.pipelines();
237    let layout = &pipelines.voxel_sort_filter.bind_group_layout;
238
239    let mark_bind_group = create_filter_bind_group(
240        device,
241        layout,
242        &params_buffer,
243        entries_buffer,
244        &flags_buffer,
245        &scan_scratch_buffer,
246        &inclusive_buffer,
247        &output_buffer,
248    );
249    let init_bind_group = create_filter_bind_group(
250        device,
251        layout,
252        &params_buffer,
253        entries_buffer,
254        &flags_buffer,
255        &scan_scratch_buffer,
256        &inclusive_buffer,
257        &output_buffer,
258    );
259
260    let dispatch_padded = padded_count.div_ceil(WORKGROUP_SIZE);
261
262    write_filter_params(queue, &params_buffer, point_count, padded_count, 0);
263    {
264        let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
265            label: Some("voxel-sort-filter-mark-encoder"),
266        });
267        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
268            label: Some("voxel-sort-filter-mark-pass"),
269            timestamp_writes: None,
270        });
271        pass.set_pipeline(&pipelines.voxel_sort_filter.mark);
272        pass.set_bind_group(0, &mark_bind_group, &[]);
273        pass.dispatch_workgroups(dispatch_padded, 1, 1);
274        drop(pass);
275        queue.submit(Some(encoder.finish()));
276    }
277
278    write_filter_params(queue, &params_buffer, point_count, padded_count, 0);
279    {
280        let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
281            label: Some("voxel-sort-filter-init-encoder"),
282        });
283        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
284            label: Some("voxel-sort-filter-init-pass"),
285            timestamp_writes: None,
286        });
287        pass.set_pipeline(&pipelines.voxel_sort_filter.init);
288        pass.set_bind_group(0, &init_bind_group, &[]);
289        pass.dispatch_workgroups(dispatch_padded, 1, 1);
290        drop(pass);
291        queue.submit(Some(encoder.finish()));
292    }
293
294    let mut scan_read = &inclusive_buffer;
295    let mut scan_write = &scan_scratch_buffer;
296    let mut stride = 1u32;
297    while stride < padded_count {
298        write_filter_params(queue, &params_buffer, point_count, padded_count, stride);
299        let scan_bind_group = create_filter_bind_group(
300            device,
301            layout,
302            &params_buffer,
303            entries_buffer,
304            &flags_buffer,
305            scan_read,
306            scan_write,
307            &output_buffer,
308        );
309        let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
310            label: Some("voxel-sort-filter-scan-encoder"),
311        });
312        {
313            let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
314                label: Some("voxel-sort-filter-scan-pass"),
315                timestamp_writes: None,
316            });
317            pass.set_pipeline(&pipelines.voxel_sort_filter.scan);
318            pass.set_bind_group(0, &scan_bind_group, &[]);
319            pass.dispatch_workgroups(dispatch_padded, 1, 1);
320        }
321        queue.submit(Some(encoder.finish()));
322        std::mem::swap(&mut scan_read, &mut scan_write);
323        stride *= 2;
324    }
325
326    let valid_count = read_filter_valid_count(device, queue, scan_read, padded_count)?;
327    if valid_count != point_count as usize {
328        return Err(spatialrust_core::SpatialError::InvalidArgument(format!(
329            "expected {point_count} sorted voxel entries, found {valid_count}"
330        )));
331    }
332
333    write_filter_params(queue, &params_buffer, point_count, padded_count, 0);
334    let scatter_bind_group = create_filter_bind_group(
335        device,
336        layout,
337        &params_buffer,
338        entries_buffer,
339        &flags_buffer,
340        scan_read,
341        scan_write,
342        &output_buffer,
343    );
344    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
345        label: Some("voxel-sort-filter-scatter-encoder"),
346    });
347    {
348        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
349            label: Some("voxel-sort-filter-scatter-pass"),
350            timestamp_writes: None,
351        });
352        pass.set_pipeline(&pipelines.voxel_sort_filter.scatter);
353        pass.set_bind_group(0, &scatter_bind_group, &[]);
354        pass.dispatch_workgroups(dispatch_padded, 1, 1);
355    }
356    queue.submit(Some(encoder.finish()));
357
358    Ok(output_buffer)
359}
360
361fn create_filter_bind_group(
362    device: &wgpu::Device,
363    layout: &wgpu::BindGroupLayout,
364    params_buffer: &wgpu::Buffer,
365    entries_buffer: &wgpu::Buffer,
366    flags_buffer: &wgpu::Buffer,
367    scan_in: &wgpu::Buffer,
368    scan_out: &wgpu::Buffer,
369    output_buffer: &wgpu::Buffer,
370) -> wgpu::BindGroup {
371    device.create_bind_group(&wgpu::BindGroupDescriptor {
372        label: Some("voxel-sort-filter-bind-group"),
373        layout,
374        entries: &[
375            wgpu::BindGroupEntry { binding: 0, resource: params_buffer.as_entire_binding() },
376            wgpu::BindGroupEntry { binding: 1, resource: entries_buffer.as_entire_binding() },
377            wgpu::BindGroupEntry { binding: 2, resource: flags_buffer.as_entire_binding() },
378            wgpu::BindGroupEntry { binding: 3, resource: scan_in.as_entire_binding() },
379            wgpu::BindGroupEntry { binding: 4, resource: scan_out.as_entire_binding() },
380            wgpu::BindGroupEntry { binding: 5, resource: output_buffer.as_entire_binding() },
381        ],
382    })
383}
384
385#[repr(C)]
386#[derive(Clone, Copy, Debug, Pod, Zeroable)]
387struct FilterParams {
388    point_count: u32,
389    padded_count: u32,
390    scan_stride: u32,
391    _pad: u32,
392}
393
394fn write_filter_params(
395    queue: &wgpu::Queue,
396    params_buffer: &wgpu::Buffer,
397    point_count: u32,
398    padded_count: u32,
399    scan_stride: u32,
400) {
401    let params = FilterParams { point_count, padded_count, scan_stride, _pad: 0 };
402    queue.write_buffer(params_buffer, 0, bytemuck::bytes_of(&params));
403}
404
405fn read_filter_valid_count(
406    device: &wgpu::Device,
407    queue: &wgpu::Queue,
408    inclusive_buffer: &wgpu::Buffer,
409    padded_count: u32,
410) -> SpatialResult<usize> {
411    let offset = ((padded_count as u64).saturating_sub(1)) * std::mem::size_of::<u32>() as u64;
412    let staging = device.create_buffer(&wgpu::BufferDescriptor {
413        label: Some("voxel-sort-filter-count-staging"),
414        size: std::mem::size_of::<u32>() as u64,
415        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
416        mapped_at_creation: false,
417    });
418
419    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
420        label: Some("voxel-sort-filter-count-encoder"),
421    });
422    encoder.copy_buffer_to_buffer(inclusive_buffer, offset, &staging, 0, staging.size());
423    queue.submit(Some(encoder.finish()));
424
425    let slice = staging.slice(..);
426    let (sender, receiver) = std::sync::mpsc::channel();
427    slice.map_async(wgpu::MapMode::Read, move |result| {
428        let _ = sender.send(result);
429    });
430    device.poll(wgpu::Maintain::Wait);
431    receiver
432        .recv()
433        .map_err(|_| {
434            spatialrust_core::SpatialError::InvalidArgument(
435                "failed to receive wgpu map result".to_owned(),
436            )
437        })?
438        .map_err(|error| {
439            spatialrust_core::SpatialError::InvalidArgument(format!(
440                "failed to map wgpu buffer: {error}"
441            ))
442        })?;
443
444    let data = slice.get_mapped_range();
445    let count = bytemuck::cast_slice::<u8, u32>(&data)[0] as usize;
446    drop(data);
447    staging.unmap();
448    Ok(count)
449}
450
451fn sort_entries_gpu(
452    runtime: &WgpuRuntime,
453    entries_buffer: wgpu::Buffer,
454    padded_count: u32,
455) -> SpatialResult<wgpu::Buffer> {
456    let device = runtime.device();
457    let queue = runtime.queue();
458    let pipelines = runtime.pipelines();
459
460    let params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
461        label: Some("voxel-sort-params"),
462        size: std::mem::size_of::<SortParams>() as u64,
463        usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
464        mapped_at_creation: false,
465    });
466
467    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
468        label: Some("voxel-sort-bind-group"),
469        layout: &pipelines.voxel_sort.bind_group_layout,
470        entries: &[
471            wgpu::BindGroupEntry { binding: 0, resource: params_buffer.as_entire_binding() },
472            wgpu::BindGroupEntry { binding: 1, resource: entries_buffer.as_entire_binding() },
473        ],
474    });
475
476    let mut k = 2u32;
477    while k <= padded_count {
478        let mut j = k / 2;
479        while j >= 1 {
480            let params = SortParams { padded_count, pair_distance: j, block_width: k, _pad: 0 };
481            queue.write_buffer(&params_buffer, 0, bytemuck::bytes_of(&params));
482
483            let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
484                label: Some("voxel-sort-encoder"),
485            });
486            {
487                let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
488                    label: Some("voxel-sort-pass"),
489                    timestamp_writes: None,
490                });
491                pass.set_pipeline(&pipelines.voxel_sort.pipeline);
492                pass.set_bind_group(0, &bind_group, &[]);
493                pass.dispatch_workgroups(padded_count.div_ceil(WORKGROUP_SIZE), 1, 1);
494            }
495            queue.submit(Some(encoder.finish()));
496            j /= 2;
497        }
498        k *= 2;
499    }
500
501    Ok(entries_buffer)
502}
503
504#[cfg(test)]
505mod tests {
506    use super::build_voxel_segments_gpu;
507    use crate::kernels::voxel_segments::build_voxel_segments;
508    use crate::runtime::WgpuRuntime;
509
510    #[test]
511    fn gpu_segment_build_matches_cpu_reference() {
512        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
513        let keys = vec![(0, 0, 0), (1, 0, 0), (0, 0, 0), (1, 0, 0), (2, 1, 0)];
514        let cpu = build_voxel_segments(&keys);
515        let gpu = build_voxel_segments_gpu(&runtime, &keys).expect("gpu segments");
516        assert_eq!(cpu, gpu);
517    }
518}