Skip to main content

spatialrust_segmentation/
plane.rs

1use spatialrust_core::{HasPositions3, PointCloud, SpatialError, SpatialResult};
2use spatialrust_math::{symmetric_eigen3, Mat3, Vec3};
3
4use crate::cloud::extract_mask;
5use crate::segmenter::PointCloudSegmenter;
6
7/// Plane model in Hessian form: `normal ยท p + d = 0` with unit normal.
8#[derive(Clone, Copy, Debug, PartialEq)]
9pub struct PlaneModel {
10    /// Unit-length plane normal.
11    pub normal: Vec3<f32>,
12    /// Plane offset term.
13    pub d: f32,
14}
15
16impl PlaneModel {
17    /// Returns the signed distance from a point to the plane.
18    #[must_use]
19    pub fn signed_distance(&self, point: Vec3<f32>) -> f32 {
20        self.normal.dot(point) + self.d
21    }
22
23    /// Returns the absolute distance from a point to the plane.
24    #[must_use]
25    pub fn distance(&self, point: Vec3<f32>) -> f32 {
26        self.signed_distance(point).abs()
27    }
28
29    /// Returns the absolute distance from XYZ coordinates to the plane.
30    #[must_use]
31    pub fn distance_xyz(&self, x: f32, y: f32, z: f32) -> f32 {
32        (self.normal.x * x + self.normal.y * y + self.normal.z * z + self.d).abs()
33    }
34}
35
36/// Configuration for RANSAC plane segmentation.
37#[derive(Clone, Copy, Debug, PartialEq)]
38pub struct RansacPlaneConfig {
39    /// Maximum distance from the plane for inlier classification.
40    pub distance_threshold: f32,
41    /// Maximum number of RANSAC iterations.
42    pub max_iterations: usize,
43    /// Minimum number of inliers required to accept a model.
44    pub min_inliers: usize,
45    /// Seed for deterministic sampling in tests.
46    pub seed: u64,
47}
48
49impl Default for RansacPlaneConfig {
50    fn default() -> Self {
51        Self { distance_threshold: 0.01, max_iterations: 1_000, min_inliers: 3, seed: 42 }
52    }
53}
54
55impl RansacPlaneConfig {
56    /// Creates a config with the given distance threshold.
57    #[must_use]
58    pub const fn with_distance_threshold(distance_threshold: f32) -> Self {
59        Self { distance_threshold, max_iterations: 1_000, min_inliers: 3, seed: 42 }
60    }
61}
62
63/// Result of RANSAC plane segmentation.
64#[derive(Clone, Debug, PartialEq)]
65pub struct RansacPlaneSegmentation {
66    /// Fitted plane model refined from inliers.
67    pub model: PlaneModel,
68    /// Points classified as inliers.
69    pub inliers: PointCloud,
70    /// Points classified as outliers.
71    pub outliers: PointCloud,
72    /// Number of inlier points.
73    pub inlier_count: usize,
74}
75
76/// RANSAC-based dominant plane segmenter.
77#[derive(Clone, Copy, Debug, PartialEq)]
78pub struct RansacPlaneSegmenter {
79    config: RansacPlaneConfig,
80}
81
82impl RansacPlaneSegmenter {
83    /// Creates a segmenter from config.
84    #[must_use]
85    pub const fn new(config: RansacPlaneConfig) -> Self {
86        Self { config }
87    }
88
89    /// Returns the segmenter config.
90    #[must_use]
91    pub const fn config(&self) -> RansacPlaneConfig {
92        self.config
93    }
94
95    /// Segments the dominant plane and returns inlier/outlier clouds.
96    pub fn segment(&self, input: &PointCloud) -> SpatialResult<RansacPlaneSegmentation> {
97        if input.is_empty() {
98            return Err(SpatialError::InvalidArgument(
99                "cannot segment plane from empty point cloud".to_owned(),
100            ));
101        }
102
103        let (x, y, z) = input.positions3()?;
104        let len = input.len();
105        if len < 3 {
106            return Err(SpatialError::InvalidArgument(
107                "plane segmentation requires at least three points".to_owned(),
108            ));
109        }
110
111        let mut rng = Rng::new(self.config.seed);
112        let mut best_inliers = Vec::new();
113        let mut best_model = None;
114
115        for _ in 0..self.config.max_iterations {
116            let Some(sample) = sample_indices(&mut rng, len) else {
117                continue;
118            };
119            let Some(candidate) = plane_from_indices(x, y, z, sample) else {
120                continue;
121            };
122
123            let inliers = collect_inliers(x, y, z, &candidate, self.config.distance_threshold);
124            if inliers.len() > best_inliers.len() {
125                best_inliers = inliers;
126                best_model = Some(candidate);
127            }
128        }
129
130        if best_inliers.len() < self.config.min_inliers {
131            return Err(SpatialError::InvalidArgument(format!(
132                "RANSAC found only {} inliers, minimum is {}",
133                best_inliers.len(),
134                self.config.min_inliers
135            )));
136        }
137
138        let model =
139            refine_plane_from_inliers(x, y, z, &best_inliers).or(best_model).ok_or_else(|| {
140                SpatialError::InvalidArgument("failed to refine plane model".to_owned())
141            })?;
142
143        let mut inlier_mask = vec![false; len];
144        for index in &best_inliers {
145            inlier_mask[*index] = true;
146        }
147        let mut outlier_mask = inlier_mask.clone();
148        for selected in &mut outlier_mask {
149            *selected = !*selected;
150        }
151
152        let inliers = extract_mask(input, &inlier_mask)?;
153        let outliers = extract_mask(input, &outlier_mask)?;
154
155        Ok(RansacPlaneSegmentation { inlier_count: best_inliers.len(), model, inliers, outliers })
156    }
157
158    /// Returns only the outlier cloud after removing the dominant plane.
159    pub fn extract_outliers(&self, input: &PointCloud) -> SpatialResult<PointCloud> {
160        self.segment(input).map(|result| result.outliers)
161    }
162}
163
164impl PointCloudSegmenter for RansacPlaneSegmenter {
165    fn name(&self) -> &'static str {
166        "RansacPlaneSegmenter"
167    }
168}
169
170struct Rng {
171    state: u64,
172}
173
174impl Rng {
175    fn new(seed: u64) -> Self {
176        Self { state: seed.max(1) }
177    }
178
179    fn next_u64(&mut self) -> u64 {
180        self.state = self.state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
181        self.state
182    }
183
184    fn next_usize(&mut self, upper: usize) -> usize {
185        // Use the high, well-mixed bits of the LCG (its low bits have a short
186        // period) and map them into `0..upper` with a multiply-shift, which
187        // keeps the sampling uniform enough for RANSAC.
188        let high = self.next_u64() >> 32;
189        ((high * upper as u64) >> 32) as usize
190    }
191}
192
193fn sample_indices(rng: &mut Rng, len: usize) -> Option<[usize; 3]> {
194    if len < 3 {
195        return None;
196    }
197
198    let mut indices = [0usize; 3];
199    indices[0] = rng.next_usize(len);
200    indices[1] = rng.next_usize(len);
201    while indices[1] == indices[0] {
202        indices[1] = rng.next_usize(len);
203    }
204    indices[2] = rng.next_usize(len);
205    while indices[2] == indices[0] || indices[2] == indices[1] {
206        indices[2] = rng.next_usize(len);
207    }
208    Some(indices)
209}
210
211fn plane_from_indices(x: &[f32], y: &[f32], z: &[f32], indices: [usize; 3]) -> Option<PlaneModel> {
212    let points = [
213        Vec3::new(x[indices[0]], y[indices[0]], z[indices[0]]),
214        Vec3::new(x[indices[1]], y[indices[1]], z[indices[1]]),
215        Vec3::new(x[indices[2]], y[indices[2]], z[indices[2]]),
216    ];
217    plane_from_points(points[0], points[1], points[2])
218}
219
220fn plane_from_points(p0: Vec3<f32>, p1: Vec3<f32>, p2: Vec3<f32>) -> Option<PlaneModel> {
221    let v1 = p1 - p0;
222    let v2 = p2 - p0;
223    let mut normal = v1.cross(v2);
224    if normal.length_squared() < 1e-12 {
225        return None;
226    }
227    normal = normal.normalize();
228    let d = -normal.dot(p0);
229    Some(PlaneModel { normal, d })
230}
231
232fn collect_inliers(
233    x: &[f32],
234    y: &[f32],
235    z: &[f32],
236    model: &PlaneModel,
237    threshold: f32,
238) -> Vec<usize> {
239    x.iter()
240        .enumerate()
241        .filter_map(|(index, &px)| {
242            (model.distance_xyz(px, y[index], z[index]) <= threshold).then_some(index)
243        })
244        .collect()
245}
246
247fn refine_plane_from_inliers(
248    x: &[f32],
249    y: &[f32],
250    z: &[f32],
251    inliers: &[usize],
252) -> Option<PlaneModel> {
253    if inliers.len() < 3 {
254        return None;
255    }
256
257    let count = inliers.len() as f64;
258    let mut mean_x = 0.0_f64;
259    let mut mean_y = 0.0_f64;
260    let mut mean_z = 0.0_f64;
261    for &index in inliers {
262        mean_x += f64::from(x[index]);
263        mean_y += f64::from(y[index]);
264        mean_z += f64::from(z[index]);
265    }
266    mean_x /= count;
267    mean_y /= count;
268    mean_z /= count;
269
270    let mut c00 = 0.0_f64;
271    let mut c11 = 0.0_f64;
272    let mut c22 = 0.0_f64;
273    let mut c01 = 0.0_f64;
274    let mut c02 = 0.0_f64;
275    let mut c12 = 0.0_f64;
276    for &index in inliers {
277        let dx = f64::from(x[index]) - mean_x;
278        let dy = f64::from(y[index]) - mean_y;
279        let dz = f64::from(z[index]) - mean_z;
280        c00 += dx * dx;
281        c11 += dy * dy;
282        c22 += dz * dz;
283        c01 += dx * dy;
284        c02 += dx * dz;
285        c12 += dy * dz;
286    }
287    let inv = 1.0 / count;
288    let covariance = Mat3::<f64>::from_rows(
289        [c00 * inv, c01 * inv, c02 * inv],
290        [c01 * inv, c11 * inv, c12 * inv],
291        [c02 * inv, c12 * inv, c22 * inv],
292    );
293
294    let eigen = symmetric_eigen3(covariance);
295    let normal = Vec3::new(
296        eigen.eigenvectors.m[0][0] as f32,
297        eigen.eigenvectors.m[1][0] as f32,
298        eigen.eigenvectors.m[2][0] as f32,
299    )
300    .normalize();
301    let centroid = Vec3::new(mean_x as f32, mean_y as f32, mean_z as f32);
302    let d = -normal.dot(centroid);
303    Some(PlaneModel { normal, d })
304}
305
306#[cfg(test)]
307mod tests {
308    use super::{PlaneModel, RansacPlaneConfig, RansacPlaneSegmenter};
309    use spatialrust_core::{HasPositions3, PointCloudBuilder, StandardSchemas};
310    use spatialrust_math::Vec3;
311
312    fn plane_with_outliers() -> spatialrust_core::PointCloud {
313        let mut builder = PointCloudBuilder::new(StandardSchemas::point_xyz());
314        for x in 0..10 {
315            for y in 0..10 {
316                builder.push_point([x as f32, y as f32, 0.0]).unwrap();
317            }
318        }
319        builder.push_point([0.0, 0.0, 5.0]).unwrap();
320        builder.push_point([1.0, 1.0, 5.0]).unwrap();
321        builder.build().unwrap()
322    }
323
324    #[test]
325    fn segments_dominant_plane() {
326        let input = plane_with_outliers();
327        let segmenter = RansacPlaneSegmenter::new(RansacPlaneConfig {
328            distance_threshold: 0.05,
329            max_iterations: 500,
330            min_inliers: 50,
331            seed: 7,
332        });
333        let result = segmenter.segment(&input).unwrap();
334        assert_eq!(result.inlier_count, 100);
335        assert_eq!(result.outliers.len(), 2);
336        assert!(result.model.normal.z.abs() > 0.9);
337    }
338
339    #[test]
340    fn plane_distance_matches_point() {
341        let model = PlaneModel { normal: Vec3::new(0.0, 0.0, 1.0), d: 0.0 };
342        assert!((model.distance(Vec3::new(0.0, 0.0, 1.0)) - 1.0).abs() < 1e-6);
343    }
344
345    #[test]
346    fn extract_outliers_removes_plane() {
347        let input = plane_with_outliers();
348        let segmenter = RansacPlaneSegmenter::new(RansacPlaneConfig {
349            distance_threshold: 0.05,
350            max_iterations: 500,
351            min_inliers: 50,
352            seed: 7,
353        });
354        let outliers = segmenter.extract_outliers(&input).unwrap();
355        let (_, _, z) = outliers.positions3().unwrap();
356        assert!(z.iter().all(|value| *value > 1.0));
357    }
358}