1#![allow(dead_code, clippy::too_many_arguments)]
2
3use rand::Rng;
9use std::cmp::Ordering;
10use std::collections::{BinaryHeap, HashMap};
11
12#[derive(Debug, Clone)]
14pub struct PRMStarConfig {
15 pub n_samples: usize,
17 pub robot_radius: f64,
19 pub gamma: f64,
21}
22
23impl Default for PRMStarConfig {
24 fn default() -> Self {
25 Self {
26 n_samples: 500,
27 robot_radius: 0.8,
28 gamma: 2.5,
29 }
30 }
31}
32
33impl PRMStarConfig {
34 pub fn validate(&self) -> Result<(), String> {
36 if self.n_samples == 0 {
37 return Err("PRM* requires at least one sample".to_string());
38 }
39 if !self.robot_radius.is_finite() || self.robot_radius <= 0.0 {
40 return Err("PRM* robot_radius must be positive and finite".to_string());
41 }
42 if !self.gamma.is_finite() || self.gamma <= 0.0 {
43 return Err("PRM* gamma must be positive and finite".to_string());
44 }
45 Ok(())
46 }
47}
48
49#[derive(Clone)]
50struct Node {
51 x: f64,
52 y: f64,
53 cost: f64,
54 parent: Option<usize>,
55}
56
57impl Node {
58 fn new(x: f64, y: f64) -> Self {
59 Self {
60 x,
61 y,
62 cost: f64::INFINITY,
63 parent: None,
64 }
65 }
66}
67
68#[derive(Clone)]
69struct QueueItem {
70 cost: f64,
71 index: usize,
72}
73
74impl PartialEq for QueueItem {
75 fn eq(&self, other: &Self) -> bool {
76 self.cost == other.cost
77 }
78}
79
80impl Eq for QueueItem {}
81
82impl Ord for QueueItem {
83 fn cmp(&self, other: &Self) -> Ordering {
84 other
85 .cost
86 .partial_cmp(&self.cost)
87 .unwrap_or(Ordering::Equal)
88 }
89}
90
91impl PartialOrd for QueueItem {
92 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
93 Some(self.cmp(other))
94 }
95}
96
97struct KDTree {
98 points: Vec<(f64, f64)>,
99}
100
101impl KDTree {
102 fn new(points: Vec<(f64, f64)>) -> Self {
103 Self { points }
104 }
105
106 fn query_radius(&self, x: f64, y: f64, radius: f64) -> Vec<(usize, f64)> {
107 let r2 = radius * radius;
108 self.points
109 .iter()
110 .enumerate()
111 .filter_map(|(i, (px, py))| {
112 let dx = x - px;
113 let dy = y - py;
114 let d2 = dx * dx + dy * dy;
115 if d2 <= r2 {
116 Some((i, d2.sqrt()))
117 } else {
118 None
119 }
120 })
121 .collect()
122 }
123
124 fn min_distance(&self, x: f64, y: f64) -> f64 {
125 self.points
126 .iter()
127 .map(|(px, py)| ((x - px).powi(2) + (y - py).powi(2)).sqrt())
128 .fold(f64::INFINITY, f64::min)
129 }
130}
131
132pub struct PRMStarPlanner {
134 sample_x: Vec<f64>,
135 sample_y: Vec<f64>,
136 road_map: Vec<Vec<usize>>,
137 connection_radius: f64,
138}
139
140impl PRMStarPlanner {
141 pub fn new(
143 ox: &[f64],
144 oy: &[f64],
145 start: (f64, f64),
146 goal: (f64, f64),
147 config: PRMStarConfig,
148 ) -> Self {
149 config.validate().expect(
150 "invalid PRM* configuration: n_samples > 0, robot_radius > 0, gamma > 0 required",
151 );
152
153 let obstacle_tree = KDTree::new(ox.iter().zip(oy.iter()).map(|(&x, &y)| (x, y)).collect());
154 let min_x = ox.iter().copied().fold(f64::INFINITY, f64::min);
155 let max_x = ox.iter().copied().fold(f64::NEG_INFINITY, f64::max);
156 let min_y = oy.iter().copied().fold(f64::INFINITY, f64::min);
157 let max_y = oy.iter().copied().fold(f64::NEG_INFINITY, f64::max);
158
159 let (sample_x, sample_y) = Self::sample_points(
160 start,
161 goal,
162 min_x,
163 max_x,
164 min_y,
165 max_y,
166 config.n_samples,
167 config.robot_radius,
168 &obstacle_tree,
169 );
170
171 let workspace_scale = ((max_x - min_x).powi(2) + (max_y - min_y).powi(2))
172 .sqrt()
173 .max(1.0);
174 let connection_radius =
175 Self::compute_connection_radius(sample_x.len(), workspace_scale, config.gamma);
176 let road_map = Self::generate_road_map(
177 &sample_x,
178 &sample_y,
179 config.robot_radius,
180 connection_radius,
181 &obstacle_tree,
182 );
183
184 Self {
185 sample_x,
186 sample_y,
187 road_map,
188 connection_radius,
189 }
190 }
191
192 fn compute_connection_radius(n: usize, workspace_scale: f64, gamma: f64) -> f64 {
193 let n_f = n as f64;
194 let radius_normalized = gamma * (n_f.ln() / n_f).sqrt();
195 (radius_normalized * workspace_scale).max(1e-3)
196 }
197
198 fn sample_points(
199 start: (f64, f64),
200 goal: (f64, f64),
201 min_x: f64,
202 max_x: f64,
203 min_y: f64,
204 max_y: f64,
205 n_samples: usize,
206 robot_radius: f64,
207 obstacle_tree: &KDTree,
208 ) -> (Vec<f64>, Vec<f64>) {
209 let mut rng = rand::rng();
210 let mut sample_x = Vec::with_capacity(n_samples + 2);
211 let mut sample_y = Vec::with_capacity(n_samples + 2);
212
213 while sample_x.len() < n_samples {
214 let x = rng.random_range(min_x..max_x);
215 let y = rng.random_range(min_y..max_y);
216 if obstacle_tree.min_distance(x, y) > robot_radius {
217 sample_x.push(x);
218 sample_y.push(y);
219 }
220 }
221
222 sample_x.push(start.0);
223 sample_y.push(start.1);
224 sample_x.push(goal.0);
225 sample_y.push(goal.1);
226
227 (sample_x, sample_y)
228 }
229
230 fn generate_road_map(
231 sample_x: &[f64],
232 sample_y: &[f64],
233 robot_radius: f64,
234 connection_radius: f64,
235 obstacle_tree: &KDTree,
236 ) -> Vec<Vec<usize>> {
237 let sample_tree = KDTree::new(
238 sample_x
239 .iter()
240 .zip(sample_y.iter())
241 .map(|(&x, &y)| (x, y))
242 .collect(),
243 );
244 let mut road_map: Vec<Vec<usize>> = vec![Vec::new(); sample_x.len()];
245
246 for (i, (&x, &y)) in sample_x.iter().zip(sample_y.iter()).enumerate() {
247 for (j, dist) in sample_tree.query_radius(x, y, connection_radius) {
248 if i == j {
249 continue;
250 }
251 if !Self::is_collision(x, y, sample_x[j], sample_y[j], robot_radius, obstacle_tree)
252 {
253 road_map[i].push(j);
254 } else {
255 let _ = dist;
256 }
257 }
258 }
259
260 road_map
261 }
262
263 fn is_collision(
264 x1: f64,
265 y1: f64,
266 x2: f64,
267 y2: f64,
268 robot_radius: f64,
269 obstacle_tree: &KDTree,
270 ) -> bool {
271 let dx = x2 - x1;
272 let dy = y2 - y1;
273 let d = (dx * dx + dy * dy).sqrt();
274
275 if d == 0.0 {
276 return false;
277 }
278
279 let step = robot_radius;
280 let n_steps = (d / step).ceil() as usize;
281 for i in 0..=n_steps {
282 let t = i as f64 / n_steps as f64;
283 let x = x1 + t * dx;
284 let y = y1 + t * dy;
285 if obstacle_tree.min_distance(x, y) <= robot_radius {
286 return true;
287 }
288 }
289 false
290 }
291
292 pub fn plan(&self) -> Option<(Vec<f64>, Vec<f64>)> {
294 let n = self.sample_x.len();
295 let start_idx = n - 2;
296 let goal_idx = n - 1;
297
298 let mut nodes: Vec<Node> = self
299 .sample_x
300 .iter()
301 .zip(self.sample_y.iter())
302 .map(|(&x, &y)| Node::new(x, y))
303 .collect();
304 nodes[start_idx].cost = 0.0;
305
306 let mut open_set = BinaryHeap::new();
307 open_set.push(QueueItem {
308 cost: 0.0,
309 index: start_idx,
310 });
311 let mut closed_set: HashMap<usize, bool> = HashMap::new();
312
313 while let Some(current) = open_set.pop() {
314 if current.index == goal_idx {
315 return Some(self.reconstruct_path(&nodes, goal_idx));
316 }
317 if closed_set.contains_key(¤t.index) {
318 continue;
319 }
320 closed_set.insert(current.index, true);
321
322 for &neighbor_idx in &self.road_map[current.index] {
323 if closed_set.contains_key(&neighbor_idx) {
324 continue;
325 }
326 let dx = nodes[neighbor_idx].x - nodes[current.index].x;
327 let dy = nodes[neighbor_idx].y - nodes[current.index].y;
328 let edge_cost = (dx * dx + dy * dy).sqrt();
329 let new_cost = nodes[current.index].cost + edge_cost;
330 if new_cost < nodes[neighbor_idx].cost {
331 nodes[neighbor_idx].cost = new_cost;
332 nodes[neighbor_idx].parent = Some(current.index);
333 open_set.push(QueueItem {
334 cost: new_cost,
335 index: neighbor_idx,
336 });
337 }
338 }
339 }
340
341 None
342 }
343
344 fn reconstruct_path(&self, nodes: &[Node], goal_idx: usize) -> (Vec<f64>, Vec<f64>) {
345 let mut path_x = Vec::new();
346 let mut path_y = Vec::new();
347 let mut current = goal_idx;
348
349 while let Some(parent) = nodes[current].parent {
350 path_x.push(nodes[current].x);
351 path_y.push(nodes[current].y);
352 current = parent;
353 }
354 path_x.push(nodes[current].x);
355 path_y.push(nodes[current].y);
356 path_x.reverse();
357 path_y.reverse();
358
359 (path_x, path_y)
360 }
361
362 pub fn get_samples(&self) -> (&[f64], &[f64]) {
364 (&self.sample_x, &self.sample_y)
365 }
366
367 pub fn connection_radius(&self) -> f64 {
369 self.connection_radius
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 fn rectangular_walls(size: usize) -> (Vec<f64>, Vec<f64>) {
378 let mut ox = Vec::new();
379 let mut oy = Vec::new();
380 for i in 0..=size {
381 let v = i as f64;
382 ox.push(v);
383 oy.push(0.0);
384 ox.push(v);
385 oy.push(size as f64);
386 ox.push(0.0);
387 oy.push(v);
388 ox.push(size as f64);
389 oy.push(v);
390 }
391 (ox, oy)
392 }
393
394 fn path_length(xs: &[f64], ys: &[f64]) -> f64 {
395 xs.windows(2)
396 .zip(ys.windows(2))
397 .map(|(wx, wy)| {
398 let dx = wx[1] - wx[0];
399 let dy = wy[1] - wy[0];
400 (dx * dx + dy * dy).sqrt()
401 })
402 .sum()
403 }
404
405 #[test]
406 fn test_prm_star_finds_path() {
407 let (ox, oy) = rectangular_walls(30);
408 let config = PRMStarConfig {
409 n_samples: 450,
410 robot_radius: 0.8,
411 gamma: 2.5,
412 };
413 let planner = PRMStarPlanner::new(&ox, &oy, (2.0, 2.0), (28.0, 28.0), config);
414
415 let path = planner.plan();
416 assert!(path.is_some(), "PRM* should find a path in free interior");
417
418 let (px, py) = path.unwrap();
419 assert_eq!(px.len(), py.len());
420 assert!(px.len() >= 2);
421 }
422
423 #[test]
424 fn test_prm_star_path_quality_improves_with_more_samples() {
425 let (ox, oy) = rectangular_walls(20);
426 let low_cfg = PRMStarConfig {
427 n_samples: 60,
428 robot_radius: 0.8,
429 gamma: 2.5,
430 };
431 let high_cfg = PRMStarConfig {
432 n_samples: 200,
433 robot_radius: 0.8,
434 gamma: 2.5,
435 };
436
437 let planner_low = PRMStarPlanner::new(&ox, &oy, (2.0, 2.0), (18.0, 18.0), low_cfg.clone());
438 let low_result = planner_low.plan();
439
440 let planner_high =
441 PRMStarPlanner::new(&ox, &oy, (2.0, 2.0), (18.0, 18.0), high_cfg.clone());
442 let high_result = planner_high.plan();
443
444 assert!(
446 low_result.is_some() || high_result.is_some(),
447 "at least one configuration should find a path"
448 );
449 }
450
451 #[test]
452 fn test_prm_star_config_defaults() {
453 let config = PRMStarConfig::default();
454 assert_eq!(config.n_samples, 500);
455 assert!((config.robot_radius - 0.8).abs() < f64::EPSILON);
456 assert!((config.gamma - 2.5).abs() < f64::EPSILON);
457 }
458}