Skip to main content

rust_robotics_planning/
model_predictive_trajectory_generator.rs

1#![allow(clippy::too_many_arguments)]
2
3//! Model Predictive Trajectory Generator
4//!
5//! Generates smooth trajectories using numerical optimization (Newton's method)
6//! to connect an initial state to a target state. The trajectory is parameterized
7//! by arc length `s`, mid-point curvature `km`, and final curvature `kf`. A
8//! quadratic curvature profile interpolated from `(k0, km, kf)` is integrated
9//! via a bicycle kinematic model to produce the path. The Jacobian is computed
10//! numerically and used to iteratively refine the parameters until the terminal
11//! state error falls below a threshold.
12//!
13//! Reference: <https://github.com/AtsushiSakai/PythonRobotics/tree/master/PathPlanning/ModelPredictiveTrajectoryGenerator>
14
15use nalgebra::{Matrix3, Vector3};
16
17// ---------------------------------------------------------------------------
18// Configuration
19// ---------------------------------------------------------------------------
20
21/// Parameters for the trajectory generator.
22#[derive(Debug, Clone)]
23pub struct MptgConfig {
24    /// Wheel base length \[m\]
25    pub wheel_base: f64,
26    /// Step distance for trajectory discretization \[m\]
27    pub ds: f64,
28    /// Constant forward velocity used in the motion model \[m/s\]
29    pub velocity: f64,
30    /// Finite-difference step sizes for numerical Jacobian `(h_s, h_km, h_kf)`
31    pub h: Vector3<f64>,
32    /// Maximum number of optimization iterations
33    pub max_iter: usize,
34    /// Cost (terminal error norm) convergence threshold
35    pub cost_th: f64,
36}
37
38impl Default for MptgConfig {
39    fn default() -> Self {
40        Self {
41            wheel_base: 1.0,
42            ds: 0.1,
43            velocity: 10.0 / 3.6,
44            h: Vector3::new(0.5, 0.02, 0.02),
45            max_iter: 100,
46            cost_th: 0.1,
47        }
48    }
49}
50
51// ---------------------------------------------------------------------------
52// Target state
53// ---------------------------------------------------------------------------
54
55/// A 2-D pose used as the optimization target.
56#[derive(Debug, Clone, Copy)]
57pub struct TargetState {
58    pub x: f64,
59    pub y: f64,
60    pub yaw: f64,
61}
62
63impl TargetState {
64    pub fn new(x: f64, y: f64, yaw: f64) -> Self {
65        Self { x, y, yaw }
66    }
67}
68
69// ---------------------------------------------------------------------------
70// Result
71// ---------------------------------------------------------------------------
72
73/// Result of a successful trajectory optimization.
74#[derive(Debug, Clone)]
75pub struct MptgResult {
76    /// X coordinates along the trajectory
77    pub x: Vec<f64>,
78    /// Y coordinates along the trajectory
79    pub y: Vec<f64>,
80    /// Yaw angles along the trajectory \[rad\]
81    pub yaw: Vec<f64>,
82    /// Optimized parameters `(s, km, kf)`
83    pub params: Vector3<f64>,
84}
85
86// ---------------------------------------------------------------------------
87// Internal bicycle-model state
88// ---------------------------------------------------------------------------
89
90struct BicycleState {
91    x: f64,
92    y: f64,
93    yaw: f64,
94}
95
96impl BicycleState {
97    fn new() -> Self {
98        Self {
99            x: 0.0,
100            y: 0.0,
101            yaw: 0.0,
102        }
103    }
104
105    fn update(&mut self, v: f64, delta: f64, dt: f64, wheel_base: f64) {
106        self.x += v * self.yaw.cos() * dt;
107        self.y += v * self.yaw.sin() * dt;
108        self.yaw += v / wheel_base * delta.tan() * dt;
109        self.yaw = pi2pi(self.yaw);
110    }
111}
112
113// ---------------------------------------------------------------------------
114// Helpers
115// ---------------------------------------------------------------------------
116
117fn pi2pi(mut angle: f64) -> f64 {
118    while angle > std::f64::consts::PI {
119        angle -= 2.0 * std::f64::consts::PI;
120    }
121    while angle < -std::f64::consts::PI {
122        angle += 2.0 * std::f64::consts::PI;
123    }
124    angle
125}
126
127/// Fit a quadratic `a*t^2 + b*t + c` through three `(t, k)` points.
128fn quad_interp(t: (f64, f64, f64), k: (f64, f64, f64)) -> (f64, f64, f64) {
129    let mat = Matrix3::new(
130        t.0 * t.0,
131        t.0,
132        1.0,
133        t.1 * t.1,
134        t.1,
135        1.0,
136        t.2 * t.2,
137        t.2,
138        1.0,
139    );
140    let rhs = Vector3::new(k.0, k.1, k.2);
141    let coef = mat.try_inverse().expect("quad_interp: singular matrix") * rhs;
142    (coef[0], coef[1], coef[2])
143}
144
145/// Evaluate the curvature profile `k(t) = a*t^2 + b*t + c`.
146#[inline]
147fn eval_curvature(coef: (f64, f64, f64), t: f64) -> f64 {
148    coef.0 * t * t + coef.1 * t + coef.2
149}
150
151// ---------------------------------------------------------------------------
152// Trajectory generation (forward simulation)
153// ---------------------------------------------------------------------------
154
155/// Generate a full trajectory given parameters `(s, km, kf)` and initial
156/// curvature `k0`.
157fn generate_trajectory(
158    s: f64,
159    km: f64,
160    kf: f64,
161    k0: f64,
162    cfg: &MptgConfig,
163) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
164    let n = (s / cfg.ds).round().max(1.0) as usize;
165    let time = s / cfg.velocity;
166    let dt = time / n as f64;
167
168    let coef = quad_interp((0.0, time / 2.0, time), (k0, km, kf));
169
170    let mut state = BicycleState::new();
171    let mut xs = vec![state.x];
172    let mut ys = vec![state.y];
173    let mut yaws = vec![state.yaw];
174
175    for i in 0..n {
176        let t = i as f64 * dt;
177        let delta = eval_curvature(coef, t);
178        state.update(cfg.velocity, delta, dt, cfg.wheel_base);
179        xs.push(state.x);
180        ys.push(state.y);
181        yaws.push(state.yaw);
182    }
183
184    (xs, ys, yaws)
185}
186
187/// Generate only the last state (used in Jacobian computation for efficiency).
188fn generate_last_state(s: f64, km: f64, kf: f64, k0: f64, cfg: &MptgConfig) -> (f64, f64, f64) {
189    let n = (s / cfg.ds).round().max(1.0) as usize;
190    let time = s / cfg.velocity;
191    let dt = time / n as f64;
192
193    let coef = quad_interp((0.0, time / 2.0, time), (k0, km, kf));
194
195    let mut state = BicycleState::new();
196    for i in 0..n {
197        let t = i as f64 * dt;
198        let delta = eval_curvature(coef, t);
199        state.update(cfg.velocity, delta, dt, cfg.wheel_base);
200    }
201
202    (state.x, state.y, state.yaw)
203}
204
205// ---------------------------------------------------------------------------
206// Optimization internals
207// ---------------------------------------------------------------------------
208
209/// Terminal state error.
210fn calc_diff(target: &TargetState, x: f64, y: f64, yaw: f64) -> Vector3<f64> {
211    Vector3::new(target.x - x, target.y - y, pi2pi(target.yaw - yaw))
212}
213
214/// Numerical Jacobian via central differences.
215fn calc_jacobian(
216    target: &TargetState,
217    p: &Vector3<f64>,
218    k0: f64,
219    cfg: &MptgConfig,
220) -> Matrix3<f64> {
221    let h = &cfg.h;
222    let mut cols: [Vector3<f64>; 3] = [Vector3::zeros(); 3];
223
224    for dim in 0..3 {
225        let mut pp = *p;
226        let mut pn = *p;
227        pp[dim] += h[dim];
228        pn[dim] -= h[dim];
229
230        let (xp, yp, yawp) = generate_last_state(pp[0], pp[1], pp[2], k0, cfg);
231        let dp = calc_diff(target, xp, yp, yawp);
232
233        let (xn, yn, yawn) = generate_last_state(pn[0], pn[1], pn[2], k0, cfg);
234        let dn = calc_diff(target, xn, yn, yawn);
235
236        cols[dim] = (dp - dn) / (2.0 * h[dim]);
237    }
238
239    Matrix3::from_columns(&cols)
240}
241
242/// Line-search to pick a learning rate that minimizes cost.
243fn select_learning_rate(
244    dp: &Vector3<f64>,
245    p: &Vector3<f64>,
246    k0: f64,
247    target: &TargetState,
248    cfg: &MptgConfig,
249) -> f64 {
250    let mut best_alpha = 1.0;
251    let mut min_cost = f64::MAX;
252
253    let mut alpha = 1.0;
254    while alpha < 2.0 {
255        let tp = p + alpha * dp;
256        let (xc, yc, yawc) = generate_last_state(tp[0], tp[1], tp[2], k0, cfg);
257        let dc = calc_diff(target, xc, yc, yawc);
258        let cost = dc.norm();
259        if cost < min_cost {
260            best_alpha = alpha;
261            min_cost = cost;
262        }
263        alpha += 0.5;
264    }
265
266    best_alpha
267}
268
269// ---------------------------------------------------------------------------
270// Public API
271// ---------------------------------------------------------------------------
272
273/// Run the trajectory optimization.
274///
275/// Given a `target` pose, an initial curvature `k0`, and an initial parameter
276/// guess `init_p = (s, km, kf)`, iteratively refine the parameters using
277/// Newton's method until the terminal-state error is below the configured
278/// threshold.
279///
280/// Returns `Some(MptgResult)` on success, or `None` if the optimization fails
281/// to converge or encounters a singular Jacobian.
282pub fn optimize_trajectory(
283    target: &TargetState,
284    k0: f64,
285    init_p: Vector3<f64>,
286    cfg: &MptgConfig,
287) -> Option<MptgResult> {
288    let mut p = init_p;
289
290    for _ in 0..cfg.max_iter {
291        let (xc, yc, yawc) = generate_trajectory(p[0], p[1], p[2], k0, cfg);
292
293        let last_x = *xc.last().unwrap();
294        let last_y = *yc.last().unwrap();
295        let last_yaw = *yawc.last().unwrap();
296
297        let dc = calc_diff(target, last_x, last_y, last_yaw);
298        let cost = dc.norm();
299
300        if cost <= cfg.cost_th {
301            return Some(MptgResult {
302                x: xc,
303                y: yc,
304                yaw: yawc,
305                params: p,
306            });
307        }
308
309        let j = calc_jacobian(target, &p, k0, cfg);
310        let j_inv = j.try_inverse()?;
311        let dp = -j_inv * dc;
312
313        let alpha = select_learning_rate(&dp, &p, k0, target, cfg);
314        p += alpha * dp;
315    }
316
317    // Did not converge within max_iter
318    None
319}
320
321// ---------------------------------------------------------------------------
322// Lookup table
323// ---------------------------------------------------------------------------
324
325/// A single entry in the lookup table: `(x, y, yaw, s, km, kf)`.
326#[derive(Debug, Clone, Copy)]
327pub struct LookupEntry {
328    pub x: f64,
329    pub y: f64,
330    pub yaw: f64,
331    #[allow(dead_code)]
332    pub s: f64,
333    pub km: f64,
334    pub kf: f64,
335}
336
337/// Search the lookup table for the entry closest (in Euclidean + yaw sense)
338/// to the query `(tx, ty, tyaw)`.
339pub fn search_nearest_in_lookup_table(
340    tx: f64,
341    ty: f64,
342    tyaw: f64,
343    table: &[LookupEntry],
344) -> Option<&LookupEntry> {
345    table.iter().min_by(|a, b| {
346        let da = (tx - a.x).powi(2) + (ty - a.y).powi(2) + (tyaw - a.yaw).powi(2);
347        let db = (tx - b.x).powi(2) + (ty - b.y).powi(2) + (tyaw - b.yaw).powi(2);
348        da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
349    })
350}
351
352/// Generate a lookup table of pre-computed trajectory parameters.
353///
354/// For each combination of `(x, y, yaw)` in the given ranges, the optimizer
355/// is run to find `(s, km, kf)`. The nearest existing entry is used as the
356/// initial guess for each new target, bootstrapping convergence.
357pub fn generate_lookup_table(
358    x_range: &[f64],
359    y_range: &[f64],
360    yaw_range: &[f64],
361    k0: f64,
362    cfg: &MptgConfig,
363) -> Vec<LookupEntry> {
364    let mut table = vec![LookupEntry {
365        x: 1.0,
366        y: 0.0,
367        yaw: 0.0,
368        s: 1.0,
369        km: 0.0,
370        kf: 0.0,
371    }];
372
373    for &yaw in yaw_range {
374        for &y in y_range {
375            for &x in x_range {
376                let best = search_nearest_in_lookup_table(x, y, yaw, &table).unwrap();
377                let target = TargetState::new(x, y, yaw);
378                let s_init = (x * x + y * y).sqrt();
379                let init_p = Vector3::new(s_init, best.km, best.kf);
380
381                if let Some(result) = optimize_trajectory(&target, k0, init_p, cfg) {
382                    let last_x = *result.x.last().unwrap();
383                    let last_y = *result.y.last().unwrap();
384                    let last_yaw = *result.yaw.last().unwrap();
385                    table.push(LookupEntry {
386                        x: last_x,
387                        y: last_y,
388                        yaw: last_yaw,
389                        s: result.params[0],
390                        km: result.params[1],
391                        kf: result.params[2],
392                    });
393                }
394            }
395        }
396    }
397
398    table
399}
400
401// ---------------------------------------------------------------------------
402// Tests
403// ---------------------------------------------------------------------------
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408    use std::f64::consts::PI;
409
410    fn default_cfg() -> MptgConfig {
411        MptgConfig::default()
412    }
413
414    #[test]
415    fn test_pi2pi() {
416        assert!((pi2pi(3.0 * PI) - PI).abs() < 1e-10);
417        assert!((pi2pi(-3.0 * PI) - (-PI)).abs() < 1e-10);
418        assert!((pi2pi(0.5) - 0.5).abs() < 1e-10);
419    }
420
421    #[test]
422    fn test_quad_interp_linear() {
423        // f(t) = 2t + 1 => a=0, b=2, c=1
424        let (a, b, c) = quad_interp((0.0, 1.0, 2.0), (1.0, 3.0, 5.0));
425        assert!(a.abs() < 1e-10);
426        assert!((b - 2.0).abs() < 1e-10);
427        assert!((c - 1.0).abs() < 1e-10);
428    }
429
430    #[test]
431    fn test_generate_trajectory_straight() {
432        let cfg = default_cfg();
433        // Zero curvature throughout => straight line
434        let (xs, ys, yaws) = generate_trajectory(5.0, 0.0, 0.0, 0.0, &cfg);
435        assert!(xs.len() > 2);
436        // Final x should be approximately s=5.0
437        let last_x = *xs.last().unwrap();
438        assert!((last_x - 5.0).abs() < 0.5, "Expected ~5.0, got {last_x}");
439        // y should stay near zero
440        let last_y = ys.last().unwrap().abs();
441        assert!(last_y < 0.1, "Expected ~0.0, got {last_y}");
442        // yaw should stay near zero
443        let last_yaw = yaws.last().unwrap().abs();
444        assert!(last_yaw < 0.1, "Expected ~0.0, got {last_yaw}");
445    }
446
447    #[test]
448    fn test_generate_last_state_matches_trajectory() {
449        let cfg = default_cfg();
450        let (xs, ys, yaws) = generate_trajectory(6.0, 0.1, -0.05, 0.0, &cfg);
451        let (lx, ly, lyaw) = generate_last_state(6.0, 0.1, -0.05, 0.0, &cfg);
452        assert!((xs.last().unwrap() - lx).abs() < 1e-10, "x mismatch");
453        assert!((ys.last().unwrap() - ly).abs() < 1e-10, "y mismatch");
454        assert!((yaws.last().unwrap() - lyaw).abs() < 1e-10, "yaw mismatch");
455    }
456
457    #[test]
458    fn test_optimize_trajectory_90deg() {
459        let cfg = default_cfg();
460        let target = TargetState::new(5.0, 2.0, PI / 2.0);
461        let k0 = 0.0;
462        let init_p = Vector3::new(6.0, 0.0, 0.0);
463
464        let result = optimize_trajectory(&target, k0, init_p, &cfg);
465        assert!(result.is_some(), "Optimization should converge");
466
467        let res = result.unwrap();
468        let last_x = *res.x.last().unwrap();
469        let last_y = *res.y.last().unwrap();
470        let last_yaw = *res.yaw.last().unwrap();
471
472        assert!(
473            (last_x - target.x).abs() < cfg.cost_th,
474            "x error too large: {last_x} vs {}",
475            target.x
476        );
477        assert!(
478            (last_y - target.y).abs() < cfg.cost_th,
479            "y error too large: {last_y} vs {}",
480            target.y
481        );
482        assert!(
483            pi2pi(last_yaw - target.yaw).abs() < cfg.cost_th,
484            "yaw error too large"
485        );
486    }
487
488    #[test]
489    fn test_optimize_trajectory_straight_ahead() {
490        let cfg = default_cfg();
491        let target = TargetState::new(10.0, 0.0, 0.0);
492        let init_p = Vector3::new(10.0, 0.0, 0.0);
493
494        let result = optimize_trajectory(&target, 0.0, init_p, &cfg);
495        assert!(result.is_some(), "Straight-ahead should converge");
496    }
497
498    #[test]
499    fn test_optimize_trajectory_negative_yaw() {
500        let cfg = default_cfg();
501        let target = TargetState::new(5.0, -2.0, -PI / 4.0);
502        let init_p = Vector3::new(6.0, 0.0, 0.0);
503
504        let result = optimize_trajectory(&target, 0.0, init_p, &cfg);
505        assert!(result.is_some(), "Negative yaw target should converge");
506    }
507
508    #[test]
509    fn test_lookup_table_generation() {
510        let cfg = MptgConfig {
511            max_iter: 100,
512            cost_th: 0.3,
513            ..Default::default()
514        };
515
516        let x_range: Vec<f64> = vec![10.0, 15.0];
517        let y_range: Vec<f64> = vec![0.0, 5.0];
518        let yaw_range: Vec<f64> = vec![0.0];
519
520        let table = generate_lookup_table(&x_range, &y_range, &yaw_range, 0.0, &cfg);
521
522        // Should have the seed entry plus some solved entries
523        assert!(
524            table.len() > 1,
525            "Lookup table should contain more than just the seed"
526        );
527    }
528
529    #[test]
530    fn test_search_nearest_in_lookup_table() {
531        let table = vec![
532            LookupEntry {
533                x: 1.0,
534                y: 0.0,
535                yaw: 0.0,
536                s: 1.0,
537                km: 0.0,
538                kf: 0.0,
539            },
540            LookupEntry {
541                x: 10.0,
542                y: 5.0,
543                yaw: 0.5,
544                s: 11.0,
545                km: 0.1,
546                kf: 0.05,
547            },
548        ];
549        let nearest = search_nearest_in_lookup_table(9.0, 4.0, 0.4, &table).unwrap();
550        assert!((nearest.x - 10.0).abs() < 1e-10);
551    }
552
553    #[test]
554    fn test_search_nearest_empty_table() {
555        let table: Vec<LookupEntry> = vec![];
556        assert!(search_nearest_in_lookup_table(1.0, 0.0, 0.0, &table).is_none());
557    }
558
559    #[test]
560    fn test_config_default() {
561        let cfg = MptgConfig::default();
562        assert!((cfg.wheel_base - 1.0).abs() < 1e-10);
563        assert!((cfg.ds - 0.1).abs() < 1e-10);
564        assert!(cfg.max_iter == 100);
565    }
566}