spatialrust_math/
covariance.rs1use crate::{Mat3, Vec3};
2
3#[derive(Clone, Debug, PartialEq)]
5pub struct CovarianceAccumulator3 {
6 count: u64,
7 sum: [f64; 3],
8 sum_sq: [f64; 6],
9}
10
11impl Default for CovarianceAccumulator3 {
12 fn default() -> Self {
13 Self::new()
14 }
15}
16
17impl CovarianceAccumulator3 {
18 #[must_use]
20 pub fn new() -> Self {
21 Self { count: 0, sum: [0.0; 3], sum_sq: [0.0; 6] }
22 }
23
24 pub fn push(&mut self, point: Vec3<f32>) {
26 self.count += 1;
27 self.sum[0] += f64::from(point.x);
28 self.sum[1] += f64::from(point.y);
29 self.sum[2] += f64::from(point.z);
30 self.sum_sq[0] += f64::from(point.x * point.x);
31 self.sum_sq[1] += f64::from(point.y * point.y);
32 self.sum_sq[2] += f64::from(point.z * point.z);
33 self.sum_sq[3] += f64::from(point.x * point.y);
34 self.sum_sq[4] += f64::from(point.x * point.z);
35 self.sum_sq[5] += f64::from(point.y * point.z);
36 }
37
38 #[must_use]
40 pub const fn count(&self) -> u64 {
41 self.count
42 }
43
44 #[must_use]
46 pub fn mean(&self) -> Option<Vec3<f64>> {
47 if self.count == 0 {
48 return None;
49 }
50 let n = self.count as f64;
51 Some(Vec3::new(self.sum[0] / n, self.sum[1] / n, self.sum[2] / n))
52 }
53
54 #[must_use]
56 pub fn covariance(&self) -> Option<Mat3<f64>> {
57 if self.count < 2 {
58 return None;
59 }
60 let n = self.count as f64;
61 let mean = self.mean()?;
62 let inv = 1.0 / (n - 1.0);
63
64 let c00 = inv * (self.sum_sq[0] - n * mean.x * mean.x);
65 let c11 = inv * (self.sum_sq[1] - n * mean.y * mean.y);
66 let c22 = inv * (self.sum_sq[2] - n * mean.z * mean.z);
67 let c01 = inv * (self.sum_sq[3] - n * mean.x * mean.y);
68 let c02 = inv * (self.sum_sq[4] - n * mean.x * mean.z);
69 let c12 = inv * (self.sum_sq[5] - n * mean.y * mean.z);
70
71 Some(Mat3::from_rows([c00, c01, c02], [c01, c11, c12], [c02, c12, c22]))
72 }
73}
74
75#[cfg(test)]
76mod tests {
77 use super::CovarianceAccumulator3;
78 use crate::{tolerance::approx_eq_f64, Vec3};
79
80 #[test]
81 fn covariance_of_axis_points() {
82 let mut acc = CovarianceAccumulator3::new();
83 acc.push(Vec3::new(0.0, 0.0, 0.0));
84 acc.push(Vec3::new(1.0, 0.0, 0.0));
85 acc.push(Vec3::new(2.0, 0.0, 0.0));
86 let cov = acc.covariance().unwrap();
87 assert!(approx_eq_f64(cov.m[0][0], 1.0, 1e-6));
88 assert!(approx_eq_f64(cov.m[1][1], 0.0, 1e-6));
89 assert!(approx_eq_f64(cov.m[2][2], 0.0, 1e-6));
90 }
91}