Skip to main content

rust_robotics_core/
experiments.rs

1//! Shared experiment-contract helpers for cross-package exploratory workflows.
2//!
3//! This module intentionally stays small. It only carries the reusable pieces
4//! that already proved stable across multiple experiment packages:
5//! - variant descriptors
6//! - sampling plans
7//! - source/extensibility metrics
8//! - generic reference annotation against a chosen baseline
9
10use std::collections::HashMap;
11use std::fs;
12use std::hash::Hash;
13use std::path::Path;
14
15#[derive(Debug, Clone, Copy)]
16pub struct VariantDescriptor {
17    pub id: &'static str,
18    pub design_style: &'static str,
19    pub source_path: &'static str,
20    pub knob_count: usize,
21    pub reports_dispersion: bool,
22}
23
24#[derive(Debug, Clone)]
25pub struct ExperimentSamplingPlan {
26    pub initial_slots: Vec<usize>,
27    pub escalation_slots: Vec<usize>,
28    pub escalate_if_vote_split: bool,
29    pub escalate_if_ratio_margin_below: Option<f64>,
30}
31
32impl ExperimentSamplingPlan {
33    pub fn static_slots(slots: Vec<usize>) -> Self {
34        Self {
35            initial_slots: slots,
36            escalation_slots: Vec::new(),
37            escalate_if_vote_split: false,
38            escalate_if_ratio_margin_below: None,
39        }
40    }
41}
42
43#[derive(Debug, Clone, Copy)]
44pub struct SourceMetrics {
45    pub code_lines: usize,
46    pub comment_lines: usize,
47    pub branch_keywords: usize,
48}
49
50#[derive(Debug, Clone, Copy)]
51pub struct ExtensibilityMetrics {
52    pub average_coverage_ratio: f64,
53    pub knob_count: usize,
54    pub reports_dispersion: bool,
55}
56
57#[derive(Debug, Clone)]
58pub struct ExperimentVariantReport<T> {
59    pub descriptor: VariantDescriptor,
60    pub evaluation_runtime_ms: f64,
61    pub observations: Vec<T>,
62    pub source_metrics: SourceMetrics,
63    pub extensibility_metrics: ExtensibilityMetrics,
64    pub agreement_vs_reference: Option<f64>,
65    pub mean_ratio_error_vs_reference: Option<f64>,
66}
67
68pub trait ExperimentObservation {
69    type Key: Eq + Hash + Clone;
70
71    fn comparison_key(&self) -> Self::Key;
72    fn winner_label(&self) -> &'static str;
73    fn ratio_value(&self) -> f64;
74    fn coverage_ratio(&self) -> f64;
75}
76
77pub fn average_coverage_ratio<T: ExperimentObservation>(observations: &[T]) -> f64 {
78    if observations.is_empty() {
79        return 0.0;
80    }
81
82    observations
83        .iter()
84        .map(ExperimentObservation::coverage_ratio)
85        .sum::<f64>()
86        / observations.len() as f64
87}
88
89pub fn annotate_against_reference<T>(reports: &mut [ExperimentVariantReport<T>], reference_id: &str)
90where
91    T: ExperimentObservation,
92{
93    let Some(reference) = reports
94        .iter()
95        .find(|report| report.descriptor.id == reference_id)
96    else {
97        return;
98    };
99
100    let reference_lookup: HashMap<T::Key, (&'static str, f64)> = reference
101        .observations
102        .iter()
103        .map(|observation| {
104            (
105                observation.comparison_key(),
106                (observation.winner_label(), observation.ratio_value()),
107            )
108        })
109        .collect();
110
111    for report in reports.iter_mut() {
112        let mut compared = 0usize;
113        let mut agreements = 0usize;
114        let mut total_ratio_error = 0.0f64;
115        for observation in &report.observations {
116            if let Some((reference_winner, reference_ratio)) =
117                reference_lookup.get(&observation.comparison_key())
118            {
119                compared += 1;
120                if observation.winner_label() == *reference_winner {
121                    agreements += 1;
122                }
123                total_ratio_error += (observation.ratio_value() - *reference_ratio).abs();
124            }
125        }
126
127        if compared > 0 {
128            report.agreement_vs_reference = Some(agreements as f64 / compared as f64);
129            report.mean_ratio_error_vs_reference = Some(total_ratio_error / compared as f64);
130        }
131    }
132}
133
134pub fn read_source_metrics(path: &Path) -> std::io::Result<SourceMetrics> {
135    let content = fs::read_to_string(path)?;
136    let mut code_lines = 0usize;
137    let mut comment_lines = 0usize;
138    let mut branch_keywords = 0usize;
139
140    for line in content.lines() {
141        let trimmed = line.trim();
142        if trimmed.is_empty() {
143            continue;
144        }
145        if trimmed.starts_with("//") {
146            comment_lines += 1;
147            continue;
148        }
149        code_lines += 1;
150        if trimmed.starts_with("if ")
151            || trimmed.starts_with("for ")
152            || trimmed.starts_with("while ")
153            || trimmed.contains(" match ")
154            || trimmed.starts_with("match ")
155        {
156            branch_keywords += 1;
157        }
158    }
159
160    Ok(SourceMetrics {
161        code_lines,
162        comment_lines,
163        branch_keywords,
164    })
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[derive(Debug, Clone)]
172    struct StubObservation {
173        family: &'static str,
174        bucket: u32,
175        winner: &'static str,
176        ratio: f64,
177        coverage: f64,
178    }
179
180    impl ExperimentObservation for StubObservation {
181        type Key = (&'static str, u32);
182
183        fn comparison_key(&self) -> Self::Key {
184            (self.family, self.bucket)
185        }
186
187        fn winner_label(&self) -> &'static str {
188            self.winner
189        }
190
191        fn ratio_value(&self) -> f64 {
192            self.ratio
193        }
194
195        fn coverage_ratio(&self) -> f64 {
196            self.coverage
197        }
198    }
199
200    #[test]
201    fn annotate_against_reference_marks_agreement_and_ratio_error() {
202        let mut reports = vec![
203            ExperimentVariantReport {
204                descriptor: VariantDescriptor {
205                    id: "full-bucket",
206                    design_style: "reference",
207                    source_path: "ignored",
208                    knob_count: 0,
209                    reports_dispersion: true,
210                },
211                evaluation_runtime_ms: 10.0,
212                observations: vec![StubObservation {
213                    family: "case",
214                    bucket: 10,
215                    winner: "A",
216                    ratio: 1.2,
217                    coverage: 1.0,
218                }],
219                source_metrics: SourceMetrics {
220                    code_lines: 1,
221                    comment_lines: 0,
222                    branch_keywords: 0,
223                },
224                extensibility_metrics: ExtensibilityMetrics {
225                    average_coverage_ratio: 1.0,
226                    knob_count: 0,
227                    reports_dispersion: true,
228                },
229                agreement_vs_reference: None,
230                mean_ratio_error_vs_reference: None,
231            },
232            ExperimentVariantReport {
233                descriptor: VariantDescriptor {
234                    id: "candidate",
235                    design_style: "candidate",
236                    source_path: "ignored",
237                    knob_count: 0,
238                    reports_dispersion: false,
239                },
240                evaluation_runtime_ms: 5.0,
241                observations: vec![StubObservation {
242                    family: "case",
243                    bucket: 10,
244                    winner: "A",
245                    ratio: 1.1,
246                    coverage: 0.3,
247                }],
248                source_metrics: SourceMetrics {
249                    code_lines: 1,
250                    comment_lines: 0,
251                    branch_keywords: 0,
252                },
253                extensibility_metrics: ExtensibilityMetrics {
254                    average_coverage_ratio: 0.3,
255                    knob_count: 0,
256                    reports_dispersion: false,
257                },
258                agreement_vs_reference: None,
259                mean_ratio_error_vs_reference: None,
260            },
261        ];
262
263        annotate_against_reference(&mut reports, "full-bucket");
264
265        assert_eq!(reports[1].agreement_vs_reference, Some(1.0));
266        assert!(
267            (reports[1]
268                .mean_ratio_error_vs_reference
269                .expect("ratio error should be annotated")
270                - 0.1)
271                .abs()
272                < 1e-9
273        );
274    }
275
276    #[test]
277    fn average_coverage_ratio_uses_observation_contract() {
278        let observations = vec![
279            StubObservation {
280                family: "case",
281                bucket: 10,
282                winner: "A",
283                ratio: 1.0,
284                coverage: 0.1,
285            },
286            StubObservation {
287                family: "case",
288                bucket: 20,
289                winner: "B",
290                ratio: 0.9,
291                coverage: 0.5,
292            },
293        ];
294
295        assert!((average_coverage_ratio(&observations) - 0.3).abs() < 1e-9);
296    }
297}