Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
.PHONY: help check test fmt lint doc ci accuracy mel
.PHONY: help check test fmt lint doc ci accuracy mel example-controller

help:
@echo "Available targets:"
Expand All @@ -10,6 +10,7 @@ help:
@echo " lint Run clippy with warnings as errors"
@echo " doc Build and open docs in browser"
@echo " ci Run all CI checks locally (fmt, clippy, test, doc, features)"
@echo " example-controller Run TurnController example"

# Check workspace compiles
check:
Expand All @@ -27,6 +28,10 @@ accuracy:
mel:
cargo test --features pipecat -- mel_report --ignored --nocapture

# Run TurnController example
example-controller:
cargo run --features pipecat --example controller

# Format code
fmt:
cargo fmt --all
Expand Down
29 changes: 22 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ models behind common Rust traits. Same pattern as
[wavekat-vad](https://github.com/wavekat/wavekat-vad).

> [!WARNING]
> Early development. Trait API is defined; backend implementations are stubs pending ONNX model integration.
> Early development. API may change between minor versions.

## Backends

Expand All @@ -27,25 +27,34 @@ models behind common Rust traits. Same pattern as
cargo add wavekat-turn --features pipecat
```

Use the audio-based detector:
Use `TurnController` to wrap any detector with automatic state tracking:

```rust
use wavekat_turn::{AudioTurnDetector, TurnState};
use wavekat_turn::{TurnController, TurnState};
use wavekat_turn::audio::PipecatSmartTurn;

let mut detector = PipecatSmartTurn::new()?;
let detector = PipecatSmartTurn::new()?;
let mut ctrl = TurnController::new(detector);

// Feed 16 kHz f32 PCM frames after VAD detects silence
let prediction = detector.predict_audio(&audio_frames)?;
// Feed audio continuously
ctrl.push_audio(&audio_frame);

// VAD speech start — soft reset (keeps buffer if turn was unfinished)
ctrl.reset_if_finished();

// VAD speech end — predict
let prediction = ctrl.predict()?;
match prediction.state {
TurnState::Finished => { /* user is done, send to LLM */ }
TurnState::Unfinished => { /* keep listening */ }
TurnState::Wait => { /* user asked AI to hold */ }
}

// After assistant finishes responding — hard reset
ctrl.reset();
```

Or the text-based detector:
Or the text-based detector directly:

```rust
use wavekat_turn::{TextTurnDetector, TurnState};
Expand All @@ -57,13 +66,19 @@ let prediction = detector.predict_text("I was wondering if", &context)?;
assert_eq!(prediction.state, TurnState::Unfinished);
```

See [`examples/controller.rs`](crates/wavekat-turn/examples/controller.rs) for a
full walkthrough with real audio.

## Architecture

Two trait families cover the two input modalities:

- **`AudioTurnDetector`** -- operates on raw audio frames (no ASR needed)
- **`TextTurnDetector`** -- operates on ASR transcript text with optional conversation context

`TurnController` wraps any `AudioTurnDetector` and adds orchestration helpers
like soft-reset (preserves buffer when the user pauses mid-sentence).

```
wavekat-vad --> "is someone speaking?"
wavekat-turn --> "are they done speaking?"
Expand Down
4 changes: 4 additions & 0 deletions crates/wavekat-turn/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ ndarray-npy = "0.10"
serde = { version = "1", features = ["derive"] }
serde_json = "1"

[[example]]
name = "controller"
required-features = ["pipecat"]

[package.metadata.docs.rs]
all-features = true
rustdoc-args = ["--cfg", "docsrs"]
100 changes: 100 additions & 0 deletions crates/wavekat-turn/examples/controller.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
//! Example: using TurnController for VAD-driven turn detection.
//!
//! Run with: `cargo run --features pipecat --example controller`
//!
//! Demonstrates the soft-reset flow using real WAV fixtures:
//!
//! 1. User speaks mid-sentence (speech_mid.wav) → Unfinished
//! 2. User continues speaking — soft reset keeps the buffer intact
//! 3. User finishes the sentence (speech_finished.wav) → Finished
//! 4. After assistant responds, hard reset starts a fresh turn

use std::path::Path;

use wavekat_turn::audio::PipecatSmartTurn;
use wavekat_turn::{AudioFrame, TurnController};

fn load_wav(path: &Path) -> Vec<f32> {
let mut reader = hound::WavReader::open(path)
.unwrap_or_else(|e| panic!("failed to open {}: {}", path.display(), e));
let spec = reader.spec();
match spec.sample_format {
hound::SampleFormat::Int => reader
.samples::<i16>()
.map(|s| s.unwrap() as f32 / 32768.0)
.collect(),
hound::SampleFormat::Float => reader.samples::<f32>().map(|s| s.unwrap()).collect(),
}
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
let fixtures = Path::new(env!("CARGO_MANIFEST_DIR"))
.parent()
.unwrap()
.parent()
.unwrap()
.join("tests/fixtures");

let speech_mid = load_wav(&fixtures.join("speech_mid.wav"));
let speech_finished = load_wav(&fixtures.join("speech_finished.wav"));

let detector = PipecatSmartTurn::new()?;
let mut ctrl = TurnController::new(detector);

// --- Speech A: user says something mid-sentence ---
println!(">> VAD: speech started");
ctrl.reset_if_finished(); // first speech → resets

println!(">> Pushing speech_mid.wav (cut mid-sentence)");
ctrl.push_audio(&AudioFrame::new(&speech_mid[..], 16_000));

println!(">> VAD: speech ended");
let result_a = ctrl.predict()?;
println!(
" predict → {:?} (confidence: {:.3})",
result_a.state, result_a.confidence
);

// --- Speech B: user continues speaking ---
println!("\n>> VAD: speech started again");
let did_reset = ctrl.reset_if_finished();
println!(
" reset_if_finished → {}",
if did_reset {
"reset (turn was finished)"
} else {
"skipped (turn unfinished, keeping buffer)"
}
);

println!(">> Pushing speech_finished.wav (complete sentence)");
ctrl.push_audio(&AudioFrame::new(&speech_finished[..], 16_000));

println!(">> VAD: speech ended");
let result_b = ctrl.predict()?;
println!(
" predict → {:?} (confidence: {:.3}, ran on A+B combined)",
result_b.state, result_b.confidence
);

// --- New turn: after assistant responds ---
println!("\n>> Assistant finished responding");
ctrl.reset(); // hard reset for next turn
println!(" hard reset, last_state: {:?}", ctrl.last_state());

// --- Speech C: fresh turn ---
println!("\n>> VAD: speech started (new turn)");
ctrl.reset_if_finished(); // last_state is None → resets

println!(">> Pushing speech_finished.wav");
ctrl.push_audio(&AudioFrame::new(&speech_finished[..], 16_000));

println!(">> VAD: speech ended");
let result_c = ctrl.predict()?;
println!(
" predict → {:?} (confidence: {:.3})",
result_c.state, result_c.confidence
);

Ok(())
}
100 changes: 100 additions & 0 deletions crates/wavekat-turn/src/controller.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use crate::{AudioFrame, AudioTurnDetector, TurnError, TurnPrediction, TurnState};

/// Orchestration wrapper around any [`AudioTurnDetector`].
///
/// Tracks prediction state across calls and provides convenience methods
/// like [`reset_if_finished`](TurnController::reset_if_finished) for
/// correct VAD integration without manual state bookkeeping.
///
/// # Usage
///
/// ```ignore
/// let detector = PipecatSmartTurn::new()?;
/// let mut ctrl = TurnController::new(detector);
///
/// // Audio arrives continuously
/// ctrl.push_audio(&frame);
///
/// // VAD speech start — soft reset (keeps buffer if turn was unfinished)
/// ctrl.reset_if_finished();
///
/// // VAD speech end — predict
/// let result = ctrl.predict()?;
/// ```
///
/// See [`reset_if_finished`](TurnController::reset_if_finished) for details
/// on when to use soft vs hard reset.
pub struct TurnController<T: AudioTurnDetector> {
inner: T,
last_state: Option<TurnState>,
}

impl<T: AudioTurnDetector> TurnController<T> {
/// Create a new controller wrapping the given detector.
pub fn new(inner: T) -> Self {
Self {
inner,
last_state: None,
}
}

/// Feed audio into the detector.
pub fn push_audio(&mut self, frame: &AudioFrame) {
self.inner.push_audio(frame);
}

/// Run prediction on buffered audio.
///
/// Tracks the result state internally for [`reset_if_finished`](Self::reset_if_finished).
pub fn predict(&mut self) -> Result<TurnPrediction, TurnError> {
let result = self.inner.predict()?;
self.last_state = Some(result.state);
Ok(result)
}

/// Hard reset — always clears the buffer.
///
/// Use when you know a new turn is starting (e.g. after the assistant
/// finishes responding).
pub fn reset(&mut self) {
self.inner.reset();
self.last_state = None;
}

/// Soft reset — clears the buffer only if the last prediction was
/// [`Finished`](TurnState::Finished) or no prediction has been made
/// since the last reset.
///
/// Returns `true` if a reset occurred, `false` if skipped.
///
/// Call this on VAD speech-start when you don't know whether the user
/// is continuing the same turn or starting a new one. If the previous
/// prediction was [`Unfinished`](TurnState::Unfinished), the buffer is
/// preserved so the next [`predict`](Self::predict) runs on the full
/// accumulated audio.
pub fn reset_if_finished(&mut self) -> bool {
match self.last_state {
Some(TurnState::Unfinished) => false,
_ => {
self.reset();
true
}
}
}

/// Returns the state from the last [`predict`](Self::predict) call,
/// or `None` if no prediction has been made since the last reset.
pub fn last_state(&self) -> Option<TurnState> {
self.last_state
}

/// Returns a mutable reference to the inner detector.
pub fn inner_mut(&mut self) -> &mut T {
&mut self.inner
}

/// Unwrap the controller, returning the inner detector.
pub fn into_inner(self) -> T {
self.inner
}
}
35 changes: 30 additions & 5 deletions crates/wavekat-turn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,18 @@
//! - [`AudioTurnDetector`] — operates on raw audio frames (e.g. Pipecat Smart Turn)
//! - [`TextTurnDetector`] — operates on ASR transcript text (e.g. LiveKit EOU)
//!
//! For most use cases, wrap a detector in [`TurnController`] to get
//! automatic state tracking and soft-reset logic for VAD integration.
//! See [`controller`] for details.
//!
//! # Feature flags
//!
//! | Feature | Backend | Input |
//! |---------|---------|-------|
//! | `pipecat` | Pipecat Smart Turn v3 (ONNX) | Audio (16 kHz) |
//! | `livekit` | LiveKit Turn Detector (ONNX) | Text |

pub mod controller;
pub mod error;

#[cfg(any(feature = "pipecat", feature = "livekit"))]
Expand All @@ -27,6 +32,7 @@ pub mod audio;
#[cfg(feature = "livekit")]
pub mod text;

pub use controller::TurnController;
pub use error::TurnError;
pub use wavekat_core::AudioFrame;

Expand Down Expand Up @@ -77,11 +83,23 @@ pub enum Role {
/// Turn detector that operates on raw audio.
///
/// Implementations buffer audio internally and run prediction on demand.
/// The typical flow with VAD:
///
/// **Most users should wrap this in [`TurnController`]** rather than calling
/// these methods directly. The controller tracks prediction state and provides
/// [`reset_if_finished`](TurnController::reset_if_finished) for correct
/// multi-utterance handling.
///
/// # Direct usage (advanced)
///
/// If you need full control over reset logic:
///
/// 1. **Every audio chunk** → [`push_audio`](AudioTurnDetector::push_audio)
/// 2. **VAD fires "speech started"** → [`reset`](AudioTurnDetector::reset)
/// 3. **VAD fires "speech stopped"** → [`predict`](AudioTurnDetector::predict)
/// 2. **VAD fires "speech stopped"** → [`predict`](AudioTurnDetector::predict)
/// 3. **New turn begins** → [`reset`](AudioTurnDetector::reset)
///
/// Note: calling `reset` unconditionally on every VAD speech-start will discard
/// audio context when the user pauses mid-sentence. See [`TurnController`] for
/// the recommended approach.
pub trait AudioTurnDetector: Send + Sync {
/// Feed audio into the internal buffer.
///
Expand All @@ -90,10 +108,17 @@ pub trait AudioTurnDetector: Send + Sync {

/// Run prediction on buffered audio.
///
/// Call when VAD detects end of speech.
/// Call when VAD detects end of speech. The buffer is **not** cleared
/// after prediction — call [`reset`](AudioTurnDetector::reset) explicitly
/// when starting a new turn.
fn predict(&mut self) -> Result<TurnPrediction, TurnError>;

/// Clear the internal buffer. Call when a new speech turn begins.
/// Unconditionally clear the internal buffer.
///
/// Use when you are certain a new turn is starting (e.g. after the
/// assistant finishes responding). For VAD speech-start events where
/// the user may be continuing, prefer
/// [`TurnController::reset_if_finished`].
fn reset(&mut self);
}

Expand Down
Loading