1use rand::Rng;
8
9use rust_robotics_core::{Path2D, PathPlanner, Point2D, RoboticsError};
10
11#[derive(Debug, Clone)]
13pub struct CircleObstacle {
14 pub x: f64,
15 pub y: f64,
16 pub radius: f64,
17}
18
19impl CircleObstacle {
20 pub fn new(x: f64, y: f64, radius: f64) -> Self {
21 Self { x, y, radius }
22 }
23}
24
25#[derive(Debug, Clone)]
27pub struct AreaBounds {
28 pub xmin: f64,
29 pub xmax: f64,
30 pub ymin: f64,
31 pub ymax: f64,
32}
33
34impl AreaBounds {
35 pub fn new(xmin: f64, xmax: f64, ymin: f64, ymax: f64) -> Self {
36 Self {
37 xmin,
38 xmax,
39 ymin,
40 ymax,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct RRTConnectConfig {
48 pub expand_dis: f64,
50 pub path_resolution: f64,
52 pub max_iter: usize,
54 pub robot_radius: f64,
56}
57
58impl Default for RRTConnectConfig {
59 fn default() -> Self {
60 Self {
61 expand_dis: 3.0,
62 path_resolution: 0.5,
63 max_iter: 500,
64 robot_radius: 0.8,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
70struct RRTNode {
71 x: f64,
72 y: f64,
73 parent: Option<usize>,
74}
75
76impl RRTNode {
77 fn new(x: f64, y: f64) -> Self {
78 Self { x, y, parent: None }
79 }
80
81 fn to_point(&self) -> Point2D {
82 Point2D::new(self.x, self.y)
83 }
84}
85
86#[derive(Debug, Clone, Copy, PartialEq, Eq)]
87enum ExtendStatus {
88 Advanced,
89 Reached,
90}
91
92#[allow(dead_code)]
93#[derive(Debug, Clone, Copy, PartialEq, Eq)]
94enum ConnectPolicy {
95 AggressiveConnect,
96 ExtendOnce,
97}
98
99pub struct RRTConnectPlanner {
101 config: RRTConnectConfig,
102 obstacles: Vec<CircleObstacle>,
103 rand_area: AreaBounds,
104}
105
106impl RRTConnectPlanner {
107 pub fn new(
108 obstacles: Vec<CircleObstacle>,
109 rand_area: AreaBounds,
110 config: RRTConnectConfig,
111 ) -> Self {
112 Self {
113 config,
114 obstacles,
115 rand_area,
116 }
117 }
118
119 fn get_random_node(&self) -> RRTNode {
120 let mut rng = rand::rng();
121 RRTNode::new(
122 rng.random_range(self.rand_area.xmin..=self.rand_area.xmax),
123 rng.random_range(self.rand_area.ymin..=self.rand_area.ymax),
124 )
125 }
126
127 fn dist(ax: f64, ay: f64, bx: f64, by: f64) -> f64 {
128 let dx = ax - bx;
129 let dy = ay - by;
130 (dx * dx + dy * dy).sqrt()
131 }
132
133 fn get_nearest_node_index(tree: &[RRTNode], target: &RRTNode) -> usize {
134 tree.iter()
135 .enumerate()
136 .map(|(i, node)| {
137 let dx = node.x - target.x;
138 let dy = node.y - target.y;
139 (i, dx * dx + dy * dy)
140 })
141 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
142 .map(|(i, _)| i)
143 .unwrap_or(0)
144 }
145
146 fn point_in_collision(&self, x: f64, y: f64) -> bool {
147 self.obstacles
148 .iter()
149 .any(|obs| Self::dist(x, y, obs.x, obs.y) <= obs.radius + self.config.robot_radius)
150 }
151
152 fn steer(
153 &self,
154 from: &RRTNode,
155 to: &RRTNode,
156 parent_idx: usize,
157 ) -> Option<(RRTNode, ExtendStatus)> {
158 let dx = to.x - from.x;
159 let dy = to.y - from.y;
160 let distance = (dx * dx + dy * dy).sqrt();
161 if distance < f64::EPSILON {
162 return None;
163 }
164
165 let theta = dy.atan2(dx);
166 let step = distance.min(self.config.expand_dis);
167 let n_steps = (step / self.config.path_resolution).floor() as usize;
168
169 let mut cx = from.x;
170 let mut cy = from.y;
171 for _ in 0..n_steps {
172 cx += self.config.path_resolution * theta.cos();
173 cy += self.config.path_resolution * theta.sin();
174 if self.point_in_collision(cx, cy) {
175 return None;
176 }
177 }
178
179 let mut status = ExtendStatus::Advanced;
180 if Self::dist(cx, cy, to.x, to.y) <= self.config.path_resolution {
181 cx = to.x;
182 cy = to.y;
183 status = ExtendStatus::Reached;
184 }
185 if self.point_in_collision(cx, cy) {
186 return None;
187 }
188
189 Some((
190 RRTNode {
191 x: cx,
192 y: cy,
193 parent: Some(parent_idx),
194 },
195 status,
196 ))
197 }
198
199 fn extend_tree(
200 &self,
201 tree: &mut Vec<RRTNode>,
202 target: &RRTNode,
203 ) -> Option<(usize, ExtendStatus)> {
204 let nearest_idx = Self::get_nearest_node_index(tree, target);
205 let nearest = tree[nearest_idx].clone();
206 let (node, status) = self.steer(&nearest, target, nearest_idx)?;
207 tree.push(node);
208 Some((tree.len() - 1, status))
209 }
210
211 fn connect_tree(
212 &self,
213 tree: &mut Vec<RRTNode>,
214 target: &RRTNode,
215 policy: ConnectPolicy,
216 ) -> Option<(usize, ExtendStatus)> {
217 match policy {
218 ConnectPolicy::ExtendOnce => self.extend_tree(tree, target),
219 ConnectPolicy::AggressiveConnect => {
220 let mut latest = None;
221 while let Some((new_idx, status)) = self.extend_tree(tree, target) {
222 latest = Some((new_idx, status));
223 if status == ExtendStatus::Reached {
224 break;
225 }
226 }
227 latest
228 }
229 }
230 }
231
232 fn trace_path(tree: &[RRTNode], idx: usize) -> Vec<Point2D> {
233 let mut path = Vec::new();
234 let mut current = Some(idx);
235 while let Some(i) = current {
236 path.push(tree[i].to_point());
237 current = tree[i].parent;
238 }
239 path.reverse();
240 path
241 }
242
243 fn reconstruct_path(
244 &self,
245 tree_a: &[RRTNode],
246 idx_a: usize,
247 tree_b: &[RRTNode],
248 idx_b: usize,
249 a_is_start: bool,
250 ) -> Path2D {
251 let mut path_a = Self::trace_path(tree_a, idx_a);
252 let mut path_b = Self::trace_path(tree_b, idx_b);
253
254 if a_is_start {
255 path_b.reverse();
256 path_a.extend(path_b.into_iter().skip(1));
257 Path2D::from_points(path_a)
258 } else {
259 path_a.reverse();
260 path_b.extend(path_a.into_iter().skip(1));
261 Path2D::from_points(path_b)
262 }
263 }
264
265 fn run_with_sampler<F>(
266 &self,
267 start: Point2D,
268 goal: Point2D,
269 policy: ConnectPolicy,
270 mut sample: F,
271 ) -> Result<(Path2D, usize), RoboticsError>
272 where
273 F: FnMut() -> RRTNode,
274 {
275 let mut tree_a = vec![RRTNode::new(start.x, start.y)];
276 let mut tree_b = vec![RRTNode::new(goal.x, goal.y)];
277 let mut a_is_start = true;
278
279 for iter in 0..self.config.max_iter {
280 let rnd = sample();
281 if let Some((new_idx_a, _)) = self.extend_tree(&mut tree_a, &rnd) {
282 let target = tree_a[new_idx_a].clone();
283 if let Some((new_idx_b, status)) = self.connect_tree(&mut tree_b, &target, policy) {
284 if status == ExtendStatus::Reached {
285 let path = self
286 .reconstruct_path(&tree_a, new_idx_a, &tree_b, new_idx_b, a_is_start);
287 return Ok((path, iter + 1));
288 }
289 }
290 }
291
292 std::mem::swap(&mut tree_a, &mut tree_b);
293 a_is_start = !a_is_start;
294 }
295
296 Err(RoboticsError::PlanningError(
297 "RRTConnect: Cannot find path within max iterations".to_string(),
298 ))
299 }
300
301 fn run(
302 &self,
303 start: Point2D,
304 goal: Point2D,
305 policy: ConnectPolicy,
306 ) -> Result<(Path2D, usize), RoboticsError> {
307 self.run_with_sampler(start, goal, policy, || self.get_random_node())
308 }
309}
310
311impl PathPlanner for RRTConnectPlanner {
312 fn plan(&self, start: Point2D, goal: Point2D) -> Result<Path2D, RoboticsError> {
313 self.run(start, goal, ConnectPolicy::AggressiveConnect)
314 .map(|(path, _)| path)
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 fn assert_collision_free(path: &Path2D, obstacles: &[CircleObstacle], robot_radius: f64) {
323 for point in &path.points {
324 for obs in obstacles {
325 let d = ((point.x - obs.x).powi(2) + (point.y - obs.y).powi(2)).sqrt();
326 assert!(
327 d > obs.radius + robot_radius,
328 "path collides with obstacle at ({}, {})",
329 obs.x,
330 obs.y
331 );
332 }
333 }
334 }
335
336 #[test]
337 fn test_rrt_connect_finds_path_no_obstacles() {
338 let planner = RRTConnectPlanner::new(
339 vec![],
340 AreaBounds::new(-2.0, 15.0, -2.0, 15.0),
341 RRTConnectConfig::default(),
342 );
343 let result = planner.plan(Point2D::new(0.0, 0.0), Point2D::new(10.0, 10.0));
344 assert!(
345 result.is_ok(),
346 "expected a path but got: {:?}",
347 result.err()
348 );
349 let path = result.unwrap();
350 assert!(path.points.len() >= 2);
351 }
352
353 #[test]
354 fn test_rrt_connect_finds_path_with_obstacles() {
355 let obstacles = vec![
356 CircleObstacle::new(5.0, 5.0, 1.0),
357 CircleObstacle::new(3.0, 6.0, 2.0),
358 CircleObstacle::new(3.0, 8.0, 2.0),
359 CircleObstacle::new(3.0, 10.0, 2.0),
360 CircleObstacle::new(7.0, 5.0, 2.0),
361 CircleObstacle::new(9.0, 5.0, 2.0),
362 CircleObstacle::new(8.0, 10.0, 1.0),
363 ];
364 let planner = RRTConnectPlanner::new(
365 obstacles.clone(),
366 AreaBounds::new(-2.0, 15.0, -2.0, 15.0),
367 RRTConnectConfig {
368 max_iter: 2000,
369 ..Default::default()
370 },
371 );
372 let result = planner.plan(Point2D::new(0.0, 0.0), Point2D::new(6.0, 10.0));
373 assert!(
374 result.is_ok(),
375 "expected a path but got: {:?}",
376 result.err()
377 );
378 let path = result.unwrap();
379 assert_collision_free(&path, &obstacles, RRTConnectConfig::default().robot_radius);
380 }
381
382 #[test]
383 fn test_rrt_connect_requires_fewer_iterations_than_extend_only() {
384 let planner = RRTConnectPlanner::new(
385 vec![],
386 AreaBounds::new(-5.0, 20.0, -5.0, 20.0),
387 RRTConnectConfig::default(),
388 );
389 let start = Point2D::new(0.0, 0.0);
390 let goal = Point2D::new(12.0, 0.0);
391 let samples = [[6.0, 0.0], [6.0, 0.0], [6.0, 0.0]];
392
393 let mut idx_connect = 0usize;
394 let (_, connect_iters) = planner
395 .run_with_sampler(start, goal, ConnectPolicy::AggressiveConnect, || {
396 let sample = samples[idx_connect.min(samples.len() - 1)];
397 idx_connect += 1;
398 RRTNode::new(sample[0], sample[1])
399 })
400 .expect("connect should find a path");
401
402 let mut idx_extend = 0usize;
403 let (_, extend_iters) = planner
404 .run_with_sampler(start, goal, ConnectPolicy::ExtendOnce, || {
405 let sample = samples[idx_extend.min(samples.len() - 1)];
406 idx_extend += 1;
407 RRTNode::new(sample[0], sample[1])
408 })
409 .expect("extend-only should find a path");
410
411 assert!(connect_iters < extend_iters);
412 }
413}