Skip to main content

rust_robotics_planning/
closed_loop_rrt_star.rs

1#![allow(dead_code, clippy::too_many_arguments)]
2
3//! Closed-Loop RRT* (CL-RRT*) path planner
4//!
5//! Extends RRT\*-Reeds-Shepp by adding a forward-simulation verification step.
6//! After the tree is grown, candidate goal paths are tracked using a pure-pursuit
7//! controller on a unicycle (bicycle-kinematic) model. Only paths that are
8//! dynamically feasible --- the controller reaches the goal, the tracked path
9//! does not collide, and the travel distance is reasonable --- are accepted.
10//!
11//! The planner returns the full simulated trajectory (states over time) of the
12//! best feasible path, not just the geometric waypoints.
13//!
14//! # References
15//!
16//! * Kuwata, Y. et al. (2009). "Real-Time Motion Planning With Applications
17//!   to Autonomous Urban Driving." *IEEE T-CST*.
18//! * PythonRobotics `ClosedLoopRRTStar/` by Atsushi Sakai.
19
20use std::f64::consts::PI;
21
22use rust_robotics_core::types::Pose2D;
23
24use crate::rrt::{AreaBounds, CircleObstacle};
25use crate::rrt_star_reeds_shepp::{RRTStarRSConfig, RRTStarRSNode, RRTStarRSPlanner};
26
27// ---------------------------------------------------------------------------
28// Unicycle (bicycle-kinematic) model
29// ---------------------------------------------------------------------------
30
31/// Parameters for the bicycle-kinematic (unicycle) vehicle model.
32#[derive(Debug, Clone)]
33pub struct UnicycleParams {
34    /// Simulation time step \[s\].
35    pub dt: f64,
36    /// Wheelbase \[m\].
37    pub wheelbase: f64,
38    /// Maximum steering angle \[rad\].
39    pub steer_max: f64,
40    /// Maximum longitudinal acceleration \[m/s^2\].
41    pub accel_max: f64,
42}
43
44impl Default for UnicycleParams {
45    fn default() -> Self {
46        Self {
47            dt: 0.05,
48            wheelbase: 0.9,
49            steer_max: 40.0_f64.to_radians(),
50            accel_max: 5.0,
51        }
52    }
53}
54
55/// Vehicle state for the unicycle model.
56#[derive(Debug, Clone, Copy)]
57pub struct VehicleState {
58    pub x: f64,
59    pub y: f64,
60    pub yaw: f64,
61    pub v: f64,
62}
63
64impl VehicleState {
65    pub fn new(x: f64, y: f64, yaw: f64, v: f64) -> Self {
66        Self { x, y, yaw, v }
67    }
68}
69
70/// Advance the vehicle state by one time step.
71fn unicycle_update(
72    state: &VehicleState,
73    accel: f64,
74    delta: f64,
75    params: &UnicycleParams,
76) -> VehicleState {
77    let x = state.x + state.v * state.yaw.cos() * params.dt;
78    let y = state.y + state.v * state.yaw.sin() * params.dt;
79    let yaw = pi_2_pi(state.yaw + state.v / params.wheelbase * delta.tan() * params.dt);
80    let v = state.v + accel * params.dt;
81    VehicleState { x, y, yaw, v }
82}
83
84// ---------------------------------------------------------------------------
85// Pure-pursuit controller
86// ---------------------------------------------------------------------------
87
88/// Parameters for the pure-pursuit path tracker.
89#[derive(Debug, Clone)]
90pub struct PurePursuitParams {
91    /// Proportional speed gain.
92    pub kp: f64,
93    /// Look-ahead distance \[m\].
94    pub look_ahead: f64,
95    /// Maximum simulation time \[s\].
96    pub max_time: f64,
97    /// Goal distance threshold \[m\].
98    pub goal_dis: f64,
99    /// Speed below which the robot is considered stopped \[m/s\].
100    pub stop_speed: f64,
101}
102
103impl Default for PurePursuitParams {
104    fn default() -> Self {
105        Self {
106            kp: 2.0,
107            look_ahead: 0.5,
108            max_time: 100.0,
109            goal_dis: 0.5,
110            stop_speed: 0.5,
111        }
112    }
113}
114
115/// Result of a closed-loop forward simulation.
116#[derive(Debug, Clone)]
117pub struct SimulationResult {
118    pub t: Vec<f64>,
119    pub x: Vec<f64>,
120    pub y: Vec<f64>,
121    pub yaw: Vec<f64>,
122    pub v: Vec<f64>,
123    pub accel: Vec<f64>,
124    pub steer: Vec<f64>,
125    pub reached_goal: bool,
126}
127
128/// PID speed controller (proportional only, clamped).
129fn pid_control(target: f64, current: f64, kp: f64, accel_max: f64) -> f64 {
130    (kp * (target - current)).clamp(-accel_max, accel_max)
131}
132
133/// Find the target index on the reference path for pure pursuit.
134fn calc_target_index(
135    state: &VehicleState,
136    cx: &[f64],
137    cy: &[f64],
138    look_ahead: f64,
139) -> (usize, f64) {
140    let mut min_dist = f64::INFINITY;
141    let mut min_ind = 0;
142    for (i, (&rx, &ry)) in cx.iter().zip(cy.iter()).enumerate() {
143        let d = ((state.x - rx).powi(2) + (state.y - ry).powi(2)).sqrt();
144        if d < min_dist {
145            min_dist = d;
146            min_ind = i;
147        }
148    }
149
150    let mut cumulative = 0.0;
151    while look_ahead > cumulative && (min_ind + 1) < cx.len() {
152        let dx = cx[min_ind + 1] - cx[min_ind];
153        let dy = cy[min_ind + 1] - cy[min_ind];
154        cumulative += (dx * dx + dy * dy).sqrt();
155        min_ind += 1;
156    }
157
158    (min_ind, min_dist)
159}
160
161/// Pure pursuit steering control.
162fn pure_pursuit_control(
163    state: &VehicleState,
164    cx: &[f64],
165    cy: &[f64],
166    prev_ind: usize,
167    look_ahead: f64,
168    wheelbase: f64,
169    steer_max: f64,
170) -> (f64, usize, f64) {
171    let (mut ind, dis) = calc_target_index(state, cx, cy, look_ahead);
172    if prev_ind >= ind {
173        ind = prev_ind;
174    }
175
176    let (tx, ty) = if ind < cx.len() {
177        (cx[ind], cy[ind])
178    } else {
179        ind = cx.len() - 1;
180        (cx[ind], cy[ind])
181    };
182
183    let mut alpha = (ty - state.y).atan2(tx - state.x) - state.yaw;
184    if state.v <= 0.0 {
185        alpha = PI - alpha;
186    }
187
188    let delta = (2.0 * wheelbase * alpha.sin() / look_ahead).atan2(1.0);
189    let delta = delta.clamp(-steer_max, steer_max);
190
191    (delta, ind, dis)
192}
193
194/// Build a speed profile along the reference path, inserting stop points at
195/// direction reversals (forward <-> backward).
196fn calc_speed_profile(
197    cx: &[f64],
198    cy: &[f64],
199    cyaw: &[f64],
200    target_speed: f64,
201    stop_speed: f64,
202) -> Vec<f64> {
203    let n = cx.len();
204    let mut profile = vec![target_speed; n];
205    let mut forward = true;
206    let mut is_back = false;
207
208    for i in 0..n - 1 {
209        let dx = cx[i + 1] - cx[i];
210        let dy = cy[i + 1] - cy[i];
211        let move_dir = dy.atan2(dx);
212        is_back = (move_dir - cyaw[i]).abs() >= PI / 2.0;
213
214        if dx == 0.0 && dy == 0.0 {
215            continue;
216        }
217
218        if is_back {
219            profile[i] = -target_speed;
220        } else {
221            profile[i] = target_speed;
222        }
223
224        if is_back && forward {
225            profile[i] = 0.0;
226            forward = false;
227        } else if !is_back && !forward {
228            profile[i] = 0.0;
229            forward = true;
230        }
231    }
232
233    profile[0] = 0.0;
234    if is_back {
235        profile[n - 1] = -stop_speed;
236    } else {
237        profile[n - 1] = stop_speed;
238    }
239
240    profile
241}
242
243/// Extend the path beyond its end by `look_ahead` distance so the tracker
244/// does not overshoot.
245fn extend_path(
246    cx: &[f64],
247    cy: &[f64],
248    cyaw: &[f64],
249    look_ahead: f64,
250) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
251    let mut cx = cx.to_vec();
252    let mut cy = cy.to_vec();
253    let mut cyaw = cyaw.to_vec();
254
255    let dl = 0.1_f64;
256    let steps = (look_ahead / dl) as usize + 1;
257
258    let n = cx.len();
259    let move_dir = (cy[n - 1] - cy[n.saturating_sub(3)]).atan2(cx[n - 1] - cx[n.saturating_sub(3)]);
260    let is_back = (move_dir - cyaw[n - 1]).abs() >= PI / 2.0;
261
262    for _ in 0..steps {
263        let idl = if is_back { -dl } else { dl };
264        let last_yaw = *cyaw.last().unwrap();
265        cx.push(cx.last().unwrap() + idl * last_yaw.cos());
266        cy.push(cy.last().unwrap() + idl * last_yaw.sin());
267        cyaw.push(last_yaw);
268    }
269
270    (cx, cy, cyaw)
271}
272
273/// Run the closed-loop pure-pursuit simulation on a reference path.
274fn closed_loop_prediction(
275    cx: &[f64],
276    cy: &[f64],
277    _cyaw: &[f64],
278    speed_profile: &[f64],
279    goal: [f64; 3],
280    vehicle: &UnicycleParams,
281    pp: &PurePursuitParams,
282) -> SimulationResult {
283    let mut state = VehicleState::new(0.0, 0.0, 0.0, 0.0);
284    let mut time = 0.0;
285
286    let mut result = SimulationResult {
287        t: vec![0.0],
288        x: vec![state.x],
289        y: vec![state.y],
290        yaw: vec![state.yaw],
291        v: vec![state.v],
292        accel: vec![0.0],
293        steer: vec![0.0],
294        reached_goal: false,
295    };
296
297    let (mut target_ind, _) = calc_target_index(&state, cx, cy, pp.look_ahead);
298    let max_dis = 0.5_f64;
299
300    while time <= pp.max_time {
301        let (di, new_ind, dis) = pure_pursuit_control(
302            &state,
303            cx,
304            cy,
305            target_ind,
306            pp.look_ahead,
307            vehicle.wheelbase,
308            vehicle.steer_max,
309        );
310        target_ind = new_ind;
311
312        let mut target_speed = speed_profile[target_ind.min(speed_profile.len() - 1)];
313        target_speed *= (max_dis - dis.min(max_dis - 0.1)) / max_dis;
314
315        let ai = pid_control(target_speed, state.v, pp.kp, vehicle.accel_max);
316        state = unicycle_update(&state, ai, di, vehicle);
317
318        if state.v.abs() <= pp.stop_speed && target_ind <= cx.len().saturating_sub(2) {
319            target_ind += 1;
320        }
321
322        time += vehicle.dt;
323
324        // Check goal
325        let dx = state.x - goal[0];
326        let dy = state.y - goal[1];
327        if (dx * dx + dy * dy).sqrt() <= pp.goal_dis {
328            result.reached_goal = true;
329            break;
330        }
331
332        result.t.push(time);
333        result.x.push(state.x);
334        result.y.push(state.y);
335        result.yaw.push(state.yaw);
336        result.v.push(state.v);
337        result.accel.push(ai);
338        result.steer.push(di);
339    }
340
341    result
342}
343
344// ---------------------------------------------------------------------------
345// Closed-Loop RRT* planner
346// ---------------------------------------------------------------------------
347
348/// Configuration for the Closed-Loop RRT* planner.
349#[derive(Debug, Clone)]
350pub struct ClosedLoopRRTStarConfig {
351    /// Underlying RRT*-Reeds-Shepp configuration.
352    pub rrt_config: RRTStarRSConfig,
353    /// Vehicle (unicycle) model parameters.
354    pub vehicle: UnicycleParams,
355    /// Pure-pursuit tracking parameters.
356    pub pursuit: PurePursuitParams,
357    /// Target speed for the forward simulation \[m/s\].
358    pub target_speed: f64,
359    /// Yaw threshold for checking goal candidates from the tree \[rad\].
360    pub yaw_threshold: f64,
361    /// Position threshold for checking goal candidates from the tree \[m\].
362    pub xy_threshold: f64,
363    /// If `tracked_travel / geometric_travel` exceeds this ratio, the path is
364    /// rejected as inefficient.
365    pub invalid_travel_ratio: f64,
366    /// Yaw tolerance for the final simulated heading \[rad\].
367    pub final_yaw_tolerance: f64,
368}
369
370impl Default for ClosedLoopRRTStarConfig {
371    fn default() -> Self {
372        Self {
373            rrt_config: RRTStarRSConfig::default(),
374            vehicle: UnicycleParams::default(),
375            pursuit: PurePursuitParams::default(),
376            target_speed: 10.0 / 3.6,
377            yaw_threshold: 3.0_f64.to_radians(),
378            xy_threshold: 0.5,
379            invalid_travel_ratio: 5.0,
380            final_yaw_tolerance: 30.0_f64.to_radians(),
381        }
382    }
383}
384
385/// Result of a successful CL-RRT* planning run.
386#[derive(Debug, Clone)]
387pub struct ClosedLoopRRTStarResult {
388    /// Simulated trajectory.
389    pub sim: SimulationResult,
390    /// The geometric path from the RRT* tree (Reeds-Shepp waypoints).
391    pub geometric_poses: Vec<Pose2D>,
392}
393
394/// Closed-Loop RRT* planner.
395///
396/// Internally runs [`RRTStarRSPlanner`] to build the tree, then evaluates all
397/// candidate goal paths with a closed-loop forward simulation.
398pub struct ClosedLoopRRTStarPlanner {
399    config: ClosedLoopRRTStarConfig,
400    obstacles: Vec<CircleObstacle>,
401    rand_area: AreaBounds,
402    inner: RRTStarRSPlanner,
403}
404
405impl ClosedLoopRRTStarPlanner {
406    /// Create a new CL-RRT* planner.
407    pub fn new(
408        obstacles: Vec<CircleObstacle>,
409        rand_area: AreaBounds,
410        config: ClosedLoopRRTStarConfig,
411    ) -> Self {
412        let inner = RRTStarRSPlanner::new(
413            obstacles.clone(),
414            rand_area.clone(),
415            config.rrt_config.clone(),
416        );
417        Self {
418            config,
419            obstacles,
420            rand_area,
421            inner,
422        }
423    }
424
425    /// Plan from `start` to `goal`, returning the best feasible trajectory.
426    ///
427    /// Returns `None` if no dynamically feasible path is found.
428    pub fn planning(&mut self, start: Pose2D, goal: Pose2D) -> Option<ClosedLoopRRTStarResult> {
429        // Phase 1: grow the RRT*-Reeds-Shepp tree.
430        let _ = self.inner.planning(start, goal);
431        self.select_best_feasible(goal)
432    }
433
434    /// Plan using a deterministic sampler (for testing).
435    pub fn plan_with_sampler<F>(
436        &mut self,
437        start: Pose2D,
438        goal: Pose2D,
439        sample_node: F,
440    ) -> Option<ClosedLoopRRTStarResult>
441    where
442        F: FnMut(&RRTStarRSPlanner) -> RRTStarRSNode,
443    {
444        let _ = self.inner.plan_with_sampler(start, goal, sample_node);
445        self.select_best_feasible(goal)
446    }
447
448    /// Access the internal RRT* tree.
449    pub fn get_tree(&self) -> &[RRTStarRSNode] {
450        self.inner.get_tree()
451    }
452
453    // -----------------------------------------------------------------------
454    // Private
455    // -----------------------------------------------------------------------
456
457    /// After the tree has been built, find all goal-candidate nodes and return
458    /// the one whose closed-loop simulation has the shortest travel time.
459    fn select_best_feasible(&self, goal: Pose2D) -> Option<ClosedLoopRRTStarResult> {
460        let tree = self.inner.get_tree();
461        let goal_inds = self.get_goal_indexes(tree, &goal);
462
463        let mut best_time = f64::INFINITY;
464        let mut best_result: Option<ClosedLoopRRTStarResult> = None;
465
466        for &ind in &goal_inds {
467            let path = self.generate_final_course(tree, ind);
468            if path.is_empty() || path.len() < 2 {
469                continue;
470            }
471
472            let feasibility = self.check_tracking_feasible(&path, &goal);
473            if let Some(sim) = feasibility {
474                let end_time = *sim.t.last().unwrap_or(&f64::INFINITY);
475                if end_time < best_time {
476                    best_time = end_time;
477                    let poses: Vec<Pose2D> = path
478                        .iter()
479                        .map(|&(x, y, yaw)| Pose2D::new(x, y, yaw))
480                        .collect();
481                    best_result = Some(ClosedLoopRRTStarResult {
482                        sim,
483                        geometric_poses: poses,
484                    });
485                }
486            }
487        }
488
489        best_result
490    }
491
492    /// Collect node indices that are within position and yaw tolerance of the goal.
493    fn get_goal_indexes(&self, tree: &[RRTStarRSNode], goal: &Pose2D) -> Vec<usize> {
494        let mut inds = Vec::new();
495        for (i, node) in tree.iter().enumerate() {
496            let dx = node.x - goal.x;
497            let dy = node.y - goal.y;
498            if (dx * dx + dy * dy).sqrt() > self.config.xy_threshold {
499                continue;
500            }
501            if angle_diff(node.yaw, goal.yaw).abs() > self.config.yaw_threshold {
502                continue;
503            }
504            inds.push(i);
505        }
506        inds
507    }
508
509    /// Trace back from `goal_index` to the root, returning waypoints as
510    /// `(x, y, yaw)` from start to goal.
511    fn generate_final_course(
512        &self,
513        tree: &[RRTStarRSNode],
514        goal_index: usize,
515    ) -> Vec<(f64, f64, f64)> {
516        let mut path: Vec<(f64, f64, f64)> = Vec::new();
517
518        let mut node = &tree[goal_index];
519        while node.parent.is_some() {
520            for ((&px, &py), &pyaw) in node
521                .path_x
522                .iter()
523                .rev()
524                .zip(node.path_y.iter().rev())
525                .zip(node.path_yaw.iter().rev())
526            {
527                path.push((px, py, pyaw));
528            }
529            node = &tree[node.parent.unwrap()];
530        }
531        path.push((node.x, node.y, node.yaw));
532        path.reverse();
533        path
534    }
535
536    /// Forward-simulate a path using pure pursuit and validate feasibility.
537    ///
538    /// Returns `Some(SimulationResult)` if feasible, `None` otherwise.
539    fn check_tracking_feasible(
540        &self,
541        path: &[(f64, f64, f64)],
542        goal: &Pose2D,
543    ) -> Option<SimulationResult> {
544        let cx: Vec<f64> = path.iter().map(|p| p.0).collect();
545        let cy: Vec<f64> = path.iter().map(|p| p.1).collect();
546        let cyaw: Vec<f64> = path.iter().map(|p| p.2).collect();
547
548        let goal_arr = [goal.x, goal.y, goal.yaw];
549
550        let (ecx, ecy, ecyaw) = extend_path(&cx, &cy, &cyaw, self.config.pursuit.look_ahead);
551
552        let speed_profile = calc_speed_profile(
553            &ecx,
554            &ecy,
555            &ecyaw,
556            self.config.target_speed,
557            self.config.pursuit.stop_speed,
558        );
559
560        let sim = closed_loop_prediction(
561            &ecx,
562            &ecy,
563            &ecyaw,
564            &speed_profile,
565            goal_arr,
566            &self.config.vehicle,
567            &self.config.pursuit,
568        );
569
570        if !sim.reached_goal {
571            return None;
572        }
573
574        // Final yaw check
575        if let Some(&final_yaw) = sim.yaw.last() {
576            if (pi_2_pi(final_yaw) - goal.yaw).abs() >= self.config.final_yaw_tolerance {
577                return None;
578            }
579        }
580
581        // Travel ratio check
582        let travel: f64 = sim.v.iter().map(|vi| vi.abs()).sum::<f64>() * self.config.vehicle.dt;
583        let origin_travel: f64 = path
584            .windows(2)
585            .map(|w| ((w[1].0 - w[0].0).powi(2) + (w[1].1 - w[0].1).powi(2)).sqrt())
586            .sum();
587
588        if origin_travel > 0.0 && travel / origin_travel >= self.config.invalid_travel_ratio {
589            return None;
590        }
591
592        // Collision check along simulated trajectory
593        if !self.check_sim_collision(&sim) {
594            return None;
595        }
596
597        Some(sim)
598    }
599
600    /// Check that the simulated trajectory does not collide with any obstacle.
601    fn check_sim_collision(&self, sim: &SimulationResult) -> bool {
602        for obs in &self.obstacles {
603            for (&sx, &sy) in sim.x.iter().zip(sim.y.iter()) {
604                let dx = obs.x - sx;
605                let dy = obs.y - sy;
606                let d = (dx * dx + dy * dy).sqrt();
607                if d <= obs.radius + self.config.rrt_config.robot_radius {
608                    return false;
609                }
610            }
611        }
612        true
613    }
614}
615
616// ---------------------------------------------------------------------------
617// Helpers
618// ---------------------------------------------------------------------------
619
620/// Normalize an angle to \[-pi, pi\).
621fn pi_2_pi(angle: f64) -> f64 {
622    let mut a = angle % (2.0 * PI);
623    if a > PI {
624        a -= 2.0 * PI;
625    }
626    if a < -PI {
627        a += 2.0 * PI;
628    }
629    a
630}
631
632/// Shortest signed angular difference.
633fn angle_diff(a: f64, b: f64) -> f64 {
634    let mut d = a - b;
635    while d > PI {
636        d -= 2.0 * PI;
637    }
638    while d < -PI {
639        d += 2.0 * PI;
640    }
641    d
642}
643
644// ---------------------------------------------------------------------------
645// Tests
646// ---------------------------------------------------------------------------
647
648#[cfg(test)]
649mod tests {
650    use super::*;
651
652    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
653        (a - b).abs() < tol
654    }
655
656    // -- Unicycle model tests --
657
658    #[test]
659    fn test_unicycle_straight_line() {
660        let params = UnicycleParams::default();
661        let state = VehicleState::new(0.0, 0.0, 0.0, 1.0);
662        let next = unicycle_update(&state, 0.0, 0.0, &params);
663        assert!(approx_eq(next.x, params.dt, 1e-12));
664        assert!(approx_eq(next.y, 0.0, 1e-12));
665        assert!(approx_eq(next.yaw, 0.0, 1e-12));
666        assert!(approx_eq(next.v, 1.0, 1e-12));
667    }
668
669    #[test]
670    fn test_unicycle_acceleration() {
671        let params = UnicycleParams::default();
672        let state = VehicleState::new(0.0, 0.0, 0.0, 0.0);
673        let next = unicycle_update(&state, 2.0, 0.0, &params);
674        assert!(approx_eq(next.v, 2.0 * params.dt, 1e-12));
675        assert!(approx_eq(next.x, 0.0, 1e-12)); // v was 0 at start
676    }
677
678    #[test]
679    fn test_unicycle_turning() {
680        let params = UnicycleParams::default();
681        let state = VehicleState::new(0.0, 0.0, 0.0, 1.0);
682        let delta = 0.1; // small steering angle
683        let next = unicycle_update(&state, 0.0, delta, &params);
684        // Yaw should increase (turning left)
685        let expected_yaw = 1.0 / params.wheelbase * delta.tan() * params.dt;
686        assert!(approx_eq(next.yaw, expected_yaw, 1e-10));
687    }
688
689    // -- Pure pursuit tests --
690
691    #[test]
692    fn test_calc_target_index_nearest() {
693        let cx = vec![0.0, 1.0, 2.0, 3.0, 4.0];
694        let cy = vec![0.0, 0.0, 0.0, 0.0, 0.0];
695        let state = VehicleState::new(0.5, 0.0, 0.0, 1.0);
696        let (ind, dis) = calc_target_index(&state, &cx, &cy, 0.5);
697        // Nearest is index 0 or 1, then advanced by look-ahead
698        assert!(ind <= cx.len());
699        assert!(dis < 1.0);
700    }
701
702    #[test]
703    fn test_pid_control_clamping() {
704        let a = pid_control(10.0, 0.0, 2.0, 5.0);
705        assert!(approx_eq(a, 5.0, 1e-12)); // Clamped to accel_max
706
707        let a = pid_control(-10.0, 0.0, 2.0, 5.0);
708        assert!(approx_eq(a, -5.0, 1e-12)); // Clamped to -accel_max
709    }
710
711    #[test]
712    fn test_pid_control_proportional() {
713        let a = pid_control(1.0, 0.0, 2.0, 5.0);
714        assert!(approx_eq(a, 2.0, 1e-12));
715    }
716
717    // -- Speed profile tests --
718
719    #[test]
720    fn test_speed_profile_forward() {
721        let cx = vec![0.0, 1.0, 2.0, 3.0];
722        let cy = vec![0.0, 0.0, 0.0, 0.0];
723        let cyaw = vec![0.0, 0.0, 0.0, 0.0];
724        let profile = calc_speed_profile(&cx, &cy, &cyaw, 1.0, 0.5);
725        assert_eq!(profile.len(), 4);
726        assert!(approx_eq(profile[0], 0.0, 1e-12)); // start is 0
727        assert!(profile[1] > 0.0); // forward
728        assert!(approx_eq(profile[3], 0.5, 1e-12)); // end is stop_speed
729    }
730
731    #[test]
732    fn test_speed_profile_backward() {
733        let cx = vec![3.0, 2.0, 1.0, 0.0];
734        let cy = vec![0.0, 0.0, 0.0, 0.0];
735        // Yaw pointing forward (+x) but moving backward (-x)
736        let cyaw = vec![0.0, 0.0, 0.0, 0.0];
737        let profile = calc_speed_profile(&cx, &cy, &cyaw, 1.0, 0.5);
738        assert!(approx_eq(profile[0], 0.0, 1e-12));
739        // Segments should be negative (backward)
740        assert!(profile[1] < 0.0 || approx_eq(profile[1], 0.0, 1e-12));
741    }
742
743    // -- Extend path tests --
744
745    #[test]
746    fn test_extend_path_length() {
747        let cx = vec![0.0, 1.0, 2.0];
748        let cy = vec![0.0, 0.0, 0.0];
749        let cyaw = vec![0.0, 0.0, 0.0];
750        let (ecx, ecy, ecyaw) = extend_path(&cx, &cy, &cyaw, 0.5);
751        assert!(ecx.len() > cx.len());
752        assert_eq!(ecx.len(), ecy.len());
753        assert_eq!(ecx.len(), ecyaw.len());
754    }
755
756    // -- pi_2_pi tests --
757
758    #[test]
759    fn test_pi_2_pi() {
760        assert!(approx_eq(pi_2_pi(0.0), 0.0, 1e-12));
761        assert!(approx_eq(pi_2_pi(PI), PI, 1e-12));
762        assert!(approx_eq(pi_2_pi(-PI), -PI, 1e-12));
763        assert!(approx_eq(pi_2_pi(3.0 * PI), PI, 1e-10));
764        assert!(approx_eq(pi_2_pi(-3.0 * PI), -PI, 1e-10));
765    }
766
767    #[test]
768    fn test_angle_diff() {
769        assert!(approx_eq(angle_diff(0.0, 0.0), 0.0, 1e-12));
770        assert!(approx_eq(angle_diff(PI, 0.0), PI, 1e-12));
771        assert!(approx_eq(angle_diff(0.0, PI), -PI, 1e-12));
772    }
773
774    // -- Closed-loop simulation tests --
775
776    #[test]
777    fn test_closed_loop_straight_path() {
778        let cx: Vec<f64> = (0..50).map(|i| i as f64 * 0.2).collect();
779        let cy = vec![0.0; cx.len()];
780        let cyaw = vec![0.0; cx.len()];
781        let vehicle = UnicycleParams::default();
782        let pp = PurePursuitParams {
783            max_time: 20.0,
784            ..Default::default()
785        };
786        let goal = [cx[cx.len() - 6], cy[cy.len() - 6], 0.0]; // goal before extension point
787
788        let (ecx, ecy, ecyaw) = extend_path(&cx, &cy, &cyaw, pp.look_ahead);
789        let esp = calc_speed_profile(&ecx, &ecy, &ecyaw, 1.0, pp.stop_speed);
790
791        let sim = closed_loop_prediction(&ecx, &ecy, &ecyaw, &esp, goal, &vehicle, &pp);
792        // Should make progress in +x direction
793        assert!(sim.x.last().unwrap() > &0.0);
794        assert!(sim.t.len() > 1);
795    }
796
797    // -- Config defaults --
798
799    #[test]
800    fn test_config_defaults() {
801        let config = ClosedLoopRRTStarConfig::default();
802        assert!(approx_eq(config.target_speed, 10.0 / 3.6, 1e-10));
803        assert!(approx_eq(config.yaw_threshold, 3.0_f64.to_radians(), 1e-10));
804        assert!(approx_eq(config.xy_threshold, 0.5, 1e-12));
805        assert!(approx_eq(config.invalid_travel_ratio, 5.0, 1e-12));
806    }
807
808    // -- Collision check --
809
810    #[test]
811    fn test_sim_collision_check_no_obstacles() {
812        let config = ClosedLoopRRTStarConfig::default();
813        let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
814        let planner = ClosedLoopRRTStarPlanner::new(vec![], rand_area, config);
815
816        let sim = SimulationResult {
817            t: vec![0.0, 0.1],
818            x: vec![0.0, 1.0],
819            y: vec![0.0, 0.0],
820            yaw: vec![0.0, 0.0],
821            v: vec![1.0, 1.0],
822            accel: vec![0.0, 0.0],
823            steer: vec![0.0, 0.0],
824            reached_goal: true,
825        };
826        assert!(planner.check_sim_collision(&sim));
827    }
828
829    #[test]
830    fn test_sim_collision_check_with_obstacle() {
831        let obstacles = vec![CircleObstacle::new(0.5, 0.0, 0.3)];
832        let config = ClosedLoopRRTStarConfig::default();
833        let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
834        let planner = ClosedLoopRRTStarPlanner::new(obstacles, rand_area, config);
835
836        let sim = SimulationResult {
837            t: vec![0.0, 0.1],
838            x: vec![0.0, 0.5],
839            y: vec![0.0, 0.0],
840            yaw: vec![0.0, 0.0],
841            v: vec![1.0, 1.0],
842            accel: vec![0.0, 0.0],
843            steer: vec![0.0, 0.0],
844            reached_goal: true,
845        };
846        assert!(!planner.check_sim_collision(&sim));
847    }
848
849    // -- Goal index selection --
850
851    #[test]
852    fn test_get_goal_indexes_filters_by_xy_and_yaw() {
853        let config = ClosedLoopRRTStarConfig {
854            xy_threshold: 1.0,
855            yaw_threshold: 0.5,
856            ..Default::default()
857        };
858        let rand_area = AreaBounds::new(-5.0, 20.0, -5.0, 20.0);
859        let planner = ClosedLoopRRTStarPlanner::new(vec![], rand_area, config);
860
861        let goal = Pose2D::new(10.0, 10.0, 0.0);
862        let tree = vec![
863            RRTStarRSNode::new(0.0, 0.0, 0.0),   // far from goal
864            RRTStarRSNode::new(10.0, 10.0, 0.0), // at goal
865            RRTStarRSNode::new(10.5, 10.5, 0.0), // near goal, good yaw
866            RRTStarRSNode::new(10.0, 10.0, PI),  // at goal, bad yaw
867        ];
868
869        let inds = planner.get_goal_indexes(&tree, &goal);
870        assert!(inds.contains(&1));
871        assert!(inds.contains(&2));
872        assert!(!inds.contains(&0));
873        assert!(!inds.contains(&3));
874    }
875
876    // -- Generate final course --
877
878    #[test]
879    fn test_generate_final_course() {
880        let config = ClosedLoopRRTStarConfig::default();
881        let rand_area = AreaBounds::new(-5.0, 20.0, -5.0, 20.0);
882        let planner = ClosedLoopRRTStarPlanner::new(vec![], rand_area, config);
883
884        let mut root = RRTStarRSNode::new(0.0, 0.0, 0.0);
885        root.parent = None;
886        let mut child = RRTStarRSNode::new(5.0, 5.0, 0.5);
887        child.parent = Some(0);
888        child.path_x = vec![0.0, 2.5, 5.0];
889        child.path_y = vec![0.0, 2.5, 5.0];
890        child.path_yaw = vec![0.0, 0.25, 0.5];
891
892        let tree = vec![root, child];
893        let course = planner.generate_final_course(&tree, 1);
894        assert!(course.len() >= 3);
895        // First point should be root
896        assert!(approx_eq(course[0].0, 0.0, 1e-12));
897        assert!(approx_eq(course[0].1, 0.0, 1e-12));
898    }
899
900    // -- Integration test: planner construction --
901
902    #[test]
903    fn test_planner_creation() {
904        let obstacles = vec![
905            CircleObstacle::new(5.0, 5.0, 1.0),
906            CircleObstacle::new(4.0, 6.0, 1.0),
907        ];
908        let rand_area = AreaBounds::new(-2.0, 20.0, -2.0, 20.0);
909        let config = ClosedLoopRRTStarConfig::default();
910        let planner = ClosedLoopRRTStarPlanner::new(obstacles, rand_area, config);
911        assert!(planner.get_tree().is_empty() || planner.get_tree().is_empty());
912    }
913
914    // -- Integration test: planning on obstacle-free env --
915
916    #[test]
917    fn test_planning_no_obstacles_deterministic() {
918        let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
919        let config = ClosedLoopRRTStarConfig {
920            rrt_config: RRTStarRSConfig {
921                max_iter: 300,
922                goal_xy_threshold: 1.5,
923                goal_yaw_threshold: 1.0,
924                connect_circle_dist: 50.0,
925                ..Default::default()
926            },
927            xy_threshold: 1.5,
928            yaw_threshold: 1.0,
929            ..Default::default()
930        };
931        let mut planner = ClosedLoopRRTStarPlanner::new(vec![], rand_area, config);
932
933        let start = Pose2D::new(0.0, 0.0, 0.0);
934        let goal = Pose2D::new(5.0, 0.0, 0.0);
935
936        let mut call = 0;
937        let result = planner.plan_with_sampler(start, goal, |p| {
938            call += 1;
939            if call % 2 == 0 {
940                RRTStarRSNode::new(p.get_tree()[0].x + 5.0, 0.0, 0.0)
941            } else {
942                RRTStarRSNode::new(p.get_tree()[0].x + 2.5, 0.0, 0.0)
943            }
944        });
945
946        // With no obstacles and a straight-line goal, we should find something
947        // (though the closed-loop check may still reject if simulation doesn't
948        // converge; that is acceptable).
949        if let Some(res) = result {
950            assert!(!res.sim.x.is_empty());
951            assert!(res.sim.reached_goal);
952            assert!(!res.geometric_poses.is_empty());
953        }
954    }
955
956    // -- Integration test: runs without panic with obstacles --
957
958    #[test]
959    fn test_planning_with_obstacles_no_panic() {
960        let obstacles = vec![
961            CircleObstacle::new(5.0, 5.0, 1.0),
962            CircleObstacle::new(4.0, 6.0, 1.0),
963            CircleObstacle::new(4.0, 8.0, 1.0),
964            CircleObstacle::new(6.0, 5.0, 1.0),
965            CircleObstacle::new(7.0, 5.0, 1.0),
966        ];
967        let rand_area = AreaBounds::new(-2.0, 15.0, -2.0, 15.0);
968        let config = ClosedLoopRRTStarConfig {
969            rrt_config: RRTStarRSConfig {
970                max_iter: 100,
971                goal_xy_threshold: 1.5,
972                goal_yaw_threshold: 1.0,
973                ..Default::default()
974            },
975            xy_threshold: 1.5,
976            yaw_threshold: 1.0,
977            ..Default::default()
978        };
979        let mut planner = ClosedLoopRRTStarPlanner::new(obstacles, rand_area, config);
980
981        let start = Pose2D::new(0.0, 0.0, 0.0);
982        let goal = Pose2D::new(6.0, 7.0, PI / 2.0);
983
984        // Random planning may or may not find a feasible path in 100 iterations;
985        // we just verify it does not panic.
986        let _result = planner.planning(start, goal);
987    }
988
989    // -- Extend path with backward motion --
990
991    #[test]
992    fn test_extend_path_backward() {
993        // Path moving in -x while yaw = 0 (backward motion)
994        let cx = vec![3.0, 2.0, 1.0];
995        let cy = vec![0.0, 0.0, 0.0];
996        let cyaw = vec![0.0, 0.0, 0.0];
997        let (ecx, _ecy, _ecyaw) = extend_path(&cx, &cy, &cyaw, 0.5);
998        // Extended points should continue in -x direction
999        assert!(ecx.last().unwrap() < &1.0);
1000    }
1001
1002    // -- SimulationResult fields --
1003
1004    #[test]
1005    fn test_simulation_result_default_fields() {
1006        let sim = SimulationResult {
1007            t: vec![],
1008            x: vec![],
1009            y: vec![],
1010            yaw: vec![],
1011            v: vec![],
1012            accel: vec![],
1013            steer: vec![],
1014            reached_goal: false,
1015        };
1016        assert!(!sim.reached_goal);
1017        assert!(sim.t.is_empty());
1018    }
1019}