spatialrust_segmentation/
cloud.rs1use spatialrust_core::{
2 PointBuffer, PointBufferSet, PointCloud, PointField, SpatialError, SpatialResult,
3};
4
5pub fn extract_indices(input: &PointCloud, indices: &[usize]) -> SpatialResult<PointCloud> {
7 if indices.is_empty() {
8 let mut buffers = PointBufferSet::new();
9 for field in input.schema().fields() {
10 buffers.insert(field.name.clone(), PointBuffer::with_capacity(field.dtype, 0));
11 }
12 return PointCloud::try_from_parts(
13 input.schema().clone(),
14 buffers,
15 input.metadata().clone(),
16 );
17 }
18
19 let mut buffers = PointBufferSet::new();
20 for field in input.schema().fields() {
21 let source = input.field(&field.name)?;
22 buffers.insert(field.name.clone(), gather_buffer(source, indices)?);
23 }
24
25 PointCloud::try_from_parts(input.schema().clone(), buffers, input.metadata().clone())
26}
27
28pub fn extract_mask(input: &PointCloud, mask: &[bool]) -> SpatialResult<PointCloud> {
30 if mask.len() != input.len() {
31 return Err(SpatialError::InvalidArgument(format!(
32 "mask length {} does not match point count {}",
33 mask.len(),
34 input.len()
35 )));
36 }
37
38 let indices: Vec<usize> =
39 mask.iter().enumerate().filter_map(|(index, selected)| selected.then_some(index)).collect();
40 extract_indices(input, &indices)
41}
42
43pub fn with_labels(
45 input: &PointCloud,
46 field_name: &str,
47 labels: Vec<i32>,
48) -> SpatialResult<PointCloud> {
49 if labels.len() != input.len() {
50 return Err(SpatialError::InvalidArgument(format!(
51 "label length {} does not match point count {}",
52 labels.len(),
53 input.len()
54 )));
55 }
56
57 let mut schema = input.schema().clone();
58 if schema.find_semantic(spatialrust_core::FieldSemantic::Label).is_none() {
59 schema = schema.with_field(PointField::scalar(
60 field_name,
61 spatialrust_core::FieldSemantic::Label,
62 spatialrust_core::DType::I32,
63 ));
64 }
65
66 let mut buffers = PointBufferSet::new();
67 for field in input.schema().fields() {
68 let source = input.field(&field.name)?;
69 buffers.insert(field.name.clone(), clone_buffer(source)?);
70 }
71 buffers.insert(field_name.to_owned(), PointBuffer::I32(labels));
72
73 PointCloud::try_from_parts(schema, buffers, input.metadata().clone())
74}
75
76fn gather_buffer(source: &PointBuffer, indices: &[usize]) -> SpatialResult<PointBuffer> {
77 Ok(match source {
78 PointBuffer::F32(values) => {
79 PointBuffer::from_f32(indices.iter().map(|&index| values[index]).collect())
80 }
81 PointBuffer::F64(values) => {
82 PointBuffer::F64(indices.iter().map(|&index| values[index]).collect())
83 }
84 PointBuffer::U8(values) => {
85 PointBuffer::U8(indices.iter().map(|&index| values[index]).collect())
86 }
87 PointBuffer::U16(values) => {
88 PointBuffer::U16(indices.iter().map(|&index| values[index]).collect())
89 }
90 PointBuffer::U32(values) => {
91 PointBuffer::U32(indices.iter().map(|&index| values[index]).collect())
92 }
93 PointBuffer::I32(values) => {
94 PointBuffer::I32(indices.iter().map(|&index| values[index]).collect())
95 }
96 })
97}
98
99fn clone_buffer(buffer: &PointBuffer) -> SpatialResult<PointBuffer> {
100 Ok(match buffer {
101 PointBuffer::F32(values) => PointBuffer::from_f32(values.clone()),
102 PointBuffer::F64(values) => PointBuffer::F64(values.clone()),
103 PointBuffer::U8(values) => PointBuffer::U8(values.clone()),
104 PointBuffer::U16(values) => PointBuffer::U16(values.clone()),
105 PointBuffer::U32(values) => PointBuffer::U32(values.clone()),
106 PointBuffer::I32(values) => PointBuffer::I32(values.clone()),
107 })
108}