Skip to main content

spatialrust_features/
orient.rs

1//! Consistent normal orientation via minimum-spanning-tree propagation.
2//!
3//! Normal *estimation* recovers each normal only up to sign, so a surface comes
4//! out with normals pointing randomly inward/outward. This propagates a single
5//! consistent orientation across a k-nearest-neighbor graph: starting from a
6//! seed oriented upward, it walks a minimum spanning tree (edges weighted so
7//! that near-parallel neighbors are visited first, à la Hoppe) and flips each
8//! normal to agree with the one it was reached from.
9
10use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13use spatialrust_core::{
14    FieldSemantic, HasNormals3, HasPositions3, PointBuffer, PointBufferSet, PointCloud,
15    SpatialError, SpatialResult,
16};
17use spatialrust_math::Vec3;
18use spatialrust_search::{KdTree, NearestNeighborIndex};
19
20/// Configuration for [`orient_normals_consistent`].
21#[derive(Clone, Copy, Debug, PartialEq)]
22pub struct NormalOrientationConfig {
23    /// Number of nearest neighbors used to build the propagation graph.
24    pub k_neighbors: usize,
25}
26
27impl Default for NormalOrientationConfig {
28    fn default() -> Self {
29        Self { k_neighbors: 15 }
30    }
31}
32
33impl NormalOrientationConfig {
34    /// Creates a config with the given neighbor count.
35    #[must_use]
36    pub const fn new(k_neighbors: usize) -> Self {
37        Self { k_neighbors }
38    }
39}
40
41/// An MST candidate edge, ordered so the binary heap pops the smallest weight.
42struct Edge {
43    weight: f32,
44    parent: u32,
45    node: u32,
46}
47
48impl PartialEq for Edge {
49    fn eq(&self, other: &Self) -> bool {
50        self.weight == other.weight
51    }
52}
53impl Eq for Edge {}
54impl PartialOrd for Edge {
55    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
56        Some(self.cmp(other))
57    }
58}
59impl Ord for Edge {
60    fn cmp(&self, other: &Self) -> Ordering {
61        // Reverse so the max-heap behaves as a min-heap on weight.
62        other.weight.total_cmp(&self.weight)
63    }
64}
65
66/// Re-orients a cloud's normals so neighboring normals agree in sign.
67///
68/// The cloud must already carry normals. The seed of each connected component is
69/// oriented to point along `+Z` (upward); apply a viewpoint convention
70/// afterwards (e.g. `orient_normal_towards_viewpoint`) if one is needed instead.
71pub fn orient_normals_consistent(
72    input: &PointCloud,
73    config: NormalOrientationConfig,
74) -> SpatialResult<PointCloud> {
75    if config.k_neighbors == 0 {
76        return Err(SpatialError::InvalidArgument(
77            "k_neighbors must be greater than zero".to_owned(),
78        ));
79    }
80    let len = input.len();
81    if len == 0 {
82        return Ok(input.clone());
83    }
84
85    let (x, y, z) = input.positions3()?;
86    let (nx, ny, nz) = input.normals3()?;
87    let mut normals: Vec<Vec3<f32>> = (0..len).map(|i| Vec3::new(nx[i], ny[i], nz[i])).collect();
88
89    let tree = KdTree::from_slices(x, y, z);
90    let neighbors_of = |i: usize| tree.nearest_k(x[i], y[i], z[i], config.k_neighbors + 1);
91
92    let mut visited = vec![false; len];
93    let mut heap: BinaryHeap<Edge> = BinaryHeap::new();
94
95    // Process seeds in descending height so each component starts from a point
96    // whose "up" orientation is meaningful.
97    let mut order: Vec<usize> = (0..len).collect();
98    order.sort_by(|&a, &b| z[b].total_cmp(&z[a]));
99
100    for &seed in &order {
101        if visited[seed] {
102            continue;
103        }
104        // Orient the seed upward.
105        if normals[seed].z < 0.0 {
106            normals[seed] = flip(normals[seed]);
107        }
108        visited[seed] = true;
109        push_edges(&mut heap, seed, &neighbors_of(seed), &visited, &normals);
110
111        while let Some(edge) = heap.pop() {
112            let node = edge.node as usize;
113            if visited[node] {
114                continue;
115            }
116            visited[node] = true;
117            // Flip the new normal to agree with the one we reached it from.
118            if normals[node].dot(normals[edge.parent as usize]) < 0.0 {
119                normals[node] = flip(normals[node]);
120            }
121            push_edges(&mut heap, node, &neighbors_of(node), &visited, &normals);
122        }
123    }
124
125    build_output(input, &normals)
126}
127
128fn push_edges(
129    heap: &mut BinaryHeap<Edge>,
130    parent: usize,
131    neighbors: &[spatialrust_search::Neighbor],
132    visited: &[bool],
133    normals: &[Vec3<f32>],
134) {
135    for neighbor in neighbors {
136        let node = neighbor.index;
137        if node == parent || visited[node] {
138            continue;
139        }
140        // Near-parallel normals get low weight, so they propagate first.
141        let weight = 1.0 - normals[parent].dot(normals[node]).abs();
142        heap.push(Edge { weight, parent: parent as u32, node: node as u32 });
143    }
144}
145
146fn flip(v: Vec3<f32>) -> Vec3<f32> {
147    Vec3::new(-v.x, -v.y, -v.z)
148}
149
150/// Rebuilds the cloud, replacing only the normal columns.
151fn build_output(input: &PointCloud, normals: &[Vec3<f32>]) -> SpatialResult<PointCloud> {
152    let mut buffers = PointBufferSet::new();
153    for field in input.schema().fields() {
154        let buffer = match field.semantic {
155            FieldSemantic::NormalX => PointBuffer::from_f32(normals.iter().map(|n| n.x).collect()),
156            FieldSemantic::NormalY => PointBuffer::from_f32(normals.iter().map(|n| n.y).collect()),
157            FieldSemantic::NormalZ => PointBuffer::from_f32(normals.iter().map(|n| n.z).collect()),
158            _ => clone_buffer(input.field(&field.name)?),
159        };
160        buffers.insert(field.name.clone(), buffer);
161    }
162    PointCloud::try_from_parts(input.schema().clone(), buffers, input.metadata().clone())
163}
164
165fn clone_buffer(buffer: &PointBuffer) -> PointBuffer {
166    match buffer {
167        PointBuffer::F32(v) => PointBuffer::from_f32(v.clone()),
168        PointBuffer::F64(v) => PointBuffer::F64(v.clone()),
169        PointBuffer::U8(v) => PointBuffer::U8(v.clone()),
170        PointBuffer::U16(v) => PointBuffer::U16(v.clone()),
171        PointBuffer::U32(v) => PointBuffer::U32(v.clone()),
172        PointBuffer::I32(v) => PointBuffer::I32(v.clone()),
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::{orient_normals_consistent, NormalOrientationConfig};
179    use spatialrust_core::{
180        DType, FieldSemantic, HasNormals3, PointCloudBuilder, PointField, PointSchema,
181    };
182
183    fn schema() -> PointSchema {
184        PointSchema::new()
185            .with_field(PointField::scalar("x", FieldSemantic::PositionX, DType::F32))
186            .with_field(PointField::scalar("y", FieldSemantic::PositionY, DType::F32))
187            .with_field(PointField::scalar("z", FieldSemantic::PositionZ, DType::F32))
188            .with_field(PointField::scalar("normal_x", FieldSemantic::NormalX, DType::F32))
189            .with_field(PointField::scalar("normal_y", FieldSemantic::NormalY, DType::F32))
190            .with_field(PointField::scalar("normal_z", FieldSemantic::NormalZ, DType::F32))
191    }
192
193    #[test]
194    fn flips_inconsistent_normals_on_a_plane() {
195        // A flat grid whose true normal is +Z, but every other point's normal is
196        // flipped to -Z. After orientation they should all agree.
197        let mut builder = PointCloudBuilder::new(schema());
198        let mut flipped = 0;
199        for i in 0..8 {
200            for j in 0..8 {
201                let nz = if (i + j) % 2 == 0 { 1.0 } else { -1.0 };
202                if nz < 0.0 {
203                    flipped += 1;
204                }
205                builder.push_point([i as f32, j as f32, 0.0, 0.0, 0.0, nz]).unwrap();
206            }
207        }
208        assert!(flipped > 0);
209        let cloud = builder.build().unwrap();
210
211        let oriented = orient_normals_consistent(&cloud, NormalOrientationConfig::new(8)).unwrap();
212        let (_, _, onz) = oriented.normals3().unwrap();
213        // All normals should now point the same way (+Z, since the seed is up).
214        assert!(onz.iter().all(|&v| v > 0.5), "normals not consistently +Z");
215    }
216
217    #[test]
218    fn rejects_zero_neighbors() {
219        let mut builder = PointCloudBuilder::new(schema());
220        builder.push_point([0.0, 0.0, 0.0, 0.0, 0.0, 1.0]).unwrap();
221        let cloud = builder.build().unwrap();
222        assert!(orient_normals_consistent(&cloud, NormalOrientationConfig::new(0)).is_err());
223    }
224}