1use std::io::Read;
4use std::sync::Arc;
5
6use copc_streaming::ByteSource;
7
8use spatialrust_core::PointCloud;
9
10use crate::copc::query::{CopcFileInfo, CopcQuery};
11use crate::copc::reader::{read_copc_from_byte_source, read_header_info};
12use crate::error::{copc_format, IoError};
13
14const DEFAULT_MAX_PARALLEL_RANGES: usize = 8;
15
16#[derive(Clone, Debug, PartialEq, Eq)]
18pub struct HttpByteSource {
19 url: String,
20 max_parallel_ranges: usize,
21}
22
23impl HttpByteSource {
24 pub fn new(url: impl Into<String>) -> Result<Self, IoError> {
26 let url = url.into();
27 validate_http_url(&url)?;
28 Ok(Self { url, max_parallel_ranges: DEFAULT_MAX_PARALLEL_RANGES })
29 }
30
31 #[must_use]
33 pub fn with_max_parallel_ranges(mut self, max_parallel_ranges: usize) -> Self {
34 self.max_parallel_ranges = max_parallel_ranges.max(1);
35 self
36 }
37
38 #[must_use]
40 pub fn url(&self) -> &str {
41 &self.url
42 }
43
44 #[must_use]
46 pub fn max_parallel_ranges(&self) -> usize {
47 self.max_parallel_ranges
48 }
49}
50
51impl ByteSource for HttpByteSource {
52 async fn read_range(
53 &self,
54 offset: u64,
55 length: u64,
56 ) -> Result<Vec<u8>, copc_streaming::CopcError> {
57 fetch_http_range(&self.url, offset, length)
58 }
59
60 async fn read_ranges(
61 &self,
62 ranges: &[(u64, u64)],
63 ) -> Result<Vec<Vec<u8>>, copc_streaming::CopcError> {
64 fetch_http_ranges_parallel(&self.url, ranges, self.max_parallel_ranges)
65 }
66
67 async fn size(&self) -> Result<Option<u64>, copc_streaming::CopcError> {
68 fetch_http_size(&self.url)
69 }
70}
71
72pub fn read_copc_url(url: &str) -> Result<PointCloud, IoError> {
74 read_copc_url_with_query(url, None)
75}
76
77pub fn read_copc_url_with_query(
79 url: &str,
80 query: Option<&CopcQuery>,
81) -> Result<PointCloud, IoError> {
82 if let Some(query) = query {
83 query.validate()?;
84 }
85 validate_http_url(url)?;
86 let source = HttpByteSource::new(url)?;
87 pollster::block_on(read_copc_from_byte_source(source, query))
88}
89
90pub fn read_copc_url_info(url: &str) -> Result<CopcFileInfo, IoError> {
92 validate_http_url(url)?;
93 let source = HttpByteSource::new(url)?;
94 pollster::block_on(async { read_header_info(source).await.map(|(_, info)| info) })
95}
96
97fn fetch_http_range(
98 url: &str,
99 offset: u64,
100 length: u64,
101) -> Result<Vec<u8>, copc_streaming::CopcError> {
102 if length == 0 {
103 return Ok(Vec::new());
104 }
105
106 let end = offset.saturating_add(length.saturating_sub(1));
107 let response = ureq::get(url)
108 .set("Range", &format!("bytes={offset}-{end}"))
109 .call()
110 .map_err(|error| copc_streaming::CopcError::ByteSource(Box::new(error)))?;
111
112 let status = response.status();
113 if status != 200 && status != 206 {
114 return Err(copc_streaming::CopcError::ByteSource(Box::new(std::io::Error::new(
115 std::io::ErrorKind::InvalidData,
116 format!("unexpected HTTP status {status} for range request"),
117 ))));
118 }
119
120 let mut bytes = Vec::with_capacity(length as usize);
121 response
122 .into_reader()
123 .take(length)
124 .read_to_end(&mut bytes)
125 .map_err(copc_streaming::CopcError::Io)?;
126 Ok(bytes)
127}
128
129fn fetch_http_ranges_parallel(
130 url: &str,
131 ranges: &[(u64, u64)],
132 max_parallel_ranges: usize,
133) -> Result<Vec<Vec<u8>>, copc_streaming::CopcError> {
134 if ranges.is_empty() {
135 return Ok(Vec::new());
136 }
137
138 let mut results = Vec::with_capacity(ranges.len());
139 let url = Arc::new(url.to_owned());
140
141 for batch in ranges.chunks(max_parallel_ranges.max(1)) {
142 let batch_results = read_range_batch(Arc::clone(&url), batch)?;
143 results.extend(batch_results);
144 }
145
146 Ok(results)
147}
148
149fn read_range_batch(
150 url: Arc<String>,
151 ranges: &[(u64, u64)],
152) -> Result<Vec<Vec<u8>>, copc_streaming::CopcError> {
153 if ranges.len() == 1 {
154 let (offset, length) = ranges[0];
155 return Ok(vec![fetch_http_range(url.as_str(), offset, length)?]);
156 }
157
158 std::thread::scope(|scope| {
159 let mut handles = Vec::with_capacity(ranges.len());
160 for (index, &(offset, length)) in ranges.iter().enumerate() {
161 let url = Arc::clone(&url);
162 handles.push(scope.spawn(move || {
163 let bytes = fetch_http_range(url.as_str(), offset, length)?;
164 Ok::<_, copc_streaming::CopcError>((index, bytes))
165 }));
166 }
167
168 let mut batch = vec![Vec::new(); ranges.len()];
169 for handle in handles {
170 let (index, bytes) = handle.join().map_err(|_| {
171 copc_streaming::CopcError::ByteSource(Box::new(std::io::Error::other(
172 "parallel HTTP range worker panicked",
173 )))
174 })??;
175 batch[index] = bytes;
176 }
177 Ok(batch)
178 })
179}
180
181fn fetch_http_size(url: &str) -> Result<Option<u64>, copc_streaming::CopcError> {
182 if let Ok(response) = ureq::head(url).call() {
183 if let Some(total) = response.header("Content-Length").and_then(parse_u64_header) {
184 return Ok(Some(total));
185 }
186 }
187
188 let response = ureq::get(url)
189 .set("Range", "bytes=0-0")
190 .call()
191 .map_err(|error| copc_streaming::CopcError::ByteSource(Box::new(error)))?;
192
193 if let Some(total) = response.header("Content-Range").and_then(parse_content_range_total) {
194 return Ok(Some(total));
195 }
196
197 if let Some(total) = response.header("Content-Length").and_then(parse_u64_header) {
198 return Ok(Some(total));
199 }
200
201 Ok(None)
202}
203
204fn validate_http_url(url: &str) -> Result<(), IoError> {
205 if url.starts_with("http://") || url.starts_with("https://") {
206 Ok(())
207 } else {
208 Err(copc_format(format!(
209 "COPC HTTP sources require an http:// or https:// URL, got `{url}`"
210 )))
211 }
212}
213
214fn parse_u64_header(value: &str) -> Option<u64> {
215 value.trim().parse().ok()
216}
217
218fn parse_content_range_total(value: &str) -> Option<u64> {
219 value.split('/').nth(1)?.trim().parse().ok()
220}
221
222#[cfg(test)]
223mod tests {
224 use super::{
225 fetch_http_ranges_parallel, parse_content_range_total, read_range_batch, validate_http_url,
226 HttpByteSource,
227 };
228 use copc_streaming::ByteSource;
229 use std::io::{Read, Write};
230 use std::net::{TcpListener, TcpStream};
231 use std::sync::atomic::{AtomicUsize, Ordering};
232 use std::sync::Arc;
233 use std::thread;
234 use std::time::Duration;
235
236 #[test]
237 fn validates_http_urls() {
238 assert!(validate_http_url("https://example.com/cloud.copc.laz").is_ok());
239 assert!(validate_http_url("/tmp/local.copc.laz").is_err());
240 }
241
242 #[test]
243 fn parses_content_range_total() {
244 assert_eq!(parse_content_range_total("bytes 0-0/12345"), Some(12345));
245 }
246
247 #[test]
248 fn constructs_http_source() {
249 let source = HttpByteSource::new("https://example.com/cloud.copc.laz").unwrap();
250 assert_eq!(source.url(), "https://example.com/cloud.copc.laz");
251 assert_eq!(source.max_parallel_ranges(), 8);
252 }
253
254 #[test]
255 fn read_ranges_fetches_multiple_byte_ranges() {
256 let payload = b"0123456789ABCDEF";
257 let requests = Arc::new(AtomicUsize::new(0));
258 let requests_server = Arc::clone(&requests);
259
260 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
261 listener.set_nonblocking(true).unwrap();
262 let addr = listener.local_addr().unwrap();
263
264 let server = thread::spawn(move || {
265 let deadline = std::time::Instant::now() + Duration::from_secs(5);
266 while requests_server.load(Ordering::SeqCst) < 3 {
267 if std::time::Instant::now() > deadline {
268 panic!("timed out waiting for HTTP range requests");
269 }
270 let Ok((mut stream, _)) = listener.accept() else {
271 thread::sleep(Duration::from_millis(10));
272 continue;
273 };
274 serve_test_range(&mut stream, payload, &requests_server);
275 }
276 });
277
278 let url = format!("http://{addr}/cloud.copc.laz");
279 let source = HttpByteSource::new(&url).unwrap().with_max_parallel_ranges(3);
280 let ranges = vec![(0, 4), (4, 4), (8, 4)];
281 let results = pollster::block_on(source.read_ranges(&ranges)).unwrap();
282
283 assert_eq!(results.len(), 3);
284 assert_eq!(results[0], b"0123".to_vec());
285 assert_eq!(results[1], b"4567".to_vec());
286 assert_eq!(results[2], b"89AB".to_vec());
287 assert_eq!(requests.load(Ordering::SeqCst), 3);
288 server.join().unwrap();
289 }
290
291 #[test]
292 fn fetch_ranges_batches_by_parallelism_limit() {
293 let url = "https://example.com/cloud.copc.laz";
294 let ranges = vec![(0, 1); 5];
295 let err = fetch_http_ranges_parallel(url, &ranges, 2).unwrap_err();
296 assert!(matches!(
297 err,
298 copc_streaming::CopcError::ByteSource(_) | copc_streaming::CopcError::Io(_)
299 ));
300 }
301
302 #[test]
303 fn single_range_batch_delegates_to_fetch() {
304 let result = read_range_batch(
305 Arc::new("https://invalid.test/not-found.copc.laz".to_owned()),
306 &[(0, 1)],
307 );
308 assert!(result.is_err());
309 }
310
311 fn serve_test_range(stream: &mut TcpStream, payload: &[u8], requests: &AtomicUsize) {
312 let mut buffer = [0_u8; 512];
313 let read = stream.read(&mut buffer).unwrap();
314 let request = std::str::from_utf8(&buffer[..read]).unwrap();
315 let range = request
316 .lines()
317 .find_map(|line| line.strip_prefix("Range: bytes="))
318 .expect("missing Range header");
319 let (start, end) = range
320 .split_once('-')
321 .and_then(|(start, end)| Some((start.parse::<u64>().ok()?, end.parse::<u64>().ok()?)))
322 .expect("invalid Range header");
323 let start = start as usize;
324 let end = end as usize;
325 let body = payload[start..=end].to_vec();
326
327 requests.fetch_add(1, Ordering::SeqCst);
328 let response = format!(
329 "HTTP/1.1 206 Partial Content\r\nContent-Length: {}\r\nContent-Range: bytes {start}-{end}/{}\r\n\r\n",
330 body.len(),
331 payload.len()
332 );
333 stream.write_all(response.as_bytes()).unwrap();
334 stream.write_all(&body).unwrap();
335 }
336}