Skip to main content

spatialrust_gpu/
runtime.rs

1use std::sync::{Arc, OnceLock};
2
3use spatialrust_core::{SpatialError, SpatialResult};
4
5use crate::pipeline_cache::ComputePipelineCache;
6use crate::upload_cache::GpuUploadPool;
7
8/// Headless wgpu runtime for compute-only workloads.
9#[cfg(feature = "gpu-wgpu")]
10pub struct WgpuRuntime {
11    _instance: wgpu::Instance,
12    device: wgpu::Device,
13    queue: wgpu::Queue,
14    pipelines: OnceLock<ComputePipelineCache>,
15    max_gather_channels: u32,
16    upload_pool: GpuUploadPool,
17}
18
19/// Minimum storage buffers required for the 4-channel gather kernel.
20#[cfg(feature = "gpu-wgpu")]
21pub const MULTI_GATHER4_STORAGE_BUFFERS: u32 = 10;
22
23/// Minimum storage buffers required for the 2-channel gather kernel.
24#[cfg(feature = "gpu-wgpu")]
25pub const MULTI_GATHER2_STORAGE_BUFFERS: u32 = 6;
26
27#[cfg(feature = "gpu-wgpu")]
28static SHARED_RUNTIME: OnceLock<Result<Arc<WgpuRuntime>, String>> = OnceLock::new();
29
30#[cfg(feature = "gpu-wgpu")]
31impl WgpuRuntime {
32    /// Creates a headless wgpu runtime using the default adapter.
33    ///
34    /// Prefer [`Self::shared`] when running multiple GPU filters in one process.
35    pub fn new_headless() -> SpatialResult<Self> {
36        pollster::block_on(Self::new_headless_async())
37    }
38
39    /// Returns a process-wide shared headless runtime, initializing it on first use.
40    pub fn shared() -> SpatialResult<Arc<Self>> {
41        match SHARED_RUNTIME.get_or_init(init_shared_runtime) {
42            Ok(runtime) => Ok(Arc::clone(runtime)),
43            Err(message) => Err(SpatialError::InvalidArgument(message.clone())),
44        }
45    }
46
47    /// Returns the underlying wgpu device.
48    #[must_use]
49    pub fn device(&self) -> &wgpu::Device {
50        &self.device
51    }
52
53    /// Returns the underlying wgpu queue.
54    #[must_use]
55    pub fn queue(&self) -> &wgpu::Queue {
56        &self.queue
57    }
58
59    /// Returns cached compute pipelines for this runtime's device.
60    #[must_use]
61    pub fn pipelines(&self) -> &ComputePipelineCache {
62        self.pipelines.get_or_init(|| ComputePipelineCache::new(&self.device))
63    }
64
65    /// Returns the maximum attribute channels gatherable in one multi dispatch.
66    #[must_use]
67    pub fn max_gather_channels(&self) -> u32 {
68        self.max_gather_channels
69    }
70
71    /// Uploads `f32` values into a reusable pooled storage buffer.
72    pub fn upload_f32_storage(
73        &self,
74        label: &'static str,
75        data: &[f32],
76    ) -> SpatialResult<wgpu::Buffer> {
77        self.upload_pool.upload_f32_storage(self, label, data)
78    }
79
80    /// Returns a storage buffer to the upload pool for reuse.
81    pub fn recycle_storage(&self, byte_len: u64, buffer: wgpu::Buffer) {
82        self.upload_pool.recycle_storage(byte_len, buffer);
83    }
84
85    async fn new_headless_async() -> SpatialResult<Self> {
86        let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor {
87            backends: wgpu::Backends::PRIMARY,
88            ..Default::default()
89        });
90
91        let adapter = instance
92            .request_adapter(&wgpu::RequestAdapterOptions {
93                power_preference: wgpu::PowerPreference::LowPower,
94                compatible_surface: None,
95                force_fallback_adapter: false,
96            })
97            .await
98            .ok_or_else(|| {
99                SpatialError::InvalidArgument(
100                    "no compatible wgpu adapter found for headless compute".to_owned(),
101                )
102            })?;
103
104        let (device, queue) = adapter
105            .request_device(
106                &wgpu::DeviceDescriptor {
107                    label: Some("spatialrust-wgpu"),
108                    required_features: wgpu::Features::empty(),
109                    required_limits: adapter.limits(),
110                    memory_hints: wgpu::MemoryHints::Performance,
111                },
112                None,
113            )
114            .await
115            .map_err(|error| {
116                SpatialError::InvalidArgument(format!("failed to create wgpu device: {error}"))
117            })?;
118
119        let max_gather_channels =
120            max_gather_channels_for_limit(device.limits().max_storage_buffers_per_shader_stage);
121
122        Ok(Self {
123            _instance: instance,
124            device,
125            queue,
126            pipelines: OnceLock::new(),
127            max_gather_channels,
128            upload_pool: GpuUploadPool::default(),
129        })
130    }
131}
132
133#[cfg(feature = "gpu-wgpu")]
134fn max_gather_channels_for_limit(storage_buffers_per_stage: u32) -> u32 {
135    if storage_buffers_per_stage >= MULTI_GATHER4_STORAGE_BUFFERS {
136        4
137    } else if storage_buffers_per_stage >= MULTI_GATHER2_STORAGE_BUFFERS {
138        2
139    } else {
140        1
141    }
142}
143
144#[cfg(feature = "gpu-wgpu")]
145fn init_shared_runtime() -> Result<Arc<WgpuRuntime>, String> {
146    WgpuRuntime::new_headless().map(Arc::new).map_err(|error| error.to_string())
147}
148
149#[cfg(all(feature = "gpu-wgpu", test))]
150mod tests {
151    use super::WgpuRuntime;
152    use crate::pipeline_cache::ComputePipelineCache;
153    use std::sync::Arc;
154
155    #[test]
156    fn shared_runtime_is_singleton() {
157        let first = WgpuRuntime::shared().expect("shared runtime");
158        let second = WgpuRuntime::shared().expect("shared runtime");
159        assert!(Arc::ptr_eq(&first, &second));
160    }
161
162    #[test]
163    fn shared_and_headless_use_same_device_type() {
164        let shared = WgpuRuntime::shared().expect("shared runtime");
165        let local = WgpuRuntime::new_headless().expect("local runtime");
166        assert_eq!(
167            shared.device().limits().max_storage_buffers_per_shader_stage,
168            local.device().limits().max_storage_buffers_per_shader_stage
169        );
170    }
171
172    #[test]
173    fn pipeline_cache_is_initialized_once_per_runtime() {
174        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
175        let first = runtime.pipelines() as *const ComputePipelineCache;
176        let second = runtime.pipelines() as *const ComputePipelineCache;
177        assert_eq!(first, second);
178    }
179
180    #[test]
181    fn adapter_limits_enable_multi_channel_gather() {
182        let runtime = WgpuRuntime::new_headless().expect("wgpu runtime");
183        let limit = runtime.device().limits().max_storage_buffers_per_shader_stage;
184        assert!(
185            limit >= super::MULTI_GATHER2_STORAGE_BUFFERS,
186            "expected at least {} storage buffers per stage, got {limit}",
187            super::MULTI_GATHER2_STORAGE_BUFFERS
188        );
189        assert!(runtime.max_gather_channels() >= 2);
190        assert_eq!(
191            runtime.max_gather_channels(),
192            runtime.pipelines().voxel_gather.multi_max_channels
193        );
194    }
195}