diff --git a/backend/app/api/docs/llm/speech_to_speech.md b/backend/app/api/docs/llm/speech_to_speech.md new file mode 100644 index 000000000..e4ad03e6f --- /dev/null +++ b/backend/app/api/docs/llm/speech_to_speech.md @@ -0,0 +1,228 @@ +# Speech-to-Speech (STS) with RAG + +Execute a complete speech-to-speech workflow with knowledge base retrieval. + +## Endpoint + +``` +POST /llm/sts +``` + +## Flow + +``` +Voice Input → STT (auto language) → RAG (Knowledge Base) → TTS → Voice Output +``` + +## Input + +- **Voice note**: WhatsApp-compatible audio format (required) +- **Knowledge base IDs**: One or more knowledge bases for RAG (required) +- **Languages**: Input and output languages (optional, defaults to Hindi) +- **Models**: STT, LLM, and TTS model selection (optional, defaults to Sarvam) + +## Output + +You will receive **3 callbacks** to your webhook URL: + +1. **STT Callback** (Intermediate): Transcribed text from audio +2. **LLM Callback** (Intermediate): RAG-enhanced response text +3. **TTS Callback** (Final): Audio output + response text + +Each callback includes: +- Output from that step +- Token usage +- Latency information (check timestamps) + +## Supported Languages + +### Primary Indian Languages +- English, Hindi, Hinglish (code-switching) +- Bengali, Kannada, Malayalam, Marathi +- Odia, Punjabi, Tamil, Telugu, Gujarati + +### Additional Languages (Sarvam Saaras V3) +- Assamese, Urdu, Nepali +- Konkani, Kashmiri, Sindhi +- Sanskrit, Santali, Manipuri +- Bodo, Maithili, Dogri + +**Total: 25 languages** with automatic language detection + +## Available Models + +### STT (Speech-to-Text) +- `saaras:v3` - Sarvam Saaras V3 (**default**, fast, auto language detection, optimized for Indian languages) +- `gemini-2.5-pro` - Google Gemini 2.5 Pro + +**Note:** Sarvam STT uses automatic language detection. No need to specify input language. + +### LLM (RAG) +- `gpt-4o` - OpenAI GPT-4o (**default**, best quality) +- `gpt-4o-mini` - OpenAI GPT-4o Mini (faster, lower cost) + +### TTS (Text-to-Speech) +- `bulbul:v3` - Sarvam Bulbul V3 (**default**, natural Indian voices, MP3 output) +- `gemini-2.5-pro-preview-tts` - Google Gemini 2.5 Pro (OGG OPUS output) + +## Edge Cases & Error Handling + +### Empty STT Output +If speech-to-text returns empty/blank: +- Chain fails immediately +- Error message: "STT returned no transcription" +- No subsequent blocks are executed + +### Audio Size Limit +WhatsApp limit: 16MB +- TTS providers may fail if output exceeds limit +- Error is caught and reported in callback +- Consider using shorter responses or compression + +### Invalid Audio Format +If input audio format is unsupported: +- STT provider fails with format error +- Error reported in callback +- Supported: MP3, WAV, OGG, OPUS, M4A + +### Provider Failures +Each block has independent error handling: +- STT fails → Chain stops, STT error reported +- LLM fails → Chain stops, RAG error reported +- TTS fails → Chain stops, TTS error reported + +## Example Request + +```bash +curl -X POST https://api.kaapi.ai/llm/sts \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d @- < 'hi-IN').""" + # Normalize input_language + if self.input_language and self.input_language != "auto": + # Normalize BCP-47: lowercase language, uppercase region (e.g., "hi-IN") + parts = self.input_language.split("-") + if len(parts) == 2: + self.input_language = f"{parts[0].lower()}-{parts[1].upper()}" + + # Normalize output_language + if self.output_language: + parts = self.output_language.split("-") + if len(parts) == 2: + self.output_language = f"{parts[0].lower()}-{parts[1].upper()}" + + return self diff --git a/backend/app/services/llm/chain/chain.py b/backend/app/services/llm/chain/chain.py index ad0503675..f0bac46f9 100644 --- a/backend/app/services/llm/chain/chain.py +++ b/backend/app/services/llm/chain/chain.py @@ -36,6 +36,9 @@ class ChainContext: langfuse_credentials: dict[str, Any] | None = None request_metadata: dict | None = None intermediate_callback_flags: list[bool] = field(default_factory=list) + detected_language: str | None = ( + None # Stores language detected by STT for use by TTS + ) aggregated_usage: Usage = field( default_factory=lambda: Usage( input_tokens=0, @@ -45,17 +48,37 @@ class ChainContext: ) -def result_to_query(result: BlockResult) -> QueryParams: +def result_to_query( + result: BlockResult, context: ChainContext | None = None +) -> QueryParams: """Convert a block's output into the next block's QueryParams. Text output → TextInput query Audio output → AudioInput query + + Also preserves language_code from STT output for use by downstream TTS blocks. """ output = result.response.response.output if isinstance(output, TextOutput): + # Preserve language_code if present (from STT auto-detection) + language_code = ( + output.content.language_code + if hasattr(output.content, "language_code") + else None + ) + + # Store detected language in context for TTS to use + if context and language_code: + context.detected_language = language_code + logger.info(f"[result_to_query] Detected language: {language_code}") + return QueryParams( - input=TextInput(content=TextContent(value=output.content.value)) + input=TextInput( + content=TextContent( + value=output.content.value, language_code=language_code + ) + ) ) elif isinstance(output, AudioOutput): return QueryParams(input=AudioInput(content=output.content)) @@ -96,6 +119,7 @@ def execute(self, query: QueryParams) -> BlockResult: langfuse_credentials=self._context.langfuse_credentials, include_provider_raw_response=self._include_provider_raw_response, chain_id=self._context.chain_id, + detected_language=self._context.detected_language, ) @@ -132,6 +156,6 @@ def execute( return result if block is not self._blocks[-1]: - current_query = result_to_query(result) + current_query = result_to_query(result, self._context) return result diff --git a/backend/app/services/llm/chain/utils.py b/backend/app/services/llm/chain/utils.py new file mode 100644 index 000000000..223530a6d --- /dev/null +++ b/backend/app/services/llm/chain/utils.py @@ -0,0 +1,185 @@ +"""Utility functions for LLM chain operations, including speech-to-speech helpers.""" + +from typing import Any, Literal + +from app.models.llm.request import ( + ChainBlock, + ConfigBlob, + KaapiCompletionConfig, + LLMCallConfig, + LLMModel, + NativeCompletionConfig, + STTModel, + TextLLMParams, + TTSModel, +) + + +# Supported BCP-47 language codes for speech-to-speech +# These are the valid values that can be used directly in API requests +SUPPORTED_LANGUAGE_CODES = { + # Auto-detect + "auto", # Auto-detection (maps to "unknown" for Sarvam) + "unknown", # Explicit unknown for Sarvam + # Primary Indian languages (BCP-47 codes) + "en-IN", # English + "hi-IN", # Hindi (also used for Hinglish/code-switching) + "bn-IN", # Bengali + "kn-IN", # Kannada + "ml-IN", # Malayalam + "mr-IN", # Marathi + "od-IN", # Odia + "pa-IN", # Punjabi + "ta-IN", # Tamil + "te-IN", # Telugu + "gu-IN", # Gujarati + # Additional languages (saaras:v3) + "as-IN", # Assamese + "ur-IN", # Urdu + "ne-IN", # Nepali + "kok-IN", # Konkani + "ks-IN", # Kashmiri + "sd-IN", # Sindhi + "sa-IN", # Sanskrit + "sat-IN", # Santali + "mni-IN", # Manipuri + "brx-IN", # Bodo + "mai-IN", # Maithili + "doi-IN", # Dogri +} + + +def build_stt_block(model: STTModel, language_code: str) -> ChainBlock: + """Build STT (Speech-to-Text) block configuration. + + Args: + model: STT model enum + language_code: BCP-47 language code (e.g., "hi-IN", "en-IN") or "auto" for auto-detection + + Returns: + ChainBlock configured for STT + """ + # Map model to provider and actual model name + model_configs: dict[ + STTModel, + tuple[Literal["sarvamai-native", "google-native", "openai-native"], str], + ] = { + STTModel.SARVAM: ("sarvamai-native", "saaras:v3"), + STTModel.GEMINI_PRO: ("google-native", "gemini-2.5-pro"), + } + + provider, model_name = model_configs[model] + + # Build native config (provider-specific params) + params: dict[str, Any] = { + "model": model_name, + } + + # Add provider-specific parameters + if provider == "sarvamai-native": + # Map "auto" to "unknown" for Sarvam auto-detection + params["language_code"] = ( + "unknown" if language_code == "auto" else language_code + ) + params["mode"] = "transcribe" + elif provider == "google-native": + # Google requires specific language code, fallback to en-IN if auto/unknown + params["language_code"] = ( + "en-IN" if language_code in ("auto", "unknown") else language_code + ) + + return ChainBlock( + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider=provider, + type="stt", + params=params, + ) + ) + ), + intermediate_callback=True, # Send STT result to user + include_provider_raw_response=False, + ) + + +def build_rag_block(model: LLMModel, knowledge_base_ids: list[str]) -> ChainBlock: + """Build RAG (Retrieval-Augmented Generation) block configuration. + + Args: + model: LLM model enum + knowledge_base_ids: List of knowledge base IDs for retrieval + + Returns: + ChainBlock configured for RAG + """ + return ChainBlock( + config=LLMCallConfig( + blob=ConfigBlob( + completion=KaapiCompletionConfig( + provider="openai", + type="text", + params=TextLLMParams( + model=model.value, + knowledge_base_ids=knowledge_base_ids, + temperature=0.1, + instructions="Answer the user's question using the provided knowledge base. Be concise and accurate.", + ).model_dump(exclude_none=True), + ) + ) + ), + intermediate_callback=True, # Send LLM result to user + include_provider_raw_response=False, + ) + + +def build_tts_block(model: TTSModel, language_code: str = "en-IN") -> ChainBlock: + """Build TTS (Text-to-Speech) block configuration. + + Args: + model: TTS model enum + language_code: ISO language code (e.g., "hi-IN"), or "{{detected}}" to use language detected by STT + + Returns: + ChainBlock configured for TTS + """ + # Map model to provider and actual model name + voice + model_configs: dict[ + TTSModel, + tuple[Literal["sarvamai-native", "google-native", "openai-native"], str, str], + ] = { + TTSModel.SARVAM: ("sarvamai-native", "bulbul:v3", "simran"), + TTSModel.GEMINI_PRO: ("google-native", "gemini-2.5-pro", "default"), + } + + provider, model_name, voice = model_configs[model] + + # Build native config + params: dict[str, Any] = { + "model": model_name, + "voice": voice, + } + + # Add provider-specific parameters + if provider == "sarvamai-native": + # Use language_code (can be "{{detected}}" marker or actual code) + params["target_language_code"] = language_code + params["speaker"] = voice + params["output_audio_codec"] = "opus" # WhatsApp compatible + elif provider == "google-native": + params["language_code"] = language_code + params["audio_encoding"] = "OGG_OPUS" # WhatsApp compatible + + return ChainBlock( + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider=provider, + type="tts", + params=params, + ) + ) + ), + intermediate_callback=False, # Final result only + include_provider_raw_response=False, + ) diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 2a5f7dee2..621c8d912 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -23,6 +23,7 @@ ImageInput, KaapiCompletionConfig, LLMCallConfig, + NativeCompletionConfig, PDFInput, QueryParams, TextInput, @@ -335,10 +336,14 @@ def execute_llm_call( langfuse_credentials: dict | None, include_provider_raw_response: bool = False, chain_id: UUID | None = None, + detected_language: str | None = None, ) -> BlockResult: """Execute a single LLM call. Shared by /llm/call and /llm/chain. Returns BlockResult with response + usage on success, or error on failure. + + Args: + detected_language: Language code detected by STT (used to replace {{detected}} marker in TTS) """ config_blob: ConfigBlob | None = None @@ -382,6 +387,27 @@ def execute_llm_call( request_metadata = {} request_metadata.setdefault("warnings", []).extend(warnings) + # Replace {{detected}} marker in TTS configs with actual detected language + if ( + isinstance(completion_config, NativeCompletionConfig) + and completion_config.type == "tts" + ): + params = completion_config.params + # Replace {{detected}} marker in any language-related params + for key in ["target_language_code", "language_code"]: + if key in params and params[key] == "{{detected}}": + if detected_language: + params[key] = detected_language + logger.info( + f"[execute_llm_call] Using detected language for TTS: {detected_language} | job_id={job_id}" + ) + else: + # Fallback to English if no language was detected + params[key] = "en-IN" + logger.warning( + f"[execute_llm_call] No language detected, falling back to en-IN for TTS | job_id={job_id}" + ) + resolved_config_blob = ConfigBlob( completion=completion_config, prompt_template=config_blob.prompt_template, diff --git a/backend/app/services/llm/providers/sai.py b/backend/app/services/llm/providers/sai.py index c2984e6aa..4f4170d7b 100644 --- a/backend/app/services/llm/providers/sai.py +++ b/backend/app/services/llm/providers/sai.py @@ -111,7 +111,10 @@ def _execute_stt( provider=provider_name, model=model, output=TextOutput( - content=TextContent(value=sarvam_response.transcript) + content=TextContent( + value=sarvam_response.transcript, + language_code=sarvam_response.language_code, + ) ), ), usage=Usage( @@ -184,6 +187,7 @@ def _execute_tts( target_language_code=target_language_code, model=model, speaker=speaker, + speech_sample_rate=16000, output_audio_codec=output_audio_codec, ) diff --git a/backend/app/tests/services/llm/test_sts.py b/backend/app/tests/services/llm/test_sts.py new file mode 100644 index 000000000..7822bdc00 --- /dev/null +++ b/backend/app/tests/services/llm/test_sts.py @@ -0,0 +1,572 @@ +""" +Test cases for Speech-to-Speech (STS) functionality. + +Tests cover: +1. Language detection and propagation through STT → RAG → TTS chain +2. BCP-47 language code validation +3. Real-world use cases (auto-detection, explicit languages, cross-language) +""" + +from unittest.mock import patch, MagicMock +from uuid import uuid4 + +import pytest +from fastapi.testclient import TestClient + +from app.models.llm.request import ( + AudioContent, + AudioInput, + LLMModel, + STTModel, + SpeechToSpeechRequest, + TTSModel, +) +from app.models.llm.response import ( + LLMCallResponse, + LLMResponse, + TextOutput, + TextContent as ResponseTextContent, + Usage, +) +from app.services.llm.chain.chain import ChainContext, result_to_query +from app.services.llm.chain.types import BlockResult +from app.services.llm.chain.utils import ( + SUPPORTED_LANGUAGE_CODES, + build_stt_block, + build_tts_block, +) + + +# ============================================================================ +# Unit Tests: Language Detection Flow +# ============================================================================ + + +class TestLanguageDetectionFlow: + """Test language detection and propagation through the chain.""" + + def test_result_to_query_preserves_language_code(self): + """STT output with language_code should be preserved when converting to next block's input.""" + # Simulate STT response with detected Hindi + stt_response = LLMCallResponse( + response=LLMResponse( + provider_response_id="stt-resp-1", + conversation_id=None, + model="saaras:v3", + provider="sarvamai-native", + output=TextOutput( + content=ResponseTextContent( + value="नमस्ते, आप कैसे हैं?", language_code="hi-IN" + ) + ), + ), + usage=Usage(input_tokens=0, output_tokens=10, total_tokens=10), + ) + + result = BlockResult(response=stt_response, usage=stt_response.usage) + context = ChainContext( + job_id=uuid4(), + chain_id=uuid4(), + project_id=1, + organization_id=1, + callback_url=None, + total_blocks=3, + ) + + # Convert STT output to RAG input + query = result_to_query(result, context) + + # Language code should be preserved + assert query.input.content.language_code == "hi-IN" + assert query.input.content.value == "नमस्ते, आप कैसे हैं?" + + # Context should store detected language for TTS + assert context.detected_language == "hi-IN" + + def test_result_to_query_without_language_code(self): + """RAG output without language_code should not break the chain.""" + # Simulate RAG response (no language_code) + rag_response = LLMCallResponse( + response=LLMResponse( + provider_response_id="rag-resp-1", + conversation_id=None, + model="gpt-4o", + provider="openai", + output=TextOutput( + content=ResponseTextContent( + value="The capital of India is New Delhi." + ) + ), + ), + usage=Usage(input_tokens=50, output_tokens=12, total_tokens=62), + ) + + result = BlockResult(response=rag_response, usage=rag_response.usage) + context = ChainContext( + job_id=uuid4(), + chain_id=uuid4(), + project_id=1, + organization_id=1, + callback_url=None, + total_blocks=3, + detected_language="hi-IN", # From previous STT block + ) + + # Convert RAG output to TTS input + query = result_to_query(result, context) + + # Should work fine even without language_code + assert query.input.content.value == "The capital of India is New Delhi." + # Context should retain previously detected language + assert context.detected_language == "hi-IN" + + def test_detected_marker_replacement(self): + """{{detected}} marker in TTS should be replaced with actual detected language.""" + from app.services.llm.jobs import execute_llm_call + from app.models.llm.request import ( + LLMCallConfig, + ConfigBlob, + NativeCompletionConfig, + QueryParams, + ) + + config = LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="sarvamai-native", + type="tts", + params={ + "model": "bulbul:v3", + "voice": "simran", + "target_language_code": "{{detected}}", # Marker to be replaced + "speaker": "simran", + "output_audio_codec": "opus", + }, + ) + ) + ) + + with patch("app.services.llm.jobs.get_llm_provider") as mock_provider, patch( + "app.services.llm.jobs.Session" + ): + mock_provider_instance = MagicMock() + mock_provider.return_value = mock_provider_instance + mock_provider_instance.execute.return_value = (None, "test error") + + # Call with detected_language + execute_llm_call( + config=config, + query=QueryParams(input="Test text"), + job_id=uuid4(), + project_id=1, + organization_id=1, + request_metadata=None, + langfuse_credentials=None, + detected_language="ta-IN", # Detected Tamil + ) + + # Verify {{detected}} was replaced with ta-IN + # The marker replacement happens in execute_llm_call before provider.execute is called + # So we check the modified config params + call_args = mock_provider_instance.execute.call_args + # execute is called with (completion_config, query, resolved_input, include_provider_raw_response) + if call_args: + completion_config = ( + call_args[1]["completion_config"] + if len(call_args) > 1 and "completion_config" in call_args[1] + else call_args[0][0] + ) + assert completion_config.params["target_language_code"] == "ta-IN" + + +# ============================================================================ +# Unit Tests: Block Building +# ============================================================================ + + +class TestSTSBlockBuilding: + """Test STT and TTS block configuration.""" + + def test_build_stt_block_with_auto(self): + """Auto language should map to 'unknown' for Sarvam.""" + block = build_stt_block(STTModel.SARVAM, "auto") + + params = block.config.blob.completion.params + assert params["language_code"] == "unknown" + assert params["model"] == "saaras:v3" + assert params["mode"] == "transcribe" + + def test_build_stt_block_with_specific_language(self): + """Specific BCP-47 code should be used as-is.""" + block = build_stt_block(STTModel.SARVAM, "hi-IN") + + params = block.config.blob.completion.params + assert params["language_code"] == "hi-IN" + + def test_build_tts_block_with_detected_marker(self): + """TTS should accept {{detected}} marker for dynamic language.""" + block = build_tts_block(TTSModel.SARVAM, "{{detected}}") + + params = block.config.blob.completion.params + assert params["target_language_code"] == "{{detected}}" + assert params["model"] == "bulbul:v3" + + def test_build_tts_block_with_specific_language(self): + """TTS should accept specific BCP-47 codes.""" + block = build_tts_block(TTSModel.SARVAM, "ta-IN") + + params = block.config.blob.completion.params + assert params["target_language_code"] == "ta-IN" + + +# ============================================================================ +# Integration Tests: Speech-to-Speech Endpoint +# ============================================================================ + + +@pytest.fixture +def mock_audio_input(): + """Sample audio input (base64 encoded).""" + return AudioInput( + type="audio", + content=AudioContent( + format="base64", + value="SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4Ljc2LjEwMAAAAAAAAAAAAAAA//...", + mime_type="audio/ogg", + ), + ) + + +@pytest.fixture +def knowledge_base_ids(): + """Sample knowledge base IDs.""" + return ["kb-india-facts", "kb-general-knowledge"] + + +class TestSpeechToSpeechEndpoint: + """Test the /llm/sts endpoint with realistic scenarios.""" + + def test_sts_auto_detection_hindi_to_hindi( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Real-world scenario: User sends Hindi voice note, expects Hindi response. + Most common use case - auto-detect input, same language output. + """ + with patch("app.api.routes.llm_sts.start_chain_job") as mock_start_job: + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="auto", # Auto-detect + output_language=None, # Should default to detected language + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O, + callback_url="https://example.com/callback", + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert "Speech-to-speech processing initiated" in data["data"]["message"] + + # Verify job was started + mock_start_job.assert_called_once() + + def test_sts_explicit_tamil_to_tamil( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Scenario: Tamil user explicitly sets language to avoid auto-detection. + Use case: Better accuracy when language is known. + """ + with patch("app.api.routes.llm_sts.start_chain_job") as mock_start_job: + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="ta-IN", + output_language="ta-IN", + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O_MINI, + callback_url="https://example.com/callback", + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_start_job.assert_called_once() + + def test_sts_cross_language_hindi_to_english( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Scenario: User speaks Hindi but wants response in English. + Use case: Language learning, multilingual support. + """ + with patch("app.api.routes.llm_sts.start_chain_job") as mock_start_job: + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="hi-IN", + output_language="en-IN", # Respond in English + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O, + callback_url="https://example.com/callback", + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_start_job.assert_called_once() + + def test_sts_invalid_input_language_code( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Error case: User provides invalid BCP-47 code. + Should reject with clear error message. + """ + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="hindi", # Invalid - should be 'hi-IN' + output_language="en-IN", + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O, + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 # API returns 200 with error in body + data = response.json() + assert data["success"] is False + assert "Unsupported input language code" in data["error"] + assert "hindi" in data["error"] + + def test_sts_invalid_output_language_code( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Error case: Invalid output language code. + """ + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="hi-IN", + output_language="french", # Invalid - should be BCP-47 + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O, + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert "Unsupported output language code" in data["error"] + + def test_sts_case_insensitive_language_codes( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + User-friendly case: BCP-47 codes should be case-insensitive. + 'hi-in' should be normalized to 'hi-IN'. + """ + with patch("app.api.routes.llm_sts.start_chain_job") as mock_start_job: + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="hi-in", # Lowercase + output_language="en-in", # Lowercase + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O, + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_start_job.assert_called_once() + + def test_sts_regional_languages( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Test support for regional Indian languages. + Scenario: Malayalam speaker from Kerala. + """ + with patch("app.api.routes.llm_sts.start_chain_job") as mock_start_job: + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="ml-IN", # Malayalam + output_language="ml-IN", + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O_MINI, + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_start_job.assert_called_once() + + def test_sts_without_callback_url( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Callback URL is optional - job should still start. + """ + with patch("app.api.routes.llm_sts.start_chain_job") as mock_start_job: + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="auto", + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O, + # No callback_url + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_start_job.assert_called_once() + + +# ============================================================================ +# Unit Tests: Language Code Validation +# ============================================================================ + + +class TestLanguageCodeSupport: + """Verify all supported BCP-47 codes are valid.""" + + def test_all_supported_codes_are_valid(self): + """All codes in SUPPORTED_LANGUAGE_CODES should be valid BCP-47 format.""" + valid_codes = { + "auto", + "unknown", + "en-IN", + "hi-IN", + "bn-IN", + "kn-IN", + "ml-IN", + "mr-IN", + "od-IN", + "pa-IN", + "ta-IN", + "te-IN", + "gu-IN", + "as-IN", + "ur-IN", + "ne-IN", + "kok-IN", + "ks-IN", + "sd-IN", + "sa-IN", + "sat-IN", + "mni-IN", + "brx-IN", + "mai-IN", + "doi-IN", + } + + assert SUPPORTED_LANGUAGE_CODES == valid_codes + + def test_major_indian_languages_supported(self): + """Verify major Indian languages are supported.""" + major_languages = { + "hi-IN", # Hindi + "bn-IN", # Bengali + "te-IN", # Telugu + "mr-IN", # Marathi + "ta-IN", # Tamil + "ur-IN", # Urdu + "gu-IN", # Gujarati + "kn-IN", # Kannada + "ml-IN", # Malayalam + "pa-IN", # Punjabi + } + + assert major_languages.issubset(SUPPORTED_LANGUAGE_CODES) diff --git a/backend/test_sts_debug.py b/backend/test_sts_debug.py new file mode 100644 index 000000000..f6dd92c10 --- /dev/null +++ b/backend/test_sts_debug.py @@ -0,0 +1,193 @@ +"""Debug script for STS endpoint and chain job execution.""" + +import logging +import sys +from sqlmodel import Session + +# Setup logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def test_chain_job_creation(): + """Test if chain job can be created and queued.""" + from app.core.db import engine + from app.models.llm.request import ( + LLMChainRequest, + QueryParams, + AudioInput, + AudioContent, + ChainBlock, + LLMCallConfig, + ConfigBlob, + NativeCompletionConfig, + ) + from app.services.llm.jobs import start_chain_job + + print("\n" + "=" * 80) + print("STEP 1: Creating test chain request") + print("=" * 80) + + # Create a minimal valid chain request + test_request = LLMChainRequest( + query=QueryParams( + input=AudioInput( + type="audio", + content=AudioContent( + format="base64", + value="dGVzdF9hdWRpbw==", # base64 encoded "test_audio" + mime_type="audio/ogg", + ), + ) + ), + blocks=[ + ChainBlock( + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="sarvamai-native", + type="stt", + params={ + "model": "saarika:v1", + "language_code": "unknown", + "mode": "transcription", + }, + ) + ) + ), + intermediate_callback=True, + ) + ], + ) + + print(f"✅ Test request created with {len(test_request.blocks)} block(s)") + + print("\n" + "=" * 80) + print("STEP 2: Attempting to start chain job") + print("=" * 80) + + try: + with Session(engine) as session: + job_id = start_chain_job( + db=session, + request=test_request, + project_id=1, # Use test project ID + organization_id=1, # Use test org ID + ) + print(f"✅ Chain job created successfully!") + print(f" Job ID: {job_id}") + print(f" Check your Celery worker logs for task execution") + return job_id + except Exception as e: + print(f"❌ Failed to create chain job: {e}") + import traceback + + traceback.print_exc() + return None + + +def check_celery_connection(): + """Check if Celery is running and can receive tasks.""" + print("\n" + "=" * 80) + print("STEP 3: Checking Celery connection") + print("=" * 80) + + try: + from app.celery.celery_app import celery_app + + # Check if broker is reachable + inspector = celery_app.control.inspect() + active_workers = inspector.active() + + if active_workers: + print(f"✅ Celery workers are running:") + for worker_name, tasks in active_workers.items(): + print(f" - {worker_name}: {len(tasks)} active tasks") + else: + print("⚠️ No active Celery workers found!") + print(" Make sure to start the Celery worker with:") + print(" celery -A app.celery.celery_app worker --loglevel=info") + + # Check registered tasks + registered = inspector.registered() + if registered: + print(f"\n✅ Registered tasks:") + for worker_name, tasks in registered.items(): + print(f" Worker: {worker_name}") + for task in sorted(tasks): + if "high_priority" in task or "chain" in task.lower(): + print(f" - {task}") + + except Exception as e: + print(f"❌ Failed to check Celery: {e}") + import traceback + + traceback.print_exc() + + +def check_function_import(): + """Verify execute_chain_job can be imported.""" + print("\n" + "=" * 80) + print("STEP 4: Verifying execute_chain_job import") + print("=" * 80) + + try: + from app.services.llm.jobs import execute_chain_job + + print(f"✅ execute_chain_job is importable") + print(f" Parameters: {execute_chain_job.__code__.co_varnames[:6]}") + + # Try dynamic import (same way Celery does it) + import importlib + + module = importlib.import_module("app.services.llm.jobs") + func = getattr(module, "execute_chain_job") + print(f"✅ Dynamic import successful (same as Celery)") + + except Exception as e: + print(f"❌ Import failed: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + print("\n" + "=" * 80) + print("STS ENDPOINT DEBUG SCRIPT") + print("=" * 80) + + check_function_import() + check_celery_connection() + job_id = test_chain_job_creation() + + if job_id: + print("\n" + "=" * 80) + print("DEBUGGING SUMMARY") + print("=" * 80) + print(f"✅ Chain job was queued successfully: {job_id}") + print(f"\nNext steps:") + print(f"1. Check your Celery worker logs for:") + print( + f" - Task app.celery.tasks.job_execution.execute_high_priority_task received" + ) + print(f" - Executing high_priority job {job_id}") + print(f" - Function path: app.services.llm.jobs.execute_chain_job") + print(f"\n2. If you don't see the task in worker logs:") + print(f" - Verify Celery broker (RabbitMQ/Redis) is running") + print(f" - Check broker connection in Celery worker startup logs") + print(f" - Restart Celery worker") + print(f"\n3. If task starts but fails:") + print(f" - Look for error in Celery worker logs") + print( + f" - Check database for job status: SELECT * FROM job WHERE id = '{job_id}';" + ) + else: + print("\n" + "=" * 80) + print("DEBUGGING SUMMARY") + print("=" * 80) + print(f"❌ Failed to queue chain job") + print(f" Check the error messages above for details") + + print("=" * 80 + "\n")