1#[derive(Debug, Clone)]
14pub struct DmpConfig {
15 pub spring: f64,
17 pub damper: f64,
19 pub num_basis: usize,
21}
22
23impl Default for DmpConfig {
24 fn default() -> Self {
25 Self {
26 spring: 156.25,
27 damper: 25.0,
28 num_basis: 10,
29 }
30 }
31}
32
33#[derive(Debug, Clone)]
35pub struct Dmp {
36 weights: Vec<Vec<f64>>,
38 dimensions: usize,
40 timesteps: usize,
42 dt: f64,
44 spring: f64,
46 damper: f64,
48 num_basis: usize,
50}
51
52#[derive(Debug, Clone)]
54pub struct DmpTrajectory {
55 pub time: Vec<f64>,
57 pub positions: Vec<Vec<f64>>,
59}
60
61impl Dmp {
62 pub fn learn(training_data: &[Vec<f64>], data_period: f64, config: &DmpConfig) -> Self {
68 let timesteps = training_data.len();
69 assert!(timesteps >= 2, "need at least 2 data points");
70 let dimensions = training_data[0].len();
71 assert!(dimensions > 0, "need at least 1 dimension");
72
73 let dt = data_period / timesteps as f64;
74 let num_basis = config.num_basis;
75
76 let centres: Vec<f64> = (0..num_basis)
78 .map(|i| i as f64 / (num_basis - 1).max(1) as f64)
79 .collect();
80 let h = 0.65 / ((1.0 / (num_basis as f64 - 1.0)).powi(2).max(1e-12));
81
82 let init_state = &training_data[0];
83 let goal_state = &training_data[timesteps - 1];
84
85 let mut all_weights: Vec<Vec<f64>> = Vec::with_capacity(dimensions);
86
87 for dim in 0..dimensions {
88 let q0 = init_state[dim];
89 let g = goal_state[dim];
90 let g_minus_q0 = g - q0;
91
92 let mut q = q0;
93 let mut qd_last = 0.0;
94
95 let mut phi_matrix: Vec<Vec<f64>> = Vec::with_capacity(timesteps);
96 let mut f_vals: Vec<f64> = Vec::with_capacity(timesteps);
97
98 for i in 0..timesteps {
99 let qd = if i + 1 < timesteps {
100 (training_data[i + 1][dim] - training_data[i][dim]) / dt
101 } else {
102 0.0
103 };
104
105 let phase = i as f64 * dt / data_period;
107 let mut phi: Vec<f64> = centres
108 .iter()
109 .map(|&c| (-0.5 * (phase - c).powi(2) * h).exp())
110 .collect();
111 let phi_sum: f64 = phi.iter().sum::<f64>().max(1e-12);
112 for v in &mut phi {
113 *v /= phi_sum;
114 }
115
116 let qdd = (qd - qd_last) / dt;
117
118 let f = if g_minus_q0.abs() < 1e-12 {
119 0.0
120 } else {
121 (qdd * data_period.powi(2) - config.spring * (g - q)
122 + config.damper * qd * data_period)
123 / g_minus_q0
124 };
125
126 phi_matrix.push(phi);
127 f_vals.push(f);
128
129 qd_last = qd;
130 q += qd * dt;
131 }
132
133 let w = lstsq(&phi_matrix, &f_vals, num_basis);
135 all_weights.push(w);
136 }
137
138 Self {
139 weights: all_weights,
140 dimensions,
141 timesteps,
142 dt,
143 spring: config.spring,
144 damper: config.damper,
145 num_basis,
146 }
147 }
148
149 pub fn recreate(&self, init_state: &[f64], goal_state: &[f64], period: f64) -> DmpTrajectory {
155 assert_eq!(init_state.len(), self.dimensions);
156 assert_eq!(goal_state.len(), self.dimensions);
157
158 let centres: Vec<f64> = (0..self.num_basis)
159 .map(|i| i as f64 / (self.num_basis - 1).max(1) as f64)
160 .collect();
161 let h = 0.65 / ((1.0 / (self.num_basis as f64 - 1.0)).powi(2).max(1e-12));
162
163 let mut q: Vec<f64> = init_state.to_vec();
164 let mut qd = vec![0.0; self.dimensions];
165
166 let mut time_vec = Vec::with_capacity(self.timesteps);
167 let mut positions = Vec::with_capacity(self.timesteps);
168 let mut time = 0.0;
169
170 for _ in 0..self.timesteps {
171 time += self.dt;
172
173 let mut qdd = vec![0.0; self.dimensions];
174
175 for dim in 0..self.dimensions {
176 let f = if time <= period {
177 let phase = time / period;
178 let mut phi: Vec<f64> = centres
179 .iter()
180 .map(|&c| (-0.5 * (phase - c).powi(2) * h).exp())
181 .collect();
182 let phi_sum: f64 = phi.iter().sum::<f64>().max(1e-12);
183 for v in &mut phi {
184 *v /= phi_sum;
185 }
186 phi.iter()
187 .zip(self.weights[dim].iter())
188 .map(|(p, w)| p * w)
189 .sum::<f64>()
190 } else {
191 0.0
192 };
193
194 qdd[dim] = self.spring * (goal_state[dim] - q[dim]) / period.powi(2)
195 - self.damper * qd[dim] / period
196 + (goal_state[dim] - init_state[dim]) * f / period.powi(2);
197 }
198
199 for dim in 0..self.dimensions {
200 qd[dim] += qdd[dim] * self.dt;
201 q[dim] += qd[dim] * self.dt;
202 }
203
204 time_vec.push(time);
205 positions.push(q.clone());
206 }
207
208 DmpTrajectory {
209 time: time_vec,
210 positions,
211 }
212 }
213
214 pub fn dimensions(&self) -> usize {
216 self.dimensions
217 }
218
219 pub fn num_basis(&self) -> usize {
221 self.num_basis
222 }
223
224 pub fn weights(&self) -> &[Vec<f64>] {
226 &self.weights
227 }
228}
229
230fn lstsq(a: &[Vec<f64>], b: &[f64], n: usize) -> Vec<f64> {
237 let m = a.len();
238 assert_eq!(b.len(), m);
239
240 let mut ata = vec![vec![0.0; n]; n];
242 let mut atb = vec![0.0; n];
244
245 for row in 0..m {
246 for j in 0..n {
247 atb[j] += a[row][j] * b[row];
248 for k in j..n {
249 let v = a[row][j] * a[row][k];
250 ata[j][k] += v;
251 if k != j {
252 ata[k][j] += v;
253 }
254 }
255 }
256 }
257
258 for (i, ata_row) in ata.iter_mut().enumerate() {
260 ata_row[i] += 1e-10;
261 }
262
263 solve_symmetric(&mut ata, &mut atb)
265}
266
267#[allow(clippy::needless_range_loop)]
269fn solve_symmetric(a: &mut [Vec<f64>], b: &mut [f64]) -> Vec<f64> {
270 let n = b.len();
271 for col in 0..n {
273 let pivot = a[col][col];
274 for row in (col + 1)..n {
275 let factor = a[row][col] / pivot;
276 for k in col..n {
277 a[row][k] -= factor * a[col][k];
278 }
279 b[row] -= factor * b[col];
280 }
281 }
282 let mut x = vec![0.0; n];
284 for i in (0..n).rev() {
285 let mut s = b[i];
286 for j in (i + 1)..n {
287 s -= a[i][j] * x[j];
288 }
289 x[i] = s / a[i][i];
290 }
291 x
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 fn sine_demo() -> (Vec<Vec<f64>>, f64) {
300 let n = 200;
301 let period = 2.0 * std::f64::consts::PI;
302 let dt = period / n as f64;
303 let data: Vec<Vec<f64>> = (0..n)
304 .map(|i| {
305 let t = i as f64 * dt;
306 vec![t, t.sin()]
307 })
308 .collect();
309 (data, period)
310 }
311
312 #[test]
313 fn test_dmp_learn_dimensions() {
314 let (data, period) = sine_demo();
315 let config = DmpConfig::default();
316 let dmp = Dmp::learn(&data, period, &config);
317 assert_eq!(dmp.dimensions(), 2);
318 assert_eq!(dmp.num_basis(), 10);
319 assert_eq!(dmp.weights().len(), 2);
320 }
321
322 #[test]
323 fn test_recreate_same_endpoints() {
324 let (data, period) = sine_demo();
325 let config = DmpConfig::default();
326 let dmp = Dmp::learn(&data, period, &config);
327
328 let init = &data[0];
329 let goal = &data[data.len() - 1];
330 let traj = dmp.recreate(init, goal, period);
331
332 assert_eq!(traj.positions.len(), data.len());
333
334 let first = &traj.positions[0];
336 assert!(
337 (first[0] - init[0]).abs() < 1.0,
338 "first x too far from init"
339 );
340
341 let last = &traj.positions[traj.positions.len() - 1];
343 assert!(
344 (last[0] - goal[0]).abs() < 2.0,
345 "last x too far from goal: {} vs {}",
346 last[0],
347 goal[0]
348 );
349 assert!(
350 (last[1] - goal[1]).abs() < 2.0,
351 "last y too far from goal: {} vs {}",
352 last[1],
353 goal[1]
354 );
355 }
356
357 #[test]
358 fn test_recreate_shifted_goal() {
359 let (data, period) = sine_demo();
360 let config = DmpConfig::default();
361 let dmp = Dmp::learn(&data, period, &config);
362
363 let init = data[0].clone();
364 let mut goal = data[data.len() - 1].clone();
365 goal[1] += 2.0; let traj = dmp.recreate(&init, &goal, period);
368 let last = &traj.positions[traj.positions.len() - 1];
369 assert!(
371 (last[1] - goal[1]).abs() < 2.0,
372 "shifted goal: last y = {}, goal y = {}",
373 last[1],
374 goal[1]
375 );
376 }
377
378 #[test]
379 fn test_recreate_different_period() {
380 let (data, period) = sine_demo();
381 let config = DmpConfig::default();
382 let dmp = Dmp::learn(&data, period, &config);
383
384 let init = data[0].clone();
385 let goal = data[data.len() - 1].clone();
386
387 let traj_fast = dmp.recreate(&init, &goal, period * 0.5);
388 let traj_slow = dmp.recreate(&init, &goal, period * 2.0);
389
390 assert_eq!(traj_fast.positions.len(), traj_slow.positions.len());
392
393 let mid = traj_fast.positions.len() / 2;
396 let diff = (traj_fast.positions[mid][1] - traj_slow.positions[mid][1]).abs();
397 assert!(
398 diff > 1e-6,
399 "trajectories with different periods should differ at midpoint"
400 );
401 }
402
403 #[test]
404 fn test_1d_trajectory() {
405 let data: Vec<Vec<f64>> = (0..50).map(|i| vec![i as f64 * 0.1]).collect();
407 let period = 5.0;
408 let config = DmpConfig::default();
409 let dmp = Dmp::learn(&data, period, &config);
410 assert_eq!(dmp.dimensions(), 1);
411
412 let traj = dmp.recreate(&[0.0], &[4.9], period);
413 let last = traj.positions.last().unwrap();
414 assert!(
415 (last[0] - 4.9).abs() < 2.0,
416 "1-D ramp end: {} vs 4.9",
417 last[0]
418 );
419 }
420
421 #[test]
422 fn test_lstsq_identity() {
423 let a = vec![
425 vec![1.0, 0.0, 0.0],
426 vec![0.0, 1.0, 0.0],
427 vec![0.0, 0.0, 1.0],
428 ];
429 let b = vec![3.0, 5.0, 7.0];
430 let x = lstsq(&a, &b, 3);
431 for i in 0..3 {
432 assert!((x[i] - b[i]).abs() < 1e-6, "lstsq identity failed at {i}");
433 }
434 }
435}