1#[derive(Debug, Clone)]
15pub struct Bubble {
16 pub pos: [f64; 2],
18 pub radius: f64,
20}
21
22#[derive(Debug, Clone)]
24pub struct ElasticBandsConfig {
25 pub max_bubble_radius: f64,
27 pub min_bubble_radius: f64,
29 pub rho0: f64,
31 pub kc: f64,
33 pub kr: f64,
35 pub lambda: f64,
37 pub step_size: f64,
39 pub max_iter: usize,
41}
42
43impl Default for ElasticBandsConfig {
44 fn default() -> Self {
45 Self {
46 max_bubble_radius: 100.0,
47 min_bubble_radius: 10.0,
48 rho0: 20.0,
49 kc: 0.05,
50 kr: -0.1,
51 lambda: 0.7,
52 step_size: 3.0,
53 max_iter: 50,
54 }
55 }
56}
57
58pub struct DistanceField {
63 data: Vec<Vec<f64>>,
64 width: usize,
65 height: usize,
66}
67
68impl DistanceField {
69 pub fn from_obstacle_grid(obstacles: &[Vec<bool>]) -> Self {
75 let width = obstacles.len();
76 assert!(width > 0, "obstacle grid must be non-empty");
77 let height = obstacles[0].len();
78 assert!(height > 0, "obstacle grid rows must be non-empty");
79
80 let large = (width + height) as f64;
81 let mut dist = vec![vec![large; height]; width];
82
83 for ix in 0..width {
85 for iy in 0..height {
86 if obstacles[ix][iy] {
87 dist[ix][iy] = 0.0;
88 }
89 }
90 }
91
92 for ix in 0..width {
94 for iy in 0..height {
95 if ix > 0 {
96 dist[ix][iy] = dist[ix][iy].min(dist[ix - 1][iy] + 1.0);
97 }
98 if iy > 0 {
99 dist[ix][iy] = dist[ix][iy].min(dist[ix][iy - 1] + 1.0);
100 }
101 if ix > 0 && iy > 0 {
102 dist[ix][iy] =
103 dist[ix][iy].min(dist[ix - 1][iy - 1] + std::f64::consts::SQRT_2);
104 }
105 if ix > 0 && iy + 1 < height {
106 dist[ix][iy] =
107 dist[ix][iy].min(dist[ix - 1][iy + 1] + std::f64::consts::SQRT_2);
108 }
109 }
110 }
111
112 for ix in (0..width).rev() {
114 for iy in (0..height).rev() {
115 if ix + 1 < width {
116 dist[ix][iy] = dist[ix][iy].min(dist[ix + 1][iy] + 1.0);
117 }
118 if iy + 1 < height {
119 dist[ix][iy] = dist[ix][iy].min(dist[ix][iy + 1] + 1.0);
120 }
121 if ix + 1 < width && iy + 1 < height {
122 dist[ix][iy] =
123 dist[ix][iy].min(dist[ix + 1][iy + 1] + std::f64::consts::SQRT_2);
124 }
125 if ix + 1 < width && iy > 0 {
126 dist[ix][iy] =
127 dist[ix][iy].min(dist[ix + 1][iy - 1] + std::f64::consts::SQRT_2);
128 }
129 }
130 }
131
132 Self {
133 data: dist,
134 width,
135 height,
136 }
137 }
138
139 pub fn query(&self, x: f64, y: f64) -> f64 {
141 let ix = (x.round() as isize).clamp(0, self.width as isize - 1) as usize;
142 let iy = (y.round() as isize).clamp(0, self.height as isize - 1) as usize;
143 self.data[ix][iy]
144 }
145
146 pub fn width(&self) -> usize {
147 self.width
148 }
149
150 pub fn height(&self) -> usize {
151 self.height
152 }
153}
154
155pub struct ElasticBands {
160 pub bubbles: Vec<Bubble>,
162 config: ElasticBandsConfig,
163 distance_field: DistanceField,
164}
165
166impl ElasticBands {
167 pub fn new(
170 initial_path: &[[f64; 2]],
171 distance_field: DistanceField,
172 config: ElasticBandsConfig,
173 ) -> Self {
174 let bubbles: Vec<Bubble> = initial_path
175 .iter()
176 .map(|p| {
177 let r = distance_field.query(p[0], p[1]);
178 Bubble {
179 pos: *p,
180 radius: r.clamp(config.min_bubble_radius, config.max_bubble_radius),
181 }
182 })
183 .collect();
184
185 let mut band = Self {
186 bubbles,
187 config,
188 distance_field,
189 };
190 band.maintain_overlap();
191 band
192 }
193
194 pub fn from_obstacles(
196 initial_path: &[[f64; 2]],
197 obstacles: &[Vec<bool>],
198 config: ElasticBandsConfig,
199 ) -> Self {
200 let df = DistanceField::from_obstacle_grid(obstacles);
201 Self::new(initial_path, df, config)
202 }
203
204 pub fn optimise(&mut self) {
206 for _ in 0..self.config.max_iter {
207 self.update_bubbles();
208 }
209 }
210
211 pub fn update_bubbles(&mut self) {
213 let n = self.bubbles.len();
214 let mut new_bubbles = Vec::with_capacity(n);
215
216 for i in 0..n {
217 if i == 0 || i == n - 1 {
218 new_bubbles.push(self.bubbles[i].clone());
219 continue;
220 }
221
222 let fc = self.contraction_force(i);
223 let fr = self.repulsive_force(i);
224 let f_total = [fc[0] + fr[0], fc[1] + fr[1]];
225
226 let v = [
228 self.bubbles[i - 1].pos[0] - self.bubbles[i + 1].pos[0],
229 self.bubbles[i - 1].pos[1] - self.bubbles[i + 1].pos[1],
230 ];
231 let v_norm_sq = v[0] * v[0] + v[1] * v[1] + 1e-6;
232
233 let f_dot_v = f_total[0] * v[0] + f_total[1] * v[1];
235 let f_star = [
236 f_total[0] - f_dot_v * v[0] / v_norm_sq,
237 f_total[1] - f_dot_v * v[1] / v_norm_sq,
238 ];
239
240 let alpha = self.bubbles[i].radius;
241 let max_x = (self.distance_field.width() as f64 - 1.0).max(0.0);
242 let max_y = (self.distance_field.height() as f64 - 1.0).max(0.0);
243 let new_x = (self.bubbles[i].pos[0] + alpha * f_star[0]).clamp(0.0, max_x);
244 let new_y = (self.bubbles[i].pos[1] + alpha * f_star[1]).clamp(0.0, max_y);
245
246 let r = self.distance_field.query(new_x, new_y);
247 new_bubbles.push(Bubble {
248 pos: [new_x, new_y],
249 radius: r.clamp(self.config.min_bubble_radius, self.config.max_bubble_radius),
250 });
251 }
252
253 self.bubbles = new_bubbles;
254 self.maintain_overlap();
255 }
256
257 pub fn path(&self) -> Vec<[f64; 2]> {
259 self.bubbles.iter().map(|b| b.pos).collect()
260 }
261
262 fn contraction_force(&self, i: usize) -> [f64; 2] {
267 if i == 0 || i == self.bubbles.len() - 1 {
268 return [0.0, 0.0];
269 }
270 let prev = &self.bubbles[i - 1].pos;
271 let next = &self.bubbles[i + 1].pos;
272 let cur = &self.bubbles[i].pos;
273
274 let dp = [prev[0] - cur[0], prev[1] - cur[1]];
275 let dn = [next[0] - cur[0], next[1] - cur[1]];
276 let dp_len = (dp[0] * dp[0] + dp[1] * dp[1]).sqrt() + 1e-6;
277 let dn_len = (dn[0] * dn[0] + dn[1] * dn[1]).sqrt() + 1e-6;
278
279 [
280 self.config.kc * (dp[0] / dp_len + dn[0] / dn_len),
281 self.config.kc * (dp[1] / dp_len + dn[1] / dn_len),
282 ]
283 }
284
285 fn repulsive_force(&self, i: usize) -> [f64; 2] {
286 let b = &self.bubbles[i];
287 if b.radius >= self.config.rho0 {
288 return [0.0, 0.0];
289 }
290
291 let h = self.config.step_size;
292 let grad_x = (self.distance_field.query(b.pos[0] - h, b.pos[1])
293 - self.distance_field.query(b.pos[0] + h, b.pos[1]))
294 / (2.0 * h);
295 let grad_y = (self.distance_field.query(b.pos[0], b.pos[1] - h)
296 - self.distance_field.query(b.pos[0], b.pos[1] + h))
297 / (2.0 * h);
298
299 let scale = self.config.kr * (self.config.rho0 - b.radius);
300 [scale * grad_x, scale * grad_y]
301 }
302
303 fn maintain_overlap(&mut self) {
304 let mut i = 0;
306 while i < self.bubbles.len().saturating_sub(1) {
307 let dist = dist2d(&self.bubbles[i].pos, &self.bubbles[i + 1].pos);
308 let threshold =
309 self.config.lambda * (self.bubbles[i].radius + self.bubbles[i + 1].radius);
310 if dist > threshold {
311 let mid = [
312 (self.bubbles[i].pos[0] + self.bubbles[i + 1].pos[0]) * 0.5,
313 (self.bubbles[i].pos[1] + self.bubbles[i + 1].pos[1]) * 0.5,
314 ];
315 let r = self
316 .distance_field
317 .query(mid[0], mid[1])
318 .clamp(self.config.min_bubble_radius, self.config.max_bubble_radius);
319 self.bubbles.insert(
320 i + 1,
321 Bubble {
322 pos: mid,
323 radius: r,
324 },
325 );
326 i += 2;
327 } else {
328 i += 1;
329 }
330 }
331
332 let mut i = 1;
334 while i < self.bubbles.len().saturating_sub(1) {
335 let prev = &self.bubbles[i - 1];
336 let next = &self.bubbles[i + 1];
337 let dist = dist2d(&prev.pos, &next.pos);
338 let threshold = self.config.lambda * (prev.radius + next.radius);
339 if dist <= threshold {
340 self.bubbles.remove(i);
341 } else {
342 i += 1;
343 }
344 }
345 }
346}
347
348fn dist2d(a: &[f64; 2], b: &[f64; 2]) -> f64 {
349 ((a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2)).sqrt()
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 fn make_obstacle_grid() -> Vec<Vec<bool>> {
358 let (w, h) = (100, 100);
359 let mut grid = vec![vec![false; h]; w];
360 for row in grid.iter_mut().take(60).skip(40) {
362 for cell in row.iter_mut().take(65).skip(35) {
363 *cell = true;
364 }
365 }
366 grid
367 }
368
369 #[test]
370 fn test_distance_field_basic() {
371 let grid = make_obstacle_grid();
372 let df = DistanceField::from_obstacle_grid(&grid);
373 assert_eq!(df.query(50.0, 50.0), 0.0);
375 assert!(df.query(5.0, 5.0) > 10.0);
377 }
378
379 #[test]
380 fn test_bubble_creation() {
381 let config = ElasticBandsConfig::default();
382 let grid = make_obstacle_grid();
383 let df = DistanceField::from_obstacle_grid(&grid);
384 let r = df.query(10.0, 10.0);
385 let bubble = Bubble {
386 pos: [10.0, 10.0],
387 radius: r.clamp(config.min_bubble_radius, config.max_bubble_radius),
388 };
389 assert!(bubble.radius >= config.min_bubble_radius);
390 assert!(bubble.radius <= config.max_bubble_radius);
391 }
392
393 #[test]
394 fn test_elastic_bands_path_shortens() {
395 let grid = make_obstacle_grid();
396 let config = ElasticBandsConfig {
397 max_iter: 30,
398 ..Default::default()
399 };
400
401 let initial_path = vec![
403 [10.0, 50.0],
404 [20.0, 80.0],
405 [35.0, 90.0],
406 [50.0, 85.0],
407 [65.0, 90.0],
408 [80.0, 80.0],
409 [90.0, 50.0],
410 ];
411
412 let path_len =
413 |pts: &[[f64; 2]]| -> f64 { pts.windows(2).map(|w| dist2d(&w[0], &w[1])).sum() };
414
415 let initial_length = path_len(&initial_path);
416
417 let mut band = ElasticBands::from_obstacles(&initial_path, &grid, config);
418 band.optimise();
419 let optimised = band.path();
420
421 let final_length = path_len(&optimised);
422 assert!(
424 final_length < initial_length,
425 "path should shorten: initial {initial_length:.1}, final {final_length:.1}"
426 );
427 }
428
429 #[test]
430 fn test_elastic_bands_endpoints_fixed() {
431 let grid = make_obstacle_grid();
432 let config = ElasticBandsConfig::default();
433
434 let initial_path = vec![[10.0, 50.0], [50.0, 80.0], [90.0, 50.0]];
435 let mut band = ElasticBands::from_obstacles(&initial_path, &grid, config);
436 band.optimise();
437
438 let path = band.path();
439 assert!(
440 (path.first().unwrap()[0] - 10.0).abs() < 1e-9,
441 "start x should remain fixed"
442 );
443 assert!(
444 (path.last().unwrap()[0] - 90.0).abs() < 1e-9,
445 "end x should remain fixed"
446 );
447 }
448
449 #[test]
450 fn test_maintain_overlap_inserts_bubbles() {
451 let grid = vec![vec![false; 100]; 100];
452 let config = ElasticBandsConfig::default();
453
454 let initial_path = vec![[0.0, 0.0], [99.0, 99.0]];
456 let band = ElasticBands::from_obstacles(&initial_path, &grid, config);
457 assert!(
458 band.bubbles.len() > 2,
459 "bubbles should be inserted between distant endpoints"
460 );
461 }
462
463 #[test]
464 fn test_no_obstacle_path_contracts() {
465 let grid = vec![vec![false; 100]; 100];
468 let config = ElasticBandsConfig {
469 max_iter: 20,
470 ..Default::default()
471 };
472
473 let initial_path = vec![
474 [10.0, 10.0],
475 [30.0, 50.0],
476 [50.0, 10.0],
477 [70.0, 50.0],
478 [90.0, 10.0],
479 ];
480 let initial_length: f64 = initial_path.windows(2).map(|w| dist2d(&w[0], &w[1])).sum();
481
482 let mut band = ElasticBands::from_obstacles(&initial_path, &grid, config);
483 band.optimise();
484
485 let final_length: f64 = band.path().windows(2).map(|w| dist2d(&w[0], &w[1])).sum();
486 assert!(final_length < initial_length);
487 }
488}