1use spatialrust_core::{
8 HasPositions3, PointBuffer, PointBufferSet, PointCloud, SpatialError, SpatialResult,
9};
10
11use crate::filter::PointCloudFilter;
12
13#[derive(Clone, Copy, Debug, PartialEq)]
15pub struct Aabb {
16 pub min: [f32; 3],
18 pub max: [f32; 3],
20}
21
22impl Aabb {
23 #[must_use]
25 pub const fn new(min: [f32; 3], max: [f32; 3]) -> Self {
26 Self { min, max }
27 }
28
29 #[must_use]
31 pub fn contains(&self, point: [f32; 3]) -> bool {
32 (0..3).all(|i| point[i] >= self.min[i] && point[i] <= self.max[i])
33 }
34}
35
36#[derive(Clone, Copy, Debug, PartialEq)]
38pub struct CropBox {
39 bounds: Aabb,
40 invert: bool,
41}
42
43impl CropBox {
44 #[must_use]
46 pub const fn new(bounds: Aabb) -> Self {
47 Self { bounds, invert: false }
48 }
49
50 #[must_use]
52 pub const fn inverted(bounds: Aabb) -> Self {
53 Self { bounds, invert: true }
54 }
55
56 pub fn keep_mask(&self, input: &PointCloud) -> SpatialResult<Vec<bool>> {
58 if self.bounds.min.iter().zip(self.bounds.max).any(|(lo, hi)| *lo > hi) {
59 return Err(SpatialError::InvalidArgument(
60 "crop box min must not exceed max on any axis".to_owned(),
61 ));
62 }
63 let (x, y, z) = input.positions3()?;
64 Ok((0..input.len())
65 .map(|i| {
66 let inside = self.bounds.contains([x[i], y[i], z[i]]);
67 inside ^ self.invert
68 })
69 .collect())
70 }
71}
72
73impl PointCloudFilter for CropBox {
74 fn name(&self) -> &'static str {
75 "CropBox"
76 }
77
78 fn filter(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
79 let mask = self.keep_mask(input)?;
80 gather_mask(input, &mask)
81 }
82}
83
84#[derive(Clone, Debug, PartialEq)]
86pub struct PassThrough {
87 field: String,
88 min: f32,
89 max: f32,
90 invert: bool,
91}
92
93impl PassThrough {
94 #[must_use]
96 pub fn new(field: impl Into<String>, min: f32, max: f32) -> Self {
97 Self { field: field.into(), min, max, invert: false }
98 }
99
100 #[must_use]
102 pub fn inverted(field: impl Into<String>, min: f32, max: f32) -> Self {
103 Self { field: field.into(), min, max, invert: true }
104 }
105
106 pub fn keep_mask(&self, input: &PointCloud) -> SpatialResult<Vec<bool>> {
108 if self.min > self.max {
109 return Err(SpatialError::InvalidArgument(
110 "pass-through min must not exceed max".to_owned(),
111 ));
112 }
113 let values = input.field(&self.field)?.as_f32()?;
114 Ok(values
115 .iter()
116 .map(|&v| {
117 let inside = v >= self.min && v <= self.max;
118 inside ^ self.invert
119 })
120 .collect())
121 }
122}
123
124impl PointCloudFilter for PassThrough {
125 fn name(&self) -> &'static str {
126 "PassThrough"
127 }
128
129 fn filter(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
130 let mask = self.keep_mask(input)?;
131 gather_mask(input, &mask)
132 }
133}
134
135fn gather_mask(input: &PointCloud, mask: &[bool]) -> SpatialResult<PointCloud> {
137 let indices: Vec<usize> =
138 mask.iter().enumerate().filter_map(|(i, &keep)| keep.then_some(i)).collect();
139
140 let mut buffers = PointBufferSet::new();
141 for field in input.schema().fields() {
142 let source = input.field(&field.name)?;
143 buffers.insert(field.name.clone(), gather_buffer(source, &indices));
144 }
145 PointCloud::try_from_parts(input.schema().clone(), buffers, input.metadata().clone())
146}
147
148fn gather_buffer(source: &PointBuffer, indices: &[usize]) -> PointBuffer {
149 match source {
150 PointBuffer::F32(v) => PointBuffer::from_f32(indices.iter().map(|&i| v[i]).collect()),
151 PointBuffer::F64(v) => PointBuffer::F64(indices.iter().map(|&i| v[i]).collect()),
152 PointBuffer::U8(v) => PointBuffer::U8(indices.iter().map(|&i| v[i]).collect()),
153 PointBuffer::U16(v) => PointBuffer::U16(indices.iter().map(|&i| v[i]).collect()),
154 PointBuffer::U32(v) => PointBuffer::U32(indices.iter().map(|&i| v[i]).collect()),
155 PointBuffer::I32(v) => PointBuffer::I32(indices.iter().map(|&i| v[i]).collect()),
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162 use spatialrust_core::{DType, FieldSemantic, PointField, PointSchema};
163
164 fn cloud_from_xyz(points: &[[f32; 3]]) -> PointCloud {
165 let schema = PointSchema::new()
166 .with_field(PointField::scalar("x", FieldSemantic::PositionX, DType::F32))
167 .with_field(PointField::scalar("y", FieldSemantic::PositionY, DType::F32))
168 .with_field(PointField::scalar("z", FieldSemantic::PositionZ, DType::F32));
169 let mut buffers = PointBufferSet::new();
170 buffers
171 .insert("x".to_owned(), PointBuffer::from_f32(points.iter().map(|p| p[0]).collect()));
172 buffers
173 .insert("y".to_owned(), PointBuffer::from_f32(points.iter().map(|p| p[1]).collect()));
174 buffers
175 .insert("z".to_owned(), PointBuffer::from_f32(points.iter().map(|p| p[2]).collect()));
176 PointCloud::try_from_parts(schema, buffers, Default::default()).unwrap()
177 }
178
179 fn grid() -> PointCloud {
180 let mut pts = Vec::new();
181 for ix in 0..5 {
182 for iy in 0..5 {
183 pts.push([ix as f32, iy as f32, 0.0]);
184 }
185 }
186 cloud_from_xyz(&pts)
187 }
188
189 #[test]
190 fn crop_box_keeps_inside() {
191 let cloud = grid();
192 let filter = CropBox::new(Aabb::new([1.0, 1.0, -1.0], [3.0, 3.0, 1.0]));
193 let out = filter.filter(&cloud).unwrap();
194 assert_eq!(out.len(), 9);
196 }
197
198 #[test]
199 fn crop_box_inverted_drops_inside() {
200 let cloud = grid();
201 let filter = CropBox::inverted(Aabb::new([1.0, 1.0, -1.0], [3.0, 3.0, 1.0]));
202 let out = filter.filter(&cloud).unwrap();
203 assert_eq!(out.len(), cloud.len() - 9);
204 }
205
206 #[test]
207 fn pass_through_on_position_field() {
208 let cloud = grid();
209 let filter = PassThrough::new("x", 2.0, 4.0);
211 let out = filter.filter(&cloud).unwrap();
212 assert_eq!(out.len(), 15);
213 }
214
215 #[test]
216 fn errors_on_inverted_range_and_missing_field() {
217 let cloud = grid();
218 assert!(CropBox::new(Aabb::new([3.0, 0.0, 0.0], [1.0, 1.0, 1.0]))
219 .keep_mask(&cloud)
220 .is_err());
221 assert!(PassThrough::new("x", 5.0, 1.0).keep_mask(&cloud).is_err());
222 assert!(PassThrough::new("intensity", 0.0, 1.0).keep_mask(&cloud).is_err());
223 }
224}