1use crate::stl_cbs::{
24 stl_always_avoid_robustness, stl_eventually_reach_robustness, StlCbsPath, StlRectangle2D,
25 StlTimeInterval, StlTimedCell,
26};
27use rust_robotics_core::{RoboticsError, RoboticsResult};
28
29#[derive(Debug, Clone, Copy, PartialEq)]
31pub struct TimedRegion {
32 pub region: StlRectangle2D,
33 pub interval: StlTimeInterval,
34}
35
36impl TimedRegion {
37 pub fn new(region: StlRectangle2D, interval: StlTimeInterval) -> Self {
38 Self { region, interval }
39 }
40}
41
42#[derive(Debug, Clone, PartialEq)]
44pub struct SafeNavConfig {
45 pub width: i32,
46 pub height: i32,
47 pub blocked: Vec<Vec<bool>>,
49 pub goal: (i32, i32),
51 pub horizon: usize,
53 pub beam_width: usize,
55 pub reach: TimedRegion,
57 pub avoid: Vec<TimedRegion>,
59 pub policy_weight: f64,
61 pub reach_weight: f64,
63 pub safety_margin: f64,
67 pub diagonal: bool,
69}
70
71impl SafeNavConfig {
72 pub fn new(
74 width: i32,
75 height: i32,
76 goal: (i32, i32),
77 reach: TimedRegion,
78 ) -> RoboticsResult<Self> {
79 if width <= 0 || height <= 0 {
80 return Err(RoboticsError::InvalidParameter(
81 "safe-nav grid dimensions must be positive".to_string(),
82 ));
83 }
84 Ok(Self {
85 width,
86 height,
87 blocked: vec![vec![false; height as usize]; width as usize],
88 goal,
89 horizon: 40,
90 beam_width: 16,
91 reach,
92 avoid: Vec::new(),
93 policy_weight: 1.0,
94 reach_weight: 0.5,
95 safety_margin: 0.0,
96 diagonal: true,
97 })
98 }
99
100 fn in_bounds(&self, x: i32, y: i32) -> bool {
101 x >= 0 && y >= 0 && x < self.width && y < self.height
102 }
103
104 fn is_blocked(&self, x: i32, y: i32) -> bool {
105 if !self.in_bounds(x, y) {
106 return true;
107 }
108 self.blocked[x as usize][y as usize]
109 }
110}
111
112#[derive(Debug, Clone, PartialEq)]
114pub struct SafeDecodePlan {
115 pub shielded_path: Vec<StlTimedCell>,
117 pub greedy_path: Vec<StlTimedCell>,
119 pub reach_robustness: f64,
121 pub avoid_robustness: f64,
123 pub greedy_reach_robustness: f64,
125 pub greedy_avoid_robustness: f64,
128 pub interventions: usize,
130 pub reach_satisfied: bool,
132 pub avoid_satisfied: bool,
134}
135
136#[derive(Debug, Clone)]
137struct Beam {
138 path: Vec<StlTimedCell>,
139 score: f64,
140}
141
142#[derive(Debug, Clone)]
144pub struct SafeDecoder {
145 config: SafeNavConfig,
146}
147
148impl SafeDecoder {
149 pub fn new(config: SafeNavConfig) -> RoboticsResult<Self> {
150 if config.width <= 0 || config.height <= 0 {
151 return Err(RoboticsError::InvalidParameter(
152 "safe-nav grid dimensions must be positive".to_string(),
153 ));
154 }
155 if config.horizon == 0 {
156 return Err(RoboticsError::InvalidParameter(
157 "safe-nav horizon must be positive".to_string(),
158 ));
159 }
160 if config.beam_width == 0 {
161 return Err(RoboticsError::InvalidParameter(
162 "safe-nav beam width must be positive".to_string(),
163 ));
164 }
165 if config.blocked.len() != config.width as usize
166 || config
167 .blocked
168 .iter()
169 .any(|c| c.len() != config.height as usize)
170 {
171 return Err(RoboticsError::InvalidParameter(
172 "safe-nav blocked map must match grid dimensions".to_string(),
173 ));
174 }
175 Ok(Self { config })
176 }
177
178 pub fn config(&self) -> &SafeNavConfig {
179 &self.config
180 }
181
182 fn actions(&self) -> Vec<(i32, i32)> {
184 if self.config.diagonal {
185 vec![
186 (0, 0),
187 (1, 0),
188 (-1, 0),
189 (0, 1),
190 (0, -1),
191 (1, 1),
192 (1, -1),
193 (-1, 1),
194 (-1, -1),
195 ]
196 } else {
197 vec![(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)]
198 }
199 }
200
201 fn policy_score(&self, x: i32, y: i32) -> f64 {
204 let dx = (x - self.config.goal.0) as f64;
205 let dy = (y - self.config.goal.1) as f64;
206 -(dx * dx + dy * dy).sqrt()
207 }
208
209 fn violates_geofence(&self, x: i32, y: i32, t: u64) -> bool {
211 self.config.avoid.iter().any(|spec| {
212 t >= spec.interval.start
213 && t <= spec.interval.end
214 && spec.region.inside_robustness(x as f64, y as f64) > -self.config.safety_margin
215 })
216 }
217
218 pub fn greedy_decode(&self, start: (i32, i32)) -> Vec<StlTimedCell> {
221 let mut path = vec![StlTimedCell::new(start.0, start.1, 0)];
222 let mut current = start;
223 for step in 0..self.config.horizon {
224 let t = (step + 1) as u64;
225 let mut best: Option<(f64, (i32, i32))> = None;
226 for (dx, dy) in self.actions() {
227 let nx = current.0 + dx;
228 let ny = current.1 + dy;
229 if self.config.is_blocked(nx, ny) {
230 continue;
231 }
232 let score = self.policy_score(nx, ny);
233 if best.is_none() || score > best.unwrap().0 {
234 best = Some((score, (nx, ny)));
235 }
236 }
237 let Some((_, next)) = best else { break };
238 current = next;
239 path.push(StlTimedCell::new(next.0, next.1, t));
240 }
241 path
242 }
243
244 pub fn decode(&self, start: (i32, i32)) -> RoboticsResult<SafeDecodePlan> {
246 if self.config.is_blocked(start.0, start.1) {
247 return Err(RoboticsError::InvalidParameter(
248 "safe-nav start cell is blocked or out of bounds".to_string(),
249 ));
250 }
251
252 let mut beams = vec![Beam {
253 path: vec![StlTimedCell::new(start.0, start.1, 0)],
254 score: 0.0,
255 }];
256
257 for step in 0..self.config.horizon {
258 let t = (step + 1) as u64;
259 let mut next_beams: Vec<Beam> = Vec::new();
260 for beam in &beams {
261 let last = *beam.path.last().unwrap();
262 for (dx, dy) in self.actions() {
263 let nx = last.x + dx;
264 let ny = last.y + dy;
265 if self.config.is_blocked(nx, ny) {
266 continue;
267 }
268 if self.violates_geofence(nx, ny, t) {
270 continue;
271 }
272 let reach_shaping = self
273 .config
274 .reach
275 .region
276 .inside_robustness(nx as f64, ny as f64);
277 let step_score = self.config.policy_weight * self.policy_score(nx, ny)
278 + self.config.reach_weight * reach_shaping;
279 let mut path = beam.path.clone();
280 path.push(StlTimedCell::new(nx, ny, t));
281 next_beams.push(Beam {
282 path,
283 score: beam.score + step_score,
284 });
285 }
286 }
287
288 if next_beams.is_empty() {
289 break;
292 }
293
294 next_beams.sort_by(|a, b| {
296 b.score
297 .partial_cmp(&a.score)
298 .unwrap_or(std::cmp::Ordering::Equal)
299 .then_with(|| path_key(&a.path).cmp(&path_key(&b.path)))
300 });
301 next_beams.truncate(self.config.beam_width);
302 beams = next_beams;
303 }
304
305 let mut best_satisfying: Option<&Beam> = None;
307 let mut best_any: Option<&Beam> = None;
308 for beam in &beams {
309 let reach = self.reach_robustness(&beam.path)?;
310 if best_any.is_none() || beam.score > best_any.unwrap().score {
311 best_any = Some(beam);
312 }
313 if reach >= 0.0
314 && (best_satisfying.is_none() || beam.score > best_satisfying.unwrap().score)
315 {
316 best_satisfying = Some(beam);
317 }
318 }
319 let chosen = best_satisfying.or(best_any).unwrap().clone();
320
321 let greedy_path = self.greedy_decode(start);
322 let reach_robustness = self.reach_robustness(&chosen.path)?;
323 let avoid_robustness = self.avoid_robustness(&chosen.path)?;
324 let greedy_reach_robustness = self.reach_robustness(&greedy_path)?;
325 let greedy_avoid_robustness = self.avoid_robustness(&greedy_path)?;
326 let interventions = count_interventions(&chosen.path, &greedy_path);
327
328 Ok(SafeDecodePlan {
329 shielded_path: chosen.path,
330 greedy_path,
331 reach_robustness,
332 avoid_robustness,
333 greedy_reach_robustness,
334 greedy_avoid_robustness,
335 interventions,
336 reach_satisfied: reach_robustness >= 0.0,
337 avoid_satisfied: avoid_robustness >= 0.0,
338 })
339 }
340
341 fn reach_robustness(&self, path: &[StlTimedCell]) -> RoboticsResult<f64> {
342 let wrapped = StlCbsPath {
343 agent_id: 0,
344 waypoints: path.to_vec(),
345 };
346 stl_eventually_reach_robustness(
347 &wrapped,
348 self.config.reach.region,
349 self.config.reach.interval,
350 )
351 }
352
353 fn avoid_robustness(&self, path: &[StlTimedCell]) -> RoboticsResult<f64> {
354 if self.config.avoid.is_empty() {
355 return Ok(f64::INFINITY);
356 }
357 let wrapped = StlCbsPath {
358 agent_id: 0,
359 waypoints: path.to_vec(),
360 };
361 let mut worst = f64::INFINITY;
362 for spec in &self.config.avoid {
363 let r = stl_always_avoid_robustness(&wrapped, spec.region, spec.interval)?;
364 worst = worst.min(r);
365 }
366 Ok(worst)
367 }
368}
369
370fn path_key(path: &[StlTimedCell]) -> Vec<(i32, i32)> {
372 path.iter().map(|c| (c.x, c.y)).collect()
373}
374
375fn count_interventions(a: &[StlTimedCell], b: &[StlTimedCell]) -> usize {
377 let n = a.len().min(b.len());
378 let mut count = 0;
379 for i in 0..n {
380 if a[i].x != b[i].x || a[i].y != b[i].y {
381 count += 1;
382 }
383 }
384 count + a.len().abs_diff(b.len())
385}
386
387#[cfg(test)]
388mod tests {
389 use super::*;
390
391 fn rect(min_x: f64, max_x: f64, min_y: f64, max_y: f64) -> StlRectangle2D {
392 StlRectangle2D::new(min_x, max_x, min_y, max_y).unwrap()
393 }
394
395 fn interval(a: u64, b: u64) -> StlTimeInterval {
396 StlTimeInterval::new(a, b).unwrap()
397 }
398
399 fn corridor() -> SafeNavConfig {
402 let goal = (10, 0);
403 let reach = TimedRegion::new(rect(9.0, 11.0, -1.0, 1.0), interval(0, 40));
404 let mut config = SafeNavConfig::new(13, 9, goal, reach).unwrap();
405 config.avoid = vec![TimedRegion::new(rect(4.0, 6.0, -1.5, 1.5), interval(0, 40))];
407 config.horizon = 30;
408 config
409 }
410
411 #[test]
412 fn greedy_cuts_through_hazard() {
413 let decoder = SafeDecoder::new(corridor()).unwrap();
414 let plan = decoder.decode((0, 0)).unwrap();
415 assert!(
417 plan.greedy_avoid_robustness < 0.0,
418 "greedy avoid robustness {} should be negative",
419 plan.greedy_avoid_robustness
420 );
421 }
422
423 #[test]
424 fn shield_keeps_path_safe_and_reaches_goal() {
425 let decoder = SafeDecoder::new(corridor()).unwrap();
426 let plan = decoder.decode((0, 0)).unwrap();
427 assert!(
428 plan.avoid_satisfied,
429 "shield avoid robustness {}",
430 plan.avoid_robustness
431 );
432 assert!(plan.avoid_robustness >= 0.0);
433 assert!(
434 plan.reach_satisfied,
435 "shield reach robustness {}",
436 plan.reach_robustness
437 );
438 assert!(plan.interventions > 0, "no interventions");
440 }
441
442 #[test]
443 fn decode_is_deterministic() {
444 let decoder = SafeDecoder::new(corridor()).unwrap();
445 let first = decoder.decode((0, 0)).unwrap();
446 let second = decoder.decode((0, 0)).unwrap();
447 assert_eq!(first.shielded_path, second.shielded_path);
448 assert_eq!(first.interventions, second.interventions);
449 }
450
451 #[test]
452 fn no_hazard_means_no_intervention() {
453 let goal = (8, 0);
454 let reach = TimedRegion::new(rect(7.0, 9.0, -1.0, 1.0), interval(0, 30));
455 let config = SafeNavConfig::new(11, 7, goal, reach).unwrap();
456 let decoder = SafeDecoder::new(config).unwrap();
457 let plan = decoder.decode((0, 0)).unwrap();
458 assert_eq!(plan.interventions, 0);
460 assert!(plan.reach_satisfied);
461 assert!(plan.avoid_robustness.is_infinite());
462 }
463
464 #[test]
465 fn rejects_blocked_start() {
466 let goal = (5, 0);
467 let reach = TimedRegion::new(rect(4.0, 6.0, -1.0, 1.0), interval(0, 20));
468 let mut config = SafeNavConfig::new(7, 7, goal, reach).unwrap();
469 config.blocked[0][0] = true;
470 let decoder = SafeDecoder::new(config).unwrap();
471 assert!(decoder.decode((0, 0)).is_err());
472 }
473
474 #[test]
475 fn rejects_mismatched_blocked_map() {
476 let goal = (5, 0);
477 let reach = TimedRegion::new(rect(4.0, 6.0, -1.0, 1.0), interval(0, 20));
478 let mut config = SafeNavConfig::new(7, 7, goal, reach).unwrap();
479 config.blocked.pop();
480 assert!(SafeDecoder::new(config).is_err());
481 }
482}