Skip to main content

spatialrust_gpu/kernels/
voxel_gather.rs

1use bytemuck::{Pod, Zeroable};
2use spatialrust_core::{SpatialError, SpatialResult};
3use wgpu::util::DeviceExt;
4
5use crate::kernels::gpu_segments::GpuVoxelSegments;
6use crate::kernels::voxel_segments::VoxelSegments;
7use crate::readback::{
8    pad_u8_for_gpu_storage, read_staging_f32, read_staging_f32_and_u8, split_channel_blocks,
9    split_u8_channel_blocks, split_xyz_and_attribute_blocks, split_xyz_blocks,
10    u8_output_staging_bytes, unpack_u8_outputs_from_u32_staging,
11};
12use crate::runtime::WgpuRuntime;
13
14const WORKGROUP_SIZE: u32 = 256;
15const MULTI2_CHANNELS: usize = 2;
16const MULTI4_CHANNELS: usize = 4;
17
18#[repr(C)]
19#[derive(Clone, Copy, Debug, Pod, Zeroable)]
20struct GatherUniform {
21    cell_count: u32,
22    point_count: u32,
23    channel_count: u32,
24    _pad: u32,
25}
26
27/// Gathers the first point's `f32` value within each voxel cell on the GPU.
28pub fn gather_voxel_first_f32(
29    runtime: &WgpuRuntime,
30    values: &[f32],
31    segments: &VoxelSegments,
32) -> SpatialResult<Vec<f32>> {
33    if segments.is_empty() {
34        return Ok(Vec::new());
35    }
36    if values.is_empty() {
37        return Err(SpatialError::InvalidArgument(
38            "cannot gather from empty value buffer".to_owned(),
39        ));
40    }
41
42    let device = runtime.device();
43    let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
44        label: Some("voxel-gather-values"),
45        contents: bytemuck::cast_slice(values),
46        usage: wgpu::BufferUsages::STORAGE,
47    });
48    let indices_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
49        label: Some("voxel-gather-indices"),
50        contents: bytemuck::cast_slice(&segments.point_indices),
51        usage: wgpu::BufferUsages::STORAGE,
52    });
53    let starts_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
54        label: Some("voxel-gather-starts"),
55        contents: bytemuck::cast_slice(&segments.cell_starts),
56        usage: wgpu::BufferUsages::STORAGE,
57    });
58
59    dispatch_voxel_gather_f32(
60        runtime,
61        &values_buffer,
62        &indices_buffer,
63        &starts_buffer,
64        segments.len() as u32,
65        segments.point_indices.len() as u32,
66    )
67}
68
69/// Gathers the first point's `f32` value using GPU-resident segment buffers.
70pub fn gather_voxel_first_f32_gpu_buffers(
71    runtime: &WgpuRuntime,
72    values: &wgpu::Buffer,
73    segments: &GpuVoxelSegments,
74) -> SpatialResult<Vec<f32>> {
75    dispatch_voxel_gather_f32(
76        runtime,
77        values,
78        segments.point_indices_buffer(),
79        segments.cell_starts_buffer(),
80        segments.cell_count(),
81        segments.point_count(),
82    )
83}
84
85/// Uploads `f32` values and gathers the first point per GPU-resident voxel segment.
86pub fn gather_voxel_first_f32_gpu(
87    runtime: &WgpuRuntime,
88    values: &[f32],
89    segments: &GpuVoxelSegments,
90) -> SpatialResult<Vec<f32>> {
91    if segments.cell_count() == 0 {
92        return Ok(Vec::new());
93    }
94    if values.len() != segments.point_count() as usize {
95        return Err(SpatialError::BufferLengthMismatch {
96            expected: segments.point_count() as usize,
97            found: values.len(),
98        });
99    }
100
101    let device = runtime.device();
102    let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
103        label: Some("voxel-gather-values-upload"),
104        contents: bytemuck::cast_slice(values),
105        usage: wgpu::BufferUsages::STORAGE,
106    });
107    gather_voxel_first_f32_gpu_buffers(runtime, &values_buffer, segments)
108}
109
110/// Gathers multiple `f32` channels in one or more GPU dispatches.
111pub fn gather_voxel_first_f32_multi_gpu(
112    runtime: &WgpuRuntime,
113    channels: &[&[f32]],
114    segments: &GpuVoxelSegments,
115) -> SpatialResult<Vec<Vec<f32>>> {
116    if channels.is_empty() {
117        return Ok(Vec::new());
118    }
119    if segments.cell_count() == 0 {
120        return Ok(vec![Vec::new(); channels.len()]);
121    }
122
123    let point_count = segments.point_count() as usize;
124    for channel in channels {
125        if channel.len() != point_count {
126            return Err(SpatialError::BufferLengthMismatch {
127                expected: point_count,
128                found: channel.len(),
129            });
130        }
131    }
132
133    let max_channels = runtime.max_gather_channels() as usize;
134    let device = runtime.device();
135    let empty = empty_storage_buffer(device)?;
136    let mut gathered = Vec::with_capacity(channels.len());
137
138    for chunk in channels.chunks(max_channels) {
139        let mut value_buffers = [None, None, None, None];
140        for (index, channel) in chunk.iter().enumerate() {
141            value_buffers[index] =
142                Some(device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
143                    label: Some("voxel-gather-multi-values"),
144                    contents: bytemuck::cast_slice(channel),
145                    usage: wgpu::BufferUsages::STORAGE,
146                }));
147        }
148
149        let value_refs: [&wgpu::Buffer; MULTI4_CHANNELS] = std::array::from_fn(|index| {
150            if index < chunk.len() {
151                value_buffers[index].as_ref().expect("value buffer")
152            } else {
153                &empty
154            }
155        });
156
157        gathered.extend(dispatch_voxel_gather_multi_gpu_buffers(
158            runtime,
159            &value_refs,
160            segments,
161            chunk.len() as u32,
162        )?);
163    }
164
165    Ok(gathered)
166}
167
168/// Gathers xyz and multiple f32/u8 attribute channels with one GPU submit/readback.
169pub fn gather_voxel_first_xyz_and_multi_gpu(
170    runtime: &WgpuRuntime,
171    x: &wgpu::Buffer,
172    y: &wgpu::Buffer,
173    z: &wgpu::Buffer,
174    attribute_channels: &[&[f32]],
175    u8_attribute_channels: &[&[u8]],
176    segments: &GpuVoxelSegments,
177) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<Vec<f32>>, Vec<Vec<u8>>)> {
178    let attribute_count = attribute_channels.len();
179    let u8_attribute_count = u8_attribute_channels.len();
180    if segments.cell_count() == 0 {
181        return Ok((
182            Vec::new(),
183            Vec::new(),
184            Vec::new(),
185            vec![Vec::new(); attribute_count],
186            vec![Vec::new(); u8_attribute_count],
187        ));
188    }
189
190    let point_count = segments.point_count() as usize;
191    for channel in attribute_channels {
192        if channel.len() != point_count {
193            return Err(SpatialError::BufferLengthMismatch {
194                expected: point_count,
195                found: channel.len(),
196            });
197        }
198    }
199    for channel in u8_attribute_channels {
200        if channel.len() != point_count {
201            return Err(SpatialError::BufferLengthMismatch {
202                expected: point_count,
203                found: channel.len(),
204            });
205        }
206    }
207
208    let device = runtime.device();
209    let queue = runtime.queue();
210    let cell_count = segments.cell_count();
211    let cells = cell_count as usize;
212    let channel_len = cells * std::mem::size_of::<f32>();
213    let f32_channel_count = 3 + attribute_count;
214    let u8_staging_len = u8_output_staging_bytes(cells, u8_attribute_count);
215    let staging_size = channel_len * f32_channel_count + u8_staging_len;
216    let fused_xyz_attrs4 = attribute_count == 4
217        && u8_attribute_count == 0
218        && runtime.pipelines().voxel_gather.xyz_attrs4_pipeline.is_some();
219
220    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
221        label: Some("voxel-gather-xyz-attrs-staging"),
222        size: staging_size as u64,
223        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
224        mapped_at_creation: false,
225    });
226
227    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
228        label: Some("voxel-gather-xyz-attrs-encoder"),
229    });
230    let mut upload_recycle = Vec::new();
231
232    if fused_xyz_attrs4 {
233        for channel in attribute_channels {
234            upload_recycle
235                .push(runtime.upload_f32_storage("voxel-gather-xyz-attrs4-values", channel)?);
236        }
237        let attr_refs: [&wgpu::Buffer; MULTI4_CHANNELS] =
238            std::array::from_fn(|index| &upload_recycle[index]);
239        let packed_output = device.create_buffer(&wgpu::BufferDescriptor {
240            label: Some("voxel-gather-xyz-attrs4-packed-output"),
241            size: (channel_len * f32_channel_count) as u64,
242            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
243            mapped_at_creation: false,
244        });
245        record_voxel_gather_xyz_and_attrs4_packed_pass(
246            &mut encoder,
247            runtime,
248            x,
249            y,
250            z,
251            &attr_refs,
252            segments,
253            &packed_output,
254        )?;
255        encoder.copy_buffer_to_buffer(
256            &packed_output,
257            0,
258            &staging_buffer,
259            0,
260            (channel_len * f32_channel_count) as u64,
261        );
262    } else {
263        let output_x = device.create_buffer(&wgpu::BufferDescriptor {
264            label: Some("voxel-gather-xyz-out-x"),
265            size: channel_len as u64,
266            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
267            mapped_at_creation: false,
268        });
269        let output_y = device.create_buffer(&wgpu::BufferDescriptor {
270            label: Some("voxel-gather-xyz-out-y"),
271            size: channel_len as u64,
272            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
273            mapped_at_creation: false,
274        });
275        let output_z = device.create_buffer(&wgpu::BufferDescriptor {
276            label: Some("voxel-gather-xyz-out-z"),
277            size: channel_len as u64,
278            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
279            mapped_at_creation: false,
280        });
281
282        record_voxel_gather_xyz_pass(
283            &mut encoder,
284            runtime,
285            x,
286            y,
287            z,
288            segments,
289            &output_x,
290            &output_y,
291            &output_z,
292        )?;
293
294        encoder.copy_buffer_to_buffer(&output_x, 0, &staging_buffer, 0, channel_len as u64);
295        encoder.copy_buffer_to_buffer(
296            &output_y,
297            0,
298            &staging_buffer,
299            channel_len as u64,
300            channel_len as u64,
301        );
302        encoder.copy_buffer_to_buffer(
303            &output_z,
304            0,
305            &staging_buffer,
306            (channel_len * 2) as u64,
307            channel_len as u64,
308        );
309
310        record_gather_f32_attribute_channels_to_staging(
311            &mut encoder,
312            runtime,
313            attribute_channels,
314            segments,
315            &staging_buffer,
316            channel_len as u64,
317            &mut upload_recycle,
318        )?;
319    }
320
321    let u8_region_offset = (channel_len * f32_channel_count) as u64;
322    for (attribute_index, channel) in u8_attribute_channels.iter().enumerate() {
323        let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
324            label: Some("voxel-gather-xyz-u8-values"),
325            contents: &pad_u8_for_gpu_storage(channel),
326            usage: wgpu::BufferUsages::STORAGE,
327        });
328        let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
329            label: Some("voxel-gather-xyz-u8-output"),
330            size: (cells * std::mem::size_of::<u32>()) as u64,
331            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
332            mapped_at_creation: false,
333        });
334
335        record_voxel_gather_u8_pass(
336            &mut encoder,
337            runtime,
338            &values_buffer,
339            segments.point_indices_buffer(),
340            segments.cell_starts_buffer(),
341            cell_count,
342            segments.point_count(),
343            &output_buffer,
344        )?;
345
346        encoder.copy_buffer_to_buffer(
347            &output_buffer,
348            0,
349            &staging_buffer,
350            u8_region_offset + attribute_index as u64 * (cells * std::mem::size_of::<u32>()) as u64,
351            (cells * std::mem::size_of::<u32>()) as u64,
352        );
353    }
354
355    queue.submit(Some(encoder.finish()));
356
357    let (flat, u8_raw) = read_staging_f32_and_u8(
358        device,
359        &staging_buffer,
360        cells * f32_channel_count,
361        u8_staging_len,
362    )?;
363    let u8_flat = if u8_attribute_count == 0 {
364        Vec::new()
365    } else {
366        unpack_u8_outputs_from_u32_staging(u8_raw, cells, u8_attribute_count)
367    };
368    let (out_x, out_y, out_z, attributes) =
369        split_xyz_and_attribute_blocks(flat, attribute_count, cells);
370    for buffer in upload_recycle {
371        runtime.recycle_storage(buffer.size(), buffer);
372    }
373    Ok((
374        out_x,
375        out_y,
376        out_z,
377        attributes,
378        split_u8_channel_blocks(u8_flat, u8_attribute_count, cells),
379    ))
380}
381
382/// Gathers xyz and averages f32/u8 attribute channels with one GPU submit/readback.
383pub fn gather_voxel_first_xyz_and_average_multi_gpu(
384    runtime: &WgpuRuntime,
385    x: &wgpu::Buffer,
386    y: &wgpu::Buffer,
387    z: &wgpu::Buffer,
388    attribute_channels: &[&[f32]],
389    u8_attribute_channels: &[&[u8]],
390    segments: &GpuVoxelSegments,
391) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<Vec<f32>>, Vec<Vec<u8>>)> {
392    use crate::kernels::voxel_reduce::record_voxel_reduce_f32_pass;
393    use crate::kernels::voxel_reduce::record_voxel_reduce_u8_pass;
394
395    let attribute_count = attribute_channels.len();
396    let u8_attribute_count = u8_attribute_channels.len();
397    if segments.cell_count() == 0 {
398        return Ok((
399            Vec::new(),
400            Vec::new(),
401            Vec::new(),
402            vec![Vec::new(); attribute_count],
403            vec![Vec::new(); u8_attribute_count],
404        ));
405    }
406
407    let point_count = segments.point_count() as usize;
408    for channel in attribute_channels {
409        if channel.len() != point_count {
410            return Err(SpatialError::BufferLengthMismatch {
411                expected: point_count,
412                found: channel.len(),
413            });
414        }
415    }
416    for channel in u8_attribute_channels {
417        if channel.len() != point_count {
418            return Err(SpatialError::BufferLengthMismatch {
419                expected: point_count,
420                found: channel.len(),
421            });
422        }
423    }
424
425    let device = runtime.device();
426    let queue = runtime.queue();
427    let cell_count = segments.cell_count();
428    let cells = cell_count as usize;
429    let channel_len = cells * std::mem::size_of::<f32>();
430    let f32_channel_count = 3 + attribute_count;
431    let u8_staging_len = u8_output_staging_bytes(cells, u8_attribute_count);
432    let staging_size = channel_len * f32_channel_count + u8_staging_len;
433    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
434        label: Some("voxel-gather-xyz-reduce-attrs-staging"),
435        size: staging_size as u64,
436        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
437        mapped_at_creation: false,
438    });
439    let output_x = device.create_buffer(&wgpu::BufferDescriptor {
440        label: Some("voxel-gather-xyz-out-x"),
441        size: channel_len as u64,
442        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
443        mapped_at_creation: false,
444    });
445    let output_y = device.create_buffer(&wgpu::BufferDescriptor {
446        label: Some("voxel-gather-xyz-out-y"),
447        size: channel_len as u64,
448        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
449        mapped_at_creation: false,
450    });
451    let output_z = device.create_buffer(&wgpu::BufferDescriptor {
452        label: Some("voxel-gather-xyz-out-z"),
453        size: channel_len as u64,
454        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
455        mapped_at_creation: false,
456    });
457
458    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
459        label: Some("voxel-gather-xyz-reduce-attrs-encoder"),
460    });
461    record_voxel_gather_xyz_pass(
462        &mut encoder,
463        runtime,
464        x,
465        y,
466        z,
467        segments,
468        &output_x,
469        &output_y,
470        &output_z,
471    )?;
472    encoder.copy_buffer_to_buffer(&output_x, 0, &staging_buffer, 0, channel_len as u64);
473    encoder.copy_buffer_to_buffer(
474        &output_y,
475        0,
476        &staging_buffer,
477        channel_len as u64,
478        channel_len as u64,
479    );
480    encoder.copy_buffer_to_buffer(
481        &output_z,
482        0,
483        &staging_buffer,
484        (channel_len * 2) as u64,
485        channel_len as u64,
486    );
487
488    for (attribute_index, channel) in attribute_channels.iter().enumerate() {
489        let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
490            label: Some("voxel-reduce-xyz-attrs-values"),
491            contents: bytemuck::cast_slice(channel),
492            usage: wgpu::BufferUsages::STORAGE,
493        });
494        let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
495            label: Some("voxel-reduce-xyz-attrs-output"),
496            size: channel_len as u64,
497            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
498            mapped_at_creation: false,
499        });
500        record_voxel_reduce_f32_pass(
501            &mut encoder,
502            runtime,
503            &values_buffer,
504            segments.point_indices_buffer(),
505            segments.cell_starts_buffer(),
506            cell_count,
507            segments.point_count(),
508            &output_buffer,
509        )?;
510        encoder.copy_buffer_to_buffer(
511            &output_buffer,
512            0,
513            &staging_buffer,
514            (channel_len * (3 + attribute_index)) as u64,
515            channel_len as u64,
516        );
517    }
518
519    let u8_region_offset = (channel_len * f32_channel_count) as u64;
520    for (attribute_index, channel) in u8_attribute_channels.iter().enumerate() {
521        let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
522            label: Some("voxel-reduce-xyz-u8-values"),
523            contents: &pad_u8_for_gpu_storage(channel),
524            usage: wgpu::BufferUsages::STORAGE,
525        });
526        let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
527            label: Some("voxel-reduce-xyz-u8-output"),
528            size: (cells * std::mem::size_of::<u32>()) as u64,
529            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
530            mapped_at_creation: false,
531        });
532        record_voxel_reduce_u8_pass(
533            &mut encoder,
534            runtime,
535            &values_buffer,
536            segments.point_indices_buffer(),
537            segments.cell_starts_buffer(),
538            cell_count,
539            segments.point_count(),
540            &output_buffer,
541        )?;
542        encoder.copy_buffer_to_buffer(
543            &output_buffer,
544            0,
545            &staging_buffer,
546            u8_region_offset + attribute_index as u64 * (cells * std::mem::size_of::<u32>()) as u64,
547            (cells * std::mem::size_of::<u32>()) as u64,
548        );
549    }
550
551    queue.submit(Some(encoder.finish()));
552    let (flat, u8_raw) = read_staging_f32_and_u8(
553        device,
554        &staging_buffer,
555        cells * f32_channel_count,
556        u8_staging_len,
557    )?;
558    let u8_flat = if u8_attribute_count == 0 {
559        Vec::new()
560    } else {
561        unpack_u8_outputs_from_u32_staging(u8_raw, cells, u8_attribute_count)
562    };
563    let (out_x, out_y, out_z, attributes) =
564        split_xyz_and_attribute_blocks(flat, attribute_count, cells);
565    Ok((
566        out_x,
567        out_y,
568        out_z,
569        attributes,
570        split_u8_channel_blocks(u8_flat, u8_attribute_count, cells),
571    ))
572}
573
574/// Gathers xyz coordinates of the first point within each voxel cell on the GPU.
575pub fn gather_voxel_first_xyz_gpu_buffers(
576    runtime: &WgpuRuntime,
577    x: &wgpu::Buffer,
578    y: &wgpu::Buffer,
579    z: &wgpu::Buffer,
580    segments: &GpuVoxelSegments,
581) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>)> {
582    if segments.cell_count() == 0 {
583        return Ok((Vec::new(), Vec::new(), Vec::new()));
584    }
585
586    let device = runtime.device();
587    let queue = runtime.queue();
588    let cell_count = segments.cell_count();
589    let channel_len = cell_count as usize * std::mem::size_of::<f32>();
590    let output_x = device.create_buffer(&wgpu::BufferDescriptor {
591        label: Some("voxel-gather-xyz-out-x"),
592        size: channel_len as u64,
593        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
594        mapped_at_creation: false,
595    });
596    let output_y = device.create_buffer(&wgpu::BufferDescriptor {
597        label: Some("voxel-gather-xyz-out-y"),
598        size: channel_len as u64,
599        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
600        mapped_at_creation: false,
601    });
602    let output_z = device.create_buffer(&wgpu::BufferDescriptor {
603        label: Some("voxel-gather-xyz-out-z"),
604        size: channel_len as u64,
605        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
606        mapped_at_creation: false,
607    });
608    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
609        label: Some("voxel-gather-xyz-staging"),
610        size: (channel_len * 3) as u64,
611        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
612        mapped_at_creation: false,
613    });
614
615    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
616        label: Some("voxel-gather-xyz-encoder"),
617    });
618    record_voxel_gather_xyz_pass(
619        &mut encoder,
620        runtime,
621        x,
622        y,
623        z,
624        segments,
625        &output_x,
626        &output_y,
627        &output_z,
628    )?;
629    encoder.copy_buffer_to_buffer(&output_x, 0, &staging_buffer, 0, channel_len as u64);
630    encoder.copy_buffer_to_buffer(
631        &output_y,
632        0,
633        &staging_buffer,
634        channel_len as u64,
635        channel_len as u64,
636    );
637    encoder.copy_buffer_to_buffer(
638        &output_z,
639        0,
640        &staging_buffer,
641        (channel_len * 2) as u64,
642        channel_len as u64,
643    );
644    queue.submit(Some(encoder.finish()));
645
646    let flat = read_staging_f32(device, &staging_buffer, cell_count as usize * 3)?;
647    Ok(split_xyz_blocks(flat, cell_count as usize))
648}
649
650pub(crate) fn record_voxel_gather_xyz_and_attrs4_packed_pass(
651    encoder: &mut wgpu::CommandEncoder,
652    runtime: &WgpuRuntime,
653    x: &wgpu::Buffer,
654    y: &wgpu::Buffer,
655    z: &wgpu::Buffer,
656    attribute_buffers: &[&wgpu::Buffer; MULTI4_CHANNELS],
657    segments: &GpuVoxelSegments,
658    packed_output: &wgpu::Buffer,
659) -> SpatialResult<()> {
660    if segments.cell_count() == 0 {
661        return Ok(());
662    }
663
664    let device = runtime.device();
665    let pipelines = runtime.pipelines();
666    let xyz_attrs4_pipeline =
667        pipelines.voxel_gather.xyz_attrs4_pipeline.as_ref().ok_or_else(|| {
668            SpatialError::InvalidArgument(
669                "fused xyz+4-attribute gather pipeline is unavailable on this gpu adapter"
670                    .to_owned(),
671            )
672        })?;
673    let xyz_attrs4_layout =
674        pipelines.voxel_gather.xyz_attrs4_bind_group_layout.as_ref().expect("xyz attrs4 layout");
675    let cell_count = segments.cell_count();
676    let uniform = GatherUniform {
677        cell_count,
678        point_count: segments.point_count(),
679        channel_count: 0,
680        _pad: 0,
681    };
682    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
683        label: Some("voxel-gather-xyz-attrs4-uniform"),
684        contents: bytemuck::bytes_of(&uniform),
685        usage: wgpu::BufferUsages::UNIFORM,
686    });
687    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
688        label: Some("voxel-gather-xyz-attrs4-bind-group"),
689        layout: xyz_attrs4_layout,
690        entries: &[
691            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
692            wgpu::BindGroupEntry {
693                binding: 1,
694                resource: segments.point_indices_buffer().as_entire_binding(),
695            },
696            wgpu::BindGroupEntry {
697                binding: 2,
698                resource: segments.cell_starts_buffer().as_entire_binding(),
699            },
700            wgpu::BindGroupEntry { binding: 3, resource: x.as_entire_binding() },
701            wgpu::BindGroupEntry { binding: 4, resource: y.as_entire_binding() },
702            wgpu::BindGroupEntry { binding: 5, resource: z.as_entire_binding() },
703            wgpu::BindGroupEntry { binding: 6, resource: attribute_buffers[0].as_entire_binding() },
704            wgpu::BindGroupEntry { binding: 7, resource: attribute_buffers[1].as_entire_binding() },
705            wgpu::BindGroupEntry { binding: 8, resource: attribute_buffers[2].as_entire_binding() },
706            wgpu::BindGroupEntry { binding: 9, resource: attribute_buffers[3].as_entire_binding() },
707            wgpu::BindGroupEntry { binding: 10, resource: packed_output.as_entire_binding() },
708        ],
709    });
710    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
711        label: Some("voxel-gather-xyz-attrs4-pass"),
712        timestamp_writes: None,
713    });
714    pass.set_pipeline(xyz_attrs4_pipeline);
715    pass.set_bind_group(0, &bind_group, &[]);
716    pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
717    Ok(())
718}
719
720/// Records batched first-point f32 attribute gathers into a unified xyz staging buffer.
721pub(crate) fn record_gather_f32_attribute_channels_to_staging(
722    encoder: &mut wgpu::CommandEncoder,
723    runtime: &WgpuRuntime,
724    attribute_channels: &[&[f32]],
725    segments: &GpuVoxelSegments,
726    staging_buffer: &wgpu::Buffer,
727    channel_len: u64,
728    upload_recycle: &mut Vec<wgpu::Buffer>,
729) -> SpatialResult<()> {
730    if attribute_channels.is_empty() {
731        return Ok(());
732    }
733
734    let device = runtime.device();
735    let max_channels = runtime.max_gather_channels().max(1) as usize;
736
737    for chunk_start in (0..attribute_channels.len()).step_by(max_channels) {
738        let chunk_end = (chunk_start + max_channels).min(attribute_channels.len());
739        let chunk = &attribute_channels[chunk_start..chunk_end];
740        let channels_in_chunk = chunk.len();
741
742        let chunk_upload_start = upload_recycle.len();
743        for channel in chunk {
744            upload_recycle
745                .push(runtime.upload_f32_storage("voxel-gather-xyz-attrs-values", channel)?);
746        }
747        let value_buffers = &upload_recycle[chunk_upload_start..];
748
749        let output_buffers: Vec<wgpu::Buffer> = (0..channels_in_chunk)
750            .map(|_| {
751                device.create_buffer(&wgpu::BufferDescriptor {
752                    label: Some("voxel-gather-xyz-attrs-output"),
753                    size: channel_len,
754                    usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
755                    mapped_at_creation: false,
756                })
757            })
758            .collect();
759
760        match channels_in_chunk {
761            1 => record_voxel_gather_f32_pass(
762                encoder,
763                runtime,
764                &value_buffers[0],
765                segments.point_indices_buffer(),
766                segments.cell_starts_buffer(),
767                segments.cell_count(),
768                segments.point_count(),
769                &output_buffers[0],
770            )?,
771            2 => {
772                let values: [&wgpu::Buffer; MULTI2_CHANNELS] =
773                    [&value_buffers[0], &value_buffers[1]];
774                let outputs: [&wgpu::Buffer; MULTI2_CHANNELS] =
775                    [&output_buffers[0], &output_buffers[1]];
776                record_voxel_gather_multi2_pass(encoder, runtime, &values, segments, &outputs, 2)?;
777            }
778            channel_count => {
779                if runtime.pipelines().voxel_gather.multi4_pipeline.is_some() {
780                    let empty = empty_storage_buffer(device)?;
781                    let dummy_outputs: [wgpu::Buffer; MULTI4_CHANNELS] =
782                        std::array::from_fn(|_| {
783                            device.create_buffer(&wgpu::BufferDescriptor {
784                                label: Some("voxel-gather-xyz-attrs-dummy-output"),
785                                size: channel_len,
786                                usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
787                                mapped_at_creation: false,
788                            })
789                        });
790                    let values: [&wgpu::Buffer; MULTI4_CHANNELS] = std::array::from_fn(|index| {
791                        if index < channel_count {
792                            &value_buffers[index]
793                        } else {
794                            &empty
795                        }
796                    });
797                    let outputs: [&wgpu::Buffer; MULTI4_CHANNELS] = std::array::from_fn(|index| {
798                        if index < channel_count {
799                            &output_buffers[index]
800                        } else {
801                            &dummy_outputs[index]
802                        }
803                    });
804                    record_voxel_gather_multi4_pass(
805                        encoder,
806                        runtime,
807                        &values,
808                        segments,
809                        &outputs,
810                        channel_count as u32,
811                    )?;
812                } else {
813                    for (local_index, channel) in chunk.iter().enumerate() {
814                        let chunk_upload_start = upload_recycle.len();
815                        upload_recycle.push(
816                            runtime.upload_f32_storage("voxel-gather-xyz-attrs-values", channel)?,
817                        );
818                        record_voxel_gather_f32_pass(
819                            encoder,
820                            runtime,
821                            &upload_recycle[chunk_upload_start],
822                            segments.point_indices_buffer(),
823                            segments.cell_starts_buffer(),
824                            segments.cell_count(),
825                            segments.point_count(),
826                            &output_buffers[local_index],
827                        )?;
828                    }
829                }
830            }
831        }
832
833        for (local_index, output_buffer) in output_buffers.iter().enumerate() {
834            encoder.copy_buffer_to_buffer(
835                output_buffer,
836                0,
837                staging_buffer,
838                channel_len * (3 + chunk_start + local_index) as u64,
839                channel_len,
840            );
841        }
842    }
843
844    Ok(())
845}
846
847pub(crate) fn record_voxel_gather_multi2_pass(
848    encoder: &mut wgpu::CommandEncoder,
849    runtime: &WgpuRuntime,
850    values: &[&wgpu::Buffer; MULTI2_CHANNELS],
851    segments: &GpuVoxelSegments,
852    outputs: &[&wgpu::Buffer; MULTI2_CHANNELS],
853    channel_count: u32,
854) -> SpatialResult<()> {
855    let device = runtime.device();
856    let pipelines = runtime.pipelines();
857    let cell_count = segments.cell_count();
858    let uniform =
859        GatherUniform { cell_count, point_count: segments.point_count(), channel_count, _pad: 0 };
860    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
861        label: Some("voxel-gather-multi-uniform"),
862        contents: bytemuck::bytes_of(&uniform),
863        usage: wgpu::BufferUsages::UNIFORM,
864    });
865    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
866        label: Some("voxel-gather-multi-bind-group"),
867        layout: &pipelines.voxel_gather.multi_bind_group_layout,
868        entries: &[
869            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
870            wgpu::BindGroupEntry {
871                binding: 1,
872                resource: segments.point_indices_buffer().as_entire_binding(),
873            },
874            wgpu::BindGroupEntry {
875                binding: 2,
876                resource: segments.cell_starts_buffer().as_entire_binding(),
877            },
878            wgpu::BindGroupEntry { binding: 3, resource: values[0].as_entire_binding() },
879            wgpu::BindGroupEntry { binding: 4, resource: values[1].as_entire_binding() },
880            wgpu::BindGroupEntry { binding: 5, resource: outputs[0].as_entire_binding() },
881            wgpu::BindGroupEntry { binding: 6, resource: outputs[1].as_entire_binding() },
882        ],
883    });
884    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
885        label: Some("voxel-gather-multi-pass"),
886        timestamp_writes: None,
887    });
888    pass.set_pipeline(&pipelines.voxel_gather.multi_pipeline);
889    pass.set_bind_group(0, &bind_group, &[]);
890    pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
891    Ok(())
892}
893
894pub(crate) fn record_voxel_gather_multi4_pass(
895    encoder: &mut wgpu::CommandEncoder,
896    runtime: &WgpuRuntime,
897    values: &[&wgpu::Buffer; MULTI4_CHANNELS],
898    segments: &GpuVoxelSegments,
899    outputs: &[&wgpu::Buffer; MULTI4_CHANNELS],
900    channel_count: u32,
901) -> SpatialResult<()> {
902    let device = runtime.device();
903    let pipelines = runtime.pipelines();
904    let multi4_pipeline = pipelines.voxel_gather.multi4_pipeline.as_ref().ok_or_else(|| {
905        SpatialError::InvalidArgument(
906            "4-channel gather pipeline is unavailable on this gpu adapter".to_owned(),
907        )
908    })?;
909    let multi4_layout =
910        pipelines.voxel_gather.multi4_bind_group_layout.as_ref().expect("multi4 layout");
911    let cell_count = segments.cell_count();
912    let uniform =
913        GatherUniform { cell_count, point_count: segments.point_count(), channel_count, _pad: 0 };
914    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
915        label: Some("voxel-gather-multi4-uniform"),
916        contents: bytemuck::bytes_of(&uniform),
917        usage: wgpu::BufferUsages::UNIFORM,
918    });
919    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
920        label: Some("voxel-gather-multi4-bind-group"),
921        layout: multi4_layout,
922        entries: &[
923            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
924            wgpu::BindGroupEntry {
925                binding: 1,
926                resource: segments.point_indices_buffer().as_entire_binding(),
927            },
928            wgpu::BindGroupEntry {
929                binding: 2,
930                resource: segments.cell_starts_buffer().as_entire_binding(),
931            },
932            wgpu::BindGroupEntry { binding: 3, resource: values[0].as_entire_binding() },
933            wgpu::BindGroupEntry { binding: 4, resource: values[1].as_entire_binding() },
934            wgpu::BindGroupEntry { binding: 5, resource: values[2].as_entire_binding() },
935            wgpu::BindGroupEntry { binding: 6, resource: values[3].as_entire_binding() },
936            wgpu::BindGroupEntry { binding: 7, resource: outputs[0].as_entire_binding() },
937            wgpu::BindGroupEntry { binding: 8, resource: outputs[1].as_entire_binding() },
938            wgpu::BindGroupEntry { binding: 9, resource: outputs[2].as_entire_binding() },
939            wgpu::BindGroupEntry { binding: 10, resource: outputs[3].as_entire_binding() },
940        ],
941    });
942    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
943        label: Some("voxel-gather-multi4-pass"),
944        timestamp_writes: None,
945    });
946    pass.set_pipeline(multi4_pipeline);
947    pass.set_bind_group(0, &bind_group, &[]);
948    pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
949    Ok(())
950}
951
952pub(crate) fn record_voxel_gather_f32_pass(
953    encoder: &mut wgpu::CommandEncoder,
954    runtime: &WgpuRuntime,
955    values: &wgpu::Buffer,
956    point_indices: &wgpu::Buffer,
957    cell_starts: &wgpu::Buffer,
958    cell_count: u32,
959    point_count: u32,
960    output_buffer: &wgpu::Buffer,
961) -> SpatialResult<()> {
962    if cell_count == 0 {
963        return Ok(());
964    }
965
966    let device = runtime.device();
967    let pipelines = runtime.pipelines();
968    let uniform = GatherUniform { cell_count, point_count, channel_count: 1, _pad: 0 };
969    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
970        label: Some("voxel-gather-uniform"),
971        contents: bytemuck::bytes_of(&uniform),
972        usage: wgpu::BufferUsages::UNIFORM,
973    });
974
975    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
976        label: Some("voxel-gather-bind-group"),
977        layout: &pipelines.voxel_gather.bind_group_layout,
978        entries: &[
979            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
980            wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
981            wgpu::BindGroupEntry { binding: 2, resource: values.as_entire_binding() },
982            wgpu::BindGroupEntry { binding: 3, resource: cell_starts.as_entire_binding() },
983            wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
984        ],
985    });
986
987    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
988        label: Some("voxel-gather-pass"),
989        timestamp_writes: None,
990    });
991    pass.set_pipeline(&pipelines.voxel_gather.pipeline);
992    pass.set_bind_group(0, &bind_group, &[]);
993    pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
994    Ok(())
995}
996
997pub(crate) fn record_voxel_gather_u8_pass(
998    encoder: &mut wgpu::CommandEncoder,
999    runtime: &WgpuRuntime,
1000    values: &wgpu::Buffer,
1001    point_indices: &wgpu::Buffer,
1002    cell_starts: &wgpu::Buffer,
1003    cell_count: u32,
1004    point_count: u32,
1005    output_buffer: &wgpu::Buffer,
1006) -> SpatialResult<()> {
1007    if cell_count == 0 {
1008        return Ok(());
1009    }
1010
1011    let device = runtime.device();
1012    let pipelines = runtime.pipelines();
1013    let uniform = GatherUniform { cell_count, point_count, channel_count: 1, _pad: 0 };
1014    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1015        label: Some("voxel-gather-u8-uniform"),
1016        contents: bytemuck::bytes_of(&uniform),
1017        usage: wgpu::BufferUsages::UNIFORM,
1018    });
1019
1020    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1021        label: Some("voxel-gather-u8-bind-group"),
1022        layout: &pipelines.voxel_gather.u8_bind_group_layout,
1023        entries: &[
1024            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
1025            wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
1026            wgpu::BindGroupEntry { binding: 2, resource: values.as_entire_binding() },
1027            wgpu::BindGroupEntry { binding: 3, resource: cell_starts.as_entire_binding() },
1028            wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
1029        ],
1030    });
1031
1032    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1033        label: Some("voxel-gather-u8-pass"),
1034        timestamp_writes: None,
1035    });
1036    pass.set_pipeline(&pipelines.voxel_gather.u8_pipeline);
1037    pass.set_bind_group(0, &bind_group, &[]);
1038    pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
1039    Ok(())
1040}
1041
1042#[allow(clippy::too_many_arguments)]
1043pub(crate) fn record_voxel_gather_xyz_pass(
1044    encoder: &mut wgpu::CommandEncoder,
1045    runtime: &WgpuRuntime,
1046    x: &wgpu::Buffer,
1047    y: &wgpu::Buffer,
1048    z: &wgpu::Buffer,
1049    segments: &GpuVoxelSegments,
1050    output_x: &wgpu::Buffer,
1051    output_y: &wgpu::Buffer,
1052    output_z: &wgpu::Buffer,
1053) -> SpatialResult<()> {
1054    if segments.cell_count() == 0 {
1055        return Ok(());
1056    }
1057
1058    let device = runtime.device();
1059    let pipelines = runtime.pipelines();
1060    let cell_count = segments.cell_count();
1061    let point_count = segments.point_count();
1062    let uniform = GatherUniform { cell_count, point_count, channel_count: 0, _pad: 0 };
1063    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1064        label: Some("voxel-gather-xyz-uniform"),
1065        contents: bytemuck::bytes_of(&uniform),
1066        usage: wgpu::BufferUsages::UNIFORM,
1067    });
1068
1069    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1070        label: Some("voxel-gather-xyz-bind-group"),
1071        layout: &pipelines.voxel_gather.xyz_bind_group_layout,
1072        entries: &[
1073            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
1074            wgpu::BindGroupEntry {
1075                binding: 1,
1076                resource: segments.point_indices_buffer().as_entire_binding(),
1077            },
1078            wgpu::BindGroupEntry {
1079                binding: 2,
1080                resource: segments.cell_starts_buffer().as_entire_binding(),
1081            },
1082            wgpu::BindGroupEntry { binding: 3, resource: x.as_entire_binding() },
1083            wgpu::BindGroupEntry { binding: 4, resource: y.as_entire_binding() },
1084            wgpu::BindGroupEntry { binding: 5, resource: z.as_entire_binding() },
1085            wgpu::BindGroupEntry { binding: 6, resource: output_x.as_entire_binding() },
1086            wgpu::BindGroupEntry { binding: 7, resource: output_y.as_entire_binding() },
1087            wgpu::BindGroupEntry { binding: 8, resource: output_z.as_entire_binding() },
1088        ],
1089    });
1090
1091    let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1092        label: Some("voxel-gather-xyz-pass"),
1093        timestamp_writes: None,
1094    });
1095    pass.set_pipeline(&pipelines.voxel_gather.xyz_pipeline);
1096    pass.set_bind_group(0, &bind_group, &[]);
1097    pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
1098    Ok(())
1099}
1100
1101fn dispatch_voxel_gather_f32(
1102    runtime: &WgpuRuntime,
1103    values: &wgpu::Buffer,
1104    point_indices: &wgpu::Buffer,
1105    cell_starts: &wgpu::Buffer,
1106    cell_count: u32,
1107    point_count: u32,
1108) -> SpatialResult<Vec<f32>> {
1109    if cell_count == 0 {
1110        return Ok(Vec::new());
1111    }
1112
1113    let device = runtime.device();
1114    let queue = runtime.queue();
1115    let pipelines = runtime.pipelines();
1116
1117    let uniform = GatherUniform { cell_count, point_count, channel_count: 1, _pad: 0 };
1118
1119    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1120        label: Some("voxel-gather-uniform"),
1121        contents: bytemuck::bytes_of(&uniform),
1122        usage: wgpu::BufferUsages::UNIFORM,
1123    });
1124
1125    let output_len = cell_count as usize * std::mem::size_of::<f32>();
1126    let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
1127        label: Some("voxel-gather-output"),
1128        size: output_len as u64,
1129        usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1130        mapped_at_creation: false,
1131    });
1132    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
1133        label: Some("voxel-gather-staging"),
1134        size: output_len as u64,
1135        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1136        mapped_at_creation: false,
1137    });
1138
1139    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1140        label: Some("voxel-gather-bind-group"),
1141        layout: &pipelines.voxel_gather.bind_group_layout,
1142        entries: &[
1143            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
1144            wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
1145            wgpu::BindGroupEntry { binding: 2, resource: values.as_entire_binding() },
1146            wgpu::BindGroupEntry { binding: 3, resource: cell_starts.as_entire_binding() },
1147            wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
1148        ],
1149    });
1150
1151    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1152        label: Some("voxel-gather-encoder"),
1153    });
1154
1155    {
1156        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1157            label: Some("voxel-gather-pass"),
1158            timestamp_writes: None,
1159        });
1160        pass.set_pipeline(&pipelines.voxel_gather.pipeline);
1161        pass.set_bind_group(0, &bind_group, &[]);
1162        pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
1163    }
1164
1165    encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_len as u64);
1166    queue.submit(Some(encoder.finish()));
1167
1168    read_staging_f32(device, &staging_buffer, cell_count as usize)
1169}
1170
1171fn dispatch_voxel_gather_multi_gpu_buffers(
1172    runtime: &WgpuRuntime,
1173    values: &[&wgpu::Buffer; MULTI4_CHANNELS],
1174    segments: &GpuVoxelSegments,
1175    channel_count: u32,
1176) -> SpatialResult<Vec<Vec<f32>>> {
1177    let pipelines = runtime.pipelines();
1178    if channel_count > 2 && pipelines.voxel_gather.multi4_pipeline.is_some() {
1179        dispatch_voxel_gather_multi4_gpu_buffers(runtime, values, segments, channel_count)
1180    } else {
1181        if channel_count > MULTI2_CHANNELS as u32 {
1182            return Err(SpatialError::InvalidArgument(format!(
1183                "gpu adapter supports only {} channels per gather dispatch",
1184                pipelines.voxel_gather.multi_max_channels
1185            )));
1186        }
1187        dispatch_voxel_gather_multi2_gpu_buffers(runtime, values, segments, channel_count)
1188    }
1189}
1190
1191fn dispatch_voxel_gather_multi2_gpu_buffers(
1192    runtime: &WgpuRuntime,
1193    values: &[&wgpu::Buffer; MULTI4_CHANNELS],
1194    segments: &GpuVoxelSegments,
1195    channel_count: u32,
1196) -> SpatialResult<Vec<Vec<f32>>> {
1197    let device = runtime.device();
1198    let queue = runtime.queue();
1199    let pipelines = runtime.pipelines();
1200    let cell_count = segments.cell_count();
1201    let point_count = segments.point_count();
1202    let channels = channel_count as usize;
1203
1204    let uniform = GatherUniform { cell_count, point_count, channel_count, _pad: 0 };
1205    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1206        label: Some("voxel-gather-multi-uniform"),
1207        contents: bytemuck::bytes_of(&uniform),
1208        usage: wgpu::BufferUsages::UNIFORM,
1209    });
1210
1211    let channel_len = cell_count as usize * std::mem::size_of::<f32>();
1212    let mut outputs = [None, None];
1213    for output in outputs.iter_mut() {
1214        *output = Some(device.create_buffer(&wgpu::BufferDescriptor {
1215            label: Some("voxel-gather-multi-output"),
1216            size: channel_len as u64,
1217            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1218            mapped_at_creation: false,
1219        }));
1220    }
1221    let output_refs: [&wgpu::Buffer; MULTI2_CHANNELS] =
1222        std::array::from_fn(|index| outputs[index].as_ref().expect("output buffer"));
1223
1224    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
1225        label: Some("voxel-gather-multi-staging"),
1226        size: (channel_len * channels) as u64,
1227        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1228        mapped_at_creation: false,
1229    });
1230
1231    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1232        label: Some("voxel-gather-multi-bind-group"),
1233        layout: &pipelines.voxel_gather.multi_bind_group_layout,
1234        entries: &[
1235            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
1236            wgpu::BindGroupEntry {
1237                binding: 1,
1238                resource: segments.point_indices_buffer().as_entire_binding(),
1239            },
1240            wgpu::BindGroupEntry {
1241                binding: 2,
1242                resource: segments.cell_starts_buffer().as_entire_binding(),
1243            },
1244            wgpu::BindGroupEntry { binding: 3, resource: values[0].as_entire_binding() },
1245            wgpu::BindGroupEntry { binding: 4, resource: values[1].as_entire_binding() },
1246            wgpu::BindGroupEntry { binding: 5, resource: output_refs[0].as_entire_binding() },
1247            wgpu::BindGroupEntry { binding: 6, resource: output_refs[1].as_entire_binding() },
1248        ],
1249    });
1250
1251    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1252        label: Some("voxel-gather-multi-encoder"),
1253    });
1254    {
1255        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1256            label: Some("voxel-gather-multi-pass"),
1257            timestamp_writes: None,
1258        });
1259        pass.set_pipeline(&pipelines.voxel_gather.multi_pipeline);
1260        pass.set_bind_group(0, &bind_group, &[]);
1261        pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
1262    }
1263
1264    for (index, output) in output_refs.iter().take(channels).enumerate() {
1265        let offset = (channel_len * index) as u64;
1266        encoder.copy_buffer_to_buffer(output, 0, &staging_buffer, offset, channel_len as u64);
1267    }
1268    queue.submit(Some(encoder.finish()));
1269
1270    let flat = read_staging_f32(device, &staging_buffer, cell_count as usize * channels)?;
1271    Ok(split_channel_blocks(flat, channels, cell_count as usize))
1272}
1273
1274fn dispatch_voxel_gather_multi4_gpu_buffers(
1275    runtime: &WgpuRuntime,
1276    values: &[&wgpu::Buffer; MULTI4_CHANNELS],
1277    segments: &GpuVoxelSegments,
1278    channel_count: u32,
1279) -> SpatialResult<Vec<Vec<f32>>> {
1280    let device = runtime.device();
1281    let queue = runtime.queue();
1282    let pipelines = runtime.pipelines();
1283    let multi4_pipeline = pipelines.voxel_gather.multi4_pipeline.as_ref().ok_or_else(|| {
1284        SpatialError::InvalidArgument(
1285            "4-channel gather pipeline is unavailable on this gpu adapter".to_owned(),
1286        )
1287    })?;
1288    let multi4_layout =
1289        pipelines.voxel_gather.multi4_bind_group_layout.as_ref().expect("multi4 layout");
1290
1291    let cell_count = segments.cell_count();
1292    let point_count = segments.point_count();
1293    let channels = channel_count as usize;
1294
1295    let uniform = GatherUniform { cell_count, point_count, channel_count, _pad: 0 };
1296    let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1297        label: Some("voxel-gather-multi4-uniform"),
1298        contents: bytemuck::bytes_of(&uniform),
1299        usage: wgpu::BufferUsages::UNIFORM,
1300    });
1301
1302    let channel_len = cell_count as usize * std::mem::size_of::<f32>();
1303    let mut outputs = [None, None, None, None];
1304    for output in outputs.iter_mut() {
1305        *output = Some(device.create_buffer(&wgpu::BufferDescriptor {
1306            label: Some("voxel-gather-multi4-output"),
1307            size: channel_len as u64,
1308            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1309            mapped_at_creation: false,
1310        }));
1311    }
1312    let output_refs: [&wgpu::Buffer; MULTI4_CHANNELS] =
1313        std::array::from_fn(|index| outputs[index].as_ref().expect("output buffer"));
1314
1315    let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
1316        label: Some("voxel-gather-multi4-staging"),
1317        size: (channel_len * channels) as u64,
1318        usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1319        mapped_at_creation: false,
1320    });
1321
1322    let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1323        label: Some("voxel-gather-multi4-bind-group"),
1324        layout: multi4_layout,
1325        entries: &[
1326            wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
1327            wgpu::BindGroupEntry {
1328                binding: 1,
1329                resource: segments.point_indices_buffer().as_entire_binding(),
1330            },
1331            wgpu::BindGroupEntry {
1332                binding: 2,
1333                resource: segments.cell_starts_buffer().as_entire_binding(),
1334            },
1335            wgpu::BindGroupEntry { binding: 3, resource: values[0].as_entire_binding() },
1336            wgpu::BindGroupEntry { binding: 4, resource: values[1].as_entire_binding() },
1337            wgpu::BindGroupEntry { binding: 5, resource: values[2].as_entire_binding() },
1338            wgpu::BindGroupEntry { binding: 6, resource: values[3].as_entire_binding() },
1339            wgpu::BindGroupEntry { binding: 7, resource: output_refs[0].as_entire_binding() },
1340            wgpu::BindGroupEntry { binding: 8, resource: output_refs[1].as_entire_binding() },
1341            wgpu::BindGroupEntry { binding: 9, resource: output_refs[2].as_entire_binding() },
1342            wgpu::BindGroupEntry { binding: 10, resource: output_refs[3].as_entire_binding() },
1343        ],
1344    });
1345
1346    let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1347        label: Some("voxel-gather-multi4-encoder"),
1348    });
1349    {
1350        let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1351            label: Some("voxel-gather-multi4-pass"),
1352            timestamp_writes: None,
1353        });
1354        pass.set_pipeline(multi4_pipeline);
1355        pass.set_bind_group(0, &bind_group, &[]);
1356        pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
1357    }
1358
1359    for (index, output) in output_refs.iter().take(channels).enumerate() {
1360        let offset = (channel_len * index) as u64;
1361        encoder.copy_buffer_to_buffer(output, 0, &staging_buffer, offset, channel_len as u64);
1362    }
1363    queue.submit(Some(encoder.finish()));
1364
1365    let flat = read_staging_f32(device, &staging_buffer, cell_count as usize * channels)?;
1366    Ok(split_channel_blocks(flat, channels, cell_count as usize))
1367}
1368
1369fn empty_storage_buffer(device: &wgpu::Device) -> SpatialResult<wgpu::Buffer> {
1370    Ok(device.create_buffer(&wgpu::BufferDescriptor {
1371        label: Some("voxel-gather-empty"),
1372        size: 4,
1373        usage: wgpu::BufferUsages::STORAGE,
1374        mapped_at_creation: false,
1375    }))
1376}
1377
1378#[cfg(test)]
1379mod tests {
1380    use super::{gather_voxel_first_f32, gather_voxel_first_f32_multi_gpu};
1381    use crate::kernels::gpu_segments::GpuVoxelSegments;
1382    use crate::kernels::voxel_segments::build_voxel_segments;
1383    use crate::runtime::WgpuRuntime;
1384
1385    fn gpu_segments_from_keys(runtime: &WgpuRuntime, keys: &[(i64, i64, i64)]) -> GpuVoxelSegments {
1386        use crate::kernels::voxel_sort::build_voxel_segments_gpu_from_keys;
1387        build_voxel_segments_gpu_from_keys(runtime, keys).expect("gpu segments")
1388    }
1389
1390    #[test]
1391    fn gpu_first_gather_matches_cpu_reference() {
1392        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
1393        let values = [0.2_f32, 0.9, 10.0, 20.0];
1394        let keys = vec![(0, 0, 0), (0, 0, 0), (2, 0, 0), (2, 0, 0)];
1395        let segments = build_voxel_segments(&keys);
1396
1397        let gpu = gather_voxel_first_f32(&runtime, &values, &segments).expect("gpu gather");
1398        assert!((gpu[0] - 0.2).abs() < 1e-5);
1399        assert!((gpu[1] - 10.0).abs() < 1e-5);
1400    }
1401
1402    #[test]
1403    fn gpu_multi_gather_matches_single_channel_reference() {
1404        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
1405        let intensity = [0.2_f32, 0.9, 10.0, 20.0];
1406        let classification = [1.0_f32, 2.0, 3.0, 4.0];
1407        let keys = vec![(0, 0, 0), (0, 0, 0), (2, 0, 0), (2, 0, 0)];
1408        let segments = gpu_segments_from_keys(&runtime, &keys);
1409
1410        let multi =
1411            gather_voxel_first_f32_multi_gpu(&runtime, &[&intensity, &classification], &segments)
1412                .expect("multi gather");
1413
1414        assert!((multi[0][0] - 0.2).abs() < 1e-5);
1415        assert!((multi[0][1] - 10.0).abs() < 1e-5);
1416        assert!((multi[1][0] - 1.0).abs() < 1e-5);
1417        assert!((multi[1][1] - 3.0).abs() < 1e-5);
1418    }
1419
1420    #[test]
1421    fn gpu_multi4_gather_handles_four_channels_when_supported() {
1422        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
1423        if runtime.max_gather_channels() < 4 {
1424            return;
1425        }
1426
1427        let c0 = [0.2_f32, 0.9, 10.0, 20.0];
1428        let c1 = [1.0_f32, 2.0, 3.0, 4.0];
1429        let c2 = [5.0_f32, 6.0, 7.0, 8.0];
1430        let c3 = [9.0_f32, 10.0, 11.0, 12.0];
1431        let keys = vec![(0, 0, 0), (0, 0, 0), (2, 0, 0), (2, 0, 0)];
1432        let segments = gpu_segments_from_keys(&runtime, &keys);
1433
1434        let multi = gather_voxel_first_f32_multi_gpu(&runtime, &[&c0, &c1, &c2, &c3], &segments)
1435            .expect("multi4 gather");
1436
1437        assert_eq!(multi.len(), 4);
1438        assert!((multi[0][0] - 0.2).abs() < 1e-5);
1439        assert!((multi[0][1] - 10.0).abs() < 1e-5);
1440        assert!((multi[1][0] - 1.0).abs() < 1e-5);
1441        assert!((multi[1][1] - 3.0).abs() < 1e-5);
1442        assert!((multi[2][0] - 5.0).abs() < 1e-5);
1443        assert!((multi[2][1] - 7.0).abs() < 1e-5);
1444        assert!((multi[3][0] - 9.0).abs() < 1e-5);
1445        assert!((multi[3][1] - 11.0).abs() < 1e-5);
1446    }
1447}