Skip to main content

rust_robotics_planning/
bit_star.rs

1#![allow(dead_code, clippy::too_many_arguments)]
2
3//! Batch Informed Trees (BIT*) path planning algorithm
4//!
5//! An asymptotically optimal sampling-based planner that combines RRT* rewiring
6//! with informed (ellipsoidal) sampling. Vertices are processed in batches
7//! ordered by potential solution cost, using a lazy edge evaluation strategy.
8
9use std::collections::BinaryHeap;
10use std::f64::consts::PI;
11
12use nalgebra::{Matrix2, Vector2};
13use rand::Rng;
14
15use rust_robotics_core::{Path2D, Point2D, RoboticsError, RoboticsResult};
16
17/// A vertex in the BIT* search tree.
18#[derive(Clone, Debug)]
19struct Vertex {
20    pos: Vector2<f64>,
21    /// Cost-to-come from the start vertex through the tree.
22    cost: f64,
23    /// Index of the parent vertex in the vertex list, if connected to the tree.
24    parent: Option<usize>,
25}
26
27impl Vertex {
28    fn new(x: f64, y: f64) -> Self {
29        Self {
30            pos: Vector2::new(x, y),
31            cost: f64::INFINITY,
32            parent: None,
33        }
34    }
35}
36
37/// An edge in the priority queue, ordered by estimated total cost (low cost = high priority).
38#[derive(Clone, Debug, PartialEq)]
39struct QueueEdge {
40    /// Estimated total cost of a solution through this edge.
41    estimated_cost: f64,
42    from: usize,
43    to: usize,
44}
45
46impl Eq for QueueEdge {}
47
48impl PartialOrd for QueueEdge {
49    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
50        Some(self.cmp(other))
51    }
52}
53
54impl Ord for QueueEdge {
55    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
56        // Reverse ordering for min-heap behaviour with BinaryHeap (max-heap).
57        other
58            .estimated_cost
59            .partial_cmp(&self.estimated_cost)
60            .unwrap_or(std::cmp::Ordering::Equal)
61    }
62}
63
64/// Configuration for the BIT* planner.
65#[derive(Clone, Debug)]
66pub struct BITStarConfig {
67    /// Number of new samples to draw per batch.
68    pub batch_size: usize,
69    /// Maximum number of batches to run.
70    pub max_batches: usize,
71    /// Connection radius factor. The actual radius is `eta * (log(n)/n)^(1/d)` for d=2.
72    pub eta: f64,
73    /// Goal proximity threshold – a vertex closer than this to the goal is considered a goal vertex.
74    pub goal_threshold: f64,
75}
76
77impl Default for BITStarConfig {
78    fn default() -> Self {
79        Self {
80            batch_size: 100,
81            max_batches: 200,
82            eta: 40.0,
83            goal_threshold: 0.5,
84        }
85    }
86}
87
88/// Batch Informed Trees (BIT*) planner.
89pub struct BITStar {
90    config: BITStarConfig,
91    start: Vector2<f64>,
92    goal: Vector2<f64>,
93    obstacles: Vec<(f64, f64, f64)>, // (x, y, radius)
94    area_min: f64,
95    area_max: f64,
96    /// All vertices (both in tree and unconnected samples).
97    vertices: Vec<Vertex>,
98    /// Set of vertex indices that are part of the tree (have finite cost).
99    tree_set: Vec<bool>,
100}
101
102impl BITStar {
103    /// Create a new BIT* planner.
104    ///
105    /// * `start` – (x, y) start position
106    /// * `goal` – (x, y) goal position
107    /// * `obstacles` – list of circular obstacles (x, y, radius)
108    /// * `rand_area` – (min, max) bounds for uniform sampling
109    /// * `config` – planner configuration
110    pub fn new(
111        start: (f64, f64),
112        goal: (f64, f64),
113        obstacles: Vec<(f64, f64, f64)>,
114        rand_area: (f64, f64),
115        config: BITStarConfig,
116    ) -> Self {
117        Self {
118            config,
119            start: Vector2::new(start.0, start.1),
120            goal: Vector2::new(goal.0, goal.1),
121            obstacles,
122            area_min: rand_area.0,
123            area_max: rand_area.1,
124            vertices: Vec::new(),
125            tree_set: Vec::new(),
126        }
127    }
128
129    /// Run the planner and return the best path found (start to goal), or `None`.
130    pub fn planning(&mut self) -> Option<Vec<[f64; 2]>> {
131        self.reset();
132
133        let mut best_cost = f64::INFINITY;
134
135        for _batch in 0..self.config.max_batches {
136            // --- Sample a new batch ---
137            self.add_samples(best_cost);
138
139            // --- Build edge queue ---
140            let mut edge_queue = self.build_edge_queue(best_cost);
141
142            // --- Process edges ---
143            while let Some(edge) = edge_queue.pop() {
144                // Prune: skip if the optimistic estimate already exceeds best_cost.
145                if edge.estimated_cost >= best_cost {
146                    break;
147                }
148
149                let from_idx = edge.from;
150                let to_idx = edge.to;
151
152                // True cost of traversing this edge.
153                let edge_cost = self.dist(from_idx, to_idx);
154                let new_cost = self.vertices[from_idx].cost + edge_cost;
155
156                // Only useful if it would improve the target vertex.
157                if new_cost >= self.vertices[to_idx].cost {
158                    continue;
159                }
160
161                // Lazy collision check.
162                if !self.collision_free(from_idx, to_idx) {
163                    continue;
164                }
165
166                // Accept this edge – wire / rewire.
167                self.vertices[to_idx].cost = new_cost;
168                self.vertices[to_idx].parent = Some(from_idx);
169                self.tree_set[to_idx] = true;
170
171                // Check whether this vertex can reach the goal cheaply.
172                let dist_to_goal = (self.vertices[to_idx].pos - self.goal).norm();
173                if dist_to_goal < self.config.goal_threshold {
174                    let total = new_cost + dist_to_goal;
175                    if total < best_cost {
176                        best_cost = total;
177                    }
178                }
179            }
180
181            // Prune vertices whose heuristic cost exceeds best_cost.
182            self.prune(best_cost);
183        }
184
185        self.extract_path(best_cost)
186    }
187
188    /// Convenience wrapper returning [`Path2D`].
189    pub fn plan_from(&mut self, start: Point2D, goal: Point2D) -> RoboticsResult<Path2D> {
190        self.start = Vector2::new(start.x, start.y);
191        self.goal = Vector2::new(goal.x, goal.y);
192
193        self.planning()
194            .map(|raw| {
195                Path2D::from_points(raw.into_iter().map(|p| Point2D::new(p[0], p[1])).collect())
196            })
197            .ok_or_else(|| {
198                RoboticsError::PlanningError(
199                    "BIT*: Cannot find path within max batches".to_string(),
200                )
201            })
202    }
203
204    // ---- Internal helpers ----
205
206    fn reset(&mut self) {
207        self.vertices.clear();
208        self.tree_set.clear();
209        // Add start vertex with zero cost.
210        let mut start_v = Vertex::new(self.start.x, self.start.y);
211        start_v.cost = 0.0;
212        self.vertices.push(start_v);
213        self.tree_set.push(true);
214    }
215
216    /// Sample new vertices using informed (ellipsoidal) sampling when a solution exists,
217    /// or uniform sampling otherwise.
218    fn add_samples(&mut self, best_cost: f64) {
219        let mut rng = rand::rng();
220        let c_min = (self.goal - self.start).norm();
221
222        for _ in 0..self.config.batch_size {
223            let pos = if best_cost < f64::INFINITY {
224                self.sample_ellipse(best_cost, c_min, &mut rng)
225            } else {
226                Vector2::new(
227                    rng.random_range(self.area_min..=self.area_max),
228                    rng.random_range(self.area_min..=self.area_max),
229                )
230            };
231
232            let mut v = Vertex::new(pos.x, pos.y);
233            v.cost = f64::INFINITY;
234            self.vertices.push(v);
235            self.tree_set.push(false);
236        }
237    }
238
239    /// Sample a point uniformly inside the prolate hyperspheroid focused on start/goal.
240    fn sample_ellipse(&self, c_best: f64, c_min: f64, rng: &mut impl Rng) -> Vector2<f64> {
241        let center = (self.start + self.goal) / 2.0;
242        let diff = self.goal - self.start;
243        let angle = diff.y.atan2(diff.x);
244        let cos_a = angle.cos();
245        let sin_a = angle.sin();
246        let rotation = Matrix2::new(cos_a, -sin_a, sin_a, cos_a);
247
248        let r1 = c_best / 2.0;
249        let r2 = (c_best * c_best - c_min * c_min).max(0.0).sqrt() / 2.0;
250
251        // Uniform sampling inside an ellipse via unit-disk transform.
252        let theta = rng.random_range(0.0..2.0 * PI);
253        let r = rng.random::<f64>().sqrt();
254        let unit = Vector2::new(r * theta.cos(), r * theta.sin());
255        let scaled = Vector2::new(r1 * unit.x, r2 * unit.y);
256
257        center + rotation * scaled
258    }
259
260    /// Compute the connection radius for the current number of vertices.
261    fn connection_radius(&self) -> f64 {
262        let n = self.vertices.len().max(2) as f64;
263        self.config.eta * (n.ln() / n).sqrt()
264    }
265
266    /// Build the edge queue: for every tree vertex, add edges to nearby non-tree or
267    /// improvable vertices.
268    fn build_edge_queue(&self, best_cost: f64) -> BinaryHeap<QueueEdge> {
269        let r = self.connection_radius();
270        let r_sq = r * r;
271        let mut queue = BinaryHeap::new();
272
273        for (i, vi) in self.vertices.iter().enumerate() {
274            if !self.tree_set[i] {
275                continue;
276            }
277            for (j, vj) in self.vertices.iter().enumerate() {
278                if i == j {
279                    continue;
280                }
281                let d_sq = (vi.pos - vj.pos).norm_squared();
282                if d_sq > r_sq {
283                    continue;
284                }
285                let edge_cost = d_sq.sqrt();
286                let new_cost = vi.cost + edge_cost;
287
288                // Skip if this cannot improve the target vertex.
289                if new_cost >= vj.cost {
290                    continue;
291                }
292
293                // Optimistic estimate of total solution cost through this edge.
294                let estimated = new_cost + (vj.pos - self.goal).norm();
295                if estimated >= best_cost {
296                    continue;
297                }
298
299                queue.push(QueueEdge {
300                    estimated_cost: estimated,
301                    from: i,
302                    to: j,
303                });
304            }
305        }
306
307        queue
308    }
309
310    /// Euclidean distance between two vertices.
311    fn dist(&self, a: usize, b: usize) -> f64 {
312        (self.vertices[a].pos - self.vertices[b].pos).norm()
313    }
314
315    /// Check that the straight-line segment between two vertices is collision-free.
316    fn collision_free(&self, a: usize, b: usize) -> bool {
317        let pa = self.vertices[a].pos;
318        let pb = self.vertices[b].pos;
319        self.segment_collision_free(pa.x, pa.y, pb.x, pb.y)
320    }
321
322    fn segment_collision_free(&self, x1: f64, y1: f64, x2: f64, y2: f64) -> bool {
323        for &(ox, oy, radius) in &self.obstacles {
324            let dd = Self::point_to_segment_dist_sq([x1, y1], [x2, y2], [ox, oy]);
325            if dd <= radius * radius {
326                return false;
327            }
328        }
329        true
330    }
331
332    fn point_to_segment_dist_sq(v: [f64; 2], w: [f64; 2], p: [f64; 2]) -> f64 {
333        let l2 = (w[0] - v[0]).powi(2) + (w[1] - v[1]).powi(2);
334        if l2 == 0.0 {
335            return (p[0] - v[0]).powi(2) + (p[1] - v[1]).powi(2);
336        }
337        let t =
338            (((p[0] - v[0]) * (w[0] - v[0]) + (p[1] - v[1]) * (w[1] - v[1])) / l2).clamp(0.0, 1.0);
339        let proj = [v[0] + t * (w[0] - v[0]), v[1] + t * (w[1] - v[1])];
340        (p[0] - proj[0]).powi(2) + (p[1] - proj[1]).powi(2)
341    }
342
343    /// Remove samples whose optimistic cost exceeds `best_cost`.
344    fn prune(&mut self, best_cost: f64) {
345        if best_cost >= f64::INFINITY {
346            return;
347        }
348
349        for i in 0..self.vertices.len() {
350            // Never prune the start vertex.
351            if i == 0 {
352                continue;
353            }
354            let heuristic = (self.vertices[i].pos - self.start).norm()
355                + (self.vertices[i].pos - self.goal).norm();
356            if heuristic > best_cost {
357                // Disconnect this vertex.
358                self.vertices[i].cost = f64::INFINITY;
359                self.vertices[i].parent = None;
360                self.tree_set[i] = false;
361            }
362        }
363    }
364
365    /// Trace back from the best goal-connected vertex to produce the path.
366    fn extract_path(&self, best_cost: f64) -> Option<Vec<[f64; 2]>> {
367        if best_cost >= f64::INFINITY {
368            return None;
369        }
370
371        // Find the best vertex that is near the goal.
372        let mut best_idx = None;
373        let mut best_total = f64::INFINITY;
374        for (i, v) in self.vertices.iter().enumerate() {
375            if !self.tree_set[i] {
376                continue;
377            }
378            let dist_to_goal = (v.pos - self.goal).norm();
379            if dist_to_goal < self.config.goal_threshold {
380                let total = v.cost + dist_to_goal;
381                if total < best_total {
382                    best_total = total;
383                    best_idx = Some(i);
384                }
385            }
386        }
387
388        let best_idx = best_idx?;
389
390        // Trace back to start.
391        let mut path = vec![[self.goal.x, self.goal.y]];
392        let mut current = best_idx;
393        loop {
394            let v = &self.vertices[current];
395            path.push([v.pos.x, v.pos.y]);
396            match v.parent {
397                Some(p) => current = p,
398                None => break,
399            }
400        }
401        path.reverse();
402        Some(path)
403    }
404
405    /// Get the current vertices for inspection.
406    pub fn get_vertices(&self) -> Vec<(f64, f64, f64)> {
407        self.vertices
408            .iter()
409            .map(|v| (v.pos.x, v.pos.y, v.cost))
410            .collect()
411    }
412
413    /// Get the obstacle list.
414    pub fn get_obstacles(&self) -> &[(f64, f64, f64)] {
415        &self.obstacles
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    /// Helper: compute path length from a list of waypoints.
424    fn path_length(path: &[[f64; 2]]) -> f64 {
425        path.windows(2)
426            .map(|w| ((w[1][0] - w[0][0]).powi(2) + (w[1][1] - w[0][1]).powi(2)).sqrt())
427            .sum()
428    }
429
430    #[test]
431    fn test_bit_star_finds_path_open_space() {
432        let config = BITStarConfig {
433            batch_size: 200,
434            max_batches: 10,
435            eta: 40.0,
436            goal_threshold: 1.0,
437        };
438        let mut planner = BITStar::new(
439            (0.0, 0.0),
440            (10.0, 10.0),
441            vec![], // no obstacles
442            (-2.0, 15.0),
443            config,
444        );
445
446        let path = planner.planning();
447        assert!(path.is_some(), "BIT* should find a path in open space");
448
449        let path = path.unwrap();
450        // Path must start near start and end at goal.
451        assert!(path.len() >= 2);
452        let first = path.first().unwrap();
453        let last = path.last().unwrap();
454        assert!(
455            (first[0]).abs() < 1.5 && (first[1]).abs() < 1.5,
456            "Path should start near (0,0)"
457        );
458        assert!(
459            (last[0] - 10.0).abs() < 1.5 && (last[1] - 10.0).abs() < 1.5,
460            "Path should end near (10,10)"
461        );
462
463        // Path cost should be reasonable (straight-line is ~14.14).
464        let cost = path_length(&path);
465        assert!(
466            cost < 25.0,
467            "Path cost {} is unreasonably large for open space",
468            cost
469        );
470    }
471
472    #[test]
473    fn test_bit_star_finds_path_around_obstacles() {
474        let obstacles = vec![(5.0, 5.0, 1.0)];
475        let config = BITStarConfig {
476            batch_size: 200,
477            max_batches: 10,
478            eta: 30.0,
479            goal_threshold: 2.0,
480        };
481        let mut planner = BITStar::new(
482            (0.0, 0.0),
483            (10.0, 10.0),
484            obstacles.clone(),
485            (-5.0, 20.0),
486            config,
487        );
488
489        let path = planner.planning();
490        assert!(path.is_some(), "BIT* should find a path around obstacles");
491
492        let path = path.unwrap();
493        // Verify no segment of the returned path collides with obstacles.
494        for window in path.windows(2) {
495            let (x1, y1) = (window[0][0], window[0][1]);
496            let (x2, y2) = (window[1][0], window[1][1]);
497            for &(ox, oy, r) in &obstacles {
498                let dd = BITStar::point_to_segment_dist_sq([x1, y1], [x2, y2], [ox, oy]);
499                assert!(
500                    dd > r * r * 0.9, // slight tolerance for numerical issues
501                    "Path segment ({},{})--({},{}) collides with obstacle ({},{},{})",
502                    x1,
503                    y1,
504                    x2,
505                    y2,
506                    ox,
507                    oy,
508                    r
509                );
510            }
511        }
512    }
513
514    #[test]
515    #[ignore = "long-running iterative improvement test"]
516    fn test_bit_star_cost_improves_with_more_iterations() {
517        let obstacles = vec![(5.0, 5.0, 1.5)];
518
519        let mut costs = Vec::new();
520        for &max_batches in &[3, 10, 20] {
521            let config = BITStarConfig {
522                batch_size: 100,
523                max_batches,
524                eta: 30.0,
525                goal_threshold: 2.0,
526            };
527            let mut trial_costs = Vec::new();
528            for _ in 0..5 {
529                let mut planner = BITStar::new(
530                    (0.0, 0.0),
531                    (10.0, 10.0),
532                    obstacles.clone(),
533                    (-5.0, 20.0),
534                    config.clone(),
535                );
536                if let Some(path) = planner.planning() {
537                    trial_costs.push(path_length(&path));
538                }
539            }
540            assert!(
541                !trial_costs.is_empty(),
542                "At least one trial with max_batches={} should find a path",
543                max_batches
544            );
545            trial_costs.sort_by(|a, b| a.partial_cmp(b).unwrap());
546            costs.push(trial_costs[trial_costs.len() / 2]);
547        }
548
549        // With more batches the median cost should not increase (it should decrease or stay).
550        assert!(
551            costs[2] <= costs[0] + 1.0,
552            "Cost should improve (or stay similar) with more batches: {:?}",
553            costs
554        );
555    }
556
557    #[test]
558    fn test_bit_star_plan_from_returns_path2d() {
559        let config = BITStarConfig {
560            batch_size: 200,
561            max_batches: 20,
562            eta: 40.0,
563            goal_threshold: 1.0,
564        };
565        let mut planner = BITStar::new((0.0, 0.0), (10.0, 10.0), vec![], (-2.0, 15.0), config);
566
567        let result = planner.plan_from(Point2D::new(0.0, 0.0), Point2D::new(10.0, 10.0));
568        assert!(result.is_ok(), "plan_from should succeed in open space");
569        let path = result.unwrap();
570        assert!(path.points.len() >= 2);
571    }
572}