1use crate::{Scalar, Vec3};
2
3#[derive(Clone, Copy, Debug, PartialEq)]
5#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
6pub struct Mat3<T: Scalar> {
7 pub m: [[T; 3]; 3],
9}
10
11#[derive(Clone, Copy, Debug, PartialEq)]
13#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
14pub struct Mat4<T: Scalar> {
15 pub m: [[T; 4]; 4],
17}
18
19impl<T: Scalar> Mat3<T> {
20 #[must_use]
22 pub const fn from_rows(row0: [T; 3], row1: [T; 3], row2: [T; 3]) -> Self {
23 Self { m: [row0, row1, row2] }
24 }
25}
26
27impl Mat3<f32> {
28 #[must_use]
30 pub fn identity() -> Self {
31 Self::from_rows([1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0])
32 }
33
34 #[must_use]
36 pub fn transpose(self) -> Self {
37 Self::from_rows(
38 [self.m[0][0], self.m[1][0], self.m[2][0]],
39 [self.m[0][1], self.m[1][1], self.m[2][1]],
40 [self.m[0][2], self.m[1][2], self.m[2][2]],
41 )
42 }
43
44 #[must_use]
46 pub fn mul_vec3(self, v: Vec3<f32>) -> Vec3<f32> {
47 Vec3::new(
48 self.m[0][0] * v.x + self.m[0][1] * v.y + self.m[0][2] * v.z,
49 self.m[1][0] * v.x + self.m[1][1] * v.y + self.m[1][2] * v.z,
50 self.m[2][0] * v.x + self.m[2][1] * v.y + self.m[2][2] * v.z,
51 )
52 }
53
54 #[must_use]
56 pub fn mul_mat3(self, other: Self) -> Self {
57 Self::from_rows(
58 [
59 self.m[0][0] * other.m[0][0]
60 + self.m[0][1] * other.m[1][0]
61 + self.m[0][2] * other.m[2][0],
62 self.m[0][0] * other.m[0][1]
63 + self.m[0][1] * other.m[1][1]
64 + self.m[0][2] * other.m[2][1],
65 self.m[0][0] * other.m[0][2]
66 + self.m[0][1] * other.m[1][2]
67 + self.m[0][2] * other.m[2][2],
68 ],
69 [
70 self.m[1][0] * other.m[0][0]
71 + self.m[1][1] * other.m[1][0]
72 + self.m[1][2] * other.m[2][0],
73 self.m[1][0] * other.m[0][1]
74 + self.m[1][1] * other.m[1][1]
75 + self.m[1][2] * other.m[2][1],
76 self.m[1][0] * other.m[0][2]
77 + self.m[1][1] * other.m[1][2]
78 + self.m[1][2] * other.m[2][2],
79 ],
80 [
81 self.m[2][0] * other.m[0][0]
82 + self.m[2][1] * other.m[1][0]
83 + self.m[2][2] * other.m[2][0],
84 self.m[2][0] * other.m[0][1]
85 + self.m[2][1] * other.m[1][1]
86 + self.m[2][2] * other.m[2][1],
87 self.m[2][0] * other.m[0][2]
88 + self.m[2][1] * other.m[1][2]
89 + self.m[2][2] * other.m[2][2],
90 ],
91 )
92 }
93}
94
95impl Mat3<f64> {
96 #[must_use]
98 pub fn identity() -> Self {
99 Self::from_rows([1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0])
100 }
101
102 #[must_use]
104 pub fn transpose(self) -> Self {
105 Self::from_rows(
106 [self.m[0][0], self.m[1][0], self.m[2][0]],
107 [self.m[0][1], self.m[1][1], self.m[2][1]],
108 [self.m[0][2], self.m[1][2], self.m[2][2]],
109 )
110 }
111
112 #[must_use]
114 pub fn mul_vec3(self, v: Vec3<f64>) -> Vec3<f64> {
115 Vec3::new(
116 self.m[0][0] * v.x + self.m[0][1] * v.y + self.m[0][2] * v.z,
117 self.m[1][0] * v.x + self.m[1][1] * v.y + self.m[1][2] * v.z,
118 self.m[2][0] * v.x + self.m[2][1] * v.y + self.m[2][2] * v.z,
119 )
120 }
121}
122
123impl<T: Scalar> Mat4<T> {
124 #[must_use]
126 pub const fn from_rows(row0: [T; 4], row1: [T; 4], row2: [T; 4], row3: [T; 4]) -> Self {
127 Self { m: [row0, row1, row2, row3] }
128 }
129}
130
131impl Mat4<f32> {
132 #[must_use]
134 pub fn identity() -> Self {
135 Self::from_rows(
136 [1.0, 0.0, 0.0, 0.0],
137 [0.0, 1.0, 0.0, 0.0],
138 [0.0, 0.0, 1.0, 0.0],
139 [0.0, 0.0, 0.0, 1.0],
140 )
141 }
142
143 #[must_use]
145 pub fn transform_point(self, point: Vec3<f32>) -> Vec3<f32> {
146 let x =
147 self.m[0][0] * point.x + self.m[0][1] * point.y + self.m[0][2] * point.z + self.m[0][3];
148 let y =
149 self.m[1][0] * point.x + self.m[1][1] * point.y + self.m[1][2] * point.z + self.m[1][3];
150 let z =
151 self.m[2][0] * point.x + self.m[2][1] * point.y + self.m[2][2] * point.z + self.m[2][3];
152 let w =
153 self.m[3][0] * point.x + self.m[3][1] * point.y + self.m[3][2] * point.z + self.m[3][3];
154 if w == 0.0 {
155 return Vec3::new(x, y, z);
156 }
157 Vec3::new(x / w, y / w, z / w)
158 }
159
160 #[must_use]
162 pub fn transform_vector(self, vector: Vec3<f32>) -> Vec3<f32> {
163 Vec3::new(
164 self.m[0][0] * vector.x + self.m[0][1] * vector.y + self.m[0][2] * vector.z,
165 self.m[1][0] * vector.x + self.m[1][1] * vector.y + self.m[1][2] * vector.z,
166 self.m[2][0] * vector.x + self.m[2][1] * vector.y + self.m[2][2] * vector.z,
167 )
168 }
169
170 #[must_use]
172 pub fn from_rotation_translation(rotation: Mat3<f32>, translation: Vec3<f32>) -> Self {
173 Self::from_rows(
174 [rotation.m[0][0], rotation.m[0][1], rotation.m[0][2], translation.x],
175 [rotation.m[1][0], rotation.m[1][1], rotation.m[1][2], translation.y],
176 [rotation.m[2][0], rotation.m[2][1], rotation.m[2][2], translation.z],
177 [0.0, 0.0, 0.0, 1.0],
178 )
179 }
180}
181
182impl Mat4<f64> {
183 #[must_use]
185 pub fn identity() -> Self {
186 Self::from_rows(
187 [1.0, 0.0, 0.0, 0.0],
188 [0.0, 1.0, 0.0, 0.0],
189 [0.0, 0.0, 1.0, 0.0],
190 [0.0, 0.0, 0.0, 1.0],
191 )
192 }
193
194 #[must_use]
196 pub fn transform_point(self, point: Vec3<f64>) -> Vec3<f64> {
197 let x =
198 self.m[0][0] * point.x + self.m[0][1] * point.y + self.m[0][2] * point.z + self.m[0][3];
199 let y =
200 self.m[1][0] * point.x + self.m[1][1] * point.y + self.m[1][2] * point.z + self.m[1][3];
201 let z =
202 self.m[2][0] * point.x + self.m[2][1] * point.y + self.m[2][2] * point.z + self.m[2][3];
203 let w =
204 self.m[3][0] * point.x + self.m[3][1] * point.y + self.m[3][2] * point.z + self.m[3][3];
205 if w == 0.0 {
206 return Vec3::new(x, y, z);
207 }
208 Vec3::new(x / w, y / w, z / w)
209 }
210
211 #[must_use]
213 pub fn from_rotation_translation(rotation: Mat3<f64>, translation: Vec3<f64>) -> Self {
214 Self::from_rows(
215 [rotation.m[0][0], rotation.m[0][1], rotation.m[0][2], translation.x],
216 [rotation.m[1][0], rotation.m[1][1], rotation.m[1][2], translation.y],
217 [rotation.m[2][0], rotation.m[2][1], rotation.m[2][2], translation.z],
218 [0.0, 0.0, 0.0, 1.0],
219 )
220 }
221}
222
223#[cfg(test)]
224mod tests {
225 use super::{Mat3, Mat4, Vec3};
226
227 #[test]
228 fn mat3_mul_vec3() {
229 let rot_y: Mat3<f32> = Mat3::from_rows([0.0, 0.0, 1.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]);
230 let v = Vec3::new(1.0_f32, 0.0, 0.0);
231 let out = rot_y.mul_vec3(v);
232 assert!((out.x - 0.0).abs() < 1e-6);
233 assert!((out.z - (-1.0)).abs() < 1e-6);
234 }
235
236 #[test]
237 fn mat4_transform_point() {
238 let transform = Mat4::<f32>::from_rotation_translation(
239 Mat3::<f32>::identity(),
240 Vec3::new(1.0, 2.0, 3.0),
241 );
242 let p = Vec3::new(0.0_f32, 0.0, 0.0);
243 let out = transform.transform_point(p);
244 assert!((out.x - 1.0).abs() < 1e-6);
245 assert!((out.y - 2.0).abs() < 1e-6);
246 assert!((out.z - 3.0).abs() < 1e-6);
247 }
248
249 #[test]
250 fn mat4_f64_roundtrip() {
251 let m = Mat4::<f64>::identity();
252 let p = Vec3::new(1.0_f64, 2.0, 3.0);
253 let out = m.transform_point(p);
254 assert!((out.x - 1.0).abs() < 1e-12);
255 }
256}