Skip to main content

rust_robotics_planning/
dynamic_movement_primitives.rs

1//! Dynamic Movement Primitives (DMP)
2//!
3//! Learns a trajectory from demonstration data by modelling the forcing
4//! function as a weighted sum of Gaussian basis functions.  The learned
5//! weights can then be used to recreate the trajectory with different
6//! start/goal positions or durations.
7//!
8//! References:
9//! - <https://arxiv.org/abs/2102.03861>
10//! - <https://www.frontiersin.org/articles/10.3389/fncom.2013.00138/full>
11
12/// Configuration for the DMP learner.
13#[derive(Debug, Clone)]
14pub struct DmpConfig {
15    /// Virtual spring constant (K).
16    pub spring: f64,
17    /// Virtual damper coefficient (B).
18    pub damper: f64,
19    /// Number of Gaussian basis functions.
20    pub num_basis: usize,
21}
22
23impl Default for DmpConfig {
24    fn default() -> Self {
25        Self {
26            spring: 156.25,
27            damper: 25.0,
28            num_basis: 10,
29        }
30    }
31}
32
33/// A learned DMP model that can reproduce trajectories.
34#[derive(Debug, Clone)]
35pub struct Dmp {
36    /// Weight matrix: `weights\[dim\]\[basis_idx\]`.
37    weights: Vec<Vec<f64>>,
38    /// Number of spatial dimensions.
39    dimensions: usize,
40    /// Number of time-steps in the original demonstration.
41    timesteps: usize,
42    /// Time-step size derived from the training data.
43    dt: f64,
44    /// Spring constant.
45    spring: f64,
46    /// Damper coefficient.
47    damper: f64,
48    /// Number of basis functions.
49    num_basis: usize,
50}
51
52/// Result of trajectory recreation.
53#[derive(Debug, Clone)]
54pub struct DmpTrajectory {
55    /// Time values for each step.
56    pub time: Vec<f64>,
57    /// Positions at each step: `positions\[step\]\[dim\]`.
58    pub positions: Vec<Vec<f64>>,
59}
60
61impl Dmp {
62    /// Learn DMP weights from demonstration data.
63    ///
64    /// `training_data` is a slice of points where each inner slice is a
65    /// spatial position (e.g. `&[x, y]`).  `data_period` is the total time
66    /// the demonstration covers.
67    pub fn learn(training_data: &[Vec<f64>], data_period: f64, config: &DmpConfig) -> Self {
68        let timesteps = training_data.len();
69        assert!(timesteps >= 2, "need at least 2 data points");
70        let dimensions = training_data[0].len();
71        assert!(dimensions > 0, "need at least 1 dimension");
72
73        let dt = data_period / timesteps as f64;
74        let num_basis = config.num_basis;
75
76        // Centres and (shared) variance of Gaussian basis functions.
77        let centres: Vec<f64> = (0..num_basis)
78            .map(|i| i as f64 / (num_basis - 1).max(1) as f64)
79            .collect();
80        let h = 0.65 / ((1.0 / (num_basis as f64 - 1.0)).powi(2).max(1e-12));
81
82        let init_state = &training_data[0];
83        let goal_state = &training_data[timesteps - 1];
84
85        let mut all_weights: Vec<Vec<f64>> = Vec::with_capacity(dimensions);
86
87        for dim in 0..dimensions {
88            let q0 = init_state[dim];
89            let g = goal_state[dim];
90            let g_minus_q0 = g - q0;
91
92            let mut q = q0;
93            let mut qd_last = 0.0;
94
95            let mut phi_matrix: Vec<Vec<f64>> = Vec::with_capacity(timesteps);
96            let mut f_vals: Vec<f64> = Vec::with_capacity(timesteps);
97
98            for i in 0..timesteps {
99                let qd = if i + 1 < timesteps {
100                    (training_data[i + 1][dim] - training_data[i][dim]) / dt
101                } else {
102                    0.0
103                };
104
105                // Normalised basis function values.
106                let phase = i as f64 * dt / data_period;
107                let mut phi: Vec<f64> = centres
108                    .iter()
109                    .map(|&c| (-0.5 * (phase - c).powi(2) * h).exp())
110                    .collect();
111                let phi_sum: f64 = phi.iter().sum::<f64>().max(1e-12);
112                for v in &mut phi {
113                    *v /= phi_sum;
114                }
115
116                let qdd = (qd - qd_last) / dt;
117
118                let f = if g_minus_q0.abs() < 1e-12 {
119                    0.0
120                } else {
121                    (qdd * data_period.powi(2) - config.spring * (g - q)
122                        + config.damper * qd * data_period)
123                        / g_minus_q0
124                };
125
126                phi_matrix.push(phi);
127                f_vals.push(f);
128
129                qd_last = qd;
130                q += qd * dt;
131            }
132
133            // Solve least-squares:  phi_matrix * w = f_vals.
134            let w = lstsq(&phi_matrix, &f_vals, num_basis);
135            all_weights.push(w);
136        }
137
138        Self {
139            weights: all_weights,
140            dimensions,
141            timesteps,
142            dt,
143            spring: config.spring,
144            damper: config.damper,
145            num_basis,
146        }
147    }
148
149    /// Recreate a trajectory from the learned weights.
150    ///
151    /// * `init_state` — desired start position (one value per dimension).
152    /// * `goal_state` — desired goal position.
153    /// * `period`     — total time for the new trajectory.
154    pub fn recreate(&self, init_state: &[f64], goal_state: &[f64], period: f64) -> DmpTrajectory {
155        assert_eq!(init_state.len(), self.dimensions);
156        assert_eq!(goal_state.len(), self.dimensions);
157
158        let centres: Vec<f64> = (0..self.num_basis)
159            .map(|i| i as f64 / (self.num_basis - 1).max(1) as f64)
160            .collect();
161        let h = 0.65 / ((1.0 / (self.num_basis as f64 - 1.0)).powi(2).max(1e-12));
162
163        let mut q: Vec<f64> = init_state.to_vec();
164        let mut qd = vec![0.0; self.dimensions];
165
166        let mut time_vec = Vec::with_capacity(self.timesteps);
167        let mut positions = Vec::with_capacity(self.timesteps);
168        let mut time = 0.0;
169
170        for _ in 0..self.timesteps {
171            time += self.dt;
172
173            let mut qdd = vec![0.0; self.dimensions];
174
175            for dim in 0..self.dimensions {
176                let f = if time <= period {
177                    let phase = time / period;
178                    let mut phi: Vec<f64> = centres
179                        .iter()
180                        .map(|&c| (-0.5 * (phase - c).powi(2) * h).exp())
181                        .collect();
182                    let phi_sum: f64 = phi.iter().sum::<f64>().max(1e-12);
183                    for v in &mut phi {
184                        *v /= phi_sum;
185                    }
186                    phi.iter()
187                        .zip(self.weights[dim].iter())
188                        .map(|(p, w)| p * w)
189                        .sum::<f64>()
190                } else {
191                    0.0
192                };
193
194                qdd[dim] = self.spring * (goal_state[dim] - q[dim]) / period.powi(2)
195                    - self.damper * qd[dim] / period
196                    + (goal_state[dim] - init_state[dim]) * f / period.powi(2);
197            }
198
199            for dim in 0..self.dimensions {
200                qd[dim] += qdd[dim] * self.dt;
201                q[dim] += qd[dim] * self.dt;
202            }
203
204            time_vec.push(time);
205            positions.push(q.clone());
206        }
207
208        DmpTrajectory {
209            time: time_vec,
210            positions,
211        }
212    }
213
214    /// Number of spatial dimensions.
215    pub fn dimensions(&self) -> usize {
216        self.dimensions
217    }
218
219    /// Number of basis functions.
220    pub fn num_basis(&self) -> usize {
221        self.num_basis
222    }
223
224    /// Reference to the learned weight matrix.
225    pub fn weights(&self) -> &[Vec<f64>] {
226        &self.weights
227    }
228}
229
230// ------------------------------------------------------------------
231// Minimal least-squares solver (normal equations).
232// ------------------------------------------------------------------
233
234/// Solve `A * x = b` in the least-squares sense via normal equations.
235/// `A` is `(m x n)`, `b` is `(m,)`, returns `x` of length `n`.
236fn lstsq(a: &[Vec<f64>], b: &[f64], n: usize) -> Vec<f64> {
237    let m = a.len();
238    assert_eq!(b.len(), m);
239
240    // A^T A  (n x n)
241    let mut ata = vec![vec![0.0; n]; n];
242    // A^T b  (n,)
243    let mut atb = vec![0.0; n];
244
245    for row in 0..m {
246        for j in 0..n {
247            atb[j] += a[row][j] * b[row];
248            for k in j..n {
249                let v = a[row][j] * a[row][k];
250                ata[j][k] += v;
251                if k != j {
252                    ata[k][j] += v;
253                }
254            }
255        }
256    }
257
258    // Add small regularisation for numerical stability.
259    for (i, ata_row) in ata.iter_mut().enumerate() {
260        ata_row[i] += 1e-10;
261    }
262
263    // Solve via Cholesky-like Gaussian elimination (symmetric positive-definite).
264    solve_symmetric(&mut ata, &mut atb)
265}
266
267/// Solve a symmetric positive-definite system in-place via Gaussian elimination.
268#[allow(clippy::needless_range_loop)]
269fn solve_symmetric(a: &mut [Vec<f64>], b: &mut [f64]) -> Vec<f64> {
270    let n = b.len();
271    // Forward elimination.
272    for col in 0..n {
273        let pivot = a[col][col];
274        for row in (col + 1)..n {
275            let factor = a[row][col] / pivot;
276            for k in col..n {
277                a[row][k] -= factor * a[col][k];
278            }
279            b[row] -= factor * b[col];
280        }
281    }
282    // Back substitution.
283    let mut x = vec![0.0; n];
284    for i in (0..n).rev() {
285        let mut s = b[i];
286        for j in (i + 1)..n {
287            s -= a[i][j] * x[j];
288        }
289        x[i] = s / a[i][i];
290    }
291    x
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    /// Generate a simple sine-wave demonstration in 2-D.
299    fn sine_demo() -> (Vec<Vec<f64>>, f64) {
300        let n = 200;
301        let period = 2.0 * std::f64::consts::PI;
302        let dt = period / n as f64;
303        let data: Vec<Vec<f64>> = (0..n)
304            .map(|i| {
305                let t = i as f64 * dt;
306                vec![t, t.sin()]
307            })
308            .collect();
309        (data, period)
310    }
311
312    #[test]
313    fn test_dmp_learn_dimensions() {
314        let (data, period) = sine_demo();
315        let config = DmpConfig::default();
316        let dmp = Dmp::learn(&data, period, &config);
317        assert_eq!(dmp.dimensions(), 2);
318        assert_eq!(dmp.num_basis(), 10);
319        assert_eq!(dmp.weights().len(), 2);
320    }
321
322    #[test]
323    fn test_recreate_same_endpoints() {
324        let (data, period) = sine_demo();
325        let config = DmpConfig::default();
326        let dmp = Dmp::learn(&data, period, &config);
327
328        let init = &data[0];
329        let goal = &data[data.len() - 1];
330        let traj = dmp.recreate(init, goal, period);
331
332        assert_eq!(traj.positions.len(), data.len());
333
334        // First position should be close to init (after one dt step).
335        let first = &traj.positions[0];
336        assert!(
337            (first[0] - init[0]).abs() < 1.0,
338            "first x too far from init"
339        );
340
341        // Last position should converge close to goal.
342        let last = &traj.positions[traj.positions.len() - 1];
343        assert!(
344            (last[0] - goal[0]).abs() < 2.0,
345            "last x too far from goal: {} vs {}",
346            last[0],
347            goal[0]
348        );
349        assert!(
350            (last[1] - goal[1]).abs() < 2.0,
351            "last y too far from goal: {} vs {}",
352            last[1],
353            goal[1]
354        );
355    }
356
357    #[test]
358    fn test_recreate_shifted_goal() {
359        let (data, period) = sine_demo();
360        let config = DmpConfig::default();
361        let dmp = Dmp::learn(&data, period, &config);
362
363        let init = data[0].clone();
364        let mut goal = data[data.len() - 1].clone();
365        goal[1] += 2.0; // shift goal upward
366
367        let traj = dmp.recreate(&init, &goal, period);
368        let last = &traj.positions[traj.positions.len() - 1];
369        // Should converge towards the new goal, not the original.
370        assert!(
371            (last[1] - goal[1]).abs() < 2.0,
372            "shifted goal: last y = {}, goal y = {}",
373            last[1],
374            goal[1]
375        );
376    }
377
378    #[test]
379    fn test_recreate_different_period() {
380        let (data, period) = sine_demo();
381        let config = DmpConfig::default();
382        let dmp = Dmp::learn(&data, period, &config);
383
384        let init = data[0].clone();
385        let goal = data[data.len() - 1].clone();
386
387        let traj_fast = dmp.recreate(&init, &goal, period * 0.5);
388        let traj_slow = dmp.recreate(&init, &goal, period * 2.0);
389
390        // Both should have same number of steps (same timesteps / dt).
391        assert_eq!(traj_fast.positions.len(), traj_slow.positions.len());
392
393        // With a shorter period the forcing function shuts off earlier,
394        // so the trajectory shapes should differ.
395        let mid = traj_fast.positions.len() / 2;
396        let diff = (traj_fast.positions[mid][1] - traj_slow.positions[mid][1]).abs();
397        assert!(
398            diff > 1e-6,
399            "trajectories with different periods should differ at midpoint"
400        );
401    }
402
403    #[test]
404    fn test_1d_trajectory() {
405        // Simple 1-D ramp.
406        let data: Vec<Vec<f64>> = (0..50).map(|i| vec![i as f64 * 0.1]).collect();
407        let period = 5.0;
408        let config = DmpConfig::default();
409        let dmp = Dmp::learn(&data, period, &config);
410        assert_eq!(dmp.dimensions(), 1);
411
412        let traj = dmp.recreate(&[0.0], &[4.9], period);
413        let last = traj.positions.last().unwrap();
414        assert!(
415            (last[0] - 4.9).abs() < 2.0,
416            "1-D ramp end: {} vs 4.9",
417            last[0]
418        );
419    }
420
421    #[test]
422    fn test_lstsq_identity() {
423        // Trivial system: I * x = b  =>  x = b.
424        let a = vec![
425            vec![1.0, 0.0, 0.0],
426            vec![0.0, 1.0, 0.0],
427            vec![0.0, 0.0, 1.0],
428        ];
429        let b = vec![3.0, 5.0, 7.0];
430        let x = lstsq(&a, &b, 3);
431        for i in 0..3 {
432            assert!((x[i] - b[i]).abs() < 1e-6, "lstsq identity failed at {i}");
433        }
434    }
435}