1#![allow(clippy::too_many_arguments)]
2
3use rand::Rng;
9
10use rust_robotics_core::{Path2D, PathPlanner, Point2D, RoboticsError};
11
12use crate::rrt::{CircleObstacle, RRTPlanner};
13
14#[derive(Debug, Clone, Copy, PartialEq)]
15struct TargetPoint {
16 point: Point2D,
17 anchor_index: isize,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq)]
22pub struct RRTPathSmoothingConfig {
23 pub max_iter: usize,
24 pub sample_step: f64,
25}
26
27impl Default for RRTPathSmoothingConfig {
28 fn default() -> Self {
29 Self {
30 max_iter: 1000,
31 sample_step: 0.2,
32 }
33 }
34}
35
36pub struct RRTPathSmoothingPlanner {
38 rrt: RRTPlanner,
39 smoothing: RRTPathSmoothingConfig,
40 robot_radius: f64,
41}
42
43impl RRTPathSmoothingPlanner {
44 pub fn new(rrt: RRTPlanner, smoothing: RRTPathSmoothingConfig, robot_radius: f64) -> Self {
45 Self {
46 rrt,
47 smoothing,
48 robot_radius,
49 }
50 }
51
52 pub fn from_obstacles(
53 obstacle_list: Vec<(f64, f64, f64)>,
54 rand_area: [f64; 2],
55 expand_dis: f64,
56 path_resolution: f64,
57 goal_sample_rate: i32,
58 max_iter: usize,
59 play_area: Option<[f64; 4]>,
60 robot_radius: f64,
61 smoothing: RRTPathSmoothingConfig,
62 ) -> Self {
63 let rrt = RRTPlanner::from_obstacles(
64 obstacle_list,
65 rand_area,
66 expand_dis,
67 path_resolution,
68 goal_sample_rate,
69 max_iter,
70 play_area,
71 robot_radius,
72 );
73 Self::new(rrt, smoothing, robot_radius)
74 }
75
76 pub fn planning(&mut self, start: [f64; 2], goal: [f64; 2]) -> Option<Vec<[f64; 2]>> {
77 let start_pt = Point2D::new(start[0], start[1]);
78 let goal_pt = Point2D::new(goal[0], goal[1]);
79 match self.plan(start_pt, goal_pt) {
80 Ok(path) => Some(path.points.iter().map(|point| [point.x, point.y]).collect()),
81 Err(_) => None,
82 }
83 }
84
85 pub fn rrt(&self) -> &RRTPlanner {
86 &self.rrt
87 }
88}
89
90pub fn get_path_length(path: &Path2D) -> f64 {
91 path.total_length()
92}
93
94fn get_target_point(path: &Path2D, target_length: f64) -> Option<TargetPoint> {
95 if path.len() < 2 {
96 return None;
97 }
98
99 let mut accumulated = 0.0;
100 let mut anchor_index = 0_isize;
101 let mut last_pair_length = 0.0;
102 for (index, segment) in path.points.windows(2).enumerate() {
103 let length = segment[0].distance(&segment[1]);
104 accumulated += length;
105 if accumulated >= target_length {
106 anchor_index = index as isize - 1;
107 last_pair_length = length;
108 break;
109 }
110 }
111
112 if last_pair_length <= f64::EPSILON {
113 return None;
114 }
115
116 let part_ratio = (accumulated - target_length) / last_pair_length;
117 let anchor = python_index(&path.points, anchor_index)?;
118 let next = python_index(&path.points, anchor_index + 1)?;
119 Some(TargetPoint {
120 point: Point2D::new(
121 anchor.x + (next.x - anchor.x) * part_ratio,
122 anchor.y + (next.y - anchor.y) * part_ratio,
123 ),
124 anchor_index,
125 })
126}
127
128fn python_index(points: &[Point2D], index: isize) -> Option<Point2D> {
129 if points.is_empty() {
130 return None;
131 }
132
133 let resolved = if index < 0 {
134 points.len() as isize + index
135 } else {
136 index
137 };
138 points.get(resolved as usize).copied()
139}
140
141fn is_point_collision(point: Point2D, obstacle_list: &[CircleObstacle], robot_radius: f64) -> bool {
142 obstacle_list.iter().any(|obstacle| {
143 point.distance(&Point2D::new(obstacle.x, obstacle.y)) <= obstacle.radius + robot_radius
144 })
145}
146
147pub fn line_collision_check(
148 first: Point2D,
149 second: Point2D,
150 obstacle_list: &[CircleObstacle],
151 robot_radius: f64,
152 sample_step: f64,
153) -> bool {
154 let dx = second.x - first.x;
155 let dy = second.y - first.y;
156 let length = dx.hypot(dy);
157
158 if length <= f64::EPSILON {
159 return !is_point_collision(first, obstacle_list, robot_radius);
160 }
161
162 let steps = (length / sample_step) as usize + 1;
163 for step in 0..=steps {
164 let t = step as f64 / steps as f64;
165 let point = Point2D::new(first.x + t * dx, first.y + t * dy);
166 if is_point_collision(point, obstacle_list, robot_radius) {
167 return false;
168 }
169 }
170
171 true
172}
173
174pub fn shortcut_path_smoothing(
175 path: &Path2D,
176 max_iter: usize,
177 obstacle_list: &[CircleObstacle],
178 robot_radius: f64,
179 sample_step: f64,
180) -> Path2D {
181 let mut rng = rand::rng();
182 shortcut_path_smoothing_with_sampler(
183 path,
184 max_iter,
185 obstacle_list,
186 robot_radius,
187 sample_step,
188 move |length| {
189 let first = rng.random_range(0.0..=length);
190 let second = rng.random_range(0.0..=length);
191 (first, second)
192 },
193 )
194}
195
196fn shortcut_path_smoothing_with_sampler<F>(
197 path: &Path2D,
198 max_iter: usize,
199 obstacle_list: &[CircleObstacle],
200 robot_radius: f64,
201 sample_step: f64,
202 mut sampler: F,
203) -> Path2D
204where
205 F: FnMut(f64) -> (f64, f64),
206{
207 let mut path = path.clone();
208 let mut length = get_path_length(&path);
209
210 for _ in 0..max_iter {
211 let (first_pick, second_pick) = sampler(length);
212 let (first_pick, second_pick) = if first_pick <= second_pick {
213 (first_pick, second_pick)
214 } else {
215 (second_pick, first_pick)
216 };
217
218 let Some(first) = get_target_point(&path, first_pick) else {
219 continue;
220 };
221 let Some(second) = get_target_point(&path, second_pick) else {
222 continue;
223 };
224
225 if first.anchor_index <= 0 || second.anchor_index <= 0 {
226 continue;
227 }
228 if second.anchor_index as usize + 1 > path.len() {
229 continue;
230 }
231 if second.anchor_index == first.anchor_index {
232 continue;
233 }
234 if !line_collision_check(
235 first.point,
236 second.point,
237 obstacle_list,
238 robot_radius,
239 sample_step,
240 ) {
241 continue;
242 }
243
244 let mut new_points = Vec::new();
245 new_points.extend_from_slice(&path.points[..=first.anchor_index as usize]);
246 new_points.push(first.point);
247 new_points.push(second.point);
248 new_points.extend_from_slice(&path.points[second.anchor_index as usize + 1..]);
249 path = Path2D::from_points(new_points);
250 length = get_path_length(&path);
251 }
252
253 path
254}
255
256impl PathPlanner for RRTPathSmoothingPlanner {
257 fn plan(&self, start: Point2D, goal: Point2D) -> Result<Path2D, RoboticsError> {
258 let raw_path = self.rrt.plan(start, goal)?;
259 Ok(shortcut_path_smoothing(
260 &raw_path,
261 self.smoothing.max_iter,
262 self.rrt.get_obstacles(),
263 self.robot_radius,
264 self.smoothing.sample_step,
265 ))
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 fn assert_close(actual: f64, expected: f64) {
274 assert!(
275 (actual - expected).abs() < 1.0e-12,
276 "expected {expected}, got {actual}"
277 );
278 }
279
280 fn assert_point_close(actual: &Point2D, expected: [f64; 2]) {
281 assert_close(actual.x, expected[0]);
282 assert_close(actual.y, expected[1]);
283 }
284
285 fn parse_xy_fixture(csv: &str) -> Path2D {
286 let points = csv
287 .lines()
288 .skip(1)
289 .filter(|line| !line.trim().is_empty())
290 .map(|line| {
291 let (x, y) = line
292 .split_once(',')
293 .expect("xy fixture rows must contain a comma");
294 Point2D::new(x.parse().unwrap(), y.parse().unwrap())
295 })
296 .collect();
297 Path2D::from_points(points)
298 }
299
300 fn pythonrobotics_obstacles() -> Vec<CircleObstacle> {
301 vec![
302 CircleObstacle::new(5.0, 5.0, 1.0),
303 CircleObstacle::new(3.0, 6.0, 2.0),
304 CircleObstacle::new(3.0, 8.0, 2.0),
305 CircleObstacle::new(3.0, 10.0, 2.0),
306 CircleObstacle::new(7.0, 5.0, 2.0),
307 CircleObstacle::new(9.0, 5.0, 2.0),
308 ]
309 }
310
311 #[test]
312 fn test_target_point_matches_pythonrobotics_reference() {
313 let path = parse_xy_fixture(include_str!("testdata/rrt_main_seed12345_path.csv"));
314 let target = get_target_point(&path, 5.042_711_784_190_443).unwrap();
315 assert_eq!(target.anchor_index, 1);
316 assert_point_close(
317 &target.point,
318 [1.607_799_680_222_086_9, 3.021_515_605_739_562_5],
319 );
320 }
321
322 #[test]
323 fn test_shortcut_path_smoothing_fixed_schedule_matches_pythonrobotics_reference() {
324 let path = parse_xy_fixture(include_str!("testdata/rrt_main_seed12345_path.csv"));
325 let obstacles = pythonrobotics_obstacles();
326 let picks = [
327 (5.042_711_784_190_443, 9.593_003_338_960_78),
328 (9.062_254_143_395_58, 14.126_262_302_507_822),
329 (18.836_964_234_995_584, 24.764_066_566_436_444),
330 (18.305_913_662_407_01, 20.679_131_888_162_48),
331 (13.387_208_342_923_934, 17.057_633_477_275_18),
332 (5.353_949_073_814_717, 12.050_990_819_270_86),
333 ];
334 let mut pick_iter = picks.iter().copied();
335
336 let smoothed =
337 shortcut_path_smoothing_with_sampler(&path, picks.len(), &obstacles, 0.3, 0.2, |_| {
338 pick_iter.next().expect("pick schedule exhausted")
339 });
340
341 let expected = [
342 [0.0, 0.0],
343 [1.976_107_921_083_293, 2.257_210_110_785_406],
344 [1.856_622_402_011_424, 2.505_163_940_383_516],
345 [-0.798_753_011_969_844_9, 6.802_754_367_664_581],
346 [-1.179_802_666_459_188_3, 9.161_294_659_229_442],
347 [-0.313_449_454_851_734_9, 10.861_642_151_758_218],
348 [-0.600_781_734_440_396_5, 11.906_895_396_420_353],
349 [-0.267_061_812_685_109_03, 12.274_920_423_720_9],
350 [1.817_773_784_492_888_7, 13.672_274_852_803_103],
351 [2.390_944_751_282_898, 13.046_732_672_990_332],
352 [5.705_791_577_974_288, 11.515_986_653_768_12],
353 [5.860_033_119_067_657, 10.721_216_347_248_003],
354 [6.0, 10.0],
355 ];
356
357 assert_eq!(smoothed.len(), expected.len());
358 for (actual, expected) in smoothed.points.iter().zip(expected.iter()) {
359 assert_point_close(actual, *expected);
360 }
361 assert_close(smoothed.total_length(), 22.759_016_831_433_197);
362 }
363
364 #[test]
365 fn test_smoothed_path_points_remain_safe() {
366 let path = parse_xy_fixture(include_str!("testdata/rrt_main_seed12345_path.csv"));
367 let obstacles = pythonrobotics_obstacles();
368 let picks = [
369 (5.042_711_784_190_443, 9.593_003_338_960_78),
370 (9.062_254_143_395_58, 14.126_262_302_507_822),
371 (18.836_964_234_995_584, 24.764_066_566_436_444),
372 (18.305_913_662_407_01, 20.679_131_888_162_48),
373 (13.387_208_342_923_934, 17.057_633_477_275_18),
374 (5.353_949_073_814_717, 12.050_990_819_270_86),
375 ];
376 let mut pick_iter = picks.iter().copied();
377 let smoothed =
378 shortcut_path_smoothing_with_sampler(&path, picks.len(), &obstacles, 0.5, 0.2, |_| {
379 pick_iter.next().expect("pick schedule exhausted")
380 });
381
382 for point in &smoothed.points {
383 for obstacle in &obstacles {
384 let distance = point.distance(&Point2D::new(obstacle.x, obstacle.y));
385 assert!(
386 distance > obstacle.radius + 0.5,
387 "point ({:.6}, {:.6}) too close to obstacle ({:.6}, {:.6})",
388 point.x,
389 point.y,
390 obstacle.x,
391 obstacle.y
392 );
393 }
394 }
395 }
396}