1use bytemuck::{Pod, Zeroable};
2use spatialrust_core::{SpatialError, SpatialResult};
3use wgpu::util::DeviceExt;
4
5use crate::kernels::gpu_segments::GpuVoxelSegments;
6use crate::kernels::voxel_segments::VoxelSegments;
7use crate::readback::{
8 pad_u8_for_gpu_storage, read_staging_f32, read_staging_f32_and_u8, split_channel_blocks,
9 split_u8_channel_blocks, split_xyz_and_attribute_blocks, split_xyz_blocks,
10 u8_output_staging_bytes, unpack_u8_outputs_from_u32_staging,
11};
12use crate::runtime::WgpuRuntime;
13
14const WORKGROUP_SIZE: u32 = 256;
15
16#[repr(C)]
17#[derive(Clone, Copy, Debug, Pod, Zeroable)]
18struct ReduceUniform {
19 cell_count: u32,
20 point_count: u32,
21 _pad0: u32,
22 _pad1: u32,
23}
24
25pub fn reduce_voxel_average_f32(
27 runtime: &WgpuRuntime,
28 values: &[f32],
29 segments: &VoxelSegments,
30) -> SpatialResult<Vec<f32>> {
31 if segments.is_empty() {
32 return Ok(Vec::new());
33 }
34 if values.is_empty() {
35 return Err(SpatialError::InvalidArgument("cannot reduce empty value buffer".to_owned()));
36 }
37
38 let device = runtime.device();
39 let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
40 label: Some("voxel-reduce-values"),
41 contents: bytemuck::cast_slice(values),
42 usage: wgpu::BufferUsages::STORAGE,
43 });
44 let indices_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
45 label: Some("voxel-reduce-indices"),
46 contents: bytemuck::cast_slice(&segments.point_indices),
47 usage: wgpu::BufferUsages::STORAGE,
48 });
49 let starts_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
50 label: Some("voxel-reduce-starts"),
51 contents: bytemuck::cast_slice(&segments.cell_starts),
52 usage: wgpu::BufferUsages::STORAGE,
53 });
54
55 dispatch_voxel_reduce_f32(
56 runtime,
57 &values_buffer,
58 &indices_buffer,
59 &starts_buffer,
60 segments.len() as u32,
61 segments.point_indices.len() as u32,
62 )
63}
64
65pub fn reduce_voxel_average_f32_gpu_buffers(
67 runtime: &WgpuRuntime,
68 values: &wgpu::Buffer,
69 segments: &GpuVoxelSegments,
70) -> SpatialResult<Vec<f32>> {
71 dispatch_voxel_reduce_f32(
72 runtime,
73 values,
74 segments.point_indices_buffer(),
75 segments.cell_starts_buffer(),
76 segments.cell_count(),
77 segments.point_count(),
78 )
79}
80
81pub fn reduce_voxel_average_f32_gpu(
83 runtime: &WgpuRuntime,
84 values: &[f32],
85 segments: &GpuVoxelSegments,
86) -> SpatialResult<Vec<f32>> {
87 if segments.cell_count() == 0 {
88 return Ok(Vec::new());
89 }
90 if values.len() != segments.point_count() as usize {
91 return Err(SpatialError::BufferLengthMismatch {
92 expected: segments.point_count() as usize,
93 found: values.len(),
94 });
95 }
96
97 let device = runtime.device();
98 let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
99 label: Some("voxel-reduce-values-upload"),
100 contents: bytemuck::cast_slice(values),
101 usage: wgpu::BufferUsages::STORAGE,
102 });
103 reduce_voxel_average_f32_gpu_buffers(runtime, &values_buffer, segments)
104}
105
106pub fn reduce_voxel_average_f32_multi_gpu(
108 runtime: &WgpuRuntime,
109 channels: &[&[f32]],
110 segments: &GpuVoxelSegments,
111) -> SpatialResult<Vec<Vec<f32>>> {
112 if channels.is_empty() {
113 return Ok(Vec::new());
114 }
115 if segments.cell_count() == 0 {
116 return Ok(vec![Vec::new(); channels.len()]);
117 }
118
119 let point_count = segments.point_count() as usize;
120 for channel in channels {
121 if channel.len() != point_count {
122 return Err(SpatialError::BufferLengthMismatch {
123 expected: point_count,
124 found: channel.len(),
125 });
126 }
127 }
128
129 let device = runtime.device();
130 let queue = runtime.queue();
131 let cell_count = segments.cell_count();
132 let cells = cell_count as usize;
133 let channel_len = cells * std::mem::size_of::<f32>();
134 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
135 label: Some("voxel-reduce-multi-staging"),
136 size: (channel_len * channels.len()) as u64,
137 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
138 mapped_at_creation: false,
139 });
140
141 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
142 label: Some("voxel-reduce-multi-encoder"),
143 });
144
145 for (channel_index, channel) in channels.iter().enumerate() {
146 let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
147 label: Some("voxel-reduce-multi-values"),
148 contents: bytemuck::cast_slice(channel),
149 usage: wgpu::BufferUsages::STORAGE,
150 });
151 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
152 label: Some("voxel-reduce-multi-output"),
153 size: channel_len as u64,
154 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
155 mapped_at_creation: false,
156 });
157
158 record_voxel_reduce_f32_pass(
159 &mut encoder,
160 runtime,
161 &values_buffer,
162 segments.point_indices_buffer(),
163 segments.cell_starts_buffer(),
164 cell_count,
165 segments.point_count(),
166 &output_buffer,
167 )?;
168
169 encoder.copy_buffer_to_buffer(
170 &output_buffer,
171 0,
172 &staging_buffer,
173 (channel_len * channel_index) as u64,
174 channel_len as u64,
175 );
176 }
177
178 queue.submit(Some(encoder.finish()));
179
180 let flat = read_staging_f32(device, &staging_buffer, cells * channels.len())?;
181 Ok(split_channel_blocks(flat, channels.len(), cells))
182}
183
184pub fn reduce_voxel_centroids_xyz_and_average_multi_gpu(
186 runtime: &WgpuRuntime,
187 x: &wgpu::Buffer,
188 y: &wgpu::Buffer,
189 z: &wgpu::Buffer,
190 attribute_channels: &[&[f32]],
191 u8_attribute_channels: &[&[u8]],
192 segments: &GpuVoxelSegments,
193) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<Vec<f32>>, Vec<Vec<u8>>)> {
194 let attribute_count = attribute_channels.len();
195 let u8_attribute_count = u8_attribute_channels.len();
196 if segments.cell_count() == 0 {
197 return Ok((
198 Vec::new(),
199 Vec::new(),
200 Vec::new(),
201 vec![Vec::new(); attribute_count],
202 vec![Vec::new(); u8_attribute_count],
203 ));
204 }
205
206 let point_count = segments.point_count() as usize;
207 for channel in attribute_channels {
208 if channel.len() != point_count {
209 return Err(SpatialError::BufferLengthMismatch {
210 expected: point_count,
211 found: channel.len(),
212 });
213 }
214 }
215 for channel in u8_attribute_channels {
216 if channel.len() != point_count {
217 return Err(SpatialError::BufferLengthMismatch {
218 expected: point_count,
219 found: channel.len(),
220 });
221 }
222 }
223
224 let device = runtime.device();
225 let queue = runtime.queue();
226 let cell_count = segments.cell_count();
227 let cells = cell_count as usize;
228 let channel_len = cells * std::mem::size_of::<f32>();
229 let f32_channel_count = 3 + attribute_count;
230 let u8_staging_len = u8_output_staging_bytes(cells, u8_attribute_count);
231 let staging_size = channel_len * f32_channel_count + u8_staging_len;
232
233 let output_x = device.create_buffer(&wgpu::BufferDescriptor {
234 label: Some("voxel-reduce-xyz-out-x"),
235 size: channel_len as u64,
236 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
237 mapped_at_creation: false,
238 });
239 let output_y = device.create_buffer(&wgpu::BufferDescriptor {
240 label: Some("voxel-reduce-xyz-out-y"),
241 size: channel_len as u64,
242 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
243 mapped_at_creation: false,
244 });
245 let output_z = device.create_buffer(&wgpu::BufferDescriptor {
246 label: Some("voxel-reduce-xyz-out-z"),
247 size: channel_len as u64,
248 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
249 mapped_at_creation: false,
250 });
251 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
252 label: Some("voxel-reduce-xyz-attrs-staging"),
253 size: staging_size as u64,
254 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
255 mapped_at_creation: false,
256 });
257
258 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
259 label: Some("voxel-reduce-xyz-attrs-encoder"),
260 });
261
262 record_voxel_reduce_xyz_pass(
263 &mut encoder,
264 runtime,
265 x,
266 y,
267 z,
268 segments.point_indices_buffer(),
269 segments.cell_starts_buffer(),
270 cell_count,
271 segments.point_count(),
272 &output_x,
273 &output_y,
274 &output_z,
275 )?;
276
277 encoder.copy_buffer_to_buffer(&output_x, 0, &staging_buffer, 0, channel_len as u64);
278 encoder.copy_buffer_to_buffer(
279 &output_y,
280 0,
281 &staging_buffer,
282 channel_len as u64,
283 channel_len as u64,
284 );
285 encoder.copy_buffer_to_buffer(
286 &output_z,
287 0,
288 &staging_buffer,
289 (channel_len * 2) as u64,
290 channel_len as u64,
291 );
292
293 for (attribute_index, channel) in attribute_channels.iter().enumerate() {
294 let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
295 label: Some("voxel-reduce-xyz-attrs-values"),
296 contents: bytemuck::cast_slice(channel),
297 usage: wgpu::BufferUsages::STORAGE,
298 });
299 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
300 label: Some("voxel-reduce-xyz-attrs-output"),
301 size: channel_len as u64,
302 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
303 mapped_at_creation: false,
304 });
305
306 record_voxel_reduce_f32_pass(
307 &mut encoder,
308 runtime,
309 &values_buffer,
310 segments.point_indices_buffer(),
311 segments.cell_starts_buffer(),
312 cell_count,
313 segments.point_count(),
314 &output_buffer,
315 )?;
316
317 encoder.copy_buffer_to_buffer(
318 &output_buffer,
319 0,
320 &staging_buffer,
321 (channel_len * (3 + attribute_index)) as u64,
322 channel_len as u64,
323 );
324 }
325
326 let u8_region_offset = (channel_len * f32_channel_count) as u64;
327 for (attribute_index, channel) in u8_attribute_channels.iter().enumerate() {
328 let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
329 label: Some("voxel-reduce-xyz-u8-values"),
330 contents: &pad_u8_for_gpu_storage(channel),
331 usage: wgpu::BufferUsages::STORAGE,
332 });
333 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
334 label: Some("voxel-reduce-xyz-u8-output"),
335 size: (cells * std::mem::size_of::<u32>()) as u64,
336 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
337 mapped_at_creation: false,
338 });
339
340 record_voxel_reduce_u8_pass(
341 &mut encoder,
342 runtime,
343 &values_buffer,
344 segments.point_indices_buffer(),
345 segments.cell_starts_buffer(),
346 cell_count,
347 segments.point_count(),
348 &output_buffer,
349 )?;
350
351 encoder.copy_buffer_to_buffer(
352 &output_buffer,
353 0,
354 &staging_buffer,
355 u8_region_offset + attribute_index as u64 * (cells * std::mem::size_of::<u32>()) as u64,
356 (cells * std::mem::size_of::<u32>()) as u64,
357 );
358 }
359
360 queue.submit(Some(encoder.finish()));
361
362 let (flat, u8_raw) = read_staging_f32_and_u8(
363 device,
364 &staging_buffer,
365 cells * f32_channel_count,
366 u8_staging_len,
367 )?;
368 let u8_flat = if u8_attribute_count == 0 {
369 Vec::new()
370 } else {
371 unpack_u8_outputs_from_u32_staging(u8_raw, cells, u8_attribute_count)
372 };
373 let (out_x, out_y, out_z, attributes) =
374 split_xyz_and_attribute_blocks(flat, attribute_count, cells);
375 Ok((
376 out_x,
377 out_y,
378 out_z,
379 attributes,
380 split_u8_channel_blocks(u8_flat, u8_attribute_count, cells),
381 ))
382}
383
384pub fn reduce_voxel_centroids_xyz_and_gather_first_multi_gpu(
386 runtime: &WgpuRuntime,
387 x: &wgpu::Buffer,
388 y: &wgpu::Buffer,
389 z: &wgpu::Buffer,
390 attribute_channels: &[&[f32]],
391 u8_attribute_channels: &[&[u8]],
392 segments: &GpuVoxelSegments,
393) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>, Vec<Vec<f32>>, Vec<Vec<u8>>)> {
394 use crate::kernels::voxel_gather::record_gather_f32_attribute_channels_to_staging;
395 use crate::kernels::voxel_gather::record_voxel_gather_u8_pass;
396
397 let attribute_count = attribute_channels.len();
398 let u8_attribute_count = u8_attribute_channels.len();
399 if segments.cell_count() == 0 {
400 return Ok((
401 Vec::new(),
402 Vec::new(),
403 Vec::new(),
404 vec![Vec::new(); attribute_count],
405 vec![Vec::new(); u8_attribute_count],
406 ));
407 }
408
409 let point_count = segments.point_count() as usize;
410 for channel in attribute_channels {
411 if channel.len() != point_count {
412 return Err(SpatialError::BufferLengthMismatch {
413 expected: point_count,
414 found: channel.len(),
415 });
416 }
417 }
418 for channel in u8_attribute_channels {
419 if channel.len() != point_count {
420 return Err(SpatialError::BufferLengthMismatch {
421 expected: point_count,
422 found: channel.len(),
423 });
424 }
425 }
426
427 let device = runtime.device();
428 let queue = runtime.queue();
429 let cell_count = segments.cell_count();
430 let cells = cell_count as usize;
431 let channel_len = cells * std::mem::size_of::<f32>();
432 let f32_channel_count = 3 + attribute_count;
433 let u8_staging_len = u8_output_staging_bytes(cells, u8_attribute_count);
434 let staging_size = channel_len * f32_channel_count + u8_staging_len;
435 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
436 label: Some("voxel-reduce-xyz-gather-attrs-staging"),
437 size: staging_size as u64,
438 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
439 mapped_at_creation: false,
440 });
441 let output_x = device.create_buffer(&wgpu::BufferDescriptor {
442 label: Some("voxel-reduce-xyz-out-x"),
443 size: channel_len as u64,
444 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
445 mapped_at_creation: false,
446 });
447 let output_y = device.create_buffer(&wgpu::BufferDescriptor {
448 label: Some("voxel-reduce-xyz-out-y"),
449 size: channel_len as u64,
450 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
451 mapped_at_creation: false,
452 });
453 let output_z = device.create_buffer(&wgpu::BufferDescriptor {
454 label: Some("voxel-reduce-xyz-out-z"),
455 size: channel_len as u64,
456 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
457 mapped_at_creation: false,
458 });
459
460 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
461 label: Some("voxel-reduce-xyz-gather-attrs-encoder"),
462 });
463 let mut upload_recycle = Vec::new();
464 record_voxel_reduce_xyz_pass(
465 &mut encoder,
466 runtime,
467 x,
468 y,
469 z,
470 segments.point_indices_buffer(),
471 segments.cell_starts_buffer(),
472 cell_count,
473 segments.point_count(),
474 &output_x,
475 &output_y,
476 &output_z,
477 )?;
478 encoder.copy_buffer_to_buffer(&output_x, 0, &staging_buffer, 0, channel_len as u64);
479 encoder.copy_buffer_to_buffer(
480 &output_y,
481 0,
482 &staging_buffer,
483 channel_len as u64,
484 channel_len as u64,
485 );
486 encoder.copy_buffer_to_buffer(
487 &output_z,
488 0,
489 &staging_buffer,
490 (channel_len * 2) as u64,
491 channel_len as u64,
492 );
493
494 record_gather_f32_attribute_channels_to_staging(
495 &mut encoder,
496 runtime,
497 attribute_channels,
498 segments,
499 &staging_buffer,
500 channel_len as u64,
501 &mut upload_recycle,
502 )?;
503
504 let u8_region_offset = (channel_len * f32_channel_count) as u64;
505 for (attribute_index, channel) in u8_attribute_channels.iter().enumerate() {
506 let values_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
507 label: Some("voxel-gather-xyz-u8-values"),
508 contents: &pad_u8_for_gpu_storage(channel),
509 usage: wgpu::BufferUsages::STORAGE,
510 });
511 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
512 label: Some("voxel-gather-xyz-u8-output"),
513 size: (cells * std::mem::size_of::<u32>()) as u64,
514 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
515 mapped_at_creation: false,
516 });
517 record_voxel_gather_u8_pass(
518 &mut encoder,
519 runtime,
520 &values_buffer,
521 segments.point_indices_buffer(),
522 segments.cell_starts_buffer(),
523 cell_count,
524 segments.point_count(),
525 &output_buffer,
526 )?;
527 encoder.copy_buffer_to_buffer(
528 &output_buffer,
529 0,
530 &staging_buffer,
531 u8_region_offset + attribute_index as u64 * (cells * std::mem::size_of::<u32>()) as u64,
532 (cells * std::mem::size_of::<u32>()) as u64,
533 );
534 }
535
536 queue.submit(Some(encoder.finish()));
537 let (flat, u8_raw) = read_staging_f32_and_u8(
538 device,
539 &staging_buffer,
540 cells * f32_channel_count,
541 u8_staging_len,
542 )?;
543 let u8_flat = if u8_attribute_count == 0 {
544 Vec::new()
545 } else {
546 unpack_u8_outputs_from_u32_staging(u8_raw, cells, u8_attribute_count)
547 };
548 let (out_x, out_y, out_z, attributes) =
549 split_xyz_and_attribute_blocks(flat, attribute_count, cells);
550 for buffer in upload_recycle {
551 runtime.recycle_storage(buffer.size(), buffer);
552 }
553 Ok((
554 out_x,
555 out_y,
556 out_z,
557 attributes,
558 split_u8_channel_blocks(u8_flat, u8_attribute_count, cells),
559 ))
560}
561
562pub(crate) fn record_voxel_reduce_f32_pass(
563 encoder: &mut wgpu::CommandEncoder,
564 runtime: &WgpuRuntime,
565 values: &wgpu::Buffer,
566 point_indices: &wgpu::Buffer,
567 cell_starts: &wgpu::Buffer,
568 cell_count: u32,
569 point_count: u32,
570 output_buffer: &wgpu::Buffer,
571) -> SpatialResult<()> {
572 if cell_count == 0 {
573 return Ok(());
574 }
575
576 let device = runtime.device();
577 let pipelines = runtime.pipelines();
578 let uniform = ReduceUniform { cell_count, point_count, _pad0: 0, _pad1: 0 };
579 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
580 label: Some("voxel-reduce-uniform"),
581 contents: bytemuck::bytes_of(&uniform),
582 usage: wgpu::BufferUsages::UNIFORM,
583 });
584
585 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
586 label: Some("voxel-reduce-bind-group"),
587 layout: &pipelines.voxel_reduce.bind_group_layout,
588 entries: &[
589 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
590 wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
591 wgpu::BindGroupEntry { binding: 2, resource: values.as_entire_binding() },
592 wgpu::BindGroupEntry { binding: 3, resource: cell_starts.as_entire_binding() },
593 wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
594 ],
595 });
596
597 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
598 label: Some("voxel-reduce-pass"),
599 timestamp_writes: None,
600 });
601 pass.set_pipeline(&pipelines.voxel_reduce.pipeline);
602 pass.set_bind_group(0, &bind_group, &[]);
603 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
604 Ok(())
605}
606
607pub(crate) fn record_voxel_reduce_u8_pass(
608 encoder: &mut wgpu::CommandEncoder,
609 runtime: &WgpuRuntime,
610 values: &wgpu::Buffer,
611 point_indices: &wgpu::Buffer,
612 cell_starts: &wgpu::Buffer,
613 cell_count: u32,
614 point_count: u32,
615 output_buffer: &wgpu::Buffer,
616) -> SpatialResult<()> {
617 if cell_count == 0 {
618 return Ok(());
619 }
620
621 let device = runtime.device();
622 let pipelines = runtime.pipelines();
623 let uniform = ReduceUniform { cell_count, point_count, _pad0: 0, _pad1: 0 };
624 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
625 label: Some("voxel-reduce-u8-uniform"),
626 contents: bytemuck::bytes_of(&uniform),
627 usage: wgpu::BufferUsages::UNIFORM,
628 });
629
630 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
631 label: Some("voxel-reduce-u8-bind-group"),
632 layout: &pipelines.voxel_reduce.u8_bind_group_layout,
633 entries: &[
634 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
635 wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
636 wgpu::BindGroupEntry { binding: 2, resource: values.as_entire_binding() },
637 wgpu::BindGroupEntry { binding: 3, resource: cell_starts.as_entire_binding() },
638 wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
639 ],
640 });
641
642 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
643 label: Some("voxel-reduce-u8-pass"),
644 timestamp_writes: None,
645 });
646 pass.set_pipeline(&pipelines.voxel_reduce.u8_pipeline);
647 pass.set_bind_group(0, &bind_group, &[]);
648 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
649 Ok(())
650}
651
652#[allow(clippy::too_many_arguments)]
653pub(crate) fn record_voxel_reduce_xyz_pass(
654 encoder: &mut wgpu::CommandEncoder,
655 runtime: &WgpuRuntime,
656 values_x: &wgpu::Buffer,
657 values_y: &wgpu::Buffer,
658 values_z: &wgpu::Buffer,
659 point_indices: &wgpu::Buffer,
660 cell_starts: &wgpu::Buffer,
661 cell_count: u32,
662 point_count: u32,
663 output_x: &wgpu::Buffer,
664 output_y: &wgpu::Buffer,
665 output_z: &wgpu::Buffer,
666) -> SpatialResult<()> {
667 if cell_count == 0 {
668 return Ok(());
669 }
670
671 let device = runtime.device();
672 let pipelines = runtime.pipelines();
673 let uniform = ReduceUniform { cell_count, point_count, _pad0: 0, _pad1: 0 };
674 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
675 label: Some("voxel-reduce-xyz-uniform"),
676 contents: bytemuck::bytes_of(&uniform),
677 usage: wgpu::BufferUsages::UNIFORM,
678 });
679
680 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
681 label: Some("voxel-reduce-xyz-bind-group"),
682 layout: &pipelines.voxel_reduce.xyz_bind_group_layout,
683 entries: &[
684 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
685 wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
686 wgpu::BindGroupEntry { binding: 2, resource: cell_starts.as_entire_binding() },
687 wgpu::BindGroupEntry { binding: 3, resource: values_x.as_entire_binding() },
688 wgpu::BindGroupEntry { binding: 4, resource: values_y.as_entire_binding() },
689 wgpu::BindGroupEntry { binding: 5, resource: values_z.as_entire_binding() },
690 wgpu::BindGroupEntry { binding: 6, resource: output_x.as_entire_binding() },
691 wgpu::BindGroupEntry { binding: 7, resource: output_y.as_entire_binding() },
692 wgpu::BindGroupEntry { binding: 8, resource: output_z.as_entire_binding() },
693 ],
694 });
695
696 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
697 label: Some("voxel-reduce-xyz-pass"),
698 timestamp_writes: None,
699 });
700 pass.set_pipeline(&pipelines.voxel_reduce.xyz_pipeline);
701 pass.set_bind_group(0, &bind_group, &[]);
702 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
703 Ok(())
704}
705
706fn dispatch_voxel_reduce_f32(
707 runtime: &WgpuRuntime,
708 values: &wgpu::Buffer,
709 point_indices: &wgpu::Buffer,
710 cell_starts: &wgpu::Buffer,
711 cell_count: u32,
712 point_count: u32,
713) -> SpatialResult<Vec<f32>> {
714 if cell_count == 0 {
715 return Ok(Vec::new());
716 }
717
718 let device = runtime.device();
719 let queue = runtime.queue();
720
721 let uniform = ReduceUniform { cell_count, point_count, _pad0: 0, _pad1: 0 };
722
723 let uniform_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
724 label: Some("voxel-reduce-uniform"),
725 contents: bytemuck::bytes_of(&uniform),
726 usage: wgpu::BufferUsages::UNIFORM,
727 });
728
729 let output_len = cell_count as usize * std::mem::size_of::<f32>();
730 let output_buffer = device.create_buffer(&wgpu::BufferDescriptor {
731 label: Some("voxel-reduce-output"),
732 size: output_len as u64,
733 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
734 mapped_at_creation: false,
735 });
736 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
737 label: Some("voxel-reduce-staging"),
738 size: output_len as u64,
739 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
740 mapped_at_creation: false,
741 });
742
743 let pipelines = runtime.pipelines();
744
745 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
746 label: Some("voxel-reduce-bind-group"),
747 layout: &pipelines.voxel_reduce.bind_group_layout,
748 entries: &[
749 wgpu::BindGroupEntry { binding: 0, resource: uniform_buffer.as_entire_binding() },
750 wgpu::BindGroupEntry { binding: 1, resource: point_indices.as_entire_binding() },
751 wgpu::BindGroupEntry { binding: 2, resource: values.as_entire_binding() },
752 wgpu::BindGroupEntry { binding: 3, resource: cell_starts.as_entire_binding() },
753 wgpu::BindGroupEntry { binding: 4, resource: output_buffer.as_entire_binding() },
754 ],
755 });
756
757 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
758 label: Some("voxel-reduce-encoder"),
759 });
760
761 {
762 let mut pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
763 label: Some("voxel-reduce-pass"),
764 timestamp_writes: None,
765 });
766 pass.set_pipeline(&pipelines.voxel_reduce.pipeline);
767 pass.set_bind_group(0, &bind_group, &[]);
768 pass.dispatch_workgroups(cell_count.div_ceil(WORKGROUP_SIZE), 1, 1);
769 }
770
771 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_len as u64);
772 queue.submit(Some(encoder.finish()));
773
774 read_staging_f32(device, &staging_buffer, cell_count as usize)
775}
776
777pub fn reduce_voxel_centroids_xyz(
779 runtime: &WgpuRuntime,
780 x: &[f32],
781 y: &[f32],
782 z: &[f32],
783 segments: &VoxelSegments,
784) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>)> {
785 if x.len() != y.len() || x.len() != z.len() {
786 return Err(SpatialError::BufferLengthMismatch { expected: x.len(), found: y.len() });
787 }
788
789 let out_x = reduce_voxel_average_f32(runtime, x, segments)?;
790 let out_y = reduce_voxel_average_f32(runtime, y, segments)?;
791 let out_z = reduce_voxel_average_f32(runtime, z, segments)?;
792 Ok((out_x, out_y, out_z))
793}
794
795pub fn reduce_voxel_centroids_xyz_gpu_buffers(
797 runtime: &WgpuRuntime,
798 x: &wgpu::Buffer,
799 y: &wgpu::Buffer,
800 z: &wgpu::Buffer,
801 segments: &GpuVoxelSegments,
802) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>)> {
803 dispatch_voxel_reduce_xyz_f32(
804 runtime,
805 x,
806 y,
807 z,
808 segments.point_indices_buffer(),
809 segments.cell_starts_buffer(),
810 segments.cell_count(),
811 segments.point_count(),
812 )
813}
814
815fn dispatch_voxel_reduce_xyz_f32(
816 runtime: &WgpuRuntime,
817 values_x: &wgpu::Buffer,
818 values_y: &wgpu::Buffer,
819 values_z: &wgpu::Buffer,
820 point_indices: &wgpu::Buffer,
821 cell_starts: &wgpu::Buffer,
822 cell_count: u32,
823 point_count: u32,
824) -> SpatialResult<(Vec<f32>, Vec<f32>, Vec<f32>)> {
825 if cell_count == 0 {
826 return Ok((Vec::new(), Vec::new(), Vec::new()));
827 }
828
829 let device = runtime.device();
830 let queue = runtime.queue();
831 let channel_len = cell_count as usize * std::mem::size_of::<f32>();
832 let output_x = device.create_buffer(&wgpu::BufferDescriptor {
833 label: Some("voxel-reduce-xyz-out-x"),
834 size: channel_len as u64,
835 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
836 mapped_at_creation: false,
837 });
838 let output_y = device.create_buffer(&wgpu::BufferDescriptor {
839 label: Some("voxel-reduce-xyz-out-y"),
840 size: channel_len as u64,
841 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
842 mapped_at_creation: false,
843 });
844 let output_z = device.create_buffer(&wgpu::BufferDescriptor {
845 label: Some("voxel-reduce-xyz-out-z"),
846 size: channel_len as u64,
847 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
848 mapped_at_creation: false,
849 });
850 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
851 label: Some("voxel-reduce-xyz-staging"),
852 size: (channel_len * 3) as u64,
853 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
854 mapped_at_creation: false,
855 });
856
857 let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
858 label: Some("voxel-reduce-xyz-encoder"),
859 });
860 record_voxel_reduce_xyz_pass(
861 &mut encoder,
862 runtime,
863 values_x,
864 values_y,
865 values_z,
866 point_indices,
867 cell_starts,
868 cell_count,
869 point_count,
870 &output_x,
871 &output_y,
872 &output_z,
873 )?;
874 encoder.copy_buffer_to_buffer(&output_x, 0, &staging_buffer, 0, channel_len as u64);
875 encoder.copy_buffer_to_buffer(
876 &output_y,
877 0,
878 &staging_buffer,
879 channel_len as u64,
880 channel_len as u64,
881 );
882 encoder.copy_buffer_to_buffer(
883 &output_z,
884 0,
885 &staging_buffer,
886 (channel_len * 2) as u64,
887 channel_len as u64,
888 );
889 queue.submit(Some(encoder.finish()));
890
891 let flat = read_staging_f32(device, &staging_buffer, cell_count as usize * 3)?;
892 Ok(split_xyz_blocks(flat, cell_count as usize))
893}
894
895#[cfg(test)]
896mod tests {
897 use super::{
898 reduce_voxel_average_f32, reduce_voxel_average_f32_multi_gpu, reduce_voxel_centroids_xyz,
899 reduce_voxel_centroids_xyz_and_average_multi_gpu,
900 reduce_voxel_centroids_xyz_and_gather_first_multi_gpu,
901 };
902 use crate::kernels::voxel_segments::build_voxel_segments;
903 use crate::runtime::WgpuRuntime;
904
905 #[test]
906 fn gpu_centroid_reduction_matches_cpu_reference() {
907 let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
908 let x = [0.0_f32, 0.1, 1.0, 1.1];
909 let y = [0.0_f32, 0.0, 0.0, 0.0];
910 let z = [0.0_f32, 0.0, 0.0, 0.0];
911 let keys = vec![(0, 0, 0), (0, 0, 0), (2, 0, 0), (2, 0, 0)];
912 let segments = build_voxel_segments(&keys);
913
914 let (gpu_x, gpu_y, gpu_z) =
915 reduce_voxel_centroids_xyz(&runtime, &x, &y, &z, &segments).expect("gpu reduce");
916
917 assert!((gpu_x[0] - 0.05).abs() < 1e-5);
918 assert!((gpu_x[1] - 1.05).abs() < 1e-5);
919 assert_eq!(gpu_y, vec![0.0, 0.0]);
920 assert_eq!(gpu_z, vec![0.0, 0.0]);
921
922 let intensity = [0.2_f32, 0.8, 10.0, 20.0];
923 let gpu_i = reduce_voxel_average_f32(&runtime, &intensity, &segments).expect("gpu average");
924 assert!((gpu_i[0] - 0.5).abs() < 1e-5);
925 assert!((gpu_i[1] - 15.0).abs() < 1e-5);
926 }
927
928 #[test]
929 fn gpu_multi_reduce_matches_single_channel_reference() {
930 use crate::kernels::voxel_sort::build_voxel_segments_gpu_from_keys;
931
932 let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
933 let intensity = [0.2_f32, 0.8, 10.0, 20.0];
934 let classification = [1.0_f32, 2.0, 3.0, 4.0];
935 let keys = vec![(0, 0, 0), (0, 0, 0), (2, 0, 0), (2, 0, 0)];
936 let segments = build_voxel_segments_gpu_from_keys(&runtime, &keys).expect("gpu segments");
937
938 let multi =
939 reduce_voxel_average_f32_multi_gpu(&runtime, &[&intensity, &classification], &segments)
940 .expect("multi reduce");
941
942 assert!((multi[0][0] - 0.5).abs() < 1e-5);
943 assert!((multi[0][1] - 15.0).abs() < 1e-5);
944 assert!((multi[1][0] - 1.5).abs() < 1e-5);
945 assert!((multi[1][1] - 3.5).abs() < 1e-5);
946 }
947
948 #[test]
949 fn unified_xyz_and_attribute_readback_matches_staged_reference() {
950 use crate::kernels::voxel_keys::compute_voxel_keys_gpu_buffers;
951 use crate::kernels::voxel_sort::build_voxel_segments_gpu_from_keys_buffer;
952
953 let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
954 let x = [0.0_f32, 0.1, 1.0, 1.1];
955 let y = [0.0_f32, 0.0, 0.0, 0.0];
956 let z = [0.0_f32, 0.0, 0.0, 0.0];
957 let intensity = [0.2_f32, 0.8, 10.0, 20.0];
958 let positions =
959 compute_voxel_keys_gpu_buffers(&runtime, &x, &y, &z, [0.0; 3], 2.0).expect("keys");
960 let segments = build_voxel_segments_gpu_from_keys_buffer(
961 &runtime,
962 positions.keys_buffer(),
963 positions.point_count(),
964 4,
965 )
966 .expect("segments");
967
968 let (out_x, out_y, out_z, attrs, _) = reduce_voxel_centroids_xyz_and_average_multi_gpu(
969 &runtime,
970 positions.x_buffer(),
971 positions.y_buffer(),
972 positions.z_buffer(),
973 &[&intensity],
974 &[],
975 &segments,
976 )
977 .expect("unified reduce");
978
979 assert!((out_x[0] - 0.05).abs() < 1e-5);
980 assert!((out_x[1] - 1.05).abs() < 1e-5);
981 assert_eq!(out_y, vec![0.0, 0.0]);
982 assert_eq!(out_z, vec![0.0, 0.0]);
983 assert!((attrs[0][0] - 0.5).abs() < 1e-5);
984 assert!((attrs[0][1] - 15.0).abs() < 1e-5);
985
986 let (_, _, _, gathered, _) = reduce_voxel_centroids_xyz_and_gather_first_multi_gpu(
987 &runtime,
988 positions.x_buffer(),
989 positions.y_buffer(),
990 positions.z_buffer(),
991 &[&intensity],
992 &[],
993 &segments,
994 )
995 .expect("unified gather attrs");
996 assert!((gathered[0][0] - 0.2).abs() < 1e-5);
997 assert!((gathered[0][1] - 10.0).abs() < 1e-5);
998 }
999
1000 #[test]
1001 fn unified_xyz_and_u8_attribute_readback_matches_reference() {
1002 use crate::kernels::voxel_keys::compute_voxel_keys_gpu_buffers;
1003 use crate::kernels::voxel_sort::build_voxel_segments_gpu_from_keys_buffer;
1004
1005 let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
1006 let x = [0.0_f32, 0.1, 1.0, 1.1];
1007 let y = [0.0_f32, 0.0, 0.0, 0.0];
1008 let z = [0.0_f32, 0.0, 0.0, 0.0];
1009 let red = [10_u8, 20, 100, 200];
1010 let green = [30_u8, 40, 50, 60];
1011 let positions =
1012 compute_voxel_keys_gpu_buffers(&runtime, &x, &y, &z, [0.0; 3], 2.0).expect("keys");
1013 let segments = build_voxel_segments_gpu_from_keys_buffer(
1014 &runtime,
1015 positions.keys_buffer(),
1016 positions.point_count(),
1017 4,
1018 )
1019 .expect("segments");
1020
1021 let (_, _, _, _, u8_attrs) = reduce_voxel_centroids_xyz_and_average_multi_gpu(
1022 &runtime,
1023 positions.x_buffer(),
1024 positions.y_buffer(),
1025 positions.z_buffer(),
1026 &[],
1027 &[&red, &green],
1028 &segments,
1029 )
1030 .expect("unified u8 reduce");
1031
1032 assert_eq!(u8_attrs[0], vec![15, 150]);
1033 assert_eq!(u8_attrs[1], vec![35, 55]);
1034
1035 let (_, _, _, _, gathered_u8) = reduce_voxel_centroids_xyz_and_gather_first_multi_gpu(
1036 &runtime,
1037 positions.x_buffer(),
1038 positions.y_buffer(),
1039 positions.z_buffer(),
1040 &[],
1041 &[&red, &green],
1042 &segments,
1043 )
1044 .expect("unified u8 gather");
1045 assert_eq!(gathered_u8[0], vec![10, 100]);
1046 assert_eq!(gathered_u8[1], vec![30, 50]);
1047 }
1048}