Skip to main content

rust_robotics_planning/
a_star.rs

1//! A* path planning algorithm
2//!
3//! Grid-based path planning using A* search algorithm with
4//! configurable heuristic weight.
5
6use std::cmp::Ordering;
7use std::collections::{BinaryHeap, HashMap};
8
9use crate::grid::{GridMap, Node};
10use rust_robotics_core::{Obstacles, Path2D, PathPlanner, Point2D, RoboticsError, RoboticsResult};
11
12/// Configuration for A* planner
13#[derive(Debug, Clone)]
14pub struct AStarConfig {
15    /// Grid resolution in meters
16    pub resolution: f64,
17    /// Robot radius for obstacle inflation
18    pub robot_radius: f64,
19    /// Heuristic weight (1.0 = optimal, >1.0 = faster but suboptimal)
20    pub heuristic_weight: f64,
21}
22
23impl Default for AStarConfig {
24    fn default() -> Self {
25        Self {
26            resolution: 1.0,
27            robot_radius: 0.5,
28            heuristic_weight: 1.0,
29        }
30    }
31}
32
33impl AStarConfig {
34    pub fn validate(&self) -> RoboticsResult<()> {
35        if !self.resolution.is_finite() || self.resolution <= 0.0 {
36            return Err(RoboticsError::InvalidParameter(format!(
37                "resolution must be positive and finite, got {}",
38                self.resolution
39            )));
40        }
41        if !self.robot_radius.is_finite() || self.robot_radius < 0.0 {
42            return Err(RoboticsError::InvalidParameter(format!(
43                "robot_radius must be non-negative and finite, got {}",
44                self.robot_radius
45            )));
46        }
47        if !self.heuristic_weight.is_finite() || self.heuristic_weight <= 0.0 {
48            return Err(RoboticsError::InvalidParameter(format!(
49                "heuristic_weight must be positive and finite, got {}",
50                self.heuristic_weight
51            )));
52        }
53
54        Ok(())
55    }
56}
57
58/// Node with priority for A* open set (min-heap)
59#[derive(Debug)]
60struct PriorityNode {
61    x: i32,
62    y: i32,
63    cost: f64,
64    priority: f64,
65    index: usize,
66}
67
68impl Eq for PriorityNode {}
69
70impl PartialEq for PriorityNode {
71    fn eq(&self, other: &Self) -> bool {
72        self.priority == other.priority
73    }
74}
75
76impl Ord for PriorityNode {
77    fn cmp(&self, other: &Self) -> Ordering {
78        // Reverse ordering for min-heap behavior
79        other
80            .priority
81            .partial_cmp(&self.priority)
82            .unwrap_or(Ordering::Equal)
83    }
84}
85
86impl PartialOrd for PriorityNode {
87    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88        Some(self.cmp(other))
89    }
90}
91
92/// A* path planner
93pub struct AStarPlanner {
94    grid_map: GridMap,
95    config: AStarConfig,
96    motion: Vec<(i32, i32, f64)>,
97}
98
99/// Search-effort statistics collected during one A* query.
100#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
101pub struct AStarSearchStats {
102    pub expanded_nodes: usize,
103    pub generated_nodes: usize,
104    pub skipped_closed_nodes: usize,
105    pub max_frontier_len: usize,
106}
107
108impl AStarPlanner {
109    /// Create a new A* planner with obstacle positions
110    pub fn new(ox: &[f64], oy: &[f64], config: AStarConfig) -> Self {
111        Self::try_new(ox, oy, config).expect(
112            "invalid A* planner input: obstacle list must be non-empty and valid, and config values must be positive/finite",
113        )
114    }
115
116    /// Create a validated A* planner with obstacle positions
117    pub fn try_new(ox: &[f64], oy: &[f64], config: AStarConfig) -> RoboticsResult<Self> {
118        config.validate()?;
119        let grid_map = GridMap::try_new(ox, oy, config.resolution, config.robot_radius)?;
120        let motion = Self::get_motion_model();
121
122        Ok(AStarPlanner {
123            grid_map,
124            config,
125            motion,
126        })
127    }
128
129    /// Create from obstacle x/y vectors with default config
130    pub fn from_obstacles(ox: &[f64], oy: &[f64], resolution: f64, robot_radius: f64) -> Self {
131        let config = AStarConfig {
132            resolution,
133            robot_radius,
134            ..Default::default()
135        };
136        Self::new(ox, oy, config)
137    }
138
139    /// Create a validated A* planner from typed obstacle points
140    pub fn from_obstacle_points(
141        obstacles: &Obstacles,
142        config: AStarConfig,
143    ) -> RoboticsResult<Self> {
144        config.validate()?;
145        let grid_map = GridMap::from_obstacles(obstacles, config.resolution, config.robot_radius)?;
146        let motion = Self::get_motion_model();
147
148        Ok(AStarPlanner {
149            grid_map,
150            config,
151            motion,
152        })
153    }
154
155    /// Plan a path returning (rx, ry) vectors (legacy interface)
156    #[deprecated(note = "use plan() or plan_xy() instead")]
157    pub fn planning(&self, sx: f64, sy: f64, gx: f64, gy: f64) -> Option<(Vec<f64>, Vec<f64>)> {
158        match self.plan_xy(sx, sy, gx, gy) {
159            Ok(path) => Some((path.x_coords(), path.y_coords())),
160            Err(_) => None,
161        }
162    }
163
164    /// Plan a path without requiring the PathPlanner trait in scope
165    pub fn plan(&self, start: Point2D, goal: Point2D) -> RoboticsResult<Path2D> {
166        self.plan_impl(start, goal).map(|(path, _stats)| path)
167    }
168
169    /// Plan a path from raw coordinates without requiring the PathPlanner trait in scope
170    pub fn plan_xy(&self, sx: f64, sy: f64, gx: f64, gy: f64) -> RoboticsResult<Path2D> {
171        self.plan_impl(Point2D::new(sx, sy), Point2D::new(gx, gy))
172            .map(|(path, _stats)| path)
173    }
174
175    /// Plan a path and return per-query search-effort statistics.
176    pub fn plan_with_stats(
177        &self,
178        start: Point2D,
179        goal: Point2D,
180    ) -> RoboticsResult<(Path2D, AStarSearchStats)> {
181        self.plan_impl(start, goal)
182    }
183
184    /// Get reference to the grid map
185    pub fn grid_map(&self) -> &GridMap {
186        &self.grid_map
187    }
188
189    fn calc_heuristic(&self, n1_x: i32, n1_y: i32, n2_x: i32, n2_y: i32) -> f64 {
190        self.config.heuristic_weight * (((n1_x - n2_x).pow(2) + (n1_y - n2_y).pow(2)) as f64).sqrt()
191    }
192
193    fn get_motion_model() -> Vec<(i32, i32, f64)> {
194        // dx, dy, cost (8-connected grid)
195        vec![
196            (1, 0, 1.0),
197            (0, 1, 1.0),
198            (-1, 0, 1.0),
199            (0, -1, 1.0),
200            (-1, -1, std::f64::consts::SQRT_2),
201            (-1, 1, std::f64::consts::SQRT_2),
202            (1, -1, std::f64::consts::SQRT_2),
203            (1, 1, std::f64::consts::SQRT_2),
204        ]
205    }
206
207    fn build_path(&self, goal_index: usize, node_storage: &[Node]) -> Path2D {
208        let mut points = Vec::new();
209        let mut current_index = Some(goal_index);
210
211        while let Some(index) = current_index {
212            let node = &node_storage[index];
213            points.push(Point2D::new(
214                self.grid_map.calc_x_position(node.x),
215                self.grid_map.calc_y_position(node.y),
216            ));
217            current_index = node.parent_index;
218        }
219
220        points.reverse();
221        Path2D::from_points(points)
222    }
223
224    fn ensure_query_is_valid(&self, x: i32, y: i32, label: &str) -> RoboticsResult<()> {
225        if self.grid_map.is_valid(x, y) {
226            return Ok(());
227        }
228
229        Err(RoboticsError::PlanningError(format!(
230            "{} position is invalid",
231            label
232        )))
233    }
234
235    fn plan_impl(
236        &self,
237        start: Point2D,
238        goal: Point2D,
239    ) -> RoboticsResult<(Path2D, AStarSearchStats)> {
240        let start_x = self.grid_map.calc_x_index(start.x);
241        let start_y = self.grid_map.calc_y_index(start.y);
242        let goal_x = self.grid_map.calc_x_index(goal.x);
243        let goal_y = self.grid_map.calc_y_index(goal.y);
244
245        self.ensure_query_is_valid(start_x, start_y, "Start")?;
246        self.ensure_query_is_valid(goal_x, goal_y, "Goal")?;
247
248        let mut open_set = BinaryHeap::new();
249        let mut closed_set = HashMap::new();
250        let mut node_storage: Vec<Node> = Vec::new();
251        let mut stats = AStarSearchStats::default();
252
253        // Add start node
254        node_storage.push(Node::new(start_x, start_y, 0.0, None));
255        let start_index = 0;
256
257        open_set.push(PriorityNode {
258            x: start_x,
259            y: start_y,
260            cost: 0.0,
261            priority: self.calc_heuristic(start_x, start_y, goal_x, goal_y),
262            index: start_index,
263        });
264        stats.max_frontier_len = open_set.len();
265
266        while let Some(current) = open_set.pop() {
267            let current_grid_index = self.grid_map.calc_index(current.x, current.y);
268
269            // Skip if already in closed set
270            if closed_set.contains_key(&current_grid_index) {
271                stats.skipped_closed_nodes += 1;
272                continue;
273            }
274
275            stats.expanded_nodes += 1;
276
277            // Check if we reached the goal
278            if current.x == goal_x && current.y == goal_y {
279                return Ok((self.build_path(current.index, &node_storage), stats));
280            }
281
282            // Move to closed set
283            closed_set.insert(current_grid_index, current.index);
284
285            // Expand neighbors
286            for &(dx, dy, move_cost) in &self.motion {
287                let new_x = current.x + dx;
288                let new_y = current.y + dy;
289                let new_cost = current.cost + move_cost;
290                let new_grid_index = self.grid_map.calc_index(new_x, new_y);
291
292                // Skip if not valid or already visited
293                if !self.grid_map.is_valid_offset(current.x, current.y, dx, dy) {
294                    continue;
295                }
296                if closed_set.contains_key(&new_grid_index) {
297                    continue;
298                }
299
300                // Add to storage and open set
301                node_storage.push(Node::new(new_x, new_y, new_cost, Some(current.index)));
302                let new_index = node_storage.len() - 1;
303                stats.generated_nodes += 1;
304
305                let priority = new_cost + self.calc_heuristic(new_x, new_y, goal_x, goal_y);
306                open_set.push(PriorityNode {
307                    x: new_x,
308                    y: new_y,
309                    cost: new_cost,
310                    priority,
311                    index: new_index,
312                });
313                stats.max_frontier_len = stats.max_frontier_len.max(open_set.len());
314            }
315        }
316
317        Err(RoboticsError::PlanningError("No path found".to_string()))
318    }
319}
320
321impl PathPlanner for AStarPlanner {
322    fn plan(&self, start: Point2D, goal: Point2D) -> Result<Path2D, RoboticsError> {
323        self.plan_impl(start, goal).map(|(path, _stats)| path)
324    }
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330    use crate::moving_ai::{MovingAiMap, MovingAiScenario};
331
332    use rust_robotics_core::Obstacles;
333
334    fn create_simple_obstacles() -> (Vec<f64>, Vec<f64>) {
335        let mut ox = Vec::new();
336        let mut oy = Vec::new();
337
338        // Boundary
339        for i in 0..11 {
340            ox.push(i as f64);
341            oy.push(0.0);
342            ox.push(i as f64);
343            oy.push(10.0);
344            ox.push(0.0);
345            oy.push(i as f64);
346            ox.push(10.0);
347            oy.push(i as f64);
348        }
349
350        // Internal obstacle
351        for i in 4..7 {
352            ox.push(5.0);
353            oy.push(i as f64);
354        }
355
356        (ox, oy)
357    }
358
359    #[test]
360    fn test_a_star_finds_path() {
361        let (ox, oy) = create_simple_obstacles();
362        let planner = AStarPlanner::from_obstacles(&ox, &oy, 1.0, 0.5);
363
364        let start = Point2D::new(2.0, 2.0);
365        let goal = Point2D::new(8.0, 8.0);
366
367        let result = planner.plan(start, goal);
368        assert!(result.is_ok());
369
370        let path = result.unwrap();
371        assert!(!path.is_empty());
372    }
373
374    #[test]
375    #[allow(deprecated)]
376    fn test_a_star_legacy_interface() {
377        let (ox, oy) = create_simple_obstacles();
378        let planner = AStarPlanner::from_obstacles(&ox, &oy, 1.0, 0.5);
379
380        let result = planner.planning(2.0, 2.0, 8.0, 8.0);
381        assert!(result.is_some());
382
383        let (rx, ry) = result.unwrap();
384        assert!(!rx.is_empty());
385        assert_eq!(rx.len(), ry.len());
386    }
387
388    #[test]
389    fn test_a_star_from_obstacle_points() {
390        let (ox, oy) = create_simple_obstacles();
391        let obstacles = Obstacles::try_from_xy(&ox, &oy).unwrap();
392        let planner =
393            AStarPlanner::from_obstacle_points(&obstacles, AStarConfig::default()).unwrap();
394
395        let path = planner.plan_xy(2.0, 2.0, 8.0, 8.0).unwrap();
396        assert!(!path.is_empty());
397    }
398
399    #[test]
400    fn test_a_star_try_new_rejects_invalid_config() {
401        let (ox, oy) = create_simple_obstacles();
402        let config = AStarConfig {
403            heuristic_weight: 0.0,
404            ..Default::default()
405        };
406
407        let err = match AStarPlanner::try_new(&ox, &oy, config) {
408            Ok(_) => panic!("expected invalid config to be rejected"),
409            Err(err) => err,
410        };
411        assert!(matches!(err, RoboticsError::InvalidParameter(_)));
412    }
413
414    #[test]
415    fn test_a_star_preserves_asymmetric_world_coordinates() {
416        let mut obstacles = Obstacles::new();
417
418        for x in 10..=20 {
419            obstacles.push(Point2D::new(x as f64, -4.0));
420            obstacles.push(Point2D::new(x as f64, 6.0));
421        }
422        for y in -4..=6 {
423            obstacles.push(Point2D::new(10.0, y as f64));
424            obstacles.push(Point2D::new(20.0, y as f64));
425        }
426
427        let planner =
428            AStarPlanner::from_obstacle_points(&obstacles, AStarConfig::default()).unwrap();
429        let path = planner.plan_xy(12.0, -2.0, 18.0, 4.0).unwrap();
430
431        let first = path.points.first().unwrap();
432        let last = path.points.last().unwrap();
433        assert!((first.x - 12.0).abs() < 1e-6);
434        assert!((first.y + 2.0).abs() < 1e-6);
435        assert!((last.x - 18.0).abs() < 1e-6);
436        assert!((last.y - 4.0).abs() < 1e-6);
437    }
438
439    #[test]
440    #[ignore = "long-running MovingAI benchmark"]
441    fn test_a_star_matches_moving_ai_arena2_bucket_80_optimal_length() {
442        let map = MovingAiMap::parse_str(include_str!("../benchdata/moving_ai/dao/arena2.map"))
443            .expect("arena2 MovingAI map should parse");
444        let scenario =
445            MovingAiScenario::parse_str(include_str!("../benchdata/moving_ai/dao/arena2.map.scen"))
446                .expect("arena2 MovingAI scenarios should parse")
447                .into_iter()
448                .find(|row| row.bucket == 80)
449                .expect("arena2 MovingAI bucket 80 scenario should exist");
450
451        let planner = AStarPlanner::from_obstacle_points(
452            &map.to_obstacles(),
453            AStarConfig {
454                resolution: 1.0,
455                robot_radius: 0.5,
456                heuristic_weight: 1.0,
457            },
458        )
459        .expect("A* planner should build from arena2 obstacles");
460
461        let start = map
462            .planning_point(scenario.start_x, scenario.start_y)
463            .expect("arena2 start should be valid");
464        let goal = map
465            .planning_point(scenario.goal_x, scenario.goal_y)
466            .expect("arena2 goal should be valid");
467
468        let path = planner
469            .plan(start, goal)
470            .expect("A* should solve the arena2 bucket 80 scenario");
471
472        assert!(
473            (path.total_length() - scenario.optimal_length).abs() < 1e-6,
474            "A* path length {} should match MovingAI optimal {} when corner cutting is disabled",
475            path.total_length(),
476            scenario.optimal_length
477        );
478    }
479}