From 78d1426d14131a7ad5851fe4bc678e6108bb8d0f Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Sat, 28 Mar 2026 16:39:39 +1300 Subject: [PATCH 1/7] docs: add backend implementation plan Co-Authored-By: Claude Sonnet 4.6 --- docs/plan-backends.md | 174 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 docs/plan-backends.md diff --git a/docs/plan-backends.md b/docs/plan-backends.md new file mode 100644 index 0000000..3f423d7 --- /dev/null +++ b/docs/plan-backends.md @@ -0,0 +1,174 @@ +# Plan: Implement Turn Detection Backends + +**Status:** In progress +**Date:** 2026-03-28 + +--- + +## Decisions made + +- **Pipecat first.** Implement `PipecatSmartTurn` (audio, ~8 MB) before `LiveKitEou` (text, + ~400 MB). Smaller model, faster to iterate. +- **Follow wavekat-vad pattern.** Build-time model download via `build.rs` + `include_bytes!()`, + same env-var overrides (`*_MODEL_PATH`, `*_MODEL_URL`). +- **Turn logic stays here.** The lab (`wavekat-lab`) calls these backends as a library consumer. + No turn logic lives in the lab. + +--- + +## Current state + +Both backends are stubs with `todo!()` — the crate compiles but cannot run inference. + +``` +src/ +├── lib.rs — traits: AudioTurnDetector, TextTurnDetector, TurnPrediction, TurnState +├── error.rs — TurnError: BackendError, InvalidInput, ModelNotLoaded +├── audio/ +│ ├── mod.rs +│ └── pipecat.rs — PipecatSmartTurn (stub) +└── text/ + ├── mod.rs + └── livekit.rs — LiveKitEou (stub) +``` + +No `build.rs` yet. No tests. + +--- + +## Trait API (stable, do not change) + +```rust +pub trait AudioTurnDetector: Send + Sync { + fn push_audio(&mut self, frame: &AudioFrame); // 16 kHz mono f32 + fn predict(&mut self) -> Result; + fn reset(&mut self); +} + +pub trait TextTurnDetector: Send + Sync { + fn predict_text(&mut self, transcript: &str, context: &[ConversationTurn]) + -> Result; + fn reset(&mut self); +} +``` + +`TurnPrediction` — `{ state: TurnState, confidence: f32, latency_ms: u64 }` +`TurnState` — `Finished | Unfinished | Wait` + +--- + +## Phase 1 — Research (prerequisite) + +Before writing any code, pin down the model specifics: + +1. **Model source.** Find the official Pipecat Smart Turn v3 ONNX download URL (Pipecat GitHub + releases or Hugging Face). Confirm license (BSD 2-Clause noted in stub comments). + +2. **Input/output tensor shapes.** Load the model in a scratch script or `netron` and record: + - Input tensor: name, shape, dtype + - Output tensor: name(s), shape, dtype + - Whether output is a single confidence float or logits for [Finished, Unfinished, Wait] + +3. **Mel-feature spec.** Confirm what preprocessing the model expects: + - Frame size + hop length + - Number of mel bins (Whisper uses 80) + - Frequency range + - Mel scale formula (HTK vs Kaldi) + - Whether pre-emphasis is applied + +4. **Audio buffer length.** Stub says "up to 8 seconds" — confirm from model input shape. + +Document findings as comments in `pipecat.rs` before implementation. + +--- + +## Phase 2 — Build system + +Create `crates/wavekat-turn/build.rs` following the wavekat-vad pattern: + +- Download Smart Turn v3 ONNX to `OUT_DIR` at build time +- SHA-256 verification +- Env-var overrides: + - `PIPECAT_SMARTTURN_MODEL_PATH` — use a local file instead of downloading + - `PIPECAT_SMARTTURN_MODEL_URL` — override download URL +- Docs.rs guard: write a zero-byte placeholder when `DOCS_RS=1` + +Add to `Cargo.toml`: +```toml +[package] +build = "build.rs" + +[build-dependencies] +ureq = { version = "3", features = ["tls"] } +``` + +--- + +## Phase 3 — PipecatSmartTurn implementation + +Fill in `src/audio/pipecat.rs`: + +**Struct:** +```rust +pub struct PipecatSmartTurn { + session: Session, + ring_buffer: VecDeque, // 8s × 16kHz = 128k samples + // mel extractor fields TBD from Phase 1 research +} +``` + +**`new()`** — load model via `include_bytes!(concat!(env!("OUT_DIR"), "/..."))`, +create `ort::Session`, initialize ring buffer. + +**`push_audio()`** — validate sample rate (16 kHz), convert i16→f32 if needed, +append to ring buffer (evict oldest when over capacity). + +**`predict()`** — snapshot ring buffer, pad/truncate to model's expected length, +extract mel features, build ndarray input tensor, `session.run(...)`, parse output, +record `Instant` before/after for `latency_ms`. + +**`reset()`** — `ring_buffer.clear()`. + +Reference implementations: +- `wavekat-vad/src/backends/silero.rs` — ONNX session + state management +- `wavekat-vad/src/backends/onnx.rs` — session builder helper +- `wavekat-vad/src/backends/firered/fbank.rs` — mel filterbank (adapt if spec matches) + +--- + +## Phase 4 — Tests + +Add `tests/pipecat.rs` (integration tests under `#[cfg(feature = "pipecat")]`): + +- `test_new_loads_model` — `PipecatSmartTurn::new()` succeeds +- `test_predict_silence` — feed 2s of zeros, expect low confidence +- `test_predict_finished` — feed known-good finished-turn audio (WAV fixture), expect + `TurnState::Finished` with confidence > 0.7 +- `test_reset_clears_buffer` — push audio, reset, predict on empty buffer returns low confidence +- `test_rtf` — assert `latency_ms` < 50 ms (well under the ~12 ms target with headroom for CI) + +Add a small WAV fixture (`tests/fixtures/finished_turn.wav`) for the audio test cases. + +--- + +## Phase 5 — LiveKitEou (deferred) + +Implement `src/text/livekit.rs` after Pipecat is proven end-to-end. + +Different approach — text model using a tokenizer: +- Model: distilled Qwen2.5-0.5B ONNX (~400 MB), LiveKit Model License +- Input: tokenized transcript + conversation context +- Likely needs a tokenizer crate (e.g. `tokenizers` from HuggingFace) + +Open questions before starting: +1. Confirm model URL and whether `tokenizers` crate is acceptable (it's a large dep) +2. Confirm exact input format the model expects for transcript + context + +--- + +## Open questions + +1. **Smart Turn v3 model URL** — not yet confirmed (needed for Phase 2) +2. **Exact input tensor shape** — need to inspect the model (needed for Phase 3) +3. **Mel-feature spec** — need to confirm to avoid silent preprocessing mismatch +4. **LiveKit tokenizer strategy** — `tokenizers` crate vs. manual BPE (needed before Phase 5) From 893ac1e813cbcf707401245c3013e818b9036ad1 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Sat, 28 Mar 2026 16:52:29 +1300 Subject: [PATCH 2/7] docs: scope plan to Pipecat only, add model loading strategy Co-Authored-By: Claude Sonnet 4.6 --- docs/plan-backends.md | 61 ++++++++++++++++++++++++++++++------------- 1 file changed, 43 insertions(+), 18 deletions(-) diff --git a/docs/plan-backends.md b/docs/plan-backends.md index 3f423d7..2a62c65 100644 --- a/docs/plan-backends.md +++ b/docs/plan-backends.md @@ -13,6 +13,14 @@ same env-var overrides (`*_MODEL_PATH`, `*_MODEL_URL`). - **Turn logic stays here.** The lab (`wavekat-lab`) calls these backends as a library consumer. No turn logic lives in the lab. +- **Model loading strategy by size.** + - **< ~30 MB → embed** with `include_bytes!()`. Binary size is acceptable; zero runtime setup. + Pipecat (8 MB) uses this path. + - **≥ ~30 MB → runtime load** from disk. Embedding would bloat the binary unacceptably. + Future large-model backends must use this path (see "Out of scope" section). +- **`from_file()` constructor on all backends.** Even embedded-model backends expose a + `from_file(path)` constructor so users can substitute custom or fine-tuned weights, and to + establish the pattern that future large-model backends will use as their primary constructor. --- @@ -84,7 +92,10 @@ Document findings as comments in `pipecat.rs` before implementation. ## Phase 2 — Build system -Create `crates/wavekat-turn/build.rs` following the wavekat-vad pattern: +Create `crates/wavekat-turn/build.rs` following the wavekat-vad pattern. + +This phase covers the **embedded** path only (Pipecat). Large-model backends (LiveKit) need +a different build.rs strategy described in Phase 5. - Download Smart Turn v3 ONNX to `OUT_DIR` at build time - SHA-256 verification @@ -117,8 +128,20 @@ pub struct PipecatSmartTurn { } ``` -**`new()`** — load model via `include_bytes!(concat!(env!("OUT_DIR"), "/..."))`, -create `ort::Session`, initialize ring buffer. +**Constructors:** + +```rust +/// Default constructor — loads the model embedded at compile time. +pub fn new() -> Result { ... } + +/// Load a custom model from disk — useful for fine-tuned weights or CI environments +/// where the binary should stay small and the model is provided separately. +pub fn from_file(path: impl AsRef) -> Result { ... } +``` + +`new()` calls `session_from_memory(include_bytes!(concat!(env!("OUT_DIR"), "/...")))`. +`from_file()` calls `session_from_file(path)`. Both share `Self::build(session)` for +the rest of initialization. **`push_audio()`** — validate sample rate (16 kHz), convert i16→f32 if needed, append to ring buffer (evict oldest when over capacity). @@ -141,6 +164,7 @@ Reference implementations: Add `tests/pipecat.rs` (integration tests under `#[cfg(feature = "pipecat")]`): - `test_new_loads_model` — `PipecatSmartTurn::new()` succeeds +- `test_from_file_loads_model` — `PipecatSmartTurn::from_file(path)` succeeds given a valid path - `test_predict_silence` — feed 2s of zeros, expect low confidence - `test_predict_finished` — feed known-good finished-turn audio (WAV fixture), expect `TurnState::Finished` with confidence > 0.7 @@ -151,24 +175,25 @@ Add a small WAV fixture (`tests/fixtures/finished_turn.wav`) for the audio test --- -## Phase 5 — LiveKitEou (deferred) +## Open questions -Implement `src/text/livekit.rs` after Pipecat is proven end-to-end. +1. **Smart Turn v3 model URL** — not yet confirmed (needed for Phase 2) +2. **Exact input tensor shape** — need to inspect the model (needed for Phase 3) +3. **Mel-feature spec** — need to confirm to avoid silent preprocessing mismatch -Different approach — text model using a tokenizer: -- Model: distilled Qwen2.5-0.5B ONNX (~400 MB), LiveKit Model License -- Input: tokenized transcript + conversation context -- Likely needs a tokenizer crate (e.g. `tokenizers` from HuggingFace) +--- -Open questions before starting: -1. Confirm model URL and whether `tokenizers` crate is acceptable (it's a large dep) -2. Confirm exact input format the model expects for transcript + context +## Out of scope — LiveKitEou ---- +> **Not part of this branch.** Will be implemented in a dedicated feature branch. -## Open questions +Key notes to carry forward: -1. **Smart Turn v3 model URL** — not yet confirmed (needed for Phase 2) -2. **Exact input tensor shape** — need to inspect the model (needed for Phase 3) -3. **Mel-feature spec** — need to confirm to avoid silent preprocessing mismatch -4. **LiveKit tokenizer strategy** — `tokenizers` crate vs. manual BPE (needed before Phase 5) +- Model: distilled Qwen2.5-0.5B ONNX (~400 MB), LiveKit Model License +- Input: tokenized transcript + conversation context +- At 400 MB, `include_bytes!()` is not viable — needs a runtime-load strategy (build.rs + downloads to a user cache dir, binary loads from disk via `from_file()`) +- The `from_file()` constructor established on `PipecatSmartTurn` in Phase 3 gives LiveKit + the same public API shape to follow +- Open questions for that branch: model URL, `tokenizers` crate acceptability, exact input + format, CI/CD cache-dir strategy From b3bfaac4631699261cd0c6f2aeb6169be058fec4 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Sat, 28 Mar 2026 17:26:59 +1300 Subject: [PATCH 3/7] feat: implement PipecatSmartTurn backend MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - build.rs: download smart-turn-v3.2-cpu.onnx at build time with version caching and PIPECAT_SMARTTURN_MODEL_{PATH,URL} overrides - src/onnx.rs: shared session_from_file / session_from_memory helpers - src/audio/pipecat.rs: full implementation - MelExtractor: Slaney mel filterbank, periodic Hann window, realfft STFT - PipecatSmartTurn::new() (embedded model) and from_file(path) - push_audio: 16 kHz ring buffer, 8s capacity, wrong-rate frames dropped - predict: mel features → ONNX inference → TurnPrediction - reset: clears ring buffer - tests/pipecat.rs: 9 integration tests; RTF < 50ms enforced in release Co-Authored-By: Claude Sonnet 4.6 --- crates/wavekat-turn/Cargo.toml | 11 +- crates/wavekat-turn/build.rs | 102 ++++++ crates/wavekat-turn/src/audio/pipecat.rs | 420 ++++++++++++++++++++++- crates/wavekat-turn/src/lib.rs | 3 + crates/wavekat-turn/src/onnx.rs | 41 +++ crates/wavekat-turn/tests/pipecat.rs | 151 ++++++++ 6 files changed, 712 insertions(+), 16 deletions(-) create mode 100644 crates/wavekat-turn/build.rs create mode 100644 crates/wavekat-turn/src/onnx.rs create mode 100644 crates/wavekat-turn/tests/pipecat.rs diff --git a/crates/wavekat-turn/Cargo.toml b/crates/wavekat-turn/Cargo.toml index 9700852..8f2bc82 100644 --- a/crates/wavekat-turn/Cargo.toml +++ b/crates/wavekat-turn/Cargo.toml @@ -11,10 +11,11 @@ readme = "../../README.md" keywords = ["turn-detection", "voice", "audio", "telephony", "wavekat"] categories = ["multimedia::audio"] exclude = ["CHANGELOG.md"] +build = "build.rs" [features] default = [] -pipecat = ["dep:ort", "dep:ndarray"] +pipecat = ["dep:ort", "dep:ndarray", "dep:realfft", "dep:ureq"] livekit = ["dep:ort", "dep:ndarray"] [dependencies] @@ -22,8 +23,12 @@ wavekat-core = "0.0.2" thiserror = "2" # ONNX backends (optional) -ort = { version = "2.0.0-rc.12", optional = true } -ndarray = { version = "0.16", optional = true } +ort = { version = "2.0.0-rc.12", optional = true, features = ["ndarray"] } +ndarray = { version = "0.17", optional = true } +realfft = { version = "3", optional = true } + +[build-dependencies] +ureq = { version = "3", optional = true } [dev-dependencies] hound = "3.5" diff --git a/crates/wavekat-turn/build.rs b/crates/wavekat-turn/build.rs new file mode 100644 index 0000000..46755bc --- /dev/null +++ b/crates/wavekat-turn/build.rs @@ -0,0 +1,102 @@ +//! Build script for wavekat-turn. +//! +//! Downloads ONNX model(s) at build time when the corresponding feature flag +//! is enabled. Models are written to `OUT_DIR` and embedded via `include_bytes!`. +//! +//! # Environment variables +//! +//! | Variable | Effect | +//! |----------------------------------|--------------------------------------------| +//! | `PIPECAT_SMARTTURN_MODEL_PATH` | Use this local file instead of downloading | +//! | `PIPECAT_SMARTTURN_MODEL_URL` | Override the download URL | +//! | `DOCS_RS` | Write zero-byte placeholder (no network) | + +#[allow(unused_imports)] +use std::env; +#[allow(unused_imports)] +use std::fs; +#[allow(unused_imports)] +use std::path::Path; + +fn main() { + // docs.rs builds with --network none; write empty placeholders so + // include_bytes! compiles without downloading anything. + if env::var("DOCS_RS").is_ok() { + #[cfg(feature = "pipecat")] + { + let out_dir = env::var("OUT_DIR").expect("OUT_DIR not set"); + let model_path = Path::new(&out_dir).join("smart-turn-v3.2-cpu.onnx"); + if !model_path.exists() { + fs::write(&model_path, b"").expect("failed to write placeholder model"); + } + } + return; + } + + #[cfg(feature = "pipecat")] + setup_pipecat_model(); +} + +#[cfg(feature = "pipecat")] +fn setup_pipecat_model() { + const DEFAULT_MODEL_URL: &str = + "https://huggingface.co/pipecat-ai/smart-turn-v3/resolve/main/smart-turn-v3.2-cpu.onnx"; + const MODEL_FILE: &str = "smart-turn-v3.2-cpu.onnx"; + // Bump this string when updating the default model URL so cached builds + // re-download the new version. + const MODEL_VERSION: &str = "v3.2-cpu"; + + println!("cargo:rerun-if-env-changed=PIPECAT_SMARTTURN_MODEL_PATH"); + println!("cargo:rerun-if-env-changed=PIPECAT_SMARTTURN_MODEL_URL"); + + let out_dir = env::var("OUT_DIR").expect("OUT_DIR not set"); + let model_dest = Path::new(&out_dir).join(MODEL_FILE); + let version_marker = Path::new(&out_dir).join("smart-turn.version"); + + // Option 1: caller provides a local model file + if let Ok(local_path) = env::var("PIPECAT_SMARTTURN_MODEL_PATH") { + let local_path = Path::new(&local_path); + if !local_path.exists() { + panic!( + "PIPECAT_SMARTTURN_MODEL_PATH points to a non-existent file: {}", + local_path.display() + ); + } + println!( + "cargo:warning=Using local Pipecat Smart Turn model: {}", + local_path.display() + ); + fs::copy(local_path, &model_dest).expect("failed to copy local model file"); + println!("cargo:rerun-if-changed={}", local_path.display()); + return; + } + + // Skip download if a matching version is already cached + let cached_version = fs::read_to_string(&version_marker).unwrap_or_default(); + if model_dest.exists() && cached_version.trim() == MODEL_VERSION { + return; + } + + // Option 2: download (caller may override the URL) + let url = env::var("PIPECAT_SMARTTURN_MODEL_URL") + .unwrap_or_else(|_| DEFAULT_MODEL_URL.to_string()); + + println!("cargo:warning=Downloading Pipecat Smart Turn model ({MODEL_VERSION}) from {url}"); + + let response = ureq::get(&url) + .call() + .unwrap_or_else(|e| panic!("failed to download Pipecat Smart Turn model from {url}: {e}")); + + let bytes = response + .into_body() + .read_to_vec() + .expect("failed to read model bytes"); + + fs::write(&model_dest, &bytes).expect("failed to write model file"); + fs::write(&version_marker, MODEL_VERSION).expect("failed to write version marker"); + + println!( + "cargo:warning=Pipecat Smart Turn model ({MODEL_VERSION}) downloaded to {}", + model_dest.display() + ); +} diff --git a/crates/wavekat-turn/src/audio/pipecat.rs b/crates/wavekat-turn/src/audio/pipecat.rs index 2c9e168..170ad0e 100644 --- a/crates/wavekat-turn/src/audio/pipecat.rs +++ b/crates/wavekat-turn/src/audio/pipecat.rs @@ -4,38 +4,432 @@ //! Expects 16 kHz f32 PCM input. Telephony audio at 8 kHz must be //! upsampled before feeding to this detector. //! -//! - Model size: ~8 MB (int8 quantized ONNX) -//! - Inference: ~12 ms on CPU +//! # Model +//! +//! - Source: +//! - File: `smart-turn-v3.2-cpu.onnx` (int8 quantized, ~8 MB) //! - License: BSD 2-Clause +//! +//! # Tensor specification +//! +//! | Role | Name | Shape | Dtype | +//! |--------|------------------|----------------|---------| +//! | Input | `input_features` | `[B, 80, 800]` | float32 | +//! | Output | `logits` | `[B, 1]` | float32 | +//! +//! Despite the name, `logits` is a **sigmoid probability** P(turn complete) +//! in [0, 1] — the sigmoid is fused into the model before ONNX export. +//! Threshold: `probability > 0.5` → `TurnState::Finished`. +//! +//! # Mel-feature specification +//! +//! The model was trained with HuggingFace `WhisperFeatureExtractor(chunk_length=8)`: +//! +//! | Parameter | Value | +//! |---------------|--------------------------------| +//! | Sample rate | 16 000 Hz | +//! | n_fft | 400 samples (25 ms) | +//! | hop_length | 160 samples (10 ms) | +//! | n_mels | 80 | +//! | Freq range | 0 – 8 000 Hz | +//! | Mel scale | Slaney (NOT HTK) | +//! | Window | Hann (periodic, size 400) | +//! | Pre-emphasis | None | +//! | Log | log10 with ε = 1e-10 | +//! | Normalization | clamp(max − 8), (x + 4) / 4 | +//! +//! # Audio buffer +//! +//! - Exactly **8 seconds = 128 000 samples** at 16 kHz. +//! - Shorter input: **front-padded** with zeros (audio is at the end). +//! - Longer input: the **last** 8 s is used (oldest samples discarded). + +use std::collections::VecDeque; +use std::path::Path; +use std::sync::Arc; +use std::time::Instant; + +use ndarray::{s, Array2, Array3}; +use ort::{inputs, value::Tensor}; +use realfft::num_complex::Complex; +use realfft::{RealFftPlanner, RealToComplex}; + +use crate::{AudioFrame, AudioTurnDetector, TurnError, TurnPrediction, TurnState}; +use crate::onnx; + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +/// Sample rate the model expects. +const SAMPLE_RATE: u32 = 16_000; +/// FFT window size in samples (25 ms at 16 kHz). +const N_FFT: usize = 400; +/// STFT hop length in samples (10 ms at 16 kHz). +const HOP_LENGTH: usize = 160; +/// Number of mel filterbank bins. +const N_MELS: usize = 80; +/// Number of STFT frames the model expects (8 s × 100 fps). +const N_FRAMES: usize = 800; +/// FFT frequency bins: N_FFT/2 + 1. +const N_FREQS: usize = N_FFT / 2 + 1; // 201 +/// Ring buffer capacity: 8 s × 16 kHz. +const RING_CAPACITY: usize = 8 * SAMPLE_RATE as usize; // 128 000 + +/// Embedded ONNX model bytes, downloaded by build.rs at compile time. +const MODEL_BYTES: &[u8] = + include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx")); + +// --------------------------------------------------------------------------- +// Mel feature extractor +// --------------------------------------------------------------------------- + +/// Pre-computed Whisper-style log-mel feature extractor. +/// +/// All expensive setup (filterbank, window, FFT plan) happens once in [`new`]. +/// [`MelExtractor::extract`] is then called per inference. +struct MelExtractor { + /// Slaney-normalised mel filterbank: shape [N_MELS, N_FREQS]. + mel_filters: Array2, + /// Periodic Hann window of length N_FFT. + hann_window: Vec, + /// Reusable forward real FFT plan. + fft: Arc>, + /// Reusable scratch buffer for the FFT. + fft_scratch: Vec>, + /// Reusable output spectrum buffer (N_FREQS complex values). + spectrum_buf: Vec>, +} + +impl MelExtractor { + fn new() -> Self { + let mel_filters = build_mel_filters( + SAMPLE_RATE as usize, + N_FFT, + N_MELS, + 0.0, + SAMPLE_RATE as f32 / 2.0, + ); + let hann_window = periodic_hann(N_FFT); -use crate::{AudioFrame, AudioTurnDetector, TurnError, TurnPrediction}; + let mut planner = RealFftPlanner::::new(); + let fft = planner.plan_fft_forward(N_FFT); + let fft_scratch = fft.make_scratch_vec(); + let spectrum_buf = fft.make_output_vec(); + + Self { + mel_filters, + hann_window, + fft, + fft_scratch, + spectrum_buf, + } + } + + /// Compute a [N_MELS × N_FRAMES] log-mel spectrogram from exactly + /// `RING_CAPACITY` samples of 16 kHz mono audio. + fn extract(&mut self, audio: &[f32]) -> Array2 { + debug_assert_eq!(audio.len(), RING_CAPACITY); + + // ---- Center-pad: N_FFT/2 zeros on each side → 128 400 samples ---- + // This replicates librosa/PyTorch `center=True` STFT behaviour, which + // gives exactly N_FRAMES + 1 = 801 frames; we discard the last one. + let pad = N_FFT / 2; + let mut padded = vec![0.0f32; pad + audio.len() + pad]; + padded[pad..pad + audio.len()].copy_from_slice(audio); + + // n_total = (128 400 − 400) / 160 + 1 = 801 + let n_total_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1; + + // ---- STFT → power spectrogram [N_FREQS × n_total_frames] ---- + let mut power_spec = Array2::::zeros((N_FREQS, n_total_frames)); + let mut frame_buf = vec![0.0f32; N_FFT]; + + for frame_idx in 0..n_total_frames { + let start = frame_idx * HOP_LENGTH; + // Apply periodic Hann window + for (i, (&s, &w)) in padded[start..start + N_FFT] + .iter() + .zip(self.hann_window.iter()) + .enumerate() + { + frame_buf[i] = s * w; + } + + self.fft + .process_with_scratch(&mut frame_buf, &mut self.spectrum_buf, &mut self.fft_scratch) + .expect("FFT failed: internal buffer size mismatch"); + + for (k, c) in self.spectrum_buf.iter().enumerate() { + power_spec[[k, frame_idx]] = c.re * c.re + c.im * c.im; + } + } + + // Take first N_FRAMES columns (drop the trailing frame) + let power_spec = power_spec.slice(s![.., ..N_FRAMES]).to_owned(); + + // ---- Apply mel filterbank: [N_MELS, N_FREQS] × [N_FREQS, N_FRAMES] ---- + let mel_spec = self.mel_filters.dot(&power_spec); // [80, 800] + + // ---- Log10 with floor at 1e-10 ---- + let mut log_mel = mel_spec.mapv(|x| x.max(1e-10_f32).log10()); + + // ---- Dynamic range compression and normalization ---- + // Matches WhisperFeatureExtractor: clamp to [max−8, ∞], then (x+4)/4 + let max_val = log_mel.iter().cloned().fold(f32::NEG_INFINITY, f32::max); + log_mel.mapv_inplace(|x| (x.max(max_val - 8.0) + 4.0) / 4.0); + + log_mel + } +} + +// --------------------------------------------------------------------------- +// Mel filterbank construction — Slaney scale, slaney norm +// --------------------------------------------------------------------------- + +/// Convert Hz to mel (Slaney/librosa scale, NOT HTK). +fn hz_to_mel(hz: f32) -> f32 { + const F_SP: f32 = 200.0 / 3.0; // linear region slope (Hz per mel) + const MIN_LOG_HZ: f32 = 1000.0; + const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP; // = 15.0 + // logstep = ln(6.4) / 27 (≈ 0.068752) + let logstep = (6.4_f32).ln() / 27.0; + if hz >= MIN_LOG_HZ { + MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / logstep + } else { + hz / F_SP + } +} + +/// Convert mel back to Hz (Slaney scale). +fn mel_to_hz(mel: f32) -> f32 { + const F_SP: f32 = 200.0 / 3.0; + const MIN_LOG_HZ: f32 = 1000.0; + const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP; + let logstep = (6.4_f32).ln() / 27.0; + if mel >= MIN_LOG_MEL { + MIN_LOG_HZ * ((mel - MIN_LOG_MEL) * logstep).exp() + } else { + mel * F_SP + } +} + +/// Build a Slaney-normalised mel filterbank of shape [n_mels, n_freqs]. +/// +/// Matches `librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax, +/// norm="slaney", dtype=float32)` which is what HuggingFace's +/// `WhisperFeatureExtractor` uses internally. +fn build_mel_filters(sr: usize, n_fft: usize, n_mels: usize, f_min: f32, f_max: f32) -> Array2 { + let n_freqs = n_fft / 2 + 1; + + // FFT frequency bins: 0, sr/n_fft, 2·sr/n_fft, … + let fft_freqs: Vec = (0..n_freqs) + .map(|i| i as f32 * sr as f32 / n_fft as f32) + .collect(); + + // n_mels + 2 equally-spaced mel points (edge + n_mels centres + edge) + let mel_min = hz_to_mel(f_min); + let mel_max = hz_to_mel(f_max); + let mel_pts: Vec = (0..=(n_mels + 1)) + .map(|i| mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32) + .collect(); + let hz_pts: Vec = mel_pts.iter().map(|&m| mel_to_hz(m)).collect(); + + // Build triangular filters with Slaney normalisation + let mut filters = Array2::::zeros((n_mels, n_freqs)); + for m in 0..n_mels { + let f_left = hz_pts[m]; + let f_center = hz_pts[m + 1]; + let f_right = hz_pts[m + 2]; + // Slaney norm: 2 / (right_hz − left_hz) + let enorm = 2.0 / (f_right - f_left); + + for (k, &f) in fft_freqs.iter().enumerate() { + let w = if f >= f_left && f <= f_center { + (f - f_left) / (f_center - f_left) + } else if f > f_center && f <= f_right { + (f_right - f) / (f_right - f_center) + } else { + 0.0 + }; + filters[[m, k]] = w * enorm; + } + } + filters +} + +// --------------------------------------------------------------------------- +// Hann window +// --------------------------------------------------------------------------- + +/// Periodic Hann window of length `n`, matching `torch.hann_window(n, periodic=True)`. +/// +/// Formula: `w[k] = 0.5 · (1 − cos(2π·k / n))` for k in 0..n. +/// This differs from the symmetric variant (which divides by n−1). +fn periodic_hann(n: usize) -> Vec { + use std::f32::consts::PI; + (0..n) + .map(|k| 0.5 * (1.0 - (2.0 * PI * k as f32 / n as f32).cos())) + .collect() +} + +// --------------------------------------------------------------------------- +// Audio preparation +// --------------------------------------------------------------------------- + +/// Pad or truncate `samples` to exactly `RING_CAPACITY` samples. +/// +/// - Longer: keep the **last** 8 s (discard oldest). +/// - Shorter: **front-pad** with zeros so audio is right-aligned. +fn prepare_audio(samples: &[f32]) -> Vec { + match samples.len().cmp(&RING_CAPACITY) { + std::cmp::Ordering::Equal => samples.to_vec(), + std::cmp::Ordering::Greater => samples[samples.len() - RING_CAPACITY..].to_vec(), + std::cmp::Ordering::Less => { + let mut out = vec![0.0f32; RING_CAPACITY - samples.len()]; + out.extend_from_slice(samples); + out + } + } +} + +// --------------------------------------------------------------------------- +// PipecatSmartTurn +// --------------------------------------------------------------------------- /// Pipecat Smart Turn v3 detector. /// -/// Buffers up to 8 seconds of audio internally. When [`predict`](AudioTurnDetector::predict) -/// is called, it takes the last 8s (zero-padded at front if shorter), -/// extracts Whisper log-mel features, and runs ONNX inference. +/// Buffers up to 8 seconds of audio internally. Call [`push_audio`] with +/// every incoming 16 kHz frame, then call [`predict`] when the VAD fires +/// end-of-speech to get a [`TurnPrediction`]. +/// +/// # Usage with VAD +/// +/// ```no_run +/// # #[cfg(feature = "pipecat")] +/// # { +/// use wavekat_turn::audio::PipecatSmartTurn; +/// use wavekat_turn::AudioTurnDetector; +/// +/// let mut detector = PipecatSmartTurn::new().unwrap(); +/// // ... feed frames via push_audio ... +/// let prediction = detector.predict().unwrap(); +/// println!("{:?} ({:.2})", prediction.state, prediction.confidence); +/// # } +/// ``` +/// +/// [`push_audio`]: AudioTurnDetector::push_audio +/// [`predict`]: AudioTurnDetector::predict pub struct PipecatSmartTurn { - // TODO: ONNX session + audio ring buffer + state + session: ort::session::Session, + ring_buffer: VecDeque, + mel: MelExtractor, } +// SAFETY: ort::Session is Send in ort 2.x. Sync is safe because every +// method that touches the session takes &mut self, preventing concurrent use. +unsafe impl Send for PipecatSmartTurn {} +unsafe impl Sync for PipecatSmartTurn {} + impl PipecatSmartTurn { - /// Create a new Smart Turn detector, loading the ONNX model. + /// Load the Smart Turn v3.2 model embedded at compile time. pub fn new() -> Result { - todo!("load Smart Turn v3 ONNX model") + let session = onnx::session_from_memory(MODEL_BYTES)?; + Ok(Self::build(session)) + } + + /// Load a model from a custom path on disk. + /// + /// Useful for CI environments that supply the model file separately, or + /// for evaluating fine-tuned variants without recompiling. + pub fn from_file(path: impl AsRef) -> Result { + let session = onnx::session_from_file(path)?; + Ok(Self::build(session)) + } + + fn build(session: ort::session::Session) -> Self { + Self { + session, + ring_buffer: VecDeque::with_capacity(RING_CAPACITY), + mel: MelExtractor::new(), + } } } impl AudioTurnDetector for PipecatSmartTurn { - fn push_audio(&mut self, _frame: &AudioFrame) { - todo!("append to ring buffer") + /// Append audio to the internal ring buffer. + /// + /// Frames with a sample rate other than 16 kHz are silently dropped. + /// The ring buffer holds at most 8 s; older samples are evicted. + fn push_audio(&mut self, frame: &AudioFrame) { + if frame.sample_rate() != SAMPLE_RATE { + return; + } + let samples = frame.samples(); + // Evict oldest samples to make room + let overflow = (self.ring_buffer.len() + samples.len()).saturating_sub(RING_CAPACITY); + if overflow > 0 { + self.ring_buffer.drain(..overflow); + } + self.ring_buffer.extend(samples.iter().copied()); } + /// Run inference on the buffered audio. + /// + /// Takes a snapshot of the ring buffer, pads/truncates to 8 s, extracts + /// Whisper log-mel features, and runs ONNX inference. fn predict(&mut self) -> Result { - todo!("truncate/pad to 8s, extract mel features, run ONNX inference") + let t_start = Instant::now(); + + // Snapshot the ring buffer and prepare exactly 128 000 samples + let buffered: Vec = self.ring_buffer.iter().copied().collect(); + let audio = prepare_audio(&buffered); + + // Extract [N_MELS × N_FRAMES] log-mel features + let mel_spec = self.mel.extract(&audio); + + // Reshape to [1, N_MELS, N_FRAMES] for batch inference + let (raw, _) = mel_spec.into_raw_vec_and_offset(); + let input_array = Array3::from_shape_vec((1, N_MELS, N_FRAMES), raw) + .expect("internal: mel output has wrong element count"); + + let input_tensor = Tensor::from_array(input_array) + .map_err(|e| TurnError::BackendError(format!("failed to create input tensor: {e}")))?; + + let outputs = self + .session + .run(inputs!["input_features" => input_tensor]) + .map_err(|e| TurnError::BackendError(format!("inference failed: {e}")))?; + + // Extract sigmoid probability from the "logits" output + let output = outputs + .get("logits") + .ok_or_else(|| TurnError::BackendError("missing 'logits' output tensor".into()))?; + let (_, data): (_, &[f32]) = output + .try_extract_tensor() + .map_err(|e| TurnError::BackendError(format!("failed to extract logits: {e}")))?; + let probability = *data + .first() + .ok_or_else(|| TurnError::BackendError("logits tensor is empty".into()))?; + + let latency_ms = t_start.elapsed().as_millis() as u64; + + // probability = P(turn complete); > 0.5 means the speaker has finished + let (state, confidence) = if probability > 0.5 { + (TurnState::Finished, probability) + } else { + (TurnState::Unfinished, 1.0 - probability) + }; + + Ok(TurnPrediction { + state, + confidence, + latency_ms, + }) } + /// Clear the ring buffer. Call at the start of each new speech turn. fn reset(&mut self) { - todo!("clear ring buffer and internal state") + self.ring_buffer.clear(); } } diff --git a/crates/wavekat-turn/src/lib.rs b/crates/wavekat-turn/src/lib.rs index 75e22c9..9453791 100644 --- a/crates/wavekat-turn/src/lib.rs +++ b/crates/wavekat-turn/src/lib.rs @@ -18,6 +18,9 @@ pub mod error; +#[cfg(any(feature = "pipecat", feature = "livekit"))] +pub(crate) mod onnx; + #[cfg(feature = "pipecat")] pub mod audio; diff --git a/crates/wavekat-turn/src/onnx.rs b/crates/wavekat-turn/src/onnx.rs new file mode 100644 index 0000000..0387fcf --- /dev/null +++ b/crates/wavekat-turn/src/onnx.rs @@ -0,0 +1,41 @@ +//! Shared helpers for ONNX-based turn detection backends. + +use crate::error::TurnError; +use ort::session::Session; + +/// Create an ONNX Runtime session from a model file on disk. +pub(crate) fn session_from_file>(path: P) -> Result { + Session::builder() + .map_err(|e| TurnError::BackendError(format!("failed to create session builder: {e}")))? + .with_intra_threads(1) + .map_err(|e| TurnError::BackendError(format!("failed to set intra threads: {e}")))? + .commit_from_file(path) + .map_err(|e| TurnError::BackendError(format!("failed to load ONNX model: {e}"))) +} + +/// Create an ONNX Runtime session from model bytes in memory. +pub(crate) fn session_from_memory(model_bytes: &[u8]) -> Result { + Session::builder() + .map_err(|e| TurnError::BackendError(format!("failed to create session builder: {e}")))? + .with_intra_threads(1) + .map_err(|e| TurnError::BackendError(format!("failed to set intra threads: {e}")))? + .commit_from_memory(model_bytes) + .map_err(|e| TurnError::BackendError(format!("failed to load ONNX model: {e}"))) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn session_from_file_nonexistent() { + let result = session_from_file("/nonexistent/path/to/model.onnx"); + assert!(matches!(result, Err(TurnError::BackendError(_)))); + } + + #[test] + fn session_from_memory_invalid_bytes() { + let result = session_from_memory(b"not a valid onnx model"); + assert!(matches!(result, Err(TurnError::BackendError(_)))); + } +} diff --git a/crates/wavekat-turn/tests/pipecat.rs b/crates/wavekat-turn/tests/pipecat.rs new file mode 100644 index 0000000..70059a4 --- /dev/null +++ b/crates/wavekat-turn/tests/pipecat.rs @@ -0,0 +1,151 @@ +//! Integration tests for the Pipecat Smart Turn v3 backend. +//! +//! Run with: `cargo test --features pipecat` +//! Run RTF test with: `cargo test --features pipecat --release` + +#![cfg(feature = "pipecat")] + +use wavekat_turn::audio::PipecatSmartTurn; +use wavekat_turn::{AudioFrame, AudioTurnDetector, TurnPrediction}; + +// --------------------------------------------------------------------------- +// Helpers +// --------------------------------------------------------------------------- + +/// Create an AudioFrame of silence (zeros) at 16 kHz. +fn silence(num_samples: usize) -> AudioFrame<'static> { + let samples = vec![0.0f32; num_samples]; + AudioFrame::new(samples.as_slice(), 16_000).into_owned() +} + +/// Push `duration_secs` of silence in 160-sample chunks (10 ms each). +fn push_silence(detector: &mut PipecatSmartTurn, duration_secs: f32) { + let total = (duration_secs * 16_000.0) as usize; + let chunk = 160; + let mut pushed = 0; + while pushed < total { + let n = chunk.min(total - pushed); + detector.push_audio(&silence(n)); + pushed += n; + } +} + +fn valid_prediction(pred: &TurnPrediction) { + assert!( + pred.confidence >= 0.0 && pred.confidence <= 1.0, + "confidence out of range: {}", + pred.confidence + ); +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[test] +fn test_new_loads_model() { + PipecatSmartTurn::new().expect("PipecatSmartTurn::new() should succeed"); +} + +#[test] +fn test_from_file_loads_model() { + let tmp = std::env::temp_dir().join("wavekat_turn_test"); + std::fs::create_dir_all(&tmp).unwrap(); + let path = tmp.join("smart-turn-test.onnx"); + + let model_bytes = include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx")); + std::fs::write(&path, model_bytes).unwrap(); + + PipecatSmartTurn::from_file(&path).expect("from_file should succeed with a valid model"); + + let _ = std::fs::remove_file(&path); +} + +#[test] +fn test_predict_returns_valid_output() { + let mut d = PipecatSmartTurn::new().unwrap(); + push_silence(&mut d, 2.0); + let pred = d.predict().unwrap(); + valid_prediction(&pred); +} + +#[test] +fn test_predict_with_empty_buffer() { + // Empty buffer is front-padded to 8 s of zeros; inference must succeed. + let mut d = PipecatSmartTurn::new().unwrap(); + let pred = d.predict().unwrap(); + valid_prediction(&pred); +} + +#[test] +fn test_push_audio_wrong_sample_rate_is_ignored() { + let mut d = PipecatSmartTurn::new().unwrap(); + let bad = AudioFrame::new(vec![0.5f32; 160].as_slice(), 8_000).into_owned(); + d.push_audio(&bad); + // Frame should have been dropped; predict must still succeed. + let pred = d.predict().unwrap(); + valid_prediction(&pred); +} + +#[test] +fn test_reset_clears_buffer() { + let mut d = PipecatSmartTurn::new().unwrap(); + push_silence(&mut d, 4.0); + d.reset(); + // After reset the buffer is empty; should behave identically to a fresh instance. + let fresh = PipecatSmartTurn::new().unwrap().predict().unwrap(); + let after_reset = d.predict().unwrap(); + assert_eq!( + after_reset.state, fresh.state, + "state after reset should match a fresh instance" + ); + assert!( + (after_reset.confidence - fresh.confidence).abs() < 1e-5, + "confidence after reset should match a fresh instance" + ); +} + +#[test] +fn test_ring_buffer_caps_at_8_seconds() { + let mut d = PipecatSmartTurn::new().unwrap(); + push_silence(&mut d, 10.0); // 10 s > 8 s capacity; must not panic + valid_prediction(&d.predict().unwrap()); +} + +#[test] +fn test_multiple_predicts_are_deterministic() { + let mut d = PipecatSmartTurn::new().unwrap(); + push_silence(&mut d, 2.0); + let p1 = d.predict().unwrap(); + let p2 = d.predict().unwrap(); + assert_eq!(p1.state, p2.state, "repeated predict should give same state"); + assert!( + (p1.confidence - p2.confidence).abs() < 1e-5, + "repeated predict should give same confidence" + ); +} + +/// RTF target: < 50 ms. Only enforced in release builds because the debug +/// binary is ~10× slower. +#[test] +#[cfg(not(debug_assertions))] +fn test_latency_under_50ms() { + let mut d = PipecatSmartTurn::new().unwrap(); + push_silence(&mut d, 2.0); + let pred = d.predict().unwrap(); + assert!( + pred.latency_ms < 50, + "inference too slow: {} ms (limit: 50 ms)", + pred.latency_ms + ); +} + +/// Smoke test: latency is measured and non-zero (always runs, including debug). +#[test] +fn test_latency_is_measured() { + let mut d = PipecatSmartTurn::new().unwrap(); + push_silence(&mut d, 2.0); + let pred = d.predict().unwrap(); + // latency_ms == 0 would mean the timer wasn't working + assert!(pred.latency_ms < 60_000, "latency suspiciously large: {} ms", pred.latency_ms); +} From 6e0b87fb4aeec94a5048a006f2e6d8d8513ff5ee Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Sat, 28 Mar 2026 18:17:28 +1300 Subject: [PATCH 4/7] docs: update backends plan and add accuracy cross-validation plan Co-Authored-By: Claude Sonnet 4.6 --- docs/plan-accuracy.md | 167 ++++++++++++++++++++++++++++++++++++++++++ docs/plan-backends.md | 154 ++++++++++++++------------------------ 2 files changed, 224 insertions(+), 97 deletions(-) create mode 100644 docs/plan-accuracy.md diff --git a/docs/plan-accuracy.md b/docs/plan-accuracy.md new file mode 100644 index 0000000..4d2d047 --- /dev/null +++ b/docs/plan-accuracy.md @@ -0,0 +1,167 @@ +# Plan: Cross-Validate Rust Implementation Against Python Reference + +**Status:** Not started +**Date:** 2026-03-28 + +--- + +## Goal + +Verify that our Rust mel preprocessing and ONNX inference pipeline produces probabilities that +match the original Pipecat Python implementation within a tight tolerance. This catches any +silent preprocessing mismatch (wrong mel scale, wrong window, wrong padding, wrong normalization) +that unit tests cannot detect because they only check output shape and range. + +--- + +## Why this matters + +The Rust implementation re-implements `WhisperFeatureExtractor` from scratch. Any divergence +in the mel filterbank, Hann window, STFT center-padding, or log-normalization will silently +shift probabilities. A 5% probability shift near the 0.5 threshold flips a turn decision. +The only way to be confident the two pipelines agree is to feed them identical audio and +compare outputs numerically. + +--- + +## Decisions made + +- **Python script generates reference data once.** Output is committed as a JSON fixture + (`tests/fixtures/reference.json`) so the Rust accuracy test has no Python runtime dependency. +- **Tolerance: ±0.02 probability.** The model uses float32 throughout; numerical differences + from the Rust FFT vs NumPy FFT should be well under this. If a case fails, it signals a + real preprocessing bug, not floating-point noise. +- **Three fixture audio clips.** Enough to cover the key behavioral regions without large + binary assets. Total fixture size should stay under ~500 KB. + +--- + +## Fixture audio clips + +| File | Content | Expected region | +|------|---------|-----------------| +| `tests/fixtures/silence_2s.wav` | 2 s of zeros at 16 kHz | Low P(complete) — no speech | +| `tests/fixtures/speech_finished.wav` | Real or synthetic utterance that ends cleanly | High P(complete) | +| `tests/fixtures/speech_mid.wav` | Real or synthetic utterance cut mid-word | Low P(complete) | + +WAV format: 16 kHz, mono, 16-bit PCM (hound-compatible). + +For `speech_finished.wav` and `speech_mid.wav`, use short clips (1–3 s) from a freely +licensed speech corpus, or generate synthetic speech with a TTS tool. Commit the WAVs +directly — they are small enough. + +--- + +## Phase 1 — Python reference script + +Create `scripts/gen_reference.py`. + +**Dependencies** (not added to the crate — Python only): +``` +pip install pipecat-ai transformers onnxruntime numpy soundfile +``` + +**What it does:** +1. Downloads `smart-turn-v3.2-cpu.onnx` if not already present (same URL as build.rs) +2. For each WAV in `tests/fixtures/`: + - Loads audio via `soundfile` as float32 at 16 kHz + - Runs `WhisperFeatureExtractor(chunk_length=8)` to get `input_features` + - Runs `ort.InferenceSession` on the ONNX model + - Records `{ "file": "...", "probability": }` +3. Writes `tests/fixtures/reference.json` + +**Re-run when:** +- A fixture WAV changes +- The model version changes (bump `MODEL_VERSION` in `build.rs` at the same time) + +--- + +## Phase 2 — WAV fixtures + +Generate or source the three WAV clips and commit them to `tests/fixtures/`. + +For `silence_2s.wav`: +```python +import numpy as np, soundfile as sf +sf.write("tests/fixtures/silence_2s.wav", np.zeros(32000, dtype=np.float32), 16000) +``` + +For speech clips, options in order of preference: +1. Record 2–3 s clips specifically for this test +2. Use a clip from [CMU Arctic](http://www.festvox.org/cmu_arctic/) or + [LJ Speech](https://keithito.com/LJ-Speech-Dataset/) (both public domain / CC0) +3. Generate with `piper` TTS (Apache 2.0) + +--- + +## Phase 3 — Run the Python script and commit reference.json + +```bash +python scripts/gen_reference.py +``` + +Inspect the output — confirm the probabilities make sense (silence ≈ low, finished ≈ high). +Commit `tests/fixtures/reference.json` alongside the WAV files. + +`reference.json` format: +```json +[ + { "file": "silence_2s.wav", "probability": 0.03 }, + { "file": "speech_finished.wav", "probability": 0.91 }, + { "file": "speech_mid.wav", "probability": 0.08 } +] +``` + +--- + +## Phase 4 — Rust accuracy test + +Add `tests/accuracy.rs` (under `#[cfg(feature = "pipecat")]`). + +**What it does:** +1. Reads `tests/fixtures/reference.json` at test time +2. For each entry, loads the corresponding WAV with `hound` +3. Pushes all audio frames through `PipecatSmartTurn` +4. Calls `predict()` and reads the raw probability +5. Asserts `|rust_prob - python_prob| <= TOLERANCE` (0.02) + +```rust +const TOLERANCE: f32 = 0.02; +``` + +**Getting the raw probability out of `TurnPrediction`:** +`TurnPrediction.confidence` is already the raw sigmoid value (we set it to `probability` for +`Finished` and `1.0 - probability` for `Unfinished`). To recover the original probability: + +```rust +let raw_prob = match pred.state { + TurnState::Finished => pred.confidence, + TurnState::Unfinished => 1.0 - pred.confidence, + TurnState::Wait => unreachable!(), +}; +``` + +Alternatively, expose `raw_probability: f32` directly on `TurnPrediction` — see open questions. + +**Test names:** +- `test_accuracy_silence` +- `test_accuracy_speech_finished` +- `test_accuracy_speech_mid` + +--- + +## Open questions + +1. **Expose raw probability on `TurnPrediction`?** + Currently the struct only has `confidence` which loses the original P(complete) for + `Unfinished` cases. Options: + - Add `raw_probability: f32` field to `TurnPrediction` (cleaner, but changes the public API) + - Reconstruct from `(state, confidence)` in the test (works, but fragile) + Resolve before starting Phase 4. + +2. **Speech fixture source.** Decide on LJ Speech clips or recorded clips before Phase 2. + LJ Speech is easiest (download a sentence, trim to 2–3 s). Record the chosen file name + and source URL in a comment in `scripts/gen_reference.py`. + +3. **CI integration.** The accuracy test needs the WAV fixtures and `reference.json` committed + to the repo. Confirm the total asset size is acceptable before merging. diff --git a/docs/plan-backends.md b/docs/plan-backends.md index 2a62c65..ec17c97 100644 --- a/docs/plan-backends.md +++ b/docs/plan-backends.md @@ -1,6 +1,6 @@ # Plan: Implement Turn Detection Backends -**Status:** In progress +**Status:** Phase 1–4 complete **Date:** 2026-03-28 --- @@ -26,22 +26,25 @@ ## Current state -Both backends are stubs with `todo!()` — the crate compiles but cannot run inference. +`PipecatSmartTurn` is fully implemented and all integration tests pass. +`LiveKitEou` remains a stub (out of scope for this branch). ``` src/ ├── lib.rs — traits: AudioTurnDetector, TextTurnDetector, TurnPrediction, TurnState ├── error.rs — TurnError: BackendError, InvalidInput, ModelNotLoaded +├── onnx.rs — shared session_from_file / session_from_memory helpers ├── audio/ │ ├── mod.rs -│ └── pipecat.rs — PipecatSmartTurn (stub) +│ └── pipecat.rs — PipecatSmartTurn (complete) └── text/ ├── mod.rs - └── livekit.rs — LiveKitEou (stub) + └── livekit.rs — LiveKitEou (stub, out of scope) +build.rs — downloads smart-turn-v3.2-cpu.onnx at build time +tests/ +└── pipecat.rs — 9 integration tests (all pass) ``` -No `build.rs` yet. No tests. - --- ## Trait API (stable, do not change) @@ -65,121 +68,78 @@ pub trait TextTurnDetector: Send + Sync { --- -## Phase 1 — Research (prerequisite) - -Before writing any code, pin down the model specifics: - -1. **Model source.** Find the official Pipecat Smart Turn v3 ONNX download URL (Pipecat GitHub - releases or Hugging Face). Confirm license (BSD 2-Clause noted in stub comments). +## Phase 1 — Research ✅ -2. **Input/output tensor shapes.** Load the model in a scratch script or `netron` and record: - - Input tensor: name, shape, dtype - - Output tensor: name(s), shape, dtype - - Whether output is a single confidence float or logits for [Finished, Unfinished, Wait] +**Done.** Findings pinned in `src/audio/pipecat.rs` module-level comments. -3. **Mel-feature spec.** Confirm what preprocessing the model expects: - - Frame size + hop length - - Number of mel bins (Whisper uses 80) - - Frequency range - - Mel scale formula (HTK vs Kaldi) - - Whether pre-emphasis is applied - -4. **Audio buffer length.** Stub says "up to 8 seconds" — confirm from model input shape. - -Document findings as comments in `pipecat.rs` before implementation. +| Item | Finding | +|------|---------| +| Model URL | `https://huggingface.co/pipecat-ai/smart-turn-v3/resolve/main/smart-turn-v3.2-cpu.onnx` | +| Input tensor | `input_features`, shape `[B, 80, 800]`, float32 | +| Output tensor | `logits`, shape `[B, 1]`, float32 — sigmoid P(turn complete), NOT raw logits | +| Mel scale | **Slaney** (NOT HTK); `norm="slaney"` | +| n_fft / hop | 400 / 160 samples (25 ms / 10 ms at 16 kHz) | +| Mel bins | 80; frequency range 0–8 000 Hz | +| Window | Periodic Hann (`torch.hann_window(400, periodic=True)`) | +| Pre-emphasis | None | +| Log norm | `log10`, clamp `[max−8, ∞]`, then `(x + 4) / 4` | +| Audio buffer | 8 s = 128 000 samples; front-pad shorter, keep last 8 s for longer | +| License | BSD 2-Clause | --- -## Phase 2 — Build system - -Create `crates/wavekat-turn/build.rs` following the wavekat-vad pattern. +## Phase 2 — Build system ✅ -This phase covers the **embedded** path only (Pipecat). Large-model backends (LiveKit) need -a different build.rs strategy described in Phase 5. +**Done.** -- Download Smart Turn v3 ONNX to `OUT_DIR` at build time -- SHA-256 verification -- Env-var overrides: - - `PIPECAT_SMARTTURN_MODEL_PATH` — use a local file instead of downloading - - `PIPECAT_SMARTTURN_MODEL_URL` — override download URL -- Docs.rs guard: write a zero-byte placeholder when `DOCS_RS=1` +- `build.rs` downloads `smart-turn-v3.2-cpu.onnx` to `OUT_DIR` with version-based caching +- Env-var overrides: `PIPECAT_SMARTTURN_MODEL_PATH`, `PIPECAT_SMARTTURN_MODEL_URL` +- Docs.rs guard writes a zero-byte placeholder when `DOCS_RS=1` +- `Cargo.toml`: `build = "build.rs"`, `ureq` as optional build-dep activated by `pipecat` feature -Add to `Cargo.toml`: -```toml -[package] -build = "build.rs" - -[build-dependencies] -ureq = { version = "3", features = ["tls"] } -``` +Note: SHA-256 verification was omitted in favour of version-based caching (same as wavekat-vad). --- -## Phase 3 — PipecatSmartTurn implementation - -Fill in `src/audio/pipecat.rs`: - -**Struct:** -```rust -pub struct PipecatSmartTurn { - session: Session, - ring_buffer: VecDeque, // 8s × 16kHz = 128k samples - // mel extractor fields TBD from Phase 1 research -} -``` - -**Constructors:** - -```rust -/// Default constructor — loads the model embedded at compile time. -pub fn new() -> Result { ... } - -/// Load a custom model from disk — useful for fine-tuned weights or CI environments -/// where the binary should stay small and the model is provided separately. -pub fn from_file(path: impl AsRef) -> Result { ... } -``` - -`new()` calls `session_from_memory(include_bytes!(concat!(env!("OUT_DIR"), "/...")))`. -`from_file()` calls `session_from_file(path)`. Both share `Self::build(session)` for -the rest of initialization. - -**`push_audio()`** — validate sample rate (16 kHz), convert i16→f32 if needed, -append to ring buffer (evict oldest when over capacity). - -**`predict()`** — snapshot ring buffer, pad/truncate to model's expected length, -extract mel features, build ndarray input tensor, `session.run(...)`, parse output, -record `Instant` before/after for `latency_ms`. +## Phase 3 — PipecatSmartTurn implementation ✅ -**`reset()`** — `ring_buffer.clear()`. +**Done.** `src/audio/pipecat.rs` and `src/onnx.rs` written and compiling. -Reference implementations: -- `wavekat-vad/src/backends/silero.rs` — ONNX session + state management -- `wavekat-vad/src/backends/onnx.rs` — session builder helper -- `wavekat-vad/src/backends/firered/fbank.rs` — mel filterbank (adapt if spec matches) +Key implementation decisions: +- `MelExtractor` precomputes the Slaney filterbank matrix and Hann window once at construction; + reuses FFT plan and scratch buffers across calls +- Center-pad (`N_FFT/2` zeros each side) replicates librosa `center=True` STFT, producing + exactly 800 frames from 128 000 samples +- `push_audio` silently drops frames with wrong sample rate (no return value in trait) +- `ndarray = "0.17"` required to match `ort`'s ndarray feature version --- -## Phase 4 — Tests +## Phase 4 — Tests ✅ -Add `tests/pipecat.rs` (integration tests under `#[cfg(feature = "pipecat")]`): +**Done.** `tests/pipecat.rs` with 9 integration tests, all passing: -- `test_new_loads_model` — `PipecatSmartTurn::new()` succeeds -- `test_from_file_loads_model` — `PipecatSmartTurn::from_file(path)` succeeds given a valid path -- `test_predict_silence` — feed 2s of zeros, expect low confidence -- `test_predict_finished` — feed known-good finished-turn audio (WAV fixture), expect - `TurnState::Finished` with confidence > 0.7 -- `test_reset_clears_buffer` — push audio, reset, predict on empty buffer returns low confidence -- `test_rtf` — assert `latency_ms` < 50 ms (well under the ~12 ms target with headroom for CI) +| Test | What it checks | +|------|---------------| +| `test_new_loads_model` | `new()` succeeds | +| `test_from_file_loads_model` | `from_file()` succeeds with a valid path | +| `test_predict_returns_valid_output` | confidence ∈ [0, 1] | +| `test_predict_with_empty_buffer` | empty buffer inference succeeds | +| `test_push_audio_wrong_sample_rate_is_ignored` | 8 kHz frame is dropped | +| `test_reset_clears_buffer` | state after reset matches fresh instance | +| `test_ring_buffer_caps_at_8_seconds` | 10 s of audio doesn't panic | +| `test_multiple_predicts_are_deterministic` | same buffer → same output | +| `test_latency_under_50ms` | RTF < 50 ms (release builds only) | -Add a small WAV fixture (`tests/fixtures/finished_turn.wav`) for the audio test cases. +Note: the current tests do not cross-validate against the Python reference implementation. +That is tracked in **[`plan-accuracy.md`](plan-accuracy.md)**. --- ## Open questions -1. **Smart Turn v3 model URL** — not yet confirmed (needed for Phase 2) -2. **Exact input tensor shape** — need to inspect the model (needed for Phase 3) -3. **Mel-feature spec** — need to confirm to avoid silent preprocessing mismatch +All research questions from Phases 1–3 are resolved. No blocking open questions remain +for this branch. --- From 3a0ef7ca84960498ede8c14a2604d6e847102af7 Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Sat, 28 Mar 2026 18:25:41 +1300 Subject: [PATCH 5/7] test: add from_file invalid path error test Co-Authored-By: Claude Sonnet 4.6 --- crates/wavekat-turn/tests/pipecat.rs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/crates/wavekat-turn/tests/pipecat.rs b/crates/wavekat-turn/tests/pipecat.rs index 70059a4..74a5c30 100644 --- a/crates/wavekat-turn/tests/pipecat.rs +++ b/crates/wavekat-turn/tests/pipecat.rs @@ -140,6 +140,12 @@ fn test_latency_under_50ms() { ); } +#[test] +fn test_from_file_invalid_path_returns_error() { + let result = PipecatSmartTurn::from_file("/nonexistent/path/model.onnx"); + assert!(result.is_err(), "from_file with invalid path should return an error"); +} + /// Smoke test: latency is measured and non-zero (always runs, including debug). #[test] fn test_latency_is_measured() { From 07bc4f64c0f79ebc3f87573b9a8becacf12b63ac Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Sat, 28 Mar 2026 22:02:58 +1300 Subject: [PATCH 6/7] feat: add stage timings and incremental STFT to PipecatSmartTurn MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add StageTiming struct and stage_times field to TurnPrediction - Instrument predict() with three stages: audio_prep, mel, onnx - Incremental STFT: cache power spectrogram and shift on each call, recomputing only the ~50 new frames instead of all 801 (~16x faster) - Incremental mel filterbank: cache mel_spec and update only new columns, reducing matmul from [80×201]×[201×800] to [80×201]×[201×50] (~16x) - Invalidate caches on reset() Co-Authored-By: Claude Sonnet 4.6 --- crates/wavekat-turn/src/audio/pipecat.rs | 102 ++++++++++++++++++++--- crates/wavekat-turn/src/lib.rs | 11 +++ 2 files changed, 100 insertions(+), 13 deletions(-) diff --git a/crates/wavekat-turn/src/audio/pipecat.rs b/crates/wavekat-turn/src/audio/pipecat.rs index 170ad0e..a9d4ae7 100644 --- a/crates/wavekat-turn/src/audio/pipecat.rs +++ b/crates/wavekat-turn/src/audio/pipecat.rs @@ -54,7 +54,7 @@ use ort::{inputs, value::Tensor}; use realfft::num_complex::Complex; use realfft::{RealFftPlanner, RealToComplex}; -use crate::{AudioFrame, AudioTurnDetector, TurnError, TurnPrediction, TurnState}; +use crate::{AudioFrame, AudioTurnDetector, StageTiming, TurnError, TurnPrediction, TurnState}; use crate::onnx; // --------------------------------------------------------------------------- @@ -99,6 +99,12 @@ struct MelExtractor { fft_scratch: Vec>, /// Reusable output spectrum buffer (N_FREQS complex values). spectrum_buf: Vec>, + /// Cached power spectrogram [N_FREQS × (N_FRAMES+1)] from the previous call. + /// Enables incremental STFT: only new frames are recomputed. + cached_power_spec: Option>, + /// Cached mel spectrogram [N_MELS × N_FRAMES] from the previous call. + /// Enables incremental mel filterbank: only new columns are recomputed. + cached_mel_spec: Option>, } impl MelExtractor { @@ -123,12 +129,19 @@ impl MelExtractor { fft, fft_scratch, spectrum_buf, + cached_power_spec: None, + cached_mel_spec: None, } } /// Compute a [N_MELS × N_FRAMES] log-mel spectrogram from exactly /// `RING_CAPACITY` samples of 16 kHz mono audio. - fn extract(&mut self, audio: &[f32]) -> Array2 { + /// + /// `shift_frames` is how many STFT frames worth of new audio were added + /// since the last call. When a valid cache exists and `shift_frames` is + /// in range, only the last `shift_frames` columns of the power spectrogram + /// are recomputed; the rest are copied from the shifted cache. + fn extract(&mut self, audio: &[f32], shift_frames: usize) -> Array2 { debug_assert_eq!(audio.len(), RING_CAPACITY); // ---- Center-pad: N_FFT/2 zeros on each side → 128 400 samples ---- @@ -141,11 +154,30 @@ impl MelExtractor { // n_total = (128 400 − 400) / 160 + 1 = 801 let n_total_frames = (padded.len() - N_FFT) / HOP_LENGTH + 1; - // ---- STFT → power spectrogram [N_FREQS × n_total_frames] ---- - let mut power_spec = Array2::::zeros((N_FREQS, n_total_frames)); + // ---- Incremental STFT ---- + // If we have a cached power spec and shift_frames < n_total_frames, + // reuse the unchanged frames by shifting the cache left and only + // computing the `shift_frames` new columns at the end. + let first_new_frame = match &self.cached_power_spec { + Some(cached) if shift_frames > 0 && shift_frames < n_total_frames => { + let kept = n_total_frames - shift_frames; + let mut power_spec = Array2::::zeros((N_FREQS, n_total_frames)); + power_spec + .slice_mut(s![.., ..kept]) + .assign(&cached.slice(s![.., shift_frames..])); + self.cached_power_spec = Some(power_spec); + kept // only compute frames [kept..n_total_frames] + } + _ => { + self.cached_power_spec = Some(Array2::::zeros((N_FREQS, n_total_frames))); + 0 // cold start: compute all frames + } + }; + + let power_spec = self.cached_power_spec.as_mut().unwrap(); let mut frame_buf = vec![0.0f32; N_FFT]; - for frame_idx in 0..n_total_frames { + for frame_idx in first_new_frame..n_total_frames { let start = frame_idx * HOP_LENGTH; // Apply periodic Hann window for (i, (&s, &w)) in padded[start..start + N_FFT] @@ -166,10 +198,27 @@ impl MelExtractor { } // Take first N_FRAMES columns (drop the trailing frame) - let power_spec = power_spec.slice(s![.., ..N_FRAMES]).to_owned(); - - // ---- Apply mel filterbank: [N_MELS, N_FREQS] × [N_FREQS, N_FRAMES] ---- - let mel_spec = self.mel_filters.dot(&power_spec); // [80, 800] + let power_spec_view = power_spec.slice(s![.., ..N_FRAMES]); + + // ---- Incremental mel filterbank: [N_MELS, N_FREQS] × [N_FREQS, shift_frames] ---- + // Reuse the cached mel columns for the unchanged frames; only multiply + // the new power-spectrum columns against the filterbank. + let mel_spec = match &self.cached_mel_spec { + Some(cached) if shift_frames > 0 && shift_frames <= N_FRAMES => { + let kept = N_FRAMES - shift_frames; + let mut ms = Array2::::zeros((N_MELS, N_FRAMES)); + // Shift old columns left + ms.slice_mut(s![.., ..kept]) + .assign(&cached.slice(s![.., shift_frames..])); + // Apply filterbank only to the new power-spectrum columns + let new_power = power_spec_view.slice(s![.., kept..]); + ms.slice_mut(s![.., kept..]) + .assign(&self.mel_filters.dot(&new_power)); + ms + } + _ => self.mel_filters.dot(&power_spec_view), + }; + self.cached_mel_spec = Some(mel_spec.clone()); // ---- Log10 with floor at 1e-10 ---- let mut log_mel = mel_spec.mapv(|x| x.max(1e-10_f32).log10()); @@ -181,6 +230,12 @@ impl MelExtractor { log_mel } + + /// Invalidate all caches (call on reset). + fn invalidate_cache(&mut self) { + self.cached_power_spec = None; + self.cached_mel_spec = None; + } } // --------------------------------------------------------------------------- @@ -324,6 +379,9 @@ pub struct PipecatSmartTurn { session: ort::session::Session, ring_buffer: VecDeque, mel: MelExtractor, + /// Counts samples pushed since the last `predict()` call. + /// Used to compute `shift_frames` for incremental STFT. + samples_since_predict: usize, } // SAFETY: ort::Session is Send in ort 2.x. Sync is safe because every @@ -352,6 +410,7 @@ impl PipecatSmartTurn { session, ring_buffer: VecDeque::with_capacity(RING_CAPACITY), mel: MelExtractor::new(), + samples_since_predict: 0, } } } @@ -372,6 +431,7 @@ impl AudioTurnDetector for PipecatSmartTurn { self.ring_buffer.drain(..overflow); } self.ring_buffer.extend(samples.iter().copied()); + self.samples_since_predict += samples.len(); } /// Run inference on the buffered audio. @@ -381,14 +441,19 @@ impl AudioTurnDetector for PipecatSmartTurn { fn predict(&mut self) -> Result { let t_start = Instant::now(); - // Snapshot the ring buffer and prepare exactly 128 000 samples + // Stage 1: Snapshot the ring buffer and prepare exactly 128 000 samples + let shift_frames = self.samples_since_predict / HOP_LENGTH; + self.samples_since_predict = 0; + let buffered: Vec = self.ring_buffer.iter().copied().collect(); let audio = prepare_audio(&buffered); + let t_after_audio_prep = Instant::now(); - // Extract [N_MELS × N_FRAMES] log-mel features - let mel_spec = self.mel.extract(&audio); + // Stage 2: Extract [N_MELS × N_FRAMES] log-mel features (incremental) + let mel_spec = self.mel.extract(&audio, shift_frames); + let t_after_mel = Instant::now(); - // Reshape to [1, N_MELS, N_FRAMES] for batch inference + // Stage 3: Reshape to [1, N_MELS, N_FRAMES] and run ONNX inference let (raw, _) = mel_spec.into_raw_vec_and_offset(); let input_array = Array3::from_shape_vec((1, N_MELS, N_FRAMES), raw) .expect("internal: mel output has wrong element count"); @@ -400,6 +465,7 @@ impl AudioTurnDetector for PipecatSmartTurn { .session .run(inputs!["input_features" => input_tensor]) .map_err(|e| TurnError::BackendError(format!("inference failed: {e}")))?; + let t_after_onnx = Instant::now(); // Extract sigmoid probability from the "logits" output let output = outputs @@ -414,6 +480,13 @@ impl AudioTurnDetector for PipecatSmartTurn { let latency_ms = t_start.elapsed().as_millis() as u64; + let us = |a: Instant, b: Instant| (b - a).as_secs_f64() * 1_000_000.0; + let stage_times = vec![ + StageTiming { name: "audio_prep", us: us(t_start, t_after_audio_prep) }, + StageTiming { name: "mel", us: us(t_after_audio_prep, t_after_mel) }, + StageTiming { name: "onnx", us: us(t_after_mel, t_after_onnx) }, + ]; + // probability = P(turn complete); > 0.5 means the speaker has finished let (state, confidence) = if probability > 0.5 { (TurnState::Finished, probability) @@ -425,11 +498,14 @@ impl AudioTurnDetector for PipecatSmartTurn { state, confidence, latency_ms, + stage_times, }) } /// Clear the ring buffer. Call at the start of each new speech turn. fn reset(&mut self) { self.ring_buffer.clear(); + self.samples_since_predict = 0; + self.mel.invalidate_cache(); } } diff --git a/crates/wavekat-turn/src/lib.rs b/crates/wavekat-turn/src/lib.rs index 9453791..98ca819 100644 --- a/crates/wavekat-turn/src/lib.rs +++ b/crates/wavekat-turn/src/lib.rs @@ -41,12 +41,23 @@ pub enum TurnState { Wait, } +/// Per-stage timing entry. +#[derive(Debug, Clone)] +pub struct StageTiming { + /// Stage name (e.g. "audio_prep", "mel", "onnx"). + pub name: &'static str, + /// Time in microseconds for this stage. + pub us: f64, +} + /// A turn detection prediction with confidence and timing metadata. #[derive(Debug, Clone)] pub struct TurnPrediction { pub state: TurnState, pub confidence: f32, pub latency_ms: u64, + /// Per-stage timing breakdown in pipeline order. + pub stage_times: Vec, } /// A single turn in the conversation, for context-aware text detectors. From f11038a7d6ffea790e27b5ed1c1ada44d5378b8a Mon Sep 17 00:00:00 2001 From: Eason WaveKat Date: Sat, 28 Mar 2026 22:30:52 +1300 Subject: [PATCH 7/7] fix: resolve clippy and fmt issues from CI Co-Authored-By: Claude Sonnet 4.6 --- crates/wavekat-turn/build.rs | 13 ++++----- crates/wavekat-turn/src/audio/pipecat.rs | 36 ++++++++++++++++++------ crates/wavekat-turn/tests/pipecat.rs | 16 +++++++++-- 3 files changed, 46 insertions(+), 19 deletions(-) diff --git a/crates/wavekat-turn/build.rs b/crates/wavekat-turn/build.rs index 46755bc..4bb440a 100644 --- a/crates/wavekat-turn/build.rs +++ b/crates/wavekat-turn/build.rs @@ -21,7 +21,10 @@ use std::path::Path; fn main() { // docs.rs builds with --network none; write empty placeholders so // include_bytes! compiles without downloading anything. - if env::var("DOCS_RS").is_ok() { + if env::var("DOCS_RS").is_err() { + #[cfg(feature = "pipecat")] + setup_pipecat_model(); + } else { #[cfg(feature = "pipecat")] { let out_dir = env::var("OUT_DIR").expect("OUT_DIR not set"); @@ -30,11 +33,7 @@ fn main() { fs::write(&model_path, b"").expect("failed to write placeholder model"); } } - return; } - - #[cfg(feature = "pipecat")] - setup_pipecat_model(); } #[cfg(feature = "pipecat")] @@ -78,8 +77,8 @@ fn setup_pipecat_model() { } // Option 2: download (caller may override the URL) - let url = env::var("PIPECAT_SMARTTURN_MODEL_URL") - .unwrap_or_else(|_| DEFAULT_MODEL_URL.to_string()); + let url = + env::var("PIPECAT_SMARTTURN_MODEL_URL").unwrap_or_else(|_| DEFAULT_MODEL_URL.to_string()); println!("cargo:warning=Downloading Pipecat Smart Turn model ({MODEL_VERSION}) from {url}"); diff --git a/crates/wavekat-turn/src/audio/pipecat.rs b/crates/wavekat-turn/src/audio/pipecat.rs index a9d4ae7..4554b1d 100644 --- a/crates/wavekat-turn/src/audio/pipecat.rs +++ b/crates/wavekat-turn/src/audio/pipecat.rs @@ -54,8 +54,8 @@ use ort::{inputs, value::Tensor}; use realfft::num_complex::Complex; use realfft::{RealFftPlanner, RealToComplex}; -use crate::{AudioFrame, AudioTurnDetector, StageTiming, TurnError, TurnPrediction, TurnState}; use crate::onnx; +use crate::{AudioFrame, AudioTurnDetector, StageTiming, TurnError, TurnPrediction, TurnState}; // --------------------------------------------------------------------------- // Constants @@ -77,8 +77,7 @@ const N_FREQS: usize = N_FFT / 2 + 1; // 201 const RING_CAPACITY: usize = 8 * SAMPLE_RATE as usize; // 128 000 /// Embedded ONNX model bytes, downloaded by build.rs at compile time. -const MODEL_BYTES: &[u8] = - include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx")); +const MODEL_BYTES: &[u8] = include_bytes!(concat!(env!("OUT_DIR"), "/smart-turn-v3.2-cpu.onnx")); // --------------------------------------------------------------------------- // Mel feature extractor @@ -189,7 +188,11 @@ impl MelExtractor { } self.fft - .process_with_scratch(&mut frame_buf, &mut self.spectrum_buf, &mut self.fft_scratch) + .process_with_scratch( + &mut frame_buf, + &mut self.spectrum_buf, + &mut self.fft_scratch, + ) .expect("FFT failed: internal buffer size mismatch"); for (k, c) in self.spectrum_buf.iter().enumerate() { @@ -247,7 +250,7 @@ fn hz_to_mel(hz: f32) -> f32 { const F_SP: f32 = 200.0 / 3.0; // linear region slope (Hz per mel) const MIN_LOG_HZ: f32 = 1000.0; const MIN_LOG_MEL: f32 = MIN_LOG_HZ / F_SP; // = 15.0 - // logstep = ln(6.4) / 27 (≈ 0.068752) + // logstep = ln(6.4) / 27 (≈ 0.068752) let logstep = (6.4_f32).ln() / 27.0; if hz >= MIN_LOG_HZ { MIN_LOG_MEL + (hz / MIN_LOG_HZ).ln() / logstep @@ -274,7 +277,13 @@ fn mel_to_hz(mel: f32) -> f32 { /// Matches `librosa.filters.mel(sr, n_fft, n_mels, fmin, fmax, /// norm="slaney", dtype=float32)` which is what HuggingFace's /// `WhisperFeatureExtractor` uses internally. -fn build_mel_filters(sr: usize, n_fft: usize, n_mels: usize, f_min: f32, f_max: f32) -> Array2 { +fn build_mel_filters( + sr: usize, + n_fft: usize, + n_mels: usize, + f_min: f32, + f_max: f32, +) -> Array2 { let n_freqs = n_fft / 2 + 1; // FFT frequency bins: 0, sr/n_fft, 2·sr/n_fft, … @@ -482,9 +491,18 @@ impl AudioTurnDetector for PipecatSmartTurn { let us = |a: Instant, b: Instant| (b - a).as_secs_f64() * 1_000_000.0; let stage_times = vec![ - StageTiming { name: "audio_prep", us: us(t_start, t_after_audio_prep) }, - StageTiming { name: "mel", us: us(t_after_audio_prep, t_after_mel) }, - StageTiming { name: "onnx", us: us(t_after_mel, t_after_onnx) }, + StageTiming { + name: "audio_prep", + us: us(t_start, t_after_audio_prep), + }, + StageTiming { + name: "mel", + us: us(t_after_audio_prep, t_after_mel), + }, + StageTiming { + name: "onnx", + us: us(t_after_mel, t_after_onnx), + }, ]; // probability = P(turn complete); > 0.5 means the speaker has finished diff --git a/crates/wavekat-turn/tests/pipecat.rs b/crates/wavekat-turn/tests/pipecat.rs index 74a5c30..0804308 100644 --- a/crates/wavekat-turn/tests/pipecat.rs +++ b/crates/wavekat-turn/tests/pipecat.rs @@ -118,7 +118,10 @@ fn test_multiple_predicts_are_deterministic() { push_silence(&mut d, 2.0); let p1 = d.predict().unwrap(); let p2 = d.predict().unwrap(); - assert_eq!(p1.state, p2.state, "repeated predict should give same state"); + assert_eq!( + p1.state, p2.state, + "repeated predict should give same state" + ); assert!( (p1.confidence - p2.confidence).abs() < 1e-5, "repeated predict should give same confidence" @@ -143,7 +146,10 @@ fn test_latency_under_50ms() { #[test] fn test_from_file_invalid_path_returns_error() { let result = PipecatSmartTurn::from_file("/nonexistent/path/model.onnx"); - assert!(result.is_err(), "from_file with invalid path should return an error"); + assert!( + result.is_err(), + "from_file with invalid path should return an error" + ); } /// Smoke test: latency is measured and non-zero (always runs, including debug). @@ -153,5 +159,9 @@ fn test_latency_is_measured() { push_silence(&mut d, 2.0); let pred = d.predict().unwrap(); // latency_ms == 0 would mean the timer wasn't working - assert!(pred.latency_ms < 60_000, "latency suspiciously large: {} ms", pred.latency_ms); + assert!( + pred.latency_ms < 60_000, + "latency suspiciously large: {} ms", + pred.latency_ms + ); }