Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 44 additions & 18 deletions src/capture.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ struct AudioOutputHandler {
source_type: AudioSourceType,
}

fn cm_time_to_ns(value: i64, timescale: i32) -> u64 {
if timescale <= 0 || value < 0 {
return 0;
}

let value = value as u64;
let timescale = timescale as u64;

if value <= u64::MAX / 1_000_000_000 {
(value * 1_000_000_000) / timescale
} else {
// For very large values, divide first to avoid multiplication overflow.
value / timescale * 1_000_000_000 + (value % timescale * 1_000_000_000) / timescale
}
}

impl SCStreamOutputTrait for AudioOutputHandler {
fn did_output_sample_buffer(&self, sample_buffer: CMSampleBuffer, of_type: SCStreamOutputType) {
let is_target = match (self.source_type, of_type) {
Expand All @@ -31,8 +47,8 @@ impl SCStreamOutputTrait for AudioOutputHandler {

if num_buffers == 0 { return; }

// Get first buffer to check channels
let first_buffer = audio_data.get(0).unwrap();
// Get first buffer to check channels; skip malformed empty payloads.
let Some(first_buffer) = audio_data.get(0) else { return; };
let channels_per_buffer = first_buffer.number_channels as usize;

if num_buffers == 1 {
Expand All @@ -59,22 +75,7 @@ impl SCStreamOutputTrait for AudioOutputHandler {

if !interleaved_samples.is_empty() {
let pts = sample_buffer.presentation_timestamp();
// Convert CMTime to nanoseconds (value / timescale * 1e9)
// Use saturating arithmetic to prevent overflow
let timestamp = if pts.timescale > 0 && pts.value >= 0 {
let value = pts.value as u64;
let timescale = pts.timescale as u64;
// Check for potential overflow before multiplication
if value <= u64::MAX / 1_000_000_000 {
(value * 1_000_000_000) / timescale
} else {
// For very large values, do division first to prevent overflow
value / timescale * 1_000_000_000 +
(value % timescale * 1_000_000_000) / timescale
}
} else {
0
};
let timestamp = cm_time_to_ns(pts.value, pts.timescale);

let packet = AudioFrame {
source: self.source_type,
Expand Down Expand Up @@ -154,3 +155,28 @@ pub fn spawn_capture_engine(
stream.start_capture().map_err(|e| anyhow!("Failed to start capture: {}", e))?;
Ok(stream)
}

#[cfg(test)]
mod tests {
use super::cm_time_to_ns;

#[test]
fn cm_time_to_ns_converts_basic_values() {
assert_eq!(cm_time_to_ns(1, 1), 1_000_000_000);
assert_eq!(cm_time_to_ns(480, 48_000), 10_000_000);
}

#[test]
fn cm_time_to_ns_rejects_invalid_input() {
assert_eq!(cm_time_to_ns(-1, 1), 0);
assert_eq!(cm_time_to_ns(1, 0), 0);
assert_eq!(cm_time_to_ns(1, -1), 0);
}

#[test]
fn cm_time_to_ns_handles_large_values_without_overflow() {
let large = i64::MAX;
let out = cm_time_to_ns(large, 48_000);
assert!(out > 0);
}
}
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use pyo3::prelude::*;

#[pyclass]
#[pyclass(from_py_object)]
#[derive(Clone, Debug)]
pub struct AudioProcessingConfig {
#[pyo3(get, set)]
Expand Down
98 changes: 98 additions & 0 deletions src/modular_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ use numpy::ToPyArray;
use std::time::Instant;

/// Modular pipeline that processes audio through a chain of processors
#[allow(dead_code)]
pub struct ModularPipeline {
rx: Receiver<AudioFrame>,
stop_rx: Receiver<()>,
Expand All @@ -20,6 +21,7 @@ pub struct ModularPipeline {
stats: RuntimeStatsHandle,
}

#[allow(dead_code)]
impl ModularPipeline {
fn select_total_pipeline_delay(input_timestamp: u64, output_timestamp: u64, processing_delay: u64) -> u64 {
let timestamp_delay = output_timestamp.saturating_sub(input_timestamp);
Expand Down Expand Up @@ -399,6 +401,43 @@ mod tests {
fn reset(&mut self) {}
}

struct PassAndDrainOnce {
drained: bool,
}
impl AudioProcessor for PassAndDrainOnce {
fn process(&mut self, frame: AudioFrame) -> anyhow::Result<Option<AudioFrame>> {
Ok(Some(frame))
}
fn drain_ready(&mut self) -> anyhow::Result<Option<AudioFrame>> {
if self.drained {
Ok(None)
} else {
self.drained = true;
Ok(Some(AudioFrame {
source: AudioSourceType::Microphone,
samples: vec![0.7; 4],
sample_rate: 48_000,
channels: 1,
timestamp: 123,
}))
}
}
fn flush(&mut self) -> Vec<AudioFrame> { Vec::new() }
fn reset(&mut self) {}
}

struct ErrorProcessor;
impl AudioProcessor for ErrorProcessor {
fn process(&mut self, _frame: AudioFrame) -> anyhow::Result<Option<AudioFrame>> {
Err(anyhow::anyhow!("boom"))
}
fn drain_ready(&mut self) -> anyhow::Result<Option<AudioFrame>> {
Err(anyhow::anyhow!("drain boom"))
}
fn flush(&mut self) -> Vec<AudioFrame> { Vec::new() }
fn reset(&mut self) {}
}

#[test]
fn total_pipeline_delay_prefers_timestamp_when_plausible() {
let d = ModularPipeline::select_total_pipeline_delay(1_000, 2_000, 777);
Expand Down Expand Up @@ -437,4 +476,63 @@ mod tests {
let out = ModularPipeline::process_through_processors_static(&mut processors, input, &stats);
assert!(out.is_empty());
}

#[test]
fn static_pipeline_collects_drain_ready_frames() {
let mut processors: Vec<Box<dyn AudioProcessor>> =
vec![Box::new(PassAndDrainOnce { drained: false })];
let stats = RuntimeStatsHandle::new();
let input = AudioFrame {
source: AudioSourceType::Microphone,
samples: vec![0.1; 4],
sample_rate: 48_000,
channels: 1,
timestamp: 0,
};

let out = ModularPipeline::process_through_processors_static(&mut processors, input, &stats);
assert_eq!(out.len(), 2);
assert_eq!(out[0].samples.len(), 4);
assert_eq!(out[1].timestamp, 123);
}

#[test]
fn static_pipeline_counts_process_and_drain_errors() {
let mut processors: Vec<Box<dyn AudioProcessor>> = vec![Box::new(ErrorProcessor)];
let stats = RuntimeStatsHandle::new();
let input = AudioFrame {
source: AudioSourceType::Microphone,
samples: vec![0.1; 4],
sample_rate: 48_000,
channels: 1,
timestamp: 0,
};

let out = ModularPipeline::process_through_processors_static(&mut processors, input, &stats);
assert!(out.is_empty());
let snap = stats.snapshot();
assert_eq!(snap.processor_errors, 1);
assert_eq!(snap.processor_drain_errors, 1);
}

#[test]
fn static_pipeline_processes_multiple_processors_in_order() {
let mut processors: Vec<Box<dyn AudioProcessor>> = vec![
Box::new(PassAndDrainOnce { drained: false }),
Box::new(PassAndDrainOnce { drained: false }),
];
let stats = RuntimeStatsHandle::new();
let input = AudioFrame {
source: AudioSourceType::Microphone,
samples: vec![0.2; 8],
sample_rate: 48_000,
channels: 1,
timestamp: 42,
};

let out = ModularPipeline::process_through_processors_static(&mut processors, input, &stats);
assert!(!out.is_empty());
// At least one frame should survive both processors.
assert!(out.iter().any(|f| f.timestamp == 42));
}
}
61 changes: 57 additions & 4 deletions src/processors/aec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ impl AecProcessor {

pub fn new(config: AudioProcessingConfig, stats: RuntimeStatsHandle) -> Self {
let apm = if config.enable_aec {
Some(Self::create_apm(&config))
Self::create_apm(&config)
} else {
None
};
Expand Down Expand Up @@ -265,11 +265,17 @@ impl AecProcessor {
apm_config
}

fn create_apm(config: &AudioProcessingConfig) -> Processor {
let apm = Processor::new(48_000).expect("Failed to create WebRTC Processor for AEC");
fn create_apm(config: &AudioProcessingConfig) -> Option<Processor> {
let apm = match Processor::new(48_000) {
Ok(apm) => apm,
Err(err) => {
eprintln!("Warning: failed to create AEC processor: {}", err);
return None;
}
};
let delay_ms = config.aec_stream_delay_ms.max(0);
apm.set_config(Self::build_apm_config(delay_ms));
apm
Some(apm)
}
}

Expand Down Expand Up @@ -375,4 +381,51 @@ mod tests {
let out = aec.process(frame(AudioSourceType::System, 0.5)).unwrap();
assert!(out.is_none());
}

#[test]
fn tuner_freezes_after_stable_high_erle() {
let stats = RuntimeStatsHandle::new();
let mut aec = AecProcessor::new(config(false, true), stats.clone());

for _ in 0..8 {
let _ = aec.tune_delay_on_the_fly(4.0, Some(10));
}

assert!(aec.tuner_frozen);
assert_eq!(aec.applied_delay_ms, aec.tuner_best_delay_ms);
assert_eq!(stats.snapshot().aec_tuner.freeze_events, 1);
}

#[test]
fn tuner_rolls_back_to_best_when_quality_drops() {
let stats = RuntimeStatsHandle::new();
let mut aec = AecProcessor::new(config(false, true), stats.clone());
aec.applied_delay_ms = 40;
aec.tuner_best_delay_ms = 20;
aec.tuner_best_erle = Some(5.0);
aec.tuner_erle_ema = Some(5.0);

let tuned = aec.tune_delay_on_the_fly(0.0, Some(10));
assert!(tuned);
assert_eq!(aec.applied_delay_ms, 20);
assert_eq!(stats.snapshot().aec_tuner.rollback_events, 1);
}

#[test]
fn tuner_clamps_delay_to_max_bound() {
let stats = RuntimeStatsHandle::new();
let mut aec = AecProcessor::new(config(false, true), stats);
aec.tuner_max_delay_ms = 100;
aec.tuner_best_delay_ms = 95;
aec.applied_delay_ms = 98;
aec.tuner_step_ms = 8;
aec.tuner_direction = 1;
aec.tuner_best_erle = Some(1.0);
aec.tuner_erle_ema = Some(1.0);

let tuned = aec.tune_delay_on_the_fly(1.0, Some(10));
assert!(tuned);
assert_eq!(aec.applied_delay_ms, 100);
assert!(aec.applied_delay_ms <= aec.tuner_max_delay_ms);
}
}
16 changes: 11 additions & 5 deletions src/processors/noise_suppression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ pub struct NoiseSuppressionProcessor {
impl NoiseSuppressionProcessor {
pub fn new(config: AudioProcessingConfig) -> Self {
let apm = if config.enable_ns {
Some(Self::create_ns_apm(&config))
Self::create_ns_apm(&config)
} else {
None
};
Expand All @@ -25,8 +25,14 @@ impl NoiseSuppressionProcessor {
}
}

fn create_ns_apm(_config: &AudioProcessingConfig) -> Processor {
let apm = Processor::new(48_000).expect("Failed to create WebRTC Processor for Noise Suppression");
fn create_ns_apm(_config: &AudioProcessingConfig) -> Option<Processor> {
let apm = match Processor::new(48_000) {
Ok(apm) => apm,
Err(err) => {
eprintln!("Warning: failed to create NS processor: {}", err);
return None;
}
};

let mut apm_config = Config::default();

Expand All @@ -44,7 +50,7 @@ impl NoiseSuppressionProcessor {
apm_config.gain_controller = None;

apm.set_config(apm_config);
apm
Some(apm)
}
}

Expand Down Expand Up @@ -88,7 +94,7 @@ impl AudioProcessor for NoiseSuppressionProcessor {
fn reset(&mut self) {
// Reset NS processor state if needed
if self.config.enable_ns {
self.apm = Some(Self::create_ns_apm(&self.config));
self.apm = Self::create_ns_apm(&self.config);
} else {
self.apm = None;
}
Expand Down
Loading
Loading