1use rust_robotics_core::{RoboticsError, RoboticsResult};
9
10const EPS: f64 = 1e-9;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub struct BranchOutPose2D {
15 pub x: f64,
16 pub y: f64,
17 pub speed: f64,
18}
19
20impl BranchOutPose2D {
21 pub fn new(x: f64, y: f64, speed: f64) -> Self {
22 Self { x, y, speed }
23 }
24}
25
26#[derive(Debug, Clone, Copy, PartialEq)]
28pub struct BranchOutObstacle2D {
29 pub x: f64,
30 pub y: f64,
31 pub radius: f64,
32}
33
34impl BranchOutObstacle2D {
35 pub fn new(x: f64, y: f64, radius: f64) -> Self {
36 Self { x, y, radius }
37 }
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub enum BranchOutDecisionMode2D {
43 KeepLane,
44 Yield,
45 LaneChangeLeft,
46 LaneChangeRight,
47}
48
49impl BranchOutDecisionMode2D {
50 pub fn label(self) -> &'static str {
51 match self {
52 Self::KeepLane => "keep-lane",
53 Self::Yield => "yield",
54 Self::LaneChangeLeft => "lane-change-left",
55 Self::LaneChangeRight => "lane-change-right",
56 }
57 }
58}
59
60#[derive(Debug, Clone, PartialEq)]
62pub struct BranchOutDrivingScene2D {
63 pub start: BranchOutPose2D,
64 pub lane_width: f64,
65 pub lane_count_each_side: i32,
66 pub route_length: f64,
67 pub desired_speed: f64,
68 pub obstacles: Vec<BranchOutObstacle2D>,
69}
70
71impl BranchOutDrivingScene2D {
72 pub fn simple_overtake() -> Self {
73 Self {
74 start: BranchOutPose2D::new(0.0, 0.0, 2.2),
75 lane_width: 1.2,
76 lane_count_each_side: 1,
77 route_length: 9.0,
78 desired_speed: 2.2,
79 obstacles: vec![BranchOutObstacle2D::new(4.1, 0.0, 0.42)],
80 }
81 }
82
83 pub fn wide_overtake() -> Self {
86 Self {
87 start: BranchOutPose2D::new(0.0, 0.0, 2.2),
88 lane_width: 1.6,
89 lane_count_each_side: 1,
90 route_length: 9.0,
91 desired_speed: 2.2,
92 obstacles: vec![BranchOutObstacle2D::new(4.1, 0.0, 0.42)],
93 }
94 }
95
96 pub fn forced_yield() -> Self {
99 Self {
100 start: BranchOutPose2D::new(0.0, 0.0, 2.2),
101 lane_width: 1.2,
102 lane_count_each_side: 0,
103 route_length: 9.0,
104 desired_speed: 2.2,
105 obstacles: vec![BranchOutObstacle2D::new(4.1, 0.0, 0.42)],
106 }
107 }
108
109 pub fn lane_center(&self, lane_index: i32) -> f64 {
110 lane_index as f64 * self.lane_width
111 }
112
113 pub fn nearest_lane_index(&self, y: f64) -> i32 {
114 (y / self.lane_width).round().clamp(
115 -(self.lane_count_each_side as f64),
116 self.lane_count_each_side as f64,
117 ) as i32
118 }
119}
120
121#[derive(Debug, Clone, PartialEq)]
123pub struct BranchOutPlannerConfig2D {
124 pub horizon_steps: usize,
125 pub dt: f64,
126 pub ego_radius: f64,
127 pub probability_temperature: f64,
128 pub progress_weight: f64,
129 pub collision_weight: f64,
130 pub lane_weight: f64,
131 pub comfort_weight: f64,
132 pub route_weight: f64,
133 pub modes: Vec<BranchOutDecisionMode2D>,
134}
135
136impl Default for BranchOutPlannerConfig2D {
137 fn default() -> Self {
138 Self {
139 horizon_steps: 28,
140 dt: 0.12,
141 ego_radius: 0.32,
142 probability_temperature: 4.0,
143 progress_weight: 1.4,
144 collision_weight: 80.0,
145 lane_weight: 12.0,
146 comfort_weight: 0.35,
147 route_weight: 0.12,
148 modes: vec![
149 BranchOutDecisionMode2D::KeepLane,
150 BranchOutDecisionMode2D::Yield,
151 BranchOutDecisionMode2D::LaneChangeLeft,
152 BranchOutDecisionMode2D::LaneChangeRight,
153 ],
154 }
155 }
156}
157
158#[derive(Debug, Clone, PartialEq)]
160pub struct BranchOutTrajectory2D {
161 pub mode: BranchOutDecisionMode2D,
162 pub probability: f64,
163 pub cost: f64,
164 pub poses: Vec<BranchOutPose2D>,
165 pub collision_risk: f64,
166 pub comfort_cost: f64,
167 pub route_completion: f64,
168}
169
170impl BranchOutTrajectory2D {
171 pub fn final_pose(&self) -> BranchOutPose2D {
172 *self
173 .poses
174 .last()
175 .expect("validated BranchOut trajectory is non-empty")
176 }
177}
178
179#[derive(Debug, Clone, PartialEq)]
181pub struct BranchOutPlan2D {
182 pub trajectories: Vec<BranchOutTrajectory2D>,
183}
184
185impl BranchOutPlan2D {
186 pub fn best(&self) -> Option<&BranchOutTrajectory2D> {
187 self.trajectories
188 .iter()
189 .max_by(|a, b| a.probability.total_cmp(&b.probability))
190 }
191
192 pub fn probability_sum(&self) -> f64 {
193 self.trajectories
194 .iter()
195 .map(|trajectory| trajectory.probability)
196 .sum()
197 }
198}
199
200#[derive(Debug, Clone, Copy, PartialEq)]
202pub struct BranchOutMultimodalMetrics2D {
203 pub mode_count: usize,
204 pub mean_pairwise_final_distance: f64,
205 pub mean_pairwise_frechet: f64,
206 pub min_ground_truth_frechet: f64,
207 pub negative_log_likelihood: f64,
208 pub speed_jsd: f64,
209 pub expected_route_completion: f64,
210}
211
212#[derive(Debug, Clone, Copy, PartialEq)]
214pub struct BranchOutClosedLoopConfig2D {
215 pub steps: usize,
217 pub ttc_threshold: f64,
219 pub goal_completion: f64,
221 pub max_lateral_speed: f64,
223}
224
225impl Default for BranchOutClosedLoopConfig2D {
226 fn default() -> Self {
227 Self {
228 steps: 40,
229 ttc_threshold: 1.5,
230 goal_completion: 0.95,
231 max_lateral_speed: 0.9,
232 }
233 }
234}
235
236#[derive(Debug, Clone, PartialEq)]
238pub struct BranchOutClosedLoopMetrics2D {
239 pub steps: usize,
240 pub route_completion: f64,
241 pub reached_goal: bool,
242 pub collision_steps: usize,
243 pub no_collision_rate: f64,
244 pub min_clearance: f64,
245 pub mean_comfort_cost: f64,
246 pub min_time_to_collision: f64,
247 pub risky_ttc_steps: usize,
248 pub executed_path: Vec<BranchOutPose2D>,
250 pub mode_sequence: Vec<BranchOutDecisionMode2D>,
252}
253
254#[derive(Debug, Clone, PartialEq)]
256pub struct BranchOutPlanner2D {
257 config: BranchOutPlannerConfig2D,
258}
259
260impl BranchOutPlanner2D {
261 pub fn new(config: BranchOutPlannerConfig2D) -> RoboticsResult<Self> {
262 validate_config(&config)?;
263 Ok(Self { config })
264 }
265
266 pub fn config(&self) -> &BranchOutPlannerConfig2D {
267 &self.config
268 }
269
270 pub fn plan(&self, scene: &BranchOutDrivingScene2D) -> RoboticsResult<BranchOutPlan2D> {
271 validate_scene(scene)?;
272 let mut trajectories = self
273 .config
274 .modes
275 .iter()
276 .map(|&mode| self.rollout_mode(scene, mode))
277 .collect::<RoboticsResult<Vec<_>>>()?;
278 assign_mixture_probabilities(&mut trajectories, self.config.probability_temperature)?;
279 Ok(BranchOutPlan2D { trajectories })
280 }
281
282 pub fn evaluate_multimodal(
283 &self,
284 plan: &BranchOutPlan2D,
285 ground_truths: &[Vec<BranchOutPose2D>],
286 ) -> RoboticsResult<BranchOutMultimodalMetrics2D> {
287 validate_plan(plan)?;
288 if ground_truths.is_empty() {
289 return Err(RoboticsError::InvalidParameter(
290 "BranchOut ground_truths must be non-empty".to_string(),
291 ));
292 }
293 for ground_truth in ground_truths {
294 validate_poses(ground_truth)?;
295 }
296
297 let mean_pairwise_final_distance = mean_pairwise_final_distance(&plan.trajectories);
298 let mean_pairwise_frechet = mean_pairwise_frechet(&plan.trajectories);
299 let min_ground_truth_frechet = ground_truths
300 .iter()
301 .map(|gt| {
302 plan.trajectories
303 .iter()
304 .map(|trajectory| discrete_frechet(&trajectory.poses, gt))
305 .fold(f64::INFINITY, f64::min)
306 })
307 .sum::<f64>()
308 / ground_truths.len() as f64;
309 let negative_log_likelihood = trajectory_set_nll(&plan.trajectories, ground_truths, 0.75);
310 let speed_jsd = speed_jsd(&plan.trajectories, ground_truths, 8, 4.0);
311 let expected_route_completion = plan
312 .trajectories
313 .iter()
314 .map(|trajectory| trajectory.probability * trajectory.route_completion)
315 .sum();
316
317 Ok(BranchOutMultimodalMetrics2D {
318 mode_count: plan.trajectories.len(),
319 mean_pairwise_final_distance,
320 mean_pairwise_frechet,
321 min_ground_truth_frechet,
322 negative_log_likelihood,
323 speed_jsd,
324 expected_route_completion,
325 })
326 }
327
328 pub fn simulate_closed_loop(
336 &self,
337 scene: &BranchOutDrivingScene2D,
338 obstacle_velocities: &[(f64, f64)],
339 config: BranchOutClosedLoopConfig2D,
340 ) -> RoboticsResult<BranchOutClosedLoopMetrics2D> {
341 validate_scene(scene)?;
342 if obstacle_velocities.len() != scene.obstacles.len() {
343 return Err(RoboticsError::InvalidParameter(
344 "BranchOut obstacle_velocities length must match scene.obstacles".to_string(),
345 ));
346 }
347 for &(vx, vy) in obstacle_velocities {
348 if !vx.is_finite() || !vy.is_finite() {
349 return Err(RoboticsError::InvalidParameter(
350 "BranchOut obstacle velocity must be finite".to_string(),
351 ));
352 }
353 }
354 if config.steps == 0
355 || !config.ttc_threshold.is_finite()
356 || config.ttc_threshold <= 0.0
357 || !config.goal_completion.is_finite()
358 || !(0.0..=1.0).contains(&config.goal_completion)
359 || !config.max_lateral_speed.is_finite()
360 || config.max_lateral_speed <= 0.0
361 {
362 return Err(RoboticsError::InvalidParameter(
363 "BranchOut closed-loop config steps/ttc_threshold/goal_completion/max_lateral_speed are invalid"
364 .to_string(),
365 ));
366 }
367
368 let dt = self.config.dt;
369 let mut ego = scene.start;
370 let mut obstacles = scene.obstacles.clone();
371 let mut executed_path = vec![ego];
372 let mut mode_sequence = Vec::with_capacity(config.steps);
373
374 let mut collision_steps = 0;
375 let mut min_clearance = f64::INFINITY;
376 let mut min_time_to_collision = f64::INFINITY;
377 let mut risky_ttc_steps = 0;
378
379 for _ in 0..config.steps {
380 let current_scene = BranchOutDrivingScene2D {
381 start: ego,
382 obstacles: obstacles.clone(),
383 ..scene.clone()
384 };
385 let plan = self.plan(¤t_scene)?;
386 let mode = plan
387 .best()
388 .expect("validated plan always has at least one trajectory")
389 .mode;
390
391 let start_lane = current_scene.nearest_lane_index(ego.y);
397 let target_y =
398 current_scene.lane_center(mode_target_lane(¤t_scene, start_lane, mode));
399 let desired_speed = match mode {
400 BranchOutDecisionMode2D::Yield => yield_speed(¤t_scene, ego.x),
401 _ => current_scene.desired_speed,
402 };
403 let max_lateral_step = config.max_lateral_speed * dt;
404 let mut next = ego;
405 next.speed += 0.35 * (desired_speed - ego.speed);
406 next.x += next.speed * dt;
407 next.y += (target_y - ego.y).clamp(-max_lateral_step, max_lateral_step);
408 let ego_velocity = ((next.x - ego.x) / dt, (next.y - ego.y) / dt);
409
410 for (obstacle, &(vx, vy)) in obstacles.iter_mut().zip(obstacle_velocities) {
413 obstacle.x += vx * dt;
414 obstacle.y += vy * dt;
415 }
416
417 let mut step_min_ttc = f64::INFINITY;
418 let mut step_min_clearance = f64::INFINITY;
419 for (obstacle, &velocity) in obstacles.iter().zip(obstacle_velocities) {
420 let radius_sum = obstacle.radius + self.config.ego_radius;
421 let clearance =
422 point_distance((next.x, next.y), (obstacle.x, obstacle.y)) - radius_sum;
423 step_min_clearance = step_min_clearance.min(clearance);
424 let ttc = time_to_collision(
425 (next.x, next.y),
426 ego_velocity,
427 (obstacle.x, obstacle.y),
428 velocity,
429 radius_sum,
430 );
431 step_min_ttc = step_min_ttc.min(ttc);
432 }
433 min_clearance = min_clearance.min(step_min_clearance);
434 if step_min_clearance < 0.0 {
435 collision_steps += 1;
436 }
437 min_time_to_collision = min_time_to_collision.min(step_min_ttc);
438 if step_min_ttc < config.ttc_threshold {
439 risky_ttc_steps += 1;
440 }
441
442 ego = next;
443 executed_path.push(ego);
444 mode_sequence.push(mode);
445 }
446
447 if obstacles.is_empty() {
448 min_clearance = f64::INFINITY;
449 }
450 let steps = config.steps;
451 let route_completion = (ego.x / scene.route_length).clamp(0.0, 1.0);
452 let mean_comfort_cost = closed_loop_comfort(&executed_path, dt);
453
454 Ok(BranchOutClosedLoopMetrics2D {
455 steps,
456 route_completion,
457 reached_goal: route_completion >= config.goal_completion && collision_steps == 0,
458 collision_steps,
459 no_collision_rate: (steps - collision_steps) as f64 / steps as f64,
460 min_clearance,
461 mean_comfort_cost,
462 min_time_to_collision,
463 risky_ttc_steps,
464 executed_path,
465 mode_sequence,
466 })
467 }
468
469 fn rollout_mode(
470 &self,
471 scene: &BranchOutDrivingScene2D,
472 mode: BranchOutDecisionMode2D,
473 ) -> RoboticsResult<BranchOutTrajectory2D> {
474 let start_lane = scene.nearest_lane_index(scene.start.y);
475 let target_lane = mode_target_lane(scene, start_lane, mode);
476 let target_y = scene.lane_center(target_lane);
477 let mut poses = Vec::with_capacity(self.config.horizon_steps + 1);
478 let mut pose = scene.start;
479 poses.push(pose);
480
481 for step in 1..=self.config.horizon_steps {
482 let phase = step as f64 / self.config.horizon_steps as f64;
483 let smooth = smoothstep(phase);
484 let desired_speed = match mode {
485 BranchOutDecisionMode2D::Yield => yield_speed(scene, pose.x),
486 _ => scene.desired_speed,
487 };
488 pose.speed += 0.35 * (desired_speed - pose.speed);
489 pose.x += pose.speed * self.config.dt;
490 pose.y = scene.start.y + (target_y - scene.start.y) * smooth;
491 poses.push(pose);
492 }
493
494 let (collision_risk, lane_penalty, comfort_cost) = self.cost_terms(scene, &poses);
495 let final_pose = *poses
496 .last()
497 .expect("BranchOut rollout always contains start and horizon poses");
498 let progress_error = (scene.route_length - final_pose.x).max(0.0);
499 let route_completion = (final_pose.x / scene.route_length).clamp(0.0, 1.0);
500 let target_route_y = scene.lane_center(start_lane);
501 let route_deviation = (final_pose.y - target_route_y).abs();
502 let cost = self.config.progress_weight * progress_error
503 + self.config.collision_weight * collision_risk
504 + self.config.lane_weight * lane_penalty
505 + self.config.comfort_weight * comfort_cost
506 + self.config.route_weight * route_deviation;
507
508 Ok(BranchOutTrajectory2D {
509 mode,
510 probability: 0.0,
511 cost,
512 poses,
513 collision_risk,
514 comfort_cost,
515 route_completion,
516 })
517 }
518
519 fn cost_terms(
520 &self,
521 scene: &BranchOutDrivingScene2D,
522 poses: &[BranchOutPose2D],
523 ) -> (f64, f64, f64) {
524 let mut collision_risk = 0.0;
525 let mut lane_penalty = 0.0;
526 let mut comfort_cost = 0.0;
527 let road_half_width = (scene.lane_count_each_side as f64 + 0.5) * scene.lane_width;
528
529 for pose in poses {
530 for obstacle in &scene.obstacles {
531 let clearance = point_distance((pose.x, pose.y), (obstacle.x, obstacle.y))
532 - obstacle.radius
533 - self.config.ego_radius;
534 if clearance < 0.0 {
535 collision_risk += (1.0 - clearance).powi(2);
536 } else {
537 collision_risk += 0.03 / (clearance + 0.3);
538 }
539 }
540 if pose.y.abs() > road_half_width {
541 lane_penalty += (pose.y.abs() - road_half_width).powi(2);
542 }
543 }
544
545 for window in poses.windows(3) {
546 let ay0 = window[1].y - window[0].y;
547 let ay1 = window[2].y - window[1].y;
548 comfort_cost += (ay1 - ay0).powi(2) / (self.config.dt * self.config.dt);
549 comfort_cost += (window[2].speed - window[1].speed).powi(2);
550 }
551
552 let norm = poses.len() as f64;
553 (
554 collision_risk / norm,
555 lane_penalty / norm,
556 comfort_cost / norm,
557 )
558 }
559}
560
561fn assign_mixture_probabilities(
562 trajectories: &mut [BranchOutTrajectory2D],
563 temperature: f64,
564) -> RoboticsResult<()> {
565 let min_cost = trajectories
566 .iter()
567 .map(|trajectory| trajectory.cost)
568 .fold(f64::INFINITY, f64::min);
569 let mut weight_sum = 0.0;
570 for trajectory in trajectories.iter_mut() {
571 trajectory.probability = (-(trajectory.cost - min_cost) / temperature).exp();
572 weight_sum += trajectory.probability;
573 }
574 if weight_sum <= 0.0 || !weight_sum.is_finite() {
575 return Err(RoboticsError::PlanningError(
576 "BranchOut mixture weights collapsed".to_string(),
577 ));
578 }
579 for trajectory in trajectories {
580 trajectory.probability /= weight_sum;
581 }
582 Ok(())
583}
584
585fn mode_target_lane(
586 scene: &BranchOutDrivingScene2D,
587 start_lane: i32,
588 mode: BranchOutDecisionMode2D,
589) -> i32 {
590 match mode {
591 BranchOutDecisionMode2D::KeepLane | BranchOutDecisionMode2D::Yield => start_lane,
592 BranchOutDecisionMode2D::LaneChangeLeft => (start_lane + 1).min(scene.lane_count_each_side),
593 BranchOutDecisionMode2D::LaneChangeRight => {
594 (start_lane - 1).max(-scene.lane_count_each_side)
595 }
596 }
597}
598
599fn nearest_obstacle_x(scene: &BranchOutDrivingScene2D) -> Option<f64> {
600 scene
601 .obstacles
602 .iter()
603 .filter(|obstacle| obstacle.x >= scene.start.x)
604 .map(|obstacle| obstacle.x)
605 .min_by(f64::total_cmp)
606}
607
608fn yield_speed(scene: &BranchOutDrivingScene2D, ego_x: f64) -> f64 {
609 let Some(obstacle_x) = nearest_obstacle_x(scene) else {
610 return scene.desired_speed;
611 };
612 let stop_x = obstacle_x - 1.25;
613 if ego_x >= stop_x {
614 0.0
615 } else {
616 let distance_to_stop = (stop_x - ego_x).max(0.0);
617 (0.75 * scene.desired_speed).min(distance_to_stop)
618 }
619}
620
621fn mean_pairwise_final_distance(trajectories: &[BranchOutTrajectory2D]) -> f64 {
622 let mut sum = 0.0;
623 let mut count = 0;
624 for i in 0..trajectories.len() {
625 for j in i + 1..trajectories.len() {
626 let a = trajectories[i].final_pose();
627 let b = trajectories[j].final_pose();
628 sum += point_distance((a.x, a.y), (b.x, b.y));
629 count += 1;
630 }
631 }
632 if count == 0 {
633 0.0
634 } else {
635 sum / count as f64
636 }
637}
638
639fn mean_pairwise_frechet(trajectories: &[BranchOutTrajectory2D]) -> f64 {
640 let mut sum = 0.0;
641 let mut count = 0;
642 for i in 0..trajectories.len() {
643 for j in i + 1..trajectories.len() {
644 sum += discrete_frechet(&trajectories[i].poses, &trajectories[j].poses);
645 count += 1;
646 }
647 }
648 if count == 0 {
649 0.0
650 } else {
651 sum / count as f64
652 }
653}
654
655fn trajectory_set_nll(
656 trajectories: &[BranchOutTrajectory2D],
657 ground_truths: &[Vec<BranchOutPose2D>],
658 sigma: f64,
659) -> f64 {
660 let variance = sigma * sigma;
661 let normalizer = 2.0 * std::f64::consts::PI * variance;
662 let mut nll = 0.0;
663 for ground_truth in ground_truths {
664 let gt_final = *ground_truth
665 .last()
666 .expect("validated ground-truth trajectory is non-empty");
667 let likelihood = trajectories
668 .iter()
669 .map(|trajectory| {
670 let final_pose = trajectory.final_pose();
671 let distance_sq =
672 squared_distance((final_pose.x, final_pose.y), (gt_final.x, gt_final.y));
673 trajectory.probability * (-0.5 * distance_sq / variance).exp() / normalizer
674 })
675 .sum::<f64>()
676 .max(EPS);
677 nll -= likelihood.ln();
678 }
679 nll / ground_truths.len() as f64
680}
681
682fn speed_jsd(
683 trajectories: &[BranchOutTrajectory2D],
684 ground_truths: &[Vec<BranchOutPose2D>],
685 bins: usize,
686 max_speed: f64,
687) -> f64 {
688 let mut predicted = vec![0.0; bins];
689 for trajectory in trajectories {
690 for pose in &trajectory.poses {
691 predicted[speed_bin(pose.speed, bins, max_speed)] += trajectory.probability;
692 }
693 }
694 normalize_distribution(&mut predicted);
695
696 let mut truth = vec![0.0; bins];
697 for trajectory in ground_truths {
698 for pose in trajectory {
699 truth[speed_bin(pose.speed, bins, max_speed)] += 1.0;
700 }
701 }
702 normalize_distribution(&mut truth);
703
704 let mixture = predicted
705 .iter()
706 .zip(&truth)
707 .map(|(p, q)| 0.5 * (p + q))
708 .collect::<Vec<_>>();
709 0.5 * kl_divergence(&predicted, &mixture) + 0.5 * kl_divergence(&truth, &mixture)
710}
711
712fn speed_bin(speed: f64, bins: usize, max_speed: f64) -> usize {
713 ((speed.clamp(0.0, max_speed) / max_speed) * bins as f64)
714 .floor()
715 .min((bins - 1) as f64) as usize
716}
717
718fn normalize_distribution(values: &mut [f64]) {
719 let sum = values.iter().sum::<f64>();
720 if sum > 0.0 {
721 for value in values {
722 *value /= sum;
723 }
724 }
725}
726
727fn kl_divergence(p: &[f64], q: &[f64]) -> f64 {
728 p.iter()
729 .zip(q)
730 .filter(|(p, q)| **p > 0.0 && **q > 0.0)
731 .map(|(p, q)| p * (p / q).ln())
732 .sum()
733}
734
735fn discrete_frechet(a: &[BranchOutPose2D], b: &[BranchOutPose2D]) -> f64 {
736 let mut ca = vec![vec![0.0; b.len()]; a.len()];
737 for i in 0..a.len() {
738 for j in 0..b.len() {
739 let distance = point_distance((a[i].x, a[i].y), (b[j].x, b[j].y));
740 ca[i][j] = if i == 0 && j == 0 {
741 distance
742 } else if i == 0 {
743 ca[i][j - 1].max(distance)
744 } else if j == 0 {
745 ca[i - 1][j].max(distance)
746 } else {
747 ca[i - 1][j]
748 .min(ca[i - 1][j - 1])
749 .min(ca[i][j - 1])
750 .max(distance)
751 };
752 }
753 }
754 ca[a.len() - 1][b.len() - 1]
755}
756
757fn time_to_collision(
761 ego: (f64, f64),
762 ego_velocity: (f64, f64),
763 obstacle: (f64, f64),
764 obstacle_velocity: (f64, f64),
765 radius_sum: f64,
766) -> f64 {
767 let px = obstacle.0 - ego.0;
768 let py = obstacle.1 - ego.1;
769 let vx = obstacle_velocity.0 - ego_velocity.0;
770 let vy = obstacle_velocity.1 - ego_velocity.1;
771 let distance_sq = px * px + py * py;
772 let radius_sq = radius_sum * radius_sum;
773 if distance_sq <= radius_sq {
774 return 0.0;
775 }
776 let a = vx * vx + vy * vy;
777 if a <= EPS {
778 return f64::INFINITY;
779 }
780 let b = 2.0 * (px * vx + py * vy);
781 let c = distance_sq - radius_sq;
782 let discriminant = b * b - 4.0 * a * c;
783 if discriminant < 0.0 {
784 return f64::INFINITY;
785 }
786 let root = (-b - discriminant.sqrt()) / (2.0 * a);
787 if root >= 0.0 {
788 root
789 } else {
790 f64::INFINITY
791 }
792}
793
794fn closed_loop_comfort(path: &[BranchOutPose2D], dt: f64) -> f64 {
797 if path.len() < 3 {
798 return 0.0;
799 }
800 let mut comfort = 0.0;
801 for window in path.windows(3) {
802 let ay0 = window[1].y - window[0].y;
803 let ay1 = window[2].y - window[1].y;
804 comfort += (ay1 - ay0).powi(2) / (dt * dt);
805 comfort += (window[2].speed - window[1].speed).powi(2);
806 }
807 comfort / (path.len() - 2) as f64
808}
809
810fn smoothstep(t: f64) -> f64 {
811 let clamped = t.clamp(0.0, 1.0);
812 clamped * clamped * (3.0 - 2.0 * clamped)
813}
814
815fn squared_distance(a: (f64, f64), b: (f64, f64)) -> f64 {
816 let dx = a.0 - b.0;
817 let dy = a.1 - b.1;
818 dx * dx + dy * dy
819}
820
821fn point_distance(a: (f64, f64), b: (f64, f64)) -> f64 {
822 squared_distance(a, b).sqrt()
823}
824
825fn validate_config(config: &BranchOutPlannerConfig2D) -> RoboticsResult<()> {
826 if config.horizon_steps == 0 || config.modes.is_empty() {
827 return Err(RoboticsError::InvalidParameter(
828 "BranchOut horizon_steps and modes must be non-empty".to_string(),
829 ));
830 }
831 for (label, value) in [
832 ("BranchOut dt", config.dt),
833 ("BranchOut ego_radius", config.ego_radius),
834 (
835 "BranchOut probability_temperature",
836 config.probability_temperature,
837 ),
838 ("BranchOut progress_weight", config.progress_weight),
839 ("BranchOut collision_weight", config.collision_weight),
840 ("BranchOut lane_weight", config.lane_weight),
841 ("BranchOut comfort_weight", config.comfort_weight),
842 ("BranchOut route_weight", config.route_weight),
843 ] {
844 if value < 0.0 || !value.is_finite() {
845 return Err(RoboticsError::InvalidParameter(format!(
846 "{label} must be finite and non-negative"
847 )));
848 }
849 }
850 if config.dt <= 0.0 || config.ego_radius <= 0.0 || config.probability_temperature <= 0.0 {
851 return Err(RoboticsError::InvalidParameter(
852 "BranchOut dt, ego_radius, and probability_temperature must be positive".to_string(),
853 ));
854 }
855 Ok(())
856}
857
858fn validate_scene(scene: &BranchOutDrivingScene2D) -> RoboticsResult<()> {
859 validate_pose(scene.start)?;
860 if scene.lane_width <= 0.0
861 || !scene.lane_width.is_finite()
862 || scene.route_length <= 0.0
863 || !scene.route_length.is_finite()
864 || scene.desired_speed <= 0.0
865 || !scene.desired_speed.is_finite()
866 {
867 return Err(RoboticsError::InvalidParameter(
868 "BranchOut scene lane_width, route_length, and desired_speed must be positive"
869 .to_string(),
870 ));
871 }
872 if scene.lane_count_each_side < 0 {
873 return Err(RoboticsError::InvalidParameter(
874 "BranchOut lane_count_each_side must be non-negative".to_string(),
875 ));
876 }
877 for obstacle in &scene.obstacles {
878 if !obstacle.x.is_finite() || !obstacle.y.is_finite() {
879 return Err(RoboticsError::InvalidParameter(
880 "BranchOut obstacle coordinates must be finite".to_string(),
881 ));
882 }
883 if obstacle.radius <= 0.0 || !obstacle.radius.is_finite() {
884 return Err(RoboticsError::InvalidParameter(
885 "BranchOut obstacle radius must be positive".to_string(),
886 ));
887 }
888 }
889 Ok(())
890}
891
892fn validate_plan(plan: &BranchOutPlan2D) -> RoboticsResult<()> {
893 if plan.trajectories.is_empty() {
894 return Err(RoboticsError::InvalidParameter(
895 "BranchOut plan must contain trajectories".to_string(),
896 ));
897 }
898 for trajectory in &plan.trajectories {
899 validate_poses(&trajectory.poses)?;
900 if trajectory.probability < 0.0
901 || !trajectory.probability.is_finite()
902 || !trajectory.cost.is_finite()
903 {
904 return Err(RoboticsError::InvalidParameter(
905 "BranchOut trajectory probability and cost must be finite".to_string(),
906 ));
907 }
908 }
909 Ok(())
910}
911
912fn validate_poses(poses: &[BranchOutPose2D]) -> RoboticsResult<()> {
913 if poses.is_empty() {
914 return Err(RoboticsError::InvalidParameter(
915 "BranchOut trajectory poses must be non-empty".to_string(),
916 ));
917 }
918 for &pose in poses {
919 validate_pose(pose)?;
920 }
921 Ok(())
922}
923
924fn validate_pose(pose: BranchOutPose2D) -> RoboticsResult<()> {
925 if !pose.x.is_finite() || !pose.y.is_finite() || !pose.speed.is_finite() || pose.speed < 0.0 {
926 return Err(RoboticsError::InvalidParameter(
927 "BranchOut pose must be finite with non-negative speed".to_string(),
928 ));
929 }
930 Ok(())
931}
932
933#[cfg(test)]
934mod tests {
935 use super::*;
936
937 #[test]
938 fn branchout_emits_multiple_modes_with_normalized_probabilities() {
939 let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
940 let scene = BranchOutDrivingScene2D::simple_overtake();
941 let plan = planner.plan(&scene).unwrap();
942
943 assert_eq!(plan.trajectories.len(), 4);
944 assert!((plan.probability_sum() - 1.0).abs() < 1e-9);
945 assert!(plan.best().unwrap().probability > 0.25);
946 }
947
948 #[test]
949 fn lane_change_modes_end_in_distinct_lanes() {
950 let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
951 let scene = BranchOutDrivingScene2D::simple_overtake();
952 let plan = planner.plan(&scene).unwrap();
953
954 let left = plan
955 .trajectories
956 .iter()
957 .find(|trajectory| trajectory.mode == BranchOutDecisionMode2D::LaneChangeLeft)
958 .unwrap();
959 let right = plan
960 .trajectories
961 .iter()
962 .find(|trajectory| trajectory.mode == BranchOutDecisionMode2D::LaneChangeRight)
963 .unwrap();
964
965 assert!(left.final_pose().y > 0.9);
966 assert!(right.final_pose().y < -0.9);
967 }
968
969 #[test]
970 fn multimodal_metrics_reward_coverage() {
971 let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
972 let scene = BranchOutDrivingScene2D::simple_overtake();
973 let plan = planner.plan(&scene).unwrap();
974 let ground_truths = plan
975 .trajectories
976 .iter()
977 .filter(|trajectory| trajectory.mode != BranchOutDecisionMode2D::KeepLane)
978 .map(|trajectory| trajectory.poses.clone())
979 .collect::<Vec<_>>();
980 let metrics = planner.evaluate_multimodal(&plan, &ground_truths).unwrap();
981
982 assert_eq!(metrics.mode_count, 4);
983 assert!(metrics.mean_pairwise_final_distance > 0.5);
984 assert!(metrics.min_ground_truth_frechet < 0.1);
985 assert!(metrics.negative_log_likelihood.is_finite());
986 assert!(metrics.speed_jsd >= 0.0);
987 }
988
989 #[test]
990 fn closed_loop_overtake_completes_route_without_collision() {
991 let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
992 let scene = BranchOutDrivingScene2D::wide_overtake();
993 let velocities = vec![(0.0, 0.0); scene.obstacles.len()];
994 let metrics = planner
995 .simulate_closed_loop(&scene, &velocities, BranchOutClosedLoopConfig2D::default())
996 .unwrap();
997
998 assert_eq!(metrics.executed_path.len(), metrics.steps + 1);
999 assert_eq!(metrics.mode_sequence.len(), metrics.steps);
1000 assert_eq!(metrics.collision_steps, 0);
1001 assert_eq!(metrics.no_collision_rate, 1.0);
1002 assert!(metrics.min_clearance > 0.0);
1003 assert!(metrics.reached_goal);
1004 assert!(metrics.route_completion >= 0.95);
1005 assert!(metrics.min_time_to_collision > 0.0);
1006 assert!(metrics.mean_comfort_cost.is_finite());
1007 }
1008
1009 #[test]
1010 fn closed_loop_yields_safely_when_blocked() {
1011 let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
1012 let scene = BranchOutDrivingScene2D::forced_yield();
1013 let metrics = planner
1014 .simulate_closed_loop(
1015 &scene,
1016 &[(0.0, 0.0)],
1017 BranchOutClosedLoopConfig2D::default(),
1018 )
1019 .unwrap();
1020
1021 assert_eq!(metrics.collision_steps, 0);
1024 assert!(!metrics.reached_goal);
1025 assert!(metrics.route_completion < 0.6);
1026 assert!(metrics.min_clearance > 0.0);
1027 assert!(metrics
1028 .mode_sequence
1029 .iter()
1030 .all(|&mode| mode == BranchOutDecisionMode2D::Yield));
1031 }
1032
1033 #[test]
1034 fn closed_loop_is_deterministic() {
1035 let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
1036 let scene = BranchOutDrivingScene2D::simple_overtake();
1037 let velocities = vec![(0.0, 0.0); scene.obstacles.len()];
1038 let config = BranchOutClosedLoopConfig2D::default();
1039 let first = planner
1040 .simulate_closed_loop(&scene, &velocities, config)
1041 .unwrap();
1042 let second = planner
1043 .simulate_closed_loop(&scene, &velocities, config)
1044 .unwrap();
1045
1046 assert_eq!(first.executed_path, second.executed_path);
1047 assert_eq!(first.mode_sequence, second.mode_sequence);
1048 assert_eq!(first.collision_steps, second.collision_steps);
1049 }
1050
1051 #[test]
1052 fn oncoming_obstacle_lowers_time_to_collision() {
1053 let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
1054 let mut scene = BranchOutDrivingScene2D::simple_overtake();
1056 scene.obstacles = vec![BranchOutObstacle2D::new(8.5, 0.0, 0.42)];
1057 let static_metrics = planner
1058 .simulate_closed_loop(
1059 &scene,
1060 &[(0.0, 0.0)],
1061 BranchOutClosedLoopConfig2D::default(),
1062 )
1063 .unwrap();
1064 let oncoming_metrics = planner
1065 .simulate_closed_loop(
1066 &scene,
1067 &[(-1.6, 0.0)],
1068 BranchOutClosedLoopConfig2D::default(),
1069 )
1070 .unwrap();
1071
1072 assert!(oncoming_metrics.min_time_to_collision <= static_metrics.min_time_to_collision);
1074 assert!(oncoming_metrics.min_time_to_collision.is_finite());
1075 }
1076
1077 #[test]
1078 fn closed_loop_rejects_mismatched_velocities() {
1079 let planner = BranchOutPlanner2D::new(BranchOutPlannerConfig2D::default()).unwrap();
1080 let scene = BranchOutDrivingScene2D::simple_overtake();
1081 assert!(planner
1082 .simulate_closed_loop(&scene, &[], BranchOutClosedLoopConfig2D::default())
1083 .is_err());
1084 }
1085
1086 #[test]
1087 fn time_to_collision_detects_closing_and_separating() {
1088 let closing = time_to_collision((0.0, 0.0), (1.0, 0.0), (5.0, 0.0), (-1.0, 0.0), 1.0);
1090 assert!((closing - 2.0).abs() < 1e-9);
1091 let separating = time_to_collision((0.0, 0.0), (-1.0, 0.0), (5.0, 0.0), (1.0, 0.0), 1.0);
1093 assert!(separating.is_infinite());
1094 let overlapping = time_to_collision((0.0, 0.0), (0.0, 0.0), (0.5, 0.0), (0.0, 0.0), 1.0);
1096 assert_eq!(overlapping, 0.0);
1097 }
1098}