Skip to main content

rust_robotics_planning/state_lattice/
trajectory_generator.rs

1//! Model Predictive Trajectory Generator
2//!
3//! Optimizes trajectory parameters to reach target states using Newton-Raphson method.
4//! Based on PythonRobotics implementation.
5
6use nalgebra::{Matrix3, Vector3};
7
8use super::motion_model::{normalize_angle, MotionModel};
9
10const DEFAULT_LOOKUP_TABLE_CSV: &str = include_str!("lookup_table.csv");
11
12/// Trajectory parameters: [s, km, kf]
13/// - s: arc length
14/// - km: middle curvature
15/// - kf: final curvature
16pub type TrajectoryParams = Vector3<f64>;
17pub type GeneratedTrajectory = (Vec<f64>, Vec<f64>, Vec<f64>, TrajectoryParams);
18
19/// Target state: [x, y, yaw]
20pub type TargetState = Vector3<f64>;
21
22/// Trajectory generator configuration
23#[derive(Debug, Clone)]
24pub struct TrajectoryGeneratorConfig {
25    /// Maximum optimization iterations
26    pub max_iter: usize,
27    /// Convergence threshold for cost
28    pub cost_threshold: f64,
29    /// Finite difference step for Jacobian
30    pub h: Vector3<f64>,
31    /// Initial curvature (steering at start)
32    pub k0: f64,
33}
34
35impl Default for TrajectoryGeneratorConfig {
36    fn default() -> Self {
37        Self {
38            max_iter: 100,
39            cost_threshold: 0.1,
40            h: Vector3::new(0.5, 0.02, 0.02), // [ds, dkm, dkf]
41            k0: 0.0,
42        }
43    }
44}
45
46/// Model Predictive Trajectory Generator
47pub struct TrajectoryGenerator {
48    pub(crate) motion_model: MotionModel,
49    config: TrajectoryGeneratorConfig,
50}
51
52impl TrajectoryGenerator {
53    pub fn new(motion_model: MotionModel, config: TrajectoryGeneratorConfig) -> Self {
54        Self {
55            motion_model,
56            config,
57        }
58    }
59
60    pub fn with_defaults() -> Self {
61        Self::new(
62            MotionModel::with_defaults(),
63            TrajectoryGeneratorConfig::default(),
64        )
65    }
66
67    /// Set initial curvature
68    pub fn set_k0(&mut self, k0: f64) {
69        self.config.k0 = k0;
70    }
71
72    /// Generate trajectory with given parameters
73    pub fn generate(&self, params: &TrajectoryParams) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
74        self.motion_model.generate_trajectory(
75            params[0], // s
76            self.config.k0,
77            params[1], // km
78            params[2], // kf
79        )
80    }
81
82    /// Calculate difference between current trajectory endpoint and target
83    fn calc_diff(&self, params: &TrajectoryParams, target: &TargetState) -> Vector3<f64> {
84        let (x_final, y_final, yaw_final) = self.motion_model.generate_trajectory_final_state(
85            params[0],
86            self.config.k0,
87            params[1],
88            params[2],
89        );
90
91        Vector3::new(
92            x_final - target[0],
93            y_final - target[1],
94            normalize_angle(yaw_final - target[2]),
95        )
96    }
97
98    /// Calculate cost (L2 norm of difference)
99    fn calc_cost(&self, params: &TrajectoryParams, target: &TargetState) -> f64 {
100        let diff = self.calc_diff(params, target);
101        diff.norm()
102    }
103
104    /// Calculate Jacobian matrix using finite differences
105    fn calc_jacobian(&self, params: &TrajectoryParams, target: &TargetState) -> Matrix3<f64> {
106        let h = &self.config.h;
107        let mut jacobian = Matrix3::zeros();
108        let diff_current = self.calc_diff(params, target);
109
110        for i in 0..3 {
111            let mut params_plus = *params;
112            params_plus[i] += h[i];
113            let mut params_minus = *params;
114            params_minus[i] -= h[i];
115
116            let diff_plus = self.calc_diff(&params_plus, target);
117            let diff_minus = self.calc_diff(&params_minus, target);
118
119            for j in 0..3 {
120                let delta = if i == 0 && params_minus[i] <= 0.0 {
121                    diff_plus[j] - diff_current[j]
122                } else {
123                    diff_plus[j] - diff_minus[j]
124                };
125                let denom = if i == 0 && params_minus[i] <= 0.0 {
126                    h[i]
127                } else {
128                    2.0 * h[i]
129                };
130                jacobian[(j, i)] = delta / denom;
131            }
132        }
133
134        jacobian
135    }
136
137    /// Line search to find optimal step size
138    fn line_search(
139        &self,
140        params: &TrajectoryParams,
141        dp: &Vector3<f64>,
142        target: &TargetState,
143    ) -> f64 {
144        let alphas = [1.0, 1.5];
145        let mut best_alpha = 1.0;
146        let mut min_cost = f64::MAX;
147
148        for &alpha in &alphas {
149            let new_params = params + alpha * dp;
150            if new_params[0] > 0.0 {
151                let cost = self.calc_cost(&new_params, target);
152                if cost < min_cost {
153                    min_cost = cost;
154                    best_alpha = alpha;
155                }
156            }
157        }
158
159        best_alpha
160    }
161
162    /// Optimize trajectory to reach target state
163    pub fn optimize(
164        &self,
165        target: &TargetState,
166        init_params: &TrajectoryParams,
167    ) -> Option<TrajectoryParams> {
168        let mut params = *init_params;
169
170        for _iter in 0..self.config.max_iter {
171            let cost = self.calc_cost(&params, target);
172
173            if cost < self.config.cost_threshold {
174                return Some(params);
175            }
176
177            let jacobian = self.calc_jacobian(&params, target);
178
179            let diff = self.calc_diff(&params, target);
180
181            if let Some(j_inv) = jacobian.try_inverse() {
182                let dp = -j_inv * diff;
183
184                let alpha = self.line_search(&params, &dp, target);
185                params += alpha * dp;
186
187                if params[0] < 0.1 {
188                    params[0] = 0.1;
189                }
190            } else {
191                return None;
192            }
193        }
194
195        None
196    }
197
198    /// Generate optimized trajectory to reach target
199    pub fn generate_optimized(
200        &self,
201        target: &TargetState,
202        init_params: &TrajectoryParams,
203    ) -> Option<GeneratedTrajectory> {
204        let params = self.optimize(target, init_params)?;
205        let (x, y, yaw) = self.generate(&params);
206        Some((x, y, yaw, params))
207    }
208}
209
210/// Lookup table entry for trajectory generation
211#[derive(Debug, Clone)]
212pub struct LookupTableEntry {
213    pub x: f64,
214    pub y: f64,
215    pub yaw: f64,
216    pub s: f64,
217    pub km: f64,
218    pub kf: f64,
219}
220
221impl LookupTableEntry {
222    pub fn new(x: f64, y: f64, yaw: f64, s: f64, km: f64, kf: f64) -> Self {
223        Self {
224            x,
225            y,
226            yaw,
227            s,
228            km,
229            kf,
230        }
231    }
232
233    /// Get target state from entry
234    pub fn target(&self) -> TargetState {
235        Vector3::new(self.x, self.y, self.yaw)
236    }
237
238    /// Get trajectory parameters from entry
239    pub fn params(&self) -> TrajectoryParams {
240        Vector3::new(self.s, self.km, self.kf)
241    }
242
243    /// Calculate distance to a target state
244    pub fn distance_to(&self, target: &TargetState) -> f64 {
245        let dx = self.x - target[0];
246        let dy = self.y - target[1];
247        let dyaw = self.yaw - target[2];
248        (dx * dx + dy * dy + dyaw * dyaw).sqrt()
249    }
250}
251
252/// Lookup table for efficient trajectory initialization
253#[derive(Debug, Clone)]
254pub struct LookupTable {
255    entries: Vec<LookupTableEntry>,
256}
257
258impl LookupTable {
259    pub fn new() -> Self {
260        Self {
261            entries: Vec::new(),
262        }
263    }
264
265    /// Create lookup table from CSV data
266    pub fn from_csv(csv_data: &str) -> Self {
267        let mut entries = Vec::new();
268
269        for line in csv_data.lines() {
270            let line = line.trim();
271            if line.is_empty() || line.starts_with('#') || line.starts_with("x,") {
272                continue;
273            }
274
275            let parts: Vec<&str> = line.split(',').collect();
276            if parts.len() >= 6 {
277                if let (Ok(x), Ok(y), Ok(yaw), Ok(s), Ok(km), Ok(kf)) = (
278                    parts[0].trim().parse::<f64>(),
279                    parts[1].trim().parse::<f64>(),
280                    parts[2].trim().parse::<f64>(),
281                    parts[3].trim().parse::<f64>(),
282                    parts[4].trim().parse::<f64>(),
283                    parts[5].trim().parse::<f64>(),
284                ) {
285                    entries.push(LookupTableEntry::new(x, y, yaw, s, km, kf));
286                }
287            }
288        }
289
290        Self { entries }
291    }
292
293    /// Generate default lookup table with common trajectories
294    pub fn generate_default() -> Self {
295        Self::from_csv(DEFAULT_LOOKUP_TABLE_CSV)
296    }
297
298    /// Find nearest entry to target state
299    pub fn find_nearest(&self, target: &TargetState) -> Option<&LookupTableEntry> {
300        self.entries.iter().min_by(|a, b| {
301            a.distance_to(target)
302                .partial_cmp(&b.distance_to(target))
303                .unwrap_or(std::cmp::Ordering::Equal)
304        })
305    }
306
307    /// Add entry to table
308    pub fn add(&mut self, entry: LookupTableEntry) {
309        self.entries.push(entry);
310    }
311
312    /// Get number of entries
313    pub fn len(&self) -> usize {
314        self.entries.len()
315    }
316
317    /// Check if table is empty
318    pub fn is_empty(&self) -> bool {
319        self.entries.is_empty()
320    }
321
322    /// Convert to CSV string
323    pub fn to_csv(&self) -> String {
324        let mut csv = String::from("x,y,yaw,s,km,kf\n");
325        for entry in &self.entries {
326            csv.push_str(&format!(
327                "{},{},{},{},{},{}\n",
328                entry.x, entry.y, entry.yaw, entry.s, entry.km, entry.kf
329            ));
330        }
331        csv
332    }
333}
334
335impl Default for LookupTable {
336    fn default() -> Self {
337        Self::generate_default()
338    }
339}
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn test_trajectory_generator_straight() {
347        let generator = TrajectoryGenerator::with_defaults();
348        let params = Vector3::new(5.0, 0.0, 0.0);
349
350        let (x, y, _yaw) = generator.generate(&params);
351
352        assert!(x.len() > 1);
353        let final_x = x.last().unwrap();
354        let final_y = y.last().unwrap();
355        assert!(*final_x > 4.0);
356        assert!(final_y.abs() < 0.1);
357    }
358
359    #[test]
360    fn test_trajectory_generator_turn() {
361        let generator = TrajectoryGenerator::with_defaults();
362        let params = Vector3::new(5.0, 0.1, 0.1);
363
364        let (x, y, _yaw) = generator.generate(&params);
365
366        assert!(x.len() > 1);
367        let final_y = y.last().unwrap();
368        assert!(*final_y > 0.0);
369    }
370
371    #[test]
372    fn test_calc_diff() {
373        let generator = TrajectoryGenerator::with_defaults();
374        let target = Vector3::new(5.0, 0.0, 0.0);
375        let params = Vector3::new(5.0, 0.0, 0.0);
376
377        let diff = generator.calc_diff(&params, &target);
378
379        assert!(diff.norm() < 1.0);
380    }
381
382    #[test]
383    fn test_optimize_straight() {
384        let generator = TrajectoryGenerator::with_defaults();
385        let target = Vector3::new(10.0, 0.0, 0.0);
386        let init_params = Vector3::new(10.0, 0.0, 0.0);
387
388        let result = generator.optimize(&target, &init_params);
389
390        assert!(result.is_some());
391        let params = result.unwrap();
392        assert!(params[0] > 0.0);
393    }
394
395    #[test]
396    fn test_optimize_turn() {
397        let generator = TrajectoryGenerator::with_defaults();
398        let target = Vector3::new(8.0, 3.0, 0.5);
399        let init_params = Vector3::new(10.0, 0.05, 0.05);
400
401        let result = generator.optimize(&target, &init_params);
402
403        if let Some(params) = result {
404            assert!(params[0] > 0.0);
405        }
406    }
407
408    #[test]
409    fn test_optimize_matches_upstream_lane_reference() {
410        let generator = TrajectoryGenerator::with_defaults();
411        let target = Vector3::new(10.0, 9.0, 0.0);
412        let init_params = Vector3::new(13.45362404707371, 0.1482242831571022, -0.5606578442626601);
413
414        let params = generator.optimize(&target, &init_params).unwrap();
415
416        assert!((params[0] - 14.806296460297).abs() < 1e-9);
417        assert!((params[1] - 0.148478839778).abs() < 1e-9);
418        assert!((params[2] - -0.57288113757).abs() < 1e-9);
419    }
420
421    #[test]
422    fn test_lookup_table_default() {
423        let table = LookupTable::generate_default();
424        assert!(!table.is_empty());
425        assert_eq!(table.len(), 81);
426    }
427
428    #[test]
429    fn test_lookup_table_default_matches_upstream_reference_rows() {
430        let table = LookupTable::generate_default();
431        let first = &table.entries[0];
432        let last = table.entries.last().unwrap();
433
434        assert!((first.x - 1.0).abs() < 1e-12);
435        assert!((first.y - 0.0).abs() < 1e-12);
436        assert!((first.yaw - 0.0).abs() < 1e-12);
437        assert!((first.s - 1.0).abs() < 1e-12);
438        assert!((first.km - 0.0).abs() < 1e-12);
439        assert!((first.kf - 0.0).abs() < 1e-12);
440
441        assert!((last.x - 24.960019173190652).abs() < 1e-12);
442        assert!((last.y - 17.98909417109214).abs() < 1e-12);
443        assert!((last.yaw - 0.011594018486178026).abs() < 1e-12);
444        assert!((last.s - 33.0995680641525).abs() < 1e-12);
445        assert!((last.km - 0.05634561447882407).abs() < 1e-12);
446        assert!((last.kf - -0.22402297280749597).abs() < 1e-12);
447    }
448
449    #[test]
450    fn test_lookup_table_find_nearest() {
451        let table = LookupTable::generate_default();
452        let target = Vector3::new(10.0, 0.0, 0.0);
453
454        let nearest = table.find_nearest(&target);
455        assert!(nearest.is_some());
456    }
457
458    #[test]
459    fn test_lookup_table_csv() {
460        let table = LookupTable::generate_default();
461        let csv = table.to_csv();
462
463        assert!(csv.contains("x,y,yaw,s,km,kf"));
464
465        let parsed = LookupTable::from_csv(&csv);
466        assert_eq!(table.len(), parsed.len());
467    }
468
469    #[test]
470    fn test_lookup_entry_distance() {
471        let entry = LookupTableEntry::new(10.0, 0.0, 0.0, 10.0, 0.0, 0.0);
472        let target = Vector3::new(10.0, 0.0, 0.0);
473
474        assert!(entry.distance_to(&target) < 0.001);
475    }
476}