Skip to main content

rust_robotics_planning/state_lattice/
state_lattice_planner.rs

1//! State Lattice Planner
2//!
3//! Implements state lattice planning for path planning.
4//! Uses a model predictive trajectory generator to create smooth paths.
5//!
6//! Based on:
7//! - PythonRobotics State Lattice Planner by Atsushi Sakai
8//! - "State Space Sampling of Feasible Motions for High-Performance Mobile Robot Navigation"
9
10use rust_robotics_core::{Obstacles, Path2D, Point2D, RoboticsError};
11
12use super::motion_model::{normalize_angle, MotionModel, MotionModelConfig};
13use super::trajectory_generator::{
14    LookupTable, TargetState, TrajectoryGenerator, TrajectoryGeneratorConfig, TrajectoryParams,
15};
16use nalgebra::Vector3;
17
18/// Configuration for State Lattice Planner
19#[derive(Debug, Clone)]
20pub struct StateLatticeConfig {
21    /// Motion model configuration
22    pub motion_config: MotionModelConfig,
23    /// Trajectory generator configuration
24    pub trajectory_config: TrajectoryGeneratorConfig,
25
26    // Uniform polar sampling parameters
27    /// Number of xy samples for polar sampling
28    pub nxy: usize,
29    /// Number of heading samples
30    pub nh: usize,
31    /// Distance for sampling \[m\]
32    pub d: f64,
33    /// Minimum angle for sampling \[rad\]
34    pub a_min: f64,
35    /// Maximum angle for sampling \[rad\]
36    pub a_max: f64,
37    /// Minimum heading offset angle \[rad\]
38    pub p_min: f64,
39    /// Maximum heading offset angle \[rad\]
40    pub p_max: f64,
41
42    // Lane sampling parameters
43    /// Lane center offset
44    pub lane_center: f64,
45    /// Lane heading
46    pub lane_heading: f64,
47    /// Lane width \[m\]
48    pub lane_width: f64,
49    /// Vehicle width \[m\]
50    pub vehicle_width: f64,
51
52    // Biased sampling parameters
53    /// Number of samples for biased sampling
54    pub ns: usize,
55}
56
57impl Default for StateLatticeConfig {
58    fn default() -> Self {
59        Self {
60            motion_config: MotionModelConfig::default(),
61            trajectory_config: TrajectoryGeneratorConfig::default(),
62            nxy: 5,
63            nh: 3,
64            d: 20.0,
65            a_min: -45.0_f64.to_radians(),
66            a_max: 45.0_f64.to_radians(),
67            p_min: -45.0_f64.to_radians(),
68            p_max: 45.0_f64.to_radians(),
69            lane_center: 0.0,
70            lane_heading: 0.0,
71            lane_width: 3.0,
72            vehicle_width: 1.0,
73            ns: 10,
74        }
75    }
76}
77
78/// Target state for planning
79#[derive(Debug, Clone, Copy)]
80pub struct TargetPose {
81    pub x: f64,
82    pub y: f64,
83    pub yaw: f64,
84}
85
86impl TargetPose {
87    pub fn new(x: f64, y: f64, yaw: f64) -> Self {
88        Self { x, y, yaw }
89    }
90
91    pub fn to_target_state(&self) -> TargetState {
92        Vector3::new(self.x, self.y, self.yaw)
93    }
94}
95
96/// Generated trajectory
97#[derive(Debug, Clone)]
98pub struct Trajectory {
99    pub x: Vec<f64>,
100    pub y: Vec<f64>,
101    pub yaw: Vec<f64>,
102    pub params: TrajectoryParams,
103}
104
105impl Trajectory {
106    pub fn to_path(&self) -> Path2D {
107        let points: Vec<Point2D> = self
108            .x
109            .iter()
110            .zip(self.y.iter())
111            .map(|(&x, &y)| Point2D::new(x, y))
112            .collect();
113        Path2D::from_points(points)
114    }
115
116    pub fn len(&self) -> usize {
117        self.x.len()
118    }
119
120    pub fn is_empty(&self) -> bool {
121        self.x.is_empty()
122    }
123}
124
125/// State Lattice Planner
126pub struct StateLattice {
127    config: StateLatticeConfig,
128    trajectory_generator: TrajectoryGenerator,
129    lookup_table: LookupTable,
130}
131
132impl StateLattice {
133    pub fn new(config: StateLatticeConfig) -> Self {
134        let motion_model = MotionModel::new(config.motion_config.clone());
135        let trajectory_generator =
136            TrajectoryGenerator::new(motion_model, config.trajectory_config.clone());
137        let lookup_table = LookupTable::generate_default();
138
139        Self {
140            config,
141            trajectory_generator,
142            lookup_table,
143        }
144    }
145
146    pub fn with_defaults() -> Self {
147        Self::new(StateLatticeConfig::default())
148    }
149
150    /// Set lookup table from CSV data
151    pub fn set_lookup_table_from_csv(&mut self, csv_data: &str) {
152        self.lookup_table = LookupTable::from_csv(csv_data);
153    }
154
155    /// Set initial curvature
156    pub fn set_initial_curvature(&mut self, k0: f64) {
157        self.trajectory_generator.set_k0(k0);
158    }
159
160    /// Get configuration
161    pub fn config(&self) -> &StateLatticeConfig {
162        &self.config
163    }
164
165    // ========================================================================
166    // State Sampling Methods
167    // ========================================================================
168
169    /// Calculate uniform polar states
170    pub fn calc_uniform_polar_states(&self) -> Vec<TargetPose> {
171        let config = &self.config;
172
173        let angle_samples: Vec<f64> = (0..config.nxy)
174            .map(|i| {
175                if config.nxy > 1 {
176                    i as f64 / (config.nxy - 1) as f64
177                } else {
178                    0.5
179                }
180            })
181            .collect();
182
183        self.sample_states(&angle_samples)
184    }
185
186    /// Calculate biased polar states toward a goal
187    pub fn calc_biased_polar_states(&self, goal_angle: f64) -> Vec<TargetPose> {
188        let config = &self.config;
189
190        if config.nxy == 0 {
191            return Vec::new();
192        }
193
194        if config.ns <= 1 || config.nxy == 1 {
195            return self.sample_states(&[0.5]);
196        }
197
198        let asi: Vec<f64> = (0..config.ns - 1)
199            .map(|i| {
200                config.a_min + (config.a_max - config.a_min) * i as f64 / (config.ns - 1) as f64
201            })
202            .collect();
203
204        let cnav: Vec<f64> = asi
205            .iter()
206            .map(|&angle| std::f64::consts::PI - (angle - goal_angle).abs())
207            .collect();
208        let cnav_sum: f64 = cnav.iter().sum();
209        let cnav_max = cnav.iter().copied().fold(f64::NEG_INFINITY, f64::max);
210
211        let normalized: Vec<f64> = cnav
212            .iter()
213            .map(|&value| (cnav_max - value) / (cnav_max * config.ns as f64 - cnav_sum))
214            .collect();
215
216        let mut cumulative = Vec::with_capacity(normalized.len());
217        let mut running_sum = 0.0;
218        for value in normalized {
219            running_sum += value;
220            cumulative.push(running_sum);
221        }
222
223        let mut angle_samples = Vec::with_capacity(config.nxy);
224        let mut li = 0usize;
225        for i in 0..config.nxy {
226            let threshold = i as f64 / (config.nxy - 1) as f64;
227            for (ii, &sample) in cumulative.iter().enumerate().take(config.ns - 1).skip(li) {
228                if ii as f64 / config.ns as f64 >= threshold {
229                    angle_samples.push(sample);
230                    li = ii.saturating_sub(1);
231                    break;
232                }
233            }
234        }
235
236        self.sample_states(&angle_samples)
237    }
238
239    /// Calculate lane states for structured driving
240    pub fn calc_lane_states(&self) -> Vec<TargetPose> {
241        let config = &self.config;
242
243        let nxy = config.nxy;
244        let d = config.d;
245        let l_center = config.lane_center;
246        let l_heading = config.lane_heading;
247        let l_width = config.lane_width;
248        let v_width = config.vehicle_width;
249
250        let mut states = Vec::new();
251
252        for i in 0..nxy {
253            let delta = if nxy > 1 {
254                -0.5 * (l_width - v_width) + (l_width - v_width) * i as f64 / (nxy - 1) as f64
255            } else {
256                0.0
257            };
258            let x = d - delta * l_heading.sin();
259            let y = l_center + delta * l_heading.cos();
260            states.push(TargetPose::new(x, y, l_heading));
261        }
262
263        states
264    }
265
266    /// Sample states from angle samples
267    fn sample_states(&self, angle_samples: &[f64]) -> Vec<TargetPose> {
268        let config = &self.config;
269        let mut states = Vec::new();
270
271        for &sample in angle_samples {
272            let angle = config.a_min + (config.a_max - config.a_min) * sample;
273            let x = config.d * angle.cos();
274            let y = config.d * angle.sin();
275
276            for j in 0..config.nh {
277                let yaw = if config.nh == 1 {
278                    (config.p_max - config.p_min) / 2.0 + angle
279                } else {
280                    config.p_min
281                        + (config.p_max - config.p_min) * j as f64 / (config.nh - 1) as f64
282                        + angle
283                };
284
285                states.push(TargetPose::new(x, y, normalize_angle(yaw)));
286            }
287        }
288
289        states
290    }
291
292    // ========================================================================
293    // Path Generation Methods
294    // ========================================================================
295
296    /// Generate paths to target states
297    pub fn generate_paths(&self, targets: &[TargetPose]) -> Vec<Trajectory> {
298        let mut paths = Vec::new();
299
300        for target in targets {
301            if let Some(trajectory) = self.generate_path_to_target(target) {
302                paths.push(trajectory);
303            }
304        }
305
306        paths
307    }
308
309    /// Generate a single path to a target
310    fn generate_path_to_target(&self, target: &TargetPose) -> Option<Trajectory> {
311        let target_state = target.to_target_state();
312        let dist = target.x.hypot(target.y);
313
314        let init_params = if let Some(nearest) = self.lookup_table.find_nearest(&target_state) {
315            Vector3::new(dist, nearest.km, nearest.kf)
316        } else {
317            Vector3::new(dist.max(1.0), 0.0, 0.0)
318        };
319
320        let result = self
321            .trajectory_generator
322            .generate_optimized(&target_state, &init_params)?;
323
324        Some(Trajectory {
325            x: result.0,
326            y: result.1,
327            yaw: result.2,
328            params: result.3,
329        })
330    }
331
332    /// Generate paths using uniform polar sampling
333    pub fn plan_uniform_polar(&self) -> Vec<Trajectory> {
334        let targets = self.calc_uniform_polar_states();
335        self.generate_paths(&targets)
336    }
337
338    /// Generate paths using biased polar sampling
339    pub fn plan_biased_polar(&self, goal_angle: f64) -> Vec<Trajectory> {
340        let targets = self.calc_biased_polar_states(goal_angle);
341        self.generate_paths(&targets)
342    }
343
344    /// Generate paths using lane sampling
345    pub fn plan_lane_states(&self) -> Vec<Trajectory> {
346        let targets = self.calc_lane_states();
347        self.generate_paths(&targets)
348    }
349
350    // ========================================================================
351    // Full Planning Interface
352    // ========================================================================
353
354    /// Plan a path from start to goal using state lattice
355    pub fn plan(&self, start: Point2D, goal: Point2D) -> Result<Path2D, RoboticsError> {
356        let dx = goal.x - start.x;
357        let dy = goal.y - start.y;
358        let distance = (dx * dx + dy * dy).sqrt();
359        let goal_angle = dy.atan2(dx);
360
361        let paths = self.plan_biased_polar(goal_angle);
362
363        if paths.is_empty() {
364            return Err(RoboticsError::PlanningError(
365                "No valid paths found".to_string(),
366            ));
367        }
368
369        let best_path = paths
370            .iter()
371            .min_by(|a, b| {
372                let a_final_x = a.x.last().unwrap_or(&0.0);
373                let a_final_y = a.y.last().unwrap_or(&0.0);
374                let b_final_x = b.x.last().unwrap_or(&0.0);
375                let b_final_y = b.y.last().unwrap_or(&0.0);
376
377                let scale = distance / self.config.d;
378                let goal_x = dx;
379                let goal_y = dy;
380
381                let a_cost =
382                    (a_final_x * scale - goal_x).powi(2) + (a_final_y * scale - goal_y).powi(2);
383                let b_cost =
384                    (b_final_x * scale - goal_x).powi(2) + (b_final_y * scale - goal_y).powi(2);
385
386                a_cost
387                    .partial_cmp(&b_cost)
388                    .unwrap_or(std::cmp::Ordering::Equal)
389            })
390            .unwrap();
391
392        let scale = distance / self.config.d;
393        let cos_angle = goal_angle.cos();
394        let sin_angle = goal_angle.sin();
395
396        let points: Vec<Point2D> = best_path
397            .x
398            .iter()
399            .zip(best_path.y.iter())
400            .map(|(&lx, &ly)| {
401                let sx = lx * scale;
402                let sy = ly * scale;
403                let wx = start.x + sx * cos_angle - sy * sin_angle;
404                let wy = start.y + sx * sin_angle + sy * cos_angle;
405                Point2D::new(wx, wy)
406            })
407            .collect();
408
409        Ok(Path2D::from_points(points))
410    }
411
412    // ========================================================================
413    // Obstacle-Aware Planning
414    // ========================================================================
415
416    /// Plan the best collision-free trajectory considering obstacles.
417    ///
418    /// Generates candidate trajectories using biased polar sampling toward the
419    /// goal, checks each against obstacles, scores them, and returns the best
420    /// collision-free trajectory.
421    ///
422    /// All coordinates are in the ego frame (vehicle at origin heading along +x).
423    /// Transform obstacles into the ego frame before calling.
424    pub fn plan_with_obstacles(
425        &self,
426        goal: Point2D,
427        obstacles: &Obstacles,
428        robot_radius: f64,
429    ) -> Result<ObstacleAwarePlanResult, RoboticsError> {
430        let goal_angle = goal.y.atan2(goal.x);
431
432        let targets = self.calc_biased_polar_states(goal_angle);
433        if targets.is_empty() {
434            return Err(RoboticsError::PlanningError(
435                "no target states sampled".to_string(),
436            ));
437        }
438
439        let paths = self.generate_paths(&targets);
440        if paths.is_empty() {
441            return Err(RoboticsError::PlanningError(
442                "no valid trajectories generated".to_string(),
443            ));
444        }
445
446        let mut valid_indices: Vec<usize> = Vec::new();
447        for (idx, traj) in paths.iter().enumerate() {
448            if !check_trajectory_collision(traj, obstacles, robot_radius) {
449                valid_indices.push(idx);
450            }
451        }
452
453        if valid_indices.is_empty() {
454            return Err(RoboticsError::PlanningError(
455                "all trajectories collide with obstacles".to_string(),
456            ));
457        }
458
459        let goal_vec = Vector3::new(goal.x, goal.y, goal_angle);
460        let mut best_idx = valid_indices[0];
461        let mut best_cost = f64::MAX;
462        for &idx in &valid_indices {
463            let cost = trajectory_cost(&paths[idx], &goal_vec, obstacles);
464            if cost < best_cost {
465                best_cost = cost;
466                best_idx = idx;
467            }
468        }
469
470        Ok(ObstacleAwarePlanResult {
471            best: paths[best_idx].clone(),
472            candidates: paths,
473            valid_indices,
474        })
475    }
476
477    /// Plan from a world-frame pose with obstacles, returning a world-frame result.
478    ///
479    /// Transforms obstacles into the ego frame, plans, then transforms the
480    /// result back to world frame.
481    pub fn plan_from_pose_with_obstacles(
482        &self,
483        pose_x: f64,
484        pose_y: f64,
485        pose_yaw: f64,
486        goal: Point2D,
487        obstacles: &Obstacles,
488        robot_radius: f64,
489    ) -> Result<ObstacleAwarePlanResult, RoboticsError> {
490        let cos_yaw = pose_yaw.cos();
491        let sin_yaw = pose_yaw.sin();
492
493        // Transform goal to ego frame.
494        let dx = goal.x - pose_x;
495        let dy = goal.y - pose_y;
496        let ego_goal = Point2D::new(cos_yaw * dx + sin_yaw * dy, -sin_yaw * dx + cos_yaw * dy);
497
498        // Transform obstacles to ego frame.
499        let ego_obs = Obstacles::from_points(
500            obstacles
501                .points
502                .iter()
503                .map(|o| {
504                    let odx = o.x - pose_x;
505                    let ody = o.y - pose_y;
506                    Point2D::new(
507                        cos_yaw * odx + sin_yaw * ody,
508                        -sin_yaw * odx + cos_yaw * ody,
509                    )
510                })
511                .collect(),
512        );
513
514        let mut result = self.plan_with_obstacles(ego_goal, &ego_obs, robot_radius)?;
515
516        // Transform results back to world frame.
517        for traj in &mut result.candidates {
518            transform_trajectory_to_world(traj, pose_x, pose_y, pose_yaw);
519        }
520        transform_trajectory_to_world(&mut result.best, pose_x, pose_y, pose_yaw);
521
522        Ok(result)
523    }
524}
525
526/// Result of obstacle-aware planning.
527#[derive(Debug, Clone)]
528pub struct ObstacleAwarePlanResult {
529    /// The lowest-cost collision-free trajectory.
530    pub best: Trajectory,
531    /// All candidate trajectories (including those that collide).
532    pub candidates: Vec<Trajectory>,
533    /// Indices into `candidates` that are collision-free.
534    pub valid_indices: Vec<usize>,
535}
536
537/// Check if any point on a trajectory is within `radius` of any obstacle.
538fn check_trajectory_collision(traj: &Trajectory, obstacles: &Obstacles, radius: f64) -> bool {
539    let r_sq = radius * radius;
540    for (&tx, &ty) in traj.x.iter().zip(traj.y.iter()) {
541        for obs in &obstacles.points {
542            let dx = tx - obs.x;
543            let dy = ty - obs.y;
544            if dx * dx + dy * dy <= r_sq {
545                return true;
546            }
547        }
548    }
549    false
550}
551
552/// Score a trajectory: lower is better.
553///
554/// Combines goal-deviation cost and obstacle-proximity cost.
555fn trajectory_cost(traj: &Trajectory, goal: &Vector3<f64>, obstacles: &Obstacles) -> f64 {
556    if traj.is_empty() {
557        return f64::MAX;
558    }
559
560    let last_x = *traj.x.last().unwrap();
561    let last_y = *traj.y.last().unwrap();
562    let last_yaw = *traj.yaw.last().unwrap();
563
564    // Goal deviation.
565    let dx = last_x - goal[0];
566    let dy = last_y - goal[1];
567    let dyaw = normalize_angle(last_yaw - goal[2]);
568    let goal_cost = (dx * dx + dy * dy).sqrt() + dyaw.abs();
569
570    // Inverse minimum distance to nearest obstacle (clearance incentive).
571    let mut min_dist_sq = f64::MAX;
572    if !obstacles.is_empty() {
573        for (&tx, &ty) in traj.x.iter().zip(traj.y.iter()) {
574            for obs in &obstacles.points {
575                let d2 = (tx - obs.x).powi(2) + (ty - obs.y).powi(2);
576                if d2 < min_dist_sq {
577                    min_dist_sq = d2;
578                }
579            }
580        }
581    }
582    let obstacle_cost = if min_dist_sq < f64::MAX && min_dist_sq > 1e-12 {
583        1.0 / min_dist_sq.sqrt()
584    } else {
585        0.0
586    };
587
588    // Lateral offset cost.
589    let lateral_cost = last_y.abs();
590
591    goal_cost + 5.0 * obstacle_cost + 0.5 * lateral_cost
592}
593
594/// Transform a trajectory from ego frame to world frame.
595fn transform_trajectory_to_world(traj: &mut Trajectory, ox: f64, oy: f64, oyaw: f64) {
596    let cos_yaw = oyaw.cos();
597    let sin_yaw = oyaw.sin();
598    for ((&mut ref mut x, &mut ref mut y), yaw) in traj
599        .x
600        .iter_mut()
601        .zip(traj.y.iter_mut())
602        .zip(traj.yaw.iter_mut())
603    {
604        let lx = *x;
605        let ly = *y;
606        *x = cos_yaw * lx - sin_yaw * ly + ox;
607        *y = sin_yaw * lx + cos_yaw * ly + oy;
608        *yaw = normalize_angle(*yaw + oyaw);
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615    use std::f64::consts::{FRAC_PI_2, FRAC_PI_8};
616
617    type TrajectoryTerminal = (f64, f64, f64, f64, f64, f64);
618
619    fn assert_pose_close(actual: &TargetPose, expected: (f64, f64, f64)) {
620        assert!((actual.x - expected.0).abs() < 1e-9);
621        assert!((actual.y - expected.1).abs() < 1e-9);
622        assert!((actual.yaw - expected.2).abs() < 1e-9);
623    }
624
625    fn assert_trajectory_close_with_tolerance(
626        actual: &Trajectory,
627        expected: TrajectoryTerminal,
628        tolerance: TrajectoryTerminal,
629    ) {
630        let observed = (
631            *actual.x.last().unwrap(),
632            *actual.y.last().unwrap(),
633            *actual.yaw.last().unwrap(),
634            actual.params[0],
635            actual.params[1],
636            actual.params[2],
637        );
638        assert!(
639            (observed.0 - expected.0).abs() < tolerance.0,
640            "x mismatch: observed={observed:?}, expected={expected:?}"
641        );
642        assert!(
643            (observed.1 - expected.1).abs() < tolerance.1,
644            "y mismatch: observed={observed:?}, expected={expected:?}"
645        );
646        assert!(
647            (observed.2 - expected.2).abs() < tolerance.2,
648            "yaw mismatch: observed={observed:?}, expected={expected:?}"
649        );
650        assert!(
651            (observed.3 - expected.3).abs() < tolerance.3,
652            "s mismatch: observed={observed:?}, expected={expected:?}"
653        );
654        assert!(
655            (observed.4 - expected.4).abs() < tolerance.4,
656            "km mismatch: observed={observed:?}, expected={expected:?}"
657        );
658        assert!(
659            (observed.5 - expected.5).abs() < tolerance.5,
660            "kf mismatch: observed={observed:?}, expected={expected:?}"
661        );
662    }
663
664    fn parse_reference_trajectories(csv: &str) -> Vec<TrajectoryTerminal> {
665        csv.lines()
666            .skip(1)
667            .filter(|line| !line.trim().is_empty())
668            .map(|line| {
669                let values: Vec<f64> = line
670                    .split(',')
671                    .map(|value| value.parse::<f64>().unwrap())
672                    .collect();
673                assert_eq!(values.len(), 6);
674                (
675                    values[0], values[1], values[2], values[3], values[4], values[5],
676                )
677            })
678            .collect()
679    }
680
681    fn assert_trajectory_table_close(
682        actual: &[Trajectory],
683        expected: &[TrajectoryTerminal],
684        tolerance: TrajectoryTerminal,
685    ) {
686        assert_eq!(actual.len(), expected.len());
687        for (actual, expected) in actual.iter().zip(expected.iter().copied()) {
688            assert_trajectory_close_with_tolerance(actual, expected, tolerance);
689        }
690    }
691
692    #[test]
693    fn test_state_lattice_creation() {
694        let planner = StateLattice::with_defaults();
695        assert!(planner.config.nxy > 0);
696    }
697
698    #[test]
699    fn test_uniform_polar_states() {
700        let planner = StateLattice::with_defaults();
701        let states = planner.calc_uniform_polar_states();
702        assert!(!states.is_empty());
703    }
704
705    #[test]
706    fn test_uniform_polar_states_match_upstream_reference() {
707        let planner = StateLattice::new(StateLatticeConfig {
708            nxy: 5,
709            nh: 3,
710            d: 20.0,
711            a_min: -45.0_f64.to_radians(),
712            a_max: 45.0_f64.to_radians(),
713            p_min: -45.0_f64.to_radians(),
714            p_max: 45.0_f64.to_radians(),
715            ..Default::default()
716        });
717        let states = planner.calc_uniform_polar_states();
718
719        assert_eq!(states.len(), 15);
720        assert_pose_close(
721            &states[0],
722            (14.142135623730951, -14.14213562373095, -FRAC_PI_2),
723        );
724        assert_pose_close(
725            &states[5],
726            (18.477590650225736, -7.653668647301796, FRAC_PI_8),
727        );
728        assert_pose_close(
729            states.last().unwrap(),
730            (14.142135623730951, 14.14213562373095, FRAC_PI_2),
731        );
732    }
733
734    #[test]
735    fn test_biased_polar_states() {
736        let planner = StateLattice::with_defaults();
737        let states = planner.calc_biased_polar_states(0.0);
738        assert!(!states.is_empty());
739    }
740
741    #[test]
742    fn test_biased_polar_states_match_upstream_reference_window() {
743        let planner = StateLattice::new(StateLatticeConfig {
744            nxy: 30,
745            nh: 2,
746            d: 20.0,
747            a_min: -45.0_f64.to_radians(),
748            a_max: 45.0_f64.to_radians(),
749            p_min: -20.0_f64.to_radians(),
750            p_max: 20.0_f64.to_radians(),
751            ns: 100,
752            ..Default::default()
753        });
754        let states = planner.calc_biased_polar_states(0.0);
755
756        assert_eq!(states.len(), 58);
757        assert_pose_close(
758            &states[0],
759            (14.554768777999886, -13.717095385647784, -1.1048434557770916),
760        );
761        assert_pose_close(
762            &states[5],
763            (
764                16.88791878728749,
765                -10.714392144867987,
766                -0.216_293_881_998_737_4,
767            ),
768        );
769        assert_pose_close(
770            states.last().unwrap(),
771            (16.077756451864964, 11.89561883529035, 0.9860589731081659),
772        );
773    }
774
775    #[test]
776    fn test_lane_states() {
777        let planner = StateLattice::with_defaults();
778        let states = planner.calc_lane_states();
779        assert!(!states.is_empty());
780    }
781
782    #[test]
783    fn test_lane_states_match_upstream_reference() {
784        let planner = StateLattice::new(StateLatticeConfig {
785            lane_center: 10.0,
786            lane_heading: 0.0,
787            lane_width: 3.0,
788            vehicle_width: 1.0,
789            d: 10.0,
790            nxy: 5,
791            ..Default::default()
792        });
793        let states = planner.calc_lane_states();
794
795        assert_eq!(states.len(), 5);
796        assert_pose_close(&states[0], (10.0, 9.0, 0.0));
797        assert_pose_close(&states[2], (10.0, 10.0, 0.0));
798        assert_pose_close(states.last().unwrap(), (10.0, 11.0, 0.0));
799    }
800
801    #[test]
802    fn test_generate_paths() {
803        let planner = StateLattice::with_defaults();
804        let targets = vec![
805            TargetPose::new(10.0, 0.0, 0.0),
806            TargetPose::new(10.0, 2.0, 0.2),
807        ];
808        let paths = planner.generate_paths(&targets);
809        assert!(!paths.is_empty() || targets.is_empty());
810    }
811
812    #[test]
813    fn test_generate_lane_paths_match_upstream_full_example() {
814        let planner = StateLattice::new(StateLatticeConfig {
815            lane_center: 10.0,
816            lane_heading: 0.0,
817            lane_width: 3.0,
818            vehicle_width: 1.0,
819            d: 10.0,
820            nxy: 5,
821            ..Default::default()
822        });
823        let targets = planner.calc_lane_states();
824        let paths = planner.generate_paths(&targets);
825        let expected =
826            parse_reference_trajectories(include_str!("testdata/lane_state_sampling_test1.csv"));
827
828        assert_eq!(targets.len(), expected.len());
829        assert_trajectory_table_close(&paths, &expected, (5e-3, 5e-3, 5e-4, 5e-3, 5e-4, 5e-4));
830    }
831
832    #[test]
833    fn test_generate_uniform_paths_match_upstream_full_example() {
834        let planner = StateLattice::new(StateLatticeConfig {
835            nxy: 5,
836            nh: 3,
837            d: 20.0,
838            a_min: -45.0_f64.to_radians(),
839            a_max: 45.0_f64.to_radians(),
840            p_min: -45.0_f64.to_radians(),
841            p_max: 45.0_f64.to_radians(),
842            ..Default::default()
843        });
844        let targets = planner.calc_uniform_polar_states();
845        let paths = planner.generate_paths(&targets);
846        let expected = parse_reference_trajectories(include_str!(
847            "testdata/uniform_terminal_state_sampling_test1.csv"
848        ));
849
850        assert_eq!(targets.len(), expected.len());
851        assert_trajectory_table_close(
852            &paths,
853            &expected,
854            (1.5e-1, 1.5e-1, 5e-2, 2e-1, 1e-2, 1.5e-2),
855        );
856    }
857
858    #[test]
859    fn test_generate_uniform2_paths_match_upstream_full_example() {
860        let mut planner = StateLattice::new(StateLatticeConfig {
861            nxy: 6,
862            nh: 3,
863            d: 20.0,
864            a_min: 10.0_f64.to_radians(),
865            a_max: 45.0_f64.to_radians(),
866            p_min: -20.0_f64.to_radians(),
867            p_max: 20.0_f64.to_radians(),
868            ..Default::default()
869        });
870        planner.set_initial_curvature(0.1);
871        let targets = planner.calc_uniform_polar_states();
872        let paths = planner.generate_paths(&targets);
873        let expected = parse_reference_trajectories(include_str!(
874            "testdata/uniform_terminal_state_sampling_test2.csv"
875        ));
876
877        assert_eq!(targets.len(), expected.len());
878        assert_trajectory_table_close(&paths, &expected, (1e-1, 1e-1, 2e-2, 2e-1, 1e-2, 1e-2));
879    }
880
881    #[test]
882    fn test_generate_biased_paths_match_upstream_full_example() {
883        let planner = StateLattice::new(StateLatticeConfig {
884            nxy: 30,
885            nh: 2,
886            d: 20.0,
887            a_min: -45.0_f64.to_radians(),
888            a_max: 45.0_f64.to_radians(),
889            p_min: -20.0_f64.to_radians(),
890            p_max: 20.0_f64.to_radians(),
891            ns: 100,
892            ..Default::default()
893        });
894        let targets = planner.calc_biased_polar_states(0.0);
895        let paths = planner.generate_paths(&targets);
896        let expected = parse_reference_trajectories(include_str!(
897            "testdata/biased_terminal_state_sampling_test1.csv"
898        ));
899
900        assert_eq!(targets.len(), expected.len());
901        assert_trajectory_table_close(&paths, &expected, (1.5e-1, 1.5e-1, 3e-2, 2e-1, 1e-2, 1e-2));
902    }
903
904    #[test]
905    fn test_generate_biased2_paths_match_upstream_full_example() {
906        let planner = StateLattice::new(StateLatticeConfig {
907            nxy: 30,
908            nh: 1,
909            d: 20.0,
910            a_min: 0.0,
911            a_max: 45.0_f64.to_radians(),
912            p_min: -20.0_f64.to_radians(),
913            p_max: 20.0_f64.to_radians(),
914            ns: 100,
915            ..Default::default()
916        });
917        let targets = planner.calc_biased_polar_states(30.0_f64.to_radians());
918        let paths = planner.generate_paths(&targets);
919        let expected = parse_reference_trajectories(include_str!(
920            "testdata/biased_terminal_state_sampling_test2.csv"
921        ));
922
923        assert_eq!(targets.len(), expected.len());
924        assert_trajectory_table_close(&paths, &expected, (1e-1, 1e-1, 2e-2, 2e-1, 1e-2, 1e-2));
925    }
926
927    #[test]
928    fn test_plan_uniform_polar() {
929        let planner = StateLattice::with_defaults();
930        let _paths = planner.plan_uniform_polar();
931    }
932
933    #[test]
934    fn test_target_pose() {
935        let target = TargetPose::new(10.0, 5.0, 0.5);
936        let state = target.to_target_state();
937        assert_eq!(state[0], 10.0);
938        assert_eq!(state[1], 5.0);
939        assert_eq!(state[2], 0.5);
940    }
941
942    #[test]
943    fn test_trajectory_to_path() {
944        let traj = Trajectory {
945            x: vec![0.0, 1.0, 2.0],
946            y: vec![0.0, 0.5, 1.0],
947            yaw: vec![0.0, 0.1, 0.2],
948            params: Vector3::new(2.0, 0.0, 0.0),
949        };
950
951        let path = traj.to_path();
952        assert_eq!(path.len(), 3);
953    }
954
955    #[test]
956    fn test_plan_straight() {
957        let planner = StateLattice::with_defaults();
958        let start = Point2D::new(0.0, 0.0);
959        let goal = Point2D::new(20.0, 0.0);
960
961        let result = planner.plan(start, goal);
962        if let Ok(path) = result {
963            assert!(!path.is_empty());
964            assert!((path.points[0].x - start.x).abs() < 1.0);
965        }
966    }
967
968    // ====================================================================
969    // Obstacle-aware planning tests
970    // ====================================================================
971
972    #[test]
973    fn test_plan_with_obstacles_no_obstacles() {
974        let planner = StateLattice::with_defaults();
975        let goal = Point2D::new(20.0, 0.0);
976        let obstacles = Obstacles::new();
977
978        let result = planner.plan_with_obstacles(goal, &obstacles, 1.0);
979        assert!(result.is_ok(), "planning should succeed with no obstacles");
980        let pr = result.unwrap();
981        assert!(!pr.best.is_empty());
982        assert!(!pr.valid_indices.is_empty());
983        assert_eq!(pr.valid_indices.len(), pr.candidates.len());
984    }
985
986    #[test]
987    fn test_plan_with_obstacles_avoids_blocked_path() {
988        let planner = StateLattice::with_defaults();
989        let goal = Point2D::new(20.0, 0.0);
990        // Place obstacle on the straight-ahead path.
991        let obstacles = Obstacles::from_points(vec![Point2D::new(10.0, 0.0)]);
992
993        let result = planner.plan_with_obstacles(goal, &obstacles, 1.0);
994        assert!(result.is_ok());
995        let pr = result.unwrap();
996        // The best trajectory should not go through (10, 0).
997        let _last_y = pr.best.y.last().unwrap();
998        // It should have been deflected laterally or picked a different angle.
999        assert!(
1000            pr.valid_indices.len() < pr.candidates.len(),
1001            "some trajectories should have been filtered out"
1002        );
1003    }
1004
1005    #[test]
1006    fn test_plan_with_obstacles_all_blocked() {
1007        let planner = StateLattice::new(StateLatticeConfig {
1008            nxy: 5,
1009            nh: 3,
1010            d: 20.0,
1011            ..Default::default()
1012        });
1013        // Dense wall of obstacles across all angles.
1014        let mut obs_pts = Vec::new();
1015        for i in -30..=30 {
1016            obs_pts.push(Point2D::new(10.0, i as f64 * 1.0));
1017        }
1018        let obstacles = Obstacles::from_points(obs_pts);
1019
1020        let result = planner.plan_with_obstacles(Point2D::new(20.0, 0.0), &obstacles, 3.0);
1021        assert!(result.is_err(), "all trajectories should be blocked");
1022    }
1023
1024    #[test]
1025    fn test_collision_check() {
1026        let traj = Trajectory {
1027            x: vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0],
1028            y: vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
1029            yaw: vec![0.0; 6],
1030            params: Vector3::new(5.0, 0.0, 0.0),
1031        };
1032
1033        // Obstacle on the path.
1034        let obs = Obstacles::from_points(vec![Point2D::new(2.5, 0.0)]);
1035        assert!(check_trajectory_collision(&traj, &obs, 0.6));
1036
1037        // Obstacle far away.
1038        let obs_far = Obstacles::from_points(vec![Point2D::new(2.5, 10.0)]);
1039        assert!(!check_trajectory_collision(&traj, &obs_far, 0.6));
1040    }
1041
1042    #[test]
1043    fn test_plan_from_pose_with_obstacles() {
1044        let planner = StateLattice::with_defaults();
1045        let obstacles = Obstacles::new();
1046        let goal = Point2D::new(30.0, 10.0);
1047
1048        let result = planner.plan_from_pose_with_obstacles(10.0, 5.0, 0.3, goal, &obstacles, 1.0);
1049        assert!(result.is_ok());
1050        let pr = result.unwrap();
1051        let ep_x = pr.best.x.last().unwrap();
1052        let ep_y = pr.best.y.last().unwrap();
1053        // Trajectory should advance from the start pose toward the goal.
1054        assert!(*ep_x > 10.0 || *ep_y > 5.0, "ep = ({}, {})", ep_x, ep_y);
1055    }
1056
1057    #[test]
1058    fn test_transform_trajectory_to_world() {
1059        let mut traj = Trajectory {
1060            x: vec![1.0, 2.0],
1061            y: vec![0.0, 0.0],
1062            yaw: vec![0.0, 0.0],
1063            params: Vector3::new(2.0, 0.0, 0.0),
1064        };
1065
1066        let yaw = std::f64::consts::FRAC_PI_2;
1067        transform_trajectory_to_world(&mut traj, 5.0, 3.0, yaw);
1068
1069        // After 90-degree rotation: ego (1,0) -> world (5 + 0, 3 + 1) = (5, 4)
1070        assert!((traj.x[0] - 5.0).abs() < 1e-9);
1071        assert!((traj.y[0] - 4.0).abs() < 1e-9);
1072    }
1073}