1#![allow(dead_code, clippy::too_many_arguments)]
2
3use rand::Rng;
8
9use rust_robotics_core::{Path2D, Point2D, RoboticsError, RoboticsResult};
10
11#[derive(Debug, Clone)]
13pub struct Node {
14 pub x: f64,
15 pub y: f64,
16 pub path_x: Vec<f64>,
17 pub path_y: Vec<f64>,
18 pub cost: f64,
19 pub parent: Option<usize>,
20}
21
22impl Node {
23 pub fn new(x: f64, y: f64) -> Self {
24 Node {
25 x,
26 y,
27 path_x: Vec::new(),
28 path_y: Vec::new(),
29 cost: 0.0,
30 parent: None,
31 }
32 }
33}
34
35pub struct RRTStar {
36 pub start: Node,
37 pub end: Node,
38 pub min_rand: f64,
39 pub max_rand: f64,
40 pub expand_dis: f64,
41 pub path_resolution: f64,
42 pub goal_sample_rate: i32,
43 pub max_iter: i32,
44 pub connect_circle_dist: f64,
45 pub search_until_max_iter: bool,
46 pub robot_radius: f64,
47 pub obstacle_list: Vec<(f64, f64, f64)>, pub node_list: Vec<Node>,
49}
50
51impl RRTStar {
52 pub fn new(
53 start: (f64, f64),
54 goal: (f64, f64),
55 obstacle_list: Vec<(f64, f64, f64)>,
56 rand_area: (f64, f64),
57 expand_dis: f64,
58 path_resolution: f64,
59 goal_sample_rate: i32,
60 max_iter: i32,
61 connect_circle_dist: f64,
62 search_until_max_iter: bool,
63 robot_radius: f64,
64 ) -> Self {
65 RRTStar {
66 start: Node::new(start.0, start.1),
67 end: Node::new(goal.0, goal.1),
68 min_rand: rand_area.0,
69 max_rand: rand_area.1,
70 expand_dis,
71 path_resolution,
72 goal_sample_rate,
73 max_iter,
74 connect_circle_dist,
75 search_until_max_iter,
76 robot_radius,
77 obstacle_list,
78 node_list: Vec::new(),
79 }
80 }
81
82 pub fn planning(&mut self) -> Option<Vec<[f64; 2]>> {
83 self.planning_with_sampler(|planner| planner.get_random_node())
84 }
85
86 fn reset_search(&mut self) {
87 self.node_list = vec![self.start.clone()];
88 }
89
90 fn planning_with_sampler<F>(&mut self, mut sample_node: F) -> Option<Vec<[f64; 2]>>
91 where
92 F: FnMut(&RRTStar) -> Node,
93 {
94 self.reset_search();
95
96 for _i in 0..self.max_iter {
97 let rnd_node = sample_node(self);
98 let nearest_ind = self.get_nearest_node_index(&rnd_node);
99 let mut new_node = self.steer(nearest_ind, &rnd_node);
100
101 if let Some(ref mut node) = new_node {
102 let near_node = &self.node_list[nearest_ind];
103 node.cost = near_node.cost + self.calc_distance(near_node, node);
104
105 if self.check_collision_free(node) {
106 let near_inds = self.find_near_nodes(node);
107 let node_with_updated_parent = self.choose_parent(node.clone(), &near_inds);
108
109 if let Some(updated_node) = node_with_updated_parent {
110 let new_index = self.node_list.len();
111 self.node_list.push(updated_node);
112 self.rewire(new_index, &near_inds);
113 } else {
114 self.node_list.push(node.clone());
115 }
116 }
117 }
118
119 if !self.search_until_max_iter && new_node.is_some() {
120 if let Some(last_index) = self.search_best_goal_node() {
121 return Some(self.generate_final_course(last_index));
122 }
123 }
124 }
125
126 if let Some(last_index) = self.search_best_goal_node() {
127 return Some(self.generate_final_course(last_index));
128 }
129
130 None
131 }
132
133 fn get_random_node(&self) -> Node {
134 let mut rng = rand::rng();
135
136 if rng.random_range(0..=100) > self.goal_sample_rate {
137 Node::new(
138 rng.random_range(self.min_rand..=self.max_rand),
139 rng.random_range(self.min_rand..=self.max_rand),
140 )
141 } else {
142 Node::new(self.end.x, self.end.y)
143 }
144 }
145
146 fn get_nearest_node_index(&self, rnd_node: &Node) -> usize {
147 let mut min_dist = f64::INFINITY;
148 let mut nearest_ind = 0;
149
150 for (i, node) in self.node_list.iter().enumerate() {
151 let dist = self.calc_distance(node, rnd_node);
152 if dist < min_dist {
153 min_dist = dist;
154 nearest_ind = i;
155 }
156 }
157
158 nearest_ind
159 }
160
161 fn steer(&self, from_ind: usize, to_node: &Node) -> Option<Node> {
162 let from_node = &self.node_list[from_ind];
163 Some(self.steer_from_node(from_node, to_node, self.expand_dis, Some(from_ind)))
164 }
165
166 fn steer_from_node(
167 &self,
168 from_node: &Node,
169 to_node: &Node,
170 extend_length: f64,
171 parent: Option<usize>,
172 ) -> Node {
173 let mut new_node = Node::new(from_node.x, from_node.y);
174 let (dist, theta) = self.calc_distance_and_angle(&new_node, to_node);
175 let extend_length = extend_length.min(dist);
176
177 new_node.path_x = vec![new_node.x];
178 new_node.path_y = vec![new_node.y];
179
180 let n_expand = (extend_length / self.path_resolution).floor() as i32;
181 for _ in 0..n_expand {
182 new_node.x += self.path_resolution * theta.cos();
183 new_node.y += self.path_resolution * theta.sin();
184 new_node.path_x.push(new_node.x);
185 new_node.path_y.push(new_node.y);
186 }
187
188 let (remaining_dist, _) = self.calc_distance_and_angle(&new_node, to_node);
189 if remaining_dist <= self.path_resolution {
190 new_node.path_x.push(to_node.x);
191 new_node.path_y.push(to_node.y);
192 new_node.x = to_node.x;
193 new_node.y = to_node.y;
194 }
195
196 new_node.parent = parent;
197 new_node
198 }
199
200 fn check_collision_free(&self, node: &Node) -> bool {
201 if node.path_x.is_empty() || node.path_y.is_empty() {
202 return true;
203 }
204
205 for &(ox, oy, size) in &self.obstacle_list {
206 for (&px, &py) in node.path_x.iter().zip(node.path_y.iter()) {
207 let d = (px - ox).powi(2) + (py - oy).powi(2);
208 if d <= (size + self.robot_radius).powi(2) {
209 return false;
210 }
211 }
212 }
213
214 true
215 }
216
217 fn find_near_nodes(&self, new_node: &Node) -> Vec<usize> {
218 let nnode = self.node_list.len() + 1;
219 let r = self.connect_circle_dist * ((nnode as f64).ln() / nnode as f64).sqrt();
220 let r = r.min(self.expand_dis);
221
222 self.node_list
223 .iter()
224 .enumerate()
225 .filter_map(|(i, node)| {
226 let dist_sq = (node.x - new_node.x).powi(2) + (node.y - new_node.y).powi(2);
227 if dist_sq <= r.powi(2) {
228 Some(i)
229 } else {
230 None
231 }
232 })
233 .collect()
234 }
235
236 fn choose_parent(&self, new_node: Node, near_inds: &[usize]) -> Option<Node> {
237 if near_inds.is_empty() {
238 return None;
239 }
240
241 let mut costs = Vec::new();
242 for &i in near_inds {
243 let near_node = &self.node_list[i];
244 let t_node = self.steer_from_node(near_node, &new_node, f64::INFINITY, Some(i));
245
246 if self.check_collision_free(&t_node) {
247 costs.push(self.calc_new_cost(near_node, &new_node));
248 } else {
249 costs.push(f64::INFINITY);
250 }
251 }
252
253 let min_cost = costs.iter().fold(f64::INFINITY, |a, &b| a.min(b));
254
255 if min_cost == f64::INFINITY {
256 return None;
257 }
258
259 let min_ind = costs.iter().position(|&x| x == min_cost)?;
260 let parent_ind = near_inds[min_ind];
261
262 let mut result_node = self.steer_from_node(
263 &self.node_list[parent_ind],
264 &new_node,
265 f64::INFINITY,
266 Some(parent_ind),
267 );
268 result_node.cost = min_cost;
269
270 Some(result_node)
271 }
272
273 fn search_best_goal_node(&self) -> Option<usize> {
274 let dist_to_goal_list: Vec<f64> = self
275 .node_list
276 .iter()
277 .map(|n| self.calc_dist_to_goal(n.x, n.y))
278 .collect();
279
280 let goal_inds: Vec<usize> = dist_to_goal_list
281 .iter()
282 .enumerate()
283 .filter_map(|(i, &dist)| {
284 if dist <= self.expand_dis {
285 Some(i)
286 } else {
287 None
288 }
289 })
290 .collect();
291
292 let safe_goal_inds: Vec<usize> = goal_inds
293 .into_iter()
294 .filter(|&goal_ind| {
295 let t_node = self.steer_from_node(
296 &self.node_list[goal_ind],
297 &self.end,
298 f64::INFINITY,
299 Some(goal_ind),
300 );
301 self.check_collision_free(&t_node)
302 })
303 .collect();
304
305 if safe_goal_inds.is_empty() {
306 return None;
307 }
308
309 let safe_goal_costs: Vec<f64> = safe_goal_inds
310 .iter()
311 .map(|&i| {
312 self.node_list[i].cost
313 + self.calc_dist_to_goal(self.node_list[i].x, self.node_list[i].y)
314 })
315 .collect();
316
317 let min_cost = safe_goal_costs.iter().fold(f64::INFINITY, |a, &b| a.min(b));
318
319 safe_goal_inds
320 .into_iter()
321 .zip(safe_goal_costs)
322 .find(|(_, cost)| *cost == min_cost)
323 .map(|(i, _)| i)
324 }
325
326 fn rewire(&mut self, new_node_ind: usize, near_inds: &[usize]) {
327 for &i in near_inds {
328 let near_node = self.node_list[i].clone();
329 let new_node = &self.node_list[new_node_ind];
330
331 let mut edge_node =
332 self.steer_from_node(new_node, &near_node, f64::INFINITY, Some(new_node_ind));
333 edge_node.cost = self.calc_new_cost(new_node, &near_node);
334
335 let no_collision = self.check_collision_free(&edge_node);
336 let improved_cost = near_node.cost > edge_node.cost;
337
338 if no_collision && improved_cost {
339 self.node_list[i] = edge_node;
340 self.propagate_cost_to_leaves(i);
341 }
342 }
343 }
344
345 fn calc_new_cost(&self, from_node: &Node, to_node: &Node) -> f64 {
346 from_node.cost + self.calc_distance(from_node, to_node)
347 }
348
349 fn propagate_cost_to_leaves(&mut self, parent_ind: usize) {
350 let parent_node = self.node_list[parent_ind].clone();
351
352 for i in 0..self.node_list.len() {
353 if let Some(node_parent) = self.node_list[i].parent {
354 if node_parent == parent_ind {
355 self.node_list[i].cost =
356 self.calc_new_cost(&parent_node, &self.node_list[i].clone());
357 self.propagate_cost_to_leaves(i);
358 }
359 }
360 }
361 }
362
363 fn generate_final_course(&self, goal_ind: usize) -> Vec<[f64; 2]> {
364 let mut path = vec![[self.end.x, self.end.y]];
365 let mut node = &self.node_list[goal_ind];
366
367 while let Some(parent_ind) = node.parent {
368 path.push([node.x, node.y]);
369 node = &self.node_list[parent_ind];
370 }
371 path.push([node.x, node.y]);
372
373 path
374 }
375
376 fn calc_dist_to_goal(&self, x: f64, y: f64) -> f64 {
377 let dx = x - self.end.x;
378 let dy = y - self.end.y;
379 (dx * dx + dy * dy).sqrt()
380 }
381
382 fn calc_distance(&self, from_node: &Node, to_node: &Node) -> f64 {
383 let dx = to_node.x - from_node.x;
384 let dy = to_node.y - from_node.y;
385 (dx * dx + dy * dy).sqrt()
386 }
387
388 fn calc_distance_and_angle(&self, from_node: &Node, to_node: &Node) -> (f64, f64) {
389 let dx = to_node.x - from_node.x;
390 let dy = to_node.y - from_node.y;
391 let d = (dx * dx + dy * dy).sqrt();
392 let theta = dy.atan2(dx);
393 (d, theta)
394 }
395
396 pub fn plan_from(&mut self, start: Point2D, goal: Point2D) -> RoboticsResult<Path2D> {
402 self.start = Node::new(start.x, start.y);
403 self.end = Node::new(goal.x, goal.y);
404
405 self.planning()
406 .map(|raw_path| {
407 Path2D::from_points(
408 raw_path
409 .into_iter()
410 .rev()
411 .map(|p| Point2D::new(p[0], p[1]))
412 .collect(),
413 )
414 })
415 .ok_or_else(|| {
416 RoboticsError::PlanningError(
417 "RRT*: Cannot find path within max iterations".to_string(),
418 )
419 })
420 }
421
422 pub fn get_tree(&self) -> &[Node] {
424 &self.node_list
425 }
426
427 pub fn get_obstacles(&self) -> &[(f64, f64, f64)] {
429 &self.obstacle_list
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 fn assert_close(actual: f64, expected: f64) {
438 assert!(
439 (actual - expected).abs() < 1.0e-12,
440 "expected {expected}, got {actual}"
441 );
442 }
443
444 fn parse_xy_fixture(csv: &str) -> Vec<[f64; 2]> {
445 csv.lines()
446 .skip(1)
447 .filter(|line| !line.trim().is_empty())
448 .map(|line| {
449 let (x, y) = line
450 .split_once(',')
451 .expect("xy fixture rows must contain a comma");
452 [x.parse().unwrap(), y.parse().unwrap()]
453 })
454 .collect()
455 }
456
457 fn create_pythonrobotics_main_planner() -> RRTStar {
458 RRTStar::new(
459 (0.0, 0.0),
460 (6.0, 10.0),
461 vec![
462 (5.0, 5.0, 1.0),
463 (3.0, 6.0, 2.0),
464 (3.0, 8.0, 2.0),
465 (3.0, 10.0, 2.0),
466 (7.0, 5.0, 2.0),
467 (9.0, 5.0, 2.0),
468 (8.0, 10.0, 1.0),
469 (6.0, 12.0, 1.0),
470 ],
471 (-2.0, 15.0),
472 1.0,
473 1.0,
474 20,
475 300,
476 50.0,
477 false,
478 0.8,
479 )
480 }
481
482 #[test]
483 fn test_rrt_star_config() {
484 let rrt = RRTStar::new(
485 (0.0, 0.0),
486 (6.0, 10.0),
487 vec![(5.0, 5.0, 1.0)],
488 (-2.0, 15.0),
489 2.0,
490 0.5,
491 20,
492 500,
493 50.0,
494 false,
495 0.3,
496 );
497 assert_eq!(rrt.expand_dis, 2.0);
498 assert_eq!(rrt.max_iter, 500);
499 }
500
501 #[test]
502 fn test_rrt_star_upstream_no_obstacle_seeded_reference() {
503 for robot_radius in [0.0, 0.8] {
504 let mut rrt = RRTStar::new(
505 (0.0, 0.0),
506 (6.0, 10.0),
507 vec![],
508 (-2.0, 15.0),
509 30.0,
510 1.0,
511 20,
512 300,
513 50.0,
514 false,
515 robot_radius,
516 );
517 let sample = [10.455_649_682_677_358, 11.942_970_283_541_907];
518 let path = rrt
519 .planning_with_sampler(|_| Node::new(sample[0], sample[1]))
520 .unwrap();
521 assert_eq!(rrt.node_list.len(), 2);
522 assert_eq!(path, vec![[6.0, 10.0], [0.0, 0.0]]);
523 assert_close(rrt.node_list[1].x, sample[0]);
524 assert_close(rrt.node_list[1].y, sample[1]);
525 assert_close(rrt.node_list[1].cost, 15.873_095_144_943_73);
526 assert_eq!(rrt.node_list[1].parent, Some(0));
527 }
528 }
529
530 #[test]
531 fn test_rrt_star_upstream_seeded_main_prefix_matches_pythonrobotics_reference() {
532 let mut rrt = create_pythonrobotics_main_planner();
533 rrt.max_iter = 20;
534 let samples = parse_xy_fixture(include_str!("testdata/rrt_star_main_seed10_samples.csv"));
535 let mut sample_index = 0_usize;
536 let prefix_len = 20_usize;
537
538 let path = rrt.planning_with_sampler(|_| {
539 let sample = samples
540 .get(sample_index)
541 .filter(|_| sample_index < prefix_len)
542 .expect("python reference sample sequence exhausted");
543 sample_index += 1;
544 Node::new(sample[0], sample[1])
545 });
546
547 assert!(path.is_none());
548 assert_eq!(sample_index, prefix_len);
549 assert_eq!(rrt.node_list.len(), 14);
550
551 let expected_nodes = [
552 (
553 1,
554 [-0.227_015_105_864_128, 0.973_891_237_104_79],
555 1.0,
556 Some(0),
557 ),
558 (
559 2,
560 [0.340_848_395_898_016, 1.797_013_976_048_647],
561 2.0,
562 Some(1),
563 ),
564 (
565 5,
566 [2.912_922_340_312_151, 1.655_751_229_702_901],
567 5.0,
568 Some(4),
569 ),
570 (
571 10,
572 [5.856_411_905_724_674, 1.320_164_679_038_256],
573 8.0,
574 Some(9),
575 ),
576 (
577 13,
578 [8.543_112_538_843_18, 0.823_740_039_534_573],
579 11.0,
580 Some(12),
581 ),
582 ];
583 for (index, xy, cost, parent) in expected_nodes {
584 let node = &rrt.node_list[index];
585 assert_close(node.x, xy[0]);
586 assert_close(node.y, xy[1]);
587 assert_close(node.cost, cost);
588 assert_eq!(node.parent, parent);
589 }
590 }
591
592 #[test]
593 fn test_rrt_star_upstream_seeded_main_matches_pythonrobotics_reference() {
594 let mut rrt = create_pythonrobotics_main_planner();
595 let samples = parse_xy_fixture(include_str!("testdata/rrt_star_main_seed10_samples.csv"));
596 let expected_path =
597 parse_xy_fixture(include_str!("testdata/rrt_star_main_seed10_path.csv"));
598 let mut sample_index = 0_usize;
599
600 let path = rrt
601 .planning_with_sampler(|_| {
602 let sample = samples
603 .get(sample_index)
604 .expect("python reference sample sequence exhausted");
605 sample_index += 1;
606 Node::new(sample[0], sample[1])
607 })
608 .expect("python reference run should find a path");
609
610 assert_eq!(sample_index, samples.len());
611 assert_eq!(rrt.node_list.len(), 100);
612 assert_eq!(path.len(), expected_path.len());
613 for (actual, expected) in path.iter().zip(expected_path.iter()) {
614 assert_close(actual[0], expected[0]);
615 assert_close(actual[1], expected[1]);
616 }
617
618 let expected_nodes = [
619 (
620 1,
621 [-0.227_015_105_864_128, 0.973_891_237_104_79],
622 1.0,
623 Some(0),
624 ),
625 (
626 2,
627 [0.340_848_395_898_016, 1.797_013_976_048_647],
628 2.0,
629 Some(1),
630 ),
631 (
632 5,
633 [2.912_922_340_312_151, 1.655_751_229_702_901],
634 5.0,
635 Some(4),
636 ),
637 (
638 10,
639 [5.856_411_905_724_674, 1.320_164_679_038_256],
640 8.0,
641 Some(9),
642 ),
643 (
644 20,
645 [12.105_226_205_468_63, 1.607_428_066_363_632],
646 14.812_039_643_502_144,
647 Some(19),
648 ),
649 (
650 40,
651 [13.266_098_354_827_152, 11.032_918_978_213_733],
652 25.673_392_954_630_07,
653 Some(39),
654 ),
655 (
656 60,
657 [8.777_150_456_317_27, 12.593_447_860_337_104],
658 31.673_392_954_630_07,
659 Some(53),
660 ),
661 (
662 80,
663 [10.550_895_454_349_991, 1.108_429_868_595_935],
664 13.203_336_815_700_968,
665 Some(17),
666 ),
667 (99, [6.0, 10.0], 28.741_122_081_549_424, Some(97)),
668 ];
669
670 for (index, xy, cost, parent) in expected_nodes {
671 let node = &rrt.node_list[index];
672 assert_close(node.x, xy[0]);
673 assert_close(node.y, xy[1]);
674 assert_close(node.cost, cost);
675 assert_eq!(node.parent, parent);
676 }
677 }
678}