Skip to main content

rust_robotics_planning/
grid_a_star_3d.rs

1//! 3D grid-based A* path planning for aerial robots.
2//!
3//! The planner operates on a voxel grid and returns a collision-free path
4//! between two 3D points inside a bounded workspace.
5
6use std::cmp::Ordering;
7use std::collections::{BinaryHeap, HashMap, HashSet};
8
9use rust_robotics_core::{Point3D, RoboticsError, RoboticsResult};
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12struct GridPoint3D {
13    x: i32,
14    y: i32,
15    z: i32,
16}
17
18impl GridPoint3D {
19    fn new(x: i32, y: i32, z: i32) -> Self {
20        Self { x, y, z }
21    }
22}
23
24#[derive(Debug, Clone)]
25pub struct Path3D {
26    pub points: Vec<Point3D>,
27}
28
29impl Path3D {
30    pub fn new(points: Vec<Point3D>) -> Self {
31        Self { points }
32    }
33
34    pub fn len(&self) -> usize {
35        self.points.len()
36    }
37
38    pub fn is_empty(&self) -> bool {
39        self.points.is_empty()
40    }
41}
42
43#[derive(Debug, Clone)]
44pub struct GridAStar3DConfig {
45    pub resolution: f64,
46    pub bounds_min: Point3D,
47    pub bounds_max: Point3D,
48    pub allow_diagonal: bool,
49}
50
51impl Default for GridAStar3DConfig {
52    fn default() -> Self {
53        Self {
54            resolution: 1.0,
55            bounds_min: Point3D::new(0.0, 0.0, 0.0),
56            bounds_max: Point3D::new(10.0, 10.0, 5.0),
57            allow_diagonal: true,
58        }
59    }
60}
61
62#[derive(Debug, Clone, Copy)]
63struct PriorityNode {
64    point: GridPoint3D,
65    cost: f64,
66    priority: f64,
67}
68
69impl Eq for PriorityNode {}
70
71impl PartialEq for PriorityNode {
72    fn eq(&self, other: &Self) -> bool {
73        self.priority == other.priority
74    }
75}
76
77impl Ord for PriorityNode {
78    fn cmp(&self, other: &Self) -> Ordering {
79        other
80            .priority
81            .partial_cmp(&self.priority)
82            .unwrap_or(Ordering::Equal)
83    }
84}
85
86impl PartialOrd for PriorityNode {
87    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
88        Some(self.cmp(other))
89    }
90}
91
92pub struct GridAStar3DPlanner {
93    config: GridAStar3DConfig,
94    max_index: GridPoint3D,
95    obstacles: HashSet<GridPoint3D>,
96    motions: Vec<(i32, i32, i32, f64)>,
97}
98
99impl GridAStar3DPlanner {
100    pub fn new(config: GridAStar3DConfig, obstacles: &[Point3D]) -> RoboticsResult<Self> {
101        validate_config(&config)?;
102
103        let max_index = GridPoint3D::new(
104            ((config.bounds_max.x - config.bounds_min.x) / config.resolution).round() as i32,
105            ((config.bounds_max.y - config.bounds_min.y) / config.resolution).round() as i32,
106            ((config.bounds_max.z - config.bounds_min.z) / config.resolution).round() as i32,
107        );
108
109        let planner = Self {
110            max_index,
111            obstacles: obstacles
112                .iter()
113                .map(|point| quantize(point, &config))
114                .collect(),
115            motions: build_motion_model(config.allow_diagonal),
116            config,
117        };
118
119        Ok(planner)
120    }
121
122    pub fn plan(&self, start: Point3D, goal: Point3D) -> RoboticsResult<Path3D> {
123        let start_grid = quantize(&start, &self.config);
124        let goal_grid = quantize(&goal, &self.config);
125
126        if !self.is_valid(start_grid) {
127            return Err(RoboticsError::PlanningError(
128                "Start point is out of bounds or occupied".to_string(),
129            ));
130        }
131
132        if !self.is_valid(goal_grid) {
133            return Err(RoboticsError::PlanningError(
134                "Goal point is out of bounds or occupied".to_string(),
135            ));
136        }
137
138        let mut open_set = BinaryHeap::new();
139        let mut came_from = HashMap::new();
140        let mut best_cost = HashMap::new();
141
142        open_set.push(PriorityNode {
143            point: start_grid,
144            cost: 0.0,
145            priority: heuristic(start_grid, goal_grid),
146        });
147        best_cost.insert(start_grid, 0.0);
148
149        while let Some(current) = open_set.pop() {
150            if current.point == goal_grid {
151                return Ok(self.reconstruct_path(goal_grid, start_grid, &came_from));
152            }
153
154            let Some(known_cost) = best_cost.get(&current.point).copied() else {
155                continue;
156            };
157            if current.cost > known_cost {
158                continue;
159            }
160
161            for (dx, dy, dz, move_cost) in &self.motions {
162                let next = GridPoint3D::new(
163                    current.point.x + dx,
164                    current.point.y + dy,
165                    current.point.z + dz,
166                );
167                if !self.is_valid(next) {
168                    continue;
169                }
170
171                let tentative_cost = current.cost + move_cost;
172                let current_best = best_cost.get(&next).copied().unwrap_or(f64::INFINITY);
173                if tentative_cost >= current_best {
174                    continue;
175                }
176
177                came_from.insert(next, current.point);
178                best_cost.insert(next, tentative_cost);
179                open_set.push(PriorityNode {
180                    point: next,
181                    cost: tentative_cost,
182                    priority: tentative_cost + heuristic(next, goal_grid),
183                });
184            }
185        }
186
187        Err(RoboticsError::PlanningError("No 3D path found".to_string()))
188    }
189
190    fn is_valid(&self, point: GridPoint3D) -> bool {
191        point.x >= 0
192            && point.y >= 0
193            && point.z >= 0
194            && point.x <= self.max_index.x
195            && point.y <= self.max_index.y
196            && point.z <= self.max_index.z
197            && !self.obstacles.contains(&point)
198    }
199
200    fn reconstruct_path(
201        &self,
202        goal: GridPoint3D,
203        start: GridPoint3D,
204        came_from: &HashMap<GridPoint3D, GridPoint3D>,
205    ) -> Path3D {
206        let mut points = vec![goal];
207        let mut current = goal;
208
209        while current != start {
210            current = came_from[&current];
211            points.push(current);
212        }
213
214        points.reverse();
215        Path3D::new(
216            points
217                .into_iter()
218                .map(|point| dequantize(point, &self.config))
219                .collect(),
220        )
221    }
222}
223
224fn validate_config(config: &GridAStar3DConfig) -> RoboticsResult<()> {
225    if config.resolution <= 0.0 {
226        return Err(RoboticsError::InvalidParameter(
227            "resolution must be positive".to_string(),
228        ));
229    }
230
231    if config.bounds_min.x > config.bounds_max.x
232        || config.bounds_min.y > config.bounds_max.y
233        || config.bounds_min.z > config.bounds_max.z
234    {
235        return Err(RoboticsError::InvalidParameter(
236            "bounds_min must not exceed bounds_max".to_string(),
237        ));
238    }
239
240    Ok(())
241}
242
243fn quantize(point: &Point3D, config: &GridAStar3DConfig) -> GridPoint3D {
244    GridPoint3D::new(
245        ((point.x - config.bounds_min.x) / config.resolution).round() as i32,
246        ((point.y - config.bounds_min.y) / config.resolution).round() as i32,
247        ((point.z - config.bounds_min.z) / config.resolution).round() as i32,
248    )
249}
250
251fn dequantize(point: GridPoint3D, config: &GridAStar3DConfig) -> Point3D {
252    Point3D::new(
253        config.bounds_min.x + point.x as f64 * config.resolution,
254        config.bounds_min.y + point.y as f64 * config.resolution,
255        config.bounds_min.z + point.z as f64 * config.resolution,
256    )
257}
258
259fn heuristic(a: GridPoint3D, b: GridPoint3D) -> f64 {
260    let dx = (a.x - b.x) as f64;
261    let dy = (a.y - b.y) as f64;
262    let dz = (a.z - b.z) as f64;
263    (dx * dx + dy * dy + dz * dz).sqrt()
264}
265
266fn build_motion_model(allow_diagonal: bool) -> Vec<(i32, i32, i32, f64)> {
267    let mut motions = vec![
268        (1, 0, 0, 1.0),
269        (-1, 0, 0, 1.0),
270        (0, 1, 0, 1.0),
271        (0, -1, 0, 1.0),
272        (0, 0, 1, 1.0),
273        (0, 0, -1, 1.0),
274    ];
275
276    if allow_diagonal {
277        for dx in -1_i32..=1 {
278            for dy in -1_i32..=1 {
279                for dz in -1_i32..=1 {
280                    if dx == 0 && dy == 0 && dz == 0 {
281                        continue;
282                    }
283                    if dx.abs() + dy.abs() + dz.abs() <= 1 {
284                        continue;
285                    }
286                    let cost = ((dx * dx + dy * dy + dz * dz) as f64).sqrt();
287                    motions.push((dx, dy, dz, cost));
288                }
289            }
290        }
291    }
292
293    motions
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299
300    fn planner_with_config(allow_diagonal: bool, obstacles: Vec<Point3D>) -> GridAStar3DPlanner {
301        GridAStar3DPlanner::new(
302            GridAStar3DConfig {
303                resolution: 1.0,
304                bounds_min: Point3D::new(0.0, 0.0, 0.0),
305                bounds_max: Point3D::new(4.0, 4.0, 4.0),
306                allow_diagonal,
307            },
308            &obstacles,
309        )
310        .expect("planner should be created")
311    }
312
313    #[test]
314    fn test_invalid_config_is_rejected() {
315        let result = GridAStar3DPlanner::new(
316            GridAStar3DConfig {
317                resolution: 0.0,
318                ..Default::default()
319            },
320            &[],
321        );
322
323        assert!(matches!(result, Err(RoboticsError::InvalidParameter(_))));
324    }
325
326    #[test]
327    fn test_grid_a_star_3d_finds_path() {
328        let planner = planner_with_config(true, vec![]);
329        let start = Point3D::new(0.0, 0.0, 0.0);
330        let goal = Point3D::new(3.0, 2.0, 1.0);
331
332        let path = planner.plan(start, goal).expect("path should exist");
333
334        assert!(!path.is_empty());
335        assert_eq!(path.points.first().copied(), Some(start));
336        assert_eq!(path.points.last().copied(), Some(goal));
337    }
338
339    #[test]
340    fn test_grid_a_star_3d_uses_diagonal_shortcut_when_enabled() {
341        let planner = planner_with_config(true, vec![]);
342
343        let path = planner
344            .plan(Point3D::new(0.0, 0.0, 0.0), Point3D::new(2.0, 2.0, 2.0))
345            .expect("path should exist");
346
347        assert_eq!(path.len(), 3);
348        assert_eq!(path.points[1], Point3D::new(1.0, 1.0, 1.0));
349    }
350
351    #[test]
352    fn test_grid_a_star_3d_avoids_obstacles() {
353        let obstacles = vec![
354            Point3D::new(1.0, 0.0, 0.0),
355            Point3D::new(1.0, 1.0, 0.0),
356            Point3D::new(1.0, 2.0, 0.0),
357        ];
358        let planner = planner_with_config(false, obstacles.clone());
359
360        let path = planner
361            .plan(Point3D::new(0.0, 0.0, 0.0), Point3D::new(2.0, 2.0, 0.0))
362            .expect("path should route around the wall");
363
364        assert!(path.points.iter().all(|point| !obstacles.contains(point)));
365        assert!(path.points.iter().any(|point| point.z > 0.0));
366        assert!(path.len() > 4);
367    }
368
369    #[test]
370    fn test_grid_a_star_3d_reports_no_path() {
371        let planner = planner_with_config(
372            false,
373            vec![
374                Point3D::new(0.0, 1.0, 1.0),
375                Point3D::new(2.0, 1.0, 1.0),
376                Point3D::new(1.0, 0.0, 1.0),
377                Point3D::new(1.0, 2.0, 1.0),
378                Point3D::new(1.0, 1.0, 0.0),
379                Point3D::new(1.0, 1.0, 2.0),
380            ],
381        );
382
383        let result = planner.plan(Point3D::new(0.0, 0.0, 0.0), Point3D::new(1.0, 1.0, 1.0));
384
385        assert!(matches!(result, Err(RoboticsError::PlanningError(_))));
386    }
387}