Skip to main content

rust_robotics_planning/
hierarchical_mapf.rs

1//! Hierarchical MAPF replanning foundation.
2//!
3//! This module adds a region-level trigger layer above STL-CBS. It first plans
4//! independent shortest paths, finds coarse region conflicts, then replans only
5//! the affected agent groups with CBS.
6
7use std::collections::{BTreeSet, HashMap, HashSet, VecDeque};
8
9use rust_robotics_core::{RoboticsError, RoboticsResult};
10
11use crate::stl_cbs::{StlCbsAgent, StlCbsConfig, StlCbsPath, StlCbsPlanner};
12
13/// Agent query for hierarchical MAPF.
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub struct HierarchicalMapfAgent2D {
16    pub id: usize,
17    pub start: (i32, i32),
18    pub goal: (i32, i32),
19}
20
21impl HierarchicalMapfAgent2D {
22    pub fn new(id: usize, start: (i32, i32), goal: (i32, i32)) -> Self {
23        Self { id, start, goal }
24    }
25
26    fn as_stl_cbs_agent(self) -> StlCbsAgent {
27        StlCbsAgent::new(self.id, self.start, self.goal)
28    }
29}
30
31/// Coarse region id for hierarchical grouping.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
33pub struct HierarchicalMapfRegion2D {
34    pub rx: i32,
35    pub ry: i32,
36}
37
38impl HierarchicalMapfRegion2D {
39    pub fn new(rx: i32, ry: i32) -> Self {
40        Self { rx, ry }
41    }
42}
43
44/// Planner configuration.
45#[derive(Debug, Clone, PartialEq)]
46pub struct HierarchicalMapfConfig2D {
47    pub width: i32,
48    pub height: i32,
49    pub obstacle_map: Vec<Vec<bool>>,
50    pub region_width: i32,
51    pub region_height: i32,
52    pub max_time: u64,
53    pub max_cbs_nodes: usize,
54}
55
56impl HierarchicalMapfConfig2D {
57    pub fn new(width: i32, height: i32, region_width: i32, region_height: i32) -> Self {
58        Self {
59            width,
60            height,
61            obstacle_map: vec![vec![false; height.max(0) as usize]; width.max(0) as usize],
62            region_width,
63            region_height,
64            max_time: 96,
65            max_cbs_nodes: 4_096,
66        }
67    }
68}
69
70/// Compressed route through coarse regions for one agent.
71#[derive(Debug, Clone, PartialEq, Eq)]
72pub struct HierarchicalMapfRegionRoute2D {
73    pub agent_id: usize,
74    pub regions: Vec<HierarchicalMapfRegion2D>,
75}
76
77/// A coarse conflict used to trigger local replanning.
78#[derive(Debug, Clone, PartialEq, Eq)]
79pub struct HierarchicalMapfRegionConflict2D {
80    pub region: HierarchicalMapfRegion2D,
81    pub t: u64,
82    pub agent_ids: Vec<usize>,
83}
84
85/// One group that was replanned with CBS.
86#[derive(Debug, Clone, PartialEq, Eq)]
87pub struct HierarchicalMapfReplannedGroup2D {
88    pub agent_ids: Vec<usize>,
89    pub cbs_total_cost: u64,
90    pub cbs_conflicts_resolved: usize,
91}
92
93/// Hierarchical MAPF plan summary.
94#[derive(Debug, Clone, PartialEq)]
95pub struct HierarchicalMapfPlan2D {
96    pub paths: Vec<StlCbsPath>,
97    pub independent_paths: Vec<StlCbsPath>,
98    pub region_routes: Vec<HierarchicalMapfRegionRoute2D>,
99    pub region_conflicts: Vec<HierarchicalMapfRegionConflict2D>,
100    pub replanned_groups: Vec<HierarchicalMapfReplannedGroup2D>,
101    pub independent_cell_conflicts: usize,
102    pub final_cell_conflicts: usize,
103    pub total_cost: u64,
104    pub fallback_full_replan: bool,
105}
106
107/// Region-triggered MAPF replanner.
108#[derive(Debug, Clone, PartialEq)]
109pub struct HierarchicalMapfPlanner2D {
110    config: HierarchicalMapfConfig2D,
111    low_level: StlCbsPlanner,
112}
113
114impl HierarchicalMapfPlanner2D {
115    pub fn new(config: HierarchicalMapfConfig2D) -> RoboticsResult<Self> {
116        validate_config(&config)?;
117        let low_level = StlCbsPlanner::new(StlCbsConfig {
118            width: config.width,
119            height: config.height,
120            obstacle_map: config.obstacle_map.clone(),
121            max_time: config.max_time,
122            max_cbs_nodes: config.max_cbs_nodes,
123            allow_wait: true,
124        })?;
125        Ok(Self { config, low_level })
126    }
127
128    pub fn config(&self) -> &HierarchicalMapfConfig2D {
129        &self.config
130    }
131
132    pub fn region_of_cell(&self, x: i32, y: i32) -> RoboticsResult<HierarchicalMapfRegion2D> {
133        if !self.is_in_bounds(x, y) {
134            return Err(RoboticsError::InvalidParameter(
135                "hierarchical MAPF cell is out of bounds".to_string(),
136            ));
137        }
138        Ok(HierarchicalMapfRegion2D::new(
139            x / self.config.region_width,
140            y / self.config.region_height,
141        ))
142    }
143
144    pub fn region_route_for_path(
145        &self,
146        path: &StlCbsPath,
147    ) -> RoboticsResult<HierarchicalMapfRegionRoute2D> {
148        if path.waypoints.is_empty() {
149            return Err(RoboticsError::InvalidParameter(
150                "hierarchical MAPF path must contain waypoints".to_string(),
151            ));
152        }
153        let mut regions = Vec::new();
154        for waypoint in &path.waypoints {
155            let region = self.region_of_cell(waypoint.x, waypoint.y)?;
156            if regions.last().copied() != Some(region) {
157                regions.push(region);
158            }
159        }
160        Ok(HierarchicalMapfRegionRoute2D {
161            agent_id: path.agent_id,
162            regions,
163        })
164    }
165
166    pub fn region_conflicts(
167        &self,
168        paths: &[StlCbsPath],
169    ) -> RoboticsResult<Vec<HierarchicalMapfRegionConflict2D>> {
170        let mut conflicts = Vec::new();
171        for t in 0..=self.config.max_time {
172            let mut occupancy: HashMap<HierarchicalMapfRegion2D, BTreeSet<usize>> = HashMap::new();
173            for path in paths {
174                let position = path.position_at(t);
175                let region = self.region_of_cell(position.x, position.y)?;
176                occupancy.entry(region).or_default().insert(path.agent_id);
177            }
178            for (region, agents) in occupancy {
179                if agents.len() > 1 {
180                    conflicts.push(HierarchicalMapfRegionConflict2D {
181                        region,
182                        t,
183                        agent_ids: agents.into_iter().collect(),
184                    });
185                }
186            }
187        }
188        Ok(conflicts)
189    }
190
191    pub fn plan(
192        &self,
193        agents: &[HierarchicalMapfAgent2D],
194    ) -> RoboticsResult<HierarchicalMapfPlan2D> {
195        validate_agents(agents)?;
196        let cbs_agents = agents
197            .iter()
198            .copied()
199            .map(HierarchicalMapfAgent2D::as_stl_cbs_agent)
200            .collect::<Vec<_>>();
201        let independent_paths = self.low_level.plan_independent(&cbs_agents)?;
202        let region_routes = independent_paths
203            .iter()
204            .map(|path| self.region_route_for_path(path))
205            .collect::<RoboticsResult<Vec<_>>>()?;
206        let region_conflicts = self.region_conflicts(&independent_paths)?;
207        let independent_cell_conflicts =
208            cell_conflict_count(&independent_paths, self.config.max_time);
209
210        let mut paths = independent_paths.clone();
211        let mut replanned_groups = Vec::new();
212        let affected_components = connected_agent_components(&region_conflicts);
213
214        for component in affected_components {
215            if component.len() <= 1 {
216                continue;
217            }
218            let group_agents = cbs_agents
219                .iter()
220                .copied()
221                .filter(|agent| component.contains(&agent.id))
222                .collect::<Vec<_>>();
223            let group_plan = self.low_level.plan(&group_agents)?;
224            for replanned_path in &group_plan.paths {
225                if let Some(index) = paths
226                    .iter()
227                    .position(|path| path.agent_id == replanned_path.agent_id)
228                {
229                    paths[index] = replanned_path.clone();
230                }
231            }
232            replanned_groups.push(HierarchicalMapfReplannedGroup2D {
233                agent_ids: sorted_component(&component),
234                cbs_total_cost: group_plan.total_cost,
235                cbs_conflicts_resolved: group_plan.conflicts_resolved,
236            });
237        }
238
239        let mut fallback_full_replan = false;
240        let mut final_cell_conflicts = cell_conflict_count(&paths, self.config.max_time);
241        if final_cell_conflicts > 0 {
242            let full_plan = self.low_level.plan(&cbs_agents)?;
243            paths = full_plan.paths;
244            replanned_groups.push(HierarchicalMapfReplannedGroup2D {
245                agent_ids: agents.iter().map(|agent| agent.id).collect(),
246                cbs_total_cost: full_plan.total_cost,
247                cbs_conflicts_resolved: full_plan.conflicts_resolved,
248            });
249            fallback_full_replan = true;
250            final_cell_conflicts = cell_conflict_count(&paths, self.config.max_time);
251        }
252
253        Ok(HierarchicalMapfPlan2D {
254            total_cost: paths.iter().map(StlCbsPath::arrival_time).sum(),
255            paths,
256            independent_paths,
257            region_routes,
258            region_conflicts,
259            replanned_groups,
260            independent_cell_conflicts,
261            final_cell_conflicts,
262            fallback_full_replan,
263        })
264    }
265
266    fn is_in_bounds(&self, x: i32, y: i32) -> bool {
267        x >= 0 && y >= 0 && x < self.config.width && y < self.config.height
268    }
269}
270
271pub fn cell_conflict_count(paths: &[StlCbsPath], max_time: u64) -> usize {
272    let mut count = 0;
273    for t in 0..=max_time {
274        for i in 0..paths.len() {
275            for j in i + 1..paths.len() {
276                let a = paths[i].position_at(t);
277                let b = paths[j].position_at(t);
278                if a.x == b.x && a.y == b.y {
279                    count += 1;
280                }
281                if t > 0 {
282                    let a_prev = paths[i].position_at(t - 1);
283                    let b_prev = paths[j].position_at(t - 1);
284                    if a_prev.x == b.x && a_prev.y == b.y && b_prev.x == a.x && b_prev.y == a.y {
285                        count += 1;
286                    }
287                }
288            }
289        }
290    }
291    count
292}
293
294fn connected_agent_components(
295    conflicts: &[HierarchicalMapfRegionConflict2D],
296) -> Vec<HashSet<usize>> {
297    let mut adjacency: HashMap<usize, HashSet<usize>> = HashMap::new();
298    for conflict in conflicts {
299        for &agent in &conflict.agent_ids {
300            adjacency.entry(agent).or_default();
301        }
302        for i in 0..conflict.agent_ids.len() {
303            for j in i + 1..conflict.agent_ids.len() {
304                adjacency
305                    .entry(conflict.agent_ids[i])
306                    .or_default()
307                    .insert(conflict.agent_ids[j]);
308                adjacency
309                    .entry(conflict.agent_ids[j])
310                    .or_default()
311                    .insert(conflict.agent_ids[i]);
312            }
313        }
314    }
315
316    let mut visited = HashSet::new();
317    let mut components = Vec::new();
318    for &seed in adjacency.keys() {
319        if visited.contains(&seed) {
320            continue;
321        }
322        let mut component = HashSet::new();
323        let mut queue = VecDeque::new();
324        queue.push_back(seed);
325        visited.insert(seed);
326        while let Some(agent) = queue.pop_front() {
327            component.insert(agent);
328            if let Some(neighbors) = adjacency.get(&agent) {
329                for &neighbor in neighbors {
330                    if visited.insert(neighbor) {
331                        queue.push_back(neighbor);
332                    }
333                }
334            }
335        }
336        components.push(component);
337    }
338    components
339}
340
341fn sorted_component(component: &HashSet<usize>) -> Vec<usize> {
342    let mut ids = component.iter().copied().collect::<Vec<_>>();
343    ids.sort_unstable();
344    ids
345}
346
347fn validate_config(config: &HierarchicalMapfConfig2D) -> RoboticsResult<()> {
348    if config.width <= 0
349        || config.height <= 0
350        || config.region_width <= 0
351        || config.region_height <= 0
352    {
353        return Err(RoboticsError::InvalidParameter(
354            "hierarchical MAPF dimensions and region sizes must be positive".to_string(),
355        ));
356    }
357    if config.max_time == 0 || config.max_cbs_nodes == 0 {
358        return Err(RoboticsError::InvalidParameter(
359            "hierarchical MAPF max_time and max_cbs_nodes must be positive".to_string(),
360        ));
361    }
362    if config.obstacle_map.len() != config.width as usize {
363        return Err(RoboticsError::InvalidParameter(
364            "hierarchical MAPF obstacle_map x-dimension must match width".to_string(),
365        ));
366    }
367    for column in &config.obstacle_map {
368        if column.len() != config.height as usize {
369            return Err(RoboticsError::InvalidParameter(
370                "hierarchical MAPF obstacle_map y-dimension must match height".to_string(),
371            ));
372        }
373    }
374    Ok(())
375}
376
377fn validate_agents(agents: &[HierarchicalMapfAgent2D]) -> RoboticsResult<()> {
378    if agents.is_empty() {
379        return Err(RoboticsError::InvalidParameter(
380            "hierarchical MAPF requires at least one agent".to_string(),
381        ));
382    }
383    let mut ids = HashSet::new();
384    for agent in agents {
385        if !ids.insert(agent.id) {
386            return Err(RoboticsError::InvalidParameter(
387                "hierarchical MAPF agent ids must be unique".to_string(),
388            ));
389        }
390    }
391    Ok(())
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397
398    fn planner() -> HierarchicalMapfPlanner2D {
399        HierarchicalMapfPlanner2D::new(HierarchicalMapfConfig2D {
400            max_time: 18,
401            ..HierarchicalMapfConfig2D::new(12, 8, 4, 4)
402        })
403        .unwrap()
404    }
405
406    fn agents() -> Vec<HierarchicalMapfAgent2D> {
407        vec![
408            HierarchicalMapfAgent2D::new(0, (0, 3), (11, 3)),
409            HierarchicalMapfAgent2D::new(1, (11, 3), (0, 3)),
410            HierarchicalMapfAgent2D::new(2, (5, 0), (5, 7)),
411            HierarchicalMapfAgent2D::new(3, (0, 7), (3, 7)),
412        ]
413    }
414
415    #[test]
416    fn region_routes_are_compressed() {
417        let planner = planner();
418        let paths = planner
419            .low_level
420            .plan_independent(
421                &agents()
422                    .iter()
423                    .copied()
424                    .map(HierarchicalMapfAgent2D::as_stl_cbs_agent)
425                    .collect::<Vec<_>>(),
426            )
427            .unwrap();
428        let route = planner.region_route_for_path(&paths[0]).unwrap();
429
430        assert_eq!(route.agent_id, 0);
431        assert!(route.regions.len() < paths[0].waypoints.len());
432        assert_eq!(route.regions[0], HierarchicalMapfRegion2D::new(0, 0));
433    }
434
435    #[test]
436    fn hierarchical_replanning_resolves_cell_conflicts() {
437        let planner = planner();
438        let plan = planner.plan(&agents()).unwrap();
439
440        assert!(plan.independent_cell_conflicts > 0);
441        assert_eq!(plan.final_cell_conflicts, 0);
442        assert!(!plan.replanned_groups.is_empty());
443        assert!(!plan.region_conflicts.is_empty());
444    }
445}