1use bytemuck::{Pod, Zeroable};
2use spatialrust_core::{SpatialError, SpatialResult};
3use wgpu::util::DeviceExt;
4
5use crate::kernels::gpu_segments::GpuVoxelSegments;
6use crate::kernels::voxel_segments::VoxelSegments;
7use crate::readback::{
8 pad_u8_for_gpu_storage, read_staging_f32, read_staging_f32_and_u8, split_channel_blocks,
9 split_u8_channel_blocks, split_xyz_and_attribute_blocks, split_xyz_blocks,
10 u8_output_staging_bytes, unpack_u8_outputs_from_u32_staging,
11};
12use crate::runtime::WgpuRuntime;
13
14const WORKGROUP_SIZE: u32 = 256;
15const MULTI2_CHANNELS: usize = 2;
16const MULTI4_CHANNELS: usize = 4;
17
18#[repr(C)]
19#[derive(Clone, Copy, Debug, Pod, Zeroable)]
20struct GatherUniform {
21 cell_count: u32,
22 point_count: u32,
23 channel_count: u32,
24 _pad: u32,
25}
26
27pub fn gather_voxel_first_f32(
29 runtime: &WgpuRuntime,
30 values: &[f32],
31 segments: &VoxelSegments,
32) -> SpatialResult<Vec<f32>> {
33 if segments.is_empty() {
34 return Ok(Vec::new());
35 }
36 if values.is_empty() {
37 return Err(SpatialError::InvalidArgument(
38 "cannot gather from empty value buffer".to_owned(),
39 ));
40 }
41
42 let device = runtime.device();
43 let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
44 label: Some("voxel-gather-values"),
45 contents: bytemuck::cast_slice(values),
46 usage: wgpu::BufferUsages::STORAGE,
47 });
48 let indices_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
49 label: Some("voxel-gather-indices"),
50 contents: bytemuck::cast_slice(&segments.point_indices),
51 usage: wgpu::BufferUsages::STORAGE,
52 });
53 let starts_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
54 label: Some("voxel-gather-starts"),
55 contents: bytemuck::cast_slice(&segments.cell_starts),
56 usage: wgpu::BufferUsages::STORAGE,
57 });
58
59 dispatch_voxel_gather_f32(
60 runtime,
61 &values_buffer,
62 &indices_buffer,
63 &starts_buffer,
64 segments.len() as u32,
65 segments.point_indices.len() as u32,
66 )
67}
68
69pub fn gather_voxel_first_f32_gpu_buffers(
71 runtime: &WgpuRuntime,
72 values: &wgpu::Buffer,
73 segments: &GpuVoxelSegments,
74) -> SpatialResult<Vec<f32>> {
75 dispatch_voxel_gather_f32(
76 runtime,
77 values,
78 segments.point_indices_buffer(),
79 segments.cell_starts_buffer(),
80 segments.cell_count(),
81 segments.point_count(),
82 )
83}
84
85pub fn gather_voxel_first_f32_gpu(
87 runtime: &WgpuRuntime,
88 values: &[f32],
89 segments: &GpuVoxelSegments,
90) -> SpatialResult<Vec<f32>> {
91 if segments.cell_count() == 0 {
92 return Ok(Vec::new());
93 }
94 if values.len() != segments.point_count() as usize {
95 return Err(SpatialError::BufferLengthMismatch {
96 expected: segments.point_count() as usize,
97 found: values.len(),
98 });
99 }
100
101 let device = runtime.device();
102 let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
103 label: Some("voxel-gather-values-upload"),
104 contents: bytemuck::cast_slice(values),
105 usage: wgpu::BufferUsages::STORAGE,
106 });
107 gather_voxel_first_f32_gpu_buffers(runtime, &values_buffer, segments)
108}
109
110pub fn gather_voxel_first_f32_multi_gpu(
112 runtime: &WgpuRuntime,
113 channels: &[&[f32]],
114 segments: &GpuVoxelSegments,
115) -> SpatialResult<Vec<Vec<f32>>> {
116 if channels.is_empty() {
117 return Ok(Vec::new());
118 }
119 if segments.cell_count() == 0 {
120 return Ok(vec![Vec::new(); channels.len()]);
121 }
122
123 let point_count = segments.point_count() as usize;
124 for channel in channels {
125 if channel.len() != point_count {
126 return Err(SpatialError::BufferLengthMismatch {
127 expected: point_count,
128 found: channel.len(),
129 });
130 }
131 }
132
133 let max_channels = runtime.max_gather_channels() as usize;
134 let device = runtime.device();
135 let empty = empty_storage_buffer(device)?;
136 let mut gathered = Vec::with_capacity(channels.len());
137
138 for chunk in channels.chunks(max_channels) {
139 let mut value_buffers = [None, None, None, None];
140 for (index, channel) in chunk.iter().enumerate() {
141 value_buffers[index] =
142 Some(device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
143 label: Some("voxel-gather-multi-values"),
144 contents: bytemuck::cast_slice(channel),
145 usage: wgpu::BufferUsages::STORAGE,
146 }));
147 }
148
149 let value_refs: [&wgpu::Buffer; MULTI4_CHANNELS] = std::array::from_fn(|index| {
150 if index < chunk.len() {
151 value_buffers[index].as_ref().expect("value buffer")
152 } else {
153 &empty
154 }
155 });
156
157 gathered.extend(dispatch_voxel_gather_multi_gpu_buffers(
158 runtime,
159 &value_refs,
160 segments,
161 chunk.len() as u32,
162 )?);
163 }
164
165 Ok(gathered)
166}
167
168pub fn gather_voxel_first_xyz_and_multi_gpu(
170 runtime: &WgpuRuntime,
171 x: &wgpu::Buffer,
172 y: &wgpu::Buffer,
173 z: &wgpu::Buffer,
174 attribute_channels: &[&[f32]],
175 u8_attribute_channels: &[&[u8]],
176 segments: &GpuVoxelSegments,
177) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<Vec<f32>>, Vec<Vec<u8>>)> {
178 let attribute_count = attribute_channels.len();
179 let u8_attribute_count = u8_attribute_channels.len();
180 if segments.cell_count() == 0 {
181 return Ok((
182 Vec::new(),
183 Vec::new(),
184 Vec::new(),
185 vec![Vec::new(); attribute_count],
186 vec![Vec::new(); u8_attribute_count],
187 ));
188 }
189
190 let point_count = segments.point_count() as usize;
191 for channel in attribute_channels {
192 if channel.len() != point_count {
193 return Err(SpatialError::BufferLengthMismatch {
194 expected: point_count,
195 found: channel.len(),
196 });
197 }
198 }
199 for channel in u8_attribute_channels {
200 if channel.len() != point_count {
201 return Err(SpatialError::BufferLengthMismatch {
202 expected: point_count,
203 found: channel.len(),
204 });
205 }
206 }
207
208 let device = runtime.device();
209 let queue = runtime.queue();
210 let cell_count = segments.cell_count();
211 let cells = cell_count as usize;
212 let channel_len = cells * std::mem::size_of::<f32>();
213 let f32_channel_count = 3 + attribute_count;
214 let u8_staging_len = u8_output_staging_bytes(cells, u8_attribute_count);
215 let staging_size = channel_len * f32_channel_count + u8_staging_len;
216 let fused_xyz_attrs4 = attribute_count == 4
217 && u8_attribute_count == 0
218 && runtime.pipelines().voxel_gather.xyz_attrs4_pipeline.is_some();
219
220 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
221 label: Some("voxel-gather-xyz-attrs-staging"),
222 size: staging_size as u64,
223 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
224 mapped_at_creation: false,
225 });
226
227 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
228 label: Some("voxel-gather-xyz-attrs-encoder"),
229 });
230 let mut upload_recycle = Vec::new();
231
232 if fused_xyz_attrs4 {
233 for channel in attribute_channels {
234 upload_recycle
235 .push(runtime.upload_f32_storage("voxel-gather-xyz-attrs4-values", channel)?);
236 }
237 let attr_refs: [&wgpu::Buffer; MULTI4_CHANNELS] =
238 std::array::from_fn(|index| &upload_recycle[index]);
239 let packed_output = device.create_buffer(&wgpu::BufferDescriptor {
240 label: Some("voxel-gather-xyz-attrs4-packed-output"),
241 size: (channel_len * f32_channel_count) as u64,
242 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
243 mapped_at_creation: false,
244 });
245 record_voxel_gather_xyz_and_attrs4_packed_pass(
246 &mut encoder,
247 runtime,
248 x,
249 y,
250 z,
251 &attr_refs,
252 segments,
253 &packed_output,
254 )?;
255 encoder.copy_buffer_to_buffer(
256 &packed_output,
257 0,
258 &staging_buffer,
259 0,
260 (channel_len * f32_channel_count) as u64,
261 );
262 } else {
263 let output_x = device.create_buffer(&wgpu::BufferDescriptor {
264 label: Some("voxel-gather-xyz-out-x"),
265 size: channel_len as u64,
266 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
267 mapped_at_creation: false,
268 });
269 let output_y = device.create_buffer(&wgpu::BufferDescriptor {
270 label: Some("voxel-gather-xyz-out-y"),
271 size: channel_len as u64,
272 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
273 mapped_at_creation: false,
274 });
275 let output_z = device.create_buffer(&wgpu::BufferDescriptor {
276 label: Some("voxel-gather-xyz-out-z"),
277 size: channel_len as u64,
278 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
279 mapped_at_creation: false,
280 });
281
282 record_voxel_gather_xyz_pass(
283 &mut encoder,
284 runtime,
285 x,
286 y,
287 z,
288 segments,
289 &output_x,
290 &output_y,
291 &output_z,
292 )?;
293
294 encoder.copy_buffer_to_buffer(&output_x, 0, &staging_buffer, 0, channel_len as u64);
295 encoder.copy_buffer_to_buffer(
296 &output_y,
297 0,
298 &staging_buffer,
299 channel_len as u64,
300 channel_len as u64,
301 );
302 encoder.copy_buffer_to_buffer(
303 &output_z,
304 0,
305 &staging_buffer,
306 (channel_len * 2) as u64,
307 channel_len as u64,
308 );
309
310 record_gather_f32_attribute_channels_to_staging(
311 &mut encoder,
312 runtime,
313 attribute_channels,
314 segments,
315 &staging_buffer,
316 channel_len as u64,
317 &mut upload_recycle,
318 )?;
319 }
320
321 let u8_region_offset = (channel_len * f32_channel_count) as u64;
322 for (attribute_index, channel) in u8_attribute_channels.iter().enumerate() {
323 let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
324 label: Some("voxel-gather-xyz-u8-values"),
325 contents: &pad_u8_for_gpu_storage(channel),
326 usage: wgpu::BufferUsages::STORAGE,
327 });
328 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
329 label: Some("voxel-gather-xyz-u8-output"),
330 size: (cells * std::mem::size_of::<u32>()) as u64,
331 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
332 mapped_at_creation: false,
333 });
334
335 record_voxel_gather_u8_pass(
336 &mut encoder,
337 runtime,
338 &values_buffer,
339 segments.point_indices_buffer(),
340 segments.cell_starts_buffer(),
341 cell_count,
342 segments.point_count(),
343 &output_buffer,
344 )?;
345
346 encoder.copy_buffer_to_buffer(
347 &output_buffer,
348 0,
349 &staging_buffer,
350 u8_region_offset + attribute_index as u64 * (cells * std::mem::size_of::<u32>()) as u64,
351 (cells * std::mem::size_of::<u32>()) as u64,
352 );
353 }
354
355 queue.submit(Some(encoder.finish()));
356
357 let (flat, u8_raw) = read_staging_f32_and_u8(
358 device,
359 &staging_buffer,
360 cells * f32_channel_count,
361 u8_staging_len,
362 )?;
363 let u8_flat = if u8_attribute_count == 0 {
364 Vec::new()
365 } else {
366 unpack_u8_outputs_from_u32_staging(u8_raw, cells, u8_attribute_count)
367 };
368 let (out_x, out_y, out_z, attributes) =
369 split_xyz_and_attribute_blocks(flat, attribute_count, cells);
370 for buffer in upload_recycle {
371 runtime.recycle_storage(buffer.size(), buffer);
372 }
373 Ok((
374 out_x,
375 out_y,
376 out_z,
377 attributes,
378 split_u8_channel_blocks(u8_flat, u8_attribute_count, cells),
379 ))
380}
381
382pub fn gather_voxel_first_xyz_and_average_multi_gpu(
384 runtime: &WgpuRuntime,
385 x: &wgpu::Buffer,
386 y: &wgpu::Buffer,
387 z: &wgpu::Buffer,
388 attribute_channels: &[&[f32]],
389 u8_attribute_channels: &[&[u8]],
390 segments: &GpuVoxelSegments,
391) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<Vec<f32>>, Vec<Vec<u8>>)> {
392 use crate::kernels::voxel_reduce::record_voxel_reduce_f32_pass;
393 use crate::kernels::voxel_reduce::record_voxel_reduce_u8_pass;
394
395 let attribute_count = attribute_channels.len();
396 let u8_attribute_count = u8_attribute_channels.len();
397 if segments.cell_count() == 0 {
398 return Ok((
399 Vec::new(),
400 Vec::new(),
401 Vec::new(),
402 vec![Vec::new(); attribute_count],
403 vec![Vec::new(); u8_attribute_count],
404 ));
405 }
406
407 let point_count = segments.point_count() as usize;
408 for channel in attribute_channels {
409 if channel.len() != point_count {
410 return Err(SpatialError::BufferLengthMismatch {
411 expected: point_count,
412 found: channel.len(),
413 });
414 }
415 }
416 for channel in u8_attribute_channels {
417 if channel.len() != point_count {
418 return Err(SpatialError::BufferLengthMismatch {
419 expected: point_count,
420 found: channel.len(),
421 });
422 }
423 }
424
425 let device = runtime.device();
426 let queue = runtime.queue();
427 let cell_count = segments.cell_count();
428 let cells = cell_count as usize;
429 let channel_len = cells * std::mem::size_of::<f32>();
430 let f32_channel_count = 3 + attribute_count;
431 let u8_staging_len = u8_output_staging_bytes(cells, u8_attribute_count);
432 let staging_size = channel_len * f32_channel_count + u8_staging_len;
433 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
434 label: Some("voxel-gather-xyz-reduce-attrs-staging"),
435 size: staging_size as u64,
436 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
437 mapped_at_creation: false,
438 });
439 let output_x = device.create_buffer(&wgpu::BufferDescriptor {
440 label: Some("voxel-gather-xyz-out-x"),
441 size: channel_len as u64,
442 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
443 mapped_at_creation: false,
444 });
445 let output_y = device.create_buffer(&wgpu::BufferDescriptor {
446 label: Some("voxel-gather-xyz-out-y"),
447 size: channel_len as u64,
448 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
449 mapped_at_creation: false,
450 });
451 let output_z = device.create_buffer(&wgpu::BufferDescriptor {
452 label: Some("voxel-gather-xyz-out-z"),
453 size: channel_len as u64,
454 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
455 mapped_at_creation: false,
456 });
457
458 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
459 label: Some("voxel-gather-xyz-reduce-attrs-encoder"),
460 });
461 record_voxel_gather_xyz_pass(
462 &mut encoder,
463 runtime,
464 x,
465 y,
466 z,
467 segments,
468 &output_x,
469 &output_y,
470 &output_z,
471 )?;
472 encoder.copy_buffer_to_buffer(&output_x, 0, &staging_buffer, 0, channel_len as u64);
473 encoder.copy_buffer_to_buffer(
474 &output_y,
475 0,
476 &staging_buffer,
477 channel_len as u64,
478 channel_len as u64,
479 );
480 encoder.copy_buffer_to_buffer(
481 &output_z,
482 0,
483 &staging_buffer,
484 (channel_len * 2) as u64,
485 channel_len as u64,
486 );
487
488 for (attribute_index, channel) in attribute_channels.iter().enumerate() {
489 let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
490 label: Some("voxel-reduce-xyz-attrs-values"),
491 contents: bytemuck::cast_slice(channel),
492 usage: wgpu::BufferUsages::STORAGE,
493 });
494 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
495 label: Some("voxel-reduce-xyz-attrs-output"),
496 size: channel_len as u64,
497 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
498 mapped_at_creation: false,
499 });
500 record_voxel_reduce_f32_pass(
501 &mut encoder,
502 runtime,
503 &values_buffer,
504 segments.point_indices_buffer(),
505 segments.cell_starts_buffer(),
506 cell_count,
507 segments.point_count(),
508 &output_buffer,
509 )?;
510 encoder.copy_buffer_to_buffer(
511 &output_buffer,
512 0,
513 &staging_buffer,
514 (channel_len * (3 + attribute_index)) as u64,
515 channel_len as u64,
516 );
517 }
518
519 let u8_region_offset = (channel_len * f32_channel_count) as u64;
520 for (attribute_index, channel) in u8_attribute_channels.iter().enumerate() {
521 let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
522 label: Some("voxel-reduce-xyz-u8-values"),
523 contents: &pad_u8_for_gpu_storage(channel),
524 usage: wgpu::BufferUsages::STORAGE,
525 });
526 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
527 label: Some("voxel-reduce-xyz-u8-output"),
528 size: (cells * std::mem::size_of::<u32>()) as u64,
529 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
530 mapped_at_creation: false,
531 });
532 record_voxel_reduce_u8_pass(
533 &mut encoder,
534 runtime,
535 &values_buffer,
536 segments.point_indices_buffer(),
537 segments.cell_starts_buffer(),
538 cell_count,
539 segments.point_count(),
540 &output_buffer,
541 )?;
542 encoder.copy_buffer_to_buffer(
543 &output_buffer,
544 0,
545 &staging_buffer,
546 u8_region_offset + attribute_index as u64 * (cells * std::mem::size_of::<u32>()) as u64,
547 (cells * std::mem::size_of::<u32>()) as u64,
548 );
549 }
550
551 queue.submit(Some(encoder.finish()));
552 let (flat, u8_raw) = read_staging_f32_and_u8(
553 device,
554 &staging_buffer,
555 cells * f32_channel_count,
556 u8_staging_len,
557 )?;
558 let u8_flat = if u8_attribute_count == 0 {
559 Vec::new()
560 } else {
561 unpack_u8_outputs_from_u32_staging(u8_raw, cells, u8_attribute_count)
562 };
563 let (out_x, out_y, out_z, attributes) =
564 split_xyz_and_attribute_blocks(flat, attribute_count, cells);
565 Ok((
566 out_x,
567 out_y,
568 out_z,
569 attributes,
570 split_u8_channel_blocks(u8_flat, u8_attribute_count, cells),
571 ))
572}
573
574pub fn gather_voxel_first_xyz_gpu_buffers(
576 runtime: &WgpuRuntime,
577 x: &wgpu::Buffer,
578 y: &wgpu::Buffer,
579 z: &wgpu::Buffer,
580 segments: &GpuVoxelSegments,
581) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>)> {
582 if segments.cell_count() == 0 {
583 return Ok((Vec::new(), Vec::new(), Vec::new()));
584 }
585
586 let device = runtime.device();
587 let queue = runtime.queue();
588 let cell_count = segments.cell_count();
589 let channel_len = cell_count as usize * std::mem::size_of::<f32>();
590 let output_x = device.create_buffer(&wgpu::BufferDescriptor {
591 label: Some("voxel-gather-xyz-out-x"),
592 size: channel_len as u64,
593 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
594 mapped_at_creation: false,
595 });
596 let output_y = device.create_buffer(&wgpu::BufferDescriptor {
597 label: Some("voxel-gather-xyz-out-y"),
598 size: channel_len as u64,
599 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
600 mapped_at_creation: false,
601 });
602 let output_z = device.create_buffer(&wgpu::BufferDescriptor {
603 label: Some("voxel-gather-xyz-out-z"),
604 size: channel_len as u64,
605 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
606 mapped_at_creation: false,
607 });
608 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
609 label: Some("voxel-gather-xyz-staging"),
610 size: (channel_len * 3) as u64,
611 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
612 mapped_at_creation: false,
613 });
614
615 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
616 label: Some("voxel-gather-xyz-encoder"),
617 });
618 record_voxel_gather_xyz_pass(
619 &mut encoder,
620 runtime,
621 x,
622 y,
623 z,
624 segments,
625 &output_x,
626 &output_y,
627 &output_z,
628 )?;
629 encoder.copy_buffer_to_buffer(&output_x, 0, &staging_buffer, 0, channel_len as u64);
630 encoder.copy_buffer_to_buffer(
631 &output_y,
632 0,
633 &staging_buffer,
634 channel_len as u64,
635 channel_len as u64,
636 );
637 encoder.copy_buffer_to_buffer(
638 &output_z,
639 0,
640 &staging_buffer,
641 (channel_len * 2) as u64,
642 channel_len as u64,
643 );
644 queue.submit(Some(encoder.finish()));
645
646 let flat = read_staging_f32(device, &staging_buffer, cell_count as usize * 3)?;
647 Ok(split_xyz_blocks(flat, cell_count as usize))
648}
649
650pub(crate) fn record_voxel_gather_xyz_and_attrs4_packed_pass(
651 encoder: &mut wgpu::CommandEncoder,
652 runtime: &WgpuRuntime,
653 x: &wgpu::Buffer,
654 y: &wgpu::Buffer,
655 z: &wgpu::Buffer,
656 attribute_buffers: &[&wgpu::Buffer; MULTI4_CHANNELS],
657 segments: &GpuVoxelSegments,
658 packed_output: &wgpu::Buffer,
659) -> SpatialResult<()> {
660 if segments.cell_count() == 0 {
661 return Ok(());
662 }
663
664 let device = runtime.device();
665 let pipelines = runtime.pipelines();
666 let xyz_attrs4_pipeline =
667 pipelines.voxel_gather.xyz_attrs4_pipeline.as_ref().ok_or_else(|| {
668 SpatialError::InvalidArgument(
669 "fused xyz+4-attribute gather pipeline is unavailable on this gpu adapter"
670 .to_owned(),
671 )
672 })?;
673 let xyz_attrs4_layout =
674 pipelines.voxel_gather.xyz_attrs4_bind_group_layout.as_ref().expect("xyz attrs4 layout");
675 let cell_count = segments.cell_count();
676 let uniform = GatherUniform {
677 cell_count,
678 point_count: segments.point_count(),
679 channel_count: 0,
680 _pad: 0,
681 };
682 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
683 label: Some("voxel-gather-xyz-attrs4-uniform"),
684 contents: bytemuck::bytes_of(&uniform),
685 usage: wgpu::BufferUsages::UNIFORM,
686 });
687 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
688 label: Some("voxel-gather-xyz-attrs4-bind-group"),
689 layout: xyz_attrs4_layout,
690 entries: &[
691 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
692 wgpu::BindGroupEntry {
693 binding: 1,
694 resource: segments.point_indices_buffer().as_entire_binding(),
695 },
696 wgpu::BindGroupEntry {
697 binding: 2,
698 resource: segments.cell_starts_buffer().as_entire_binding(),
699 },
700 wgpu::BindGroupEntry { binding: 3, resource: x.as_entire_binding() },
701 wgpu::BindGroupEntry { binding: 4, resource: y.as_entire_binding() },
702 wgpu::BindGroupEntry { binding: 5, resource: z.as_entire_binding() },
703 wgpu::BindGroupEntry { binding: 6, resource: attribute_buffers[0].as_entire_binding() },
704 wgpu::BindGroupEntry { binding: 7, resource: attribute_buffers[1].as_entire_binding() },
705 wgpu::BindGroupEntry { binding: 8, resource: attribute_buffers[2].as_entire_binding() },
706 wgpu::BindGroupEntry { binding: 9, resource: attribute_buffers[3].as_entire_binding() },
707 wgpu::BindGroupEntry { binding: 10, resource: packed_output.as_entire_binding() },
708 ],
709 });
710 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
711 label: Some("voxel-gather-xyz-attrs4-pass"),
712 timestamp_writes: None,
713 });
714 pass.set_pipeline(xyz_attrs4_pipeline);
715 pass.set_bind_group(0, &bind_group, &[]);
716 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
717 Ok(())
718}
719
720pub(crate) fn record_gather_f32_attribute_channels_to_staging(
722 encoder: &mut wgpu::CommandEncoder,
723 runtime: &WgpuRuntime,
724 attribute_channels: &[&[f32]],
725 segments: &GpuVoxelSegments,
726 staging_buffer: &wgpu::Buffer,
727 channel_len: u64,
728 upload_recycle: &mut Vec<wgpu::Buffer>,
729) -> SpatialResult<()> {
730 if attribute_channels.is_empty() {
731 return Ok(());
732 }
733
734 let device = runtime.device();
735 let max_channels = runtime.max_gather_channels().max(1) as usize;
736
737 for chunk_start in (0..attribute_channels.len()).step_by(max_channels) {
738 let chunk_end = (chunk_start + max_channels).min(attribute_channels.len());
739 let chunk = &attribute_channels[chunk_start..chunk_end];
740 let channels_in_chunk = chunk.len();
741
742 let chunk_upload_start = upload_recycle.len();
743 for channel in chunk {
744 upload_recycle
745 .push(runtime.upload_f32_storage("voxel-gather-xyz-attrs-values", channel)?);
746 }
747 let value_buffers = &upload_recycle[chunk_upload_start..];
748
749 let output_buffers: Vec<wgpu::Buffer> = (0..channels_in_chunk)
750 .map(|_| {
751 device.create_buffer(&wgpu::BufferDescriptor {
752 label: Some("voxel-gather-xyz-attrs-output"),
753 size: channel_len,
754 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
755 mapped_at_creation: false,
756 })
757 })
758 .collect();
759
760 match channels_in_chunk {
761 1 => record_voxel_gather_f32_pass(
762 encoder,
763 runtime,
764 &value_buffers[0],
765 segments.point_indices_buffer(),
766 segments.cell_starts_buffer(),
767 segments.cell_count(),
768 segments.point_count(),
769 &output_buffers[0],
770 )?,
771 2 => {
772 let values: [&wgpu::Buffer; MULTI2_CHANNELS] =
773 [&value_buffers[0], &value_buffers[1]];
774 let outputs: [&wgpu::Buffer; MULTI2_CHANNELS] =
775 [&output_buffers[0], &output_buffers[1]];
776 record_voxel_gather_multi2_pass(encoder, runtime, &values, segments, &outputs, 2)?;
777 }
778 channel_count => {
779 if runtime.pipelines().voxel_gather.multi4_pipeline.is_some() {
780 let empty = empty_storage_buffer(device)?;
781 let dummy_outputs: [wgpu::Buffer; MULTI4_CHANNELS] =
782 std::array::from_fn(|_| {
783 device.create_buffer(&wgpu::BufferDescriptor {
784 label: Some("voxel-gather-xyz-attrs-dummy-output"),
785 size: channel_len,
786 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
787 mapped_at_creation: false,
788 })
789 });
790 let values: [&wgpu::Buffer; MULTI4_CHANNELS] = std::array::from_fn(|index| {
791 if index < channel_count {
792 &value_buffers[index]
793 } else {
794 &empty
795 }
796 });
797 let outputs: [&wgpu::Buffer; MULTI4_CHANNELS] = std::array::from_fn(|index| {
798 if index < channel_count {
799 &output_buffers[index]
800 } else {
801 &dummy_outputs[index]
802 }
803 });
804 record_voxel_gather_multi4_pass(
805 encoder,
806 runtime,
807 &values,
808 segments,
809 &outputs,
810 channel_count as u32,
811 )?;
812 } else {
813 for (local_index, channel) in chunk.iter().enumerate() {
814 let chunk_upload_start = upload_recycle.len();
815 upload_recycle.push(
816 runtime.upload_f32_storage("voxel-gather-xyz-attrs-values", channel)?,
817 );
818 record_voxel_gather_f32_pass(
819 encoder,
820 runtime,
821 &upload_recycle[chunk_upload_start],
822 segments.point_indices_buffer(),
823 segments.cell_starts_buffer(),
824 segments.cell_count(),
825 segments.point_count(),
826 &output_buffers[local_index],
827 )?;
828 }
829 }
830 }
831 }
832
833 for (local_index, output_buffer) in output_buffers.iter().enumerate() {
834 encoder.copy_buffer_to_buffer(
835 output_buffer,
836 0,
837 staging_buffer,
838 channel_len * (3 + chunk_start + local_index) as u64,
839 channel_len,
840 );
841 }
842 }
843
844 Ok(())
845}
846
847pub(crate) fn record_voxel_gather_multi2_pass(
848 encoder: &mut wgpu::CommandEncoder,
849 runtime: &WgpuRuntime,
850 values: &[&wgpu::Buffer; MULTI2_CHANNELS],
851 segments: &GpuVoxelSegments,
852 outputs: &[&wgpu::Buffer; MULTI2_CHANNELS],
853 channel_count: u32,
854) -> SpatialResult<()> {
855 let device = runtime.device();
856 let pipelines = runtime.pipelines();
857 let cell_count = segments.cell_count();
858 let uniform =
859 GatherUniform { cell_count, point_count: segments.point_count(), channel_count, _pad: 0 };
860 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
861 label: Some("voxel-gather-multi-uniform"),
862 contents: bytemuck::bytes_of(&uniform),
863 usage: wgpu::BufferUsages::UNIFORM,
864 });
865 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
866 label: Some("voxel-gather-multi-bind-group"),
867 layout: &pipelines.voxel_gather.multi_bind_group_layout,
868 entries: &[
869 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
870 wgpu::BindGroupEntry {
871 binding: 1,
872 resource: segments.point_indices_buffer().as_entire_binding(),
873 },
874 wgpu::BindGroupEntry {
875 binding: 2,
876 resource: segments.cell_starts_buffer().as_entire_binding(),
877 },
878 wgpu::BindGroupEntry { binding: 3, resource: values[0].as_entire_binding() },
879 wgpu::BindGroupEntry { binding: 4, resource: values[1].as_entire_binding() },
880 wgpu::BindGroupEntry { binding: 5, resource: outputs[0].as_entire_binding() },
881 wgpu::BindGroupEntry { binding: 6, resource: outputs[1].as_entire_binding() },
882 ],
883 });
884 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
885 label: Some("voxel-gather-multi-pass"),
886 timestamp_writes: None,
887 });
888 pass.set_pipeline(&pipelines.voxel_gather.multi_pipeline);
889 pass.set_bind_group(0, &bind_group, &[]);
890 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
891 Ok(())
892}
893
894pub(crate) fn record_voxel_gather_multi4_pass(
895 encoder: &mut wgpu::CommandEncoder,
896 runtime: &WgpuRuntime,
897 values: &[&wgpu::Buffer; MULTI4_CHANNELS],
898 segments: &GpuVoxelSegments,
899 outputs: &[&wgpu::Buffer; MULTI4_CHANNELS],
900 channel_count: u32,
901) -> SpatialResult<()> {
902 let device = runtime.device();
903 let pipelines = runtime.pipelines();
904 let multi4_pipeline = pipelines.voxel_gather.multi4_pipeline.as_ref().ok_or_else(|| {
905 SpatialError::InvalidArgument(
906 "4-channel gather pipeline is unavailable on this gpu adapter".to_owned(),
907 )
908 })?;
909 let multi4_layout =
910 pipelines.voxel_gather.multi4_bind_group_layout.as_ref().expect("multi4 layout");
911 let cell_count = segments.cell_count();
912 let uniform =
913 GatherUniform { cell_count, point_count: segments.point_count(), channel_count, _pad: 0 };
914 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
915 label: Some("voxel-gather-multi4-uniform"),
916 contents: bytemuck::bytes_of(&uniform),
917 usage: wgpu::BufferUsages::UNIFORM,
918 });
919 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
920 label: Some("voxel-gather-multi4-bind-group"),
921 layout: multi4_layout,
922 entries: &[
923 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
924 wgpu::BindGroupEntry {
925 binding: 1,
926 resource: segments.point_indices_buffer().as_entire_binding(),
927 },
928 wgpu::BindGroupEntry {
929 binding: 2,
930 resource: segments.cell_starts_buffer().as_entire_binding(),
931 },
932 wgpu::BindGroupEntry { binding: 3, resource: values[0].as_entire_binding() },
933 wgpu::BindGroupEntry { binding: 4, resource: values[1].as_entire_binding() },
934 wgpu::BindGroupEntry { binding: 5, resource: values[2].as_entire_binding() },
935 wgpu::BindGroupEntry { binding: 6, resource: values[3].as_entire_binding() },
936 wgpu::BindGroupEntry { binding: 7, resource: outputs[0].as_entire_binding() },
937 wgpu::BindGroupEntry { binding: 8, resource: outputs[1].as_entire_binding() },
938 wgpu::BindGroupEntry { binding: 9, resource: outputs[2].as_entire_binding() },
939 wgpu::BindGroupEntry { binding: 10, resource: outputs[3].as_entire_binding() },
940 ],
941 });
942 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
943 label: Some("voxel-gather-multi4-pass"),
944 timestamp_writes: None,
945 });
946 pass.set_pipeline(multi4_pipeline);
947 pass.set_bind_group(0, &bind_group, &[]);
948 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
949 Ok(())
950}
951
952pub(crate) fn record_voxel_gather_f32_pass(
953 encoder: &mut wgpu::CommandEncoder,
954 runtime: &WgpuRuntime,
955 values: &wgpu::Buffer,
956 point_indices: &wgpu::Buffer,
957 cell_starts: &wgpu::Buffer,
958 cell_count: u32,
959 point_count: u32,
960 output_buffer: &wgpu::Buffer,
961) -> SpatialResult<()> {
962 if cell_count == 0 {
963 return Ok(());
964 }
965
966 let device = runtime.device();
967 let pipelines = runtime.pipelines();
968 let uniform = GatherUniform { cell_count, point_count, channel_count: 1, _pad: 0 };
969 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
970 label: Some("voxel-gather-uniform"),
971 contents: bytemuck::bytes_of(&uniform),
972 usage: wgpu::BufferUsages::UNIFORM,
973 });
974
975 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
976 label: Some("voxel-gather-bind-group"),
977 layout: &pipelines.voxel_gather.bind_group_layout,
978 entries: &[
979 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
980 wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
981 wgpu::BindGroupEntry { binding: 2, resource: values.as_entire_binding() },
982 wgpu::BindGroupEntry { binding: 3, resource: cell_starts.as_entire_binding() },
983 wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
984 ],
985 });
986
987 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
988 label: Some("voxel-gather-pass"),
989 timestamp_writes: None,
990 });
991 pass.set_pipeline(&pipelines.voxel_gather.pipeline);
992 pass.set_bind_group(0, &bind_group, &[]);
993 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
994 Ok(())
995}
996
997pub(crate) fn record_voxel_gather_u8_pass(
998 encoder: &mut wgpu::CommandEncoder,
999 runtime: &WgpuRuntime,
1000 values: &wgpu::Buffer,
1001 point_indices: &wgpu::Buffer,
1002 cell_starts: &wgpu::Buffer,
1003 cell_count: u32,
1004 point_count: u32,
1005 output_buffer: &wgpu::Buffer,
1006) -> SpatialResult<()> {
1007 if cell_count == 0 {
1008 return Ok(());
1009 }
1010
1011 let device = runtime.device();
1012 let pipelines = runtime.pipelines();
1013 let uniform = GatherUniform { cell_count, point_count, channel_count: 1, _pad: 0 };
1014 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1015 label: Some("voxel-gather-u8-uniform"),
1016 contents: bytemuck::bytes_of(&uniform),
1017 usage: wgpu::BufferUsages::UNIFORM,
1018 });
1019
1020 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1021 label: Some("voxel-gather-u8-bind-group"),
1022 layout: &pipelines.voxel_gather.u8_bind_group_layout,
1023 entries: &[
1024 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
1025 wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
1026 wgpu::BindGroupEntry { binding: 2, resource: values.as_entire_binding() },
1027 wgpu::BindGroupEntry { binding: 3, resource: cell_starts.as_entire_binding() },
1028 wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
1029 ],
1030 });
1031
1032 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1033 label: Some("voxel-gather-u8-pass"),
1034 timestamp_writes: None,
1035 });
1036 pass.set_pipeline(&pipelines.voxel_gather.u8_pipeline);
1037 pass.set_bind_group(0, &bind_group, &[]);
1038 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
1039 Ok(())
1040}
1041
1042#[allow(clippy::too_many_arguments)]
1043pub(crate) fn record_voxel_gather_xyz_pass(
1044 encoder: &mut wgpu::CommandEncoder,
1045 runtime: &WgpuRuntime,
1046 x: &wgpu::Buffer,
1047 y: &wgpu::Buffer,
1048 z: &wgpu::Buffer,
1049 segments: &GpuVoxelSegments,
1050 output_x: &wgpu::Buffer,
1051 output_y: &wgpu::Buffer,
1052 output_z: &wgpu::Buffer,
1053) -> SpatialResult<()> {
1054 if segments.cell_count() == 0 {
1055 return Ok(());
1056 }
1057
1058 let device = runtime.device();
1059 let pipelines = runtime.pipelines();
1060 let cell_count = segments.cell_count();
1061 let point_count = segments.point_count();
1062 let uniform = GatherUniform { cell_count, point_count, channel_count: 0, _pad: 0 };
1063 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1064 label: Some("voxel-gather-xyz-uniform"),
1065 contents: bytemuck::bytes_of(&uniform),
1066 usage: wgpu::BufferUsages::UNIFORM,
1067 });
1068
1069 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1070 label: Some("voxel-gather-xyz-bind-group"),
1071 layout: &pipelines.voxel_gather.xyz_bind_group_layout,
1072 entries: &[
1073 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
1074 wgpu::BindGroupEntry {
1075 binding: 1,
1076 resource: segments.point_indices_buffer().as_entire_binding(),
1077 },
1078 wgpu::BindGroupEntry {
1079 binding: 2,
1080 resource: segments.cell_starts_buffer().as_entire_binding(),
1081 },
1082 wgpu::BindGroupEntry { binding: 3, resource: x.as_entire_binding() },
1083 wgpu::BindGroupEntry { binding: 4, resource: y.as_entire_binding() },
1084 wgpu::BindGroupEntry { binding: 5, resource: z.as_entire_binding() },
1085 wgpu::BindGroupEntry { binding: 6, resource: output_x.as_entire_binding() },
1086 wgpu::BindGroupEntry { binding: 7, resource: output_y.as_entire_binding() },
1087 wgpu::BindGroupEntry { binding: 8, resource: output_z.as_entire_binding() },
1088 ],
1089 });
1090
1091 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1092 label: Some("voxel-gather-xyz-pass"),
1093 timestamp_writes: None,
1094 });
1095 pass.set_pipeline(&pipelines.voxel_gather.xyz_pipeline);
1096 pass.set_bind_group(0, &bind_group, &[]);
1097 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
1098 Ok(())
1099}
1100
1101fn dispatch_voxel_gather_f32(
1102 runtime: &WgpuRuntime,
1103 values: &wgpu::Buffer,
1104 point_indices: &wgpu::Buffer,
1105 cell_starts: &wgpu::Buffer,
1106 cell_count: u32,
1107 point_count: u32,
1108) -> SpatialResult<Vec<f32>> {
1109 if cell_count == 0 {
1110 return Ok(Vec::new());
1111 }
1112
1113 let device = runtime.device();
1114 let queue = runtime.queue();
1115 let pipelines = runtime.pipelines();
1116
1117 let uniform = GatherUniform { cell_count, point_count, channel_count: 1, _pad: 0 };
1118
1119 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1120 label: Some("voxel-gather-uniform"),
1121 contents: bytemuck::bytes_of(&uniform),
1122 usage: wgpu::BufferUsages::UNIFORM,
1123 });
1124
1125 let output_len = cell_count as usize * std::mem::size_of::<f32>();
1126 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
1127 label: Some("voxel-gather-output"),
1128 size: output_len as u64,
1129 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1130 mapped_at_creation: false,
1131 });
1132 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
1133 label: Some("voxel-gather-staging"),
1134 size: output_len as u64,
1135 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1136 mapped_at_creation: false,
1137 });
1138
1139 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1140 label: Some("voxel-gather-bind-group"),
1141 layout: &pipelines.voxel_gather.bind_group_layout,
1142 entries: &[
1143 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
1144 wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
1145 wgpu::BindGroupEntry { binding: 2, resource: values.as_entire_binding() },
1146 wgpu::BindGroupEntry { binding: 3, resource: cell_starts.as_entire_binding() },
1147 wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
1148 ],
1149 });
1150
1151 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1152 label: Some("voxel-gather-encoder"),
1153 });
1154
1155 {
1156 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1157 label: Some("voxel-gather-pass"),
1158 timestamp_writes: None,
1159 });
1160 pass.set_pipeline(&pipelines.voxel_gather.pipeline);
1161 pass.set_bind_group(0, &bind_group, &[]);
1162 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
1163 }
1164
1165 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_len as u64);
1166 queue.submit(Some(encoder.finish()));
1167
1168 read_staging_f32(device, &staging_buffer, cell_count as usize)
1169}
1170
1171fn dispatch_voxel_gather_multi_gpu_buffers(
1172 runtime: &WgpuRuntime,
1173 values: &[&wgpu::Buffer; MULTI4_CHANNELS],
1174 segments: &GpuVoxelSegments,
1175 channel_count: u32,
1176) -> SpatialResult<Vec<Vec<f32>>> {
1177 let pipelines = runtime.pipelines();
1178 if channel_count > 2 && pipelines.voxel_gather.multi4_pipeline.is_some() {
1179 dispatch_voxel_gather_multi4_gpu_buffers(runtime, values, segments, channel_count)
1180 } else {
1181 if channel_count > MULTI2_CHANNELS as u32 {
1182 return Err(SpatialError::InvalidArgument(format!(
1183 "gpu adapter supports only {} channels per gather dispatch",
1184 pipelines.voxel_gather.multi_max_channels
1185 )));
1186 }
1187 dispatch_voxel_gather_multi2_gpu_buffers(runtime, values, segments, channel_count)
1188 }
1189}
1190
1191fn dispatch_voxel_gather_multi2_gpu_buffers(
1192 runtime: &WgpuRuntime,
1193 values: &[&wgpu::Buffer; MULTI4_CHANNELS],
1194 segments: &GpuVoxelSegments,
1195 channel_count: u32,
1196) -> SpatialResult<Vec<Vec<f32>>> {
1197 let device = runtime.device();
1198 let queue = runtime.queue();
1199 let pipelines = runtime.pipelines();
1200 let cell_count = segments.cell_count();
1201 let point_count = segments.point_count();
1202 let channels = channel_count as usize;
1203
1204 let uniform = GatherUniform { cell_count, point_count, channel_count, _pad: 0 };
1205 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1206 label: Some("voxel-gather-multi-uniform"),
1207 contents: bytemuck::bytes_of(&uniform),
1208 usage: wgpu::BufferUsages::UNIFORM,
1209 });
1210
1211 let channel_len = cell_count as usize * std::mem::size_of::<f32>();
1212 let mut outputs = [None, None];
1213 for output in outputs.iter_mut() {
1214 *output = Some(device.create_buffer(&wgpu::BufferDescriptor {
1215 label: Some("voxel-gather-multi-output"),
1216 size: channel_len as u64,
1217 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1218 mapped_at_creation: false,
1219 }));
1220 }
1221 let output_refs: [&wgpu::Buffer; MULTI2_CHANNELS] =
1222 std::array::from_fn(|index| outputs[index].as_ref().expect("output buffer"));
1223
1224 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
1225 label: Some("voxel-gather-multi-staging"),
1226 size: (channel_len * channels) as u64,
1227 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1228 mapped_at_creation: false,
1229 });
1230
1231 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1232 label: Some("voxel-gather-multi-bind-group"),
1233 layout: &pipelines.voxel_gather.multi_bind_group_layout,
1234 entries: &[
1235 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
1236 wgpu::BindGroupEntry {
1237 binding: 1,
1238 resource: segments.point_indices_buffer().as_entire_binding(),
1239 },
1240 wgpu::BindGroupEntry {
1241 binding: 2,
1242 resource: segments.cell_starts_buffer().as_entire_binding(),
1243 },
1244 wgpu::BindGroupEntry { binding: 3, resource: values[0].as_entire_binding() },
1245 wgpu::BindGroupEntry { binding: 4, resource: values[1].as_entire_binding() },
1246 wgpu::BindGroupEntry { binding: 5, resource: output_refs[0].as_entire_binding() },
1247 wgpu::BindGroupEntry { binding: 6, resource: output_refs[1].as_entire_binding() },
1248 ],
1249 });
1250
1251 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1252 label: Some("voxel-gather-multi-encoder"),
1253 });
1254 {
1255 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1256 label: Some("voxel-gather-multi-pass"),
1257 timestamp_writes: None,
1258 });
1259 pass.set_pipeline(&pipelines.voxel_gather.multi_pipeline);
1260 pass.set_bind_group(0, &bind_group, &[]);
1261 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
1262 }
1263
1264 for (index, output) in output_refs.iter().take(channels).enumerate() {
1265 let offset = (channel_len * index) as u64;
1266 encoder.copy_buffer_to_buffer(output, 0, &staging_buffer, offset, channel_len as u64);
1267 }
1268 queue.submit(Some(encoder.finish()));
1269
1270 let flat = read_staging_f32(device, &staging_buffer, cell_count as usize * channels)?;
1271 Ok(split_channel_blocks(flat, channels, cell_count as usize))
1272}
1273
1274fn dispatch_voxel_gather_multi4_gpu_buffers(
1275 runtime: &WgpuRuntime,
1276 values: &[&wgpu::Buffer; MULTI4_CHANNELS],
1277 segments: &GpuVoxelSegments,
1278 channel_count: u32,
1279) -> SpatialResult<Vec<Vec<f32>>> {
1280 let device = runtime.device();
1281 let queue = runtime.queue();
1282 let pipelines = runtime.pipelines();
1283 let multi4_pipeline = pipelines.voxel_gather.multi4_pipeline.as_ref().ok_or_else(|| {
1284 SpatialError::InvalidArgument(
1285 "4-channel gather pipeline is unavailable on this gpu adapter".to_owned(),
1286 )
1287 })?;
1288 let multi4_layout =
1289 pipelines.voxel_gather.multi4_bind_group_layout.as_ref().expect("multi4 layout");
1290
1291 let cell_count = segments.cell_count();
1292 let point_count = segments.point_count();
1293 let channels = channel_count as usize;
1294
1295 let uniform = GatherUniform { cell_count, point_count, channel_count, _pad: 0 };
1296 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
1297 label: Some("voxel-gather-multi4-uniform"),
1298 contents: bytemuck::bytes_of(&uniform),
1299 usage: wgpu::BufferUsages::UNIFORM,
1300 });
1301
1302 let channel_len = cell_count as usize * std::mem::size_of::<f32>();
1303 let mut outputs = [None, None, None, None];
1304 for output in outputs.iter_mut() {
1305 *output = Some(device.create_buffer(&wgpu::BufferDescriptor {
1306 label: Some("voxel-gather-multi4-output"),
1307 size: channel_len as u64,
1308 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
1309 mapped_at_creation: false,
1310 }));
1311 }
1312 let output_refs: [&wgpu::Buffer; MULTI4_CHANNELS] =
1313 std::array::from_fn(|index| outputs[index].as_ref().expect("output buffer"));
1314
1315 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
1316 label: Some("voxel-gather-multi4-staging"),
1317 size: (channel_len * channels) as u64,
1318 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
1319 mapped_at_creation: false,
1320 });
1321
1322 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
1323 label: Some("voxel-gather-multi4-bind-group"),
1324 layout: multi4_layout,
1325 entries: &[
1326 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
1327 wgpu::BindGroupEntry {
1328 binding: 1,
1329 resource: segments.point_indices_buffer().as_entire_binding(),
1330 },
1331 wgpu::BindGroupEntry {
1332 binding: 2,
1333 resource: segments.cell_starts_buffer().as_entire_binding(),
1334 },
1335 wgpu::BindGroupEntry { binding: 3, resource: values[0].as_entire_binding() },
1336 wgpu::BindGroupEntry { binding: 4, resource: values[1].as_entire_binding() },
1337 wgpu::BindGroupEntry { binding: 5, resource: values[2].as_entire_binding() },
1338 wgpu::BindGroupEntry { binding: 6, resource: values[3].as_entire_binding() },
1339 wgpu::BindGroupEntry { binding: 7, resource: output_refs[0].as_entire_binding() },
1340 wgpu::BindGroupEntry { binding: 8, resource: output_refs[1].as_entire_binding() },
1341 wgpu::BindGroupEntry { binding: 9, resource: output_refs[2].as_entire_binding() },
1342 wgpu::BindGroupEntry { binding: 10, resource: output_refs[3].as_entire_binding() },
1343 ],
1344 });
1345
1346 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
1347 label: Some("voxel-gather-multi4-encoder"),
1348 });
1349 {
1350 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
1351 label: Some("voxel-gather-multi4-pass"),
1352 timestamp_writes: None,
1353 });
1354 pass.set_pipeline(multi4_pipeline);
1355 pass.set_bind_group(0, &bind_group, &[]);
1356 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
1357 }
1358
1359 for (index, output) in output_refs.iter().take(channels).enumerate() {
1360 let offset = (channel_len * index) as u64;
1361 encoder.copy_buffer_to_buffer(output, 0, &staging_buffer, offset, channel_len as u64);
1362 }
1363 queue.submit(Some(encoder.finish()));
1364
1365 let flat = read_staging_f32(device, &staging_buffer, cell_count as usize * channels)?;
1366 Ok(split_channel_blocks(flat, channels, cell_count as usize))
1367}
1368
1369fn empty_storage_buffer(device: &wgpu::Device) -> SpatialResult<wgpu::Buffer> {
1370 Ok(device.create_buffer(&wgpu::BufferDescriptor {
1371 label: Some("voxel-gather-empty"),
1372 size: 4,
1373 usage: wgpu::BufferUsages::STORAGE,
1374 mapped_at_creation: false,
1375 }))
1376}
1377
1378#[cfg(test)]
1379mod tests {
1380 use super::{gather_voxel_first_f32, gather_voxel_first_f32_multi_gpu};
1381 use crate::kernels::gpu_segments::GpuVoxelSegments;
1382 use crate::kernels::voxel_segments::build_voxel_segments;
1383 use crate::runtime::WgpuRuntime;
1384
1385 fn gpu_segments_from_keys(runtime: &WgpuRuntime, keys: &[(i64, i64, i64)]) -> GpuVoxelSegments {
1386 use crate::kernels::voxel_sort::build_voxel_segments_gpu_from_keys;
1387 build_voxel_segments_gpu_from_keys(runtime, keys).expect("gpu segments")
1388 }
1389
1390 #[test]
1391 fn gpu_first_gather_matches_cpu_reference() {
1392 let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
1393 let values = [0.2_f32, 0.9, 10.0, 20.0];
1394 let keys = vec![(0, 0, 0), (0, 0, 0), (2, 0, 0), (2, 0, 0)];
1395 let segments = build_voxel_segments(&keys);
1396
1397 let gpu = gather_voxel_first_f32(&runtime, &values, &segments).expect("gpu gather");
1398 assert!((gpu[0] - 0.2).abs() < 1e-5);
1399 assert!((gpu[1] - 10.0).abs() < 1e-5);
1400 }
1401
1402 #[test]
1403 fn gpu_multi_gather_matches_single_channel_reference() {
1404 let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
1405 let intensity = [0.2_f32, 0.9, 10.0, 20.0];
1406 let classification = [1.0_f32, 2.0, 3.0, 4.0];
1407 let keys = vec![(0, 0, 0), (0, 0, 0), (2, 0, 0), (2, 0, 0)];
1408 let segments = gpu_segments_from_keys(&runtime, &keys);
1409
1410 let multi =
1411 gather_voxel_first_f32_multi_gpu(&runtime, &[&intensity, &classification], &segments)
1412 .expect("multi gather");
1413
1414 assert!((multi[0][0] - 0.2).abs() < 1e-5);
1415 assert!((multi[0][1] - 10.0).abs() < 1e-5);
1416 assert!((multi[1][0] - 1.0).abs() < 1e-5);
1417 assert!((multi[1][1] - 3.0).abs() < 1e-5);
1418 }
1419
1420 #[test]
1421 fn gpu_multi4_gather_handles_four_channels_when_supported() {
1422 let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
1423 if runtime.max_gather_channels() < 4 {
1424 return;
1425 }
1426
1427 let c0 = [0.2_f32, 0.9, 10.0, 20.0];
1428 let c1 = [1.0_f32, 2.0, 3.0, 4.0];
1429 let c2 = [5.0_f32, 6.0, 7.0, 8.0];
1430 let c3 = [9.0_f32, 10.0, 11.0, 12.0];
1431 let keys = vec![(0, 0, 0), (0, 0, 0), (2, 0, 0), (2, 0, 0)];
1432 let segments = gpu_segments_from_keys(&runtime, &keys);
1433
1434 let multi = gather_voxel_first_f32_multi_gpu(&runtime, &[&c0, &c1, &c2, &c3], &segments)
1435 .expect("multi4 gather");
1436
1437 assert_eq!(multi.len(), 4);
1438 assert!((multi[0][0] - 0.2).abs() < 1e-5);
1439 assert!((multi[0][1] - 10.0).abs() < 1e-5);
1440 assert!((multi[1][0] - 1.0).abs() < 1e-5);
1441 assert!((multi[1][1] - 3.0).abs() < 1e-5);
1442 assert!((multi[2][0] - 5.0).abs() < 1e-5);
1443 assert!((multi[2][1] - 7.0).abs() < 1e-5);
1444 assert!((multi[3][0] - 9.0).abs() < 1e-5);
1445 assert!((multi[3][1] - 11.0).abs() < 1e-5);
1446 }
1447}