Skip to main content

rust_robotics_planning/
wavefront_cpp.rs

1//! Wavefront Coverage Path Planner
2//!
3//! Uses BFS wavefront expansion from a goal cell to build a distance/path
4//! transform matrix, then greedily follows highest-valued unvisited neighbours
5//! to produce a coverage path that visits every reachable free cell.
6//!
7//! Reference: "Planning paths of complete coverage of an unstructured
8//! environment by a mobile robot" — Zelinsky et al.
9
10use std::collections::{HashSet, VecDeque};
11
12/// Distance metric used during wavefront expansion.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum DistanceType {
15    /// All 8-connected moves cost 1.
16    Chessboard,
17    /// Cardinal moves cost 1, diagonal moves cost sqrt(2).
18    Euclidean,
19}
20
21/// Transform type used to build the wavefront matrix.
22#[derive(Debug, Clone, Copy, PartialEq)]
23pub enum TransformType {
24    /// Pure distance transform (no obstacle weighting).
25    Distance,
26    /// Path transform — adds an obstacle-proximity penalty weighted by `alpha`.
27    Path,
28}
29
30/// Configuration for the wavefront coverage path planner.
31#[derive(Debug, Clone)]
32pub struct WavefrontCppConfig {
33    /// Distance metric for BFS expansion.
34    pub distance_type: DistanceType,
35    /// Transform type (distance or path).
36    pub transform_type: TransformType,
37    /// Weight of obstacle proximity penalty (only used with `TransformType::Path`).
38    pub alpha: f64,
39}
40
41impl Default for WavefrontCppConfig {
42    fn default() -> Self {
43        Self {
44            distance_type: DistanceType::Chessboard,
45            transform_type: TransformType::Distance,
46            alpha: 0.01,
47        }
48    }
49}
50
51/// 2D grid for wavefront planning. `true` = obstacle, `false` = free.
52pub struct WavefrontGrid {
53    /// Number of rows.
54    pub rows: usize,
55    /// Number of columns.
56    pub cols: usize,
57    /// Row-major obstacle data: `cells[row * cols + col]`.
58    cells: Vec<bool>,
59}
60
61impl WavefrontGrid {
62    /// Creates a new grid filled with free cells.
63    pub fn new(rows: usize, cols: usize) -> Self {
64        Self {
65            rows,
66            cols,
67            cells: vec![false; rows * cols],
68        }
69    }
70
71    /// Creates a grid from a row-major boolean vector.
72    ///
73    /// `true` = obstacle, `false` = free.
74    pub fn from_vec(rows: usize, cols: usize, cells: Vec<bool>) -> Self {
75        assert_eq!(cells.len(), rows * cols);
76        Self { rows, cols, cells }
77    }
78
79    /// Returns whether the cell is an obstacle.
80    pub fn is_obstacle(&self, row: usize, col: usize) -> bool {
81        self.cells[row * self.cols + col]
82    }
83
84    /// Sets a cell as obstacle or free.
85    pub fn set_obstacle(&mut self, row: usize, col: usize, val: bool) {
86        self.cells[row * self.cols + col] = val;
87    }
88
89    fn in_bounds(&self, r: i32, c: i32) -> bool {
90        r >= 0 && (r as usize) < self.rows && c >= 0 && (c as usize) < self.cols
91    }
92
93    fn is_free_signed(&self, r: i32, c: i32) -> bool {
94        self.in_bounds(r, c) && !self.is_obstacle(r as usize, c as usize)
95    }
96}
97
98/// 8-connected neighbour increments: E, SE, S, SW, W, NW, N, NE.
99const INC_ORDER: [(i32, i32); 8] = [
100    (0, 1),
101    (1, 1),
102    (1, 0),
103    (1, -1),
104    (0, -1),
105    (-1, -1),
106    (-1, 0),
107    (-1, 1),
108];
109
110/// Compute the obstacle distance transform using BFS (chessboard distance).
111///
112/// Returns a matrix where each free cell holds its minimum chessboard distance
113/// to the nearest obstacle, and obstacle cells hold 0.
114fn obstacle_distance_transform(grid: &WavefrontGrid) -> Vec<f64> {
115    let n = grid.rows * grid.cols;
116    let mut dist = vec![f64::INFINITY; n];
117    let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
118
119    // Seed: all obstacle cells have distance 0.
120    for r in 0..grid.rows {
121        for c in 0..grid.cols {
122            if grid.is_obstacle(r, c) {
123                dist[r * grid.cols + c] = 0.0;
124                queue.push_back((r, c));
125            }
126        }
127    }
128
129    while let Some((r, c)) = queue.pop_front() {
130        let cur = dist[r * grid.cols + c];
131        for &(dr, dc) in &INC_ORDER {
132            let nr = r as i32 + dr;
133            let nc = c as i32 + dc;
134            if grid.in_bounds(nr, nc) {
135                let nr = nr as usize;
136                let nc = nc as usize;
137                let nd = cur + 1.0;
138                if nd < dist[nr * grid.cols + nc] {
139                    dist[nr * grid.cols + nc] = nd;
140                    queue.push_back((nr, nc));
141                }
142            }
143        }
144    }
145
146    dist
147}
148
149/// Build the wavefront transform matrix via BFS from `src`.
150///
151/// Each free cell receives a cost indicating the wavefront expansion distance
152/// from `src`. Obstacle cells remain at `f64::INFINITY`.
153fn build_transform(
154    grid: &WavefrontGrid,
155    src: (usize, usize),
156    config: &WavefrontCppConfig,
157) -> Vec<f64> {
158    let n = grid.rows * grid.cols;
159    let mut mat = vec![f64::INFINITY; n];
160    mat[src.0 * grid.cols + src.1] = 0.0;
161
162    let costs: [f64; 8] = match config.distance_type {
163        DistanceType::Chessboard => [1.0; 8],
164        DistanceType::Euclidean => {
165            let s = std::f64::consts::SQRT_2;
166            [1.0, s, 1.0, s, 1.0, s, 1.0, s]
167        }
168    };
169
170    let obstacle_dist = match config.transform_type {
171        TransformType::Distance => vec![0.0; n],
172        TransformType::Path => obstacle_distance_transform(grid),
173    };
174
175    let mut visited = vec![false; n];
176    visited[src.0 * grid.cols + src.1] = true;
177    let mut queue: VecDeque<(usize, usize)> = VecDeque::new();
178    queue.push_back(src);
179    let mut enqueued = HashSet::new();
180    enqueued.insert(src);
181
182    while let Some((r, c)) = queue.pop_front() {
183        for (k, &(dr, dc)) in INC_ORDER.iter().enumerate() {
184            let nr = r as i32 + dr;
185            let nc = c as i32 + dc;
186            if grid.is_free_signed(nr, nc) {
187                let nr_u = nr as usize;
188                let nc_u = nc as usize;
189                let idx = nr_u * grid.cols + nc_u;
190                let cur_idx = r * grid.cols + c;
191
192                visited[cur_idx] = true;
193
194                let new_cost = mat[idx] + costs[k] + config.alpha * obstacle_dist[idx];
195                if new_cost < mat[cur_idx] {
196                    mat[cur_idx] = new_cost;
197                }
198
199                if !visited[idx] && !enqueued.contains(&(nr_u, nc_u)) {
200                    queue.push_back((nr_u, nc_u));
201                    enqueued.insert((nr_u, nc_u));
202                }
203            }
204        }
205    }
206
207    mat
208}
209
210/// Determine the 8-connected neighbour search order based on start→goal direction.
211///
212/// Prioritises directions that move away from the goal so that the greedy
213/// traversal covers remote cells first before heading back.
214fn search_order(start: (usize, usize), goal: (usize, usize)) -> [(i32, i32); 8] {
215    let sr = start.0 as i32;
216    let sc = start.1 as i32;
217    let gr = goal.0 as i32;
218    let gc = goal.1 as i32;
219
220    if sr >= gr && sc >= gc {
221        [
222            (1, 0),
223            (0, 1),
224            (-1, 0),
225            (0, -1),
226            (1, 1),
227            (1, -1),
228            (-1, 1),
229            (-1, -1),
230        ]
231    } else if sr <= gr && sc >= gc {
232        [
233            (-1, 0),
234            (0, 1),
235            (1, 0),
236            (0, -1),
237            (-1, 1),
238            (-1, -1),
239            (1, 1),
240            (1, -1),
241        ]
242    } else if sr >= gr && sc <= gc {
243        [
244            (1, 0),
245            (0, -1),
246            (-1, 0),
247            (0, 1),
248            (1, -1),
249            (-1, -1),
250            (1, 1),
251            (-1, 1),
252        ]
253    } else {
254        [
255            (-1, 0),
256            (0, -1),
257            (0, 1),
258            (1, 0),
259            (-1, -1),
260            (-1, 1),
261            (1, -1),
262            (1, 1),
263        ]
264    }
265}
266
267/// Run the wavefront coverage path planner.
268///
269/// Returns a path (list of `(row, col)` cells) that attempts to visit every
270/// reachable free cell, starting from `start` and ending at `goal`.
271///
272/// The algorithm:
273/// 1. Build a wavefront transform matrix via BFS from `goal`.
274/// 2. From `start`, greedily move to the unvisited neighbour with the highest
275///    transform value (farthest from goal).
276/// 3. When stuck, backtrace along the existing path to find a cell with an
277///    unvisited neighbour and resume from there.
278pub fn wavefront_cpp(
279    grid: &WavefrontGrid,
280    start: (usize, usize),
281    goal: (usize, usize),
282    config: &WavefrontCppConfig,
283) -> Vec<(usize, usize)> {
284    let transform = build_transform(grid, goal, config);
285    let order = search_order(start, goal);
286
287    let mut path: Vec<(usize, usize)> = Vec::new();
288    let mut visited = vec![false; grid.rows * grid.cols];
289    let mut current = start;
290
291    loop {
292        if current == goal {
293            path.push(current);
294            break;
295        }
296
297        let (r, c) = current;
298        path.push((r, c));
299        visited[r * grid.cols + c] = true;
300
301        // Search back through the path for a cell with an unvisited neighbour.
302        let mut best = None;
303        let mut best_val = f64::NEG_INFINITY;
304
305        for &(pr, pc) in path.iter().rev() {
306            for &(dr, dc) in &order {
307                let nr = pr as i32 + dr;
308                let nc = pc as i32 + dc;
309                if grid.is_free_signed(nr, nc) {
310                    let nr_u = nr as usize;
311                    let nc_u = nc as usize;
312                    let idx = nr_u * grid.cols + nc_u;
313                    if !visited[idx] && transform[idx] != f64::INFINITY && transform[idx] > best_val
314                    {
315                        best_val = transform[idx];
316                        best = Some((nr_u, nc_u));
317                    }
318                }
319            }
320            // If we found a candidate from the current backtrack cell, use it.
321            if best.is_some() {
322                break;
323            }
324        }
325
326        match best {
327            Some(next) => current = next,
328            None => {
329                // No reachable unvisited cell found — coverage complete.
330                break;
331            }
332        }
333    }
334
335    path
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    /// Helper: create a small open grid with no obstacles.
343    fn open_grid(rows: usize, cols: usize) -> WavefrontGrid {
344        WavefrontGrid::new(rows, cols)
345    }
346
347    #[test]
348    fn test_simple_open_grid_visits_all_cells() {
349        let grid = open_grid(5, 5);
350        let config = WavefrontCppConfig::default();
351        let path = wavefront_cpp(&grid, (4, 0), (0, 0), &config);
352
353        // Path should start at start and end at goal.
354        assert_eq!(*path.first().unwrap(), (4, 0));
355        assert_eq!(*path.last().unwrap(), (0, 0));
356
357        // All 25 free cells should be visited.
358        let unique: HashSet<_> = path.iter().copied().collect();
359        assert_eq!(unique.len(), 25);
360    }
361
362    #[test]
363    fn test_start_equals_goal() {
364        let grid = open_grid(3, 3);
365        let config = WavefrontCppConfig::default();
366        let path = wavefront_cpp(&grid, (0, 0), (0, 0), &config);
367
368        assert_eq!(*path.first().unwrap(), (0, 0));
369        assert_eq!(*path.last().unwrap(), (0, 0));
370    }
371
372    #[test]
373    fn test_grid_with_obstacles() {
374        // 5x5 grid with a wall in the middle row (row 2), leaving a gap at col 0.
375        let mut grid = WavefrontGrid::new(5, 5);
376        for c in 1..5 {
377            grid.set_obstacle(2, c, true);
378        }
379
380        let config = WavefrontCppConfig::default();
381        let path = wavefront_cpp(&grid, (4, 4), (0, 0), &config);
382
383        // No path cell should be on an obstacle.
384        for &(r, c) in &path {
385            assert!(!grid.is_obstacle(r, c), "Path on obstacle at ({r}, {c})");
386        }
387
388        // The 4 obstacle cells are not visited; 21 free cells should be covered.
389        let unique: HashSet<_> = path.iter().copied().collect();
390        assert_eq!(unique.len(), 21);
391    }
392
393    #[test]
394    fn test_euclidean_distance_type() {
395        let grid = open_grid(4, 4);
396        let config = WavefrontCppConfig {
397            distance_type: DistanceType::Euclidean,
398            transform_type: TransformType::Distance,
399            alpha: 0.0,
400        };
401        let path = wavefront_cpp(&grid, (3, 0), (0, 0), &config);
402
403        let unique: HashSet<_> = path.iter().copied().collect();
404        assert_eq!(unique.len(), 16);
405    }
406
407    #[test]
408    fn test_path_transform_type() {
409        let mut grid = WavefrontGrid::new(6, 6);
410        // Place a couple of obstacles.
411        grid.set_obstacle(2, 2, true);
412        grid.set_obstacle(2, 3, true);
413
414        let config = WavefrontCppConfig {
415            distance_type: DistanceType::Chessboard,
416            transform_type: TransformType::Path,
417            alpha: 0.01,
418        };
419        let path = wavefront_cpp(&grid, (5, 0), (0, 0), &config);
420
421        for &(r, c) in &path {
422            assert!(!grid.is_obstacle(r, c));
423        }
424
425        let unique: HashSet<_> = path.iter().copied().collect();
426        // 36 total - 2 obstacles = 34 free cells.
427        assert_eq!(unique.len(), 34);
428    }
429
430    #[test]
431    fn test_from_vec_constructor() {
432        #[rustfmt::skip]
433        let cells = vec![
434            false, false, false,
435            false, true,  false,
436            false, false, false,
437        ];
438        let grid = WavefrontGrid::from_vec(3, 3, cells);
439        assert!(grid.is_obstacle(1, 1));
440        assert!(!grid.is_obstacle(0, 0));
441    }
442
443    #[test]
444    fn test_obstacle_distance_transform_basic() {
445        let mut grid = WavefrontGrid::new(5, 5);
446        grid.set_obstacle(2, 2, true);
447        let dist = obstacle_distance_transform(&grid);
448
449        // Obstacle cell itself should be 0.
450        assert_eq!(dist[2 * 5 + 2], 0.0);
451        // Immediate neighbours should be 1.
452        assert_eq!(dist[5 + 2], 1.0);
453        assert_eq!(dist[2 * 5 + 1], 1.0);
454        // Corner cell (0,0) chessboard distance = 2.
455        assert_eq!(dist[0], 2.0);
456    }
457
458    #[test]
459    fn test_single_cell_grid() {
460        let grid = open_grid(1, 1);
461        let config = WavefrontCppConfig::default();
462        let path = wavefront_cpp(&grid, (0, 0), (0, 0), &config);
463        assert_eq!(path, vec![(0, 0)]);
464    }
465
466    #[test]
467    fn test_no_path_cells_are_duplicated() {
468        let grid = open_grid(4, 4);
469        let config = WavefrontCppConfig::default();
470        let path = wavefront_cpp(&grid, (3, 3), (0, 0), &config);
471
472        let unique: HashSet<_> = path.iter().copied().collect();
473        assert_eq!(
474            unique.len(),
475            path.len(),
476            "Path should not contain duplicates"
477        );
478    }
479
480    #[test]
481    fn test_search_order_all_quadrants() {
482        // Just verify they return 8 unique directions.
483        let cases = [
484            ((3, 3), (0, 0)),
485            ((0, 0), (3, 3)),
486            ((3, 0), (0, 3)),
487            ((0, 3), (3, 0)),
488        ];
489        for (s, g) in cases {
490            let order = search_order(s, g);
491            let set: HashSet<_> = order.iter().copied().collect();
492            assert_eq!(set.len(), 8);
493        }
494    }
495}