Skip to main content

rust_robotics_planning/
safe_decode_nav.rs

1//! SafeDec-lite: STL-shielded constrained decoding for grid navigation.
2//!
3//! This reproduces the core idea of constrained decoding for safe navigation
4//! policies without a learned model: a base navigation *policy* proposes a score
5//! for each discrete action, and a *shield* expressed in Signal Temporal Logic
6//! (STL) keeps the decoded action sequence safe.
7//!
8//! - The base policy here is a deterministic greedy goal-seeker — it always
9//!   prefers the action that most reduces distance to the goal, which happily
10//!   cuts straight through a hazard.
11//! - The shield is a set of STL specifications reusing [`crate::stl_cbs`]
12//!   primitives: an `always-avoid` geofence over a time interval (a hard
13//!   constraint — candidates entering it are pruned) and an `eventually-reach`
14//!   goal region (a soft reward that shapes the beam).
15//! - [`SafeDecoder::decode`] runs a deterministic constrained beam search and
16//!   returns both the greedy (unshielded) path and the shielded path, so the
17//!   number of steps where the shield overrode the greedy choice — and the
18//!   resulting robustness gain — is directly measurable.
19//!
20//! Everything is deterministic: ties are broken by action index and the beam is
21//! sorted with a stable total order.
22
23use crate::stl_cbs::{
24    stl_always_avoid_robustness, stl_eventually_reach_robustness, StlCbsPath, StlRectangle2D,
25    StlTimeInterval, StlTimedCell,
26};
27use rust_robotics_core::{RoboticsError, RoboticsResult};
28
29/// A timed STL specification rectangle.
30#[derive(Debug, Clone, Copy, PartialEq)]
31pub struct TimedRegion {
32    pub region: StlRectangle2D,
33    pub interval: StlTimeInterval,
34}
35
36impl TimedRegion {
37    pub fn new(region: StlRectangle2D, interval: StlTimeInterval) -> Self {
38        Self { region, interval }
39    }
40}
41
42/// Configuration for the STL-shielded navigation decoder.
43#[derive(Debug, Clone, PartialEq)]
44pub struct SafeNavConfig {
45    pub width: i32,
46    pub height: i32,
47    /// Static obstacle map indexed `[x][y]`.
48    pub blocked: Vec<Vec<bool>>,
49    /// Goal cell the base policy seeks.
50    pub goal: (i32, i32),
51    /// Number of decoded steps.
52    pub horizon: usize,
53    /// Beam width retained at each decoding step.
54    pub beam_width: usize,
55    /// `eventually-reach` goal region and the interval it must be reached in.
56    pub reach: TimedRegion,
57    /// `always-avoid` geofences (hard shield).
58    pub avoid: Vec<TimedRegion>,
59    /// Weight on the base-policy score.
60    pub policy_weight: f64,
61    /// Weight on the eventually-reach shaping reward.
62    pub reach_weight: f64,
63    /// Clearance the shield keeps from every geofence (in cells). With `0.0` the
64    /// shield only prunes cells strictly inside a geofence; a positive margin
65    /// keeps the decoded path that far outside.
66    pub safety_margin: f64,
67    /// Allow diagonal moves in the action set.
68    pub diagonal: bool,
69}
70
71impl SafeNavConfig {
72    /// A bounded empty grid with the given goal and reach region.
73    pub fn new(
74        width: i32,
75        height: i32,
76        goal: (i32, i32),
77        reach: TimedRegion,
78    ) -> RoboticsResult<Self> {
79        if width <= 0 || height <= 0 {
80            return Err(RoboticsError::InvalidParameter(
81                "safe-nav grid dimensions must be positive".to_string(),
82            ));
83        }
84        Ok(Self {
85            width,
86            height,
87            blocked: vec![vec![false; height as usize]; width as usize],
88            goal,
89            horizon: 40,
90            beam_width: 16,
91            reach,
92            avoid: Vec::new(),
93            policy_weight: 1.0,
94            reach_weight: 0.5,
95            safety_margin: 0.0,
96            diagonal: true,
97        })
98    }
99
100    fn in_bounds(&self, x: i32, y: i32) -> bool {
101        x >= 0 && y >= 0 && x < self.width && y < self.height
102    }
103
104    fn is_blocked(&self, x: i32, y: i32) -> bool {
105        if !self.in_bounds(x, y) {
106            return true;
107        }
108        self.blocked[x as usize][y as usize]
109    }
110}
111
112/// Result of a shielded decode, paired with the greedy baseline.
113#[derive(Debug, Clone, PartialEq)]
114pub struct SafeDecodePlan {
115    /// Shielded (STL-constrained) decoded path.
116    pub shielded_path: Vec<StlTimedCell>,
117    /// Greedy, unshielded decoded path (base policy argmax).
118    pub greedy_path: Vec<StlTimedCell>,
119    /// Eventually-reach robustness of the shielded path (>= 0 means satisfied).
120    pub reach_robustness: f64,
121    /// Worst always-avoid robustness of the shielded path (>= 0 means safe).
122    pub avoid_robustness: f64,
123    /// Eventually-reach robustness of the greedy path.
124    pub greedy_reach_robustness: f64,
125    /// Worst always-avoid robustness of the greedy path (< 0 means it cut
126    /// through a geofence).
127    pub greedy_avoid_robustness: f64,
128    /// Number of steps where the shielded action differs from the greedy action.
129    pub interventions: usize,
130    /// Whether the shielded path satisfies the reach spec.
131    pub reach_satisfied: bool,
132    /// Whether the shielded path satisfies every avoid spec.
133    pub avoid_satisfied: bool,
134}
135
136#[derive(Debug, Clone)]
137struct Beam {
138    path: Vec<StlTimedCell>,
139    score: f64,
140}
141
142/// STL-shielded constrained decoder over a grid navigation policy.
143#[derive(Debug, Clone)]
144pub struct SafeDecoder {
145    config: SafeNavConfig,
146}
147
148impl SafeDecoder {
149    pub fn new(config: SafeNavConfig) -> RoboticsResult<Self> {
150        if config.width <= 0 || config.height <= 0 {
151            return Err(RoboticsError::InvalidParameter(
152                "safe-nav grid dimensions must be positive".to_string(),
153            ));
154        }
155        if config.horizon == 0 {
156            return Err(RoboticsError::InvalidParameter(
157                "safe-nav horizon must be positive".to_string(),
158            ));
159        }
160        if config.beam_width == 0 {
161            return Err(RoboticsError::InvalidParameter(
162                "safe-nav beam width must be positive".to_string(),
163            ));
164        }
165        if config.blocked.len() != config.width as usize
166            || config
167                .blocked
168                .iter()
169                .any(|c| c.len() != config.height as usize)
170        {
171            return Err(RoboticsError::InvalidParameter(
172                "safe-nav blocked map must match grid dimensions".to_string(),
173            ));
174        }
175        Ok(Self { config })
176    }
177
178    pub fn config(&self) -> &SafeNavConfig {
179        &self.config
180    }
181
182    /// Action set: 4- or 8-connected moves plus a wait.
183    fn actions(&self) -> Vec<(i32, i32)> {
184        if self.config.diagonal {
185            vec![
186                (0, 0),
187                (1, 0),
188                (-1, 0),
189                (0, 1),
190                (0, -1),
191                (1, 1),
192                (1, -1),
193                (-1, 1),
194                (-1, -1),
195            ]
196        } else {
197            vec![(0, 0), (1, 0), (-1, 0), (0, 1), (0, -1)]
198        }
199    }
200
201    /// Base-policy score for moving to `(x, y)`: higher is better. Greedy
202    /// goal-seeking — the negative Euclidean distance to the goal.
203    fn policy_score(&self, x: i32, y: i32) -> f64 {
204        let dx = (x - self.config.goal.0) as f64;
205        let dy = (y - self.config.goal.1) as f64;
206        -(dx * dx + dy * dy).sqrt()
207    }
208
209    /// Whether stepping to `(x, y)` at time `t` violates any active geofence.
210    fn violates_geofence(&self, x: i32, y: i32, t: u64) -> bool {
211        self.config.avoid.iter().any(|spec| {
212            t >= spec.interval.start
213                && t <= spec.interval.end
214                && spec.region.inside_robustness(x as f64, y as f64) > -self.config.safety_margin
215        })
216    }
217
218    /// Greedy unshielded decode: argmax base policy each step (ties by action
219    /// index), respecting only the grid bounds and static obstacles.
220    pub fn greedy_decode(&self, start: (i32, i32)) -> Vec<StlTimedCell> {
221        let mut path = vec![StlTimedCell::new(start.0, start.1, 0)];
222        let mut current = start;
223        for step in 0..self.config.horizon {
224            let t = (step + 1) as u64;
225            let mut best: Option<(f64, (i32, i32))> = None;
226            for (dx, dy) in self.actions() {
227                let nx = current.0 + dx;
228                let ny = current.1 + dy;
229                if self.config.is_blocked(nx, ny) {
230                    continue;
231                }
232                let score = self.policy_score(nx, ny);
233                if best.is_none() || score > best.unwrap().0 {
234                    best = Some((score, (nx, ny)));
235                }
236            }
237            let Some((_, next)) = best else { break };
238            current = next;
239            path.push(StlTimedCell::new(next.0, next.1, t));
240        }
241        path
242    }
243
244    /// STL-shielded decode via deterministic constrained beam search.
245    pub fn decode(&self, start: (i32, i32)) -> RoboticsResult<SafeDecodePlan> {
246        if self.config.is_blocked(start.0, start.1) {
247            return Err(RoboticsError::InvalidParameter(
248                "safe-nav start cell is blocked or out of bounds".to_string(),
249            ));
250        }
251
252        let mut beams = vec![Beam {
253            path: vec![StlTimedCell::new(start.0, start.1, 0)],
254            score: 0.0,
255        }];
256
257        for step in 0..self.config.horizon {
258            let t = (step + 1) as u64;
259            let mut next_beams: Vec<Beam> = Vec::new();
260            for beam in &beams {
261                let last = *beam.path.last().unwrap();
262                for (dx, dy) in self.actions() {
263                    let nx = last.x + dx;
264                    let ny = last.y + dy;
265                    if self.config.is_blocked(nx, ny) {
266                        continue;
267                    }
268                    // Hard shield: prune candidates entering an active geofence.
269                    if self.violates_geofence(nx, ny, t) {
270                        continue;
271                    }
272                    let reach_shaping = self
273                        .config
274                        .reach
275                        .region
276                        .inside_robustness(nx as f64, ny as f64);
277                    let step_score = self.config.policy_weight * self.policy_score(nx, ny)
278                        + self.config.reach_weight * reach_shaping;
279                    let mut path = beam.path.clone();
280                    path.push(StlTimedCell::new(nx, ny, t));
281                    next_beams.push(Beam {
282                        path,
283                        score: beam.score + step_score,
284                    });
285                }
286            }
287
288            if next_beams.is_empty() {
289                // Shield pruned everything; keep the previous beam (wait in place
290                // is always an option unless the cell itself became unsafe).
291                break;
292            }
293
294            // Deterministic prune: sort by score desc, then by path tail for ties.
295            next_beams.sort_by(|a, b| {
296                b.score
297                    .partial_cmp(&a.score)
298                    .unwrap_or(std::cmp::Ordering::Equal)
299                    .then_with(|| path_key(&a.path).cmp(&path_key(&b.path)))
300            });
301            next_beams.truncate(self.config.beam_width);
302            beams = next_beams;
303        }
304
305        // Choose the best beam that satisfies eventually-reach, else best score.
306        let mut best_satisfying: Option<&Beam> = None;
307        let mut best_any: Option<&Beam> = None;
308        for beam in &beams {
309            let reach = self.reach_robustness(&beam.path)?;
310            if best_any.is_none() || beam.score > best_any.unwrap().score {
311                best_any = Some(beam);
312            }
313            if reach >= 0.0
314                && (best_satisfying.is_none() || beam.score > best_satisfying.unwrap().score)
315            {
316                best_satisfying = Some(beam);
317            }
318        }
319        let chosen = best_satisfying.or(best_any).unwrap().clone();
320
321        let greedy_path = self.greedy_decode(start);
322        let reach_robustness = self.reach_robustness(&chosen.path)?;
323        let avoid_robustness = self.avoid_robustness(&chosen.path)?;
324        let greedy_reach_robustness = self.reach_robustness(&greedy_path)?;
325        let greedy_avoid_robustness = self.avoid_robustness(&greedy_path)?;
326        let interventions = count_interventions(&chosen.path, &greedy_path);
327
328        Ok(SafeDecodePlan {
329            shielded_path: chosen.path,
330            greedy_path,
331            reach_robustness,
332            avoid_robustness,
333            greedy_reach_robustness,
334            greedy_avoid_robustness,
335            interventions,
336            reach_satisfied: reach_robustness >= 0.0,
337            avoid_satisfied: avoid_robustness >= 0.0,
338        })
339    }
340
341    fn reach_robustness(&self, path: &[StlTimedCell]) -> RoboticsResult<f64> {
342        let wrapped = StlCbsPath {
343            agent_id: 0,
344            waypoints: path.to_vec(),
345        };
346        stl_eventually_reach_robustness(
347            &wrapped,
348            self.config.reach.region,
349            self.config.reach.interval,
350        )
351    }
352
353    fn avoid_robustness(&self, path: &[StlTimedCell]) -> RoboticsResult<f64> {
354        if self.config.avoid.is_empty() {
355            return Ok(f64::INFINITY);
356        }
357        let wrapped = StlCbsPath {
358            agent_id: 0,
359            waypoints: path.to_vec(),
360        };
361        let mut worst = f64::INFINITY;
362        for spec in &self.config.avoid {
363            let r = stl_always_avoid_robustness(&wrapped, spec.region, spec.interval)?;
364            worst = worst.min(r);
365        }
366        Ok(worst)
367    }
368}
369
370/// Integer key for deterministic tie-breaking on a path's cells.
371fn path_key(path: &[StlTimedCell]) -> Vec<(i32, i32)> {
372    path.iter().map(|c| (c.x, c.y)).collect()
373}
374
375/// Count time steps where two paths occupy different cells.
376fn count_interventions(a: &[StlTimedCell], b: &[StlTimedCell]) -> usize {
377    let n = a.len().min(b.len());
378    let mut count = 0;
379    for i in 0..n {
380        if a[i].x != b[i].x || a[i].y != b[i].y {
381            count += 1;
382        }
383    }
384    count + a.len().abs_diff(b.len())
385}
386
387#[cfg(test)]
388mod tests {
389    use super::*;
390
391    fn rect(min_x: f64, max_x: f64, min_y: f64, max_y: f64) -> StlRectangle2D {
392        StlRectangle2D::new(min_x, max_x, min_y, max_y).unwrap()
393    }
394
395    fn interval(a: u64, b: u64) -> StlTimeInterval {
396        StlTimeInterval::new(a, b).unwrap()
397    }
398
399    /// A corridor scenario: goal straight ahead, a hazard right on the
400    /// straight-line path so greedy decoding cuts through it.
401    fn corridor() -> SafeNavConfig {
402        let goal = (10, 0);
403        let reach = TimedRegion::new(rect(9.0, 11.0, -1.0, 1.0), interval(0, 40));
404        let mut config = SafeNavConfig::new(13, 9, goal, reach).unwrap();
405        // Hazard band centered on y=0, between x=4 and x=6.
406        config.avoid = vec![TimedRegion::new(rect(4.0, 6.0, -1.5, 1.5), interval(0, 40))];
407        config.horizon = 30;
408        config
409    }
410
411    #[test]
412    fn greedy_cuts_through_hazard() {
413        let decoder = SafeDecoder::new(corridor()).unwrap();
414        let plan = decoder.decode((0, 0)).unwrap();
415        // The unshielded greedy policy drives straight through the geofence.
416        assert!(
417            plan.greedy_avoid_robustness < 0.0,
418            "greedy avoid robustness {} should be negative",
419            plan.greedy_avoid_robustness
420        );
421    }
422
423    #[test]
424    fn shield_keeps_path_safe_and_reaches_goal() {
425        let decoder = SafeDecoder::new(corridor()).unwrap();
426        let plan = decoder.decode((0, 0)).unwrap();
427        assert!(
428            plan.avoid_satisfied,
429            "shield avoid robustness {}",
430            plan.avoid_robustness
431        );
432        assert!(plan.avoid_robustness >= 0.0);
433        assert!(
434            plan.reach_satisfied,
435            "shield reach robustness {}",
436            plan.reach_robustness
437        );
438        // The shield had to override the greedy choice at least once.
439        assert!(plan.interventions > 0, "no interventions");
440    }
441
442    #[test]
443    fn decode_is_deterministic() {
444        let decoder = SafeDecoder::new(corridor()).unwrap();
445        let first = decoder.decode((0, 0)).unwrap();
446        let second = decoder.decode((0, 0)).unwrap();
447        assert_eq!(first.shielded_path, second.shielded_path);
448        assert_eq!(first.interventions, second.interventions);
449    }
450
451    #[test]
452    fn no_hazard_means_no_intervention() {
453        let goal = (8, 0);
454        let reach = TimedRegion::new(rect(7.0, 9.0, -1.0, 1.0), interval(0, 30));
455        let config = SafeNavConfig::new(11, 7, goal, reach).unwrap();
456        let decoder = SafeDecoder::new(config).unwrap();
457        let plan = decoder.decode((0, 0)).unwrap();
458        // With no geofence, the shielded path matches greedy.
459        assert_eq!(plan.interventions, 0);
460        assert!(plan.reach_satisfied);
461        assert!(plan.avoid_robustness.is_infinite());
462    }
463
464    #[test]
465    fn rejects_blocked_start() {
466        let goal = (5, 0);
467        let reach = TimedRegion::new(rect(4.0, 6.0, -1.0, 1.0), interval(0, 20));
468        let mut config = SafeNavConfig::new(7, 7, goal, reach).unwrap();
469        config.blocked[0][0] = true;
470        let decoder = SafeDecoder::new(config).unwrap();
471        assert!(decoder.decode((0, 0)).is_err());
472    }
473
474    #[test]
475    fn rejects_mismatched_blocked_map() {
476        let goal = (5, 0);
477        let reach = TimedRegion::new(rect(4.0, 6.0, -1.0, 1.0), interval(0, 20));
478        let mut config = SafeNavConfig::new(7, 7, goal, reach).unwrap();
479        config.blocked.pop();
480        assert!(SafeDecoder::new(config).is_err());
481    }
482}