Skip to main content

rust_robotics_planning/
branchout_multimodal.rs

1//! BranchOut-lite multimodal driving planner.
2//!
3//! This is a deterministic 2-D reproduction slice of BranchOut's core
4//! multimodal planning idea: emit multiple plausible driving trajectories with
5//! mixture weights, then evaluate distributional coverage instead of only a
6//! single ground-truth path.
7
8use rust_robotics_core::{RoboticsError, RoboticsResult};
9
10const EPS: f64 = 1e-9;
11
12/// Ego state on a lane-aligned 2-D road.
13#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct BranchOutPose2D {
15    pub x: f64,
16    pub y: f64,
17    pub speed: f64,
18}
19
20impl BranchOutPose2D {
21    pub fn new(x: f64, y: f64, speed: f64) -> Self {
22        Self { x, y, speed }
23    }
24}
25
26/// Circular traffic or road obstacle.
27#[derive(Debug, Clone, Copy, PartialEq)]
28pub struct BranchOutObstacle2D {
29    pub x: f64,
30    pub y: f64,
31    pub radius: f64,
32}
33
34impl BranchOutObstacle2D {
35    pub fn new(x: f64, y: f64, radius: f64) -> Self {
36        Self { x, y, radius }
37    }
38}
39
40/// Coarse driving command/mode used by the compact GMM-like head.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub enum BranchOutDecisionMode2D {
43    KeepLane,
44    Yield,
45    LaneChangeLeft,
46    LaneChangeRight,
47}
48
49impl BranchOutDecisionMode2D {
50    pub fn label(self) -> &'static str {
51        match self {
52            Self::KeepLane => "keep-lane",
53            Self::Yield => "yield",
54            Self::LaneChangeLeft => "lane-change-left",
55            Self::LaneChangeRight => "lane-change-right",
56        }
57    }
58}
59
60/// Lane-level driving scene for a compact multimodal planner.
61#[derive(Debug, Clone, PartialEq)]
62pub struct BranchOutDrivingScene2D {
63    pub start: BranchOutPose2D,
64    pub lane_width: f64,
65    pub lane_count_each_side: i32,
66    pub route_length: f64,
67    pub desired_speed: f64,
68    pub obstacles: Vec<BranchOutObstacle2D>,
69}
70
71impl BranchOutDrivingScene2D {
72    pub fn simple_overtake() -> Self {
73        Self {
74            start: BranchOutPose2D::new(0.0, 0.0, 2.2),
75            lane_width: 1.2,
76            lane_count_each_side: 1,
77            route_length: 9.0,
78            desired_speed: 2.2,
79            obstacles: vec![BranchOutObstacle2D::new(4.1, 0.0, 0.42)],
80        }
81    }
82
83    /// Overtake scene with enough lateral room that a lane change clears the
84    /// stalled obstacle (the closed-loop planner prefers to overtake here).
85    pub fn wide_overtake() -> Self {
86        Self {
87            start: BranchOutPose2D::new(0.0, 0.0, 2.2),
88            lane_width: 1.6,
89            lane_count_each_side: 1,
90            route_length: 9.0,
91            desired_speed: 2.2,
92            obstacles: vec![BranchOutObstacle2D::new(4.1, 0.0, 0.42)],
93        }
94    }
95
96    /// Single-lane blocked scene with no room to pass, so the closed-loop
97    /// planner must yield behind the obstacle rather than overtake.
98    pub fn forced_yield() -> Self {
99        Self {
100            start: BranchOutPose2D::new(0.0, 0.0, 2.2),
101            lane_width: 1.2,
102            lane_count_each_side: 0,
103            route_length: 9.0,
104            desired_speed: 2.2,
105            obstacles: vec![BranchOutObstacle2D::new(4.1, 0.0, 0.42)],
106        }
107    }
108
109    pub fn lane_center(&self, lane_index: i32) -> f64 {
110        lane_index as f64 * self.lane_width
111    }
112
113    pub fn nearest_lane_index(&self, y: f64) -> i32 {
114        (y / self.lane_width).round().clamp(
115            -(self.lane_count_each_side as f64),
116            self.lane_count_each_side as f64,
117        ) as i32
118    }
119}
120
121/// Planner configuration for BranchOut-lite.
122#[derive(Debug, Clone, PartialEq)]
123pub struct BranchOutPlannerConfig2D {
124    pub horizon_steps: usize,
125    pub dt: f64,
126    pub ego_radius: f64,
127    pub probability_temperature: f64,
128    pub progress_weight: f64,
129    pub collision_weight: f64,
130    pub lane_weight: f64,
131    pub comfort_weight: f64,
132    pub route_weight: f64,
133    pub modes: Vec<BranchOutDecisionMode2D>,
134}
135
136impl Default for BranchOutPlannerConfig2D {
137    fn default() -> Self {
138        Self {
139            horizon_steps: 28,
140            dt: 0.12,
141            ego_radius: 0.32,
142            probability_temperature: 4.0,
143            progress_weight: 1.4,
144            collision_weight: 80.0,
145            lane_weight: 12.0,
146            comfort_weight: 0.35,
147            route_weight: 0.12,
148            modes: vec![
149                BranchOutDecisionMode2D::KeepLane,
150                BranchOutDecisionMode2D::Yield,
151                BranchOutDecisionMode2D::LaneChangeLeft,
152                BranchOutDecisionMode2D::LaneChangeRight,
153            ],
154        }
155    }
156}
157
158/// One mode trajectory with a GMM-like mixture probability.
159#[derive(Debug, Clone, PartialEq)]
160pub struct BranchOutTrajectory2D {
161    pub mode: BranchOutDecisionMode2D,
162    pub probability: f64,
163    pub cost: f64,
164    pub poses: Vec<BranchOutPose2D>,
165    pub collision_risk: f64,
166    pub comfort_cost: f64,
167    pub route_completion: f64,
168}
169
170impl BranchOutTrajectory2D {
171    pub fn final_pose(&self) -> BranchOutPose2D {
172        *self
173            .poses
174            .last()
175            .expect("validated BranchOut trajectory is non-empty")
176    }
177}
178
179/// Planner result with one trajectory per decision mode.
180#[derive(Debug, Clone, PartialEq)]
181pub struct BranchOutPlan2D {
182    pub trajectories: Vec<BranchOutTrajectory2D>,
183}
184
185impl BranchOutPlan2D {
186    pub fn best(&self) -> Option<&BranchOutTrajectory2D> {
187        self.trajectories
188            .iter()
189            .max_by(|a, b| a.probability.total_cmp(&b.probability))
190    }
191
192    pub fn probability_sum(&self) -> f64 {
193        self.trajectories
194            .iter()
195            .map(|trajectory| trajectory.probability)
196            .sum()
197    }
198}
199
200/// Multimodal evaluation metrics inspired by BranchOut's distributional focus.
201#[derive(Debug, Clone, Copy, PartialEq)]
202pub struct BranchOutMultimodalMetrics2D {
203    pub mode_count: usize,
204    pub mean_pairwise_final_distance: f64,
205    pub mean_pairwise_frechet: f64,
206    pub min_ground_truth_frechet: f64,
207    pub negative_log_likelihood: f64,
208    pub speed_jsd: f64,
209    pub expected_route_completion: f64,
210}
211
212/// Configuration for a receding-horizon closed-loop BranchOut rollout.
213#[derive(Debug, Clone, Copy, PartialEq)]
214pub struct BranchOutClosedLoopConfig2D {
215    /// Number of closed-loop control steps to execute.
216    pub steps: usize,
217    /// Time-to-collision threshold; steps below it count as risky.
218    pub ttc_threshold: f64,
219    /// Route fraction at or above which the goal counts as reached.
220    pub goal_completion: f64,
221    /// Maximum lateral speed used when tracking the selected mode's lane.
222    pub max_lateral_speed: f64,
223}
224
225impl Default for BranchOutClosedLoopConfig2D {
226    fn default() -> Self {
227        Self {
228            steps: 40,
229            ttc_threshold: 1.5,
230            goal_completion: 0.95,
231            max_lateral_speed: 0.9,
232        }
233    }
234}
235
236/// Closed-loop driving metrics from a receding-horizon BranchOut rollout.
237#[derive(Debug, Clone, PartialEq)]
238pub struct BranchOutClosedLoopMetrics2D {
239    pub steps: usize,
240    pub route_completion: f64,
241    pub reached_goal: bool,
242    pub collision_steps: usize,
243    pub no_collision_rate: f64,
244    pub min_clearance: f64,
245    pub mean_comfort_cost: f64,
246    pub min_time_to_collision: f64,
247    pub risky_ttc_steps: usize,
248    /// Realized closed-loop ego path (start pose plus one pose per step).
249    pub executed_path: Vec<BranchOutPose2D>,
250    /// Decision mode selected at each control step.
251    pub mode_sequence: Vec<BranchOutDecisionMode2D>,
252}
253
254/// Deterministic multimodal planner over coarse driving modes.
255#[derive(Debug, Clone, PartialEq)]
256pub struct BranchOutPlanner2D {
257    config: BranchOutPlannerConfig2D,
258}
259
260impl BranchOutPlanner2D {
261    pub fn new(config: BranchOutPlannerConfig2D) -> RoboticsResult<Self> {
262        validate_config(&config)?;
263        Ok(Self { config })
264    }
265
266    pub fn config(&self) -> &BranchOutPlannerConfig2D {
267        &self.config
268    }
269
270    pub fn plan(&self, scene: &BranchOutDrivingScene2D) -> RoboticsResult<BranchOutPlan2D> {
271        validate_scene(scene)?;
272        let mut trajectories = self
273            .config
274            .modes
275            .iter()
276            .map(|&mode| self.rollout_mode(scene, mode))
277            .collect::<RoboticsResult<Vec<_>>>()?;
278        assign_mixture_probabilities(&mut trajectories, self.config.probability_temperature)?;
279        Ok(BranchOutPlan2D { trajectories })
280    }
281
282    pub fn evaluate_multimodal(
283        &self,
284        plan: &BranchOutPlan2D,
285        ground_truths: &[Vec<BranchOutPose2D>],
286    ) -> RoboticsResult<BranchOutMultimodalMetrics2D> {
287        validate_plan(plan)?;
288        if ground_truths.is_empty() {
289            return Err(RoboticsError::InvalidParameter(
290                "BranchOut ground_truths must be non-empty".to_string(),
291            ));
292        }
293        for ground_truth in ground_truths {
294            validate_poses(ground_truth)?;
295        }
296
297        let mean_pairwise_final_distance = mean_pairwise_final_distance(&plan.trajectories);
298        let mean_pairwise_frechet = mean_pairwise_frechet(&plan.trajectories);
299        let min_ground_truth_frechet = ground_truths
300            .iter()
301            .map(|gt| {
302                plan.trajectories
303                    .iter()
304                    .map(|trajectory| discrete_frechet(&trajectory.poses, gt))
305                    .fold(f64::INFINITY, f64::min)
306            })
307            .sum::<f64>()
308            / ground_truths.len() as f64;
309        let negative_log_likelihood = trajectory_set_nll(&plan.trajectories, ground_truths, 0.75);
310        let speed_jsd = speed_jsd(&plan.trajectories, ground_truths, 8, 4.0);
311        let expected_route_completion = plan
312            .trajectories
313            .iter()
314            .map(|trajectory| trajectory.probability * trajectory.route_completion)
315            .sum();
316
317        Ok(BranchOutMultimodalMetrics2D {
318            mode_count: plan.trajectories.len(),
319            mean_pairwise_final_distance,
320            mean_pairwise_frechet,
321            min_ground_truth_frechet,
322            negative_log_likelihood,
323            speed_jsd,
324            expected_route_completion,
325        })
326    }
327
328    /// Run a receding-horizon closed-loop rollout: at every control step
329    /// re-plan from the current ego pose, commit the first step of the
330    /// highest-probability mode, advance the (optionally moving) obstacles, and
331    /// accumulate closed-loop driving metrics.
332    ///
333    /// `obstacle_velocities` must match `scene.obstacles` in length; pass zeros
334    /// for static traffic.
335    pub fn simulate_closed_loop(
336        &self,
337        scene: &BranchOutDrivingScene2D,
338        obstacle_velocities: &[(f64, f64)],
339        config: BranchOutClosedLoopConfig2D,
340    ) -> RoboticsResult<BranchOutClosedLoopMetrics2D> {
341        validate_scene(scene)?;
342        if obstacle_velocities.len() != scene.obstacles.len() {
343            return Err(RoboticsError::InvalidParameter(
344                "BranchOut obstacle_velocities length must match scene.obstacles".to_string(),
345            ));
346        }
347        for &(vx, vy) in obstacle_velocities {
348            if !vx.is_finite() || !vy.is_finite() {
349                return Err(RoboticsError::InvalidParameter(
350                    "BranchOut obstacle velocity must be finite".to_string(),
351                ));
352            }
353        }
354        if config.steps == 0
355            || !config.ttc_threshold.is_finite()
356            || config.ttc_threshold <= 0.0
357            || !config.goal_completion.is_finite()
358            || !(0.0..=1.0).contains(&config.goal_completion)
359            || !config.max_lateral_speed.is_finite()
360            || config.max_lateral_speed <= 0.0
361        {
362            return Err(RoboticsError::InvalidParameter(
363                "BranchOut closed-loop config steps/ttc_threshold/goal_completion/max_lateral_speed are invalid"
364                    .to_string(),
365            ));
366        }
367
368        let dt = self.config.dt;
369        let mut ego = scene.start;
370        let mut obstacles = scene.obstacles.clone();
371        let mut executed_path = vec![ego];
372        let mut mode_sequence = Vec::with_capacity(config.steps);
373
374        let mut collision_steps = 0;
375        let mut min_clearance = f64::INFINITY;
376        let mut min_time_to_collision = f64::INFINITY;
377        let mut risky_ttc_steps = 0;
378
379        for _ in 0..config.steps {
380            let current_scene = BranchOutDrivingScene2D {
381                start: ego,
382                obstacles: obstacles.clone(),
383                ..scene.clone()
384            };
385            let plan = self.plan(&current_scene)?;
386            let mode = plan
387                .best()
388                .expect("validated plan always has at least one trajectory")
389                .mode;
390
391            // Receding-horizon commit: BranchOut selects the mode; the closed
392            // loop tracks that mode's target lane with a bounded lateral rate
393            // and a first-order speed law (the per-mode rollout's lateral curve
394            // is back-loaded, so replaying only its first step would barely
395            // steer). This keeps lane changes physically realizable.
396            let start_lane = current_scene.nearest_lane_index(ego.y);
397            let target_y =
398                current_scene.lane_center(mode_target_lane(&current_scene, start_lane, mode));
399            let desired_speed = match mode {
400                BranchOutDecisionMode2D::Yield => yield_speed(&current_scene, ego.x),
401                _ => current_scene.desired_speed,
402            };
403            let max_lateral_step = config.max_lateral_speed * dt;
404            let mut next = ego;
405            next.speed += 0.35 * (desired_speed - ego.speed);
406            next.x += next.speed * dt;
407            next.y += (target_y - ego.y).clamp(-max_lateral_step, max_lateral_step);
408            let ego_velocity = ((next.x - ego.x) / dt, (next.y - ego.y) / dt);
409
410            // Advance obstacles, then evaluate clearance/TTC at the committed
411            // state against the obstacles' new positions.
412            for (obstacle, &(vx, vy)) in obstacles.iter_mut().zip(obstacle_velocities) {
413                obstacle.x += vx * dt;
414                obstacle.y += vy * dt;
415            }
416
417            let mut step_min_ttc = f64::INFINITY;
418            let mut step_min_clearance = f64::INFINITY;
419            for (obstacle, &velocity) in obstacles.iter().zip(obstacle_velocities) {
420                let radius_sum = obstacle.radius + self.config.ego_radius;
421                let clearance =
422                    point_distance((next.x, next.y), (obstacle.x, obstacle.y)) - radius_sum;
423                step_min_clearance = step_min_clearance.min(clearance);
424                let ttc = time_to_collision(
425                    (next.x, next.y),
426                    ego_velocity,
427                    (obstacle.x, obstacle.y),
428                    velocity,
429                    radius_sum,
430                );
431                step_min_ttc = step_min_ttc.min(ttc);
432            }
433            min_clearance = min_clearance.min(step_min_clearance);
434            if step_min_clearance < 0.0 {
435                collision_steps += 1;
436            }
437            min_time_to_collision = min_time_to_collision.min(step_min_ttc);
438            if step_min_ttc < config.ttc_threshold {
439                risky_ttc_steps += 1;
440            }
441
442            ego = next;
443            executed_path.push(ego);
444            mode_sequence.push(mode);
445        }
446
447        if obstacles.is_empty() {
448            min_clearance = f64::INFINITY;
449        }
450        let steps = config.steps;
451        let route_completion = (ego.x / scene.route_length).clamp(0.0, 1.0);
452        let mean_comfort_cost = closed_loop_comfort(&executed_path, dt);
453
454        Ok(BranchOutClosedLoopMetrics2D {
455            steps,
456            route_completion,
457            reached_goal: route_completion >= config.goal_completion && collision_steps == 0,
458            collision_steps,
459            no_collision_rate: (steps - collision_steps) as f64 / steps as f64,
460            min_clearance,
461            mean_comfort_cost,
462            min_time_to_collision,
463            risky_ttc_steps,
464            executed_path,
465            mode_sequence,
466        })
467    }
468
469    fn rollout_mode(
470        &self,
471        scene: &BranchOutDrivingScene2D,
472        mode: BranchOutDecisionMode2D,
473    ) -> RoboticsResult<BranchOutTrajectory2D> {
474        let start_lane = scene.nearest_lane_index(scene.start.y);
475        let target_lane = mode_target_lane(scene, start_lane, mode);
476        let target_y = scene.lane_center(target_lane);
477        let mut poses = Vec::with_capacity(self.config.horizon_steps + 1);
478        let mut pose = scene.start;
479        poses.push(pose);
480
481        for step in 1..=self.config.horizon_steps {
482            let phase = step as f64 / self.config.horizon_steps as f64;
483            let smooth = smoothstep(phase);
484            let desired_speed = match mode {
485                BranchOutDecisionMode2D::Yield => yield_speed(scene, pose.x),
486                _ => scene.desired_speed,
487            };
488            pose.speed += 0.35 * (desired_speed - pose.speed);
489            pose.x += pose.speed * self.config.dt;
490            pose.y = scene.start.y + (target_y - scene.start.y) * smooth;
491            poses.push(pose);
492        }
493
494        let (collision_risk, lane_penalty, comfort_cost) = self.cost_terms(scene, &poses);
495        let final_pose = *poses
496            .last()
497            .expect("BranchOut rollout always contains start and horizon poses");
498        let progress_error = (scene.route_length - final_pose.x).max(0.0);
499        let route_completion = (final_pose.x / scene.route_length).clamp(0.0, 1.0);
500        let target_route_y = scene.lane_center(start_lane);
501        let route_deviation = (final_pose.y - target_route_y).abs();
502        let cost = self.config.progress_weight * progress_error
503            + self.config.collision_weight * collision_risk
504            + self.config.lane_weight * lane_penalty
505            + self.config.comfort_weight * comfort_cost
506            + self.config.route_weight * route_deviation;
507
508        Ok(BranchOutTrajectory2D {
509            mode,
510            probability: 0.0,
511            cost,
512            poses,
513            collision_risk,
514            comfort_cost,
515            route_completion,
516        })
517    }
518
519    fn cost_terms(
520        &self,
521        scene: &BranchOutDrivingScene2D,
522        poses: &[BranchOutPose2D],
523    ) -> (f64, f64, f64) {
524        let mut collision_risk = 0.0;
525        let mut lane_penalty = 0.0;
526        let mut comfort_cost = 0.0;
527        let road_half_width = (scene.lane_count_each_side as f64 + 0.5) * scene.lane_width;
528
529        for pose in poses {
530            for obstacle in &scene.obstacles {
531                let clearance = point_distance((pose.x, pose.y), (obstacle.x, obstacle.y))
532                    - obstacle.radius
533                    - self.config.ego_radius;
534                if clearance < 0.0 {
535                    collision_risk += (1.0 - clearance).powi(2);
536                } else {
537                    collision_risk += 0.03 / (clearance + 0.3);
538                }
539            }
540            if pose.y.abs() > road_half_width {
541                lane_penalty += (pose.y.abs() - road_half_width).powi(2);
542            }
543        }
544
545        for window in poses.windows(3) {
546            let ay0 = window[1].y - window[0].y;
547            let ay1 = window[2].y - window[1].y;
548            comfort_cost += (ay1 - ay0).powi(2) / (self.config.dt * self.config.dt);
549            comfort_cost += (window[2].speed - window[1].speed).powi(2);
550        }
551
552        let norm = poses.len() as f64;
553        (
554            collision_risk / norm,
555            lane_penalty / norm,
556            comfort_cost / norm,
557        )
558    }
559}
560
561fn assign_mixture_probabilities(
562    trajectories: &mut [BranchOutTrajectory2D],
563    temperature: f64,
564) -> RoboticsResult<()> {
565    let min_cost = trajectories
566        .iter()
567        .map(|trajectory| trajectory.cost)
568        .fold(f64::INFINITY, f64::min);
569    let mut weight_sum = 0.0;
570    for trajectory in trajectories.iter_mut() {
571        trajectory.probability = (-(trajectory.cost - min_cost) / temperature).exp();
572        weight_sum += trajectory.probability;
573    }
574    if weight_sum <= 0.0 || !weight_sum.is_finite() {
575        return Err(RoboticsError::PlanningError(
576            "BranchOut mixture weights collapsed".to_string(),
577        ));
578    }
579    for trajectory in trajectories {
580        trajectory.probability /= weight_sum;
581    }
582    Ok(())
583}
584
585fn mode_target_lane(
586    scene: &BranchOutDrivingScene2D,
587    start_lane: i32,
588    mode: BranchOutDecisionMode2D,
589) -> i32 {
590    match mode {
591        BranchOutDecisionMode2D::KeepLane | BranchOutDecisionMode2D::Yield => start_lane,
592        BranchOutDecisionMode2D::LaneChangeLeft => (start_lane + 1).min(scene.lane_count_each_side),
593        BranchOutDecisionMode2D::LaneChangeRight => {
594            (start_lane - 1).max(-scene.lane_count_each_side)
595        }
596    }
597}
598
599fn nearest_obstacle_x(scene: &BranchOutDrivingScene2D) -> Option<f64> {
600    scene
601        .obstacles
602        .iter()
603        .filter(|obstacle| obstacle.x >= scene.start.x)
604        .map(|obstacle| obstacle.x)
605        .min_by(f64::total_cmp)
606}
607
608fn yield_speed(scene: &BranchOutDrivingScene2D, ego_x: f64) -> f64 {
609    let Some(obstacle_x) = nearest_obstacle_x(scene) else {
610        return scene.desired_speed;
611    };
612    let stop_x = obstacle_x - 1.25;
613    if ego_x >= stop_x {
614        0.0
615    } else {
616        let distance_to_stop = (stop_x - ego_x).max(0.0);
617        (0.75 * scene.desired_speed).min(distance_to_stop)
618    }
619}
620
621fn mean_pairwise_final_distance(trajectories: &[BranchOutTrajectory2D]) -> f64 {
622    let mut sum = 0.0;
623    let mut count = 0;
624    for i in 0..trajectories.len() {
625        for j in i + 1..trajectories.len() {
626            let a = trajectories[i].final_pose();
627            let b = trajectories[j].final_pose();
628            sum += point_distance((a.x, a.y), (b.x, b.y));
629            count += 1;
630        }
631    }
632    if count == 0 {
633        0.0
634    } else {
635        sum / count as f64
636    }
637}
638
639fn mean_pairwise_frechet(trajectories: &[BranchOutTrajectory2D]) -> f64 {
640    let mut sum = 0.0;
641    let mut count = 0;
642    for i in 0..trajectories.len() {
643        for j in i + 1..trajectories.len() {
644            sum += discrete_frechet(&trajectories[i].poses, &trajectories[j].poses);
645            count += 1;
646        }
647    }
648    if count == 0 {
649        0.0
650    } else {
651        sum / count as f64
652    }
653}
654
655fn trajectory_set_nll(
656    trajectories: &[BranchOutTrajectory2D],
657    ground_truths: &[Vec<BranchOutPose2D>],
658    sigma: f64,
659) -> f64 {
660    let variance = sigma * sigma;
661    let normalizer = 2.0 * std::f64::consts::PI * variance;
662    let mut nll = 0.0;
663    for ground_truth in ground_truths {
664        let gt_final = *ground_truth
665            .last()
666            .expect("validated ground-truth trajectory is non-empty");
667        let likelihood = trajectories
668            .iter()
669            .map(|trajectory| {
670                let final_pose = trajectory.final_pose();
671                let distance_sq =
672                    squared_distance((final_pose.x, final_pose.y), (gt_final.x, gt_final.y));
673                trajectory.probability * (-0.5 * distance_sq / variance).exp() / normalizer
674            })
675            .sum::<f64>()
676            .max(EPS);
677        nll -= likelihood.ln();
678    }
679    nll / ground_truths.len() as f64
680}
681
682fn speed_jsd(
683    trajectories: &[BranchOutTrajectory2D],
684    ground_truths: &[Vec<BranchOutPose2D>],
685    bins: usize,
686    max_speed: f64,
687) -> f64 {
688    let mut predicted = vec![0.0; bins];
689    for trajectory in trajectories {
690        for pose in &trajectory.poses {
691            predicted[speed_bin(pose.speed, bins, max_speed)] += trajectory.probability;
692        }
693    }
694    normalize_distribution(&mut predicted);
695
696    let mut truth = vec![0.0; bins];
697    for trajectory in ground_truths {
698        for pose in trajectory {
699            truth[speed_bin(pose.speed, bins, max_speed)] += 1.0;
700        }
701    }
702    normalize_distribution(&mut truth);
703
704    let mixture = predicted
705        .iter()
706        .zip(&truth)
707        .map(|(p, q)| 0.5 * (p + q))
708        .collect::<Vec<_>>();
709    0.5 * kl_divergence(&predicted, &mixture) + 0.5 * kl_divergence(&truth, &mixture)
710}
711
712fn speed_bin(speed: f64, bins: usize, max_speed: f64) -> usize {
713    ((speed.clamp(0.0, max_speed) / max_speed) * bins as f64)
714        .floor()
715        .min((bins - 1) as f64) as usize
716}
717
718fn normalize_distribution(values: &mut [f64]) {
719    let sum = values.iter().sum::<f64>();
720    if sum > 0.0 {
721        for value in values {
722            *value /= sum;
723        }
724    }
725}
726
727fn kl_divergence(p: &[f64], q: &[f64]) -> f64 {
728    p.iter()
729        .zip(q)
730        .filter(|(p, q)| **p > 0.0 && **q > 0.0)
731        .map(|(p, q)| p * (p / q).ln())
732        .sum()
733}
734
735fn discrete_frechet(a: &[BranchOutPose2D], b: &[BranchOutPose2D]) -> f64 {
736    let mut ca = vec![vec![0.0; b.len()]; a.len()];
737    for i in 0..a.len() {
738        for j in 0..b.len() {
739            let distance = point_distance((a[i].x, a[i].y), (b[j].x, b[j].y));
740            ca[i][j] = if i == 0 && j == 0 {
741                distance
742            } else if i == 0 {
743                ca[i][j - 1].max(distance)
744            } else if j == 0 {
745                ca[i - 1][j].max(distance)
746            } else {
747                ca[i - 1][j]
748                    .min(ca[i - 1][j - 1])
749                    .min(ca[i][j - 1])
750                    .max(distance)
751            };
752        }
753    }
754    ca[a.len() - 1][b.len() - 1]
755}
756
757/// Time until the ego and obstacle disks (combined radius `radius_sum`) first
758/// touch, given current positions and constant velocities. Returns `INFINITY`
759/// when they are separating or never intersect; `0.0` when already overlapping.
760fn time_to_collision(
761    ego: (f64, f64),
762    ego_velocity: (f64, f64),
763    obstacle: (f64, f64),
764    obstacle_velocity: (f64, f64),
765    radius_sum: f64,
766) -> f64 {
767    let px = obstacle.0 - ego.0;
768    let py = obstacle.1 - ego.1;
769    let vx = obstacle_velocity.0 - ego_velocity.0;
770    let vy = obstacle_velocity.1 - ego_velocity.1;
771    let distance_sq = px * px + py * py;
772    let radius_sq = radius_sum * radius_sum;
773    if distance_sq <= radius_sq {
774        return 0.0;
775    }
776    let a = vx * vx + vy * vy;
777    if a <= EPS {
778        return f64::INFINITY;
779    }
780    let b = 2.0 * (px * vx + py * vy);
781    let c = distance_sq - radius_sq;
782    let discriminant = b * b - 4.0 * a * c;
783    if discriminant < 0.0 {
784        return f64::INFINITY;
785    }
786    let root = (-b - discriminant.sqrt()) / (2.0 * a);
787    if root >= 0.0 {
788        root
789    } else {
790        f64::INFINITY
791    }
792}
793
794/// Mean closed-loop comfort cost: lateral jerk plus longitudinal acceleration
795/// over the realized path, matching the per-mode comfort term.
796fn closed_loop_comfort(path: &[BranchOutPose2D], dt: f64) -> f64 {
797    if path.len() < 3 {
798        return 0.0;
799    }
800    let mut comfort = 0.0;
801    for window in path.windows(3) {
802        let ay0 = window[1].y - window[0].y;
803        let ay1 = window[2].y - window[1].y;
804        comfort += (ay1 - ay0).powi(2) / (dt * dt);
805        comfort += (window[2].speed - window[1].speed).powi(2);
806    }
807    comfort / (path.len() - 2) as f64
808}
809
810fn smoothstep(t: f64) -> f64 {
811    let clamped = t.clamp(0.0, 1.0);
812    clamped * clamped * (3.0 - 2.0 * clamped)
813}
814
815fn squared_distance(a: (f64, f64), b: (f64, f64)) -> f64 {
816    let dx = a.0 - b.0;
817    let dy = a.1 - b.1;
818    dx * dx + dy * dy
819}
820
821fn point_distance(a: (f64, f64), b: (f64, f64)) -> f64 {
822    squared_distance(a, b).sqrt()
823}
824
825fn validate_config(config: &BranchOutPlannerConfig2D) -> RoboticsResult<()> {
826    if config.horizon_steps == 0 || config.modes.is_empty() {
827        return Err(RoboticsError::InvalidParameter(
828            "BranchOut horizon_steps and modes must be non-empty".to_string(),
829        ));
830    }
831    for (label, value) in [
832        ("BranchOut dt", config.dt),
833        ("BranchOut ego_radius", config.ego_radius),
834        (
835            "BranchOut probability_temperature",
836            config.probability_temperature,
837        ),
838        ("BranchOut progress_weight", config.progress_weight),
839        ("BranchOut collision_weight", config.collision_weight),
840        ("BranchOut lane_weight", config.lane_weight),
841        ("BranchOut comfort_weight", config.comfort_weight),
842        ("BranchOut route_weight", config.route_weight),
843    ] {
844        if value < 0.0 || !value.is_finite() {
845            return Err(RoboticsError::InvalidParameter(format!(
846                "{label} must be finite and non-negative"
847            )));
848        }
849    }
850    if config.dt <= 0.0 || config.ego_radius <= 0.0 || config.probability_temperature <= 0.0 {
851        return Err(RoboticsError::InvalidParameter(
852            "BranchOut dt, ego_radius, and probability_temperature must be positive".to_string(),
853        ));
854    }
855    Ok(())
856}
857
858fn validate_scene(scene: &BranchOutDrivingScene2D) -> RoboticsResult<()> {
859    validate_pose(scene.start)?;
860    if scene.lane_width <= 0.0
861        || !scene.lane_width.is_finite()
862        || scene.route_length <= 0.0
863        || !scene.route_length.is_finite()
864        || scene.desired_speed <= 0.0
865        || !scene.desired_speed.is_finite()
866    {
867        return Err(RoboticsError::InvalidParameter(
868            "BranchOut scene lane_width, route_length, and desired_speed must be positive"
869                .to_string(),
870        ));
871    }
872    if scene.lane_count_each_side < 0 {
873        return Err(RoboticsError::InvalidParameter(
874            "BranchOut lane_count_each_side must be non-negative".to_string(),
875        ));
876    }
877    for obstacle in &scene.obstacles {
878        if !obstacle.x.is_finite() || !obstacle.y.is_finite() {
879            return Err(RoboticsError::InvalidParameter(
880                "BranchOut obstacle coordinates must be finite".to_string(),
881            ));
882        }
883        if obstacle.radius <= 0.0 || !obstacle.radius.is_finite() {
884            return Err(RoboticsError::InvalidParameter(
885                "BranchOut obstacle radius must be positive".to_string(),
886            ));
887        }
888    }
889    Ok(())
890}
891
892fn validate_plan(plan: &BranchOutPlan2D) -> RoboticsResult<()> {
893    if plan.trajectories.is_empty() {
894        return Err(RoboticsError::InvalidParameter(
895            "BranchOut plan must contain trajectories".to_string(),
896        ));
897    }
898    for trajectory in &plan.trajectories {
899        validate_poses(&trajectory.poses)?;
900        if trajectory.probability < 0.0
901            || !trajectory.probability.is_finite()
902            || !trajectory.cost.is_finite()
903        {
904            return Err(RoboticsError::InvalidParameter(
905                "BranchOut trajectory probability and cost must be finite".to_string(),
906            ));
907        }
908    }
909    Ok(())
910}
911
912fn validate_poses(poses: &[BranchOutPose2D]) -> RoboticsResult<()> {
913    if poses.is_empty() {
914        return Err(RoboticsError::InvalidParameter(
915            "BranchOut trajectory poses must be non-empty".to_string(),
916        ));
917    }
918    for &pose in poses {
919        validate_pose(pose)?;
920    }
921    Ok(())
922}
923
924fn validate_pose(pose: BranchOutPose2D) -> RoboticsResult<()> {
925    if !pose.x.is_finite() || !pose.y.is_finite() || !pose.speed.is_finite() || pose.speed < 0.0 {
926        return Err(RoboticsError::InvalidParameter(
927            "BranchOut pose must be finite with non-negative speed".to_string(),
928        ));
929    }
930    Ok(())
931}
932
933#[cfg(test)]
934mod tests {
935    use super::*;
936
937    #[test]
938    fn branchout_emits_multiple_modes_with_normalized_probabilities() {
939        let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
940        let scene = BranchOutDrivingScene2D::simple_overtake();
941        let plan = planner.plan(&scene).unwrap();
942
943        assert_eq!(plan.trajectories.len(), 4);
944        assert!((plan.probability_sum() - 1.0).abs() < 1e-9);
945        assert!(plan.best().unwrap().probability > 0.25);
946    }
947
948    #[test]
949    fn lane_change_modes_end_in_distinct_lanes() {
950        let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
951        let scene = BranchOutDrivingScene2D::simple_overtake();
952        let plan = planner.plan(&scene).unwrap();
953
954        let left = plan
955            .trajectories
956            .iter()
957            .find(|trajectory| trajectory.mode == BranchOutDecisionMode2D::LaneChangeLeft)
958            .unwrap();
959        let right = plan
960            .trajectories
961            .iter()
962            .find(|trajectory| trajectory.mode == BranchOutDecisionMode2D::LaneChangeRight)
963            .unwrap();
964
965        assert!(left.final_pose().y > 0.9);
966        assert!(right.final_pose().y < -0.9);
967    }
968
969    #[test]
970    fn multimodal_metrics_reward_coverage() {
971        let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
972        let scene = BranchOutDrivingScene2D::simple_overtake();
973        let plan = planner.plan(&scene).unwrap();
974        let ground_truths = plan
975            .trajectories
976            .iter()
977            .filter(|trajectory| trajectory.mode != BranchOutDecisionMode2D::KeepLane)
978            .map(|trajectory| trajectory.poses.clone())
979            .collect::<Vec<_>>();
980        let metrics = planner.evaluate_multimodal(&plan, &ground_truths).unwrap();
981
982        assert_eq!(metrics.mode_count, 4);
983        assert!(metrics.mean_pairwise_final_distance > 0.5);
984        assert!(metrics.min_ground_truth_frechet < 0.1);
985        assert!(metrics.negative_log_likelihood.is_finite());
986        assert!(metrics.speed_jsd >= 0.0);
987    }
988
989    #[test]
990    fn closed_loop_overtake_completes_route_without_collision() {
991        let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
992        let scene = BranchOutDrivingScene2D::wide_overtake();
993        let velocities = vec![(0.0, 0.0); scene.obstacles.len()];
994        let metrics = planner
995            .simulate_closed_loop(&scene, &velocities, BranchOutClosedLoopConfig2D::default())
996            .unwrap();
997
998        assert_eq!(metrics.executed_path.len(), metrics.steps + 1);
999        assert_eq!(metrics.mode_sequence.len(), metrics.steps);
1000        assert_eq!(metrics.collision_steps, 0);
1001        assert_eq!(metrics.no_collision_rate, 1.0);
1002        assert!(metrics.min_clearance > 0.0);
1003        assert!(metrics.reached_goal);
1004        assert!(metrics.route_completion >= 0.95);
1005        assert!(metrics.min_time_to_collision > 0.0);
1006        assert!(metrics.mean_comfort_cost.is_finite());
1007    }
1008
1009    #[test]
1010    fn closed_loop_yields_safely_when_blocked() {
1011        let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
1012        let scene = BranchOutDrivingScene2D::forced_yield();
1013        let metrics = planner
1014            .simulate_closed_loop(
1015                &scene,
1016                &[(0.0, 0.0)],
1017                BranchOutClosedLoopConfig2D::default(),
1018            )
1019            .unwrap();
1020
1021        // No room to pass: the ego stops behind the obstacle, so it never
1022        // reaches the goal but stays collision-free with positive clearance.
1023        assert_eq!(metrics.collision_steps, 0);
1024        assert!(!metrics.reached_goal);
1025        assert!(metrics.route_completion < 0.6);
1026        assert!(metrics.min_clearance > 0.0);
1027        assert!(metrics
1028            .mode_sequence
1029            .iter()
1030            .all(|&mode| mode == BranchOutDecisionMode2D::Yield));
1031    }
1032
1033    #[test]
1034    fn closed_loop_is_deterministic() {
1035        let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
1036        let scene = BranchOutDrivingScene2D::simple_overtake();
1037        let velocities = vec![(0.0, 0.0); scene.obstacles.len()];
1038        let config = BranchOutClosedLoopConfig2D::default();
1039        let first = planner
1040            .simulate_closed_loop(&scene, &velocities, config)
1041            .unwrap();
1042        let second = planner
1043            .simulate_closed_loop(&scene, &velocities, config)
1044            .unwrap();
1045
1046        assert_eq!(first.executed_path, second.executed_path);
1047        assert_eq!(first.mode_sequence, second.mode_sequence);
1048        assert_eq!(first.collision_steps, second.collision_steps);
1049    }
1050
1051    #[test]
1052    fn oncoming_obstacle_lowers_time_to_collision() {
1053        let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
1054        // A blocker sits far ahead in the ego lane and drives toward the ego.
1055        let mut scene = BranchOutDrivingScene2D::simple_overtake();
1056        scene.obstacles = vec![BranchOutObstacle2D::new(8.5, 0.0, 0.42)];
1057        let static_metrics = planner
1058            .simulate_closed_loop(
1059                &scene,
1060                &[(0.0, 0.0)],
1061                BranchOutClosedLoopConfig2D::default(),
1062            )
1063            .unwrap();
1064        let oncoming_metrics = planner
1065            .simulate_closed_loop(
1066                &scene,
1067                &[(-1.6, 0.0)],
1068                BranchOutClosedLoopConfig2D::default(),
1069            )
1070            .unwrap();
1071
1072        // An approaching obstacle must not raise the minimum time-to-collision.
1073        assert!(oncoming_metrics.min_time_to_collision <= static_metrics.min_time_to_collision);
1074        assert!(oncoming_metrics.min_time_to_collision.is_finite());
1075    }
1076
1077    #[test]
1078    fn closed_loop_rejects_mismatched_velocities() {
1079        let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
1080        let scene = BranchOutDrivingScene2D::simple_overtake();
1081        assert!(planner
1082            .simulate_closed_loop(&scene, &[], BranchOutClosedLoopConfig2D::default())
1083            .is_err());
1084    }
1085
1086    #[test]
1087    fn time_to_collision_detects_closing_and_separating() {
1088        // Closing head-on along x: surface gap 4, closing speed 2 -> 2.0 s.
1089        let closing = time_to_collision((0.0, 0.0), (1.0, 0.0), (5.0, 0.0), (-1.0, 0.0), 1.0);
1090        assert!((closing - 2.0).abs() < 1e-9);
1091        // Separating: never collide.
1092        let separating = time_to_collision((0.0, 0.0), (-1.0, 0.0), (5.0, 0.0), (1.0, 0.0), 1.0);
1093        assert!(separating.is_infinite());
1094        // Already overlapping -> 0.
1095        let overlapping = time_to_collision((0.0, 0.0), (0.0, 0.0), (0.5, 0.0), (0.0, 0.0), 1.0);
1096        assert_eq!(overlapping, 0.0);
1097    }
1098}