Skip to main content

rust_robotics_planning/
batch_informed_rrt_star.rs

1#![allow(dead_code, clippy::needless_borrows_for_generic_args)]
2
3//! Batch Informed RRT* path planning algorithm
4//!
5//! Extends Informed RRT* with batch sampling for improved convergence.
6//! Instead of sampling one point per iteration, a batch of samples is drawn
7//! from the informed ellipsoidal region (or free space if no solution exists
8//! yet) and processed together. After each batch the best solution cost is
9//! updated and the sampling ellipsoid shrinks, focusing future samples on
10//! the promising region.
11//!
12//! Reference: <https://arxiv.org/abs/1405.5848>
13
14use rand::Rng;
15use std::f64::consts::PI;
16
17use rust_robotics_core::{Path2D, Point2D, RoboticsError, RoboticsResult};
18
19/// A node in the RRT tree.
20#[derive(Clone, Debug)]
21pub struct Node {
22    pub x: f64,
23    pub y: f64,
24    pub cost: f64,
25    pub parent: Option<usize>,
26}
27
28impl Node {
29    pub fn new(x: f64, y: f64) -> Self {
30        Node {
31            x,
32            y,
33            cost: 0.0,
34            parent: None,
35        }
36    }
37}
38
39/// Configuration for the Batch Informed RRT* planner.
40#[derive(Clone, Debug)]
41pub struct BatchInformedRRTStarConfig {
42    /// Number of samples per batch.
43    pub batch_size: usize,
44    /// Maximum number of batches to run.
45    pub max_batches: usize,
46    /// Step size for tree expansion.
47    pub expand_dis: f64,
48    /// Percentage of time to sample the goal directly (0-100).
49    pub goal_sample_rate: i32,
50    /// Circular obstacles as (x, y, radius).
51    pub obstacle_list: Vec<(f64, f64, f64)>,
52    /// Sampling bounds (min, max).
53    pub rand_area: (f64, f64),
54}
55
56impl Default for BatchInformedRRTStarConfig {
57    fn default() -> Self {
58        Self {
59            batch_size: 50,
60            max_batches: 10,
61            expand_dis: 0.5,
62            goal_sample_rate: 10,
63            obstacle_list: Vec::new(),
64            rand_area: (-2.0, 15.0),
65        }
66    }
67}
68
69/// Batch Informed RRT* planner.
70///
71/// Combines RRT* tree construction (nearest-neighbour expansion, rewiring) with
72/// informed ellipsoidal sampling. Samples are generated in batches: after each
73/// batch the best-known solution cost is updated and the sampling ellipsoid is
74/// tightened, leading to faster convergence towards the optimal path.
75pub struct BatchInformedRRTStar {
76    pub start: Node,
77    pub goal: Node,
78    pub min_rand: f64,
79    pub max_rand: f64,
80    pub expand_dis: f64,
81    pub goal_sample_rate: i32,
82    pub batch_size: usize,
83    pub max_batches: usize,
84    pub obstacle_list: Vec<(f64, f64, f64)>,
85    pub node_list: Vec<Node>,
86}
87
88impl BatchInformedRRTStar {
89    /// Create a new Batch Informed RRT* planner.
90    pub fn new(start: (f64, f64), goal: (f64, f64), config: BatchInformedRRTStarConfig) -> Self {
91        BatchInformedRRTStar {
92            start: Node::new(start.0, start.1),
93            goal: Node::new(goal.0, goal.1),
94            min_rand: config.rand_area.0,
95            max_rand: config.rand_area.1,
96            expand_dis: config.expand_dis,
97            goal_sample_rate: config.goal_sample_rate,
98            batch_size: config.batch_size,
99            max_batches: config.max_batches,
100            obstacle_list: config.obstacle_list,
101            node_list: Vec::new(),
102        }
103    }
104
105    /// Run the planner and return the best path found (goal to start order), or `None`.
106    pub fn planning(&mut self) -> Option<Vec<[f64; 2]>> {
107        self.planning_with_sampler(|planner, c_best, c_min, x_center, rotation_matrix| {
108            planner.informed_sample(c_best, c_min, x_center, rotation_matrix)
109        })
110    }
111
112    fn sampling_frame(&self) -> (f64, [f64; 2], [[f64; 2]; 2]) {
113        let c_min =
114            ((self.start.x - self.goal.x).powi(2) + (self.start.y - self.goal.y).powi(2)).sqrt();
115        let x_center = [
116            (self.start.x + self.goal.x) / 2.0,
117            (self.start.y + self.goal.y) / 2.0,
118        ];
119        let a1 = [
120            (self.goal.x - self.start.x) / c_min,
121            (self.goal.y - self.start.y) / c_min,
122        ];
123        let e_theta = a1[1].atan2(a1[0]);
124        let cos_theta = e_theta.cos();
125        let sin_theta = e_theta.sin();
126        let rotation_matrix = [[cos_theta, -sin_theta], [sin_theta, cos_theta]];
127
128        (c_min, x_center, rotation_matrix)
129    }
130
131    fn reset_search(&mut self) {
132        self.node_list = vec![self.start.clone()];
133    }
134
135    /// Core planning loop that processes samples in batches.
136    ///
137    /// After each batch of `batch_size` samples, the best known path cost is updated
138    /// which tightens the informed sampling ellipsoid for the next batch.
139    fn planning_with_sampler<F>(&mut self, mut sample_node: F) -> Option<Vec<[f64; 2]>>
140    where
141        F: FnMut(&BatchInformedRRTStar, f64, f64, &[f64; 2], &[[f64; 2]; 2]) -> [f64; 2],
142    {
143        self.reset_search();
144        let mut c_best = f64::INFINITY;
145        let mut path = None;
146
147        let (c_min, x_center, rotation_matrix) = self.sampling_frame();
148
149        for _batch in 0..self.max_batches {
150            // Generate and process a batch of samples
151            for _sample in 0..self.batch_size {
152                let rnd = sample_node(self, c_best, c_min, &x_center, &rotation_matrix);
153                let n_ind = self.get_nearest_list_index(&rnd);
154                let nearest_node = &self.node_list[n_ind];
155
156                let theta = (rnd[1] - nearest_node.y).atan2(rnd[0] - nearest_node.x);
157                let new_node = self.get_new_node(theta, n_ind, nearest_node);
158                let d = self.line_cost(nearest_node, &new_node);
159
160                let no_collision = self.check_collision(nearest_node, theta, d);
161
162                if no_collision {
163                    let near_inds = self.find_near_nodes(&new_node);
164                    let new_node = self.choose_parent(new_node, &near_inds);
165
166                    let new_node_index = self.node_list.len();
167                    self.node_list.push(new_node);
168                    self.rewire(new_node_index, &near_inds);
169
170                    if self.is_near_goal(&self.node_list[new_node_index])
171                        && self.check_segment_collision(
172                            self.node_list[new_node_index].x,
173                            self.node_list[new_node_index].y,
174                            self.goal.x,
175                            self.goal.y,
176                        )
177                    {
178                        let temp_path = self.get_final_course(new_node_index);
179                        let temp_path_len = self.get_path_len(&temp_path);
180                        if temp_path_len < c_best {
181                            path = Some(temp_path);
182                            c_best = temp_path_len;
183                        }
184                    }
185                }
186            }
187
188            // After each batch, prune nodes that cannot improve the solution.
189            // This keeps the tree lean and focused on the informed region.
190            if c_best < f64::INFINITY {
191                self.prune_nodes(c_best, c_min, &x_center, &rotation_matrix);
192            }
193        }
194
195        path
196    }
197
198    /// Remove leaf nodes whose heuristic lower-bound cost exceeds `c_best`.
199    ///
200    /// A node is prunable when it is a leaf (no children) and its cost-to-come
201    /// plus straight-line cost-to-goal exceeds the best known solution cost.
202    /// Pruning is done iteratively until no more leaves can be removed.
203    fn prune_nodes(
204        &mut self,
205        c_best: f64,
206        _c_min: f64,
207        _x_center: &[f64; 2],
208        _rotation_matrix: &[[f64; 2]; 2],
209    ) {
210        // Iteratively remove leaves that cannot improve the solution.
211        loop {
212            let n = self.node_list.len();
213            if n <= 1 {
214                break;
215            }
216
217            // Find which nodes are parents (have children).
218            let mut is_parent = vec![false; n];
219            for node in &self.node_list {
220                if let Some(p) = node.parent {
221                    if p < n {
222                        is_parent[p] = true;
223                    }
224                }
225            }
226
227            // Collect indices of leaf nodes (not parents, not root) that exceed c_best.
228            let mut to_remove = Vec::new();
229            for (i, &is_par) in is_parent.iter().enumerate().skip(1) {
230                if !is_par {
231                    let heuristic = self.node_list[i].cost + self.dist_to_goal(&self.node_list[i]);
232                    if heuristic > c_best {
233                        to_remove.push(i);
234                    }
235                }
236            }
237
238            if to_remove.is_empty() {
239                break;
240            }
241
242            // Remove in reverse order to preserve indices.
243            to_remove.sort_unstable();
244            for &idx in to_remove.iter().rev() {
245                self.node_list.remove(idx);
246                // Update parent references that point beyond the removed index.
247                for node in &mut self.node_list {
248                    if let Some(ref mut p) = node.parent {
249                        if *p == idx {
250                            // This should not happen since idx was a leaf,
251                            // but handle defensively.
252                            node.parent = None;
253                        } else if *p > idx {
254                            *p -= 1;
255                        }
256                    }
257                }
258            }
259        }
260    }
261
262    fn dist_to_goal(&self, node: &Node) -> f64 {
263        ((node.x - self.goal.x).powi(2) + (node.y - self.goal.y).powi(2)).sqrt()
264    }
265
266    fn choose_parent(&self, mut new_node: Node, near_inds: &[usize]) -> Node {
267        if near_inds.is_empty() {
268            return new_node;
269        }
270
271        let mut d_list = Vec::new();
272        for &i in near_inds {
273            let dx = new_node.x - self.node_list[i].x;
274            let dy = new_node.y - self.node_list[i].y;
275            let d = (dx * dx + dy * dy).sqrt();
276            let theta = dy.atan2(dx);
277            if self.check_collision(&self.node_list[i], theta, d) {
278                d_list.push(self.node_list[i].cost + d);
279            } else {
280                d_list.push(f64::INFINITY);
281            }
282        }
283
284        let min_cost = d_list.iter().fold(f64::INFINITY, |a, &b| a.min(b));
285        if let Some(min_index) = d_list.iter().position(|&x| x == min_cost) {
286            if min_cost != f64::INFINITY {
287                new_node.cost = min_cost;
288                new_node.parent = Some(near_inds[min_index]);
289            }
290        }
291
292        new_node
293    }
294
295    fn find_near_nodes(&self, new_node: &Node) -> Vec<usize> {
296        let n_node = self.node_list.len();
297        let r = 50.0 * ((n_node as f64).ln() / n_node as f64).sqrt();
298        let mut near_inds = Vec::new();
299
300        for (i, node) in self.node_list.iter().enumerate() {
301            let d_sq = (node.x - new_node.x).powi(2) + (node.y - new_node.y).powi(2);
302            if d_sq <= r * r {
303                near_inds.push(i);
304            }
305        }
306
307        near_inds
308    }
309
310    fn informed_sample(
311        &self,
312        c_max: f64,
313        c_min: f64,
314        x_center: &[f64; 2],
315        rotation_matrix: &[[f64; 2]; 2],
316    ) -> [f64; 2] {
317        if c_max < f64::INFINITY {
318            let x_ball = self.sample_unit_ball();
319            self.informed_sample_from_unit_ball(c_max, c_min, x_center, rotation_matrix, x_ball)
320        } else {
321            self.sample_free_space()
322        }
323    }
324
325    fn informed_sample_from_unit_ball(
326        &self,
327        c_max: f64,
328        c_min: f64,
329        x_center: &[f64; 2],
330        rotation_matrix: &[[f64; 2]; 2],
331        x_ball: [f64; 2],
332    ) -> [f64; 2] {
333        let r = [c_max / 2.0, (c_max * c_max - c_min * c_min).sqrt() / 2.0];
334        let scaled = [r[0] * x_ball[0], r[1] * x_ball[1]];
335        let rotated = [
336            rotation_matrix[0][0] * scaled[0] + rotation_matrix[0][1] * scaled[1],
337            rotation_matrix[1][0] * scaled[0] + rotation_matrix[1][1] * scaled[1],
338        ];
339
340        [rotated[0] + x_center[0], rotated[1] + x_center[1]]
341    }
342
343    fn sample_unit_ball(&self) -> [f64; 2] {
344        let mut rng = rand::rng();
345        let a: f64 = rng.random();
346        let b: f64 = rng.random();
347
348        Self::sample_unit_ball_from_uniforms(a, b)
349    }
350
351    fn sample_unit_ball_from_uniforms(a: f64, b: f64) -> [f64; 2] {
352        let (a, b) = if b < a { (b, a) } else { (a, b) };
353        let sample = (b * (2.0 * PI * a / b).cos(), b * (2.0 * PI * a / b).sin());
354        [sample.0, sample.1]
355    }
356
357    fn sample_free_space(&self) -> [f64; 2] {
358        let mut rng = rand::rng();
359        if rng.random_range(0..=100) > self.goal_sample_rate {
360            [
361                rng.random_range(self.min_rand..=self.max_rand),
362                rng.random_range(self.min_rand..=self.max_rand),
363            ]
364        } else {
365            [self.goal.x, self.goal.y]
366        }
367    }
368
369    fn get_path_len(&self, path: &[[f64; 2]]) -> f64 {
370        let mut path_len = 0.0;
371        for i in 1..path.len() {
372            let dx = path[i][0] - path[i - 1][0];
373            let dy = path[i][1] - path[i - 1][1];
374            path_len += (dx * dx + dy * dy).sqrt();
375        }
376        path_len
377    }
378
379    fn line_cost(&self, node1: &Node, node2: &Node) -> f64 {
380        ((node1.x - node2.x).powi(2) + (node1.y - node2.y).powi(2)).sqrt()
381    }
382
383    fn get_nearest_list_index(&self, rnd: &[f64; 2]) -> usize {
384        let mut min_dist = f64::INFINITY;
385        let mut min_index = 0;
386
387        for (i, node) in self.node_list.iter().enumerate() {
388            let dist = (node.x - rnd[0]).powi(2) + (node.y - rnd[1]).powi(2);
389            if dist < min_dist {
390                min_dist = dist;
391                min_index = i;
392            }
393        }
394
395        min_index
396    }
397
398    fn get_new_node(&self, theta: f64, n_ind: usize, nearest_node: &Node) -> Node {
399        let mut new_node = nearest_node.clone();
400        new_node.x += self.expand_dis * theta.cos();
401        new_node.y += self.expand_dis * theta.sin();
402        new_node.cost += self.expand_dis;
403        new_node.parent = Some(n_ind);
404        new_node
405    }
406
407    fn is_near_goal(&self, node: &Node) -> bool {
408        let d = self.line_cost(node, &self.goal);
409        d < self.expand_dis
410    }
411
412    fn rewire(&mut self, new_node_index: usize, near_inds: &[usize]) {
413        for &i in near_inds {
414            let near_node = &self.node_list[i];
415            let new_node = &self.node_list[new_node_index];
416
417            let d =
418                ((near_node.x - new_node.x).powi(2) + (near_node.y - new_node.y).powi(2)).sqrt();
419            let s_cost = new_node.cost + d;
420
421            if near_node.cost > s_cost {
422                let theta = (new_node.y - near_node.y).atan2(new_node.x - near_node.x);
423                if self.check_collision(near_node, theta, d) {
424                    self.node_list[i].parent = Some(new_node_index);
425                    self.node_list[i].cost = s_cost;
426                }
427            }
428        }
429    }
430
431    fn distance_squared_point_to_segment(&self, v: [f64; 2], w: [f64; 2], p: [f64; 2]) -> f64 {
432        if v[0] == w[0] && v[1] == w[1] {
433            return (p[0] - v[0]).powi(2) + (p[1] - v[1]).powi(2);
434        }
435
436        let l2 = (w[0] - v[0]).powi(2) + (w[1] - v[1]).powi(2);
437        let t =
438            (((p[0] - v[0]) * (w[0] - v[0]) + (p[1] - v[1]) * (w[1] - v[1])) / l2).clamp(0.0, 1.0);
439        let projection = [v[0] + t * (w[0] - v[0]), v[1] + t * (w[1] - v[1])];
440        (p[0] - projection[0]).powi(2) + (p[1] - projection[1]).powi(2)
441    }
442
443    fn check_segment_collision(&self, x1: f64, y1: f64, x2: f64, y2: f64) -> bool {
444        for &(ox, oy, size) in &self.obstacle_list {
445            let dd = self.distance_squared_point_to_segment([x1, y1], [x2, y2], [ox, oy]);
446            if dd <= size * size {
447                return false;
448            }
449        }
450        true
451    }
452
453    fn check_collision(&self, near_node: &Node, theta: f64, d: f64) -> bool {
454        let end_x = near_node.x + theta.cos() * d;
455        let end_y = near_node.y + theta.sin() * d;
456        self.check_segment_collision(near_node.x, near_node.y, end_x, end_y)
457    }
458
459    fn get_final_course(&self, last_index: usize) -> Vec<[f64; 2]> {
460        let mut path = vec![[self.goal.x, self.goal.y]];
461        let mut current_index = last_index;
462
463        while let Some(parent_index) = self.node_list[current_index].parent {
464            let node = &self.node_list[current_index];
465            path.push([node.x, node.y]);
466            current_index = parent_index;
467        }
468
469        path.push([self.start.x, self.start.y]);
470        path
471    }
472
473    /// Plan a path from the given start to goal, returning a [`Path2D`].
474    ///
475    /// This is a convenience wrapper around [`planning()`](Self::planning) that accepts
476    /// [`Point2D`], sets the start/goal, runs the planner, and returns [`Path2D`].
477    pub fn plan_from(&mut self, start: Point2D, goal: Point2D) -> RoboticsResult<Path2D> {
478        self.start = Node::new(start.x, start.y);
479        self.goal = Node::new(goal.x, goal.y);
480
481        self.planning()
482            .map(|raw_path| {
483                Path2D::from_points(
484                    raw_path
485                        .into_iter()
486                        .rev()
487                        .map(|p| Point2D::new(p[0], p[1]))
488                        .collect(),
489                )
490            })
491            .ok_or_else(|| {
492                RoboticsError::PlanningError(
493                    "BatchInformedRRT*: Cannot find path within max batches".to_string(),
494                )
495            })
496    }
497
498    /// Get the tree nodes for external inspection.
499    pub fn get_tree(&self) -> &[Node] {
500        &self.node_list
501    }
502
503    /// Get the obstacle list.
504    pub fn get_obstacles(&self) -> &[(f64, f64, f64)] {
505        &self.obstacle_list
506    }
507}
508
509#[cfg(test)]
510impl BatchInformedRRTStar {
511    fn prune_nodes_for_test(
512        &mut self,
513        c_best: f64,
514        c_min: f64,
515        x_center: &[f64; 2],
516        rotation_matrix: &[[f64; 2]; 2],
517    ) {
518        self.prune_nodes(c_best, c_min, x_center, rotation_matrix);
519    }
520}
521
522#[cfg(test)]
523mod tests {
524    use super::*;
525
526    fn assert_close(actual: f64, expected: f64) {
527        assert!(
528            (actual - expected).abs() < 1.0e-12,
529            "expected {expected}, got {actual}"
530        );
531    }
532
533    fn default_obstacles() -> Vec<(f64, f64, f64)> {
534        vec![
535            (5.0, 5.0, 0.5),
536            (9.0, 6.0, 1.0),
537            (7.0, 5.0, 1.0),
538            (1.0, 5.0, 1.0),
539            (3.0, 6.0, 1.0),
540            (7.0, 9.0, 1.0),
541        ]
542    }
543
544    fn create_planner(
545        obstacles: Vec<(f64, f64, f64)>,
546        batch_size: usize,
547        max_batches: usize,
548    ) -> BatchInformedRRTStar {
549        BatchInformedRRTStar::new(
550            (0.0, 0.0),
551            (5.0, 10.0),
552            BatchInformedRRTStarConfig {
553                batch_size,
554                max_batches,
555                expand_dis: 0.5,
556                goal_sample_rate: 10,
557                obstacle_list: obstacles,
558                rand_area: (-2.0, 15.0),
559            },
560        )
561    }
562
563    #[test]
564    fn test_batch_informed_rrt_star_config() {
565        let planner = create_planner(vec![(5.0, 5.0, 0.5)], 50, 10);
566        assert_eq!(planner.expand_dis, 0.5);
567        assert_eq!(planner.batch_size, 50);
568        assert_eq!(planner.max_batches, 10);
569    }
570
571    #[test]
572    fn test_sampling_frame_matches_informed_rrt_star() {
573        let planner = create_planner(default_obstacles(), 50, 10);
574        let (c_min, x_center, rotation_matrix) = planner.sampling_frame();
575
576        assert_close(c_min, 11.180_339_887_498_949);
577        assert_close(x_center[0], 2.5);
578        assert_close(x_center[1], 5.0);
579        assert_close(rotation_matrix[0][0], 0.447_213_595_499_958);
580        assert_close(rotation_matrix[0][1], -0.894_427_190_999_916);
581        assert_close(rotation_matrix[1][0], 0.894_427_190_999_916);
582        assert_close(rotation_matrix[1][1], 0.447_213_595_499_958);
583    }
584
585    #[test]
586    fn test_sample_unit_ball() {
587        let expected = [
588            ((0.2, 0.8), [0.0, 0.8]),
589            ((0.1, 0.4), [0.0, 0.4]),
590            ((0.33, 0.9), [-0.602_217_545_722_972, 0.668_830_342_929_655]),
591        ];
592
593        for ((a, b), xy) in expected {
594            let sample = BatchInformedRRTStar::sample_unit_ball_from_uniforms(a, b);
595            assert_close(sample[0], xy[0]);
596            assert_close(sample[1], xy[1]);
597        }
598    }
599
600    #[test]
601    fn test_informed_ellipse_sampling() {
602        let planner = create_planner(default_obstacles(), 50, 10);
603        let (c_min, x_center, rotation_matrix) = planner.sampling_frame();
604        let unit_ball_samples = [
605            [0.0, 0.8],
606            [0.0, 0.4],
607            [-0.602_217_545_722_972, 0.668_830_342_929_655],
608        ];
609        let expected = [
610            (
611                12.0,
612                [
613                    [0.940_512_904_830_566, 5.779_743_547_584_717],
614                    [1.720_256_452_415_283, 5.389_871_773_792_358],
615                    [-0.419_709_604_196_265, 2.420_056_693_659_169],
616                ],
617            ),
618            (
619                14.0,
620                [
621                    [-0.514_630_989_026_684, 6.507_315_494_513_342],
622                    [0.992_684_505_486_658, 5.753_657_747_256_671],
623                    [-1.905_584_965_017_868, 2.489_694_689_330_144],
624                ],
625            ),
626        ];
627
628        for (c_best, xy_samples) in expected {
629            for (x_ball, xy) in unit_ball_samples.iter().zip(xy_samples.iter()) {
630                let sample = planner.informed_sample_from_unit_ball(
631                    c_best,
632                    c_min,
633                    &x_center,
634                    &rotation_matrix,
635                    *x_ball,
636                );
637                assert_close(sample[0], xy[0]);
638                assert_close(sample[1], xy[1]);
639            }
640        }
641    }
642
643    #[test]
644    fn test_seeded_single_batch_matches_informed_rrt_star_logic() {
645        // Use deterministic samples to verify the tree-building logic is correct.
646        // A single batch with N samples should behave identically to InformedRRTStar
647        // with N iterations (no pruning occurs because c_best starts infinite within
648        // the first batch).
649        let mut planner = create_planner(default_obstacles(), 20, 1);
650        let samples = [
651            [10.455_649_682_677_358, 11.942_970_283_541_907],
652            [12.537_351_811_401_535, 13.840_339_744_778_298],
653            [7.622_138_868_390_643, 0.748_693_006_799_259],
654            [12.860_963_734_685_752, 2.438_529_994_751_022],
655            [0.963_840_532_303_441, 7.404_758_454_678_607],
656            [10.548_525_179_081_42, 10.458_784_524_738_785],
657            [14.636_879_924_696_97, 5.006_029_679_968_117],
658            [0.831_739_168_751_872, 1.497_141_169_178_794],
659            [5.0, 10.0],
660            [12.230_194_783_259_385, 3.474_659_848_212_157],
661            [3.771_802_130_764_652, 14.447_201_800_957_814],
662            [10.657_010_682_397_589, -1.941_271_617_163_375],
663            [12.803_038_468_843_097, 11.104_183_754_785_783],
664            [2.871_774_581_756_395, 9.093_185_216_538_995],
665            [13.054_119_941_473_152, 7.827_462_465_227_494],
666            [4.375_040_433_962_334, 14.322_895_906_129_173],
667            [10.059_569_437_230_238, 12.022_177_097_262_272],
668            [5.0, 10.0],
669            [0.036_638_892_026_463, 6.486_823_094_475_769],
670            [11.632_287_039_256_342, 7.436_450_081_952_417],
671        ];
672        let mut sample_index = 0_usize;
673
674        let path = planner.planning_with_sampler(|_, _, _, _, _| {
675            let sample = *samples
676                .get(sample_index)
677                .expect("sample sequence exhausted");
678            sample_index += 1;
679            sample
680        });
681
682        assert!(path.is_none());
683        assert_eq!(sample_index, samples.len());
684        // Tree should have 20 nodes (same as InformedRRTStar with 20 iterations)
685        assert_eq!(planner.node_list.len(), 20);
686
687        let expected_nodes = [
688            (
689                1,
690                [0.329_351_320_180_549, 0.376_201_685_130_901],
691                0.5,
692                Some(0),
693            ),
694            (
695                2,
696                [0.665_203_546_734_372, 0.746_611_298_827_467],
697                0.999_962_094_343_993,
698                Some(0),
699            ),
700            (
701                5,
702                [1.607_494_092_207_533, 1.315_569_836_427_121],
703                2.077_200_339_639_631,
704                Some(0),
705            ),
706            (
707                10,
708                [3.085_807_934_983_019, 2.339_074_098_076_378],
709                3.872_141_300_094_301,
710                Some(0),
711            ),
712            (
713                15,
714                [3.900_997_195_468_872, 3.858_949_080_306_712],
715                5.487_191_187_069_758,
716                Some(0),
717            ),
718        ];
719
720        for (index, xy, cost, parent) in expected_nodes {
721            let node = &planner.node_list[index];
722            assert_close(node.x, xy[0]);
723            assert_close(node.y, xy[1]);
724            assert_close(node.cost, cost);
725            assert_eq!(node.parent, parent);
726        }
727    }
728
729    #[test]
730    fn test_finds_path_open_space() {
731        // With no obstacles and a large expand_dis, a path should be found quickly.
732        let mut planner = BatchInformedRRTStar::new(
733            (0.0, 0.0),
734            (5.0, 10.0),
735            BatchInformedRRTStarConfig {
736                batch_size: 100,
737                max_batches: 10,
738                expand_dis: 5.0_f64.hypot(10.0),
739                goal_sample_rate: 100,
740                obstacle_list: vec![],
741                rand_area: (-2.0, 15.0),
742            },
743        );
744
745        let path = planner.planning();
746        assert!(path.is_some(), "Should find a path in open space");
747
748        let path = path.unwrap();
749        assert!(path.len() >= 2);
750        // Path starts at goal, ends at start (goal-to-start order).
751        let first = path.first().unwrap();
752        let last = path.last().unwrap();
753        assert_close(first[0], 5.0);
754        assert_close(first[1], 10.0);
755        assert_close(last[0], 0.0);
756        assert_close(last[1], 0.0);
757    }
758
759    #[test]
760    fn test_finds_path_with_obstacles() {
761        let mut planner = create_planner(default_obstacles(), 200, 20);
762
763        let path = planner.planning();
764        assert!(
765            path.is_some(),
766            "Should find a path around obstacles with enough samples"
767        );
768
769        let path = path.unwrap();
770        // Verify the path is collision-free.
771        for window in path.windows(2) {
772            let (x1, y1) = (window[0][0], window[0][1]);
773            let (x2, y2) = (window[1][0], window[1][1]);
774            assert!(
775                planner.check_segment_collision(x1, y1, x2, y2),
776                "Path segment ({},{})--({},{}) should be collision-free",
777                x1,
778                y1,
779                x2,
780                y2
781            );
782        }
783    }
784
785    #[test]
786    fn test_plan_from_returns_path2d() {
787        let mut planner = BatchInformedRRTStar::new(
788            (0.0, 0.0),
789            (5.0, 10.0),
790            BatchInformedRRTStarConfig {
791                batch_size: 200,
792                max_batches: 20,
793                expand_dis: 0.5,
794                goal_sample_rate: 10,
795                obstacle_list: vec![],
796                rand_area: (-2.0, 15.0),
797            },
798        );
799
800        let result = planner.plan_from(Point2D::new(0.0, 0.0), Point2D::new(5.0, 10.0));
801        assert!(result.is_ok(), "plan_from should succeed in open space");
802        let path = result.unwrap();
803        assert!(path.points.len() >= 2);
804        // Path2D is start-to-goal order (reversed from raw).
805        assert_close(path.points.first().unwrap().x, 0.0);
806        assert_close(path.points.first().unwrap().y, 0.0);
807        assert_close(path.points.last().unwrap().x, 5.0);
808        assert_close(path.points.last().unwrap().y, 10.0);
809    }
810
811    #[test]
812    fn test_batch_processing_improves_or_maintains_cost() {
813        // Run with 1 large batch vs. multiple smaller batches. The multi-batch
814        // version benefits from pruning and tighter ellipsoidal sampling.
815        let obstacles = default_obstacles();
816
817        let mut single_batch_costs = Vec::new();
818        let mut multi_batch_costs = Vec::new();
819
820        for _ in 0..5 {
821            let mut p1 = create_planner(obstacles.clone(), 400, 1);
822            if let Some(path) = p1.planning() {
823                single_batch_costs.push(p1.get_path_len(&path));
824            }
825
826            let mut p2 = create_planner(obstacles.clone(), 100, 4);
827            if let Some(path) = p2.planning() {
828                multi_batch_costs.push(p2.get_path_len(&path));
829            }
830        }
831
832        // Both configurations should find paths at least some of the time.
833        assert!(
834            !single_batch_costs.is_empty() || !multi_batch_costs.is_empty(),
835            "At least one configuration should find a path"
836        );
837    }
838
839    #[test]
840    fn test_prune_nodes_removes_dominated_leaf() {
841        // Deterministic check: a leaf whose cost + straight-line-to-goal exceeds
842        // c_best must be removed (see `prune_nodes`). Comparing two stochastic
843        // planning runs was flaky across CI runners due to `thread_rng()`.
844        let mut planner = create_planner(vec![], 10, 1);
845        planner.node_list = vec![
846            Node {
847                x: 0.0,
848                y: 0.0,
849                cost: 0.0,
850                parent: None,
851            },
852            Node {
853                x: 100.0,
854                y: 100.0,
855                cost: 5.0,
856                parent: Some(0),
857            },
858        ];
859        let (c_min, x_center, rotation_matrix) = planner.sampling_frame();
860        let c_best = 15.0;
861        assert!(
862            planner.node_list[1].cost + planner.dist_to_goal(&planner.node_list[1]) > c_best,
863            "fixture leaf should be prunable"
864        );
865        planner.prune_nodes_for_test(c_best, c_min, &x_center, &rotation_matrix);
866        assert_eq!(planner.node_list.len(), 1);
867    }
868}