spatialrust_math/
linalg.rs1#[derive(Clone, Debug, PartialEq)]
3pub enum LeastSquaresResult {
4 Solved(Vec<f64>),
6 Singular,
8}
9
10#[must_use]
12pub fn solve_linear_system(mut a: Vec<Vec<f64>>, mut b: Vec<f64>) -> LeastSquaresResult {
13 let n = b.len();
14 if n == 0 || a.len() != n || a.iter().any(|row| row.len() != n) {
15 return LeastSquaresResult::Singular;
16 }
17
18 for col in 0..n {
19 let mut pivot_row = col;
20 for row in (col + 1)..n {
21 if a[row][col].abs() > a[pivot_row][col].abs() {
22 pivot_row = row;
23 }
24 }
25 if a[pivot_row][col].abs() < 1e-12 {
26 return LeastSquaresResult::Singular;
27 }
28 if pivot_row != col {
29 a.swap(pivot_row, col);
30 b.swap(pivot_row, col);
31 }
32
33 for row in (col + 1)..n {
34 let factor = a[row][col] / a[col][col];
35 #[allow(clippy::needless_range_loop)]
36 for k in col..n {
37 a[row][k] -= factor * a[col][k];
38 }
39 b[row] -= factor * b[col];
40 }
41 }
42
43 let mut x = vec![0.0; n];
44 for row in (0..n).rev() {
45 let mut sum = b[row];
46 for col in (row + 1)..n {
47 sum -= a[row][col] * x[col];
48 }
49 x[row] = sum / a[row][row];
50 }
51
52 LeastSquaresResult::Solved(x)
53}
54
55#[cfg(test)]
56mod tests {
57 use super::{solve_linear_system, LeastSquaresResult};
58
59 #[test]
60 fn solves_3x3_system() {
61 let a = vec![vec![3.0, 2.0, -1.0], vec![2.0, -2.0, 4.0], vec![-1.0, 0.5, -1.0]];
62 let b = vec![1.0, -2.0, 0.0];
63 match solve_linear_system(a, b) {
64 LeastSquaresResult::Solved(x) => {
65 assert!((x[0] - 1.0).abs() < 1e-9);
66 assert!((x[1] - -2.0).abs() < 1e-9);
67 assert!((x[2] - -2.0).abs() < 1e-9);
68 }
69 LeastSquaresResult::Singular => panic!("expected unique solution"),
70 }
71 }
72}