Skip to main content

rust_robotics_planning/
rrt_path_smoothing.rs

1#![allow(clippy::too_many_arguments)]
2
3//! RRT shortcut path smoothing.
4//!
5//! This mirrors PythonRobotics `PathPlanning/RRT/rrt_with_pathsmoothing.py`
6//! rather than the grid line-of-sight post-processing in `path_smoothing.rs`.
7
8use rand::Rng;
9
10use rust_robotics_core::{Path2D, PathPlanner, Point2D, RoboticsError};
11
12use crate::rrt::{CircleObstacle, RRTPlanner};
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15struct TargetPoint {
16    point: Point2D,
17    anchor_index: isize,
18}
19
20/// Configuration for shortcut smoothing over an RRT-generated path.
21#[derive(Debug, Clone, Copy, PartialEq)]
22pub struct RRTPathSmoothingConfig {
23    pub max_iter: usize,
24    pub sample_step: f64,
25}
26
27impl Default for RRTPathSmoothingConfig {
28    fn default() -> Self {
29        Self {
30            max_iter: 1000,
31            sample_step: 0.2,
32        }
33    }
34}
35
36/// Shortcut-smoothed RRT planner.
37pub struct RRTPathSmoothingPlanner {
38    rrt: RRTPlanner,
39    smoothing: RRTPathSmoothingConfig,
40    robot_radius: f64,
41}
42
43impl RRTPathSmoothingPlanner {
44    pub fn new(rrt: RRTPlanner, smoothing: RRTPathSmoothingConfig, robot_radius: f64) -> Self {
45        Self {
46            rrt,
47            smoothing,
48            robot_radius,
49        }
50    }
51
52    pub fn from_obstacles(
53        obstacle_list: Vec<(f64, f64, f64)>,
54        rand_area: [f64; 2],
55        expand_dis: f64,
56        path_resolution: f64,
57        goal_sample_rate: i32,
58        max_iter: usize,
59        play_area: Option<[f64; 4]>,
60        robot_radius: f64,
61        smoothing: RRTPathSmoothingConfig,
62    ) -> Self {
63        let rrt = RRTPlanner::from_obstacles(
64            obstacle_list,
65            rand_area,
66            expand_dis,
67            path_resolution,
68            goal_sample_rate,
69            max_iter,
70            play_area,
71            robot_radius,
72        );
73        Self::new(rrt, smoothing, robot_radius)
74    }
75
76    pub fn planning(&mut self, start: [f64; 2], goal: [f64; 2]) -> Option<Vec<[f64; 2]>> {
77        let start_pt = Point2D::new(start[0], start[1]);
78        let goal_pt = Point2D::new(goal[0], goal[1]);
79        match self.plan(start_pt, goal_pt) {
80            Ok(path) => Some(path.points.iter().map(|point| [point.x, point.y]).collect()),
81            Err(_) => None,
82        }
83    }
84
85    pub fn rrt(&self) -> &RRTPlanner {
86        &self.rrt
87    }
88}
89
90pub fn get_path_length(path: &Path2D) -> f64 {
91    path.total_length()
92}
93
94fn get_target_point(path: &Path2D, target_length: f64) -> Option<TargetPoint> {
95    if path.len() < 2 {
96        return None;
97    }
98
99    let mut accumulated = 0.0;
100    let mut anchor_index = 0_isize;
101    let mut last_pair_length = 0.0;
102    for (index, segment) in path.points.windows(2).enumerate() {
103        let length = segment[0].distance(&segment[1]);
104        accumulated += length;
105        if accumulated >= target_length {
106            anchor_index = index as isize - 1;
107            last_pair_length = length;
108            break;
109        }
110    }
111
112    if last_pair_length <= f64::EPSILON {
113        return None;
114    }
115
116    let part_ratio = (accumulated - target_length) / last_pair_length;
117    let anchor = python_index(&path.points, anchor_index)?;
118    let next = python_index(&path.points, anchor_index + 1)?;
119    Some(TargetPoint {
120        point: Point2D::new(
121            anchor.x + (next.x - anchor.x) * part_ratio,
122            anchor.y + (next.y - anchor.y) * part_ratio,
123        ),
124        anchor_index,
125    })
126}
127
128fn python_index(points: &[Point2D], index: isize) -> Option<Point2D> {
129    if points.is_empty() {
130        return None;
131    }
132
133    let resolved = if index < 0 {
134        points.len() as isize + index
135    } else {
136        index
137    };
138    points.get(resolved as usize).copied()
139}
140
141fn is_point_collision(point: Point2D, obstacle_list: &[CircleObstacle], robot_radius: f64) -> bool {
142    obstacle_list.iter().any(|obstacle| {
143        point.distance(&Point2D::new(obstacle.x, obstacle.y)) <= obstacle.radius + robot_radius
144    })
145}
146
147pub fn line_collision_check(
148    first: Point2D,
149    second: Point2D,
150    obstacle_list: &[CircleObstacle],
151    robot_radius: f64,
152    sample_step: f64,
153) -> bool {
154    let dx = second.x - first.x;
155    let dy = second.y - first.y;
156    let length = dx.hypot(dy);
157
158    if length <= f64::EPSILON {
159        return !is_point_collision(first, obstacle_list, robot_radius);
160    }
161
162    let steps = (length / sample_step) as usize + 1;
163    for step in 0..=steps {
164        let t = step as f64 / steps as f64;
165        let point = Point2D::new(first.x + t * dx, first.y + t * dy);
166        if is_point_collision(point, obstacle_list, robot_radius) {
167            return false;
168        }
169    }
170
171    true
172}
173
174pub fn shortcut_path_smoothing(
175    path: &Path2D,
176    max_iter: usize,
177    obstacle_list: &[CircleObstacle],
178    robot_radius: f64,
179    sample_step: f64,
180) -> Path2D {
181    let mut rng = rand::rng();
182    shortcut_path_smoothing_with_sampler(
183        path,
184        max_iter,
185        obstacle_list,
186        robot_radius,
187        sample_step,
188        move |length| {
189            let first = rng.random_range(0.0..=length);
190            let second = rng.random_range(0.0..=length);
191            (first, second)
192        },
193    )
194}
195
196fn shortcut_path_smoothing_with_sampler<F>(
197    path: &Path2D,
198    max_iter: usize,
199    obstacle_list: &[CircleObstacle],
200    robot_radius: f64,
201    sample_step: f64,
202    mut sampler: F,
203) -> Path2D
204where
205    F: FnMut(f64) -> (f64, f64),
206{
207    let mut path = path.clone();
208    let mut length = get_path_length(&path);
209
210    for _ in 0..max_iter {
211        let (first_pick, second_pick) = sampler(length);
212        let (first_pick, second_pick) = if first_pick <= second_pick {
213            (first_pick, second_pick)
214        } else {
215            (second_pick, first_pick)
216        };
217
218        let Some(first) = get_target_point(&path, first_pick) else {
219            continue;
220        };
221        let Some(second) = get_target_point(&path, second_pick) else {
222            continue;
223        };
224
225        if first.anchor_index <= 0 || second.anchor_index <= 0 {
226            continue;
227        }
228        if second.anchor_index as usize + 1 > path.len() {
229            continue;
230        }
231        if second.anchor_index == first.anchor_index {
232            continue;
233        }
234        if !line_collision_check(
235            first.point,
236            second.point,
237            obstacle_list,
238            robot_radius,
239            sample_step,
240        ) {
241            continue;
242        }
243
244        let mut new_points = Vec::new();
245        new_points.extend_from_slice(&path.points[..=first.anchor_index as usize]);
246        new_points.push(first.point);
247        new_points.push(second.point);
248        new_points.extend_from_slice(&path.points[second.anchor_index as usize + 1..]);
249        path = Path2D::from_points(new_points);
250        length = get_path_length(&path);
251    }
252
253    path
254}
255
256impl PathPlanner for RRTPathSmoothingPlanner {
257    fn plan(&self, start: Point2D, goal: Point2D) -> Result<Path2D, RoboticsError> {
258        let raw_path = self.rrt.plan(start, goal)?;
259        Ok(shortcut_path_smoothing(
260            &raw_path,
261            self.smoothing.max_iter,
262            self.rrt.get_obstacles(),
263            self.robot_radius,
264            self.smoothing.sample_step,
265        ))
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    fn assert_close(actual: f64, expected: f64) {
274        assert!(
275            (actual - expected).abs() < 1.0e-12,
276            "expected {expected}, got {actual}"
277        );
278    }
279
280    fn assert_point_close(actual: &Point2D, expected: [f64; 2]) {
281        assert_close(actual.x, expected[0]);
282        assert_close(actual.y, expected[1]);
283    }
284
285    fn parse_xy_fixture(csv: &str) -> Path2D {
286        let points = csv
287            .lines()
288            .skip(1)
289            .filter(|line| !line.trim().is_empty())
290            .map(|line| {
291                let (x, y) = line
292                    .split_once(',')
293                    .expect("xy fixture rows must contain a comma");
294                Point2D::new(x.parse().unwrap(), y.parse().unwrap())
295            })
296            .collect();
297        Path2D::from_points(points)
298    }
299
300    fn pythonrobotics_obstacles() -> Vec<CircleObstacle> {
301        vec![
302            CircleObstacle::new(5.0, 5.0, 1.0),
303            CircleObstacle::new(3.0, 6.0, 2.0),
304            CircleObstacle::new(3.0, 8.0, 2.0),
305            CircleObstacle::new(3.0, 10.0, 2.0),
306            CircleObstacle::new(7.0, 5.0, 2.0),
307            CircleObstacle::new(9.0, 5.0, 2.0),
308        ]
309    }
310
311    #[test]
312    fn test_target_point_matches_pythonrobotics_reference() {
313        let path = parse_xy_fixture(include_str!("testdata/rrt_main_seed12345_path.csv"));
314        let target = get_target_point(&path, 5.042_711_784_190_443).unwrap();
315        assert_eq!(target.anchor_index, 1);
316        assert_point_close(
317            &target.point,
318            [1.607_799_680_222_086_9, 3.021_515_605_739_562_5],
319        );
320    }
321
322    #[test]
323    fn test_shortcut_path_smoothing_fixed_schedule_matches_pythonrobotics_reference() {
324        let path = parse_xy_fixture(include_str!("testdata/rrt_main_seed12345_path.csv"));
325        let obstacles = pythonrobotics_obstacles();
326        let picks = [
327            (5.042_711_784_190_443, 9.593_003_338_960_78),
328            (9.062_254_143_395_58, 14.126_262_302_507_822),
329            (18.836_964_234_995_584, 24.764_066_566_436_444),
330            (18.305_913_662_407_01, 20.679_131_888_162_48),
331            (13.387_208_342_923_934, 17.057_633_477_275_18),
332            (5.353_949_073_814_717, 12.050_990_819_270_86),
333        ];
334        let mut pick_iter = picks.iter().copied();
335
336        let smoothed =
337            shortcut_path_smoothing_with_sampler(&path, picks.len(), &obstacles, 0.3, 0.2, |_| {
338                pick_iter.next().expect("pick schedule exhausted")
339            });
340
341        let expected = [
342            [0.0, 0.0],
343            [1.976_107_921_083_293, 2.257_210_110_785_406],
344            [1.856_622_402_011_424, 2.505_163_940_383_516],
345            [-0.798_753_011_969_844_9, 6.802_754_367_664_581],
346            [-1.179_802_666_459_188_3, 9.161_294_659_229_442],
347            [-0.313_449_454_851_734_9, 10.861_642_151_758_218],
348            [-0.600_781_734_440_396_5, 11.906_895_396_420_353],
349            [-0.267_061_812_685_109_03, 12.274_920_423_720_9],
350            [1.817_773_784_492_888_7, 13.672_274_852_803_103],
351            [2.390_944_751_282_898, 13.046_732_672_990_332],
352            [5.705_791_577_974_288, 11.515_986_653_768_12],
353            [5.860_033_119_067_657, 10.721_216_347_248_003],
354            [6.0, 10.0],
355        ];
356
357        assert_eq!(smoothed.len(), expected.len());
358        for (actual, expected) in smoothed.points.iter().zip(expected.iter()) {
359            assert_point_close(actual, *expected);
360        }
361        assert_close(smoothed.total_length(), 22.759_016_831_433_197);
362    }
363
364    #[test]
365    fn test_smoothed_path_points_remain_safe() {
366        let path = parse_xy_fixture(include_str!("testdata/rrt_main_seed12345_path.csv"));
367        let obstacles = pythonrobotics_obstacles();
368        let picks = [
369            (5.042_711_784_190_443, 9.593_003_338_960_78),
370            (9.062_254_143_395_58, 14.126_262_302_507_822),
371            (18.836_964_234_995_584, 24.764_066_566_436_444),
372            (18.305_913_662_407_01, 20.679_131_888_162_48),
373            (13.387_208_342_923_934, 17.057_633_477_275_18),
374            (5.353_949_073_814_717, 12.050_990_819_270_86),
375        ];
376        let mut pick_iter = picks.iter().copied();
377        let smoothed =
378            shortcut_path_smoothing_with_sampler(&path, picks.len(), &obstacles, 0.5, 0.2, |_| {
379                pick_iter.next().expect("pick schedule exhausted")
380            });
381
382        for point in &smoothed.points {
383            for obstacle in &obstacles {
384                let distance = point.distance(&Point2D::new(obstacle.x, obstacle.y));
385                assert!(
386                    distance > obstacle.radius + 0.5,
387                    "point ({:.6}, {:.6}) too close to obstacle ({:.6}, {:.6})",
388                    point.x,
389                    point.y,
390                    obstacle.x,
391                    obstacle.y
392                );
393            }
394        }
395    }
396}