1#![allow(dead_code, clippy::too_many_arguments)]
2
3use std::collections::BinaryHeap;
10use std::f64::consts::PI;
11
12use nalgebra::{Matrix2, Vector2};
13use rand::Rng;
14
15use rust_robotics_core::{Path2D, Point2D, RoboticsError, RoboticsResult};
16
17#[derive(Clone, Debug)]
19struct Vertex {
20 pos: Vector2<f64>,
21 cost: f64,
23 parent: Option<usize>,
25}
26
27impl Vertex {
28 fn new(x: f64, y: f64) -> Self {
29 Self {
30 pos: Vector2::new(x, y),
31 cost: f64::INFINITY,
32 parent: None,
33 }
34 }
35}
36
37#[derive(Clone, Debug, PartialEq)]
39struct QueueEdge {
40 estimated_cost: f64,
42 from: usize,
43 to: usize,
44}
45
46impl Eq for QueueEdge {}
47
48impl PartialOrd for QueueEdge {
49 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
50 Some(self.cmp(other))
51 }
52}
53
54impl Ord for QueueEdge {
55 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
56 other
58 .estimated_cost
59 .partial_cmp(&self.estimated_cost)
60 .unwrap_or(std::cmp::Ordering::Equal)
61 }
62}
63
64#[derive(Clone, Debug)]
66pub struct BITStarConfig {
67 pub batch_size: usize,
69 pub max_batches: usize,
71 pub eta: f64,
73 pub goal_threshold: f64,
75}
76
77impl Default for BITStarConfig {
78 fn default() -> Self {
79 Self {
80 batch_size: 100,
81 max_batches: 200,
82 eta: 40.0,
83 goal_threshold: 0.5,
84 }
85 }
86}
87
88pub struct BITStar {
90 config: BITStarConfig,
91 start: Vector2<f64>,
92 goal: Vector2<f64>,
93 obstacles: Vec<(f64, f64, f64)>, area_min: f64,
95 area_max: f64,
96 vertices: Vec<Vertex>,
98 tree_set: Vec<bool>,
100}
101
102impl BITStar {
103 pub fn new(
111 start: (f64, f64),
112 goal: (f64, f64),
113 obstacles: Vec<(f64, f64, f64)>,
114 rand_area: (f64, f64),
115 config: BITStarConfig,
116 ) -> Self {
117 Self {
118 config,
119 start: Vector2::new(start.0, start.1),
120 goal: Vector2::new(goal.0, goal.1),
121 obstacles,
122 area_min: rand_area.0,
123 area_max: rand_area.1,
124 vertices: Vec::new(),
125 tree_set: Vec::new(),
126 }
127 }
128
129 pub fn planning(&mut self) -> Option<Vec<[f64; 2]>> {
131 self.reset();
132
133 let mut best_cost = f64::INFINITY;
134
135 for _batch in 0..self.config.max_batches {
136 self.add_samples(best_cost);
138
139 let mut edge_queue = self.build_edge_queue(best_cost);
141
142 while let Some(edge) = edge_queue.pop() {
144 if edge.estimated_cost >= best_cost {
146 break;
147 }
148
149 let from_idx = edge.from;
150 let to_idx = edge.to;
151
152 let edge_cost = self.dist(from_idx, to_idx);
154 let new_cost = self.vertices[from_idx].cost + edge_cost;
155
156 if new_cost >= self.vertices[to_idx].cost {
158 continue;
159 }
160
161 if !self.collision_free(from_idx, to_idx) {
163 continue;
164 }
165
166 self.vertices[to_idx].cost = new_cost;
168 self.vertices[to_idx].parent = Some(from_idx);
169 self.tree_set[to_idx] = true;
170
171 let dist_to_goal = (self.vertices[to_idx].pos - self.goal).norm();
173 if dist_to_goal < self.config.goal_threshold {
174 let total = new_cost + dist_to_goal;
175 if total < best_cost {
176 best_cost = total;
177 }
178 }
179 }
180
181 self.prune(best_cost);
183 }
184
185 self.extract_path(best_cost)
186 }
187
188 pub fn plan_from(&mut self, start: Point2D, goal: Point2D) -> RoboticsResult<Path2D> {
190 self.start = Vector2::new(start.x, start.y);
191 self.goal = Vector2::new(goal.x, goal.y);
192
193 self.planning()
194 .map(|raw| {
195 Path2D::from_points(raw.into_iter().map(|p| Point2D::new(p[0], p[1])).collect())
196 })
197 .ok_or_else(|| {
198 RoboticsError::PlanningError(
199 "BIT*: Cannot find path within max batches".to_string(),
200 )
201 })
202 }
203
204 fn reset(&mut self) {
207 self.vertices.clear();
208 self.tree_set.clear();
209 let mut start_v = Vertex::new(self.start.x, self.start.y);
211 start_v.cost = 0.0;
212 self.vertices.push(start_v);
213 self.tree_set.push(true);
214 }
215
216 fn add_samples(&mut self, best_cost: f64) {
219 let mut rng = rand::rng();
220 let c_min = (self.goal - self.start).norm();
221
222 for _ in 0..self.config.batch_size {
223 let pos = if best_cost < f64::INFINITY {
224 self.sample_ellipse(best_cost, c_min, &mut rng)
225 } else {
226 Vector2::new(
227 rng.random_range(self.area_min..=self.area_max),
228 rng.random_range(self.area_min..=self.area_max),
229 )
230 };
231
232 let mut v = Vertex::new(pos.x, pos.y);
233 v.cost = f64::INFINITY;
234 self.vertices.push(v);
235 self.tree_set.push(false);
236 }
237 }
238
239 fn sample_ellipse(&self, c_best: f64, c_min: f64, rng: &mut impl Rng) -> Vector2<f64> {
241 let center = (self.start + self.goal) / 2.0;
242 let diff = self.goal - self.start;
243 let angle = diff.y.atan2(diff.x);
244 let cos_a = angle.cos();
245 let sin_a = angle.sin();
246 let rotation = Matrix2::new(cos_a, -sin_a, sin_a, cos_a);
247
248 let r1 = c_best / 2.0;
249 let r2 = (c_best * c_best - c_min * c_min).max(0.0).sqrt() / 2.0;
250
251 let theta = rng.random_range(0.0..2.0 * PI);
253 let r = rng.random::<f64>().sqrt();
254 let unit = Vector2::new(r * theta.cos(), r * theta.sin());
255 let scaled = Vector2::new(r1 * unit.x, r2 * unit.y);
256
257 center + rotation * scaled
258 }
259
260 fn connection_radius(&self) -> f64 {
262 let n = self.vertices.len().max(2) as f64;
263 self.config.eta * (n.ln() / n).sqrt()
264 }
265
266 fn build_edge_queue(&self, best_cost: f64) -> BinaryHeap<QueueEdge> {
269 let r = self.connection_radius();
270 let r_sq = r * r;
271 let mut queue = BinaryHeap::new();
272
273 for (i, vi) in self.vertices.iter().enumerate() {
274 if !self.tree_set[i] {
275 continue;
276 }
277 for (j, vj) in self.vertices.iter().enumerate() {
278 if i == j {
279 continue;
280 }
281 let d_sq = (vi.pos - vj.pos).norm_squared();
282 if d_sq > r_sq {
283 continue;
284 }
285 let edge_cost = d_sq.sqrt();
286 let new_cost = vi.cost + edge_cost;
287
288 if new_cost >= vj.cost {
290 continue;
291 }
292
293 let estimated = new_cost + (vj.pos - self.goal).norm();
295 if estimated >= best_cost {
296 continue;
297 }
298
299 queue.push(QueueEdge {
300 estimated_cost: estimated,
301 from: i,
302 to: j,
303 });
304 }
305 }
306
307 queue
308 }
309
310 fn dist(&self, a: usize, b: usize) -> f64 {
312 (self.vertices[a].pos - self.vertices[b].pos).norm()
313 }
314
315 fn collision_free(&self, a: usize, b: usize) -> bool {
317 let pa = self.vertices[a].pos;
318 let pb = self.vertices[b].pos;
319 self.segment_collision_free(pa.x, pa.y, pb.x, pb.y)
320 }
321
322 fn segment_collision_free(&self, x1: f64, y1: f64, x2: f64, y2: f64) -> bool {
323 for &(ox, oy, radius) in &self.obstacles {
324 let dd = Self::point_to_segment_dist_sq([x1, y1], [x2, y2], [ox, oy]);
325 if dd <= radius * radius {
326 return false;
327 }
328 }
329 true
330 }
331
332 fn point_to_segment_dist_sq(v: [f64; 2], w: [f64; 2], p: [f64; 2]) -> f64 {
333 let l2 = (w[0] - v[0]).powi(2) + (w[1] - v[1]).powi(2);
334 if l2 == 0.0 {
335 return (p[0] - v[0]).powi(2) + (p[1] - v[1]).powi(2);
336 }
337 let t =
338 (((p[0] - v[0]) * (w[0] - v[0]) + (p[1] - v[1]) * (w[1] - v[1])) / l2).clamp(0.0, 1.0);
339 let proj = [v[0] + t * (w[0] - v[0]), v[1] + t * (w[1] - v[1])];
340 (p[0] - proj[0]).powi(2) + (p[1] - proj[1]).powi(2)
341 }
342
343 fn prune(&mut self, best_cost: f64) {
345 if best_cost >= f64::INFINITY {
346 return;
347 }
348
349 for i in 0..self.vertices.len() {
350 if i == 0 {
352 continue;
353 }
354 let heuristic = (self.vertices[i].pos - self.start).norm()
355 + (self.vertices[i].pos - self.goal).norm();
356 if heuristic > best_cost {
357 self.vertices[i].cost = f64::INFINITY;
359 self.vertices[i].parent = None;
360 self.tree_set[i] = false;
361 }
362 }
363 }
364
365 fn extract_path(&self, best_cost: f64) -> Option<Vec<[f64; 2]>> {
367 if best_cost >= f64::INFINITY {
368 return None;
369 }
370
371 let mut best_idx = None;
373 let mut best_total = f64::INFINITY;
374 for (i, v) in self.vertices.iter().enumerate() {
375 if !self.tree_set[i] {
376 continue;
377 }
378 let dist_to_goal = (v.pos - self.goal).norm();
379 if dist_to_goal < self.config.goal_threshold {
380 let total = v.cost + dist_to_goal;
381 if total < best_total {
382 best_total = total;
383 best_idx = Some(i);
384 }
385 }
386 }
387
388 let best_idx = best_idx?;
389
390 let mut path = vec![[self.goal.x, self.goal.y]];
392 let mut current = best_idx;
393 loop {
394 let v = &self.vertices[current];
395 path.push([v.pos.x, v.pos.y]);
396 match v.parent {
397 Some(p) => current = p,
398 None => break,
399 }
400 }
401 path.reverse();
402 Some(path)
403 }
404
405 pub fn get_vertices(&self) -> Vec<(f64, f64, f64)> {
407 self.vertices
408 .iter()
409 .map(|v| (v.pos.x, v.pos.y, v.cost))
410 .collect()
411 }
412
413 pub fn get_obstacles(&self) -> &[(f64, f64, f64)] {
415 &self.obstacles
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 fn path_length(path: &[[f64; 2]]) -> f64 {
425 path.windows(2)
426 .map(|w| ((w[1][0] - w[0][0]).powi(2) + (w[1][1] - w[0][1]).powi(2)).sqrt())
427 .sum()
428 }
429
430 #[test]
431 fn test_bit_star_finds_path_open_space() {
432 let config = BITStarConfig {
433 batch_size: 200,
434 max_batches: 10,
435 eta: 40.0,
436 goal_threshold: 1.0,
437 };
438 let mut planner = BITStar::new(
439 (0.0, 0.0),
440 (10.0, 10.0),
441 vec![], (-2.0, 15.0),
443 config,
444 );
445
446 let path = planner.planning();
447 assert!(path.is_some(), "BIT* should find a path in open space");
448
449 let path = path.unwrap();
450 assert!(path.len() >= 2);
452 let first = path.first().unwrap();
453 let last = path.last().unwrap();
454 assert!(
455 (first[0]).abs() < 1.5 && (first[1]).abs() < 1.5,
456 "Path should start near (0,0)"
457 );
458 assert!(
459 (last[0] - 10.0).abs() < 1.5 && (last[1] - 10.0).abs() < 1.5,
460 "Path should end near (10,10)"
461 );
462
463 let cost = path_length(&path);
465 assert!(
466 cost < 25.0,
467 "Path cost {} is unreasonably large for open space",
468 cost
469 );
470 }
471
472 #[test]
473 fn test_bit_star_finds_path_around_obstacles() {
474 let obstacles = vec![(5.0, 5.0, 1.0)];
475 let config = BITStarConfig {
476 batch_size: 200,
477 max_batches: 10,
478 eta: 30.0,
479 goal_threshold: 2.0,
480 };
481 let mut planner = BITStar::new(
482 (0.0, 0.0),
483 (10.0, 10.0),
484 obstacles.clone(),
485 (-5.0, 20.0),
486 config,
487 );
488
489 let path = planner.planning();
490 assert!(path.is_some(), "BIT* should find a path around obstacles");
491
492 let path = path.unwrap();
493 for window in path.windows(2) {
495 let (x1, y1) = (window[0][0], window[0][1]);
496 let (x2, y2) = (window[1][0], window[1][1]);
497 for &(ox, oy, r) in &obstacles {
498 let dd = BITStar::point_to_segment_dist_sq([x1, y1], [x2, y2], [ox, oy]);
499 assert!(
500 dd > r * r * 0.9, "Path segment ({},{})--({},{}) collides with obstacle ({},{},{})",
502 x1,
503 y1,
504 x2,
505 y2,
506 ox,
507 oy,
508 r
509 );
510 }
511 }
512 }
513
514 #[test]
515 #[ignore = "long-running iterative improvement test"]
516 fn test_bit_star_cost_improves_with_more_iterations() {
517 let obstacles = vec![(5.0, 5.0, 1.5)];
518
519 let mut costs = Vec::new();
520 for &max_batches in &[3, 10, 20] {
521 let config = BITStarConfig {
522 batch_size: 100,
523 max_batches,
524 eta: 30.0,
525 goal_threshold: 2.0,
526 };
527 let mut trial_costs = Vec::new();
528 for _ in 0..5 {
529 let mut planner = BITStar::new(
530 (0.0, 0.0),
531 (10.0, 10.0),
532 obstacles.clone(),
533 (-5.0, 20.0),
534 config.clone(),
535 );
536 if let Some(path) = planner.planning() {
537 trial_costs.push(path_length(&path));
538 }
539 }
540 assert!(
541 !trial_costs.is_empty(),
542 "At least one trial with max_batches={} should find a path",
543 max_batches
544 );
545 trial_costs.sort_by(|a, b| a.partial_cmp(b).unwrap());
546 costs.push(trial_costs[trial_costs.len() / 2]);
547 }
548
549 assert!(
551 costs[2] <= costs[0] + 1.0,
552 "Cost should improve (or stay similar) with more batches: {:?}",
553 costs
554 );
555 }
556
557 #[test]
558 fn test_bit_star_plan_from_returns_path2d() {
559 let config = BITStarConfig {
560 batch_size: 200,
561 max_batches: 20,
562 eta: 40.0,
563 goal_threshold: 1.0,
564 };
565 let mut planner = BITStar::new((0.0, 0.0), (10.0, 10.0), vec![], (-2.0, 15.0), config);
566
567 let result = planner.plan_from(Point2D::new(0.0, 0.0), Point2D::new(10.0, 10.0));
568 assert!(result.is_ok(), "plan_from should succeed in open space");
569 let path = result.unwrap();
570 assert!(path.points.len() >= 2);
571 }
572}