Skip to main content

rust_robotics_planning/
rrt_star.rs

1#![allow(dead_code, clippy::too_many_arguments)]
2
3//! RRT* (Rapidly-exploring Random Tree Star) path planning algorithm
4//!
5//! An optimized version of RRT that rewires the tree to find shorter paths.
6
7use rand::Rng;
8
9use rust_robotics_core::{Path2D, Point2D, RoboticsError, RoboticsResult};
10
11/// Internal node for RRT* tree
12#[derive(Debug, Clone)]
13pub struct Node {
14    pub x: f64,
15    pub y: f64,
16    pub path_x: Vec<f64>,
17    pub path_y: Vec<f64>,
18    pub cost: f64,
19    pub parent: Option<usize>,
20}
21
22impl Node {
23    pub fn new(x: f64, y: f64) -> Self {
24        Node {
25            x,
26            y,
27            path_x: Vec::new(),
28            path_y: Vec::new(),
29            cost: 0.0,
30            parent: None,
31        }
32    }
33}
34
35pub struct RRTStar {
36    pub start: Node,
37    pub end: Node,
38    pub min_rand: f64,
39    pub max_rand: f64,
40    pub expand_dis: f64,
41    pub path_resolution: f64,
42    pub goal_sample_rate: i32,
43    pub max_iter: i32,
44    pub connect_circle_dist: f64,
45    pub search_until_max_iter: bool,
46    pub robot_radius: f64,
47    pub obstacle_list: Vec<(f64, f64, f64)>, // (x, y, radius)
48    pub node_list: Vec<Node>,
49}
50
51impl RRTStar {
52    pub fn new(
53        start: (f64, f64),
54        goal: (f64, f64),
55        obstacle_list: Vec<(f64, f64, f64)>,
56        rand_area: (f64, f64),
57        expand_dis: f64,
58        path_resolution: f64,
59        goal_sample_rate: i32,
60        max_iter: i32,
61        connect_circle_dist: f64,
62        search_until_max_iter: bool,
63        robot_radius: f64,
64    ) -> Self {
65        RRTStar {
66            start: Node::new(start.0, start.1),
67            end: Node::new(goal.0, goal.1),
68            min_rand: rand_area.0,
69            max_rand: rand_area.1,
70            expand_dis,
71            path_resolution,
72            goal_sample_rate,
73            max_iter,
74            connect_circle_dist,
75            search_until_max_iter,
76            robot_radius,
77            obstacle_list,
78            node_list: Vec::new(),
79        }
80    }
81
82    pub fn planning(&mut self) -> Option<Vec<[f64; 2]>> {
83        self.planning_with_sampler(|planner| planner.get_random_node())
84    }
85
86    fn reset_search(&mut self) {
87        self.node_list = vec![self.start.clone()];
88    }
89
90    fn planning_with_sampler<F>(&mut self, mut sample_node: F) -> Option<Vec<[f64; 2]>>
91    where
92        F: FnMut(&RRTStar) -> Node,
93    {
94        self.reset_search();
95
96        for _i in 0..self.max_iter {
97            let rnd_node = sample_node(self);
98            let nearest_ind = self.get_nearest_node_index(&rnd_node);
99            let mut new_node = self.steer(nearest_ind, &rnd_node);
100
101            if let Some(ref mut node) = new_node {
102                let near_node = &self.node_list[nearest_ind];
103                node.cost = near_node.cost + self.calc_distance(near_node, node);
104
105                if self.check_collision_free(node) {
106                    let near_inds = self.find_near_nodes(node);
107                    let node_with_updated_parent = self.choose_parent(node.clone(), &near_inds);
108
109                    if let Some(updated_node) = node_with_updated_parent {
110                        let new_index = self.node_list.len();
111                        self.node_list.push(updated_node);
112                        self.rewire(new_index, &near_inds);
113                    } else {
114                        self.node_list.push(node.clone());
115                    }
116                }
117            }
118
119            if !self.search_until_max_iter && new_node.is_some() {
120                if let Some(last_index) = self.search_best_goal_node() {
121                    return Some(self.generate_final_course(last_index));
122                }
123            }
124        }
125
126        if let Some(last_index) = self.search_best_goal_node() {
127            return Some(self.generate_final_course(last_index));
128        }
129
130        None
131    }
132
133    fn get_random_node(&self) -> Node {
134        let mut rng = rand::rng();
135
136        if rng.random_range(0..=100) > self.goal_sample_rate {
137            Node::new(
138                rng.random_range(self.min_rand..=self.max_rand),
139                rng.random_range(self.min_rand..=self.max_rand),
140            )
141        } else {
142            Node::new(self.end.x, self.end.y)
143        }
144    }
145
146    fn get_nearest_node_index(&self, rnd_node: &Node) -> usize {
147        let mut min_dist = f64::INFINITY;
148        let mut nearest_ind = 0;
149
150        for (i, node) in self.node_list.iter().enumerate() {
151            let dist = self.calc_distance(node, rnd_node);
152            if dist < min_dist {
153                min_dist = dist;
154                nearest_ind = i;
155            }
156        }
157
158        nearest_ind
159    }
160
161    fn steer(&self, from_ind: usize, to_node: &Node) -> Option<Node> {
162        let from_node = &self.node_list[from_ind];
163        Some(self.steer_from_node(from_node, to_node, self.expand_dis, Some(from_ind)))
164    }
165
166    fn steer_from_node(
167        &self,
168        from_node: &Node,
169        to_node: &Node,
170        extend_length: f64,
171        parent: Option<usize>,
172    ) -> Node {
173        let mut new_node = Node::new(from_node.x, from_node.y);
174        let (dist, theta) = self.calc_distance_and_angle(&new_node, to_node);
175        let extend_length = extend_length.min(dist);
176
177        new_node.path_x = vec![new_node.x];
178        new_node.path_y = vec![new_node.y];
179
180        let n_expand = (extend_length / self.path_resolution).floor() as i32;
181        for _ in 0..n_expand {
182            new_node.x += self.path_resolution * theta.cos();
183            new_node.y += self.path_resolution * theta.sin();
184            new_node.path_x.push(new_node.x);
185            new_node.path_y.push(new_node.y);
186        }
187
188        let (remaining_dist, _) = self.calc_distance_and_angle(&new_node, to_node);
189        if remaining_dist <= self.path_resolution {
190            new_node.path_x.push(to_node.x);
191            new_node.path_y.push(to_node.y);
192            new_node.x = to_node.x;
193            new_node.y = to_node.y;
194        }
195
196        new_node.parent = parent;
197        new_node
198    }
199
200    fn check_collision_free(&self, node: &Node) -> bool {
201        if node.path_x.is_empty() || node.path_y.is_empty() {
202            return true;
203        }
204
205        for &(ox, oy, size) in &self.obstacle_list {
206            for (&px, &py) in node.path_x.iter().zip(node.path_y.iter()) {
207                let d = (px - ox).powi(2) + (py - oy).powi(2);
208                if d <= (size + self.robot_radius).powi(2) {
209                    return false;
210                }
211            }
212        }
213
214        true
215    }
216
217    fn find_near_nodes(&self, new_node: &Node) -> Vec<usize> {
218        let nnode = self.node_list.len() + 1;
219        let r = self.connect_circle_dist * ((nnode as f64).ln() / nnode as f64).sqrt();
220        let r = r.min(self.expand_dis);
221
222        self.node_list
223            .iter()
224            .enumerate()
225            .filter_map(|(i, node)| {
226                let dist_sq = (node.x - new_node.x).powi(2) + (node.y - new_node.y).powi(2);
227                if dist_sq <= r.powi(2) {
228                    Some(i)
229                } else {
230                    None
231                }
232            })
233            .collect()
234    }
235
236    fn choose_parent(&self, new_node: Node, near_inds: &[usize]) -> Option<Node> {
237        if near_inds.is_empty() {
238            return None;
239        }
240
241        let mut costs = Vec::new();
242        for &i in near_inds {
243            let near_node = &self.node_list[i];
244            let t_node = self.steer_from_node(near_node, &new_node, f64::INFINITY, Some(i));
245
246            if self.check_collision_free(&t_node) {
247                costs.push(self.calc_new_cost(near_node, &new_node));
248            } else {
249                costs.push(f64::INFINITY);
250            }
251        }
252
253        let min_cost = costs.iter().fold(f64::INFINITY, |a, &b| a.min(b));
254
255        if min_cost == f64::INFINITY {
256            return None;
257        }
258
259        let min_ind = costs.iter().position(|&x| x == min_cost)?;
260        let parent_ind = near_inds[min_ind];
261
262        let mut result_node = self.steer_from_node(
263            &self.node_list[parent_ind],
264            &new_node,
265            f64::INFINITY,
266            Some(parent_ind),
267        );
268        result_node.cost = min_cost;
269
270        Some(result_node)
271    }
272
273    fn search_best_goal_node(&self) -> Option<usize> {
274        let dist_to_goal_list: Vec<f64> = self
275            .node_list
276            .iter()
277            .map(|n| self.calc_dist_to_goal(n.x, n.y))
278            .collect();
279
280        let goal_inds: Vec<usize> = dist_to_goal_list
281            .iter()
282            .enumerate()
283            .filter_map(|(i, &dist)| {
284                if dist <= self.expand_dis {
285                    Some(i)
286                } else {
287                    None
288                }
289            })
290            .collect();
291
292        let safe_goal_inds: Vec<usize> = goal_inds
293            .into_iter()
294            .filter(|&goal_ind| {
295                let t_node = self.steer_from_node(
296                    &self.node_list[goal_ind],
297                    &self.end,
298                    f64::INFINITY,
299                    Some(goal_ind),
300                );
301                self.check_collision_free(&t_node)
302            })
303            .collect();
304
305        if safe_goal_inds.is_empty() {
306            return None;
307        }
308
309        let safe_goal_costs: Vec<f64> = safe_goal_inds
310            .iter()
311            .map(|&i| {
312                self.node_list[i].cost
313                    + self.calc_dist_to_goal(self.node_list[i].x, self.node_list[i].y)
314            })
315            .collect();
316
317        let min_cost = safe_goal_costs.iter().fold(f64::INFINITY, |a, &b| a.min(b));
318
319        safe_goal_inds
320            .into_iter()
321            .zip(safe_goal_costs)
322            .find(|(_, cost)| *cost == min_cost)
323            .map(|(i, _)| i)
324    }
325
326    fn rewire(&mut self, new_node_ind: usize, near_inds: &[usize]) {
327        for &i in near_inds {
328            let near_node = self.node_list[i].clone();
329            let new_node = &self.node_list[new_node_ind];
330
331            let mut edge_node =
332                self.steer_from_node(new_node, &near_node, f64::INFINITY, Some(new_node_ind));
333            edge_node.cost = self.calc_new_cost(new_node, &near_node);
334
335            let no_collision = self.check_collision_free(&edge_node);
336            let improved_cost = near_node.cost > edge_node.cost;
337
338            if no_collision && improved_cost {
339                self.node_list[i] = edge_node;
340                self.propagate_cost_to_leaves(i);
341            }
342        }
343    }
344
345    fn calc_new_cost(&self, from_node: &Node, to_node: &Node) -> f64 {
346        from_node.cost + self.calc_distance(from_node, to_node)
347    }
348
349    fn propagate_cost_to_leaves(&mut self, parent_ind: usize) {
350        let parent_node = self.node_list[parent_ind].clone();
351
352        for i in 0..self.node_list.len() {
353            if let Some(node_parent) = self.node_list[i].parent {
354                if node_parent == parent_ind {
355                    self.node_list[i].cost =
356                        self.calc_new_cost(&parent_node, &self.node_list[i].clone());
357                    self.propagate_cost_to_leaves(i);
358                }
359            }
360        }
361    }
362
363    fn generate_final_course(&self, goal_ind: usize) -> Vec<[f64; 2]> {
364        let mut path = vec![[self.end.x, self.end.y]];
365        let mut node = &self.node_list[goal_ind];
366
367        while let Some(parent_ind) = node.parent {
368            path.push([node.x, node.y]);
369            node = &self.node_list[parent_ind];
370        }
371        path.push([node.x, node.y]);
372
373        path
374    }
375
376    fn calc_dist_to_goal(&self, x: f64, y: f64) -> f64 {
377        let dx = x - self.end.x;
378        let dy = y - self.end.y;
379        (dx * dx + dy * dy).sqrt()
380    }
381
382    fn calc_distance(&self, from_node: &Node, to_node: &Node) -> f64 {
383        let dx = to_node.x - from_node.x;
384        let dy = to_node.y - from_node.y;
385        (dx * dx + dy * dy).sqrt()
386    }
387
388    fn calc_distance_and_angle(&self, from_node: &Node, to_node: &Node) -> (f64, f64) {
389        let dx = to_node.x - from_node.x;
390        let dy = to_node.y - from_node.y;
391        let d = (dx * dx + dy * dy).sqrt();
392        let theta = dy.atan2(dx);
393        (d, theta)
394    }
395
396    /// Plan a path from the given start to goal, returning a [`Path2D`].
397    ///
398    /// This is a convenience wrapper around [`planning()`](Self::planning) that accepts
399    /// [`Point2D`], sets the start/goal, runs the planner, and returns [`Path2D`].
400    /// Requires `&mut self` because the underlying algorithm mutates internal state.
401    pub fn plan_from(&mut self, start: Point2D, goal: Point2D) -> RoboticsResult<Path2D> {
402        self.start = Node::new(start.x, start.y);
403        self.end = Node::new(goal.x, goal.y);
404
405        self.planning()
406            .map(|raw_path| {
407                Path2D::from_points(
408                    raw_path
409                        .into_iter()
410                        .rev()
411                        .map(|p| Point2D::new(p[0], p[1]))
412                        .collect(),
413                )
414            })
415            .ok_or_else(|| {
416                RoboticsError::PlanningError(
417                    "RRT*: Cannot find path within max iterations".to_string(),
418                )
419            })
420    }
421
422    /// Get the tree nodes for external inspection
423    pub fn get_tree(&self) -> &[Node] {
424        &self.node_list
425    }
426
427    /// Get the obstacle list
428    pub fn get_obstacles(&self) -> &[(f64, f64, f64)] {
429        &self.obstacle_list
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    fn assert_close(actual: f64, expected: f64) {
438        assert!(
439            (actual - expected).abs() < 1.0e-12,
440            "expected {expected}, got {actual}"
441        );
442    }
443
444    fn parse_xy_fixture(csv: &str) -> Vec<[f64; 2]> {
445        csv.lines()
446            .skip(1)
447            .filter(|line| !line.trim().is_empty())
448            .map(|line| {
449                let (x, y) = line
450                    .split_once(',')
451                    .expect("xy fixture rows must contain a comma");
452                [x.parse().unwrap(), y.parse().unwrap()]
453            })
454            .collect()
455    }
456
457    fn create_pythonrobotics_main_planner() -> RRTStar {
458        RRTStar::new(
459            (0.0, 0.0),
460            (6.0, 10.0),
461            vec![
462                (5.0, 5.0, 1.0),
463                (3.0, 6.0, 2.0),
464                (3.0, 8.0, 2.0),
465                (3.0, 10.0, 2.0),
466                (7.0, 5.0, 2.0),
467                (9.0, 5.0, 2.0),
468                (8.0, 10.0, 1.0),
469                (6.0, 12.0, 1.0),
470            ],
471            (-2.0, 15.0),
472            1.0,
473            1.0,
474            20,
475            300,
476            50.0,
477            false,
478            0.8,
479        )
480    }
481
482    #[test]
483    fn test_rrt_star_config() {
484        let rrt = RRTStar::new(
485            (0.0, 0.0),
486            (6.0, 10.0),
487            vec![(5.0, 5.0, 1.0)],
488            (-2.0, 15.0),
489            2.0,
490            0.5,
491            20,
492            500,
493            50.0,
494            false,
495            0.3,
496        );
497        assert_eq!(rrt.expand_dis, 2.0);
498        assert_eq!(rrt.max_iter, 500);
499    }
500
501    #[test]
502    fn test_rrt_star_upstream_no_obstacle_seeded_reference() {
503        for robot_radius in [0.0, 0.8] {
504            let mut rrt = RRTStar::new(
505                (0.0, 0.0),
506                (6.0, 10.0),
507                vec![],
508                (-2.0, 15.0),
509                30.0,
510                1.0,
511                20,
512                300,
513                50.0,
514                false,
515                robot_radius,
516            );
517            let sample = [10.455_649_682_677_358, 11.942_970_283_541_907];
518            let path = rrt
519                .planning_with_sampler(|_| Node::new(sample[0], sample[1]))
520                .unwrap();
521            assert_eq!(rrt.node_list.len(), 2);
522            assert_eq!(path, vec![[6.0, 10.0], [0.0, 0.0]]);
523            assert_close(rrt.node_list[1].x, sample[0]);
524            assert_close(rrt.node_list[1].y, sample[1]);
525            assert_close(rrt.node_list[1].cost, 15.873_095_144_943_73);
526            assert_eq!(rrt.node_list[1].parent, Some(0));
527        }
528    }
529
530    #[test]
531    fn test_rrt_star_upstream_seeded_main_prefix_matches_pythonrobotics_reference() {
532        let mut rrt = create_pythonrobotics_main_planner();
533        rrt.max_iter = 20;
534        let samples = parse_xy_fixture(include_str!("testdata/rrt_star_main_seed10_samples.csv"));
535        let mut sample_index = 0_usize;
536        let prefix_len = 20_usize;
537
538        let path = rrt.planning_with_sampler(|_| {
539            let sample = samples
540                .get(sample_index)
541                .filter(|_| sample_index < prefix_len)
542                .expect("python reference sample sequence exhausted");
543            sample_index += 1;
544            Node::new(sample[0], sample[1])
545        });
546
547        assert!(path.is_none());
548        assert_eq!(sample_index, prefix_len);
549        assert_eq!(rrt.node_list.len(), 14);
550
551        let expected_nodes = [
552            (
553                1,
554                [-0.227_015_105_864_128, 0.973_891_237_104_79],
555                1.0,
556                Some(0),
557            ),
558            (
559                2,
560                [0.340_848_395_898_016, 1.797_013_976_048_647],
561                2.0,
562                Some(1),
563            ),
564            (
565                5,
566                [2.912_922_340_312_151, 1.655_751_229_702_901],
567                5.0,
568                Some(4),
569            ),
570            (
571                10,
572                [5.856_411_905_724_674, 1.320_164_679_038_256],
573                8.0,
574                Some(9),
575            ),
576            (
577                13,
578                [8.543_112_538_843_18, 0.823_740_039_534_573],
579                11.0,
580                Some(12),
581            ),
582        ];
583        for (index, xy, cost, parent) in expected_nodes {
584            let node = &rrt.node_list[index];
585            assert_close(node.x, xy[0]);
586            assert_close(node.y, xy[1]);
587            assert_close(node.cost, cost);
588            assert_eq!(node.parent, parent);
589        }
590    }
591
592    #[test]
593    fn test_rrt_star_upstream_seeded_main_matches_pythonrobotics_reference() {
594        let mut rrt = create_pythonrobotics_main_planner();
595        let samples = parse_xy_fixture(include_str!("testdata/rrt_star_main_seed10_samples.csv"));
596        let expected_path =
597            parse_xy_fixture(include_str!("testdata/rrt_star_main_seed10_path.csv"));
598        let mut sample_index = 0_usize;
599
600        let path = rrt
601            .planning_with_sampler(|_| {
602                let sample = samples
603                    .get(sample_index)
604                    .expect("python reference sample sequence exhausted");
605                sample_index += 1;
606                Node::new(sample[0], sample[1])
607            })
608            .expect("python reference run should find a path");
609
610        assert_eq!(sample_index, samples.len());
611        assert_eq!(rrt.node_list.len(), 100);
612        assert_eq!(path.len(), expected_path.len());
613        for (actual, expected) in path.iter().zip(expected_path.iter()) {
614            assert_close(actual[0], expected[0]);
615            assert_close(actual[1], expected[1]);
616        }
617
618        let expected_nodes = [
619            (
620                1,
621                [-0.227_015_105_864_128, 0.973_891_237_104_79],
622                1.0,
623                Some(0),
624            ),
625            (
626                2,
627                [0.340_848_395_898_016, 1.797_013_976_048_647],
628                2.0,
629                Some(1),
630            ),
631            (
632                5,
633                [2.912_922_340_312_151, 1.655_751_229_702_901],
634                5.0,
635                Some(4),
636            ),
637            (
638                10,
639                [5.856_411_905_724_674, 1.320_164_679_038_256],
640                8.0,
641                Some(9),
642            ),
643            (
644                20,
645                [12.105_226_205_468_63, 1.607_428_066_363_632],
646                14.812_039_643_502_144,
647                Some(19),
648            ),
649            (
650                40,
651                [13.266_098_354_827_152, 11.032_918_978_213_733],
652                25.673_392_954_630_07,
653                Some(39),
654            ),
655            (
656                60,
657                [8.777_150_456_317_27, 12.593_447_860_337_104],
658                31.673_392_954_630_07,
659                Some(53),
660            ),
661            (
662                80,
663                [10.550_895_454_349_991, 1.108_429_868_595_935],
664                13.203_336_815_700_968,
665                Some(17),
666            ),
667            (99, [6.0, 10.0], 28.741_122_081_549_424, Some(97)),
668        ];
669
670        for (index, xy, cost, parent) in expected_nodes {
671            let node = &rrt.node_list[index];
672            assert_close(node.x, xy[0]);
673            assert_close(node.y, xy[1]);
674            assert_close(node.cost, cost);
675            assert_eq!(node.parent, parent);
676        }
677    }
678}