1#![allow(clippy::too_many_arguments)]
2
3use nalgebra::{Matrix3, Vector3};
16
17#[derive(Debug, Clone)]
23pub struct MptgConfig {
24 pub wheel_base: f64,
26 pub ds: f64,
28 pub velocity: f64,
30 pub h: Vector3<f64>,
32 pub max_iter: usize,
34 pub cost_th: f64,
36}
37
38impl Default for MptgConfig {
39 fn default() -> Self {
40 Self {
41 wheel_base: 1.0,
42 ds: 0.1,
43 velocity: 10.0 / 3.6,
44 h: Vector3::new(0.5, 0.02, 0.02),
45 max_iter: 100,
46 cost_th: 0.1,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Copy)]
57pub struct TargetState {
58 pub x: f64,
59 pub y: f64,
60 pub yaw: f64,
61}
62
63impl TargetState {
64 pub fn new(x: f64, y: f64, yaw: f64) -> Self {
65 Self { x, y, yaw }
66 }
67}
68
69#[derive(Debug, Clone)]
75pub struct MptgResult {
76 pub x: Vec<f64>,
78 pub y: Vec<f64>,
80 pub yaw: Vec<f64>,
82 pub params: Vector3<f64>,
84}
85
86struct BicycleState {
91 x: f64,
92 y: f64,
93 yaw: f64,
94}
95
96impl BicycleState {
97 fn new() -> Self {
98 Self {
99 x: 0.0,
100 y: 0.0,
101 yaw: 0.0,
102 }
103 }
104
105 fn update(&mut self, v: f64, delta: f64, dt: f64, wheel_base: f64) {
106 self.x += v * self.yaw.cos() * dt;
107 self.y += v * self.yaw.sin() * dt;
108 self.yaw += v / wheel_base * delta.tan() * dt;
109 self.yaw = pi2pi(self.yaw);
110 }
111}
112
113fn pi2pi(mut angle: f64) -> f64 {
118 while angle > std::f64::consts::PI {
119 angle -= 2.0 * std::f64::consts::PI;
120 }
121 while angle < -std::f64::consts::PI {
122 angle += 2.0 * std::f64::consts::PI;
123 }
124 angle
125}
126
127fn quad_interp(t: (f64, f64, f64), k: (f64, f64, f64)) -> (f64, f64, f64) {
129 let mat = Matrix3::new(
130 t.0 * t.0,
131 t.0,
132 1.0,
133 t.1 * t.1,
134 t.1,
135 1.0,
136 t.2 * t.2,
137 t.2,
138 1.0,
139 );
140 let rhs = Vector3::new(k.0, k.1, k.2);
141 let coef = mat.try_inverse().expect("quad_interp: singular matrix") * rhs;
142 (coef[0], coef[1], coef[2])
143}
144
145#[inline]
147fn eval_curvature(coef: (f64, f64, f64), t: f64) -> f64 {
148 coef.0 * t * t + coef.1 * t + coef.2
149}
150
151fn generate_trajectory(
158 s: f64,
159 km: f64,
160 kf: f64,
161 k0: f64,
162 cfg: &MptgConfig,
163) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
164 let n = (s / cfg.ds).round().max(1.0) as usize;
165 let time = s / cfg.velocity;
166 let dt = time / n as f64;
167
168 let coef = quad_interp((0.0, time / 2.0, time), (k0, km, kf));
169
170 let mut state = BicycleState::new();
171 let mut xs = vec![state.x];
172 let mut ys = vec![state.y];
173 let mut yaws = vec![state.yaw];
174
175 for i in 0..n {
176 let t = i as f64 * dt;
177 let delta = eval_curvature(coef, t);
178 state.update(cfg.velocity, delta, dt, cfg.wheel_base);
179 xs.push(state.x);
180 ys.push(state.y);
181 yaws.push(state.yaw);
182 }
183
184 (xs, ys, yaws)
185}
186
187fn generate_last_state(s: f64, km: f64, kf: f64, k0: f64, cfg: &MptgConfig) -> (f64, f64, f64) {
189 let n = (s / cfg.ds).round().max(1.0) as usize;
190 let time = s / cfg.velocity;
191 let dt = time / n as f64;
192
193 let coef = quad_interp((0.0, time / 2.0, time), (k0, km, kf));
194
195 let mut state = BicycleState::new();
196 for i in 0..n {
197 let t = i as f64 * dt;
198 let delta = eval_curvature(coef, t);
199 state.update(cfg.velocity, delta, dt, cfg.wheel_base);
200 }
201
202 (state.x, state.y, state.yaw)
203}
204
205fn calc_diff(target: &TargetState, x: f64, y: f64, yaw: f64) -> Vector3<f64> {
211 Vector3::new(target.x - x, target.y - y, pi2pi(target.yaw - yaw))
212}
213
214fn calc_jacobian(
216 target: &TargetState,
217 p: &Vector3<f64>,
218 k0: f64,
219 cfg: &MptgConfig,
220) -> Matrix3<f64> {
221 let h = &cfg.h;
222 let mut cols: [Vector3<f64>; 3] = [Vector3::zeros(); 3];
223
224 for dim in 0..3 {
225 let mut pp = *p;
226 let mut pn = *p;
227 pp[dim] += h[dim];
228 pn[dim] -= h[dim];
229
230 let (xp, yp, yawp) = generate_last_state(pp[0], pp[1], pp[2], k0, cfg);
231 let dp = calc_diff(target, xp, yp, yawp);
232
233 let (xn, yn, yawn) = generate_last_state(pn[0], pn[1], pn[2], k0, cfg);
234 let dn = calc_diff(target, xn, yn, yawn);
235
236 cols[dim] = (dp - dn) / (2.0 * h[dim]);
237 }
238
239 Matrix3::from_columns(&cols)
240}
241
242fn select_learning_rate(
244 dp: &Vector3<f64>,
245 p: &Vector3<f64>,
246 k0: f64,
247 target: &TargetState,
248 cfg: &MptgConfig,
249) -> f64 {
250 let mut best_alpha = 1.0;
251 let mut min_cost = f64::MAX;
252
253 let mut alpha = 1.0;
254 while alpha < 2.0 {
255 let tp = p + alpha * dp;
256 let (xc, yc, yawc) = generate_last_state(tp[0], tp[1], tp[2], k0, cfg);
257 let dc = calc_diff(target, xc, yc, yawc);
258 let cost = dc.norm();
259 if cost < min_cost {
260 best_alpha = alpha;
261 min_cost = cost;
262 }
263 alpha += 0.5;
264 }
265
266 best_alpha
267}
268
269pub fn optimize_trajectory(
283 target: &TargetState,
284 k0: f64,
285 init_p: Vector3<f64>,
286 cfg: &MptgConfig,
287) -> Option<MptgResult> {
288 let mut p = init_p;
289
290 for _ in 0..cfg.max_iter {
291 let (xc, yc, yawc) = generate_trajectory(p[0], p[1], p[2], k0, cfg);
292
293 let last_x = *xc.last().unwrap();
294 let last_y = *yc.last().unwrap();
295 let last_yaw = *yawc.last().unwrap();
296
297 let dc = calc_diff(target, last_x, last_y, last_yaw);
298 let cost = dc.norm();
299
300 if cost <= cfg.cost_th {
301 return Some(MptgResult {
302 x: xc,
303 y: yc,
304 yaw: yawc,
305 params: p,
306 });
307 }
308
309 let j = calc_jacobian(target, &p, k0, cfg);
310 let j_inv = j.try_inverse()?;
311 let dp = -j_inv * dc;
312
313 let alpha = select_learning_rate(&dp, &p, k0, target, cfg);
314 p += alpha * dp;
315 }
316
317 None
319}
320
321#[derive(Debug, Clone, Copy)]
327pub struct LookupEntry {
328 pub x: f64,
329 pub y: f64,
330 pub yaw: f64,
331 #[allow(dead_code)]
332 pub s: f64,
333 pub km: f64,
334 pub kf: f64,
335}
336
337pub fn search_nearest_in_lookup_table(
340 tx: f64,
341 ty: f64,
342 tyaw: f64,
343 table: &[LookupEntry],
344) -> Option<&LookupEntry> {
345 table.iter().min_by(|a, b| {
346 let da = (tx - a.x).powi(2) + (ty - a.y).powi(2) + (tyaw - a.yaw).powi(2);
347 let db = (tx - b.x).powi(2) + (ty - b.y).powi(2) + (tyaw - b.yaw).powi(2);
348 da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
349 })
350}
351
352pub fn generate_lookup_table(
358 x_range: &[f64],
359 y_range: &[f64],
360 yaw_range: &[f64],
361 k0: f64,
362 cfg: &MptgConfig,
363) -> Vec<LookupEntry> {
364 let mut table = vec![LookupEntry {
365 x: 1.0,
366 y: 0.0,
367 yaw: 0.0,
368 s: 1.0,
369 km: 0.0,
370 kf: 0.0,
371 }];
372
373 for &yaw in yaw_range {
374 for &y in y_range {
375 for &x in x_range {
376 let best = search_nearest_in_lookup_table(x, y, yaw, &table).unwrap();
377 let target = TargetState::new(x, y, yaw);
378 let s_init = (x * x + y * y).sqrt();
379 let init_p = Vector3::new(s_init, best.km, best.kf);
380
381 if let Some(result) = optimize_trajectory(&target, k0, init_p, cfg) {
382 let last_x = *result.x.last().unwrap();
383 let last_y = *result.y.last().unwrap();
384 let last_yaw = *result.yaw.last().unwrap();
385 table.push(LookupEntry {
386 x: last_x,
387 y: last_y,
388 yaw: last_yaw,
389 s: result.params[0],
390 km: result.params[1],
391 kf: result.params[2],
392 });
393 }
394 }
395 }
396 }
397
398 table
399}
400
401#[cfg(test)]
406mod tests {
407 use super::*;
408 use std::f64::consts::PI;
409
410 fn default_cfg() -> MptgConfig {
411 MptgConfig::default()
412 }
413
414 #[test]
415 fn test_pi2pi() {
416 assert!((pi2pi(3.0 * PI) - PI).abs() < 1e-10);
417 assert!((pi2pi(-3.0 * PI) - (-PI)).abs() < 1e-10);
418 assert!((pi2pi(0.5) - 0.5).abs() < 1e-10);
419 }
420
421 #[test]
422 fn test_quad_interp_linear() {
423 let (a, b, c) = quad_interp((0.0, 1.0, 2.0), (1.0, 3.0, 5.0));
425 assert!(a.abs() < 1e-10);
426 assert!((b - 2.0).abs() < 1e-10);
427 assert!((c - 1.0).abs() < 1e-10);
428 }
429
430 #[test]
431 fn test_generate_trajectory_straight() {
432 let cfg = default_cfg();
433 let (xs, ys, yaws) = generate_trajectory(5.0, 0.0, 0.0, 0.0, &cfg);
435 assert!(xs.len() > 2);
436 let last_x = *xs.last().unwrap();
438 assert!((last_x - 5.0).abs() < 0.5, "Expected ~5.0, got {last_x}");
439 let last_y = ys.last().unwrap().abs();
441 assert!(last_y < 0.1, "Expected ~0.0, got {last_y}");
442 let last_yaw = yaws.last().unwrap().abs();
444 assert!(last_yaw < 0.1, "Expected ~0.0, got {last_yaw}");
445 }
446
447 #[test]
448 fn test_generate_last_state_matches_trajectory() {
449 let cfg = default_cfg();
450 let (xs, ys, yaws) = generate_trajectory(6.0, 0.1, -0.05, 0.0, &cfg);
451 let (lx, ly, lyaw) = generate_last_state(6.0, 0.1, -0.05, 0.0, &cfg);
452 assert!((xs.last().unwrap() - lx).abs() < 1e-10, "x mismatch");
453 assert!((ys.last().unwrap() - ly).abs() < 1e-10, "y mismatch");
454 assert!((yaws.last().unwrap() - lyaw).abs() < 1e-10, "yaw mismatch");
455 }
456
457 #[test]
458 fn test_optimize_trajectory_90deg() {
459 let cfg = default_cfg();
460 let target = TargetState::new(5.0, 2.0, PI / 2.0);
461 let k0 = 0.0;
462 let init_p = Vector3::new(6.0, 0.0, 0.0);
463
464 let result = optimize_trajectory(&target, k0, init_p, &cfg);
465 assert!(result.is_some(), "Optimization should converge");
466
467 let res = result.unwrap();
468 let last_x = *res.x.last().unwrap();
469 let last_y = *res.y.last().unwrap();
470 let last_yaw = *res.yaw.last().unwrap();
471
472 assert!(
473 (last_x - target.x).abs() < cfg.cost_th,
474 "x error too large: {last_x} vs {}",
475 target.x
476 );
477 assert!(
478 (last_y - target.y).abs() < cfg.cost_th,
479 "y error too large: {last_y} vs {}",
480 target.y
481 );
482 assert!(
483 pi2pi(last_yaw - target.yaw).abs() < cfg.cost_th,
484 "yaw error too large"
485 );
486 }
487
488 #[test]
489 fn test_optimize_trajectory_straight_ahead() {
490 let cfg = default_cfg();
491 let target = TargetState::new(10.0, 0.0, 0.0);
492 let init_p = Vector3::new(10.0, 0.0, 0.0);
493
494 let result = optimize_trajectory(&target, 0.0, init_p, &cfg);
495 assert!(result.is_some(), "Straight-ahead should converge");
496 }
497
498 #[test]
499 fn test_optimize_trajectory_negative_yaw() {
500 let cfg = default_cfg();
501 let target = TargetState::new(5.0, -2.0, -PI / 4.0);
502 let init_p = Vector3::new(6.0, 0.0, 0.0);
503
504 let result = optimize_trajectory(&target, 0.0, init_p, &cfg);
505 assert!(result.is_some(), "Negative yaw target should converge");
506 }
507
508 #[test]
509 fn test_lookup_table_generation() {
510 let cfg = MptgConfig {
511 max_iter: 100,
512 cost_th: 0.3,
513 ..Default::default()
514 };
515
516 let x_range: Vec<f64> = vec![10.0, 15.0];
517 let y_range: Vec<f64> = vec![0.0, 5.0];
518 let yaw_range: Vec<f64> = vec![0.0];
519
520 let table = generate_lookup_table(&x_range, &y_range, &yaw_range, 0.0, &cfg);
521
522 assert!(
524 table.len() > 1,
525 "Lookup table should contain more than just the seed"
526 );
527 }
528
529 #[test]
530 fn test_search_nearest_in_lookup_table() {
531 let table = vec![
532 LookupEntry {
533 x: 1.0,
534 y: 0.0,
535 yaw: 0.0,
536 s: 1.0,
537 km: 0.0,
538 kf: 0.0,
539 },
540 LookupEntry {
541 x: 10.0,
542 y: 5.0,
543 yaw: 0.5,
544 s: 11.0,
545 km: 0.1,
546 kf: 0.05,
547 },
548 ];
549 let nearest = search_nearest_in_lookup_table(9.0, 4.0, 0.4, &table).unwrap();
550 assert!((nearest.x - 10.0).abs() < 1e-10);
551 }
552
553 #[test]
554 fn test_search_nearest_empty_table() {
555 let table: Vec<LookupEntry> = vec![];
556 assert!(search_nearest_in_lookup_table(1.0, 0.0, 0.0, &table).is_none());
557 }
558
559 #[test]
560 fn test_config_default() {
561 let cfg = MptgConfig::default();
562 assert!((cfg.wheel_base - 1.0).abs() < 1e-10);
563 assert!((cfg.ds - 0.1).abs() < 1e-10);
564 assert!(cfg.max_iter == 100);
565 }
566}