Skip to main content

rust_robotics_planning/
theta_star.rs

1//! Theta* path planning algorithm
2//!
3//! Theta* is an any-angle path planning algorithm that extends A* by
4//! allowing paths to connect any two visible nodes, not just grid neighbors.
5//! This produces shorter, more natural paths compared to standard A*.
6//!
7//! Key features:
8//! - Line-of-sight checks to skip intermediate nodes
9//! - Produces any-angle paths (not restricted to grid directions)
10//! - Optimal or near-optimal path lengths
11//!
12//! Reference: Nash, A., Daniel, K., Koenig, S., & Felner, A. (2007).
13//! "Theta*: Any-Angle Path Planning on Grids"
14
15use std::cmp::Ordering;
16use std::collections::{BinaryHeap, HashMap};
17
18use crate::grid::{GridMap, Node};
19use rust_robotics_core::{Obstacles, Path2D, PathPlanner, Point2D, RoboticsError, RoboticsResult};
20
21/// Configuration for Theta* planner
22#[derive(Debug, Clone)]
23pub struct ThetaStarConfig {
24    pub resolution: f64,
25    pub robot_radius: f64,
26    pub heuristic_weight: f64,
27}
28
29impl Default for ThetaStarConfig {
30    fn default() -> Self {
31        Self {
32            resolution: 1.0,
33            robot_radius: 0.5,
34            heuristic_weight: 1.0,
35        }
36    }
37}
38
39impl ThetaStarConfig {
40    pub fn validate(&self) -> RoboticsResult<()> {
41        if !self.resolution.is_finite() || self.resolution <= 0.0 {
42            return Err(RoboticsError::InvalidParameter(format!(
43                "resolution must be positive and finite, got {}",
44                self.resolution
45            )));
46        }
47        if !self.robot_radius.is_finite() || self.robot_radius < 0.0 {
48            return Err(RoboticsError::InvalidParameter(format!(
49                "robot_radius must be non-negative and finite, got {}",
50                self.robot_radius
51            )));
52        }
53        if !self.heuristic_weight.is_finite() || self.heuristic_weight <= 0.0 {
54            return Err(RoboticsError::InvalidParameter(format!(
55                "heuristic_weight must be positive and finite, got {}",
56                self.heuristic_weight
57            )));
58        }
59        Ok(())
60    }
61}
62
63#[derive(Debug)]
64struct PriorityNode {
65    x: i32,
66    y: i32,
67    cost: f64,
68    priority: f64,
69    index: usize,
70}
71impl Eq for PriorityNode {}
72impl PartialEq for PriorityNode {
73    fn eq(&self, other: &Self) -> bool {
74        self.priority == other.priority
75    }
76}
77impl Ord for PriorityNode {
78    fn cmp(&self, other: &Self) -> Ordering {
79        other
80            .priority
81            .partial_cmp(&self.priority)
82            .unwrap_or(Ordering::Equal)
83    }
84}
85impl PartialOrd for PriorityNode {
86    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
87        Some(self.cmp(other))
88    }
89}
90
91pub struct ThetaStarPlanner {
92    grid_map: GridMap,
93    config: ThetaStarConfig,
94    motion: Vec<(i32, i32, f64)>,
95}
96
97impl ThetaStarPlanner {
98    pub fn new(ox: &[f64], oy: &[f64], config: ThetaStarConfig) -> Self {
99        Self::try_new(ox, oy, config).expect(
100            "invalid Theta* planner input: obstacle list must be non-empty and valid, and config values must be positive/finite",
101        )
102    }
103
104    pub fn try_new(ox: &[f64], oy: &[f64], config: ThetaStarConfig) -> RoboticsResult<Self> {
105        config.validate()?;
106        let grid_map = GridMap::try_new(ox, oy, config.resolution, config.robot_radius)?;
107        let motion = Self::get_motion_model();
108        Ok(ThetaStarPlanner {
109            grid_map,
110            config,
111            motion,
112        })
113    }
114
115    pub fn from_obstacles(ox: &[f64], oy: &[f64], resolution: f64, robot_radius: f64) -> Self {
116        let config = ThetaStarConfig {
117            resolution,
118            robot_radius,
119            ..Default::default()
120        };
121        Self::new(ox, oy, config)
122    }
123
124    pub fn from_obstacle_points(
125        obstacles: &Obstacles,
126        config: ThetaStarConfig,
127    ) -> RoboticsResult<Self> {
128        config.validate()?;
129        let grid_map = GridMap::from_obstacles(obstacles, config.resolution, config.robot_radius)?;
130        let motion = Self::get_motion_model();
131        Ok(ThetaStarPlanner {
132            grid_map,
133            config,
134            motion,
135        })
136    }
137
138    #[deprecated(note = "use plan() or plan_xy() instead")]
139    pub fn planning(&self, sx: f64, sy: f64, gx: f64, gy: f64) -> Option<(Vec<f64>, Vec<f64>)> {
140        match self.plan_xy(sx, sy, gx, gy) {
141            Ok(path) => Some((path.x_coords(), path.y_coords())),
142            Err(_) => None,
143        }
144    }
145
146    pub fn plan(&self, start: Point2D, goal: Point2D) -> RoboticsResult<Path2D> {
147        self.plan_impl(start, goal)
148    }
149
150    pub fn plan_xy(&self, sx: f64, sy: f64, gx: f64, gy: f64) -> RoboticsResult<Path2D> {
151        self.plan_impl(Point2D::new(sx, sy), Point2D::new(gx, gy))
152    }
153
154    pub fn grid_map(&self) -> &GridMap {
155        &self.grid_map
156    }
157
158    fn calc_heuristic(&self, n1_x: i32, n1_y: i32, n2_x: i32, n2_y: i32) -> f64 {
159        self.config.heuristic_weight * (((n1_x - n2_x).pow(2) + (n1_y - n2_y).pow(2)) as f64).sqrt()
160    }
161
162    fn get_motion_model() -> Vec<(i32, i32, f64)> {
163        vec![
164            (1, 0, 1.0),
165            (0, 1, 1.0),
166            (-1, 0, 1.0),
167            (0, -1, 1.0),
168            (-1, -1, std::f64::consts::SQRT_2),
169            (-1, 1, std::f64::consts::SQRT_2),
170            (1, -1, std::f64::consts::SQRT_2),
171            (1, 1, std::f64::consts::SQRT_2),
172        ]
173    }
174
175    fn line_of_sight(&self, x0: i32, y0: i32, x1: i32, y1: i32) -> bool {
176        if !self.grid_map.is_valid(x0, y0) || !self.grid_map.is_valid(x1, y1) {
177            return false;
178        }
179
180        if x0 == x1 && y0 == y1 {
181            return true;
182        }
183
184        let dx = x1 - x0;
185        let dy = y1 - y0;
186        let step_x = dx.signum();
187        let step_y = dy.signum();
188        let abs_dx = dx.abs() as f64;
189        let abs_dy = dy.abs() as f64;
190
191        let mut x = x0;
192        let mut y = y0;
193        let mut t_max_x = if step_x != 0 {
194            0.5 / abs_dx
195        } else {
196            f64::INFINITY
197        };
198        let mut t_max_y = if step_y != 0 {
199            0.5 / abs_dy
200        } else {
201            f64::INFINITY
202        };
203        let t_delta_x = if step_x != 0 {
204            1.0 / abs_dx
205        } else {
206            f64::INFINITY
207        };
208        let t_delta_y = if step_y != 0 {
209            1.0 / abs_dy
210        } else {
211            f64::INFINITY
212        };
213
214        while x != x1 || y != y1 {
215            let advance_x = t_max_x <= t_max_y;
216            let advance_y = t_max_y <= t_max_x;
217            let next_x = if advance_x { x + step_x } else { x };
218            let next_y = if advance_y { y + step_y } else { y };
219
220            if !self.grid_map.is_valid_step(x, y, next_x, next_y) {
221                return false;
222            }
223
224            x = next_x;
225            y = next_y;
226
227            if advance_x {
228                t_max_x += t_delta_x;
229            }
230            if advance_y {
231                t_max_y += t_delta_y;
232            }
233        }
234
235        true
236    }
237
238    fn euclidean_distance(&self, x1: i32, y1: i32, x2: i32, y2: i32) -> f64 {
239        (((x1 - x2).pow(2) + (y1 - y2).pow(2)) as f64).sqrt()
240    }
241
242    fn build_path(&self, goal_index: usize, node_storage: &[Node]) -> Path2D {
243        let mut points = Vec::new();
244        let mut current_index = Some(goal_index);
245        while let Some(index) = current_index {
246            let node = &node_storage[index];
247            points.push(Point2D::new(
248                self.grid_map.calc_x_position(node.x),
249                self.grid_map.calc_y_position(node.y),
250            ));
251            current_index = node.parent_index;
252        }
253        points.reverse();
254        Path2D::from_points(points)
255    }
256
257    fn ensure_query_is_valid(&self, x: i32, y: i32, label: &str) -> RoboticsResult<()> {
258        if self.grid_map.is_valid(x, y) {
259            return Ok(());
260        }
261        Err(RoboticsError::PlanningError(format!(
262            "{} position is invalid",
263            label
264        )))
265    }
266
267    fn plan_impl(&self, start: Point2D, goal: Point2D) -> RoboticsResult<Path2D> {
268        let start_x = self.grid_map.calc_x_index(start.x);
269        let start_y = self.grid_map.calc_y_index(start.y);
270        let goal_x = self.grid_map.calc_x_index(goal.x);
271        let goal_y = self.grid_map.calc_y_index(goal.y);
272
273        self.ensure_query_is_valid(start_x, start_y, "Start")?;
274        self.ensure_query_is_valid(goal_x, goal_y, "Goal")?;
275
276        let mut open_set = BinaryHeap::new();
277        let mut closed_set = HashMap::new();
278        let mut node_storage: Vec<Node> = Vec::new();
279        let mut g_values: HashMap<i32, f64> = HashMap::new();
280        let mut best_index: HashMap<i32, usize> = HashMap::new();
281
282        node_storage.push(Node::new(start_x, start_y, 0.0, None));
283        let start_index = 0;
284        let start_grid_index = self.grid_map.calc_index(start_x, start_y);
285        g_values.insert(start_grid_index, 0.0);
286        best_index.insert(start_grid_index, start_index);
287
288        open_set.push(PriorityNode {
289            x: start_x,
290            y: start_y,
291            cost: 0.0,
292            priority: self.calc_heuristic(start_x, start_y, goal_x, goal_y),
293            index: start_index,
294        });
295
296        while let Some(current) = open_set.pop() {
297            let current_grid_index = self.grid_map.calc_index(current.x, current.y);
298            if current.x == goal_x && current.y == goal_y {
299                return Ok(self.build_path(current.index, &node_storage));
300            }
301            if closed_set.contains_key(&current_grid_index) {
302                continue;
303            }
304            closed_set.insert(current_grid_index, current.index);
305
306            let current_node = &node_storage[current.index];
307            let parent_index = current_node.parent_index;
308
309            for &(dx, dy, _) in &self.motion {
310                let new_x = current.x + dx;
311                let new_y = current.y + dy;
312                let new_grid_index = self.grid_map.calc_index(new_x, new_y);
313                if !self.grid_map.is_valid_offset(current.x, current.y, dx, dy) {
314                    continue;
315                }
316                if closed_set.contains_key(&new_grid_index) {
317                    continue;
318                }
319
320                let (new_cost, new_parent_index) = if let Some(p_idx) = parent_index {
321                    let parent_node = &node_storage[p_idx];
322                    if self.line_of_sight(parent_node.x, parent_node.y, new_x, new_y) {
323                        let dist =
324                            self.euclidean_distance(parent_node.x, parent_node.y, new_x, new_y);
325                        (parent_node.cost + dist, Some(p_idx))
326                    } else {
327                        let dist = self.euclidean_distance(current.x, current.y, new_x, new_y);
328                        (current.cost + dist, Some(current.index))
329                    }
330                } else {
331                    let dist = self.euclidean_distance(current.x, current.y, new_x, new_y);
332                    (current.cost + dist, Some(current.index))
333                };
334
335                let existing_g = g_values
336                    .get(&new_grid_index)
337                    .copied()
338                    .unwrap_or(f64::INFINITY);
339                if new_cost < existing_g {
340                    g_values.insert(new_grid_index, new_cost);
341                    node_storage.push(Node::new(new_x, new_y, new_cost, new_parent_index));
342                    let new_index = node_storage.len() - 1;
343                    best_index.insert(new_grid_index, new_index);
344                    let priority = new_cost + self.calc_heuristic(new_x, new_y, goal_x, goal_y);
345                    open_set.push(PriorityNode {
346                        x: new_x,
347                        y: new_y,
348                        cost: new_cost,
349                        priority,
350                        index: new_index,
351                    });
352                }
353            }
354        }
355
356        Err(RoboticsError::PlanningError("No path found".to_string()))
357    }
358}
359
360impl PathPlanner for ThetaStarPlanner {
361    fn plan(&self, start: Point2D, goal: Point2D) -> Result<Path2D, RoboticsError> {
362        self.plan_impl(start, goal)
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369    use rust_robotics_core::Obstacles;
370
371    fn create_simple_obstacles() -> (Vec<f64>, Vec<f64>) {
372        let mut ox = Vec::new();
373        let mut oy = Vec::new();
374        for i in 0..21 {
375            ox.push(i as f64);
376            oy.push(0.0);
377            ox.push(i as f64);
378            oy.push(20.0);
379            ox.push(0.0);
380            oy.push(i as f64);
381            ox.push(20.0);
382            oy.push(i as f64);
383        }
384        for i in 5..15 {
385            ox.push(10.0);
386            oy.push(i as f64);
387        }
388        (ox, oy)
389    }
390
391    #[test]
392    fn test_theta_star_finds_path() {
393        let (ox, oy) = create_simple_obstacles();
394        let planner = ThetaStarPlanner::from_obstacles(&ox, &oy, 1.0, 0.5);
395        let result = planner.plan(Point2D::new(2.0, 10.0), Point2D::new(18.0, 10.0));
396        assert!(result.is_ok());
397        assert!(!result.unwrap().is_empty());
398    }
399
400    #[test]
401    #[allow(deprecated)]
402    fn test_theta_star_legacy_interface() {
403        let (ox, oy) = create_simple_obstacles();
404        let planner = ThetaStarPlanner::from_obstacles(&ox, &oy, 1.0, 0.5);
405        let result = planner.planning(2.0, 10.0, 18.0, 10.0);
406        assert!(result.is_some());
407        let (rx, ry) = result.unwrap();
408        assert!(!rx.is_empty());
409        assert_eq!(rx.len(), ry.len());
410    }
411
412    #[test]
413    fn test_theta_star_shorter_than_a_star() {
414        let (ox, oy) = create_simple_obstacles();
415        let theta_planner = ThetaStarPlanner::from_obstacles(&ox, &oy, 1.0, 0.5);
416        let a_star_planner = crate::a_star::AStarPlanner::from_obstacles(&ox, &oy, 1.0, 0.5);
417        let start = Point2D::new(2.0, 2.0);
418        let goal = Point2D::new(18.0, 18.0);
419        let theta_path = theta_planner.plan(start, goal).unwrap();
420        let a_star_path = a_star_planner.plan(start, goal).unwrap();
421        let theta_length = theta_path.total_length();
422        let a_star_length = a_star_path.total_length();
423        assert!(
424            theta_length <= a_star_length + 0.1,
425            "Theta* path ({}) should not be significantly longer than A* path ({})",
426            theta_length,
427            a_star_length
428        );
429    }
430
431    #[test]
432    fn test_line_of_sight() {
433        let (ox, oy) = create_simple_obstacles();
434        let planner = ThetaStarPlanner::from_obstacles(&ox, &oy, 1.0, 0.5);
435        assert!(planner.line_of_sight(2, 2, 5, 5));
436        assert!(!planner.line_of_sight(5, 10, 15, 10));
437    }
438
439    #[test]
440    fn test_line_of_sight_blocks_corner_cutting() {
441        let open_obstacles = Obstacles::from_points(vec![
442            Point2D::new(0.0, 0.0),
443            Point2D::new(1.0, 0.0),
444            Point2D::new(2.0, 0.0),
445            Point2D::new(3.0, 0.0),
446            Point2D::new(0.0, 1.0),
447            Point2D::new(3.0, 1.0),
448            Point2D::new(0.0, 2.0),
449            Point2D::new(3.0, 2.0),
450            Point2D::new(0.0, 3.0),
451            Point2D::new(1.0, 3.0),
452            Point2D::new(2.0, 3.0),
453            Point2D::new(3.0, 3.0),
454        ]);
455        let open_planner =
456            ThetaStarPlanner::from_obstacle_points(&open_obstacles, ThetaStarConfig::default())
457                .unwrap();
458
459        assert!(open_planner.line_of_sight(1, 1, 2, 1));
460
461        let blocked_obstacles = Obstacles::from_points(vec![
462            Point2D::new(0.0, 0.0),
463            Point2D::new(1.0, 0.0),
464            Point2D::new(2.0, 0.0),
465            Point2D::new(3.0, 0.0),
466            Point2D::new(0.0, 1.0),
467            Point2D::new(3.0, 1.0),
468            Point2D::new(0.0, 2.0),
469            Point2D::new(3.0, 2.0),
470            Point2D::new(0.0, 3.0),
471            Point2D::new(1.0, 3.0),
472            Point2D::new(2.0, 3.0),
473            Point2D::new(3.0, 3.0),
474            Point2D::new(1.0, 2.0),
475            Point2D::new(2.0, 1.0),
476        ]);
477        let planner =
478            ThetaStarPlanner::from_obstacle_points(&blocked_obstacles, ThetaStarConfig::default())
479                .unwrap();
480
481        assert!(!planner.line_of_sight(1, 1, 2, 2));
482    }
483
484    #[test]
485    fn test_theta_star_from_obstacle_points() {
486        let (ox, oy) = create_simple_obstacles();
487        let obstacles = Obstacles::try_from_xy(&ox, &oy).unwrap();
488        let planner =
489            ThetaStarPlanner::from_obstacle_points(&obstacles, ThetaStarConfig::default()).unwrap();
490        let path = planner.plan_xy(2.0, 10.0, 18.0, 10.0).unwrap();
491        assert!(!path.is_empty());
492    }
493
494    #[test]
495    fn test_theta_star_try_new_rejects_invalid_config() {
496        let (ox, oy) = create_simple_obstacles();
497        let config = ThetaStarConfig {
498            heuristic_weight: 0.0,
499            ..Default::default()
500        };
501        let err = match ThetaStarPlanner::try_new(&ox, &oy, config) {
502            Ok(_) => panic!("expected invalid config to be rejected"),
503            Err(err) => err,
504        };
505        assert!(matches!(err, RoboticsError::InvalidParameter(_)));
506    }
507}