1use alloc::format;
7use alloc::string::ToString;
8use nalgebra::{Matrix2, Matrix2x4, Matrix4, Vector2, Vector4};
9#[cfg(not(feature = "std"))]
10#[allow(unused_imports)]
11use num_traits::Float;
13use rust_robotics_core::{
14 ControlInput, Point2D, RoboticsError, RoboticsResult, State2D, StateEstimator,
15};
16
17pub type EKFState = Vector4<f64>;
19
20pub type EKFMeasurement = Vector2<f64>;
22
23pub type EKFControl = Vector2<f64>;
25
26#[derive(Debug, Clone)]
28pub struct EKFConfig {
29 pub q: Matrix4<f64>,
31 pub r: Matrix2<f64>,
33}
34
35impl Default for EKFConfig {
36 fn default() -> Self {
37 let mut q = Matrix4::<f64>::identity();
38 q[(0, 0)] = 0.1_f64.powi(2);
39 q[(1, 1)] = 0.1_f64.powi(2);
40 q[(2, 2)] = (1.0_f64.to_radians()).powi(2);
41 q[(3, 3)] = 0.1_f64.powi(2);
42
43 let r = Matrix2::<f64>::identity();
44
45 Self { q, r }
46 }
47}
48
49impl EKFConfig {
50 pub fn validate(&self) -> RoboticsResult<()> {
51 if self.q.iter().any(|value| !value.is_finite()) {
52 return Err(RoboticsError::InvalidParameter(
53 "EKF process noise matrix must contain only finite values".to_string(),
54 ));
55 }
56 if self.r.iter().any(|value| !value.is_finite()) {
57 return Err(RoboticsError::InvalidParameter(
58 "EKF measurement noise matrix must contain only finite values".to_string(),
59 ));
60 }
61 for i in 0..4 {
62 if self.q[(i, i)] < 0.0 {
63 return Err(RoboticsError::InvalidParameter(
64 "EKF process noise diagonal entries must be non-negative".to_string(),
65 ));
66 }
67 }
68 for i in 0..2 {
69 if self.r[(i, i)] < 0.0 {
70 return Err(RoboticsError::InvalidParameter(
71 "EKF measurement noise diagonal entries must be non-negative".to_string(),
72 ));
73 }
74 }
75
76 Ok(())
77 }
78}
79
80pub struct EKFLocalizer {
82 state: EKFState,
84 covariance: Matrix4<f64>,
86 config: EKFConfig,
88}
89
90impl EKFLocalizer {
91 pub fn new(config: EKFConfig) -> Self {
93 Self::try_new(config).expect(
94 "invalid EKF configuration: noise matrices must be finite and have non-negative diagonals",
95 )
96 }
97
98 pub fn try_new(config: EKFConfig) -> RoboticsResult<Self> {
100 config.validate()?;
101 Ok(EKFLocalizer {
102 state: EKFState::zeros(),
103 covariance: Matrix4::identity(),
104 config,
105 })
106 }
107
108 pub fn with_defaults() -> Self {
110 Self::new(EKFConfig::default())
111 }
112
113 pub fn with_initial_state(initial_state: EKFState, config: EKFConfig) -> Self {
115 Self::try_with_initial_state(initial_state, config)
116 .expect("invalid EKF initialization: state must be finite and config must be valid")
117 }
118
119 pub fn try_with_initial_state(
121 initial_state: EKFState,
122 config: EKFConfig,
123 ) -> RoboticsResult<Self> {
124 config.validate()?;
125 Self::validate_state_vector(&initial_state)?;
126 Ok(EKFLocalizer {
127 state: initial_state,
128 covariance: Matrix4::identity(),
129 config,
130 })
131 }
132
133 pub fn with_initial_state_2d(
135 initial_state: State2D,
136 config: EKFConfig,
137 ) -> RoboticsResult<Self> {
138 Self::try_with_initial_state(initial_state.to_vector(), config)
139 }
140
141 pub fn get_covariance_matrix(&self) -> &Matrix4<f64> {
143 &self.covariance
144 }
145
146 pub fn state_2d(&self) -> State2D {
148 State2D::new(self.state[0], self.state[1], self.state[2], self.state[3])
149 }
150
151 pub fn set_process_noise(&mut self, q: Matrix4<f64>) {
153 self.try_set_process_noise(q)
154 .expect("invalid EKF process noise matrix")
155 }
156
157 pub fn try_set_process_noise(&mut self, q: Matrix4<f64>) -> RoboticsResult<()> {
159 if q.iter().any(|value| !value.is_finite()) {
160 return Err(RoboticsError::InvalidParameter(
161 "EKF process noise matrix must contain only finite values".to_string(),
162 ));
163 }
164 for i in 0..4 {
165 if q[(i, i)] < 0.0 {
166 return Err(RoboticsError::InvalidParameter(
167 "EKF process noise diagonal entries must be non-negative".to_string(),
168 ));
169 }
170 }
171
172 self.config.q = q;
173 Ok(())
174 }
175
176 pub fn set_measurement_noise(&mut self, r: Matrix2<f64>) {
178 self.try_set_measurement_noise(r)
179 .expect("invalid EKF measurement noise matrix")
180 }
181
182 pub fn try_set_measurement_noise(&mut self, r: Matrix2<f64>) -> RoboticsResult<()> {
184 if r.iter().any(|value| !value.is_finite()) {
185 return Err(RoboticsError::InvalidParameter(
186 "EKF measurement noise matrix must contain only finite values".to_string(),
187 ));
188 }
189 for i in 0..2 {
190 if r[(i, i)] < 0.0 {
191 return Err(RoboticsError::InvalidParameter(
192 "EKF measurement noise diagonal entries must be non-negative".to_string(),
193 ));
194 }
195 }
196
197 self.config.r = r;
198 Ok(())
199 }
200
201 fn motion_model(x: &EKFState, u: &EKFControl, dt: f64) -> EKFState {
203 let yaw = x[2];
204 EKFState::new(
205 x[0] + dt * u[0] * yaw.cos(),
206 x[1] + dt * u[0] * yaw.sin(),
207 x[2] + dt * u[1],
208 u[0],
209 )
210 }
211
212 fn jacobian_f(x: &EKFState, u: &EKFControl, dt: f64) -> Matrix4<f64> {
214 let yaw = x[2];
215 let v = u[0];
216 Matrix4::new(
217 1.,
218 0.,
219 -dt * v * yaw.sin(),
220 0.,
221 0.,
222 1.,
223 dt * v * yaw.cos(),
224 0.,
225 0.,
226 0.,
227 1.,
228 0.,
229 0.,
230 0.,
231 0.,
232 0.,
233 )
234 }
235
236 fn observation_model(x: &EKFState) -> EKFMeasurement {
238 let h = Matrix2x4::new(1., 0., 0., 0., 0., 1., 0., 0.);
239 h * x
240 }
241
242 fn jacobian_h() -> Matrix2x4<f64> {
244 Matrix2x4::new(1., 0., 0., 0., 0., 1., 0., 0.)
245 }
246
247 pub fn estimate(
249 &mut self,
250 measurement: &EKFMeasurement,
251 control: &EKFControl,
252 dt: f64,
253 ) -> Result<&EKFState, RoboticsError> {
254 Self::validate_measurement_vector(measurement)?;
255 Self::validate_control_vector(control)?;
256 Self::validate_dt(dt)?;
257
258 let x_pred = Self::motion_model(&self.state, control, dt);
260 let j_f = Self::jacobian_f(&x_pred, control, dt);
261 let p_pred = j_f * self.covariance * j_f.transpose() + self.config.q;
262
263 let j_h = Self::jacobian_h();
265 let z_pred = Self::observation_model(&x_pred);
266 let y = measurement - z_pred;
267 let s = j_h * p_pred * j_h.transpose() + self.config.r;
268
269 let s_inv = s.try_inverse().ok_or_else(|| {
270 RoboticsError::NumericalError("Failed to invert S matrix".to_string())
271 })?;
272
273 let k = p_pred * j_h.transpose() * s_inv;
274 self.state = x_pred + k * y;
275 self.covariance = (Matrix4::identity() - k * j_h) * p_pred;
276
277 Ok(&self.state)
278 }
279
280 pub fn estimate_state(
282 &mut self,
283 measurement: Point2D,
284 control: ControlInput,
285 dt: f64,
286 ) -> RoboticsResult<State2D> {
287 self.estimate(&measurement.to_vector(), &control.to_vector(), dt)?;
288 Ok(self.state_2d())
289 }
290
291 pub fn ekf_estimation(
293 x_est: EKFState,
294 p_est: Matrix4<f64>,
295 z: EKFMeasurement,
296 u: EKFControl,
297 q: Matrix4<f64>,
298 r: Matrix2<f64>,
299 dt: f64,
300 ) -> (EKFState, Matrix4<f64>) {
301 let x_pred = Self::motion_model(&x_est, &u, dt);
302 let j_f = Self::jacobian_f(&x_pred, &u, dt);
303 let p_pred = j_f * p_est * j_f.transpose() + q;
304
305 let j_h = Self::jacobian_h();
306 let z_pred = Self::observation_model(&x_pred);
307 let y = z - z_pred;
308 let s = j_h * p_pred * j_h.transpose() + r;
309 let k = p_pred * j_h.transpose() * s.try_inverse().unwrap();
310 let new_x_est = x_pred + k * y;
311 let new_p_est = (Matrix4::identity() - k * j_h) * p_pred;
312
313 (new_x_est, new_p_est)
314 }
315
316 fn validate_state_vector(state: &EKFState) -> RoboticsResult<()> {
317 if state.iter().any(|value| !value.is_finite()) {
318 return Err(RoboticsError::InvalidParameter(
319 "EKF state must contain only finite values".to_string(),
320 ));
321 }
322
323 Ok(())
324 }
325
326 fn validate_measurement_vector(measurement: &EKFMeasurement) -> RoboticsResult<()> {
327 if measurement.iter().any(|value| !value.is_finite()) {
328 return Err(RoboticsError::InvalidParameter(
329 "EKF measurement must contain only finite values".to_string(),
330 ));
331 }
332
333 Ok(())
334 }
335
336 fn validate_control_vector(control: &EKFControl) -> RoboticsResult<()> {
337 if control.iter().any(|value| !value.is_finite()) {
338 return Err(RoboticsError::InvalidParameter(
339 "EKF control input must contain only finite values".to_string(),
340 ));
341 }
342
343 Ok(())
344 }
345
346 fn validate_dt(dt: f64) -> RoboticsResult<()> {
347 if !dt.is_finite() || dt <= 0.0 {
348 return Err(RoboticsError::InvalidParameter(format!(
349 "EKF dt must be positive and finite, got {}",
350 dt
351 )));
352 }
353
354 Ok(())
355 }
356}
357
358impl StateEstimator for EKFLocalizer {
359 type State = EKFState;
360 type Measurement = EKFMeasurement;
361 type Control = EKFControl;
362
363 fn predict(&mut self, control: &Self::Control, dt: f64) {
364 let x_pred = Self::motion_model(&self.state, control, dt);
365 let j_f = Self::jacobian_f(&x_pred, control, dt);
366 self.covariance = j_f * self.covariance * j_f.transpose() + self.config.q;
367 self.state = x_pred;
368 }
369
370 fn update(&mut self, measurement: &Self::Measurement) {
371 let j_h = Self::jacobian_h();
372 let z_pred = Self::observation_model(&self.state);
373 let y = measurement - z_pred;
374 let s = j_h * self.covariance * j_h.transpose() + self.config.r;
375
376 if let Some(s_inv) = s.try_inverse() {
377 let k = self.covariance * j_h.transpose() * s_inv;
378 self.state += k * y;
379 self.covariance = (Matrix4::identity() - k * j_h) * self.covariance;
380 }
381 }
382
383 fn get_state(&self) -> &Self::State {
384 &self.state
385 }
386
387 fn get_covariance(&self) -> Option<&nalgebra::DMatrix<f64>> {
388 None
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_ekf_creation() {
399 let ekf = EKFLocalizer::with_defaults();
400 let state = ekf.get_state();
401 assert_eq!(state[0], 0.0);
402 assert_eq!(state[1], 0.0);
403 }
404
405 #[test]
406 fn test_ekf_predict() {
407 let mut ekf = EKFLocalizer::with_defaults();
408 let control = EKFControl::new(1.0, 0.0); ekf.predict(&control, 0.1);
410 let state = ekf.get_state();
411 assert!(state[0] > 0.0);
413 assert!(state[1].abs() < 0.001);
414 }
415
416 #[test]
417 fn test_ekf_update() {
418 let mut ekf = EKFLocalizer::with_defaults();
419 let measurement = EKFMeasurement::new(1.0, 1.0);
420 ekf.update(&measurement);
421 let state = ekf.get_state();
422 assert!(state[0] > 0.0);
424 assert!(state[1] > 0.0);
425 }
426
427 #[test]
428 fn test_ekf_estimate() {
429 let mut ekf = EKFLocalizer::with_defaults();
430 let control = EKFControl::new(1.0, 0.1);
431 let measurement = EKFMeasurement::new(0.1, 0.01);
432
433 let result = ekf.estimate(&measurement, &control, 0.1);
434 assert!(result.is_ok());
435 }
436
437 #[test]
438 fn test_ekf_legacy_interface() {
439 let x_est = EKFState::zeros();
440 let p_est = Matrix4::identity();
441 let z = EKFMeasurement::new(0.1, 0.0);
442 let u = EKFControl::new(1.0, 0.0);
443 let q = EKFConfig::default().q;
444 let r = EKFConfig::default().r;
445
446 let (new_x, new_p) = EKFLocalizer::ekf_estimation(x_est, p_est, z, u, q, r, 0.1);
447
448 assert!(new_x[0] > 0.0);
449 assert!(new_p[(0, 0)] > 0.0);
450 }
451
452 #[test]
453 fn test_ekf_try_new_rejects_invalid_config() {
454 let mut config = EKFConfig::default();
455 config.q[(0, 0)] = -1.0;
456
457 let err = match EKFLocalizer::try_new(config) {
458 Ok(_) => panic!("expected invalid config to be rejected"),
459 Err(err) => err,
460 };
461 assert!(matches!(err, RoboticsError::InvalidParameter(_)));
462 }
463
464 #[test]
465 fn test_ekf_with_initial_state_2d() {
466 let ekf = EKFLocalizer::with_initial_state_2d(
467 State2D::new(1.0, 2.0, 0.3, 0.5),
468 EKFConfig::default(),
469 )
470 .unwrap();
471
472 let state = ekf.state_2d();
473 assert_eq!(state.x, 1.0);
474 assert_eq!(state.y, 2.0);
475 assert_eq!(state.yaw, 0.3);
476 assert_eq!(state.v, 0.5);
477 }
478
479 #[test]
480 fn test_ekf_estimate_state_with_common_types() {
481 let mut ekf = EKFLocalizer::with_defaults();
482 let state = ekf
483 .estimate_state(Point2D::new(0.1, 0.0), ControlInput::new(1.0, 0.0), 0.1)
484 .unwrap();
485
486 assert!(state.x > 0.0);
487 }
488
489 #[test]
490 fn test_ekf_velocity_tracks_control_without_accumulating() {
491 let mut ekf = EKFLocalizer::with_defaults();
492
493 for step in 1..=20 {
494 let state = ekf
495 .estimate_state(
496 Point2D::new(step as f64 * 0.05, 0.0),
497 ControlInput::new(0.5, 0.0),
498 0.1,
499 )
500 .unwrap();
501
502 assert!((state.v - 0.5).abs() < 1e-9);
503 }
504 }
505
506 #[test]
507 fn test_ekf_estimate_rejects_invalid_dt() {
508 let mut ekf = EKFLocalizer::with_defaults();
509 let err = match ekf.estimate(&EKFMeasurement::new(0.0, 0.0), &EKFControl::zeros(), 0.0) {
510 Ok(_) => panic!("expected invalid dt to be rejected"),
511 Err(err) => err,
512 };
513
514 assert!(matches!(err, RoboticsError::InvalidParameter(_)));
515 }
516}