Skip to main content

rust_robotics_planning/
prm.rs

1#![allow(dead_code, clippy::too_many_arguments)]
2
3//! Probabilistic Road-Map (PRM) path planning
4//!
5//! Sampling-based planner that builds a roadmap of collision-free
6//! configurations and searches for a path using Dijkstra's algorithm.
7
8use rand::Rng;
9use std::cmp::Ordering;
10use std::collections::{BinaryHeap, HashMap};
11
12// Parameters
13const N_SAMPLE: usize = 500;
14const N_KNN: usize = 10;
15const MAX_EDGE_LEN: f64 = 30.0;
16
17/// Node for Dijkstra search
18#[derive(Clone)]
19struct Node {
20    x: f64,
21    y: f64,
22    cost: f64,
23    parent: Option<usize>,
24}
25
26impl Node {
27    fn new(x: f64, y: f64) -> Self {
28        Node {
29            x,
30            y,
31            cost: f64::INFINITY,
32            parent: None,
33        }
34    }
35}
36
37/// Priority queue item for Dijkstra
38#[derive(Clone)]
39struct QueueItem {
40    cost: f64,
41    index: usize,
42}
43
44impl PartialEq for QueueItem {
45    fn eq(&self, other: &Self) -> bool {
46        self.cost == other.cost
47    }
48}
49
50impl Eq for QueueItem {}
51
52impl Ord for QueueItem {
53    fn cmp(&self, other: &Self) -> Ordering {
54        other
55            .cost
56            .partial_cmp(&self.cost)
57            .unwrap_or(Ordering::Equal)
58    }
59}
60
61impl PartialOrd for QueueItem {
62    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
63        Some(self.cmp(other))
64    }
65}
66
67/// Simple KDTree for nearest neighbor search
68struct KDTree {
69    points: Vec<(f64, f64)>,
70}
71
72impl KDTree {
73    fn new(points: Vec<(f64, f64)>) -> Self {
74        KDTree { points }
75    }
76
77    fn query_knn(&self, x: f64, y: f64, k: usize) -> Vec<(usize, f64)> {
78        let mut distances: Vec<(usize, f64)> = self
79            .points
80            .iter()
81            .enumerate()
82            .map(|(i, (px, py))| {
83                let d = ((x - px).powi(2) + (y - py).powi(2)).sqrt();
84                (i, d)
85            })
86            .collect();
87
88        distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
89        distances.truncate(k + 1);
90        distances
91    }
92
93    fn min_distance(&self, x: f64, y: f64) -> f64 {
94        self.points
95            .iter()
96            .map(|(px, py)| ((x - px).powi(2) + (y - py).powi(2)).sqrt())
97            .fold(f64::INFINITY, f64::min)
98    }
99}
100
101/// PRM Planner
102pub struct PRMPlanner {
103    sample_x: Vec<f64>,
104    sample_y: Vec<f64>,
105    road_map: Vec<Vec<usize>>,
106}
107
108impl PRMPlanner {
109    /// Create a new PRM planner
110    pub fn new(
111        ox: &[f64],
112        oy: &[f64],
113        start: (f64, f64),
114        goal: (f64, f64),
115        robot_radius: f64,
116    ) -> Self {
117        let obstacle_tree = KDTree::new(ox.iter().zip(oy.iter()).map(|(&x, &y)| (x, y)).collect());
118
119        let min_x = ox.iter().cloned().fold(f64::INFINITY, f64::min);
120        let max_x = ox.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
121        let min_y = oy.iter().cloned().fold(f64::INFINITY, f64::min);
122        let max_y = oy.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
123
124        let (sample_x, sample_y) = Self::sample_points(
125            start,
126            goal,
127            min_x,
128            max_x,
129            min_y,
130            max_y,
131            robot_radius,
132            &obstacle_tree,
133        );
134
135        let road_map = Self::generate_road_map(&sample_x, &sample_y, robot_radius, &obstacle_tree);
136
137        PRMPlanner {
138            sample_x,
139            sample_y,
140            road_map,
141        }
142    }
143
144    fn sample_points(
145        start: (f64, f64),
146        goal: (f64, f64),
147        min_x: f64,
148        max_x: f64,
149        min_y: f64,
150        max_y: f64,
151        robot_radius: f64,
152        obstacle_tree: &KDTree,
153    ) -> (Vec<f64>, Vec<f64>) {
154        let mut rng = rand::rng();
155        let mut sample_x = Vec::new();
156        let mut sample_y = Vec::new();
157
158        while sample_x.len() < N_SAMPLE {
159            let x = rng.random_range(min_x..max_x);
160            let y = rng.random_range(min_y..max_y);
161
162            let min_dist = obstacle_tree.min_distance(x, y);
163            if min_dist > robot_radius {
164                sample_x.push(x);
165                sample_y.push(y);
166            }
167        }
168
169        sample_x.push(start.0);
170        sample_y.push(start.1);
171        sample_x.push(goal.0);
172        sample_y.push(goal.1);
173
174        (sample_x, sample_y)
175    }
176
177    fn generate_road_map(
178        sample_x: &[f64],
179        sample_y: &[f64],
180        robot_radius: f64,
181        obstacle_tree: &KDTree,
182    ) -> Vec<Vec<usize>> {
183        let sample_tree = KDTree::new(
184            sample_x
185                .iter()
186                .zip(sample_y.iter())
187                .map(|(&x, &y)| (x, y))
188                .collect(),
189        );
190
191        let mut road_map: Vec<Vec<usize>> = vec![Vec::new(); sample_x.len()];
192
193        for (i, (&x, &y)) in sample_x.iter().zip(sample_y.iter()).enumerate() {
194            let neighbors = sample_tree.query_knn(x, y, N_KNN);
195
196            for (j, dist) in neighbors {
197                if i == j {
198                    continue;
199                }
200
201                if dist > MAX_EDGE_LEN {
202                    continue;
203                }
204
205                if !Self::is_collision(x, y, sample_x[j], sample_y[j], robot_radius, obstacle_tree)
206                {
207                    road_map[i].push(j);
208                }
209            }
210        }
211
212        road_map
213    }
214
215    fn is_collision(
216        x1: f64,
217        y1: f64,
218        x2: f64,
219        y2: f64,
220        robot_radius: f64,
221        obstacle_tree: &KDTree,
222    ) -> bool {
223        let dx = x2 - x1;
224        let dy = y2 - y1;
225        let d = (dx * dx + dy * dy).sqrt();
226
227        if d == 0.0 {
228            return false;
229        }
230
231        let step = robot_radius;
232        let n_steps = (d / step).ceil() as usize;
233
234        for i in 0..=n_steps {
235            let t = i as f64 / n_steps as f64;
236            let x = x1 + t * dx;
237            let y = y1 + t * dy;
238
239            let min_dist = obstacle_tree.min_distance(x, y);
240            if min_dist <= robot_radius {
241                return true;
242            }
243        }
244
245        false
246    }
247
248    /// Plan path using Dijkstra's algorithm
249    pub fn plan(&self) -> Option<(Vec<f64>, Vec<f64>)> {
250        let n = self.sample_x.len();
251        let start_idx = n - 2;
252        let goal_idx = n - 1;
253
254        let mut nodes: Vec<Node> = self
255            .sample_x
256            .iter()
257            .zip(self.sample_y.iter())
258            .map(|(&x, &y)| Node::new(x, y))
259            .collect();
260
261        nodes[start_idx].cost = 0.0;
262
263        let mut open_set = BinaryHeap::new();
264        open_set.push(QueueItem {
265            cost: 0.0,
266            index: start_idx,
267        });
268
269        let mut closed_set: HashMap<usize, bool> = HashMap::new();
270
271        while let Some(current) = open_set.pop() {
272            if current.index == goal_idx {
273                return Some(self.reconstruct_path(&nodes, goal_idx));
274            }
275
276            if closed_set.contains_key(&current.index) {
277                continue;
278            }
279            closed_set.insert(current.index, true);
280
281            for &neighbor_idx in &self.road_map[current.index] {
282                if closed_set.contains_key(&neighbor_idx) {
283                    continue;
284                }
285
286                let dx = nodes[neighbor_idx].x - nodes[current.index].x;
287                let dy = nodes[neighbor_idx].y - nodes[current.index].y;
288                let edge_cost = (dx * dx + dy * dy).sqrt();
289                let new_cost = nodes[current.index].cost + edge_cost;
290
291                if new_cost < nodes[neighbor_idx].cost {
292                    nodes[neighbor_idx].cost = new_cost;
293                    nodes[neighbor_idx].parent = Some(current.index);
294                    open_set.push(QueueItem {
295                        cost: new_cost,
296                        index: neighbor_idx,
297                    });
298                }
299            }
300        }
301
302        None
303    }
304
305    fn reconstruct_path(&self, nodes: &[Node], goal_idx: usize) -> (Vec<f64>, Vec<f64>) {
306        let mut path_x = Vec::new();
307        let mut path_y = Vec::new();
308
309        let mut current = goal_idx;
310        while let Some(parent) = nodes[current].parent {
311            path_x.push(nodes[current].x);
312            path_y.push(nodes[current].y);
313            current = parent;
314        }
315        path_x.push(nodes[current].x);
316        path_y.push(nodes[current].y);
317
318        path_x.reverse();
319        path_y.reverse();
320
321        (path_x, path_y)
322    }
323
324    /// Get sample points for external inspection
325    pub fn get_samples(&self) -> (&[f64], &[f64]) {
326        (&self.sample_x, &self.sample_y)
327    }
328
329    /// Get road map edges for external inspection
330    pub fn get_edges(&self) -> Vec<((f64, f64), (f64, f64))> {
331        let mut edges = Vec::new();
332
333        for (i, neighbors) in self.road_map.iter().enumerate() {
334            for &j in neighbors {
335                if i < j {
336                    edges.push((
337                        (self.sample_x[i], self.sample_y[i]),
338                        (self.sample_x[j], self.sample_y[j]),
339                    ));
340                }
341            }
342        }
343
344        edges
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_prm_creation() {
354        let mut ox = Vec::new();
355        let mut oy = Vec::new();
356        for i in 0..20 {
357            ox.push(i as f64);
358            oy.push(0.0);
359            ox.push(i as f64);
360            oy.push(20.0);
361            ox.push(0.0);
362            oy.push(i as f64);
363            ox.push(20.0);
364            oy.push(i as f64);
365        }
366
367        let prm = PRMPlanner::new(&ox, &oy, (2.0, 2.0), (18.0, 18.0), 2.0);
368        let (sx, sy) = prm.get_samples();
369        assert!(!sx.is_empty());
370        assert_eq!(sx.len(), sy.len());
371    }
372}