Skip to main content

spatialrust_gpu/kernels/
voxel_reduce.rs

1use bytemuck::{Pod, Zeroable};
2use spatialrust_core::{SpatialError, SpatialResult};
3use wgpu::util::DeviceExt;
4
5use crate::kernels::gpu_segments::GpuVoxelSegments;
6use crate::kernels::voxel_segments::VoxelSegments;
7use crate::readback::{
8    pad_u8_for_gpu_storage, read_staging_f32, read_staging_f32_and_u8, split_channel_blocks,
9    split_u8_channel_blocks, split_xyz_and_attribute_blocks, split_xyz_blocks,
10    u8_output_staging_bytes, unpack_u8_outputs_from_u32_staging,
11};
12use crate::runtime::WgpuRuntime;
13
14const WORKGROUP_SIZE: u32 = 256;
15
16#[repr(C)]
17#[derive(Clone, Copy, Debug, Pod, Zeroable)]
18struct ReduceUniform {
19    cell_count: u32,
20    point_count: u32,
21    _pad0: u32,
22    _pad1: u32,
23}
24
25/// Averages per-point `f32` values within each voxel cell on the GPU.
26pub fn reduce_voxel_average_f32(
27    runtime: &WgpuRuntime,
28    values: &[f32],
29    segments: &VoxelSegments,
30) -> SpatialResult<Vec<f32>> {
31    if segments.is_empty() {
32        return Ok(Vec::new());
33    }
34    if values.is_empty() {
35        return Err(SpatialError::InvalidArgument("cannot reduce empty value buffer".to_owned()));
36    }
37
38    let device = runtime.device();
39    let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
40        label: Some("voxel-reduce-values"),
41        contents: bytemuck::cast_slice(values),
42        usage: wgpu::BufferUsages::STORAGE,
43    });
44    let indices_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
45        label: Some("voxel-reduce-indices"),
46        contents: bytemuck::cast_slice(&segments.point_indices),
47        usage: wgpu::BufferUsages::STORAGE,
48    });
49    let starts_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
50        label: Some("voxel-reduce-starts"),
51        contents: bytemuck::cast_slice(&segments.cell_starts),
52        usage: wgpu::BufferUsages::STORAGE,
53    });
54
55    dispatch_voxel_reduce_f32(
56        runtime,
57        &values_buffer,
58        &indices_buffer,
59        &starts_buffer,
60        segments.len() as u32,
61        segments.point_indices.len() as u32,
62    )
63}
64
65/// Averages per-point `f32` values using GPU-resident segment buffers.
66pub fn reduce_voxel_average_f32_gpu_buffers(
67    runtime: &WgpuRuntime,
68    values: &wgpu::Buffer,
69    segments: &GpuVoxelSegments,
70) -> SpatialResult<Vec<f32>> {
71    dispatch_voxel_reduce_f32(
72        runtime,
73        values,
74        segments.point_indices_buffer(),
75        segments.cell_starts_buffer(),
76        segments.cell_count(),
77        segments.point_count(),
78    )
79}
80
81/// Uploads `f32` values and averages them within GPU-resident voxel segments.
82pub fn reduce_voxel_average_f32_gpu(
83    runtime: &WgpuRuntime,
84    values: &[f32],
85    segments: &GpuVoxelSegments,
86) -> SpatialResult<Vec<f32>> {
87    if segments.cell_count() == 0 {
88        return Ok(Vec::new());
89    }
90    if values.len() != segments.point_count() as usize {
91        return Err(SpatialError::BufferLengthMismatch {
92            expected: segments.point_count() as usize,
93            found: values.len(),
94        });
95    }
96
97    let device = runtime.device();
98    let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
99        label: Some("voxel-reduce-values-upload"),
100        contents: bytemuck::cast_slice(values),
101        usage: wgpu::BufferUsages::STORAGE,
102    });
103    reduce_voxel_average_f32_gpu_buffers(runtime, &values_buffer, segments)
104}
105
106/// Uploads multiple `f32` channels and averages them with one GPU submit/readback.
107pub fn reduce_voxel_average_f32_multi_gpu(
108    runtime: &WgpuRuntime,
109    channels: &[&[f32]],
110    segments: &GpuVoxelSegments,
111) -> SpatialResult<Vec<Vec<f32>>> {
112    if channels.is_empty() {
113        return Ok(Vec::new());
114    }
115    if segments.cell_count() == 0 {
116        return Ok(vec![Vec::new(); channels.len()]);
117    }
118
119    let point_count = segments.point_count() as usize;
120    for channel in channels {
121        if channel.len() != point_count {
122            return Err(SpatialError::BufferLengthMismatch {
123                expected: point_count,
124                found: channel.len(),
125            });
126        }
127    }
128
129    let device = runtime.device();
130    let queue = runtime.queue();
131    let cell_count = segments.cell_count();
132    let cells = cell_count as usize;
133    let channel_len = cells * std::mem::size_of::<f32>();
134    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
135        label: Some("voxel-reduce-multi-staging"),
136        size: (channel_len * channels.len()) as u64,
137        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
138        mapped_at_creation: false,
139    });
140
141    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
142        label: Some("voxel-reduce-multi-encoder"),
143    });
144
145    for (channel_index, channel) in channels.iter().enumerate() {
146        let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
147            label: Some("voxel-reduce-multi-values"),
148            contents: bytemuck::cast_slice(channel),
149            usage: wgpu::BufferUsages::STORAGE,
150        });
151        let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
152            label: Some("voxel-reduce-multi-output"),
153            size: channel_len as u64,
154            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
155            mapped_at_creation: false,
156        });
157
158        record_voxel_reduce_f32_pass(
159            &mut encoder,
160            runtime,
161            &values_buffer,
162            segments.point_indices_buffer(),
163            segments.cell_starts_buffer(),
164            cell_count,
165            segments.point_count(),
166            &output_buffer,
167        )?;
168
169        encoder.copy_buffer_to_buffer(
170            &output_buffer,
171            0,
172            &staging_buffer,
173            (channel_len * channel_index) as u64,
174            channel_len as u64,
175        );
176    }
177
178    queue.submit(Some(encoder.finish()));
179
180    let flat = read_staging_f32(device, &staging_buffer, cells * channels.len())?;
181    Ok(split_channel_blocks(flat, channels.len(), cells))
182}
183
184/// Averages xyz and multiple f32/u8 attribute channels with one GPU submit/readback.
185pub fn reduce_voxel_centroids_xyz_and_average_multi_gpu(
186    runtime: &WgpuRuntime,
187    x: &wgpu::Buffer,
188    y: &wgpu::Buffer,
189    z: &wgpu::Buffer,
190    attribute_channels: &[&[f32]],
191    u8_attribute_channels: &[&[u8]],
192    segments: &GpuVoxelSegments,
193) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<Vec<f32>>, Vec<Vec<u8>>)> {
194    let attribute_count = attribute_channels.len();
195    let u8_attribute_count = u8_attribute_channels.len();
196    if segments.cell_count() == 0 {
197        return Ok((
198            Vec::new(),
199            Vec::new(),
200            Vec::new(),
201            vec![Vec::new(); attribute_count],
202            vec![Vec::new(); u8_attribute_count],
203        ));
204    }
205
206    let point_count = segments.point_count() as usize;
207    for channel in attribute_channels {
208        if channel.len() != point_count {
209            return Err(SpatialError::BufferLengthMismatch {
210                expected: point_count,
211                found: channel.len(),
212            });
213        }
214    }
215    for channel in u8_attribute_channels {
216        if channel.len() != point_count {
217            return Err(SpatialError::BufferLengthMismatch {
218                expected: point_count,
219                found: channel.len(),
220            });
221        }
222    }
223
224    let device = runtime.device();
225    let queue = runtime.queue();
226    let cell_count = segments.cell_count();
227    let cells = cell_count as usize;
228    let channel_len = cells * std::mem::size_of::<f32>();
229    let f32_channel_count = 3 + attribute_count;
230    let u8_staging_len = u8_output_staging_bytes(cells, u8_attribute_count);
231    let staging_size = channel_len * f32_channel_count + u8_staging_len;
232
233    let output_x = device.create_buffer(&wgpu::BufferDescriptor {
234        label: Some("voxel-reduce-xyz-out-x"),
235        size: channel_len as u64,
236        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
237        mapped_at_creation: false,
238    });
239    let output_y = device.create_buffer(&wgpu::BufferDescriptor {
240        label: Some("voxel-reduce-xyz-out-y"),
241        size: channel_len as u64,
242        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
243        mapped_at_creation: false,
244    });
245    let output_z = device.create_buffer(&wgpu::BufferDescriptor {
246        label: Some("voxel-reduce-xyz-out-z"),
247        size: channel_len as u64,
248        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
249        mapped_at_creation: false,
250    });
251    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
252        label: Some("voxel-reduce-xyz-attrs-staging"),
253        size: staging_size as u64,
254        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
255        mapped_at_creation: false,
256    });
257
258    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
259        label: Some("voxel-reduce-xyz-attrs-encoder"),
260    });
261
262    record_voxel_reduce_xyz_pass(
263        &mut encoder,
264        runtime,
265        x,
266        y,
267        z,
268        segments.point_indices_buffer(),
269        segments.cell_starts_buffer(),
270        cell_count,
271        segments.point_count(),
272        &output_x,
273        &output_y,
274        &output_z,
275    )?;
276
277    encoder.copy_buffer_to_buffer(&output_x, 0, &staging_buffer, 0, channel_len as u64);
278    encoder.copy_buffer_to_buffer(
279        &output_y,
280        0,
281        &staging_buffer,
282        channel_len as u64,
283        channel_len as u64,
284    );
285    encoder.copy_buffer_to_buffer(
286        &output_z,
287        0,
288        &staging_buffer,
289        (channel_len * 2) as u64,
290        channel_len as u64,
291    );
292
293    for (attribute_index, channel) in attribute_channels.iter().enumerate() {
294        let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
295            label: Some("voxel-reduce-xyz-attrs-values"),
296            contents: bytemuck::cast_slice(channel),
297            usage: wgpu::BufferUsages::STORAGE,
298        });
299        let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
300            label: Some("voxel-reduce-xyz-attrs-output"),
301            size: channel_len as u64,
302            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
303            mapped_at_creation: false,
304        });
305
306        record_voxel_reduce_f32_pass(
307            &mut encoder,
308            runtime,
309            &values_buffer,
310            segments.point_indices_buffer(),
311            segments.cell_starts_buffer(),
312            cell_count,
313            segments.point_count(),
314            &output_buffer,
315        )?;
316
317        encoder.copy_buffer_to_buffer(
318            &output_buffer,
319            0,
320            &staging_buffer,
321            (channel_len * (3 + attribute_index)) as u64,
322            channel_len as u64,
323        );
324    }
325
326    let u8_region_offset = (channel_len * f32_channel_count) as u64;
327    for (attribute_index, channel) in u8_attribute_channels.iter().enumerate() {
328        let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
329            label: Some("voxel-reduce-xyz-u8-values"),
330            contents: &pad_u8_for_gpu_storage(channel),
331            usage: wgpu::BufferUsages::STORAGE,
332        });
333        let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
334            label: Some("voxel-reduce-xyz-u8-output"),
335            size: (cells * std::mem::size_of::<u32>()) as u64,
336            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
337            mapped_at_creation: false,
338        });
339
340        record_voxel_reduce_u8_pass(
341            &mut encoder,
342            runtime,
343            &values_buffer,
344            segments.point_indices_buffer(),
345            segments.cell_starts_buffer(),
346            cell_count,
347            segments.point_count(),
348            &output_buffer,
349        )?;
350
351        encoder.copy_buffer_to_buffer(
352            &output_buffer,
353            0,
354            &staging_buffer,
355            u8_region_offset + attribute_index as u64 * (cells * std::mem::size_of::<u32>()) as u64,
356            (cells * std::mem::size_of::<u32>()) as u64,
357        );
358    }
359
360    queue.submit(Some(encoder.finish()));
361
362    let (flat, u8_raw) = read_staging_f32_and_u8(
363        device,
364        &staging_buffer,
365        cells * f32_channel_count,
366        u8_staging_len,
367    )?;
368    let u8_flat = if u8_attribute_count == 0 {
369        Vec::new()
370    } else {
371        unpack_u8_outputs_from_u32_staging(u8_raw, cells, u8_attribute_count)
372    };
373    let (out_x, out_y, out_z, attributes) =
374        split_xyz_and_attribute_blocks(flat, attribute_count, cells);
375    Ok((
376        out_x,
377        out_y,
378        out_z,
379        attributes,
380        split_u8_channel_blocks(u8_flat, u8_attribute_count, cells),
381    ))
382}
383
384/// Averages xyz and gathers the first f32/u8 attribute value per voxel with one readback.
385pub fn reduce_voxel_centroids_xyz_and_gather_first_multi_gpu(
386    runtime: &WgpuRuntime,
387    x: &wgpu::Buffer,
388    y: &wgpu::Buffer,
389    z: &wgpu::Buffer,
390    attribute_channels: &[&[f32]],
391    u8_attribute_channels: &[&[u8]],
392    segments: &GpuVoxelSegments,
393) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<Vec<f32>>, Vec<Vec<u8>>)> {
394    use crate::kernels::voxel_gather::record_gather_f32_attribute_channels_to_staging;
395    use crate::kernels::voxel_gather::record_voxel_gather_u8_pass;
396
397    let attribute_count = attribute_channels.len();
398    let u8_attribute_count = u8_attribute_channels.len();
399    if segments.cell_count() == 0 {
400        return Ok((
401            Vec::new(),
402            Vec::new(),
403            Vec::new(),
404            vec![Vec::new(); attribute_count],
405            vec![Vec::new(); u8_attribute_count],
406        ));
407    }
408
409    let point_count = segments.point_count() as usize;
410    for channel in attribute_channels {
411        if channel.len() != point_count {
412            return Err(SpatialError::BufferLengthMismatch {
413                expected: point_count,
414                found: channel.len(),
415            });
416        }
417    }
418    for channel in u8_attribute_channels {
419        if channel.len() != point_count {
420            return Err(SpatialError::BufferLengthMismatch {
421                expected: point_count,
422                found: channel.len(),
423            });
424        }
425    }
426
427    let device = runtime.device();
428    let queue = runtime.queue();
429    let cell_count = segments.cell_count();
430    let cells = cell_count as usize;
431    let channel_len = cells * std::mem::size_of::<f32>();
432    let f32_channel_count = 3 + attribute_count;
433    let u8_staging_len = u8_output_staging_bytes(cells, u8_attribute_count);
434    let staging_size = channel_len * f32_channel_count + u8_staging_len;
435    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
436        label: Some("voxel-reduce-xyz-gather-attrs-staging"),
437        size: staging_size as u64,
438        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
439        mapped_at_creation: false,
440    });
441    let output_x = device.create_buffer(&wgpu::BufferDescriptor {
442        label: Some("voxel-reduce-xyz-out-x"),
443        size: channel_len as u64,
444        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
445        mapped_at_creation: false,
446    });
447    let output_y = device.create_buffer(&wgpu::BufferDescriptor {
448        label: Some("voxel-reduce-xyz-out-y"),
449        size: channel_len as u64,
450        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
451        mapped_at_creation: false,
452    });
453    let output_z = device.create_buffer(&wgpu::BufferDescriptor {
454        label: Some("voxel-reduce-xyz-out-z"),
455        size: channel_len as u64,
456        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
457        mapped_at_creation: false,
458    });
459
460    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
461        label: Some("voxel-reduce-xyz-gather-attrs-encoder"),
462    });
463    let mut upload_recycle = Vec::new();
464    record_voxel_reduce_xyz_pass(
465        &mut encoder,
466        runtime,
467        x,
468        y,
469        z,
470        segments.point_indices_buffer(),
471        segments.cell_starts_buffer(),
472        cell_count,
473        segments.point_count(),
474        &output_x,
475        &output_y,
476        &output_z,
477    )?;
478    encoder.copy_buffer_to_buffer(&output_x, 0, &staging_buffer, 0, channel_len as u64);
479    encoder.copy_buffer_to_buffer(
480        &output_y,
481        0,
482        &staging_buffer,
483        channel_len as u64,
484        channel_len as u64,
485    );
486    encoder.copy_buffer_to_buffer(
487        &output_z,
488        0,
489        &staging_buffer,
490        (channel_len * 2) as u64,
491        channel_len as u64,
492    );
493
494    record_gather_f32_attribute_channels_to_staging(
495        &mut encoder,
496        runtime,
497        attribute_channels,
498        segments,
499        &staging_buffer,
500        channel_len as u64,
501        &mut upload_recycle,
502    )?;
503
504    let u8_region_offset = (channel_len * f32_channel_count) as u64;
505    for (attribute_index, channel) in u8_attribute_channels.iter().enumerate() {
506        let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
507            label: Some("voxel-gather-xyz-u8-values"),
508            contents: &pad_u8_for_gpu_storage(channel),
509            usage: wgpu::BufferUsages::STORAGE,
510        });
511        let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
512            label: Some("voxel-gather-xyz-u8-output"),
513            size: (cells * std::mem::size_of::<u32>()) as u64,
514            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
515            mapped_at_creation: false,
516        });
517        record_voxel_gather_u8_pass(
518            &mut encoder,
519            runtime,
520            &values_buffer,
521            segments.point_indices_buffer(),
522            segments.cell_starts_buffer(),
523            cell_count,
524            segments.point_count(),
525            &output_buffer,
526        )?;
527        encoder.copy_buffer_to_buffer(
528            &output_buffer,
529            0,
530            &staging_buffer,
531            u8_region_offset + attribute_index as u64 * (cells * std::mem::size_of::<u32>()) as u64,
532            (cells * std::mem::size_of::<u32>()) as u64,
533        );
534    }
535
536    queue.submit(Some(encoder.finish()));
537    let (flat, u8_raw) = read_staging_f32_and_u8(
538        device,
539        &staging_buffer,
540        cells * f32_channel_count,
541        u8_staging_len,
542    )?;
543    let u8_flat = if u8_attribute_count == 0 {
544        Vec::new()
545    } else {
546        unpack_u8_outputs_from_u32_staging(u8_raw, cells, u8_attribute_count)
547    };
548    let (out_x, out_y, out_z, attributes) =
549        split_xyz_and_attribute_blocks(flat, attribute_count, cells);
550    for buffer in upload_recycle {
551        runtime.recycle_storage(buffer.size(), buffer);
552    }
553    Ok((
554        out_x,
555        out_y,
556        out_z,
557        attributes,
558        split_u8_channel_blocks(u8_flat, u8_attribute_count, cells),
559    ))
560}
561
562pub(crate) fn record_voxel_reduce_f32_pass(
563    encoder: &mut wgpu::CommandEncoder,
564    runtime: &WgpuRuntime,
565    values: &wgpu::Buffer,
566    point_indices: &wgpu::Buffer,
567    cell_starts: &wgpu::Buffer,
568    cell_count: u32,
569    point_count: u32,
570    output_buffer: &wgpu::Buffer,
571) -> SpatialResult<()> {
572    if cell_count == 0 {
573        return Ok(());
574    }
575
576    let device = runtime.device();
577    let pipelines = runtime.pipelines();
578    let uniform = ReduceUniform { cell_count, point_count, _pad0: 0, _pad1: 0 };
579    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
580        label: Some("voxel-reduce-uniform"),
581        contents: bytemuck::bytes_of(&uniform),
582        usage: wgpu::BufferUsages::UNIFORM,
583    });
584
585    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
586        label: Some("voxel-reduce-bind-group"),
587        layout: &pipelines.voxel_reduce.bind_group_layout,
588        entries: &[
589            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
590            wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
591            wgpu::BindGroupEntry { binding: 2, resource: values.as_entire_binding() },
592            wgpu::BindGroupEntry { binding: 3, resource: cell_starts.as_entire_binding() },
593            wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
594        ],
595    });
596
597    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
598        label: Some("voxel-reduce-pass"),
599        timestamp_writes: None,
600    });
601    pass.set_pipeline(&pipelines.voxel_reduce.pipeline);
602    pass.set_bind_group(0, &bind_group, &[]);
603    pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
604    Ok(())
605}
606
607pub(crate) fn record_voxel_reduce_u8_pass(
608    encoder: &mut wgpu::CommandEncoder,
609    runtime: &WgpuRuntime,
610    values: &wgpu::Buffer,
611    point_indices: &wgpu::Buffer,
612    cell_starts: &wgpu::Buffer,
613    cell_count: u32,
614    point_count: u32,
615    output_buffer: &wgpu::Buffer,
616) -> SpatialResult<()> {
617    if cell_count == 0 {
618        return Ok(());
619    }
620
621    let device = runtime.device();
622    let pipelines = runtime.pipelines();
623    let uniform = ReduceUniform { cell_count, point_count, _pad0: 0, _pad1: 0 };
624    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
625        label: Some("voxel-reduce-u8-uniform"),
626        contents: bytemuck::bytes_of(&uniform),
627        usage: wgpu::BufferUsages::UNIFORM,
628    });
629
630    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
631        label: Some("voxel-reduce-u8-bind-group"),
632        layout: &pipelines.voxel_reduce.u8_bind_group_layout,
633        entries: &[
634            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
635            wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
636            wgpu::BindGroupEntry { binding: 2, resource: values.as_entire_binding() },
637            wgpu::BindGroupEntry { binding: 3, resource: cell_starts.as_entire_binding() },
638            wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
639        ],
640    });
641
642    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
643        label: Some("voxel-reduce-u8-pass"),
644        timestamp_writes: None,
645    });
646    pass.set_pipeline(&pipelines.voxel_reduce.u8_pipeline);
647    pass.set_bind_group(0, &bind_group, &[]);
648    pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
649    Ok(())
650}
651
652#[allow(clippy::too_many_arguments)]
653pub(crate) fn record_voxel_reduce_xyz_pass(
654    encoder: &mut wgpu::CommandEncoder,
655    runtime: &WgpuRuntime,
656    values_x: &wgpu::Buffer,
657    values_y: &wgpu::Buffer,
658    values_z: &wgpu::Buffer,
659    point_indices: &wgpu::Buffer,
660    cell_starts: &wgpu::Buffer,
661    cell_count: u32,
662    point_count: u32,
663    output_x: &wgpu::Buffer,
664    output_y: &wgpu::Buffer,
665    output_z: &wgpu::Buffer,
666) -> SpatialResult<()> {
667    if cell_count == 0 {
668        return Ok(());
669    }
670
671    let device = runtime.device();
672    let pipelines = runtime.pipelines();
673    let uniform = ReduceUniform { cell_count, point_count, _pad0: 0, _pad1: 0 };
674    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
675        label: Some("voxel-reduce-xyz-uniform"),
676        contents: bytemuck::bytes_of(&uniform),
677        usage: wgpu::BufferUsages::UNIFORM,
678    });
679
680    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
681        label: Some("voxel-reduce-xyz-bind-group"),
682        layout: &pipelines.voxel_reduce.xyz_bind_group_layout,
683        entries: &[
684            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
685            wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
686            wgpu::BindGroupEntry { binding: 2, resource: cell_starts.as_entire_binding() },
687            wgpu::BindGroupEntry { binding: 3, resource: values_x.as_entire_binding() },
688            wgpu::BindGroupEntry { binding: 4, resource: values_y.as_entire_binding() },
689            wgpu::BindGroupEntry { binding: 5, resource: values_z.as_entire_binding() },
690            wgpu::BindGroupEntry { binding: 6, resource: output_x.as_entire_binding() },
691            wgpu::BindGroupEntry { binding: 7, resource: output_y.as_entire_binding() },
692            wgpu::BindGroupEntry { binding: 8, resource: output_z.as_entire_binding() },
693        ],
694    });
695
696    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
697        label: Some("voxel-reduce-xyz-pass"),
698        timestamp_writes: None,
699    });
700    pass.set_pipeline(&pipelines.voxel_reduce.xyz_pipeline);
701    pass.set_bind_group(0, &bind_group, &[]);
702    pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
703    Ok(())
704}
705
706fn dispatch_voxel_reduce_f32(
707    runtime: &WgpuRuntime,
708    values: &wgpu::Buffer,
709    point_indices: &wgpu::Buffer,
710    cell_starts: &wgpu::Buffer,
711    cell_count: u32,
712    point_count: u32,
713) -> SpatialResult<Vec<f32>> {
714    if cell_count == 0 {
715        return Ok(Vec::new());
716    }
717
718    let device = runtime.device();
719    let queue = runtime.queue();
720
721    let uniform = ReduceUniform { cell_count, point_count, _pad0: 0, _pad1: 0 };
722
723    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
724        label: Some("voxel-reduce-uniform"),
725        contents: bytemuck::bytes_of(&uniform),
726        usage: wgpu::BufferUsages::UNIFORM,
727    });
728
729    let output_len = cell_count as usize * std::mem::size_of::<f32>();
730    let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
731        label: Some("voxel-reduce-output"),
732        size: output_len as u64,
733        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
734        mapped_at_creation: false,
735    });
736    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
737        label: Some("voxel-reduce-staging"),
738        size: output_len as u64,
739        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
740        mapped_at_creation: false,
741    });
742
743    let pipelines = runtime.pipelines();
744
745    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
746        label: Some("voxel-reduce-bind-group"),
747        layout: &pipelines.voxel_reduce.bind_group_layout,
748        entries: &[
749            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
750            wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
751            wgpu::BindGroupEntry { binding: 2, resource: values.as_entire_binding() },
752            wgpu::BindGroupEntry { binding: 3, resource: cell_starts.as_entire_binding() },
753            wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
754        ],
755    });
756
757    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
758        label: Some("voxel-reduce-encoder"),
759    });
760
761    {
762        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
763            label: Some("voxel-reduce-pass"),
764            timestamp_writes: None,
765        });
766        pass.set_pipeline(&pipelines.voxel_reduce.pipeline);
767        pass.set_bind_group(0, &bind_group, &[]);
768        pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
769    }
770
771    encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_len as u64);
772    queue.submit(Some(encoder.finish()));
773
774    read_staging_f32(device, &staging_buffer, cell_count as usize)
775}
776
777/// Averages xyz positions within each voxel cell on the GPU.
778pub fn reduce_voxel_centroids_xyz(
779    runtime: &WgpuRuntime,
780    x: &[f32],
781    y: &[f32],
782    z: &[f32],
783    segments: &VoxelSegments,
784) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>)> {
785    if x.len() != y.len() || x.len() != z.len() {
786        return Err(SpatialError::BufferLengthMismatch { expected: x.len(), found: y.len() });
787    }
788
789    let out_x = reduce_voxel_average_f32(runtime, x, segments)?;
790    let out_y = reduce_voxel_average_f32(runtime, y, segments)?;
791    let out_z = reduce_voxel_average_f32(runtime, z, segments)?;
792    Ok((out_x, out_y, out_z))
793}
794
795/// Averages xyz positions using GPU-resident buffers.
796pub fn reduce_voxel_centroids_xyz_gpu_buffers(
797    runtime: &WgpuRuntime,
798    x: &wgpu::Buffer,
799    y: &wgpu::Buffer,
800    z: &wgpu::Buffer,
801    segments: &GpuVoxelSegments,
802) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>)> {
803    dispatch_voxel_reduce_xyz_f32(
804        runtime,
805        x,
806        y,
807        z,
808        segments.point_indices_buffer(),
809        segments.cell_starts_buffer(),
810        segments.cell_count(),
811        segments.point_count(),
812    )
813}
814
815fn dispatch_voxel_reduce_xyz_f32(
816    runtime: &WgpuRuntime,
817    values_x: &wgpu::Buffer,
818    values_y: &wgpu::Buffer,
819    values_z: &wgpu::Buffer,
820    point_indices: &wgpu::Buffer,
821    cell_starts: &wgpu::Buffer,
822    cell_count: u32,
823    point_count: u32,
824) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>)> {
825    if cell_count == 0 {
826        return Ok((Vec::new(), Vec::new(), Vec::new()));
827    }
828
829    let device = runtime.device();
830    let queue = runtime.queue();
831    let channel_len = cell_count as usize * std::mem::size_of::<f32>();
832    let output_x = device.create_buffer(&wgpu::BufferDescriptor {
833        label: Some("voxel-reduce-xyz-out-x"),
834        size: channel_len as u64,
835        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
836        mapped_at_creation: false,
837    });
838    let output_y = device.create_buffer(&wgpu::BufferDescriptor {
839        label: Some("voxel-reduce-xyz-out-y"),
840        size: channel_len as u64,
841        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
842        mapped_at_creation: false,
843    });
844    let output_z = device.create_buffer(&wgpu::BufferDescriptor {
845        label: Some("voxel-reduce-xyz-out-z"),
846        size: channel_len as u64,
847        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
848        mapped_at_creation: false,
849    });
850    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
851        label: Some("voxel-reduce-xyz-staging"),
852        size: (channel_len * 3) as u64,
853        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
854        mapped_at_creation: false,
855    });
856
857    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
858        label: Some("voxel-reduce-xyz-encoder"),
859    });
860    record_voxel_reduce_xyz_pass(
861        &mut encoder,
862        runtime,
863        values_x,
864        values_y,
865        values_z,
866        point_indices,
867        cell_starts,
868        cell_count,
869        point_count,
870        &output_x,
871        &output_y,
872        &output_z,
873    )?;
874    encoder.copy_buffer_to_buffer(&output_x, 0, &staging_buffer, 0, channel_len as u64);
875    encoder.copy_buffer_to_buffer(
876        &output_y,
877        0,
878        &staging_buffer,
879        channel_len as u64,
880        channel_len as u64,
881    );
882    encoder.copy_buffer_to_buffer(
883        &output_z,
884        0,
885        &staging_buffer,
886        (channel_len * 2) as u64,
887        channel_len as u64,
888    );
889    queue.submit(Some(encoder.finish()));
890
891    let flat = read_staging_f32(device, &staging_buffer, cell_count as usize * 3)?;
892    Ok(split_xyz_blocks(flat, cell_count as usize))
893}
894
895#[cfg(test)]
896mod tests {
897    use super::{
898        reduce_voxel_average_f32, reduce_voxel_average_f32_multi_gpu, reduce_voxel_centroids_xyz,
899        reduce_voxel_centroids_xyz_and_average_multi_gpu,
900        reduce_voxel_centroids_xyz_and_gather_first_multi_gpu,
901    };
902    use crate::kernels::voxel_segments::build_voxel_segments;
903    use crate::runtime::WgpuRuntime;
904
905    #[test]
906    fn gpu_centroid_reduction_matches_cpu_reference() {
907        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
908        let x = [0.0_f32, 0.1, 1.0, 1.1];
909        let y = [0.0_f32, 0.0, 0.0, 0.0];
910        let z = [0.0_f32, 0.0, 0.0, 0.0];
911        let keys = vec![(0, 0, 0), (0, 0, 0), (2, 0, 0), (2, 0, 0)];
912        let segments = build_voxel_segments(&keys);
913
914        let (gpu_x, gpu_y, gpu_z) =
915            reduce_voxel_centroids_xyz(&runtime, &x, &y, &z, &segments).expect("gpu reduce");
916
917        assert!((gpu_x[0] - 0.05).abs() < 1e-5);
918        assert!((gpu_x[1] - 1.05).abs() < 1e-5);
919        assert_eq!(gpu_y, vec![0.0, 0.0]);
920        assert_eq!(gpu_z, vec![0.0, 0.0]);
921
922        let intensity = [0.2_f32, 0.8, 10.0, 20.0];
923        let gpu_i = reduce_voxel_average_f32(&runtime, &intensity, &segments).expect("gpu average");
924        assert!((gpu_i[0] - 0.5).abs() < 1e-5);
925        assert!((gpu_i[1] - 15.0).abs() < 1e-5);
926    }
927
928    #[test]
929    fn gpu_multi_reduce_matches_single_channel_reference() {
930        use crate::kernels::voxel_sort::build_voxel_segments_gpu_from_keys;
931
932        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
933        let intensity = [0.2_f32, 0.8, 10.0, 20.0];
934        let classification = [1.0_f32, 2.0, 3.0, 4.0];
935        let keys = vec![(0, 0, 0), (0, 0, 0), (2, 0, 0), (2, 0, 0)];
936        let segments = build_voxel_segments_gpu_from_keys(&runtime, &keys).expect("gpu segments");
937
938        let multi =
939            reduce_voxel_average_f32_multi_gpu(&runtime, &[&intensity, &classification], &segments)
940                .expect("multi reduce");
941
942        assert!((multi[0][0] - 0.5).abs() < 1e-5);
943        assert!((multi[0][1] - 15.0).abs() < 1e-5);
944        assert!((multi[1][0] - 1.5).abs() < 1e-5);
945        assert!((multi[1][1] - 3.5).abs() < 1e-5);
946    }
947
948    #[test]
949    fn unified_xyz_and_attribute_readback_matches_staged_reference() {
950        use crate::kernels::voxel_keys::compute_voxel_keys_gpu_buffers;
951        use crate::kernels::voxel_sort::build_voxel_segments_gpu_from_keys_buffer;
952
953        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
954        let x = [0.0_f32, 0.1, 1.0, 1.1];
955        let y = [0.0_f32, 0.0, 0.0, 0.0];
956        let z = [0.0_f32, 0.0, 0.0, 0.0];
957        let intensity = [0.2_f32, 0.8, 10.0, 20.0];
958        let positions =
959            compute_voxel_keys_gpu_buffers(&runtime, &x, &y, &z, [0.0; 3], 2.0).expect("keys");
960        let segments = build_voxel_segments_gpu_from_keys_buffer(
961            &runtime,
962            positions.keys_buffer(),
963            positions.point_count(),
964            4,
965        )
966        .expect("segments");
967
968        let (out_x, out_y, out_z, attrs, _) = reduce_voxel_centroids_xyz_and_average_multi_gpu(
969            &runtime,
970            positions.x_buffer(),
971            positions.y_buffer(),
972            positions.z_buffer(),
973            &[&intensity],
974            &[],
975            &segments,
976        )
977        .expect("unified reduce");
978
979        assert!((out_x[0] - 0.05).abs() < 1e-5);
980        assert!((out_x[1] - 1.05).abs() < 1e-5);
981        assert_eq!(out_y, vec![0.0, 0.0]);
982        assert_eq!(out_z, vec![0.0, 0.0]);
983        assert!((attrs[0][0] - 0.5).abs() < 1e-5);
984        assert!((attrs[0][1] - 15.0).abs() < 1e-5);
985
986        let (_, _, _, gathered, _) = reduce_voxel_centroids_xyz_and_gather_first_multi_gpu(
987            &runtime,
988            positions.x_buffer(),
989            positions.y_buffer(),
990            positions.z_buffer(),
991            &[&intensity],
992            &[],
993            &segments,
994        )
995        .expect("unified gather attrs");
996        assert!((gathered[0][0] - 0.2).abs() < 1e-5);
997        assert!((gathered[0][1] - 10.0).abs() < 1e-5);
998    }
999
1000    #[test]
1001    fn unified_xyz_and_u8_attribute_readback_matches_reference() {
1002        use crate::kernels::voxel_keys::compute_voxel_keys_gpu_buffers;
1003        use crate::kernels::voxel_sort::build_voxel_segments_gpu_from_keys_buffer;
1004
1005        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
1006        let x = [0.0_f32, 0.1, 1.0, 1.1];
1007        let y = [0.0_f32, 0.0, 0.0, 0.0];
1008        let z = [0.0_f32, 0.0, 0.0, 0.0];
1009        let red = [10_u8, 20, 100, 200];
1010        let green = [30_u8, 40, 50, 60];
1011        let positions =
1012            compute_voxel_keys_gpu_buffers(&runtime, &x, &y, &z, [0.0; 3], 2.0).expect("keys");
1013        let segments = build_voxel_segments_gpu_from_keys_buffer(
1014            &runtime,
1015            positions.keys_buffer(),
1016            positions.point_count(),
1017            4,
1018        )
1019        .expect("segments");
1020
1021        let (_, _, _, _, u8_attrs) = reduce_voxel_centroids_xyz_and_average_multi_gpu(
1022            &runtime,
1023            positions.x_buffer(),
1024            positions.y_buffer(),
1025            positions.z_buffer(),
1026            &[],
1027            &[&red, &green],
1028            &segments,
1029        )
1030        .expect("unified u8 reduce");
1031
1032        assert_eq!(u8_attrs[0], vec![15, 150]);
1033        assert_eq!(u8_attrs[1], vec![35, 55]);
1034
1035        let (_, _, _, _, gathered_u8) = reduce_voxel_centroids_xyz_and_gather_first_multi_gpu(
1036            &runtime,
1037            positions.x_buffer(),
1038            positions.y_buffer(),
1039            positions.z_buffer(),
1040            &[],
1041            &[&red, &green],
1042            &segments,
1043        )
1044        .expect("unified u8 gather");
1045        assert_eq!(gathered_u8[0], vec![10, 100]);
1046        assert_eq!(gathered_u8[1], vec![30, 50]);
1047    }
1048}