1use std::cmp::Ordering;
16use std::collections::{BinaryHeap, HashMap};
17
18use crate::grid::{GridMap, Node};
19use rust_robotics_core::{Obstacles, Path2D, PathPlanner, Point2D, RoboticsError, RoboticsResult};
20
21#[derive(Debug, Clone)]
23pub struct ThetaStarConfig {
24 pub resolution: f64,
25 pub robot_radius: f64,
26 pub heuristic_weight: f64,
27}
28
29impl Default for ThetaStarConfig {
30 fn default() -> Self {
31 Self {
32 resolution: 1.0,
33 robot_radius: 0.5,
34 heuristic_weight: 1.0,
35 }
36 }
37}
38
39impl ThetaStarConfig {
40 pub fn validate(&self) -> RoboticsResult<()> {
41 if !self.resolution.is_finite() || self.resolution <= 0.0 {
42 return Err(RoboticsError::InvalidParameter(format!(
43 "resolution must be positive and finite, got {}",
44 self.resolution
45 )));
46 }
47 if !self.robot_radius.is_finite() || self.robot_radius < 0.0 {
48 return Err(RoboticsError::InvalidParameter(format!(
49 "robot_radius must be non-negative and finite, got {}",
50 self.robot_radius
51 )));
52 }
53 if !self.heuristic_weight.is_finite() || self.heuristic_weight <= 0.0 {
54 return Err(RoboticsError::InvalidParameter(format!(
55 "heuristic_weight must be positive and finite, got {}",
56 self.heuristic_weight
57 )));
58 }
59 Ok(())
60 }
61}
62
63#[derive(Debug)]
64struct PriorityNode {
65 x: i32,
66 y: i32,
67 cost: f64,
68 priority: f64,
69 index: usize,
70}
71impl Eq for PriorityNode {}
72impl PartialEq for PriorityNode {
73 fn eq(&self, other: &Self) -> bool {
74 self.priority == other.priority
75 }
76}
77impl Ord for PriorityNode {
78 fn cmp(&self, other: &Self) -> Ordering {
79 other
80 .priority
81 .partial_cmp(&self.priority)
82 .unwrap_or(Ordering::Equal)
83 }
84}
85impl PartialOrd for PriorityNode {
86 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
87 Some(self.cmp(other))
88 }
89}
90
91pub struct ThetaStarPlanner {
92 grid_map: GridMap,
93 config: ThetaStarConfig,
94 motion: Vec<(i32, i32, f64)>,
95}
96
97impl ThetaStarPlanner {
98 pub fn new(ox: &[f64], oy: &[f64], config: ThetaStarConfig) -> Self {
99 Self::try_new(ox, oy, config).expect(
100 "invalid Theta* planner input: obstacle list must be non-empty and valid, and config values must be positive/finite",
101 )
102 }
103
104 pub fn try_new(ox: &[f64], oy: &[f64], config: ThetaStarConfig) -> RoboticsResult<Self> {
105 config.validate()?;
106 let grid_map = GridMap::try_new(ox, oy, config.resolution, config.robot_radius)?;
107 let motion = Self::get_motion_model();
108 Ok(ThetaStarPlanner {
109 grid_map,
110 config,
111 motion,
112 })
113 }
114
115 pub fn from_obstacles(ox: &[f64], oy: &[f64], resolution: f64, robot_radius: f64) -> Self {
116 let config = ThetaStarConfig {
117 resolution,
118 robot_radius,
119 ..Default::default()
120 };
121 Self::new(ox, oy, config)
122 }
123
124 pub fn from_obstacle_points(
125 obstacles: &Obstacles,
126 config: ThetaStarConfig,
127 ) -> RoboticsResult<Self> {
128 config.validate()?;
129 let grid_map = GridMap::from_obstacles(obstacles, config.resolution, config.robot_radius)?;
130 let motion = Self::get_motion_model();
131 Ok(ThetaStarPlanner {
132 grid_map,
133 config,
134 motion,
135 })
136 }
137
138 #[deprecated(note = "use plan() or plan_xy() instead")]
139 pub fn planning(&self, sx: f64, sy: f64, gx: f64, gy: f64) -> Option<(Vec<f64>, Vec<f64>)> {
140 match self.plan_xy(sx, sy, gx, gy) {
141 Ok(path) => Some((path.x_coords(), path.y_coords())),
142 Err(_) => None,
143 }
144 }
145
146 pub fn plan(&self, start: Point2D, goal: Point2D) -> RoboticsResult<Path2D> {
147 self.plan_impl(start, goal)
148 }
149
150 pub fn plan_xy(&self, sx: f64, sy: f64, gx: f64, gy: f64) -> RoboticsResult<Path2D> {
151 self.plan_impl(Point2D::new(sx, sy), Point2D::new(gx, gy))
152 }
153
154 pub fn grid_map(&self) -> &GridMap {
155 &self.grid_map
156 }
157
158 fn calc_heuristic(&self, n1_x: i32, n1_y: i32, n2_x: i32, n2_y: i32) -> f64 {
159 self.config.heuristic_weight * (((n1_x - n2_x).pow(2) + (n1_y - n2_y).pow(2)) as f64).sqrt()
160 }
161
162 fn get_motion_model() -> Vec<(i32, i32, f64)> {
163 vec![
164 (1, 0, 1.0),
165 (0, 1, 1.0),
166 (-1, 0, 1.0),
167 (0, -1, 1.0),
168 (-1, -1, std::f64::consts::SQRT_2),
169 (-1, 1, std::f64::consts::SQRT_2),
170 (1, -1, std::f64::consts::SQRT_2),
171 (1, 1, std::f64::consts::SQRT_2),
172 ]
173 }
174
175 fn line_of_sight(&self, x0: i32, y0: i32, x1: i32, y1: i32) -> bool {
176 if !self.grid_map.is_valid(x0, y0) || !self.grid_map.is_valid(x1, y1) {
177 return false;
178 }
179
180 if x0 == x1 && y0 == y1 {
181 return true;
182 }
183
184 let dx = x1 - x0;
185 let dy = y1 - y0;
186 let step_x = dx.signum();
187 let step_y = dy.signum();
188 let abs_dx = dx.abs() as f64;
189 let abs_dy = dy.abs() as f64;
190
191 let mut x = x0;
192 let mut y = y0;
193 let mut t_max_x = if step_x != 0 {
194 0.5 / abs_dx
195 } else {
196 f64::INFINITY
197 };
198 let mut t_max_y = if step_y != 0 {
199 0.5 / abs_dy
200 } else {
201 f64::INFINITY
202 };
203 let t_delta_x = if step_x != 0 {
204 1.0 / abs_dx
205 } else {
206 f64::INFINITY
207 };
208 let t_delta_y = if step_y != 0 {
209 1.0 / abs_dy
210 } else {
211 f64::INFINITY
212 };
213
214 while x != x1 || y != y1 {
215 let advance_x = t_max_x <= t_max_y;
216 let advance_y = t_max_y <= t_max_x;
217 let next_x = if advance_x { x + step_x } else { x };
218 let next_y = if advance_y { y + step_y } else { y };
219
220 if !self.grid_map.is_valid_step(x, y, next_x, next_y) {
221 return false;
222 }
223
224 x = next_x;
225 y = next_y;
226
227 if advance_x {
228 t_max_x += t_delta_x;
229 }
230 if advance_y {
231 t_max_y += t_delta_y;
232 }
233 }
234
235 true
236 }
237
238 fn euclidean_distance(&self, x1: i32, y1: i32, x2: i32, y2: i32) -> f64 {
239 (((x1 - x2).pow(2) + (y1 - y2).pow(2)) as f64).sqrt()
240 }
241
242 fn build_path(&self, goal_index: usize, node_storage: &[Node]) -> Path2D {
243 let mut points = Vec::new();
244 let mut current_index = Some(goal_index);
245 while let Some(index) = current_index {
246 let node = &node_storage[index];
247 points.push(Point2D::new(
248 self.grid_map.calc_x_position(node.x),
249 self.grid_map.calc_y_position(node.y),
250 ));
251 current_index = node.parent_index;
252 }
253 points.reverse();
254 Path2D::from_points(points)
255 }
256
257 fn ensure_query_is_valid(&self, x: i32, y: i32, label: &str) -> RoboticsResult<()> {
258 if self.grid_map.is_valid(x, y) {
259 return Ok(());
260 }
261 Err(RoboticsError::PlanningError(format!(
262 "{} position is invalid",
263 label
264 )))
265 }
266
267 fn plan_impl(&self, start: Point2D, goal: Point2D) -> RoboticsResult<Path2D> {
268 let start_x = self.grid_map.calc_x_index(start.x);
269 let start_y = self.grid_map.calc_y_index(start.y);
270 let goal_x = self.grid_map.calc_x_index(goal.x);
271 let goal_y = self.grid_map.calc_y_index(goal.y);
272
273 self.ensure_query_is_valid(start_x, start_y, "Start")?;
274 self.ensure_query_is_valid(goal_x, goal_y, "Goal")?;
275
276 let mut open_set = BinaryHeap::new();
277 let mut closed_set = HashMap::new();
278 let mut node_storage: Vec<Node> = Vec::new();
279 let mut g_values: HashMap<i32, f64> = HashMap::new();
280 let mut best_index: HashMap<i32, usize> = HashMap::new();
281
282 node_storage.push(Node::new(start_x, start_y, 0.0, None));
283 let start_index = 0;
284 let start_grid_index = self.grid_map.calc_index(start_x, start_y);
285 g_values.insert(start_grid_index, 0.0);
286 best_index.insert(start_grid_index, start_index);
287
288 open_set.push(PriorityNode {
289 x: start_x,
290 y: start_y,
291 cost: 0.0,
292 priority: self.calc_heuristic(start_x, start_y, goal_x, goal_y),
293 index: start_index,
294 });
295
296 while let Some(current) = open_set.pop() {
297 let current_grid_index = self.grid_map.calc_index(current.x, current.y);
298 if current.x == goal_x && current.y == goal_y {
299 return Ok(self.build_path(current.index, &node_storage));
300 }
301 if closed_set.contains_key(¤t_grid_index) {
302 continue;
303 }
304 closed_set.insert(current_grid_index, current.index);
305
306 let current_node = &node_storage[current.index];
307 let parent_index = current_node.parent_index;
308
309 for &(dx, dy, _) in &self.motion {
310 let new_x = current.x + dx;
311 let new_y = current.y + dy;
312 let new_grid_index = self.grid_map.calc_index(new_x, new_y);
313 if !self.grid_map.is_valid_offset(current.x, current.y, dx, dy) {
314 continue;
315 }
316 if closed_set.contains_key(&new_grid_index) {
317 continue;
318 }
319
320 let (new_cost, new_parent_index) = if let Some(p_idx) = parent_index {
321 let parent_node = &node_storage[p_idx];
322 if self.line_of_sight(parent_node.x, parent_node.y, new_x, new_y) {
323 let dist =
324 self.euclidean_distance(parent_node.x, parent_node.y, new_x, new_y);
325 (parent_node.cost + dist, Some(p_idx))
326 } else {
327 let dist = self.euclidean_distance(current.x, current.y, new_x, new_y);
328 (current.cost + dist, Some(current.index))
329 }
330 } else {
331 let dist = self.euclidean_distance(current.x, current.y, new_x, new_y);
332 (current.cost + dist, Some(current.index))
333 };
334
335 let existing_g = g_values
336 .get(&new_grid_index)
337 .copied()
338 .unwrap_or(f64::INFINITY);
339 if new_cost < existing_g {
340 g_values.insert(new_grid_index, new_cost);
341 node_storage.push(Node::new(new_x, new_y, new_cost, new_parent_index));
342 let new_index = node_storage.len() - 1;
343 best_index.insert(new_grid_index, new_index);
344 let priority = new_cost + self.calc_heuristic(new_x, new_y, goal_x, goal_y);
345 open_set.push(PriorityNode {
346 x: new_x,
347 y: new_y,
348 cost: new_cost,
349 priority,
350 index: new_index,
351 });
352 }
353 }
354 }
355
356 Err(RoboticsError::PlanningError("No path found".to_string()))
357 }
358}
359
360impl PathPlanner for ThetaStarPlanner {
361 fn plan(&self, start: Point2D, goal: Point2D) -> Result<Path2D, RoboticsError> {
362 self.plan_impl(start, goal)
363 }
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use rust_robotics_core::Obstacles;
370
371 fn create_simple_obstacles() -> (Vec<f64>, Vec<f64>) {
372 let mut ox = Vec::new();
373 let mut oy = Vec::new();
374 for i in 0..21 {
375 ox.push(i as f64);
376 oy.push(0.0);
377 ox.push(i as f64);
378 oy.push(20.0);
379 ox.push(0.0);
380 oy.push(i as f64);
381 ox.push(20.0);
382 oy.push(i as f64);
383 }
384 for i in 5..15 {
385 ox.push(10.0);
386 oy.push(i as f64);
387 }
388 (ox, oy)
389 }
390
391 #[test]
392 fn test_theta_star_finds_path() {
393 let (ox, oy) = create_simple_obstacles();
394 let planner = ThetaStarPlanner::from_obstacles(&ox, &oy, 1.0, 0.5);
395 let result = planner.plan(Point2D::new(2.0, 10.0), Point2D::new(18.0, 10.0));
396 assert!(result.is_ok());
397 assert!(!result.unwrap().is_empty());
398 }
399
400 #[test]
401 #[allow(deprecated)]
402 fn test_theta_star_legacy_interface() {
403 let (ox, oy) = create_simple_obstacles();
404 let planner = ThetaStarPlanner::from_obstacles(&ox, &oy, 1.0, 0.5);
405 let result = planner.planning(2.0, 10.0, 18.0, 10.0);
406 assert!(result.is_some());
407 let (rx, ry) = result.unwrap();
408 assert!(!rx.is_empty());
409 assert_eq!(rx.len(), ry.len());
410 }
411
412 #[test]
413 fn test_theta_star_shorter_than_a_star() {
414 let (ox, oy) = create_simple_obstacles();
415 let theta_planner = ThetaStarPlanner::from_obstacles(&ox, &oy, 1.0, 0.5);
416 let a_star_planner = crate::a_star::AStarPlanner::from_obstacles(&ox, &oy, 1.0, 0.5);
417 let start = Point2D::new(2.0, 2.0);
418 let goal = Point2D::new(18.0, 18.0);
419 let theta_path = theta_planner.plan(start, goal).unwrap();
420 let a_star_path = a_star_planner.plan(start, goal).unwrap();
421 let theta_length = theta_path.total_length();
422 let a_star_length = a_star_path.total_length();
423 assert!(
424 theta_length <= a_star_length + 0.1,
425 "Theta* path ({}) should not be significantly longer than A* path ({})",
426 theta_length,
427 a_star_length
428 );
429 }
430
431 #[test]
432 fn test_line_of_sight() {
433 let (ox, oy) = create_simple_obstacles();
434 let planner = ThetaStarPlanner::from_obstacles(&ox, &oy, 1.0, 0.5);
435 assert!(planner.line_of_sight(2, 2, 5, 5));
436 assert!(!planner.line_of_sight(5, 10, 15, 10));
437 }
438
439 #[test]
440 fn test_line_of_sight_blocks_corner_cutting() {
441 let open_obstacles = Obstacles::from_points(vec![
442 Point2D::new(0.0, 0.0),
443 Point2D::new(1.0, 0.0),
444 Point2D::new(2.0, 0.0),
445 Point2D::new(3.0, 0.0),
446 Point2D::new(0.0, 1.0),
447 Point2D::new(3.0, 1.0),
448 Point2D::new(0.0, 2.0),
449 Point2D::new(3.0, 2.0),
450 Point2D::new(0.0, 3.0),
451 Point2D::new(1.0, 3.0),
452 Point2D::new(2.0, 3.0),
453 Point2D::new(3.0, 3.0),
454 ]);
455 let open_planner =
456 ThetaStarPlanner::from_obstacle_points(&open_obstacles, ThetaStarConfig::default())
457 .unwrap();
458
459 assert!(open_planner.line_of_sight(1, 1, 2, 1));
460
461 let blocked_obstacles = Obstacles::from_points(vec![
462 Point2D::new(0.0, 0.0),
463 Point2D::new(1.0, 0.0),
464 Point2D::new(2.0, 0.0),
465 Point2D::new(3.0, 0.0),
466 Point2D::new(0.0, 1.0),
467 Point2D::new(3.0, 1.0),
468 Point2D::new(0.0, 2.0),
469 Point2D::new(3.0, 2.0),
470 Point2D::new(0.0, 3.0),
471 Point2D::new(1.0, 3.0),
472 Point2D::new(2.0, 3.0),
473 Point2D::new(3.0, 3.0),
474 Point2D::new(1.0, 2.0),
475 Point2D::new(2.0, 1.0),
476 ]);
477 let planner =
478 ThetaStarPlanner::from_obstacle_points(&blocked_obstacles, ThetaStarConfig::default())
479 .unwrap();
480
481 assert!(!planner.line_of_sight(1, 1, 2, 2));
482 }
483
484 #[test]
485 fn test_theta_star_from_obstacle_points() {
486 let (ox, oy) = create_simple_obstacles();
487 let obstacles = Obstacles::try_from_xy(&ox, &oy).unwrap();
488 let planner =
489 ThetaStarPlanner::from_obstacle_points(&obstacles, ThetaStarConfig::default()).unwrap();
490 let path = planner.plan_xy(2.0, 10.0, 18.0, 10.0).unwrap();
491 assert!(!path.is_empty());
492 }
493
494 #[test]
495 fn test_theta_star_try_new_rejects_invalid_config() {
496 let (ox, oy) = create_simple_obstacles();
497 let config = ThetaStarConfig {
498 heuristic_weight: 0.0,
499 ..Default::default()
500 };
501 let err = match ThetaStarPlanner::try_new(&ox, &oy, config) {
502 Ok(_) => panic!("expected invalid config to be rejected"),
503 Err(err) => err,
504 };
505 assert!(matches!(err, RoboticsError::InvalidParameter(_)));
506 }
507}