1use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13use spatialrust_core::{
14 FieldSemantic, HasNormals3, HasPositions3, PointBuffer, PointBufferSet, PointCloud,
15 SpatialError, SpatialResult,
16};
17use spatialrust_math::Vec3;
18use spatialrust_search::{KdTree, NearestNeighborIndex};
19
20#[derive(Clone, Copy, Debug, PartialEq)]
22pub struct NormalOrientationConfig {
23 pub k_neighbors: usize,
25}
26
27impl Default for NormalOrientationConfig {
28 fn default() -> Self {
29 Self { k_neighbors: 15 }
30 }
31}
32
33impl NormalOrientationConfig {
34 #[must_use]
36 pub const fn new(k_neighbors: usize) -> Self {
37 Self { k_neighbors }
38 }
39}
40
41struct Edge {
43 weight: f32,
44 parent: u32,
45 node: u32,
46}
47
48impl PartialEq for Edge {
49 fn eq(&self, other: &Self) -> bool {
50 self.weight == other.weight
51 }
52}
53impl Eq for Edge {}
54impl PartialOrd for Edge {
55 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
56 Some(self.cmp(other))
57 }
58}
59impl Ord for Edge {
60 fn cmp(&self, other: &Self) -> Ordering {
61 other.weight.total_cmp(&self.weight)
63 }
64}
65
66pub fn orient_normals_consistent(
72 input: &PointCloud,
73 config: NormalOrientationConfig,
74) -> SpatialResult<PointCloud> {
75 if config.k_neighbors == 0 {
76 return Err(SpatialError::InvalidArgument(
77 "k_neighbors must be greater than zero".to_owned(),
78 ));
79 }
80 let len = input.len();
81 if len == 0 {
82 return Ok(input.clone());
83 }
84
85 let (x, y, z) = input.positions3()?;
86 let (nx, ny, nz) = input.normals3()?;
87 let mut normals: Vec<Vec3<f32>> = (0..len).map(|i| Vec3::new(nx[i], ny[i], nz[i])).collect();
88
89 let tree = KdTree::from_slices(x, y, z);
90 let neighbors_of = |i: usize| tree.nearest_k(x[i], y[i], z[i], config.k_neighbors + 1);
91
92 let mut visited = vec![false; len];
93 let mut heap: BinaryHeap<Edge> = BinaryHeap::new();
94
95 let mut order: Vec<usize> = (0..len).collect();
98 order.sort_by(|&a, &b| z[b].total_cmp(&z[a]));
99
100 for &seed in &order {
101 if visited[seed] {
102 continue;
103 }
104 if normals[seed].z < 0.0 {
106 normals[seed] = flip(normals[seed]);
107 }
108 visited[seed] = true;
109 push_edges(&mut heap, seed, &neighbors_of(seed), &visited, &normals);
110
111 while let Some(edge) = heap.pop() {
112 let node = edge.node as usize;
113 if visited[node] {
114 continue;
115 }
116 visited[node] = true;
117 if normals[node].dot(normals[edge.parent as usize]) < 0.0 {
119 normals[node] = flip(normals[node]);
120 }
121 push_edges(&mut heap, node, &neighbors_of(node), &visited, &normals);
122 }
123 }
124
125 build_output(input, &normals)
126}
127
128fn push_edges(
129 heap: &mut BinaryHeap<Edge>,
130 parent: usize,
131 neighbors: &[spatialrust_search::Neighbor],
132 visited: &[bool],
133 normals: &[Vec3<f32>],
134) {
135 for neighbor in neighbors {
136 let node = neighbor.index;
137 if node == parent || visited[node] {
138 continue;
139 }
140 let weight = 1.0 - normals[parent].dot(normals[node]).abs();
142 heap.push(Edge { weight, parent: parent as u32, node: node as u32 });
143 }
144}
145
146fn flip(v: Vec3<f32>) -> Vec3<f32> {
147 Vec3::new(-v.x, -v.y, -v.z)
148}
149
150fn build_output(input: &PointCloud, normals: &[Vec3<f32>]) -> SpatialResult<PointCloud> {
152 let mut buffers = PointBufferSet::new();
153 for field in input.schema().fields() {
154 let buffer = match field.semantic {
155 FieldSemantic::NormalX => PointBuffer::from_f32(normals.iter().map(|n| n.x).collect()),
156 FieldSemantic::NormalY => PointBuffer::from_f32(normals.iter().map(|n| n.y).collect()),
157 FieldSemantic::NormalZ => PointBuffer::from_f32(normals.iter().map(|n| n.z).collect()),
158 _ => clone_buffer(input.field(&field.name)?),
159 };
160 buffers.insert(field.name.clone(), buffer);
161 }
162 PointCloud::try_from_parts(input.schema().clone(), buffers, input.metadata().clone())
163}
164
165fn clone_buffer(buffer: &PointBuffer) -> PointBuffer {
166 match buffer {
167 PointBuffer::F32(v) => PointBuffer::from_f32(v.clone()),
168 PointBuffer::F64(v) => PointBuffer::F64(v.clone()),
169 PointBuffer::U8(v) => PointBuffer::U8(v.clone()),
170 PointBuffer::U16(v) => PointBuffer::U16(v.clone()),
171 PointBuffer::U32(v) => PointBuffer::U32(v.clone()),
172 PointBuffer::I32(v) => PointBuffer::I32(v.clone()),
173 }
174}
175
176#[cfg(test)]
177mod tests {
178 use super::{orient_normals_consistent, NormalOrientationConfig};
179 use spatialrust_core::{
180 DType, FieldSemantic, HasNormals3, PointCloudBuilder, PointField, PointSchema,
181 };
182
183 fn schema() -> PointSchema {
184 PointSchema::new()
185 .with_field(PointField::scalar("x", FieldSemantic::PositionX, DType::F32))
186 .with_field(PointField::scalar("y", FieldSemantic::PositionY, DType::F32))
187 .with_field(PointField::scalar("z", FieldSemantic::PositionZ, DType::F32))
188 .with_field(PointField::scalar("normal_x", FieldSemantic::NormalX, DType::F32))
189 .with_field(PointField::scalar("normal_y", FieldSemantic::NormalY, DType::F32))
190 .with_field(PointField::scalar("normal_z", FieldSemantic::NormalZ, DType::F32))
191 }
192
193 #[test]
194 fn flips_inconsistent_normals_on_a_plane() {
195 let mut builder = PointCloudBuilder::new(schema());
198 let mut flipped = 0;
199 for i in 0..8 {
200 for j in 0..8 {
201 let nz = if (i + j) % 2 == 0 { 1.0 } else { -1.0 };
202 if nz < 0.0 {
203 flipped += 1;
204 }
205 builder.push_point([i as f32, j as f32, 0.0, 0.0, 0.0, nz]).unwrap();
206 }
207 }
208 assert!(flipped > 0);
209 let cloud = builder.build().unwrap();
210
211 let oriented = orient_normals_consistent(&cloud, NormalOrientationConfig::new(8)).unwrap();
212 let (_, _, onz) = oriented.normals3().unwrap();
213 assert!(onz.iter().all(|&v| v > 0.5), "normals not consistently +Z");
215 }
216
217 #[test]
218 fn rejects_zero_neighbors() {
219 let mut builder = PointCloudBuilder::new(schema());
220 builder.push_point([0.0, 0.0, 0.0, 0.0, 0.0, 1.0]).unwrap();
221 let cloud = builder.build().unwrap();
222 assert!(orient_normals_consistent(&cloud, NormalOrientationConfig::new(0)).is_err());
223 }
224}