1use spatialrust_core::{HasNormals3, HasPositions3, PointCloud, SpatialError, SpatialResult};
9use spatialrust_math::{solve_linear_system, LeastSquaresResult, Vec3};
10
11use crate::cloud::extract_mask;
12use crate::segmenter::PointCloudSegmenter;
13
14#[derive(Clone, Copy, Debug, PartialEq)]
16pub struct RansacPrimitiveConfig {
17 pub distance_threshold: f32,
19 pub max_iterations: usize,
21 pub min_inliers: usize,
23 pub min_radius: f32,
25 pub max_radius: f32,
27 pub seed: u64,
29}
30
31impl Default for RansacPrimitiveConfig {
32 fn default() -> Self {
33 Self {
34 distance_threshold: 0.02,
35 max_iterations: 1_000,
36 min_inliers: 10,
37 min_radius: 0.0,
38 max_radius: f32::INFINITY,
39 seed: 42,
40 }
41 }
42}
43
44#[derive(Clone, Copy, Debug, PartialEq)]
46pub struct SphereModel {
47 pub center: Vec3<f32>,
49 pub radius: f32,
51}
52
53impl SphereModel {
54 #[must_use]
56 pub fn distance(&self, point: Vec3<f32>) -> f32 {
57 ((point - self.center).length() - self.radius).abs()
58 }
59}
60
61#[derive(Clone, Copy, Debug, PartialEq)]
63pub struct CylinderModel {
64 pub axis_point: Vec3<f32>,
66 pub axis_direction: Vec3<f32>,
68 pub radius: f32,
70}
71
72impl CylinderModel {
73 #[must_use]
75 pub fn axis_distance(&self, point: Vec3<f32>) -> f32 {
76 let v = point - self.axis_point;
77 let along = scale(self.axis_direction, v.dot(self.axis_direction));
78 (v - along).length()
79 }
80
81 #[must_use]
83 pub fn distance(&self, point: Vec3<f32>) -> f32 {
84 (self.axis_distance(point) - self.radius).abs()
85 }
86}
87
88#[derive(Clone, Debug, PartialEq)]
90pub struct PrimitiveSegmentation<M> {
91 pub model: M,
93 pub inliers: PointCloud,
95 pub outliers: PointCloud,
97 pub inlier_count: usize,
99}
100
101#[derive(Clone, Copy, Debug, PartialEq)]
103pub struct RansacSphereSegmenter {
104 config: RansacPrimitiveConfig,
105}
106
107impl RansacSphereSegmenter {
108 #[must_use]
110 pub const fn new(config: RansacPrimitiveConfig) -> Self {
111 Self { config }
112 }
113
114 #[must_use]
116 pub const fn config(&self) -> RansacPrimitiveConfig {
117 self.config
118 }
119
120 pub fn segment(&self, input: &PointCloud) -> SpatialResult<PrimitiveSegmentation<SphereModel>> {
122 let (x, y, z) = input.positions3()?;
123 let len = input.len();
124 if len < 4 {
125 return Err(SpatialError::InvalidArgument(
126 "sphere fitting requires at least four points".to_owned(),
127 ));
128 }
129
130 let mut rng = Rng::new(self.config.seed);
131 let mut best_inliers: Vec<usize> = Vec::new();
132 let mut best_model = None;
133
134 for _ in 0..self.config.max_iterations {
135 let Some(sample) = sample_distinct::<4>(&mut rng, len) else {
136 continue;
137 };
138 let Some(model) = sphere_from_points(x, y, z, sample) else {
139 continue;
140 };
141 if model.radius < self.config.min_radius || model.radius > self.config.max_radius {
142 continue;
143 }
144 let inliers = collect_inliers(len, self.config.distance_threshold, |i| {
145 model.distance(Vec3::new(x[i], y[i], z[i]))
146 });
147 if inliers.len() > best_inliers.len() {
148 best_inliers = inliers;
149 best_model = Some(model);
150 }
151 }
152
153 finalize(input, best_model, &best_inliers, self.config.min_inliers)
154 }
155}
156
157impl PointCloudSegmenter for RansacSphereSegmenter {
158 fn name(&self) -> &'static str {
159 "RansacSphereSegmenter"
160 }
161}
162
163#[derive(Clone, Copy, Debug, PartialEq)]
165pub struct RansacCylinderSegmenter {
166 config: RansacPrimitiveConfig,
167}
168
169impl RansacCylinderSegmenter {
170 #[must_use]
172 pub const fn new(config: RansacPrimitiveConfig) -> Self {
173 Self { config }
174 }
175
176 #[must_use]
178 pub const fn config(&self) -> RansacPrimitiveConfig {
179 self.config
180 }
181
182 pub fn segment(
184 &self,
185 input: &PointCloud,
186 ) -> SpatialResult<PrimitiveSegmentation<CylinderModel>> {
187 let (x, y, z) = input.positions3()?;
188 let (nx, ny, nz) = input.normals3()?;
189 let len = input.len();
190 if len < 2 {
191 return Err(SpatialError::InvalidArgument(
192 "cylinder fitting requires at least two points".to_owned(),
193 ));
194 }
195
196 let mut rng = Rng::new(self.config.seed);
197 let mut best_inliers: Vec<usize> = Vec::new();
198 let mut best_model = None;
199
200 for _ in 0..self.config.max_iterations {
201 let Some(sample) = sample_distinct::<2>(&mut rng, len) else {
202 continue;
203 };
204 let Some(model) = cylinder_from_points(x, y, z, nx, ny, nz, sample) else {
205 continue;
206 };
207 if model.radius < self.config.min_radius || model.radius > self.config.max_radius {
208 continue;
209 }
210 let inliers = collect_inliers(len, self.config.distance_threshold, |i| {
211 model.distance(Vec3::new(x[i], y[i], z[i]))
212 });
213 if inliers.len() > best_inliers.len() {
214 best_inliers = inliers;
215 best_model = Some(model);
216 }
217 }
218
219 finalize(input, best_model, &best_inliers, self.config.min_inliers)
220 }
221}
222
223impl PointCloudSegmenter for RansacCylinderSegmenter {
224 fn name(&self) -> &'static str {
225 "RansacCylinderSegmenter"
226 }
227}
228
229fn finalize<M>(
231 input: &PointCloud,
232 best_model: Option<M>,
233 best_inliers: &[usize],
234 min_inliers: usize,
235) -> SpatialResult<PrimitiveSegmentation<M>> {
236 if best_inliers.len() < min_inliers || best_model.is_none() {
237 return Err(SpatialError::InvalidArgument(format!(
238 "RANSAC found only {} inliers, minimum is {min_inliers}",
239 best_inliers.len()
240 )));
241 }
242 let model = best_model.expect("checked above");
243
244 let mut inlier_mask = vec![false; input.len()];
245 for &index in best_inliers {
246 inlier_mask[index] = true;
247 }
248 let outlier_mask: Vec<bool> = inlier_mask.iter().map(|&keep| !keep).collect();
249
250 Ok(PrimitiveSegmentation {
251 model,
252 inliers: extract_mask(input, &inlier_mask)?,
253 outliers: extract_mask(input, &outlier_mask)?,
254 inlier_count: best_inliers.len(),
255 })
256}
257
258fn collect_inliers(len: usize, threshold: f32, distance: impl Fn(usize) -> f32) -> Vec<usize> {
259 (0..len).filter(|&i| distance(i) <= threshold).collect()
260}
261
262fn sphere_from_points(x: &[f32], y: &[f32], z: &[f32], idx: [usize; 4]) -> Option<SphereModel> {
265 let p: Vec<[f64; 3]> =
266 idx.iter().map(|&i| [f64::from(x[i]), f64::from(y[i]), f64::from(z[i])]).collect();
267 let sq = |q: [f64; 3]| q[0] * q[0] + q[1] * q[1] + q[2] * q[2];
268
269 let mut a = Vec::with_capacity(3);
270 let mut b = Vec::with_capacity(3);
271 for row in 1..4 {
272 a.push(vec![
273 2.0 * (p[row][0] - p[0][0]),
274 2.0 * (p[row][1] - p[0][1]),
275 2.0 * (p[row][2] - p[0][2]),
276 ]);
277 b.push(sq(p[row]) - sq(p[0]));
278 }
279
280 let center = match solve_linear_system(a, b) {
281 LeastSquaresResult::Solved(c) => c,
282 LeastSquaresResult::Singular => return None,
283 };
284 let center = Vec3::new(center[0] as f32, center[1] as f32, center[2] as f32);
285 let radius = (center - Vec3::new(x[idx[0]], y[idx[0]], z[idx[0]])).length();
286 if !radius.is_finite() {
287 return None;
288 }
289 Some(SphereModel { center, radius })
290}
291
292#[allow(clippy::too_many_arguments)]
297fn cylinder_from_points(
298 x: &[f32],
299 y: &[f32],
300 z: &[f32],
301 nx: &[f32],
302 ny: &[f32],
303 nz: &[f32],
304 idx: [usize; 2],
305) -> Option<CylinderModel> {
306 let (i0, i1) = (idx[0], idx[1]);
307 let p0 = Vec3::new(x[i0], y[i0], z[i0]);
308 let p1 = Vec3::new(x[i1], y[i1], z[i1]);
309 let n0 = Vec3::new(nx[i0], ny[i0], nz[i0]);
310 let n1 = Vec3::new(nx[i1], ny[i1], nz[i1]);
311
312 let axis = n0.cross(n1);
313 if axis.length_squared() < 1e-10 {
314 return None; }
316 let axis = axis.normalize();
317
318 let helper =
320 if axis.x.abs() < 0.9 { Vec3::new(1.0, 0.0, 0.0) } else { Vec3::new(0.0, 1.0, 0.0) };
321 let u = axis.cross(helper).normalize();
322 let w = axis.cross(u);
323
324 let proj = |v: Vec3<f32>| (v.dot(u), v.dot(w));
325 let (p0u, p0w) = proj(p0);
326 let (p1u, p1w) = proj(p1);
327 let (mut n0u, mut n0w) = proj(n0);
328 let (mut n1u, mut n1w) = proj(n1);
329 let l0 = (n0u * n0u + n0w * n0w).sqrt();
330 let l1 = (n1u * n1u + n1w * n1w).sqrt();
331 if l0 < 1e-6 || l1 < 1e-6 {
332 return None; }
334 n0u /= l0;
335 n0w /= l0;
336 n1u /= l1;
337 n1w /= l1;
338
339 let det = n0u * (-n1w) - (-n1u) * n0w;
342 if det.abs() < 1e-9 {
343 return None;
344 }
345 let rhs = (p1u - p0u, p1w - p0w);
346 let t0 = (rhs.0 * (-n1w) - (-n1u) * rhs.1) / det;
347
348 let center_u = p0u + t0 * n0u;
349 let center_w = p0w + t0 * n0w;
350 let radius = ((center_u - p0u).powi(2) + (center_w - p0w).powi(2)).sqrt();
351 if !radius.is_finite() {
352 return None;
353 }
354
355 let axis_point = scale(u, center_u) + scale(w, center_w);
356 Some(CylinderModel { axis_point, axis_direction: axis, radius })
357}
358
359fn scale(v: Vec3<f32>, s: f32) -> Vec3<f32> {
361 Vec3::new(v.x * s, v.y * s, v.z * s)
362}
363
364struct Rng {
365 state: u64,
366}
367
368impl Rng {
369 fn new(seed: u64) -> Self {
370 Self { state: seed.max(1) }
371 }
372
373 fn next_usize(&mut self, upper: usize) -> usize {
374 self.state = self.state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
375 (((self.state >> 32) * upper as u64) >> 32) as usize
378 }
379}
380
381fn sample_distinct<const N: usize>(rng: &mut Rng, len: usize) -> Option<[usize; N]> {
383 if len < N {
384 return None;
385 }
386 let mut out = [0usize; N];
387 let mut filled = 0;
388 let mut attempts = 0;
389 while filled < N && attempts < N * 16 {
390 let candidate = rng.next_usize(len);
391 if !out[..filled].contains(&candidate) {
392 out[filled] = candidate;
393 filled += 1;
394 }
395 attempts += 1;
396 }
397 (filled == N).then_some(out)
398}
399
400#[cfg(test)]
401mod tests {
402 use super::*;
403 use spatialrust_core::{DType, FieldSemantic, PointCloudBuilder, PointField, PointSchema};
404 use std::f32::consts::PI;
405
406 fn xyz_cloud(points: &[Vec3<f32>]) -> PointCloud {
407 let mut builder = PointCloudBuilder::new(
408 PointSchema::new()
409 .with_field(PointField::scalar("x", FieldSemantic::PositionX, DType::F32))
410 .with_field(PointField::scalar("y", FieldSemantic::PositionY, DType::F32))
411 .with_field(PointField::scalar("z", FieldSemantic::PositionZ, DType::F32)),
412 );
413 for p in points {
414 builder.push_point([p.x, p.y, p.z]).unwrap();
415 }
416 builder.build().unwrap()
417 }
418
419 fn xyz_normal_cloud(points: &[(Vec3<f32>, Vec3<f32>)]) -> PointCloud {
420 let mut builder = PointCloudBuilder::new(
421 PointSchema::new()
422 .with_field(PointField::scalar("x", FieldSemantic::PositionX, DType::F32))
423 .with_field(PointField::scalar("y", FieldSemantic::PositionY, DType::F32))
424 .with_field(PointField::scalar("z", FieldSemantic::PositionZ, DType::F32))
425 .with_field(PointField::scalar("normal_x", FieldSemantic::NormalX, DType::F32))
426 .with_field(PointField::scalar("normal_y", FieldSemantic::NormalY, DType::F32))
427 .with_field(PointField::scalar("normal_z", FieldSemantic::NormalZ, DType::F32)),
428 );
429 for (p, n) in points {
430 builder.push_point([p.x, p.y, p.z, n.x, n.y, n.z]).unwrap();
431 }
432 builder.build().unwrap()
433 }
434
435 #[test]
436 fn fits_sphere_with_outliers() {
437 let center = Vec3::new(1.0, 2.0, 3.0);
438 let radius = 0.5_f32;
439 let mut pts = Vec::new();
440 for i in 0..12 {
441 for j in 0..12 {
442 let theta = PI * i as f32 / 11.0;
443 let phi = 2.0 * PI * j as f32 / 12.0;
444 pts.push(
445 center
446 + Vec3::new(
447 radius * theta.sin() * phi.cos(),
448 radius * theta.sin() * phi.sin(),
449 radius * theta.cos(),
450 ),
451 );
452 }
453 }
454 pts.push(center + Vec3::new(3.0, 0.0, 0.0));
456 pts.push(center + Vec3::new(0.0, 3.0, 0.0));
457
458 let cloud = xyz_cloud(&pts);
459 let seg = RansacSphereSegmenter::new(RansacPrimitiveConfig {
460 distance_threshold: 0.02,
461 max_iterations: 800,
462 min_inliers: 50,
463 seed: 3,
464 ..RansacPrimitiveConfig::default()
465 });
466 let result = seg.segment(&cloud).unwrap();
467 assert!((result.model.radius - radius).abs() < 0.02);
468 assert!((result.model.center - center).length() < 0.02);
469 assert_eq!(result.outliers.len(), 2);
470 }
471
472 #[test]
473 fn fits_sphere_among_many_distractors() {
474 let center = Vec3::new(0.0, 0.0, 0.0);
479 let radius = 0.4_f32;
480 let mut pts = Vec::new();
481 for i in 0..16 {
482 for j in 0..16 {
483 let theta = PI * i as f32 / 15.0;
484 let phi = 2.0 * PI * j as f32 / 16.0;
485 pts.push(Vec3::new(
486 radius * theta.sin() * phi.cos(),
487 radius * theta.sin() * phi.sin(),
488 radius * theta.cos(),
489 ));
490 }
491 }
492 let mut s = 1234_u64;
494 let mut rand = || {
495 s = s.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
496 (s >> 40) as f32 / (1u64 << 24) as f32 };
498 for _ in 0..256 {
499 pts.push(Vec3::new(rand() * 4.0 - 2.0, rand() * 4.0 - 2.0, rand() * 4.0 + 2.0));
500 }
501
502 let cloud = xyz_cloud(&pts);
503 let seg = RansacSphereSegmenter::new(RansacPrimitiveConfig {
504 distance_threshold: 0.02,
505 max_iterations: 2000,
506 min_inliers: 100,
507 seed: 11,
508 ..RansacPrimitiveConfig::default()
509 });
510 let result = seg.segment(&cloud).unwrap();
511 assert!((result.model.radius - radius).abs() < 0.03, "radius {}", result.model.radius);
512 assert!((result.model.center - center).length() < 0.03);
513 }
514
515 #[test]
516 fn fits_cylinder_axis_and_radius() {
517 let radius = 0.4_f32;
518 let mut samples = Vec::new();
520 for i in 0..20 {
521 for j in 0..24 {
522 let h = i as f32 * 0.1;
523 let phi = 2.0 * PI * j as f32 / 24.0;
524 let dir = Vec3::new(phi.cos(), phi.sin(), 0.0);
525 samples.push((Vec3::new(radius * phi.cos(), radius * phi.sin(), h), dir));
526 }
527 }
528 let cloud = xyz_normal_cloud(&samples);
529 let seg = RansacCylinderSegmenter::new(RansacPrimitiveConfig {
530 distance_threshold: 0.02,
531 max_iterations: 800,
532 min_inliers: 100,
533 seed: 5,
534 ..RansacPrimitiveConfig::default()
535 });
536 let result = seg.segment(&cloud).unwrap();
537 assert!((result.model.radius - radius).abs() < 0.03, "radius {}", result.model.radius);
538 assert!(result.model.axis_direction.z.abs() > 0.98);
540 }
541
542 #[test]
543 fn sphere_rejects_too_few_points() {
544 let cloud = xyz_cloud(&[Vec3::new(0.0, 0.0, 0.0), Vec3::new(1.0, 0.0, 0.0)]);
545 assert!(RansacSphereSegmenter::new(RansacPrimitiveConfig::default())
546 .segment(&cloud)
547 .is_err());
548 }
549}