Skip to main content

rust_robotics_planning/
rrt_connect.rs

1//! RRT-Connect path planning algorithm.
2//!
3//! RRT-Connect aggressively grows two trees and uses a CONNECT step that keeps
4//! extending the opposite tree toward a newly added node until trapped or
5//! reached.
6
7use rand::Rng;
8
9use rust_robotics_core::{Path2D, PathPlanner, Point2D, RoboticsError};
10
11/// Circular obstacle \(x, y, radius\).
12#[derive(Debug, Clone)]
13pub struct CircleObstacle {
14    pub x: f64,
15    pub y: f64,
16    pub radius: f64,
17}
18
19impl CircleObstacle {
20    pub fn new(x: f64, y: f64, radius: f64) -> Self {
21        Self { x, y, radius }
22    }
23}
24
25/// Axis-aligned random sampling bounds.
26#[derive(Debug, Clone)]
27pub struct AreaBounds {
28    pub xmin: f64,
29    pub xmax: f64,
30    pub ymin: f64,
31    pub ymax: f64,
32}
33
34impl AreaBounds {
35    pub fn new(xmin: f64, xmax: f64, ymin: f64, ymax: f64) -> Self {
36        Self {
37            xmin,
38            xmax,
39            ymin,
40            ymax,
41        }
42    }
43}
44
45/// Configuration for RRT-Connect planner.
46#[derive(Debug, Clone)]
47pub struct RRTConnectConfig {
48    /// Maximum extension distance per tree growth step \[m\].
49    pub expand_dis: f64,
50    /// Interpolation distance used for collision checks \[m\].
51    pub path_resolution: f64,
52    /// Maximum number of planning iterations.
53    pub max_iter: usize,
54    /// Robot radius used during obstacle collision checks \[m\].
55    pub robot_radius: f64,
56}
57
58impl Default for RRTConnectConfig {
59    fn default() -> Self {
60        Self {
61            expand_dis: 3.0,
62            path_resolution: 0.5,
63            max_iter: 500,
64            robot_radius: 0.8,
65        }
66    }
67}
68
69#[derive(Debug, Clone)]
70struct RRTNode {
71    x: f64,
72    y: f64,
73    parent: Option<usize>,
74}
75
76impl RRTNode {
77    fn new(x: f64, y: f64) -> Self {
78        Self { x, y, parent: None }
79    }
80
81    fn to_point(&self) -> Point2D {
82        Point2D::new(self.x, self.y)
83    }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87enum ExtendStatus {
88    Advanced,
89    Reached,
90}
91
92#[allow(dead_code)]
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94enum ConnectPolicy {
95    AggressiveConnect,
96    ExtendOnce,
97}
98
99/// RRT-Connect planner.
100pub struct RRTConnectPlanner {
101    config: RRTConnectConfig,
102    obstacles: Vec<CircleObstacle>,
103    rand_area: AreaBounds,
104}
105
106impl RRTConnectPlanner {
107    pub fn new(
108        obstacles: Vec<CircleObstacle>,
109        rand_area: AreaBounds,
110        config: RRTConnectConfig,
111    ) -> Self {
112        Self {
113            config,
114            obstacles,
115            rand_area,
116        }
117    }
118
119    fn get_random_node(&self) -> RRTNode {
120        let mut rng = rand::rng();
121        RRTNode::new(
122            rng.random_range(self.rand_area.xmin..=self.rand_area.xmax),
123            rng.random_range(self.rand_area.ymin..=self.rand_area.ymax),
124        )
125    }
126
127    fn dist(ax: f64, ay: f64, bx: f64, by: f64) -> f64 {
128        let dx = ax - bx;
129        let dy = ay - by;
130        (dx * dx + dy * dy).sqrt()
131    }
132
133    fn get_nearest_node_index(tree: &[RRTNode], target: &RRTNode) -> usize {
134        tree.iter()
135            .enumerate()
136            .map(|(i, node)| {
137                let dx = node.x - target.x;
138                let dy = node.y - target.y;
139                (i, dx * dx + dy * dy)
140            })
141            .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
142            .map(|(i, _)| i)
143            .unwrap_or(0)
144    }
145
146    fn point_in_collision(&self, x: f64, y: f64) -> bool {
147        self.obstacles
148            .iter()
149            .any(|obs| Self::dist(x, y, obs.x, obs.y) <= obs.radius + self.config.robot_radius)
150    }
151
152    fn steer(
153        &self,
154        from: &RRTNode,
155        to: &RRTNode,
156        parent_idx: usize,
157    ) -> Option<(RRTNode, ExtendStatus)> {
158        let dx = to.x - from.x;
159        let dy = to.y - from.y;
160        let distance = (dx * dx + dy * dy).sqrt();
161        if distance < f64::EPSILON {
162            return None;
163        }
164
165        let theta = dy.atan2(dx);
166        let step = distance.min(self.config.expand_dis);
167        let n_steps = (step / self.config.path_resolution).floor() as usize;
168
169        let mut cx = from.x;
170        let mut cy = from.y;
171        for _ in 0..n_steps {
172            cx += self.config.path_resolution * theta.cos();
173            cy += self.config.path_resolution * theta.sin();
174            if self.point_in_collision(cx, cy) {
175                return None;
176            }
177        }
178
179        let mut status = ExtendStatus::Advanced;
180        if Self::dist(cx, cy, to.x, to.y) <= self.config.path_resolution {
181            cx = to.x;
182            cy = to.y;
183            status = ExtendStatus::Reached;
184        }
185        if self.point_in_collision(cx, cy) {
186            return None;
187        }
188
189        Some((
190            RRTNode {
191                x: cx,
192                y: cy,
193                parent: Some(parent_idx),
194            },
195            status,
196        ))
197    }
198
199    fn extend_tree(
200        &self,
201        tree: &mut Vec<RRTNode>,
202        target: &RRTNode,
203    ) -> Option<(usize, ExtendStatus)> {
204        let nearest_idx = Self::get_nearest_node_index(tree, target);
205        let nearest = tree[nearest_idx].clone();
206        let (node, status) = self.steer(&nearest, target, nearest_idx)?;
207        tree.push(node);
208        Some((tree.len() - 1, status))
209    }
210
211    fn connect_tree(
212        &self,
213        tree: &mut Vec<RRTNode>,
214        target: &RRTNode,
215        policy: ConnectPolicy,
216    ) -> Option<(usize, ExtendStatus)> {
217        match policy {
218            ConnectPolicy::ExtendOnce => self.extend_tree(tree, target),
219            ConnectPolicy::AggressiveConnect => {
220                let mut latest = None;
221                while let Some((new_idx, status)) = self.extend_tree(tree, target) {
222                    latest = Some((new_idx, status));
223                    if status == ExtendStatus::Reached {
224                        break;
225                    }
226                }
227                latest
228            }
229        }
230    }
231
232    fn trace_path(tree: &[RRTNode], idx: usize) -> Vec<Point2D> {
233        let mut path = Vec::new();
234        let mut current = Some(idx);
235        while let Some(i) = current {
236            path.push(tree[i].to_point());
237            current = tree[i].parent;
238        }
239        path.reverse();
240        path
241    }
242
243    fn reconstruct_path(
244        &self,
245        tree_a: &[RRTNode],
246        idx_a: usize,
247        tree_b: &[RRTNode],
248        idx_b: usize,
249        a_is_start: bool,
250    ) -> Path2D {
251        let mut path_a = Self::trace_path(tree_a, idx_a);
252        let mut path_b = Self::trace_path(tree_b, idx_b);
253
254        if a_is_start {
255            path_b.reverse();
256            path_a.extend(path_b.into_iter().skip(1));
257            Path2D::from_points(path_a)
258        } else {
259            path_a.reverse();
260            path_b.extend(path_a.into_iter().skip(1));
261            Path2D::from_points(path_b)
262        }
263    }
264
265    fn run_with_sampler<F>(
266        &self,
267        start: Point2D,
268        goal: Point2D,
269        policy: ConnectPolicy,
270        mut sample: F,
271    ) -> Result<(Path2D, usize), RoboticsError>
272    where
273        F: FnMut() -> RRTNode,
274    {
275        let mut tree_a = vec![RRTNode::new(start.x, start.y)];
276        let mut tree_b = vec![RRTNode::new(goal.x, goal.y)];
277        let mut a_is_start = true;
278
279        for iter in 0..self.config.max_iter {
280            let rnd = sample();
281            if let Some((new_idx_a, _)) = self.extend_tree(&mut tree_a, &rnd) {
282                let target = tree_a[new_idx_a].clone();
283                if let Some((new_idx_b, status)) = self.connect_tree(&mut tree_b, &target, policy) {
284                    if status == ExtendStatus::Reached {
285                        let path = self
286                            .reconstruct_path(&tree_a, new_idx_a, &tree_b, new_idx_b, a_is_start);
287                        return Ok((path, iter + 1));
288                    }
289                }
290            }
291
292            std::mem::swap(&mut tree_a, &mut tree_b);
293            a_is_start = !a_is_start;
294        }
295
296        Err(RoboticsError::PlanningError(
297            "RRTConnect: Cannot find path within max iterations".to_string(),
298        ))
299    }
300
301    fn run(
302        &self,
303        start: Point2D,
304        goal: Point2D,
305        policy: ConnectPolicy,
306    ) -> Result<(Path2D, usize), RoboticsError> {
307        self.run_with_sampler(start, goal, policy, || self.get_random_node())
308    }
309}
310
311impl PathPlanner for RRTConnectPlanner {
312    fn plan(&self, start: Point2D, goal: Point2D) -> Result<Path2D, RoboticsError> {
313        self.run(start, goal, ConnectPolicy::AggressiveConnect)
314            .map(|(path, _)| path)
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    fn assert_collision_free(path: &Path2D, obstacles: &[CircleObstacle], robot_radius: f64) {
323        for point in &path.points {
324            for obs in obstacles {
325                let d = ((point.x - obs.x).powi(2) + (point.y - obs.y).powi(2)).sqrt();
326                assert!(
327                    d > obs.radius + robot_radius,
328                    "path collides with obstacle at ({}, {})",
329                    obs.x,
330                    obs.y
331                );
332            }
333        }
334    }
335
336    #[test]
337    fn test_rrt_connect_finds_path_no_obstacles() {
338        let planner = RRTConnectPlanner::new(
339            vec![],
340            AreaBounds::new(-2.0, 15.0, -2.0, 15.0),
341            RRTConnectConfig::default(),
342        );
343        let result = planner.plan(Point2D::new(0.0, 0.0), Point2D::new(10.0, 10.0));
344        assert!(
345            result.is_ok(),
346            "expected a path but got: {:?}",
347            result.err()
348        );
349        let path = result.unwrap();
350        assert!(path.points.len() >= 2);
351    }
352
353    #[test]
354    fn test_rrt_connect_finds_path_with_obstacles() {
355        let obstacles = vec![
356            CircleObstacle::new(5.0, 5.0, 1.0),
357            CircleObstacle::new(3.0, 6.0, 2.0),
358            CircleObstacle::new(3.0, 8.0, 2.0),
359            CircleObstacle::new(3.0, 10.0, 2.0),
360            CircleObstacle::new(7.0, 5.0, 2.0),
361            CircleObstacle::new(9.0, 5.0, 2.0),
362            CircleObstacle::new(8.0, 10.0, 1.0),
363        ];
364        let planner = RRTConnectPlanner::new(
365            obstacles.clone(),
366            AreaBounds::new(-2.0, 15.0, -2.0, 15.0),
367            RRTConnectConfig {
368                max_iter: 2000,
369                ..Default::default()
370            },
371        );
372        let result = planner.plan(Point2D::new(0.0, 0.0), Point2D::new(6.0, 10.0));
373        assert!(
374            result.is_ok(),
375            "expected a path but got: {:?}",
376            result.err()
377        );
378        let path = result.unwrap();
379        assert_collision_free(&path, &obstacles, RRTConnectConfig::default().robot_radius);
380    }
381
382    #[test]
383    fn test_rrt_connect_requires_fewer_iterations_than_extend_only() {
384        let planner = RRTConnectPlanner::new(
385            vec![],
386            AreaBounds::new(-5.0, 20.0, -5.0, 20.0),
387            RRTConnectConfig::default(),
388        );
389        let start = Point2D::new(0.0, 0.0);
390        let goal = Point2D::new(12.0, 0.0);
391        let samples = [[6.0, 0.0], [6.0, 0.0], [6.0, 0.0]];
392
393        let mut idx_connect = 0usize;
394        let (_, connect_iters) = planner
395            .run_with_sampler(start, goal, ConnectPolicy::AggressiveConnect, || {
396                let sample = samples[idx_connect.min(samples.len() - 1)];
397                idx_connect += 1;
398                RRTNode::new(sample[0], sample[1])
399            })
400            .expect("connect should find a path");
401
402        let mut idx_extend = 0usize;
403        let (_, extend_iters) = planner
404            .run_with_sampler(start, goal, ConnectPolicy::ExtendOnce, || {
405                let sample = samples[idx_extend.min(samples.len() - 1)];
406                idx_extend += 1;
407                RRTNode::new(sample[0], sample[1])
408            })
409            .expect("extend-only should find a path");
410
411        assert!(connect_iters < extend_iters);
412    }
413}