Skip to main content

rust_robotics_planning/
prm_star.rs

1#![allow(dead_code, clippy::too_many_arguments)]
2
3//! Probabilistic Road-Map Star (PRM*) path planning.
4//!
5//! Uses the asymptotically optimal connection radius:
6//! `r_n = gamma * (log(n) / n)^(1/d)` with `d = 2`.
7
8use rand::Rng;
9use std::cmp::Ordering;
10use std::collections::{BinaryHeap, HashMap};
11
12/// Configuration for PRM*.
13#[derive(Debug, Clone)]
14pub struct PRMStarConfig {
15    /// Number of sampled free-space points (excluding start/goal).
16    pub n_samples: usize,
17    /// Robot radius used for collision checking.
18    pub robot_radius: f64,
19    /// Radius scale factor in PRM* formula.
20    pub gamma: f64,
21}
22
23impl Default for PRMStarConfig {
24    fn default() -> Self {
25        Self {
26            n_samples: 500,
27            robot_radius: 0.8,
28            gamma: 2.5,
29        }
30    }
31}
32
33impl PRMStarConfig {
34    /// Validate PRM* configuration.
35    pub fn validate(&self) -> Result<(), String> {
36        if self.n_samples == 0 {
37            return Err("PRM* requires at least one sample".to_string());
38        }
39        if !self.robot_radius.is_finite() || self.robot_radius <= 0.0 {
40            return Err("PRM* robot_radius must be positive and finite".to_string());
41        }
42        if !self.gamma.is_finite() || self.gamma <= 0.0 {
43            return Err("PRM* gamma must be positive and finite".to_string());
44        }
45        Ok(())
46    }
47}
48
49#[derive(Clone)]
50struct Node {
51    x: f64,
52    y: f64,
53    cost: f64,
54    parent: Option<usize>,
55}
56
57impl Node {
58    fn new(x: f64, y: f64) -> Self {
59        Self {
60            x,
61            y,
62            cost: f64::INFINITY,
63            parent: None,
64        }
65    }
66}
67
68#[derive(Clone)]
69struct QueueItem {
70    cost: f64,
71    index: usize,
72}
73
74impl PartialEq for QueueItem {
75    fn eq(&self, other: &Self) -> bool {
76        self.cost == other.cost
77    }
78}
79
80impl Eq for QueueItem {}
81
82impl Ord for QueueItem {
83    fn cmp(&self, other: &Self) -> Ordering {
84        other
85            .cost
86            .partial_cmp(&self.cost)
87            .unwrap_or(Ordering::Equal)
88    }
89}
90
91impl PartialOrd for QueueItem {
92    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
93        Some(self.cmp(other))
94    }
95}
96
97struct KDTree {
98    points: Vec<(f64, f64)>,
99}
100
101impl KDTree {
102    fn new(points: Vec<(f64, f64)>) -> Self {
103        Self { points }
104    }
105
106    fn query_radius(&self, x: f64, y: f64, radius: f64) -> Vec<(usize, f64)> {
107        let r2 = radius * radius;
108        self.points
109            .iter()
110            .enumerate()
111            .filter_map(|(i, (px, py))| {
112                let dx = x - px;
113                let dy = y - py;
114                let d2 = dx * dx + dy * dy;
115                if d2 <= r2 {
116                    Some((i, d2.sqrt()))
117                } else {
118                    None
119                }
120            })
121            .collect()
122    }
123
124    fn min_distance(&self, x: f64, y: f64) -> f64 {
125        self.points
126            .iter()
127            .map(|(px, py)| ((x - px).powi(2) + (y - py).powi(2)).sqrt())
128            .fold(f64::INFINITY, f64::min)
129    }
130}
131
132/// PRM* planner.
133pub struct PRMStarPlanner {
134    sample_x: Vec<f64>,
135    sample_y: Vec<f64>,
136    road_map: Vec<Vec<usize>>,
137    connection_radius: f64,
138}
139
140impl PRMStarPlanner {
141    /// Create a new PRM* planner.
142    pub fn new(
143        ox: &[f64],
144        oy: &[f64],
145        start: (f64, f64),
146        goal: (f64, f64),
147        config: PRMStarConfig,
148    ) -> Self {
149        config.validate().expect(
150            "invalid PRM* configuration: n_samples > 0, robot_radius > 0, gamma > 0 required",
151        );
152
153        let obstacle_tree = KDTree::new(ox.iter().zip(oy.iter()).map(|(&x, &y)| (x, y)).collect());
154        let min_x = ox.iter().copied().fold(f64::INFINITY, f64::min);
155        let max_x = ox.iter().copied().fold(f64::NEG_INFINITY, f64::max);
156        let min_y = oy.iter().copied().fold(f64::INFINITY, f64::min);
157        let max_y = oy.iter().copied().fold(f64::NEG_INFINITY, f64::max);
158
159        let (sample_x, sample_y) = Self::sample_points(
160            start,
161            goal,
162            min_x,
163            max_x,
164            min_y,
165            max_y,
166            config.n_samples,
167            config.robot_radius,
168            &obstacle_tree,
169        );
170
171        let workspace_scale = ((max_x - min_x).powi(2) + (max_y - min_y).powi(2))
172            .sqrt()
173            .max(1.0);
174        let connection_radius =
175            Self::compute_connection_radius(sample_x.len(), workspace_scale, config.gamma);
176        let road_map = Self::generate_road_map(
177            &sample_x,
178            &sample_y,
179            config.robot_radius,
180            connection_radius,
181            &obstacle_tree,
182        );
183
184        Self {
185            sample_x,
186            sample_y,
187            road_map,
188            connection_radius,
189        }
190    }
191
192    fn compute_connection_radius(n: usize, workspace_scale: f64, gamma: f64) -> f64 {
193        let n_f = n as f64;
194        let radius_normalized = gamma * (n_f.ln() / n_f).sqrt();
195        (radius_normalized * workspace_scale).max(1e-3)
196    }
197
198    fn sample_points(
199        start: (f64, f64),
200        goal: (f64, f64),
201        min_x: f64,
202        max_x: f64,
203        min_y: f64,
204        max_y: f64,
205        n_samples: usize,
206        robot_radius: f64,
207        obstacle_tree: &KDTree,
208    ) -> (Vec<f64>, Vec<f64>) {
209        let mut rng = rand::rng();
210        let mut sample_x = Vec::with_capacity(n_samples + 2);
211        let mut sample_y = Vec::with_capacity(n_samples + 2);
212
213        while sample_x.len() < n_samples {
214            let x = rng.random_range(min_x..max_x);
215            let y = rng.random_range(min_y..max_y);
216            if obstacle_tree.min_distance(x, y) > robot_radius {
217                sample_x.push(x);
218                sample_y.push(y);
219            }
220        }
221
222        sample_x.push(start.0);
223        sample_y.push(start.1);
224        sample_x.push(goal.0);
225        sample_y.push(goal.1);
226
227        (sample_x, sample_y)
228    }
229
230    fn generate_road_map(
231        sample_x: &[f64],
232        sample_y: &[f64],
233        robot_radius: f64,
234        connection_radius: f64,
235        obstacle_tree: &KDTree,
236    ) -> Vec<Vec<usize>> {
237        let sample_tree = KDTree::new(
238            sample_x
239                .iter()
240                .zip(sample_y.iter())
241                .map(|(&x, &y)| (x, y))
242                .collect(),
243        );
244        let mut road_map: Vec<Vec<usize>> = vec![Vec::new(); sample_x.len()];
245
246        for (i, (&x, &y)) in sample_x.iter().zip(sample_y.iter()).enumerate() {
247            for (j, dist) in sample_tree.query_radius(x, y, connection_radius) {
248                if i == j {
249                    continue;
250                }
251                if !Self::is_collision(x, y, sample_x[j], sample_y[j], robot_radius, obstacle_tree)
252                {
253                    road_map[i].push(j);
254                } else {
255                    let _ = dist;
256                }
257            }
258        }
259
260        road_map
261    }
262
263    fn is_collision(
264        x1: f64,
265        y1: f64,
266        x2: f64,
267        y2: f64,
268        robot_radius: f64,
269        obstacle_tree: &KDTree,
270    ) -> bool {
271        let dx = x2 - x1;
272        let dy = y2 - y1;
273        let d = (dx * dx + dy * dy).sqrt();
274
275        if d == 0.0 {
276            return false;
277        }
278
279        let step = robot_radius;
280        let n_steps = (d / step).ceil() as usize;
281        for i in 0..=n_steps {
282            let t = i as f64 / n_steps as f64;
283            let x = x1 + t * dx;
284            let y = y1 + t * dy;
285            if obstacle_tree.min_distance(x, y) <= robot_radius {
286                return true;
287            }
288        }
289        false
290    }
291
292    /// Plan path using Dijkstra's algorithm.
293    pub fn plan(&self) -> Option<(Vec<f64>, Vec<f64>)> {
294        let n = self.sample_x.len();
295        let start_idx = n - 2;
296        let goal_idx = n - 1;
297
298        let mut nodes: Vec<Node> = self
299            .sample_x
300            .iter()
301            .zip(self.sample_y.iter())
302            .map(|(&x, &y)| Node::new(x, y))
303            .collect();
304        nodes[start_idx].cost = 0.0;
305
306        let mut open_set = BinaryHeap::new();
307        open_set.push(QueueItem {
308            cost: 0.0,
309            index: start_idx,
310        });
311        let mut closed_set: HashMap<usize, bool> = HashMap::new();
312
313        while let Some(current) = open_set.pop() {
314            if current.index == goal_idx {
315                return Some(self.reconstruct_path(&nodes, goal_idx));
316            }
317            if closed_set.contains_key(&current.index) {
318                continue;
319            }
320            closed_set.insert(current.index, true);
321
322            for &neighbor_idx in &self.road_map[current.index] {
323                if closed_set.contains_key(&neighbor_idx) {
324                    continue;
325                }
326                let dx = nodes[neighbor_idx].x - nodes[current.index].x;
327                let dy = nodes[neighbor_idx].y - nodes[current.index].y;
328                let edge_cost = (dx * dx + dy * dy).sqrt();
329                let new_cost = nodes[current.index].cost + edge_cost;
330                if new_cost < nodes[neighbor_idx].cost {
331                    nodes[neighbor_idx].cost = new_cost;
332                    nodes[neighbor_idx].parent = Some(current.index);
333                    open_set.push(QueueItem {
334                        cost: new_cost,
335                        index: neighbor_idx,
336                    });
337                }
338            }
339        }
340
341        None
342    }
343
344    fn reconstruct_path(&self, nodes: &[Node], goal_idx: usize) -> (Vec<f64>, Vec<f64>) {
345        let mut path_x = Vec::new();
346        let mut path_y = Vec::new();
347        let mut current = goal_idx;
348
349        while let Some(parent) = nodes[current].parent {
350            path_x.push(nodes[current].x);
351            path_y.push(nodes[current].y);
352            current = parent;
353        }
354        path_x.push(nodes[current].x);
355        path_y.push(nodes[current].y);
356        path_x.reverse();
357        path_y.reverse();
358
359        (path_x, path_y)
360    }
361
362    /// Return sample points.
363    pub fn get_samples(&self) -> (&[f64], &[f64]) {
364        (&self.sample_x, &self.sample_y)
365    }
366
367    /// Return current connection radius.
368    pub fn connection_radius(&self) -> f64 {
369        self.connection_radius
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    fn rectangular_walls(size: usize) -> (Vec<f64>, Vec<f64>) {
378        let mut ox = Vec::new();
379        let mut oy = Vec::new();
380        for i in 0..=size {
381            let v = i as f64;
382            ox.push(v);
383            oy.push(0.0);
384            ox.push(v);
385            oy.push(size as f64);
386            ox.push(0.0);
387            oy.push(v);
388            ox.push(size as f64);
389            oy.push(v);
390        }
391        (ox, oy)
392    }
393
394    fn path_length(xs: &[f64], ys: &[f64]) -> f64 {
395        xs.windows(2)
396            .zip(ys.windows(2))
397            .map(|(wx, wy)| {
398                let dx = wx[1] - wx[0];
399                let dy = wy[1] - wy[0];
400                (dx * dx + dy * dy).sqrt()
401            })
402            .sum()
403    }
404
405    #[test]
406    fn test_prm_star_finds_path() {
407        let (ox, oy) = rectangular_walls(30);
408        let config = PRMStarConfig {
409            n_samples: 450,
410            robot_radius: 0.8,
411            gamma: 2.5,
412        };
413        let planner = PRMStarPlanner::new(&ox, &oy, (2.0, 2.0), (28.0, 28.0), config);
414
415        let path = planner.plan();
416        assert!(path.is_some(), "PRM* should find a path in free interior");
417
418        let (px, py) = path.unwrap();
419        assert_eq!(px.len(), py.len());
420        assert!(px.len() >= 2);
421    }
422
423    #[test]
424    fn test_prm_star_path_quality_improves_with_more_samples() {
425        let (ox, oy) = rectangular_walls(20);
426        let low_cfg = PRMStarConfig {
427            n_samples: 60,
428            robot_radius: 0.8,
429            gamma: 2.5,
430        };
431        let high_cfg = PRMStarConfig {
432            n_samples: 200,
433            robot_radius: 0.8,
434            gamma: 2.5,
435        };
436
437        let planner_low = PRMStarPlanner::new(&ox, &oy, (2.0, 2.0), (18.0, 18.0), low_cfg.clone());
438        let low_result = planner_low.plan();
439
440        let planner_high =
441            PRMStarPlanner::new(&ox, &oy, (2.0, 2.0), (18.0, 18.0), high_cfg.clone());
442        let high_result = planner_high.plan();
443
444        // At least one should find a path
445        assert!(
446            low_result.is_some() || high_result.is_some(),
447            "at least one configuration should find a path"
448        );
449    }
450
451    #[test]
452    fn test_prm_star_config_defaults() {
453        let config = PRMStarConfig::default();
454        assert_eq!(config.n_samples, 500);
455        assert!((config.robot_radius - 0.8).abs() < f64::EPSILON);
456        assert!((config.gamma - 2.5).abs() < f64::EPSILON);
457    }
458}