Skip to main content

rust_robotics_planning/
rrt.rs

1#![allow(clippy::too_many_arguments)]
2
3//! RRT (Rapidly-exploring Random Tree) path planning algorithm
4//!
5//! Sampling-based path planning algorithm that builds a tree by
6//! randomly sampling the configuration space.
7
8use rand::Rng;
9
10use rust_robotics_core::{Path2D, PathPlanner, Point2D, RoboticsError};
11
12/// Internal node for RRT tree
13#[derive(Debug, Clone)]
14pub struct RRTNode {
15    pub x: f64,
16    pub y: f64,
17    pub path_x: Vec<f64>,
18    pub path_y: Vec<f64>,
19    pub parent: Option<usize>,
20}
21
22impl RRTNode {
23    pub fn new(x: f64, y: f64) -> Self {
24        RRTNode {
25            x,
26            y,
27            path_x: Vec::new(),
28            path_y: Vec::new(),
29            parent: None,
30        }
31    }
32    pub fn to_point(&self) -> Point2D {
33        Point2D::new(self.x, self.y)
34    }
35}
36
37/// Area bounds for RRT search space
38#[derive(Debug, Clone)]
39pub struct AreaBounds {
40    pub xmin: f64,
41    pub xmax: f64,
42    pub ymin: f64,
43    pub ymax: f64,
44}
45
46impl AreaBounds {
47    pub fn new(xmin: f64, xmax: f64, ymin: f64, ymax: f64) -> Self {
48        AreaBounds {
49            xmin,
50            xmax,
51            ymin,
52            ymax,
53        }
54    }
55    pub fn from_array(area: [f64; 4]) -> Self {
56        AreaBounds {
57            xmin: area[0],
58            xmax: area[1],
59            ymin: area[2],
60            ymax: area[3],
61        }
62    }
63}
64
65/// Circular obstacle (x, y, radius)
66#[derive(Debug, Clone)]
67pub struct CircleObstacle {
68    pub x: f64,
69    pub y: f64,
70    pub radius: f64,
71}
72
73impl CircleObstacle {
74    pub fn new(x: f64, y: f64, radius: f64) -> Self {
75        Self { x, y, radius }
76    }
77}
78
79/// Configuration for RRT planner
80#[derive(Debug, Clone)]
81pub struct RRTConfig {
82    pub expand_dis: f64,
83    pub path_resolution: f64,
84    pub goal_sample_rate: i32,
85    pub max_iter: usize,
86    pub robot_radius: f64,
87}
88
89impl Default for RRTConfig {
90    fn default() -> Self {
91        Self {
92            expand_dis: 3.0,
93            path_resolution: 0.5,
94            goal_sample_rate: 5,
95            max_iter: 500,
96            robot_radius: 0.8,
97        }
98    }
99}
100
101/// RRT path planner
102pub struct RRTPlanner {
103    config: RRTConfig,
104    obstacles: Vec<CircleObstacle>,
105    play_area: Option<AreaBounds>,
106    rand_area: AreaBounds,
107    node_list: Vec<RRTNode>,
108    _start: RRTNode,
109    goal: RRTNode,
110}
111
112impl RRTPlanner {
113    pub fn new(
114        obstacles: Vec<CircleObstacle>,
115        rand_area: AreaBounds,
116        play_area: Option<AreaBounds>,
117        config: RRTConfig,
118    ) -> Self {
119        RRTPlanner {
120            config,
121            obstacles,
122            play_area,
123            rand_area,
124            node_list: Vec::new(),
125            _start: RRTNode::new(0.0, 0.0),
126            goal: RRTNode::new(0.0, 0.0),
127        }
128    }
129
130    pub fn from_obstacles(
131        obstacle_list: Vec<(f64, f64, f64)>,
132        rand_area: [f64; 2],
133        expand_dis: f64,
134        path_resolution: f64,
135        goal_sample_rate: i32,
136        max_iter: usize,
137        play_area: Option<[f64; 4]>,
138        robot_radius: f64,
139    ) -> Self {
140        let obstacles = obstacle_list
141            .into_iter()
142            .map(|(x, y, r)| CircleObstacle::new(x, y, r))
143            .collect();
144        let config = RRTConfig {
145            expand_dis,
146            path_resolution,
147            goal_sample_rate,
148            max_iter,
149            robot_radius,
150        };
151        let rand_bounds = AreaBounds::new(rand_area[0], rand_area[1], rand_area[0], rand_area[1]);
152        let play_bounds = play_area.map(AreaBounds::from_array);
153        Self::new(obstacles, rand_bounds, play_bounds, config)
154    }
155
156    pub fn planning(&mut self, start: [f64; 2], goal: [f64; 2]) -> Option<Vec<[f64; 2]>> {
157        let start_pt = Point2D::new(start[0], start[1]);
158        let goal_pt = Point2D::new(goal[0], goal[1]);
159        // Plan on `self` (not a clone) so `get_tree` exposes the search tree
160        // afterwards.
161        match self.plan_with_sampler(start_pt, goal_pt, |planner| planner.get_random_node()) {
162            Ok(path) => Some(path.points.iter().map(|p| [p.x, p.y]).collect()),
163            Err(_) => None,
164        }
165    }
166
167    pub fn get_tree(&self) -> &[RRTNode] {
168        &self.node_list
169    }
170    pub fn get_obstacles(&self) -> &[CircleObstacle] {
171        &self.obstacles
172    }
173
174    fn reset_search(&mut self, start: Point2D, goal: Point2D) {
175        self.node_list = vec![RRTNode::new(start.x, start.y)];
176        self._start = RRTNode::new(start.x, start.y);
177        self.goal = RRTNode::new(goal.x, goal.y);
178    }
179
180    fn steer(&self, from_node: &RRTNode, to_node: &RRTNode, extend_length: f64) -> RRTNode {
181        let mut new_node = RRTNode::new(from_node.x, from_node.y);
182        let (d, theta) = self.calc_distance_and_angle(from_node, to_node);
183        new_node.path_x = vec![new_node.x];
184        new_node.path_y = vec![new_node.y];
185        let extend_length = extend_length.min(d);
186        let n_expand = (extend_length / self.config.path_resolution).floor() as usize;
187        for _ in 0..n_expand {
188            new_node.x += self.config.path_resolution * theta.cos();
189            new_node.y += self.config.path_resolution * theta.sin();
190            new_node.path_x.push(new_node.x);
191            new_node.path_y.push(new_node.y);
192        }
193        let (d, _) = self.calc_distance_and_angle(&new_node, to_node);
194        if d <= self.config.path_resolution {
195            new_node.path_x.push(to_node.x);
196            new_node.path_y.push(to_node.y);
197            new_node.x = to_node.x;
198            new_node.y = to_node.y;
199        }
200        new_node
201    }
202
203    fn generate_final_course(&self, goal_ind: usize) -> Path2D {
204        let mut points = vec![self.goal.to_point()];
205        let mut node_index = Some(goal_ind);
206        while let Some(index) = node_index {
207            let node = &self.node_list[index];
208            points.push(node.to_point());
209            node_index = node.parent;
210        }
211        points.reverse();
212        Path2D::from_points(points)
213    }
214
215    fn calc_dist_to_goal(&self, x: f64, y: f64) -> f64 {
216        let dx = x - self.goal.x;
217        let dy = y - self.goal.y;
218        (dx * dx + dy * dy).sqrt()
219    }
220
221    fn get_random_node(&self) -> RRTNode {
222        let mut rng = rand::rng();
223        if rng.random_range(0..=100) > self.config.goal_sample_rate {
224            RRTNode::new(
225                rng.random_range(self.rand_area.xmin..=self.rand_area.xmax),
226                rng.random_range(self.rand_area.ymin..=self.rand_area.ymax),
227            )
228        } else {
229            RRTNode::new(self.goal.x, self.goal.y)
230        }
231    }
232
233    fn get_nearest_node_index(&self, rnd_node: &RRTNode) -> usize {
234        let mut min_dist = f64::INFINITY;
235        let mut min_ind = 0;
236        for (i, node) in self.node_list.iter().enumerate() {
237            let dist = (node.x - rnd_node.x).powi(2) + (node.y - rnd_node.y).powi(2);
238            if dist < min_dist {
239                min_dist = dist;
240                min_ind = i;
241            }
242        }
243        min_ind
244    }
245
246    fn check_if_outside_play_area(&self, node: &RRTNode) -> bool {
247        if let Some(ref play_area) = self.play_area {
248            if node.x < play_area.xmin
249                || node.x > play_area.xmax
250                || node.y < play_area.ymin
251                || node.y > play_area.ymax
252            {
253                return false;
254            }
255        }
256        true
257    }
258
259    fn check_collision(&self, node: &RRTNode) -> bool {
260        for obs in &self.obstacles {
261            for (&px, &py) in node.path_x.iter().zip(node.path_y.iter()) {
262                let dx = obs.x - px;
263                let dy = obs.y - py;
264                let d = (dx * dx + dy * dy).sqrt();
265                if d <= obs.radius + self.config.robot_radius {
266                    return false;
267                }
268            }
269        }
270        true
271    }
272
273    fn calc_distance_and_angle(&self, from_node: &RRTNode, to_node: &RRTNode) -> (f64, f64) {
274        let dx = to_node.x - from_node.x;
275        let dy = to_node.y - from_node.y;
276        ((dx * dx + dy * dy).sqrt(), dy.atan2(dx))
277    }
278
279    pub(crate) fn plan_with_sampler<F>(
280        &mut self,
281        start: Point2D,
282        goal: Point2D,
283        mut sample_node: F,
284    ) -> Result<Path2D, RoboticsError>
285    where
286        F: FnMut(&RRTPlanner) -> RRTNode,
287    {
288        self.reset_search(start, goal);
289        for _ in 0..self.config.max_iter {
290            let rnd_node = sample_node(self);
291            let nearest_ind = self.get_nearest_node_index(&rnd_node);
292            let nearest_node = self.node_list[nearest_ind].clone();
293            let new_node = self.steer(&nearest_node, &rnd_node, self.config.expand_dis);
294            if self.check_if_outside_play_area(&new_node) && self.check_collision(&new_node) {
295                let mut new_node = new_node;
296                new_node.parent = Some(nearest_ind);
297                self.node_list.push(new_node);
298                let last = self.node_list.last().unwrap();
299                if self.calc_dist_to_goal(last.x, last.y) <= self.config.expand_dis {
300                    let final_node = self.steer(last, &self.goal.clone(), self.config.expand_dis);
301                    if self.check_collision(&final_node) {
302                        return Ok(self.generate_final_course(self.node_list.len() - 1));
303                    }
304                }
305            }
306        }
307        Err(RoboticsError::PlanningError(
308            "RRT: Cannot find path within max iterations".to_string(),
309        ))
310    }
311}
312
313impl PathPlanner for RRTPlanner {
314    fn plan(&self, start: Point2D, goal: Point2D) -> Result<Path2D, RoboticsError> {
315        let mut planner = RRTPlanner {
316            config: self.config.clone(),
317            obstacles: self.obstacles.clone(),
318            play_area: self.play_area.clone(),
319            rand_area: self.rand_area.clone(),
320            node_list: vec![RRTNode::new(start.x, start.y)],
321            _start: RRTNode::new(start.x, start.y),
322            goal: RRTNode::new(goal.x, goal.y),
323        };
324        planner.plan_with_sampler(start, goal, |planner| planner.get_random_node())
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331
332    fn assert_close(actual: f64, expected: f64) {
333        assert!(
334            (actual - expected).abs() < 1.0e-12,
335            "expected {expected}, got {actual}"
336        );
337    }
338
339    fn assert_point_close(actual: &Point2D, expected: [f64; 2]) {
340        assert_close(actual.x, expected[0]);
341        assert_close(actual.y, expected[1]);
342    }
343
344    fn parse_xy_fixture(csv: &str) -> Vec<[f64; 2]> {
345        csv.lines()
346            .skip(1)
347            .filter(|line| !line.trim().is_empty())
348            .map(|line| {
349                let (x, y) = line
350                    .split_once(',')
351                    .expect("xy fixture rows must contain a comma");
352                [x.parse().unwrap(), y.parse().unwrap()]
353            })
354            .collect()
355    }
356
357    fn create_test_planner() -> RRTPlanner {
358        let obstacles = vec![
359            CircleObstacle::new(5.0, 5.0, 1.0),
360            CircleObstacle::new(3.0, 6.0, 2.0),
361            CircleObstacle::new(7.0, 5.0, 2.0),
362        ];
363        let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
364        let config = RRTConfig {
365            max_iter: 1000,
366            ..Default::default()
367        };
368        RRTPlanner::new(obstacles, rand_area, None, config)
369    }
370
371    fn create_pythonrobotics_main_planner() -> RRTPlanner {
372        let obstacles = vec![
373            CircleObstacle::new(5.0, 5.0, 1.0),
374            CircleObstacle::new(3.0, 6.0, 2.0),
375            CircleObstacle::new(3.0, 8.0, 2.0),
376            CircleObstacle::new(3.0, 10.0, 2.0),
377            CircleObstacle::new(7.0, 5.0, 2.0),
378            CircleObstacle::new(9.0, 5.0, 2.0),
379            CircleObstacle::new(8.0, 10.0, 1.0),
380        ];
381        let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
382        let config = RRTConfig {
383            robot_radius: 0.8,
384            ..Default::default()
385        };
386        RRTPlanner::new(obstacles, rand_area, None, config)
387    }
388
389    #[test]
390    fn test_rrt_finds_path() {
391        let planner = create_test_planner();
392        let start = Point2D::new(0.0, 0.0);
393        let goal = Point2D::new(10.0, 10.0);
394        let result = planner.plan(start, goal);
395        assert!(result.is_ok() || result.is_err());
396    }
397
398    #[test]
399    fn test_rrt_config_default() {
400        let config = RRTConfig::default();
401        assert_eq!(config.expand_dis, 3.0);
402        assert_eq!(config.max_iter, 500);
403    }
404
405    #[test]
406    fn test_rrt_upstream_test_rrt_short_goal_matches_pythonrobotics_reference() {
407        let mut planner = create_pythonrobotics_main_planner();
408        let start = Point2D::new(0.0, 0.0);
409        let goal = Point2D::new(1.0, 1.0);
410        let sample = [10.455_649_682_677_358, 11.942_970_283_541_907];
411        let path = planner
412            .plan_with_sampler(start, goal, |_| RRTNode::new(sample[0], sample[1]))
413            .unwrap();
414
415        assert_eq!(planner.node_list.len(), 2);
416        assert_eq!(path.points.len(), 3);
417        assert_point_close(&path.points[0], [0.0, 0.0]);
418        assert_point_close(
419            &path.points[1],
420            [1.976_107_921_083_293_5, 2.257_210_110_785_406],
421        );
422        assert_point_close(&path.points[2], [1.0, 1.0]);
423        assert_eq!(planner.node_list[1].parent, Some(0));
424        assert_eq!(planner.node_list[1].path_x.len(), 7);
425        assert_close(planner.node_list[1].path_x[1], 0.329_351_320_180_548_95);
426        assert_close(planner.node_list[1].path_y[1], 0.376_201_685_130_900_95);
427    }
428
429    #[test]
430    fn test_rrt_upstream_main_seeded_path_matches_pythonrobotics_reference() {
431        let mut planner = create_pythonrobotics_main_planner();
432        let start = Point2D::new(0.0, 0.0);
433        let goal = Point2D::new(6.0, 10.0);
434        let samples = parse_xy_fixture(include_str!("testdata/rrt_main_seed12345_samples.csv"));
435        let expected_path = parse_xy_fixture(include_str!("testdata/rrt_main_seed12345_path.csv"));
436        let mut sample_iter = samples.iter();
437
438        let path = planner
439            .plan_with_sampler(start, goal, |_| {
440                let sample = sample_iter
441                    .next()
442                    .expect("python reference sample sequence exhausted");
443                RRTNode::new(sample[0], sample[1])
444            })
445            .unwrap();
446
447        assert_eq!(planner.node_list.len(), 88);
448        assert_eq!(path.points.len(), expected_path.len());
449        for (actual, expected) in path.points.iter().zip(expected_path.iter()) {
450            assert_point_close(actual, *expected);
451        }
452
453        let expected_nodes = [
454            (
455                1,
456                [1.976_107_921_083_293, 2.257_210_110_785_406],
457                Some(0),
458                7,
459                [0.0, 0.329_351_320_180_549, 0.658_702_640_361_098],
460                [0.0, 0.376_201_685_130_901, 0.752_403_370_261_802],
461            ),
462            (
463                5,
464                [1.229_513_438_270_946, 3.806_527_238_150_387],
465                Some(1),
466                5,
467                [
468                    1.976_107_921_083_293,
469                    1.759_052_148_041_024,
470                    1.541_996_374_998_755,
471                ],
472                [
473                    2.257_210_110_785_406,
474                    2.707_639_673_968_178,
475                    3.158_069_237_150_95,
476                ],
477            ),
478            (
479                10,
480                [-0.964_870_854_478_227, 5.566_933_984_574_426],
481                Some(9),
482                7,
483                [
484                    0.668_674_115_555_151,
485                    0.358_283_476_598_378,
486                    0.047_892_837_641_605,
487                ],
488                [
489                    3.503_932_336_109_255,
490                    3.895_924_238_127_659,
491                    4.287_916_140_146_064,
492                ],
493            ),
494            (
495                20,
496                [13.060_964_451_302_038, 12.199_474_225_398_257],
497                Some(16),
498                8,
499                [
500                    11.070_558_229_598_33,
501                    11.395_810_000_597_592,
502                    11.721_061_771_596_855,
503                ],
504                [
505                    9.875_551_544_349_548,
506                    10.255_303_154_565_809,
507                    10.635_054_764_782_069,
508                ],
509            ),
510            (
511                87,
512                [5.860_033_119_067_657, 10.721_216_347_248_003],
513                Some(72),
514                7,
515                [
516                    5.288_485_092_568_921,
517                    5.383_743_096_985_377,
518                    5.479_001_101_401_833,
519                ],
520                [
521                    13.666_268_613_915_847,
522                    13.175_426_569_471_206,
523                    12.684_584_525_026_565,
524                ],
525            ),
526        ];
527        for (index, xy, parent, path_len, path_x3, path_y3) in expected_nodes {
528            let node = &planner.node_list[index];
529            assert_close(node.x, xy[0]);
530            assert_close(node.y, xy[1]);
531            assert_eq!(node.parent, parent);
532            assert_eq!(node.path_x.len(), path_len);
533            for (actual, expected) in node.path_x.iter().take(3).zip(path_x3.iter()) {
534                assert_close(*actual, *expected);
535            }
536            for (actual, expected) in node.path_y.iter().take(3).zip(path_y3.iter()) {
537                assert_close(*actual, *expected);
538            }
539        }
540    }
541}