1use rand::Rng;
7
8use rust_robotics_core::{Path2D, PathPlanner, Point2D, RoboticsError};
9
10#[derive(Debug, Clone)]
12pub struct RRTNode {
13 pub x: f64,
14 pub y: f64,
15 pub parent: Option<usize>,
16}
17
18impl RRTNode {
19 pub fn new(x: f64, y: f64) -> Self {
20 RRTNode { x, y, parent: None }
21 }
22
23 fn to_point(&self) -> Point2D {
24 Point2D::new(self.x, self.y)
25 }
26}
27
28#[derive(Debug, Clone)]
30pub struct CircleObstacle {
31 pub x: f64,
32 pub y: f64,
33 pub radius: f64,
34}
35
36impl CircleObstacle {
37 pub fn new(x: f64, y: f64, radius: f64) -> Self {
38 Self { x, y, radius }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub struct AreaBounds {
45 pub xmin: f64,
46 pub xmax: f64,
47 pub ymin: f64,
48 pub ymax: f64,
49}
50
51impl AreaBounds {
52 pub fn new(xmin: f64, xmax: f64, ymin: f64, ymax: f64) -> Self {
53 AreaBounds {
54 xmin,
55 xmax,
56 ymin,
57 ymax,
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct BidirectionalRRTConfig {
65 pub expand_dis: f64,
67 pub path_resolution: f64,
69 pub max_iter: usize,
71 pub robot_radius: f64,
73}
74
75impl Default for BidirectionalRRTConfig {
76 fn default() -> Self {
77 Self {
78 expand_dis: 3.0,
79 path_resolution: 0.5,
80 max_iter: 500,
81 robot_radius: 0.8,
82 }
83 }
84}
85
86pub struct BidirectionalRRTPlanner {
88 config: BidirectionalRRTConfig,
89 obstacles: Vec<CircleObstacle>,
90 rand_area: AreaBounds,
91}
92
93impl BidirectionalRRTPlanner {
94 pub fn new(
95 obstacles: Vec<CircleObstacle>,
96 rand_area: AreaBounds,
97 config: BidirectionalRRTConfig,
98 ) -> Self {
99 BidirectionalRRTPlanner {
100 config,
101 obstacles,
102 rand_area,
103 }
104 }
105
106 fn get_random_node(&self) -> RRTNode {
109 let mut rng = rand::rng();
110 RRTNode::new(
111 rng.random_range(self.rand_area.xmin..=self.rand_area.xmax),
112 rng.random_range(self.rand_area.ymin..=self.rand_area.ymax),
113 )
114 }
115
116 fn get_nearest_node_index(tree: &[RRTNode], target: &RRTNode) -> usize {
117 tree.iter()
118 .enumerate()
119 .map(|(i, n)| {
120 let dx = n.x - target.x;
121 let dy = n.y - target.y;
122 (i, dx * dx + dy * dy)
123 })
124 .min_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
125 .map(|(i, _)| i)
126 .unwrap_or(0)
127 }
128
129 fn dist(ax: f64, ay: f64, bx: f64, by: f64) -> f64 {
130 let dx = ax - bx;
131 let dy = ay - by;
132 (dx * dx + dy * dy).sqrt()
133 }
134
135 fn steer(&self, from: &RRTNode, to: &RRTNode, parent_idx: usize) -> Option<RRTNode> {
138 let dx = to.x - from.x;
139 let dy = to.y - from.y;
140 let d = (dx * dx + dy * dy).sqrt();
141 let theta = dy.atan2(dx);
142 let step = d.min(self.config.expand_dis);
143 let n_steps = (step / self.config.path_resolution).floor() as usize;
144
145 let mut cx = from.x;
146 let mut cy = from.y;
147
148 for _ in 0..n_steps {
149 cx += self.config.path_resolution * theta.cos();
150 cy += self.config.path_resolution * theta.sin();
151 if self.point_in_collision(cx, cy) {
152 return None;
153 }
154 }
155 if Self::dist(cx, cy, to.x, to.y) <= self.config.path_resolution {
157 cx = to.x;
158 cy = to.y;
159 }
160 if self.point_in_collision(cx, cy) {
161 return None;
162 }
163
164 Some(RRTNode {
165 x: cx,
166 y: cy,
167 parent: Some(parent_idx),
168 })
169 }
170
171 fn point_in_collision(&self, x: f64, y: f64) -> bool {
172 for obs in &self.obstacles {
173 if Self::dist(x, y, obs.x, obs.y) <= obs.radius + self.config.robot_radius {
174 return true;
175 }
176 }
177 false
178 }
179
180 fn check_collision(&self, ax: f64, ay: f64, bx: f64, by: f64) -> bool {
182 let d = Self::dist(ax, ay, bx, by);
183 let n = (d / self.config.path_resolution).ceil() as usize;
184 for i in 0..=n {
185 let t = if n == 0 { 0.0 } else { i as f64 / n as f64 };
186 let x = ax + t * (bx - ax);
187 let y = ay + t * (by - ay);
188 if self.point_in_collision(x, y) {
189 return false;
190 }
191 }
192 true
193 }
194
195 fn trace_path(tree: &[RRTNode], idx: usize) -> Vec<Point2D> {
200 let mut path = Vec::new();
201 let mut current = Some(idx);
202 while let Some(i) = current {
203 path.push(tree[i].to_point());
204 current = tree[i].parent;
205 }
206 path.reverse();
207 path
208 }
209
210 fn run(&self, start: Point2D, goal: Point2D) -> Result<Path2D, RoboticsError> {
213 let mut tree_a: Vec<RRTNode> = vec![RRTNode::new(start.x, start.y)];
216 let mut tree_b: Vec<RRTNode> = vec![RRTNode::new(goal.x, goal.y)];
217
218 let mut a_is_start = true;
220
221 for _ in 0..self.config.max_iter {
222 let rnd = self.get_random_node();
224 let nearest_a = Self::get_nearest_node_index(&tree_a, &rnd);
225 if let Some(new_node) = self.steer(&tree_a[nearest_a].clone(), &rnd, nearest_a) {
226 tree_a.push(new_node);
227 let new_a_idx = tree_a.len() - 1;
228 let (na_x, na_y) = (tree_a[new_a_idx].x, tree_a[new_a_idx].y);
229
230 let nearest_b = Self::get_nearest_node_index(&tree_b, &tree_a[new_a_idx]);
232 let (nb_x, nb_y) = (tree_b[nearest_b].x, tree_b[nearest_b].y);
233
234 if Self::dist(na_x, na_y, nb_x, nb_y) <= self.config.expand_dis
235 && self.check_collision(na_x, na_y, nb_x, nb_y)
236 {
237 let mut path_a = Self::trace_path(&tree_a, new_a_idx);
239 let mut path_b = Self::trace_path(&tree_b, nearest_b);
240
241 if a_is_start {
245 path_b.reverse();
247 path_a.extend(path_b);
248 } else {
249 path_b.reverse();
251 path_b.extend(path_a);
252 path_a = path_b;
253 }
254 return Ok(Path2D::from_points(path_a));
255 }
256 }
257
258 std::mem::swap(&mut tree_a, &mut tree_b);
260 a_is_start = !a_is_start;
261 }
262
263 Err(RoboticsError::PlanningError(
264 "BidirectionalRRT: Cannot find path within max iterations".to_string(),
265 ))
266 }
267}
268
269impl PathPlanner for BidirectionalRRTPlanner {
270 fn plan(&self, start: Point2D, goal: Point2D) -> Result<Path2D, RoboticsError> {
271 self.run(start, goal)
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_bidir_rrt_config_defaults() {
281 let cfg = BidirectionalRRTConfig::default();
282 assert_eq!(cfg.expand_dis, 3.0);
283 assert_eq!(cfg.path_resolution, 0.5);
284 assert_eq!(cfg.max_iter, 500);
285 assert_eq!(cfg.robot_radius, 0.8);
286 }
287
288 #[test]
289 fn test_bidir_rrt_finds_path_no_obstacles() {
290 let planner = BidirectionalRRTPlanner::new(
291 vec![],
292 AreaBounds::new(-2.0, 15.0, -2.0, 15.0),
293 BidirectionalRRTConfig::default(),
294 );
295 let result = planner.plan(Point2D::new(0.0, 0.0), Point2D::new(10.0, 10.0));
296 assert!(
297 result.is_ok(),
298 "expected a path but got: {:?}",
299 result.err()
300 );
301 let path = result.unwrap();
302 assert!(
303 path.points.len() >= 2,
304 "path should have at least start and goal"
305 );
306 }
307
308 #[test]
309 fn test_bidir_rrt_finds_path_with_obstacles() {
310 let obstacles = vec![
311 CircleObstacle::new(5.0, 5.0, 1.0),
312 CircleObstacle::new(3.0, 6.0, 2.0),
313 CircleObstacle::new(3.0, 8.0, 2.0),
314 CircleObstacle::new(3.0, 10.0, 2.0),
315 CircleObstacle::new(7.0, 5.0, 2.0),
316 CircleObstacle::new(9.0, 5.0, 2.0),
317 CircleObstacle::new(8.0, 10.0, 1.0),
318 ];
319 let planner = BidirectionalRRTPlanner::new(
320 obstacles,
321 AreaBounds::new(-2.0, 15.0, -2.0, 15.0),
322 BidirectionalRRTConfig {
323 max_iter: 2000,
324 ..Default::default()
325 },
326 );
327 let result = planner.plan(Point2D::new(0.0, 0.0), Point2D::new(6.0, 10.0));
328 assert!(
329 result.is_ok(),
330 "expected a path but got: {:?}",
331 result.err()
332 );
333 }
334}