1use alloc::string::ToString;
6use alloc::vec;
7use nalgebra::{DMatrix, DVector, Matrix2, Matrix4, Vector2, Vector4};
8#[cfg(not(feature = "std"))]
9#[allow(unused_imports)]
10use num_traits::Float;
12use rust_robotics_core::{RoboticsError, RoboticsResult, State2D, StateEstimator};
13
14pub type SRUKFState = Vector4<f64>;
16pub type SRUKFMeasurement = Vector2<f64>;
18pub type SRUKFControl = Vector2<f64>;
20
21const STATE_DIM: usize = 4;
22const MEAS_DIM: usize = 2;
23const SIGMA_COUNT: usize = 2 * STATE_DIM + 1;
24const NUMERICAL_EPS: f64 = 1e-12;
25
26#[derive(Debug, Clone)]
28pub struct SRUKFConfig {
29 pub q: Matrix4<f64>,
31 pub r: Matrix2<f64>,
33 pub alpha: f64,
35 pub beta: f64,
37 pub kappa: f64,
39}
40
41impl Default for SRUKFConfig {
42 fn default() -> Self {
43 let mut q = Matrix4::<f64>::identity();
44 q[(0, 0)] = 0.1_f64.powi(2);
45 q[(1, 1)] = (1.0_f64.to_radians()).powi(2);
46 q[(2, 2)] = 0.1_f64.powi(2);
47 q[(3, 3)] = 0.1_f64.powi(2);
48
49 Self {
50 q,
51 r: Matrix2::identity(),
52 alpha: 1e-3,
53 beta: 2.0,
54 kappa: 0.0,
55 }
56 }
57}
58
59impl SRUKFConfig {
60 pub fn validate(&self) -> RoboticsResult<()> {
61 if self.q.iter().any(|value| !value.is_finite()) {
62 return Err(RoboticsError::InvalidParameter(
63 "SR-UKF process noise matrix must contain only finite values".to_string(),
64 ));
65 }
66 if self.r.iter().any(|value| !value.is_finite()) {
67 return Err(RoboticsError::InvalidParameter(
68 "SR-UKF measurement noise matrix must contain only finite values".to_string(),
69 ));
70 }
71 for i in 0..STATE_DIM {
72 if self.q[(i, i)] < 0.0 {
73 return Err(RoboticsError::InvalidParameter(
74 "SR-UKF process noise diagonal entries must be non-negative".to_string(),
75 ));
76 }
77 }
78 for i in 0..MEAS_DIM {
79 if self.r[(i, i)] < 0.0 {
80 return Err(RoboticsError::InvalidParameter(
81 "SR-UKF measurement noise diagonal entries must be non-negative".to_string(),
82 ));
83 }
84 }
85 if !self.alpha.is_finite() || self.alpha <= 0.0 {
86 return Err(RoboticsError::InvalidParameter(
87 "SR-UKF alpha must be positive and finite".to_string(),
88 ));
89 }
90 if !self.beta.is_finite() || self.beta < 0.0 {
91 return Err(RoboticsError::InvalidParameter(
92 "SR-UKF beta must be non-negative and finite".to_string(),
93 ));
94 }
95 if !self.kappa.is_finite() {
96 return Err(RoboticsError::InvalidParameter(
97 "SR-UKF kappa must be finite".to_string(),
98 ));
99 }
100
101 let lambda = self.alpha.powi(2) * (STATE_DIM as f64 + self.kappa) - STATE_DIM as f64;
102 let scale = lambda + STATE_DIM as f64;
103 if !scale.is_finite() || scale <= 0.0 {
104 return Err(RoboticsError::InvalidParameter(
105 "SR-UKF scaling parameters produce a non-positive sigma spread".to_string(),
106 ));
107 }
108
109 Ok(())
110 }
111}
112
113pub struct SRUKFLocalizer {
115 state: SRUKFState,
117 sqrt_covariance: Matrix4<f64>,
119 config: SRUKFConfig,
121 wm: DVector<f64>,
122 wc: DVector<f64>,
123 gamma: f64,
124 covariance_dyn: DMatrix<f64>,
125}
126
127impl SRUKFLocalizer {
128 pub fn new(config: SRUKFConfig) -> Self {
130 Self::try_new(config).expect(
131 "invalid SR-UKF configuration: noise matrices and sigma point parameters must be valid",
132 )
133 }
134
135 pub fn try_new(config: SRUKFConfig) -> RoboticsResult<Self> {
137 config.validate()?;
138 let (wm, wc, gamma) = Self::compute_weights(&config)?;
139 let sqrt_covariance = Matrix4::identity();
140 let covariance_dyn = DMatrix::identity(STATE_DIM, STATE_DIM);
141
142 Ok(Self {
143 state: SRUKFState::zeros(),
144 sqrt_covariance,
145 config,
146 wm,
147 wc,
148 gamma,
149 covariance_dyn,
150 })
151 }
152
153 pub fn with_initial_state(state: SRUKFState, config: SRUKFConfig) -> Self {
155 Self::try_with_initial_state(state, config)
156 .expect("invalid SR-UKF initialization: state must be finite and config must be valid")
157 }
158
159 fn try_with_initial_state(state: SRUKFState, config: SRUKFConfig) -> RoboticsResult<Self> {
160 Self::validate_state(&state)?;
161 let mut localizer = Self::try_new(config)?;
162 localizer.state = state;
163 localizer.refresh_covariance_cache();
164 Ok(localizer)
165 }
166
167 pub fn state_2d(&self) -> State2D {
169 State2D::new(self.state[0], self.state[1], self.state[2], self.state[3])
170 }
171
172 fn motion_model(x: &SRUKFState, u: &SRUKFControl, dt: f64) -> SRUKFState {
174 let yaw = x[2];
175 let f = Matrix4::new(
176 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1.,
177 );
178 let b = nalgebra::Matrix4x2::new(dt * yaw.cos(), 0., dt * yaw.sin(), 0., 0., dt, 1., 0.);
179 f * x + b * u
180 }
181
182 fn observation_model(x: &SRUKFState) -> SRUKFMeasurement {
184 let h = nalgebra::Matrix2x4::new(1., 0., 0., 0., 0., 1., 0., 0.);
185 h * x
186 }
187
188 fn try_predict(&mut self, control: &SRUKFControl, dt: f64) -> RoboticsResult<()> {
189 Self::validate_control(control)?;
190 Self::validate_dt(dt)?;
191
192 let sigma = self.generate_sigma_points();
193 let sigma_pred = core::array::from_fn(|i| Self::motion_model(&sigma[i], control, dt));
194 let x_pred = self.weighted_state_mean(&sigma_pred);
195
196 let sqrt_q = Self::cholesky_lower_4(&self.config.q)?;
197 let mut cols = DMatrix::zeros(STATE_DIM, 2 * STATE_DIM + STATE_DIM);
198 for (i, sigma_point) in sigma_pred.iter().enumerate().skip(1) {
199 let dx = *sigma_point - x_pred;
200 let scaled = dx * self.wc[i].sqrt();
201 cols.set_column(i - 1, &scaled);
202 }
203 for j in 0..STATE_DIM {
204 cols.set_column(2 * STATE_DIM + j, &sqrt_q.column(j).into_owned());
205 }
206
207 let mut sqrt_pred = Self::qr_lower_root_4(&cols)?;
208 let diff0 = sigma_pred[0] - x_pred;
209 let w0 = self.wc[0];
210 let scaled0 = diff0 * w0.abs().sqrt();
211 sqrt_pred = Self::chol_rank1_4(sqrt_pred, scaled0, w0.signum())?;
212
213 self.state = x_pred;
214 self.sqrt_covariance = sqrt_pred;
215 self.refresh_covariance_cache();
216 Ok(())
217 }
218
219 fn try_update(&mut self, measurement: &SRUKFMeasurement) -> RoboticsResult<()> {
220 Self::validate_measurement(measurement)?;
221
222 let sigma = self.generate_sigma_points();
223 let z_sigma = core::array::from_fn(|i| Self::observation_model(&sigma[i]));
224 let z_pred = self.weighted_measurement_mean(&z_sigma);
225
226 let sqrt_r = Self::cholesky_lower_2(&self.config.r)?;
227 let mut cols = DMatrix::zeros(MEAS_DIM, 2 * STATE_DIM + MEAS_DIM);
228 for (i, sigma_point) in z_sigma.iter().enumerate().skip(1) {
229 let dz = *sigma_point - z_pred;
230 let scaled = dz * self.wc[i].sqrt();
231 cols.set_column(i - 1, &scaled);
232 }
233 for j in 0..MEAS_DIM {
234 cols.set_column(2 * STATE_DIM + j, &sqrt_r.column(j).into_owned());
235 }
236
237 let mut sqrt_innovation = Self::qr_lower_root_2(&cols)?;
238 let dz0 = z_sigma[0] - z_pred;
239 let w0 = self.wc[0];
240 let scaled0 = dz0 * w0.abs().sqrt();
241 sqrt_innovation = Self::chol_rank1_2(sqrt_innovation, scaled0, w0.signum())?;
242
243 let mut pxz = DMatrix::zeros(STATE_DIM, MEAS_DIM);
244 for i in 0..SIGMA_COUNT {
245 let dx = sigma[i] - self.state;
246 let dz = z_sigma[i] - z_pred;
247 pxz += self.wc[i] * dx * dz.transpose();
248 }
249
250 let s_cov = sqrt_innovation * sqrt_innovation.transpose();
251 let s_inv = s_cov.try_inverse().ok_or_else(|| {
252 RoboticsError::NumericalError(
253 "Failed to invert SR-UKF innovation covariance".to_string(),
254 )
255 })?;
256 let k = &pxz * s_inv;
257
258 let innovation =
259 DVector::from_vec(vec![measurement[0] - z_pred[0], measurement[1] - z_pred[1]]);
260 let x_update = DVector::from_vec(vec![
261 self.state[0],
262 self.state[1],
263 self.state[2],
264 self.state[3],
265 ]) + &k * innovation;
266 self.state = SRUKFState::new(x_update[0], x_update[1], x_update[2], x_update[3]);
267
268 let u = &k * sqrt_innovation;
269 let mut sqrt_post = self.sqrt_covariance;
270 for j in 0..MEAS_DIM {
271 let downdate_col = Vector4::new(u[(0, j)], u[(1, j)], u[(2, j)], u[(3, j)]);
272 sqrt_post = Self::chol_rank1_4(sqrt_post, downdate_col, -1.0)?;
273 }
274 self.sqrt_covariance = sqrt_post;
275 self.refresh_covariance_cache();
276 Ok(())
277 }
278
279 fn compute_weights(config: &SRUKFConfig) -> RoboticsResult<(DVector<f64>, DVector<f64>, f64)> {
280 let lambda = config.alpha.powi(2) * (STATE_DIM as f64 + config.kappa) - STATE_DIM as f64;
281 let scale = lambda + STATE_DIM as f64;
282 if !scale.is_finite() || scale <= 0.0 {
283 return Err(RoboticsError::InvalidParameter(
284 "SR-UKF scaling parameters produce a non-positive sigma spread".to_string(),
285 ));
286 }
287
288 let mut wm = vec![0.0; SIGMA_COUNT];
289 let mut wc = vec![0.0; SIGMA_COUNT];
290 wm[0] = lambda / scale;
291 wc[0] = wm[0] + (1.0 - config.alpha.powi(2) + config.beta);
292 let tail_weight = 1.0 / (2.0 * scale);
293 for i in 1..SIGMA_COUNT {
294 wm[i] = tail_weight;
295 wc[i] = tail_weight;
296 }
297
298 let gamma = scale.sqrt();
299 if !gamma.is_finite() {
300 return Err(RoboticsError::InvalidParameter(
301 "SR-UKF gamma must be finite".to_string(),
302 ));
303 }
304
305 Ok((DVector::from_vec(wm), DVector::from_vec(wc), gamma))
306 }
307
308 fn generate_sigma_points(&self) -> [SRUKFState; SIGMA_COUNT] {
309 let mut sigma = [self.state; SIGMA_COUNT];
310 sigma[0] = self.state;
311 for i in 0..STATE_DIM {
312 let offset = self.sqrt_covariance.column(i).into_owned() * self.gamma;
313 sigma[i + 1] = self.state + offset;
314 sigma[i + 1 + STATE_DIM] = self.state - offset;
315 }
316 sigma
317 }
318
319 fn weighted_state_mean(&self, sigma: &[SRUKFState; SIGMA_COUNT]) -> SRUKFState {
320 let mut mean = SRUKFState::zeros();
321 for (i, point) in sigma.iter().enumerate() {
322 mean += *point * self.wm[i];
323 }
324 mean
325 }
326
327 fn weighted_measurement_mean(
328 &self,
329 sigma: &[SRUKFMeasurement; SIGMA_COUNT],
330 ) -> SRUKFMeasurement {
331 let mut mean = SRUKFMeasurement::zeros();
332 for (i, point) in sigma.iter().enumerate() {
333 mean += *point * self.wm[i];
334 }
335 mean
336 }
337
338 fn qr_lower_root_4(cols: &DMatrix<f64>) -> RoboticsResult<Matrix4<f64>> {
339 let qr = cols.transpose().qr();
340 let r = qr.r();
341 if r.nrows() < STATE_DIM || r.ncols() < STATE_DIM {
342 return Err(RoboticsError::NumericalError(
343 "QR decomposition returned unexpected dimensions for state root".to_string(),
344 ));
345 }
346
347 let mut lower = Matrix4::zeros();
348 for i in 0..STATE_DIM {
349 for j in 0..=i {
350 lower[(i, j)] = r[(j, i)];
351 }
352 }
353 Self::enforce_positive_diagonal_4(lower)
354 }
355
356 fn qr_lower_root_2(cols: &DMatrix<f64>) -> RoboticsResult<Matrix2<f64>> {
357 let qr = cols.transpose().qr();
358 let r = qr.r();
359 if r.nrows() < MEAS_DIM || r.ncols() < MEAS_DIM {
360 return Err(RoboticsError::NumericalError(
361 "QR decomposition returned unexpected dimensions for innovation root".to_string(),
362 ));
363 }
364
365 let mut lower = Matrix2::zeros();
366 for i in 0..MEAS_DIM {
367 for j in 0..=i {
368 lower[(i, j)] = r[(j, i)];
369 }
370 }
371 Self::enforce_positive_diagonal_2(lower)
372 }
373
374 fn enforce_positive_diagonal_4(mut lower: Matrix4<f64>) -> RoboticsResult<Matrix4<f64>> {
375 for i in 0..STATE_DIM {
376 if !lower[(i, i)].is_finite() || lower[(i, i)].abs() <= NUMERICAL_EPS {
377 return Err(RoboticsError::NumericalError(
378 "State Cholesky factor has a near-zero diagonal".to_string(),
379 ));
380 }
381 if lower[(i, i)] < 0.0 {
382 for row in i..STATE_DIM {
383 lower[(row, i)] = -lower[(row, i)];
384 }
385 }
386 }
387 Ok(lower)
388 }
389
390 fn enforce_positive_diagonal_2(mut lower: Matrix2<f64>) -> RoboticsResult<Matrix2<f64>> {
391 for i in 0..MEAS_DIM {
392 if !lower[(i, i)].is_finite() || lower[(i, i)].abs() <= NUMERICAL_EPS {
393 return Err(RoboticsError::NumericalError(
394 "Innovation Cholesky factor has a near-zero diagonal".to_string(),
395 ));
396 }
397 if lower[(i, i)] < 0.0 {
398 for row in i..MEAS_DIM {
399 lower[(row, i)] = -lower[(row, i)];
400 }
401 }
402 }
403 Ok(lower)
404 }
405
406 fn cholesky_lower_4(cov: &Matrix4<f64>) -> RoboticsResult<Matrix4<f64>> {
407 if let Some(chol) = cov.cholesky() {
408 return Ok(chol.l());
409 }
410 let jittered = *cov + Matrix4::identity() * NUMERICAL_EPS;
411 jittered.cholesky().map(|chol| chol.l()).ok_or_else(|| {
412 RoboticsError::NumericalError(
413 "Failed to compute Cholesky factor for SR-UKF process covariance".to_string(),
414 )
415 })
416 }
417
418 fn cholesky_lower_2(cov: &Matrix2<f64>) -> RoboticsResult<Matrix2<f64>> {
419 if let Some(chol) = cov.cholesky() {
420 return Ok(chol.l());
421 }
422 let jittered = *cov + Matrix2::identity() * NUMERICAL_EPS;
423 jittered.cholesky().map(|chol| chol.l()).ok_or_else(|| {
424 RoboticsError::NumericalError(
425 "Failed to compute Cholesky factor for SR-UKF measurement covariance".to_string(),
426 )
427 })
428 }
429
430 fn chol_rank1_4(
431 mut lower: Matrix4<f64>,
432 mut vector: Vector4<f64>,
433 sign: f64,
434 ) -> RoboticsResult<Matrix4<f64>> {
435 if sign != 1.0 && sign != -1.0 {
436 return Err(RoboticsError::InvalidParameter(
437 "Cholesky rank-1 update sign must be +1 or -1".to_string(),
438 ));
439 }
440 for k in 0..STATE_DIM {
441 let lkk = lower[(k, k)];
442 let xk = vector[k];
443 let radicand = lkk * lkk + sign * xk * xk;
444 if !radicand.is_finite() || radicand <= NUMERICAL_EPS {
445 return Err(RoboticsError::NumericalError(
446 "Cholesky rank-1 update/downdate became non-positive definite".to_string(),
447 ));
448 }
449 let r = radicand.sqrt();
450 let c = r / lkk;
451 let s = xk / lkk;
452 lower[(k, k)] = r;
453
454 for j in (k + 1)..STATE_DIM {
455 let updated = (lower[(j, k)] + sign * s * vector[j]) / c;
456 lower[(j, k)] = updated;
457 vector[j] = c * vector[j] - s * updated;
458 }
459 }
460 Ok(lower)
461 }
462
463 fn chol_rank1_2(
464 mut lower: Matrix2<f64>,
465 mut vector: Vector2<f64>,
466 sign: f64,
467 ) -> RoboticsResult<Matrix2<f64>> {
468 if sign != 1.0 && sign != -1.0 {
469 return Err(RoboticsError::InvalidParameter(
470 "Cholesky rank-1 update sign must be +1 or -1".to_string(),
471 ));
472 }
473 for k in 0..MEAS_DIM {
474 let lkk = lower[(k, k)];
475 let xk = vector[k];
476 let radicand = lkk * lkk + sign * xk * xk;
477 if !radicand.is_finite() || radicand <= NUMERICAL_EPS {
478 return Err(RoboticsError::NumericalError(
479 "Innovation Cholesky rank-1 update/downdate became non-positive definite"
480 .to_string(),
481 ));
482 }
483 let r = radicand.sqrt();
484 let c = r / lkk;
485 let s = xk / lkk;
486 lower[(k, k)] = r;
487
488 for j in (k + 1)..MEAS_DIM {
489 let updated = (lower[(j, k)] + sign * s * vector[j]) / c;
490 lower[(j, k)] = updated;
491 vector[j] = c * vector[j] - s * updated;
492 }
493 }
494 Ok(lower)
495 }
496
497 fn refresh_covariance_cache(&mut self) {
498 let covariance = self.sqrt_covariance * self.sqrt_covariance.transpose();
499 self.covariance_dyn = DMatrix::from_fn(STATE_DIM, STATE_DIM, |i, j| covariance[(i, j)]);
500 }
501
502 fn validate_state(state: &SRUKFState) -> RoboticsResult<()> {
503 if state.iter().any(|value| !value.is_finite()) {
504 return Err(RoboticsError::InvalidParameter(
505 "SR-UKF state must contain only finite values".to_string(),
506 ));
507 }
508 Ok(())
509 }
510
511 fn validate_control(control: &SRUKFControl) -> RoboticsResult<()> {
512 if control.iter().any(|value| !value.is_finite()) {
513 return Err(RoboticsError::InvalidParameter(
514 "SR-UKF control input must contain only finite values".to_string(),
515 ));
516 }
517 Ok(())
518 }
519
520 fn validate_measurement(measurement: &SRUKFMeasurement) -> RoboticsResult<()> {
521 if measurement.iter().any(|value| !value.is_finite()) {
522 return Err(RoboticsError::InvalidParameter(
523 "SR-UKF measurement must contain only finite values".to_string(),
524 ));
525 }
526 Ok(())
527 }
528
529 fn validate_dt(dt: f64) -> RoboticsResult<()> {
530 if !dt.is_finite() || dt <= 0.0 {
531 return Err(RoboticsError::InvalidParameter(
532 "SR-UKF dt must be positive and finite".to_string(),
533 ));
534 }
535 Ok(())
536 }
537}
538
539impl StateEstimator for SRUKFLocalizer {
540 type State = SRUKFState;
541 type Measurement = SRUKFMeasurement;
542 type Control = SRUKFControl;
543
544 fn predict(&mut self, control: &Self::Control, dt: f64) {
545 let _ = self.try_predict(control, dt);
546 }
547
548 fn update(&mut self, measurement: &Self::Measurement) {
549 let _ = self.try_update(measurement);
550 }
551
552 fn get_state(&self) -> &Self::State {
553 &self.state
554 }
555
556 fn get_covariance(&self) -> Option<&DMatrix<f64>> {
557 Some(&self.covariance_dyn)
558 }
559}
560
561#[cfg(test)]
562mod tests {
563 use super::*;
564
565 #[test]
566 fn test_square_root_ukf_creation() {
567 let localizer = SRUKFLocalizer::new(SRUKFConfig::default());
568 assert_eq!(*localizer.get_state(), SRUKFState::zeros());
569 assert!(localizer.sqrt_covariance[(0, 0)] > 0.0);
570 }
571
572 #[test]
573 fn test_square_root_ukf_predict_moves_state() {
574 let mut localizer = SRUKFLocalizer::new(SRUKFConfig::default());
575 localizer.predict(&SRUKFControl::new(1.0, 0.0), 0.1);
576 assert!(localizer.get_state()[0] > 0.0);
577 }
578
579 #[test]
580 fn test_square_root_ukf_update_toward_measurement() {
581 let mut localizer = SRUKFLocalizer::with_initial_state(
582 SRUKFState::new(5.0, 5.0, 0.0, 1.0),
583 SRUKFConfig::default(),
584 );
585 localizer.update(&SRUKFMeasurement::new(0.0, 0.0));
586 assert!(localizer.get_state()[0] < 5.0);
587 assert!(localizer.get_state()[1] < 5.0);
588 }
589
590 #[test]
591 fn test_square_root_ukf_state_2d_helper() {
592 let localizer = SRUKFLocalizer::with_initial_state(
593 SRUKFState::new(1.0, 2.0, 0.3, 0.4),
594 SRUKFConfig::default(),
595 );
596 let state = localizer.state_2d();
597 assert_eq!(state.x, 1.0);
598 assert_eq!(state.y, 2.0);
599 assert_eq!(state.yaw, 0.3);
600 assert_eq!(state.v, 0.4);
601 }
602
603 #[test]
604 fn test_square_root_ukf_try_new_rejects_invalid() {
605 let config = SRUKFConfig {
606 alpha: 0.0,
607 ..Default::default()
608 };
609 let err = match SRUKFLocalizer::try_new(config) {
610 Ok(_) => panic!("expected invalid configuration to fail"),
611 Err(err) => err,
612 };
613 assert!(matches!(err, RoboticsError::InvalidParameter(_)));
614 }
615}