1use std::cmp::Ordering;
7use std::collections::{BinaryHeap, HashMap, HashSet};
8
9use rust_robotics_core::{Point3D, RoboticsError, RoboticsResult};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12struct GridPoint3D {
13 x: i32,
14 y: i32,
15 z: i32,
16}
17
18impl GridPoint3D {
19 fn new(x: i32, y: i32, z: i32) -> Self {
20 Self { x, y, z }
21 }
22}
23
24#[derive(Debug, Clone)]
25pub struct Path3D {
26 pub points: Vec<Point3D>,
27}
28
29impl Path3D {
30 pub fn new(points: Vec<Point3D>) -> Self {
31 Self { points }
32 }
33
34 pub fn len(&self) -> usize {
35 self.points.len()
36 }
37
38 pub fn is_empty(&self) -> bool {
39 self.points.is_empty()
40 }
41}
42
43#[derive(Debug, Clone)]
44pub struct GridAStar3DConfig {
45 pub resolution: f64,
46 pub bounds_min: Point3D,
47 pub bounds_max: Point3D,
48 pub allow_diagonal: bool,
49}
50
51impl Default for GridAStar3DConfig {
52 fn default() -> Self {
53 Self {
54 resolution: 1.0,
55 bounds_min: Point3D::new(0.0, 0.0, 0.0),
56 bounds_max: Point3D::new(10.0, 10.0, 5.0),
57 allow_diagonal: true,
58 }
59 }
60}
61
62#[derive(Debug, Clone, Copy)]
63struct PriorityNode {
64 point: GridPoint3D,
65 cost: f64,
66 priority: f64,
67}
68
69impl Eq for PriorityNode {}
70
71impl PartialEq for PriorityNode {
72 fn eq(&self, other: &Self) -> bool {
73 self.priority == other.priority
74 }
75}
76
77impl Ord for PriorityNode {
78 fn cmp(&self, other: &Self) -> Ordering {
79 other
80 .priority
81 .partial_cmp(&self.priority)
82 .unwrap_or(Ordering::Equal)
83 }
84}
85
86impl PartialOrd for PriorityNode {
87 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88 Some(self.cmp(other))
89 }
90}
91
92pub struct GridAStar3DPlanner {
93 config: GridAStar3DConfig,
94 max_index: GridPoint3D,
95 obstacles: HashSet<GridPoint3D>,
96 motions: Vec<(i32, i32, i32, f64)>,
97}
98
99impl GridAStar3DPlanner {
100 pub fn new(config: GridAStar3DConfig, obstacles: &[Point3D]) -> RoboticsResult<Self> {
101 validate_config(&config)?;
102
103 let max_index = GridPoint3D::new(
104 ((config.bounds_max.x - config.bounds_min.x) / config.resolution).round() as i32,
105 ((config.bounds_max.y - config.bounds_min.y) / config.resolution).round() as i32,
106 ((config.bounds_max.z - config.bounds_min.z) / config.resolution).round() as i32,
107 );
108
109 let planner = Self {
110 max_index,
111 obstacles: obstacles
112 .iter()
113 .map(|point| quantize(point, &config))
114 .collect(),
115 motions: build_motion_model(config.allow_diagonal),
116 config,
117 };
118
119 Ok(planner)
120 }
121
122 pub fn plan(&self, start: Point3D, goal: Point3D) -> RoboticsResult<Path3D> {
123 let start_grid = quantize(&start, &self.config);
124 let goal_grid = quantize(&goal, &self.config);
125
126 if !self.is_valid(start_grid) {
127 return Err(RoboticsError::PlanningError(
128 "Start point is out of bounds or occupied".to_string(),
129 ));
130 }
131
132 if !self.is_valid(goal_grid) {
133 return Err(RoboticsError::PlanningError(
134 "Goal point is out of bounds or occupied".to_string(),
135 ));
136 }
137
138 let mut open_set = BinaryHeap::new();
139 let mut came_from = HashMap::new();
140 let mut best_cost = HashMap::new();
141
142 open_set.push(PriorityNode {
143 point: start_grid,
144 cost: 0.0,
145 priority: heuristic(start_grid, goal_grid),
146 });
147 best_cost.insert(start_grid, 0.0);
148
149 while let Some(current) = open_set.pop() {
150 if current.point == goal_grid {
151 return Ok(self.reconstruct_path(goal_grid, start_grid, &came_from));
152 }
153
154 let Some(known_cost) = best_cost.get(¤t.point).copied() else {
155 continue;
156 };
157 if current.cost > known_cost {
158 continue;
159 }
160
161 for (dx, dy, dz, move_cost) in &self.motions {
162 let next = GridPoint3D::new(
163 current.point.x + dx,
164 current.point.y + dy,
165 current.point.z + dz,
166 );
167 if !self.is_valid(next) {
168 continue;
169 }
170
171 let tentative_cost = current.cost + move_cost;
172 let current_best = best_cost.get(&next).copied().unwrap_or(f64::INFINITY);
173 if tentative_cost >= current_best {
174 continue;
175 }
176
177 came_from.insert(next, current.point);
178 best_cost.insert(next, tentative_cost);
179 open_set.push(PriorityNode {
180 point: next,
181 cost: tentative_cost,
182 priority: tentative_cost + heuristic(next, goal_grid),
183 });
184 }
185 }
186
187 Err(RoboticsError::PlanningError("No 3D path found".to_string()))
188 }
189
190 fn is_valid(&self, point: GridPoint3D) -> bool {
191 point.x >= 0
192 && point.y >= 0
193 && point.z >= 0
194 && point.x <= self.max_index.x
195 && point.y <= self.max_index.y
196 && point.z <= self.max_index.z
197 && !self.obstacles.contains(&point)
198 }
199
200 fn reconstruct_path(
201 &self,
202 goal: GridPoint3D,
203 start: GridPoint3D,
204 came_from: &HashMap<GridPoint3D, GridPoint3D>,
205 ) -> Path3D {
206 let mut points = vec![goal];
207 let mut current = goal;
208
209 while current != start {
210 current = came_from[¤t];
211 points.push(current);
212 }
213
214 points.reverse();
215 Path3D::new(
216 points
217 .into_iter()
218 .map(|point| dequantize(point, &self.config))
219 .collect(),
220 )
221 }
222}
223
224fn validate_config(config: &GridAStar3DConfig) -> RoboticsResult<()> {
225 if config.resolution <= 0.0 {
226 return Err(RoboticsError::InvalidParameter(
227 "resolution must be positive".to_string(),
228 ));
229 }
230
231 if config.bounds_min.x > config.bounds_max.x
232 || config.bounds_min.y > config.bounds_max.y
233 || config.bounds_min.z > config.bounds_max.z
234 {
235 return Err(RoboticsError::InvalidParameter(
236 "bounds_min must not exceed bounds_max".to_string(),
237 ));
238 }
239
240 Ok(())
241}
242
243fn quantize(point: &Point3D, config: &GridAStar3DConfig) -> GridPoint3D {
244 GridPoint3D::new(
245 ((point.x - config.bounds_min.x) / config.resolution).round() as i32,
246 ((point.y - config.bounds_min.y) / config.resolution).round() as i32,
247 ((point.z - config.bounds_min.z) / config.resolution).round() as i32,
248 )
249}
250
251fn dequantize(point: GridPoint3D, config: &GridAStar3DConfig) -> Point3D {
252 Point3D::new(
253 config.bounds_min.x + point.x as f64 * config.resolution,
254 config.bounds_min.y + point.y as f64 * config.resolution,
255 config.bounds_min.z + point.z as f64 * config.resolution,
256 )
257}
258
259fn heuristic(a: GridPoint3D, b: GridPoint3D) -> f64 {
260 let dx = (a.x - b.x) as f64;
261 let dy = (a.y - b.y) as f64;
262 let dz = (a.z - b.z) as f64;
263 (dx * dx + dy * dy + dz * dz).sqrt()
264}
265
266fn build_motion_model(allow_diagonal: bool) -> Vec<(i32, i32, i32, f64)> {
267 let mut motions = vec![
268 (1, 0, 0, 1.0),
269 (-1, 0, 0, 1.0),
270 (0, 1, 0, 1.0),
271 (0, -1, 0, 1.0),
272 (0, 0, 1, 1.0),
273 (0, 0, -1, 1.0),
274 ];
275
276 if allow_diagonal {
277 for dx in -1_i32..=1 {
278 for dy in -1_i32..=1 {
279 for dz in -1_i32..=1 {
280 if dx == 0 && dy == 0 && dz == 0 {
281 continue;
282 }
283 if dx.abs() + dy.abs() + dz.abs() <= 1 {
284 continue;
285 }
286 let cost = ((dx * dx + dy * dy + dz * dz) as f64).sqrt();
287 motions.push((dx, dy, dz, cost));
288 }
289 }
290 }
291 }
292
293 motions
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 fn planner_with_config(allow_diagonal: bool, obstacles: Vec<Point3D>) -> GridAStar3DPlanner {
301 GridAStar3DPlanner::new(
302 GridAStar3DConfig {
303 resolution: 1.0,
304 bounds_min: Point3D::new(0.0, 0.0, 0.0),
305 bounds_max: Point3D::new(4.0, 4.0, 4.0),
306 allow_diagonal,
307 },
308 &obstacles,
309 )
310 .expect("planner should be created")
311 }
312
313 #[test]
314 fn test_invalid_config_is_rejected() {
315 let result = GridAStar3DPlanner::new(
316 GridAStar3DConfig {
317 resolution: 0.0,
318 ..Default::default()
319 },
320 &[],
321 );
322
323 assert!(matches!(result, Err(RoboticsError::InvalidParameter(_))));
324 }
325
326 #[test]
327 fn test_grid_a_star_3d_finds_path() {
328 let planner = planner_with_config(true, vec![]);
329 let start = Point3D::new(0.0, 0.0, 0.0);
330 let goal = Point3D::new(3.0, 2.0, 1.0);
331
332 let path = planner.plan(start, goal).expect("path should exist");
333
334 assert!(!path.is_empty());
335 assert_eq!(path.points.first().copied(), Some(start));
336 assert_eq!(path.points.last().copied(), Some(goal));
337 }
338
339 #[test]
340 fn test_grid_a_star_3d_uses_diagonal_shortcut_when_enabled() {
341 let planner = planner_with_config(true, vec![]);
342
343 let path = planner
344 .plan(Point3D::new(0.0, 0.0, 0.0), Point3D::new(2.0, 2.0, 2.0))
345 .expect("path should exist");
346
347 assert_eq!(path.len(), 3);
348 assert_eq!(path.points[1], Point3D::new(1.0, 1.0, 1.0));
349 }
350
351 #[test]
352 fn test_grid_a_star_3d_avoids_obstacles() {
353 let obstacles = vec![
354 Point3D::new(1.0, 0.0, 0.0),
355 Point3D::new(1.0, 1.0, 0.0),
356 Point3D::new(1.0, 2.0, 0.0),
357 ];
358 let planner = planner_with_config(false, obstacles.clone());
359
360 let path = planner
361 .plan(Point3D::new(0.0, 0.0, 0.0), Point3D::new(2.0, 2.0, 0.0))
362 .expect("path should route around the wall");
363
364 assert!(path.points.iter().all(|point| !obstacles.contains(point)));
365 assert!(path.points.iter().any(|point| point.z > 0.0));
366 assert!(path.len() > 4);
367 }
368
369 #[test]
370 fn test_grid_a_star_3d_reports_no_path() {
371 let planner = planner_with_config(
372 false,
373 vec![
374 Point3D::new(0.0, 1.0, 1.0),
375 Point3D::new(2.0, 1.0, 1.0),
376 Point3D::new(1.0, 0.0, 1.0),
377 Point3D::new(1.0, 2.0, 1.0),
378 Point3D::new(1.0, 1.0, 0.0),
379 Point3D::new(1.0, 1.0, 2.0),
380 ],
381 );
382
383 let result = planner.plan(Point3D::new(0.0, 0.0, 0.0), Point3D::new(1.0, 1.0, 1.0));
384
385 assert!(matches!(result, Err(RoboticsError::PlanningError(_))));
386 }
387}