From 9ac86eac750fa335424cead720952a21717432ad Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Fri, 20 Feb 2026 23:27:32 +0530 Subject: [PATCH 01/15] LLM Chain: Add foundation for chain execution with database schema --- .../versions/048_create_llm_chain_table.py | 181 +++++ backend/app/api/main.py | 2 + backend/app/api/routes/llm_chain.py | 62 ++ backend/app/crud/llm.py | 2 + backend/app/crud/llm_chain.py | 151 ++++ backend/app/models/__init__.py | 3 + backend/app/models/job.py | 1 + backend/app/models/llm/__init__.py | 9 + backend/app/models/llm/request.py | 231 ++++++ backend/app/models/llm/response.py | 39 + backend/app/services/llm/chain/__init__.py | 0 backend/app/services/llm/chain/chain.py | 221 ++++++ backend/app/services/llm/chain/executor.py | 197 +++++ backend/app/services/llm/chain/types.py | 18 + backend/app/services/llm/jobs.py | 700 ++++++++++++------ 15 files changed, 1570 insertions(+), 247 deletions(-) create mode 100644 backend/app/alembic/versions/048_create_llm_chain_table.py create mode 100644 backend/app/api/routes/llm_chain.py create mode 100644 backend/app/crud/llm_chain.py create mode 100644 backend/app/services/llm/chain/__init__.py create mode 100644 backend/app/services/llm/chain/chain.py create mode 100644 backend/app/services/llm/chain/executor.py create mode 100644 backend/app/services/llm/chain/types.py diff --git a/backend/app/alembic/versions/048_create_llm_chain_table.py b/backend/app/alembic/versions/048_create_llm_chain_table.py new file mode 100644 index 000000000..ac49eb0ec --- /dev/null +++ b/backend/app/alembic/versions/048_create_llm_chain_table.py @@ -0,0 +1,181 @@ +"""Create llm_chain table + +Revision ID: 048 +Revises: 047 +Create Date: 2026-02-20 00:00:00.000000 + +""" + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects.postgresql import JSONB + +revision = "048" +down_revision = "047" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # 1. Create llm_chain table + op.create_table( + "llm_chain", + sa.Column( + "id", + sa.Uuid(), + nullable=False, + comment="Unique identifier for the LLM chain record", + ), + sa.Column( + "job_id", + sa.Uuid(), + nullable=False, + comment="Reference to the parent job (status tracked in job table)", + ), + sa.Column( + "project_id", + sa.Integer(), + nullable=False, + comment="Reference to the project this LLM call belongs to", + ), + sa.Column( + "organization_id", + sa.Integer(), + nullable=False, + comment="Reference to the organization this LLM call belongs to", + ), + sa.Column( + "status", + sa.String(), + nullable=False, + server_default="pending", + comment="Chain execution status (pending, running, failed, completed)", + ), + sa.Column( + "error", + sa.Text(), + nullable=True, + comment="Error message if the chain execution failed", + ), + sa.Column( + "block_sequences", + JSONB(), + nullable=True, + comment="Ordered list of llm_call UUIDs as blocks complete", + ), + sa.Column( + "total_blocks", + sa.Integer(), + nullable=False, + comment="Total number of blocks to execute", + ), + sa.Column( + "number_of_blocks_processed", + sa.Integer(), + nullable=False, + server_default="0", + comment="Number of blocks processed so far (used for tracking progress)", + ), + sa.Column( + "input", + sa.String(), + nullable=False, + comment="First block user's input - text string, binary data, or file path for multimodal", + ), + sa.Column( + "output", + JSONB(), + nullable=True, + comment="Last block's final output (set on chain completion)", + ), + sa.Column( + "configs", + JSONB(), + nullable=True, + comment="Ordered list of block configs as submitted in the request", + ), + sa.Column( + "total_usage", + JSONB(), + nullable=True, + comment="Aggregated token usage: {input_tokens, output_tokens, total_tokens}", + ), + sa.Column( + "metadata", + JSONB(), + nullable=True, + comment="Future-proof extensibility catch-all", + ), + sa.Column( + "started_at", + sa.DateTime(), + nullable=True, + comment="Timestamp when chain execution started", + ), + sa.Column( + "completed_at", + sa.DateTime(), + nullable=True, + comment="Timestamp when chain execution completed", + ), + sa.Column( + "created_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the chain record was created", + ), + sa.Column( + "updated_at", + sa.DateTime(), + nullable=False, + comment="Timestamp when the chain record was last updated", + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["job_id"], ["job.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint(["project_id"], ["project.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["organization_id"], ["organization.id"], ondelete="CASCADE" + ), + ) + + op.create_index( + "idx_llm_chain_job_id", + "llm_chain", + ["job_id"], + ) + + # 2. Add chain_id FK column to llm_call table + op.add_column( + "llm_call", + sa.Column( + "chain_id", + sa.Uuid(), + nullable=True, + comment="Reference to the parent chain (NULL for standalone /llm/call requests)", + ), + ) + op.create_foreign_key( + "fk_llm_call_chain_id", + "llm_call", + "llm_chain", + ["chain_id"], + ["id"], + ondelete="SET NULL", + ) + op.create_index( + "idx_llm_call_chain_id", + "llm_call", + ["chain_id"], + postgresql_where=sa.text("chain_id IS NOT NULL"), + ) + + op.execute("ALTER TYPE jobtype ADD VALUE IF NOT EXISTS 'LLM_CHAIN'") + + +def downgrade() -> None: + op.drop_index("idx_llm_call_chain_id", table_name="llm_call") + op.drop_constraint("fk_llm_call_chain_id", "llm_call", type_="foreignkey") + op.drop_column("llm_call", "chain_id") + + op.drop_index("idx_llm_chain_job_id", table_name="llm_chain") + op.drop_table("llm_chain") diff --git a/backend/app/api/main.py b/backend/app/api/main.py index ed58e57f2..5ab1cbd9e 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -10,6 +10,7 @@ login, languages, llm, + llm_chain, organization, openai_conversation, project, @@ -41,6 +42,7 @@ api_router.include_router(evaluations.router) api_router.include_router(languages.router) api_router.include_router(llm.router) +api_router.include_router(llm_chain.router) api_router.include_router(login.router) api_router.include_router(onboarding.router) api_router.include_router(openai_conversation.router) diff --git a/backend/app/api/routes/llm_chain.py b/backend/app/api/routes/llm_chain.py new file mode 100644 index 000000000..0634c2038 --- /dev/null +++ b/backend/app/api/routes/llm_chain.py @@ -0,0 +1,62 @@ +import logging + +from fastapi import APIRouter, Depends +from app.api.deps import AuthContextDep, SessionDep +from app.api.permissions import Permission, require_permission +from app.models import LLMChainRequest, LLMChainResponse, Message +from app.services.llm.jobs import start_chain_job +from app.utils import APIResponse, validate_callback_url, load_description + +logger = logging.getLogger(__name__) + +router = APIRouter(tags=["LLM Chain"]) +llm_callback_router = APIRouter() + + +@llm_callback_router.post( + "{$callback_url}", + name="llm_chain_callback", +) +def llm_callback_notification(body: APIResponse[LLMChainResponse]): + """ + Callback endpoint specification for LLM chain completion. + + The callback will receive: + - On success: APIResponse with success=True and data containing LLMChainResponse + - On failure: APIResponse with success=False and error message + - metadata field will always be included if provided in the request + """ + ... + + +@router.post( + "/llm/chain", + description=load_description("llm/llm_call.md"), + response_model=APIResponse[Message], + callbacks=llm_callback_router.routes, + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def llm_chain( + _current_user: AuthContextDep, _session: SessionDep, request: LLMChainRequest +): + """ + Endpoint to initiate an LLM chain as a background job. + """ + project_id = _current_user.project_.id + organization_id = _current_user.organization_.id + + if request.callback_url: + validate_callback_url(str(request.callback_url)) + + start_chain_job( + db=_session, + request=request, + project_id=project_id, + organization_id=organization_id, + ) + + return APIResponse.success_response( + data=Message( + message="Your response is being generated and will be delivered via callback." + ), + ) diff --git a/backend/app/crud/llm.py b/backend/app/crud/llm.py index b5c23cd6e..32f8ca46f 100644 --- a/backend/app/crud/llm.py +++ b/backend/app/crud/llm.py @@ -53,6 +53,7 @@ def create_llm_call( *, request: LLMCallRequest, job_id: UUID, + chain_id: UUID | None = None, project_id: int, organization_id: int, resolved_config: ConfigBlob, @@ -120,6 +121,7 @@ def create_llm_call( job_id=job_id, project_id=project_id, organization_id=organization_id, + chain_id=chain_id, input=serialize_input(request.query.input), input_type=input_type, output_type=output_type, diff --git a/backend/app/crud/llm_chain.py b/backend/app/crud/llm_chain.py new file mode 100644 index 000000000..77ab70987 --- /dev/null +++ b/backend/app/crud/llm_chain.py @@ -0,0 +1,151 @@ +import logging +from typing import Any +from uuid import UUID + +from sqlmodel import Session + +from app.core.util import now +from app.models.llm.request import ChainStatus, LlmChain + +logger = logging.getLogger(__name__) + + +def create_llm_chain( + session: Session, + *, + job_id: UUID, + project_id: int, + organization_id: int, + total_blocks: int, + input: str, + configs: list[dict[str, Any]], +) -> LlmChain: + """Create a new LLM chain record. + Args: + session: Database session + job_id: Reference to the parent job + project_id: Reference to the project + organization_id: Reference to the organization + total_blocks: Total number of blocks to execute + input: Serialized input string (via serialize_input) + configs: Ordered list of block configs as submitted + + Returns: + LlmChain: The created chain record + """ + db_llm_chain = LlmChain( + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + status=ChainStatus.PENDING, + total_blocks=total_blocks, + number_of_blocks_processed=0, + input=input, + configs=configs, + block_sequences=[], + ) + + session.add(db_llm_chain) + session.commit() + session.refresh(db_llm_chain) + + logger.info( + f"[create_llm_chain] Created LLM chain id={db_llm_chain.id}, " + f"job_id={job_id}, total_blocks={total_blocks}" + ) + + return db_llm_chain + + +def update_llm_chain_status( + session: Session, + *, + chain_id: UUID, + status: ChainStatus, + output: dict[str, Any] | None = None, + total_usage: dict[str, Any] | None = None, + error: str | None = None, +) -> LlmChain: + """Update chain record status and related fields. + Args: + session: Database session + chain_id: The chain record ID + status: New chain status + output: Last block's output dict (only for COMPLETED) + total_usage: Aggregated token usage across all blocks (for COMPLETED/FAILED) + error: Error message (only for FAILED) + + Returns: + LlmChain: The updated chain record + """ + db_chain = session.get(LlmChain, chain_id) + if not db_chain: + raise ValueError(f"LLM chain not found with id={chain_id}") + + db_chain.status = status + db_chain.updated_at = now() + + if status == ChainStatus.RUNNING: + db_chain.started_at = now() + + if status == ChainStatus.FAILED: + db_chain.error = error + db_chain.total_usage = total_usage + db_chain.completed_at = now() + + if status == ChainStatus.COMPLETED: + db_chain.output = output + db_chain.total_usage = total_usage + db_chain.completed_at = now() + + session.add(db_chain) + session.commit() + session.refresh(db_chain) + + logger.info( + f"[update_llm_chain_status] Chain {chain_id} → {status.value} | " + f"has_output={output is not None}, " + f"blocks={db_chain.number_of_blocks_processed}/{db_chain.total_blocks}, " + f"error={error}" + ) + return db_chain + + +def update_llm_chain_block_completed( + session: Session, + *, + chain_id: UUID, + llm_call_id: UUID, +) -> LlmChain: + """Update chain progress after a block completes. + Args: + session: Database session + chain_id: The chain record ID + llm_call_id: The llm_call record ID for the completed block + + Returns: + LlmChain: The updated chain record + """ + db_chain = session.get(LlmChain, chain_id) + if not db_chain: + raise ValueError(f"LLM chain not found with id={chain_id}") + + # Append to block_sequences + sequences = list(db_chain.block_sequences or []) + sequences.append(str(llm_call_id)) + db_chain.block_sequences = sequences + + # Increment progress + db_chain.number_of_blocks_processed = len(sequences) + db_chain.updated_at = now() + + session.add(db_chain) + session.commit() + session.refresh(db_chain) + + logger.info( + f"[update_llm_chain_block_completed] Chain {chain_id} | " + f"block={db_chain.number_of_blocks_processed}/{db_chain.total_blocks}, " + f"llm_call_id={llm_call_id}" + ) + return db_chain diff --git a/backend/app/models/__init__.py b/backend/app/models/__init__.py index 2c28d7b4f..c76a02579 100644 --- a/backend/app/models/__init__.py +++ b/backend/app/models/__init__.py @@ -111,6 +111,9 @@ LLMCallRequest, LLMCallResponse, LlmCall, + LLMChainRequest, + LLMChainResponse, + LlmChain, ) from .message import Message diff --git a/backend/app/models/job.py b/backend/app/models/job.py index b6a1a5ae7..3b20249f5 100644 --- a/backend/app/models/job.py +++ b/backend/app/models/job.py @@ -17,6 +17,7 @@ class JobStatus(str, Enum): class JobType(str, Enum): RESPONSE = "RESPONSE" LLM_API = "LLM_API" + LLM_CHAIN = "LLM_CHAIN" class Job(SQLModel, table=True): diff --git a/backend/app/models/llm/__init__.py b/backend/app/models/llm/__init__.py index b183543c4..9bcf3a035 100644 --- a/backend/app/models/llm/__init__.py +++ b/backend/app/models/llm/__init__.py @@ -9,6 +9,13 @@ LlmCall, AudioContent, TextContent, + TextInput, + AudioInput, + PromptTemplate, + ChainBlock, + ChainStatus, + LLMChainRequest, + LlmChain, ) from app.models.llm.response import ( LLMCallResponse, @@ -17,4 +24,6 @@ Usage, TextOutput, AudioOutput, + LLMChainResponse, + IntermediateChainResponse, ) diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index b90fb6229..d6abd7d8d 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -1,3 +1,4 @@ +from enum import Enum from typing import Annotated, Any, Literal, Union from uuid import UUID, uuid4 @@ -214,11 +215,21 @@ class Validator(SQLModel): validator_config_id: UUID +class PromptTemplate(SQLModel): + template: str = Field(..., description="Template string with {{input}} placeholder") + + class ConfigBlob(SQLModel): """Raw JSON blob of config.""" completion: CompletionConfig = Field(..., description="Completion configuration") + # used for llm-chain to provide prompt interpolation + prompt_template: PromptTemplate | None = Field( + default=None, + description="Prompt template with {{input}} placeholder to wrap around the user input", + ) + input_guardrails: list[Validator] | None = Field( default=None, description="Guardrails applied to validate/sanitize the input before the LLM call", @@ -384,6 +395,16 @@ class LlmCall(SQLModel, table=True): }, ) + chain_id: UUID | None = Field( + default=None, + foreign_key="llm_chain.id", + nullable=True, + ondelete="SET NULL", + sa_column_kwargs={ + "comment": "Reference to the parent chain (NULL for standalone llm_call requests)" + }, + ) + # Request fields input: str = Field( ..., @@ -496,3 +517,213 @@ class LlmCall(SQLModel, table=True): nullable=True, sa_column_kwargs={"comment": "Timestamp when the record was soft-deleted"}, ) + + +class ChainBlock(SQLModel): + """A single block in an LLM chain execution.""" + + config: LLMCallConfig = Field( + ..., description="LLM call configuration (stored id+version OR ad-hoc blob)" + ) + + include_provider_raw_response: bool = Field( + default=False, + description="Whether to include the raw LLM provider response in the output for this block", + ) + + intermediate_callback: bool = Field( + default=False, + description="Whether to send intermediate callback after this block completes", + ) + + +class LLMChainRequest(SQLModel): + """ + API request for an LLM chain execution. + + Orchestrates multiple LLM calls sequentially where each block's output + becomes the next block's input. + """ + + query: QueryParams = Field( + ..., description="Initial query input for the first block in the chain" + ) + + blocks: list[ChainBlock] = Field( + ..., min_length=1, description="Ordered list of blocks to execute sequentially" + ) + + callback_url: HttpUrl | None = Field( + default=None, description="Webhook URL for async response delivery" + ) + + request_metadata: dict[str, Any] | None = Field( + default=None, + description=( + "Client-provided metadata passed through unchanged in the response. " + "Use this to correlate responses with requests or track request state. " + "The exact dictionary provided here will be returned in the response metadata field." + ), + ) + + +class ChainStatus(str, Enum): + """Status of an LLM chain execution.""" + + PENDING = "pending" + RUNNING = "running" + FAILED = "failed" + COMPLETED = "completed" + + +class LlmChain(SQLModel, table=True): + """ + Database model for tracking LLM chain execution + + it manages and orchestrates sequential llm_call executions. + """ + + __tablename__ = "llm_chain" + __table_args__ = ( + Index( + "idx_llm_chain_job_id", + "job_id", + ), + ) + + id: UUID = Field( + default_factory=uuid4, + primary_key=True, + sa_column_kwargs={"comment": "Unique identifier for the LLM chain record"}, + ) + + job_id: UUID = Field( + foreign_key="job.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the parent job (status tracked in job table)" + }, + ) + + project_id: int = Field( + foreign_key="project.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the project this LLM call belongs to" + }, + ) + + organization_id: int = Field( + foreign_key="organization.id", + nullable=False, + ondelete="CASCADE", + sa_column_kwargs={ + "comment": "Reference to the organization this LLM call belongs to" + }, + ) + + status: ChainStatus = Field( + default=ChainStatus.PENDING, + sa_column_kwargs={ + "comment": "Chain execution status (pending, running, failed, completed)" + }, + ) + + error: str | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Error message if the chain execution failed"}, + ) + + block_sequences: list[str] | None = Field( + default_factory=list, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Ordered list of llm_call UUIDs as blocks complete", + ), + ) + + total_blocks: int = Field( + ..., sa_column_kwargs={"comment": "Total number of blocks to execute"} + ) + + number_of_blocks_processed: int = Field( + default=0, + sa_column_kwargs={ + "comment": "Number of blocks processed so far (used for tracking progress)" + }, + ) + + # Request fields + input: str = Field( + ..., + sa_column_kwargs={ + "comment": "First block user's input - text string, binary data, or file path for multimodal" + }, + ) + + output: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Last block's final output (set on chain completion)", + ), + ) + + configs: list[dict[str, Any]] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Ordered list of block configs as submitted in the request", + ), + ) + + total_usage: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + JSONB, + nullable=True, + comment="Aggregated token usage: {input_tokens, output_tokens, total_tokens}", + ), + ) + + metadata_: dict[str, Any] | None = Field( + default=None, + sa_column=sa.Column( + "metadata", + JSONB, + nullable=True, + comment="Future-proof extensibility catch-all", + ), + ) + + started_at: datetime | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Timestamp when chain execution started"}, + ) + + completed_at: datetime | None = Field( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Timestamp when chain execution completed"}, + ) + + created_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={"comment": "Timestamp when the chain record was created"}, + ) + + updated_at: datetime = Field( + default_factory=now, + nullable=False, + sa_column_kwargs={ + "comment": "Timestamp when the chain record was last updated" + }, + ) diff --git a/backend/app/models/llm/response.py b/backend/app/models/llm/response.py index 7b13e301c..1ae7619f6 100644 --- a/backend/app/models/llm/response.py +++ b/backend/app/models/llm/response.py @@ -62,3 +62,42 @@ class LLMCallResponse(SQLModel): default=None, description="Unmodified raw response from the LLM provider.", ) + + +class LLMChainResponse(SQLModel): + """Response schema for an LLM chain execution.""" + + response: LLMResponse = Field( + ..., description="LLM response from the final step of the chain execution." + ) + usage: Usage = Field( + ..., + description="Aggregate token usage and cost for the entire chain execution.", + ) + provider_raw_response: dict[str, object] | None = Field( + default=None, + description="Raw provider response from the last block (if requested)", + ) + + +class IntermediateChainResponse(SQLModel): + """ + Intermediate callback response from the intermediate blocks + from the llm chain execution. (if configured) + + Flattend structure matching LLMCallResponse keys for consistency + """ + + type: Literal["intermediate"] = "intermediate" + block_index: int = Field(..., description="Current block position") + total_blocks: int = Field(..., description="Total number of blocks in the chain") + response: LLMResponse = Field( + ..., description="LLM Response from the current block" + ) + usage: Usage = Field( + ..., description="Token usage and cost information from the current block" + ) + provider_raw_response: dict[str, object] | None = Field( + default=None, + description="Unmodified raw response from the LLM provider from the current block", + ) diff --git a/backend/app/services/llm/chain/__init__.py b/backend/app/services/llm/chain/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/backend/app/services/llm/chain/chain.py b/backend/app/services/llm/chain/chain.py new file mode 100644 index 000000000..390247d8d --- /dev/null +++ b/backend/app/services/llm/chain/chain.py @@ -0,0 +1,221 @@ +import logging +from dataclasses import dataclass, field +from typing import Any +from uuid import UUID + +from sqlmodel import Session + +from app.core.db import engine +from app.crud.llm_chain import update_llm_chain_block_completed +from app.models.llm.request import ( + LLMCallConfig, + QueryParams, + TextInput, + TextContent, + AudioInput, +) +from app.models.llm.response import ( + IntermediateChainResponse, + TextOutput, + AudioOutput, + Usage, +) +from app.services.llm.chain.types import BlockResult +from app.services.llm.jobs import execute_llm_call +from app.utils import APIResponse, send_callback + + +logger = logging.getLogger(__name__) + + +@dataclass +class ChainContext: + """Shared state passed to all blocks. Accumulates responses.""" + + job_id: UUID + chain_id: UUID + project_id: int + organization_id: int + callback_url: str + total_blocks: int + + langfuse_credentials: dict[str, Any] | None = None + request_metadata: dict | None = None + intermediate_callback_flags: list[bool] = field(default_factory=list) + aggregated_usage: Usage = field( + default_factory=lambda: Usage( + input_tokens=0, + output_tokens=0, + total_tokens=0, + ) + ) + + def on_block_completed(self, block_index: int, result: BlockResult) -> None: + """Called after each block completes. Updates chain state in DB and sends intermediate callback.""" + + if result.usage: + self.aggregated_usage.input_tokens += result.usage.input_tokens + self.aggregated_usage.output_tokens += result.usage.output_tokens + self.aggregated_usage.total_tokens += result.usage.total_tokens + + if result.success and result.llm_call_id: + with Session(engine) as session: + update_llm_chain_block_completed( + session, + chain_id=self.chain_id, + llm_call_id=result.llm_call_id, + ) + + if ( + block_index < len(self.intermediate_callback_flags) + and self.intermediate_callback_flags[block_index] + and self.callback_url + ): + self._send_intermediate_callback(block_index, result) + + def _send_intermediate_callback( + self, block_index: int, result: BlockResult + ) -> None: + """Send intermediate callback for a completed block.""" + try: + intermediate = IntermediateChainResponse( + block_index=block_index + 1, + total_blocks=self.total_blocks, + response=result.response.response, + usage=result.usage, + provider_raw_response=result.response.provider_raw_response, + ) + callback_data = APIResponse.success_response( + data=intermediate, + metadata=self.request_metadata, + ) + send_callback( + callback_url=self.callback_url, + data=callback_data.model_dump(), + ) + logger.info( + f"[ChainContext] Sent intermediate callback | " + f"block={block_index + 1}/{self.total_blocks}, job_id={self.job_id}" + ) + except Exception as e: + logger.warning( + f"[ChainContext] Failed to send intermediate callback: {e} | " + f"block={block_index + 1}/{self.total_blocks}, job_id={self.job_id}" + ) + + +def result_to_query(result: BlockResult) -> QueryParams: + """Convert a block's output into the next block's QueryParams. + + Text output → TextInput query + Audio output → AudioInput query + """ + output = result.response.response.output + + if isinstance(output, TextOutput): + return QueryParams( + input=TextInput(content=TextContent(value=output.content.value)) + ) + elif isinstance(output, AudioOutput): + return QueryParams(input=AudioInput(content=output.content)) + else: + raise ValueError(f"Cannot chain output type: {output.type}") + + +class ChainBlock: + """A single node in the linked chain. + + Wraps execute_block() with linking capability. + Each block knows its next block and forwards output to it. + """ + + def __init__( + self, + *, + config: LLMCallConfig, + index: int, + context: ChainContext, + include_provider_raw_response: bool = False, + ): + self._config = config + self._index = index + self._context = context + self._include_provider_raw_response = include_provider_raw_response + self._next: ChainBlock | None = None + + def link(self, next_block: "ChainBlock") -> "ChainBlock": + """Link to the next block in the chain.""" + self._next = next_block + return next_block + + def execute(self, query: QueryParams) -> BlockResult: + """Execute this block, then flow to next. + + No loop. Each block calls the next via the linked reference. + Data flows through the chain like a linked list traversal. + """ + logger.info( + f"[ChainBlock.execute] Executing block {self._index} | " + f"job_id={self._context.job_id}" + ) + + result = execute_llm_call( + config=self._config, + query=query, + job_id=self._context.job_id, + project_id=self._context.project_id, + organization_id=self._context.organization_id, + request_metadata=self._context.request_metadata, + langfuse_credentials=self._context.langfuse_credentials, + include_provider_raw_response=self._include_provider_raw_response, + chain_id=self._context.chain_id, + ) + + self._context.on_block_completed(self._index, result) + + if not result.success: + logger.error( + f"[ChainBlock.execute] Block {self._index} failed: {result.error} | " + f"job_id={self._context.job_id}" + ) + return result + + if self._next: + next_query = result_to_query(result) + return self._next.execute(next_query) + + logger.info( + f"[ChainBlock.execute] Block {self._index} is the last block | " + f"job_id={self._context.job_id}" + ) + return result + + +class LLMChain: + """Links ChainBlocks together into a sequential chain. + + Construction builds the linked structure. + Execution pushes input into the head — it flows through to the tail. + """ + + def __init__(self, blocks: list[ChainBlock]): + self._head: ChainBlock | None = None + self._tail: ChainBlock | None = None + self._link_blocks(blocks) + + def _link_blocks(self, blocks: list[ChainBlock]) -> None: + """Link all blocks in sequence.""" + if not blocks: + return + self._head = blocks[0] + self._tail = blocks[-1] + prev = blocks[0] + for curr in blocks[1:]: + prev.link(curr) + prev = curr + + def execute(self, query: QueryParams) -> BlockResult: + """Push input into the chain head. It flows through to the tail.""" + if not self._head: + return BlockResult(error="Chain has no blocks") + return self._head.execute(query) diff --git a/backend/app/services/llm/chain/executor.py b/backend/app/services/llm/chain/executor.py new file mode 100644 index 000000000..78808d84c --- /dev/null +++ b/backend/app/services/llm/chain/executor.py @@ -0,0 +1,197 @@ +import logging + +from sqlmodel import Session + +from app.core.db import engine +from app.crud.config import ConfigVersionCrud +from app.crud.jobs import JobCrud +from app.crud.llm_chain import update_llm_chain_status +from app.models import JobStatus, JobUpdate +from app.models.llm.request import ( + ChainStatus, + ConfigBlob, + LLMChainRequest, +) +from app.models.llm.response import LLMChainResponse +from app.services.llm.chain.chain import ChainContext, LLMChain +from app.services.llm.chain.types import BlockResult +from app.services.llm.jobs import ( + apply_input_guardrails, + apply_output_guardrails, + resolve_config_blob, +) +from app.utils import APIResponse, send_callback + +logger = logging.getLogger(__name__) + + +class ChainExecutor: + """Manage the lifecycle of an LLM chain execution.""" + + def __init__( + self, + *, + chain: LLMChain, + context: ChainContext, + request: LLMChainRequest, + ): + self._chain = chain + self._context = context + self._request = request + + def run(self) -> dict: + """Execute the full chain lifecycle. Returns serialized APIResponse.""" + try: + self._setup() + + first_config_blob, resolve_error = self._resolve_block_config_blob(0) + if resolve_error: + return self._handle_error(resolve_error) + + query, error = apply_input_guardrails( + config_blob=first_config_blob, + query=self._request.query, + job_id=self._context.job_id, + project_id=self._context.project_id, + organization_id=self._context.organization_id, + ) + if error: + return self._handle_error(error) + + result = self._chain.execute(query) + + if result.success: + last_config_blob, resolve_error = self._resolve_block_config_blob( + len(self._request.blocks) - 1 + ) + if resolve_error: + return self._handle_error(resolve_error) + + result, error = apply_output_guardrails( + config_blob=last_config_blob, + result=result, + job_id=self._context.job_id, + project_id=self._context.project_id, + organization_id=self._context.organization_id, + ) + if error: + return self._handle_error(error) + + return self._teardown(result) + + except Exception as e: + return self._handle_unexpected_error(e) + + def _resolve_block_config_blob( + self, block_index: int + ) -> tuple[ConfigBlob | None, str | None]: + """Resolve a block's config to its ConfigBlob. + + Uses is_stored_config property (same pattern as execute_job in jobs.py): + - Stored config (is_stored_config=True): fetch from DB via resolve_config_blob() + - Ad-hoc config (blob provided): return blob directly + + Returns: + (config_blob, error): ConfigBlob on success, or error string on failure + """ + block = self._request.blocks[block_index] + config = block.config + + if not config.is_stored_config: + return config.blob, None + + with Session(engine) as session: + config_crud = ConfigVersionCrud( + session=session, + project_id=self._context.project_id, + config_id=config.id, + ) + config_blob, error = resolve_config_blob(config_crud, config) + if error: + return None, error + return config_blob, None + + def _setup(self) -> None: + with Session(engine) as session: + JobCrud(session).update( + job_id=self._context.job_id, + job_update=JobUpdate(status=JobStatus.PROCESSING), + ) + + update_llm_chain_status( + session=session, + chain_id=self._context.chain_id, + status=ChainStatus.RUNNING, + ) + + def _teardown(self, result: BlockResult) -> dict: + """Finalize chain record, send callback, and update job status.""" + + with Session(engine) as session: + if result.success: + final = LLMChainResponse( + response=result.response.response, + usage=result.usage, + provider_raw_response=result.response.provider_raw_response, + ) + callback_response = APIResponse.success_response( + data=final, metadata=self._request.request_metadata + ) + if self._request.callback_url: + send_callback( + callback_url=str(self._request.callback_url), + data=callback_response.model_dump(), + ) + JobCrud(session).update( + job_id=self._context.job_id, + job_update=JobUpdate(status=JobStatus.SUCCESS), + ) + update_llm_chain_status( + session=session, + chain_id=self._context.chain_id, + status=ChainStatus.COMPLETED, + output=result.response.response.output.model_dump(), + total_usage=self._context.aggregated_usage.model_dump(), + ) + return callback_response.model_dump() + else: + return self._handle_error(result.error) + + def _handle_error(self, error: str) -> dict: + callback_response = APIResponse.failure_response( + error=error or "Unknown error occurred", + metadata=self._request.request_metadata, + ) + logger.error( + f"[ChainExecutor] Chain execution failed | " + f"chain_id={self._context.chain_id}, job_id={self._context.job_id}, error={error}" + ) + + with Session(engine) as session: + if self._request.callback_url: + send_callback( + callback_url=str(self._request.callback_url), + data=callback_response.model_dump(), + ) + + update_llm_chain_status( + session, + chain_id=self._context.chain_id, + status=ChainStatus.FAILED, + output=None, + total_usage=self._context.aggregated_usage.model_dump(), + error=error, + ) + JobCrud(session).update( + job_id=self._context.job_id, + job_update=JobUpdate(status=JobStatus.FAILED, error_message=error), + ) + return callback_response.model_dump() + + def _handle_unexpected_error(self, e: Exception) -> dict: + logger.error( + f"[ChainExecutor.run] Unexpected error: {e} | " + f"job_id={self._context.job_id}", + exc_info=True, + ) + return self._handle_error("Unexpected error occurred") diff --git a/backend/app/services/llm/chain/types.py b/backend/app/services/llm/chain/types.py new file mode 100644 index 000000000..69ab3d02f --- /dev/null +++ b/backend/app/services/llm/chain/types.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass +from uuid import UUID + +from app.models.llm.response import LLMCallResponse, Usage + + +@dataclass +class BlockResult: + """Result of a single block/LLM call execution.""" + + response: LLMCallResponse | None = None + llm_call_id: UUID | None = None + usage: Usage | None = None + error: str | None = None + + @property + def success(self) -> bool: + return self.error is None and self.response is not None diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index c6997a084..cd71e5bfa 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -11,23 +11,26 @@ from app.crud.config import ConfigVersionCrud from app.crud.credentials import get_provider_credential from app.crud.jobs import JobCrud -from app.crud.llm import create_llm_call, update_llm_call_response -from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, Job +from app.crud.llm import create_llm_call, serialize_input, update_llm_call_response +from app.crud.llm_chain import create_llm_chain, update_llm_chain_status +from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, LLMChainRequest from app.models.llm.request import ( + ChainStatus, ConfigBlob, - LLMCallConfig, KaapiCompletionConfig, + LLMCallConfig, + QueryParams, TextInput, ) from app.models.llm.response import TextOutput +from app.services.llm.chain.types import BlockResult from app.services.llm.guardrails import ( list_validators_config, run_guardrails_validation, ) -from app.services.llm.providers.registry import get_llm_provider +from app.services.llm.input_resolver import cleanup_temp_file, resolve_input from app.services.llm.mappers import transform_kaapi_config_to_native -from app.services.llm.input_resolver import resolve_input, cleanup_temp_file - +from app.services.llm.providers.registry import get_llm_provider from app.utils import APIResponse, send_callback logger = logging.getLogger(__name__) @@ -75,6 +78,49 @@ def start_job( return job.id +def start_chain_job( + db: Session, request: LLMChainRequest, project_id: int, organization_id: int +) -> UUID: + """Create an LLM Chain job and schedule Celery task.""" + trace_id = correlation_id.get() or "N/A" + job_crud = JobCrud(session=db) + job = job_crud.create(job_type=JobType.LLM_CHAIN, trace_id=trace_id) + + # Explicitly flush to ensure job is persisted before Celery task starts + db.flush() + db.commit() + + logger.info( + f"[start_chain_job] Created job | job_id={job.id}, status={job.status}, project_id={project_id}" + ) + + try: + task_id = start_high_priority_job( + function_path="app.services.llm.jobs.execute_chain_job", + project_id=project_id, + job_id=str(job.id), + trace_id=trace_id, + request_data=request.model_dump(mode="json"), + organization_id=organization_id, + ) + except Exception as e: + logger.error( + f"[start_chain_job] Error starting Celery task: {str(e)} | job_id={job.id}, project_id={project_id}", + exc_info=True, + ) + job_update = JobUpdate(status=JobStatus.FAILED, error_message=str(e)) + job_crud.update(job_id=job.id, job_update=job_update) + raise HTTPException( + status_code=500, + detail="Internal server error while executing LLM chain job", + ) + + logger.info( + f"[start_chain_job] Job scheduled for LLM chain job | job_id={job.id}, project_id={project_id}, task_id={task_id}" + ) + return job.id + + def handle_job_error( job_id: UUID, callback_url: str | None, @@ -136,226 +182,225 @@ def resolve_config_blob( return None, "Unexpected error occurred while parsing stored configuration" -def execute_job( - request_data: dict, +def apply_input_guardrails( + *, + config_blob: ConfigBlob | None, + query: QueryParams, + job_id: UUID, project_id: int, organization_id: int, - job_id: str, - task_id: str, - task_instance, -) -> dict: - """Celery task to process an LLM request asynchronously. +) -> tuple[QueryParams, str | None]: + """Apply input guardrails from a config_blob. Shared with llm-call and llm-chain.""" + if not config_blob or not config_blob.input_guardrails: + return query, None + + if not isinstance(query.input, TextInput): + logger.info( + f"[apply_input_guardrails] Skipping for non-text input. " + f"job_id={job_id}, " + f"input_type={getattr(query.input, 'type', type(query.input).__name__)}" + ) + return query, None - Returns: - dict: Serialized APIResponse[LLMCallResponse] on success, APIResponse[None] on failure - """ + input_guardrails, _ = list_validators_config( + organization_id=organization_id, + project_id=project_id, + input_validator_configs=config_blob.input_guardrails, + output_validator_configs=None, + ) - request = LLMCallRequest(**request_data) - job_id: UUID = UUID(job_id) + if not input_guardrails: + return query, None - config = request.config - callback_response = None - config_blob: ConfigBlob | None = None - input_guardrails: list[dict] = [] - output_guardrails: list[dict] = [] - llm_call_id: UUID | None = None # Track the LLM call record + safe = run_guardrails_validation( + query.input.content.value, + input_guardrails, + job_id, + project_id, + organization_id, + suppress_pass_logs=True, + ) logger.info( - f"[execute_job] Starting LLM job execution | job_id={job_id}, task_id={task_id}, " + f"[apply_input_guardrails] Validation result | success={safe['success']}, job_id={job_id}" ) - try: - with Session(engine) as session: - # Update job status to PROCESSING - job_crud = JobCrud(session=session) - logger.info(f"[execute_job] Attempting to fetch job | job_id={job_id}") - job = session.get(Job, job_id) - if not job: - # Log all jobs to see what's in the database - from sqlmodel import select - - all_jobs = session.exec( - select(Job).order_by(Job.created_at.desc()).limit(5) - ).all() - logger.error( - f"[execute_job] Job not found! | job_id={job_id} | " - f"Recent jobs in DB: {[(j.id, j.status) for j in all_jobs]}" - ) - else: - logger.info( - f"[execute_job] Found job | job_id={job_id}, status={job.status}" - ) + if safe.get("bypassed"): + logger.info( + f"[apply_input_guardrails] Guardrails bypassed (service unavailable) | job_id={job_id}" + ) + return query, None - job_crud.update( - job_id=job_id, job_update=JobUpdate(status=JobStatus.PROCESSING) - ) + if safe["success"]: + query.input.content.value = safe["data"]["safe_text"] + return query, None - # if stored config, fetch blob from DB - if config.is_stored_config: - config_crud = ConfigVersionCrud( - session=session, project_id=project_id, config_id=config.id - ) + return query, safe["error"] - # blob is dynamic, need to resolve to ConfigBlob format - config_blob, error = resolve_config_blob(config_crud, config) - if error: - callback_response = APIResponse.failure_response( - error=error, - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) +def apply_output_guardrails( + *, + config_blob: ConfigBlob | None, + result: BlockResult, + job_id: UUID, + project_id: int, + organization_id: int, +) -> tuple[BlockResult, str | None]: + """Apply output guardrails from a config_blob. Shared by /llm/call and /llm/chain. - else: - config_blob = config.blob + Returns (modified_result, None) on success, or (result, error_string) on failure. + """ + if not config_blob or not config_blob.output_guardrails: + return result, None + + if not isinstance(result.response.response.output, TextOutput): + logger.info( + f"[apply_output_guardrails] Skipping for non-text output. " + f"job_id={job_id}, " + f"output_type={getattr(result.response.response.output, 'type', type(result.response.response.output).__name__)}" + ) + return result, None - if config_blob is not None: - if config_blob.input_guardrails or config_blob.output_guardrails: - input_guardrails, output_guardrails = list_validators_config( - organization_id=organization_id, - project_id=project_id, - input_validator_configs=config_blob.input_guardrails, - output_validator_configs=config_blob.output_guardrails, - ) + _, output_guardrails = list_validators_config( + organization_id=organization_id, + project_id=project_id, + input_validator_configs=None, + output_validator_configs=config_blob.output_guardrails, + ) - if input_guardrails: - if not isinstance(request.query.input, TextInput): - logger.info( - "[execute_job] Skipping input guardrails for non-text input. " - f"job_id={job_id}, input_type={getattr(request.query.input, 'type', type(request.query.input).__name__)}" - ) - else: - safe_input = run_guardrails_validation( - request.query.input.content.value, - input_guardrails, - job_id, - project_id, - organization_id, - suppress_pass_logs=True, - ) + if not output_guardrails: + return result, None + + output_text = result.response.response.output.content.value + safe = run_guardrails_validation( + output_text, + output_guardrails, + job_id, + project_id, + organization_id, + suppress_pass_logs=True, + ) - logger.info( - f"[execute_job] Input guardrail validation | success={safe_input['success']}." - ) + logger.info( + f"[apply_output_guardrails] Validation result | success={safe['success']}, job_id={job_id}" + ) - if safe_input.get("bypassed"): - logger.info( - "[execute_job] Guardrails bypassed (service unavailable)" - ) + if safe.get("bypassed"): + logger.info( + f"[apply_output_guardrails] Guardrails bypassed (service unavailable) | job_id={job_id}" + ) + return result, None - elif safe_input["success"]: - request.query.input.content.value = safe_input["data"][ - "safe_text" - ] - else: - # Update the text value with error message - request.query.input.content.value = safe_input["error"] - - callback_response = APIResponse.failure_response( - error=safe_input["error"], - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - user_sent_config_provider = "" + if safe["success"]: + result.response.response.output.content.value = safe["data"]["safe_text"] + if safe["data"].get("rephrase_needed"): + return result, result.response.response.output.content.value + return result, None - try: - # Transform Kaapi config to native config if needed (before getting provider) - completion_config = config_blob.completion + return result, safe["error"] - original_provider = ( - config_blob.completion.provider - ) # openai, google or prefixed - if isinstance(completion_config, KaapiCompletionConfig): - completion_config, warnings = transform_kaapi_config_to_native( - completion_config - ) +def execute_llm_call( + *, + config: LLMCallConfig, + query: QueryParams, + job_id: UUID, + project_id: int, + organization_id: int, + request_metadata: dict | None, + langfuse_credentials: dict | None, + include_provider_raw_response: bool = False, + chain_id: UUID | None = None, +) -> BlockResult: + """Execute a single LLM call. Shared by /llm/call and /llm/chain. + + Returns BlockResult with response + usage on success, or error on failure. + """ - if request.request_metadata is None: - request.request_metadata = {} - request.request_metadata.setdefault("warnings", []).extend(warnings) - else: - pass - except Exception as e: - callback_response = APIResponse.failure_response( - error=f"Error processing configuration: {str(e)}", - metadata=request.request_metadata, + config_blob: ConfigBlob | None = None + llm_call_id: UUID | None = None + + try: + with Session(engine) as session: + if config.is_stored_config: + config_crud = ConfigVersionCrud( + session=session, project_id=project_id, config_id=config.id ) - return handle_job_error(job_id, request.callback_url, callback_response) + config_blob, error = resolve_config_blob(config_crud, config) + if error: + return BlockResult(error=error) + else: + config_blob = config.blob - # Create LLM call record before execution - try: - # Rebuild ConfigBlob with transformed native config - resolved_config_blob = ConfigBlob( - completion=completion_config, - input_guardrails=config_blob.input_guardrails, - output_guardrails=config_blob.output_guardrails, + if config_blob.prompt_template and isinstance(query.input, TextInput): + template = config_blob.prompt_template.template + interpolated = template.replace("{{input}}", query.input.content.value) + query.input.content.value = interpolated + + completion_config = config_blob.completion + original_provider = completion_config.provider + + if isinstance(completion_config, KaapiCompletionConfig): + completion_config, warnings = transform_kaapi_config_to_native( + completion_config ) + if request_metadata is None: + request_metadata = {} + request_metadata.setdefault("warnings", []).extend(warnings) + + resolved_config_blob = ConfigBlob( + completion=completion_config, + prompt_template=config_blob.prompt_template, + input_guardrails=config_blob.input_guardrails, + output_guardrails=config_blob.output_guardrails, + ) + try: + temp_request = LLMCallRequest( + query=query, + config=config, + request_metadata=request_metadata, + ) llm_call = create_llm_call( session, - request=request, + request=temp_request, job_id=job_id, project_id=project_id, organization_id=organization_id, resolved_config=resolved_config_blob, original_provider=original_provider, + chain_id=chain_id, ) llm_call_id = llm_call.id logger.info( - f"[execute_job] Created LLM call record | llm_call_id={llm_call_id}, job_id={job_id}" + f"[execute_llm_call] Created LLM call record | " + f"llm_call_id={llm_call_id}, job_id={job_id}" ) except Exception as e: logger.error( - f"[execute_job] Failed to create LLM call record: {str(e)} | job_id={job_id}", + f"[execute_llm_call] Failed to create LLM call record: {e} | job_id={job_id}", exc_info=True, ) - callback_response = APIResponse.failure_response( - error=f"Failed to create LLM call record: {str(e)}", - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) + return BlockResult(error=f"Failed to create LLM call record: {str(e)}") try: provider_instance = get_llm_provider( session=session, - provider_type=completion_config.provider, # Now always native provider type i.e openai-native, google-native regardless + provider_type=completion_config.provider, project_id=project_id, organization_id=organization_id, ) except ValueError as ve: - callback_response = APIResponse.failure_response( - error=str(ve), - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) - - langfuse_credentials = get_provider_credential( - session=session, - org_id=organization_id, - project_id=project_id, - provider="langfuse", - ) + return BlockResult(error=str(ve), llm_call_id=llm_call_id) - # Extract conversation_id for langfuse session grouping conversation_id = None - if request.query.conversation and request.query.conversation.id: - conversation_id = request.query.conversation.id + if query.conversation and query.conversation.id: + conversation_id = query.conversation.id - # Resolve input (handles text, audio_base64, audio_url) - resolved_input, resolve_error = resolve_input(request.query.input) + resolved_input, resolve_error = resolve_input(query.input) if resolve_error: - callback_response = APIResponse.failure_response( - error=resolve_error, - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) + return BlockResult(error=resolve_error, llm_call_id=llm_call_id) - # Apply Langfuse observability decorator to provider execute method decorated_execute = observe_llm_execution( credentials=langfuse_credentials, session_id=conversation_id, @@ -364,80 +409,16 @@ def execute_job( try: response, error = decorated_execute( completion_config=completion_config, - query=request.query, + query=query, resolved_input=resolved_input, - include_provider_raw_response=request.include_provider_raw_response, + include_provider_raw_response=include_provider_raw_response, ) finally: - # Clean up temp files for audio inputs - if resolved_input and resolved_input != request.query.input: + if resolved_input and resolved_input != query.input: cleanup_temp_file(resolved_input) if response: - if output_guardrails: - if not isinstance(response.response.output, TextOutput): - logger.info( - "[execute_job] Skipping output guardrails for non-text output. " - f"job_id={job_id}, output_type={getattr(response.response.output, 'type', type(response.response.output).__name__)}" - ) - else: - output_text = response.response.output.content.value - safe_output = run_guardrails_validation( - output_text, - output_guardrails, - job_id, - project_id, - organization_id, - suppress_pass_logs=True, - ) - - logger.info( - f"[execute_job] Output guardrail validation | success={safe_output['success']}." - ) - - if safe_output.get("bypassed"): - logger.info( - "[execute_job] Guardrails bypassed (service unavailable)" - ) - - elif safe_output["success"]: - response.response.output.content.value = safe_output["data"][ - "safe_text" - ] - - if safe_output["data"]["rephrase_needed"] == True: - callback_response = APIResponse.failure_response( - error=request.query.input, - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - - else: - response.response.output.content.value = safe_output["error"] - - callback_response = APIResponse.failure_response( - error=safe_output["error"], - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - - callback_response = APIResponse.success_response( - data=response, metadata=request.request_metadata - ) - if request.callback_url: - send_callback( - callback_url=request.callback_url, - data=callback_response.model_dump(), - ) - with Session(engine) as session: - job_crud = JobCrud(session=session) - - # Update LLM call record with response data if llm_call_id: try: update_llm_call_response( @@ -448,34 +429,154 @@ def execute_job( usage=response.usage.model_dump(), conversation_id=response.response.conversation_id, ) - logger.info( - f"[execute_job] Updated LLM call record | llm_call_id={llm_call_id}" - ) except Exception as e: logger.error( - f"[execute_job] Failed to update LLM call record: {str(e)} | llm_call_id={llm_call_id}", + f"[execute_llm_call] Failed to update LLM call record: {e} | " + f"llm_call_id={llm_call_id}", exc_info=True, ) - # Don't fail the job if updating the record fails - job_crud.update( + return BlockResult( + response=response, + llm_call_id=llm_call_id, + usage=response.usage, + ) + + return BlockResult( + error=error or "Unknown error occurred", + llm_call_id=llm_call_id, + ) + + except Exception as e: + logger.error( + f"[execute_llm_call] Unexpected error: {e} | job_id={job_id}", + exc_info=True, + ) + return BlockResult( + error="Unexpected error occurred", + llm_call_id=llm_call_id, + ) + + +def execute_job( + request_data: dict, + project_id: int, + organization_id: int, + job_id: str, + task_id: str, + task_instance, +) -> dict: + """Celery task to process an LLM request asynchronously. + + Uses centralized functions: apply_input_guardrails, apply_output_guardrails, execute_llm_call. + """ + request = LLMCallRequest(**request_data) + job_id: UUID = UUID(job_id) + config = request.config + config_blob: ConfigBlob | None = None + + logger.info( + f"[execute_job] Starting LLM job execution | job_id={job_id}, task_id={task_id}" + ) + + try: + with Session(engine) as session: + job_crud = JobCrud(session=session) + job_crud.update( + job_id=job_id, job_update=JobUpdate(status=JobStatus.PROCESSING) + ) + + if config.is_stored_config: + config_crud = ConfigVersionCrud( + session=session, project_id=project_id, config_id=config.id + ) + config_blob, error = resolve_config_blob(config_crud, config) + if error: + callback_response = APIResponse.failure_response( + error=error, + metadata=request.request_metadata, + ) + return handle_job_error( + job_id, request.callback_url, callback_response + ) + else: + config_blob = config.blob + + request.query, input_error = apply_input_guardrails( + config_blob=config_blob, + query=request.query, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + ) + if input_error: + callback_response = APIResponse.failure_response( + error=input_error, + metadata=request.request_metadata, + ) + return handle_job_error(job_id, request.callback_url, callback_response) + + langfuse_credentials = get_provider_credential( + session=session, + org_id=organization_id, + project_id=project_id, + provider="langfuse", + ) + + result = execute_llm_call( + config=request.config, + query=request.query, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + request_metadata=request.request_metadata, + langfuse_credentials=langfuse_credentials, + include_provider_raw_response=request.include_provider_raw_response, + ) + + if result.success: + result, output_error = apply_output_guardrails( + config_blob=config_blob, + result=result, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + ) + if output_error: + callback_response = APIResponse.failure_response( + error=output_error, + metadata=request.request_metadata, + ) + return handle_job_error(job_id, request.callback_url, callback_response) + + callback_response = APIResponse.success_response( + data=result.response, metadata=request.request_metadata + ) + if request.callback_url: + send_callback( + callback_url=request.callback_url, + data=callback_response.model_dump(), + ) + + with Session(engine) as session: + JobCrud(session=session).update( job_id=job_id, job_update=JobUpdate(status=JobStatus.SUCCESS) ) logger.info( f"[execute_job] Successfully completed LLM job | job_id={job_id}, " - f"provider_response_id={response.response.provider_response_id}, tokens={response.usage.total_tokens}" + f"tokens={result.usage.total_tokens}" ) return callback_response.model_dump() callback_response = APIResponse.failure_response( - error=error or "Unknown error occurred", + error=result.error or "Unknown error occurred", metadata=request.request_metadata, ) return handle_job_error(job_id, request.callback_url, callback_response) except Exception as e: callback_response = APIResponse.failure_response( - error=f"Unexpected error occurred", + error="Unexpected error occurred", metadata=request.request_metadata, ) logger.error( @@ -483,3 +584,108 @@ def execute_job( exc_info=True, ) return handle_job_error(job_id, request.callback_url, callback_response) + + +def execute_chain_job( + request_data: dict, + project_id: int, + organization_id: int, + job_id: str, + task_id: str, + task_instance, +) -> dict: + """Celery task entry point for LLM chain execution.""" + # imports to avoid circular dependency: + from app.services.llm.chain.chain import ChainBlock, ChainContext, LLMChain + from app.services.llm.chain.executor import ChainExecutor + + request = LLMChainRequest(**request_data) + job_uuid = UUID(job_id) + chain_uuid = None + + logger.info( + f"[execute_chain_job] Starting chain execution | " + f"job_id={job_uuid}, total_blocks={len(request.blocks)}" + ) + + try: + with Session(engine) as session: + chain_record = create_llm_chain( + session, + job_id=job_uuid, + project_id=project_id, + organization_id=organization_id, + total_blocks=len(request.blocks), + input=serialize_input(request.query.input), + configs=[block.model_dump(mode="json") for block in request.blocks], + ) + chain_uuid = chain_record.id + + logger.info( + f"[execute_chain_job] Created chain record | " + f"chain_id={chain_uuid}, job_id={job_uuid}" + ) + + langfuse_credentials = get_provider_credential( + session=session, + org_id=organization_id, + project_id=project_id, + provider="langfuse", + ) + + context = ChainContext( + job_id=job_uuid, + chain_id=chain_uuid, + project_id=project_id, + organization_id=organization_id, + langfuse_credentials=langfuse_credentials, + request_metadata=request.request_metadata, + total_blocks=len(request.blocks), + callback_url=str(request.callback_url) if request.callback_url else None, + intermediate_callback_flags=[ + block.intermediate_callback for block in request.blocks + ], + ) + + blocks = [ + ChainBlock( + config=block.config, + index=i, + context=context, + include_provider_raw_response=block.include_provider_raw_response, + ) + for i, block in enumerate(request.blocks) + ] + + chain = LLMChain(blocks) + + executor = ChainExecutor(chain=chain, context=context, request=request) + return executor.run() + + except Exception as e: + logger.error( + f"[execute_chain_job] Failed: {e} | job_id={job_uuid}", + exc_info=True, + ) + + if chain_uuid: + try: + with Session(engine) as session: + update_llm_chain_status( + session, + chain_id=chain_uuid, + status=ChainStatus.FAILED, + error=str(e), + ) + except Exception: + logger.error( + f"[execute_chain_job] Failed to update chain status: {e} | " + f"chain_id={chain_uuid}", + exc_info=True, + ) + + callback_response = APIResponse.failure_response( + error="Unexpected error occurred", + metadata=request.request_metadata, + ) + return handle_job_error(job_uuid, request.callback_url, callback_response) From 6451bb03c45017aa71c78a389384adff84e427fa Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Sat, 21 Feb 2026 15:54:59 +0530 Subject: [PATCH 02/15] LLM Chain: Add documentation and update endpoint description for chain execution --- backend/app/api/docs/llm/llm_chain.md | 60 +++++++++++++++++++++++++++ backend/app/api/routes/llm_chain.py | 2 +- backend/app/services/llm/jobs.py | 11 +++-- 3 files changed, 68 insertions(+), 5 deletions(-) create mode 100644 backend/app/api/docs/llm/llm_chain.md diff --git a/backend/app/api/docs/llm/llm_chain.md b/backend/app/api/docs/llm/llm_chain.md new file mode 100644 index 000000000..d6c17893c --- /dev/null +++ b/backend/app/api/docs/llm/llm_chain.md @@ -0,0 +1,60 @@ +Execute a chain of LLM calls sequentially, where each block's output becomes the next block's input. + +This endpoint initiates an asynchronous LLM chain job. The request is queued +for processing, and results are delivered via the callback URL when complete. + +### Key Parameters + +**`query`** (required) - Initial query input for the first block in the chain: +- `input` (required, string, min 1 char): User question/prompt/query +- `conversation` (optional, object): Conversation configuration + - `id` (optional, string): Existing conversation ID to continue + - `auto_create` (optional, boolean, default false): Create new conversation if no ID provided + - **Note**: Cannot specify both `id` and `auto_create=true` + + +**`blocks`** (required, array, min 1 block) - Ordered list of blocks to execute sequentially. Each block contains: + +- `config` (required) - Configuration for this block's LLM call (just choose one mode): + + - **Mode 1: Stored Configuration** + - `id` (UUID): Configuration ID + - `version` (integer >= 1): Version number + - **Both required together** + - **Note**: When using stored configuration, do not include the `blob` field in the request body + + - **Mode 2: Ad-hoc Configuration** + - `blob` (object): Complete configuration object + - `completion` (required, object): Completion configuration + - `provider` (required, string): Provider type - either `"openai"` (Kaapi abstraction) or `"openai-native"` (pass-through) + - `params` (required, object): Parameters structure depends on provider type (see schema for detailed structure) + - `prompt_template` (optional, object): Template for text interpolation + - `template` (required, string): Template string with `{{input}}` placeholder — replaced with the block's input before execution + - **Note** + - When using ad-hoc configuration, do not include `id` and `version` fields + - When using the Kaapi abstraction, parameters that are not supported by the selected provider or model are automatically suppressed. If any parameters are ignored, a list of warnings is included in the metadata.warnings. + - **Recommendation**: Use stored configs (Mode 1) for production; use ad-hoc configs only for testing/validation + - **Schema**: Check the API schema or examples below for the complete parameter structure for each provider type + +- `include_provider_raw_response` (optional, boolean, default false): + - When true, includes the unmodified raw response from the LLM provider for this block + +- `intermediate_callback` (optional, boolean, default false): + - When true, sends an intermediate callback after this block completes with the block's response, usage, and position in the chain + +**`callback_url`** (optional, HTTPS URL): +- Webhook endpoint to receive the final response and intermediate callbacks +- Must be a valid HTTPS URL +- If not provided, response is only accessible through job status + +**`request_metadata`** (optional, object): +- Custom JSON metadata +- Passed through unchanged in the response + +### Note +- Input guardrails from the first block's config are applied before chain execution starts +- Output guardrails from the last block's config are applied after all blocks complete +- If any block fails, the chain stops immediately — no subsequent blocks are executed +- `warnings` list is automatically added in response metadata when using Kaapi configs if any parameters are suppressed or adjusted (e.g., temperature on reasoning models) + +--- diff --git a/backend/app/api/routes/llm_chain.py b/backend/app/api/routes/llm_chain.py index 0634c2038..92a3cdb4d 100644 --- a/backend/app/api/routes/llm_chain.py +++ b/backend/app/api/routes/llm_chain.py @@ -31,7 +31,7 @@ def llm_callback_notification(body: APIResponse[LLMChainResponse]): @router.post( "/llm/chain", - description=load_description("llm/llm_call.md"), + description=load_description("llm/llm_chain.md"), response_model=APIResponse[Message], callbacks=llm_callback_router.routes, dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index cd71e5bfa..a68ab6426 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -293,8 +293,6 @@ def apply_output_guardrails( if safe["success"]: result.response.response.output.content.value = safe["data"]["safe_text"] - if safe["data"].get("rephrase_needed"): - return result, result.response.response.output.content.value return result, None return result, safe["error"] @@ -468,7 +466,8 @@ def execute_job( ) -> dict: """Celery task to process an LLM request asynchronously. - Uses centralized functions: apply_input_guardrails, apply_output_guardrails, execute_llm_call. + Returns: + dict: Serialized APIResponse[LLMCallResponse] on success, APIResponse[None] on failure """ request = LLMCallRequest(**request_data) job_id: UUID = UUID(job_id) @@ -594,7 +593,11 @@ def execute_chain_job( task_id: str, task_instance, ) -> dict: - """Celery task entry point for LLM chain execution.""" + """Celery task to process an LLM Chain request asynchronously. + + Returns: + dict: Serialized APIResponse[LLMChainResponse] on success, APIResponse[None] on failure + """ # imports to avoid circular dependency: from app.services.llm.chain.chain import ChainBlock, ChainContext, LLMChain from app.services.llm.chain.executor import ChainExecutor From c9f94e257552c3513d17c2f15ca880a9b4dd60e0 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Sat, 21 Feb 2026 16:09:58 +0530 Subject: [PATCH 03/15] LLM Chain: Move guardrails into execute_llm_call for per-block support and eliminate code duplication --- backend/app/services/llm/chain/executor.py | 69 +-------------------- backend/app/services/llm/jobs.py | 71 +++++++--------------- 2 files changed, 24 insertions(+), 116 deletions(-) diff --git a/backend/app/services/llm/chain/executor.py b/backend/app/services/llm/chain/executor.py index 78808d84c..25208aad6 100644 --- a/backend/app/services/llm/chain/executor.py +++ b/backend/app/services/llm/chain/executor.py @@ -3,23 +3,16 @@ from sqlmodel import Session from app.core.db import engine -from app.crud.config import ConfigVersionCrud from app.crud.jobs import JobCrud from app.crud.llm_chain import update_llm_chain_status from app.models import JobStatus, JobUpdate from app.models.llm.request import ( ChainStatus, - ConfigBlob, LLMChainRequest, ) from app.models.llm.response import LLMChainResponse from app.services.llm.chain.chain import ChainContext, LLMChain from app.services.llm.chain.types import BlockResult -from app.services.llm.jobs import ( - apply_input_guardrails, - apply_output_guardrails, - resolve_config_blob, -) from app.utils import APIResponse, send_callback logger = logging.getLogger(__name__) @@ -44,73 +37,13 @@ def run(self) -> dict: try: self._setup() - first_config_blob, resolve_error = self._resolve_block_config_blob(0) - if resolve_error: - return self._handle_error(resolve_error) - - query, error = apply_input_guardrails( - config_blob=first_config_blob, - query=self._request.query, - job_id=self._context.job_id, - project_id=self._context.project_id, - organization_id=self._context.organization_id, - ) - if error: - return self._handle_error(error) - - result = self._chain.execute(query) - - if result.success: - last_config_blob, resolve_error = self._resolve_block_config_blob( - len(self._request.blocks) - 1 - ) - if resolve_error: - return self._handle_error(resolve_error) - - result, error = apply_output_guardrails( - config_blob=last_config_blob, - result=result, - job_id=self._context.job_id, - project_id=self._context.project_id, - organization_id=self._context.organization_id, - ) - if error: - return self._handle_error(error) + result = self._chain.execute(self._request.query) return self._teardown(result) except Exception as e: return self._handle_unexpected_error(e) - def _resolve_block_config_blob( - self, block_index: int - ) -> tuple[ConfigBlob | None, str | None]: - """Resolve a block's config to its ConfigBlob. - - Uses is_stored_config property (same pattern as execute_job in jobs.py): - - Stored config (is_stored_config=True): fetch from DB via resolve_config_blob() - - Ad-hoc config (blob provided): return blob directly - - Returns: - (config_blob, error): ConfigBlob on success, or error string on failure - """ - block = self._request.blocks[block_index] - config = block.config - - if not config.is_stored_config: - return config.blob, None - - with Session(engine) as session: - config_crud = ConfigVersionCrud( - session=session, - project_id=self._context.project_id, - config_id=config.id, - ) - config_blob, error = resolve_config_blob(config_crud, config) - if error: - return None, error - return config_blob, None - def _setup(self) -> None: with Session(engine) as session: JobCrud(session).update( diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index a68ab6426..196bdd60b 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -335,6 +335,16 @@ def execute_llm_call( interpolated = template.replace("{{input}}", query.input.content.value) query.input.content.value = interpolated + query, input_error = apply_input_guardrails( + config_blob=config_blob, + query=query, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + ) + if input_error: + return BlockResult(error=input_error) + completion_config = config_blob.completion original_provider = completion_config.provider @@ -433,13 +443,24 @@ def execute_llm_call( f"llm_call_id={llm_call_id}", exc_info=True, ) - - return BlockResult( + result = BlockResult( response=response, llm_call_id=llm_call_id, usage=response.usage, ) + result, output_error = apply_output_guardrails( + config_blob=config_blob, + result=result, + job_id=job_id, + project_id=project_id, + organization_id=organization_id, + ) + if output_error: + return BlockResult(error=output_error, llm_call_id=llm_call_id) + + return result + return BlockResult( error=error or "Unknown error occurred", llm_call_id=llm_call_id, @@ -471,8 +492,6 @@ def execute_job( """ request = LLMCallRequest(**request_data) job_id: UUID = UUID(job_id) - config = request.config - config_blob: ConfigBlob | None = None logger.info( f"[execute_job] Starting LLM job execution | job_id={job_id}, task_id={task_id}" @@ -485,36 +504,6 @@ def execute_job( job_id=job_id, job_update=JobUpdate(status=JobStatus.PROCESSING) ) - if config.is_stored_config: - config_crud = ConfigVersionCrud( - session=session, project_id=project_id, config_id=config.id - ) - config_blob, error = resolve_config_blob(config_crud, config) - if error: - callback_response = APIResponse.failure_response( - error=error, - metadata=request.request_metadata, - ) - return handle_job_error( - job_id, request.callback_url, callback_response - ) - else: - config_blob = config.blob - - request.query, input_error = apply_input_guardrails( - config_blob=config_blob, - query=request.query, - job_id=job_id, - project_id=project_id, - organization_id=organization_id, - ) - if input_error: - callback_response = APIResponse.failure_response( - error=input_error, - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) - langfuse_credentials = get_provider_credential( session=session, org_id=organization_id, @@ -534,20 +523,6 @@ def execute_job( ) if result.success: - result, output_error = apply_output_guardrails( - config_blob=config_blob, - result=result, - job_id=job_id, - project_id=project_id, - organization_id=organization_id, - ) - if output_error: - callback_response = APIResponse.failure_response( - error=output_error, - metadata=request.request_metadata, - ) - return handle_job_error(job_id, request.callback_url, callback_response) - callback_response = APIResponse.success_response( data=result.response, metadata=request.request_metadata ) From baaac95920072e6b1b8a761827104705510ce034 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Sun, 1 Mar 2026 20:29:49 +0530 Subject: [PATCH 04/15] prettify format --- backend/app/services/llm/jobs.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 6888fe99a..4142f690d 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -493,10 +493,7 @@ def execute_llm_call( include_provider_raw_response=include_provider_raw_response, ) except ValueError as ve: - return BlockResult( - error=str(ve), - llm_call_id=llm_call_id - ) + return BlockResult(error=str(ve), llm_call_id=llm_call_id) if response: with Session(engine) as session: From 5177bfb6b0100ee4638ea99e152a082093c93fe4 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Sun, 1 Mar 2026 22:24:05 +0530 Subject: [PATCH 05/15] refactor: update STTLLMParams to allow optional instructions and improve callback logic in ChainContext --- backend/app/models/llm/request.py | 2 +- backend/app/services/llm/chain/chain.py | 1 + backend/app/services/llm/jobs.py | 7 ++----- 3 files changed, 4 insertions(+), 6 deletions(-) diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 58c4529a6..ed7f08b55 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -44,7 +44,7 @@ class TextLLMParams(SQLModel): class STTLLMParams(SQLModel): model: str - instructions: str + instructions: str | None = None input_language: str | None = None output_language: str | None = None response_format: Literal["text"] | None = Field( diff --git a/backend/app/services/llm/chain/chain.py b/backend/app/services/llm/chain/chain.py index 390247d8d..98457fff6 100644 --- a/backend/app/services/llm/chain/chain.py +++ b/backend/app/services/llm/chain/chain.py @@ -70,6 +70,7 @@ def on_block_completed(self, block_index: int, result: BlockResult) -> None: block_index < len(self.intermediate_callback_flags) and self.intermediate_callback_flags[block_index] and self.callback_url + and block_index < self.total_blocks - 1 ): self._send_intermediate_callback(block_index, result) diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 4142f690d..3efdbbc27 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -474,10 +474,7 @@ def execute_llm_call( if query.conversation and query.conversation.id: conversation_id = query.conversation.id - resolved_input, resolve_error = resolve_input(query.input) - if resolve_error: - return BlockResult(error=resolve_error, llm_call_id=llm_call_id) - + # Apply Langfuse observability decorator to provider execute method decorated_execute = observe_llm_execution( credentials=langfuse_credentials, session_id=conversation_id, @@ -485,7 +482,7 @@ def execute_llm_call( # Resolve input and execute LLM (context manager handles cleanup) try: - with resolved_input_context(query) as resolved_input: + with resolved_input_context(query.input) as resolved_input: response, error = decorated_execute( completion_config=completion_config, query=query, From 2fb81b1b49d84941d3d4d2b94e4c4abfb2eb2308 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Sun, 1 Mar 2026 22:36:32 +0530 Subject: [PATCH 06/15] feat: add metadata to BlockResult and update job execution to use result metadata --- backend/app/services/llm/chain/types.py | 1 + backend/app/services/llm/jobs.py | 3 ++- backend/app/tests/services/llm/test_jobs.py | 27 ++++++++++++--------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/backend/app/services/llm/chain/types.py b/backend/app/services/llm/chain/types.py index 69ab3d02f..7fa0f39d8 100644 --- a/backend/app/services/llm/chain/types.py +++ b/backend/app/services/llm/chain/types.py @@ -12,6 +12,7 @@ class BlockResult: llm_call_id: UUID | None = None usage: Usage | None = None error: str | None = None + metadata: dict | None = None @property def success(self) -> bool: diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 3efdbbc27..62d4325cc 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -514,6 +514,7 @@ def execute_llm_call( response=response, llm_call_id=llm_call_id, usage=response.usage, + metadata=request_metadata, ) result, output_error = apply_output_guardrails( @@ -592,7 +593,7 @@ def execute_job( if result.success: callback_response = APIResponse.success_response( - data=result.response, metadata=request.request_metadata + data=result.response, metadata=result.metadata ) if callback_url_str: send_callback( diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index 60456e00b..f448be9b2 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -367,7 +367,7 @@ def test_exception_during_execution( result = self._execute_job(job_for_execution, db, request_data) assert not result["success"] - assert "Unexpected error during LLM execution" in result["error"] + assert "Unexpected error occurred" in result["error"] def test_exception_during_provider_retrieval( self, db, job_env, job_for_execution, request_data @@ -1108,16 +1108,21 @@ def test_execute_job_fetches_validator_configs_from_blob_refs( result = self._execute_job(job_for_execution, db, request_data) assert result["success"] - mock_fetch_configs.assert_called_once() - _, kwargs = mock_fetch_configs.call_args - input_validator_configs = kwargs["input_validator_configs"] - output_validator_configs = kwargs["output_validator_configs"] - assert [v.validator_config_id for v in input_validator_configs] == [ - UUID(VALIDATOR_CONFIG_ID_1) - ] - assert [v.validator_config_id for v in output_validator_configs] == [ - UUID(VALIDATOR_CONFIG_ID_2) - ] + assert mock_fetch_configs.call_count == 2 + + # First call: input guardrails + _, input_kwargs = mock_fetch_configs.call_args_list[0] + assert [ + v.validator_config_id for v in input_kwargs["input_validator_configs"] + ] == [UUID(VALIDATOR_CONFIG_ID_1)] + assert input_kwargs["output_validator_configs"] is None + + # Second call: output guardrails + _, output_kwargs = mock_fetch_configs.call_args_list[1] + assert output_kwargs["input_validator_configs"] is None + assert [ + v.validator_config_id for v in output_kwargs["output_validator_configs"] + ] == [UUID(VALIDATOR_CONFIG_ID_2)] def test_execute_job_continues_when_no_validator_configs_resolved( self, db, job_env, job_for_execution From 113488a6bd4739c9ecbae28e3eb41a9ce40e7dce Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Mon, 2 Mar 2026 10:43:21 +0530 Subject: [PATCH 07/15] feat: add tests for LLM chain execution and job handling --- backend/app/tests/crud/test_llm_chain.py | 153 ++++++++ backend/app/tests/services/llm/test_chain.py | 356 ++++++++++++++++++ .../tests/services/llm/test_chain_executor.py | 215 +++++++++++ backend/app/tests/services/llm/test_jobs.py | 206 +++++++++- 4 files changed, 929 insertions(+), 1 deletion(-) create mode 100644 backend/app/tests/crud/test_llm_chain.py create mode 100644 backend/app/tests/services/llm/test_chain.py create mode 100644 backend/app/tests/services/llm/test_chain_executor.py diff --git a/backend/app/tests/crud/test_llm_chain.py b/backend/app/tests/crud/test_llm_chain.py new file mode 100644 index 000000000..84324f86c --- /dev/null +++ b/backend/app/tests/crud/test_llm_chain.py @@ -0,0 +1,153 @@ +import pytest +from uuid import uuid4 + +from sqlmodel import Session + +from app.crud import JobCrud +from app.crud.llm_chain import ( + create_llm_chain, + update_llm_chain_status, + update_llm_chain_block_completed, +) +from app.models import JobType +from app.models.llm.request import ChainStatus +from app.tests.utils.utils import get_project + + +class TestCreateLlmChain: + def test_creates_chain_record(self, db: Session): + project = get_project(db) + job = JobCrud(session=db).create( + job_type=JobType.LLM_CHAIN, trace_id="test-trace" + ) + db.commit() + + chain = create_llm_chain( + db, + job_id=job.id, + project_id=project.id, + organization_id=project.organization_id, + total_blocks=3, + input="Test input", + configs=[{"completion": {"provider": "openai-native"}}], + ) + + assert chain.id is not None + assert chain.job_id == job.id + assert chain.project_id == project.id + assert chain.status == ChainStatus.PENDING + assert chain.total_blocks == 3 + assert chain.number_of_blocks_processed == 0 + assert chain.input == "Test input" + assert chain.block_sequences == [] + + +class TestUpdateLlmChainStatus: + @pytest.fixture + def chain(self, db: Session): + project = get_project(db) + job = JobCrud(session=db).create( + job_type=JobType.LLM_CHAIN, trace_id="test-trace" + ) + db.commit() + chain = create_llm_chain( + db, + job_id=job.id, + project_id=project.id, + organization_id=project.organization_id, + total_blocks=2, + input="hello", + configs=[], + ) + return chain + + def test_update_to_running(self, db: Session, chain): + updated = update_llm_chain_status( + db, chain_id=chain.id, status=ChainStatus.RUNNING + ) + + assert updated.status == ChainStatus.RUNNING + assert updated.started_at is not None + + def test_update_to_completed(self, db: Session, chain): + output = {"type": "text", "content": {"value": "result"}} + usage = {"input_tokens": 10, "output_tokens": 20, "total_tokens": 30} + + updated = update_llm_chain_status( + db, + chain_id=chain.id, + status=ChainStatus.COMPLETED, + output=output, + total_usage=usage, + ) + + assert updated.status == ChainStatus.COMPLETED + assert updated.output == output + assert updated.total_usage == usage + assert updated.completed_at is not None + + def test_update_to_failed(self, db: Session, chain): + usage = {"input_tokens": 5, "output_tokens": 0, "total_tokens": 5} + + updated = update_llm_chain_status( + db, + chain_id=chain.id, + status=ChainStatus.FAILED, + error="Provider timeout", + total_usage=usage, + ) + + assert updated.status == ChainStatus.FAILED + assert updated.error == "Provider timeout" + assert updated.total_usage == usage + assert updated.completed_at is not None + + def test_raises_for_missing_chain(self, db: Session): + with pytest.raises(ValueError, match="LLM chain not found"): + update_llm_chain_status(db, chain_id=uuid4(), status=ChainStatus.RUNNING) + + +class TestUpdateLlmChainBlockCompleted: + @pytest.fixture + def chain(self, db: Session): + project = get_project(db) + job = JobCrud(session=db).create( + job_type=JobType.LLM_CHAIN, trace_id="test-trace" + ) + db.commit() + chain = create_llm_chain( + db, + job_id=job.id, + project_id=project.id, + organization_id=project.organization_id, + total_blocks=3, + input="hello", + configs=[], + ) + return chain + + def test_appends_llm_call_id(self, db: Session, chain): + call_id = uuid4() + + updated = update_llm_chain_block_completed( + db, chain_id=chain.id, llm_call_id=call_id + ) + + assert str(call_id) in updated.block_sequences + assert updated.number_of_blocks_processed == 1 + + def test_appends_multiple_blocks(self, db: Session, chain): + call_id_1 = uuid4() + call_id_2 = uuid4() + + update_llm_chain_block_completed(db, chain_id=chain.id, llm_call_id=call_id_1) + updated = update_llm_chain_block_completed( + db, chain_id=chain.id, llm_call_id=call_id_2 + ) + + assert len(updated.block_sequences) == 2 + assert updated.number_of_blocks_processed == 2 + + def test_raises_for_missing_chain(self, db: Session): + with pytest.raises(ValueError, match="LLM chain not found"): + update_llm_chain_block_completed(db, chain_id=uuid4(), llm_call_id=uuid4()) diff --git a/backend/app/tests/services/llm/test_chain.py b/backend/app/tests/services/llm/test_chain.py new file mode 100644 index 000000000..d93380d84 --- /dev/null +++ b/backend/app/tests/services/llm/test_chain.py @@ -0,0 +1,356 @@ +from unittest.mock import patch, MagicMock +from uuid import uuid4 + +import pytest + +from app.models.llm.request import ( + LLMCallConfig, + ConfigBlob, + NativeCompletionConfig, + QueryParams, + TextInput, + TextContent, + AudioInput, +) +from app.models.llm.response import ( + LLMCallResponse, + LLMResponse, + Usage, + TextOutput, + TextContent as ResponseTextContent, + AudioOutput, + AudioContent, +) +from app.services.llm.chain.chain import ( + ChainBlock, + ChainContext, + LLMChain, + result_to_query, +) +from app.services.llm.chain.types import BlockResult + + +@pytest.fixture +def context(): + return ChainContext( + job_id=uuid4(), + chain_id=uuid4(), + project_id=1, + organization_id=1, + callback_url="https://example.com/callback", + total_blocks=3, + intermediate_callback_flags=[True, True, False], + ) + + +@pytest.fixture +def text_response(): + return LLMCallResponse( + response=LLMResponse( + provider_response_id="resp-1", + conversation_id=None, + model="gpt-4", + provider="openai", + output=TextOutput(content=ResponseTextContent(value="Hello world")), + ), + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + provider_raw_response=None, + ) + + +@pytest.fixture +def audio_response(): + return LLMCallResponse( + response=LLMResponse( + provider_response_id="resp-2", + conversation_id=None, + model="gemini", + provider="google", + output=AudioOutput( + content=AudioContent( + format="base64", + value="audio-data-base64", + mime_type="audio/wav", + ) + ), + ), + usage=Usage(input_tokens=5, output_tokens=15, total_tokens=20), + provider_raw_response=None, + ) + + +def make_config(): + return LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + type="text", + params={"model": "gpt-4"}, + ) + ) + ) + + +class TestResultToQuery: + def test_text_output_to_query(self, text_response): + result = BlockResult(response=text_response, usage=text_response.usage) + + query = result_to_query(result) + + assert isinstance(query.input, TextInput) + assert query.input.content.value == "Hello world" + + def test_audio_output_to_query(self, audio_response): + result = BlockResult(response=audio_response, usage=audio_response.usage) + + query = result_to_query(result) + + assert isinstance(query.input, AudioInput) + assert query.input.content.value == "audio-data-base64" + + def test_unsupported_output_type_raises(self): + mock_response = MagicMock() + mock_response.response.output.type = "unknown" + mock_response.response.output.__class__ = type("Unknown", (), {}) + result = BlockResult(response=mock_response, usage=MagicMock()) + + with pytest.raises(ValueError, match="Cannot chain output type"): + result_to_query(result) + + +class TestChainContext: + def test_aggregates_usage(self, context): + usage = Usage(input_tokens=10, output_tokens=20, total_tokens=30) + result = BlockResult( + response=MagicMock(), llm_call_id=uuid4(), usage=usage, error=None + ) + + with patch("app.services.llm.chain.chain.Session"): + context.on_block_completed(0, result) + + assert context.aggregated_usage.input_tokens == 10 + assert context.aggregated_usage.output_tokens == 20 + assert context.aggregated_usage.total_tokens == 30 + + def test_aggregates_usage_across_blocks(self, context): + usage1 = Usage(input_tokens=10, output_tokens=20, total_tokens=30) + usage2 = Usage(input_tokens=5, output_tokens=15, total_tokens=20) + + result1 = BlockResult( + response=MagicMock(), llm_call_id=uuid4(), usage=usage1, error=None + ) + result2 = BlockResult( + response=MagicMock(), llm_call_id=uuid4(), usage=usage2, error=None + ) + + with patch("app.services.llm.chain.chain.Session"): + context.on_block_completed(0, result1) + context.on_block_completed(1, result2) + + assert context.aggregated_usage.input_tokens == 15 + assert context.aggregated_usage.total_tokens == 50 + + def test_updates_db_on_success(self, context): + llm_call_id = uuid4() + result = BlockResult( + response=MagicMock(), llm_call_id=llm_call_id, usage=MagicMock(), error=None + ) + + with patch("app.services.llm.chain.chain.Session") as mock_session, patch( + "app.services.llm.chain.chain.update_llm_chain_block_completed" + ) as mock_update: + mock_session.return_value.__enter__.return_value = MagicMock() + context.on_block_completed(0, result) + + mock_update.assert_called_once_with( + mock_session.return_value.__enter__.return_value, + chain_id=context.chain_id, + llm_call_id=llm_call_id, + ) + + def test_sends_intermediate_callback(self, context, text_response): + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.chain.Session") as mock_session, + patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), + patch("app.services.llm.chain.chain.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + context.on_block_completed(0, result) + + mock_callback.assert_called_once() + call_kwargs = mock_callback.call_args[1] + assert call_kwargs["callback_url"] == "https://example.com/callback" + + def test_skips_intermediate_callback_for_last_block(self, context, text_response): + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.chain.Session") as mock_session, + patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), + patch("app.services.llm.chain.chain.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + # Block index 2 = last block (total_blocks=3) + context.on_block_completed(2, result) + + mock_callback.assert_not_called() + + def test_skips_intermediate_callback_when_flag_false(self, context, text_response): + context.intermediate_callback_flags = [False, True, False] + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.chain.Session") as mock_session, + patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), + patch("app.services.llm.chain.chain.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + context.on_block_completed(0, result) + + mock_callback.assert_not_called() + + def test_skips_db_update_on_error(self, context): + result = BlockResult(error="Block failed", usage=MagicMock()) + + with patch( + "app.services.llm.chain.chain.update_llm_chain_block_completed" + ) as mock_update: + context.on_block_completed(0, result) + mock_update.assert_not_called() + + def test_intermediate_callback_exception_is_swallowed(self, context, text_response): + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.chain.Session") as mock_session, + patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), + patch( + "app.services.llm.chain.chain.send_callback", + side_effect=Exception("Connection error"), + ), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + # Should not raise + context.on_block_completed(0, result) + + +class TestChainBlock: + def test_execute_single_block(self, context, text_response): + query = QueryParams(input="test input") + config = make_config() + block = ChainBlock(config=config, index=0, context=context) + + with patch( + "app.services.llm.chain.chain.execute_llm_call" + ) as mock_execute, patch.object(context, "on_block_completed"): + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + result = block.execute(query) + + assert result.success + mock_execute.assert_called_once() + + def test_execute_chains_to_next_block(self, context, text_response): + query = QueryParams(input="test input") + config = make_config() + block1 = ChainBlock(config=config, index=0, context=context) + block2 = ChainBlock(config=config, index=1, context=context) + block1.link(block2) + + with patch( + "app.services.llm.chain.chain.execute_llm_call" + ) as mock_execute, patch.object(context, "on_block_completed"): + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + result = block1.execute(query) + + assert mock_execute.call_count == 2 + + def test_execute_stops_on_failure(self, context): + query = QueryParams(input="test input") + config = make_config() + block1 = ChainBlock(config=config, index=0, context=context) + block2 = ChainBlock(config=config, index=1, context=context) + block1.link(block2) + + with patch( + "app.services.llm.chain.chain.execute_llm_call" + ) as mock_execute, patch.object(context, "on_block_completed"): + mock_execute.return_value = BlockResult(error="Provider error") + + result = block1.execute(query) + + assert not result.success + assert result.error == "Provider error" + mock_execute.assert_called_once() + + +class TestLLMChain: + def test_execute_empty_chain(self): + chain = LLMChain([]) + query = QueryParams(input="test") + + result = chain.execute(query) + + assert not result.success + assert result.error == "Chain has no blocks" + + def test_execute_single_block_chain(self, context, text_response): + config = make_config() + block = ChainBlock(config=config, index=0, context=context) + chain = LLMChain([block]) + + with patch( + "app.services.llm.chain.chain.execute_llm_call" + ) as mock_execute, patch.object(context, "on_block_completed"): + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + result = chain.execute(QueryParams(input="hello")) + + assert result.success + mock_execute.assert_called_once() + + def test_execute_multi_block_chain(self, context, text_response): + config = make_config() + blocks = [ChainBlock(config=config, index=i, context=context) for i in range(3)] + chain = LLMChain(blocks) + + with patch( + "app.services.llm.chain.chain.execute_llm_call" + ) as mock_execute, patch.object(context, "on_block_completed"): + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + result = chain.execute(QueryParams(input="hello")) + + assert result.success + assert mock_execute.call_count == 3 diff --git a/backend/app/tests/services/llm/test_chain_executor.py b/backend/app/tests/services/llm/test_chain_executor.py new file mode 100644 index 000000000..e8fdc31a9 --- /dev/null +++ b/backend/app/tests/services/llm/test_chain_executor.py @@ -0,0 +1,215 @@ +from unittest.mock import patch, MagicMock +from uuid import uuid4 + +import pytest + +from app.models.llm.request import ( + LLMChainRequest, + LLMCallConfig, + ConfigBlob, + NativeCompletionConfig, + QueryParams, + ChainStatus, +) +from app.models.llm.request import ChainBlock as ChainBlockModel +from app.models.llm.response import ( + LLMCallResponse, + LLMResponse, + Usage, + TextOutput, + TextContent, +) +from app.models import JobStatus +from app.services.llm.chain.chain import ChainBlock, ChainContext, LLMChain +from app.services.llm.chain.executor import ChainExecutor +from app.services.llm.chain.types import BlockResult + + +@pytest.fixture +def context(): + return ChainContext( + job_id=uuid4(), + chain_id=uuid4(), + project_id=1, + organization_id=1, + callback_url="https://example.com/callback", + total_blocks=1, + ) + + +@pytest.fixture +def request_obj(): + return LLMChainRequest( + query=QueryParams(input="hello"), + blocks=[ + ChainBlockModel( + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + type="text", + params={"model": "gpt-4"}, + ) + ) + ) + ) + ], + callback_url="https://example.com/callback", + ) + + +@pytest.fixture +def text_response(): + return LLMCallResponse( + response=LLMResponse( + provider_response_id="resp-1", + conversation_id=None, + model="gpt-4", + provider="openai", + output=TextOutput(content=TextContent(value="Response text")), + ), + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + provider_raw_response=None, + ) + + +@pytest.fixture +def success_result(text_response): + return BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + ) + + +@pytest.fixture +def failure_result(): + return BlockResult(error="Provider failed") + + +class TestChainExecutor: + def _make_executor(self, context, request_obj, chain_result): + mock_chain = MagicMock(spec=LLMChain) + mock_chain.execute.return_value = chain_result + return ChainExecutor(chain=mock_chain, context=context, request=request_obj) + + def test_run_success_with_callback(self, context, request_obj, success_result): + executor = self._make_executor(context, request_obj, success_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + patch( + "app.services.llm.chain.executor.update_llm_chain_status" + ) as mock_chain_status, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + assert result["success"] is True + mock_callback.assert_called_once() + # Verify chain status updated to COMPLETED + completed_call = [ + c + for c in mock_chain_status.call_args_list + if c[1].get("status") == ChainStatus.COMPLETED + ] + assert len(completed_call) == 1 + + def test_run_success_without_callback(self, context, request_obj, success_result): + request_obj.callback_url = None + context.callback_url = None + executor = self._make_executor(context, request_obj, success_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + patch("app.services.llm.chain.executor.update_llm_chain_status"), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + assert result["success"] is True + mock_callback.assert_not_called() + + def test_run_failure_updates_status(self, context, request_obj, failure_result): + executor = self._make_executor(context, request_obj, failure_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback"), + patch( + "app.services.llm.chain.executor.update_llm_chain_status" + ) as mock_chain_status, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + assert result["success"] is False + assert result["error"] == "Provider failed" + # Verify chain status updated to FAILED + failed_call = [ + c + for c in mock_chain_status.call_args_list + if c[1].get("status") == ChainStatus.FAILED + ] + assert len(failed_call) == 1 + + def test_run_failure_sends_callback(self, context, request_obj, failure_result): + executor = self._make_executor(context, request_obj, failure_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + patch("app.services.llm.chain.executor.update_llm_chain_status"), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + mock_callback.assert_called_once() + + def test_run_unexpected_exception_handled(self, context, request_obj): + mock_chain = MagicMock(spec=LLMChain) + mock_chain.execute.side_effect = RuntimeError("Something broke") + executor = ChainExecutor(chain=mock_chain, context=context, request=request_obj) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback"), + patch("app.services.llm.chain.executor.update_llm_chain_status"), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + result = executor.run() + + assert result["success"] is False + assert "Unexpected error occurred" in result["error"] + + def test_setup_updates_job_and_chain_status( + self, context, request_obj, success_result + ): + executor = self._make_executor(context, request_obj, success_result) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.send_callback"), + patch( + "app.services.llm.chain.executor.update_llm_chain_status" + ) as mock_chain_status, + patch("app.services.llm.chain.executor.JobCrud") as mock_job_crud, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + + executor.run() + + # _setup should set chain to RUNNING + running_calls = [ + c + for c in mock_chain_status.call_args_list + if c[1].get("status") == ChainStatus.RUNNING + ] + assert len(running_calls) == 1 diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index f448be9b2..8cef08e96 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -23,11 +23,14 @@ # KaapiLLMParams, KaapiCompletionConfig, ) -from app.models.llm.request import ConfigBlob, LLMCallConfig +from app.models.llm.request import ConfigBlob, LLMCallConfig, LLMChainRequest +from app.models.llm.request import ChainBlock as ChainBlockModel from app.services.llm.jobs import ( start_job, + start_chain_job, handle_job_error, execute_job, + execute_chain_job, resolve_config_blob, ) from app.tests.utils.utils import get_project @@ -1161,6 +1164,207 @@ def test_execute_job_continues_when_no_validator_configs_resolved( mock_guardrails.assert_not_called() +class TestStartChainJob: + """Test cases for the start_chain_job function.""" + + @pytest.fixture + def chain_request(self): + return LLMChainRequest( + query=QueryParams(input="Test query"), + blocks=[ + ChainBlockModel( + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="openai-native", + type="text", + params={"model": "gpt-4"}, + ) + ) + ) + ) + ], + ) + + def test_start_chain_job_success(self, db: Session, chain_request): + project = get_project(db) + + with ( + patch("app.services.llm.jobs.start_high_priority_job") as mock_schedule, + patch("app.services.llm.jobs.JobCrud") as mock_job_crud_class, + ): + mock_schedule.return_value = "fake-task-id" + mock_job = MagicMock() + mock_job.id = uuid4() + mock_job.job_type = JobType.LLM_CHAIN + mock_job.status = JobStatus.PENDING + mock_job_crud_class.return_value.create.return_value = mock_job + + job_id = start_chain_job( + db, chain_request, project.id, project.organization_id + ) + + assert job_id == mock_job.id + mock_schedule.assert_called_once() + _, kwargs = mock_schedule.call_args + assert kwargs["function_path"] == "app.services.llm.jobs.execute_chain_job" + + def test_start_chain_job_celery_failure(self, db: Session, chain_request): + project = get_project(db) + + with ( + patch("app.services.llm.jobs.start_high_priority_job") as mock_schedule, + patch("app.services.llm.jobs.JobCrud") as mock_job_crud_class, + ): + mock_schedule.side_effect = Exception("Celery connection failed") + mock_job = MagicMock() + mock_job.id = uuid4() + mock_job_crud_class.return_value.create.return_value = mock_job + + with pytest.raises(HTTPException) as exc_info: + start_chain_job(db, chain_request, project.id, project.organization_id) + + assert exc_info.value.status_code == 500 + assert "Internal server error while executing LLM chain job" in str( + exc_info.value.detail + ) + + +class TestExecuteChainJob: + """Test suite for execute_chain_job.""" + + @pytest.fixture + def chain_request_data(self): + return { + "query": {"input": "Test query"}, + "blocks": [ + { + "config": { + "blob": { + "completion": { + "provider": "openai-native", + "type": "text", + "params": {"model": "gpt-4"}, + } + } + }, + } + ], + } + + @pytest.fixture + def mock_llm_response(self): + return LLMCallResponse( + response=LLMResponse( + provider_response_id="resp-123", + conversation_id=None, + model="gpt-4", + provider="openai", + output=TextOutput(content=TextContent(value="Test response")), + ), + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + provider_raw_response=None, + ) + + def _execute_chain_job(self, request_data): + return execute_chain_job( + request_data=request_data, + project_id=1, + organization_id=1, + job_id=str(uuid4()), + task_id="task-123", + task_instance=None, + ) + + def test_success_flow(self, chain_request_data, mock_llm_response): + from app.services.llm.chain.types import BlockResult + + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch("app.services.llm.jobs.create_llm_chain") as mock_create_chain, + patch("app.services.llm.jobs.get_provider_credential") as mock_creds, + patch("app.services.llm.chain.executor.Session") as mock_executor_session, + patch("app.services.llm.chain.executor.send_callback"), + patch("app.services.llm.chain.executor.update_llm_chain_status"), + patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute_llm, + patch("app.services.llm.chain.chain.Session"), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + mock_executor_session.return_value.__enter__.return_value = MagicMock() + mock_executor_session.return_value.__exit__.return_value = None + + mock_chain_record = MagicMock() + mock_chain_record.id = uuid4() + mock_create_chain.return_value = mock_chain_record + mock_creds.return_value = None + + mock_execute_llm.return_value = BlockResult( + response=mock_llm_response, + llm_call_id=uuid4(), + usage=mock_llm_response.usage, + ) + + result = self._execute_chain_job(chain_request_data) + + assert result["success"] is True + + def test_exception_during_chain_creation(self, chain_request_data): + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch( + "app.services.llm.jobs.create_llm_chain", + side_effect=Exception("DB error"), + ), + patch("app.services.llm.jobs.handle_job_error") as mock_handle_error, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + mock_handle_error.return_value = { + "success": False, + "error": "Unexpected error occurred", + } + + result = self._execute_chain_job(chain_request_data) + + assert result["success"] is False + + def test_chain_status_updated_to_failed_on_error(self, chain_request_data): + chain_id = uuid4() + + with ( + patch("app.services.llm.jobs.Session") as mock_session, + patch("app.services.llm.jobs.create_llm_chain") as mock_create_chain, + patch("app.services.llm.jobs.get_provider_credential") as mock_creds, + patch( + "app.services.llm.jobs.update_llm_chain_status" + ) as mock_update_status, + patch("app.services.llm.jobs.handle_job_error") as mock_handle_error, + patch( + "app.services.llm.chain.chain.LLMChain", + side_effect=Exception("Chain init error"), + ), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + mock_session.return_value.__exit__.return_value = None + + mock_chain_record = MagicMock() + mock_chain_record.id = chain_id + mock_create_chain.return_value = mock_chain_record + mock_creds.return_value = None + mock_handle_error.return_value = { + "success": False, + "error": "Unexpected error occurred", + } + + result = self._execute_chain_job(chain_request_data) + + mock_update_status.assert_called_once() + _, kwargs = mock_update_status.call_args + assert kwargs["chain_id"] == chain_id + assert kwargs["status"].value == "failed" + + class TestResolveConfigBlob: """Test suite for resolve_config_blob function.""" From 64214656c9c60e32de2da4252283d83c1d7d2380 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Mon, 2 Mar 2026 21:36:37 +0530 Subject: [PATCH 08/15] fix: correct variable name from job_id to job_uuid in execute_job function --- backend/app/services/llm/jobs.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index b7a2f0226..221b8b32d 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -588,7 +588,7 @@ def execute_job( result = execute_llm_call( config=request.config, query=request.query, - job_id=job_id, + job_id=job_uuid, project_id=project_id, organization_id=organization_id, request_metadata=request.request_metadata, @@ -608,7 +608,7 @@ def execute_job( with Session(engine) as session: JobCrud(session=session).update( - job_id=job_id, job_update=JobUpdate(status=JobStatus.SUCCESS) + job_id=job_uuid, job_update=JobUpdate(status=JobStatus.SUCCESS) ) logger.info( f"[execute_job] Successfully completed LLM job | job_id={job_id}, " @@ -631,7 +631,7 @@ def execute_job( f"[execute_job] Unexpected error: {str(e)} | job_id={job_uuid}, task_id={task_id}", exc_info=True, ) - return handle_job_error(job_id, request.callback_url, callback_response) + return handle_job_error(job_uuid, callback_url_str, callback_response) def execute_chain_job( From 19d6f5888c77ed1969dfe70ea98626c00fecdf71 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 5 Mar 2026 15:18:53 +0530 Subject: [PATCH 09/15] refactor: streamline LLM chain execution and enhance callback handling --- backend/app/api/docs/llm/llm_chain.md | 2 - backend/app/models/llm/request.py | 14 +- backend/app/services/llm/chain/chain.py | 155 +++---------- backend/app/services/llm/chain/executor.py | 64 +++++- backend/app/services/llm/jobs.py | 67 +----- backend/app/tests/services/llm/test_chain.py | 217 ++++-------------- .../tests/services/llm/test_chain_executor.py | 166 ++++++++++++++ backend/app/tests/services/llm/test_jobs.py | 2 +- 8 files changed, 316 insertions(+), 371 deletions(-) diff --git a/backend/app/api/docs/llm/llm_chain.md b/backend/app/api/docs/llm/llm_chain.md index d6c17893c..1d17f24bf 100644 --- a/backend/app/api/docs/llm/llm_chain.md +++ b/backend/app/api/docs/llm/llm_chain.md @@ -52,8 +52,6 @@ for processing, and results are delivered via the callback URL when complete. - Passed through unchanged in the response ### Note -- Input guardrails from the first block's config are applied before chain execution starts -- Output guardrails from the last block's config are applied after all blocks complete - If any block fails, the chain stops immediately — no subsequent blocks are executed - `warnings` list is automatically added in response metadata when using Kaapi configs if any parameters are suppressed or adjusted (e.g., temperature on reasoning models) diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index eaa834440..cc8b11e81 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -1,19 +1,13 @@ +from datetime import datetime from enum import Enum from typing import Annotated, Any, Literal, Union - from uuid import UUID, uuid4 -from sqlmodel import Field, SQLModel -from pydantic import Discriminator, model_validator, HttpUrl -from datetime import datetime -from app.core.util import now import sqlalchemy as sa -from typing import Annotated, Any, List, Literal, Union -from uuid import UUID, uuid4 -from pydantic import model_validator, HttpUrl -from datetime import datetime +from pydantic import HttpUrl, model_validator from sqlalchemy.dialects.postgresql import JSONB -from sqlmodel import Field, SQLModel, Index, text +from sqlmodel import Field, Index, SQLModel, text + from app.core.util import now diff --git a/backend/app/services/llm/chain/chain.py b/backend/app/services/llm/chain/chain.py index 98457fff6..ad0503675 100644 --- a/backend/app/services/llm/chain/chain.py +++ b/backend/app/services/llm/chain/chain.py @@ -1,12 +1,8 @@ import logging from dataclasses import dataclass, field -from typing import Any +from typing import Any, Callable from uuid import UUID -from sqlmodel import Session - -from app.core.db import engine -from app.crud.llm_chain import update_llm_chain_block_completed from app.models.llm.request import ( LLMCallConfig, QueryParams, @@ -15,14 +11,12 @@ AudioInput, ) from app.models.llm.response import ( - IntermediateChainResponse, TextOutput, AudioOutput, Usage, ) from app.services.llm.chain.types import BlockResult from app.services.llm.jobs import execute_llm_call -from app.utils import APIResponse, send_callback logger = logging.getLogger(__name__) @@ -30,13 +24,13 @@ @dataclass class ChainContext: - """Shared state passed to all blocks. Accumulates responses.""" + """Shared state for chain execution.""" job_id: UUID chain_id: UUID project_id: int organization_id: int - callback_url: str + callback_url: str | None total_blocks: int langfuse_credentials: dict[str, Any] | None = None @@ -50,60 +44,6 @@ class ChainContext: ) ) - def on_block_completed(self, block_index: int, result: BlockResult) -> None: - """Called after each block completes. Updates chain state in DB and sends intermediate callback.""" - - if result.usage: - self.aggregated_usage.input_tokens += result.usage.input_tokens - self.aggregated_usage.output_tokens += result.usage.output_tokens - self.aggregated_usage.total_tokens += result.usage.total_tokens - - if result.success and result.llm_call_id: - with Session(engine) as session: - update_llm_chain_block_completed( - session, - chain_id=self.chain_id, - llm_call_id=result.llm_call_id, - ) - - if ( - block_index < len(self.intermediate_callback_flags) - and self.intermediate_callback_flags[block_index] - and self.callback_url - and block_index < self.total_blocks - 1 - ): - self._send_intermediate_callback(block_index, result) - - def _send_intermediate_callback( - self, block_index: int, result: BlockResult - ) -> None: - """Send intermediate callback for a completed block.""" - try: - intermediate = IntermediateChainResponse( - block_index=block_index + 1, - total_blocks=self.total_blocks, - response=result.response.response, - usage=result.usage, - provider_raw_response=result.response.provider_raw_response, - ) - callback_data = APIResponse.success_response( - data=intermediate, - metadata=self.request_metadata, - ) - send_callback( - callback_url=self.callback_url, - data=callback_data.model_dump(), - ) - logger.info( - f"[ChainContext] Sent intermediate callback | " - f"block={block_index + 1}/{self.total_blocks}, job_id={self.job_id}" - ) - except Exception as e: - logger.warning( - f"[ChainContext] Failed to send intermediate callback: {e} | " - f"block={block_index + 1}/{self.total_blocks}, job_id={self.job_id}" - ) - def result_to_query(result: BlockResult) -> QueryParams: """Convert a block's output into the next block's QueryParams. @@ -124,11 +64,7 @@ def result_to_query(result: BlockResult) -> QueryParams: class ChainBlock: - """A single node in the linked chain. - - Wraps execute_block() with linking capability. - Each block knows its next block and forwards output to it. - """ + """A single block in the chain. Only responsible for executing itself.""" def __init__( self, @@ -142,25 +78,15 @@ def __init__( self._index = index self._context = context self._include_provider_raw_response = include_provider_raw_response - self._next: ChainBlock | None = None - - def link(self, next_block: "ChainBlock") -> "ChainBlock": - """Link to the next block in the chain.""" - self._next = next_block - return next_block def execute(self, query: QueryParams) -> BlockResult: - """Execute this block, then flow to next. - - No loop. Each block calls the next via the linked reference. - Data flows through the chain like a linked list traversal. - """ + """Execute this block and return the result.""" logger.info( f"[ChainBlock.execute] Executing block {self._index} | " f"job_id={self._context.job_id}" ) - result = execute_llm_call( + return execute_llm_call( config=self._config, query=query, job_id=self._context.job_id, @@ -172,51 +98,40 @@ def execute(self, query: QueryParams) -> BlockResult: chain_id=self._context.chain_id, ) - self._context.on_block_completed(self._index, result) - if not result.success: - logger.error( - f"[ChainBlock.execute] Block {self._index} failed: {result.error} | " - f"job_id={self._context.job_id}" - ) - return result +class LLMChain: + """Orchestrates sequential execution of ChainBlocks.""" - if self._next: - next_query = result_to_query(result) - return self._next.execute(next_query) + def __init__(self, blocks: list[ChainBlock], context: ChainContext): + self._blocks = blocks + self._context = context - logger.info( - f"[ChainBlock.execute] Block {self._index} is the last block | " - f"job_id={self._context.job_id}" - ) - return result + def execute( + self, + query: QueryParams, + on_block_completed: Callable[[int, BlockResult], None] | None = None, + ) -> BlockResult: + """Execute blocks sequentially, passing output of each to the next.""" + if not self._blocks: + return BlockResult(error="Chain has no blocks") + current_query = query + result: BlockResult | None = None -class LLMChain: - """Links ChainBlocks together into a sequential chain. + for block in self._blocks: + result = block.execute(current_query) - Construction builds the linked structure. - Execution pushes input into the head — it flows through to the tail. - """ + if on_block_completed: + on_block_completed(block._index, result) - def __init__(self, blocks: list[ChainBlock]): - self._head: ChainBlock | None = None - self._tail: ChainBlock | None = None - self._link_blocks(blocks) - - def _link_blocks(self, blocks: list[ChainBlock]) -> None: - """Link all blocks in sequence.""" - if not blocks: - return - self._head = blocks[0] - self._tail = blocks[-1] - prev = blocks[0] - for curr in blocks[1:]: - prev.link(curr) - prev = curr + if not result.success: + logger.error( + f"[LLMChain.execute] Block {block._index} failed: {result.error} | " + f"job_id={self._context.job_id}" + ) + return result - def execute(self, query: QueryParams) -> BlockResult: - """Push input into the chain head. It flows through to the tail.""" - if not self._head: - return BlockResult(error="Chain has no blocks") - return self._head.execute(query) + if block is not self._blocks[-1]: + current_query = result_to_query(result) + + return result diff --git a/backend/app/services/llm/chain/executor.py b/backend/app/services/llm/chain/executor.py index 25208aad6..27ab8de86 100644 --- a/backend/app/services/llm/chain/executor.py +++ b/backend/app/services/llm/chain/executor.py @@ -4,13 +4,13 @@ from app.core.db import engine from app.crud.jobs import JobCrud -from app.crud.llm_chain import update_llm_chain_status +from app.crud.llm_chain import update_llm_chain_block_completed, update_llm_chain_status from app.models import JobStatus, JobUpdate from app.models.llm.request import ( ChainStatus, LLMChainRequest, ) -from app.models.llm.response import LLMChainResponse +from app.models.llm.response import IntermediateChainResponse, LLMChainResponse from app.services.llm.chain.chain import ChainContext, LLMChain from app.services.llm.chain.types import BlockResult from app.utils import APIResponse, send_callback @@ -37,7 +37,10 @@ def run(self) -> dict: try: self._setup() - result = self._chain.execute(self._request.query) + result = self._chain.execute( + self._request.query, + on_block_completed=self._on_block_completed, + ) return self._teardown(result) @@ -96,7 +99,7 @@ def _handle_error(self, error: str) -> dict: metadata=self._request.request_metadata, ) logger.error( - f"[ChainExecutor] Chain execution failed | " + f"[_handle_error] Chain execution failed | " f"chain_id={self._context.chain_id}, job_id={self._context.job_id}, error={error}" ) @@ -121,6 +124,59 @@ def _handle_error(self, error: str) -> dict: ) return callback_response.model_dump() + def _on_block_completed(self, block_index: int, result: BlockResult) -> None: + """Handle side effects after each block completes.""" + if result.usage: + self._context.aggregated_usage.input_tokens += result.usage.input_tokens + self._context.aggregated_usage.output_tokens += result.usage.output_tokens + self._context.aggregated_usage.total_tokens += result.usage.total_tokens + + if result.success and result.llm_call_id: + with Session(engine) as session: + update_llm_chain_block_completed( + session, + chain_id=self._context.chain_id, + llm_call_id=result.llm_call_id, + ) + + if ( + block_index < len(self._context.intermediate_callback_flags) + and self._context.intermediate_callback_flags[block_index] + and self._request.callback_url + and block_index < self._context.total_blocks - 1 + ): + self._send_intermediate_callback(block_index, result) + + def _send_intermediate_callback( + self, block_index: int, result: BlockResult + ) -> None: + """Send intermediate callback for a completed block.""" + try: + intermediate = IntermediateChainResponse( + block_index=block_index + 1, + total_blocks=self._context.total_blocks, + response=result.response.response, + usage=result.usage, + provider_raw_response=result.response.provider_raw_response, + ) + callback_data = APIResponse.success_response( + data=intermediate, + metadata=self._context.request_metadata, + ) + send_callback( + callback_url=str(self._request.callback_url), + data=callback_data.model_dump(), + ) + logger.info( + f"[_send_intermediate_callback] Sent intermediate callback | " + f"block={block_index + 1}/{self._context.total_blocks}, job_id={self._context.job_id}" + ) + except Exception as e: + logger.warning( + f"[_send_intermediate_callback] Failed to send intermediate callback: {e} | " + f"block={block_index + 1}/{self._context.total_blocks}, job_id={self._context.job_id}" + ) + def _handle_unexpected_error(self, e: Exception) -> dict: logger.error( f"[ChainExecutor.run] Unexpected error: {e} | " diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 221b8b32d..2a5f7dee2 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -2,6 +2,7 @@ from contextlib import contextmanager from typing import Any from uuid import UUID + from asgi_correlation_id import correlation_id from fastapi import HTTPException from sqlmodel import Session @@ -16,15 +17,15 @@ from app.crud.llm_chain import create_llm_chain, update_llm_chain_status from app.models import JobStatus, JobType, JobUpdate, LLMCallRequest, LLMChainRequest from app.models.llm.request import ( + AudioInput, ChainStatus, ConfigBlob, + ImageInput, KaapiCompletionConfig, LLMCallConfig, + PDFInput, QueryParams, TextInput, - AudioInput, - ImageInput, - PDFInput, ) from app.models.llm.response import TextOutput from app.services.llm.chain.types import BlockResult @@ -34,7 +35,7 @@ ) from app.services.llm.mappers import transform_kaapi_config_to_native from app.services.llm.providers.registry import get_llm_provider -from app.utils import APIResponse, send_callback, resolve_input, cleanup_temp_file +from app.utils import APIResponse, cleanup_temp_file, resolve_input, send_callback logger = logging.getLogger(__name__) @@ -172,55 +173,6 @@ def resolved_input_context( cleanup_temp_file(resolved_input) -def validate_text_with_guardrails( - text: str, - guardrails: list[dict[str, Any]], - job_id: UUID, - project_id: int, - organization_id: int, - guardrail_type: str, # "input" or "output" -) -> tuple[str | None, str | None]: - """Validate text against guardrails. - - Returns: - (validated_text, error_message) - - If successful: (modified_text, None) - - If failed: (None, error_message) - - If bypassed: (original_text, None) - """ - safe_result = run_guardrails_validation( - text, - guardrails, - job_id, - project_id, - organization_id, - suppress_pass_logs=True, - ) - - logger.info( - f"[validate_text_with_guardrails] {guardrail_type.capitalize()} guardrail validation | " - f"success={safe_result['success']}, job_id={job_id}" - ) - - if safe_result.get("bypassed"): - logger.info( - f"[validate_text_with_guardrails] Guardrails bypassed (service unavailable) | " - f"job_id={job_id}" - ) - return text, None - - if safe_result["success"]: - validated_text = safe_result["data"]["safe_text"] - - # Special case for output guardrails: check if rephrase is needed - if guardrail_type == "output" and safe_result["data"].get("rephrase_needed"): - return None, "Output requires rephrasing" - - return validated_text, None - - return None, safe_result["error"] - - def resolve_config_blob( config_crud: ConfigVersionCrud, config: LLMCallConfig ) -> tuple[ConfigBlob | None, str | None]: @@ -438,14 +390,14 @@ def execute_llm_call( ) try: - temp_request = LLMCallRequest( + llm_call_request = LLMCallRequest( query=query, config=config, request_metadata=request_metadata, ) llm_call = create_llm_call( session, - request=temp_request, + request=llm_call_request, job_id=job_id, project_id=project_id, organization_id=organization_id, @@ -653,6 +605,7 @@ def execute_chain_job( request = LLMChainRequest(**request_data) job_uuid = UUID(job_id) + callback_url_str = str(request.callback_url) if request.callback_url else None chain_uuid = None logger.info( @@ -709,7 +662,7 @@ def execute_chain_job( for i, block in enumerate(request.blocks) ] - chain = LLMChain(blocks) + chain = LLMChain(blocks, context) executor = ChainExecutor(chain=chain, context=context, request=request) return executor.run() @@ -740,4 +693,4 @@ def execute_chain_job( error="Unexpected error occurred", metadata=request.request_metadata, ) - return handle_job_error(job_uuid, request.callback_url, callback_response) + return handle_job_error(job_uuid, callback_url_str, callback_response) diff --git a/backend/app/tests/services/llm/test_chain.py b/backend/app/tests/services/llm/test_chain.py index d93380d84..5b5cfed3f 100644 --- a/backend/app/tests/services/llm/test_chain.py +++ b/backend/app/tests/services/llm/test_chain.py @@ -118,153 +118,13 @@ def test_unsupported_output_type_raises(self): result_to_query(result) -class TestChainContext: - def test_aggregates_usage(self, context): - usage = Usage(input_tokens=10, output_tokens=20, total_tokens=30) - result = BlockResult( - response=MagicMock(), llm_call_id=uuid4(), usage=usage, error=None - ) - - with patch("app.services.llm.chain.chain.Session"): - context.on_block_completed(0, result) - - assert context.aggregated_usage.input_tokens == 10 - assert context.aggregated_usage.output_tokens == 20 - assert context.aggregated_usage.total_tokens == 30 - - def test_aggregates_usage_across_blocks(self, context): - usage1 = Usage(input_tokens=10, output_tokens=20, total_tokens=30) - usage2 = Usage(input_tokens=5, output_tokens=15, total_tokens=20) - - result1 = BlockResult( - response=MagicMock(), llm_call_id=uuid4(), usage=usage1, error=None - ) - result2 = BlockResult( - response=MagicMock(), llm_call_id=uuid4(), usage=usage2, error=None - ) - - with patch("app.services.llm.chain.chain.Session"): - context.on_block_completed(0, result1) - context.on_block_completed(1, result2) - - assert context.aggregated_usage.input_tokens == 15 - assert context.aggregated_usage.total_tokens == 50 - - def test_updates_db_on_success(self, context): - llm_call_id = uuid4() - result = BlockResult( - response=MagicMock(), llm_call_id=llm_call_id, usage=MagicMock(), error=None - ) - - with patch("app.services.llm.chain.chain.Session") as mock_session, patch( - "app.services.llm.chain.chain.update_llm_chain_block_completed" - ) as mock_update: - mock_session.return_value.__enter__.return_value = MagicMock() - context.on_block_completed(0, result) - - mock_update.assert_called_once_with( - mock_session.return_value.__enter__.return_value, - chain_id=context.chain_id, - llm_call_id=llm_call_id, - ) - - def test_sends_intermediate_callback(self, context, text_response): - result = BlockResult( - response=text_response, - llm_call_id=uuid4(), - usage=text_response.usage, - error=None, - ) - - with ( - patch("app.services.llm.chain.chain.Session") as mock_session, - patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), - patch("app.services.llm.chain.chain.send_callback") as mock_callback, - ): - mock_session.return_value.__enter__.return_value = MagicMock() - context.on_block_completed(0, result) - - mock_callback.assert_called_once() - call_kwargs = mock_callback.call_args[1] - assert call_kwargs["callback_url"] == "https://example.com/callback" - - def test_skips_intermediate_callback_for_last_block(self, context, text_response): - result = BlockResult( - response=text_response, - llm_call_id=uuid4(), - usage=text_response.usage, - error=None, - ) - - with ( - patch("app.services.llm.chain.chain.Session") as mock_session, - patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), - patch("app.services.llm.chain.chain.send_callback") as mock_callback, - ): - mock_session.return_value.__enter__.return_value = MagicMock() - # Block index 2 = last block (total_blocks=3) - context.on_block_completed(2, result) - - mock_callback.assert_not_called() - - def test_skips_intermediate_callback_when_flag_false(self, context, text_response): - context.intermediate_callback_flags = [False, True, False] - result = BlockResult( - response=text_response, - llm_call_id=uuid4(), - usage=text_response.usage, - error=None, - ) - - with ( - patch("app.services.llm.chain.chain.Session") as mock_session, - patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), - patch("app.services.llm.chain.chain.send_callback") as mock_callback, - ): - mock_session.return_value.__enter__.return_value = MagicMock() - context.on_block_completed(0, result) - - mock_callback.assert_not_called() - - def test_skips_db_update_on_error(self, context): - result = BlockResult(error="Block failed", usage=MagicMock()) - - with patch( - "app.services.llm.chain.chain.update_llm_chain_block_completed" - ) as mock_update: - context.on_block_completed(0, result) - mock_update.assert_not_called() - - def test_intermediate_callback_exception_is_swallowed(self, context, text_response): - result = BlockResult( - response=text_response, - llm_call_id=uuid4(), - usage=text_response.usage, - error=None, - ) - - with ( - patch("app.services.llm.chain.chain.Session") as mock_session, - patch("app.services.llm.chain.chain.update_llm_chain_block_completed"), - patch( - "app.services.llm.chain.chain.send_callback", - side_effect=Exception("Connection error"), - ), - ): - mock_session.return_value.__enter__.return_value = MagicMock() - # Should not raise - context.on_block_completed(0, result) - - class TestChainBlock: def test_execute_single_block(self, context, text_response): query = QueryParams(input="test input") config = make_config() block = ChainBlock(config=config, index=0, context=context) - with patch( - "app.services.llm.chain.chain.execute_llm_call" - ) as mock_execute, patch.object(context, "on_block_completed"): + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: mock_execute.return_value = BlockResult( response=text_response, usage=text_response.usage ) @@ -274,37 +134,15 @@ def test_execute_single_block(self, context, text_response): assert result.success mock_execute.assert_called_once() - def test_execute_chains_to_next_block(self, context, text_response): - query = QueryParams(input="test input") - config = make_config() - block1 = ChainBlock(config=config, index=0, context=context) - block2 = ChainBlock(config=config, index=1, context=context) - block1.link(block2) - - with patch( - "app.services.llm.chain.chain.execute_llm_call" - ) as mock_execute, patch.object(context, "on_block_completed"): - mock_execute.return_value = BlockResult( - response=text_response, usage=text_response.usage - ) - - result = block1.execute(query) - - assert mock_execute.call_count == 2 - - def test_execute_stops_on_failure(self, context): + def test_execute_returns_failure(self, context): query = QueryParams(input="test input") config = make_config() - block1 = ChainBlock(config=config, index=0, context=context) - block2 = ChainBlock(config=config, index=1, context=context) - block1.link(block2) + block = ChainBlock(config=config, index=0, context=context) - with patch( - "app.services.llm.chain.chain.execute_llm_call" - ) as mock_execute, patch.object(context, "on_block_completed"): + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: mock_execute.return_value = BlockResult(error="Provider error") - result = block1.execute(query) + result = block.execute(query) assert not result.success assert result.error == "Provider error" @@ -312,8 +150,8 @@ def test_execute_stops_on_failure(self, context): class TestLLMChain: - def test_execute_empty_chain(self): - chain = LLMChain([]) + def test_execute_empty_chain(self, context): + chain = LLMChain([], context) query = QueryParams(input="test") result = chain.execute(query) @@ -324,11 +162,9 @@ def test_execute_empty_chain(self): def test_execute_single_block_chain(self, context, text_response): config = make_config() block = ChainBlock(config=config, index=0, context=context) - chain = LLMChain([block]) + chain = LLMChain([block], context) - with patch( - "app.services.llm.chain.chain.execute_llm_call" - ) as mock_execute, patch.object(context, "on_block_completed"): + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: mock_execute.return_value = BlockResult( response=text_response, usage=text_response.usage ) @@ -341,11 +177,9 @@ def test_execute_single_block_chain(self, context, text_response): def test_execute_multi_block_chain(self, context, text_response): config = make_config() blocks = [ChainBlock(config=config, index=i, context=context) for i in range(3)] - chain = LLMChain(blocks) + chain = LLMChain(blocks, context) - with patch( - "app.services.llm.chain.chain.execute_llm_call" - ) as mock_execute, patch.object(context, "on_block_completed"): + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: mock_execute.return_value = BlockResult( response=text_response, usage=text_response.usage ) @@ -354,3 +188,32 @@ def test_execute_multi_block_chain(self, context, text_response): assert result.success assert mock_execute.call_count == 3 + + def test_execute_stops_on_failure(self, context, text_response): + config = make_config() + blocks = [ChainBlock(config=config, index=i, context=context) for i in range(3)] + chain = LLMChain(blocks, context) + + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: + mock_execute.return_value = BlockResult(error="Provider error") + + result = chain.execute(QueryParams(input="hello")) + + assert not result.success + assert result.error == "Provider error" + mock_execute.assert_called_once() + + def test_execute_calls_on_block_completed(self, context, text_response): + config = make_config() + blocks = [ChainBlock(config=config, index=i, context=context) for i in range(2)] + chain = LLMChain(blocks, context) + callback = MagicMock() + + with patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute: + mock_execute.return_value = BlockResult( + response=text_response, usage=text_response.usage + ) + + chain.execute(QueryParams(input="hello"), on_block_completed=callback) + + assert callback.call_count == 2 diff --git a/backend/app/tests/services/llm/test_chain_executor.py b/backend/app/tests/services/llm/test_chain_executor.py index e8fdc31a9..6564ebafb 100644 --- a/backend/app/tests/services/llm/test_chain_executor.py +++ b/backend/app/tests/services/llm/test_chain_executor.py @@ -213,3 +213,169 @@ def test_setup_updates_job_and_chain_status( if c[1].get("status") == ChainStatus.RUNNING ] assert len(running_calls) == 1 + + +class TestOnBlockCompleted: + def _make_executor(self, context, request_obj): + mock_chain = MagicMock(spec=LLMChain) + return ChainExecutor(chain=mock_chain, context=context, request=request_obj) + + def test_aggregates_usage(self, context, request_obj): + executor = self._make_executor(context, request_obj) + usage = Usage(input_tokens=10, output_tokens=20, total_tokens=30) + result = BlockResult( + response=MagicMock(), llm_call_id=uuid4(), usage=usage, error=None + ) + + with patch("app.services.llm.chain.executor.Session"): + executor._on_block_completed(0, result) + + assert context.aggregated_usage.input_tokens == 10 + assert context.aggregated_usage.output_tokens == 20 + assert context.aggregated_usage.total_tokens == 30 + + def test_aggregates_usage_across_blocks(self, context, request_obj): + executor = self._make_executor(context, request_obj) + result1 = BlockResult( + response=MagicMock(), + llm_call_id=uuid4(), + usage=Usage(input_tokens=10, output_tokens=20, total_tokens=30), + error=None, + ) + result2 = BlockResult( + response=MagicMock(), + llm_call_id=uuid4(), + usage=Usage(input_tokens=5, output_tokens=15, total_tokens=20), + error=None, + ) + + with patch("app.services.llm.chain.executor.Session"): + executor._on_block_completed(0, result1) + executor._on_block_completed(1, result2) + + assert context.aggregated_usage.input_tokens == 15 + assert context.aggregated_usage.total_tokens == 50 + + def test_updates_db_on_success(self, context, request_obj): + executor = self._make_executor(context, request_obj) + llm_call_id = uuid4() + result = BlockResult( + response=MagicMock(), llm_call_id=llm_call_id, usage=MagicMock(), error=None + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch( + "app.services.llm.chain.executor.update_llm_chain_block_completed" + ) as mock_update, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + executor._on_block_completed(0, result) + + mock_update.assert_called_once_with( + mock_session.return_value.__enter__.return_value, + chain_id=context.chain_id, + llm_call_id=llm_call_id, + ) + + def test_skips_db_update_on_error(self, context, request_obj): + executor = self._make_executor(context, request_obj) + result = BlockResult(error="Block failed", usage=MagicMock()) + + with patch( + "app.services.llm.chain.executor.update_llm_chain_block_completed" + ) as mock_update: + executor._on_block_completed(0, result) + mock_update.assert_not_called() + + def test_sends_intermediate_callback(self, context, request_obj, text_response): + context.total_blocks = 3 + context.intermediate_callback_flags = [True, True, False] + executor = self._make_executor(context, request_obj) + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + executor._on_block_completed(0, result) + + mock_callback.assert_called_once() + + def test_skips_intermediate_callback_for_last_block( + self, context, request_obj, text_response + ): + context.total_blocks = 3 + context.intermediate_callback_flags = [True, True, False] + executor = self._make_executor(context, request_obj) + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + executor._on_block_completed(2, result) + + mock_callback.assert_not_called() + + def test_skips_intermediate_callback_when_flag_false( + self, context, request_obj, text_response + ): + context.total_blocks = 3 + context.intermediate_callback_flags = [False, True, False] + executor = self._make_executor(context, request_obj) + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + patch("app.services.llm.chain.executor.send_callback") as mock_callback, + ): + mock_session.return_value.__enter__.return_value = MagicMock() + executor._on_block_completed(0, result) + + mock_callback.assert_not_called() + + def test_intermediate_callback_exception_is_swallowed( + self, context, request_obj, text_response + ): + context.total_blocks = 3 + context.intermediate_callback_flags = [True, True, False] + executor = self._make_executor(context, request_obj) + result = BlockResult( + response=text_response, + llm_call_id=uuid4(), + usage=text_response.usage, + error=None, + ) + + with ( + patch("app.services.llm.chain.executor.Session") as mock_session, + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), + patch( + "app.services.llm.chain.executor.send_callback", + side_effect=Exception("Connection error"), + ), + ): + mock_session.return_value.__enter__.return_value = MagicMock() + # Should not raise + executor._on_block_completed(0, result) diff --git a/backend/app/tests/services/llm/test_jobs.py b/backend/app/tests/services/llm/test_jobs.py index 8cef08e96..cc67a7d6e 100644 --- a/backend/app/tests/services/llm/test_jobs.py +++ b/backend/app/tests/services/llm/test_jobs.py @@ -1287,7 +1287,7 @@ def test_success_flow(self, chain_request_data, mock_llm_response): patch("app.services.llm.chain.executor.send_callback"), patch("app.services.llm.chain.executor.update_llm_chain_status"), patch("app.services.llm.chain.chain.execute_llm_call") as mock_execute_llm, - patch("app.services.llm.chain.chain.Session"), + patch("app.services.llm.chain.executor.update_llm_chain_block_completed"), ): mock_session.return_value.__enter__.return_value = MagicMock() mock_session.return_value.__exit__.return_value = None From 9cc5cf8af2dd17467568f8ad5e1d74e25edee5db Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Fri, 6 Mar 2026 09:02:16 +0530 Subject: [PATCH 10/15] docs: enhance llm_chain.md with detailed input specifications and guardrails --- backend/app/api/docs/llm/llm_chain.md | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/backend/app/api/docs/llm/llm_chain.md b/backend/app/api/docs/llm/llm_chain.md index 1d17f24bf..0f38cc658 100644 --- a/backend/app/api/docs/llm/llm_chain.md +++ b/backend/app/api/docs/llm/llm_chain.md @@ -6,7 +6,7 @@ for processing, and results are delivered via the callback URL when complete. ### Key Parameters **`query`** (required) - Initial query input for the first block in the chain: -- `input` (required, string, min 1 char): User question/prompt/query +- `input` (required): User question/prompt/query — accepts a plain string, a structured input object (`text`, `audio`, `image`, `pdf`), or a list of structured inputs - `conversation` (optional, object): Conversation configuration - `id` (optional, string): Existing conversation ID to continue - `auto_create` (optional, boolean, default false): Create new conversation if no ID provided @@ -26,8 +26,11 @@ for processing, and results are delivered via the callback URL when complete. - **Mode 2: Ad-hoc Configuration** - `blob` (object): Complete configuration object - `completion` (required, object): Completion configuration - - `provider` (required, string): Provider type - either `"openai"` (Kaapi abstraction) or `"openai-native"` (pass-through) - - `params` (required, object): Parameters structure depends on provider type (see schema for detailed structure) + - `provider` (required, string): Kaapi providers (`openai`, `google`, `sarvamai`) — params are validated and mapped internally. Native providers (`openai-native`, `google-native`, `sarvamai-native`) — params are passed through as-is + - `type` (required, string): Completion type — `"text"`, `"stt"`, or `"tts"` + - `params` (required, object): Parameters structure depends on provider and type (see schema for detailed structure) + - `input_guardrails` (optional, array): Guardrails applied to validate/sanitize input before the LLM call + - `output_guardrails` (optional, array): Guardrails applied to validate/sanitize output after the LLM call - `prompt_template` (optional, object): Template for text interpolation - `template` (required, string): Template string with `{{input}}` placeholder — replaced with the block's input before execution - **Note** From f7797d1d5785909c8876e586a59c7b17f41b1857 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Fri, 6 Mar 2026 11:55:34 +0530 Subject: [PATCH 11/15] refactor: remove unused timestamps from LlmChain model and update related tests --- .../alembic/versions/048_create_llm_chain_table.py | 14 +------------- backend/app/crud/llm_chain.py | 5 ----- backend/app/models/llm/request.py | 14 +------------- backend/app/tests/crud/test_llm_chain.py | 3 --- 4 files changed, 2 insertions(+), 34 deletions(-) diff --git a/backend/app/alembic/versions/048_create_llm_chain_table.py b/backend/app/alembic/versions/048_create_llm_chain_table.py index ac49eb0ec..ad498d465 100644 --- a/backend/app/alembic/versions/048_create_llm_chain_table.py +++ b/backend/app/alembic/versions/048_create_llm_chain_table.py @@ -107,19 +107,7 @@ def upgrade() -> None: comment="Future-proof extensibility catch-all", ), sa.Column( - "started_at", - sa.DateTime(), - nullable=True, - comment="Timestamp when chain execution started", - ), - sa.Column( - "completed_at", - sa.DateTime(), - nullable=True, - comment="Timestamp when chain execution completed", - ), - sa.Column( - "created_at", + "inserted_at", sa.DateTime(), nullable=False, comment="Timestamp when the chain record was created", diff --git a/backend/app/crud/llm_chain.py b/backend/app/crud/llm_chain.py index 77ab70987..010d8abbd 100644 --- a/backend/app/crud/llm_chain.py +++ b/backend/app/crud/llm_chain.py @@ -85,18 +85,13 @@ def update_llm_chain_status( db_chain.status = status db_chain.updated_at = now() - if status == ChainStatus.RUNNING: - db_chain.started_at = now() - if status == ChainStatus.FAILED: db_chain.error = error db_chain.total_usage = total_usage - db_chain.completed_at = now() if status == ChainStatus.COMPLETED: db_chain.output = output db_chain.total_usage = total_usage - db_chain.completed_at = now() session.add(db_chain) session.commit() diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 1760e569f..0a8c33818 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -744,19 +744,7 @@ class LlmChain(SQLModel, table=True): ), ) - started_at: datetime | None = Field( - default=None, - nullable=True, - sa_column_kwargs={"comment": "Timestamp when chain execution started"}, - ) - - completed_at: datetime | None = Field( - default=None, - nullable=True, - sa_column_kwargs={"comment": "Timestamp when chain execution completed"}, - ) - - created_at: datetime = Field( + inserted_at: datetime = Field( default_factory=now, nullable=False, sa_column_kwargs={"comment": "Timestamp when the chain record was created"}, diff --git a/backend/app/tests/crud/test_llm_chain.py b/backend/app/tests/crud/test_llm_chain.py index 84324f86c..dfeceeee4 100644 --- a/backend/app/tests/crud/test_llm_chain.py +++ b/backend/app/tests/crud/test_llm_chain.py @@ -67,7 +67,6 @@ def test_update_to_running(self, db: Session, chain): ) assert updated.status == ChainStatus.RUNNING - assert updated.started_at is not None def test_update_to_completed(self, db: Session, chain): output = {"type": "text", "content": {"value": "result"}} @@ -84,7 +83,6 @@ def test_update_to_completed(self, db: Session, chain): assert updated.status == ChainStatus.COMPLETED assert updated.output == output assert updated.total_usage == usage - assert updated.completed_at is not None def test_update_to_failed(self, db: Session, chain): usage = {"input_tokens": 5, "output_tokens": 0, "total_tokens": 5} @@ -100,7 +98,6 @@ def test_update_to_failed(self, db: Session, chain): assert updated.status == ChainStatus.FAILED assert updated.error == "Provider timeout" assert updated.total_usage == usage - assert updated.completed_at is not None def test_raises_for_missing_chain(self, db: Session): with pytest.raises(ValueError, match="LLM chain not found"): From 5b9a4e9e9be30067eb2db618a37abc838d71c403 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Thu, 5 Mar 2026 22:46:54 +0530 Subject: [PATCH 12/15] feat: basic speech-to-speech impl on top of llm_chain --- backend/app/api/docs/llm/speech_to_speech.md | 228 +++++++++ backend/app/api/main.py | 2 + backend/app/api/routes/llm_speech.py | 141 ++++++ backend/app/api/routes/llm_speech_examples.md | 470 ++++++++++++++++++ backend/app/models/llm/request.py | 89 ++++ backend/app/services/llm/chain/utils.py | 197 ++++++++ 6 files changed, 1127 insertions(+) create mode 100644 backend/app/api/docs/llm/speech_to_speech.md create mode 100644 backend/app/api/routes/llm_speech.py create mode 100644 backend/app/api/routes/llm_speech_examples.md create mode 100644 backend/app/services/llm/chain/utils.py diff --git a/backend/app/api/docs/llm/speech_to_speech.md b/backend/app/api/docs/llm/speech_to_speech.md new file mode 100644 index 000000000..bad465b66 --- /dev/null +++ b/backend/app/api/docs/llm/speech_to_speech.md @@ -0,0 +1,228 @@ +# Speech-to-Speech (STS) with RAG + +Execute a complete speech-to-speech workflow with knowledge base retrieval. + +## Endpoint + +``` +POST /llm/sts +``` + +## Flow + +``` +Voice Input → STT (auto language) → RAG (Knowledge Base) → TTS → Voice Output +``` + +## Input + +- **Voice note**: WhatsApp-compatible audio format (required) +- **Knowledge base IDs**: One or more knowledge bases for RAG (required) +- **Languages**: Input and output languages (optional, defaults to Hindi) +- **Models**: STT, LLM, and TTS model selection (optional, defaults to Sarvam) + +## Output + +You will receive **3 callbacks** to your webhook URL: + +1. **STT Callback** (Intermediate): Transcribed text from audio +2. **LLM Callback** (Intermediate): RAG-enhanced response text +3. **TTS Callback** (Final): Audio output + response text + +Each callback includes: +- Output from that step +- Token usage +- Latency information (check timestamps) + +## Supported Languages + +### Primary Indian Languages +- English, Hindi, Hinglish (code-switching) +- Bengali, Kannada, Malayalam, Marathi +- Odia, Punjabi, Tamil, Telugu, Gujarati + +### Additional Languages (Sarvam Saaras V3) +- Assamese, Urdu, Nepali +- Konkani, Kashmiri, Sindhi +- Sanskrit, Santali, Manipuri +- Bodo, Maithili, Dogri + +**Total: 25 languages** with automatic language detection + +## Available Models + +### STT (Speech-to-Text) +- `saaras:v3` - Sarvam Saaras V3 (**default**, fast, auto language detection, optimized for Indian languages) +- `gemini-2.5-pro` - Google Gemini 2.5 Pro + +**Note:** Sarvam STT uses automatic language detection. No need to specify input language. + +### LLM (RAG) +- `gpt-4o` - OpenAI GPT-4o (**default**, best quality) +- `gpt-4o-mini` - OpenAI GPT-4o Mini (faster, lower cost) + +### TTS (Text-to-Speech) +- `bulbul-v3` - Sarvam Bulbul V3 (**default**, natural Indian voices, MP3 output) +- `gemini-2.5-pro-preview-tts` - Google Gemini 2.5 Pro (OGG OPUS output) + +## Edge Cases & Error Handling + +### Empty STT Output +If speech-to-text returns empty/blank: +- Chain fails immediately +- Error message: "STT returned no transcription" +- No subsequent blocks are executed + +### Audio Size Limit +WhatsApp limit: 16MB +- TTS providers may fail if output exceeds limit +- Error is caught and reported in callback +- Consider using shorter responses or compression + +### Invalid Audio Format +If input audio format is unsupported: +- STT provider fails with format error +- Error reported in callback +- Supported: MP3, WAV, OGG, OPUS, M4A + +### Provider Failures +Each block has independent error handling: +- STT fails → Chain stops, STT error reported +- LLM fails → Chain stops, RAG error reported +- TTS fails → Chain stops, TTS error reported + +## Example Request + +```bash +curl -X POST https://api.kaapi.ai/llm/sts \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d @- < 16MB: TTS provider will fail (caught and reported) + - Invalid audio format: STT provider will fail (caught and reported) + """ + project_id = _current_user.project_.id + organization_id = _current_user.organization_.id + + # Validate callback URL + if request.callback_url: + validate_callback_url(str(request.callback_url)) + + # Validate and determine languages + if request.input_language and request.input_language != "auto": + if request.input_language not in LANGUAGE_CODES: + from fastapi import HTTPException + + raise HTTPException( + status_code=400, + detail=f"Unsupported input language: {request.input_language}. Supported: {', '.join(LANGUAGE_CODES.keys())}", + ) + + if request.output_language and request.output_language not in LANGUAGE_CODES: + from fastapi import HTTPException + + raise HTTPException( + status_code=400, + detail=f"Unsupported output language: {request.output_language}. Supported: {', '.join(LANGUAGE_CODES.keys())}", + ) + + input_lang_code = get_language_code(request.input_language) + output_lang_code = get_language_code( + request.output_language, default=request.input_language or "auto" + ) + + logger.info( + f"[speech_to_speech] Starting STS chain | " + f"project_id={project_id}, " + f"input_lang={input_lang_code}, " + f"output_lang={output_lang_code}, " + f"stt_model={request.stt_model.value}, " + f"llm_model={request.llm_model.value}, " + f"tts_model={request.tts_model.value}" + ) + + # Build 3-block chain: STT → RAG → TTS + blocks = [ + build_stt_block(request.stt_model, input_lang_code), + build_rag_block(request.llm_model, request.knowledge_base_ids), + build_tts_block(request.tts_model, output_lang_code), + ] + + # Add metadata to track STS-specific info + metadata = request.request_metadata or {} + metadata.update( + { + "speech_to_speech": True, + "input_language": input_lang_code, + "output_language": output_lang_code, + "stt_model": request.stt_model.value, + "llm_model": request.llm_model.value, + "tts_model": request.tts_model.value, + } + ) + + # Create chain request + chain_request = LLMChainRequest( + query=QueryParams(input=request.audio), + blocks=blocks, + callback_url=request.callback_url, + request_metadata=metadata, + ) + + # Start async chain job + start_chain_job( + db=_session, + request=chain_request, + project_id=project_id, + organization_id=organization_id, + ) + + return APIResponse.success_response( + data=Message( + message=( + "Speech-to-speech processing initiated. " + "You will receive intermediate callbacks for STT and LLM outputs, " + "followed by the final callback with audio and text." + ) + ) + ) diff --git a/backend/app/api/routes/llm_speech_examples.md b/backend/app/api/routes/llm_speech_examples.md new file mode 100644 index 000000000..43e578dc2 --- /dev/null +++ b/backend/app/api/routes/llm_speech_examples.md @@ -0,0 +1,470 @@ +# Speech-to-Speech (STS) API Examples + +## Endpoint + +``` +POST /llm/sts +``` + +## Quick Start + +### Minimal Request (All Defaults) +```bash +curl -X POST https://api.kaapi.ai/llm/sts \ + -H "Authorization: Bearer YOUR_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "audio": { + "type": "audio", + "content": { + "format": "base64", + "value": "BASE64_AUDIO_DATA", + "mime_type": "audio/ogg" + } + }, + "knowledge_base_ids": ["kb_abc123"], + "callback_url": "https://your-app.com/webhook" + }' +``` + +**Defaults Used:** +- Input language: Auto-detect (Sarvam STT) +- Output language: Hindi (same as detected input) +- STT: Sarvam Saaras V3 (auto language detection) +- LLM: OpenAI GPT-4o +- TTS: Sarvam Bulbul V3 + +--- + +## Full Configuration Example + +### Request with All Options +```json +{ + "audio": { + "type": "audio", + "content": { + "format": "base64", + "value": "UklGRiQAAABXQVZFZm10...", + "mime_type": "audio/ogg" + } + }, + "knowledge_base_ids": ["kb_customer_support", "kb_product_info"], + "input_language": "hindi", + "output_language": "english", + "stt_model": "saaras:v3", + "llm_model": "gpt-4o", + "tts_model": "bulbul:v3", + "callback_url": "https://api.yourapp.com/webhooks/speech-response", + "request_metadata": { + "user_id": "user_123", + "session_id": "session_456", + "source": "whatsapp" + } +} +``` + +### Response (Immediate) +```json +{ + "success": true, + "data": { + "message": "Speech-to-speech processing initiated. You will receive intermediate callbacks for STT and LLM outputs, followed by the final callback with audio and text." + } +} +``` + +--- + +## Callback Sequence + +You'll receive **3 callbacks** to your webhook URL: + +### 1. STT Callback (Intermediate) +Sent after audio transcription completes. + +```json +{ + "success": true, + "data": { + "block_index": 1, + "total_blocks": 3, + "response": { + "provider_response_id": "stt_xyz789", + "provider": "google-native", + "model": "gemini-2.5-pro", + "output": { + "type": "text", + "content": { + "value": "मेरा अकाउंट बैलेंस क्या है?" + } + } + }, + "usage": { + "input_tokens": 0, + "output_tokens": 8, + "total_tokens": 8 + } + }, + "metadata": { + "speech_to_speech": true, + "input_language": "hi-IN", + "output_language": "en-IN", + "stt_model": "gemini-2.5-pro", + "llm_model": "gpt-4o", + "tts_model": "bulbul-v3", + "user_id": "user_123", + "session_id": "session_456", + "source": "whatsapp" + } +} +``` + +**Latency Calculation:** +``` +STT_latency = callback_1_timestamp - request_timestamp +``` + +--- + +### 2. LLM/RAG Callback (Intermediate) +Sent after knowledge base retrieval and response generation. + +```json +{ + "success": true, + "data": { + "block_index": 2, + "total_blocks": 3, + "response": { + "provider_response_id": "chatcmpl_abc123", + "conversation_id": null, + "provider": "openai", + "model": "gpt-4o", + "output": { + "type": "text", + "content": { + "value": "Your current account balance is ₹5,000. You have 3 transactions in the last month." + } + } + }, + "usage": { + "input_tokens": 250, + "output_tokens": 22, + "total_tokens": 272 + } + }, + "metadata": { + "speech_to_speech": true, + "user_id": "user_123", + "session_id": "session_456", + "source": "whatsapp" + } +} +``` + +**Latency Calculation:** +``` +LLM_latency = callback_2_timestamp - callback_1_timestamp +``` + +--- + +### 3. TTS Callback (Final) +Sent after text-to-speech conversion completes. This is your final output. + +```json +{ + "success": true, + "data": { + "response": { + "provider_response_id": "tts_def456", + "conversation_id": null, + "provider": "sarvamai-native", + "model": "bulbul:v1", + "output": { + "type": "audio", + "content": { + "format": "base64", + "value": "T2dnUwACAAAAAAAAAAAEBQ...", + "mime_type": "audio/ogg" + } + } + }, + "usage": { + "input_tokens": 22, + "output_tokens": 0, + "total_tokens": 22 + } + }, + "metadata": { + "speech_to_speech": true, + "output_language": "en-IN", + "user_id": "user_123", + "session_id": "session_456", + "source": "whatsapp" + } +} +``` + +**Latency Calculation:** +``` +TTS_latency = callback_3_timestamp - callback_2_timestamp +Total_latency = callback_3_timestamp - request_timestamp +``` + +--- + +## Error Handling Examples + +### Empty STT Output +If the audio contains no speech or is unintelligible: + +```json +{ + "success": false, + "error": "STT returned no transcription. The audio may be empty or unintelligible.", + "metadata": { + "speech_to_speech": true, + "user_id": "user_123" + } +} +``` + +### Invalid Audio Format +If the audio format is not supported: + +```json +{ + "success": false, + "error": "SarvamAI STT transcription failed: Invalid audio format. Supported formats: mp3, wav, ogg, opus, m4a", + "metadata": { + "speech_to_speech": true, + "user_id": "user_123" + } +} +``` + +### Audio Size Exceeds Limit +If TTS generates audio > 16MB (rare): + +```json +{ + "success": false, + "error": "TTS audio output exceeds WhatsApp size limit (16MB). Try reducing response length.", + "metadata": { + "speech_to_speech": true, + "user_id": "user_123" + } +} +``` + +### Knowledge Base Not Found +If specified knowledge base doesn't exist: + +```json +{ + "success": false, + "error": "Knowledge base 'kb_invalid' not found or not accessible.", + "metadata": { + "speech_to_speech": true, + "user_id": "user_123" + } +} +``` + +--- + +## Language-Specific Examples + +### English → English +```json +{ + "audio": {...}, + "knowledge_base_ids": ["kb_123"], + "input_language": "english", + "output_language": "english", + "callback_url": "..." +} +``` + +### Hindi → English (Translation) +```json +{ + "audio": {...}, + "knowledge_base_ids": ["kb_123"], + "input_language": "hindi", + "output_language": "english", + "callback_url": "..." +} +``` + +### Hinglish (Code-Switching) +```json +{ + "audio": {...}, + "knowledge_base_ids": ["kb_123"], + "input_language": "hinglish", + "output_language": "hinglish", + "callback_url": "..." +} +``` +**Note:** Hinglish is treated as Hindi for model selection. + +### Regional Indian Languages +```json +{ + "audio": {...}, + "knowledge_base_ids": ["kb_123"], + "input_language": "auto", // Auto-detect + "output_language": "odia", // Odia, Bengali, Punjabi, etc. + "callback_url": "..." +} +``` + +**Supported Regional Languages:** +- Bengali, Malayalam, Punjabi, Odia +- Assamese, Urdu, Nepali +- Konkani, Kashmiri, Sindhi, Sanskrit +- Santali, Manipuri, Bodo, Maithili, Dogri + +--- + +## Model Selection Guide + +### For Indian Languages (Recommended - Default) +```json +{ + "stt_model": "saaras:v3", + "llm_model": "gpt-4o", + "tts_model": "bulbul-v3" +} +``` +**Benefits:** +- Auto language detection (no need to specify language) +- Fastest processing +- Best accent handling for Indian languages +- Natural voice quality +- MP3 output (WhatsApp compatible) + +### For Maximum Accuracy +```json +{ + "stt_model": "gemini-2.5-pro", + "llm_model": "gpt-4o", + "tts_model": "gemini-2.5-pro-preview-tts" +} +``` +**Benefits:** Highest accuracy, best for complex queries, OGG OPUS output + +### For Cost Optimization +```json +{ + "stt_model": "saaras:v3", + "llm_model": "gpt-4o-mini", + "tts_model": "bulbul-v3" +} +``` +**Benefits:** Lower cost, still good quality, faster response + +--- + +## Integration Patterns + +### WhatsApp Bot Integration +```python +import base64 +import requests + +def handle_whatsapp_voice_message(audio_url, user_id): + # Download audio from WhatsApp + audio_response = requests.get(audio_url) + audio_base64 = base64.b64encode(audio_response.content).decode() + + # Send to Kaapi STS + response = requests.post( + "https://api.kaapi.ai/llm/sts", + headers={"Authorization": f"Bearer {API_KEY}"}, + json={ + "audio": { + "type": "audio", + "content": { + "format": "base64", + "value": audio_base64, + "mime_type": "audio/ogg" + } + }, + "knowledge_base_ids": ["kb_customer_support"], + "callback_url": f"https://yourapp.com/webhook?user={user_id}", + "request_metadata": {"user_id": user_id} + } + ) + + return response.json() + +def handle_s2s_callback(callback_data): + """Handle the final TTS callback.""" + if not callback_data["success"]: + # Handle error + return + + # Extract final audio + audio_base64 = callback_data["data"]["response"]["output"]["content"]["value"] + audio_bytes = base64.b64decode(audio_base64) + + # Send back to WhatsApp user + send_whatsapp_voice(audio_bytes, user_id) +``` + +--- + +## Performance Benchmarks + +**Typical Latencies** (with Sarvam models, Hindi): +- STT: 1-2 seconds +- RAG: 2-4 seconds +- TTS: 1-2 seconds +- **Total: 4-8 seconds** + +**With Gemini models**: +- STT: 2-3 seconds +- RAG: 2-4 seconds +- TTS: 2-3 seconds +- **Total: 6-10 seconds** + +--- + +## Testing Tips + +1. **Test with Silent Audio**: Verify error handling for empty STT +2. **Test Different Formats**: OGG, MP3, WAV, M4A +3. **Test Language Mixing**: Hinglish, code-switching +4. **Test Long Audio**: >1 minute clips +5. **Load Test**: Multiple concurrent requests +6. **Monitor Latencies**: Track each block's timing +7. **Validate Audio Output**: Ensure < 16MB for WhatsApp + +--- + +## Troubleshooting + +### High Latency +- Check knowledge base size (larger = slower retrieval) +- Consider using faster models (gemini-flash, gpt-4o-mini) +- Verify callback URL response time + +### Poor Transcription Quality +- Ensure audio quality is good (no background noise) +- Try different STT models +- Check if language setting matches audio + +### Unnatural TTS Voice +- Try different TTS models +- Sarvam Bulbul is best for Indian accents +- Gemini is good for neutral accents + +### Callback Not Received +- Verify callback URL is publicly accessible +- Check for HTTPS (required) +- Ensure webhook can handle POST requests +- Check firewall settings diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 0a8c33818..ced2a60a1 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -11,6 +11,28 @@ from app.core.util import now +# Speech-to-Speech Model Enums +class STTModel(str, Enum): + """Supported STT models for speech-to-speech.""" + + GEMINI_PRO = "gemini-2.5-pro" + SARVAM = "saaras:v3" + + +class TTSModel(str, Enum): + """Supported TTS models for speech-to-speech.""" + + GEMINI_PRO = "gemini-2.5-pro-preview-tts" + SARVAM = "bulbul:v3" + + +class LLMModel(str, Enum): + """Supported LLM models for RAG in speech-to-speech.""" + + GPT4O = "gpt-4o" + GPT4O_MINI = "gpt-4o-mini" + + class TextLLMParams(SQLModel): model: str instructions: str | None = Field( @@ -757,3 +779,70 @@ class LlmChain(SQLModel, table=True): "comment": "Timestamp when the chain record was last updated" }, ) + + +class SpeechToSpeechRequest(SQLModel): + """ + API request for speech-to-speech (STS) with RAG. + + Convenience endpoint that orchestrates a 3-block chain: + STT → RAG → TTS + + Input: Audio + Output: Audio + Text (via callback) + """ + + audio: AudioInput = Field( + ..., description="Voice note input (WhatsApp compatible format)" + ) + knowledge_base_ids: list[str] = Field( + ..., min_length=1, description="Knowledge base IDs for RAG retrieval" + ) + + # Optional language config + input_language: str | None = Field( + "auto", + description=( + "Input language for STT (auto-detect by default). " + "Supported: auto, english, hindi, hinglish, bengali, kannada, malayalam, marathi, " + "odia, punjabi, tamil, telugu, gujarati, assamese, urdu, nepali, konkani, kashmiri, " + "sindhi, sanskrit, santali, manipuri, bodo, maithili, dogri" + ), + ) + output_language: str | None = Field( + None, + description=( + "Output language for TTS (defaults to input_language if not specified). " + "Same language options as input_language." + ), + ) + + # Optional model overrides + stt_model: STTModel = Field( + STTModel.SARVAM, description="STT model (default: Sarvam Saaras V3)" + ) + tts_model: TTSModel = Field( + TTSModel.SARVAM, description="TTS model (default: Sarvam Bulbul V3)" + ) + llm_model: LLMModel = Field( + LLMModel.GPT4O, description="LLM model for RAG (default: GPT-4o)" + ) + + # Callback and metadata + callback_url: HttpUrl | None = Field( + None, description="Webhook URL for async response delivery" + ) + request_metadata: dict[str, Any] | None = Field( + None, description="Client-provided metadata" + ) + + @model_validator(mode="after") + def validate_languages(self): + """Validate language fields.""" + # Validation happens in the endpoint using LANGUAGE_CODES from utils + # This is just to ensure the fields are lowercase if provided + if self.input_language and self.input_language != "auto": + self.input_language = self.input_language.lower() + if self.output_language: + self.output_language = self.output_language.lower() + return self diff --git a/backend/app/services/llm/chain/utils.py b/backend/app/services/llm/chain/utils.py new file mode 100644 index 000000000..376879fd8 --- /dev/null +++ b/backend/app/services/llm/chain/utils.py @@ -0,0 +1,197 @@ +"""Utility functions for LLM chain operations, including speech-to-speech helpers.""" + +from typing import Any, Literal + +from app.models.llm.request import ( + ChainBlock, + ConfigBlob, + KaapiCompletionConfig, + LLMCallConfig, + LLMModel, + NativeCompletionConfig, + STTModel, + TextLLMParams, + TTSModel, +) + + +# Supported languages for speech-to-speech (BCP-47 language codes) +LANGUAGE_CODES = { + # Auto-detect + "auto": "unknown", # Sarvam auto-detection + # Primary Indian languages + "english": "en-IN", + "hindi": "hi-IN", + "hinglish": "hi-IN", # Code-switching, treat as Hindi + "bengali": "bn-IN", + "kannada": "kn-IN", + "malayalam": "ml-IN", + "marathi": "mr-IN", + "odia": "od-IN", + "punjabi": "pa-IN", + "tamil": "ta-IN", + "telugu": "te-IN", + "gujarati": "gu-IN", + # Additional languages (saaras:v3) + "assamese": "as-IN", + "urdu": "ur-IN", + "nepali": "ne-IN", + "konkani": "kok-IN", + "kashmiri": "ks-IN", + "sindhi": "sd-IN", + "sanskrit": "sa-IN", + "santali": "sat-IN", + "manipuri": "mni-IN", + "bodo": "brx-IN", + "maithili": "mai-IN", + "dogri": "doi-IN", +} + + +def get_language_code(language: str | None, default: str = "auto") -> str: + """Convert language name to BCP-47 language code. + + Args: + language: Language name (e.g., "hindi", "english", "auto") + default: Default language if not specified (default: "auto") + + Returns: + BCP-47 language code (e.g., "hi-IN", "en-IN", "unknown" for auto-detect) + """ + lang = (language or default).lower() + return LANGUAGE_CODES.get(lang, LANGUAGE_CODES["auto"]) + + +def build_stt_block(model: STTModel, language_code: str) -> ChainBlock: + """Build STT (Speech-to-Text) block configuration. + + Args: + model: STT model enum + language_code: ISO language code (e.g., "hi-IN") + + Returns: + ChainBlock configured for STT + """ + # Map model to provider and actual model name + model_configs: dict[ + STTModel, + tuple[Literal["sarvamai-native", "google-native", "openai-native"], str], + ] = { + STTModel.SARVAM: ("sarvamai-native", "saaras:v3"), + STTModel.GEMINI_PRO: ("google-native", "gemini-2.5-pro"), + } + + provider, model_name = model_configs[model] + + # Build native config (provider-specific params) + params: dict[str, Any] = { + "model": model_name, + } + + # Add provider-specific parameters + if provider == "sarvamai-native": + # Use "unknown" for automatic language detection, or specific BCP-47 code + params["language_code"] = ( + language_code if language_code != "unknown" else "unknown" + ) + params["mode"] = "transcription" + elif provider == "google-native": + # Google requires specific language code, fallback to en-IN if unknown + params["language_code"] = ( + language_code if language_code != "unknown" else "en-IN" + ) + + return ChainBlock( + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider=provider, + type="stt", + params=params, + ) + ) + ), + intermediate_callback=True, # Send STT result to user + include_provider_raw_response=False, + ) + + +def build_rag_block(model: LLMModel, knowledge_base_ids: list[str]) -> ChainBlock: + """Build RAG (Retrieval-Augmented Generation) block configuration. + + Args: + model: LLM model enum + knowledge_base_ids: List of knowledge base IDs for retrieval + + Returns: + ChainBlock configured for RAG + """ + return ChainBlock( + config=LLMCallConfig( + blob=ConfigBlob( + completion=KaapiCompletionConfig( + provider="openai", + type="text", + params=TextLLMParams( + model=model.value, + knowledge_base_ids=knowledge_base_ids, + temperature=0.1, + instructions="Answer the user's question using the provided knowledge base. Be concise and accurate.", + ).model_dump(exclude_none=True), + ) + ) + ), + intermediate_callback=True, # Send LLM result to user + include_provider_raw_response=False, + ) + + +def build_tts_block(model: TTSModel, language_code: str) -> ChainBlock: + """Build TTS (Text-to-Speech) block configuration. + + Args: + model: TTS model enum + language_code: ISO language code (e.g., "hi-IN") + + Returns: + ChainBlock configured for TTS + """ + # Map model to provider and actual model name + voice + model_configs: dict[ + TTSModel, + tuple[Literal["sarvamai-native", "google-native", "openai-native"], str, str], + ] = { + TTSModel.SARVAM: ("sarvamai-native", "bulbul:v3", "simran"), + TTSModel.GEMINI_PRO: ("google-native", "gemini-2.5-pro", "default"), + } + + provider, model_name, voice = model_configs[model] + + # Build native config + params: dict[str, Any] = { + "model": model_name, + "voice": voice, + } + + # Add provider-specific parameters + if provider == "sarvamai-native": + params["target_language_code"] = language_code + params["speaker"] = voice + params["output_audio_codec"] = "mp3" # WhatsApp compatible + elif provider == "google-native": + params["language_code"] = language_code + params["audio_encoding"] = "OGG_OPUS" # WhatsApp compatible + + return ChainBlock( + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider=provider, + type="tts", + params=params, + ) + ) + ), + intermediate_callback=False, # Final result only + include_provider_raw_response=False, + ) From c1807df44cd2344ed88cd594b19cdef01aa2dbb6 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Fri, 6 Mar 2026 00:01:44 +0530 Subject: [PATCH 13/15] feat: add s2s blocks --- backend/app/api/docs/llm/speech_to_speech.md | 4 +- backend/app/api/routes/llm_speech.py | 2 +- backend/app/api/routes/llm_speech_examples.md | 14 +- backend/app/models/llm/request.py | 2 +- backend/app/services/llm/chain/utils.py | 6 +- backend/test_sts_debug.py | 193 ++++++++++++++++++ 6 files changed, 207 insertions(+), 14 deletions(-) create mode 100644 backend/test_sts_debug.py diff --git a/backend/app/api/docs/llm/speech_to_speech.md b/backend/app/api/docs/llm/speech_to_speech.md index bad465b66..e4ad03e6f 100644 --- a/backend/app/api/docs/llm/speech_to_speech.md +++ b/backend/app/api/docs/llm/speech_to_speech.md @@ -62,7 +62,7 @@ Each callback includes: - `gpt-4o-mini` - OpenAI GPT-4o Mini (faster, lower cost) ### TTS (Text-to-Speech) -- `bulbul-v3` - Sarvam Bulbul V3 (**default**, natural Indian voices, MP3 output) +- `bulbul:v3` - Sarvam Bulbul V3 (**default**, natural Indian voices, MP3 output) - `gemini-2.5-pro-preview-tts` - Google Gemini 2.5 Pro (OGG OPUS output) ## Edge Cases & Error Handling @@ -99,7 +99,7 @@ curl -X POST https://api.kaapi.ai/llm/sts \ -H "Content-Type: application/json" \ -d @- < ChainBlock: params["language_code"] = ( language_code if language_code != "unknown" else "unknown" ) - params["mode"] = "transcription" + params["mode"] = "transcribe" elif provider == "google-native": # Google requires specific language code, fallback to en-IN if unknown params["language_code"] = ( @@ -146,7 +146,7 @@ def build_rag_block(model: LLMModel, knowledge_base_ids: list[str]) -> ChainBloc ) -def build_tts_block(model: TTSModel, language_code: str) -> ChainBlock: +def build_tts_block(model: TTSModel, language_code: str = "en-IN") -> ChainBlock: """Build TTS (Text-to-Speech) block configuration. Args: @@ -175,7 +175,7 @@ def build_tts_block(model: TTSModel, language_code: str) -> ChainBlock: # Add provider-specific parameters if provider == "sarvamai-native": - params["target_language_code"] = language_code + params["target_language_code"] = "en-IN" params["speaker"] = voice params["output_audio_codec"] = "mp3" # WhatsApp compatible elif provider == "google-native": diff --git a/backend/test_sts_debug.py b/backend/test_sts_debug.py new file mode 100644 index 000000000..f6dd92c10 --- /dev/null +++ b/backend/test_sts_debug.py @@ -0,0 +1,193 @@ +"""Debug script for STS endpoint and chain job execution.""" + +import logging +import sys +from sqlmodel import Session + +# Setup logging +logging.basicConfig( + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +def test_chain_job_creation(): + """Test if chain job can be created and queued.""" + from app.core.db import engine + from app.models.llm.request import ( + LLMChainRequest, + QueryParams, + AudioInput, + AudioContent, + ChainBlock, + LLMCallConfig, + ConfigBlob, + NativeCompletionConfig, + ) + from app.services.llm.jobs import start_chain_job + + print("\n" + "=" * 80) + print("STEP 1: Creating test chain request") + print("=" * 80) + + # Create a minimal valid chain request + test_request = LLMChainRequest( + query=QueryParams( + input=AudioInput( + type="audio", + content=AudioContent( + format="base64", + value="dGVzdF9hdWRpbw==", # base64 encoded "test_audio" + mime_type="audio/ogg", + ), + ) + ), + blocks=[ + ChainBlock( + config=LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="sarvamai-native", + type="stt", + params={ + "model": "saarika:v1", + "language_code": "unknown", + "mode": "transcription", + }, + ) + ) + ), + intermediate_callback=True, + ) + ], + ) + + print(f"✅ Test request created with {len(test_request.blocks)} block(s)") + + print("\n" + "=" * 80) + print("STEP 2: Attempting to start chain job") + print("=" * 80) + + try: + with Session(engine) as session: + job_id = start_chain_job( + db=session, + request=test_request, + project_id=1, # Use test project ID + organization_id=1, # Use test org ID + ) + print(f"✅ Chain job created successfully!") + print(f" Job ID: {job_id}") + print(f" Check your Celery worker logs for task execution") + return job_id + except Exception as e: + print(f"❌ Failed to create chain job: {e}") + import traceback + + traceback.print_exc() + return None + + +def check_celery_connection(): + """Check if Celery is running and can receive tasks.""" + print("\n" + "=" * 80) + print("STEP 3: Checking Celery connection") + print("=" * 80) + + try: + from app.celery.celery_app import celery_app + + # Check if broker is reachable + inspector = celery_app.control.inspect() + active_workers = inspector.active() + + if active_workers: + print(f"✅ Celery workers are running:") + for worker_name, tasks in active_workers.items(): + print(f" - {worker_name}: {len(tasks)} active tasks") + else: + print("⚠️ No active Celery workers found!") + print(" Make sure to start the Celery worker with:") + print(" celery -A app.celery.celery_app worker --loglevel=info") + + # Check registered tasks + registered = inspector.registered() + if registered: + print(f"\n✅ Registered tasks:") + for worker_name, tasks in registered.items(): + print(f" Worker: {worker_name}") + for task in sorted(tasks): + if "high_priority" in task or "chain" in task.lower(): + print(f" - {task}") + + except Exception as e: + print(f"❌ Failed to check Celery: {e}") + import traceback + + traceback.print_exc() + + +def check_function_import(): + """Verify execute_chain_job can be imported.""" + print("\n" + "=" * 80) + print("STEP 4: Verifying execute_chain_job import") + print("=" * 80) + + try: + from app.services.llm.jobs import execute_chain_job + + print(f"✅ execute_chain_job is importable") + print(f" Parameters: {execute_chain_job.__code__.co_varnames[:6]}") + + # Try dynamic import (same way Celery does it) + import importlib + + module = importlib.import_module("app.services.llm.jobs") + func = getattr(module, "execute_chain_job") + print(f"✅ Dynamic import successful (same as Celery)") + + except Exception as e: + print(f"❌ Import failed: {e}") + import traceback + + traceback.print_exc() + + +if __name__ == "__main__": + print("\n" + "=" * 80) + print("STS ENDPOINT DEBUG SCRIPT") + print("=" * 80) + + check_function_import() + check_celery_connection() + job_id = test_chain_job_creation() + + if job_id: + print("\n" + "=" * 80) + print("DEBUGGING SUMMARY") + print("=" * 80) + print(f"✅ Chain job was queued successfully: {job_id}") + print(f"\nNext steps:") + print(f"1. Check your Celery worker logs for:") + print( + f" - Task app.celery.tasks.job_execution.execute_high_priority_task received" + ) + print(f" - Executing high_priority job {job_id}") + print(f" - Function path: app.services.llm.jobs.execute_chain_job") + print(f"\n2. If you don't see the task in worker logs:") + print(f" - Verify Celery broker (RabbitMQ/Redis) is running") + print(f" - Check broker connection in Celery worker startup logs") + print(f" - Restart Celery worker") + print(f"\n3. If task starts but fails:") + print(f" - Look for error in Celery worker logs") + print( + f" - Check database for job status: SELECT * FROM job WHERE id = '{job_id}';" + ) + else: + print("\n" + "=" * 80) + print("DEBUGGING SUMMARY") + print("=" * 80) + print(f"❌ Failed to queue chain job") + print(f" Check the error messages above for details") + + print("=" * 80 + "\n") From 56920b1bd1f63a52ee4089ca2cf478eeda80eb1d Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Mon, 9 Mar 2026 14:32:47 +0530 Subject: [PATCH 14/15] feat: detected lang in the webhook rsponse, context passing across links and BCP-47 code normalization --- backend/app/api/main.py | 5 +- backend/app/api/routes/llm_speech_examples.md | 470 -------------- .../api/routes/{llm_speech.py => llm_sts.py} | 52 +- backend/app/models/llm/request.py | 34 +- backend/app/services/llm/chain/chain.py | 30 +- backend/app/services/llm/chain/utils.py | 88 ++- backend/app/services/llm/jobs.py | 26 + backend/app/services/llm/providers/sai.py | 6 +- backend/app/tests/services/llm/test_sts.py | 572 ++++++++++++++++++ 9 files changed, 722 insertions(+), 561 deletions(-) delete mode 100644 backend/app/api/routes/llm_speech_examples.md rename backend/app/api/routes/{llm_speech.py => llm_sts.py} (72%) create mode 100644 backend/app/tests/services/llm/test_sts.py diff --git a/backend/app/api/main.py b/backend/app/api/main.py index deab6823e..858148064 100644 --- a/backend/app/api/main.py +++ b/backend/app/api/main.py @@ -7,11 +7,12 @@ config, doc_transformation_job, documents, + llm_sts, login, languages, llm, llm_chain, - llm_speech, + llm_sts, organization, openai_conversation, project, @@ -44,7 +45,7 @@ api_router.include_router(languages.router) api_router.include_router(llm.router) api_router.include_router(llm_chain.router) -api_router.include_router(llm_speech.router) +api_router.include_router(llm_sts.router) api_router.include_router(login.router) api_router.include_router(onboarding.router) api_router.include_router(openai_conversation.router) diff --git a/backend/app/api/routes/llm_speech_examples.md b/backend/app/api/routes/llm_speech_examples.md deleted file mode 100644 index e6d11e40f..000000000 --- a/backend/app/api/routes/llm_speech_examples.md +++ /dev/null @@ -1,470 +0,0 @@ -# Speech-to-Speech (STS) API Examples - -## Endpoint - -``` -POST /llm/sts -``` - -## Quick Start - -### Minimal Request (All Defaults) -```bash -curl -X POST https://api.kaapi.ai/llm/sts \ - -H "Authorization: Bearer YOUR_API_KEY" \ - -H "Content-Type: application/json" \ - -d '{ - "query": { - "type": "audio", - "content": { - "format": "base64", - "value": "BASE64_AUDIO_DATA", - "mime_type": "audio/ogg" - } - }, - "knowledge_base_ids": ["kb_abc123"], - "callback_url": "https://your-app.com/webhook" - }' -``` - -**Defaults Used:** -- Input language: Auto-detect (Sarvam STT) -- Output language: Hindi (same as detected input) -- STT: Sarvam Saaras V3 (auto language detection) -- LLM: OpenAI GPT-4o -- TTS: Sarvam Bulbul V3 - ---- - -## Full Configuration Example - -### Request with All Options -```json -{ - "query": { - "type": "audio", - "content": { - "format": "base64", - "value": "UklGRiQAAABXQVZFZm10...", - "mime_type": "audio/ogg" - } - }, - "knowledge_base_ids": ["kb_customer_support", "kb_product_info"], - "input_language": "hindi", - "output_language": "english", - "stt_model": "saaras:v3", - "llm_model": "gpt-4o", - "tts_model": "bulbul:v3", - "callback_url": "https://api.yourapp.com/webhooks/speech-response", - "request_metadata": { - "user_id": "user_123", - "session_id": "session_456", - "source": "whatsapp" - } -} -``` - -### Response (Immediate) -```json -{ - "success": true, - "data": { - "message": "Speech-to-speech processing initiated. You will receive intermediate callbacks for STT and LLM outputs, followed by the final callback with audio and text." - } -} -``` - ---- - -## Callback Sequence - -You'll receive **3 callbacks** to your webhook URL: - -### 1. STT Callback (Intermediate) -Sent after audio transcription completes. - -```json -{ - "success": true, - "data": { - "block_index": 1, - "total_blocks": 3, - "response": { - "provider_response_id": "stt_xyz789", - "provider": "google-native", - "model": "gemini-2.5-pro", - "output": { - "type": "text", - "content": { - "value": "मेरा अकाउंट बैलेंस क्या है?" - } - } - }, - "usage": { - "input_tokens": 0, - "output_tokens": 8, - "total_tokens": 8 - } - }, - "metadata": { - "speech_to_speech": true, - "input_language": "hi-IN", - "output_language": "en-IN", - "stt_model": "gemini-2.5-pro", - "llm_model": "gpt-4o", - "tts_model": "bulbul-v3", - "user_id": "user_123", - "session_id": "session_456", - "source": "whatsapp" - } -} -``` - -**Latency Calculation:** -``` -STT_latency = callback_1_timestamp - request_timestamp -``` - ---- - -### 2. LLM/RAG Callback (Intermediate) -Sent after knowledge base retrieval and response generation. - -```json -{ - "success": true, - "data": { - "block_index": 2, - "total_blocks": 3, - "response": { - "provider_response_id": "chatcmpl_abc123", - "conversation_id": null, - "provider": "openai", - "model": "gpt-4o", - "output": { - "type": "text", - "content": { - "value": "Your current account balance is ₹5,000. You have 3 transactions in the last month." - } - } - }, - "usage": { - "input_tokens": 250, - "output_tokens": 22, - "total_tokens": 272 - } - }, - "metadata": { - "speech_to_speech": true, - "user_id": "user_123", - "session_id": "session_456", - "source": "whatsapp" - } -} -``` - -**Latency Calculation:** -``` -LLM_latency = callback_2_timestamp - callback_1_timestamp -``` - ---- - -### 3. TTS Callback (Final) -Sent after text-to-speech conversion completes. This is your final output. - -```json -{ - "success": true, - "data": { - "response": { - "provider_response_id": "tts_def456", - "conversation_id": null, - "provider": "sarvamai-native", - "model": "bulbul:v1", - "output": { - "type": "audio", - "content": { - "format": "base64", - "value": "T2dnUwACAAAAAAAAAAAEBQ...", - "mime_type": "audio/ogg" - } - } - }, - "usage": { - "input_tokens": 22, - "output_tokens": 0, - "total_tokens": 22 - } - }, - "metadata": { - "speech_to_speech": true, - "output_language": "en-IN", - "user_id": "user_123", - "session_id": "session_456", - "source": "whatsapp" - } -} -``` - -**Latency Calculation:** -``` -TTS_latency = callback_3_timestamp - callback_2_timestamp -Total_latency = callback_3_timestamp - request_timestamp -``` - ---- - -## Error Handling Examples - -### Empty STT Output -If the audio contains no speech or is unintelligible: - -```json -{ - "success": false, - "error": "STT returned no transcription. The audio may be empty or unintelligible.", - "metadata": { - "speech_to_speech": true, - "user_id": "user_123" - } -} -``` - -### Invalid Audio Format -If the audio format is not supported: - -```json -{ - "success": false, - "error": "SarvamAI STT transcription failed: Invalid audio format. Supported formats: mp3, wav, ogg, opus, m4a", - "metadata": { - "speech_to_speech": true, - "user_id": "user_123" - } -} -``` - -### Audio Size Exceeds Limit -If TTS generates audio > 16MB (rare): - -```json -{ - "success": false, - "error": "TTS audio output exceeds WhatsApp size limit (16MB). Try reducing response length.", - "metadata": { - "speech_to_speech": true, - "user_id": "user_123" - } -} -``` - -### Knowledge Base Not Found -If specified knowledge base doesn't exist: - -```json -{ - "success": false, - "error": "Knowledge base 'kb_invalid' not found or not accessible.", - "metadata": { - "speech_to_speech": true, - "user_id": "user_123" - } -} -``` - ---- - -## Language-Specific Examples - -### English → English -```json -{ - "query": {...}, - "knowledge_base_ids": ["kb_123"], - "input_language": "english", - "output_language": "english", - "callback_url": "..." -} -``` - -### Hindi → English (Translation) -```json -{ - "query": {...}, - "knowledge_base_ids": ["kb_123"], - "input_language": "hindi", - "output_language": "english", - "callback_url": "..." -} -``` - -### Hinglish (Code-Switching) -```json -{ - "query": {...}, - "knowledge_base_ids": ["kb_123"], - "input_language": "hinglish", - "output_language": "hinglish", - "callback_url": "..." -} -``` -**Note:** Hinglish is treated as Hindi for model selection. - -### Regional Indian Languages -```json -{ - "query": {...}, - "knowledge_base_ids": ["kb_123"], - "input_language": "auto", // Auto-detect - "output_language": "odia", // Odia, Bengali, Punjabi, etc. - "callback_url": "..." -} -``` - -**Supported Regional Languages:** -- Bengali, Malayalam, Punjabi, Odia -- Assamese, Urdu, Nepali -- Konkani, Kashmiri, Sindhi, Sanskrit -- Santali, Manipuri, Bodo, Maithili, Dogri - ---- - -## Model Selection Guide - -### For Indian Languages (Recommended - Default) -```json -{ - "stt_model": "saaras:v3", - "llm_model": "gpt-4o", - "tts_model": "bulbul-v3" -} -``` -**Benefits:** -- Auto language detection (no need to specify language) -- Fastest processing -- Best accent handling for Indian languages -- Natural voice quality -- MP3 output (WhatsApp compatible) - -### For Maximum Accuracy -```json -{ - "stt_model": "gemini-2.5-pro", - "llm_model": "gpt-4o", - "tts_model": "gemini-2.5-pro-preview-tts" -} -``` -**Benefits:** Highest accuracy, best for complex queries, OGG OPUS output - -### For Cost Optimization -```json -{ - "stt_model": "saaras:v3", - "llm_model": "gpt-4o-mini", - "tts_model": "bulbul-v3" -} -``` -**Benefits:** Lower cost, still good quality, faster response - ---- - -## Integration Patterns - -### WhatsApp Bot Integration -```python -import base64 -import requests - -def handle_whatsapp_voice_message(audio_url, user_id): - # Download audio from WhatsApp - audio_response = requests.get(audio_url) - audio_base64 = base64.b64encode(audio_response.content).decode() - - # Send to Kaapi STS - response = requests.post( - "https://api.kaapi.ai/llm/sts", - headers={"Authorization": f"Bearer {API_KEY}"}, - json={ - "query": { - "type": "audio", - "content": { - "format": "base64", - "value": audio_base64, - "mime_type": "audio/ogg" - } - }, - "knowledge_base_ids": ["kb_customer_support"], - "callback_url": f"https://yourapp.com/webhook?user={user_id}", - "request_metadata": {"user_id": user_id} - } - ) - - return response.json() - -def handle_s2s_callback(callback_data): - """Handle the final TTS callback.""" - if not callback_data["success"]: - # Handle error - return - - # Extract final audio - audio_base64 = callback_data["data"]["response"]["output"]["content"]["value"] - audio_bytes = base64.b64decode(audio_base64) - - # Send back to WhatsApp user - send_whatsapp_voice(audio_bytes, user_id) -``` - ---- - -## Performance Benchmarks - -**Typical Latencies** (with Sarvam models, Hindi): -- STT: 1-2 seconds -- RAG: 2-4 seconds -- TTS: 1-2 seconds -- **Total: 4-8 seconds** - -**With Gemini models**: -- STT: 2-3 seconds -- RAG: 2-4 seconds -- TTS: 2-3 seconds -- **Total: 6-10 seconds** - ---- - -## Testing Tips - -1. **Test with Silent Audio**: Verify error handling for empty STT -2. **Test Different Formats**: OGG, MP3, WAV, M4A -3. **Test Language Mixing**: Hinglish, code-switching -4. **Test Long Audio**: >1 minute clips -5. **Load Test**: Multiple concurrent requests -6. **Monitor Latencies**: Track each block's timing -7. **Validate Audio Output**: Ensure < 16MB for WhatsApp - ---- - -## Troubleshooting - -### High Latency -- Check knowledge base size (larger = slower retrieval) -- Consider using faster models (gemini-flash, gpt-4o-mini) -- Verify callback URL response time - -### Poor Transcription Quality -- Ensure audio quality is good (no background noise) -- Try different STT models -- Check if language setting matches audio - -### Unnatural TTS Voice -- Try different TTS models -- Sarvam Bulbul is best for Indian accents -- Gemini is good for neutral accents - -### Callback Not Received -- Verify callback URL is publicly accessible -- Check for HTTPS (required) -- Ensure webhook can handle POST requests -- Check firewall settings diff --git a/backend/app/api/routes/llm_speech.py b/backend/app/api/routes/llm_sts.py similarity index 72% rename from backend/app/api/routes/llm_speech.py rename to backend/app/api/routes/llm_sts.py index e360eb4bd..6d1808c09 100644 --- a/backend/app/api/routes/llm_speech.py +++ b/backend/app/api/routes/llm_sts.py @@ -2,7 +2,7 @@ import logging -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from app.api.deps import AuthContextDep, SessionDep from app.api.permissions import Permission, require_permission @@ -13,11 +13,10 @@ SpeechToSpeechRequest, ) from app.services.llm.chain.utils import ( - LANGUAGE_CODES, + SUPPORTED_LANGUAGE_CODES, build_rag_block, build_stt_block, build_tts_block, - get_language_code, ) from app.services.llm.jobs import start_chain_job from app.utils import APIResponse, load_description, validate_callback_url @@ -61,28 +60,36 @@ def speech_to_speech( if request.callback_url: validate_callback_url(str(request.callback_url)) - # Validate and determine languages - if request.input_language and request.input_language != "auto": - if request.input_language not in LANGUAGE_CODES: - from fastapi import HTTPException - - raise HTTPException( - status_code=400, - detail=f"Unsupported input language: {request.input_language}. Supported: {', '.join(LANGUAGE_CODES.keys())}", - ) - - if request.output_language and request.output_language not in LANGUAGE_CODES: - from fastapi import HTTPException + # Validate BCP-47 language codes + if ( + request.input_language + and request.input_language not in SUPPORTED_LANGUAGE_CODES + ): + return APIResponse.failure_response( + error=f"Unsupported input language code: {request.input_language}. Supported: {', '.join(sorted(SUPPORTED_LANGUAGE_CODES))}", + metadata={"status_code": 400}, + ) - raise HTTPException( - status_code=400, - detail=f"Unsupported output language: {request.output_language}. Supported: {', '.join(LANGUAGE_CODES.keys())}", + if ( + request.output_language + and request.output_language not in SUPPORTED_LANGUAGE_CODES + ): + return APIResponse.failure_response( + error=f"Unsupported output language code: {request.output_language}. Supported: {', '.join(sorted(SUPPORTED_LANGUAGE_CODES))}", + metadata={"status_code": 400}, ) - input_lang_code = get_language_code(request.input_language) - output_lang_code = get_language_code( - request.output_language, default=request.input_language or "auto" - ) + # Determine language codes (already BCP-47, no conversion needed) + input_lang_code = request.input_language or "auto" + + # If output_language not set, default to input_language + # If input is "auto", use "{{detected}}" marker to signal TTS to use detected language + if request.output_language: + output_lang_code = request.output_language + elif input_lang_code == "auto": + output_lang_code = "{{detected}}" # Marker to use detected language from STT + else: + output_lang_code = input_lang_code logger.info( f"[speech_to_speech] Starting STS chain | " @@ -101,7 +108,6 @@ def speech_to_speech( build_tts_block(request.tts_model, output_lang_code), ] - # Add metadata to track STS-specific info metadata = request.request_metadata or {} metadata.update( { diff --git a/backend/app/models/llm/request.py b/backend/app/models/llm/request.py index 4ea288fe3..a9e59e6fb 100644 --- a/backend/app/models/llm/request.py +++ b/backend/app/models/llm/request.py @@ -93,6 +93,9 @@ class TTSLLMParams(SQLModel): class TextContent(SQLModel): format: Literal["text"] = "text" value: str = Field(..., description="Text content") + language_code: str | None = Field( + None, description="Optional detected language code in STT 'auto' mode" + ) class AudioContent(SQLModel): @@ -799,21 +802,21 @@ class SpeechToSpeechRequest(SQLModel): ..., min_length=1, description="Knowledge base IDs for RAG retrieval" ) - # Optional language config + # Optional language config (BCP-47 codes) input_language: str | None = Field( "auto", description=( - "Input language for STT (auto-detect by default). " - "Supported: auto, english, hindi, hinglish, bengali, kannada, malayalam, marathi, " - "odia, punjabi, tamil, telugu, gujarati, assamese, urdu, nepali, konkani, kashmiri, " - "sindhi, sanskrit, santali, manipuri, bodo, maithili, dogri" + "BCP-47 language code for STT input (auto-detect by default). " + "Supported codes: 'auto', 'en-IN', 'hi-IN', 'bn-IN', 'kn-IN', 'ml-IN', 'mr-IN', 'od-IN', " + "'pa-IN', 'ta-IN', 'te-IN', 'gu-IN', 'as-IN', 'ur-IN', 'ne-IN', 'kok-IN', 'ks-IN', " + "'sd-IN', 'sa-IN', 'sat-IN', 'mni-IN', 'brx-IN', 'mai-IN', 'doi-IN'" ), ) output_language: str | None = Field( None, description=( - "Output language for TTS (defaults to input_language if not specified). " - "Same language options as input_language." + "BCP-47 language code for TTS output (defaults to input_language if not specified). " + "Supported codes: same as input_language (except 'auto')." ), ) @@ -838,11 +841,18 @@ class SpeechToSpeechRequest(SQLModel): @model_validator(mode="after") def validate_languages(self): - """Validate language fields.""" - # Validation happens in the endpoint using LANGUAGE_CODES from utils - # This is just to ensure the fields are lowercase if provided + """Normalize BCP-47 language codes to standard format (e.g., 'hi-in' -> 'hi-IN').""" + # Normalize input_language if self.input_language and self.input_language != "auto": - self.input_language = self.input_language.lower() + # Normalize BCP-47: lowercase language, uppercase region (e.g., "hi-IN") + parts = self.input_language.split("-") + if len(parts) == 2: + self.input_language = f"{parts[0].lower()}-{parts[1].upper()}" + + # Normalize output_language if self.output_language: - self.output_language = self.output_language.lower() + parts = self.output_language.split("-") + if len(parts) == 2: + self.output_language = f"{parts[0].lower()}-{parts[1].upper()}" + return self diff --git a/backend/app/services/llm/chain/chain.py b/backend/app/services/llm/chain/chain.py index ad0503675..f0bac46f9 100644 --- a/backend/app/services/llm/chain/chain.py +++ b/backend/app/services/llm/chain/chain.py @@ -36,6 +36,9 @@ class ChainContext: langfuse_credentials: dict[str, Any] | None = None request_metadata: dict | None = None intermediate_callback_flags: list[bool] = field(default_factory=list) + detected_language: str | None = ( + None # Stores language detected by STT for use by TTS + ) aggregated_usage: Usage = field( default_factory=lambda: Usage( input_tokens=0, @@ -45,17 +48,37 @@ class ChainContext: ) -def result_to_query(result: BlockResult) -> QueryParams: +def result_to_query( + result: BlockResult, context: ChainContext | None = None +) -> QueryParams: """Convert a block's output into the next block's QueryParams. Text output → TextInput query Audio output → AudioInput query + + Also preserves language_code from STT output for use by downstream TTS blocks. """ output = result.response.response.output if isinstance(output, TextOutput): + # Preserve language_code if present (from STT auto-detection) + language_code = ( + output.content.language_code + if hasattr(output.content, "language_code") + else None + ) + + # Store detected language in context for TTS to use + if context and language_code: + context.detected_language = language_code + logger.info(f"[result_to_query] Detected language: {language_code}") + return QueryParams( - input=TextInput(content=TextContent(value=output.content.value)) + input=TextInput( + content=TextContent( + value=output.content.value, language_code=language_code + ) + ) ) elif isinstance(output, AudioOutput): return QueryParams(input=AudioInput(content=output.content)) @@ -96,6 +119,7 @@ def execute(self, query: QueryParams) -> BlockResult: langfuse_credentials=self._context.langfuse_credentials, include_provider_raw_response=self._include_provider_raw_response, chain_id=self._context.chain_id, + detected_language=self._context.detected_language, ) @@ -132,6 +156,6 @@ def execute( return result if block is not self._blocks[-1]: - current_query = result_to_query(result) + current_query = result_to_query(result, self._context) return result diff --git a/backend/app/services/llm/chain/utils.py b/backend/app/services/llm/chain/utils.py index 37bf0942d..223530a6d 100644 --- a/backend/app/services/llm/chain/utils.py +++ b/backend/app/services/llm/chain/utils.py @@ -15,59 +15,46 @@ ) -# Supported languages for speech-to-speech (BCP-47 language codes) -LANGUAGE_CODES = { +# Supported BCP-47 language codes for speech-to-speech +# These are the valid values that can be used directly in API requests +SUPPORTED_LANGUAGE_CODES = { # Auto-detect - "auto": "unknown", # Sarvam auto-detection - # Primary Indian languages - "english": "en-IN", - "hindi": "hi-IN", - "hinglish": "hi-IN", # Code-switching, treat as Hindi - "bengali": "bn-IN", - "kannada": "kn-IN", - "malayalam": "ml-IN", - "marathi": "mr-IN", - "odia": "od-IN", - "punjabi": "pa-IN", - "tamil": "ta-IN", - "telugu": "te-IN", - "gujarati": "gu-IN", + "auto", # Auto-detection (maps to "unknown" for Sarvam) + "unknown", # Explicit unknown for Sarvam + # Primary Indian languages (BCP-47 codes) + "en-IN", # English + "hi-IN", # Hindi (also used for Hinglish/code-switching) + "bn-IN", # Bengali + "kn-IN", # Kannada + "ml-IN", # Malayalam + "mr-IN", # Marathi + "od-IN", # Odia + "pa-IN", # Punjabi + "ta-IN", # Tamil + "te-IN", # Telugu + "gu-IN", # Gujarati # Additional languages (saaras:v3) - "assamese": "as-IN", - "urdu": "ur-IN", - "nepali": "ne-IN", - "konkani": "kok-IN", - "kashmiri": "ks-IN", - "sindhi": "sd-IN", - "sanskrit": "sa-IN", - "santali": "sat-IN", - "manipuri": "mni-IN", - "bodo": "brx-IN", - "maithili": "mai-IN", - "dogri": "doi-IN", + "as-IN", # Assamese + "ur-IN", # Urdu + "ne-IN", # Nepali + "kok-IN", # Konkani + "ks-IN", # Kashmiri + "sd-IN", # Sindhi + "sa-IN", # Sanskrit + "sat-IN", # Santali + "mni-IN", # Manipuri + "brx-IN", # Bodo + "mai-IN", # Maithili + "doi-IN", # Dogri } -def get_language_code(language: str | None, default: str = "auto") -> str: - """Convert language name to BCP-47 language code. - - Args: - language: Language name (e.g., "hindi", "english", "auto") - default: Default language if not specified (default: "auto") - - Returns: - BCP-47 language code (e.g., "hi-IN", "en-IN", "unknown" for auto-detect) - """ - lang = (language or default).lower() - return LANGUAGE_CODES.get(lang, LANGUAGE_CODES["auto"]) - - def build_stt_block(model: STTModel, language_code: str) -> ChainBlock: """Build STT (Speech-to-Text) block configuration. Args: model: STT model enum - language_code: ISO language code (e.g., "hi-IN") + language_code: BCP-47 language code (e.g., "hi-IN", "en-IN") or "auto" for auto-detection Returns: ChainBlock configured for STT @@ -90,15 +77,15 @@ def build_stt_block(model: STTModel, language_code: str) -> ChainBlock: # Add provider-specific parameters if provider == "sarvamai-native": - # Use "unknown" for automatic language detection, or specific BCP-47 code + # Map "auto" to "unknown" for Sarvam auto-detection params["language_code"] = ( - language_code if language_code != "unknown" else "unknown" + "unknown" if language_code == "auto" else language_code ) params["mode"] = "transcribe" elif provider == "google-native": - # Google requires specific language code, fallback to en-IN if unknown + # Google requires specific language code, fallback to en-IN if auto/unknown params["language_code"] = ( - language_code if language_code != "unknown" else "en-IN" + "en-IN" if language_code in ("auto", "unknown") else language_code ) return ChainBlock( @@ -151,7 +138,7 @@ def build_tts_block(model: TTSModel, language_code: str = "en-IN") -> ChainBlock Args: model: TTS model enum - language_code: ISO language code (e.g., "hi-IN") + language_code: ISO language code (e.g., "hi-IN"), or "{{detected}}" to use language detected by STT Returns: ChainBlock configured for TTS @@ -175,9 +162,10 @@ def build_tts_block(model: TTSModel, language_code: str = "en-IN") -> ChainBlock # Add provider-specific parameters if provider == "sarvamai-native": - params["target_language_code"] = "en-IN" + # Use language_code (can be "{{detected}}" marker or actual code) + params["target_language_code"] = language_code params["speaker"] = voice - params["output_audio_codec"] = "mp3" # WhatsApp compatible + params["output_audio_codec"] = "opus" # WhatsApp compatible elif provider == "google-native": params["language_code"] = language_code params["audio_encoding"] = "OGG_OPUS" # WhatsApp compatible diff --git a/backend/app/services/llm/jobs.py b/backend/app/services/llm/jobs.py index 2a5f7dee2..621c8d912 100644 --- a/backend/app/services/llm/jobs.py +++ b/backend/app/services/llm/jobs.py @@ -23,6 +23,7 @@ ImageInput, KaapiCompletionConfig, LLMCallConfig, + NativeCompletionConfig, PDFInput, QueryParams, TextInput, @@ -335,10 +336,14 @@ def execute_llm_call( langfuse_credentials: dict | None, include_provider_raw_response: bool = False, chain_id: UUID | None = None, + detected_language: str | None = None, ) -> BlockResult: """Execute a single LLM call. Shared by /llm/call and /llm/chain. Returns BlockResult with response + usage on success, or error on failure. + + Args: + detected_language: Language code detected by STT (used to replace {{detected}} marker in TTS) """ config_blob: ConfigBlob | None = None @@ -382,6 +387,27 @@ def execute_llm_call( request_metadata = {} request_metadata.setdefault("warnings", []).extend(warnings) + # Replace {{detected}} marker in TTS configs with actual detected language + if ( + isinstance(completion_config, NativeCompletionConfig) + and completion_config.type == "tts" + ): + params = completion_config.params + # Replace {{detected}} marker in any language-related params + for key in ["target_language_code", "language_code"]: + if key in params and params[key] == "{{detected}}": + if detected_language: + params[key] = detected_language + logger.info( + f"[execute_llm_call] Using detected language for TTS: {detected_language} | job_id={job_id}" + ) + else: + # Fallback to English if no language was detected + params[key] = "en-IN" + logger.warning( + f"[execute_llm_call] No language detected, falling back to en-IN for TTS | job_id={job_id}" + ) + resolved_config_blob = ConfigBlob( completion=completion_config, prompt_template=config_blob.prompt_template, diff --git a/backend/app/services/llm/providers/sai.py b/backend/app/services/llm/providers/sai.py index c2984e6aa..4f4170d7b 100644 --- a/backend/app/services/llm/providers/sai.py +++ b/backend/app/services/llm/providers/sai.py @@ -111,7 +111,10 @@ def _execute_stt( provider=provider_name, model=model, output=TextOutput( - content=TextContent(value=sarvam_response.transcript) + content=TextContent( + value=sarvam_response.transcript, + language_code=sarvam_response.language_code, + ) ), ), usage=Usage( @@ -184,6 +187,7 @@ def _execute_tts( target_language_code=target_language_code, model=model, speaker=speaker, + speech_sample_rate=16000, output_audio_codec=output_audio_codec, ) diff --git a/backend/app/tests/services/llm/test_sts.py b/backend/app/tests/services/llm/test_sts.py new file mode 100644 index 000000000..7822bdc00 --- /dev/null +++ b/backend/app/tests/services/llm/test_sts.py @@ -0,0 +1,572 @@ +""" +Test cases for Speech-to-Speech (STS) functionality. + +Tests cover: +1. Language detection and propagation through STT → RAG → TTS chain +2. BCP-47 language code validation +3. Real-world use cases (auto-detection, explicit languages, cross-language) +""" + +from unittest.mock import patch, MagicMock +from uuid import uuid4 + +import pytest +from fastapi.testclient import TestClient + +from app.models.llm.request import ( + AudioContent, + AudioInput, + LLMModel, + STTModel, + SpeechToSpeechRequest, + TTSModel, +) +from app.models.llm.response import ( + LLMCallResponse, + LLMResponse, + TextOutput, + TextContent as ResponseTextContent, + Usage, +) +from app.services.llm.chain.chain import ChainContext, result_to_query +from app.services.llm.chain.types import BlockResult +from app.services.llm.chain.utils import ( + SUPPORTED_LANGUAGE_CODES, + build_stt_block, + build_tts_block, +) + + +# ============================================================================ +# Unit Tests: Language Detection Flow +# ============================================================================ + + +class TestLanguageDetectionFlow: + """Test language detection and propagation through the chain.""" + + def test_result_to_query_preserves_language_code(self): + """STT output with language_code should be preserved when converting to next block's input.""" + # Simulate STT response with detected Hindi + stt_response = LLMCallResponse( + response=LLMResponse( + provider_response_id="stt-resp-1", + conversation_id=None, + model="saaras:v3", + provider="sarvamai-native", + output=TextOutput( + content=ResponseTextContent( + value="नमस्ते, आप कैसे हैं?", language_code="hi-IN" + ) + ), + ), + usage=Usage(input_tokens=0, output_tokens=10, total_tokens=10), + ) + + result = BlockResult(response=stt_response, usage=stt_response.usage) + context = ChainContext( + job_id=uuid4(), + chain_id=uuid4(), + project_id=1, + organization_id=1, + callback_url=None, + total_blocks=3, + ) + + # Convert STT output to RAG input + query = result_to_query(result, context) + + # Language code should be preserved + assert query.input.content.language_code == "hi-IN" + assert query.input.content.value == "नमस्ते, आप कैसे हैं?" + + # Context should store detected language for TTS + assert context.detected_language == "hi-IN" + + def test_result_to_query_without_language_code(self): + """RAG output without language_code should not break the chain.""" + # Simulate RAG response (no language_code) + rag_response = LLMCallResponse( + response=LLMResponse( + provider_response_id="rag-resp-1", + conversation_id=None, + model="gpt-4o", + provider="openai", + output=TextOutput( + content=ResponseTextContent( + value="The capital of India is New Delhi." + ) + ), + ), + usage=Usage(input_tokens=50, output_tokens=12, total_tokens=62), + ) + + result = BlockResult(response=rag_response, usage=rag_response.usage) + context = ChainContext( + job_id=uuid4(), + chain_id=uuid4(), + project_id=1, + organization_id=1, + callback_url=None, + total_blocks=3, + detected_language="hi-IN", # From previous STT block + ) + + # Convert RAG output to TTS input + query = result_to_query(result, context) + + # Should work fine even without language_code + assert query.input.content.value == "The capital of India is New Delhi." + # Context should retain previously detected language + assert context.detected_language == "hi-IN" + + def test_detected_marker_replacement(self): + """{{detected}} marker in TTS should be replaced with actual detected language.""" + from app.services.llm.jobs import execute_llm_call + from app.models.llm.request import ( + LLMCallConfig, + ConfigBlob, + NativeCompletionConfig, + QueryParams, + ) + + config = LLMCallConfig( + blob=ConfigBlob( + completion=NativeCompletionConfig( + provider="sarvamai-native", + type="tts", + params={ + "model": "bulbul:v3", + "voice": "simran", + "target_language_code": "{{detected}}", # Marker to be replaced + "speaker": "simran", + "output_audio_codec": "opus", + }, + ) + ) + ) + + with patch("app.services.llm.jobs.get_llm_provider") as mock_provider, patch( + "app.services.llm.jobs.Session" + ): + mock_provider_instance = MagicMock() + mock_provider.return_value = mock_provider_instance + mock_provider_instance.execute.return_value = (None, "test error") + + # Call with detected_language + execute_llm_call( + config=config, + query=QueryParams(input="Test text"), + job_id=uuid4(), + project_id=1, + organization_id=1, + request_metadata=None, + langfuse_credentials=None, + detected_language="ta-IN", # Detected Tamil + ) + + # Verify {{detected}} was replaced with ta-IN + # The marker replacement happens in execute_llm_call before provider.execute is called + # So we check the modified config params + call_args = mock_provider_instance.execute.call_args + # execute is called with (completion_config, query, resolved_input, include_provider_raw_response) + if call_args: + completion_config = ( + call_args[1]["completion_config"] + if len(call_args) > 1 and "completion_config" in call_args[1] + else call_args[0][0] + ) + assert completion_config.params["target_language_code"] == "ta-IN" + + +# ============================================================================ +# Unit Tests: Block Building +# ============================================================================ + + +class TestSTSBlockBuilding: + """Test STT and TTS block configuration.""" + + def test_build_stt_block_with_auto(self): + """Auto language should map to 'unknown' for Sarvam.""" + block = build_stt_block(STTModel.SARVAM, "auto") + + params = block.config.blob.completion.params + assert params["language_code"] == "unknown" + assert params["model"] == "saaras:v3" + assert params["mode"] == "transcribe" + + def test_build_stt_block_with_specific_language(self): + """Specific BCP-47 code should be used as-is.""" + block = build_stt_block(STTModel.SARVAM, "hi-IN") + + params = block.config.blob.completion.params + assert params["language_code"] == "hi-IN" + + def test_build_tts_block_with_detected_marker(self): + """TTS should accept {{detected}} marker for dynamic language.""" + block = build_tts_block(TTSModel.SARVAM, "{{detected}}") + + params = block.config.blob.completion.params + assert params["target_language_code"] == "{{detected}}" + assert params["model"] == "bulbul:v3" + + def test_build_tts_block_with_specific_language(self): + """TTS should accept specific BCP-47 codes.""" + block = build_tts_block(TTSModel.SARVAM, "ta-IN") + + params = block.config.blob.completion.params + assert params["target_language_code"] == "ta-IN" + + +# ============================================================================ +# Integration Tests: Speech-to-Speech Endpoint +# ============================================================================ + + +@pytest.fixture +def mock_audio_input(): + """Sample audio input (base64 encoded).""" + return AudioInput( + type="audio", + content=AudioContent( + format="base64", + value="SUQzBAAAAAAAI1RTU0UAAAAPAAADTGF2ZjU4Ljc2LjEwMAAAAAAAAAAAAAAA//...", + mime_type="audio/ogg", + ), + ) + + +@pytest.fixture +def knowledge_base_ids(): + """Sample knowledge base IDs.""" + return ["kb-india-facts", "kb-general-knowledge"] + + +class TestSpeechToSpeechEndpoint: + """Test the /llm/sts endpoint with realistic scenarios.""" + + def test_sts_auto_detection_hindi_to_hindi( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Real-world scenario: User sends Hindi voice note, expects Hindi response. + Most common use case - auto-detect input, same language output. + """ + with patch("app.api.routes.llm_sts.start_chain_job") as mock_start_job: + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="auto", # Auto-detect + output_language=None, # Should default to detected language + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O, + callback_url="https://example.com/callback", + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + assert "Speech-to-speech processing initiated" in data["data"]["message"] + + # Verify job was started + mock_start_job.assert_called_once() + + def test_sts_explicit_tamil_to_tamil( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Scenario: Tamil user explicitly sets language to avoid auto-detection. + Use case: Better accuracy when language is known. + """ + with patch("app.api.routes.llm_sts.start_chain_job") as mock_start_job: + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="ta-IN", + output_language="ta-IN", + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O_MINI, + callback_url="https://example.com/callback", + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_start_job.assert_called_once() + + def test_sts_cross_language_hindi_to_english( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Scenario: User speaks Hindi but wants response in English. + Use case: Language learning, multilingual support. + """ + with patch("app.api.routes.llm_sts.start_chain_job") as mock_start_job: + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="hi-IN", + output_language="en-IN", # Respond in English + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O, + callback_url="https://example.com/callback", + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_start_job.assert_called_once() + + def test_sts_invalid_input_language_code( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Error case: User provides invalid BCP-47 code. + Should reject with clear error message. + """ + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="hindi", # Invalid - should be 'hi-IN' + output_language="en-IN", + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O, + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 # API returns 200 with error in body + data = response.json() + assert data["success"] is False + assert "Unsupported input language code" in data["error"] + assert "hindi" in data["error"] + + def test_sts_invalid_output_language_code( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Error case: Invalid output language code. + """ + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="hi-IN", + output_language="french", # Invalid - should be BCP-47 + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O, + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is False + assert "Unsupported output language code" in data["error"] + + def test_sts_case_insensitive_language_codes( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + User-friendly case: BCP-47 codes should be case-insensitive. + 'hi-in' should be normalized to 'hi-IN'. + """ + with patch("app.api.routes.llm_sts.start_chain_job") as mock_start_job: + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="hi-in", # Lowercase + output_language="en-in", # Lowercase + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O, + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_start_job.assert_called_once() + + def test_sts_regional_languages( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Test support for regional Indian languages. + Scenario: Malayalam speaker from Kerala. + """ + with patch("app.api.routes.llm_sts.start_chain_job") as mock_start_job: + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="ml-IN", # Malayalam + output_language="ml-IN", + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O_MINI, + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_start_job.assert_called_once() + + def test_sts_without_callback_url( + self, + client: TestClient, + user_api_key_header: dict[str, str], + mock_audio_input, + knowledge_base_ids, + ): + """ + Callback URL is optional - job should still start. + """ + with patch("app.api.routes.llm_sts.start_chain_job") as mock_start_job: + payload = SpeechToSpeechRequest( + query=mock_audio_input, + knowledge_base_ids=knowledge_base_ids, + input_language="auto", + stt_model=STTModel.SARVAM, + tts_model=TTSModel.SARVAM, + llm_model=LLMModel.GPT4O, + # No callback_url + ) + + response = client.post( + "api/v1/llm/sts", + json=payload.model_dump(mode="json"), + headers=user_api_key_header, + ) + + assert response.status_code == 200 + data = response.json() + assert data["success"] is True + mock_start_job.assert_called_once() + + +# ============================================================================ +# Unit Tests: Language Code Validation +# ============================================================================ + + +class TestLanguageCodeSupport: + """Verify all supported BCP-47 codes are valid.""" + + def test_all_supported_codes_are_valid(self): + """All codes in SUPPORTED_LANGUAGE_CODES should be valid BCP-47 format.""" + valid_codes = { + "auto", + "unknown", + "en-IN", + "hi-IN", + "bn-IN", + "kn-IN", + "ml-IN", + "mr-IN", + "od-IN", + "pa-IN", + "ta-IN", + "te-IN", + "gu-IN", + "as-IN", + "ur-IN", + "ne-IN", + "kok-IN", + "ks-IN", + "sd-IN", + "sa-IN", + "sat-IN", + "mni-IN", + "brx-IN", + "mai-IN", + "doi-IN", + } + + assert SUPPORTED_LANGUAGE_CODES == valid_codes + + def test_major_indian_languages_supported(self): + """Verify major Indian languages are supported.""" + major_languages = { + "hi-IN", # Hindi + "bn-IN", # Bengali + "te-IN", # Telugu + "mr-IN", # Marathi + "ta-IN", # Tamil + "ur-IN", # Urdu + "gu-IN", # Gujarati + "kn-IN", # Kannada + "ml-IN", # Malayalam + "pa-IN", # Punjabi + } + + assert major_languages.issubset(SUPPORTED_LANGUAGE_CODES) From 96ea78e6749b03397ecdba85355ab361e487cf78 Mon Sep 17 00:00:00 2001 From: Prajna1999 Date: Mon, 9 Mar 2026 18:45:03 +0530 Subject: [PATCH 15/15] chore: docs --- backend/app/api/routes/llm_sts.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/backend/app/api/routes/llm_sts.py b/backend/app/api/routes/llm_sts.py index 6d1808c09..d379140eb 100644 --- a/backend/app/api/routes/llm_sts.py +++ b/backend/app/api/routes/llm_sts.py @@ -46,12 +46,9 @@ def speech_to_speech( 3. TTS (Text-to-Speech) - Converts response back to audio Input: Voice note (WhatsApp compatible) - Output: Voice note + text (via callback) + Output 1: Voice note + Output 2: text (via intermediate callback) - Edge cases: - - Empty STT output: Chain fails with clear error - - Audio > 16MB: TTS provider will fail (caught and reported) - - Invalid audio format: STT provider will fail (caught and reported) """ project_id = _current_user.project_.id organization_id = _current_user.organization_.id