Skip to main content

rust_robotics_planning/
elastic_bands.rs

1//! Elastic Bands path deformation algorithm
2//!
3//! Deforms an initial path using internal contraction forces and external
4//! repulsive forces computed from a distance field. Each waypoint is
5//! surrounded by a "bubble" whose radius equals the clearance to the nearest
6//! obstacle. The bubble chain is maintained by inserting / deleting bubbles
7//! to satisfy an overlap constraint.
8//!
9//! Reference:
10//! - Elastic Bands: Connecting Path Planning and Control
11//!   <http://www8.cs.umu.se/research/ifor/dl/Control/elastic%20bands.pdf>
12
13/// A single bubble in the elastic band.
14#[derive(Debug, Clone)]
15pub struct Bubble {
16    /// Centre position \[x, y\].
17    pub pos: [f64; 2],
18    /// Safety radius (distance to nearest obstacle, clamped).
19    pub radius: f64,
20}
21
22/// Configuration parameters for the [`ElasticBands`] planner.
23#[derive(Debug, Clone)]
24pub struct ElasticBandsConfig {
25    /// Maximum allowed bubble radius.
26    pub max_bubble_radius: f64,
27    /// Minimum allowed bubble radius.
28    pub min_bubble_radius: f64,
29    /// Distance threshold for repulsive force (rho0).
30    pub rho0: f64,
31    /// Contraction (internal) force gain.
32    pub kc: f64,
33    /// Repulsive (external) force gain (typically negative).
34    pub kr: f64,
35    /// Overlap constraint factor (lambda).
36    pub lambda: f64,
37    /// Finite-difference step size for gradient estimation.
38    pub step_size: f64,
39    /// Maximum number of optimisation iterations.
40    pub max_iter: usize,
41}
42
43impl Default for ElasticBandsConfig {
44    fn default() -> Self {
45        Self {
46            max_bubble_radius: 100.0,
47            min_bubble_radius: 10.0,
48            rho0: 20.0,
49            kc: 0.05,
50            kr: -0.1,
51            lambda: 0.7,
52            step_size: 3.0,
53            max_iter: 50,
54        }
55    }
56}
57
58/// A signed-distance field represented as a 2-D grid.
59///
60/// `sdf\[ix\]\[iy\]` stores the distance to the nearest obstacle for the cell
61/// at integer coordinates `(ix, iy)`.  Values are positive outside obstacles.
62pub struct DistanceField {
63    data: Vec<Vec<f64>>,
64    width: usize,
65    height: usize,
66}
67
68impl DistanceField {
69    /// Build a distance field from a binary obstacle grid.
70    ///
71    /// `obstacles` is a row-major `width x height` grid where `true` means
72    /// occupied.  The returned field stores the Euclidean distance to the
73    /// nearest occupied cell (approximated via a two-pass sweeping algorithm).
74    pub fn from_obstacle_grid(obstacles: &[Vec<bool>]) -> Self {
75        let width = obstacles.len();
76        assert!(width > 0, "obstacle grid must be non-empty");
77        let height = obstacles[0].len();
78        assert!(height > 0, "obstacle grid rows must be non-empty");
79
80        let large = (width + height) as f64;
81        let mut dist = vec![vec![large; height]; width];
82
83        // Initialise: 0 at obstacles, large elsewhere.
84        for ix in 0..width {
85            for iy in 0..height {
86                if obstacles[ix][iy] {
87                    dist[ix][iy] = 0.0;
88                }
89            }
90        }
91
92        // Forward pass (top-left to bottom-right).
93        for ix in 0..width {
94            for iy in 0..height {
95                if ix > 0 {
96                    dist[ix][iy] = dist[ix][iy].min(dist[ix - 1][iy] + 1.0);
97                }
98                if iy > 0 {
99                    dist[ix][iy] = dist[ix][iy].min(dist[ix][iy - 1] + 1.0);
100                }
101                if ix > 0 && iy > 0 {
102                    dist[ix][iy] =
103                        dist[ix][iy].min(dist[ix - 1][iy - 1] + std::f64::consts::SQRT_2);
104                }
105                if ix > 0 && iy + 1 < height {
106                    dist[ix][iy] =
107                        dist[ix][iy].min(dist[ix - 1][iy + 1] + std::f64::consts::SQRT_2);
108                }
109            }
110        }
111
112        // Backward pass (bottom-right to top-left).
113        for ix in (0..width).rev() {
114            for iy in (0..height).rev() {
115                if ix + 1 < width {
116                    dist[ix][iy] = dist[ix][iy].min(dist[ix + 1][iy] + 1.0);
117                }
118                if iy + 1 < height {
119                    dist[ix][iy] = dist[ix][iy].min(dist[ix][iy + 1] + 1.0);
120                }
121                if ix + 1 < width && iy + 1 < height {
122                    dist[ix][iy] =
123                        dist[ix][iy].min(dist[ix + 1][iy + 1] + std::f64::consts::SQRT_2);
124                }
125                if ix + 1 < width && iy > 0 {
126                    dist[ix][iy] =
127                        dist[ix][iy].min(dist[ix + 1][iy - 1] + std::f64::consts::SQRT_2);
128                }
129            }
130        }
131
132        Self {
133            data: dist,
134            width,
135            height,
136        }
137    }
138
139    /// Query the distance at a continuous position (nearest-cell lookup).
140    pub fn query(&self, x: f64, y: f64) -> f64 {
141        let ix = (x.round() as isize).clamp(0, self.width as isize - 1) as usize;
142        let iy = (y.round() as isize).clamp(0, self.height as isize - 1) as usize;
143        self.data[ix][iy]
144    }
145
146    pub fn width(&self) -> usize {
147        self.width
148    }
149
150    pub fn height(&self) -> usize {
151        self.height
152    }
153}
154
155/// Elastic Bands planner.
156///
157/// Iteratively deforms a path (represented as a chain of [`Bubble`]s) to be
158/// shorter and further from obstacles.
159pub struct ElasticBands {
160    /// Current chain of bubbles.
161    pub bubbles: Vec<Bubble>,
162    config: ElasticBandsConfig,
163    distance_field: DistanceField,
164}
165
166impl ElasticBands {
167    /// Create a new elastic band from an initial path and a pre-computed
168    /// distance field.
169    pub fn new(
170        initial_path: &[[f64; 2]],
171        distance_field: DistanceField,
172        config: ElasticBandsConfig,
173    ) -> Self {
174        let bubbles: Vec<Bubble> = initial_path
175            .iter()
176            .map(|p| {
177                let r = distance_field.query(p[0], p[1]);
178                Bubble {
179                    pos: *p,
180                    radius: r.clamp(config.min_bubble_radius, config.max_bubble_radius),
181                }
182            })
183            .collect();
184
185        let mut band = Self {
186            bubbles,
187            config,
188            distance_field,
189        };
190        band.maintain_overlap();
191        band
192    }
193
194    /// Create from an initial path and a binary obstacle grid.
195    pub fn from_obstacles(
196        initial_path: &[[f64; 2]],
197        obstacles: &[Vec<bool>],
198        config: ElasticBandsConfig,
199    ) -> Self {
200        let df = DistanceField::from_obstacle_grid(obstacles);
201        Self::new(initial_path, df, config)
202    }
203
204    /// Run the full optimisation for `config.max_iter` iterations.
205    pub fn optimise(&mut self) {
206        for _ in 0..self.config.max_iter {
207            self.update_bubbles();
208        }
209    }
210
211    /// Perform a single optimisation step: move bubbles then maintain overlap.
212    pub fn update_bubbles(&mut self) {
213        let n = self.bubbles.len();
214        let mut new_bubbles = Vec::with_capacity(n);
215
216        for i in 0..n {
217            if i == 0 || i == n - 1 {
218                new_bubbles.push(self.bubbles[i].clone());
219                continue;
220            }
221
222            let fc = self.contraction_force(i);
223            let fr = self.repulsive_force(i);
224            let f_total = [fc[0] + fr[0], fc[1] + fr[1]];
225
226            // Direction between neighbours (for tangent removal).
227            let v = [
228                self.bubbles[i - 1].pos[0] - self.bubbles[i + 1].pos[0],
229                self.bubbles[i - 1].pos[1] - self.bubbles[i + 1].pos[1],
230            ];
231            let v_norm_sq = v[0] * v[0] + v[1] * v[1] + 1e-6;
232
233            // Remove tangential component.
234            let f_dot_v = f_total[0] * v[0] + f_total[1] * v[1];
235            let f_star = [
236                f_total[0] - f_dot_v * v[0] / v_norm_sq,
237                f_total[1] - f_dot_v * v[1] / v_norm_sq,
238            ];
239
240            let alpha = self.bubbles[i].radius;
241            let max_x = (self.distance_field.width() as f64 - 1.0).max(0.0);
242            let max_y = (self.distance_field.height() as f64 - 1.0).max(0.0);
243            let new_x = (self.bubbles[i].pos[0] + alpha * f_star[0]).clamp(0.0, max_x);
244            let new_y = (self.bubbles[i].pos[1] + alpha * f_star[1]).clamp(0.0, max_y);
245
246            let r = self.distance_field.query(new_x, new_y);
247            new_bubbles.push(Bubble {
248                pos: [new_x, new_y],
249                radius: r.clamp(self.config.min_bubble_radius, self.config.max_bubble_radius),
250            });
251        }
252
253        self.bubbles = new_bubbles;
254        self.maintain_overlap();
255    }
256
257    /// Extract the optimised path as a vector of `[x, y]` points.
258    pub fn path(&self) -> Vec<[f64; 2]> {
259        self.bubbles.iter().map(|b| b.pos).collect()
260    }
261
262    // ------------------------------------------------------------------
263    // Internal helpers
264    // ------------------------------------------------------------------
265
266    fn contraction_force(&self, i: usize) -> [f64; 2] {
267        if i == 0 || i == self.bubbles.len() - 1 {
268            return [0.0, 0.0];
269        }
270        let prev = &self.bubbles[i - 1].pos;
271        let next = &self.bubbles[i + 1].pos;
272        let cur = &self.bubbles[i].pos;
273
274        let dp = [prev[0] - cur[0], prev[1] - cur[1]];
275        let dn = [next[0] - cur[0], next[1] - cur[1]];
276        let dp_len = (dp[0] * dp[0] + dp[1] * dp[1]).sqrt() + 1e-6;
277        let dn_len = (dn[0] * dn[0] + dn[1] * dn[1]).sqrt() + 1e-6;
278
279        [
280            self.config.kc * (dp[0] / dp_len + dn[0] / dn_len),
281            self.config.kc * (dp[1] / dp_len + dn[1] / dn_len),
282        ]
283    }
284
285    fn repulsive_force(&self, i: usize) -> [f64; 2] {
286        let b = &self.bubbles[i];
287        if b.radius >= self.config.rho0 {
288            return [0.0, 0.0];
289        }
290
291        let h = self.config.step_size;
292        let grad_x = (self.distance_field.query(b.pos[0] - h, b.pos[1])
293            - self.distance_field.query(b.pos[0] + h, b.pos[1]))
294            / (2.0 * h);
295        let grad_y = (self.distance_field.query(b.pos[0], b.pos[1] - h)
296            - self.distance_field.query(b.pos[0], b.pos[1] + h))
297            / (2.0 * h);
298
299        let scale = self.config.kr * (self.config.rho0 - b.radius);
300        [scale * grad_x, scale * grad_y]
301    }
302
303    fn maintain_overlap(&mut self) {
304        // Insert bubbles where neighbours are too far apart.
305        let mut i = 0;
306        while i < self.bubbles.len().saturating_sub(1) {
307            let dist = dist2d(&self.bubbles[i].pos, &self.bubbles[i + 1].pos);
308            let threshold =
309                self.config.lambda * (self.bubbles[i].radius + self.bubbles[i + 1].radius);
310            if dist > threshold {
311                let mid = [
312                    (self.bubbles[i].pos[0] + self.bubbles[i + 1].pos[0]) * 0.5,
313                    (self.bubbles[i].pos[1] + self.bubbles[i + 1].pos[1]) * 0.5,
314                ];
315                let r = self
316                    .distance_field
317                    .query(mid[0], mid[1])
318                    .clamp(self.config.min_bubble_radius, self.config.max_bubble_radius);
319                self.bubbles.insert(
320                    i + 1,
321                    Bubble {
322                        pos: mid,
323                        radius: r,
324                    },
325                );
326                i += 2;
327            } else {
328                i += 1;
329            }
330        }
331
332        // Delete redundant bubbles.
333        let mut i = 1;
334        while i < self.bubbles.len().saturating_sub(1) {
335            let prev = &self.bubbles[i - 1];
336            let next = &self.bubbles[i + 1];
337            let dist = dist2d(&prev.pos, &next.pos);
338            let threshold = self.config.lambda * (prev.radius + next.radius);
339            if dist <= threshold {
340                self.bubbles.remove(i);
341            } else {
342                i += 1;
343            }
344        }
345    }
346}
347
348fn dist2d(a: &[f64; 2], b: &[f64; 2]) -> f64 {
349    ((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2)).sqrt()
350}
351
352#[cfg(test)]
353mod tests {
354    use super::*;
355
356    /// Helper: create a 100x100 grid with a rectangular obstacle block.
357    fn make_obstacle_grid() -> Vec<Vec<bool>> {
358        let (w, h) = (100, 100);
359        let mut grid = vec![vec![false; h]; w];
360        // Block from (40,35) to (60,65).
361        for row in grid.iter_mut().take(60).skip(40) {
362            for cell in row.iter_mut().take(65).skip(35) {
363                *cell = true;
364            }
365        }
366        grid
367    }
368
369    #[test]
370    fn test_distance_field_basic() {
371        let grid = make_obstacle_grid();
372        let df = DistanceField::from_obstacle_grid(&grid);
373        // Inside obstacle should be 0.
374        assert_eq!(df.query(50.0, 50.0), 0.0);
375        // Far from obstacle should be positive.
376        assert!(df.query(5.0, 5.0) > 10.0);
377    }
378
379    #[test]
380    fn test_bubble_creation() {
381        let config = ElasticBandsConfig::default();
382        let grid = make_obstacle_grid();
383        let df = DistanceField::from_obstacle_grid(&grid);
384        let r = df.query(10.0, 10.0);
385        let bubble = Bubble {
386            pos: [10.0, 10.0],
387            radius: r.clamp(config.min_bubble_radius, config.max_bubble_radius),
388        };
389        assert!(bubble.radius >= config.min_bubble_radius);
390        assert!(bubble.radius <= config.max_bubble_radius);
391    }
392
393    #[test]
394    fn test_elastic_bands_path_shortens() {
395        let grid = make_obstacle_grid();
396        let config = ElasticBandsConfig {
397            max_iter: 30,
398            ..Default::default()
399        };
400
401        // Deliberately detoured path that goes above the obstacle.
402        let initial_path = vec![
403            [10.0, 50.0],
404            [20.0, 80.0],
405            [35.0, 90.0],
406            [50.0, 85.0],
407            [65.0, 90.0],
408            [80.0, 80.0],
409            [90.0, 50.0],
410        ];
411
412        let path_len =
413            |pts: &[[f64; 2]]| -> f64 { pts.windows(2).map(|w| dist2d(&w[0], &w[1])).sum() };
414
415        let initial_length = path_len(&initial_path);
416
417        let mut band = ElasticBands::from_obstacles(&initial_path, &grid, config);
418        band.optimise();
419        let optimised = band.path();
420
421        let final_length = path_len(&optimised);
422        // The contraction force should shorten the path.
423        assert!(
424            final_length < initial_length,
425            "path should shorten: initial {initial_length:.1}, final {final_length:.1}"
426        );
427    }
428
429    #[test]
430    fn test_elastic_bands_endpoints_fixed() {
431        let grid = make_obstacle_grid();
432        let config = ElasticBandsConfig::default();
433
434        let initial_path = vec![[10.0, 50.0], [50.0, 80.0], [90.0, 50.0]];
435        let mut band = ElasticBands::from_obstacles(&initial_path, &grid, config);
436        band.optimise();
437
438        let path = band.path();
439        assert!(
440            (path.first().unwrap()[0] - 10.0).abs() < 1e-9,
441            "start x should remain fixed"
442        );
443        assert!(
444            (path.last().unwrap()[0] - 90.0).abs() < 1e-9,
445            "end x should remain fixed"
446        );
447    }
448
449    #[test]
450    fn test_maintain_overlap_inserts_bubbles() {
451        let grid = vec![vec![false; 100]; 100];
452        let config = ElasticBandsConfig::default();
453
454        // Two points very far apart — overlap maintenance should insert bubbles.
455        let initial_path = vec![[0.0, 0.0], [99.0, 99.0]];
456        let band = ElasticBands::from_obstacles(&initial_path, &grid, config);
457        assert!(
458            band.bubbles.len() > 2,
459            "bubbles should be inserted between distant endpoints"
460        );
461    }
462
463    #[test]
464    fn test_no_obstacle_path_contracts() {
465        // With no obstacles the repulsive force is zero and the contraction
466        // force should pull the path straighter.
467        let grid = vec![vec![false; 100]; 100];
468        let config = ElasticBandsConfig {
469            max_iter: 20,
470            ..Default::default()
471        };
472
473        let initial_path = vec![
474            [10.0, 10.0],
475            [30.0, 50.0],
476            [50.0, 10.0],
477            [70.0, 50.0],
478            [90.0, 10.0],
479        ];
480        let initial_length: f64 = initial_path.windows(2).map(|w| dist2d(&w[0], &w[1])).sum();
481
482        let mut band = ElasticBands::from_obstacles(&initial_path, &grid, config);
483        band.optimise();
484
485        let final_length: f64 = band.path().windows(2).map(|w| dist2d(&w[0], &w[1])).sum();
486        assert!(final_length < initial_length);
487    }
488}