diff --git a/sdk/rt/speechmatics/rt/_async_client.py b/sdk/rt/speechmatics/rt/_async_client.py index 50618666..e12e6a4e 100644 --- a/sdk/rt/speechmatics/rt/_async_client.py +++ b/sdk/rt/speechmatics/rt/_async_client.py @@ -23,6 +23,7 @@ _UNSET = object() + class AsyncClient(_BaseClient): """ Asynchronous client for Speechmatics real-time audio transcription. @@ -195,7 +196,7 @@ async def force_end_of_utterance(self, *, timestamp: Optional[float] | object = ... await client.force_end_of_utterance() """ - message: dict[str,Any] = {"message": ClientMessageType.FORCE_END_OF_UTTERANCE} + message: dict[str, Any] = {"message": ClientMessageType.FORCE_END_OF_UTTERANCE} if timestamp is _UNSET: # default: auto-set from audio_seconds_sent diff --git a/sdk/voice/pyproject.toml b/sdk/voice/pyproject.toml index 9006bd1f..c67e5dc1 100644 --- a/sdk/voice/pyproject.toml +++ b/sdk/voice/pyproject.toml @@ -11,7 +11,7 @@ authors = [{ name = "Speechmatics", email = "support@speechmatics.com" }] license = "MIT" requires-python = ">=3.9" dependencies = [ - "speechmatics-rt>=0.5.3", + "speechmatics-rt>=1.0.0", "pydantic>=2.10.6,<3", "numpy>=1.26.4,<3" ] diff --git a/sdk/voice/speechmatics/voice/_client.py b/sdk/voice/speechmatics/voice/_client.py index c0988dd3..f250daa1 100644 --- a/sdk/voice/speechmatics/voice/_client.py +++ b/sdk/voice/speechmatics/voice/_client.py @@ -176,6 +176,10 @@ def __init__( preset_config = VoiceAgentConfigPreset.load(preset) config = VoiceAgentConfigPreset._merge_configs(preset_config, config) + # Validate the final config (deferred to allow overlay/preset merging first) + if config is not None: + config.validate_config() + # Process the config self._config, self._transcription_config, self._audio_format = self._prepare_config(config) @@ -310,24 +314,26 @@ def __init__( self._turn_handler: TurnTaskProcessor = TurnTaskProcessor(name="turn_handler", done_callback=self.finalize) self._eot_calculation_task: Optional[asyncio.Task] = None - # Uses fixed EndOfUtterance message from STT - self._uses_fixed_eou: bool = ( - self._eou_mode == EndOfUtteranceMode.FIXED - and not self._silero_detector - and not self._config.end_of_turn_config.use_forced_eou - ) - - # Uses ForceEndOfUtterance message - self._uses_forced_eou: bool = not self._uses_fixed_eou + # Forced end of utterance handling + # FEOU is not used in FIXED mode, unless VAD has been enabled. It can / should + # also be disabled during testing when not connected to an endpoint, as the + # waiting for FEOU response will block the test. + self._use_forced_eou: bool = self._eou_mode is not EndOfUtteranceMode.FIXED or self._uses_silero_vad self._forced_eou_active: bool = False self._last_forced_eou_latency: float = 0.0 - # Emit EOT prediction (uses _uses_forced_eou) - self._uses_eot_prediction: bool = self._eou_mode not in [ + # Emit EOT prediction + # EOT predictions are only relevant when not using the FIXED or EXTERNAL modes, + # as these use different triggers to finalize the turn. + self._emit_eot_predictions: bool = self._eou_mode not in [ EndOfUtteranceMode.FIXED, EndOfUtteranceMode.EXTERNAL, ] + # Time slip for Forced End Of Utterance + self._feou_chunk_s: float = 0.360 + self._feou_padding_s: float = 0.0 + # ------------------------------------- # Diarization / Speakers # ------------------------------------- @@ -360,8 +366,8 @@ def __init__( AudioEncoding.PCM_S16LE: 2, }.get(self._audio_format.encoding, 1) - # Default audio buffer - if not self._config.audio_buffer_length and (self._uses_smart_turn or self._uses_silero_vad): + # Default audio buffer (used when Silero VAD is enabled and with Smart Turn) + if not self._config.audio_buffer_length and self._uses_silero_vad: self._config.audio_buffer_length = 15.0 # Audio buffer @@ -447,9 +453,7 @@ def _prepare_config( ) # Fixed end of Utterance - if bool( - config.end_of_utterance_mode == EndOfUtteranceMode.FIXED and not config.end_of_turn_config.use_forced_eou - ): + if config.end_of_utterance_mode == EndOfUtteranceMode.FIXED: transcription_config.conversation_config = ConversationConfig( end_of_utterance_silence_trigger=config.end_of_utterance_silence_trigger, ) @@ -659,8 +663,14 @@ async def send_audio(self, payload: bytes) -> None: return # Process with Silero VAD - if self._silero_detector: - asyncio.create_task(self._silero_detector.process_audio(payload)) + if self._uses_silero_vad and self._silero_detector is not None: + asyncio.create_task( + self._silero_detector.process_audio( + payload, + sample_rate=self._audio_sample_rate, + sample_width=self._audio_sample_width, + ) + ) # Add to audio buffer (use put_bytes to handle variable chunk sizes) if self._config.audio_buffer_length > 0: @@ -738,7 +748,7 @@ async def emit() -> None: """Wait for EndOfUtterance if needed, then emit segments.""" # Forced end of utterance message (only when no speaker is detected) - if self._config.end_of_turn_config.use_forced_eou: + if self._use_forced_eou: await self._await_forced_eou() # Check if the turn has changed @@ -749,7 +759,7 @@ async def emit() -> None: self._stt_message_queue.put_nowait(lambda: self._emit_segments(finalize=True, is_eou=True)) # Call async task (only if not already waiting for forced EOU) - if not (self._config.end_of_turn_config.use_forced_eou and self._forced_eou_active): + if not self._forced_eou_active: asyncio.create_task(emit()) # ============================================================================ @@ -788,8 +798,8 @@ def _evt_on_final_transcript(message: dict[str, Any]) -> None: return self._stt_message_queue.put_nowait(lambda: self._handle_transcript(message, is_final=True)) - # End of Utterance (FIXED mode only) - if self._uses_fixed_eou: + # End of Utterance - only when not using ForceEndOfUtterance messages + if not self._use_forced_eou: @self.on(ServerMessageType.END_OF_UTTERANCE) # type: ignore[misc] def _evt_on_end_of_utterance(message: dict[str, Any]) -> None: @@ -1066,7 +1076,7 @@ async def _add_speech_fragments(self, message: dict[str, Any], is_final: bool = # Metadata metadata = message.get("metadata", {}) - payload_end_time = metadata.get("end_time", 0) + payload_end_time = self._calc_adjusted_time(metadata.get("end_time", 0)) # Iterate over the results in the payload for result in message.get("results", []): @@ -1075,8 +1085,8 @@ async def _add_speech_fragments(self, message: dict[str, Any], is_final: bool = # Create the new fragment fragment = SpeechFragment( idx=self._next_fragment_id(), - start_time=result.get("start_time", 0), - end_time=result.get("end_time", 0), + start_time=self._calc_adjusted_time(result.get("start_time", 0)), + end_time=self._calc_adjusted_time(result.get("end_time", 0)), language=alt.get("language", "en"), direction=alt.get("direction", "ltr"), type_=result.get("type", "word"), @@ -1121,7 +1131,7 @@ async def _add_speech_fragments(self, message: dict[str, Any], is_final: bool = self._last_fragment_end_time = max(self._last_fragment_end_time, fragment.end_time) # Evaluate for VAD (only done on partials) - await self._vad_evaluation(fragments, is_final=is_final) + await self._speaker_start_stop_evaluation(fragments, is_final=is_final) # Fragments to retain retained_fragments = [ @@ -1205,18 +1215,8 @@ async def _process_speech_fragments(self, change_filter: Optional[list[Annotatio if change_filter and not changes.any(*change_filter): return - # Skip re-evaluation if transcripts are older than smart turn cutoff - if self._smart_turn_pending_cutoff is not None and self._current_view: - latest_end_time = max( - (f.end_time for f in self._current_view.fragments if f.end_time is not None), default=0.0 - ) - - # If all fragments end before or at the cutoff, skip re-evaluation - if latest_end_time <= self._smart_turn_pending_cutoff: - return - # Turn prediction - if self._uses_eot_prediction and self._uses_forced_eou and not self._forced_eou_active: + if self._emit_eot_predictions and not self._forced_eou_active and self._use_forced_eou: async def fn() -> None: ttl = await self._calculate_finalize_delay() @@ -1518,7 +1518,7 @@ async def _calculate_finalize_delay( annotation = annotation or AnnotationResult() # VAD enabled - if self._silero_detector: + if self._uses_silero_vad: annotation.add(AnnotationFlags.VAD_ACTIVE) else: annotation.add(AnnotationFlags.VAD_INACTIVE) @@ -1526,6 +1526,12 @@ async def _calculate_finalize_delay( # Smart Turn enabled if self._smart_turn_detector: annotation.add(AnnotationFlags.SMART_TURN_ACTIVE) + # If Smart Turn hasn't returned a result yet but is enabled, add NO_SIGNAL annotation. + # This covers the case where the TTL fires before VAD triggers Smart Turn inference. + if not annotation.has(AnnotationFlags.SMART_TURN_TRUE) and not annotation.has( + AnnotationFlags.SMART_TURN_FALSE + ): + annotation.add(AnnotationFlags.SMART_TURN_NO_SIGNAL) else: annotation.add(AnnotationFlags.SMART_TURN_INACTIVE) @@ -1551,8 +1557,7 @@ async def _calculate_finalize_delay( delay = round(self._config.end_of_utterance_silence_trigger * multiplier, 3) # Trim off the most recent forced EOU delay if we're in forced EOU mode - if self._uses_forced_eou: - delay -= self._last_forced_eou_latency + delay -= self._last_forced_eou_latency # Clamp to max delay and adjust for TTFB clamped_delay = min(delay, self._config.end_of_utterance_max_delay) @@ -1586,7 +1591,10 @@ async def _eot_prediction( # Wait for Smart Turn result if self._smart_turn_detector and end_time is not None: result = await self._smart_turn_prediction(end_time, self._config.language, speaker=speaker) - if result.prediction: + if result.error: + # No valid prediction — SMART_TURN_NO_SIGNAL will be applied by _calculate_finalize_delay + pass + elif result.prediction: annotation.add(AnnotationFlags.SMART_TURN_TRUE) else: annotation.add(AnnotationFlags.SMART_TURN_FALSE) @@ -1676,9 +1684,6 @@ async def _await_forced_eou(self, timeout: float = 1.0) -> None: # Add listener self.once(AgentServerMessageType.END_OF_UTTERANCE, lambda message: eou_received.set()) - # Trigger EOU message - self._emit_diagnostic_message("ForceEndOfUtterance sent - waiting for EndOfUtterance") - # Wait for EOU try: # Track the start time @@ -1686,7 +1691,11 @@ async def _await_forced_eou(self, timeout: float = 1.0) -> None: self._forced_eou_active = True # Send the force EOU and wait for the response - await self.force_end_of_utterance() + timestamp = await self._calc_force_end_of_utterance() + await self.force_end_of_utterance(timestamp=timestamp) + self._emit_diagnostic_message(f"ForceEndOfUtterance sent - waiting for EndOfUtterance ({timestamp=})") + + # Wait for the response await asyncio.wait_for(eou_received.wait(), timeout=timeout) # Record the latency @@ -1698,11 +1707,43 @@ async def _await_forced_eou(self, timeout: float = 1.0) -> None: finally: self._forced_eou_active = False + async def _calc_force_end_of_utterance(self) -> float: + """Force the end of the current utterance.""" + + # Seconds sent + timestamp: float = float(self._audio_bytes_sent) / ( + self._audio_format.sample_rate * self._audio_format.bytes_per_sample + ) + + # Add padding for transcriber chunk size + padding: float = self._feou_chunk_s - (timestamp % self._feou_chunk_s) + self._feou_padding_s += padding + + # Send silence + padding_silence = b"\x00" * int(padding * self._audio_format.sample_rate * self._audio_format.bytes_per_sample) + await self.send_audio(padding_silence) + + # Return the time + return timestamp + padding + + def _calc_adjusted_time(self, timestamp: float) -> float: + """Calculate the adjusted timestamp. + + As forced end of utterance is used, the time needs to get padded to fill the chunk size processed + by the engine. This needs to be kept track of and removed from timestamps that are returned within + the conversation. + """ + + if not self._use_forced_eou: + return timestamp + + return round(timestamp - self._feou_padding_s, 4) + # ============================================================================ # VAD (VOICE ACTIVITY DETECTION) / SPEAKER DETECTION # ============================================================================ - async def _vad_evaluation(self, fragments: list[SpeechFragment], is_final: bool) -> None: + async def _speaker_start_stop_evaluation(self, fragments: list[SpeechFragment], is_final: bool) -> None: """Emit a VAD event. This will emit `SPEAKER_STARTED` and `SPEAKER_ENDED` events to the client and is @@ -1850,18 +1891,20 @@ def _handle_silero_vad_result(self, result: SileroVADResult) -> None: annotation.add(AnnotationFlags.VAD_STARTED) # If speech has ended, we need to predict the end of turn - if result.speech_ended and self._uses_eot_prediction: + if self._emit_eot_predictions and result.speech_ended: """VAD-based end of turn prediction.""" # Set cutoff to prevent late transcripts from cancelling finalization self._smart_turn_pending_cutoff = event_time + # Async callback async def fn() -> None: ttl = await self._eot_prediction( end_time=event_time, speaker=self._current_speaker, annotation=annotation ) self._turn_handler.update_timer(ttl) + # Call the eot calculation asynchronously self._run_background_eot_calculation(fn, "silero_vad") async def _handle_speaker_started(self, speaker: Optional[str], event_time: float) -> None: @@ -1878,8 +1921,7 @@ async def _handle_speaker_started(self, speaker: Optional[str], event_time: floa await self._emit_start_of_turn(event_time) # Update the turn handler - if self._uses_forced_eou: - self._turn_handler.reset() + self._turn_handler.reset() # Emit the event self._emit_message( @@ -1902,7 +1944,7 @@ async def _handle_speaker_stopped(self, speaker: Optional[str], event_time: floa self._last_speak_end_latency = self._total_time - event_time # Turn prediction - if self._uses_eot_prediction and not self._forced_eou_active: + if self._emit_eot_predictions and not self._forced_eou_active: async def fn() -> None: ttl = await self._eot_prediction(event_time, speaker) diff --git a/sdk/voice/speechmatics/voice/_models.py b/sdk/voice/speechmatics/voice/_models.py index b4a432c2..c58a7ca6 100644 --- a/sdk/voice/speechmatics/voice/_models.py +++ b/sdk/voice/speechmatics/voice/_models.py @@ -13,7 +13,6 @@ from pydantic import BaseModel as PydanticBaseModel from pydantic import ConfigDict from pydantic import Field -from pydantic import model_validator from typing_extensions import Self from speechmatics.rt import AudioEncoding @@ -261,6 +260,7 @@ class AnnotationFlags(str, Enum): SMART_TURN_INACTIVE = "smart_turn_inactive" SMART_TURN_TRUE = "smart_turn_true" SMART_TURN_FALSE = "smart_turn_false" + SMART_TURN_NO_SIGNAL = "smart_turn_no_signal" # ============================================================================== @@ -410,35 +410,57 @@ class EndOfTurnConfig(BaseModel): base_multiplier: Base multiplier for end of turn delay. min_end_of_turn_delay: Minimum end of turn delay. penalties: List of end of turn penalty items. - use_forced_eou: Whether to use forced end of utterance detection. + use_forced_eou: Whether to use forced end of utterance detection. (SHOULD ONLY EVER BE TRUE) """ base_multiplier: float = 1.0 min_end_of_turn_delay: float = 0.01 penalties: list[EndOfTurnPenaltyItem] = Field( default_factory=lambda: [ - # Increase delay + # + # Speaker rate increases expected TTL EndOfTurnPenaltyItem(penalty=3.0, annotation=[AnnotationFlags.VERY_SLOW_SPEAKER]), EndOfTurnPenaltyItem(penalty=2.0, annotation=[AnnotationFlags.SLOW_SPEAKER]), + # + # High / low rate of disfluencies EndOfTurnPenaltyItem(penalty=2.5, annotation=[AnnotationFlags.ENDS_WITH_DISFLUENCY]), EndOfTurnPenaltyItem(penalty=1.1, annotation=[AnnotationFlags.HAS_DISFLUENCY]), + # + # We do NOT have an end of sentence character EndOfTurnPenaltyItem( penalty=2.0, annotation=[AnnotationFlags.ENDS_WITH_EOS], is_not=True, ), - # Decrease delay + # + # We have finals and end of sentence EndOfTurnPenaltyItem( penalty=0.5, annotation=[AnnotationFlags.ENDS_WITH_FINAL, AnnotationFlags.ENDS_WITH_EOS] ), - # Smart Turn + VAD - EndOfTurnPenaltyItem(penalty=0.2, annotation=[AnnotationFlags.SMART_TURN_TRUE]), + # + # Smart Turn - when false, wait longer to prevent premature end of turn EndOfTurnPenaltyItem( - penalty=0.2, annotation=[AnnotationFlags.VAD_STOPPED, AnnotationFlags.SMART_TURN_INACTIVE] + penalty=0.2, annotation=[AnnotationFlags.SMART_TURN_TRUE, AnnotationFlags.SMART_TURN_ACTIVE] + ), + EndOfTurnPenaltyItem( + penalty=2.0, annotation=[AnnotationFlags.SMART_TURN_FALSE, AnnotationFlags.SMART_TURN_ACTIVE] + ), + EndOfTurnPenaltyItem( + penalty=1.5, annotation=[AnnotationFlags.SMART_TURN_NO_SIGNAL, AnnotationFlags.SMART_TURN_ACTIVE] + ), + # + # VAD - only applied when smart turn is not in use and on the speaker stopping + EndOfTurnPenaltyItem( + penalty=0.2, + annotation=[ + AnnotationFlags.VAD_STOPPED, + AnnotationFlags.VAD_ACTIVE, + AnnotationFlags.SMART_TURN_INACTIVE, + ], ), ] ) - use_forced_eou: bool = False + use_forced_eou: bool = True class VoiceActivityConfig(BaseModel): @@ -711,10 +733,16 @@ class VoiceAgentConfig(BaseModel): audio_encoding: AudioEncoding = AudioEncoding.PCM_S16LE chunk_size: int = 160 - # Validation - @model_validator(mode="after") # type: ignore[misc] - def validate_config(self) -> Self: - """Validate the configuration.""" + def validate_config(self) -> None: + """Validate the configuration. + + Cross-field validation is deferred to this method so that configs can be + constructed as overlays (e.g. for presets) without triggering validation + on intermediate states. Call this once the final config is ready. + + Raises: + ValueError: If any validation errors are found. + """ # Validation errors errors: list[str] = [] @@ -723,12 +751,6 @@ def validate_config(self) -> Self: if self.end_of_utterance_mode == EndOfUtteranceMode.EXTERNAL and self.smart_turn_config: errors.append("EXTERNAL mode cannot be used in conjunction with SmartTurnConfig") - # Cannot have FIXED and forced end of utterance enabled without VAD being enabled - if (self.end_of_utterance_mode == EndOfUtteranceMode.FIXED and self.end_of_turn_config.use_forced_eou) and not ( - self.vad_config and self.vad_config.enabled - ): - errors.append("FIXED mode cannot be used in conjunction with forced end of utterance without VAD enabled") - # Cannot use VAD with external end of utterance mode if self.end_of_utterance_mode == EndOfUtteranceMode.EXTERNAL and (self.vad_config and self.vad_config.enabled): errors.append("EXTERNAL mode cannot be used in conjunction with VAD being enabled") @@ -751,13 +773,14 @@ def validate_config(self) -> Self: if self.sample_rate not in [8000, 16000]: errors.append("sample_rate must be 8000 or 16000") + # Check that forced end of utterance is set to True + if not self.end_of_turn_config.use_forced_eou: + errors.append("EndOfTurnConfig.use_forced_eou cannot be False") + # Raise error if any validation errors if errors: raise ValueError(f"{len(errors)} config error(s): {'; '.join(errors)}") - # Return validated config - return self - # ============================================================================== # SESSION & INFO MODELS diff --git a/sdk/voice/speechmatics/voice/_presets.py b/sdk/voice/speechmatics/voice/_presets.py index 2bcb092f..b21cdba3 100644 --- a/sdk/voice/speechmatics/voice/_presets.py +++ b/sdk/voice/speechmatics/voice/_presets.py @@ -57,6 +57,7 @@ def FIXED(overlay: Optional[VoiceAgentConfig] = None) -> VoiceAgentConfig: # no end_of_utterance_silence_trigger=0.5, end_of_utterance_mode=EndOfUtteranceMode.FIXED, speech_segment_config=SpeechSegmentConfig(emit_sentences=False), + end_of_turn_config=EndOfTurnConfig(penalties=[]), ), overlay, ) @@ -82,7 +83,6 @@ def ADAPTIVE(overlay: Optional[VoiceAgentConfig] = None) -> VoiceAgentConfig: # end_of_utterance_mode=EndOfUtteranceMode.ADAPTIVE, speech_segment_config=SpeechSegmentConfig(emit_sentences=False), vad_config=VoiceActivityConfig(enabled=True), - end_of_turn_config=EndOfTurnConfig(use_forced_eou=True), ), overlay, ) @@ -114,7 +114,6 @@ def SMART_TURN(overlay: Optional[VoiceAgentConfig] = None) -> VoiceAgentConfig: enabled=True, ), vad_config=VoiceActivityConfig(enabled=True), - end_of_turn_config=EndOfTurnConfig(use_forced_eou=True), ), overlay, ) @@ -175,7 +174,6 @@ def EXTERNAL(overlay: Optional[VoiceAgentConfig] = None) -> VoiceAgentConfig: # max_delay=2.0, end_of_utterance_mode=EndOfUtteranceMode.EXTERNAL, speech_segment_config=SpeechSegmentConfig(emit_sentences=False), - end_of_turn_config=EndOfTurnConfig(use_forced_eou=True), ), overlay, ) @@ -232,4 +230,10 @@ def _merge_configs(base: VoiceAgentConfig, overlay: Optional[VoiceAgentConfig]) **base.model_dump(exclude_unset=True, exclude_none=True), **overlay.model_dump(exclude_unset=True, exclude_none=True), } - return VoiceAgentConfig.from_dict(merged_dict) + config = VoiceAgentConfig.from_dict(merged_dict) + + # Validate the merged config + config.validate_config() + + # Return the merged config + return config diff --git a/sdk/voice/speechmatics/voice/_smart_turn.py b/sdk/voice/speechmatics/voice/_smart_turn.py index 9ce44a03..529f4653 100644 --- a/sdk/voice/speechmatics/voice/_smart_turn.py +++ b/sdk/voice/speechmatics/voice/_smart_turn.py @@ -196,13 +196,21 @@ async def predict( # Convert int16 to float32 in range [-1, 1] (same as reference implementation) float32_array: np.ndarray = int16_array.astype(np.float32) / 32768.0 + # Whisper's feature extractor requires 16kHz audio. Resample if needed. + target_rate = 16000 + if sample_rate != target_rate: + float32_array = self._resample(float32_array, sample_rate, target_rate) + + # After resampling, max_samples is relative to 16kHz + max_samples_16k = 8 * target_rate + # Process audio using Whisper's feature extractor inputs = self.feature_extractor( float32_array, - sampling_rate=sample_rate, + sampling_rate=target_rate, return_tensors="np", padding="max_length", - max_length=max_samples, + max_length=max_samples_16k, truncation=True, do_normalize=True, ) @@ -230,6 +238,44 @@ async def predict( processing_time=round(float((end_time - start_time).total_seconds()), 3), ) + @staticmethod + def _resample(audio: np.ndarray, orig_rate: int, target_rate: int) -> np.ndarray: + """Resample audio using FFT-based method (zero-pad in frequency domain). + + This produces higher quality resampling than linear interpolation by + preserving the original spectral content without aliasing artifacts. + + Args: + audio: Float32 numpy array of audio samples. + orig_rate: Original sample rate. + target_rate: Target sample rate. + + Returns: + Resampled float32 numpy array. + """ + if orig_rate == target_rate: + return audio + + n_orig = len(audio) + n_target = int(n_orig * target_rate / orig_rate) + + # FFT of original signal + fft = np.fft.rfft(audio) + + # Create zero-padded FFT array for target length + n_fft_target = n_target // 2 + 1 + new_fft = np.zeros(n_fft_target, dtype=complex) + + # Copy original frequency bins (preserves spectral content) + copy_len = min(len(fft), n_fft_target) + new_fft[:copy_len] = fft[:copy_len] + + # Inverse FFT at target length, scale to preserve amplitude + resampled = np.fft.irfft(new_fft, n=n_target) + resampled *= n_target / n_orig + + return resampled.astype(np.float32) + @staticmethod def truncate_audio_to_last_n_seconds( audio_array: np.ndarray, n_seconds: float = 8.0, sample_rate: int = 16000 diff --git a/sdk/voice/speechmatics/voice/_vad.py b/sdk/voice/speechmatics/voice/_vad.py index e5a7b1e8..a65502d5 100644 --- a/sdk/voice/speechmatics/voice/_vad.py +++ b/sdk/voice/speechmatics/voice/_vad.py @@ -38,12 +38,16 @@ # Hint for when dependencies are not available SILERO_INSTALL_HINT = "Silero VAD unavailable. Install `speechmatics-voice[smart]` to enable VAD." -# Silero VAD constants -SILERO_SAMPLE_RATE = 16000 -SILERO_CHUNK_SIZE = 512 # Silero expects 512 samples at 16kHz (32ms chunks) -SILERO_CONTEXT_SIZE = 64 # Silero uses 64-sample context +# Silero VAD supported sample rates (see https://github.com/snakers4/silero-vad) +SILERO_SUPPORTED_SAMPLE_RATES = [8000, 16000] + +# Chunk and context sizes differ by sample rate. +# Both result in ~32ms chunks: 512/16000 = 256/8000 = 0.032s +SILERO_CHUNK_SIZES = {16000: 512, 8000: 256} +SILERO_CONTEXT_SIZES = {16000: 64, 8000: 32} + MODEL_RESET_STATES_TIME = 5.0 # Reset state every 5 seconds -SILERO_CHUNK_DURATION_MS = (SILERO_CHUNK_SIZE / SILERO_SAMPLE_RATE) * 1000 # 32ms per chunk +SILERO_CHUNK_DURATION_MS = 32.0 # Both sample rates produce 32ms chunks class SileroVADResult(BaseModel): @@ -70,7 +74,7 @@ class SileroVAD: """Silero Voice Activity Detector. Uses Silero's opensource VAD model for detecting speech vs silence. - Processes audio in 512-sample chunks at 16kHz. + Supports 8kHz (256-sample chunks) and 16kHz (512-sample chunks). Further information at https://github.com/snakers4/silero-vad """ @@ -172,56 +176,72 @@ def build_session(self, onnx_path: str) -> ort.InferenceSession: # Return the new session return ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"], sess_options=so) - def _init_states(self) -> None: - """Initialize or reset internal VAD states.""" + def _init_states(self, sample_rate: int = 16000) -> None: + """Initialize or reset internal VAD states. + + Args: + sample_rate: Audio sample rate, used to determine context size. + """ + context_size = SILERO_CONTEXT_SIZES.get(sample_rate, 64) self._state = np.zeros((2, 1, 128), dtype=np.float32) - self._context = np.zeros((1, SILERO_CONTEXT_SIZE), dtype=np.float32) + self._context = np.zeros((1, context_size), dtype=np.float32) + self._last_sr: int = sample_rate self._last_reset_time = time.time() - def _maybe_reset_states(self) -> None: + def _maybe_reset_states(self, sample_rate: int) -> None: """Reset ONNX model states periodically to prevent drift. + Also resets if the sample rate changes between calls. + Note: Does NOT reset prediction window or speech state tracking. """ - if (time.time() - self._last_reset_time) >= MODEL_RESET_STATES_TIME: - self._state = np.zeros((2, 1, 128), dtype=np.float32) - self._context = np.zeros((1, SILERO_CONTEXT_SIZE), dtype=np.float32) - self._last_reset_time = time.time() + # Reset if sample rate changed (context size depends on it) + sr_changed = hasattr(self, "_last_sr") and self._last_sr != sample_rate + time_expired = (time.time() - self._last_reset_time) >= MODEL_RESET_STATES_TIME + + if sr_changed or time_expired: + self._init_states(sample_rate) + + def process_chunk(self, chunk_f32: np.ndarray, sample_rate: int = 16000) -> float: + """Process a single audio chunk and return speech probability. - def process_chunk(self, chunk_f32: np.ndarray) -> float: - """Process a single 512-sample chunk and return speech probability. + Chunk size depends on sample rate: 512 samples at 16kHz, 256 at 8kHz. Args: - chunk_f32: Float32 numpy array of exactly 512 samples. + chunk_f32: Float32 numpy array of audio samples. + sample_rate: Sample rate of the audio (8000 or 16000). Returns: Speech probability (0.0-1.0). Raises: - ValueError: If chunk is not exactly 512 samples. + ValueError: If chunk size doesn't match expected size for sample rate. """ - # Ensure shape (1, 512) + # Expected sizes depend on sample rate (512 @ 16kHz, 256 @ 8kHz) + expected_chunk_size = SILERO_CHUNK_SIZES.get(sample_rate, 512) + context_size = SILERO_CONTEXT_SIZES.get(sample_rate, 64) + x = np.reshape(chunk_f32, (1, -1)) - if x.shape[1] != SILERO_CHUNK_SIZE: - raise ValueError(f"Expected {SILERO_CHUNK_SIZE} samples, got {x.shape[1]}") + if x.shape[1] != expected_chunk_size: + raise ValueError(f"Expected {expected_chunk_size} samples for {sample_rate}Hz, got {x.shape[1]}") - # Concatenate with context (previous 64 samples) + # Concatenate with context (previous N samples, where N depends on sample rate) if self._context is not None: x = np.concatenate((self._context, x), axis=1) - # Run ONNX inference + # Run ONNX inference — pass actual sample rate so the model uses correct internal params ort_inputs = { "input": x.astype(np.float32), "state": self._state, - "sr": np.array(SILERO_SAMPLE_RATE, dtype=np.int64), + "sr": np.array(sample_rate, dtype=np.int64), } out, self._state = self.session.run(None, ort_inputs) - # Update context (keep last 64 samples) - self._context = x[:, -SILERO_CONTEXT_SIZE:] + # Update context (keep last N samples for next chunk) + self._context = x[:, -context_size:] # Maybe reset states periodically - self._maybe_reset_states() + self._maybe_reset_states(sample_rate) # Return probability (out shape is (1, 1)) return float(out[0][0]) @@ -229,12 +249,13 @@ def process_chunk(self, chunk_f32: np.ndarray) -> float: async def process_audio(self, audio_bytes: bytes, sample_rate: int = 16000, sample_width: int = 2) -> None: """Process incoming audio bytes and invoke callback on state changes. - This method buffers incomplete chunks and processes all complete 512-sample chunks. + This method buffers incomplete chunks and processes all complete chunks. + Chunk size depends on sample rate: 512 samples at 16kHz, 256 at 8kHz. The callback is invoked only once at the end if the VAD state changed during processing. Args: audio_bytes: Raw audio bytes (int16 PCM). - sample_rate: Sample rate of the audio (must be 16000). + sample_rate: Sample rate of the audio (8000 or 16000). sample_width: Sample width in bytes (2 for int16). """ @@ -242,15 +263,17 @@ async def process_audio(self, audio_bytes: bytes, sample_rate: int = 16000, samp logger.error("SileroVAD is not initialized") return - if sample_rate != SILERO_SAMPLE_RATE: - logger.error(f"Sample rate must be {SILERO_SAMPLE_RATE}Hz, got {sample_rate}Hz") + # Silero VAD only supports 8kHz and 16kHz natively + if sample_rate not in SILERO_SUPPORTED_SAMPLE_RATES: + logger.error(f"Sample rate must be one of {SILERO_SUPPORTED_SAMPLE_RATES}Hz, got {sample_rate}Hz") return # Add new bytes to buffer self._audio_buffer += audio_bytes - # Calculate bytes per chunk (512 samples * 2 bytes for int16) - bytes_per_chunk = SILERO_CHUNK_SIZE * sample_width + # Chunk size depends on sample rate (512 @ 16kHz, 256 @ 8kHz) + chunk_samples = SILERO_CHUNK_SIZES[sample_rate] + bytes_per_chunk = chunk_samples * sample_width # Process all complete chunks in buffer while len(self._audio_buffer) >= bytes_per_chunk: @@ -266,8 +289,8 @@ async def process_audio(self, audio_bytes: bytes, sample_rate: int = 16000, samp float32_array: np.ndarray = int16_array.astype(np.float32) / 32768.0 try: - # Process the chunk and add probability to rolling window - probability = self.process_chunk(float32_array) + # Process the chunk with the correct sample rate + probability = self.process_chunk(float32_array, sample_rate=sample_rate) self._prediction_window.append(probability) except Exception as e: @@ -307,10 +330,26 @@ async def process_audio(self, audio_bytes: bytes, sample_rate: int = 16000, samp # Update state after emitting self._last_is_speech = is_speech - def reset(self) -> None: - """Reset the VAD state and clear audio buffer.""" + @property + def is_speech_likely(self) -> bool: + """Quick check if the most recent raw prediction suggests speech. + + Unlike _last_is_speech which uses a smoothed rolling average (slower to + react), this checks the latest chunk prediction directly — giving faster + speech-onset detection at the cost of more false positives. + """ + if not self._prediction_window: + return self._last_is_speech + return float(self._prediction_window[-1]) >= self._threshold + + def reset(self, sample_rate: int = 16000) -> None: + """Reset the VAD state and clear audio buffer. + + Args: + sample_rate: Sample rate to reinitialise context size for. + """ if self._is_initialized: - self._init_states() + self._init_states(sample_rate) self._audio_buffer = b"" self._prediction_window.clear() self._last_is_speech = False diff --git a/tests/voice/_utils.py b/tests/voice/_utils.py index 8308e905..ad49128a 100644 --- a/tests/voice/_utils.py +++ b/tests/voice/_utils.py @@ -18,7 +18,7 @@ async def get_client( api_key: Optional[str] = None, url: Optional[str] = None, - app: Optional[str] = None, + app: str = "sdk-test", config: Optional[VoiceAgentConfig] = None, connect: bool = True, ) -> VoiceAgentClient: diff --git a/tests/voice/assets/audio_10_16kHz.wav b/tests/voice/assets/audio_10_16kHz.wav new file mode 100644 index 00000000..a6fe0267 Binary files /dev/null and b/tests/voice/assets/audio_10_16kHz.wav differ diff --git a/tests/voice/test_05_utterance.py b/tests/voice/test_05_utterance.py index 9c3c6604..d184ac9a 100644 --- a/tests/voice/test_05_utterance.py +++ b/tests/voice/test_05_utterance.py @@ -10,7 +10,6 @@ from _utils import log_client_messages from speechmatics.voice import AgentServerMessageType -from speechmatics.voice import EndOfTurnConfig from speechmatics.voice import EndOfUtteranceMode from speechmatics.voice import SpeechSegmentConfig from speechmatics.voice import VoiceAgentConfig @@ -232,11 +231,13 @@ async def test_external_vad(): config=VoiceAgentConfig( end_of_utterance_silence_trigger=adaptive_timeout, end_of_utterance_mode=EndOfUtteranceMode.EXTERNAL, - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) assert client is not None + # Set FEOU to disabled for offline tests + client._use_forced_eou = False + # Start the queue client._start_stt_queue() @@ -335,7 +336,6 @@ async def test_end_of_utterance_adaptive_vad(): end_of_utterance_silence_trigger=adaptive_timeout, end_of_utterance_mode=EndOfUtteranceMode.ADAPTIVE, speech_segment_config=SpeechSegmentConfig(emit_sentences=False), - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) assert client is not None @@ -344,6 +344,9 @@ async def test_end_of_utterance_adaptive_vad(): if SHOW_LOG: log_client_messages(client) + # Set FEOU to disabled for offline tests + client._use_forced_eou = False + # Start the queue client._start_stt_queue() diff --git a/tests/voice/test_07_languages.py b/tests/voice/test_07_languages.py index c83428d5..3dc15f0a 100644 --- a/tests/voice/test_07_languages.py +++ b/tests/voice/test_07_languages.py @@ -14,7 +14,6 @@ from speechmatics.voice import AdditionalVocabEntry from speechmatics.voice import AgentServerMessageType -from speechmatics.voice import EndOfTurnConfig from speechmatics.voice import EndOfUtteranceMode from speechmatics.voice import SpeechSegmentConfig from speechmatics.voice import VoiceAgentConfig @@ -25,7 +24,7 @@ # Constants API_KEY = os.getenv("SPEECHMATICS_API_KEY") -URL = "wss://eu2.rt.speechmatics.com/v2" +URL = os.getenv("SPEECHMATICS_RT_URL", "wss://eu2.rt.speechmatics.com/v2") SHOW_LOG = os.getenv("SPEECHMATICS_SHOW_LOG", "0").lower() in ["1", "true"] @@ -113,22 +112,24 @@ async def test_transcribe_languages(sample: AudioSample): if not API_KEY: pytest.skip("Valid API key required for test") + # Config + config = VoiceAgentConfig( + max_delay=1.2, + end_of_utterance_mode=EndOfUtteranceMode.FIXED, + end_of_utterance_silence_trigger=1.2, + language=sample.language, + additional_vocab=[AdditionalVocabEntry(content=vocab) for vocab in sample.vocab], + speech_segment_config=SpeechSegmentConfig( + emit_sentences=False, + ), + ) + # Client client = await get_client( api_key=API_KEY, url=URL, connect=False, - config=VoiceAgentConfig( - max_delay=1.2, - end_of_utterance_mode=EndOfUtteranceMode.FIXED, - end_of_utterance_silence_trigger=1.2, - language=sample.language, - additional_vocab=[AdditionalVocabEntry(content=vocab) for vocab in sample.vocab], - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), - speech_segment_config=SpeechSegmentConfig( - emit_sentences=False, - ), - ), + config=config, ) assert client is not None @@ -188,6 +189,10 @@ def log_segment(message): # Extract the last message assert last_message.get("message") == AgentServerMessageType.ADD_SEGMENT + # Close session + await client.disconnect() + assert not client._is_connected + # Check the segment assert len(segments) >= 1 seg0 = segments[0] @@ -216,7 +221,3 @@ def log_segment(message): print(f"Transcribed: [{str_transcribed}]") print(f"CER: {str_cer}") raise AssertionError("Transcription does not match original") - - # Close session - await client.disconnect() - assert not client._is_connected diff --git a/tests/voice/test_08_multiple_speakers.py b/tests/voice/test_08_multiple_speakers.py index fa662aa5..73a7d299 100644 --- a/tests/voice/test_08_multiple_speakers.py +++ b/tests/voice/test_08_multiple_speakers.py @@ -24,6 +24,7 @@ # Constants API_KEY = os.getenv("SPEECHMATICS_API_KEY") +URL = os.getenv("SPEECHMATICS_RT_URL", "wss://eu2.rt.speechmatics.com/v2") SHOW_LOG = os.getenv("SPEECHMATICS_SHOW_LOG", "0").lower() in ["1", "true"] @@ -116,11 +117,17 @@ async def test_multiple_speakers(sample: SpeakerTest): # Client client = await get_client( + url=URL, api_key=API_KEY, connect=False, config=config, ) + # Debug + if SHOW_LOG: + print(config.to_json(exclude_none=True, exclude_defaults=True, exclude_unset=True, indent=2)) + print(json.dumps(client._transcription_config.to_dict(), indent=2)) + # Create an event to track when the callback is called messages: list[str] = [] bytes_sent: int = 0 @@ -148,19 +155,35 @@ def log_final_segment(message): segments: list[SpeakerSegment] = message["segments"] final_segments.extend(segments) + # Log end of turn + def log_end_of_turn(message): + final_segments.extend([{"speaker_id": "--", "text": "_TURN_"}]) + # Add listeners client.once(AgentServerMessageType.RECOGNITION_STARTED, log_message) client.once(AgentServerMessageType.INFO, log_message) client.on(AgentServerMessageType.WARNING, log_message) client.on(AgentServerMessageType.ERROR, log_message) + client.on(AgentServerMessageType.DIAGNOSTICS, log_message) + + # Transcript + client.on(AgentServerMessageType.ADD_PARTIAL_TRANSCRIPT, log_message) + client.on(AgentServerMessageType.ADD_TRANSCRIPT, log_message) client.on(AgentServerMessageType.ADD_PARTIAL_SEGMENT, log_message) client.on(AgentServerMessageType.ADD_SEGMENT, log_message) + + # Turn events + client.on(AgentServerMessageType.VAD_STATUS, log_message) client.on(AgentServerMessageType.SPEAKER_STARTED, log_message) client.on(AgentServerMessageType.SPEAKER_ENDED, log_message) + client.on(AgentServerMessageType.START_OF_TURN, log_message) client.on(AgentServerMessageType.END_OF_TURN, log_message) + client.on(AgentServerMessageType.END_OF_TURN_PREDICTION, log_message) + client.on(AgentServerMessageType.END_OF_UTTERANCE, log_message) - # Log ADD_SEGMENT + # Log ADD_SEGMENT + END_OF_TURN client.on(AgentServerMessageType.ADD_SEGMENT, log_final_segment) + client.on(AgentServerMessageType.END_OF_TURN, log_end_of_turn) # HEADER if SHOW_LOG: @@ -187,22 +210,44 @@ def log_final_segment(message): progress_callback=log_bytes_sent, ) + # Close session + await client.disconnect() + # FOOTER if SHOW_LOG: print("---") print() print() + # Print all final_segments + if SHOW_LOG: + print("Final segments:") + for idx, segment in enumerate(final_segments): + print(f"{idx}: [{segment.get('speaker_id')}] {segment.get('text')}") + print() + + # Accumulate errors + errors: list[str] = [] + + # Check number of final segments + if len(final_segments) < len(sample.segment_regex): + errors.append(f"Expected at least {len(sample.segment_regex)} segments, got {len(final_segments)}") + # Check final segments against regex + if SHOW_LOG: + print("Checking final segments against regex:") for idx, _test in enumerate(sample.segment_regex): + text = final_segments[idx].get("text") if idx < len(final_segments) else None + match = text and re.search(_test, text, flags=re.IGNORECASE | re.MULTILINE) if SHOW_LOG: - print(f"`{_test}` -> `{final_segments[idx].get('text')}`") - assert re.search(_test, final_segments[idx].get("text"), flags=re.IGNORECASE | re.MULTILINE) + print(f'{idx}: {"✅" if match else "❌"} - `{_test}` -> `{text}`') + if not match: + errors.append(f"Segment {idx}: expected /{_test}/ but got '{text}'") # Check only speakers present speakers = [segment.get("speaker_id") for segment in final_segments] - assert set(speakers) == set(sample.speakers_present) + if set(speakers) != set(sample.speakers_present): + errors.append(f"Speakers: expected {set(sample.speakers_present)} but got {set(speakers)}") - # Close session - await client.disconnect() - assert not client._is_connected + # Report all errors + assert not errors, "\n".join(errors) diff --git a/tests/voice/test_09_speaker_id.py b/tests/voice/test_09_speaker_id.py index 6e8dc0bc..a8bb7ccc 100644 --- a/tests/voice/test_09_speaker_id.py +++ b/tests/voice/test_09_speaker_id.py @@ -11,7 +11,6 @@ from speechmatics.rt import ClientMessageType from speechmatics.voice import AdditionalVocabEntry from speechmatics.voice import AgentServerMessageType -from speechmatics.voice import EndOfTurnConfig from speechmatics.voice import EndOfUtteranceMode from speechmatics.voice import SpeakerIdentifier from speechmatics.voice import SpeechSegmentConfig @@ -23,7 +22,7 @@ # Constants API_KEY = os.getenv("SPEECHMATICS_API_KEY") -URL: Optional[str] = "wss://eu2.rt.speechmatics.com/v2" +URL = os.getenv("SPEECHMATICS_RT_URL", "wss://eu2.rt.speechmatics.com/v2") SHOW_LOG = os.getenv("SPEECHMATICS_SHOW_LOG", "0").lower() in ["1", "true"] # List of know speakers during tests @@ -59,7 +58,6 @@ async def test_extract_speaker_ids(): additional_vocab=[ AdditionalVocabEntry(content="GeoRouter"), ], - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) @@ -192,7 +190,6 @@ async def test_known_speakers(): additional_vocab=[ AdditionalVocabEntry(content="GeoRouter"), ], - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) @@ -227,9 +224,6 @@ def log_final_segment(message): speakers = [segment.get("speaker_id") for segment in final_segments] assert set(speakers) == set({"Assistant", "John Doe"}) - # Should be 5 segments - assert len(final_segments) == 5 - # Close session await client.disconnect() assert not client._is_connected @@ -270,7 +264,6 @@ async def test_ignoring_assistant(): additional_vocab=[ AdditionalVocabEntry(content="GeoRouter"), ], - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) diff --git a/tests/voice/test_11_audio_buffer.py b/tests/voice/test_11_audio_buffer.py index a10834e9..6472859c 100644 --- a/tests/voice/test_11_audio_buffer.py +++ b/tests/voice/test_11_audio_buffer.py @@ -14,7 +14,6 @@ from speechmatics.voice import AdditionalVocabEntry from speechmatics.voice import AgentServerMessageType -from speechmatics.voice import EndOfTurnConfig from speechmatics.voice import EndOfUtteranceMode from speechmatics.voice import SmartTurnConfig from speechmatics.voice import VoiceAgentConfig @@ -263,7 +262,6 @@ async def save_slice( AdditionalVocabEntry(content="Speechmatics", sounds_like=["speech matics"]), ], smart_turn_config=SmartTurnConfig(enabled=True), - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) @@ -369,7 +367,6 @@ async def save_slice( AdditionalVocabEntry(content="Speechmatics", sounds_like=["speech matics"]), ], smart_turn_config=SmartTurnConfig(enabled=True), - end_of_turn_config=EndOfTurnConfig(use_forced_eou=False), ), ) diff --git a/tests/voice/test_17_eou_feou.py b/tests/voice/test_17_eou_feou.py index f78c6abe..fb554c95 100644 --- a/tests/voice/test_17_eou_feou.py +++ b/tests/voice/test_17_eou_feou.py @@ -48,41 +48,41 @@ class TranscriptionTests(BaseModel): SAMPLES: TranscriptionTests = TranscriptionTests.from_dict( { "samples": [ - # { - # "id": "07b", - # "path": "./assets/audio_07b_16kHz.wav", - # "sample_rate": 16000, - # "language": "en", - # "segments": [ - # {"text": "Hello.", "start_time": 1.05, "end_time": 1.67}, - # {"text": "Tomorrow.", "start_time": 3.5, "end_time": 4.1}, - # {"text": "Wednesday.", "start_time": 6.05, "end_time": 6.73}, - # {"text": "Of course. That's fine.", "start_time": 8.8, "end_time": 9.96}, - # {"text": "Behind.", "start_time": 12.03, "end_time": 12.73}, - # {"text": "In front.", "start_time": 14.84, "end_time": 15.52}, - # {"text": "Do you think so?", "start_time": 17.54, "end_time": 18.32}, - # {"text": "Brilliant.", "start_time": 20.55, "end_time": 21.08}, - # {"text": "Banana.", "start_time": 22.98, "end_time": 23.53}, - # {"text": "When?", "start_time": 25.49, "end_time": 25.96}, - # {"text": "Today.", "start_time": 27.66, "end_time": 28.15}, - # {"text": "This morning.", "start_time": 29.91, "end_time": 30.47}, - # {"text": "Goodbye.", "start_time": 32.21, "end_time": 32.68}, - # ], - # }, - # { - # "id": "08", - # "path": "./assets/audio_08_16kHz.wav", - # "sample_rate": 16000, - # "language": "en", - # "segments": [ - # {"text": "Hello.", "start_time": 0.4, "end_time": 0.75}, - # {"text": "Goodbye.", "start_time": 2.12, "end_time": 2.5}, - # {"text": "Banana.", "start_time": 3.84, "end_time": 4.27}, - # {"text": "Breakaway.", "start_time": 5.62, "end_time": 6.42}, - # {"text": "Before.", "start_time": 7.76, "end_time": 8.16}, - # {"text": "After.", "start_time": 9.56, "end_time": 10.05}, - # ], - # }, + { + "id": "07b", + "path": "./assets/audio_07b_16kHz.wav", + "sample_rate": 16000, + "language": "en", + "segments": [ + {"text": "Hello.", "start_time": 1.05, "end_time": 1.67}, + {"text": "Tomorrow.", "start_time": 3.5, "end_time": 4.1}, + {"text": "Wednesday.", "start_time": 6.05, "end_time": 6.73}, + {"text": "Of course. That's fine.", "start_time": 8.8, "end_time": 9.96}, + {"text": "Behind.", "start_time": 12.03, "end_time": 12.73}, + {"text": "In front.", "start_time": 14.84, "end_time": 15.52}, + {"text": "Do you think so?", "start_time": 17.54, "end_time": 18.32}, + {"text": "Brilliant.", "start_time": 20.55, "end_time": 21.08}, + {"text": "Banana.", "start_time": 22.98, "end_time": 23.53}, + {"text": "When?", "start_time": 25.49, "end_time": 25.96}, + {"text": "Today.", "start_time": 27.66, "end_time": 28.15}, + {"text": "This morning.", "start_time": 29.91, "end_time": 30.47}, + {"text": "Goodbye.", "start_time": 32.21, "end_time": 32.68}, + ], + }, + { + "id": "08", + "path": "./assets/audio_08_16kHz.wav", + "sample_rate": 16000, + "language": "en", + "segments": [ + {"text": "Hello.", "start_time": 0.4, "end_time": 0.75}, + {"text": "Goodbye.", "start_time": 2.12, "end_time": 2.5}, + {"text": "Banana.", "start_time": 3.84, "end_time": 4.27}, + {"text": "Breakaway.", "start_time": 5.62, "end_time": 6.42}, + {"text": "Before.", "start_time": 7.76, "end_time": 8.16}, + {"text": "After.", "start_time": 9.56, "end_time": 10.05}, + ], + }, { "id": "09", "path": "./assets/audio_09_16kHz.wav", @@ -97,12 +97,12 @@ class TranscriptionTests(BaseModel): ) # VAD delay -VAD_DELAY_S: list[float] = [0.18, 0.22] +VAD_DELAY_S: list[float] = [0.18] # , 0.22] # Endpoints ENDPOINTS: list[str] = [ - # "wss://eu-west-2-research.speechmatics.cloud/v2", - "wss://eu.rt.speechmatics.com/v2", + "wss://preview.rt.speechmatics.com/v2", + # "wss://eu.rt.speechmatics.com/v2", # "wss://us.rt.speechmatics.com/v2", ] @@ -177,6 +177,11 @@ async def run_test(endpoint: str, sample: TranscriptionTest, config: VoiceAgentC # Start time start_time = datetime.datetime.now() + # Zero time + def zero_time(message): + global start_time + start_time = datetime.datetime.now() + # Finalized segment def add_segments(message): segments = message["segments"] @@ -213,6 +218,13 @@ def log_message(message): log = json.dumps({"ts": round(ts, 3), "payload": message}) print(log) + # Custom listeners + client.on(AgentServerMessageType.RECOGNITION_STARTED, zero_time) + client.on(AgentServerMessageType.END_OF_TURN, eot_detected) + client.on(AgentServerMessageType.ADD_SEGMENT, add_segments) + client.on(AgentServerMessageType.ADD_PARTIAL_TRANSCRIPT, rx_partial) + client.on(AgentServerMessageType.ADD_TRANSCRIPT, rx_partial) + # Add listeners if SHOW_LOG: message_types = [m for m in AgentServerMessageType if m != AgentServerMessageType.AUDIO_ADDED] @@ -220,12 +232,6 @@ def log_message(message): for message_type in message_types: client.on(message_type, log_message) - # Custom listeners - client.on(AgentServerMessageType.END_OF_TURN, eot_detected) - client.on(AgentServerMessageType.ADD_SEGMENT, add_segments) - client.on(AgentServerMessageType.ADD_PARTIAL_TRANSCRIPT, rx_partial) - client.on(AgentServerMessageType.ADD_TRANSCRIPT, rx_partial) - # HEADER if SHOW_LOG: print() @@ -326,7 +332,9 @@ def log_message(message): # Calculate the CER cer = TextUtils.cer(normalized_expected, normalized_received) - print(f"[{idx}] `{normalized_expected}` -> `{normalized_received}` (CER: {cer:.1%})") + # Debug metrics + if SHOW_LOG: + print(f"[{idx}] `{normalized_expected}` -> `{normalized_received}` (CER: {cer:.1%})") # Check CER if cer > CER_THRESHOLD: diff --git a/tests/voice/test_18_feou_timestamp.py b/tests/voice/test_18_feou_timestamp.py new file mode 100644 index 00000000..39d85bfe --- /dev/null +++ b/tests/voice/test_18_feou_timestamp.py @@ -0,0 +1,73 @@ +import os + +import pytest +from _utils import get_client +from _utils import send_silence + +from speechmatics.rt import AudioEncoding +from speechmatics.voice import VoiceAgentConfig + +# Constants +API_KEY = os.getenv("SPEECHMATICS_API_KEY") + +# Skip for CI testing +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping in CI") +pytestmark = pytest.mark.skipif(API_KEY is None, reason="Skipping when no API key is provided") + +# How much silence to send (seconds) +SILENCE_DURATION = 3.0 + +# Tolerance for the timestamp check +TOLERANCE = 0.00 + +# Audio format configurations to test: (encoding, chunk_size, bytes_per_sample) +AUDIO_FORMATS = [ + pytest.param(AudioEncoding.PCM_S16LE, 160, 2, id="s16-chunk160"), + pytest.param(AudioEncoding.PCM_S16LE, 320, 2, id="s16-chunk320"), + pytest.param(AudioEncoding.PCM_F32LE, 160, 4, id="f32-chunk160"), + pytest.param(AudioEncoding.PCM_F32LE, 320, 4, id="f32-chunk320"), +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("encoding,chunk_size,sample_size", AUDIO_FORMATS) +async def test_feou_timestamp(encoding: AudioEncoding, chunk_size: int, sample_size: int): + """Test that audio_seconds_sent correctly computes elapsed audio time. + + Sends 3 seconds of silence (zero bytes) with different audio encodings + and chunk sizes, then verifies that audio_seconds_sent returns the + correct duration. + """ + + # Create and connect client + config = VoiceAgentConfig(audio_encoding=encoding, chunk_size=chunk_size) + client = await get_client( + api_key=API_KEY, + connect=False, + config=config, + ) + + try: + await client.connect() + except Exception: + pytest.skip("Failed to connect to server") + + assert client._is_connected + + # Send 3 seconds of silence + await send_silence( + client, + duration=SILENCE_DURATION, + chunk_size=chunk_size, + sample_size=sample_size, + ) + + # Check the computed audio seconds + actual_seconds = client.audio_seconds_sent + assert ( + abs(actual_seconds - SILENCE_DURATION) <= TOLERANCE + ), f"Expected ~{SILENCE_DURATION}s but got {actual_seconds:.4f}s" + + # Clean up + await client.disconnect() + assert not client._is_connected diff --git a/tests/voice/test_19_no_feou_fix.py b/tests/voice/test_19_no_feou_fix.py new file mode 100644 index 00000000..ad903865 --- /dev/null +++ b/tests/voice/test_19_no_feou_fix.py @@ -0,0 +1,139 @@ +import json +import os +import shutil +import time + +import pytest +from _utils import get_client +from _utils import send_audio_file + +from speechmatics.voice import AgentServerMessageType +from speechmatics.voice import EndOfTurnConfig +from speechmatics.voice import EndOfUtteranceMode +from speechmatics.voice import SmartTurnConfig +from speechmatics.voice import VoiceActivityConfig +from speechmatics.voice import VoiceAgentConfig + +# Skip for CI testing +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping transcription tests in CI") + +# Constants +API_KEY = os.getenv("SPEECHMATICS_API_KEY") +SHOW_LOG = os.getenv("SPEECHMATICS_SHOW_LOG", "0").lower() in ["1", "true"] + + +@pytest.mark.asyncio +async def test_no_feou_fix(): + """Test for when FEOU is disabled.""" + + # API key + if not API_KEY: + pytest.skip("Valid API key required for test") + + # Config + config = VoiceAgentConfig( + language="en", + end_of_utterance_mode=EndOfUtteranceMode.ADAPTIVE, + end_of_utterance_silence_trigger=0.5, + smart_turn_config=SmartTurnConfig(enabled=True, smart_turn_threshold=0.80), + vad_config=VoiceActivityConfig(enabled=True), + end_of_turn_config=EndOfTurnConfig(base_multiplier=1.0), + ) + + # Debug config + print( + config.to_json( + indent=2, + exclude_none=True, + exclude_defaults=True, + exclude_unset=True, + ) + ) + + # Client + client = await get_client( + api_key=API_KEY, + connect=False, + config=config, + ) + + # Disable FEOU + client._use_forced_eou = False + + # Add listeners + messages = [message for message in AgentServerMessageType if message != AgentServerMessageType.AUDIO_ADDED] + + # Colors for messages + colors = { + "StartOfTurn": "\033[94m", # Blue + "EndOfTurn": "\033[92m", # Green + "AddSegment": "\033[93m", # Yellow + "AddPartialSegment": "\033[38;5;208m", # Orange + "SpeakerStarted": "\033[96m", # Cyan + "SpeakerEnded": "\033[95m", # Magenta + "VadStatus": "\033[91m", # Red + } + + # Callback for each message + term_width = shutil.get_terminal_size().columns + log_start_time = time.monotonic() + + def log_message(message): + """Log a message with color and formatting.""" + + # Elapsed time in seconds (right-aligned, capacity for 100s) + elapsed = time.monotonic() - log_start_time + timestamp = f"{elapsed:>7.3f}" + + # Extract message type and remaining payload (drop noisy keys) + msg_type = message.get("message", "") + rest = {k: v for k, v in message.items() if k not in ("message", "format")} + + # Color based on message type (default: dark gray) + color = colors.get(msg_type, "\033[90m") + reset = "\033[0m" + + # Format: timestamp - fixed-width type label + JSON payload + label = f"{msg_type:<20}" + payload = json.dumps(rest, default=str) + visible = f"{timestamp} - {label} - {payload}" + + # Truncate to terminal width to prevent wrapping + if len(visible) > term_width: + visible = visible[: term_width - 1] + "…" + + # Print with color + print(f"{color}{visible}{reset}") + + # Add listeners + for message_type in messages: + client.on(message_type, log_message) + + # Load the audio file `./assets/audio_01_16kHz.wav` + # audio_file = "../../tmp/feou/recording-appointment.wav" + audio_file = "./assets/audio_10_16kHz.wav" + + # HEADER + if SHOW_LOG: + print() + print() + print("---") + + # Connect + await client.connect() + + # Check we are connected + assert client._is_connected + + # Individual payloads + await send_audio_file(client, audio_file) + + # Close session + await client.disconnect() + assert not client._is_connected + + # FOOTER + if SHOW_LOG: + print("---") + print() + print() diff --git a/tests/voice/test_20_stt.py b/tests/voice/test_20_stt.py new file mode 100644 index 00000000..adebddab --- /dev/null +++ b/tests/voice/test_20_stt.py @@ -0,0 +1,129 @@ +"""Raw RT SDK transcription test using 8kHz audio. + +Sends audio_02_8kHz.wav directly via the RT AsyncClient and logs all +server messages (except AUDIO_ADDED) to help debug sample-rate issues. +""" + +import asyncio +import datetime +import json +import os +import time + +import aiofiles +import pytest + +from speechmatics.rt import AsyncClient +from speechmatics.rt import AudioEncoding +from speechmatics.rt import AudioFormat +from speechmatics.rt import ConversationConfig +from speechmatics.rt import ServerMessageType +from speechmatics.rt import SpeakerDiarizationConfig +from speechmatics.rt import TranscriptionConfig + +# Skip for CI testing +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping STT tests in CI") + +# Constants +API_KEY = os.getenv("SPEECHMATICS_API_KEY") +URL = os.getenv("SPEECHMATICS_RT_URL", "wss://eu2.rt.speechmatics.com/v2") +AUDIO_FILE = "./assets/audio_02c_8kHz.wav" +SAMPLE_RATE = 8000 +CHUNK_SIZE = 160 + + +@pytest.mark.asyncio +async def test_rt_transcription_8khz(): + """Transcribe 8kHz audio using the RT SDK directly. + + Logs all server messages (except AUDIO_ADDED) to stdout for debugging. + """ + + if not API_KEY: + pytest.skip("Valid API key required for test") + + # Resolve audio file path + audio_path = os.path.join(os.path.dirname(__file__), AUDIO_FILE) + assert os.path.exists(audio_path), f"Audio file not found: {audio_path}" + + # RT client + client = AsyncClient(api_key=API_KEY, url=URL) + + # Logging + start_time = datetime.datetime.now() + messages: list[dict] = [] + + def log_message(message): + ts = (datetime.datetime.now() - start_time).total_seconds() + entry = {"ts": round(ts, 3), "payload": message} + messages.append(entry) + print(json.dumps(entry)) + + # Register listeners for all message types except AUDIO_ADDED + for msg_type in ServerMessageType: + if msg_type != ServerMessageType.AUDIO_ADDED: + client.on(msg_type, log_message) + + # Audio format + audio_format = AudioFormat( + encoding=AudioEncoding.PCM_S16LE, + sample_rate=SAMPLE_RATE, + chunk_size=CHUNK_SIZE, + ) + + # Transcription config + transcription_config = TranscriptionConfig( + language="en", + operating_point="enhanced", + diarization="speaker", + additional_vocab=[{"content": "GeoRouter"}], + enable_entities=False, + audio_filtering_config={"volume_threshold": 0.0}, + max_delay=2.0, + max_delay_mode="flexible", + enable_partials=True, + speaker_diarization_config=SpeakerDiarizationConfig( + speaker_sensitivity=0.5, + prefer_current_speaker=False, + ), + conversation_config=ConversationConfig( + end_of_utterance_silence_trigger=0.25, + ), + ) + + # Debug + print(json.dumps(transcription_config.to_dict(), indent=2)) + print(json.dumps(audio_format.to_dict(), indent=2)) + + # Start session + await client.start_session( + audio_format=audio_format, + transcription_config=transcription_config, + ) + + # Send audio in real-time + delay = CHUNK_SIZE / SAMPLE_RATE / 2 # 2 bytes per sample (int16) + async with aiofiles.open(audio_path, "rb") as f: + await f.seek(44) # skip WAV header + next_time = time.perf_counter() + delay + while True: + chunk = await f.read(CHUNK_SIZE) + if not chunk: + break + await client.send_audio(chunk) + sleep_time = next_time - time.perf_counter() + if sleep_time > 0: + await asyncio.sleep(sleep_time) + next_time += delay + + # Stop session and wait for EndOfTranscript + await client.stop_session() + + # Basic assertions + assert len(messages) > 0, "No messages received from server" + + # Check we got at least one final transcript + finals = [m for m in messages if m["payload"].get("message") == "AddTranscript"] + assert len(finals) > 0, "No final transcripts received" + + print(f"\n--- Summary: {len(messages)} messages, {len(finals)} final transcripts ---") diff --git a/tests/voice/test_21_stt_raw.py b/tests/voice/test_21_stt_raw.py new file mode 100644 index 00000000..cba07969 --- /dev/null +++ b/tests/voice/test_21_stt_raw.py @@ -0,0 +1,183 @@ +"""Raw WebSocket transcription test using 8kHz audio. + +Bypasses the RT SDK entirely and uses raw websockets to connect to the +Speechmatics endpoint. This isolates whether issues are server-side or +SDK-related. +""" + +import asyncio +import datetime +import json +import os +import time +import wave + +import pytest +import websockets + +# Skip for CI testing +pytestmark = pytest.mark.skipif(os.getenv("CI") == "true", reason="Skipping raw STT tests in CI") + +# Constants +API_KEY = os.getenv("SPEECHMATICS_API_KEY") +URL = os.getenv("SPEECHMATICS_RT_URL", "wss://eu2.rt.speechmatics.com/v2") +AUDIO_FILE = os.getenv("AUDIO_FILE", "./assets/audio_01_16kHz.wav") +CHUNK_SIZE = 160 +RECV_TIMEOUT = 5.0 + + +@pytest.mark.asyncio +async def test_raw_ws_transcription_8khz(): + """Transcribe 8kHz audio using raw WebSocket (no SDK). + + Logs all server messages (except AudioAdded) to stdout for debugging. + """ + + if not API_KEY: + pytest.skip("Valid API key required for test") + + # Resolve audio file path + audio_path = os.path.join(os.path.dirname(__file__), AUDIO_FILE) + assert os.path.exists(audio_path), f"Audio file not found: {audio_path}" + + # Load audio from WAV file + with wave.open(audio_path, "rb") as wf: + sample_rate = wf.getframerate() + sample_width = wf.getsampwidth() + n_channels = wf.getnchannels() + audio_data = wf.readframes(wf.getnframes()) + + # Only mono audio is supported for this test + if n_channels != 1: + pytest.skip(f"Skipping: expected mono audio, got {n_channels} channels") + + # Only 16-bit PCM is supported + if sample_width != 2: + pytest.skip(f"Skipping: expected 16-bit audio, got {sample_width * 8}-bit") + + # Debug + print(f"processing audio file: {audio_path}") + print(f" -> WAV: {sample_rate}Hz, {sample_width * 8}-bit, {n_channels}ch, {len(audio_data)} bytes") + + # Build the StartRecognition message sent over the WebSocket. + # This is the exact JSON the server expects as the first message. + start_recognition = { + "message": "StartRecognition", + "audio_format": { + "type": "raw", + "encoding": "pcm_s16le", + "sample_rate": sample_rate, + }, + "transcription_config": { + "language": "en", + "operating_point": "standard", + "diarization": "speaker", + "additional_vocab": [{"content": "GeoRouter"}], + "enable_entities": False, + "audio_filtering_config": {"volume_threshold": 0.0}, + "max_delay": 2.0, + "max_delay_mode": "flexible", + "enable_partials": True, + "speaker_diarization_config": { + "speaker_sensitivity": 0.5, + "prefer_current_speaker": False, + }, + "conversation_config": { + "end_of_utterance_silence_trigger": 0.25, + }, + }, + } + + # Log the config for debugging + print(json.dumps(start_recognition, indent=2)) + + # Track wall-clock time for log timestamps + start_time = datetime.datetime.now() + messages: list[dict] = [] + + # Log messages from WebSocket + def log_message(msg: dict): + """Append message to buffer and print with wall-clock offset.""" + ts = (datetime.datetime.now() - start_time).total_seconds() + entry = {"ts": round(ts, 3), "payload": msg} + messages.append(entry) + print(json.dumps(entry)) + + # Connect via raw WebSocket with Bearer token auth + async with websockets.connect( + URL, + additional_headers={"Authorization": f"Bearer {API_KEY}"}, + ) as ws: + + # First message must be StartRecognition + await ws.send(json.dumps(start_recognition)) + + # The server may send Info messages before RecognitionStarted, + # so we loop until we see the expected handshake response + while True: + raw = await asyncio.wait_for(ws.recv(), timeout=RECV_TIMEOUT) + if isinstance(raw, str): + msg = json.loads(raw) + log_message(msg) + if msg.get("message") == "RecognitionStarted": + break + + # Signalled when the server sends EndOfTranscript + eot_event = asyncio.Event() + + async def rx(): + """Background task to receive and log server messages.""" + while True: + try: + raw = await asyncio.wait_for(ws.recv(), timeout=RECV_TIMEOUT) + if isinstance(raw, str): + msg = json.loads(raw) + # AudioAdded is high-frequency and not useful for debugging + if msg.get("message") != "AudioAdded": + log_message(msg) + if msg.get("message") == "EndOfTranscript": + eot_event.set() + break + except (asyncio.TimeoutError, websockets.ConnectionClosed): + break + + # RX task + rx_task = asyncio.create_task(rx()) + + # Stream audio from the in-memory buffer at real-time pace. + # delay = duration of one chunk in seconds (bytes / rate / width) + delay = CHUNK_SIZE / sample_rate / sample_width + offset = 0 + next_time = time.perf_counter() + delay + while offset < len(audio_data): + chunk = audio_data[offset : offset + CHUNK_SIZE] + offset += CHUNK_SIZE + await ws.send(chunk) + # Pace sending to match real-time playback speed + sleep_time = next_time - time.perf_counter() + if sleep_time > 0: + await asyncio.sleep(sleep_time) + next_time += delay + + # Signal the server that no more audio will be sent + await ws.send(json.dumps({"message": "EndOfStream", "last_seq_no": 0})) + + # Wait for the server to finish processing and send EndOfTranscript + await asyncio.wait_for(eot_event.wait(), timeout=RECV_TIMEOUT) + + # Clean up the background receiver + rx_task.cancel() + try: + await rx_task + except asyncio.CancelledError: + pass + + # Verify we received messages + assert len(messages) > 0, "No messages received from server" + + # Verify at least one final transcript was produced + finals = [m for m in messages if m["payload"].get("message") == "AddTranscript"] + assert len(finals) > 0, "No final transcripts received" + + # Final report + print(f"\n--- Summary: {len(messages)} messages, {len(finals)} final transcripts ---")