Skip to main content

spatialrust_segmentation/
cloud.rs

1use spatialrust_core::{
2    PointBuffer, PointBufferSet, PointCloud, PointField, SpatialError, SpatialResult,
3};
4
5/// Extracts a sub-cloud containing only the selected point indices.
6pub fn extract_indices(input: &PointCloud, indices: &[usize]) -> SpatialResult<PointCloud> {
7    if indices.is_empty() {
8        let mut buffers = PointBufferSet::new();
9        for field in input.schema().fields() {
10            buffers.insert(field.name.clone(), PointBuffer::with_capacity(field.dtype, 0));
11        }
12        return PointCloud::try_from_parts(
13            input.schema().clone(),
14            buffers,
15            input.metadata().clone(),
16        );
17    }
18
19    let mut buffers = PointBufferSet::new();
20    for field in input.schema().fields() {
21        let source = input.field(&field.name)?;
22        buffers.insert(field.name.clone(), gather_buffer(source, indices)?);
23    }
24
25    PointCloud::try_from_parts(input.schema().clone(), buffers, input.metadata().clone())
26}
27
28/// Extracts points where `mask[index]` is true.
29pub fn extract_mask(input: &PointCloud, mask: &[bool]) -> SpatialResult<PointCloud> {
30    if mask.len() != input.len() {
31        return Err(SpatialError::InvalidArgument(format!(
32            "mask length {} does not match point count {}",
33            mask.len(),
34            input.len()
35        )));
36    }
37
38    let indices: Vec<usize> =
39        mask.iter().enumerate().filter_map(|(index, selected)| selected.then_some(index)).collect();
40    extract_indices(input, &indices)
41}
42
43/// Adds or replaces a per-point label field on a point cloud copy.
44pub fn with_labels(
45    input: &PointCloud,
46    field_name: &str,
47    labels: Vec<i32>,
48) -> SpatialResult<PointCloud> {
49    if labels.len() != input.len() {
50        return Err(SpatialError::InvalidArgument(format!(
51            "label length {} does not match point count {}",
52            labels.len(),
53            input.len()
54        )));
55    }
56
57    let mut schema = input.schema().clone();
58    if schema.find_semantic(spatialrust_core::FieldSemantic::Label).is_none() {
59        schema = schema.with_field(PointField::scalar(
60            field_name,
61            spatialrust_core::FieldSemantic::Label,
62            spatialrust_core::DType::I32,
63        ));
64    }
65
66    let mut buffers = PointBufferSet::new();
67    for field in input.schema().fields() {
68        let source = input.field(&field.name)?;
69        buffers.insert(field.name.clone(), clone_buffer(source)?);
70    }
71    buffers.insert(field_name.to_owned(), PointBuffer::I32(labels));
72
73    PointCloud::try_from_parts(schema, buffers, input.metadata().clone())
74}
75
76fn gather_buffer(source: &PointBuffer, indices: &[usize]) -> SpatialResult<PointBuffer> {
77    Ok(match source {
78        PointBuffer::F32(values) => {
79            PointBuffer::from_f32(indices.iter().map(|&index| values[index]).collect())
80        }
81        PointBuffer::F64(values) => {
82            PointBuffer::F64(indices.iter().map(|&index| values[index]).collect())
83        }
84        PointBuffer::U8(values) => {
85            PointBuffer::U8(indices.iter().map(|&index| values[index]).collect())
86        }
87        PointBuffer::U16(values) => {
88            PointBuffer::U16(indices.iter().map(|&index| values[index]).collect())
89        }
90        PointBuffer::U32(values) => {
91            PointBuffer::U32(indices.iter().map(|&index| values[index]).collect())
92        }
93        PointBuffer::I32(values) => {
94            PointBuffer::I32(indices.iter().map(|&index| values[index]).collect())
95        }
96    })
97}
98
99fn clone_buffer(buffer: &PointBuffer) -> SpatialResult<PointBuffer> {
100    Ok(match buffer {
101        PointBuffer::F32(values) => PointBuffer::from_f32(values.clone()),
102        PointBuffer::F64(values) => PointBuffer::F64(values.clone()),
103        PointBuffer::U8(values) => PointBuffer::U8(values.clone()),
104        PointBuffer::U16(values) => PointBuffer::U16(values.clone()),
105        PointBuffer::U32(values) => PointBuffer::U32(values.clone()),
106        PointBuffer::I32(values) => PointBuffer::I32(values.clone()),
107    })
108}