diff --git a/src/capture.rs b/src/capture.rs index f025c9a..445558f 100644 --- a/src/capture.rs +++ b/src/capture.rs @@ -16,6 +16,22 @@ struct AudioOutputHandler { source_type: AudioSourceType, } +fn cm_time_to_ns(value: i64, timescale: i32) -> u64 { + if timescale <= 0 || value < 0 { + return 0; + } + + let value = value as u64; + let timescale = timescale as u64; + + if value <= u64::MAX / 1_000_000_000 { + (value * 1_000_000_000) / timescale + } else { + // For very large values, divide first to avoid multiplication overflow. + value / timescale * 1_000_000_000 + (value % timescale * 1_000_000_000) / timescale + } +} + impl SCStreamOutputTrait for AudioOutputHandler { fn did_output_sample_buffer(&self, sample_buffer: CMSampleBuffer, of_type: SCStreamOutputType) { let is_target = match (self.source_type, of_type) { @@ -31,8 +47,8 @@ impl SCStreamOutputTrait for AudioOutputHandler { if num_buffers == 0 { return; } - // Get first buffer to check channels - let first_buffer = audio_data.get(0).unwrap(); + // Get first buffer to check channels; skip malformed empty payloads. + let Some(first_buffer) = audio_data.get(0) else { return; }; let channels_per_buffer = first_buffer.number_channels as usize; if num_buffers == 1 { @@ -59,22 +75,7 @@ impl SCStreamOutputTrait for AudioOutputHandler { if !interleaved_samples.is_empty() { let pts = sample_buffer.presentation_timestamp(); - // Convert CMTime to nanoseconds (value / timescale * 1e9) - // Use saturating arithmetic to prevent overflow - let timestamp = if pts.timescale > 0 && pts.value >= 0 { - let value = pts.value as u64; - let timescale = pts.timescale as u64; - // Check for potential overflow before multiplication - if value <= u64::MAX / 1_000_000_000 { - (value * 1_000_000_000) / timescale - } else { - // For very large values, do division first to prevent overflow - value / timescale * 1_000_000_000 + - (value % timescale * 1_000_000_000) / timescale - } - } else { - 0 - }; + let timestamp = cm_time_to_ns(pts.value, pts.timescale); let packet = AudioFrame { source: self.source_type, @@ -154,3 +155,28 @@ pub fn spawn_capture_engine( stream.start_capture().map_err(|e| anyhow!("Failed to start capture: {}", e))?; Ok(stream) } + +#[cfg(test)] +mod tests { + use super::cm_time_to_ns; + + #[test] + fn cm_time_to_ns_converts_basic_values() { + assert_eq!(cm_time_to_ns(1, 1), 1_000_000_000); + assert_eq!(cm_time_to_ns(480, 48_000), 10_000_000); + } + + #[test] + fn cm_time_to_ns_rejects_invalid_input() { + assert_eq!(cm_time_to_ns(-1, 1), 0); + assert_eq!(cm_time_to_ns(1, 0), 0); + assert_eq!(cm_time_to_ns(1, -1), 0); + } + + #[test] + fn cm_time_to_ns_handles_large_values_without_overflow() { + let large = i64::MAX; + let out = cm_time_to_ns(large, 48_000); + assert!(out > 0); + } +} diff --git a/src/config.rs b/src/config.rs index 24adda1..1db0b9a 100644 --- a/src/config.rs +++ b/src/config.rs @@ -1,6 +1,6 @@ use pyo3::prelude::*; -#[pyclass] +#[pyclass(from_py_object)] #[derive(Clone, Debug)] pub struct AudioProcessingConfig { #[pyo3(get, set)] diff --git a/src/modular_pipeline.rs b/src/modular_pipeline.rs index 03f1f5c..4741428 100644 --- a/src/modular_pipeline.rs +++ b/src/modular_pipeline.rs @@ -10,6 +10,7 @@ use numpy::ToPyArray; use std::time::Instant; /// Modular pipeline that processes audio through a chain of processors +#[allow(dead_code)] pub struct ModularPipeline { rx: Receiver, stop_rx: Receiver<()>, @@ -20,6 +21,7 @@ pub struct ModularPipeline { stats: RuntimeStatsHandle, } +#[allow(dead_code)] impl ModularPipeline { fn select_total_pipeline_delay(input_timestamp: u64, output_timestamp: u64, processing_delay: u64) -> u64 { let timestamp_delay = output_timestamp.saturating_sub(input_timestamp); @@ -399,6 +401,43 @@ mod tests { fn reset(&mut self) {} } + struct PassAndDrainOnce { + drained: bool, + } + impl AudioProcessor for PassAndDrainOnce { + fn process(&mut self, frame: AudioFrame) -> anyhow::Result> { + Ok(Some(frame)) + } + fn drain_ready(&mut self) -> anyhow::Result> { + if self.drained { + Ok(None) + } else { + self.drained = true; + Ok(Some(AudioFrame { + source: AudioSourceType::Microphone, + samples: vec![0.7; 4], + sample_rate: 48_000, + channels: 1, + timestamp: 123, + })) + } + } + fn flush(&mut self) -> Vec { Vec::new() } + fn reset(&mut self) {} + } + + struct ErrorProcessor; + impl AudioProcessor for ErrorProcessor { + fn process(&mut self, _frame: AudioFrame) -> anyhow::Result> { + Err(anyhow::anyhow!("boom")) + } + fn drain_ready(&mut self) -> anyhow::Result> { + Err(anyhow::anyhow!("drain boom")) + } + fn flush(&mut self) -> Vec { Vec::new() } + fn reset(&mut self) {} + } + #[test] fn total_pipeline_delay_prefers_timestamp_when_plausible() { let d = ModularPipeline::select_total_pipeline_delay(1_000, 2_000, 777); @@ -437,4 +476,63 @@ mod tests { let out = ModularPipeline::process_through_processors_static(&mut processors, input, &stats); assert!(out.is_empty()); } + + #[test] + fn static_pipeline_collects_drain_ready_frames() { + let mut processors: Vec> = + vec![Box::new(PassAndDrainOnce { drained: false })]; + let stats = RuntimeStatsHandle::new(); + let input = AudioFrame { + source: AudioSourceType::Microphone, + samples: vec![0.1; 4], + sample_rate: 48_000, + channels: 1, + timestamp: 0, + }; + + let out = ModularPipeline::process_through_processors_static(&mut processors, input, &stats); + assert_eq!(out.len(), 2); + assert_eq!(out[0].samples.len(), 4); + assert_eq!(out[1].timestamp, 123); + } + + #[test] + fn static_pipeline_counts_process_and_drain_errors() { + let mut processors: Vec> = vec![Box::new(ErrorProcessor)]; + let stats = RuntimeStatsHandle::new(); + let input = AudioFrame { + source: AudioSourceType::Microphone, + samples: vec![0.1; 4], + sample_rate: 48_000, + channels: 1, + timestamp: 0, + }; + + let out = ModularPipeline::process_through_processors_static(&mut processors, input, &stats); + assert!(out.is_empty()); + let snap = stats.snapshot(); + assert_eq!(snap.processor_errors, 1); + assert_eq!(snap.processor_drain_errors, 1); + } + + #[test] + fn static_pipeline_processes_multiple_processors_in_order() { + let mut processors: Vec> = vec![ + Box::new(PassAndDrainOnce { drained: false }), + Box::new(PassAndDrainOnce { drained: false }), + ]; + let stats = RuntimeStatsHandle::new(); + let input = AudioFrame { + source: AudioSourceType::Microphone, + samples: vec![0.2; 8], + sample_rate: 48_000, + channels: 1, + timestamp: 42, + }; + + let out = ModularPipeline::process_through_processors_static(&mut processors, input, &stats); + assert!(!out.is_empty()); + // At least one frame should survive both processors. + assert!(out.iter().any(|f| f.timestamp == 42)); + } } diff --git a/src/processors/aec.rs b/src/processors/aec.rs index 044c7b7..f1e580a 100644 --- a/src/processors/aec.rs +++ b/src/processors/aec.rs @@ -39,7 +39,7 @@ impl AecProcessor { pub fn new(config: AudioProcessingConfig, stats: RuntimeStatsHandle) -> Self { let apm = if config.enable_aec { - Some(Self::create_apm(&config)) + Self::create_apm(&config) } else { None }; @@ -265,11 +265,17 @@ impl AecProcessor { apm_config } - fn create_apm(config: &AudioProcessingConfig) -> Processor { - let apm = Processor::new(48_000).expect("Failed to create WebRTC Processor for AEC"); + fn create_apm(config: &AudioProcessingConfig) -> Option { + let apm = match Processor::new(48_000) { + Ok(apm) => apm, + Err(err) => { + eprintln!("Warning: failed to create AEC processor: {}", err); + return None; + } + }; let delay_ms = config.aec_stream_delay_ms.max(0); apm.set_config(Self::build_apm_config(delay_ms)); - apm + Some(apm) } } @@ -375,4 +381,51 @@ mod tests { let out = aec.process(frame(AudioSourceType::System, 0.5)).unwrap(); assert!(out.is_none()); } + + #[test] + fn tuner_freezes_after_stable_high_erle() { + let stats = RuntimeStatsHandle::new(); + let mut aec = AecProcessor::new(config(false, true), stats.clone()); + + for _ in 0..8 { + let _ = aec.tune_delay_on_the_fly(4.0, Some(10)); + } + + assert!(aec.tuner_frozen); + assert_eq!(aec.applied_delay_ms, aec.tuner_best_delay_ms); + assert_eq!(stats.snapshot().aec_tuner.freeze_events, 1); + } + + #[test] + fn tuner_rolls_back_to_best_when_quality_drops() { + let stats = RuntimeStatsHandle::new(); + let mut aec = AecProcessor::new(config(false, true), stats.clone()); + aec.applied_delay_ms = 40; + aec.tuner_best_delay_ms = 20; + aec.tuner_best_erle = Some(5.0); + aec.tuner_erle_ema = Some(5.0); + + let tuned = aec.tune_delay_on_the_fly(0.0, Some(10)); + assert!(tuned); + assert_eq!(aec.applied_delay_ms, 20); + assert_eq!(stats.snapshot().aec_tuner.rollback_events, 1); + } + + #[test] + fn tuner_clamps_delay_to_max_bound() { + let stats = RuntimeStatsHandle::new(); + let mut aec = AecProcessor::new(config(false, true), stats); + aec.tuner_max_delay_ms = 100; + aec.tuner_best_delay_ms = 95; + aec.applied_delay_ms = 98; + aec.tuner_step_ms = 8; + aec.tuner_direction = 1; + aec.tuner_best_erle = Some(1.0); + aec.tuner_erle_ema = Some(1.0); + + let tuned = aec.tune_delay_on_the_fly(1.0, Some(10)); + assert!(tuned); + assert_eq!(aec.applied_delay_ms, 100); + assert!(aec.applied_delay_ms <= aec.tuner_max_delay_ms); + } } diff --git a/src/processors/noise_suppression.rs b/src/processors/noise_suppression.rs index f8d270f..3e8f1e1 100644 --- a/src/processors/noise_suppression.rs +++ b/src/processors/noise_suppression.rs @@ -14,7 +14,7 @@ pub struct NoiseSuppressionProcessor { impl NoiseSuppressionProcessor { pub fn new(config: AudioProcessingConfig) -> Self { let apm = if config.enable_ns { - Some(Self::create_ns_apm(&config)) + Self::create_ns_apm(&config) } else { None }; @@ -25,8 +25,14 @@ impl NoiseSuppressionProcessor { } } - fn create_ns_apm(_config: &AudioProcessingConfig) -> Processor { - let apm = Processor::new(48_000).expect("Failed to create WebRTC Processor for Noise Suppression"); + fn create_ns_apm(_config: &AudioProcessingConfig) -> Option { + let apm = match Processor::new(48_000) { + Ok(apm) => apm, + Err(err) => { + eprintln!("Warning: failed to create NS processor: {}", err); + return None; + } + }; let mut apm_config = Config::default(); @@ -44,7 +50,7 @@ impl NoiseSuppressionProcessor { apm_config.gain_controller = None; apm.set_config(apm_config); - apm + Some(apm) } } @@ -88,7 +94,7 @@ impl AudioProcessor for NoiseSuppressionProcessor { fn reset(&mut self) { // Reset NS processor state if needed if self.config.enable_ns { - self.apm = Some(Self::create_ns_apm(&self.config)); + self.apm = Self::create_ns_apm(&self.config); } else { self.apm = None; } diff --git a/src/processors/resample.rs b/src/processors/resample.rs index 1de3f60..1ed4d98 100644 --- a/src/processors/resample.rs +++ b/src/processors/resample.rs @@ -26,6 +26,7 @@ impl<'a> Adapter<'a, f32> for PlanarBuffer<'a> { struct StreamState { resampler: Option>, + source_rate: u32, input_buffer: VecDeque, output_queue: VecDeque, current_timestamp: u64, @@ -36,20 +37,30 @@ struct StreamState { impl StreamState { fn new(source_rate: u32, target_rate: u32, target_channels: u16, chunk_size: usize) -> Self { let resampler = if source_rate != target_rate { - Some(Fft::::new( + match Fft::::new( source_rate as usize, target_rate as usize, chunk_size, 1, target_channels as usize, FixedSync::Input, - ).expect("Failed to create resampler")) + ) { + Ok(resampler) => Some(resampler), + Err(err) => { + eprintln!( + "Warning: failed to create resampler {}->{}Hz: {}. Falling back to passthrough.", + source_rate, target_rate, err + ); + None + } + } } else { None }; Self { resampler, + source_rate, input_buffer: VecDeque::with_capacity(chunk_size * 4), output_queue: VecDeque::with_capacity(chunk_size * 4), current_timestamp: 0, @@ -158,12 +169,18 @@ impl ResampleProcessor { if channels == 1 { for _ in 0..chunk_size { - planar_data[0].push(state.input_buffer.pop_front().unwrap()); + let Some(sample) = state.input_buffer.pop_front() else { + return results; + }; + planar_data[0].push(sample); } } else { for _ in 0..chunk_size { for channel_buf in &mut planar_data { - channel_buf.push(state.input_buffer.pop_front().unwrap()); + let Some(sample) = state.input_buffer.pop_front() else { + return results; + }; + channel_buf.push(sample); } } } @@ -178,12 +195,20 @@ impl ResampleProcessor { let mut samples = Vec::new(); if channels == 1 { for i in 0..output.frames() { - samples.push(output.read_sample(0, i).unwrap()); + if let Some(sample) = output.read_sample(0, i) { + samples.push(sample); + } else { + return results; + } } } else { for i in 0..output.frames() { for ch in 0..channels { - samples.push(output.read_sample(ch, i).unwrap()); + if let Some(sample) = output.read_sample(ch, i) { + samples.push(sample); + } else { + return results; + } } } } @@ -214,7 +239,11 @@ impl ResampleProcessor { results.push(AudioFrame { source, samples, - sample_rate: target_rate, + sample_rate: if state.source_rate == target_rate { + target_rate + } else { + state.source_rate + }, channels: target_channels, timestamp: frame_ts, }); diff --git a/src/stats.rs b/src/stats.rs index f59827d..f3e0057 100644 --- a/src/stats.rs +++ b/src/stats.rs @@ -103,6 +103,7 @@ impl RuntimeStatsHandle { } } + #[allow(dead_code)] pub fn reset(&self) { if let Ok(mut stats) = self.inner.lock() { *stats = RuntimeStats::default();