1#![allow(dead_code, clippy::too_many_arguments)]
2
3use std::f64::consts::PI;
23
24use rand::Rng;
25
26use rust_robotics_core::types::Pose2D;
27
28use crate::dubins_path::DubinsPlanner;
29use crate::rrt::{AreaBounds, CircleObstacle};
30use crate::rrt_dubins::{sample_dubins_with_yaw, RRTDubinsNode};
31
32#[derive(Debug, Clone)]
34pub struct RRTStarDubinsConfig {
35 pub curvature: f64,
37 pub goal_sample_rate: i32,
39 pub max_iter: usize,
41 pub goal_xy_threshold: f64,
43 pub goal_yaw_threshold: f64,
45 pub robot_radius: f64,
47 pub path_resolution: f64,
49 pub connect_circle_dist: f64,
51}
52
53impl Default for RRTStarDubinsConfig {
54 fn default() -> Self {
55 Self {
56 curvature: 1.0,
57 goal_sample_rate: 10,
58 max_iter: 200,
59 goal_xy_threshold: 0.5,
60 goal_yaw_threshold: 1.0_f64.to_radians(),
61 robot_radius: 0.0,
62 path_resolution: 0.1,
63 connect_circle_dist: 50.0,
64 }
65 }
66}
67
68#[derive(Debug, Clone)]
70pub struct RRTStarDubinsPath {
71 pub poses: Vec<Pose2D>,
73}
74
75pub struct RRTStarDubinsPlanner {
77 config: RRTStarDubinsConfig,
78 obstacles: Vec<CircleObstacle>,
79 rand_area: AreaBounds,
80 dubins: DubinsPlanner,
81 node_list: Vec<RRTDubinsNode>,
82 start: RRTDubinsNode,
83 goal: RRTDubinsNode,
84}
85
86impl RRTStarDubinsPlanner {
87 pub fn new(
89 obstacles: Vec<CircleObstacle>,
90 rand_area: AreaBounds,
91 config: RRTStarDubinsConfig,
92 ) -> Self {
93 let dubins = DubinsPlanner::new(config.curvature);
94 Self {
95 config,
96 obstacles,
97 rand_area,
98 dubins,
99 node_list: Vec::new(),
100 start: RRTDubinsNode::new(0.0, 0.0, 0.0),
101 goal: RRTDubinsNode::new(0.0, 0.0, 0.0),
102 }
103 }
104
105 pub fn planning(&mut self, start: Pose2D, goal: Pose2D) -> Option<RRTStarDubinsPath> {
109 self.start = RRTDubinsNode::new(start.x, start.y, start.yaw);
110 self.goal = RRTDubinsNode::new(goal.x, goal.y, goal.yaw);
111 self.node_list = vec![self.start.clone()];
112
113 for _ in 0..self.config.max_iter {
114 let rnd = self.get_random_node();
115 let nearest_ind = self.get_nearest_node_index(&rnd);
116
117 if let Some(new_node) =
118 self.steer(&self.node_list[nearest_ind].clone(), &rnd, nearest_ind)
119 {
120 if self.check_collision(&new_node) {
121 let near_inds = self.find_near_nodes(&new_node);
122 if let Some(best_node) = self.choose_parent(new_node, &near_inds) {
123 let new_index = self.node_list.len();
124 self.node_list.push(best_node);
125 self.rewire(new_index, &near_inds);
126 }
127 }
128 }
129
130 if let Some(goal_idx) = self.search_best_goal_node() {
131 return Some(self.generate_final_course(goal_idx));
132 }
133 }
134
135 self.search_best_goal_node()
136 .map(|idx| self.generate_final_course(idx))
137 }
138
139 pub fn plan_with_sampler<F>(
141 &mut self,
142 start: Pose2D,
143 goal: Pose2D,
144 mut sample_node: F,
145 ) -> Option<RRTStarDubinsPath>
146 where
147 F: FnMut(&Self) -> RRTDubinsNode,
148 {
149 self.start = RRTDubinsNode::new(start.x, start.y, start.yaw);
150 self.goal = RRTDubinsNode::new(goal.x, goal.y, goal.yaw);
151 self.node_list = vec![self.start.clone()];
152
153 for _ in 0..self.config.max_iter {
154 let rnd = sample_node(self);
155 let nearest_ind = self.get_nearest_node_index(&rnd);
156
157 if let Some(new_node) =
158 self.steer(&self.node_list[nearest_ind].clone(), &rnd, nearest_ind)
159 {
160 if self.check_collision(&new_node) {
161 let near_inds = self.find_near_nodes(&new_node);
162 if let Some(best_node) = self.choose_parent(new_node, &near_inds) {
163 let new_index = self.node_list.len();
164 self.node_list.push(best_node);
165 self.rewire(new_index, &near_inds);
166 }
167 }
168 }
169
170 if let Some(goal_idx) = self.search_best_goal_node() {
171 return Some(self.generate_final_course(goal_idx));
172 }
173 }
174
175 self.search_best_goal_node()
176 .map(|idx| self.generate_final_course(idx))
177 }
178
179 pub fn get_tree(&self) -> &[RRTDubinsNode] {
181 &self.node_list
182 }
183
184 fn get_random_node(&self) -> RRTDubinsNode {
189 let mut rng = rand::rng();
190 if rng.random_range(0..=100) > self.config.goal_sample_rate {
191 RRTDubinsNode::new(
192 rng.random_range(self.rand_area.xmin..=self.rand_area.xmax),
193 rng.random_range(self.rand_area.ymin..=self.rand_area.ymax),
194 rng.random_range(-PI..=PI),
195 )
196 } else {
197 RRTDubinsNode::new(self.goal.x, self.goal.y, self.goal.yaw)
198 }
199 }
200
201 fn get_nearest_node_index(&self, rnd_node: &RRTDubinsNode) -> usize {
202 let mut min_dist = f64::INFINITY;
203 let mut min_ind = 0;
204 for (i, node) in self.node_list.iter().enumerate() {
205 let dist = (node.x - rnd_node.x).powi(2) + (node.y - rnd_node.y).powi(2);
206 if dist < min_dist {
207 min_dist = dist;
208 min_ind = i;
209 }
210 }
211 min_ind
212 }
213
214 fn steer(
216 &self,
217 from_node: &RRTDubinsNode,
218 to_node: &RRTDubinsNode,
219 from_index: usize,
220 ) -> Option<RRTDubinsNode> {
221 let from_pose = Pose2D::new(from_node.x, from_node.y, from_node.yaw);
222 let to_pose = Pose2D::new(to_node.x, to_node.y, to_node.yaw);
223
224 let dubins_path = self.dubins.plan(from_pose, to_pose).ok()?;
225
226 let (px, py, pyaw) = sample_dubins_with_yaw(&dubins_path, self.config.path_resolution);
227
228 if px.len() <= 1 {
229 return None;
230 }
231
232 let mut new_node = from_node.clone();
233 new_node.x = *px.last().unwrap();
234 new_node.y = *py.last().unwrap();
235 new_node.yaw = *pyaw.last().unwrap();
236 new_node.path_x = px;
237 new_node.path_y = py;
238 new_node.path_yaw = pyaw;
239 new_node.cost += dubins_path.total_length;
240 new_node.parent = Some(from_index);
241
242 Some(new_node)
243 }
244
245 fn calc_new_cost(&self, from_node: &RRTDubinsNode, to_node: &RRTDubinsNode) -> f64 {
247 let from_pose = Pose2D::new(from_node.x, from_node.y, from_node.yaw);
248 let to_pose = Pose2D::new(to_node.x, to_node.y, to_node.yaw);
249
250 match self.dubins.plan(from_pose, to_pose) {
251 Ok(path) => from_node.cost + path.total_length,
252 Err(_) => f64::INFINITY,
253 }
254 }
255
256 fn check_collision(&self, node: &RRTDubinsNode) -> bool {
258 for obs in &self.obstacles {
259 for (&px, &py) in node.path_x.iter().zip(node.path_y.iter()) {
260 let dx = obs.x - px;
261 let dy = obs.y - py;
262 let d = (dx * dx + dy * dy).sqrt();
263 if d <= obs.radius + self.config.robot_radius {
264 return false;
265 }
266 }
267 }
268 true
269 }
270
271 fn find_near_nodes(&self, new_node: &RRTDubinsNode) -> Vec<usize> {
273 let nnode = self.node_list.len() + 1;
274 let r = self.config.connect_circle_dist * ((nnode as f64).ln() / nnode as f64).sqrt();
275
276 self.node_list
277 .iter()
278 .enumerate()
279 .filter_map(|(i, node)| {
280 let dist_sq = (node.x - new_node.x).powi(2) + (node.y - new_node.y).powi(2);
281 if dist_sq <= r * r {
282 Some(i)
283 } else {
284 None
285 }
286 })
287 .collect()
288 }
289
290 fn choose_parent(&self, new_node: RRTDubinsNode, near_inds: &[usize]) -> Option<RRTDubinsNode> {
292 if near_inds.is_empty() {
293 return Some(new_node);
294 }
295
296 let mut best_cost = f64::INFINITY;
297 let mut best_index: Option<usize> = None;
298
299 for &i in near_inds {
300 let near_node = &self.node_list[i];
301 if let Some(t_node) = self.steer(near_node, &new_node, i) {
302 if self.check_collision(&t_node) {
303 let cost = self.calc_new_cost(near_node, &new_node);
304 if cost < best_cost {
305 best_cost = cost;
306 best_index = Some(i);
307 }
308 }
309 }
310 }
311
312 let parent_ind = best_index?;
313 let mut result = self.steer(&self.node_list[parent_ind], &new_node, parent_ind)?;
314 result.cost = best_cost;
315 Some(result)
316 }
317
318 fn rewire(&mut self, new_node_ind: usize, near_inds: &[usize]) {
320 for &i in near_inds {
321 let near_node = self.node_list[i].clone();
322 let new_node = &self.node_list[new_node_ind];
323
324 let edge_cost = self.calc_new_cost(new_node, &near_node);
325 if edge_cost >= near_node.cost {
326 continue;
327 }
328
329 if let Some(mut edge_node) = self.steer(
330 &self.node_list[new_node_ind].clone(),
331 &near_node,
332 new_node_ind,
333 ) {
334 if self.check_collision(&edge_node) {
335 edge_node.cost = edge_cost;
336 self.node_list[i] = edge_node;
337 self.propagate_cost_to_leaves(i);
338 }
339 }
340 }
341 }
342
343 fn propagate_cost_to_leaves(&mut self, parent_ind: usize) {
345 for i in 0..self.node_list.len() {
346 if let Some(p) = self.node_list[i].parent {
347 if p == parent_ind {
348 self.node_list[i].cost = self.calc_new_cost(
349 &self.node_list[parent_ind].clone(),
350 &self.node_list[i].clone(),
351 );
352 self.propagate_cost_to_leaves(i);
353 }
354 }
355 }
356 }
357
358 fn search_best_goal_node(&self) -> Option<usize> {
360 let mut candidates: Vec<usize> = Vec::new();
361
362 for (i, node) in self.node_list.iter().enumerate() {
363 let dx = node.x - self.goal.x;
364 let dy = node.y - self.goal.y;
365 let dist = (dx * dx + dy * dy).sqrt();
366 if dist > self.config.goal_xy_threshold {
367 continue;
368 }
369 let dyaw = angle_diff(node.yaw, self.goal.yaw).abs();
370 if dyaw > self.config.goal_yaw_threshold {
371 continue;
372 }
373 candidates.push(i);
374 }
375
376 if candidates.is_empty() {
377 return None;
378 }
379
380 candidates.into_iter().min_by(|&a, &b| {
381 self.node_list[a]
382 .cost
383 .partial_cmp(&self.node_list[b].cost)
384 .unwrap_or(std::cmp::Ordering::Equal)
385 })
386 }
387
388 fn generate_final_course(&self, goal_index: usize) -> RRTStarDubinsPath {
390 let mut poses: Vec<Pose2D> = vec![Pose2D::new(self.goal.x, self.goal.y, self.goal.yaw)];
391
392 let mut node = &self.node_list[goal_index];
393 while node.parent.is_some() {
394 for ((&px, &py), &pyaw) in node
395 .path_x
396 .iter()
397 .rev()
398 .zip(node.path_y.iter().rev())
399 .zip(node.path_yaw.iter().rev())
400 {
401 poses.push(Pose2D::new(px, py, pyaw));
402 }
403 node = &self.node_list[node.parent.unwrap()];
404 }
405
406 poses.push(Pose2D::new(self.start.x, self.start.y, self.start.yaw));
407 poses.reverse();
408
409 RRTStarDubinsPath { poses }
410 }
411}
412
413fn angle_diff(a: f64, b: f64) -> f64 {
419 let mut d = a - b;
420 while d > PI {
421 d -= 2.0 * PI;
422 }
423 while d < -PI {
424 d += 2.0 * PI;
425 }
426 d
427}
428
429#[cfg(test)]
434mod tests {
435 use super::*;
436
437 fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
438 (a - b).abs() < tol
439 }
440
441 fn create_obstacle_list() -> Vec<CircleObstacle> {
442 vec![
443 CircleObstacle::new(5.0, 5.0, 1.0),
444 CircleObstacle::new(3.0, 6.0, 2.0),
445 CircleObstacle::new(3.0, 8.0, 2.0),
446 CircleObstacle::new(3.0, 10.0, 2.0),
447 CircleObstacle::new(7.0, 5.0, 2.0),
448 CircleObstacle::new(9.0, 5.0, 2.0),
449 ]
450 }
451
452 #[test]
453 fn test_config_default() {
454 let config = RRTStarDubinsConfig::default();
455 assert_eq!(config.curvature, 1.0);
456 assert_eq!(config.max_iter, 200);
457 assert_eq!(config.goal_sample_rate, 10);
458 assert_eq!(config.connect_circle_dist, 50.0);
459 assert!(approx_eq(config.goal_xy_threshold, 0.5, 1e-12));
460 }
461
462 #[test]
463 fn test_angle_diff() {
464 assert!(approx_eq(angle_diff(0.0, 0.0), 0.0, 1e-12));
465 assert!(approx_eq(angle_diff(PI, 0.0), PI, 1e-12));
466 assert!(approx_eq(angle_diff(0.0, PI), -PI, 1e-12));
467 assert!(approx_eq(angle_diff(3.0 * PI / 2.0, -PI / 2.0), 0.0, 1e-12));
468 }
469
470 #[test]
471 fn test_collision_check_no_obstacles() {
472 let config = RRTStarDubinsConfig::default();
473 let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
474 let planner = RRTStarDubinsPlanner::new(vec![], rand_area, config);
475
476 let mut node = RRTDubinsNode::new(1.0, 1.0, 0.0);
477 node.path_x = vec![0.0, 0.5, 1.0];
478 node.path_y = vec![0.0, 0.5, 1.0];
479 assert!(planner.check_collision(&node));
480 }
481
482 #[test]
483 fn test_collision_check_with_obstacle() {
484 let obstacles = vec![CircleObstacle::new(0.5, 0.5, 0.3)];
485 let config = RRTStarDubinsConfig::default();
486 let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
487 let planner = RRTStarDubinsPlanner::new(obstacles, rand_area, config);
488
489 let mut node = RRTDubinsNode::new(1.0, 1.0, 0.0);
490 node.path_x = vec![0.0, 0.5, 1.0];
491 node.path_y = vec![0.0, 0.5, 1.0];
492 assert!(!planner.check_collision(&node));
493 }
494
495 #[test]
496 fn test_find_near_nodes() {
497 let config = RRTStarDubinsConfig {
498 connect_circle_dist: 50.0,
499 ..Default::default()
500 };
501 let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
502 let mut planner = RRTStarDubinsPlanner::new(vec![], rand_area, config);
503 planner.node_list = vec![
504 RRTDubinsNode::new(0.0, 0.0, 0.0),
505 RRTDubinsNode::new(1.0, 1.0, 0.0),
506 RRTDubinsNode::new(100.0, 100.0, 0.0),
507 ];
508
509 let query = RRTDubinsNode::new(0.5, 0.5, 0.0);
510 let near = planner.find_near_nodes(&query);
511 assert!(near.contains(&0));
513 assert!(near.contains(&1));
514 assert!(!near.contains(&2));
515 }
516
517 #[test]
518 fn test_search_best_goal_node() {
519 let config = RRTStarDubinsConfig {
520 goal_xy_threshold: 1.0,
521 goal_yaw_threshold: 0.5,
522 ..Default::default()
523 };
524 let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
525 let mut planner = RRTStarDubinsPlanner::new(vec![], rand_area, config);
526 planner.goal = RRTDubinsNode::new(10.0, 10.0, 0.0);
527
528 planner.node_list = vec![RRTDubinsNode::new(0.0, 0.0, 0.0)];
530 assert!(planner.search_best_goal_node().is_none());
531
532 let mut n = RRTDubinsNode::new(10.0, 10.0, PI);
534 n.cost = 5.0;
535 planner.node_list.push(n);
536 assert!(planner.search_best_goal_node().is_none());
537
538 let mut n = RRTDubinsNode::new(10.2, 10.2, 0.1);
540 n.cost = 8.0;
541 planner.node_list.push(n);
542 assert_eq!(planner.search_best_goal_node(), Some(2));
543 }
544
545 #[test]
546 fn test_deterministic_planning_no_obstacles() {
547 let config = RRTStarDubinsConfig {
548 curvature: 1.0,
549 goal_sample_rate: 10,
550 max_iter: 500,
551 goal_xy_threshold: 1.5,
552 goal_yaw_threshold: 0.5,
553 robot_radius: 0.0,
554 path_resolution: 0.3,
555 connect_circle_dist: 50.0,
556 };
557 let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
558 let mut planner = RRTStarDubinsPlanner::new(vec![], rand_area, config);
559
560 let start = Pose2D::new(0.0, 0.0, 0.0);
561 let goal = Pose2D::new(5.0, 5.0, 0.0);
562
563 let result = planner.plan_with_sampler(start, goal, |p| {
564 RRTDubinsNode::new(p.goal.x, p.goal.y, p.goal.yaw)
565 });
566
567 assert!(result.is_some(), "Should find a path with no obstacles");
568 let path = result.unwrap();
569 assert!(path.poses.len() >= 2);
570
571 let first = &path.poses[0];
572 let last = &path.poses[path.poses.len() - 1];
573 assert!(approx_eq(first.x, 0.0, 0.5));
574 assert!(approx_eq(first.y, 0.0, 0.5));
575 assert!(approx_eq(last.x, 5.0, 1.5));
576 assert!(approx_eq(last.y, 5.0, 1.5));
577 }
578
579 #[test]
580 fn test_rewiring_improves_cost() {
581 let config = RRTStarDubinsConfig {
583 curvature: 1.0,
584 max_iter: 5,
585 goal_xy_threshold: 1.5,
586 goal_yaw_threshold: 1.0,
587 path_resolution: 0.3,
588 connect_circle_dist: 100.0,
589 ..Default::default()
590 };
591 let rand_area = AreaBounds::new(-5.0, 20.0, -5.0, 20.0);
592 let mut planner = RRTStarDubinsPlanner::new(vec![], rand_area, config);
593 planner.start = RRTDubinsNode::new(0.0, 0.0, 0.0);
594 planner.goal = RRTDubinsNode::new(20.0, 20.0, 0.0);
595 planner.node_list = vec![planner.start.clone()];
596
597 let detour = RRTDubinsNode::new(5.0, 0.0, 0.0);
599 if let Some(n) = planner.steer(&planner.node_list[0].clone(), &detour, 0) {
600 planner.node_list.push(n);
601 }
602
603 let initial_count = planner.node_list.len();
604 assert!(
605 initial_count >= 2,
606 "Should have at least 2 nodes after setup"
607 );
608 }
609
610 #[test]
611 fn test_planning_with_obstacles_runs_without_panic() {
612 let obstacles = create_obstacle_list();
613 let config = RRTStarDubinsConfig {
614 curvature: 1.0,
615 goal_sample_rate: 10,
616 max_iter: 300,
617 goal_xy_threshold: 1.5,
618 goal_yaw_threshold: 1.0,
619 robot_radius: 0.0,
620 path_resolution: 0.3,
621 connect_circle_dist: 50.0,
622 };
623 let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
624 let mut planner = RRTStarDubinsPlanner::new(obstacles, rand_area, config);
625
626 let start = Pose2D::new(0.0, 0.0, 0.0);
627 let goal = Pose2D::new(10.0, 10.0, 0.0);
628
629 let result = planner.planning(start, goal);
630 if let Some(path) = result {
632 assert!(path.poses.len() >= 2);
633 for pose in &path.poses {
634 assert!(pose.x.is_finite());
635 assert!(pose.y.is_finite());
636 assert!(pose.yaw.is_finite());
637 }
638 }
639 }
640
641 #[test]
642 fn test_generate_final_course() {
643 let config = RRTStarDubinsConfig::default();
644 let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
645 let mut planner = RRTStarDubinsPlanner::new(vec![], rand_area, config);
646 planner.start = RRTDubinsNode::new(0.0, 0.0, 0.0);
647 planner.goal = RRTDubinsNode::new(10.0, 10.0, 0.0);
648
649 let root = RRTDubinsNode::new(0.0, 0.0, 0.0);
650 let mut child = RRTDubinsNode::new(5.0, 5.0, 0.5);
651 child.parent = Some(0);
652 child.path_x = vec![0.0, 2.5, 5.0];
653 child.path_y = vec![0.0, 2.5, 5.0];
654 child.path_yaw = vec![0.0, 0.25, 0.5];
655
656 planner.node_list = vec![root, child];
657
658 let course = planner.generate_final_course(1);
659 assert!(course.poses.len() >= 3);
660 assert!(approx_eq(course.poses[0].x, 0.0, 1e-12));
661 let last = course.poses.last().unwrap();
662 assert!(approx_eq(last.x, 10.0, 1e-12));
663 assert!(approx_eq(last.y, 10.0, 1e-12));
664 }
665
666 #[test]
667 fn test_choose_parent_picks_lower_cost() {
668 let config = RRTStarDubinsConfig {
669 curvature: 1.0,
670 path_resolution: 0.3,
671 connect_circle_dist: 100.0,
672 ..Default::default()
673 };
674 let rand_area = AreaBounds::new(-5.0, 20.0, -5.0, 20.0);
675 let mut planner = RRTStarDubinsPlanner::new(vec![], rand_area, config);
676 planner.start = RRTDubinsNode::new(0.0, 0.0, 0.0);
677 planner.goal = RRTDubinsNode::new(20.0, 0.0, 0.0);
678
679 let mut expensive = RRTDubinsNode::new(5.0, 0.0, 0.0);
682 expensive.cost = 100.0;
683 expensive.parent = Some(0);
684
685 planner.node_list = vec![planner.start.clone(), expensive];
686
687 let new_node = RRTDubinsNode::new(3.0, 0.0, 0.0);
689 let near_inds = vec![0, 1];
690 let result = planner.choose_parent(new_node, &near_inds);
691 assert!(result.is_some());
692 let chosen = result.unwrap();
694 assert_eq!(chosen.parent, Some(0));
695 }
696}