1#[derive(Debug, Clone, Copy, PartialEq)]
13pub enum Parameterisation {
14 Uniform,
16 Centripetal,
19 Chordal,
21 Custom(f64),
23}
24
25impl Parameterisation {
26 fn alpha(self) -> f64 {
27 match self {
28 Parameterisation::Uniform => 0.0,
29 Parameterisation::Centripetal => 0.5,
30 Parameterisation::Chordal => 1.0,
31 Parameterisation::Custom(a) => a,
32 }
33 }
34}
35
36#[derive(Debug, Clone)]
38pub struct CatmullRomSpline {
39 points: Vec<(f64, f64)>,
41 param: Parameterisation,
43 points_per_segment: usize,
45}
46
47impl CatmullRomSpline {
48 pub fn new(
59 points: Vec<(f64, f64)>,
60 points_per_segment: usize,
61 param: Parameterisation,
62 ) -> Self {
63 assert!(points.len() >= 2, "At least 2 waypoints are required");
64 assert!(points_per_segment >= 2, "points_per_segment must be >= 2");
65 CatmullRomSpline {
66 points,
67 param,
68 points_per_segment,
69 }
70 }
71
72 pub fn generate_path(&self) -> Vec<(f64, f64)> {
76 let alpha = self.param.alpha();
77 let n = self.points.len();
78 let mut path: Vec<(f64, f64)> = Vec::new();
79
80 for i in 0..n - 1 {
81 let p0 = if i == 0 {
82 self.points[0]
83 } else {
84 self.points[i - 1]
85 };
86 let p1 = self.points[i];
87 let p2 = self.points[i + 1];
88 let p3 = if i + 2 < n {
89 self.points[i + 2]
90 } else {
91 self.points[n - 1]
92 };
93
94 let is_last_segment = i == n - 2;
95 let segment = Self::segment_points(
96 p0,
97 p1,
98 p2,
99 p3,
100 self.points_per_segment,
101 alpha,
102 is_last_segment,
103 );
104 path.extend(segment);
105 }
106
107 path
108 }
109
110 pub fn generate_path_with_derivatives(&self) -> (Vec<(f64, f64)>, Vec<f64>, Vec<f64>) {
115 let alpha = self.param.alpha();
116 let n = self.points.len();
117 let mut path: Vec<(f64, f64)> = Vec::new();
118 let mut yaw: Vec<f64> = Vec::new();
119 let mut curvature: Vec<f64> = Vec::new();
120
121 for i in 0..n - 1 {
122 let p0 = if i == 0 {
123 self.points[0]
124 } else {
125 self.points[i - 1]
126 };
127 let p1 = self.points[i];
128 let p2 = self.points[i + 1];
129 let p3 = if i + 2 < n {
130 self.points[i + 2]
131 } else {
132 self.points[n - 1]
133 };
134
135 let is_last_segment = i == n - 2;
136 let seg = Self::segment_points_with_derivatives(
137 p0,
138 p1,
139 p2,
140 p3,
141 self.points_per_segment,
142 alpha,
143 is_last_segment,
144 );
145 path.extend(seg.0);
146 yaw.extend(seg.1);
147 curvature.extend(seg.2);
148 }
149
150 (path, yaw, curvature)
151 }
152
153 fn knot_interval(p0: (f64, f64), p1: (f64, f64), alpha: f64) -> f64 {
155 let dx = p1.0 - p0.0;
156 let dy = p1.1 - p0.1;
157 let d2 = dx * dx + dy * dy;
158 d2.powf(alpha * 0.5)
159 }
160
161 fn segment_points(
167 p0: (f64, f64),
168 p1: (f64, f64),
169 p2: (f64, f64),
170 p3: (f64, f64),
171 num_points: usize,
172 alpha: f64,
173 include_endpoint: bool,
174 ) -> Vec<(f64, f64)> {
175 if alpha == 0.0 {
176 Self::segment_uniform(p0, p1, p2, p3, num_points, include_endpoint)
178 } else {
179 Self::segment_general(p0, p1, p2, p3, num_points, alpha, include_endpoint)
181 }
182 }
183
184 fn segment_uniform(
186 p0: (f64, f64),
187 p1: (f64, f64),
188 p2: (f64, f64),
189 p3: (f64, f64),
190 num_points: usize,
191 include_endpoint: bool,
192 ) -> Vec<(f64, f64)> {
193 let count = if include_endpoint {
194 num_points
195 } else {
196 num_points - 1
197 };
198 let mut pts = Vec::with_capacity(count);
199
200 for j in 0..count {
201 let t = j as f64 / (num_points - 1) as f64;
202 let t2 = t * t;
203 let t3 = t2 * t;
204
205 let x = 0.5
206 * ((2.0 * p1.0)
207 + (-p0.0 + p2.0) * t
208 + (2.0 * p0.0 - 5.0 * p1.0 + 4.0 * p2.0 - p3.0) * t2
209 + (-p0.0 + 3.0 * p1.0 - 3.0 * p2.0 + p3.0) * t3);
210 let y = 0.5
211 * ((2.0 * p1.1)
212 + (-p0.1 + p2.1) * t
213 + (2.0 * p0.1 - 5.0 * p1.1 + 4.0 * p2.1 - p3.1) * t2
214 + (-p0.1 + 3.0 * p1.1 - 3.0 * p2.1 + p3.1) * t3);
215
216 pts.push((x, y));
217 }
218
219 pts
220 }
221
222 fn segment_general(
225 p0: (f64, f64),
226 p1: (f64, f64),
227 p2: (f64, f64),
228 p3: (f64, f64),
229 num_points: usize,
230 alpha: f64,
231 include_endpoint: bool,
232 ) -> Vec<(f64, f64)> {
233 let t0: f64 = 0.0;
234 let t1 = t0 + Self::knot_interval(p0, p1, alpha);
235 let t2 = t1 + Self::knot_interval(p1, p2, alpha);
236 let t3 = t2 + Self::knot_interval(p2, p3, alpha);
237
238 let count = if include_endpoint {
239 num_points
240 } else {
241 num_points - 1
242 };
243 let mut pts = Vec::with_capacity(count);
244
245 for j in 0..count {
246 let frac = j as f64 / (num_points - 1) as f64;
247 let t = t1 + frac * (t2 - t1);
248
249 let a1 = Self::lerp(p0, p1, t0, t1, t);
250 let a2 = Self::lerp(p1, p2, t1, t2, t);
251 let a3 = Self::lerp(p2, p3, t2, t3, t);
252
253 let b1 = Self::lerp(a1, a2, t0, t2, t);
254 let b2 = Self::lerp(a2, a3, t1, t3, t);
255
256 let c = Self::lerp(b1, b2, t1, t2, t);
257 pts.push(c);
258 }
259
260 pts
261 }
262
263 fn segment_points_with_derivatives(
266 p0: (f64, f64),
267 p1: (f64, f64),
268 p2: (f64, f64),
269 p3: (f64, f64),
270 num_points: usize,
271 alpha: f64,
272 include_endpoint: bool,
273 ) -> (Vec<(f64, f64)>, Vec<f64>, Vec<f64>) {
274 let positions = Self::segment_points(p0, p1, p2, p3, num_points, alpha, include_endpoint);
275
276 let n = positions.len();
277 let mut yaw = Vec::with_capacity(n);
278 let mut curvature = Vec::with_capacity(n);
279
280 for i in 0..n {
281 let (dx, dy, ddx, ddy) = if i == 0 && n > 1 {
282 let dx = positions[1].0 - positions[0].0;
283 let dy = positions[1].1 - positions[0].1;
284 (dx, dy, 0.0, 0.0)
285 } else if i == n - 1 && n > 1 {
286 let dx = positions[n - 1].0 - positions[n - 2].0;
287 let dy = positions[n - 1].1 - positions[n - 2].1;
288 (dx, dy, 0.0, 0.0)
289 } else {
290 let dx = (positions[i + 1].0 - positions[i - 1].0) * 0.5;
291 let dy = (positions[i + 1].1 - positions[i - 1].1) * 0.5;
292 let ddx = positions[i + 1].0 - 2.0 * positions[i].0 + positions[i - 1].0;
293 let ddy = positions[i + 1].1 - 2.0 * positions[i].1 + positions[i - 1].1;
294 (dx, dy, ddx, ddy)
295 };
296
297 yaw.push(dy.atan2(dx));
298 let denom = (dx * dx + dy * dy).powf(1.5);
299 if denom.abs() < 1e-12 {
300 curvature.push(0.0);
301 } else {
302 curvature.push((ddy * dx - ddx * dy) / denom);
303 }
304 }
305
306 (positions, yaw, curvature)
307 }
308
309 #[inline]
311 fn lerp(p0: (f64, f64), p1: (f64, f64), t0: f64, t1: f64, t: f64) -> (f64, f64) {
312 let d = t1 - t0;
313 if d.abs() < 1e-15 {
314 return p0;
315 }
316 let a = (t1 - t) / d;
317 let b = (t - t0) / d;
318 (a * p0.0 + b * p1.0, a * p0.1 + b * p1.1)
319 }
320}
321
322#[derive(Debug, Clone)]
325pub struct CatmullRomPlanner {
326 pub path: Vec<(f64, f64)>,
328 pub yaw: Vec<f64>,
330 pub curvature: Vec<f64>,
332}
333
334impl CatmullRomPlanner {
335 pub fn new() -> Self {
336 CatmullRomPlanner {
337 path: Vec::new(),
338 yaw: Vec::new(),
339 curvature: Vec::new(),
340 }
341 }
342
343 pub fn planning(
353 &mut self,
354 waypoints_x: Vec<f64>,
355 waypoints_y: Vec<f64>,
356 points_per_segment: usize,
357 param: Parameterisation,
358 ) -> bool {
359 if waypoints_x.len() != waypoints_y.len() || waypoints_x.len() < 2 {
360 return false;
361 }
362 if points_per_segment < 2 {
363 return false;
364 }
365
366 let points: Vec<(f64, f64)> = waypoints_x.into_iter().zip(waypoints_y).collect();
367
368 let spline = CatmullRomSpline::new(points, points_per_segment, param);
369 let (path, yaw, curvature) = spline.generate_path_with_derivatives();
370
371 self.path = path;
372 self.yaw = yaw;
373 self.curvature = curvature;
374
375 true
376 }
377}
378
379impl Default for CatmullRomPlanner {
380 fn default() -> Self {
381 Self::new()
382 }
383}
384
385#[cfg(test)]
386mod tests {
387 use super::*;
388
389 fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
390 (a - b).abs() < tol
391 }
392
393 #[test]
396 #[should_panic(expected = "At least 2 waypoints")]
397 fn test_too_few_points() {
398 CatmullRomSpline::new(vec![(0.0, 0.0)], 10, Parameterisation::Uniform);
399 }
400
401 #[test]
402 #[should_panic(expected = "points_per_segment must be >= 2")]
403 fn test_too_few_per_segment() {
404 CatmullRomSpline::new(vec![(0.0, 0.0), (1.0, 1.0)], 1, Parameterisation::Uniform);
405 }
406
407 #[test]
410 fn test_uniform_matches_python_reference() {
411 let way_points = vec![
413 (-1.0, -2.0),
414 (1.0, -1.0),
415 (3.0, -2.0),
416 (4.0, -1.0),
417 (3.0, 1.0),
418 (1.0, 2.0),
419 (0.0, 2.0),
420 ];
421 let n_course_point = 100;
422
423 let spline = CatmullRomSpline::new(
424 way_points.clone(),
425 n_course_point,
426 Parameterisation::Uniform,
427 );
428 let path = spline.generate_path();
429
430 assert!(path.len() > 100);
436
437 assert!(approx_eq(path[0].0, -1.0, 1e-12));
439 assert!(approx_eq(path[0].1, -2.0, 1e-12));
440
441 let last = path[path.len() - 1];
443 assert!(approx_eq(last.0, 0.0, 1e-12));
444 assert!(approx_eq(last.1, 2.0, 1e-12));
445 }
446
447 #[test]
450 fn test_passes_through_all_control_points() {
451 let way_points = vec![(0.0, 0.0), (1.0, 2.0), (3.0, 1.0), (5.0, 4.0), (7.0, 0.0)];
452 let n = 50;
453
454 for param in &[
455 Parameterisation::Uniform,
456 Parameterisation::Centripetal,
457 Parameterisation::Chordal,
458 ] {
459 let spline = CatmullRomSpline::new(way_points.clone(), n, *param);
460 let path = spline.generate_path();
461
462 assert!(
464 approx_eq(path[0].0, way_points[0].0, 1e-10),
465 "param={param:?}"
466 );
467 assert!(
468 approx_eq(path[0].1, way_points[0].1, 1e-10),
469 "param={param:?}"
470 );
471 let last = path.last().unwrap();
472 let wlast = way_points.last().unwrap();
473 assert!(approx_eq(last.0, wlast.0, 1e-10), "param={param:?}");
474 assert!(approx_eq(last.1, wlast.1, 1e-10), "param={param:?}");
475
476 for (i, wp) in way_points
480 .iter()
481 .enumerate()
482 .skip(1)
483 .take(way_points.len() - 2)
484 {
485 let idx = i * (n - 1);
486 assert!(
487 approx_eq(path[idx].0, wp.0, 1e-10),
488 "param={param:?}, i={i}"
489 );
490 assert!(
491 approx_eq(path[idx].1, wp.1, 1e-10),
492 "param={param:?}, i={i}"
493 );
494 }
495 }
496 }
497
498 #[test]
501 fn test_straight_line() {
502 let way_points = vec![(0.0, 0.0), (1.0, 1.0), (2.0, 2.0), (3.0, 3.0)];
503 let spline = CatmullRomSpline::new(way_points, 20, Parameterisation::Uniform);
504 let path = spline.generate_path();
505
506 for &(x, y) in &path {
507 assert!(
508 approx_eq(x, y, 1e-10),
509 "expected x==y on a diagonal line, got ({x}, {y})"
510 );
511 }
512 }
513
514 #[test]
517 fn test_two_points() {
518 let way_points = vec![(0.0, 0.0), (10.0, 5.0)];
519 let spline = CatmullRomSpline::new(way_points, 20, Parameterisation::Centripetal);
520 let path = spline.generate_path();
521 assert_eq!(path.len(), 20);
522 assert!(approx_eq(path[0].0, 0.0, 1e-12));
523 assert!(approx_eq(path[0].1, 0.0, 1e-12));
524 assert!(approx_eq(path[19].0, 10.0, 1e-12));
525 assert!(approx_eq(path[19].1, 5.0, 1e-12));
526 }
527
528 #[test]
531 fn test_planner_success() {
532 let mut planner = CatmullRomPlanner::new();
533 let ok = planner.planning(
534 vec![-1.0, 1.0, 3.0, 4.0, 3.0, 1.0, 0.0],
535 vec![-2.0, -1.0, -2.0, -1.0, 1.0, 2.0, 2.0],
536 50,
537 Parameterisation::Centripetal,
538 );
539 assert!(ok);
540 assert!(!planner.path.is_empty());
541 assert_eq!(planner.path.len(), planner.yaw.len());
542 assert_eq!(planner.path.len(), planner.curvature.len());
543 }
544
545 #[test]
546 fn test_planner_bad_input() {
547 let mut planner = CatmullRomPlanner::new();
548 assert!(!planner.planning(vec![0.0], vec![0.0, 1.0], 10, Parameterisation::Uniform));
549 assert!(!planner.planning(vec![0.0], vec![0.0], 10, Parameterisation::Uniform));
550 }
551
552 #[test]
555 fn test_centripetal_differs_from_uniform() {
556 let way_points = vec![(0.0, 0.0), (1.0, 5.0), (4.0, 0.0), (6.0, 3.0)];
557 let n = 30;
558 let uniform =
559 CatmullRomSpline::new(way_points.clone(), n, Parameterisation::Uniform).generate_path();
560 let centripetal =
561 CatmullRomSpline::new(way_points, n, Parameterisation::Centripetal).generate_path();
562
563 assert_eq!(uniform.len(), centripetal.len());
565 let mut diff_count = 0;
566 for (u, c) in uniform.iter().zip(centripetal.iter()) {
567 if (u.0 - c.0).abs() > 1e-10 || (u.1 - c.1).abs() > 1e-10 {
568 diff_count += 1;
569 }
570 }
571 assert!(diff_count > 0, "centripetal should differ from uniform");
573 }
574
575 #[test]
578 fn test_yaw_and_curvature_sanity() {
579 let way_points = vec![(0.0, 0.0), (2.0, 0.0), (4.0, 0.0), (6.0, 0.0)];
580 let spline = CatmullRomSpline::new(way_points, 20, Parameterisation::Uniform);
581 let (path, yaw, curvature) = spline.generate_path_with_derivatives();
582
583 for &y in &yaw {
585 assert!(
586 approx_eq(y, 0.0, 1e-10),
587 "expected yaw ~0 on a horizontal line, got {y}"
588 );
589 }
590 for &k in &curvature {
591 assert!(
592 approx_eq(k, 0.0, 1e-10),
593 "expected curvature ~0 on a straight line, got {k}"
594 );
595 }
596 assert_eq!(path.len(), yaw.len());
597 assert_eq!(path.len(), curvature.len());
598 }
599}