Skip to main content

rust_robotics_planning/
spiral_spanning_tree_cpp.rs

1//! Spiral Spanning Tree Coverage Path Planner
2//!
3//! Implements the Spiral-STC algorithm for coverage path planning on a grid.
4//!
5//! Reference paper: "Spiral-STC: An On-Line Coverage Algorithm of Grid Environments
6//! by a Mobile Robot" by Gabriely et al.
7//! <https://ieeexplore.ieee.org/abstract/document/1013479>
8//!
9//! The algorithm works by:
10//! 1. Merging the original grid into 2x2 mega-cells.
11//! 2. Building a spanning tree over the free mega-cells using a recursive DFS.
12//! 3. Tracing a spiral coverage path around the spanning tree edges at original resolution.
13
14use std::collections::HashSet;
15
16/// A node position on the merged (half-resolution) grid.
17pub type MergedNode = (i32, i32);
18
19/// A point on the original (full-resolution) grid.
20pub type SubNode = [i32; 2];
21
22/// A directed edge in the spanning tree (from, to).
23pub type TreeEdge = (MergedNode, MergedNode);
24
25/// A path segment consisting of two sub-nodes (entry, exit) in original resolution.
26pub type PathSegment = [SubNode; 2];
27
28/// Cardinal direction for movement between adjacent merged nodes.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30enum Direction {
31    North,
32    South,
33    East,
34    West,
35}
36
37/// Result of the coverage path planning.
38#[derive(Debug, Clone)]
39pub struct CoveragePlanResult {
40    /// Spanning tree edges on the merged grid.
41    pub edges: Vec<TreeEdge>,
42    /// Route of merged-grid nodes visited (including backtrace duplicates).
43    pub route: Vec<MergedNode>,
44    /// Coverage path segments at original grid resolution.
45    pub path: Vec<PathSegment>,
46}
47
48/// Occupancy grid for the Spiral-STC planner.
49///
50/// The grid is stored in row-major order with `true` = free and `false` = occupied.
51/// Width and height must both be even numbers so that 2x2 mega-cell merging works.
52#[derive(Debug, Clone)]
53pub struct OccupancyGrid {
54    width: usize,
55    height: usize,
56    /// Row-major occupancy: `data[row * width + col]`. `true` means free.
57    data: Vec<bool>,
58}
59
60impl OccupancyGrid {
61    /// Create a new occupancy grid.
62    ///
63    /// # Panics
64    /// Panics if `width` or `height` is odd, or if `data.len() != width * height`.
65    pub fn new(width: usize, height: usize, data: Vec<bool>) -> Self {
66        assert!(width % 2 == 0, "width must be even, got {width}");
67        assert!(height % 2 == 0, "height must be even, got {height}");
68        assert_eq!(data.len(), width * height);
69        Self {
70            width,
71            height,
72            data,
73        }
74    }
75
76    /// Create an all-free grid of the given dimensions.
77    ///
78    /// # Panics
79    /// Panics if `width` or `height` is odd.
80    pub fn all_free(width: usize, height: usize) -> Self {
81        Self::new(width, height, vec![true; width * height])
82    }
83
84    fn get(&self, row: usize, col: usize) -> bool {
85        self.data[row * self.width + col]
86    }
87}
88
89/// Spiral Spanning Tree Coverage Path Planner.
90pub struct SpiralSpanningTreePlanner {
91    occ: OccupancyGrid,
92    merged_height: usize,
93    merged_width: usize,
94}
95
96impl SpiralSpanningTreePlanner {
97    /// Create a planner from an occupancy grid.
98    pub fn new(occ: OccupancyGrid) -> Self {
99        let merged_height = occ.height / 2;
100        let merged_width = occ.width / 2;
101        Self {
102            occ,
103            merged_height,
104            merged_width,
105        }
106    }
107
108    /// Plan a coverage path starting from the given merged-grid node `(row, col)`.
109    ///
110    /// Returns the spanning tree edges, the route on the merged grid, and the
111    /// coverage path segments at original resolution.
112    pub fn plan(&self, start: MergedNode) -> CoveragePlanResult {
113        let mh = self.merged_height;
114        let mw = self.merged_width;
115
116        let mut visit_times = vec![vec![0u8; mw]; mh];
117        visit_times[start.0 as usize][start.1 as usize] = 1;
118
119        let mut edges = Vec::new();
120        let mut route = Vec::new();
121
122        self.build_spanning_tree(start, &mut visit_times, &mut route, &mut edges);
123
124        // Generate coverage path from the route.
125        let mut path: Vec<PathSegment> = Vec::new();
126        for idx in 0..route.len().saturating_sub(1) {
127            let cur = route[idx];
128            let next = route[idx + 1];
129            let dp = (cur.0 - next.0).unsigned_abs() + (cur.1 - next.1).unsigned_abs();
130
131            match dp {
132                0 => {
133                    // Round-trip: node revisited during backtrace.
134                    if idx > 0 {
135                        let seg = self.round_trip_path(route[idx - 1], cur);
136                        path.push(seg);
137                    }
138                }
139                1 => {
140                    path.push(self.move_segment(cur, next));
141                }
142                2 => {
143                    // Non-adjacent route nodes: insert intermediate node from spanning tree.
144                    let mid = self.intermediate_node(cur, next, &edges);
145                    path.push(self.move_segment(cur, mid));
146                    path.push(self.move_segment(mid, next));
147                }
148                _ => panic!("adjacent route node distance > 2: {dp}"),
149            }
150        }
151
152        CoveragePlanResult { edges, route, path }
153    }
154
155    /// Check whether a merged-grid cell is valid (in bounds and all 4 sub-cells free).
156    fn is_valid_merged(&self, i: i32, j: i32) -> bool {
157        if i < 0 || j < 0 {
158            return false;
159        }
160        let (ui, uj) = (i as usize, j as usize);
161        if ui >= self.merged_height || uj >= self.merged_width {
162            return false;
163        }
164        let r = 2 * ui;
165        let c = 2 * uj;
166        self.occ.get(r, c)
167            && self.occ.get(r + 1, c)
168            && self.occ.get(r, c + 1)
169            && self.occ.get(r + 1, c + 1)
170    }
171
172    /// Recursive DFS to build spanning tree and route.
173    fn build_spanning_tree(
174        &self,
175        current: MergedNode,
176        visit_times: &mut [Vec<u8>],
177        route: &mut Vec<MergedNode>,
178        edges: &mut Vec<TreeEdge>,
179    ) {
180        // Counter-clockwise neighbor order: S, E, N, W
181        const ORDER: [(i32, i32); 4] = [(1, 0), (0, 1), (-1, 0), (0, -1)];
182
183        route.push(current);
184        let mut found = false;
185
186        for &(di, dj) in &ORDER {
187            let ni = current.0 + di;
188            let nj = current.1 + dj;
189            if self.is_valid_merged(ni, nj) && visit_times[ni as usize][nj as usize] == 0 {
190                let neighbor = (ni, nj);
191                edges.push((current, neighbor));
192                found = true;
193                visit_times[ni as usize][nj as usize] = 1;
194                self.build_spanning_tree(neighbor, visit_times, route, edges);
195            }
196        }
197
198        // Backtrace from dead-end to first node with unvisited neighbor.
199        if !found {
200            let mut has_unvisited_ngb = false;
201            for node in route.clone().iter().rev() {
202                if visit_times[node.0 as usize][node.1 as usize] == 2 {
203                    continue;
204                }
205                visit_times[node.0 as usize][node.1 as usize] += 1;
206                route.push(*node);
207
208                for &(di, dj) in &ORDER {
209                    let ni = node.0 + di;
210                    let nj = node.1 + dj;
211                    if self.is_valid_merged(ni, nj) && visit_times[ni as usize][nj as usize] == 0 {
212                        has_unvisited_ngb = true;
213                        break;
214                    }
215                }
216                if has_unvisited_ngb {
217                    break;
218                }
219            }
220        }
221    }
222
223    /// Determine the cardinal direction from merged node `p` to merged node `q`.
224    fn direction(p: MergedNode, q: MergedNode) -> Direction {
225        if p.0 == q.0 && p.1 < q.1 {
226            Direction::East
227        } else if p.0 == q.0 && p.1 > q.1 {
228            Direction::West
229        } else if p.0 < q.0 && p.1 == q.1 {
230            Direction::South
231        } else if p.0 > q.0 && p.1 == q.1 {
232            Direction::North
233        } else {
234            panic!(
235                "direction: only cardinal directions supported, got {:?} -> {:?}",
236                p, q
237            );
238        }
239    }
240
241    /// Convert a merged node to one of its four sub-nodes at original resolution.
242    fn sub_node(node: MergedNode, quad: &str) -> SubNode {
243        let (r, c) = (node.0, node.1);
244        match quad {
245            "SE" => [2 * r + 1, 2 * c + 1],
246            "SW" => [2 * r + 1, 2 * c],
247            "NE" => [2 * r, 2 * c + 1],
248            "NW" => [2 * r, 2 * c],
249            _ => panic!("sub_node: invalid quadrant '{quad}'"),
250        }
251    }
252
253    /// Compute the path segment when moving from merged node `p` to adjacent `q`.
254    fn move_segment(&self, p: MergedNode, q: MergedNode) -> PathSegment {
255        match Self::direction(p, q) {
256            Direction::East => [Self::sub_node(p, "SE"), Self::sub_node(q, "SW")],
257            Direction::West => [Self::sub_node(p, "NW"), Self::sub_node(q, "NE")],
258            Direction::South => [Self::sub_node(p, "SW"), Self::sub_node(q, "NW")],
259            Direction::North => [Self::sub_node(p, "NE"), Self::sub_node(q, "SE")],
260        }
261    }
262
263    /// Compute the round-trip path segment when backtracing at a pivot node.
264    fn round_trip_path(&self, last: MergedNode, pivot: MergedNode) -> PathSegment {
265        match Self::direction(last, pivot) {
266            Direction::East => [Self::sub_node(pivot, "SE"), Self::sub_node(pivot, "NE")],
267            Direction::South => [Self::sub_node(pivot, "SW"), Self::sub_node(pivot, "SE")],
268            Direction::West => [Self::sub_node(pivot, "NW"), Self::sub_node(pivot, "SW")],
269            Direction::North => [Self::sub_node(pivot, "NE"), Self::sub_node(pivot, "NW")],
270        }
271    }
272
273    /// Find the intermediate node between two non-adjacent route nodes
274    /// by looking for a shared neighbor in the spanning tree.
275    fn intermediate_node(&self, p: MergedNode, q: MergedNode, edges: &[TreeEdge]) -> MergedNode {
276        let mut p_ngb = HashSet::new();
277        let mut q_ngb = HashSet::new();
278
279        for &(m, n) in edges {
280            if m == p {
281                p_ngb.insert(n);
282            }
283            if n == p {
284                p_ngb.insert(m);
285            }
286            if m == q {
287                q_ngb.insert(n);
288            }
289            if n == q {
290                q_ngb.insert(m);
291            }
292        }
293
294        let intersection: Vec<_> = p_ngb.intersection(&q_ngb).copied().collect();
295        assert!(
296            intersection.len() == 1,
297            "expected exactly 1 intermediate node between {p:?} and {q:?}, found {}",
298            intersection.len()
299        );
300        intersection[0]
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    /// Helper: create an all-free grid and plan from the given start.
309    fn plan_all_free(height: usize, width: usize, start: MergedNode) -> CoveragePlanResult {
310        let grid = OccupancyGrid::all_free(width, height);
311        let planner = SpiralSpanningTreePlanner::new(grid);
312        planner.plan(start)
313    }
314
315    #[test]
316    fn test_small_2x2_grid() {
317        // 2x2 original grid -> 1x1 merged grid -> single node, no edges.
318        // Route has 2 entries because the backtrace appends the node once more.
319        let result = plan_all_free(2, 2, (0, 0));
320        assert_eq!(result.route.len(), 2);
321        assert!(result.edges.is_empty());
322        assert!(result.path.is_empty());
323    }
324
325    #[test]
326    fn test_4x4_grid_visits_all_merged_cells() {
327        // 4x4 original -> 2x2 merged grid -> 4 merged cells.
328        let result = plan_all_free(4, 4, (0, 0));
329
330        let visited: HashSet<MergedNode> = result.route.iter().copied().collect();
331        // All 4 merged cells must be visited.
332        for r in 0..2 {
333            for c in 0..2 {
334                assert!(
335                    visited.contains(&(r, c)),
336                    "merged cell ({r}, {c}) was not visited"
337                );
338            }
339        }
340    }
341
342    #[test]
343    fn test_spanning_tree_edge_count() {
344        // A spanning tree on n nodes has n-1 edges.
345        let result = plan_all_free(4, 4, (0, 0));
346        // 2x2 = 4 merged cells -> 3 edges.
347        assert_eq!(result.edges.len(), 3);
348    }
349
350    #[test]
351    fn test_6x6_grid_all_merged_cells_visited() {
352        // 6x6 original -> 3x3 merged grid -> 9 cells.
353        let result = plan_all_free(6, 6, (0, 0));
354        let visited: HashSet<MergedNode> = result.route.iter().copied().collect();
355        for r in 0..3 {
356            for c in 0..3 {
357                assert!(
358                    visited.contains(&(r, c)),
359                    "merged cell ({r}, {c}) not visited"
360                );
361            }
362        }
363        // 9 nodes -> 8 spanning tree edges.
364        assert_eq!(result.edges.len(), 8);
365    }
366
367    #[test]
368    fn test_path_segments_use_original_resolution() {
369        let result = plan_all_free(4, 4, (0, 0));
370        // All path segment coordinates should be within original grid bounds.
371        for seg in &result.path {
372            for pt in seg {
373                assert!(pt[0] >= 0 && pt[0] < 4, "row out of bounds: {}", pt[0]);
374                assert!(pt[1] >= 0 && pt[1] < 4, "col out of bounds: {}", pt[1]);
375            }
376        }
377    }
378
379    #[test]
380    fn test_obstacle_blocks_merged_cell() {
381        // 4x4 grid with one sub-cell blocked makes one merged cell invalid.
382        let mut data = vec![true; 16];
383        // Block cell (0,0) in original grid -> merged cell (0,0) is invalid.
384        data[0] = false;
385        let grid = OccupancyGrid::new(4, 4, data);
386        let planner = SpiralSpanningTreePlanner::new(grid);
387        // Start from merged cell (0,1) which is still free.
388        let result = planner.plan((0, 1));
389
390        let visited: HashSet<MergedNode> = result.route.iter().copied().collect();
391        // Merged cell (0,0) should NOT be visited since one sub-cell is blocked.
392        assert!(
393            !visited.contains(&(0, 0)),
394            "blocked merged cell (0,0) should not be visited"
395        );
396    }
397
398    #[test]
399    fn test_path_not_empty_for_multi_cell() {
400        let result = plan_all_free(4, 4, (0, 0));
401        assert!(
402            !result.path.is_empty(),
403            "path should not be empty for a multi-cell grid"
404        );
405    }
406
407    #[test]
408    fn test_edges_connect_adjacent_nodes() {
409        let result = plan_all_free(6, 6, (1, 1));
410        for &(a, b) in &result.edges {
411            let dist = (a.0 - b.0).unsigned_abs() + (a.1 - b.1).unsigned_abs();
412            assert_eq!(dist, 1, "edge {a:?}-{b:?} is not between adjacent nodes");
413        }
414    }
415
416    #[test]
417    fn test_different_start_positions() {
418        // The algorithm should work from any valid free merged cell.
419        for &start in &[(0, 0), (1, 0), (0, 1), (1, 1)] {
420            let result = plan_all_free(4, 4, start);
421            let visited: HashSet<MergedNode> = result.route.iter().copied().collect();
422            assert_eq!(
423                visited.len(),
424                4,
425                "all 4 merged cells should be visited from start {start:?}"
426            );
427        }
428    }
429
430    #[test]
431    fn test_larger_grid_coverage() {
432        // 10x10 original -> 5x5 merged = 25 cells.
433        let result = plan_all_free(10, 10, (0, 0));
434        let visited: HashSet<MergedNode> = result.route.iter().copied().collect();
435        assert_eq!(visited.len(), 25);
436        assert_eq!(result.edges.len(), 24);
437    }
438
439    #[test]
440    #[should_panic(expected = "width must be even")]
441    fn test_odd_width_panics() {
442        OccupancyGrid::new(3, 4, vec![true; 12]);
443    }
444
445    #[test]
446    #[should_panic(expected = "height must be even")]
447    fn test_odd_height_panics() {
448        OccupancyGrid::new(4, 3, vec![true; 12]);
449    }
450
451    #[test]
452    fn test_sub_node_coordinates() {
453        // Verify sub_node mapping for merged node (1, 2).
454        assert_eq!(SpiralSpanningTreePlanner::sub_node((1, 2), "NW"), [2, 4]);
455        assert_eq!(SpiralSpanningTreePlanner::sub_node((1, 2), "NE"), [2, 5]);
456        assert_eq!(SpiralSpanningTreePlanner::sub_node((1, 2), "SW"), [3, 4]);
457        assert_eq!(SpiralSpanningTreePlanner::sub_node((1, 2), "SE"), [3, 5]);
458    }
459}