Skip to main content

spatialrust_segmentation/
primitives.rs

1//! RANSAC fitting of sphere and cylinder primitives.
2//!
3//! These complement the plane segmenter for detecting man-made shapes — pipes,
4//! tanks, poles, balls. Sphere fitting needs only positions; cylinder fitting
5//! also needs per-point normals (the axis is recovered from two surface
6//! normals), so the input cloud must carry normals.
7
8use spatialrust_core::{HasNormals3, HasPositions3, PointCloud, SpatialError, SpatialResult};
9use spatialrust_math::{solve_linear_system, LeastSquaresResult, Vec3};
10
11use crate::cloud::extract_mask;
12use crate::segmenter::PointCloudSegmenter;
13
14/// Shared RANSAC controls for primitive fitting.
15#[derive(Clone, Copy, Debug, PartialEq)]
16pub struct RansacPrimitiveConfig {
17    /// Maximum surface distance for inlier classification.
18    pub distance_threshold: f32,
19    /// Maximum number of RANSAC iterations.
20    pub max_iterations: usize,
21    /// Minimum number of inliers required to accept a model.
22    pub min_inliers: usize,
23    /// Smallest acceptable radius (rejects degenerate near-flat fits).
24    pub min_radius: f32,
25    /// Largest acceptable radius.
26    pub max_radius: f32,
27    /// Seed for deterministic sampling.
28    pub seed: u64,
29}
30
31impl Default for RansacPrimitiveConfig {
32    fn default() -> Self {
33        Self {
34            distance_threshold: 0.02,
35            max_iterations: 1_000,
36            min_inliers: 10,
37            min_radius: 0.0,
38            max_radius: f32::INFINITY,
39            seed: 42,
40        }
41    }
42}
43
44/// Sphere model: all surface points are `radius` from `center`.
45#[derive(Clone, Copy, Debug, PartialEq)]
46pub struct SphereModel {
47    /// Sphere center.
48    pub center: Vec3<f32>,
49    /// Sphere radius.
50    pub radius: f32,
51}
52
53impl SphereModel {
54    /// Absolute distance from `point` to the sphere surface.
55    #[must_use]
56    pub fn distance(&self, point: Vec3<f32>) -> f32 {
57        ((point - self.center).length() - self.radius).abs()
58    }
59}
60
61/// Cylinder model: all surface points are `radius` from the axis line.
62#[derive(Clone, Copy, Debug, PartialEq)]
63pub struct CylinderModel {
64    /// A point lying on the cylinder axis.
65    pub axis_point: Vec3<f32>,
66    /// Unit-length axis direction.
67    pub axis_direction: Vec3<f32>,
68    /// Cylinder radius.
69    pub radius: f32,
70}
71
72impl CylinderModel {
73    /// Perpendicular distance from `point` to the axis line.
74    #[must_use]
75    pub fn axis_distance(&self, point: Vec3<f32>) -> f32 {
76        let v = point - self.axis_point;
77        let along = scale(self.axis_direction, v.dot(self.axis_direction));
78        (v - along).length()
79    }
80
81    /// Absolute distance from `point` to the cylinder surface.
82    #[must_use]
83    pub fn distance(&self, point: Vec3<f32>) -> f32 {
84        (self.axis_distance(point) - self.radius).abs()
85    }
86}
87
88/// Result of fitting a primitive, partitioning the cloud into inliers/outliers.
89#[derive(Clone, Debug, PartialEq)]
90pub struct PrimitiveSegmentation<M> {
91    /// Fitted model.
92    pub model: M,
93    /// Points classified as inliers.
94    pub inliers: PointCloud,
95    /// Points classified as outliers.
96    pub outliers: PointCloud,
97    /// Number of inlier points.
98    pub inlier_count: usize,
99}
100
101/// RANSAC sphere segmenter.
102#[derive(Clone, Copy, Debug, PartialEq)]
103pub struct RansacSphereSegmenter {
104    config: RansacPrimitiveConfig,
105}
106
107impl RansacSphereSegmenter {
108    /// Creates a segmenter from config.
109    #[must_use]
110    pub const fn new(config: RansacPrimitiveConfig) -> Self {
111        Self { config }
112    }
113
114    /// Returns the segmenter config.
115    #[must_use]
116    pub const fn config(&self) -> RansacPrimitiveConfig {
117        self.config
118    }
119
120    /// Fits the dominant sphere and partitions the cloud.
121    pub fn segment(&self, input: &PointCloud) -> SpatialResult<PrimitiveSegmentation<SphereModel>> {
122        let (x, y, z) = input.positions3()?;
123        let len = input.len();
124        if len < 4 {
125            return Err(SpatialError::InvalidArgument(
126                "sphere fitting requires at least four points".to_owned(),
127            ));
128        }
129
130        let mut rng = Rng::new(self.config.seed);
131        let mut best_inliers: Vec<usize> = Vec::new();
132        let mut best_model = None;
133
134        for _ in 0..self.config.max_iterations {
135            let Some(sample) = sample_distinct::<4>(&mut rng, len) else {
136                continue;
137            };
138            let Some(model) = sphere_from_points(x, y, z, sample) else {
139                continue;
140            };
141            if model.radius < self.config.min_radius || model.radius > self.config.max_radius {
142                continue;
143            }
144            let inliers = collect_inliers(len, self.config.distance_threshold, |i| {
145                model.distance(Vec3::new(x[i], y[i], z[i]))
146            });
147            if inliers.len() > best_inliers.len() {
148                best_inliers = inliers;
149                best_model = Some(model);
150            }
151        }
152
153        finalize(input, best_model, &best_inliers, self.config.min_inliers)
154    }
155}
156
157impl PointCloudSegmenter for RansacSphereSegmenter {
158    fn name(&self) -> &'static str {
159        "RansacSphereSegmenter"
160    }
161}
162
163/// RANSAC cylinder segmenter. The input cloud must carry normals.
164#[derive(Clone, Copy, Debug, PartialEq)]
165pub struct RansacCylinderSegmenter {
166    config: RansacPrimitiveConfig,
167}
168
169impl RansacCylinderSegmenter {
170    /// Creates a segmenter from config.
171    #[must_use]
172    pub const fn new(config: RansacPrimitiveConfig) -> Self {
173        Self { config }
174    }
175
176    /// Returns the segmenter config.
177    #[must_use]
178    pub const fn config(&self) -> RansacPrimitiveConfig {
179        self.config
180    }
181
182    /// Fits the dominant cylinder and partitions the cloud.
183    pub fn segment(
184        &self,
185        input: &PointCloud,
186    ) -> SpatialResult<PrimitiveSegmentation<CylinderModel>> {
187        let (x, y, z) = input.positions3()?;
188        let (nx, ny, nz) = input.normals3()?;
189        let len = input.len();
190        if len < 2 {
191            return Err(SpatialError::InvalidArgument(
192                "cylinder fitting requires at least two points".to_owned(),
193            ));
194        }
195
196        let mut rng = Rng::new(self.config.seed);
197        let mut best_inliers: Vec<usize> = Vec::new();
198        let mut best_model = None;
199
200        for _ in 0..self.config.max_iterations {
201            let Some(sample) = sample_distinct::<2>(&mut rng, len) else {
202                continue;
203            };
204            let Some(model) = cylinder_from_points(x, y, z, nx, ny, nz, sample) else {
205                continue;
206            };
207            if model.radius < self.config.min_radius || model.radius > self.config.max_radius {
208                continue;
209            }
210            let inliers = collect_inliers(len, self.config.distance_threshold, |i| {
211                model.distance(Vec3::new(x[i], y[i], z[i]))
212            });
213            if inliers.len() > best_inliers.len() {
214                best_inliers = inliers;
215                best_model = Some(model);
216            }
217        }
218
219        finalize(input, best_model, &best_inliers, self.config.min_inliers)
220    }
221}
222
223impl PointCloudSegmenter for RansacCylinderSegmenter {
224    fn name(&self) -> &'static str {
225        "RansacCylinderSegmenter"
226    }
227}
228
229/// Builds the inlier/outlier partition once the best model is known.
230fn finalize<M>(
231    input: &PointCloud,
232    best_model: Option<M>,
233    best_inliers: &[usize],
234    min_inliers: usize,
235) -> SpatialResult<PrimitiveSegmentation<M>> {
236    if best_inliers.len() < min_inliers || best_model.is_none() {
237        return Err(SpatialError::InvalidArgument(format!(
238            "RANSAC found only {} inliers, minimum is {min_inliers}",
239            best_inliers.len()
240        )));
241    }
242    let model = best_model.expect("checked above");
243
244    let mut inlier_mask = vec![false; input.len()];
245    for &index in best_inliers {
246        inlier_mask[index] = true;
247    }
248    let outlier_mask: Vec<bool> = inlier_mask.iter().map(|&keep| !keep).collect();
249
250    Ok(PrimitiveSegmentation {
251        model,
252        inliers: extract_mask(input, &inlier_mask)?,
253        outliers: extract_mask(input, &outlier_mask)?,
254        inlier_count: best_inliers.len(),
255    })
256}
257
258fn collect_inliers(len: usize, threshold: f32, distance: impl Fn(usize) -> f32) -> Vec<usize> {
259    (0..len).filter(|&i| distance(i) <= threshold).collect()
260}
261
262/// Solves for the sphere through four points (subtract one equation from the
263/// others to linearize, then solve the 3×3 system for the center).
264fn sphere_from_points(x: &[f32], y: &[f32], z: &[f32], idx: [usize; 4]) -> Option<SphereModel> {
265    let p: Vec<[f64; 3]> =
266        idx.iter().map(|&i| [f64::from(x[i]), f64::from(y[i]), f64::from(z[i])]).collect();
267    let sq = |q: [f64; 3]| q[0] * q[0] + q[1] * q[1] + q[2] * q[2];
268
269    let mut a = Vec::with_capacity(3);
270    let mut b = Vec::with_capacity(3);
271    for row in 1..4 {
272        a.push(vec![
273            2.0 * (p[row][0] - p[0][0]),
274            2.0 * (p[row][1] - p[0][1]),
275            2.0 * (p[row][2] - p[0][2]),
276        ]);
277        b.push(sq(p[row]) - sq(p[0]));
278    }
279
280    let center = match solve_linear_system(a, b) {
281        LeastSquaresResult::Solved(c) => c,
282        LeastSquaresResult::Singular => return None,
283    };
284    let center = Vec3::new(center[0] as f32, center[1] as f32, center[2] as f32);
285    let radius = (center - Vec3::new(x[idx[0]], y[idx[0]], z[idx[0]])).length();
286    if !radius.is_finite() {
287        return None;
288    }
289    Some(SphereModel { center, radius })
290}
291
292/// Recovers a cylinder from two points with surface normals: the axis is the
293/// cross product of the normals, and projecting into the plane perpendicular to
294/// the axis turns the problem into fitting a circle through two points whose
295/// (projected) normals point at the center.
296#[allow(clippy::too_many_arguments)]
297fn cylinder_from_points(
298    x: &[f32],
299    y: &[f32],
300    z: &[f32],
301    nx: &[f32],
302    ny: &[f32],
303    nz: &[f32],
304    idx: [usize; 2],
305) -> Option<CylinderModel> {
306    let (i0, i1) = (idx[0], idx[1]);
307    let p0 = Vec3::new(x[i0], y[i0], z[i0]);
308    let p1 = Vec3::new(x[i1], y[i1], z[i1]);
309    let n0 = Vec3::new(nx[i0], ny[i0], nz[i0]);
310    let n1 = Vec3::new(nx[i1], ny[i1], nz[i1]);
311
312    let axis = n0.cross(n1);
313    if axis.length_squared() < 1e-10 {
314        return None; // normals parallel: axis undefined.
315    }
316    let axis = axis.normalize();
317
318    // Orthonormal basis (u, w) spanning the plane perpendicular to the axis.
319    let helper =
320        if axis.x.abs() < 0.9 { Vec3::new(1.0, 0.0, 0.0) } else { Vec3::new(0.0, 1.0, 0.0) };
321    let u = axis.cross(helper).normalize();
322    let w = axis.cross(u);
323
324    let proj = |v: Vec3<f32>| (v.dot(u), v.dot(w));
325    let (p0u, p0w) = proj(p0);
326    let (p1u, p1w) = proj(p1);
327    let (mut n0u, mut n0w) = proj(n0);
328    let (mut n1u, mut n1w) = proj(n1);
329    let l0 = (n0u * n0u + n0w * n0w).sqrt();
330    let l1 = (n1u * n1u + n1w * n1w).sqrt();
331    if l0 < 1e-6 || l1 < 1e-6 {
332        return None; // a normal nearly parallel to the axis projects to ~0.
333    }
334    n0u /= l0;
335    n0w /= l0;
336    n1u /= l1;
337    n1w /= l1;
338
339    // Intersect the two in-plane normal lines: P0 + t0 N0 = P1 + t1 N1.
340    // [N0.u, -N1.u; N0.w, -N1.w] [t0; t1] = [P1.u - P0.u; P1.w - P0.w].
341    let det = n0u * (-n1w) - (-n1u) * n0w;
342    if det.abs() < 1e-9 {
343        return None;
344    }
345    let rhs = (p1u - p0u, p1w - p0w);
346    let t0 = (rhs.0 * (-n1w) - (-n1u) * rhs.1) / det;
347
348    let center_u = p0u + t0 * n0u;
349    let center_w = p0w + t0 * n0w;
350    let radius = ((center_u - p0u).powi(2) + (center_w - p0w).powi(2)).sqrt();
351    if !radius.is_finite() {
352        return None;
353    }
354
355    let axis_point = scale(u, center_u) + scale(w, center_w);
356    Some(CylinderModel { axis_point, axis_direction: axis, radius })
357}
358
359/// Scales a vector by a scalar (`Vec3` has no scalar-multiply operator).
360fn scale(v: Vec3<f32>, s: f32) -> Vec3<f32> {
361    Vec3::new(v.x * s, v.y * s, v.z * s)
362}
363
364struct Rng {
365    state: u64,
366}
367
368impl Rng {
369    fn new(seed: u64) -> Self {
370        Self { state: seed.max(1) }
371    }
372
373    fn next_usize(&mut self, upper: usize) -> usize {
374        self.state = self.state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
375        // Map the high, well-mixed bits of the LCG into `0..upper` (its low bits
376        // have a short period, which biases naive `% upper` sampling).
377        (((self.state >> 32) * upper as u64) >> 32) as usize
378    }
379}
380
381/// Draws `N` distinct indices in `0..len`, or `None` if it cannot.
382fn sample_distinct<const N: usize>(rng: &mut Rng, len: usize) -> Option<[usize; N]> {
383    if len < N {
384        return None;
385    }
386    let mut out = [0usize; N];
387    let mut filled = 0;
388    let mut attempts = 0;
389    while filled < N && attempts < N * 16 {
390        let candidate = rng.next_usize(len);
391        if !out[..filled].contains(&candidate) {
392            out[filled] = candidate;
393            filled += 1;
394        }
395        attempts += 1;
396    }
397    (filled == N).then_some(out)
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403    use spatialrust_core::{DType, FieldSemantic, PointCloudBuilder, PointField, PointSchema};
404    use std::f32::consts::PI;
405
406    fn xyz_cloud(points: &[Vec3<f32>]) -> PointCloud {
407        let mut builder = PointCloudBuilder::new(
408            PointSchema::new()
409                .with_field(PointField::scalar("x", FieldSemantic::PositionX, DType::F32))
410                .with_field(PointField::scalar("y", FieldSemantic::PositionY, DType::F32))
411                .with_field(PointField::scalar("z", FieldSemantic::PositionZ, DType::F32)),
412        );
413        for p in points {
414            builder.push_point([p.x, p.y, p.z]).unwrap();
415        }
416        builder.build().unwrap()
417    }
418
419    fn xyz_normal_cloud(points: &[(Vec3<f32>, Vec3<f32>)]) -> PointCloud {
420        let mut builder = PointCloudBuilder::new(
421            PointSchema::new()
422                .with_field(PointField::scalar("x", FieldSemantic::PositionX, DType::F32))
423                .with_field(PointField::scalar("y", FieldSemantic::PositionY, DType::F32))
424                .with_field(PointField::scalar("z", FieldSemantic::PositionZ, DType::F32))
425                .with_field(PointField::scalar("normal_x", FieldSemantic::NormalX, DType::F32))
426                .with_field(PointField::scalar("normal_y", FieldSemantic::NormalY, DType::F32))
427                .with_field(PointField::scalar("normal_z", FieldSemantic::NormalZ, DType::F32)),
428        );
429        for (p, n) in points {
430            builder.push_point([p.x, p.y, p.z, n.x, n.y, n.z]).unwrap();
431        }
432        builder.build().unwrap()
433    }
434
435    #[test]
436    fn fits_sphere_with_outliers() {
437        let center = Vec3::new(1.0, 2.0, 3.0);
438        let radius = 0.5_f32;
439        let mut pts = Vec::new();
440        for i in 0..12 {
441            for j in 0..12 {
442                let theta = PI * i as f32 / 11.0;
443                let phi = 2.0 * PI * j as f32 / 12.0;
444                pts.push(
445                    center
446                        + Vec3::new(
447                            radius * theta.sin() * phi.cos(),
448                            radius * theta.sin() * phi.sin(),
449                            radius * theta.cos(),
450                        ),
451                );
452            }
453        }
454        // Scatter some outliers far from the surface.
455        pts.push(center + Vec3::new(3.0, 0.0, 0.0));
456        pts.push(center + Vec3::new(0.0, 3.0, 0.0));
457
458        let cloud = xyz_cloud(&pts);
459        let seg = RansacSphereSegmenter::new(RansacPrimitiveConfig {
460            distance_threshold: 0.02,
461            max_iterations: 800,
462            min_inliers: 50,
463            seed: 3,
464            ..RansacPrimitiveConfig::default()
465        });
466        let result = seg.segment(&cloud).unwrap();
467        assert!((result.model.radius - radius).abs() < 0.02);
468        assert!((result.model.center - center).length() < 0.02);
469        assert_eq!(result.outliers.len(), 2);
470    }
471
472    #[test]
473    fn fits_sphere_among_many_distractors() {
474        // A sphere plus a cloud of scattered random distractors of comparable
475        // size. Random points fit no sphere, so the only high-inlier model is the
476        // true sphere, and finding it requires reliably sampling 4 sphere points
477        // among the ~50% noise -- a stress test for the RANSAC sampler.
478        let center = Vec3::new(0.0, 0.0, 0.0);
479        let radius = 0.4_f32;
480        let mut pts = Vec::new();
481        for i in 0..16 {
482            for j in 0..16 {
483                let theta = PI * i as f32 / 15.0;
484                let phi = 2.0 * PI * j as f32 / 16.0;
485                pts.push(Vec3::new(
486                    radius * theta.sin() * phi.cos(),
487                    radius * theta.sin() * phi.sin(),
488                    radius * theta.cos(),
489                ));
490            }
491        }
492        // ~256 scattered distractors in a box well away from the sphere surface.
493        let mut s = 1234_u64;
494        let mut rand = || {
495            s = s.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
496            (s >> 40) as f32 / (1u64 << 24) as f32 // in [0, 1)
497        };
498        for _ in 0..256 {
499            pts.push(Vec3::new(rand() * 4.0 - 2.0, rand() * 4.0 - 2.0, rand() * 4.0 + 2.0));
500        }
501
502        let cloud = xyz_cloud(&pts);
503        let seg = RansacSphereSegmenter::new(RansacPrimitiveConfig {
504            distance_threshold: 0.02,
505            max_iterations: 2000,
506            min_inliers: 100,
507            seed: 11,
508            ..RansacPrimitiveConfig::default()
509        });
510        let result = seg.segment(&cloud).unwrap();
511        assert!((result.model.radius - radius).abs() < 0.03, "radius {}", result.model.radius);
512        assert!((result.model.center - center).length() < 0.03);
513    }
514
515    #[test]
516    fn fits_cylinder_axis_and_radius() {
517        let radius = 0.4_f32;
518        // Cylinder along +z through the origin.
519        let mut samples = Vec::new();
520        for i in 0..20 {
521            for j in 0..24 {
522                let h = i as f32 * 0.1;
523                let phi = 2.0 * PI * j as f32 / 24.0;
524                let dir = Vec3::new(phi.cos(), phi.sin(), 0.0);
525                samples.push((Vec3::new(radius * phi.cos(), radius * phi.sin(), h), dir));
526            }
527        }
528        let cloud = xyz_normal_cloud(&samples);
529        let seg = RansacCylinderSegmenter::new(RansacPrimitiveConfig {
530            distance_threshold: 0.02,
531            max_iterations: 800,
532            min_inliers: 100,
533            seed: 5,
534            ..RansacPrimitiveConfig::default()
535        });
536        let result = seg.segment(&cloud).unwrap();
537        assert!((result.model.radius - radius).abs() < 0.03, "radius {}", result.model.radius);
538        // Axis should be (anti)parallel to +z.
539        assert!(result.model.axis_direction.z.abs() > 0.98);
540    }
541
542    #[test]
543    fn sphere_rejects_too_few_points() {
544        let cloud = xyz_cloud(&[Vec3::new(0.0, 0.0, 0.0), Vec3::new(1.0, 0.0, 0.0)]);
545        assert!(RansacSphereSegmenter::new(RansacPrimitiveConfig::default())
546            .segment(&cloud)
547            .is_err());
548    }
549}