Skip to main content

spatialrust_math/
mat.rs

1use crate::{Scalar, Vec3};
2
3/// 3x3 matrix stored in row-major order.
4#[derive(Clone, Copy, Debug, PartialEq)]
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6pub struct Mat3<T: Scalar> {
7    /// Row-major matrix elements.
8    pub m: [[T; 3]; 3],
9}
10
11/// 4x4 matrix stored in row-major order.
12#[derive(Clone, Copy, Debug, PartialEq)]
13#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
14pub struct Mat4<T: Scalar> {
15    /// Row-major matrix elements.
16    pub m: [[T; 4]; 4],
17}
18
19impl<T: Scalar> Mat3<T> {
20    /// Creates a matrix from row vectors.
21    #[must_use]
22    pub const fn from_rows(row0: [T; 3], row1: [T; 3], row2: [T; 3]) -> Self {
23        Self { m: [row0, row1, row2] }
24    }
25}
26
27impl Mat3<f32> {
28    /// Identity matrix for `f32`.
29    #[must_use]
30    pub fn identity() -> Self {
31        Self::from_rows([1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0])
32    }
33
34    /// Transposed matrix.
35    #[must_use]
36    pub fn transpose(self) -> Self {
37        Self::from_rows(
38            [self.m[0][0], self.m[1][0], self.m[2][0]],
39            [self.m[0][1], self.m[1][1], self.m[2][1]],
40            [self.m[0][2], self.m[1][2], self.m[2][2]],
41        )
42    }
43
44    /// Matrix-vector multiplication.
45    #[must_use]
46    pub fn mul_vec3(self, v: Vec3<f32>) -> Vec3<f32> {
47        Vec3::new(
48            self.m[0][0] * v.x + self.m[0][1] * v.y + self.m[0][2] * v.z,
49            self.m[1][0] * v.x + self.m[1][1] * v.y + self.m[1][2] * v.z,
50            self.m[2][0] * v.x + self.m[2][1] * v.y + self.m[2][2] * v.z,
51        )
52    }
53
54    /// Matrix multiplication.
55    #[must_use]
56    pub fn mul_mat3(self, other: Self) -> Self {
57        Self::from_rows(
58            [
59                self.m[0][0] * other.m[0][0]
60                    + self.m[0][1] * other.m[1][0]
61                    + self.m[0][2] * other.m[2][0],
62                self.m[0][0] * other.m[0][1]
63                    + self.m[0][1] * other.m[1][1]
64                    + self.m[0][2] * other.m[2][1],
65                self.m[0][0] * other.m[0][2]
66                    + self.m[0][1] * other.m[1][2]
67                    + self.m[0][2] * other.m[2][2],
68            ],
69            [
70                self.m[1][0] * other.m[0][0]
71                    + self.m[1][1] * other.m[1][0]
72                    + self.m[1][2] * other.m[2][0],
73                self.m[1][0] * other.m[0][1]
74                    + self.m[1][1] * other.m[1][1]
75                    + self.m[1][2] * other.m[2][1],
76                self.m[1][0] * other.m[0][2]
77                    + self.m[1][1] * other.m[1][2]
78                    + self.m[1][2] * other.m[2][2],
79            ],
80            [
81                self.m[2][0] * other.m[0][0]
82                    + self.m[2][1] * other.m[1][0]
83                    + self.m[2][2] * other.m[2][0],
84                self.m[2][0] * other.m[0][1]
85                    + self.m[2][1] * other.m[1][1]
86                    + self.m[2][2] * other.m[2][1],
87                self.m[2][0] * other.m[0][2]
88                    + self.m[2][1] * other.m[1][2]
89                    + self.m[2][2] * other.m[2][2],
90            ],
91        )
92    }
93}
94
95impl Mat3<f64> {
96    /// Identity matrix for `f64`.
97    #[must_use]
98    pub fn identity() -> Self {
99        Self::from_rows([1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0])
100    }
101
102    /// Transposed matrix.
103    #[must_use]
104    pub fn transpose(self) -> Self {
105        Self::from_rows(
106            [self.m[0][0], self.m[1][0], self.m[2][0]],
107            [self.m[0][1], self.m[1][1], self.m[2][1]],
108            [self.m[0][2], self.m[1][2], self.m[2][2]],
109        )
110    }
111
112    /// Matrix-vector multiplication.
113    #[must_use]
114    pub fn mul_vec3(self, v: Vec3<f64>) -> Vec3<f64> {
115        Vec3::new(
116            self.m[0][0] * v.x + self.m[0][1] * v.y + self.m[0][2] * v.z,
117            self.m[1][0] * v.x + self.m[1][1] * v.y + self.m[1][2] * v.z,
118            self.m[2][0] * v.x + self.m[2][1] * v.y + self.m[2][2] * v.z,
119        )
120    }
121}
122
123impl<T: Scalar> Mat4<T> {
124    /// Creates a matrix from row vectors.
125    #[must_use]
126    pub const fn from_rows(row0: [T; 4], row1: [T; 4], row2: [T; 4], row3: [T; 4]) -> Self {
127        Self { m: [row0, row1, row2, row3] }
128    }
129}
130
131impl Mat4<f32> {
132    /// Identity matrix for `f32`.
133    #[must_use]
134    pub fn identity() -> Self {
135        Self::from_rows(
136            [1.0, 0.0, 0.0, 0.0],
137            [0.0, 1.0, 0.0, 0.0],
138            [0.0, 0.0, 1.0, 0.0],
139            [0.0, 0.0, 0.0, 1.0],
140        )
141    }
142
143    /// Homogeneous point transform.
144    #[must_use]
145    pub fn transform_point(self, point: Vec3<f32>) -> Vec3<f32> {
146        let x =
147            self.m[0][0] * point.x + self.m[0][1] * point.y + self.m[0][2] * point.z + self.m[0][3];
148        let y =
149            self.m[1][0] * point.x + self.m[1][1] * point.y + self.m[1][2] * point.z + self.m[1][3];
150        let z =
151            self.m[2][0] * point.x + self.m[2][1] * point.y + self.m[2][2] * point.z + self.m[2][3];
152        let w =
153            self.m[3][0] * point.x + self.m[3][1] * point.y + self.m[3][2] * point.z + self.m[3][3];
154        if w == 0.0 {
155            return Vec3::new(x, y, z);
156        }
157        Vec3::new(x / w, y / w, z / w)
158    }
159
160    /// Homogeneous vector transform (ignores translation).
161    #[must_use]
162    pub fn transform_vector(self, vector: Vec3<f32>) -> Vec3<f32> {
163        Vec3::new(
164            self.m[0][0] * vector.x + self.m[0][1] * vector.y + self.m[0][2] * vector.z,
165            self.m[1][0] * vector.x + self.m[1][1] * vector.y + self.m[1][2] * vector.z,
166            self.m[2][0] * vector.x + self.m[2][1] * vector.y + self.m[2][2] * vector.z,
167        )
168    }
169
170    /// Builds a rigid transform matrix from rotation and translation.
171    #[must_use]
172    pub fn from_rotation_translation(rotation: Mat3<f32>, translation: Vec3<f32>) -> Self {
173        Self::from_rows(
174            [rotation.m[0][0], rotation.m[0][1], rotation.m[0][2], translation.x],
175            [rotation.m[1][0], rotation.m[1][1], rotation.m[1][2], translation.y],
176            [rotation.m[2][0], rotation.m[2][1], rotation.m[2][2], translation.z],
177            [0.0, 0.0, 0.0, 1.0],
178        )
179    }
180}
181
182impl Mat4<f64> {
183    /// Identity matrix for `f64`.
184    #[must_use]
185    pub fn identity() -> Self {
186        Self::from_rows(
187            [1.0, 0.0, 0.0, 0.0],
188            [0.0, 1.0, 0.0, 0.0],
189            [0.0, 0.0, 1.0, 0.0],
190            [0.0, 0.0, 0.0, 1.0],
191        )
192    }
193
194    /// Homogeneous point transform.
195    #[must_use]
196    pub fn transform_point(self, point: Vec3<f64>) -> Vec3<f64> {
197        let x =
198            self.m[0][0] * point.x + self.m[0][1] * point.y + self.m[0][2] * point.z + self.m[0][3];
199        let y =
200            self.m[1][0] * point.x + self.m[1][1] * point.y + self.m[1][2] * point.z + self.m[1][3];
201        let z =
202            self.m[2][0] * point.x + self.m[2][1] * point.y + self.m[2][2] * point.z + self.m[2][3];
203        let w =
204            self.m[3][0] * point.x + self.m[3][1] * point.y + self.m[3][2] * point.z + self.m[3][3];
205        if w == 0.0 {
206            return Vec3::new(x, y, z);
207        }
208        Vec3::new(x / w, y / w, z / w)
209    }
210
211    /// Builds a rigid transform matrix from rotation and translation.
212    #[must_use]
213    pub fn from_rotation_translation(rotation: Mat3<f64>, translation: Vec3<f64>) -> Self {
214        Self::from_rows(
215            [rotation.m[0][0], rotation.m[0][1], rotation.m[0][2], translation.x],
216            [rotation.m[1][0], rotation.m[1][1], rotation.m[1][2], translation.y],
217            [rotation.m[2][0], rotation.m[2][1], rotation.m[2][2], translation.z],
218            [0.0, 0.0, 0.0, 1.0],
219        )
220    }
221}
222
223#[cfg(test)]
224mod tests {
225    use super::{Mat3, Mat4, Vec3};
226
227    #[test]
228    fn mat3_mul_vec3() {
229        let rot_y: Mat3<f32> = Mat3::from_rows([0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]);
230        let v = Vec3::new(1.0_f32, 0.0, 0.0);
231        let out = rot_y.mul_vec3(v);
232        assert!((out.x - 0.0).abs() < 1e-6);
233        assert!((out.z - (-1.0)).abs() < 1e-6);
234    }
235
236    #[test]
237    fn mat4_transform_point() {
238        let transform = Mat4::<f32>::from_rotation_translation(
239            Mat3::<f32>::identity(),
240            Vec3::new(1.0, 2.0, 3.0),
241        );
242        let p = Vec3::new(0.0_f32, 0.0, 0.0);
243        let out = transform.transform_point(p);
244        assert!((out.x - 1.0).abs() < 1e-6);
245        assert!((out.y - 2.0).abs() < 1e-6);
246        assert!((out.z - 3.0).abs() < 1e-6);
247    }
248
249    #[test]
250    fn mat4_f64_roundtrip() {
251        let m = Mat4::<f64>::identity();
252        let p = Vec3::new(1.0_f64, 2.0, 3.0);
253        let out = m.transform_point(p);
254        assert!((out.x - 1.0).abs() < 1e-12);
255    }
256}