1#![allow(clippy::too_many_arguments)]
2
3use rand::Rng;
9
10use rust_robotics_core::{Path2D, PathPlanner, Point2D, RoboticsError};
11
12#[derive(Debug, Clone)]
14pub struct RRTNode {
15 pub x: f64,
16 pub y: f64,
17 pub path_x: Vec<f64>,
18 pub path_y: Vec<f64>,
19 pub parent: Option<usize>,
20}
21
22impl RRTNode {
23 pub fn new(x: f64, y: f64) -> Self {
24 RRTNode {
25 x,
26 y,
27 path_x: Vec::new(),
28 path_y: Vec::new(),
29 parent: None,
30 }
31 }
32 pub fn to_point(&self) -> Point2D {
33 Point2D::new(self.x, self.y)
34 }
35}
36
37#[derive(Debug, Clone)]
39pub struct AreaBounds {
40 pub xmin: f64,
41 pub xmax: f64,
42 pub ymin: f64,
43 pub ymax: f64,
44}
45
46impl AreaBounds {
47 pub fn new(xmin: f64, xmax: f64, ymin: f64, ymax: f64) -> Self {
48 AreaBounds {
49 xmin,
50 xmax,
51 ymin,
52 ymax,
53 }
54 }
55 pub fn from_array(area: [f64; 4]) -> Self {
56 AreaBounds {
57 xmin: area[0],
58 xmax: area[1],
59 ymin: area[2],
60 ymax: area[3],
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct CircleObstacle {
68 pub x: f64,
69 pub y: f64,
70 pub radius: f64,
71}
72
73impl CircleObstacle {
74 pub fn new(x: f64, y: f64, radius: f64) -> Self {
75 Self { x, y, radius }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct RRTConfig {
82 pub expand_dis: f64,
83 pub path_resolution: f64,
84 pub goal_sample_rate: i32,
85 pub max_iter: usize,
86 pub robot_radius: f64,
87}
88
89impl Default for RRTConfig {
90 fn default() -> Self {
91 Self {
92 expand_dis: 3.0,
93 path_resolution: 0.5,
94 goal_sample_rate: 5,
95 max_iter: 500,
96 robot_radius: 0.8,
97 }
98 }
99}
100
101pub struct RRTPlanner {
103 config: RRTConfig,
104 obstacles: Vec<CircleObstacle>,
105 play_area: Option<AreaBounds>,
106 rand_area: AreaBounds,
107 node_list: Vec<RRTNode>,
108 _start: RRTNode,
109 goal: RRTNode,
110}
111
112impl RRTPlanner {
113 pub fn new(
114 obstacles: Vec<CircleObstacle>,
115 rand_area: AreaBounds,
116 play_area: Option<AreaBounds>,
117 config: RRTConfig,
118 ) -> Self {
119 RRTPlanner {
120 config,
121 obstacles,
122 play_area,
123 rand_area,
124 node_list: Vec::new(),
125 _start: RRTNode::new(0.0, 0.0),
126 goal: RRTNode::new(0.0, 0.0),
127 }
128 }
129
130 pub fn from_obstacles(
131 obstacle_list: Vec<(f64, f64, f64)>,
132 rand_area: [f64; 2],
133 expand_dis: f64,
134 path_resolution: f64,
135 goal_sample_rate: i32,
136 max_iter: usize,
137 play_area: Option<[f64; 4]>,
138 robot_radius: f64,
139 ) -> Self {
140 let obstacles = obstacle_list
141 .into_iter()
142 .map(|(x, y, r)| CircleObstacle::new(x, y, r))
143 .collect();
144 let config = RRTConfig {
145 expand_dis,
146 path_resolution,
147 goal_sample_rate,
148 max_iter,
149 robot_radius,
150 };
151 let rand_bounds = AreaBounds::new(rand_area[0], rand_area[1], rand_area[0], rand_area[1]);
152 let play_bounds = play_area.map(AreaBounds::from_array);
153 Self::new(obstacles, rand_bounds, play_bounds, config)
154 }
155
156 pub fn planning(&mut self, start: [f64; 2], goal: [f64; 2]) -> Option<Vec<[f64; 2]>> {
157 let start_pt = Point2D::new(start[0], start[1]);
158 let goal_pt = Point2D::new(goal[0], goal[1]);
159 match self.plan_with_sampler(start_pt, goal_pt, |planner| planner.get_random_node()) {
162 Ok(path) => Some(path.points.iter().map(|p| [p.x, p.y]).collect()),
163 Err(_) => None,
164 }
165 }
166
167 pub fn get_tree(&self) -> &[RRTNode] {
168 &self.node_list
169 }
170 pub fn get_obstacles(&self) -> &[CircleObstacle] {
171 &self.obstacles
172 }
173
174 fn reset_search(&mut self, start: Point2D, goal: Point2D) {
175 self.node_list = vec![RRTNode::new(start.x, start.y)];
176 self._start = RRTNode::new(start.x, start.y);
177 self.goal = RRTNode::new(goal.x, goal.y);
178 }
179
180 fn steer(&self, from_node: &RRTNode, to_node: &RRTNode, extend_length: f64) -> RRTNode {
181 let mut new_node = RRTNode::new(from_node.x, from_node.y);
182 let (d, theta) = self.calc_distance_and_angle(from_node, to_node);
183 new_node.path_x = vec![new_node.x];
184 new_node.path_y = vec![new_node.y];
185 let extend_length = extend_length.min(d);
186 let n_expand = (extend_length / self.config.path_resolution).floor() as usize;
187 for _ in 0..n_expand {
188 new_node.x += self.config.path_resolution * theta.cos();
189 new_node.y += self.config.path_resolution * theta.sin();
190 new_node.path_x.push(new_node.x);
191 new_node.path_y.push(new_node.y);
192 }
193 let (d, _) = self.calc_distance_and_angle(&new_node, to_node);
194 if d <= self.config.path_resolution {
195 new_node.path_x.push(to_node.x);
196 new_node.path_y.push(to_node.y);
197 new_node.x = to_node.x;
198 new_node.y = to_node.y;
199 }
200 new_node
201 }
202
203 fn generate_final_course(&self, goal_ind: usize) -> Path2D {
204 let mut points = vec![self.goal.to_point()];
205 let mut node_index = Some(goal_ind);
206 while let Some(index) = node_index {
207 let node = &self.node_list[index];
208 points.push(node.to_point());
209 node_index = node.parent;
210 }
211 points.reverse();
212 Path2D::from_points(points)
213 }
214
215 fn calc_dist_to_goal(&self, x: f64, y: f64) -> f64 {
216 let dx = x - self.goal.x;
217 let dy = y - self.goal.y;
218 (dx * dx + dy * dy).sqrt()
219 }
220
221 fn get_random_node(&self) -> RRTNode {
222 let mut rng = rand::rng();
223 if rng.random_range(0..=100) > self.config.goal_sample_rate {
224 RRTNode::new(
225 rng.random_range(self.rand_area.xmin..=self.rand_area.xmax),
226 rng.random_range(self.rand_area.ymin..=self.rand_area.ymax),
227 )
228 } else {
229 RRTNode::new(self.goal.x, self.goal.y)
230 }
231 }
232
233 fn get_nearest_node_index(&self, rnd_node: &RRTNode) -> usize {
234 let mut min_dist = f64::INFINITY;
235 let mut min_ind = 0;
236 for (i, node) in self.node_list.iter().enumerate() {
237 let dist = (node.x - rnd_node.x).powi(2) + (node.y - rnd_node.y).powi(2);
238 if dist < min_dist {
239 min_dist = dist;
240 min_ind = i;
241 }
242 }
243 min_ind
244 }
245
246 fn check_if_outside_play_area(&self, node: &RRTNode) -> bool {
247 if let Some(ref play_area) = self.play_area {
248 if node.x < play_area.xmin
249 || node.x > play_area.xmax
250 || node.y < play_area.ymin
251 || node.y > play_area.ymax
252 {
253 return false;
254 }
255 }
256 true
257 }
258
259 fn check_collision(&self, node: &RRTNode) -> bool {
260 for obs in &self.obstacles {
261 for (&px, &py) in node.path_x.iter().zip(node.path_y.iter()) {
262 let dx = obs.x - px;
263 let dy = obs.y - py;
264 let d = (dx * dx + dy * dy).sqrt();
265 if d <= obs.radius + self.config.robot_radius {
266 return false;
267 }
268 }
269 }
270 true
271 }
272
273 fn calc_distance_and_angle(&self, from_node: &RRTNode, to_node: &RRTNode) -> (f64, f64) {
274 let dx = to_node.x - from_node.x;
275 let dy = to_node.y - from_node.y;
276 ((dx * dx + dy * dy).sqrt(), dy.atan2(dx))
277 }
278
279 pub(crate) fn plan_with_sampler<F>(
280 &mut self,
281 start: Point2D,
282 goal: Point2D,
283 mut sample_node: F,
284 ) -> Result<Path2D, RoboticsError>
285 where
286 F: FnMut(&RRTPlanner) -> RRTNode,
287 {
288 self.reset_search(start, goal);
289 for _ in 0..self.config.max_iter {
290 let rnd_node = sample_node(self);
291 let nearest_ind = self.get_nearest_node_index(&rnd_node);
292 let nearest_node = self.node_list[nearest_ind].clone();
293 let new_node = self.steer(&nearest_node, &rnd_node, self.config.expand_dis);
294 if self.check_if_outside_play_area(&new_node) && self.check_collision(&new_node) {
295 let mut new_node = new_node;
296 new_node.parent = Some(nearest_ind);
297 self.node_list.push(new_node);
298 let last = self.node_list.last().unwrap();
299 if self.calc_dist_to_goal(last.x, last.y) <= self.config.expand_dis {
300 let final_node = self.steer(last, &self.goal.clone(), self.config.expand_dis);
301 if self.check_collision(&final_node) {
302 return Ok(self.generate_final_course(self.node_list.len() - 1));
303 }
304 }
305 }
306 }
307 Err(RoboticsError::PlanningError(
308 "RRT: Cannot find path within max iterations".to_string(),
309 ))
310 }
311}
312
313impl PathPlanner for RRTPlanner {
314 fn plan(&self, start: Point2D, goal: Point2D) -> Result<Path2D, RoboticsError> {
315 let mut planner = RRTPlanner {
316 config: self.config.clone(),
317 obstacles: self.obstacles.clone(),
318 play_area: self.play_area.clone(),
319 rand_area: self.rand_area.clone(),
320 node_list: vec![RRTNode::new(start.x, start.y)],
321 _start: RRTNode::new(start.x, start.y),
322 goal: RRTNode::new(goal.x, goal.y),
323 };
324 planner.plan_with_sampler(start, goal, |planner| planner.get_random_node())
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 fn assert_close(actual: f64, expected: f64) {
333 assert!(
334 (actual - expected).abs() < 1.0e-12,
335 "expected {expected}, got {actual}"
336 );
337 }
338
339 fn assert_point_close(actual: &Point2D, expected: [f64; 2]) {
340 assert_close(actual.x, expected[0]);
341 assert_close(actual.y, expected[1]);
342 }
343
344 fn parse_xy_fixture(csv: &str) -> Vec<[f64; 2]> {
345 csv.lines()
346 .skip(1)
347 .filter(|line| !line.trim().is_empty())
348 .map(|line| {
349 let (x, y) = line
350 .split_once(',')
351 .expect("xy fixture rows must contain a comma");
352 [x.parse().unwrap(), y.parse().unwrap()]
353 })
354 .collect()
355 }
356
357 fn create_test_planner() -> RRTPlanner {
358 let obstacles = vec![
359 CircleObstacle::new(5.0, 5.0, 1.0),
360 CircleObstacle::new(3.0, 6.0, 2.0),
361 CircleObstacle::new(7.0, 5.0, 2.0),
362 ];
363 let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
364 let config = RRTConfig {
365 max_iter: 1000,
366 ..Default::default()
367 };
368 RRTPlanner::new(obstacles, rand_area, None, config)
369 }
370
371 fn create_pythonrobotics_main_planner() -> RRTPlanner {
372 let obstacles = vec![
373 CircleObstacle::new(5.0, 5.0, 1.0),
374 CircleObstacle::new(3.0, 6.0, 2.0),
375 CircleObstacle::new(3.0, 8.0, 2.0),
376 CircleObstacle::new(3.0, 10.0, 2.0),
377 CircleObstacle::new(7.0, 5.0, 2.0),
378 CircleObstacle::new(9.0, 5.0, 2.0),
379 CircleObstacle::new(8.0, 10.0, 1.0),
380 ];
381 let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
382 let config = RRTConfig {
383 robot_radius: 0.8,
384 ..Default::default()
385 };
386 RRTPlanner::new(obstacles, rand_area, None, config)
387 }
388
389 #[test]
390 fn test_rrt_finds_path() {
391 let planner = create_test_planner();
392 let start = Point2D::new(0.0, 0.0);
393 let goal = Point2D::new(10.0, 10.0);
394 let result = planner.plan(start, goal);
395 assert!(result.is_ok() || result.is_err());
396 }
397
398 #[test]
399 fn test_rrt_config_default() {
400 let config = RRTConfig::default();
401 assert_eq!(config.expand_dis, 3.0);
402 assert_eq!(config.max_iter, 500);
403 }
404
405 #[test]
406 fn test_rrt_upstream_test_rrt_short_goal_matches_pythonrobotics_reference() {
407 let mut planner = create_pythonrobotics_main_planner();
408 let start = Point2D::new(0.0, 0.0);
409 let goal = Point2D::new(1.0, 1.0);
410 let sample = [10.455_649_682_677_358, 11.942_970_283_541_907];
411 let path = planner
412 .plan_with_sampler(start, goal, |_| RRTNode::new(sample[0], sample[1]))
413 .unwrap();
414
415 assert_eq!(planner.node_list.len(), 2);
416 assert_eq!(path.points.len(), 3);
417 assert_point_close(&path.points[0], [0.0, 0.0]);
418 assert_point_close(
419 &path.points[1],
420 [1.976_107_921_083_293_5, 2.257_210_110_785_406],
421 );
422 assert_point_close(&path.points[2], [1.0, 1.0]);
423 assert_eq!(planner.node_list[1].parent, Some(0));
424 assert_eq!(planner.node_list[1].path_x.len(), 7);
425 assert_close(planner.node_list[1].path_x[1], 0.329_351_320_180_548_95);
426 assert_close(planner.node_list[1].path_y[1], 0.376_201_685_130_900_95);
427 }
428
429 #[test]
430 fn test_rrt_upstream_main_seeded_path_matches_pythonrobotics_reference() {
431 let mut planner = create_pythonrobotics_main_planner();
432 let start = Point2D::new(0.0, 0.0);
433 let goal = Point2D::new(6.0, 10.0);
434 let samples = parse_xy_fixture(include_str!("testdata/rrt_main_seed12345_samples.csv"));
435 let expected_path = parse_xy_fixture(include_str!("testdata/rrt_main_seed12345_path.csv"));
436 let mut sample_iter = samples.iter();
437
438 let path = planner
439 .plan_with_sampler(start, goal, |_| {
440 let sample = sample_iter
441 .next()
442 .expect("python reference sample sequence exhausted");
443 RRTNode::new(sample[0], sample[1])
444 })
445 .unwrap();
446
447 assert_eq!(planner.node_list.len(), 88);
448 assert_eq!(path.points.len(), expected_path.len());
449 for (actual, expected) in path.points.iter().zip(expected_path.iter()) {
450 assert_point_close(actual, *expected);
451 }
452
453 let expected_nodes = [
454 (
455 1,
456 [1.976_107_921_083_293, 2.257_210_110_785_406],
457 Some(0),
458 7,
459 [0.0, 0.329_351_320_180_549, 0.658_702_640_361_098],
460 [0.0, 0.376_201_685_130_901, 0.752_403_370_261_802],
461 ),
462 (
463 5,
464 [1.229_513_438_270_946, 3.806_527_238_150_387],
465 Some(1),
466 5,
467 [
468 1.976_107_921_083_293,
469 1.759_052_148_041_024,
470 1.541_996_374_998_755,
471 ],
472 [
473 2.257_210_110_785_406,
474 2.707_639_673_968_178,
475 3.158_069_237_150_95,
476 ],
477 ),
478 (
479 10,
480 [-0.964_870_854_478_227, 5.566_933_984_574_426],
481 Some(9),
482 7,
483 [
484 0.668_674_115_555_151,
485 0.358_283_476_598_378,
486 0.047_892_837_641_605,
487 ],
488 [
489 3.503_932_336_109_255,
490 3.895_924_238_127_659,
491 4.287_916_140_146_064,
492 ],
493 ),
494 (
495 20,
496 [13.060_964_451_302_038, 12.199_474_225_398_257],
497 Some(16),
498 8,
499 [
500 11.070_558_229_598_33,
501 11.395_810_000_597_592,
502 11.721_061_771_596_855,
503 ],
504 [
505 9.875_551_544_349_548,
506 10.255_303_154_565_809,
507 10.635_054_764_782_069,
508 ],
509 ),
510 (
511 87,
512 [5.860_033_119_067_657, 10.721_216_347_248_003],
513 Some(72),
514 7,
515 [
516 5.288_485_092_568_921,
517 5.383_743_096_985_377,
518 5.479_001_101_401_833,
519 ],
520 [
521 13.666_268_613_915_847,
522 13.175_426_569_471_206,
523 12.684_584_525_026_565,
524 ],
525 ),
526 ];
527 for (index, xy, parent, path_len, path_x3, path_y3) in expected_nodes {
528 let node = &planner.node_list[index];
529 assert_close(node.x, xy[0]);
530 assert_close(node.y, xy[1]);
531 assert_eq!(node.parent, parent);
532 assert_eq!(node.path_x.len(), path_len);
533 for (actual, expected) in node.path_x.iter().take(3).zip(path_x3.iter()) {
534 assert_close(*actual, *expected);
535 }
536 for (actual, expected) in node.path_y.iter().take(3).zip(path_y3.iter()) {
537 assert_close(*actual, *expected);
538 }
539 }
540 }
541}