1use nalgebra::{Matrix3, Vector3};
7
8use super::motion_model::{normalize_angle, MotionModel};
9
10const DEFAULT_LOOKUP_TABLE_CSV: &str = include_str!("lookup_table.csv");
11
12pub type TrajectoryParams = Vector3<f64>;
17pub type GeneratedTrajectory = (Vec<f64>, Vec<f64>, Vec<f64>, TrajectoryParams);
18
19pub type TargetState = Vector3<f64>;
21
22#[derive(Debug, Clone)]
24pub struct TrajectoryGeneratorConfig {
25 pub max_iter: usize,
27 pub cost_threshold: f64,
29 pub h: Vector3<f64>,
31 pub k0: f64,
33}
34
35impl Default for TrajectoryGeneratorConfig {
36 fn default() -> Self {
37 Self {
38 max_iter: 100,
39 cost_threshold: 0.1,
40 h: Vector3::new(0.5, 0.02, 0.02), k0: 0.0,
42 }
43 }
44}
45
46pub struct TrajectoryGenerator {
48 pub(crate) motion_model: MotionModel,
49 config: TrajectoryGeneratorConfig,
50}
51
52impl TrajectoryGenerator {
53 pub fn new(motion_model: MotionModel, config: TrajectoryGeneratorConfig) -> Self {
54 Self {
55 motion_model,
56 config,
57 }
58 }
59
60 pub fn with_defaults() -> Self {
61 Self::new(
62 MotionModel::with_defaults(),
63 TrajectoryGeneratorConfig::default(),
64 )
65 }
66
67 pub fn set_k0(&mut self, k0: f64) {
69 self.config.k0 = k0;
70 }
71
72 pub fn generate(&self, params: &TrajectoryParams) -> (Vec<f64>, Vec<f64>, Vec<f64>) {
74 self.motion_model.generate_trajectory(
75 params[0], self.config.k0,
77 params[1], params[2], )
80 }
81
82 fn calc_diff(&self, params: &TrajectoryParams, target: &TargetState) -> Vector3<f64> {
84 let (x_final, y_final, yaw_final) = self.motion_model.generate_trajectory_final_state(
85 params[0],
86 self.config.k0,
87 params[1],
88 params[2],
89 );
90
91 Vector3::new(
92 x_final - target[0],
93 y_final - target[1],
94 normalize_angle(yaw_final - target[2]),
95 )
96 }
97
98 fn calc_cost(&self, params: &TrajectoryParams, target: &TargetState) -> f64 {
100 let diff = self.calc_diff(params, target);
101 diff.norm()
102 }
103
104 fn calc_jacobian(&self, params: &TrajectoryParams, target: &TargetState) -> Matrix3<f64> {
106 let h = &self.config.h;
107 let mut jacobian = Matrix3::zeros();
108 let diff_current = self.calc_diff(params, target);
109
110 for i in 0..3 {
111 let mut params_plus = *params;
112 params_plus[i] += h[i];
113 let mut params_minus = *params;
114 params_minus[i] -= h[i];
115
116 let diff_plus = self.calc_diff(¶ms_plus, target);
117 let diff_minus = self.calc_diff(¶ms_minus, target);
118
119 for j in 0..3 {
120 let delta = if i == 0 && params_minus[i] <= 0.0 {
121 diff_plus[j] - diff_current[j]
122 } else {
123 diff_plus[j] - diff_minus[j]
124 };
125 let denom = if i == 0 && params_minus[i] <= 0.0 {
126 h[i]
127 } else {
128 2.0 * h[i]
129 };
130 jacobian[(j, i)] = delta / denom;
131 }
132 }
133
134 jacobian
135 }
136
137 fn line_search(
139 &self,
140 params: &TrajectoryParams,
141 dp: &Vector3<f64>,
142 target: &TargetState,
143 ) -> f64 {
144 let alphas = [1.0, 1.5];
145 let mut best_alpha = 1.0;
146 let mut min_cost = f64::MAX;
147
148 for &alpha in &alphas {
149 let new_params = params + alpha * dp;
150 if new_params[0] > 0.0 {
151 let cost = self.calc_cost(&new_params, target);
152 if cost < min_cost {
153 min_cost = cost;
154 best_alpha = alpha;
155 }
156 }
157 }
158
159 best_alpha
160 }
161
162 pub fn optimize(
164 &self,
165 target: &TargetState,
166 init_params: &TrajectoryParams,
167 ) -> Option<TrajectoryParams> {
168 let mut params = *init_params;
169
170 for _iter in 0..self.config.max_iter {
171 let cost = self.calc_cost(¶ms, target);
172
173 if cost < self.config.cost_threshold {
174 return Some(params);
175 }
176
177 let jacobian = self.calc_jacobian(¶ms, target);
178
179 let diff = self.calc_diff(¶ms, target);
180
181 if let Some(j_inv) = jacobian.try_inverse() {
182 let dp = -j_inv * diff;
183
184 let alpha = self.line_search(¶ms, &dp, target);
185 params += alpha * dp;
186
187 if params[0] < 0.1 {
188 params[0] = 0.1;
189 }
190 } else {
191 return None;
192 }
193 }
194
195 None
196 }
197
198 pub fn generate_optimized(
200 &self,
201 target: &TargetState,
202 init_params: &TrajectoryParams,
203 ) -> Option<GeneratedTrajectory> {
204 let params = self.optimize(target, init_params)?;
205 let (x, y, yaw) = self.generate(¶ms);
206 Some((x, y, yaw, params))
207 }
208}
209
210#[derive(Debug, Clone)]
212pub struct LookupTableEntry {
213 pub x: f64,
214 pub y: f64,
215 pub yaw: f64,
216 pub s: f64,
217 pub km: f64,
218 pub kf: f64,
219}
220
221impl LookupTableEntry {
222 pub fn new(x: f64, y: f64, yaw: f64, s: f64, km: f64, kf: f64) -> Self {
223 Self {
224 x,
225 y,
226 yaw,
227 s,
228 km,
229 kf,
230 }
231 }
232
233 pub fn target(&self) -> TargetState {
235 Vector3::new(self.x, self.y, self.yaw)
236 }
237
238 pub fn params(&self) -> TrajectoryParams {
240 Vector3::new(self.s, self.km, self.kf)
241 }
242
243 pub fn distance_to(&self, target: &TargetState) -> f64 {
245 let dx = self.x - target[0];
246 let dy = self.y - target[1];
247 let dyaw = self.yaw - target[2];
248 (dx * dx + dy * dy + dyaw * dyaw).sqrt()
249 }
250}
251
252#[derive(Debug, Clone)]
254pub struct LookupTable {
255 entries: Vec<LookupTableEntry>,
256}
257
258impl LookupTable {
259 pub fn new() -> Self {
260 Self {
261 entries: Vec::new(),
262 }
263 }
264
265 pub fn from_csv(csv_data: &str) -> Self {
267 let mut entries = Vec::new();
268
269 for line in csv_data.lines() {
270 let line = line.trim();
271 if line.is_empty() || line.starts_with('#') || line.starts_with("x,") {
272 continue;
273 }
274
275 let parts: Vec<&str> = line.split(',').collect();
276 if parts.len() >= 6 {
277 if let (Ok(x), Ok(y), Ok(yaw), Ok(s), Ok(km), Ok(kf)) = (
278 parts[0].trim().parse::<f64>(),
279 parts[1].trim().parse::<f64>(),
280 parts[2].trim().parse::<f64>(),
281 parts[3].trim().parse::<f64>(),
282 parts[4].trim().parse::<f64>(),
283 parts[5].trim().parse::<f64>(),
284 ) {
285 entries.push(LookupTableEntry::new(x, y, yaw, s, km, kf));
286 }
287 }
288 }
289
290 Self { entries }
291 }
292
293 pub fn generate_default() -> Self {
295 Self::from_csv(DEFAULT_LOOKUP_TABLE_CSV)
296 }
297
298 pub fn find_nearest(&self, target: &TargetState) -> Option<&LookupTableEntry> {
300 self.entries.iter().min_by(|a, b| {
301 a.distance_to(target)
302 .partial_cmp(&b.distance_to(target))
303 .unwrap_or(std::cmp::Ordering::Equal)
304 })
305 }
306
307 pub fn add(&mut self, entry: LookupTableEntry) {
309 self.entries.push(entry);
310 }
311
312 pub fn len(&self) -> usize {
314 self.entries.len()
315 }
316
317 pub fn is_empty(&self) -> bool {
319 self.entries.is_empty()
320 }
321
322 pub fn to_csv(&self) -> String {
324 let mut csv = String::from("x,y,yaw,s,km,kf\n");
325 for entry in &self.entries {
326 csv.push_str(&format!(
327 "{},{},{},{},{},{}\n",
328 entry.x, entry.y, entry.yaw, entry.s, entry.km, entry.kf
329 ));
330 }
331 csv
332 }
333}
334
335impl Default for LookupTable {
336 fn default() -> Self {
337 Self::generate_default()
338 }
339}
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_trajectory_generator_straight() {
347 let generator = TrajectoryGenerator::with_defaults();
348 let params = Vector3::new(5.0, 0.0, 0.0);
349
350 let (x, y, _yaw) = generator.generate(¶ms);
351
352 assert!(x.len() > 1);
353 let final_x = x.last().unwrap();
354 let final_y = y.last().unwrap();
355 assert!(*final_x > 4.0);
356 assert!(final_y.abs() < 0.1);
357 }
358
359 #[test]
360 fn test_trajectory_generator_turn() {
361 let generator = TrajectoryGenerator::with_defaults();
362 let params = Vector3::new(5.0, 0.1, 0.1);
363
364 let (x, y, _yaw) = generator.generate(¶ms);
365
366 assert!(x.len() > 1);
367 let final_y = y.last().unwrap();
368 assert!(*final_y > 0.0);
369 }
370
371 #[test]
372 fn test_calc_diff() {
373 let generator = TrajectoryGenerator::with_defaults();
374 let target = Vector3::new(5.0, 0.0, 0.0);
375 let params = Vector3::new(5.0, 0.0, 0.0);
376
377 let diff = generator.calc_diff(¶ms, &target);
378
379 assert!(diff.norm() < 1.0);
380 }
381
382 #[test]
383 fn test_optimize_straight() {
384 let generator = TrajectoryGenerator::with_defaults();
385 let target = Vector3::new(10.0, 0.0, 0.0);
386 let init_params = Vector3::new(10.0, 0.0, 0.0);
387
388 let result = generator.optimize(&target, &init_params);
389
390 assert!(result.is_some());
391 let params = result.unwrap();
392 assert!(params[0] > 0.0);
393 }
394
395 #[test]
396 fn test_optimize_turn() {
397 let generator = TrajectoryGenerator::with_defaults();
398 let target = Vector3::new(8.0, 3.0, 0.5);
399 let init_params = Vector3::new(10.0, 0.05, 0.05);
400
401 let result = generator.optimize(&target, &init_params);
402
403 if let Some(params) = result {
404 assert!(params[0] > 0.0);
405 }
406 }
407
408 #[test]
409 fn test_optimize_matches_upstream_lane_reference() {
410 let generator = TrajectoryGenerator::with_defaults();
411 let target = Vector3::new(10.0, 9.0, 0.0);
412 let init_params = Vector3::new(13.45362404707371, 0.1482242831571022, -0.5606578442626601);
413
414 let params = generator.optimize(&target, &init_params).unwrap();
415
416 assert!((params[0] - 14.806296460297).abs() < 1e-9);
417 assert!((params[1] - 0.148478839778).abs() < 1e-9);
418 assert!((params[2] - -0.57288113757).abs() < 1e-9);
419 }
420
421 #[test]
422 fn test_lookup_table_default() {
423 let table = LookupTable::generate_default();
424 assert!(!table.is_empty());
425 assert_eq!(table.len(), 81);
426 }
427
428 #[test]
429 fn test_lookup_table_default_matches_upstream_reference_rows() {
430 let table = LookupTable::generate_default();
431 let first = &table.entries[0];
432 let last = table.entries.last().unwrap();
433
434 assert!((first.x - 1.0).abs() < 1e-12);
435 assert!((first.y - 0.0).abs() < 1e-12);
436 assert!((first.yaw - 0.0).abs() < 1e-12);
437 assert!((first.s - 1.0).abs() < 1e-12);
438 assert!((first.km - 0.0).abs() < 1e-12);
439 assert!((first.kf - 0.0).abs() < 1e-12);
440
441 assert!((last.x - 24.960019173190652).abs() < 1e-12);
442 assert!((last.y - 17.98909417109214).abs() < 1e-12);
443 assert!((last.yaw - 0.011594018486178026).abs() < 1e-12);
444 assert!((last.s - 33.0995680641525).abs() < 1e-12);
445 assert!((last.km - 0.05634561447882407).abs() < 1e-12);
446 assert!((last.kf - -0.22402297280749597).abs() < 1e-12);
447 }
448
449 #[test]
450 fn test_lookup_table_find_nearest() {
451 let table = LookupTable::generate_default();
452 let target = Vector3::new(10.0, 0.0, 0.0);
453
454 let nearest = table.find_nearest(&target);
455 assert!(nearest.is_some());
456 }
457
458 #[test]
459 fn test_lookup_table_csv() {
460 let table = LookupTable::generate_default();
461 let csv = table.to_csv();
462
463 assert!(csv.contains("x,y,yaw,s,km,kf"));
464
465 let parsed = LookupTable::from_csv(&csv);
466 assert_eq!(table.len(), parsed.len());
467 }
468
469 #[test]
470 fn test_lookup_entry_distance() {
471 let entry = LookupTableEntry::new(10.0, 0.0, 0.0, 10.0, 0.0, 0.0);
472 let target = Vector3::new(10.0, 0.0, 0.0);
473
474 assert!(entry.distance_to(&target) < 0.001);
475 }
476}