spatialrust_gpu/
runtime.rs1use std::sync::{Arc, OnceLock};
2
3use spatialrust_core::{SpatialError, SpatialResult};
4
5use crate::pipeline_cache::ComputePipelineCache;
6use crate::upload_cache::GpuUploadPool;
7
8#[cfg(feature = "gpu-wgpu")]
10pub struct WgpuRuntime {
11 _instance: wgpu::Instance,
12 device: wgpu::Device,
13 queue: wgpu::Queue,
14 pipelines: OnceLock<ComputePipelineCache>,
15 max_gather_channels: u32,
16 upload_pool: GpuUploadPool,
17}
18
19#[cfg(feature = "gpu-wgpu")]
21pub const MULTI_GATHER4_STORAGE_BUFFERS: u32 = 10;
22
23#[cfg(feature = "gpu-wgpu")]
25pub const MULTI_GATHER2_STORAGE_BUFFERS: u32 = 6;
26
27#[cfg(feature = "gpu-wgpu")]
28static SHARED_RUNTIME: OnceLock<Result<Arc<WgpuRuntime>, String>> = OnceLock::new();
29
30#[cfg(feature = "gpu-wgpu")]
31impl WgpuRuntime {
32 pub fn new_headless() -> SpatialResult<Self> {
36 pollster::block_on(Self::new_headless_async())
37 }
38
39 pub fn shared() -> SpatialResult<Arc<Self>> {
41 match SHARED_RUNTIME.get_or_init(init_shared_runtime) {
42 Ok(runtime) => Ok(Arc::clone(runtime)),
43 Err(message) => Err(SpatialError::InvalidArgument(message.clone())),
44 }
45 }
46
47 #[must_use]
49 pub fn device(&self) -> &wgpu::Device {
50 &self.device
51 }
52
53 #[must_use]
55 pub fn queue(&self) -> &wgpu::Queue {
56 &self.queue
57 }
58
59 #[must_use]
61 pub fn pipelines(&self) -> &ComputePipelineCache {
62 self.pipelines.get_or_init(|| ComputePipelineCache::new(&self.device))
63 }
64
65 #[must_use]
67 pub fn max_gather_channels(&self) -> u32 {
68 self.max_gather_channels
69 }
70
71 pub fn upload_f32_storage(
73 &self,
74 label: &'static str,
75 data: &[f32],
76 ) -> SpatialResult<wgpu::Buffer> {
77 self.upload_pool.upload_f32_storage(self, label, data)
78 }
79
80 pub fn recycle_storage(&self, byte_len: u64, buffer: wgpu::Buffer) {
82 self.upload_pool.recycle_storage(byte_len, buffer);
83 }
84
85 async fn new_headless_async() -> SpatialResult<Self> {
86 let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
87 backends: wgpu::Backends::PRIMARY,
88 ..Default::default()
89 });
90
91 let adapter = instance
92 .request_adapter(&wgpu::RequestAdapterOptions {
93 power_preference: wgpu::PowerPreference::LowPower,
94 compatible_surface: None,
95 force_fallback_adapter: false,
96 })
97 .await
98 .ok_or_else(|| {
99 SpatialError::InvalidArgument(
100 "no compatible wgpu adapter found for headless compute".to_owned(),
101 )
102 })?;
103
104 let (device, queue) = adapter
105 .request_device(
106 &wgpu::DeviceDescriptor {
107 label: Some("spatialrust-wgpu"),
108 required_features: wgpu::Features::empty(),
109 required_limits: adapter.limits(),
110 memory_hints: wgpu::MemoryHints::Performance,
111 },
112 None,
113 )
114 .await
115 .map_err(|error| {
116 SpatialError::InvalidArgument(format!("failed to create wgpu device: {error}"))
117 })?;
118
119 let max_gather_channels =
120 max_gather_channels_for_limit(device.limits().max_storage_buffers_per_shader_stage);
121
122 Ok(Self {
123 _instance: instance,
124 device,
125 queue,
126 pipelines: OnceLock::new(),
127 max_gather_channels,
128 upload_pool: GpuUploadPool::default(),
129 })
130 }
131}
132
133#[cfg(feature = "gpu-wgpu")]
134fn max_gather_channels_for_limit(storage_buffers_per_stage: u32) -> u32 {
135 if storage_buffers_per_stage >= MULTI_GATHER4_STORAGE_BUFFERS {
136 4
137 } else if storage_buffers_per_stage >= MULTI_GATHER2_STORAGE_BUFFERS {
138 2
139 } else {
140 1
141 }
142}
143
144#[cfg(feature = "gpu-wgpu")]
145fn init_shared_runtime() -> Result<Arc<WgpuRuntime>, String> {
146 WgpuRuntime::new_headless().map(Arc::new).map_err(|error| error.to_string())
147}
148
149#[cfg(all(feature = "gpu-wgpu", test))]
150mod tests {
151 use super::WgpuRuntime;
152 use crate::pipeline_cache::ComputePipelineCache;
153 use std::sync::Arc;
154
155 #[test]
156 fn shared_runtime_is_singleton() {
157 let first = WgpuRuntime::shared().expect("shared runtime");
158 let second = WgpuRuntime::shared().expect("shared runtime");
159 assert!(Arc::ptr_eq(&first, &second));
160 }
161
162 #[test]
163 fn shared_and_headless_use_same_device_type() {
164 let shared = WgpuRuntime::shared().expect("shared runtime");
165 let local = WgpuRuntime::new_headless().expect("local runtime");
166 assert_eq!(
167 shared.device().limits().max_storage_buffers_per_shader_stage,
168 local.device().limits().max_storage_buffers_per_shader_stage
169 );
170 }
171
172 #[test]
173 fn pipeline_cache_is_initialized_once_per_runtime() {
174 let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
175 let first = runtime.pipelines() as *const ComputePipelineCache;
176 let second = runtime.pipelines() as *const ComputePipelineCache;
177 assert_eq!(first, second);
178 }
179
180 #[test]
181 fn adapter_limits_enable_multi_channel_gather() {
182 let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
183 let limit = runtime.device().limits().max_storage_buffers_per_shader_stage;
184 assert!(
185 limit >= super::MULTI_GATHER2_STORAGE_BUFFERS,
186 "expected at least {} storage buffers per stage, got {limit}",
187 super::MULTI_GATHER2_STORAGE_BUFFERS
188 );
189 assert!(runtime.max_gather_channels() >= 2);
190 assert_eq!(
191 runtime.max_gather_channels(),
192 runtime.pipelines().voxel_gather.multi_max_channels
193 );
194 }
195}