diff --git a/src/openbench/cli/commands/evaluate.py b/src/openbench/cli/commands/evaluate.py index 1b4f5f7..7470c20 100644 --- a/src/openbench/cli/commands/evaluate.py +++ b/src/openbench/cli/commands/evaluate.py @@ -176,6 +176,7 @@ def run_alias_mode( wandb_tags: list[str] | None, use_keywords: bool | None, force_language: bool, + pipeline_config: list[str] | None, verbose: bool, ) -> BenchmarkResult: """Run evaluation using pipeline and dataset aliases.""" @@ -208,6 +209,16 @@ def run_alias_mode( if verbose: typer.echo("✅ Force language: enabled") + # Handle generic pipeline config overrides (key=value pairs) + if pipeline_config: + for item in pipeline_config: + if "=" not in item: + raise typer.BadParameter(f"Invalid --pipeline-config format: '{item}'. Expected key=value") + key, value = item.split("=", 1) + pipeline_config_override[key] = value + if verbose: + typer.echo(f"Config override: {key}={value}") + pipeline = PipelineRegistry.create_pipeline(pipeline_name, config=pipeline_config_override) ######### Build Benchmark Config ######### @@ -345,6 +356,12 @@ def evaluate( "--force-language", help="Force language hinting for compatible pipelines", ), + pipeline_config: list[str] | None = typer.Option( + None, + "--pipeline-config", + "-pc", + help="Override pipeline config values as key=value pairs (e.g. --pipeline-config speaker=serena)", + ), verbose: bool = typer.Option(False, "--verbose", "-v", help="Enable verbose output"), ) -> None: """Run evaluation benchmarks. @@ -406,6 +423,7 @@ def evaluate( wandb_tags=wandb_tags, use_keywords=use_keywords, force_language=force_language, + pipeline_config=pipeline_config, verbose=verbose, ) display_result(result) diff --git a/src/openbench/dataset/__init__.py b/src/openbench/dataset/__init__.py index 8c90fae..540942e 100644 --- a/src/openbench/dataset/__init__.py +++ b/src/openbench/dataset/__init__.py @@ -6,6 +6,7 @@ from .dataset_diarization import DiarizationDataset, DiarizationSample from .dataset_orchestration import OrchestrationDataset, OrchestrationSample from .dataset_registry import DatasetRegistry +from .dataset_speech_generation import SpeechGenerationDataset, SpeechGenerationSample from .dataset_streaming_transcription import StreamingDataset, StreamingSample from .dataset_transcription import TranscriptionDataset, TranscriptionSample @@ -24,11 +25,13 @@ "TranscriptionDataset", "StreamingDataset", "OrchestrationDataset", + "SpeechGenerationDataset", # Sample types "DiarizationSample", "TranscriptionSample", "StreamingSample", "OrchestrationSample", + "SpeechGenerationSample", # Registry "DatasetRegistry", ] diff --git a/src/openbench/dataset/dataset_aliases.py b/src/openbench/dataset/dataset_aliases.py index c37af42..788c568 100644 --- a/src/openbench/dataset/dataset_aliases.py +++ b/src/openbench/dataset/dataset_aliases.py @@ -554,6 +554,20 @@ def register_dataset_aliases() -> None: description="Common Voice dataset for transcription evaluation with up to 400 samples per language this subset contains only russian", ) + ########## SPEECH GENERATION ########## + + DatasetRegistry.register_alias( + "customer-service-tts-prompts-vocalized", + DatasetConfig( + dataset_id="argmaxinc/customer-service-tts-prompts-vocalized", + split="validation", + ), + supported_pipeline_types={ + PipelineType.SPEECH_GENERATION, + }, + description="Customer service TTS prompts with vocalized audio for speech generation evaluation.", + ) + ########## STREAMING TRANSCRIPTION ########## DatasetRegistry.register_alias( @@ -646,6 +660,7 @@ def register_dataset_aliases() -> None: PipelineType.DIARIZATION, PipelineType.STREAMING_TRANSCRIPTION, PipelineType.ORCHESTRATION, + PipelineType.SPEECH_GENERATION, }, description="Local dataset for testing. To use this dataset you need to set the `LOCAL_DATASET_PATH` and `LOCAL_DATASET_SPLIT` environment variables.", ) diff --git a/src/openbench/dataset/dataset_registry.py b/src/openbench/dataset/dataset_registry.py index f73ae17..1b40e3e 100644 --- a/src/openbench/dataset/dataset_registry.py +++ b/src/openbench/dataset/dataset_registry.py @@ -8,6 +8,7 @@ from .dataset_base import BaseDataset, DatasetConfig from .dataset_diarization import DiarizationDataset from .dataset_orchestration import OrchestrationDataset +from .dataset_speech_generation import SpeechGenerationDataset from .dataset_streaming_transcription import StreamingDataset from .dataset_transcription import TranscriptionDataset @@ -139,3 +140,4 @@ def has_alias(cls, alias: str) -> bool: DatasetRegistry.register(PipelineType.ORCHESTRATION, OrchestrationDataset) DatasetRegistry.register(PipelineType.STREAMING_TRANSCRIPTION, StreamingDataset) DatasetRegistry.register(PipelineType.TRANSCRIPTION, TranscriptionDataset) +DatasetRegistry.register(PipelineType.SPEECH_GENERATION, SpeechGenerationDataset) diff --git a/src/openbench/dataset/dataset_speech_generation.py b/src/openbench/dataset/dataset_speech_generation.py new file mode 100644 index 0000000..7341424 --- /dev/null +++ b/src/openbench/dataset/dataset_speech_generation.py @@ -0,0 +1,102 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +import numpy as np +from pydantic import Field +from typing_extensions import TypedDict + +from ..pipeline_prediction import Transcript +from .dataset_base import BaseDataset, BaseSample + + +class SpeechGenerationExtraInfo(TypedDict, total=False): + """Extra info for speech generation samples.""" + + language: str + dialogue: list[dict] + + +class SpeechGenerationRow(TypedDict): + """Expected row structure for speech generation. + + Requires 'text' (the prompt string). No audio needed. + """ + + text: str + + +class SpeechGenerationSample(BaseSample[Transcript, SpeechGenerationExtraInfo]): + """Sample for speech generation tasks. + + The reference Transcript is created from the text + prompt. The pipeline generates audio from this text + and transcribes it to compute WER against reference. + """ + + generated_audio_duration: float | None = Field( + default=None, + description=("Duration (seconds) of the TTS-generated audio. Set by the pipeline after generation."), + ) + + def get_audio_duration(self) -> float: + """Return generated audio duration if available. + + Falls back to the dummy waveform calculation + when the pipeline hasn't set the real duration yet. + """ + if self.generated_audio_duration is not None: + return self.generated_audio_duration + return super().get_audio_duration() + + @property + def text(self) -> str: + """The original text prompt.""" + return self.reference.get_transcript_string() + + +class SpeechGenerationDataset(BaseDataset[SpeechGenerationSample]): + """Dataset for speech generation pipelines. + + Expects column: 'text' (the prompt string). + No audio column is required since audio is generated + by the pipeline itself. + """ + + _expected_columns = ["text"] + _sample_class = SpeechGenerationSample + + def _extract_audio_info(self, row: dict) -> tuple[str, np.ndarray, int]: + """Override to provide dummy audio info. + + Speech generation datasets don't have input audio. + We provide a placeholder waveform so the framework + sample structure is satisfied. The pipeline ignores + the waveform entirely. + """ + audio_name = f"sample_{row['idx']}" + # Use audio_name from the row if available + if "audio_name" in row and row["audio_name"]: + audio_name = str(row["audio_name"]) + dummy_waveform = np.zeros(1, dtype=np.float32) + dummy_sample_rate = 16000 + return audio_name, dummy_waveform, dummy_sample_rate + + def prepare_sample(self, row: SpeechGenerationRow) -> tuple[Transcript, SpeechGenerationExtraInfo]: + """Prepare reference from dataset row. + + Splits text prompt into words to create the + reference Transcript. + """ + text = row["text"] + words = text.split() + reference = Transcript.from_words_info( + words=words, + ) + + extra_info: SpeechGenerationExtraInfo = {} + if "language" in row: + extra_info["language"] = row["language"] + if "dialogue" in row and row["dialogue"]: + extra_info["dialogue"] = row["dialogue"] + + return reference, extra_info diff --git a/src/openbench/metric/keyword_boosting_metrics/boosting_metrics.py b/src/openbench/metric/keyword_boosting_metrics/boosting_metrics.py index 6ed5798..84af57e 100644 --- a/src/openbench/metric/keyword_boosting_metrics/boosting_metrics.py +++ b/src/openbench/metric/keyword_boosting_metrics/boosting_metrics.py @@ -26,6 +26,9 @@ def compute_keyword_stats( ) -> dict[str, Any]: """Compute keyword statistics between reference and hypothesis.""" + if not dictionary: + return {"true_positives": 0, "ground_truth": 0, "false_positives": 0, "keyword_stats": {}} + # Convert transcripts to text ref_text = reference.get_transcript_string() hyp_text = hypothesis.get_transcript_string() diff --git a/src/openbench/metric/word_error_metrics/word_error_metrics.py b/src/openbench/metric/word_error_metrics/word_error_metrics.py index 4cef2a1..24655d5 100644 --- a/src/openbench/metric/word_error_metrics/word_error_metrics.py +++ b/src/openbench/metric/word_error_metrics/word_error_metrics.py @@ -223,6 +223,7 @@ def compute_metric(self, detail: Details) -> float: PipelineType.TRANSCRIPTION, PipelineType.ORCHESTRATION, PipelineType.STREAMING_TRANSCRIPTION, + PipelineType.SPEECH_GENERATION, ), MetricOptions.WER, ) diff --git a/src/openbench/pipeline/__init__.py b/src/openbench/pipeline/__init__.py index 598328b..268d0d2 100644 --- a/src/openbench/pipeline/__init__.py +++ b/src/openbench/pipeline/__init__.py @@ -6,6 +6,7 @@ from .diarization import * from .orchestration import * from .pipeline_registry import PipelineRegistry +from .speech_generation import * from .streaming_transcription import * from .transcription import * diff --git a/src/openbench/pipeline/pipeline_aliases.py b/src/openbench/pipeline/pipeline_aliases.py index 97287b0..6ad66f5 100644 --- a/src/openbench/pipeline/pipeline_aliases.py +++ b/src/openbench/pipeline/pipeline_aliases.py @@ -25,6 +25,14 @@ WhisperXPipeline, ) from .pipeline_registry import PipelineRegistry +from .speech_generation import ( + CartesiaSpeechGenerationPipeline, + ElevenLabsDialogueGenerationPipeline, + ElevenLabsSpeechGenerationPipeline, + GeminiSpeechGenerationPipeline, + OpenAISpeechGenerationPipeline, + WhisperKitSpeechGenerationPipeline, +) from .streaming_transcription import ( AssemblyAIStreamingPipeline, DeepgramStreamingPipeline, @@ -642,6 +650,131 @@ def register_pipeline_aliases() -> None: description="PyannoteAI transcription pipeline (ignores speaker attribution). Uses the precision-2 model with Nvidia Parakeet STT. Requires `PYANNOTE_TOKEN` env var from https://www.pyannote.ai/.", ) + ################# SPEECH GENERATION PIPELINES ################# + + PipelineRegistry.register_alias( + "whisperkit-speech-generation", + WhisperKitSpeechGenerationPipeline, + default_config={ + "out_dir": "./speech_generation_results", + "cli_path": os.getenv("WHISPERKIT_CLI_PATH"), + "speaker": "aiden", + "language": "english", + "seed": 10, + "temperature": 0.9, + "top_k": 50, + "max_new_tokens": 245, + "transcription_cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + "transcription_repo_id": "argmaxinc/parakeetkit-pro", + "transcription_model_variant": "nvidia_parakeet-v2_476MB", + }, + description="WhisperKit speech generation pipeline. Generates audio from text prompts using whisperkit-cli TTS, " + "then transcribes the generated audio to compute WER against the original prompt. " + "Requires `WHISPERKIT_CLI_PATH` env var pointing to the whisperkit-cli binary.", + ) + + PipelineRegistry.register_alias( + "cartesia-speech-generation", + CartesiaSpeechGenerationPipeline, + default_config={ + "out_dir": "./speech_generation_results", + "model_id": "sonic-3", + "voice_id": "e07c00bc-4134-4eae-9ea4-1a55fb45746b", + "container": "wav", + "encoding": "pcm_f32le", + "sample_rate": 44100, + "transcription_cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + "transcription_repo_id": "argmaxinc/parakeetkit-pro", + "transcription_model_variant": "nvidia_parakeet-v2_476MB", + "keep_generated_audio": False, + }, + description="Cartesia speech generation pipeline. Generates audio from text prompts using Cartesia TTS API, " + "then transcribes the generated audio to compute WER against the original prompt. " + "Requires `CARTESIA_API_KEY` and `WHISPERKITPRO_CLI_PATH` env vars.", + ) + + PipelineRegistry.register_alias( + "elevenlabs-dialogue-generation", + ElevenLabsDialogueGenerationPipeline, + default_config={ + "out_dir": "./speech_generation_results", + "model_id": "eleven_v3", + "speaker_voice_map": { + "doctor": "9BWtsMINqrJLrRacOk9x", + "patient": "IKne3meq5aSn9XLyUdCD", + "assistant": "pFZP5JQG7iQjIQuC4Bku", + }, + "default_voice_id": "9BWtsMINqrJLrRacOk9x", + "max_chars_per_chunk": 4500, + "chunk_silence_duration": 0.75, + "transcription_cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + "transcription_repo_id": "argmaxinc/parakeetkit-pro", + "transcription_model_variant": "nvidia_parakeet-v2_476MB", + "keep_generated_audio": False, + }, + description="ElevenLabs dialogue generation pipeline. Generates multi-speaker conversational audio " + "from dialogue turns using ElevenLabs text_to_dialogue API, then transcribes the generated " + "audio to compute WER against the original dialogue text. " + "Requires `ELEVENLABS_API_KEY` and `WHISPERKITPRO_CLI_PATH` env vars.", + ) + + PipelineRegistry.register_alias( + "elevenlabs-speech-generation", + ElevenLabsSpeechGenerationPipeline, + default_config={ + "out_dir": "./speech_generation_results", + "voice_id": "JBFqnCBsd6RMkjVDRZzb", + "model_id": "eleven_multilingual_v2", + "output_format": "mp3_44100_128", + "transcription_cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + "transcription_repo_id": "argmaxinc/parakeetkit-pro", + "transcription_model_variant": "nvidia_parakeet-v2_476MB", + "keep_generated_audio": False, + }, + description="ElevenLabs speech generation pipeline. Generates audio from text prompts using ElevenLabs TTS API, " + "then transcribes the generated audio to compute WER against the original prompt. " + "Requires `ELEVENLABS_API_KEY` and `WHISPERKITPRO_CLI_PATH` env vars.", + ) + + PipelineRegistry.register_alias( + "gemini-speech-generation", + GeminiSpeechGenerationPipeline, + default_config={ + "out_dir": "./speech_generation_results", + "voice_name": "Charon", + "language_code": "en-US", + "model_name": "gemini-2.5-pro-tts", + "audio_encoding": "MP3", + "transcription_cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + "transcription_repo_id": "argmaxinc/parakeetkit-pro", + "transcription_model_variant": "nvidia_parakeet-v2_476MB", + "keep_generated_audio": False, + }, + description="Google Gemini speech generation pipeline. Generates audio from text prompts using Google Cloud TTS, " + "then transcribes the generated audio to compute WER against the original prompt. " + "Requires Google Cloud credentials and `WHISPERKITPRO_CLI_PATH` env var.", + ) + + PipelineRegistry.register_alias( + "openai-speech-generation", + OpenAISpeechGenerationPipeline, + default_config={ + "out_dir": "./speech_generation_results", + "model": "gpt-4o-mini-tts", + "voice": "coral", + "response_format": "wav", + "speed": 1.0, + "transcription_cli_path": os.getenv("WHISPERKITPRO_CLI_PATH"), + "transcription_repo_id": "argmaxinc/parakeetkit-pro", + "transcription_model_variant": "nvidia_parakeet-v2_476MB", + "keep_generated_audio": False, + }, + description="OpenAI speech generation pipeline. Generates audio from text prompts using OpenAI TTS API, " + "then transcribes the generated audio to compute WER against the original prompt. " + "Requires `OPENAI_API_KEY` and `WHISPERKITPRO_CLI_PATH` env vars.", + ) + + ################# STREAMING TRANSCRIPTION PIPELINES ################# PipelineRegistry.register_alias( diff --git a/src/openbench/pipeline/speech_generation/__init__.py b/src/openbench/pipeline/speech_generation/__init__.py new file mode 100644 index 0000000..23f77f9 --- /dev/null +++ b/src/openbench/pipeline/speech_generation/__init__.py @@ -0,0 +1,42 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +from .common import SpeechGenerationConfig, SpeechGenerationOutput +from .speech_generation_cartesia import ( + CartesiaSpeechGenerationConfig, + CartesiaSpeechGenerationPipeline, +) +from .speech_generation_elevenlabs import ( + ElevenLabsSpeechGenerationConfig, + ElevenLabsSpeechGenerationPipeline, +) +from .speech_generation_elevenlabs_dialogue import ( + ElevenLabsDialogueGenerationConfig, + ElevenLabsDialogueGenerationPipeline, +) +from .speech_generation_gemini import ( + GeminiSpeechGenerationConfig, + GeminiSpeechGenerationPipeline, +) +from .speech_generation_openai import ( + OpenAISpeechGenerationConfig, + OpenAISpeechGenerationPipeline, +) +from .speech_generation_wkp import WhisperKitSpeechGenerationPipeline + + +__all__ = [ + "CartesiaSpeechGenerationConfig", + "CartesiaSpeechGenerationPipeline", + "ElevenLabsDialogueGenerationConfig", + "ElevenLabsDialogueGenerationPipeline", + "ElevenLabsSpeechGenerationConfig", + "ElevenLabsSpeechGenerationPipeline", + "GeminiSpeechGenerationConfig", + "GeminiSpeechGenerationPipeline", + "OpenAISpeechGenerationConfig", + "OpenAISpeechGenerationPipeline", + "SpeechGenerationConfig", + "SpeechGenerationOutput", + "WhisperKitSpeechGenerationPipeline", +] diff --git a/src/openbench/pipeline/speech_generation/common.py b/src/openbench/pipeline/speech_generation/common.py new file mode 100644 index 0000000..fd8b7b4 --- /dev/null +++ b/src/openbench/pipeline/speech_generation/common.py @@ -0,0 +1,95 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +from pydantic import Field + +from ...pipeline_prediction import Transcript +from ..base import PipelineConfig, PipelineOutput + + +class SpeechGenerationConfig(PipelineConfig): + """Base config for speech generation pipelines.""" + + cli_path: str = Field( + ..., + description=("Path to the whisperkit-cli binary (used for TTS generation)."), + ) + + # TTS parameters + speaker: str = Field( + default="aiden", + description="Speaker voice for TTS generation.", + ) + language: str = Field( + default="english", + description="Language for TTS generation.", + ) + seed: int | None = Field( + default=None, + description="Random seed for reproducible output.", + ) + temperature: float = Field( + default=0.9, + description="Sampling temperature for TTS.", + ) + top_k: int = Field( + default=50, + description="Top-k sampling for TTS.", + ) + max_new_tokens: int = Field( + default=245, + description="Max RVQ frames to generate.", + ) + models_path: str | None = Field( + default=None, + description="Local model directory for TTS.", + ) + model_repo: str | None = Field( + default=None, + description="HF repo for TTS model download.", + ) + version_dir: str | None = Field( + default=None, + description="TTS model version directory.", + ) + tokenizer: str | None = Field( + default=None, + description="HF tokenizer repo or local path.", + ) + + # Transcription parameters + transcription_cli_path: str | None = Field( + default=None, + description=("Path to CLI for transcription. Defaults to cli_path if not set."), + ) + transcription_repo_id: str | None = Field( + default=None, + description=("HuggingFace repo ID for transcription model (e.g. argmaxinc/parakeetkit-pro)."), + ) + transcription_model_variant: str | None = Field( + default=None, + description=("Model variant folder within the repo (e.g. nvidia_parakeet-v2_476MB)."), + ) + transcription_model_path: str | None = Field( + default=None, + description=("Local path to ASR model dir. Overrides repo_id/model_variant."), + ) + transcription_word_timestamps: bool = Field( + default=True, + description="Include word timestamps.", + ) + transcription_chunking_strategy: str = Field( + default="vad", + description="Chunking strategy (none or vad).", + ) + + +class SpeechGenerationOutput(PipelineOutput[Transcript]): + """Output for speech generation pipelines. + + The prediction is a Transcript of the generated audio + (obtained by transcribing the TTS output). WER is + computed against the original text prompt. + """ + + pass diff --git a/src/openbench/pipeline/speech_generation/speech_generation_cartesia.py b/src/openbench/pipeline/speech_generation/speech_generation_cartesia.py new file mode 100644 index 0000000..dcf11f9 --- /dev/null +++ b/src/openbench/pipeline/speech_generation/speech_generation_cartesia.py @@ -0,0 +1,345 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +""" +Speech generation pipeline using Cartesia TTS API. + +Generates TTS audio from text prompts via Cartesia, +then transcribes the generated audio back to text using +WhisperKitPro (Parakeet) for WER evaluation against the +original prompt. +""" + +import os +import time +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import BaseModel, Field + +from ...dataset.dataset_base import BaseSample +from ...dataset.dataset_speech_generation import ( + SpeechGenerationSample, +) +from ...engine.whisperkitpro_engine import ( + WhisperKitPro, + WhisperKitProConfig, + WhisperKitProInput, +) +from ...pipeline_prediction import Transcript +from ..base import ( + Pipeline, + PipelineConfig, + PipelineOutput, + PipelineType, + register_pipeline, +) +from .common import SpeechGenerationOutput + +logger = get_logger(__name__) + +TEMP_TTS_AUDIO_DIR = Path("./temp_tts_audio") + + +class CartesiaSpeechGenerationConfig(PipelineConfig): + """Config for the Cartesia speech generation pipeline.""" + + # Cartesia TTS parameters + api_key: str | None = Field( + default=None, + description=( + "Cartesia API key. Falls back to " + "CARTESIA_API_KEY env var." + ), + ) + model_id: str = Field( + default="sonic-3", + description="Cartesia TTS model ID.", + ) + voice_id: str = Field( + default="e07c00bc-4134-4eae-9ea4-1a55fb45746b", + description="Cartesia voice ID.", + ) + container: str = Field( + default="wav", + description="Audio container format (wav, raw).", + ) + encoding: str = Field( + default="pcm_f32le", + description=( + "Audio encoding " + "(pcm_f32le, pcm_s16le, pcm_mulaw, pcm_alaw)." + ), + ) + sample_rate: int = Field( + default=44100, + description="Audio sample rate in Hz.", + ) + + # Transcription parameters (WhisperKitPro / Parakeet) + transcription_cli_path: str = Field( + ..., + description=( + "Path to the whisperkit-cli binary " + "used for transcription." + ), + ) + transcription_repo_id: str | None = Field( + default=None, + description=( + "HuggingFace repo ID for transcription " + "model (e.g. argmaxinc/parakeetkit-pro)." + ), + ) + transcription_model_variant: str | None = Field( + default=None, + description=( + "Model variant folder within the repo " + "(e.g. nvidia_parakeet-v2_476MB)." + ), + ) + transcription_model_path: str | None = Field( + default=None, + description=( + "Local path to ASR model dir. " + "Overrides repo_id/model_variant." + ), + ) + transcription_word_timestamps: bool = Field( + default=True, + description="Include word timestamps.", + ) + transcription_chunking_strategy: str = Field( + default="vad", + description="Chunking strategy (none or vad).", + ) + + keep_generated_audio: bool = Field( + default=False, + description=( + "If True, keep the generated TTS audio " + "files instead of deleting them." + ), + ) + + +class CartesiaSpeechGenerationInput(BaseModel): + """Input for the Cartesia speech generation pipeline.""" + + text: str = Field( + ..., + description="Text prompt to generate speech from.", + ) + audio_name: str = Field( + ..., + description=( + "Unique identifier for this sample " + "(used for temp file naming)." + ), + ) + + +@register_pipeline +class CartesiaSpeechGenerationPipeline(Pipeline): + """Speech generation pipeline using Cartesia TTS API. + + This pipeline: + 1. Generates audio from text via Cartesia TTS API + 2. Transcribes audio via WhisperKitPro engine (Parakeet) + 3. Returns transcription as Transcript for WER eval + 4. Cleans up temporary audio and report files + """ + + _config_class = CartesiaSpeechGenerationConfig + pipeline_type = PipelineType.SPEECH_GENERATION + + def build_pipeline( + self, + ) -> Callable[[CartesiaSpeechGenerationInput], Transcript]: + config = self.config + pipeline_ref = self + + transcription_engine = self._build_transcription_engine() + + api_key = config.api_key or os.getenv("CARTESIA_API_KEY") + if not api_key: + raise ValueError( + "Cartesia API key must be provided " + "via config or CARTESIA_API_KEY env var." + ) + + from cartesia import Cartesia + + client = Cartesia(api_key=api_key) + + def generate_and_transcribe( + input: CartesiaSpeechGenerationInput, + ) -> Transcript: + TEMP_TTS_AUDIO_DIR.mkdir(parents=True, exist_ok=True) + + audio_path = ( + TEMP_TTS_AUDIO_DIR + / f"{input.audio_name}.{config.container}" + ) + + # -- Step 1: Generate audio via Cartesia API -- + response = client.tts.generate( + model_id=config.model_id, + output_format={ + "container": config.container, + "encoding": config.encoding, + "sample_rate": config.sample_rate, + }, + transcript=input.text, + voice={ + "mode": "id", + "id": config.voice_id, + }, + ) + response.write_to_file(str(audio_path)) + + if ( + not audio_path.exists() + or audio_path.stat().st_size == 0 + ): + raise RuntimeError( + "Cartesia TTS failed: audio file " + f"missing or empty at {audio_path}" + ) + + logger.info( + f"Generated Cartesia TTS audio: {audio_path}" + ) + + # -- Step 2: Read audio duration -- + try: + import soundfile as sf + + info = sf.info(str(audio_path)) + pipeline_ref._last_generated_duration = ( + info.duration + ) + except Exception as e: + logger.warning( + f"Audio duration read failed: {e}" + ) + pipeline_ref._last_generated_duration = None + + # -- Step 3: Transcribe via WhisperKitPro -- + engine_input = WhisperKitProInput( + audio_path=audio_path, + keep_audio=config.keep_generated_audio, + ) + engine_output = transcription_engine(engine_input) + + # -- Step 4: Parse transcription report -- + json_path = engine_output.json_report_path + if json_path.exists(): + import json + + with json_path.open("r") as f: + data = json.load(f) + all_words, all_starts, all_ends = [], [], [] + for seg in data.get("segments", []): + for w in seg.get("words", []): + all_words.append(w["word"]) + if "start" in w: + all_starts.append(w["start"]) + if "end" in w: + all_ends.append(w["end"]) + transcript = Transcript.from_words_info( + words=all_words, + start=( + all_starts if all_starts else None + ), + end=all_ends if all_ends else None, + ) + json_path.unlink(missing_ok=True) + srt_path = engine_output.srt_report_path + if srt_path: + srt_path.unlink(missing_ok=True) + else: + raise RuntimeError( + "Transcription report not found " + f"at {json_path}" + ) + + text_preview = ( + transcript.get_transcript_string()[:100] + ) + logger.info(f"Transcription: {text_preview}...") + + return transcript + + return generate_and_transcribe + + def _build_transcription_engine(self) -> WhisperKitPro: + """Create WhisperKitPro engine for transcription.""" + config = self.config + + import coremltools as ct + + compute = ct.ComputeUnit.CPU_AND_NE + engine_config = WhisperKitProConfig( + repo_id=config.transcription_repo_id, + model_variant=config.transcription_model_variant, + model_dir=config.transcription_model_path, + word_timestamps=( + config.transcription_word_timestamps + ), + chunking_strategy=( + config.transcription_chunking_strategy + ), + audio_encoder_compute_units=compute, + text_decoder_compute_units=compute, + ) + + return WhisperKitPro( + cli_path=config.transcription_cli_path, + transcription_config=engine_config, + ) + + def __call__( + self, input_sample: BaseSample + ) -> PipelineOutput: + """Run pipeline and set generated audio duration.""" + self._last_generated_duration: float | None = None + parsed_input = self.parse_input(input_sample) + start_time = time.perf_counter() + output = self.pipeline(parsed_input) + end_time = time.perf_counter() + prediction_time = end_time - start_time + parsed_output = self.parse_output(output) + if parsed_output.prediction_time is None: + parsed_output.prediction_time = prediction_time + + dur = self._last_generated_duration + logger.debug(f"Generated audio duration: {dur}s") + is_sg = isinstance( + input_sample, SpeechGenerationSample + ) + if is_sg and dur is not None: + input_sample.generated_audio_duration = dur + dur_val = input_sample.generated_audio_duration + logger.debug( + f"Set sample duration to {dur_val}s" + ) + + return parsed_output + + def parse_input( + self, input_sample: SpeechGenerationSample + ) -> CartesiaSpeechGenerationInput: + """Extract text prompt from the sample.""" + text = input_sample.reference.get_transcript_string() + return CartesiaSpeechGenerationInput( + text=text, + audio_name=input_sample.audio_name, + ) + + def parse_output( + self, output: Transcript + ) -> SpeechGenerationOutput: + """Wrap transcription into output.""" + return SpeechGenerationOutput(prediction=output) diff --git a/src/openbench/pipeline/speech_generation/speech_generation_elevenlabs.py b/src/openbench/pipeline/speech_generation/speech_generation_elevenlabs.py new file mode 100644 index 0000000..1e1af94 --- /dev/null +++ b/src/openbench/pipeline/speech_generation/speech_generation_elevenlabs.py @@ -0,0 +1,303 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +""" +Speech generation pipeline using ElevenLabs TTS API. + +Generates TTS audio from text prompts via ElevenLabs, +then transcribes the generated audio back to text using +WhisperKitPro (Parakeet) for WER evaluation against the +original prompt. +""" + +import os +import time +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import BaseModel, Field + +from ...dataset.dataset_base import BaseSample +from ...dataset.dataset_speech_generation import SpeechGenerationSample +from ...engine.whisperkitpro_engine import ( + WhisperKitPro, + WhisperKitProConfig, + WhisperKitProInput, +) +from ...pipeline_prediction import Transcript +from ..base import ( + Pipeline, + PipelineConfig, + PipelineOutput, + PipelineType, + register_pipeline, +) +from .common import SpeechGenerationOutput + +logger = get_logger(__name__) + +TEMP_TTS_AUDIO_DIR = Path("./temp_tts_audio") + + +class ElevenLabsSpeechGenerationConfig(PipelineConfig): + """Config for the ElevenLabs speech generation pipeline.""" + + # ElevenLabs TTS parameters + api_key: str | None = Field( + default=None, + description=( + "ElevenLabs API key. Falls back to " + "ELEVENLABS_API_KEY env var." + ), + ) + voice_id: str = Field( + default="JBFqnCBsd6RMkjVDRZzb", + description="ElevenLabs voice ID.", + ) + model_id: str = Field( + default="eleven_v3", #"eleven_multilingual_v2", + description="ElevenLabs model ID.", + ) + output_format: str = Field( + default="mp3_44100_128", + description="ElevenLabs output audio format.", + ) + + # Transcription parameters (WhisperKitPro / Parakeet) + transcription_cli_path: str = Field( + ..., + description=( + "Path to the whisperkit-cli binary " + "used for transcription." + ), + ) + transcription_repo_id: str | None = Field( + default=None, + description=( + "HuggingFace repo ID for transcription " + "model (e.g. argmaxinc/parakeetkit-pro)." + ), + ) + transcription_model_variant: str | None = Field( + default=None, + description=( + "Model variant folder within the repo " + "(e.g. nvidia_parakeet-v2_476MB)." + ), + ) + transcription_model_path: str | None = Field( + default=None, + description=( + "Local path to ASR model dir. " + "Overrides repo_id/model_variant." + ), + ) + transcription_word_timestamps: bool = Field( + default=True, + description="Include word timestamps.", + ) + transcription_chunking_strategy: str = Field( + default="vad", + description="Chunking strategy (none or vad).", + ) + + keep_generated_audio: bool = Field( + default=False, + description=( + "If True, keep the generated TTS audio " + "files instead of deleting them." + ), + ) + + +class ElevenLabsSpeechGenerationInput(BaseModel): + """Input for the ElevenLabs speech generation pipeline.""" + + text: str = Field( + ..., + description="Text prompt to generate speech from.", + ) + audio_name: str = Field( + ..., + description=( + "Unique identifier for this sample " + "(used for temp file naming)." + ), + ) + + +@register_pipeline +class ElevenLabsSpeechGenerationPipeline(Pipeline): + """Speech generation pipeline using ElevenLabs TTS API. + + This pipeline: + 1. Generates audio from text via ElevenLabs text-to-speech API + 2. Transcribes audio via WhisperKitPro engine (Parakeet) + 3. Returns transcription as Transcript for WER eval + 4. Cleans up temporary audio and report files + """ + + _config_class = ElevenLabsSpeechGenerationConfig + pipeline_type = PipelineType.SPEECH_GENERATION + + def build_pipeline( + self, + ) -> Callable[[ElevenLabsSpeechGenerationInput], Transcript]: + config = self.config + pipeline_ref = self + + transcription_engine = self._build_transcription_engine() + + api_key = config.api_key or os.getenv("ELEVENLABS_API_KEY") + if not api_key: + raise ValueError( + "ElevenLabs API key must be provided " + "via config or ELEVENLABS_API_KEY env var." + ) + + from elevenlabs.client import ElevenLabs + + client = ElevenLabs(api_key=api_key) + + def generate_and_transcribe( + input: ElevenLabsSpeechGenerationInput, + ) -> Transcript: + TEMP_TTS_AUDIO_DIR.mkdir(parents=True, exist_ok=True) + + ext = config.output_format.split("_")[0] + audio_path = TEMP_TTS_AUDIO_DIR / f"{input.audio_name}.{ext}" + + # -- Step 1: Generate audio via ElevenLabs API -- + audio_iter = client.text_to_speech.convert( + text=input.text, + voice_id=config.voice_id, + model_id=config.model_id, + output_format=config.output_format, + ) + + with open(audio_path, "wb") as f: + for chunk in audio_iter: + f.write(chunk) + + if not audio_path.exists() or audio_path.stat().st_size == 0: + raise RuntimeError( + "ElevenLabs TTS failed: audio file " + f"missing or empty at {audio_path}" + ) + + logger.info( + f"Generated ElevenLabs TTS audio: {audio_path}" + ) + + # -- Step 2: Read audio duration -- + try: + import soundfile as sf + + info = sf.info(str(audio_path)) + pipeline_ref._last_generated_duration = info.duration + except Exception as e: + logger.warning(f"Audio duration read failed: {e}") + pipeline_ref._last_generated_duration = None + + # -- Step 3: Transcribe via WhisperKitPro (Parakeet) -- + engine_input = WhisperKitProInput( + audio_path=audio_path, + keep_audio=config.keep_generated_audio, + ) + engine_output = transcription_engine(engine_input) + + # -- Step 4: Parse transcription report -- + json_path = engine_output.json_report_path + if json_path.exists(): + import json + + with json_path.open("r") as f: + data = json.load(f) + all_words, all_starts, all_ends = [], [], [] + for seg in data.get("segments", []): + for w in seg.get("words", []): + all_words.append(w["word"]) + if "start" in w: + all_starts.append(w["start"]) + if "end" in w: + all_ends.append(w["end"]) + transcript = Transcript.from_words_info( + words=all_words, + start=all_starts if all_starts else None, + end=all_ends if all_ends else None, + ) + json_path.unlink(missing_ok=True) + srt_path = engine_output.srt_report_path + if srt_path: + srt_path.unlink(missing_ok=True) + else: + raise RuntimeError( + "Transcription report not found " + f"at {json_path}" + ) + + text_preview = transcript.get_transcript_string()[:100] + logger.info(f"Transcription: {text_preview}...") + + return transcript + + return generate_and_transcribe + + def _build_transcription_engine(self) -> WhisperKitPro: + """Create WhisperKitPro engine for transcription (Parakeet).""" + config = self.config + + import coremltools as ct + + compute = ct.ComputeUnit.CPU_AND_NE + engine_config = WhisperKitProConfig( + repo_id=config.transcription_repo_id, + model_variant=config.transcription_model_variant, + model_dir=config.transcription_model_path, + word_timestamps=config.transcription_word_timestamps, + chunking_strategy=config.transcription_chunking_strategy, + audio_encoder_compute_units=compute, + text_decoder_compute_units=compute, + ) + + return WhisperKitPro( + cli_path=config.transcription_cli_path, + transcription_config=engine_config, + ) + + def __call__(self, input_sample: BaseSample) -> PipelineOutput: + """Run pipeline and set generated audio duration.""" + self._last_generated_duration: float | None = None + parsed_input = self.parse_input(input_sample) + start_time = time.perf_counter() + output = self.pipeline(parsed_input) + end_time = time.perf_counter() + prediction_time = end_time - start_time + parsed_output = self.parse_output(output) + if parsed_output.prediction_time is None: + parsed_output.prediction_time = prediction_time + + dur = self._last_generated_duration + logger.debug(f"Generated audio duration: {dur}s") + is_sg = isinstance(input_sample, SpeechGenerationSample) + if is_sg and dur is not None: + input_sample.generated_audio_duration = dur + dur_val = input_sample.generated_audio_duration + logger.debug(f"Set sample duration to {dur_val}s") + + return parsed_output + + def parse_input( + self, input_sample: SpeechGenerationSample + ) -> ElevenLabsSpeechGenerationInput: + """Extract text prompt from the sample.""" + text = input_sample.reference.get_transcript_string() + return ElevenLabsSpeechGenerationInput( + text=text, + audio_name=input_sample.audio_name, + ) + + def parse_output(self, output: Transcript) -> SpeechGenerationOutput: + """Wrap transcription into output.""" + return SpeechGenerationOutput(prediction=output) diff --git a/src/openbench/pipeline/speech_generation/speech_generation_elevenlabs_dialogue.py b/src/openbench/pipeline/speech_generation/speech_generation_elevenlabs_dialogue.py new file mode 100644 index 0000000..dd4104a --- /dev/null +++ b/src/openbench/pipeline/speech_generation/speech_generation_elevenlabs_dialogue.py @@ -0,0 +1,448 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +""" +Speech generation pipeline using ElevenLabs text-to-dialogue API. + +Generates multi-speaker conversational audio from dialogue turns +via ElevenLabs, then transcribes the generated audio back to text +using WhisperKitPro (Parakeet) for WER evaluation against the +original dialogue text. + +Long dialogues that exceed the API character limit are automatically +split into chunks, generated separately, and stitched together with +configurable silence gaps between chunks. +""" + +import io +import os +import time +from pathlib import Path +from typing import Callable + +import numpy as np +import soundfile as sf +from argmaxtools.utils import get_logger +from pydantic import BaseModel, Field + +from ...dataset.dataset_base import BaseSample +from ...dataset.dataset_speech_generation import SpeechGenerationSample +from ...engine.whisperkitpro_engine import ( + WhisperKitPro, + WhisperKitProConfig, + WhisperKitProInput, +) +from ...pipeline_prediction import Transcript +from ..base import ( + Pipeline, + PipelineConfig, + PipelineOutput, + PipelineType, + register_pipeline, +) +from .common import SpeechGenerationOutput + +logger = get_logger(__name__) + +TEMP_TTS_AUDIO_DIR = Path("./temp_tts_audio") + +DEFAULT_SPEAKER_VOICE_MAP = { + "doctor": "9BWtsMINqrJLrRacOk9x", + "patient": "IKne3meq5aSn9XLyUdCD", + "assistant": "pFZP5JQG7iQjIQuC4Bku", +} + +MAX_CHARS_PER_CHUNK = 4500 + + +def _chunk_dialogue_turns( + turns: list[dict], + speaker_voice_map: dict[str, str], + default_voice_id: str, + max_chars: int = MAX_CHARS_PER_CHUNK, +) -> list[list[dict]]: + """Split dialogue turns into chunks that fit under the char limit. + + Each chunk is a list of ElevenLabs input dicts ({text, voice_id}). + Splits on turn boundaries so no individual turn is broken. + """ + chunks: list[list[dict]] = [] + current_chunk: list[dict] = [] + current_chars = 0 + + for turn in turns: + speaker = turn.get("speaker", "") + voice_id = speaker_voice_map.get(speaker, default_voice_id) + entry = {"text": turn["text"], "voice_id": voice_id} + turn_chars = len(turn["text"]) + + if current_chars + turn_chars > max_chars and current_chunk: + chunks.append(current_chunk) + current_chunk = [] + current_chars = 0 + + current_chunk.append(entry) + current_chars += turn_chars + + if current_chunk: + chunks.append(current_chunk) + + return chunks + + +def _stitch_audio_files( + chunk_paths: list[Path], + output_path: Path, + silence_duration: float = 0.75, +) -> None: + """Concatenate audio files with silence gaps between them. + + Decodes each chunk, inserts silence, and writes as WAV. + """ + from pydub import AudioSegment + + combined = AudioSegment.empty() + for i, path in enumerate(chunk_paths): + segment = AudioSegment.from_file(str(path)) + if i > 0: + silence_ms = int(silence_duration * 1000) + combined += AudioSegment.silent( + duration=silence_ms, + frame_rate=segment.frame_rate, + ) + combined += segment + + combined.export(str(output_path), format="wav") + + +class ElevenLabsDialogueGenerationConfig(PipelineConfig): + """Config for the ElevenLabs dialogue generation pipeline.""" + + api_key: str | None = Field( + default=None, + description=( + "ElevenLabs API key. Falls back to " + "ELEVENLABS_API_KEY env var." + ), + ) + model_id: str = Field( + default="eleven_v3", + description="ElevenLabs model ID for dialogue.", + ) + speaker_voice_map: dict[str, str] = Field( + default_factory=lambda: dict(DEFAULT_SPEAKER_VOICE_MAP), + description=( + "Mapping of speaker names to ElevenLabs voice IDs. " + "Speakers not in this map use the default_voice_id." + ), + ) + default_voice_id: str = Field( + default="9BWtsMINqrJLrRacOk9x", + description="Fallback voice ID for unmapped speakers.", + ) + max_chars_per_chunk: int = Field( + default=MAX_CHARS_PER_CHUNK, + description=( + "Max characters per API call. Dialogues exceeding " + "this are split into multiple chunks." + ), + ) + chunk_silence_duration: float = Field( + default=0.75, + description=( + "Silence duration (seconds) inserted between " + "stitched chunks. Range 0.5-1.0 recommended." + ), + ) + + # Transcription parameters (WhisperKitPro / Parakeet) + transcription_cli_path: str = Field( + ..., + description=( + "Path to the whisperkit-cli binary " + "used for transcription." + ), + ) + transcription_repo_id: str | None = Field( + default=None, + description=( + "HuggingFace repo ID for transcription " + "model (e.g. argmaxinc/parakeetkit-pro)." + ), + ) + transcription_model_variant: str | None = Field( + default=None, + description=( + "Model variant folder within the repo " + "(e.g. nvidia_parakeet-v2_476MB)." + ), + ) + transcription_model_path: str | None = Field( + default=None, + description=( + "Local path to ASR model dir. " + "Overrides repo_id/model_variant." + ), + ) + transcription_word_timestamps: bool = Field( + default=True, + description="Include word timestamps.", + ) + transcription_chunking_strategy: str = Field( + default="vad", + description="Chunking strategy (none or vad).", + ) + + keep_generated_audio: bool = Field( + default=False, + description=( + "If True, keep the generated TTS audio " + "files instead of deleting them." + ), + ) + + +class ElevenLabsDialogueGenerationInput(BaseModel): + """Input for the ElevenLabs dialogue generation pipeline.""" + + text: str = Field( + ..., + description="Full concatenated dialogue text (for reference).", + ) + dialogue: list[dict] = Field( + ..., + description="List of dialogue turns with speaker and text.", + ) + audio_name: str = Field( + ..., + description=( + "Unique identifier for this sample " + "(used for temp file naming)." + ), + ) + + +@register_pipeline +class ElevenLabsDialogueGenerationPipeline(Pipeline): + """Speech generation pipeline using ElevenLabs text-to-dialogue API. + + This pipeline: + 1. Chunks dialogue turns to fit API character limits + 2. Generates audio per chunk via ElevenLabs text_to_dialogue + 3. Saves chunk audio files under temp_tts_audio/chunks/ + 4. Stitches chunks with silence gaps into final audio + 5. Transcribes audio via WhisperKitPro engine (Parakeet) + 6. Returns transcription as Transcript for WER eval + """ + + _config_class = ElevenLabsDialogueGenerationConfig + pipeline_type = PipelineType.SPEECH_GENERATION + + def build_pipeline( + self, + ) -> Callable[[ElevenLabsDialogueGenerationInput], Transcript]: + config = self.config + pipeline_ref = self + + transcription_engine = self._build_transcription_engine() + + api_key = config.api_key or os.getenv("ELEVENLABS_API_KEY") + if not api_key: + raise ValueError( + "ElevenLabs API key must be provided " + "via config or ELEVENLABS_API_KEY env var." + ) + + from elevenlabs.client import ElevenLabs + + client = ElevenLabs(api_key=api_key) + + def _generate_chunk( + chunk_inputs: list[dict], + chunk_path: Path, + ) -> Path: + """Generate audio for a single chunk of dialogue turns.""" + audio_iter = client.text_to_dialogue.convert( + inputs=chunk_inputs, + ) + with open(chunk_path, "wb") as f: + for data in audio_iter: + f.write(data) + + if not chunk_path.exists() or chunk_path.stat().st_size == 0: + raise RuntimeError( + "ElevenLabs dialogue TTS failed: " + f"chunk empty at {chunk_path}" + ) + return chunk_path + + def generate_and_transcribe( + input: ElevenLabsDialogueGenerationInput, + ) -> Transcript: + TEMP_TTS_AUDIO_DIR.mkdir(parents=True, exist_ok=True) + chunks_dir = TEMP_TTS_AUDIO_DIR / "chunks" + chunks_dir.mkdir(parents=True, exist_ok=True) + + # -- Step 1: Chunk dialogue turns -- + chunks = _chunk_dialogue_turns( + input.dialogue, + config.speaker_voice_map, + config.default_voice_id, + max_chars=config.max_chars_per_chunk, + ) + + total_turns = sum(len(c) for c in chunks) + logger.info( + f"Generating dialogue for {input.audio_name}: " + f"{total_turns} turns in {len(chunks)} chunk(s)" + ) + + # -- Step 2: Generate audio per chunk -- + chunk_paths: list[Path] = [] + for i, chunk_inputs in enumerate(chunks): + chunk_chars = sum(len(e["text"]) for e in chunk_inputs) + chunk_path = ( + chunks_dir + / f"{input.audio_name}_chunk_{i}.mp3" + ) + logger.info( + f" Chunk {i}: {len(chunk_inputs)} turns, " + f"{chunk_chars} chars -> {chunk_path.name}" + ) + _generate_chunk(chunk_inputs, chunk_path) + chunk_paths.append(chunk_path) + + # -- Step 3: Stitch or use single chunk -- + if len(chunk_paths) == 1: + audio_path = ( + TEMP_TTS_AUDIO_DIR / f"{input.audio_name}.mp3" + ) + chunk_paths[0].rename(audio_path) + else: + audio_path = ( + TEMP_TTS_AUDIO_DIR / f"{input.audio_name}.wav" + ) + _stitch_audio_files( + chunk_paths, + audio_path, + silence_duration=config.chunk_silence_duration, + ) + logger.info( + f"Stitched {len(chunk_paths)} chunks -> " + f"{audio_path.name}" + ) + + # -- Step 4: Read audio duration -- + try: + info = sf.info(str(audio_path)) + pipeline_ref._last_generated_duration = info.duration + except Exception as e: + logger.warning(f"Audio duration read failed: {e}") + pipeline_ref._last_generated_duration = None + + # -- Step 5: Transcribe via WhisperKitPro (Parakeet) -- + engine_input = WhisperKitProInput( + audio_path=audio_path, + keep_audio=config.keep_generated_audio, + ) + engine_output = transcription_engine(engine_input) + + # -- Step 6: Parse transcription report -- + json_path = engine_output.json_report_path + if json_path.exists(): + import json + + with json_path.open("r") as f: + data = json.load(f) + all_words, all_starts, all_ends = [], [], [] + for seg in data.get("segments", []): + for w in seg.get("words", []): + all_words.append(w["word"]) + if "start" in w: + all_starts.append(w["start"]) + if "end" in w: + all_ends.append(w["end"]) + transcript = Transcript.from_words_info( + words=all_words, + start=all_starts if all_starts else None, + end=all_ends if all_ends else None, + ) + else: + raise RuntimeError( + "Transcription report not found " + f"at {json_path}" + ) + + text_preview = transcript.get_transcript_string()[:100] + logger.info(f"Transcription: {text_preview}...") + + return transcript + + return generate_and_transcribe + + def _build_transcription_engine(self) -> WhisperKitPro: + """Create WhisperKitPro engine for transcription (Parakeet).""" + config = self.config + + import coremltools as ct + + compute = ct.ComputeUnit.CPU_AND_NE + engine_config = WhisperKitProConfig( + repo_id=config.transcription_repo_id, + model_variant=config.transcription_model_variant, + model_dir=config.transcription_model_path, + word_timestamps=config.transcription_word_timestamps, + chunking_strategy=config.transcription_chunking_strategy, + audio_encoder_compute_units=compute, + text_decoder_compute_units=compute, + ) + + return WhisperKitPro( + cli_path=config.transcription_cli_path, + transcription_config=engine_config, + ) + + def __call__(self, input_sample: BaseSample) -> PipelineOutput: + """Run pipeline and set generated audio duration.""" + self._last_generated_duration: float | None = None + parsed_input = self.parse_input(input_sample) + start_time = time.perf_counter() + output = self.pipeline(parsed_input) + end_time = time.perf_counter() + prediction_time = end_time - start_time + parsed_output = self.parse_output(output) + if parsed_output.prediction_time is None: + parsed_output.prediction_time = prediction_time + + dur = self._last_generated_duration + logger.debug(f"Generated audio duration: {dur}s") + is_sg = isinstance(input_sample, SpeechGenerationSample) + if is_sg and dur is not None: + input_sample.generated_audio_duration = dur + dur_val = input_sample.generated_audio_duration + logger.debug(f"Set sample duration to {dur_val}s") + + return parsed_output + + def parse_input( + self, input_sample: SpeechGenerationSample + ) -> ElevenLabsDialogueGenerationInput: + """Extract dialogue and text from the sample.""" + text = input_sample.reference.get_transcript_string() + dialogue = input_sample.extra_info.get("dialogue", []) + + if not dialogue: + raise ValueError( + f"Sample {input_sample.audio_name} has no dialogue data. " + "This pipeline requires a dialogue dataset." + ) + + return ElevenLabsDialogueGenerationInput( + text=text, + dialogue=dialogue, + audio_name=input_sample.audio_name, + ) + + def parse_output(self, output: Transcript) -> SpeechGenerationOutput: + """Wrap transcription into output.""" + return SpeechGenerationOutput(prediction=output) diff --git a/src/openbench/pipeline/speech_generation/speech_generation_gemini.py b/src/openbench/pipeline/speech_generation/speech_generation_gemini.py new file mode 100644 index 0000000..f456c10 --- /dev/null +++ b/src/openbench/pipeline/speech_generation/speech_generation_gemini.py @@ -0,0 +1,369 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +""" +Speech generation pipeline using Google Gemini TTS API. + +Generates TTS audio from text prompts via Google Cloud +Text-to-Speech, then transcribes the generated audio back +to text using WhisperKitPro (Parakeet) for WER evaluation +against the original prompt. +""" + +import time +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import BaseModel, Field + +from ...dataset.dataset_base import BaseSample +from ...dataset.dataset_speech_generation import ( + SpeechGenerationSample, +) +from ...engine.whisperkitpro_engine import ( + WhisperKitPro, + WhisperKitProConfig, + WhisperKitProInput, +) +from ...pipeline_prediction import Transcript +from ..base import ( + Pipeline, + PipelineConfig, + PipelineOutput, + PipelineType, + register_pipeline, +) +from .common import SpeechGenerationOutput + +logger = get_logger(__name__) + +TEMP_TTS_AUDIO_DIR = Path("./temp_tts_audio") + + +class GeminiSpeechGenerationConfig(PipelineConfig): + """Config for the Gemini speech generation pipeline.""" + + # Google Cloud TTS parameters + project_id: str | None = Field( + default=None, + description=( + "Google Cloud project ID. Falls back to " + "GOOGLE_CLOUD_PROJECT env var." + ), + ) + voice_name: str = Field( + default="Charon", + description="Google Cloud TTS voice name.", + ) + language_code: str = Field( + default="en-US", + description="BCP-47 language code.", + ) + model_name: str = Field( + default="gemini-2.5-pro-tts", + description="Google Cloud TTS model name.", + ) + prompt: str | None = Field( + default=None, + description=( + "Styling instructions for how to " + "synthesize the speech." + ), + ) + audio_encoding: str = Field( + default="MP3", + description=( + "Audio encoding format " + "(MP3, LINEAR16, OGG_OPUS, MULAW, ALAW)." + ), + ) + + # Transcription parameters (WhisperKitPro / Parakeet) + transcription_cli_path: str = Field( + ..., + description=( + "Path to the whisperkit-cli binary " + "used for transcription." + ), + ) + transcription_repo_id: str | None = Field( + default=None, + description=( + "HuggingFace repo ID for transcription " + "model (e.g. argmaxinc/parakeetkit-pro)." + ), + ) + transcription_model_variant: str | None = Field( + default=None, + description=( + "Model variant folder within the repo " + "(e.g. nvidia_parakeet-v2_476MB)." + ), + ) + transcription_model_path: str | None = Field( + default=None, + description=( + "Local path to ASR model dir. " + "Overrides repo_id/model_variant." + ), + ) + transcription_word_timestamps: bool = Field( + default=True, + description="Include word timestamps.", + ) + transcription_chunking_strategy: str = Field( + default="vad", + description="Chunking strategy (none or vad).", + ) + + keep_generated_audio: bool = Field( + default=False, + description=( + "If True, keep the generated TTS audio " + "files instead of deleting them." + ), + ) + + +ENCODING_TO_EXT = { + "MP3": "mp3", + "LINEAR16": "wav", + "OGG_OPUS": "ogg", + "MULAW": "wav", + "ALAW": "wav", +} + + +class GeminiSpeechGenerationInput(BaseModel): + """Input for the Gemini speech generation pipeline.""" + + text: str = Field( + ..., + description="Text prompt to generate speech from.", + ) + audio_name: str = Field( + ..., + description=( + "Unique identifier for this sample " + "(used for temp file naming)." + ), + ) + + +@register_pipeline +class GeminiSpeechGenerationPipeline(Pipeline): + """Speech generation pipeline using Google Gemini TTS. + + This pipeline: + 1. Generates audio from text via Google Cloud TTS + 2. Transcribes audio via WhisperKitPro engine (Parakeet) + 3. Returns transcription as Transcript for WER eval + 4. Cleans up temporary audio and report files + """ + + _config_class = GeminiSpeechGenerationConfig + pipeline_type = PipelineType.SPEECH_GENERATION + + def build_pipeline( + self, + ) -> Callable[[GeminiSpeechGenerationInput], Transcript]: + config = self.config + pipeline_ref = self + + transcription_engine = self._build_transcription_engine() + + from google.cloud import texttospeech + + client = texttospeech.TextToSpeechClient() + + encoding_enum = getattr( + texttospeech.AudioEncoding, + config.audio_encoding, + ) + + def generate_and_transcribe( + input: GeminiSpeechGenerationInput, + ) -> Transcript: + TEMP_TTS_AUDIO_DIR.mkdir(parents=True, exist_ok=True) + + ext = ENCODING_TO_EXT.get( + config.audio_encoding, "mp3" + ) + audio_path = ( + TEMP_TTS_AUDIO_DIR + / f"{input.audio_name}.{ext}" + ) + + # -- Step 1: Generate audio via Google Cloud TTS -- + synth_kwargs = {"text": input.text} + if config.prompt is not None: + synth_kwargs["prompt"] = config.prompt + + synthesis_input = texttospeech.SynthesisInput( + **synth_kwargs + ) + + voice = texttospeech.VoiceSelectionParams( + language_code=config.language_code, + name=config.voice_name, + model_name=config.model_name, + ) + + audio_config = texttospeech.AudioConfig( + audio_encoding=encoding_enum, + ) + + response = client.synthesize_speech( + input=synthesis_input, + voice=voice, + audio_config=audio_config, + ) + + with open(audio_path, "wb") as f: + f.write(response.audio_content) + + if ( + not audio_path.exists() + or audio_path.stat().st_size == 0 + ): + raise RuntimeError( + "Gemini TTS failed: audio file " + f"missing or empty at {audio_path}" + ) + + logger.info( + f"Generated Gemini TTS audio: {audio_path}" + ) + + # -- Step 2: Read audio duration -- + try: + import soundfile as sf + + info = sf.info(str(audio_path)) + pipeline_ref._last_generated_duration = ( + info.duration + ) + except Exception as e: + logger.warning( + f"Audio duration read failed: {e}" + ) + pipeline_ref._last_generated_duration = None + + # -- Step 3: Transcribe via WhisperKitPro -- + engine_input = WhisperKitProInput( + audio_path=audio_path, + keep_audio=config.keep_generated_audio, + ) + engine_output = transcription_engine(engine_input) + + # -- Step 4: Parse transcription report -- + json_path = engine_output.json_report_path + if json_path.exists(): + import json + + with json_path.open("r") as f: + data = json.load(f) + all_words, all_starts, all_ends = [], [], [] + for seg in data.get("segments", []): + for w in seg.get("words", []): + all_words.append(w["word"]) + if "start" in w: + all_starts.append(w["start"]) + if "end" in w: + all_ends.append(w["end"]) + transcript = Transcript.from_words_info( + words=all_words, + start=( + all_starts if all_starts else None + ), + end=all_ends if all_ends else None, + ) + json_path.unlink(missing_ok=True) + srt_path = engine_output.srt_report_path + if srt_path: + srt_path.unlink(missing_ok=True) + else: + raise RuntimeError( + "Transcription report not found " + f"at {json_path}" + ) + + text_preview = ( + transcript.get_transcript_string()[:100] + ) + logger.info(f"Transcription: {text_preview}...") + + return transcript + + return generate_and_transcribe + + def _build_transcription_engine(self) -> WhisperKitPro: + """Create WhisperKitPro engine for transcription.""" + config = self.config + + import coremltools as ct + + compute = ct.ComputeUnit.CPU_AND_NE + engine_config = WhisperKitProConfig( + repo_id=config.transcription_repo_id, + model_variant=config.transcription_model_variant, + model_dir=config.transcription_model_path, + word_timestamps=( + config.transcription_word_timestamps + ), + chunking_strategy=( + config.transcription_chunking_strategy + ), + audio_encoder_compute_units=compute, + text_decoder_compute_units=compute, + ) + + return WhisperKitPro( + cli_path=config.transcription_cli_path, + transcription_config=engine_config, + ) + + def __call__( + self, input_sample: BaseSample + ) -> PipelineOutput: + """Run pipeline and set generated audio duration.""" + self._last_generated_duration: float | None = None + parsed_input = self.parse_input(input_sample) + start_time = time.perf_counter() + output = self.pipeline(parsed_input) + end_time = time.perf_counter() + prediction_time = end_time - start_time + parsed_output = self.parse_output(output) + if parsed_output.prediction_time is None: + parsed_output.prediction_time = prediction_time + + dur = self._last_generated_duration + logger.debug(f"Generated audio duration: {dur}s") + is_sg = isinstance( + input_sample, SpeechGenerationSample + ) + if is_sg and dur is not None: + input_sample.generated_audio_duration = dur + dur_val = input_sample.generated_audio_duration + logger.debug( + f"Set sample duration to {dur_val}s" + ) + + return parsed_output + + def parse_input( + self, input_sample: SpeechGenerationSample + ) -> GeminiSpeechGenerationInput: + """Extract text prompt from the sample.""" + text = input_sample.reference.get_transcript_string() + return GeminiSpeechGenerationInput( + text=text, + audio_name=input_sample.audio_name, + ) + + def parse_output( + self, output: Transcript + ) -> SpeechGenerationOutput: + """Wrap transcription into output.""" + return SpeechGenerationOutput(prediction=output) diff --git a/src/openbench/pipeline/speech_generation/speech_generation_openai.py b/src/openbench/pipeline/speech_generation/speech_generation_openai.py new file mode 100644 index 0000000..4a2b08a --- /dev/null +++ b/src/openbench/pipeline/speech_generation/speech_generation_openai.py @@ -0,0 +1,345 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +""" +Speech generation pipeline using OpenAI TTS API. + +Generates TTS audio from text prompts via OpenAI, +then transcribes the generated audio back to text using +WhisperKitPro (Parakeet) for WER evaluation against the +original prompt. +""" + +import os +import time +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import BaseModel, Field + +from ...dataset.dataset_base import BaseSample +from ...dataset.dataset_speech_generation import ( + SpeechGenerationSample, +) +from ...engine.whisperkitpro_engine import ( + WhisperKitPro, + WhisperKitProConfig, + WhisperKitProInput, +) +from ...pipeline_prediction import Transcript +from ..base import ( + Pipeline, + PipelineConfig, + PipelineOutput, + PipelineType, + register_pipeline, +) +from .common import SpeechGenerationOutput + +logger = get_logger(__name__) + +TEMP_TTS_AUDIO_DIR = Path("./temp_tts_audio") + + +class OpenAISpeechGenerationConfig(PipelineConfig): + """Config for the OpenAI speech generation pipeline.""" + + # OpenAI TTS parameters + api_key: str | None = Field( + default=None, + description=( + "OpenAI API key. Falls back to " + "OPENAI_API_KEY env var." + ), + ) + model: str = Field( + default="gpt-4o-mini-tts", + description="OpenAI TTS model.", + ) + voice: str = Field( + default="coral", + description="OpenAI TTS voice.", + ) + instructions: str | None = Field( + default=None, + description="Voice style instructions.", + ) + response_format: str = Field( + default="wav", + description=( + "Audio output format " + "(wav, mp3, flac, opus, aac, pcm)." + ), + ) + speed: float = Field( + default=1.0, + description="Speech speed (0.25 to 4.0).", + ) + + # Transcription parameters (WhisperKitPro / Parakeet) + transcription_cli_path: str = Field( + ..., + description=( + "Path to the whisperkit-cli binary " + "used for transcription." + ), + ) + transcription_repo_id: str | None = Field( + default=None, + description=( + "HuggingFace repo ID for transcription " + "model (e.g. argmaxinc/parakeetkit-pro)." + ), + ) + transcription_model_variant: str | None = Field( + default=None, + description=( + "Model variant folder within the repo " + "(e.g. nvidia_parakeet-v2_476MB)." + ), + ) + transcription_model_path: str | None = Field( + default=None, + description=( + "Local path to ASR model dir. " + "Overrides repo_id/model_variant." + ), + ) + transcription_word_timestamps: bool = Field( + default=True, + description="Include word timestamps.", + ) + transcription_chunking_strategy: str = Field( + default="vad", + description="Chunking strategy (none or vad).", + ) + + keep_generated_audio: bool = Field( + default=False, + description=( + "If True, keep the generated TTS audio " + "files instead of deleting them." + ), + ) + + +class OpenAISpeechGenerationInput(BaseModel): + """Input for the OpenAI speech generation pipeline.""" + + text: str = Field( + ..., + description="Text prompt to generate speech from.", + ) + audio_name: str = Field( + ..., + description=( + "Unique identifier for this sample " + "(used for temp file naming)." + ), + ) + + +@register_pipeline +class OpenAISpeechGenerationPipeline(Pipeline): + """Speech generation pipeline using OpenAI TTS API. + + This pipeline: + 1. Generates audio from text via OpenAI text-to-speech + 2. Transcribes audio via WhisperKitPro engine (Parakeet) + 3. Returns transcription as Transcript for WER eval + 4. Cleans up temporary audio and report files + """ + + _config_class = OpenAISpeechGenerationConfig + pipeline_type = PipelineType.SPEECH_GENERATION + + def build_pipeline( + self, + ) -> Callable[[OpenAISpeechGenerationInput], Transcript]: + config = self.config + pipeline_ref = self + + transcription_engine = self._build_transcription_engine() + + api_key = config.api_key or os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError( + "OpenAI API key must be provided " + "via config or OPENAI_API_KEY env var." + ) + + from openai import OpenAI + + client = OpenAI(api_key=api_key) + + def generate_and_transcribe( + input: OpenAISpeechGenerationInput, + ) -> Transcript: + TEMP_TTS_AUDIO_DIR.mkdir(parents=True, exist_ok=True) + + ext = config.response_format + if ext == "pcm": + ext = "wav" + audio_path = ( + TEMP_TTS_AUDIO_DIR / f"{input.audio_name}.{ext}" + ) + + # -- Step 1: Generate audio via OpenAI TTS -- + kwargs = { + "model": config.model, + "voice": config.voice, + "input": input.text, + "response_format": config.response_format, + "speed": config.speed, + } + if config.instructions is not None: + kwargs["instructions"] = config.instructions + + response = client.audio.speech.create(**kwargs) + response.stream_to_file(str(audio_path)) + + if ( + not audio_path.exists() + or audio_path.stat().st_size == 0 + ): + raise RuntimeError( + "OpenAI TTS failed: audio file " + f"missing or empty at {audio_path}" + ) + + logger.info( + f"Generated OpenAI TTS audio: {audio_path}" + ) + + # -- Step 2: Read audio duration -- + try: + import soundfile as sf + + info = sf.info(str(audio_path)) + pipeline_ref._last_generated_duration = ( + info.duration + ) + except Exception as e: + logger.warning( + f"Audio duration read failed: {e}" + ) + pipeline_ref._last_generated_duration = None + + # -- Step 3: Transcribe via WhisperKitPro -- + engine_input = WhisperKitProInput( + audio_path=audio_path, + keep_audio=config.keep_generated_audio, + ) + engine_output = transcription_engine(engine_input) + + # -- Step 4: Parse transcription report -- + json_path = engine_output.json_report_path + if json_path.exists(): + import json + + with json_path.open("r") as f: + data = json.load(f) + all_words, all_starts, all_ends = [], [], [] + for seg in data.get("segments", []): + for w in seg.get("words", []): + all_words.append(w["word"]) + if "start" in w: + all_starts.append(w["start"]) + if "end" in w: + all_ends.append(w["end"]) + transcript = Transcript.from_words_info( + words=all_words, + start=( + all_starts if all_starts else None + ), + end=all_ends if all_ends else None, + ) + json_path.unlink(missing_ok=True) + srt_path = engine_output.srt_report_path + if srt_path: + srt_path.unlink(missing_ok=True) + else: + raise RuntimeError( + "Transcription report not found " + f"at {json_path}" + ) + + text_preview = ( + transcript.get_transcript_string()[:100] + ) + logger.info(f"Transcription: {text_preview}...") + + return transcript + + return generate_and_transcribe + + def _build_transcription_engine(self) -> WhisperKitPro: + """Create WhisperKitPro engine for transcription.""" + config = self.config + + import coremltools as ct + + compute = ct.ComputeUnit.CPU_AND_NE + engine_config = WhisperKitProConfig( + repo_id=config.transcription_repo_id, + model_variant=config.transcription_model_variant, + model_dir=config.transcription_model_path, + word_timestamps=( + config.transcription_word_timestamps + ), + chunking_strategy=( + config.transcription_chunking_strategy + ), + audio_encoder_compute_units=compute, + text_decoder_compute_units=compute, + ) + + return WhisperKitPro( + cli_path=config.transcription_cli_path, + transcription_config=engine_config, + ) + + def __call__( + self, input_sample: BaseSample + ) -> PipelineOutput: + """Run pipeline and set generated audio duration.""" + self._last_generated_duration: float | None = None + parsed_input = self.parse_input(input_sample) + start_time = time.perf_counter() + output = self.pipeline(parsed_input) + end_time = time.perf_counter() + prediction_time = end_time - start_time + parsed_output = self.parse_output(output) + if parsed_output.prediction_time is None: + parsed_output.prediction_time = prediction_time + + dur = self._last_generated_duration + logger.debug(f"Generated audio duration: {dur}s") + is_sg = isinstance( + input_sample, SpeechGenerationSample + ) + if is_sg and dur is not None: + input_sample.generated_audio_duration = dur + dur_val = input_sample.generated_audio_duration + logger.debug( + f"Set sample duration to {dur_val}s" + ) + + return parsed_output + + def parse_input( + self, input_sample: SpeechGenerationSample + ) -> OpenAISpeechGenerationInput: + """Extract text prompt from the sample.""" + text = input_sample.reference.get_transcript_string() + return OpenAISpeechGenerationInput( + text=text, + audio_name=input_sample.audio_name, + ) + + def parse_output( + self, output: Transcript + ) -> SpeechGenerationOutput: + """Wrap transcription into output.""" + return SpeechGenerationOutput(prediction=output) diff --git a/src/openbench/pipeline/speech_generation/speech_generation_wkp.py b/src/openbench/pipeline/speech_generation/speech_generation_wkp.py new file mode 100644 index 0000000..9c93440 --- /dev/null +++ b/src/openbench/pipeline/speech_generation/speech_generation_wkp.py @@ -0,0 +1,257 @@ +# For licensing see accompanying LICENSE.md file. +# Copyright (C) 2025 Argmax, Inc. All Rights Reserved. + +""" +Speech generation pipeline using WhisperKit CLI. + +Generates TTS audio from text prompts, then transcribes +the generated audio back to text for WER evaluation +against the original prompt. +""" + +import json +import subprocess +import time +from pathlib import Path +from typing import Callable + +from argmaxtools.utils import get_logger +from pydantic import BaseModel, Field + +from ...dataset.dataset_base import BaseSample +from ...dataset.dataset_speech_generation import ( + SpeechGenerationSample, +) +from ...engine.whisperkitpro_engine import ( + WhisperKitPro, + WhisperKitProConfig, + WhisperKitProInput, +) +from ...pipeline_prediction import Transcript +from ..base import ( + Pipeline, + PipelineOutput, + PipelineType, + register_pipeline, +) +from .common import SpeechGenerationConfig, SpeechGenerationOutput + + +logger = get_logger(__name__) + +TEMP_TTS_AUDIO_DIR = Path("./temp_tts_audio") + + +class SpeechGenerationInput(BaseModel): + """Input for the speech generation pipeline.""" + + text: str = Field( + ..., + description="Text prompt to generate speech from.", + ) + audio_name: str = Field( + ..., + description=("Unique identifier for this sample (used for temp file naming)."), + ) + + +@register_pipeline +class WhisperKitSpeechGenerationPipeline(Pipeline): + """Speech generation pipeline using WhisperKit CLI. + + This pipeline: + 1. Generates audio from text via whisperkit-cli tts + 2. Transcribes audio via WhisperKitPro engine + 3. Returns transcription as Transcript for WER eval + 4. Cleans up temporary audio and report files + """ + + _config_class = SpeechGenerationConfig + pipeline_type = PipelineType.SPEECH_GENERATION + + def build_pipeline( + self, + ) -> Callable[[SpeechGenerationInput], Transcript]: + config = self.config + pipeline_ref = self + + # Build the WhisperKitPro engine for transcription + # (downloads model once, reuses for all samples) + transcription_engine = self._build_transcription_engine() + + def generate_and_transcribe( + input: SpeechGenerationInput, + ) -> Transcript: + TEMP_TTS_AUDIO_DIR.mkdir(parents=True, exist_ok=True) + + audio_path = TEMP_TTS_AUDIO_DIR / f"{input.audio_name}.wav" + + # -- Step 1: Generate audio via TTS -- + tts_cmd = [ + config.cli_path, + "tts", + "--text", + input.text, + "--speaker", + config.speaker, + "--language", + config.language, + "--output-path", + str(audio_path), + "--temperature", + str(config.temperature), + "--top-k", + str(config.top_k), + "--max-new-tokens", + str(config.max_new_tokens), + ] + + if config.seed is not None: + tts_cmd.extend(["--seed", str(config.seed)]) + if config.models_path is not None: + tts_cmd.extend(["--models-path", config.models_path]) + if config.model_repo is not None: + tts_cmd.extend(["--model-repo", config.model_repo]) + if config.version_dir is not None: + tts_cmd.extend(["--version-dir", config.version_dir]) + if config.tokenizer is not None: + tts_cmd.extend(["--tokenizer", config.tokenizer]) + + logger.debug(f"Running TTS: {' '.join(tts_cmd)}") + + tts_result = subprocess.run(tts_cmd, capture_output=True, text=True) + + if tts_result.returncode != 0: + raise RuntimeError( + "whisperkit-cli tts failed " + f"(exit {tts_result.returncode}):\n" + f" stdout: " + f"{tts_result.stdout[:500]}\n" + f" stderr: " + f"{tts_result.stderr[:500]}" + ) + + if not audio_path.exists(): + raise RuntimeError(f"TTS completed but audio file not found at {audio_path}") + + logger.info(f"Generated TTS audio: {audio_path}") + + # -- Step 2: Read audio duration before + # transcription (engine may delete file) -- + try: + import soundfile as sf + + info = sf.info(str(audio_path)) + pipeline_ref._last_generated_duration = info.duration + except Exception as e: + logger.warning(f"WAV duration read failed: {e}") + pipeline_ref._last_generated_duration = None + + # -- Step 3: Transcribe via WhisperKitPro -- + engine_input = WhisperKitProInput( + audio_path=audio_path, + keep_audio=False, + ) + engine_output = transcription_engine(engine_input) + + # -- Step 4: Parse transcription report -- + json_path = engine_output.json_report_path + if json_path.exists(): + with json_path.open("r") as f: + data = json.load(f) + all_words, all_starts, all_ends = ( + [], + [], + [], + ) + for seg in data.get("segments", []): + for w in seg.get("words", []): + all_words.append(w["word"]) + if "start" in w: + all_starts.append(w["start"]) + if "end" in w: + all_ends.append(w["end"]) + transcript = Transcript.from_words_info( + words=all_words, + start=(all_starts if all_starts else None), + end=(all_ends if all_ends else None), + ) + # Clean up report files + json_path.unlink(missing_ok=True) + srt_path = engine_output.srt_report_path + if srt_path: + srt_path.unlink(missing_ok=True) + else: + raise RuntimeError(f"Transcription report not found at {json_path}") + + logger.info("Transcription: " + transcript.get_transcript_string()[:100] + "...") + + return transcript + + return generate_and_transcribe + + def _build_transcription_engine(self) -> WhisperKitPro: + """Create WhisperKitPro engine for transcription. + + Uses the same engine as the dedicated + WhisperKitPro transcription pipelines, which + handles model download, caching, and CLI args. + """ + config = self.config + cli_path = config.transcription_cli_path or config.cli_path + + import coremltools as ct + + engine_config = WhisperKitProConfig( + repo_id=config.transcription_repo_id, + model_variant=config.transcription_model_variant, + model_dir=config.transcription_model_path, + word_timestamps=config.transcription_word_timestamps, + chunking_strategy=config.transcription_chunking_strategy, + audio_encoder_compute_units=ct.ComputeUnit.CPU_AND_NE, + text_decoder_compute_units=ct.ComputeUnit.CPU_AND_NE, + ) + + return WhisperKitPro( + cli_path=cli_path, + transcription_config=engine_config, + ) + + def __call__(self, input_sample: BaseSample) -> PipelineOutput: + """Run pipeline and set generated audio duration. + + Overrides base __call__ to propagate the real + TTS audio duration back onto the sample so the + runner reports accurate audio_duration and + speed_factor values. + """ + self._last_generated_duration: float | None = None + parsed_input = self.parse_input(input_sample) + start_time = time.perf_counter() + output = self.pipeline(parsed_input) + end_time = time.perf_counter() + prediction_time = end_time - start_time + parsed_output = self.parse_output(output) + if parsed_output.prediction_time is None: + parsed_output.prediction_time = prediction_time + + # Propagate generated audio duration to sample + dur = self._last_generated_duration + logger.debug(f"Generated audio duration: {dur}s") + if isinstance(input_sample, SpeechGenerationSample) and dur is not None: + input_sample.generated_audio_duration = dur + logger.debug(f"Set sample duration to {input_sample.generated_audio_duration}s") + + return parsed_output + + def parse_input(self, input_sample: SpeechGenerationSample) -> SpeechGenerationInput: + """Extract text prompt from the sample.""" + text = input_sample.reference.get_transcript_string() + return SpeechGenerationInput( + text=text, + audio_name=input_sample.audio_name, + ) + + def parse_output(self, output: Transcript) -> SpeechGenerationOutput: + """Wrap transcription into output.""" + return SpeechGenerationOutput(prediction=output) diff --git a/src/openbench/runner/benchmark.py b/src/openbench/runner/benchmark.py index 439c15c..b195862 100644 --- a/src/openbench/runner/benchmark.py +++ b/src/openbench/runner/benchmark.py @@ -33,6 +33,7 @@ PipelineType.TRANSCRIPTION: TranscriptionSampleResult, PipelineType.ORCHESTRATION: TranscriptionSampleResult, PipelineType.STREAMING_TRANSCRIPTION: TranscriptionSampleResult, + PipelineType.SPEECH_GENERATION: TranscriptionSampleResult, } @@ -64,6 +65,7 @@ def __init__(self, config: BenchmarkConfig, pipelines: list[Pipeline]): PipelineType.TRANSCRIPTION: TranscriptionWandbLogger, PipelineType.ORCHESTRATION: TranscriptionWandbLogger, PipelineType.STREAMING_TRANSCRIPTION: TranscriptionWandbLogger, + PipelineType.SPEECH_GENERATION: TranscriptionWandbLogger, } def _get_metrics(self, pipeline: Pipeline) -> dict[str, BaseMetric]: diff --git a/src/openbench/types.py b/src/openbench/types.py index b4bbeaa..b82df22 100644 --- a/src/openbench/types.py +++ b/src/openbench/types.py @@ -12,6 +12,7 @@ class PipelineType(Enum): TRANSCRIPTION = "transcription" ORCHESTRATION = "orchestration" STREAMING_TRANSCRIPTION = "streaming_transcription" + SPEECH_GENERATION = "speech_generation" # All prediction classes that we output should conform to this