Skip to main content

rust_robotics_planning/
eta3_spline.rs

1#![allow(clippy::excessive_precision, clippy::too_many_arguments)]
2
3//! Eta^3 spline path and trajectory planner
4//!
5//! Generates smooth paths for wheeled mobile robots using eta^3 polynomial
6//! splines. Each segment is a 7th-degree parametric curve connecting two
7//! poses (position + heading), shaped by eta parameters and curvature
8//! constraints.
9//!
10//! Reference:
11//! - \[eta^3-Splines for the Smooth Path Generation of Wheeled Mobile Robots\]
12//!   (<https://ieeexplore.ieee.org/document/4339545/>)
13
14/// A pose in 2D: position (x, y) and heading angle theta \[rad\].
15#[derive(Debug, Clone, Copy)]
16pub struct Pose2D {
17    pub x: f64,
18    pub y: f64,
19    pub theta: f64,
20}
21
22impl Pose2D {
23    pub fn new(x: f64, y: f64, theta: f64) -> Self {
24        Self { x, y, theta }
25    }
26}
27
28/// Shaping parameters for an eta^3 spline segment.
29///
30/// Six values `[eta0..eta5]` controlling the curve shape.
31/// `eta[0..2]` affect the start side, `eta[3..5]` affect the end side.
32#[derive(Debug, Clone, Copy)]
33pub struct EtaParams {
34    pub values: [f64; 6],
35}
36
37impl EtaParams {
38    pub fn new(values: [f64; 6]) -> Self {
39        Self { values }
40    }
41
42    pub fn zeros() -> Self {
43        Self { values: [0.0; 6] }
44    }
45}
46
47/// Curvature parameters at the segment endpoints.
48///
49/// `[kappa_a, kappa_dot_a, kappa_b, kappa_dot_b]`
50/// where `_a` is the start and `_b` is the end of the segment.
51#[derive(Debug, Clone, Copy)]
52pub struct KappaParams {
53    pub values: [f64; 4],
54}
55
56impl KappaParams {
57    pub fn new(values: [f64; 4]) -> Self {
58        Self { values }
59    }
60
61    pub fn zeros() -> Self {
62        Self { values: [0.0; 4] }
63    }
64}
65
66/// A single eta^3 path segment connecting two poses.
67///
68/// Internally stores 2x8 polynomial coefficients (x and y, degree 0..7).
69#[derive(Debug, Clone)]
70pub struct Eta3PathSegment {
71    pub start_pose: Pose2D,
72    pub end_pose: Pose2D,
73    /// Polynomial coefficients: `coeffs[dim][degree]` where dim 0=x, 1=y.
74    coeffs: [[f64; 8]; 2],
75    /// Precomputed total arc length of this segment.
76    pub segment_length: f64,
77}
78
79impl Eta3PathSegment {
80    /// Create a new segment from start/end poses with shaping and curvature
81    /// parameters.
82    pub fn new(start_pose: Pose2D, end_pose: Pose2D, eta: &EtaParams, kappa: &KappaParams) -> Self {
83        let e = &eta.values;
84        let k = &kappa.values;
85
86        let ca = start_pose.theta.cos();
87        let sa = start_pose.theta.sin();
88        let cb = end_pose.theta.cos();
89        let sb = end_pose.theta.sin();
90
91        let dx = end_pose.x - start_pose.x;
92        let dy = end_pose.y - start_pose.y;
93
94        let mut coeffs = [[0.0f64; 8]; 2];
95
96        // u^0 (constant)
97        coeffs[0][0] = start_pose.x;
98        coeffs[1][0] = start_pose.y;
99
100        // u^1 (linear)
101        coeffs[0][1] = e[0] * ca;
102        coeffs[1][1] = e[0] * sa;
103
104        // u^2 (quadratic)
105        coeffs[0][2] = 0.5 * e[2] * ca - 0.5 * e[0].powi(2) * k[0] * sa;
106        coeffs[1][2] = 0.5 * e[2] * sa + 0.5 * e[0].powi(2) * k[0] * ca;
107
108        // u^3 (cubic)
109        let cubic_curv = e[0].powi(3) * k[1] + 3.0 * e[0] * e[2] * k[0];
110        coeffs[0][3] = (1.0 / 6.0) * e[4] * ca - (1.0 / 6.0) * cubic_curv * sa;
111        coeffs[1][3] = (1.0 / 6.0) * e[4] * sa + (1.0 / 6.0) * cubic_curv * ca;
112
113        // u^4 (quartic)
114        {
115            let t1 = 35.0 * dx;
116            let t2 = (20.0 * e[0] + 5.0 * e[2] + (2.0 / 3.0) * e[4]) * ca;
117            let t3 = (5.0 * e[0].powi(2) * k[0]
118                + (2.0 / 3.0) * e[0].powi(3) * k[1]
119                + 2.0 * e[0] * e[2] * k[0])
120                * sa;
121            let t4 = (15.0 * e[1] - 2.5 * e[3] + (1.0 / 6.0) * e[5]) * cb;
122            let t5 = (2.5 * e[1].powi(2) * k[2]
123                - (1.0 / 6.0) * e[1].powi(3) * k[3]
124                - 0.5 * e[1] * e[3] * k[2])
125                * sb;
126            coeffs[0][4] = t1 - t2 + t3 - t4 - t5;
127
128            let t1 = 35.0 * dy;
129            let t2 = (20.0 * e[0] + 5.0 * e[2] + (2.0 / 3.0) * e[4]) * sa;
130            let t3 = (5.0 * e[0].powi(2) * k[0]
131                + (2.0 / 3.0) * e[0].powi(3) * k[1]
132                + 2.0 * e[0] * e[2] * k[0])
133                * ca;
134            let t4 = (15.0 * e[1] - 2.5 * e[3] + (1.0 / 6.0) * e[5]) * sb;
135            let t5 = (2.5 * e[1].powi(2) * k[2]
136                - (1.0 / 6.0) * e[1].powi(3) * k[3]
137                - 0.5 * e[1] * e[3] * k[2])
138                * cb;
139            coeffs[1][4] = t1 - t2 - t3 - t4 + t5;
140        }
141
142        // u^5 (quintic)
143        {
144            let t1 = -84.0 * dx;
145            let t2 = (45.0 * e[0] + 10.0 * e[2] + e[4]) * ca;
146            let t3 =
147                (10.0 * e[0].powi(2) * k[0] + e[0].powi(3) * k[1] + 3.0 * e[0] * e[2] * k[0]) * sa;
148            let t4 = (39.0 * e[1] - 7.0 * e[3] + 0.5 * e[5]) * cb;
149            let t5 =
150                (7.0 * e[1].powi(2) * k[2] - 0.5 * e[1].powi(3) * k[3] - 1.5 * e[1] * e[3] * k[2])
151                    * sb;
152            coeffs[0][5] = t1 + t2 - t3 + t4 + t5;
153
154            let t1 = -84.0 * dy;
155            let t2 = (45.0 * e[0] + 10.0 * e[2] + e[4]) * sa;
156            let t3 =
157                (10.0 * e[0].powi(2) * k[0] + e[0].powi(3) * k[1] + 3.0 * e[0] * e[2] * k[0]) * ca;
158            let t4 = (39.0 * e[1] - 7.0 * e[3] + 0.5 * e[5]) * sb;
159            let t5 =
160                -(7.0 * e[1].powi(2) * k[2] - 0.5 * e[1].powi(3) * k[3] - 1.5 * e[1] * e[3] * k[2])
161                    * cb;
162            coeffs[1][5] = t1 + t2 + t3 + t4 + t5;
163        }
164
165        // u^6 (sextic)
166        {
167            let t1 = 70.0 * dx;
168            let t2 = (36.0 * e[0] + 7.5 * e[2] + (2.0 / 3.0) * e[4]) * ca;
169            let t3 = (7.5 * e[0].powi(2) * k[0]
170                + (2.0 / 3.0) * e[0].powi(3) * k[1]
171                + 2.0 * e[0] * e[2] * k[0])
172                * sa;
173            let t4 = (34.0 * e[1] - 6.5 * e[3] + 0.5 * e[5]) * cb;
174            let t5 =
175                -(6.5 * e[1].powi(2) * k[2] - 0.5 * e[1].powi(3) * k[3] - 1.5 * e[1] * e[3] * k[2])
176                    * sb;
177            coeffs[0][6] = t1 - t2 + t3 - t4 + t5;
178
179            let t1 = 70.0 * dy;
180            let t2 = -(36.0 * e[0] + 7.5 * e[2] + (2.0 / 3.0) * e[4]) * sa;
181            let t3 = -(7.5 * e[0].powi(2) * k[0]
182                + (2.0 / 3.0) * e[0].powi(3) * k[1]
183                + 2.0 * e[0] * e[2] * k[0])
184                * ca;
185            let t4 = -(34.0 * e[1] - 6.5 * e[3] + 0.5 * e[5]) * sb;
186            let t5 =
187                (6.5 * e[1].powi(2) * k[2] - 0.5 * e[1].powi(3) * k[3] - 1.5 * e[1] * e[3] * k[2])
188                    * cb;
189            coeffs[1][6] = t1 + t2 + t3 + t4 + t5;
190        }
191
192        // u^7 (septic)
193        {
194            let t1 = -20.0 * dx;
195            let t2 = (10.0 * e[0] + 2.0 * e[2] + (1.0 / 6.0) * e[4]) * ca;
196            let t3 = -(2.0 * e[0].powi(2) * k[0]
197                + (1.0 / 6.0) * e[0].powi(3) * k[1]
198                + 0.5 * e[0] * e[2] * k[0])
199                * sa;
200            let t4 = (10.0 * e[1] - 2.0 * e[3] + (1.0 / 6.0) * e[5]) * cb;
201            let t5 = (2.0 * e[1].powi(2) * k[2]
202                - (1.0 / 6.0) * e[1].powi(3) * k[3]
203                - 0.5 * e[1] * e[3] * k[2])
204                * sb;
205            coeffs[0][7] = t1 + t2 + t3 + t4 + t5;
206
207            let t1 = -20.0 * dy;
208            let t2 = (10.0 * e[0] + 2.0 * e[2] + (1.0 / 6.0) * e[4]) * sa;
209            let t3 = (2.0 * e[0].powi(2) * k[0]
210                + (1.0 / 6.0) * e[0].powi(3) * k[1]
211                + 0.5 * e[0] * e[2] * k[0])
212                * ca;
213            let t4 = (10.0 * e[1] - 2.0 * e[3] + (1.0 / 6.0) * e[5]) * sb;
214            let t5 = -(2.0 * e[1].powi(2) * k[2]
215                - (1.0 / 6.0) * e[1].powi(3) * k[3]
216                - 0.5 * e[1] * e[3] * k[2])
217                * cb;
218            coeffs[1][7] = t1 + t2 + t3 + t4 + t5;
219        }
220
221        let segment_length =
222            gauss_legendre_integrate(|u| Self::s_dot_from_coeffs(&coeffs, u), 0.0, 1.0);
223
224        Self {
225            start_pose,
226            end_pose,
227            coeffs,
228            segment_length,
229        }
230    }
231
232    /// Evaluate the position (x, y) at parameter `u` in \[0, 1\].
233    pub fn calc_point(&self, u: f64) -> (f64, f64) {
234        let powers = [
235            1.0,
236            u,
237            u.powi(2),
238            u.powi(3),
239            u.powi(4),
240            u.powi(5),
241            u.powi(6),
242            u.powi(7),
243        ];
244        let x: f64 = self.coeffs[0].iter().zip(&powers).map(|(c, p)| c * p).sum();
245        let y: f64 = self.coeffs[1].iter().zip(&powers).map(|(c, p)| c * p).sum();
246        (x, y)
247    }
248
249    /// First derivative (dx/du, dy/du) at parameter `u`.
250    pub fn calc_first_deriv(&self, u: f64) -> (f64, f64) {
251        let dpowers = [
252            1.0,
253            2.0 * u,
254            3.0 * u.powi(2),
255            4.0 * u.powi(3),
256            5.0 * u.powi(4),
257            6.0 * u.powi(5),
258            7.0 * u.powi(6),
259        ];
260        let dx: f64 = self.coeffs[0][1..]
261            .iter()
262            .zip(&dpowers)
263            .map(|(c, p)| c * p)
264            .sum();
265        let dy: f64 = self.coeffs[1][1..]
266            .iter()
267            .zip(&dpowers)
268            .map(|(c, p)| c * p)
269            .sum();
270        (dx, dy)
271    }
272
273    /// Second derivative (d^2x/du^2, d^2y/du^2) at parameter `u`.
274    pub fn calc_second_deriv(&self, u: f64) -> (f64, f64) {
275        let ddpowers = [
276            2.0,
277            6.0 * u,
278            12.0 * u.powi(2),
279            20.0 * u.powi(3),
280            30.0 * u.powi(4),
281            42.0 * u.powi(5),
282        ];
283        let ddx: f64 = self.coeffs[0][2..]
284            .iter()
285            .zip(&ddpowers)
286            .map(|(c, p)| c * p)
287            .sum();
288        let ddy: f64 = self.coeffs[1][2..]
289            .iter()
290            .zip(&ddpowers)
291            .map(|(c, p)| c * p)
292            .sum();
293        (ddx, ddy)
294    }
295
296    /// Rate of change of arc length with respect to u: ||dr/du||.
297    /// Clamped to a minimum of 1e-6 to avoid division by zero.
298    pub fn s_dot(&self, u: f64) -> f64 {
299        Self::s_dot_from_coeffs(&self.coeffs, u)
300    }
301
302    fn s_dot_from_coeffs(coeffs: &[[f64; 8]; 2], u: f64) -> f64 {
303        let dpowers = [
304            1.0,
305            2.0 * u,
306            3.0 * u.powi(2),
307            4.0 * u.powi(3),
308            5.0 * u.powi(4),
309            6.0 * u.powi(5),
310            7.0 * u.powi(6),
311        ];
312        let dx: f64 = coeffs[0][1..]
313            .iter()
314            .zip(&dpowers)
315            .map(|(c, p)| c * p)
316            .sum();
317        let dy: f64 = coeffs[1][1..]
318            .iter()
319            .zip(&dpowers)
320            .map(|(c, p)| c * p)
321            .sum();
322        (dx * dx + dy * dy).sqrt().max(1e-6)
323    }
324
325    /// Arc length from u=0 to `u_end`.
326    pub fn arc_length(&self, u_end: f64) -> f64 {
327        gauss_legendre_integrate(|u| self.s_dot(u), 0.0, u_end)
328    }
329}
330
331/// A multi-segment eta^3 path composed of contiguous segments.
332#[derive(Debug, Clone)]
333pub struct Eta3Path {
334    pub segments: Vec<Eta3PathSegment>,
335}
336
337impl Eta3Path {
338    /// Create a path from a list of segments.
339    ///
340    /// Panics if the segment list is empty.
341    pub fn new(segments: Vec<Eta3PathSegment>) -> Self {
342        assert!(!segments.is_empty(), "At least one segment is required");
343        Self { segments }
344    }
345
346    /// Evaluate the path at a normalised parameter `u` in \[0, num_segments\].
347    ///
348    /// Integer values correspond to segment boundaries.
349    pub fn calc_path_point(&self, u: f64) -> (f64, f64) {
350        let n = self.segments.len();
351        let (seg_idx, local_u) = if (u - n as f64).abs() < 1e-12 || u >= n as f64 {
352            (n - 1, 1.0)
353        } else {
354            let idx = u.floor() as usize;
355            (idx.min(n - 1), u - idx as f64)
356        };
357        self.segments[seg_idx].calc_point(local_u)
358    }
359
360    /// Generate a sampled path as a vector of (x, y) points.
361    ///
362    /// `num_points` is the total number of samples across all segments.
363    pub fn sample(&self, num_points: usize) -> Vec<(f64, f64)> {
364        let n = self.segments.len() as f64;
365        (0..num_points)
366            .map(|i| {
367                let u = n * i as f64 / (num_points - 1) as f64;
368                self.calc_path_point(u)
369            })
370            .collect()
371    }
372
373    /// Total arc length of all segments.
374    pub fn total_length(&self) -> f64 {
375        self.segments.iter().map(|s| s.segment_length).sum()
376    }
377}
378
379// ---------------------------------------------------------------------------
380// Trajectory (velocity-profiled path)
381// ---------------------------------------------------------------------------
382
383/// Configuration for the trapezoidal-with-jerk velocity profile.
384#[derive(Debug, Clone, Copy)]
385pub struct TrajectoryConfig {
386    pub max_vel: f64,
387    pub v0: f64,
388    pub a0: f64,
389    pub max_accel: f64,
390    pub max_jerk: f64,
391}
392
393impl TrajectoryConfig {
394    pub fn new(max_vel: f64, max_accel: f64, max_jerk: f64) -> Self {
395        Self {
396            max_vel,
397            v0: 0.0,
398            a0: 0.0,
399            max_accel,
400            max_jerk,
401        }
402    }
403}
404
405/// Seven-section velocity profile (jerk-limited S-curve).
406#[derive(Debug, Clone)]
407struct VelocityProfile {
408    /// Duration of each of the 7 sections.
409    times: [f64; 7],
410    /// Velocity at the end of each section.
411    vels: [f64; 7],
412    /// Arc-length traversed in each section.
413    seg_lengths: [f64; 7],
414    #[allow(dead_code)]
415    max_vel: f64,
416    v0: f64,
417    max_accel: f64,
418    max_jerk: f64,
419    total_time: f64,
420    total_length: f64,
421}
422
423impl VelocityProfile {
424    fn compute(config: &TrajectoryConfig, total_length: f64) -> Self {
425        let max_jerk = config.max_jerk;
426        let max_accel = config.max_accel;
427        let v0 = config.v0;
428        let a0 = config.a0;
429
430        // Section 0: max jerk up to max acceleration
431        let delta_a = max_accel - a0;
432        let t_s1 = delta_a / max_jerk;
433        let v_s1 = v0 + a0 * t_s1 + max_jerk * t_s1.powi(2) / 2.0;
434        let s_s1 = v0 * t_s1 + a0 * t_s1.powi(2) / 2.0 + max_jerk * t_s1.powi(3) / 6.0;
435
436        // Final section parameters
437        let t_sf = max_accel / max_jerk;
438        let v_sf = max_jerk * t_sf.powi(2) / 2.0;
439        let s_sf = max_jerk * t_sf.powi(3) / 6.0;
440
441        // Solve quadratic for achievable max velocity
442        let a_coeff = 1.0 / max_accel;
443        let b_coeff = 1.5 * max_accel / max_jerk + v_s1 / max_accel
444            - (max_accel.powi(2) / max_jerk + v_s1) / max_accel;
445        let c_coeff = s_s1 + s_sf
446            - total_length
447            - 7.0 * max_accel.powi(3) / (3.0 * max_jerk.powi(2))
448            - v_s1 * (max_accel / max_jerk + v_s1 / max_accel)
449            + (max_accel.powi(2) / max_jerk + v_s1 / max_accel).powi(2) / (2.0 * max_accel);
450
451        let discriminant = b_coeff.powi(2) - 4.0 * a_coeff * c_coeff;
452        let v_max_achievable = (-b_coeff + discriminant.max(0.0).sqrt()) / (2.0 * a_coeff);
453        let max_vel = config.max_vel.min(v_max_achievable);
454
455        let mut times = [0.0f64; 7];
456        let mut vels = [0.0f64; 7];
457        let mut seg_lengths = [0.0f64; 7];
458
459        // Section 0
460        times[0] = t_s1;
461        vels[0] = v_s1;
462        seg_lengths[0] = s_s1;
463
464        // Section 1: accelerate at max_accel
465        let dv1 = (max_vel - max_jerk * (max_accel / max_jerk).powi(2) / 2.0) - vels[0];
466        times[1] = dv1 / max_accel;
467        vels[1] = vels[0] + max_accel * times[1];
468        seg_lengths[1] = vels[0] * times[1] + max_accel * times[1].powi(2) / 2.0;
469
470        // Section 2: decrease acceleration to 0
471        times[2] = max_accel / max_jerk;
472        vels[2] = vels[1] + max_accel * times[2] - max_jerk * times[2].powi(2) / 2.0;
473        seg_lengths[2] = vels[1] * times[2] + max_accel * times[2].powi(2) / 2.0
474            - max_jerk * times[2].powi(3) / 6.0;
475
476        // Section 4: negative jerk
477        times[4] = max_accel / max_jerk;
478        vels[4] = max_vel - max_jerk * times[4].powi(2) / 2.0;
479        seg_lengths[4] = max_vel * times[4] - max_jerk * times[4].powi(3) / 6.0;
480
481        // Section 5: decelerate at max rate
482        let dv5 = vels[4] - v_sf;
483        times[5] = dv5 / max_accel;
484        vels[5] = vels[4] - max_accel * times[5];
485        seg_lengths[5] = vels[4] * times[5] - max_accel * times[5].powi(2) / 2.0;
486
487        // Section 6: final jerk to zero velocity
488        times[6] = t_sf;
489        vels[6] = vels[5] - max_jerk * t_sf.powi(2) / 2.0;
490        seg_lengths[6] = s_sf;
491
492        // Section 3: cruise (fill remaining distance)
493        let used: f64 = seg_lengths.iter().sum();
494        if used < total_length {
495            seg_lengths[3] = total_length - used;
496            vels[3] = max_vel;
497            times[3] = seg_lengths[3] / max_vel;
498        }
499
500        let total_time: f64 = times.iter().sum();
501
502        Self {
503            times,
504            vels,
505            seg_lengths,
506            max_vel,
507            v0,
508            max_accel,
509            max_jerk,
510            total_time,
511            total_length,
512        }
513    }
514
515    /// Compute (linear_velocity, arc_length, linear_acceleration) at a given time.
516    fn query(&self, time: f64) -> (f64, f64, f64) {
517        let cum_time = |n: usize| -> f64 { self.times[..n].iter().sum() };
518
519        if time <= self.times[0] {
520            let v = self.v0 + self.max_jerk * time.powi(2) / 2.0;
521            let s = self.v0 * time + self.max_jerk * time.powi(3) / 6.0;
522            let a = self.max_jerk * time;
523            (v, s, a)
524        } else if time <= cum_time(2) {
525            let dt = time - cum_time(1);
526            let v = self.vels[0] + self.max_accel * dt;
527            let s = self.seg_lengths[0] + self.vels[0] * dt + self.max_accel * dt.powi(2) / 2.0;
528            (v, s, self.max_accel)
529        } else if time <= cum_time(3) {
530            let dt = time - cum_time(2);
531            let v = self.vels[1] + self.max_accel * dt - self.max_jerk * dt.powi(2) / 2.0;
532            let s = self.seg_lengths[..2].iter().sum::<f64>()
533                + self.vels[1] * dt
534                + self.max_accel * dt.powi(2) / 2.0
535                - self.max_jerk * dt.powi(3) / 6.0;
536            let a = self.max_accel - self.max_jerk * dt;
537            (v, s, a)
538        } else if time <= cum_time(4) {
539            let dt = time - cum_time(3);
540            let v = self.vels[3];
541            let s = self.seg_lengths[..3].iter().sum::<f64>() + self.vels[3] * dt;
542            (v, s, 0.0)
543        } else if time <= cum_time(5) {
544            let dt = time - cum_time(4);
545            let v = self.vels[3] - self.max_jerk * dt.powi(2) / 2.0;
546            let s = self.seg_lengths[..4].iter().sum::<f64>() + self.vels[3] * dt
547                - self.max_jerk * dt.powi(3) / 6.0;
548            let a = -self.max_jerk * dt;
549            (v, s, a)
550        } else if time <= cum_time(6) {
551            let dt = time - cum_time(5);
552            let v = self.vels[4] - self.max_accel * dt;
553            let s = self.seg_lengths[..5].iter().sum::<f64>() + self.vels[4] * dt
554                - self.max_accel * dt.powi(2) / 2.0;
555            (v, s, -self.max_accel)
556        } else if time < self.total_time {
557            let dt = time - cum_time(6);
558            let v = self.vels[5] - self.max_accel * dt + self.max_jerk * dt.powi(2) / 2.0;
559            let s = self.seg_lengths[..6].iter().sum::<f64>() + self.vels[5] * dt
560                - self.max_accel * dt.powi(2) / 2.0
561                + self.max_jerk * dt.powi(3) / 6.0;
562            let a = -self.max_accel + self.max_jerk * dt;
563            (v, s, a)
564        } else {
565            (0.0, self.total_length, 0.0)
566        }
567    }
568}
569
570/// A trajectory state at a single instant.
571#[derive(Debug, Clone, Copy)]
572pub struct TrajectoryState {
573    pub x: f64,
574    pub y: f64,
575    pub theta: f64,
576    pub linear_velocity: f64,
577    pub angular_velocity: f64,
578}
579
580/// Eta^3 spline trajectory: a path with a jerk-limited velocity profile.
581#[derive(Debug, Clone)]
582pub struct Eta3Trajectory {
583    path: Eta3Path,
584    profile: VelocityProfile,
585    /// Cumulative arc lengths at segment boundaries (starts with 0).
586    cum_lengths: Vec<f64>,
587}
588
589impl Eta3Trajectory {
590    /// Build a trajectory from path segments and kinematic constraints.
591    pub fn new(segments: Vec<Eta3PathSegment>, config: TrajectoryConfig) -> Self {
592        let path = Eta3Path::new(segments);
593        let total_length = path.total_length();
594        let profile = VelocityProfile::compute(&config, total_length);
595
596        let mut cum_lengths = Vec::with_capacity(path.segments.len() + 1);
597        cum_lengths.push(0.0);
598        let mut acc = 0.0;
599        for seg in &path.segments {
600            acc += seg.segment_length;
601            cum_lengths.push(acc);
602        }
603
604        Self {
605            path,
606            profile,
607            cum_lengths,
608        }
609    }
610
611    /// Total trajectory time \[s\].
612    pub fn total_time(&self) -> f64 {
613        self.profile.total_time
614    }
615
616    /// Evaluate the trajectory state at a given time.
617    pub fn calc_traj_point(&self, time: f64) -> TrajectoryState {
618        let (linear_velocity, s, linear_accel) = self.profile.query(time);
619
620        // Find which path segment contains arc-length s
621        let n = self.path.segments.len();
622        let mut seg_id = 0;
623        for i in (0..n).rev() {
624            if s >= self.cum_lengths[i] {
625                seg_id = i;
626                break;
627            }
628        }
629        if seg_id >= n {
630            seg_id = n - 1;
631        }
632
633        let ui = if seg_id == n - 1 && (s - self.cum_lengths[n]).abs() < 1e-9 {
634            1.0
635        } else {
636            let local_s = s - self.cum_lengths[seg_id];
637            self.get_interp_param(seg_id, local_s)
638        };
639
640        let seg = &self.path.segments[seg_id];
641        let pos = seg.calc_point(ui);
642        let d = seg.calc_first_deriv(ui);
643        let dd = seg.calc_second_deriv(ui);
644        let su = seg.s_dot(ui);
645
646        let angular_velocity = if su.abs() > 1e-6 && linear_velocity.abs() > 1e-6 {
647            let ut = linear_velocity / su;
648            let utt = linear_accel / su - (d.0 * dd.0 + d.1 * dd.1) / su.powi(2) * ut;
649            let xt = d.0 * ut;
650            let yt = d.1 * ut;
651            let xtt = dd.0 * ut.powi(2) + d.0 * utt;
652            let ytt = dd.1 * ut.powi(2) + d.1 * utt;
653            (ytt * xt - xtt * yt) / linear_velocity.powi(2)
654        } else {
655            0.0
656        };
657
658        TrajectoryState {
659            x: pos.0,
660            y: pos.1,
661            theta: d.1.atan2(d.0),
662            linear_velocity,
663            angular_velocity,
664        }
665    }
666
667    /// Newton's method to find `u` such that arc_length(u) == target_s.
668    fn get_interp_param(&self, seg_id: usize, target_s: f64) -> f64 {
669        let seg = &self.path.segments[seg_id];
670        // Initial guess proportional to target fraction
671        let mut ui = if seg.segment_length > 1e-9 {
672            (target_s / seg.segment_length).clamp(0.0, 1.0)
673        } else {
674            0.0
675        };
676        let tol = 1e-3;
677        for _ in 0..50 {
678            let f = seg.arc_length(ui) - target_s;
679            if f.abs() < tol {
680                break;
681            }
682            let fp = seg.s_dot(ui);
683            ui -= f / fp;
684            ui = ui.clamp(0.0, 1.0);
685        }
686        ui
687    }
688
689    /// Sample the full trajectory at uniform time intervals.
690    ///
691    /// Returns a vector of `TrajectoryState` at `num_points` evenly spaced times
692    /// from 0 to `total_time`.
693    pub fn sample(&self, num_points: usize) -> Vec<TrajectoryState> {
694        let dt = self.total_time() / (num_points - 1).max(1) as f64;
695        (0..num_points)
696            .map(|i| self.calc_traj_point(dt * i as f64))
697            .collect()
698    }
699}
700
701// ---------------------------------------------------------------------------
702// Numerical integration (Gauss-Legendre 5-point, adaptive subdivision)
703// ---------------------------------------------------------------------------
704
705/// 5-point Gauss-Legendre quadrature on \[a, b\].
706fn gauss_legendre_5(f: impl Fn(f64) -> f64, a: f64, b: f64) -> f64 {
707    // Nodes and weights for [-1, 1]
708    const NODES: [f64; 5] = [
709        -0.906_179_845_938_664,
710        -0.538_469_310_105_683,
711        0.0,
712        0.538_469_310_105_683,
713        0.906_179_845_938_664,
714    ];
715    const WEIGHTS: [f64; 5] = [
716        0.236_926_885_056_189_1,
717        0.478_628_670_499_366_5,
718        0.568_888_888_888_889,
719        0.478_628_670_499_366_5,
720        0.236_926_885_056_189_1,
721    ];
722
723    let half = (b - a) / 2.0;
724    let mid = (a + b) / 2.0;
725    let mut sum = 0.0;
726    for i in 0..5 {
727        sum += WEIGHTS[i] * f(half * NODES[i] + mid);
728    }
729    sum * half
730}
731
732/// Adaptive Gauss-Legendre integration with subdivision.
733fn gauss_legendre_integrate(f: impl Fn(f64) -> f64, a: f64, b: f64) -> f64 {
734    // Use 16 sub-intervals for good accuracy on polynomial-like integrands
735    let n = 16;
736    let h = (b - a) / n as f64;
737    let mut total = 0.0;
738    for i in 0..n {
739        let lo = a + i as f64 * h;
740        let hi = lo + h;
741        total += gauss_legendre_5(&f, lo, hi);
742    }
743    total
744}
745
746#[cfg(test)]
747mod tests {
748    use super::*;
749    use std::f64::consts::PI;
750
751    fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
752        (a - b).abs() < tol
753    }
754
755    // -----------------------------------------------------------------------
756    // Path segment tests
757    // -----------------------------------------------------------------------
758
759    #[test]
760    fn test_segment_endpoints() {
761        let start = Pose2D::new(0.0, 0.0, 0.0);
762        let end = Pose2D::new(4.0, 3.0, 0.0);
763        let eta = EtaParams::new([4.0, 4.0, 0.0, 0.0, 0.0, 0.0]);
764        let kappa = KappaParams::zeros();
765        let seg = Eta3PathSegment::new(start, end, &eta, &kappa);
766
767        let p0 = seg.calc_point(0.0);
768        assert!(approx_eq(p0.0, 0.0, 1e-12));
769        assert!(approx_eq(p0.1, 0.0, 1e-12));
770
771        let p1 = seg.calc_point(1.0);
772        assert!(approx_eq(p1.0, 4.0, 1e-10));
773        assert!(approx_eq(p1.1, 3.0, 1e-10));
774    }
775
776    #[test]
777    fn test_segment_with_heading() {
778        let start = Pose2D::new(0.0, 0.0, PI / 4.0);
779        let end = Pose2D::new(5.0, 5.0, PI / 4.0);
780        let eta = EtaParams::new([5.0, 5.0, 0.0, 0.0, 0.0, 0.0]);
781        let kappa = KappaParams::zeros();
782        let seg = Eta3PathSegment::new(start, end, &eta, &kappa);
783
784        let p0 = seg.calc_point(0.0);
785        assert!(approx_eq(p0.0, 0.0, 1e-12));
786        assert!(approx_eq(p0.1, 0.0, 1e-12));
787
788        let p1 = seg.calc_point(1.0);
789        assert!(approx_eq(p1.0, 5.0, 1e-10));
790        assert!(approx_eq(p1.1, 5.0, 1e-10));
791    }
792
793    #[test]
794    fn test_segment_arc_length_positive() {
795        let start = Pose2D::new(0.0, 0.0, 0.0);
796        let end = Pose2D::new(4.0, 3.0, 0.0);
797        let eta = EtaParams::new([4.27, 4.27, 0.0, 0.0, 0.0, 0.0]);
798        let kappa = KappaParams::zeros();
799        let seg = Eta3PathSegment::new(start, end, &eta, &kappa);
800
801        assert!(seg.segment_length > 0.0);
802        // Arc length should be at least the straight-line distance
803        let straight = ((4.0f64).powi(2) + (3.0f64).powi(2)).sqrt();
804        assert!(seg.segment_length >= straight - 1e-6);
805    }
806
807    #[test]
808    fn test_segment_derivatives_at_start() {
809        let start = Pose2D::new(1.0, 2.0, 0.5);
810        let end = Pose2D::new(5.0, 4.0, 0.0);
811        let eta = EtaParams::new([3.0, 3.0, 0.0, 0.0, 0.0, 0.0]);
812        let kappa = KappaParams::zeros();
813        let seg = Eta3PathSegment::new(start, end, &eta, &kappa);
814
815        // At u=0, the first derivative direction should match the start heading
816        let d = seg.calc_first_deriv(0.0);
817        let heading = d.1.atan2(d.0);
818        assert!(approx_eq(heading, 0.5, 1e-10));
819    }
820
821    // -----------------------------------------------------------------------
822    // Multi-segment path tests
823    // -----------------------------------------------------------------------
824
825    #[test]
826    fn test_path_continuity() {
827        let segments = vec![
828            Eta3PathSegment::new(
829                Pose2D::new(0.0, 0.0, 0.0),
830                Pose2D::new(4.0, 1.5, 0.0),
831                &EtaParams::new([4.27, 4.27, 0.0, 0.0, 0.0, 0.0]),
832                &KappaParams::zeros(),
833            ),
834            Eta3PathSegment::new(
835                Pose2D::new(4.0, 1.5, 0.0),
836                Pose2D::new(5.5, 1.5, 0.0),
837                &EtaParams::zeros(),
838                &KappaParams::zeros(),
839            ),
840        ];
841        let path = Eta3Path::new(segments);
842
843        // At the junction u=1.0, path should be continuous
844        let p_end_seg0 = path.calc_path_point(1.0 - 1e-12);
845        let p_start_seg1 = path.calc_path_point(1.0);
846
847        assert!(approx_eq(p_end_seg0.0, p_start_seg1.0, 1e-4));
848        assert!(approx_eq(p_end_seg0.1, p_start_seg1.1, 1e-4));
849    }
850
851    #[test]
852    fn test_path_sample() {
853        let seg = Eta3PathSegment::new(
854            Pose2D::new(0.0, 0.0, 0.0),
855            Pose2D::new(4.0, 3.0, 0.0),
856            &EtaParams::new([4.0, 4.0, 0.0, 0.0, 0.0, 0.0]),
857            &KappaParams::zeros(),
858        );
859        let path = Eta3Path::new(vec![seg]);
860        let pts = path.sample(101);
861        assert_eq!(pts.len(), 101);
862
863        // First and last should match endpoints
864        assert!(approx_eq(pts[0].0, 0.0, 1e-12));
865        assert!(approx_eq(pts[0].1, 0.0, 1e-12));
866        assert!(approx_eq(pts[100].0, 4.0, 1e-10));
867        assert!(approx_eq(pts[100].1, 3.0, 1e-10));
868    }
869
870    #[test]
871    fn test_varying_eta_produces_different_paths() {
872        let start = Pose2D::new(0.0, 0.0, 0.0);
873        let end = Pose2D::new(4.0, 3.0, 0.0);
874        let kappa = KappaParams::zeros();
875
876        let seg1 = Eta3PathSegment::new(
877            start,
878            end,
879            &EtaParams::new([2.0, 2.0, 0.0, 0.0, 0.0, 0.0]),
880            &kappa,
881        );
882        let seg2 = Eta3PathSegment::new(
883            start,
884            end,
885            &EtaParams::new([2.0, 2.0, 5.0, 5.0, 0.0, 0.0]),
886            &kappa,
887        );
888
889        // Mid-point should differ when higher-order eta params change
890        let mid1 = seg1.calc_point(0.5);
891        let mid2 = seg2.calc_point(0.5);
892        let dist = ((mid1.0 - mid2.0).powi(2) + (mid1.1 - mid2.1).powi(2)).sqrt();
893        assert!(
894            dist > 1e-6,
895            "Different eta should produce different paths, dist={dist}"
896        );
897    }
898
899    #[test]
900    fn test_curvature_params_effect() {
901        let start = Pose2D::new(5.5, 1.5, 0.0);
902        let end = Pose2D::new(7.4377, 1.8235, 0.6667);
903        let eta = EtaParams::new([1.88, 1.88, 0.0, 0.0, 0.0, 0.0]);
904
905        let seg_no_curv = Eta3PathSegment::new(start, end, &eta, &KappaParams::zeros());
906        let seg_with_curv =
907            Eta3PathSegment::new(start, end, &eta, &KappaParams::new([0.0, 0.0, 1.0, 1.0]));
908
909        let mid1 = seg_no_curv.calc_point(0.5);
910        let mid2 = seg_with_curv.calc_point(0.5);
911        let dist = ((mid1.0 - mid2.0).powi(2) + (mid1.1 - mid2.1).powi(2)).sqrt();
912        assert!(dist > 1e-6, "Curvature params should affect the path shape");
913    }
914
915    // -----------------------------------------------------------------------
916    // Reference test (Table 1 from the paper, matches Python test3)
917    // -----------------------------------------------------------------------
918
919    #[test]
920    fn test_reference_multi_segment_path() {
921        let segments = vec![
922            // Lane-change
923            Eta3PathSegment::new(
924                Pose2D::new(0.0, 0.0, 0.0),
925                Pose2D::new(4.0, 1.5, 0.0),
926                &EtaParams::new([4.27, 4.27, 0.0, 0.0, 0.0, 0.0]),
927                &KappaParams::zeros(),
928            ),
929            // Line
930            Eta3PathSegment::new(
931                Pose2D::new(4.0, 1.5, 0.0),
932                Pose2D::new(5.5, 1.5, 0.0),
933                &EtaParams::zeros(),
934                &KappaParams::zeros(),
935            ),
936            // Cubic spiral
937            Eta3PathSegment::new(
938                Pose2D::new(5.5, 1.5, 0.0),
939                Pose2D::new(7.4377, 1.8235, 0.6667),
940                &EtaParams::new([1.88, 1.88, 0.0, 0.0, 0.0, 0.0]),
941                &KappaParams::new([0.0, 0.0, 1.0, 1.0]),
942            ),
943            // Generic twirl arc
944            Eta3PathSegment::new(
945                Pose2D::new(7.4377, 1.8235, 0.6667),
946                Pose2D::new(7.8, 4.3, 1.8),
947                &EtaParams::new([7.0, 10.0, 10.0, -10.0, 4.0, 4.0]),
948                &KappaParams::new([1.0, 1.0, 0.5, 0.0]),
949            ),
950            // Circular arc
951            Eta3PathSegment::new(
952                Pose2D::new(7.8, 4.3, 1.8),
953                Pose2D::new(5.4581, 5.8064, 3.3416),
954                &EtaParams::new([2.98, 2.98, 0.0, 0.0, 0.0, 0.0]),
955                &KappaParams::new([0.5, 0.0, 0.5, 0.0]),
956            ),
957        ];
958
959        let path = Eta3Path::new(segments);
960        let pts = path.sample(1001);
961
962        // Check endpoints
963        assert!(approx_eq(pts[0].0, 0.0, 1e-10));
964        assert!(approx_eq(pts[0].1, 0.0, 1e-10));
965        assert!(approx_eq(pts[1000].0, 5.4581, 1e-4));
966        assert!(approx_eq(pts[1000].1, 5.8064, 1e-4));
967
968        // Total path length should be reasonable (roughly 20 units)
969        let total_len = path.total_length();
970        assert!(total_len > 10.0 && total_len < 40.0);
971    }
972
973    // -----------------------------------------------------------------------
974    // Trajectory tests
975    // -----------------------------------------------------------------------
976
977    #[test]
978    fn test_trajectory_basic() {
979        let seg = Eta3PathSegment::new(
980            Pose2D::new(0.0, 0.0, 0.0),
981            Pose2D::new(4.0, 3.0, 0.0),
982            &EtaParams::new([4.27, 4.27, 0.0, 0.0, 0.0, 0.0]),
983            &KappaParams::zeros(),
984        );
985
986        let config = TrajectoryConfig::new(0.5, 0.5, 5.0);
987        let traj = Eta3Trajectory::new(vec![seg], config);
988
989        assert!(traj.total_time() > 0.0);
990
991        // At t=0, should be at start
992        let s0 = traj.calc_traj_point(0.0);
993        assert!(approx_eq(s0.x, 0.0, 1e-6));
994        assert!(approx_eq(s0.y, 0.0, 1e-6));
995        assert!(approx_eq(s0.linear_velocity, 0.0, 1e-6));
996
997        // At t=total_time, should be at end with zero velocity
998        let sf = traj.calc_traj_point(traj.total_time());
999        assert!(approx_eq(sf.x, 4.0, 0.1));
1000        assert!(approx_eq(sf.y, 3.0, 0.1));
1001        assert!(approx_eq(sf.linear_velocity, 0.0, 1e-6));
1002    }
1003
1004    #[test]
1005    fn test_trajectory_sample() {
1006        let seg = Eta3PathSegment::new(
1007            Pose2D::new(0.0, 0.0, 0.0),
1008            Pose2D::new(4.0, 3.0, 0.0),
1009            &EtaParams::new([4.0, 4.0, 0.0, 0.0, 0.0, 0.0]),
1010            &KappaParams::zeros(),
1011        );
1012
1013        let config = TrajectoryConfig::new(1.0, 0.5, 5.0);
1014        let traj = Eta3Trajectory::new(vec![seg], config);
1015        let states = traj.sample(101);
1016
1017        assert_eq!(states.len(), 101);
1018
1019        // Velocity should never exceed max_vel (with some numerical tolerance)
1020        for s in &states {
1021            assert!(s.linear_velocity <= 1.0 + 1e-6);
1022            assert!(s.linear_velocity >= -1e-6);
1023        }
1024    }
1025
1026    #[test]
1027    fn test_trajectory_monotonic_position() {
1028        // For a straight-ish segment, x should increase monotonically
1029        let seg = Eta3PathSegment::new(
1030            Pose2D::new(0.0, 0.0, 0.0),
1031            Pose2D::new(10.0, 0.0, 0.0),
1032            &EtaParams::new([5.0, 5.0, 0.0, 0.0, 0.0, 0.0]),
1033            &KappaParams::zeros(),
1034        );
1035
1036        let config = TrajectoryConfig::new(2.0, 1.0, 5.0);
1037        let traj = Eta3Trajectory::new(vec![seg], config);
1038        let states = traj.sample(200);
1039
1040        for i in 1..states.len() {
1041            assert!(
1042                states[i].x >= states[i - 1].x - 1e-6,
1043                "x should increase monotonically for a straight segment: x[{}]={} < x[{}]={}",
1044                i,
1045                states[i].x,
1046                i - 1,
1047                states[i - 1].x,
1048            );
1049        }
1050    }
1051
1052    // -----------------------------------------------------------------------
1053    // Integration test
1054    // -----------------------------------------------------------------------
1055
1056    #[test]
1057    fn test_gauss_legendre_accuracy() {
1058        // Integrate x^2 from 0 to 1, exact answer = 1/3
1059        let result = gauss_legendre_integrate(|x| x * x, 0.0, 1.0);
1060        assert!(approx_eq(result, 1.0 / 3.0, 1e-12));
1061
1062        // Integrate sin(x) from 0 to pi, exact answer = 2
1063        let result = gauss_legendre_integrate(|x| x.sin(), 0.0, PI);
1064        assert!(approx_eq(result, 2.0, 1e-10));
1065    }
1066}