diff --git a/crates/wavekat-turn/Cargo.toml b/crates/wavekat-turn/Cargo.toml index 9700852..8f2bc82 100644 --- a/crates/wavekat-turn/Cargo.toml +++ b/crates/wavekat-turn/Cargo.toml @@ -11,10 +11,11 @@ readme = "../../README.md" keywords = ["turn-detection", "voice", "audio", "telephony", "wavekat"] categories = ["multimedia::audio"] exclude = ["CHANGELOG.md"] +build = "build.rs" [features] default = [] -pipecat = ["dep:ort", "dep:ndarray"] +pipecat = ["dep:ort", "dep:ndarray", "dep:realfft", "dep:ureq"] livekit = ["dep:ort", "dep:ndarray"] [dependencies] @@ -22,8 +23,12 @@ wavekat-core = "0.0.2" thiserror = "2" # ONNX backends (optional) -ort = { version = "2.0.0-rc.12", optional = true } -ndarray = { version = "0.16", optional = true } +ort = { version = "2.0.0-rc.12", optional = true, features = ["ndarray"] } +ndarray = { version = "0.17", optional = true } +realfft = { version = "3", optional = true } + +[build-dependencies] +ureq = { version = "3", optional = true } [dev-dependencies] hound = "3.5" diff --git a/crates/wavekat-turn/build.rs b/crates/wavekat-turn/build.rs new file mode 100644 index 0000000..4bb440a --- /dev/null +++ b/crates/wavekat-turn/build.rs @@ -0,0 +1,101 @@ +//! Build script for wavekat-turn. +//! +//! Downloads ONNX model(s) at build time when the corresponding feature flag +//! is enabled. Models are written to `OUT_DIR` and embedded via `include_bytes!`. +//! +//! # Environment variables +//! +//! | Variable | Effect | +//! |----------------------------------|--------------------------------------------| +//! | `PIPECAT_SMARTTURN_MODEL_PATH` | Use this local file instead of downloading | +//! | `PIPECAT_SMARTTURN_MODEL_URL` | Override the download URL | +//! | `DOCS_RS` | Write zero-byte placeholder (no network) | + +#[allow(unused_imports)] +use std::env; +#[allow(unused_imports)] +use std::fs; +#[allow(unused_imports)] +use std::path::Path; + +fn main() { + // docs.rs builds with --network none; write empty placeholders so + // include_bytes! compiles without downloading anything. + if env::var("DOCS_RS").is_err() { + #[cfg(feature = "pipecat")] + setup_pipecat_model(); + } else { + #[cfg(feature = "pipecat")] + { + let out_dir = env::var("OUT_DIR").expect("OUT_DIR not set"); + let model_path = Path::new(&out_dir).join("smart-turn-v3.2-cpu.onnx"); + if !model_path.exists() { + fs::write(&model_path, b"").expect("failed to write placeholder model"); + } + } + } +} + +#[cfg(feature = "pipecat")] +fn setup_pipecat_model() { + const DEFAULT_MODEL_URL: &str = + "https://huggingface.co/pipecat-ai/smart-turn-v3/resolve/main/smart-turn-v3.2-cpu.onnx"; + const MODEL_FILE: &str = "smart-turn-v3.2-cpu.onnx"; + // Bump this string when updating the default model URL so cached builds + // re-download the new version. + const MODEL_VERSION: &str = "v3.2-cpu"; + + println!("cargo:rerun-if-env-changed=PIPECAT_SMARTTURN_MODEL_PATH"); + println!("cargo:rerun-if-env-changed=PIPECAT_SMARTTURN_MODEL_URL"); + + let out_dir = env::var("OUT_DIR").expect("OUT_DIR not set"); + let model_dest = Path::new(&out_dir).join(MODEL_FILE); + let version_marker = Path::new(&out_dir).join("smart-turn.version"); + + // Option 1: caller provides a local model file + if let Ok(local_path) = env::var("PIPECAT_SMARTTURN_MODEL_PATH") { + let local_path = Path::new(&local_path); + if !local_path.exists() { + panic!( + "PIPECAT_SMARTTURN_MODEL_PATH points to a non-existent file: {}", + local_path.display() + ); + } + println!( + "cargo:warning=Using local Pipecat Smart Turn model: {}", + local_path.display() + ); + fs::copy(local_path, &model_dest).expect("failed to copy local model file"); + println!("cargo:rerun-if-changed={}", local_path.display()); + return; + } + + // Skip download if a matching version is already cached + let cached_version = fs::read_to_string(&version_marker).unwrap_or_default(); + if model_dest.exists() && cached_version.trim() == MODEL_VERSION { + return; + } + + // Option 2: download (caller may override the URL) + let url = + env::var("PIPECAT_SMARTTURN_MODEL_URL").unwrap_or_else(|_| DEFAULT_MODEL_URL.to_string()); + + println!("cargo:warning=Downloading Pipecat Smart Turn model ({MODEL_VERSION}) from {url}"); + + let response = ureq::get(&url) + .call() + .unwrap_or_else(|e| panic!("failed to download Pipecat Smart Turn model from {url}: {e}")); + + let bytes = response + .into_body() + .read_to_vec() + .expect("failed to read model bytes"); + + fs::write(&model_dest, &bytes).expect("failed to write model file"); + fs::write(&version_marker, MODEL_VERSION).expect("failed to write version marker"); + + println!( + "cargo:warning=Pipecat Smart Turn model ({MODEL_VERSION}) downloaded to {}", + model_dest.display() + ); +} diff --git a/crates/wavekat-turn/src/audio/pipecat.rs b/crates/wavekat-turn/src/audio/pipecat.rs index 2c9e168..4554b1d 100644 --- a/crates/wavekat-turn/src/audio/pipecat.rs +++ b/crates/wavekat-turn/src/audio/pipecat.rs @@ -4,38 +4,526 @@ //! Expects 16 kHz f32 PCM input. Telephony audio at 8 kHz must be //! upsampled before feeding to this detector. //! -//! - Model size: ~8 MB (int8 quantized ONNX) -//! - Inference: ~12 ms on CPU +//! # Model +//! +//! - Source: +//! - File: `smart-turn-v3.2-cpu.onnx` (int8 quantized, ~8 MB) //! - License: BSD 2-Clause +//! +//! # Tensor specification +//! +//! | Role | Name | Shape | Dtype | +//! |--------|------------------|----------------|---------| +//! | Input | `input_features` | `[B, 80, 800]` | float32 | +//! | Output | `logits` | `[B, 1]` | float32 | +//! +//! Despite the name, `logits` is a **sigmoid probability** P(turn complete) +//! in [0, 1] — the sigmoid is fused into the model before ONNX export. +//! Threshold: `probability > 0.5` → `TurnState::Finished`. +//! +//! # Mel-feature specification +//! +//! The model was trained with HuggingFace `WhisperFeatureExtractor(chunk_length=8)`: +//! +//! | Parameter | Value | +//! |---------------|--------------------------------| +//! | Sample rate | 16 000 Hz | +//! | n_fft | 400 samples (25 ms) | +//! | hop_length | 160 samples (10 ms) | +//! | n_mels | 80 | +//! | Freq range | 0 – 8 000 Hz | +//! | Mel scale | Slaney (NOT HTK) | +//! | Window | Hann (periodic, size 400) | +//! | Pre-emphasis | None | +//! | Log | log10 with ε = 1e-10 | +//! | Normalization | clamp(max − 8), (x + 4) / 4 | +//! +//! # Audio buffer +//! +//! - Exactly **8 seconds = 128 000 samples** at 16 kHz. +//! - Shorter input: **front-padded** with zeros (audio is at the end). +//! - Longer input: the **last** 8 s is used (oldest samples discarded). + +use std::collections::VecDeque; +use std::path::Path; +use std::sync::Arc; +use std::time::Instant; + +use ndarray::{s, Array2, Array3}; +use ort::{inputs, value::Tensor}; +use realfft::num_complex::Complex; +use realfft::{RealFftPlanner, RealToComplex}; + +use crate::onnx; +use crate::{AudioFrame, AudioTurnDetector, StageTiming, TurnError, TurnPrediction, TurnState}; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +/// Sample rate the model expects. +const SAMPLE_RATE: u32 = 16_000; +/// FFT window size in samples (25 ms at 16 kHz). +const N_FFT: usize = 400; +/// STFT hop length in samples (10 ms at 16 kHz). +const HOP_LENGTH: usize = 160; +/// Number of mel filterbank bins. +const N_MELS: usize = 80; +/// Number of STFT frames the model expects (8 s × 100 fps). +const N_FRAMES: usize = 800; +/// FFT frequency bins: N_FFT/2 + 1. +const N_FREQS: usize = N_FFT / 2 + 1; // 201 +/// Ring buffer capacity: 8 s × 16 kHz. +const RING_CAPACITY: usize = 8 * SAMPLE_RATE as usize; // 128 000 + +/// Embedded ONNX model bytes, downloaded by build.rs at compile time. +const MODEL_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx")); + +// --------------------------------------------------------------------------- +// Mel feature extractor +// --------------------------------------------------------------------------- + +/// Pre-computed Whisper-style log-mel feature extractor. +/// +/// All expensive setup (filterbank, window, FFT plan) happens once in [`new`]. +/// [`MelExtractor::extract`] is then called per inference. +struct MelExtractor { + /// Slaney-normalised mel filterbank: shape [N_MELS, N_FREQS]. + mel_filters: Array2, + /// Periodic Hann window of length N_FFT. + hann_window: Vec, + /// Reusable forward real FFT plan. + fft: Arc>, + /// Reusable scratch buffer for the FFT. + fft_scratch: Vec>, + /// Reusable output spectrum buffer (N_FREQS complex values). + spectrum_buf: Vec>, + /// Cached power spectrogram [N_FREQS × (N_FRAMES+1)] from the previous call. + /// Enables incremental STFT: only new frames are recomputed. + cached_power_spec: Option>, + /// Cached mel spectrogram [N_MELS × N_FRAMES] from the previous call. + /// Enables incremental mel filterbank: only new columns are recomputed. + cached_mel_spec: Option>, +} + +impl MelExtractor { + fn new() -> Self { + let mel_filters = build_mel_filters( + SAMPLE_RATE as usize, + N_FFT, + N_MELS, + 0.0, + SAMPLE_RATE as f32 / 2.0, + ); + let hann_window = periodic_hann(N_FFT); + + let mut planner = RealFftPlanner::::new(); + let fft = planner.plan_fft_forward(N_FFT); + let fft_scratch = fft.make_scratch_vec(); + let spectrum_buf = fft.make_output_vec(); + + Self { + mel_filters, + hann_window, + fft, + fft_scratch, + spectrum_buf, + cached_power_spec: None, + cached_mel_spec: None, + } + } + + /// Compute a [N_MELS × N_FRAMES] log-mel spectrogram from exactly + /// `RING_CAPACITY` samples of 16 kHz mono audio. + /// + /// `shift_frames` is how many STFT frames worth of new audio were added + /// since the last call. When a valid cache exists and `shift_frames` is + /// in range, only the last `shift_frames` columns of the power spectrogram + /// are recomputed; the rest are copied from the shifted cache. + fn extract(&mut self, audio: &[f32], shift_frames: usize) -> Array2 { + debug_assert_eq!(audio.len(), RING_CAPACITY); + + // ---- Center-pad: N_FFT/2 zeros on each side → 128 400 samples ---- + // This replicates librosa/PyTorch `center=True` STFT behaviour, which + // gives exactly N_FRAMES + 1 = 801 frames; we discard the last one. + let pad = N_FFT / 2; + let mut padded = vec![0.0f32; pad + audio.len() + pad]; + padded[pad..pad + audio.len()].copy_from_slice(audio); + + // n_total = (128 400 − 400) / 160 + 1 = 801 + let n_total_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1; + + // ---- Incremental STFT ---- + // If we have a cached power spec and shift_frames < n_total_frames, + // reuse the unchanged frames by shifting the cache left and only + // computing the `shift_frames` new columns at the end. + let first_new_frame = match &self.cached_power_spec { + Some(cached) if shift_frames > 0 && shift_frames < n_total_frames => { + let kept = n_total_frames - shift_frames; + let mut power_spec = Array2::::zeros((N_FREQS, n_total_frames)); + power_spec + .slice_mut(s![.., ..kept]) + .assign(&cached.slice(s![.., shift_frames..])); + self.cached_power_spec = Some(power_spec); + kept // only compute frames [kept..n_total_frames] + } + _ => { + self.cached_power_spec = Some(Array2::::zeros((N_FREQS, n_total_frames))); + 0 // cold start: compute all frames + } + }; + + let power_spec = self.cached_power_spec.as_mut().unwrap(); + let mut frame_buf = vec![0.0f32; N_FFT]; + + for frame_idx in first_new_frame..n_total_frames { + let start = frame_idx * HOP_LENGTH; + // Apply periodic Hann window + for (i, (&s, &w)) in padded[start..start + N_FFT] + .iter() + .zip(self.hann_window.iter()) + .enumerate() + { + frame_buf[i] = s * w; + } + + self.fft + .process_with_scratch( + &mut frame_buf, + &mut self.spectrum_buf, + &mut self.fft_scratch, + ) + .expect("FFT failed: internal buffer size mismatch"); + + for (k, c) in self.spectrum_buf.iter().enumerate() { + power_spec[[k, frame_idx]] = c.re * c.re + c.im * c.im; + } + } + + // Take first N_FRAMES columns (drop the trailing frame) + let power_spec_view = power_spec.slice(s![.., ..N_FRAMES]); + + // ---- Incremental mel filterbank: [N_MELS, N_FREQS] × [N_FREQS, shift_frames] ---- + // Reuse the cached mel columns for the unchanged frames; only multiply + // the new power-spectrum columns against the filterbank. + let mel_spec = match &self.cached_mel_spec { + Some(cached) if shift_frames > 0 && shift_frames <= N_FRAMES => { + let kept = N_FRAMES - shift_frames; + let mut ms = Array2::::zeros((N_MELS, N_FRAMES)); + // Shift old columns left + ms.slice_mut(s![.., ..kept]) + .assign(&cached.slice(s![.., shift_frames..])); + // Apply filterbank only to the new power-spectrum columns + let new_power = power_spec_view.slice(s![.., kept..]); + ms.slice_mut(s![.., kept..]) + .assign(&self.mel_filters.dot(&new_power)); + ms + } + _ => self.mel_filters.dot(&power_spec_view), + }; + self.cached_mel_spec = Some(mel_spec.clone()); + + // ---- Log10 with floor at 1e-10 ---- + let mut log_mel = mel_spec.mapv(|x| x.max(1e-10_f32).log10()); + + // ---- Dynamic range compression and normalization ---- + // Matches WhisperFeatureExtractor: clamp to [max−8, ∞], then (x+4)/4 + let max_val = log_mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + log_mel.mapv_inplace(|x| (x.max(max_val - 8.0) + 4.0) / 4.0); + + log_mel + } + + /// Invalidate all caches (call on reset). + fn invalidate_cache(&mut self) { + self.cached_power_spec = None; + self.cached_mel_spec = None; + } +} + +// --------------------------------------------------------------------------- +// Mel filterbank construction — Slaney scale, slaney norm +// --------------------------------------------------------------------------- + +/// Convert Hz to mel (Slaney/librosa scale, NOT HTK). +fn hz_to_mel(hz: f32) -> f32 { + const F_SP: f32 = 200.0 / 3.0; // linear region slope (Hz per mel) + const MIN_LOG_HZ: f32 = 1000.0; + const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP; // = 15.0 + // logstep = ln(6.4) / 27 (≈ 0.068752) + let logstep = (6.4_f32).ln() / 27.0; + if hz >= MIN_LOG_HZ { + MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / logstep + } else { + hz / F_SP + } +} + +/// Convert mel back to Hz (Slaney scale). +fn mel_to_hz(mel: f32) -> f32 { + const F_SP: f32 = 200.0 / 3.0; + const MIN_LOG_HZ: f32 = 1000.0; + const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP; + let logstep = (6.4_f32).ln() / 27.0; + if mel >= MIN_LOG_MEL { + MIN_LOG_HZ * ((mel - MIN_LOG_MEL) * logstep).exp() + } else { + mel * F_SP + } +} + +/// Build a Slaney-normalised mel filterbank of shape [n_mels, n_freqs]. +/// +/// Matches `librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax, +/// norm="slaney", dtype=float32)` which is what HuggingFace's +/// `WhisperFeatureExtractor` uses internally. +fn build_mel_filters( + sr: usize, + n_fft: usize, + n_mels: usize, + f_min: f32, + f_max: f32, +) -> Array2 { + let n_freqs = n_fft / 2 + 1; + + // FFT frequency bins: 0, sr/n_fft, 2·sr/n_fft, … + let fft_freqs: Vec = (0..n_freqs) + .map(|i| i as f32 * sr as f32 / n_fft as f32) + .collect(); + + // n_mels + 2 equally-spaced mel points (edge + n_mels centres + edge) + let mel_min = hz_to_mel(f_min); + let mel_max = hz_to_mel(f_max); + let mel_pts: Vec = (0..=(n_mels + 1)) + .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32) + .collect(); + let hz_pts: Vec = mel_pts.iter().map(|&m| mel_to_hz(m)).collect(); + + // Build triangular filters with Slaney normalisation + let mut filters = Array2::::zeros((n_mels, n_freqs)); + for m in 0..n_mels { + let f_left = hz_pts[m]; + let f_center = hz_pts[m + 1]; + let f_right = hz_pts[m + 2]; + // Slaney norm: 2 / (right_hz − left_hz) + let enorm = 2.0 / (f_right - f_left); + + for (k, &f) in fft_freqs.iter().enumerate() { + let w = if f >= f_left && f <= f_center { + (f - f_left) / (f_center - f_left) + } else if f > f_center && f <= f_right { + (f_right - f) / (f_right - f_center) + } else { + 0.0 + }; + filters[[m, k]] = w * enorm; + } + } + filters +} + +// --------------------------------------------------------------------------- +// Hann window +// --------------------------------------------------------------------------- + +/// Periodic Hann window of length `n`, matching `torch.hann_window(n, periodic=True)`. +/// +/// Formula: `w[k] = 0.5 · (1 − cos(2π·k / n))` for k in 0..n. +/// This differs from the symmetric variant (which divides by n−1). +fn periodic_hann(n: usize) -> Vec { + use std::f32::consts::PI; + (0..n) + .map(|k| 0.5 * (1.0 - (2.0 * PI * k as f32 / n as f32).cos())) + .collect() +} + +// --------------------------------------------------------------------------- +// Audio preparation +// --------------------------------------------------------------------------- + +/// Pad or truncate `samples` to exactly `RING_CAPACITY` samples. +/// +/// - Longer: keep the **last** 8 s (discard oldest). +/// - Shorter: **front-pad** with zeros so audio is right-aligned. +fn prepare_audio(samples: &[f32]) -> Vec { + match samples.len().cmp(&RING_CAPACITY) { + std::cmp::Ordering::Equal => samples.to_vec(), + std::cmp::Ordering::Greater => samples[samples.len() - RING_CAPACITY..].to_vec(), + std::cmp::Ordering::Less => { + let mut out = vec![0.0f32; RING_CAPACITY - samples.len()]; + out.extend_from_slice(samples); + out + } + } +} -use crate::{AudioFrame, AudioTurnDetector, TurnError, TurnPrediction}; +// --------------------------------------------------------------------------- +// PipecatSmartTurn +// --------------------------------------------------------------------------- /// Pipecat Smart Turn v3 detector. /// -/// Buffers up to 8 seconds of audio internally. When [`predict`](AudioTurnDetector::predict) -/// is called, it takes the last 8s (zero-padded at front if shorter), -/// extracts Whisper log-mel features, and runs ONNX inference. +/// Buffers up to 8 seconds of audio internally. Call [`push_audio`] with +/// every incoming 16 kHz frame, then call [`predict`] when the VAD fires +/// end-of-speech to get a [`TurnPrediction`]. +/// +/// # Usage with VAD +/// +/// ```no_run +/// # #[cfg(feature = "pipecat")] +/// # { +/// use wavekat_turn::audio::PipecatSmartTurn; +/// use wavekat_turn::AudioTurnDetector; +/// +/// let mut detector = PipecatSmartTurn::new().unwrap(); +/// // ... feed frames via push_audio ... +/// let prediction = detector.predict().unwrap(); +/// println!("{:?} ({:.2})", prediction.state, prediction.confidence); +/// # } +/// ``` +/// +/// [`push_audio`]: AudioTurnDetector::push_audio +/// [`predict`]: AudioTurnDetector::predict pub struct PipecatSmartTurn { - // TODO: ONNX session + audio ring buffer + state + session: ort::session::Session, + ring_buffer: VecDeque, + mel: MelExtractor, + /// Counts samples pushed since the last `predict()` call. + /// Used to compute `shift_frames` for incremental STFT. + samples_since_predict: usize, } +// SAFETY: ort::Session is Send in ort 2.x. Sync is safe because every +// method that touches the session takes &mut self, preventing concurrent use. +unsafe impl Send for PipecatSmartTurn {} +unsafe impl Sync for PipecatSmartTurn {} + impl PipecatSmartTurn { - /// Create a new Smart Turn detector, loading the ONNX model. + /// Load the Smart Turn v3.2 model embedded at compile time. pub fn new() -> Result { - todo!("load Smart Turn v3 ONNX model") + let session = onnx::session_from_memory(MODEL_BYTES)?; + Ok(Self::build(session)) + } + + /// Load a model from a custom path on disk. + /// + /// Useful for CI environments that supply the model file separately, or + /// for evaluating fine-tuned variants without recompiling. + pub fn from_file(path: impl AsRef) -> Result { + let session = onnx::session_from_file(path)?; + Ok(Self::build(session)) + } + + fn build(session: ort::session::Session) -> Self { + Self { + session, + ring_buffer: VecDeque::with_capacity(RING_CAPACITY), + mel: MelExtractor::new(), + samples_since_predict: 0, + } } } impl AudioTurnDetector for PipecatSmartTurn { - fn push_audio(&mut self, _frame: &AudioFrame) { - todo!("append to ring buffer") + /// Append audio to the internal ring buffer. + /// + /// Frames with a sample rate other than 16 kHz are silently dropped. + /// The ring buffer holds at most 8 s; older samples are evicted. + fn push_audio(&mut self, frame: &AudioFrame) { + if frame.sample_rate() != SAMPLE_RATE { + return; + } + let samples = frame.samples(); + // Evict oldest samples to make room + let overflow = (self.ring_buffer.len() + samples.len()).saturating_sub(RING_CAPACITY); + if overflow > 0 { + self.ring_buffer.drain(..overflow); + } + self.ring_buffer.extend(samples.iter().copied()); + self.samples_since_predict += samples.len(); } + /// Run inference on the buffered audio. + /// + /// Takes a snapshot of the ring buffer, pads/truncates to 8 s, extracts + /// Whisper log-mel features, and runs ONNX inference. fn predict(&mut self) -> Result { - todo!("truncate/pad to 8s, extract mel features, run ONNX inference") + let t_start = Instant::now(); + + // Stage 1: Snapshot the ring buffer and prepare exactly 128 000 samples + let shift_frames = self.samples_since_predict / HOP_LENGTH; + self.samples_since_predict = 0; + + let buffered: Vec = self.ring_buffer.iter().copied().collect(); + let audio = prepare_audio(&buffered); + let t_after_audio_prep = Instant::now(); + + // Stage 2: Extract [N_MELS × N_FRAMES] log-mel features (incremental) + let mel_spec = self.mel.extract(&audio, shift_frames); + let t_after_mel = Instant::now(); + + // Stage 3: Reshape to [1, N_MELS, N_FRAMES] and run ONNX inference + let (raw, _) = mel_spec.into_raw_vec_and_offset(); + let input_array = Array3::from_shape_vec((1, N_MELS, N_FRAMES), raw) + .expect("internal: mel output has wrong element count"); + + let input_tensor = Tensor::from_array(input_array) + .map_err(|e| TurnError::BackendError(format!("failed to create input tensor: {e}")))?; + + let outputs = self + .session + .run(inputs!["input_features" => input_tensor]) + .map_err(|e| TurnError::BackendError(format!("inference failed: {e}")))?; + let t_after_onnx = Instant::now(); + + // Extract sigmoid probability from the "logits" output + let output = outputs + .get("logits") + .ok_or_else(|| TurnError::BackendError("missing 'logits' output tensor".into()))?; + let (_, data): (_, &[f32]) = output + .try_extract_tensor() + .map_err(|e| TurnError::BackendError(format!("failed to extract logits: {e}")))?; + let probability = *data + .first() + .ok_or_else(|| TurnError::BackendError("logits tensor is empty".into()))?; + + let latency_ms = t_start.elapsed().as_millis() as u64; + + let us = |a: Instant, b: Instant| (b - a).as_secs_f64() * 1_000_000.0; + let stage_times = vec![ + StageTiming { + name: "audio_prep", + us: us(t_start, t_after_audio_prep), + }, + StageTiming { + name: "mel", + us: us(t_after_audio_prep, t_after_mel), + }, + StageTiming { + name: "onnx", + us: us(t_after_mel, t_after_onnx), + }, + ]; + + // probability = P(turn complete); > 0.5 means the speaker has finished + let (state, confidence) = if probability > 0.5 { + (TurnState::Finished, probability) + } else { + (TurnState::Unfinished, 1.0 - probability) + }; + + Ok(TurnPrediction { + state, + confidence, + latency_ms, + stage_times, + }) } + /// Clear the ring buffer. Call at the start of each new speech turn. fn reset(&mut self) { - todo!("clear ring buffer and internal state") + self.ring_buffer.clear(); + self.samples_since_predict = 0; + self.mel.invalidate_cache(); } } diff --git a/crates/wavekat-turn/src/lib.rs b/crates/wavekat-turn/src/lib.rs index 75e22c9..98ca819 100644 --- a/crates/wavekat-turn/src/lib.rs +++ b/crates/wavekat-turn/src/lib.rs @@ -18,6 +18,9 @@ pub mod error; +#[cfg(any(feature = "pipecat", feature = "livekit"))] +pub(crate) mod onnx; + #[cfg(feature = "pipecat")] pub mod audio; @@ -38,12 +41,23 @@ pub enum TurnState { Wait, } +/// Per-stage timing entry. +#[derive(Debug, Clone)] +pub struct StageTiming { + /// Stage name (e.g. "audio_prep", "mel", "onnx"). + pub name: &'static str, + /// Time in microseconds for this stage. + pub us: f64, +} + /// A turn detection prediction with confidence and timing metadata. #[derive(Debug, Clone)] pub struct TurnPrediction { pub state: TurnState, pub confidence: f32, pub latency_ms: u64, + /// Per-stage timing breakdown in pipeline order. + pub stage_times: Vec, } /// A single turn in the conversation, for context-aware text detectors. diff --git a/crates/wavekat-turn/src/onnx.rs b/crates/wavekat-turn/src/onnx.rs new file mode 100644 index 0000000..0387fcf --- /dev/null +++ b/crates/wavekat-turn/src/onnx.rs @@ -0,0 +1,41 @@ +//! Shared helpers for ONNX-based turn detection backends. + +use crate::error::TurnError; +use ort::session::Session; + +/// Create an ONNX Runtime session from a model file on disk. +pub(crate) fn session_from_file>(path: P) -> Result { + Session::builder() + .map_err(|e| TurnError::BackendError(format!("failed to create session builder: {e}")))? + .with_intra_threads(1) + .map_err(|e| TurnError::BackendError(format!("failed to set intra threads: {e}")))? + .commit_from_file(path) + .map_err(|e| TurnError::BackendError(format!("failed to load ONNX model: {e}"))) +} + +/// Create an ONNX Runtime session from model bytes in memory. +pub(crate) fn session_from_memory(model_bytes: &[u8]) -> Result { + Session::builder() + .map_err(|e| TurnError::BackendError(format!("failed to create session builder: {e}")))? + .with_intra_threads(1) + .map_err(|e| TurnError::BackendError(format!("failed to set intra threads: {e}")))? + .commit_from_memory(model_bytes) + .map_err(|e| TurnError::BackendError(format!("failed to load ONNX model: {e}"))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn session_from_file_nonexistent() { + let result = session_from_file("/nonexistent/path/to/model.onnx"); + assert!(matches!(result, Err(TurnError::BackendError(_)))); + } + + #[test] + fn session_from_memory_invalid_bytes() { + let result = session_from_memory(b"not a valid onnx model"); + assert!(matches!(result, Err(TurnError::BackendError(_)))); + } +} diff --git a/crates/wavekat-turn/tests/pipecat.rs b/crates/wavekat-turn/tests/pipecat.rs new file mode 100644 index 0000000..0804308 --- /dev/null +++ b/crates/wavekat-turn/tests/pipecat.rs @@ -0,0 +1,167 @@ +//! Integration tests for the Pipecat Smart Turn v3 backend. +//! +//! Run with: `cargo test --features pipecat` +//! Run RTF test with: `cargo test --features pipecat --release` + +#![cfg(feature = "pipecat")] + +use wavekat_turn::audio::PipecatSmartTurn; +use wavekat_turn::{AudioFrame, AudioTurnDetector, TurnPrediction}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Create an AudioFrame of silence (zeros) at 16 kHz. +fn silence(num_samples: usize) -> AudioFrame<'static> { + let samples = vec![0.0f32; num_samples]; + AudioFrame::new(samples.as_slice(), 16_000).into_owned() +} + +/// Push `duration_secs` of silence in 160-sample chunks (10 ms each). +fn push_silence(detector: &mut PipecatSmartTurn, duration_secs: f32) { + let total = (duration_secs * 16_000.0) as usize; + let chunk = 160; + let mut pushed = 0; + while pushed < total { + let n = chunk.min(total - pushed); + detector.push_audio(&silence(n)); + pushed += n; + } +} + +fn valid_prediction(pred: &TurnPrediction) { + assert!( + pred.confidence >= 0.0 && pred.confidence <= 1.0, + "confidence out of range: {}", + pred.confidence + ); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[test] +fn test_new_loads_model() { + PipecatSmartTurn::new().expect("PipecatSmartTurn::new() should succeed"); +} + +#[test] +fn test_from_file_loads_model() { + let tmp = std::env::temp_dir().join("wavekat_turn_test"); + std::fs::create_dir_all(&tmp).unwrap(); + let path = tmp.join("smart-turn-test.onnx"); + + let model_bytes = include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx")); + std::fs::write(&path, model_bytes).unwrap(); + + PipecatSmartTurn::from_file(&path).expect("from_file should succeed with a valid model"); + + let _ = std::fs::remove_file(&path); +} + +#[test] +fn test_predict_returns_valid_output() { + let mut d = PipecatSmartTurn::new().unwrap(); + push_silence(&mut d, 2.0); + let pred = d.predict().unwrap(); + valid_prediction(&pred); +} + +#[test] +fn test_predict_with_empty_buffer() { + // Empty buffer is front-padded to 8 s of zeros; inference must succeed. + let mut d = PipecatSmartTurn::new().unwrap(); + let pred = d.predict().unwrap(); + valid_prediction(&pred); +} + +#[test] +fn test_push_audio_wrong_sample_rate_is_ignored() { + let mut d = PipecatSmartTurn::new().unwrap(); + let bad = AudioFrame::new(vec![0.5f32; 160].as_slice(), 8_000).into_owned(); + d.push_audio(&bad); + // Frame should have been dropped; predict must still succeed. + let pred = d.predict().unwrap(); + valid_prediction(&pred); +} + +#[test] +fn test_reset_clears_buffer() { + let mut d = PipecatSmartTurn::new().unwrap(); + push_silence(&mut d, 4.0); + d.reset(); + // After reset the buffer is empty; should behave identically to a fresh instance. + let fresh = PipecatSmartTurn::new().unwrap().predict().unwrap(); + let after_reset = d.predict().unwrap(); + assert_eq!( + after_reset.state, fresh.state, + "state after reset should match a fresh instance" + ); + assert!( + (after_reset.confidence - fresh.confidence).abs() < 1e-5, + "confidence after reset should match a fresh instance" + ); +} + +#[test] +fn test_ring_buffer_caps_at_8_seconds() { + let mut d = PipecatSmartTurn::new().unwrap(); + push_silence(&mut d, 10.0); // 10 s > 8 s capacity; must not panic + valid_prediction(&d.predict().unwrap()); +} + +#[test] +fn test_multiple_predicts_are_deterministic() { + let mut d = PipecatSmartTurn::new().unwrap(); + push_silence(&mut d, 2.0); + let p1 = d.predict().unwrap(); + let p2 = d.predict().unwrap(); + assert_eq!( + p1.state, p2.state, + "repeated predict should give same state" + ); + assert!( + (p1.confidence - p2.confidence).abs() < 1e-5, + "repeated predict should give same confidence" + ); +} + +/// RTF target: < 50 ms. Only enforced in release builds because the debug +/// binary is ~10× slower. +#[test] +#[cfg(not(debug_assertions))] +fn test_latency_under_50ms() { + let mut d = PipecatSmartTurn::new().unwrap(); + push_silence(&mut d, 2.0); + let pred = d.predict().unwrap(); + assert!( + pred.latency_ms < 50, + "inference too slow: {} ms (limit: 50 ms)", + pred.latency_ms + ); +} + +#[test] +fn test_from_file_invalid_path_returns_error() { + let result = PipecatSmartTurn::from_file("/nonexistent/path/model.onnx"); + assert!( + result.is_err(), + "from_file with invalid path should return an error" + ); +} + +/// Smoke test: latency is measured and non-zero (always runs, including debug). +#[test] +fn test_latency_is_measured() { + let mut d = PipecatSmartTurn::new().unwrap(); + push_silence(&mut d, 2.0); + let pred = d.predict().unwrap(); + // latency_ms == 0 would mean the timer wasn't working + assert!( + pred.latency_ms < 60_000, + "latency suspiciously large: {} ms", + pred.latency_ms + ); +} diff --git a/docs/plan-accuracy.md b/docs/plan-accuracy.md new file mode 100644 index 0000000..4d2d047 --- /dev/null +++ b/docs/plan-accuracy.md @@ -0,0 +1,167 @@ +# Plan: Cross-Validate Rust Implementation Against Python Reference + +**Status:** Not started +**Date:** 2026-03-28 + +--- + +## Goal + +Verify that our Rust mel preprocessing and ONNX inference pipeline produces probabilities that +match the original Pipecat Python implementation within a tight tolerance. This catches any +silent preprocessing mismatch (wrong mel scale, wrong window, wrong padding, wrong normalization) +that unit tests cannot detect because they only check output shape and range. + +--- + +## Why this matters + +The Rust implementation re-implements `WhisperFeatureExtractor` from scratch. Any divergence +in the mel filterbank, Hann window, STFT center-padding, or log-normalization will silently +shift probabilities. A 5% probability shift near the 0.5 threshold flips a turn decision. +The only way to be confident the two pipelines agree is to feed them identical audio and +compare outputs numerically. + +--- + +## Decisions made + +- **Python script generates reference data once.** Output is committed as a JSON fixture + (`tests/fixtures/reference.json`) so the Rust accuracy test has no Python runtime dependency. +- **Tolerance: ±0.02 probability.** The model uses float32 throughout; numerical differences + from the Rust FFT vs NumPy FFT should be well under this. If a case fails, it signals a + real preprocessing bug, not floating-point noise. +- **Three fixture audio clips.** Enough to cover the key behavioral regions without large + binary assets. Total fixture size should stay under ~500 KB. + +--- + +## Fixture audio clips + +| File | Content | Expected region | +|------|---------|-----------------| +| `tests/fixtures/silence_2s.wav` | 2 s of zeros at 16 kHz | Low P(complete) — no speech | +| `tests/fixtures/speech_finished.wav` | Real or synthetic utterance that ends cleanly | High P(complete) | +| `tests/fixtures/speech_mid.wav` | Real or synthetic utterance cut mid-word | Low P(complete) | + +WAV format: 16 kHz, mono, 16-bit PCM (hound-compatible). + +For `speech_finished.wav` and `speech_mid.wav`, use short clips (1–3 s) from a freely +licensed speech corpus, or generate synthetic speech with a TTS tool. Commit the WAVs +directly — they are small enough. + +--- + +## Phase 1 — Python reference script + +Create `scripts/gen_reference.py`. + +**Dependencies** (not added to the crate — Python only): +``` +pip install pipecat-ai transformers onnxruntime numpy soundfile +``` + +**What it does:** +1. Downloads `smart-turn-v3.2-cpu.onnx` if not already present (same URL as build.rs) +2. For each WAV in `tests/fixtures/`: + - Loads audio via `soundfile` as float32 at 16 kHz + - Runs `WhisperFeatureExtractor(chunk_length=8)` to get `input_features` + - Runs `ort.InferenceSession` on the ONNX model + - Records `{ "file": "...", "probability": }` +3. Writes `tests/fixtures/reference.json` + +**Re-run when:** +- A fixture WAV changes +- The model version changes (bump `MODEL_VERSION` in `build.rs` at the same time) + +--- + +## Phase 2 — WAV fixtures + +Generate or source the three WAV clips and commit them to `tests/fixtures/`. + +For `silence_2s.wav`: +```python +import numpy as np, soundfile as sf +sf.write("tests/fixtures/silence_2s.wav", np.zeros(32000, dtype=np.float32), 16000) +``` + +For speech clips, options in order of preference: +1. Record 2–3 s clips specifically for this test +2. Use a clip from [CMU Arctic](http://www.festvox.org/cmu_arctic/) or + [LJ Speech](https://keithito.com/LJ-Speech-Dataset/) (both public domain / CC0) +3. Generate with `piper` TTS (Apache 2.0) + +--- + +## Phase 3 — Run the Python script and commit reference.json + +```bash +python scripts/gen_reference.py +``` + +Inspect the output — confirm the probabilities make sense (silence ≈ low, finished ≈ high). +Commit `tests/fixtures/reference.json` alongside the WAV files. + +`reference.json` format: +```json +[ + { "file": "silence_2s.wav", "probability": 0.03 }, + { "file": "speech_finished.wav", "probability": 0.91 }, + { "file": "speech_mid.wav", "probability": 0.08 } +] +``` + +--- + +## Phase 4 — Rust accuracy test + +Add `tests/accuracy.rs` (under `#[cfg(feature = "pipecat")]`). + +**What it does:** +1. Reads `tests/fixtures/reference.json` at test time +2. For each entry, loads the corresponding WAV with `hound` +3. Pushes all audio frames through `PipecatSmartTurn` +4. Calls `predict()` and reads the raw probability +5. Asserts `|rust_prob - python_prob| <= TOLERANCE` (0.02) + +```rust +const TOLERANCE: f32 = 0.02; +``` + +**Getting the raw probability out of `TurnPrediction`:** +`TurnPrediction.confidence` is already the raw sigmoid value (we set it to `probability` for +`Finished` and `1.0 - probability` for `Unfinished`). To recover the original probability: + +```rust +let raw_prob = match pred.state { + TurnState::Finished => pred.confidence, + TurnState::Unfinished => 1.0 - pred.confidence, + TurnState::Wait => unreachable!(), +}; +``` + +Alternatively, expose `raw_probability: f32` directly on `TurnPrediction` — see open questions. + +**Test names:** +- `test_accuracy_silence` +- `test_accuracy_speech_finished` +- `test_accuracy_speech_mid` + +--- + +## Open questions + +1. **Expose raw probability on `TurnPrediction`?** + Currently the struct only has `confidence` which loses the original P(complete) for + `Unfinished` cases. Options: + - Add `raw_probability: f32` field to `TurnPrediction` (cleaner, but changes the public API) + - Reconstruct from `(state, confidence)` in the test (works, but fragile) + Resolve before starting Phase 4. + +2. **Speech fixture source.** Decide on LJ Speech clips or recorded clips before Phase 2. + LJ Speech is easiest (download a sentence, trim to 2–3 s). Record the chosen file name + and source URL in a comment in `scripts/gen_reference.py`. + +3. **CI integration.** The accuracy test needs the WAV fixtures and `reference.json` committed + to the repo. Confirm the total asset size is acceptable before merging. diff --git a/docs/plan-backends.md b/docs/plan-backends.md new file mode 100644 index 0000000..ec17c97 --- /dev/null +++ b/docs/plan-backends.md @@ -0,0 +1,159 @@ +# Plan: Implement Turn Detection Backends + +**Status:** Phase 1–4 complete +**Date:** 2026-03-28 + +--- + +## Decisions made + +- **Pipecat first.** Implement `PipecatSmartTurn` (audio, ~8 MB) before `LiveKitEou` (text, + ~400 MB). Smaller model, faster to iterate. +- **Follow wavekat-vad pattern.** Build-time model download via `build.rs` + `include_bytes!()`, + same env-var overrides (`*_MODEL_PATH`, `*_MODEL_URL`). +- **Turn logic stays here.** The lab (`wavekat-lab`) calls these backends as a library consumer. + No turn logic lives in the lab. +- **Model loading strategy by size.** + - **< ~30 MB → embed** with `include_bytes!()`. Binary size is acceptable; zero runtime setup. + Pipecat (8 MB) uses this path. + - **≥ ~30 MB → runtime load** from disk. Embedding would bloat the binary unacceptably. + Future large-model backends must use this path (see "Out of scope" section). +- **`from_file()` constructor on all backends.** Even embedded-model backends expose a + `from_file(path)` constructor so users can substitute custom or fine-tuned weights, and to + establish the pattern that future large-model backends will use as their primary constructor. + +--- + +## Current state + +`PipecatSmartTurn` is fully implemented and all integration tests pass. +`LiveKitEou` remains a stub (out of scope for this branch). + +``` +src/ +├── lib.rs — traits: AudioTurnDetector, TextTurnDetector, TurnPrediction, TurnState +├── error.rs — TurnError: BackendError, InvalidInput, ModelNotLoaded +├── onnx.rs — shared session_from_file / session_from_memory helpers +├── audio/ +│ ├── mod.rs +│ └── pipecat.rs — PipecatSmartTurn (complete) +└── text/ + ├── mod.rs + └── livekit.rs — LiveKitEou (stub, out of scope) +build.rs — downloads smart-turn-v3.2-cpu.onnx at build time +tests/ +└── pipecat.rs — 9 integration tests (all pass) +``` + +--- + +## Trait API (stable, do not change) + +```rust +pub trait AudioTurnDetector: Send + Sync { + fn push_audio(&mut self, frame: &AudioFrame); // 16 kHz mono f32 + fn predict(&mut self) -> Result; + fn reset(&mut self); +} + +pub trait TextTurnDetector: Send + Sync { + fn predict_text(&mut self, transcript: &str, context: &[ConversationTurn]) + -> Result; + fn reset(&mut self); +} +``` + +`TurnPrediction` — `{ state: TurnState, confidence: f32, latency_ms: u64 }` +`TurnState` — `Finished | Unfinished | Wait` + +--- + +## Phase 1 — Research ✅ + +**Done.** Findings pinned in `src/audio/pipecat.rs` module-level comments. + +| Item | Finding | +|------|---------| +| Model URL | `https://huggingface.co/pipecat-ai/smart-turn-v3/resolve/main/smart-turn-v3.2-cpu.onnx` | +| Input tensor | `input_features`, shape `[B, 80, 800]`, float32 | +| Output tensor | `logits`, shape `[B, 1]`, float32 — sigmoid P(turn complete), NOT raw logits | +| Mel scale | **Slaney** (NOT HTK); `norm="slaney"` | +| n_fft / hop | 400 / 160 samples (25 ms / 10 ms at 16 kHz) | +| Mel bins | 80; frequency range 0–8 000 Hz | +| Window | Periodic Hann (`torch.hann_window(400, periodic=True)`) | +| Pre-emphasis | None | +| Log norm | `log10`, clamp `[max−8, ∞]`, then `(x + 4) / 4` | +| Audio buffer | 8 s = 128 000 samples; front-pad shorter, keep last 8 s for longer | +| License | BSD 2-Clause | + +--- + +## Phase 2 — Build system ✅ + +**Done.** + +- `build.rs` downloads `smart-turn-v3.2-cpu.onnx` to `OUT_DIR` with version-based caching +- Env-var overrides: `PIPECAT_SMARTTURN_MODEL_PATH`, `PIPECAT_SMARTTURN_MODEL_URL` +- Docs.rs guard writes a zero-byte placeholder when `DOCS_RS=1` +- `Cargo.toml`: `build = "build.rs"`, `ureq` as optional build-dep activated by `pipecat` feature + +Note: SHA-256 verification was omitted in favour of version-based caching (same as wavekat-vad). + +--- + +## Phase 3 — PipecatSmartTurn implementation ✅ + +**Done.** `src/audio/pipecat.rs` and `src/onnx.rs` written and compiling. + +Key implementation decisions: +- `MelExtractor` precomputes the Slaney filterbank matrix and Hann window once at construction; + reuses FFT plan and scratch buffers across calls +- Center-pad (`N_FFT/2` zeros each side) replicates librosa `center=True` STFT, producing + exactly 800 frames from 128 000 samples +- `push_audio` silently drops frames with wrong sample rate (no return value in trait) +- `ndarray = "0.17"` required to match `ort`'s ndarray feature version + +--- + +## Phase 4 — Tests ✅ + +**Done.** `tests/pipecat.rs` with 9 integration tests, all passing: + +| Test | What it checks | +|------|---------------| +| `test_new_loads_model` | `new()` succeeds | +| `test_from_file_loads_model` | `from_file()` succeeds with a valid path | +| `test_predict_returns_valid_output` | confidence ∈ [0, 1] | +| `test_predict_with_empty_buffer` | empty buffer inference succeeds | +| `test_push_audio_wrong_sample_rate_is_ignored` | 8 kHz frame is dropped | +| `test_reset_clears_buffer` | state after reset matches fresh instance | +| `test_ring_buffer_caps_at_8_seconds` | 10 s of audio doesn't panic | +| `test_multiple_predicts_are_deterministic` | same buffer → same output | +| `test_latency_under_50ms` | RTF < 50 ms (release builds only) | + +Note: the current tests do not cross-validate against the Python reference implementation. +That is tracked in **[`plan-accuracy.md`](plan-accuracy.md)**. + +--- + +## Open questions + +All research questions from Phases 1–3 are resolved. No blocking open questions remain +for this branch. + +--- + +## Out of scope — LiveKitEou + +> **Not part of this branch.** Will be implemented in a dedicated feature branch. + +Key notes to carry forward: + +- Model: distilled Qwen2.5-0.5B ONNX (~400 MB), LiveKit Model License +- Input: tokenized transcript + conversation context +- At 400 MB, `include_bytes!()` is not viable — needs a runtime-load strategy (build.rs + downloads to a user cache dir, binary loads from disk via `from_file()`) +- The `from_file()` constructor established on `PipecatSmartTurn` in Phase 3 gives LiveKit + the same public API shape to follow +- Open questions for that branch: model URL, `tokenizers` crate acceptability, exact input + format, CI/CD cache-dir strategy