Skip to main content

spatialrust_math/
linalg.rs

1/// Result of a small dense linear solve.
2#[derive(Clone, Debug, PartialEq)]
3pub enum LeastSquaresResult {
4    /// Unique solution vector.
5    Solved(Vec<f64>),
6    /// The system is singular or ill-conditioned.
7    Singular,
8}
9
10/// Solves `A x = b` for square `n x n` systems using Gaussian elimination with partial pivoting.
11#[must_use]
12pub fn solve_linear_system(mut a: Vec<Vec<f64>>, mut b: Vec<f64>) -> LeastSquaresResult {
13    let n = b.len();
14    if n == 0 || a.len() != n || a.iter().any(|row| row.len() != n) {
15        return LeastSquaresResult::Singular;
16    }
17
18    for col in 0..n {
19        let mut pivot_row = col;
20        for row in (col + 1)..n {
21            if a[row][col].abs() > a[pivot_row][col].abs() {
22                pivot_row = row;
23            }
24        }
25        if a[pivot_row][col].abs() < 1e-12 {
26            return LeastSquaresResult::Singular;
27        }
28        if pivot_row != col {
29            a.swap(pivot_row, col);
30            b.swap(pivot_row, col);
31        }
32
33        for row in (col + 1)..n {
34            let factor = a[row][col] / a[col][col];
35            #[allow(clippy::needless_range_loop)]
36            for k in col..n {
37                a[row][k] -= factor * a[col][k];
38            }
39            b[row] -= factor * b[col];
40        }
41    }
42
43    let mut x = vec![0.0; n];
44    for row in (0..n).rev() {
45        let mut sum = b[row];
46        for col in (row + 1)..n {
47            sum -= a[row][col] * x[col];
48        }
49        x[row] = sum / a[row][row];
50    }
51
52    LeastSquaresResult::Solved(x)
53}
54
55#[cfg(test)]
56mod tests {
57    use super::{solve_linear_system, LeastSquaresResult};
58
59    #[test]
60    fn solves_3x3_system() {
61        let a = vec![vec![3.0, 2.0, -1.0], vec![2.0, -2.0, 4.0], vec![-1.0, 0.5, -1.0]];
62        let b = vec![1.0, -2.0, 0.0];
63        match solve_linear_system(a, b) {
64            LeastSquaresResult::Solved(x) => {
65                assert!((x[0] - 1.0).abs() < 1e-9);
66                assert!((x[1] - -2.0).abs() < 1e-9);
67                assert!((x[2] - -2.0).abs() < 1e-9);
68            }
69            LeastSquaresResult::Singular => panic!("expected unique solution"),
70        }
71    }
72}