Skip to main content

rust_robotics_planning/
cubic_spline_planner.rs

1#![allow(
2    dead_code,
3    clippy::needless_borrows_for_generic_args,
4    clippy::new_without_default,
5    clippy::ptr_arg,
6    clippy::type_complexity
7)]
8
9//! Cubic spline planner
10//!
11//! Path planner using cubic spline interpolation through waypoints.
12//!
13//! Reference:
14//! - PythonRobotics CubicSpline by Atsushi Sakai
15//! - CppRobotics cubic_spline by TAI Lei
16
17#[derive(Debug, Clone)]
18struct Spline {
19    a: Vec<f64>,
20    b: Vec<f64>,
21    c: Vec<f64>,
22    d: Vec<f64>,
23    x: Vec<f64>,
24    y: Vec<f64>,
25}
26
27impl Spline {
28    fn new(x: &Vec<f64>, y: &Vec<f64>) -> Spline {
29        let nx = x.len();
30        let mut b: Vec<f64> = Vec::with_capacity(nx);
31        let mut d: Vec<f64> = Vec::with_capacity(nx);
32        let mut h: Vec<f64> = Vec::with_capacity(nx - 1);
33        for i in 0..nx - 1 {
34            h.push(x[i + 1] - x[i]);
35        }
36        let a = y.clone();
37        let a_mat = Spline::__calc_a(&h);
38        let b_mat = Spline::__calc_b(&h, &a);
39
40        let a_mat_inv = a_mat.try_inverse().unwrap();
41
42        let c_na = a_mat_inv * b_mat;
43        let mut c: Vec<f64> = Vec::with_capacity(c_na.len());
44        for i in 0..c_na.len() {
45            c.push(c_na[i]);
46        }
47        for i in 0..nx - 1 {
48            d.push((c[i + 1] - c[i]) / (3. * h[i]));
49            let tb = (a[i + 1] - a[i]) / h[i] - h[i] * (c[i + 1] + 2.0 * c[i]) / 3.0;
50            b.push(tb);
51        }
52
53        Spline {
54            a,
55            b,
56            c,
57            d,
58            x: x.to_vec(),
59            y: y.to_vec(),
60        }
61    }
62
63    fn calc(&self, t: f64) -> f64 {
64        let i = self.__search_index(t);
65        let x = self.x[i];
66        let dx = t - x;
67        self.a[i] + self.b[i] * dx + self.c[i] * dx.powi(2) + self.d[i] * dx.powi(3)
68    }
69
70    fn calcd(&self, t: f64) -> f64 {
71        let i = self.__search_index(t);
72        let x = self.x[i];
73        let dx = t - x;
74        let b = self.b[i];
75        let c = self.c[i];
76        let d = self.d[i];
77        b + 2. * c * dx + 3. * d * dx.powi(2)
78    }
79
80    fn calcdd(&self, t: f64) -> f64 {
81        let i = self.__search_index(t);
82        let x = self.x[i];
83        let dx = t - x;
84        2. * self.c[i] + 6. * self.d[i] * dx
85    }
86
87    fn __search_index(&self, t: f64) -> usize {
88        let nx = self.x.len();
89        self.bisect(t, 0, nx)
90    }
91
92    fn __calc_a(h: &Vec<f64>) -> nalgebra::DMatrix<f64> {
93        let nx = h.len() + 1;
94        let mut a = nalgebra::DMatrix::from_diagonal_element(nx, nx, 0.0);
95        a[(0, 0)] = 1.;
96        for i in 0..nx - 1 {
97            if i != nx - 2 {
98                a[(i + 1, i + 1)] = 2.0 * (h[i] + h[i + 1]);
99            }
100            a[(i + 1, i)] = h[i];
101            a[(i, i + 1)] = h[i];
102        }
103        a[(0, 1)] = 0.;
104        a[(nx - 1, nx - 2)] = 0.;
105        a[(nx - 1, nx - 1)] = 1.;
106        a
107    }
108
109    fn __calc_b(h: &Vec<f64>, a: &Vec<f64>) -> nalgebra::DVector<f64> {
110        let nx = h.len() + 1;
111        let mut b = nalgebra::DVector::zeros(nx);
112        for i in 0..nx - 2 {
113            b[i + 1] = 3.0 * (a[i + 2] - a[i + 1]) / h[i + 1] - 3.0 * (a[i + 1] - a[i]) / h[i];
114        }
115        b
116    }
117
118    fn bisect(&self, t: f64, s: usize, e: usize) -> usize {
119        let mid = (s + e) / 2;
120        if t == self.x[mid] || e - s <= 1 {
121            mid
122        } else if t > self.x[mid] {
123            self.bisect(t, mid, e)
124        } else {
125            self.bisect(t, s, mid)
126        }
127    }
128}
129
130#[derive(Debug, Clone)]
131pub struct Spline2D {
132    pub s: Vec<f64>,
133    sx: Spline,
134    sy: Spline,
135}
136
137impl Spline2D {
138    pub fn new(x: Vec<f64>, y: Vec<f64>) -> Spline2D {
139        let s = Spline2D::__calc_s(&x, &y);
140        let sx = Spline::new(&s, &x);
141        let sy = Spline::new(&s, &y);
142
143        Spline2D { s, sx, sy }
144    }
145
146    fn __calc_s(x: &Vec<f64>, y: &Vec<f64>) -> Vec<f64> {
147        let nx = x.len();
148        let mut dx: Vec<f64> = Vec::with_capacity(nx - 1);
149        let mut dy: Vec<f64> = Vec::with_capacity(nx - 1);
150
151        for i in 0..nx - 1 {
152            dx.push(x[i + 1] - x[i]);
153            dy.push(y[i + 1] - y[i]);
154        }
155        let mut ds: Vec<f64> = Vec::with_capacity(nx - 1);
156        for i in 0..nx - 1 {
157            let dsi = (dx[i].powi(2) + dy[i].powi(2)).sqrt();
158            ds.push(dsi);
159        }
160        let mut s: Vec<f64> = Vec::with_capacity(nx);
161        s.push(0.);
162        for i in 0..nx - 1 {
163            s.push(s[i] + ds[i]);
164        }
165        s
166    }
167
168    pub fn calc_position(&self, is: f64) -> (f64, f64) {
169        let x = self.sx.calc(is);
170        let y = self.sy.calc(is);
171        (x, y)
172    }
173
174    pub fn calc_curvature(&self, is: f64) -> f64 {
175        let dx = self.sx.calcd(is);
176        let ddx = self.sx.calcdd(is);
177        let dy = self.sy.calcd(is);
178        let ddy = self.sy.calcdd(is);
179        (ddy * dx - ddx * dy) / ((dx.powi(2) + dy.powi(2)).powf(3. / 2.))
180    }
181
182    pub fn calc_yaw(&self, is: f64) -> f64 {
183        let dx = self.sx.calcd(is);
184        let dy = self.sy.calcd(is);
185        dy.atan2(dx)
186    }
187}
188
189pub fn calc_spline_course(
190    x: Vec<f64>,
191    y: Vec<f64>,
192    ds: f64,
193) -> (Vec<(f64, f64)>, Vec<f64>, Vec<f64>, Vec<f64>) {
194    let sp = Spline2D::new(x, y);
195    let s_end = sp.s[sp.s.len() - 1];
196    let mut r: Vec<(f64, f64)> = Vec::new();
197    let mut ryaw: Vec<f64> = Vec::new();
198    let mut rk: Vec<f64> = Vec::new();
199    let mut s: Vec<f64> = Vec::new();
200
201    let mut is = 0.0;
202    while is < s_end {
203        let pair = sp.clone().calc_position(is);
204        r.push(pair);
205        ryaw.push(sp.clone().calc_yaw(is));
206        rk.push(sp.clone().calc_curvature(is));
207        s.push(is);
208        is += ds;
209    }
210    (r, ryaw, rk, s)
211}
212
213pub struct CubicSplinePlanner {
214    pub path: Vec<(f64, f64)>,
215    pub yaw: Vec<f64>,
216    pub curvature: Vec<f64>,
217    pub s: Vec<f64>,
218}
219
220impl CubicSplinePlanner {
221    pub fn new() -> Self {
222        CubicSplinePlanner {
223            path: Vec::new(),
224            yaw: Vec::new(),
225            curvature: Vec::new(),
226            s: Vec::new(),
227        }
228    }
229
230    pub fn planning(&mut self, waypoints_x: Vec<f64>, waypoints_y: Vec<f64>, ds: f64) -> bool {
231        if waypoints_x.len() != waypoints_y.len() || waypoints_x.len() < 2 {
232            return false;
233        }
234
235        let (path, yaw, curvature, s) = calc_spline_course(waypoints_x, waypoints_y, ds);
236
237        self.path = path;
238        self.yaw = yaw;
239        self.curvature = curvature;
240        self.s = s;
241
242        true
243    }
244}
245
246impl Default for CubicSplinePlanner {
247    fn default() -> Self {
248        Self::new()
249    }
250}
251
252#[cfg(test)]
253#[allow(clippy::excessive_precision)]
254mod tests {
255    use super::*;
256
257    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
258        (a - b).abs() < tol
259    }
260
261    fn path_fingerprint(path: &[(f64, f64)]) -> (f64, f64, f64, f64) {
262        let sum_x: f64 = path.iter().map(|point| point.0).sum();
263        let sum_y: f64 = path.iter().map(|point| point.1).sum();
264        let weighted_sum_x: f64 = path
265            .iter()
266            .enumerate()
267            .map(|(index, point)| (index + 1) as f64 * point.0)
268            .sum();
269        let weighted_sum_y: f64 = path
270            .iter()
271            .enumerate()
272            .map(|(index, point)| (index + 1) as f64 * point.1)
273            .sum();
274        (sum_x, sum_y, weighted_sum_x, weighted_sum_y)
275    }
276
277    fn scalar_fingerprint(values: &[f64]) -> (f64, f64) {
278        let sum: f64 = values.iter().sum();
279        let weighted_sum: f64 = values
280            .iter()
281            .enumerate()
282            .map(|(index, value)| (index + 1) as f64 * value)
283            .sum();
284        (sum, weighted_sum)
285    }
286
287    #[test]
288    fn test_cubic_spline_planning() {
289        let waypoints_x = vec![0.0, 10.0, 20.5, 30.0, 40.5, 50.0];
290        let waypoints_y = vec![0.0, -6.0, 5.0, 6.5, 0.0, -4.0];
291        let ds = 0.1;
292
293        let mut planner = CubicSplinePlanner::new();
294        let result = planner.planning(waypoints_x, waypoints_y, ds);
295
296        assert!(result);
297        assert!(!planner.path.is_empty());
298        assert!(!planner.yaw.is_empty());
299        assert!(!planner.curvature.is_empty());
300    }
301
302    #[test]
303    fn test_spline2d() {
304        let x = vec![0.0, 10.0, 20.5, 30.0];
305        let y = vec![0.0, -6.0, 5.0, 6.5];
306        let sp = Spline2D::new(x, y);
307        assert!(!sp.s.is_empty());
308    }
309
310    #[test]
311    fn test_calc_spline_course_matches_upstream_main2d_example() {
312        let x = vec![-2.5, 0.0, 2.5, 5.0, 7.5, 3.0, -1.0];
313        let y = vec![0.7, -6.0, 5.0, 6.5, 0.0, 5.0, -2.0];
314        let (path, yaw, curvature, s) = calc_spline_course(x, y, 0.1);
315
316        assert_eq!(path.len(), 432);
317        assert_eq!(yaw.len(), 432);
318        assert_eq!(curvature.len(), 432);
319        assert_eq!(s.len(), 432);
320
321        let expected_samples = [
322            (
323                0usize,
324                (-2.5, 0.7),
325                -1.261_911_977_390_588_5,
326                3.639_626_203_496_368e-17,
327                0.0,
328            ),
329            (
330                1,
331                (-2.456_214_414_403_149, 0.562_786_288_975_573_1),
332                -1.261_892_330_947_575,
333                0.000_272_891_300_476_531_77,
334                0.1,
335            ),
336            (
337                10,
338                (-2.063_853_173_620_386_4, -0.663_709_921_270_961_1),
339                -1.259_911_718_126_614_6,
340                0.002_880_022_595_677_534_5,
341                1.0,
342            ),
343            (
344                50,
345                (-0.526_420_969_786_195_5, -5.097_072_155_044_978_5),
346                -1.172_930_102_001_000_1,
347                0.080_811_423_731_918_38,
348                5.0,
349            ),
350            (
351                431,
352                (-0.999_817_756_777_344, -1.999_280_944_483_088),
353                -1.819_017_577_858_351_5,
354                4.894_060_652_105_897e-06,
355                43.1,
356            ),
357        ];
358
359        for (index, expected_pos, expected_yaw, expected_curvature, expected_s) in expected_samples
360        {
361            assert!(approx_eq(path[index].0, expected_pos.0, 1e-12));
362            assert!(approx_eq(path[index].1, expected_pos.1, 1e-12));
363            assert!(approx_eq(yaw[index], expected_yaw, 1e-12));
364            assert!(approx_eq(curvature[index], expected_curvature, 1e-12));
365            assert!(approx_eq(s[index], expected_s, 1e-12));
366        }
367
368        let path_fp = path_fingerprint(&path);
369        assert!(approx_eq(path_fp.0, 1_022.623_164_017_587_4, 1e-9));
370        assert!(approx_eq(path_fp.1, 357.291_977_281_188_57, 1e-9));
371        assert!(approx_eq(path_fp.2, 291_306.562_899_991_5, 1e-6));
372        assert!(approx_eq(path_fp.3, 199_642.229_890_529_74, 1e-6));
373
374        let yaw_fp = scalar_fingerprint(&yaw);
375        assert!(approx_eq(yaw_fp.0, 9.657_666_682_428_65, 1e-9));
376        assert!(approx_eq(yaw_fp.1, -7_330.389_498_653_044, 1e-6));
377
378        let curvature_fp = scalar_fingerprint(&curvature);
379        assert!(approx_eq(curvature_fp.0, 62.207_543_933_499_61, 1e-9));
380        assert!(approx_eq(curvature_fp.1, -3_950.521_036_466_579_5, 1e-6));
381
382        let s_fp = scalar_fingerprint(&s);
383        assert!(approx_eq(s_fp.0, 9_309.6, 1e-9));
384        assert!(approx_eq(s_fp.1, 2_687_371.200_000_001_6, 1e-3));
385    }
386
387    #[test]
388    fn test_spline2d_matches_upstream_main2d_reference_states() {
389        let x = vec![-2.5, 0.0, 2.5, 5.0, 7.5, 3.0, -1.0];
390        let y = vec![0.7, -6.0, 5.0, 6.5, 0.0, 5.0, -2.0];
391        let sp = Spline2D::new(x, y);
392
393        assert!(approx_eq(
394            sp.s[sp.s.len() - 1],
395            43.100_477_702_041_04,
396            1e-12
397        ));
398
399        let expected_states = [
400            (
401                0.0,
402                (-2.5, 0.7),
403                -1.261_911_977_390_588_5,
404                3.639_626_203_496_368e-17,
405            ),
406            (
407                0.1,
408                (-2.456_214_414_403_149, 0.562_786_288_975_573_1),
409                -1.261_892_330_947_575,
410                0.000_272_891_300_476_531_77,
411            ),
412            (
413                1.0,
414                (-2.063_853_173_620_386_4, -0.663_709_921_270_961_1),
415                -1.259_911_718_126_614_6,
416                0.002_880_022_595_677_534_5,
417            ),
418            (
419                5.0,
420                (-0.526_420_969_786_195_5, -5.097_072_155_044_978_5),
421                -1.172_930_102_001_000_1,
422                0.080_811_423_731_918_38,
423            ),
424            (
425                10.0,
426                (0.277_082_035_120_894_85, -4.891_782_706_378_051),
427                1.505_135_127_918_666_9,
428                0.043_260_837_806_513_984,
429            ),
430            (
431                43.000_477_702_041_04,
432                (-0.961_848_258_550_467_4, -1.849_485_946_939_665_8),
433                -1.819_097_133_814_203_5,
434                0.001_025_039_188_263_279_5,
435            ),
436        ];
437
438        for (arc, expected_pos, expected_yaw, expected_curvature) in expected_states {
439            let position = sp.calc_position(arc);
440            assert!(approx_eq(position.0, expected_pos.0, 1e-12));
441            assert!(approx_eq(position.1, expected_pos.1, 1e-12));
442            assert!(approx_eq(sp.calc_yaw(arc), expected_yaw, 1e-12));
443            assert!(approx_eq(sp.calc_curvature(arc), expected_curvature, 1e-12));
444        }
445    }
446}