Skip to main content

rust_robotics_localization/
ekf.rs

1//! Extended Kalman Filter (EKF) localization
2//!
3//! Implements state estimation using the Extended Kalman Filter algorithm
4//! for robot localization with nonlinear motion and observation models.
5
6use alloc::format;
7use alloc::string::ToString;
8use nalgebra::{Matrix2, Matrix2x4, Matrix4, Vector2, Vector4};
9#[cfg(not(feature = "std"))]
10#[allow(unused_imports)]
11// f64 math via libm on no_std targets; on std hosts the inherent methods win
12use num_traits::Float;
13use rust_robotics_core::{
14    ControlInput, Point2D, RoboticsError, RoboticsResult, State2D, StateEstimator,
15};
16
17/// State representation for EKF (x, y, yaw, velocity)
18pub type EKFState = Vector4<f64>;
19
20/// Measurement representation (x, y position)
21pub type EKFMeasurement = Vector2<f64>;
22
23/// Control input (velocity, yaw rate)
24pub type EKFControl = Vector2<f64>;
25
26/// Configuration for EKF
27#[derive(Debug, Clone)]
28pub struct EKFConfig {
29    /// Process noise covariance matrix
30    pub q: Matrix4<f64>,
31    /// Measurement noise covariance matrix
32    pub r: Matrix2<f64>,
33}
34
35impl Default for EKFConfig {
36    fn default() -> Self {
37        let mut q = Matrix4::<f64>::identity();
38        q[(0, 0)] = 0.1_f64.powi(2);
39        q[(1, 1)] = 0.1_f64.powi(2);
40        q[(2, 2)] = (1.0_f64.to_radians()).powi(2);
41        q[(3, 3)] = 0.1_f64.powi(2);
42
43        let r = Matrix2::<f64>::identity();
44
45        Self { q, r }
46    }
47}
48
49impl EKFConfig {
50    pub fn validate(&self) -> RoboticsResult<()> {
51        if self.q.iter().any(|value| !value.is_finite()) {
52            return Err(RoboticsError::InvalidParameter(
53                "EKF process noise matrix must contain only finite values".to_string(),
54            ));
55        }
56        if self.r.iter().any(|value| !value.is_finite()) {
57            return Err(RoboticsError::InvalidParameter(
58                "EKF measurement noise matrix must contain only finite values".to_string(),
59            ));
60        }
61        for i in 0..4 {
62            if self.q[(i, i)] < 0.0 {
63                return Err(RoboticsError::InvalidParameter(
64                    "EKF process noise diagonal entries must be non-negative".to_string(),
65                ));
66            }
67        }
68        for i in 0..2 {
69            if self.r[(i, i)] < 0.0 {
70                return Err(RoboticsError::InvalidParameter(
71                    "EKF measurement noise diagonal entries must be non-negative".to_string(),
72                ));
73            }
74        }
75
76        Ok(())
77    }
78}
79
80/// Extended Kalman Filter for robot localization
81pub struct EKFLocalizer {
82    /// Current state estimate [x, y, yaw, v]
83    state: EKFState,
84    /// State covariance matrix
85    covariance: Matrix4<f64>,
86    /// Configuration
87    config: EKFConfig,
88}
89
90impl EKFLocalizer {
91    /// Create a new EKF localizer
92    pub fn new(config: EKFConfig) -> Self {
93        Self::try_new(config).expect(
94            "invalid EKF configuration: noise matrices must be finite and have non-negative diagonals",
95        )
96    }
97
98    /// Create a new validated EKF localizer
99    pub fn try_new(config: EKFConfig) -> RoboticsResult<Self> {
100        config.validate()?;
101        Ok(EKFLocalizer {
102            state: EKFState::zeros(),
103            covariance: Matrix4::identity(),
104            config,
105        })
106    }
107
108    /// Create with default configuration
109    pub fn with_defaults() -> Self {
110        Self::new(EKFConfig::default())
111    }
112
113    /// Create with initial state
114    pub fn with_initial_state(initial_state: EKFState, config: EKFConfig) -> Self {
115        Self::try_with_initial_state(initial_state, config)
116            .expect("invalid EKF initialization: state must be finite and config must be valid")
117    }
118
119    /// Create with validated initial state
120    pub fn try_with_initial_state(
121        initial_state: EKFState,
122        config: EKFConfig,
123    ) -> RoboticsResult<Self> {
124        config.validate()?;
125        Self::validate_state_vector(&initial_state)?;
126        Ok(EKFLocalizer {
127            state: initial_state,
128            covariance: Matrix4::identity(),
129            config,
130        })
131    }
132
133    /// Create with common State2D type
134    pub fn with_initial_state_2d(
135        initial_state: State2D,
136        config: EKFConfig,
137    ) -> RoboticsResult<Self> {
138        Self::try_with_initial_state(initial_state.to_vector(), config)
139    }
140
141    /// Get reference to state covariance
142    pub fn get_covariance_matrix(&self) -> &Matrix4<f64> {
143        &self.covariance
144    }
145
146    /// Get current estimate as State2D
147    pub fn state_2d(&self) -> State2D {
148        State2D::new(self.state[0], self.state[1], self.state[2], self.state[3])
149    }
150
151    /// Set process noise covariance
152    pub fn set_process_noise(&mut self, q: Matrix4<f64>) {
153        self.try_set_process_noise(q)
154            .expect("invalid EKF process noise matrix")
155    }
156
157    /// Set process noise covariance with validation
158    pub fn try_set_process_noise(&mut self, q: Matrix4<f64>) -> RoboticsResult<()> {
159        if q.iter().any(|value| !value.is_finite()) {
160            return Err(RoboticsError::InvalidParameter(
161                "EKF process noise matrix must contain only finite values".to_string(),
162            ));
163        }
164        for i in 0..4 {
165            if q[(i, i)] < 0.0 {
166                return Err(RoboticsError::InvalidParameter(
167                    "EKF process noise diagonal entries must be non-negative".to_string(),
168                ));
169            }
170        }
171
172        self.config.q = q;
173        Ok(())
174    }
175
176    /// Set measurement noise covariance
177    pub fn set_measurement_noise(&mut self, r: Matrix2<f64>) {
178        self.try_set_measurement_noise(r)
179            .expect("invalid EKF measurement noise matrix")
180    }
181
182    /// Set measurement noise covariance with validation
183    pub fn try_set_measurement_noise(&mut self, r: Matrix2<f64>) -> RoboticsResult<()> {
184        if r.iter().any(|value| !value.is_finite()) {
185            return Err(RoboticsError::InvalidParameter(
186                "EKF measurement noise matrix must contain only finite values".to_string(),
187            ));
188        }
189        for i in 0..2 {
190            if r[(i, i)] < 0.0 {
191                return Err(RoboticsError::InvalidParameter(
192                    "EKF measurement noise diagonal entries must be non-negative".to_string(),
193                ));
194            }
195        }
196
197        self.config.r = r;
198        Ok(())
199    }
200
201    /// Motion model: predict state based on control input
202    fn motion_model(x: &EKFState, u: &EKFControl, dt: f64) -> EKFState {
203        let yaw = x[2];
204        EKFState::new(
205            x[0] + dt * u[0] * yaw.cos(),
206            x[1] + dt * u[0] * yaw.sin(),
207            x[2] + dt * u[1],
208            u[0],
209        )
210    }
211
212    /// Jacobian of motion model with respect to state
213    fn jacobian_f(x: &EKFState, u: &EKFControl, dt: f64) -> Matrix4<f64> {
214        let yaw = x[2];
215        let v = u[0];
216        Matrix4::new(
217            1.,
218            0.,
219            -dt * v * yaw.sin(),
220            0.,
221            0.,
222            1.,
223            dt * v * yaw.cos(),
224            0.,
225            0.,
226            0.,
227            1.,
228            0.,
229            0.,
230            0.,
231            0.,
232            0.,
233        )
234    }
235
236    /// Observation model: predict measurement from state
237    fn observation_model(x: &EKFState) -> EKFMeasurement {
238        let h = Matrix2x4::new(1., 0., 0., 0., 0., 1., 0., 0.);
239        h * x
240    }
241
242    /// Jacobian of observation model
243    fn jacobian_h() -> Matrix2x4<f64> {
244        Matrix2x4::new(1., 0., 0., 0., 0., 1., 0., 0.)
245    }
246
247    /// Full EKF estimation step (predict + update)
248    pub fn estimate(
249        &mut self,
250        measurement: &EKFMeasurement,
251        control: &EKFControl,
252        dt: f64,
253    ) -> Result<&EKFState, RoboticsError> {
254        Self::validate_measurement_vector(measurement)?;
255        Self::validate_control_vector(control)?;
256        Self::validate_dt(dt)?;
257
258        // Predict
259        let x_pred = Self::motion_model(&self.state, control, dt);
260        let j_f = Self::jacobian_f(&x_pred, control, dt);
261        let p_pred = j_f * self.covariance * j_f.transpose() + self.config.q;
262
263        // Update
264        let j_h = Self::jacobian_h();
265        let z_pred = Self::observation_model(&x_pred);
266        let y = measurement - z_pred;
267        let s = j_h * p_pred * j_h.transpose() + self.config.r;
268
269        let s_inv = s.try_inverse().ok_or_else(|| {
270            RoboticsError::NumericalError("Failed to invert S matrix".to_string())
271        })?;
272
273        let k = p_pred * j_h.transpose() * s_inv;
274        self.state = x_pred + k * y;
275        self.covariance = (Matrix4::identity() - k * j_h) * p_pred;
276
277        Ok(&self.state)
278    }
279
280    /// EKF estimate step using common crate types
281    pub fn estimate_state(
282        &mut self,
283        measurement: Point2D,
284        control: ControlInput,
285        dt: f64,
286    ) -> RoboticsResult<State2D> {
287        self.estimate(&measurement.to_vector(), &control.to_vector(), dt)?;
288        Ok(self.state_2d())
289    }
290
291    /// Legacy interface for EKF estimation (standalone function style)
292    pub fn ekf_estimation(
293        x_est: EKFState,
294        p_est: Matrix4<f64>,
295        z: EKFMeasurement,
296        u: EKFControl,
297        q: Matrix4<f64>,
298        r: Matrix2<f64>,
299        dt: f64,
300    ) -> (EKFState, Matrix4<f64>) {
301        let x_pred = Self::motion_model(&x_est, &u, dt);
302        let j_f = Self::jacobian_f(&x_pred, &u, dt);
303        let p_pred = j_f * p_est * j_f.transpose() + q;
304
305        let j_h = Self::jacobian_h();
306        let z_pred = Self::observation_model(&x_pred);
307        let y = z - z_pred;
308        let s = j_h * p_pred * j_h.transpose() + r;
309        let k = p_pred * j_h.transpose() * s.try_inverse().unwrap();
310        let new_x_est = x_pred + k * y;
311        let new_p_est = (Matrix4::identity() - k * j_h) * p_pred;
312
313        (new_x_est, new_p_est)
314    }
315
316    fn validate_state_vector(state: &EKFState) -> RoboticsResult<()> {
317        if state.iter().any(|value| !value.is_finite()) {
318            return Err(RoboticsError::InvalidParameter(
319                "EKF state must contain only finite values".to_string(),
320            ));
321        }
322
323        Ok(())
324    }
325
326    fn validate_measurement_vector(measurement: &EKFMeasurement) -> RoboticsResult<()> {
327        if measurement.iter().any(|value| !value.is_finite()) {
328            return Err(RoboticsError::InvalidParameter(
329                "EKF measurement must contain only finite values".to_string(),
330            ));
331        }
332
333        Ok(())
334    }
335
336    fn validate_control_vector(control: &EKFControl) -> RoboticsResult<()> {
337        if control.iter().any(|value| !value.is_finite()) {
338            return Err(RoboticsError::InvalidParameter(
339                "EKF control input must contain only finite values".to_string(),
340            ));
341        }
342
343        Ok(())
344    }
345
346    fn validate_dt(dt: f64) -> RoboticsResult<()> {
347        if !dt.is_finite() || dt <= 0.0 {
348            return Err(RoboticsError::InvalidParameter(format!(
349                "EKF dt must be positive and finite, got {}",
350                dt
351            )));
352        }
353
354        Ok(())
355    }
356}
357
358impl StateEstimator for EKFLocalizer {
359    type State = EKFState;
360    type Measurement = EKFMeasurement;
361    type Control = EKFControl;
362
363    fn predict(&mut self, control: &Self::Control, dt: f64) {
364        let x_pred = Self::motion_model(&self.state, control, dt);
365        let j_f = Self::jacobian_f(&x_pred, control, dt);
366        self.covariance = j_f * self.covariance * j_f.transpose() + self.config.q;
367        self.state = x_pred;
368    }
369
370    fn update(&mut self, measurement: &Self::Measurement) {
371        let j_h = Self::jacobian_h();
372        let z_pred = Self::observation_model(&self.state);
373        let y = measurement - z_pred;
374        let s = j_h * self.covariance * j_h.transpose() + self.config.r;
375
376        if let Some(s_inv) = s.try_inverse() {
377            let k = self.covariance * j_h.transpose() * s_inv;
378            self.state += k * y;
379            self.covariance = (Matrix4::identity() - k * j_h) * self.covariance;
380        }
381    }
382
383    fn get_state(&self) -> &Self::State {
384        &self.state
385    }
386
387    fn get_covariance(&self) -> Option<&nalgebra::DMatrix<f64>> {
388        // Note: We use fixed-size matrix internally, so we don't implement this
389        None
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_ekf_creation() {
399        let ekf = EKFLocalizer::with_defaults();
400        let state = ekf.get_state();
401        assert_eq!(state[0], 0.0);
402        assert_eq!(state[1], 0.0);
403    }
404
405    #[test]
406    fn test_ekf_predict() {
407        let mut ekf = EKFLocalizer::with_defaults();
408        let control = EKFControl::new(1.0, 0.0); // move forward at 1 m/s
409        ekf.predict(&control, 0.1);
410        let state = ekf.get_state();
411        // Should move in x direction (yaw is 0)
412        assert!(state[0] > 0.0);
413        assert!(state[1].abs() < 0.001);
414    }
415
416    #[test]
417    fn test_ekf_update() {
418        let mut ekf = EKFLocalizer::with_defaults();
419        let measurement = EKFMeasurement::new(1.0, 1.0);
420        ekf.update(&measurement);
421        let state = ekf.get_state();
422        // State should move towards measurement
423        assert!(state[0] > 0.0);
424        assert!(state[1] > 0.0);
425    }
426
427    #[test]
428    fn test_ekf_estimate() {
429        let mut ekf = EKFLocalizer::with_defaults();
430        let control = EKFControl::new(1.0, 0.1);
431        let measurement = EKFMeasurement::new(0.1, 0.01);
432
433        let result = ekf.estimate(&measurement, &control, 0.1);
434        assert!(result.is_ok());
435    }
436
437    #[test]
438    fn test_ekf_legacy_interface() {
439        let x_est = EKFState::zeros();
440        let p_est = Matrix4::identity();
441        let z = EKFMeasurement::new(0.1, 0.0);
442        let u = EKFControl::new(1.0, 0.0);
443        let q = EKFConfig::default().q;
444        let r = EKFConfig::default().r;
445
446        let (new_x, new_p) = EKFLocalizer::ekf_estimation(x_est, p_est, z, u, q, r, 0.1);
447
448        assert!(new_x[0] > 0.0);
449        assert!(new_p[(0, 0)] > 0.0);
450    }
451
452    #[test]
453    fn test_ekf_try_new_rejects_invalid_config() {
454        let mut config = EKFConfig::default();
455        config.q[(0, 0)] = -1.0;
456
457        let err = match EKFLocalizer::try_new(config) {
458            Ok(_) => panic!("expected invalid config to be rejected"),
459            Err(err) => err,
460        };
461        assert!(matches!(err, RoboticsError::InvalidParameter(_)));
462    }
463
464    #[test]
465    fn test_ekf_with_initial_state_2d() {
466        let ekf = EKFLocalizer::with_initial_state_2d(
467            State2D::new(1.0, 2.0, 0.3, 0.5),
468            EKFConfig::default(),
469        )
470        .unwrap();
471
472        let state = ekf.state_2d();
473        assert_eq!(state.x, 1.0);
474        assert_eq!(state.y, 2.0);
475        assert_eq!(state.yaw, 0.3);
476        assert_eq!(state.v, 0.5);
477    }
478
479    #[test]
480    fn test_ekf_estimate_state_with_common_types() {
481        let mut ekf = EKFLocalizer::with_defaults();
482        let state = ekf
483            .estimate_state(Point2D::new(0.1, 0.0), ControlInput::new(1.0, 0.0), 0.1)
484            .unwrap();
485
486        assert!(state.x > 0.0);
487    }
488
489    #[test]
490    fn test_ekf_velocity_tracks_control_without_accumulating() {
491        let mut ekf = EKFLocalizer::with_defaults();
492
493        for step in 1..=20 {
494            let state = ekf
495                .estimate_state(
496                    Point2D::new(step as f64 * 0.05, 0.0),
497                    ControlInput::new(0.5, 0.0),
498                    0.1,
499                )
500                .unwrap();
501
502            assert!((state.v - 0.5).abs() < 1e-9);
503        }
504    }
505
506    #[test]
507    fn test_ekf_estimate_rejects_invalid_dt() {
508        let mut ekf = EKFLocalizer::with_defaults();
509        let err = match ekf.estimate(&EKFMeasurement::new(0.0, 0.0), &EKFControl::zeros(), 0.0) {
510            Ok(_) => panic!("expected invalid dt to be rejected"),
511            Err(err) => err,
512        };
513
514        assert!(matches!(err, RoboticsError::InvalidParameter(_)));
515    }
516}