Skip to main content

spatialrust_filtering/
crop.rs

1//! Geometric crop and field-range filters.
2//!
3//! These are the cheap, ubiquitous preprocessing steps: keep (or drop) points
4//! inside an axis-aligned box, or keep points whose value in some field falls in
5//! a range (height slices, intensity thresholds, time windows).
6
7use spatialrust_core::{
8    HasPositions3, PointBuffer, PointBufferSet, PointCloud, SpatialError, SpatialResult,
9};
10
11use crate::filter::PointCloudFilter;
12
13/// Axis-aligned bounding box used by [`CropBox`].
14#[derive(Clone, Copy, Debug, PartialEq)]
15pub struct Aabb {
16    /// Inclusive lower corner `(x, y, z)`.
17    pub min: [f32; 3],
18    /// Inclusive upper corner `(x, y, z)`.
19    pub max: [f32; 3],
20}
21
22impl Aabb {
23    /// Creates a box from its corners.
24    #[must_use]
25    pub const fn new(min: [f32; 3], max: [f32; 3]) -> Self {
26        Self { min, max }
27    }
28
29    /// Returns whether `point` lies within the (inclusive) box.
30    #[must_use]
31    pub fn contains(&self, point: [f32; 3]) -> bool {
32        (0..3).all(|i| point[i] >= self.min[i] && point[i] <= self.max[i])
33    }
34}
35
36/// Keeps (or, when `invert` is set, drops) points inside an axis-aligned box.
37#[derive(Clone, Copy, Debug, PartialEq)]
38pub struct CropBox {
39    bounds: Aabb,
40    invert: bool,
41}
42
43impl CropBox {
44    /// Keeps points inside `bounds`.
45    #[must_use]
46    pub const fn new(bounds: Aabb) -> Self {
47        Self { bounds, invert: false }
48    }
49
50    /// Drops points inside `bounds` (keeps everything outside).
51    #[must_use]
52    pub const fn inverted(bounds: Aabb) -> Self {
53        Self { bounds, invert: true }
54    }
55
56    /// Computes the keep mask for `input`.
57    pub fn keep_mask(&self, input: &PointCloud) -> SpatialResult<Vec<bool>> {
58        if self.bounds.min.iter().zip(self.bounds.max).any(|(lo, hi)| *lo > hi) {
59            return Err(SpatialError::InvalidArgument(
60                "crop box min must not exceed max on any axis".to_owned(),
61            ));
62        }
63        let (x, y, z) = input.positions3()?;
64        Ok((0..input.len())
65            .map(|i| {
66                let inside = self.bounds.contains([x[i], y[i], z[i]]);
67                inside ^ self.invert
68            })
69            .collect())
70    }
71}
72
73impl PointCloudFilter for CropBox {
74    fn name(&self) -> &'static str {
75        "CropBox"
76    }
77
78    fn filter(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
79        let mask = self.keep_mask(input)?;
80        gather_mask(input, &mask)
81    }
82}
83
84/// Keeps (or drops) points whose value in a named field falls within a range.
85#[derive(Clone, Debug, PartialEq)]
86pub struct PassThrough {
87    field: String,
88    min: f32,
89    max: f32,
90    invert: bool,
91}
92
93impl PassThrough {
94    /// Keeps points whose `field` value is within `[min, max]` (inclusive).
95    #[must_use]
96    pub fn new(field: impl Into<String>, min: f32, max: f32) -> Self {
97        Self { field: field.into(), min, max, invert: false }
98    }
99
100    /// Drops points whose `field` value is within `[min, max]`.
101    #[must_use]
102    pub fn inverted(field: impl Into<String>, min: f32, max: f32) -> Self {
103        Self { field: field.into(), min, max, invert: true }
104    }
105
106    /// Computes the keep mask for `input`.
107    pub fn keep_mask(&self, input: &PointCloud) -> SpatialResult<Vec<bool>> {
108        if self.min > self.max {
109            return Err(SpatialError::InvalidArgument(
110                "pass-through min must not exceed max".to_owned(),
111            ));
112        }
113        let values = input.field(&self.field)?.as_f32()?;
114        Ok(values
115            .iter()
116            .map(|&v| {
117                let inside = v >= self.min && v <= self.max;
118                inside ^ self.invert
119            })
120            .collect())
121    }
122}
123
124impl PointCloudFilter for PassThrough {
125    fn name(&self) -> &'static str {
126        "PassThrough"
127    }
128
129    fn filter(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
130        let mask = self.keep_mask(input)?;
131        gather_mask(input, &mask)
132    }
133}
134
135/// Builds a new cloud from the points where `mask` is true, preserving schema.
136fn gather_mask(input: &PointCloud, mask: &[bool]) -> SpatialResult<PointCloud> {
137    let indices: Vec<usize> =
138        mask.iter().enumerate().filter_map(|(i, &keep)| keep.then_some(i)).collect();
139
140    let mut buffers = PointBufferSet::new();
141    for field in input.schema().fields() {
142        let source = input.field(&field.name)?;
143        buffers.insert(field.name.clone(), gather_buffer(source, &indices));
144    }
145    PointCloud::try_from_parts(input.schema().clone(), buffers, input.metadata().clone())
146}
147
148fn gather_buffer(source: &PointBuffer, indices: &[usize]) -> PointBuffer {
149    match source {
150        PointBuffer::F32(v) => PointBuffer::from_f32(indices.iter().map(|&i| v[i]).collect()),
151        PointBuffer::F64(v) => PointBuffer::F64(indices.iter().map(|&i| v[i]).collect()),
152        PointBuffer::U8(v) => PointBuffer::U8(indices.iter().map(|&i| v[i]).collect()),
153        PointBuffer::U16(v) => PointBuffer::U16(indices.iter().map(|&i| v[i]).collect()),
154        PointBuffer::U32(v) => PointBuffer::U32(indices.iter().map(|&i| v[i]).collect()),
155        PointBuffer::I32(v) => PointBuffer::I32(indices.iter().map(|&i| v[i]).collect()),
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    use super::*;
162    use spatialrust_core::{DType, FieldSemantic, PointField, PointSchema};
163
164    fn cloud_from_xyz(points: &[[f32; 3]]) -> PointCloud {
165        let schema = PointSchema::new()
166            .with_field(PointField::scalar("x", FieldSemantic::PositionX, DType::F32))
167            .with_field(PointField::scalar("y", FieldSemantic::PositionY, DType::F32))
168            .with_field(PointField::scalar("z", FieldSemantic::PositionZ, DType::F32));
169        let mut buffers = PointBufferSet::new();
170        buffers
171            .insert("x".to_owned(), PointBuffer::from_f32(points.iter().map(|p| p[0]).collect()));
172        buffers
173            .insert("y".to_owned(), PointBuffer::from_f32(points.iter().map(|p| p[1]).collect()));
174        buffers
175            .insert("z".to_owned(), PointBuffer::from_f32(points.iter().map(|p| p[2]).collect()));
176        PointCloud::try_from_parts(schema, buffers, Default::default()).unwrap()
177    }
178
179    fn grid() -> PointCloud {
180        let mut pts = Vec::new();
181        for ix in 0..5 {
182            for iy in 0..5 {
183                pts.push([ix as f32, iy as f32, 0.0]);
184            }
185        }
186        cloud_from_xyz(&pts)
187    }
188
189    #[test]
190    fn crop_box_keeps_inside() {
191        let cloud = grid();
192        let filter = CropBox::new(Aabb::new([1.0, 1.0, -1.0], [3.0, 3.0, 1.0]));
193        let out = filter.filter(&cloud).unwrap();
194        // x,y each in {1,2,3} -> 3x3 = 9 points.
195        assert_eq!(out.len(), 9);
196    }
197
198    #[test]
199    fn crop_box_inverted_drops_inside() {
200        let cloud = grid();
201        let filter = CropBox::inverted(Aabb::new([1.0, 1.0, -1.0], [3.0, 3.0, 1.0]));
202        let out = filter.filter(&cloud).unwrap();
203        assert_eq!(out.len(), cloud.len() - 9);
204    }
205
206    #[test]
207    fn pass_through_on_position_field() {
208        let cloud = grid();
209        // Keep points with x in [2, 4] -> x in {2,3,4} -> 3 columns * 5 rows = 15.
210        let filter = PassThrough::new("x", 2.0, 4.0);
211        let out = filter.filter(&cloud).unwrap();
212        assert_eq!(out.len(), 15);
213    }
214
215    #[test]
216    fn errors_on_inverted_range_and_missing_field() {
217        let cloud = grid();
218        assert!(CropBox::new(Aabb::new([3.0, 0.0, 0.0], [1.0, 1.0, 1.0]))
219            .keep_mask(&cloud)
220            .is_err());
221        assert!(PassThrough::new("x", 5.0, 1.0).keep_mask(&cloud).is_err());
222        assert!(PassThrough::new("intensity", 0.0, 1.0).keep_mask(&cloud).is_err());
223    }
224}