1use std::collections::{HashSet, VecDeque};
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum DistanceType {
15 Chessboard,
17 Euclidean,
19}
20
21#[derive(Debug, Clone, Copy, PartialEq)]
23pub enum TransformType {
24 Distance,
26 Path,
28}
29
30#[derive(Debug, Clone)]
32pub struct WavefrontCppConfig {
33 pub distance_type: DistanceType,
35 pub transform_type: TransformType,
37 pub alpha: f64,
39}
40
41impl Default for WavefrontCppConfig {
42 fn default() -> Self {
43 Self {
44 distance_type: DistanceType::Chessboard,
45 transform_type: TransformType::Distance,
46 alpha: 0.01,
47 }
48 }
49}
50
51pub struct WavefrontGrid {
53 pub rows: usize,
55 pub cols: usize,
57 cells: Vec<bool>,
59}
60
61impl WavefrontGrid {
62 pub fn new(rows: usize, cols: usize) -> Self {
64 Self {
65 rows,
66 cols,
67 cells: vec![false; rows * cols],
68 }
69 }
70
71 pub fn from_vec(rows: usize, cols: usize, cells: Vec<bool>) -> Self {
75 assert_eq!(cells.len(), rows * cols);
76 Self { rows, cols, cells }
77 }
78
79 pub fn is_obstacle(&self, row: usize, col: usize) -> bool {
81 self.cells[row * self.cols + col]
82 }
83
84 pub fn set_obstacle(&mut self, row: usize, col: usize, val: bool) {
86 self.cells[row * self.cols + col] = val;
87 }
88
89 fn in_bounds(&self, r: i32, c: i32) -> bool {
90 r >= 0 && (r as usize) < self.rows && c >= 0 && (c as usize) < self.cols
91 }
92
93 fn is_free_signed(&self, r: i32, c: i32) -> bool {
94 self.in_bounds(r, c) && !self.is_obstacle(r as usize, c as usize)
95 }
96}
97
98const INC_ORDER: [(i32, i32); 8] = [
100 (0, 1),
101 (1, 1),
102 (1, 0),
103 (1, -1),
104 (0, -1),
105 (-1, -1),
106 (-1, 0),
107 (-1, 1),
108];
109
110fn obstacle_distance_transform(grid: &WavefrontGrid) -> Vec<f64> {
115 let n = grid.rows * grid.cols;
116 let mut dist = vec![f64::INFINITY; n];
117 let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
118
119 for r in 0..grid.rows {
121 for c in 0..grid.cols {
122 if grid.is_obstacle(r, c) {
123 dist[r * grid.cols + c] = 0.0;
124 queue.push_back((r, c));
125 }
126 }
127 }
128
129 while let Some((r, c)) = queue.pop_front() {
130 let cur = dist[r * grid.cols + c];
131 for &(dr, dc) in &INC_ORDER {
132 let nr = r as i32 + dr;
133 let nc = c as i32 + dc;
134 if grid.in_bounds(nr, nc) {
135 let nr = nr as usize;
136 let nc = nc as usize;
137 let nd = cur + 1.0;
138 if nd < dist[nr * grid.cols + nc] {
139 dist[nr * grid.cols + nc] = nd;
140 queue.push_back((nr, nc));
141 }
142 }
143 }
144 }
145
146 dist
147}
148
149fn build_transform(
154 grid: &WavefrontGrid,
155 src: (usize, usize),
156 config: &WavefrontCppConfig,
157) -> Vec<f64> {
158 let n = grid.rows * grid.cols;
159 let mut mat = vec![f64::INFINITY; n];
160 mat[src.0 * grid.cols + src.1] = 0.0;
161
162 let costs: [f64; 8] = match config.distance_type {
163 DistanceType::Chessboard => [1.0; 8],
164 DistanceType::Euclidean => {
165 let s = std::f64::consts::SQRT_2;
166 [1.0, s, 1.0, s, 1.0, s, 1.0, s]
167 }
168 };
169
170 let obstacle_dist = match config.transform_type {
171 TransformType::Distance => vec![0.0; n],
172 TransformType::Path => obstacle_distance_transform(grid),
173 };
174
175 let mut visited = vec![false; n];
176 visited[src.0 * grid.cols + src.1] = true;
177 let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
178 queue.push_back(src);
179 let mut enqueued = HashSet::new();
180 enqueued.insert(src);
181
182 while let Some((r, c)) = queue.pop_front() {
183 for (k, &(dr, dc)) in INC_ORDER.iter().enumerate() {
184 let nr = r as i32 + dr;
185 let nc = c as i32 + dc;
186 if grid.is_free_signed(nr, nc) {
187 let nr_u = nr as usize;
188 let nc_u = nc as usize;
189 let idx = nr_u * grid.cols + nc_u;
190 let cur_idx = r * grid.cols + c;
191
192 visited[cur_idx] = true;
193
194 let new_cost = mat[idx] + costs[k] + config.alpha * obstacle_dist[idx];
195 if new_cost < mat[cur_idx] {
196 mat[cur_idx] = new_cost;
197 }
198
199 if !visited[idx] && !enqueued.contains(&(nr_u, nc_u)) {
200 queue.push_back((nr_u, nc_u));
201 enqueued.insert((nr_u, nc_u));
202 }
203 }
204 }
205 }
206
207 mat
208}
209
210fn search_order(start: (usize, usize), goal: (usize, usize)) -> [(i32, i32); 8] {
215 let sr = start.0 as i32;
216 let sc = start.1 as i32;
217 let gr = goal.0 as i32;
218 let gc = goal.1 as i32;
219
220 if sr >= gr && sc >= gc {
221 [
222 (1, 0),
223 (0, 1),
224 (-1, 0),
225 (0, -1),
226 (1, 1),
227 (1, -1),
228 (-1, 1),
229 (-1, -1),
230 ]
231 } else if sr <= gr && sc >= gc {
232 [
233 (-1, 0),
234 (0, 1),
235 (1, 0),
236 (0, -1),
237 (-1, 1),
238 (-1, -1),
239 (1, 1),
240 (1, -1),
241 ]
242 } else if sr >= gr && sc <= gc {
243 [
244 (1, 0),
245 (0, -1),
246 (-1, 0),
247 (0, 1),
248 (1, -1),
249 (-1, -1),
250 (1, 1),
251 (-1, 1),
252 ]
253 } else {
254 [
255 (-1, 0),
256 (0, -1),
257 (0, 1),
258 (1, 0),
259 (-1, -1),
260 (-1, 1),
261 (1, -1),
262 (1, 1),
263 ]
264 }
265}
266
267pub fn wavefront_cpp(
279 grid: &WavefrontGrid,
280 start: (usize, usize),
281 goal: (usize, usize),
282 config: &WavefrontCppConfig,
283) -> Vec<(usize, usize)> {
284 let transform = build_transform(grid, goal, config);
285 let order = search_order(start, goal);
286
287 let mut path: Vec<(usize, usize)> = Vec::new();
288 let mut visited = vec![false; grid.rows * grid.cols];
289 let mut current = start;
290
291 loop {
292 if current == goal {
293 path.push(current);
294 break;
295 }
296
297 let (r, c) = current;
298 path.push((r, c));
299 visited[r * grid.cols + c] = true;
300
301 let mut best = None;
303 let mut best_val = f64::NEG_INFINITY;
304
305 for &(pr, pc) in path.iter().rev() {
306 for &(dr, dc) in &order {
307 let nr = pr as i32 + dr;
308 let nc = pc as i32 + dc;
309 if grid.is_free_signed(nr, nc) {
310 let nr_u = nr as usize;
311 let nc_u = nc as usize;
312 let idx = nr_u * grid.cols + nc_u;
313 if !visited[idx] && transform[idx] != f64::INFINITY && transform[idx] > best_val
314 {
315 best_val = transform[idx];
316 best = Some((nr_u, nc_u));
317 }
318 }
319 }
320 if best.is_some() {
322 break;
323 }
324 }
325
326 match best {
327 Some(next) => current = next,
328 None => {
329 break;
331 }
332 }
333 }
334
335 path
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341
342 fn open_grid(rows: usize, cols: usize) -> WavefrontGrid {
344 WavefrontGrid::new(rows, cols)
345 }
346
347 #[test]
348 fn test_simple_open_grid_visits_all_cells() {
349 let grid = open_grid(5, 5);
350 let config = WavefrontCppConfig::default();
351 let path = wavefront_cpp(&grid, (4, 0), (0, 0), &config);
352
353 assert_eq!(*path.first().unwrap(), (4, 0));
355 assert_eq!(*path.last().unwrap(), (0, 0));
356
357 let unique: HashSet<_> = path.iter().copied().collect();
359 assert_eq!(unique.len(), 25);
360 }
361
362 #[test]
363 fn test_start_equals_goal() {
364 let grid = open_grid(3, 3);
365 let config = WavefrontCppConfig::default();
366 let path = wavefront_cpp(&grid, (0, 0), (0, 0), &config);
367
368 assert_eq!(*path.first().unwrap(), (0, 0));
369 assert_eq!(*path.last().unwrap(), (0, 0));
370 }
371
372 #[test]
373 fn test_grid_with_obstacles() {
374 let mut grid = WavefrontGrid::new(5, 5);
376 for c in 1..5 {
377 grid.set_obstacle(2, c, true);
378 }
379
380 let config = WavefrontCppConfig::default();
381 let path = wavefront_cpp(&grid, (4, 4), (0, 0), &config);
382
383 for &(r, c) in &path {
385 assert!(!grid.is_obstacle(r, c), "Path on obstacle at ({r}, {c})");
386 }
387
388 let unique: HashSet<_> = path.iter().copied().collect();
390 assert_eq!(unique.len(), 21);
391 }
392
393 #[test]
394 fn test_euclidean_distance_type() {
395 let grid = open_grid(4, 4);
396 let config = WavefrontCppConfig {
397 distance_type: DistanceType::Euclidean,
398 transform_type: TransformType::Distance,
399 alpha: 0.0,
400 };
401 let path = wavefront_cpp(&grid, (3, 0), (0, 0), &config);
402
403 let unique: HashSet<_> = path.iter().copied().collect();
404 assert_eq!(unique.len(), 16);
405 }
406
407 #[test]
408 fn test_path_transform_type() {
409 let mut grid = WavefrontGrid::new(6, 6);
410 grid.set_obstacle(2, 2, true);
412 grid.set_obstacle(2, 3, true);
413
414 let config = WavefrontCppConfig {
415 distance_type: DistanceType::Chessboard,
416 transform_type: TransformType::Path,
417 alpha: 0.01,
418 };
419 let path = wavefront_cpp(&grid, (5, 0), (0, 0), &config);
420
421 for &(r, c) in &path {
422 assert!(!grid.is_obstacle(r, c));
423 }
424
425 let unique: HashSet<_> = path.iter().copied().collect();
426 assert_eq!(unique.len(), 34);
428 }
429
430 #[test]
431 fn test_from_vec_constructor() {
432 #[rustfmt::skip]
433 let cells = vec![
434 false, false, false,
435 false, true, false,
436 false, false, false,
437 ];
438 let grid = WavefrontGrid::from_vec(3, 3, cells);
439 assert!(grid.is_obstacle(1, 1));
440 assert!(!grid.is_obstacle(0, 0));
441 }
442
443 #[test]
444 fn test_obstacle_distance_transform_basic() {
445 let mut grid = WavefrontGrid::new(5, 5);
446 grid.set_obstacle(2, 2, true);
447 let dist = obstacle_distance_transform(&grid);
448
449 assert_eq!(dist[2 * 5 + 2], 0.0);
451 assert_eq!(dist[5 + 2], 1.0);
453 assert_eq!(dist[2 * 5 + 1], 1.0);
454 assert_eq!(dist[0], 2.0);
456 }
457
458 #[test]
459 fn test_single_cell_grid() {
460 let grid = open_grid(1, 1);
461 let config = WavefrontCppConfig::default();
462 let path = wavefront_cpp(&grid, (0, 0), (0, 0), &config);
463 assert_eq!(path, vec![(0, 0)]);
464 }
465
466 #[test]
467 fn test_no_path_cells_are_duplicated() {
468 let grid = open_grid(4, 4);
469 let config = WavefrontCppConfig::default();
470 let path = wavefront_cpp(&grid, (3, 3), (0, 0), &config);
471
472 let unique: HashSet<_> = path.iter().copied().collect();
473 assert_eq!(
474 unique.len(),
475 path.len(),
476 "Path should not contain duplicates"
477 );
478 }
479
480 #[test]
481 fn test_search_order_all_quadrants() {
482 let cases = [
484 ((3, 3), (0, 0)),
485 ((0, 0), (3, 3)),
486 ((3, 0), (0, 3)),
487 ((0, 3), (3, 0)),
488 ];
489 for (s, g) in cases {
490 let order = search_order(s, g);
491 let set: HashSet<_> = order.iter().copied().collect();
492 assert_eq!(set.len(), 8);
493 }
494 }
495}