diff --git a/app/features/agents/agents/experiment.py b/app/features/agents/agents/experiment.py index 19c5310a..47d594fd 100644 --- a/app/features/agents/agents/experiment.py +++ b/app/features/agents/agents/experiment.py @@ -13,7 +13,7 @@ from typing import Any, Literal import structlog -from pydantic_ai import Agent, RunContext +from pydantic_ai import Agent, PromptedOutput, RunContext from app.features.agents.agents.base import ( SAFETY_INSTRUCTIONS, @@ -90,7 +90,10 @@ def create_experiment_agent() -> Agent[AgentDeps, ExperimentReport]: agent: Agent[AgentDeps, ExperimentReport] = Agent( model=model, deps_type=AgentDeps, - output_type=ExperimentReport, + # PromptedOutput puts the JSON schema in the prompt and parses the + # model's text reply, instead of the default ToolOutput mode which + # weaker/local models fail to satisfy (issue #173). + output_type=PromptedOutput(ExperimentReport), system_prompt=EXPERIMENT_SYSTEM_PROMPT, # Apply the configured agent_retry_attempts. Without this PydanticAI # defaults to 1, and weaker models fail structured-output validation. diff --git a/app/features/agents/agents/rag_assistant.py b/app/features/agents/agents/rag_assistant.py index b935a856..5e448510 100644 --- a/app/features/agents/agents/rag_assistant.py +++ b/app/features/agents/agents/rag_assistant.py @@ -12,7 +12,7 @@ from typing import Any import structlog -from pydantic_ai import Agent, RunContext +from pydantic_ai import Agent, PromptedOutput, RunContext from app.core.config import get_settings from app.features.agents.agents.base import ( @@ -85,7 +85,10 @@ def create_rag_assistant_agent() -> Agent[AgentDeps, RAGAnswer]: agent: Agent[AgentDeps, RAGAnswer] = Agent( model=model, deps_type=AgentDeps, - output_type=RAGAnswer, + # PromptedOutput puts the JSON schema in the prompt and parses the + # model's text reply, instead of the default ToolOutput mode which + # weaker/local models fail to satisfy (issue #173). + output_type=PromptedOutput(RAGAnswer), system_prompt=RAG_SYSTEM_PROMPT, # Apply the configured agent_retry_attempts. Without this PydanticAI # defaults to 1, and weaker models fail structured-output validation. diff --git a/app/features/agents/service.py b/app/features/agents/service.py index 45d7f805..751c984c 100644 --- a/app/features/agents/service.py +++ b/app/features/agents/service.py @@ -15,6 +15,7 @@ import asyncio import uuid from collections.abc import AsyncIterator +from contextlib import AbstractContextManager from datetime import UTC, datetime, timedelta from typing import Any, Literal, cast @@ -58,6 +59,22 @@ class NoApprovalPendingError(ValueError): pass +def _sequential_tool_execution() -> AbstractContextManager[None]: + """Run an agent turn's tool calls one at a time, never concurrently. + + Every tool in a run shares the single ``AgentDeps.db`` ``AsyncSession``, + and SQLAlchemy forbids concurrent operations on one session. PydanticAI's + default parallel tool execution therefore raises ``InvalidRequestError`` + whenever a model emits more than one DB-touching tool call in a turn + (issue #172). + + Both :meth:`AgentService.chat` and :meth:`AgentService.stream_chat` wrap + their agent run in this context, so the execution-mode policy lives in + exactly one place. + """ + return Agent.parallel_tool_call_execution_mode("sequential") + + class AgentService: """Service for managing agent sessions and interactions. @@ -250,14 +267,15 @@ async def chat( ) try: - result = await asyncio.wait_for( - agent.run( - message, - deps=deps, - message_history=message_history, - ), - timeout=self.settings.agent_timeout_seconds, - ) + with _sequential_tool_execution(): + result = await asyncio.wait_for( + agent.run( + message, + deps=deps, + message_history=message_history, + ), + timeout=self.settings.agent_timeout_seconds, + ) except TimeoutError as e: raise TimeoutError( f"Agent response timed out after {self.settings.agent_timeout_seconds} seconds" @@ -440,144 +458,151 @@ async def stream_chat( # Stream the response try: - async with asyncio.timeout(self.settings.agent_timeout_seconds): - async with agent.run_stream( - message, - deps=deps, - message_history=message_history, - ) as result: - try: - async for text in result.stream_text(): - yield StreamEvent( - event_type="text_delta", - data={"delta": text}, - timestamp=datetime.now(UTC), + with _sequential_tool_execution(): + async with asyncio.timeout(self.settings.agent_timeout_seconds): + async with agent.run_stream( + message, + deps=deps, + message_history=message_history, + ) as result: + try: + async for text in result.stream_text(): + yield StreamEvent( + event_type="text_delta", + data={"delta": text}, + timestamp=datetime.now(UTC), + ) + except Exception as e: + # Structured output agents (output_type=...) cannot stream raw text deltas. + # In that case we skip delta streaming and only emit the final complete event. + logger.info( + "agents.stream_chat_text_delta_unavailable", + session_id=session_id, + error=str(e), + error_type=type(e).__name__, ) - except Exception as e: - # Structured output agents (output_type=...) cannot stream raw text deltas. - # In that case we skip delta streaming and only emit the final complete event. - logger.info( - "agents.stream_chat_text_delta_unavailable", - session_id=session_id, - error=str(e), - error_type=type(e).__name__, + + # Get final result and update session + # NOTE: PydanticAI v1.48 exposes get_output() on StreamedRunResult. + final_result: Any = await result.get_output() + usage = result.usage() + + session.message_history = self._serialize_messages(result.all_messages()) + session.total_tokens_used += usage.total_tokens or 0 + session.tool_calls_count += deps.tool_call_count + session.last_activity = datetime.now(UTC) + session.expires_at = session.last_activity + timedelta( + minutes=self.settings.agent_session_ttl_minutes ) - # Get final result and update session - # NOTE: PydanticAI v1.48 exposes get_output() on StreamedRunResult. - final_result: Any = await result.get_output() - usage = result.usage() - - session.message_history = self._serialize_messages(result.all_messages()) - session.total_tokens_used += usage.total_tokens or 0 - session.tool_calls_count += deps.tool_call_count - session.last_activity = datetime.now(UTC) - session.expires_at = session.last_activity + timedelta( - minutes=self.settings.agent_session_ttl_minutes - ) - - await db.flush() - - # Check for pending approval actions (mirror chat() logic) - pending_action = None - pending_approval = False - stream_now = datetime.now(UTC) - - # Check for pending_action in result data (primary trigger) - if hasattr(final_result, "pending_action") and final_result.pending_action: - pending_approval = True - pending_action_data = final_result.pending_action - # Extract action details - support both dict and object with attributes - if isinstance(pending_action_data, dict): - action_type = pending_action_data.get("action_type", "unknown") - arguments = pending_action_data.get("arguments", {}) - description = pending_action_data.get( - "description", f"Agent requested approval for {action_type}" - ) - else: - action_type = getattr(pending_action_data, "action_type", "unknown") - arguments = getattr(pending_action_data, "arguments", {}) - description = getattr( - pending_action_data, - "description", - f"Agent requested approval for {action_type}", + await db.flush() + + # Check for pending approval actions (mirror chat() logic) + pending_action = None + pending_approval = False + stream_now = datetime.now(UTC) + + # Check for pending_action in result data (primary trigger) + if hasattr(final_result, "pending_action") and final_result.pending_action: + pending_approval = True + pending_action_data = final_result.pending_action + # Extract action details - support both dict and object with attributes + if isinstance(pending_action_data, dict): + action_type = pending_action_data.get("action_type", "unknown") + arguments = pending_action_data.get("arguments", {}) + description = pending_action_data.get( + "description", f"Agent requested approval for {action_type}" + ) + else: + action_type = getattr(pending_action_data, "action_type", "unknown") + arguments = getattr(pending_action_data, "arguments", {}) + description = getattr( + pending_action_data, + "description", + f"Agent requested approval for {action_type}", + ) + + session.pending_action = { + "action_id": uuid.uuid4().hex[:16], + "action_type": action_type, + "description": description, + "arguments": arguments, + "created_at": stream_now.isoformat(), + "expires_at": ( + stream_now + + timedelta( + minutes=self.settings.agent_approval_timeout_minutes + ) + ).isoformat(), + } + session.status = SessionStatus.AWAITING_APPROVAL.value + pending_action = self._format_pending_action(session.pending_action) + # Fallback: check approval_required flag (legacy trigger) + elif ( + hasattr(final_result, "approval_required") + and final_result.approval_required + ): + pending_approval = True + session.pending_action = { + "action_id": uuid.uuid4().hex[:16], + "action_type": "unknown", + "description": "Agent requested approval for an action", + "arguments": {}, + "created_at": stream_now.isoformat(), + "expires_at": ( + stream_now + + timedelta( + minutes=self.settings.agent_approval_timeout_minutes + ) + ).isoformat(), + } + session.status = SessionStatus.AWAITING_APPROVAL.value + pending_action = self._format_pending_action(session.pending_action) + + await db.flush() + + # If approval is required, emit approval_required event + if pending_approval and pending_action: + yield StreamEvent( + event_type="approval_required", + data={ + "action": pending_action, + "message": "Human approval required before proceeding.", + }, + timestamp=stream_now, ) - session.pending_action = { - "action_id": uuid.uuid4().hex[:16], - "action_type": action_type, - "description": description, - "arguments": arguments, - "created_at": stream_now.isoformat(), - "expires_at": ( - stream_now - + timedelta(minutes=self.settings.agent_approval_timeout_minutes) - ).isoformat(), - } - session.status = SessionStatus.AWAITING_APPROVAL.value - pending_action = self._format_pending_action(session.pending_action) - # Fallback: check approval_required flag (legacy trigger) - elif ( - hasattr(final_result, "approval_required") - and final_result.approval_required - ): - pending_approval = True - session.pending_action = { - "action_id": uuid.uuid4().hex[:16], - "action_type": "unknown", - "description": "Agent requested approval for an action", - "arguments": {}, - "created_at": stream_now.isoformat(), - "expires_at": ( - stream_now - + timedelta(minutes=self.settings.agent_approval_timeout_minutes) - ).isoformat(), - } - session.status = SessionStatus.AWAITING_APPROVAL.value - pending_action = self._format_pending_action(session.pending_action) - - await db.flush() - - # If approval is required, emit approval_required event - if pending_approval and pending_action: + # Yield completion event + response_message: str = "No response generated." + if final_result: + if hasattr(final_result, "answer") and final_result.answer: + response_message = str(final_result.answer) + elif hasattr(final_result, "summary") and final_result.summary: + response_message = str(final_result.summary) + elif ( + hasattr(final_result, "recommendations") + and final_result.recommendations + ): + recommendations = final_result.recommendations + if isinstance(recommendations, list) and recommendations: + response_message = "\n".join( + str(item) for item in recommendations + ) + else: + response_message = str(final_result) + else: + response_message = str(final_result) + yield StreamEvent( - event_type="approval_required", + event_type="complete", data={ - "action": pending_action, - "message": "Human approval required before proceeding.", + "message": response_message, + "tokens_used": usage.total_tokens or 0, + "tool_calls_count": deps.tool_call_count, + "pending_approval": pending_approval, }, - timestamp=stream_now, + timestamp=datetime.now(UTC), ) - - # Yield completion event - response_message: str = "No response generated." - if final_result: - if hasattr(final_result, "answer") and final_result.answer: - response_message = str(final_result.answer) - elif hasattr(final_result, "summary") and final_result.summary: - response_message = str(final_result.summary) - elif ( - hasattr(final_result, "recommendations") - and final_result.recommendations - ): - recommendations = final_result.recommendations - if isinstance(recommendations, list) and recommendations: - response_message = "\n".join(str(item) for item in recommendations) - else: - response_message = str(final_result) - else: - response_message = str(final_result) - - yield StreamEvent( - event_type="complete", - data={ - "message": response_message, - "tokens_used": usage.total_tokens or 0, - "tool_calls_count": deps.tool_call_count, - "pending_approval": pending_approval, - }, - timestamp=datetime.now(UTC), - ) except TimeoutError as e: raise TimeoutError( f"Agent response timed out after {self.settings.agent_timeout_seconds} seconds" diff --git a/app/features/agents/tests/test_base.py b/app/features/agents/tests/test_base.py index 9652d3b6..4415eb11 100644 --- a/app/features/agents/tests/test_base.py +++ b/app/features/agents/tests/test_base.py @@ -1,8 +1,11 @@ """Unit tests for agent base helpers (Ollama-aware model factory).""" from collections.abc import Iterator +from unittest.mock import AsyncMock import pytest +from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart +from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.openai import OpenAIChatModel from app.core.config import get_settings @@ -13,6 +16,8 @@ ) from app.features.agents.agents.experiment import create_experiment_agent from app.features.agents.agents.rag_assistant import create_rag_assistant_agent +from app.features.agents.deps import AgentDeps +from app.features.agents.schemas import ExperimentReport, RAGAnswer @pytest.fixture(autouse=True) @@ -82,3 +87,80 @@ def test_rag_assistant_agent_applies_retry_attempts(): assert agent._max_output_retries == 4 assert agent._max_tool_retries == 4 + + +def test_experiment_agent_uses_prompted_output() -> None: + """The experiment agent runs in PromptedOutput mode, not default ToolOutput. + + Regression for issue #173: weaker/local models answer in prose and cannot + satisfy the tool-call output contract that the default ToolOutput mode + requires. PromptedOutput puts the JSON schema in the prompt and parses the + model's text reply instead. + + This is asserted behaviorally via the public ``FunctionModel`` test double: + PromptedOutput mode registers no ``final_result`` output tool, and a + plain-text JSON reply is still parsed into a valid ``ExperimentReport``. + """ + settings = get_settings() + settings.agent_default_model = "ollama:llama3.1" + agent = create_experiment_agent() + + report_json = ExperimentReport( + run_id="run-1", + status="success", + summary="seasonal_naive wins", + metrics={"mae": 8.9}, + recommendations=["deploy seasonal_naive"], + ).model_dump_json() + + captured: dict[str, list[str]] = {} + + def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + captured["output_tools"] = [tool.name for tool in info.output_tools] + return ModelResponse(parts=[TextPart(content=report_json)]) + + result = agent.run_sync( + "Run an experiment", + model=FunctionModel(respond), + deps=AgentDeps(db=AsyncMock(), session_id="test-prompted-output"), + ) + + # PromptedOutput mode registers no structured-output tool... + assert captured["output_tools"] == [] + # ...and the plain-text JSON reply is parsed into the structured type. + assert isinstance(result.output, ExperimentReport) + assert result.output.summary == "seasonal_naive wins" + + +def test_rag_assistant_agent_uses_prompted_output() -> None: + """The RAG assistant agent runs in PromptedOutput mode (issue #173). + + Mirrors test_experiment_agent_uses_prompted_output: no ``final_result`` + output tool is registered, and a plain-text JSON reply is parsed into a + valid ``RAGAnswer``. + """ + settings = get_settings() + settings.agent_default_model = "ollama:llama3.1" + agent = create_rag_assistant_agent() + + answer_json = RAGAnswer( + answer="The forecast API supports naive and seasonal_naive models.", + confidence="high", + sources=[{"source_path": "docs/api.md", "relevance": 0.9}], + ).model_dump_json() + + captured: dict[str, list[str]] = {} + + def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + captured["output_tools"] = [tool.name for tool in info.output_tools] + return ModelResponse(parts=[TextPart(content=answer_json)]) + + result = agent.run_sync( + "What models does the forecast API support?", + model=FunctionModel(respond), + deps=AgentDeps(db=AsyncMock(), session_id="test-prompted-output"), + ) + + assert captured["output_tools"] == [] + assert isinstance(result.output, RAGAnswer) + assert result.output.confidence == "high" diff --git a/app/features/agents/tests/test_service.py b/app/features/agents/tests/test_service.py index 92888355..08064495 100644 --- a/app/features/agents/tests/test_service.py +++ b/app/features/agents/tests/test_service.py @@ -1,11 +1,13 @@ """Unit tests for agent service.""" import json +from collections.abc import AsyncIterator from datetime import UTC, datetime, timedelta from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest +from pydantic_ai import Agent from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( ModelMessage, @@ -333,6 +335,49 @@ async def test_chat_model_misbehavior_returns_friendly_message( assert "invalid tool call" in response.message assert "exceeded max retries" not in response.message + @pytest.mark.asyncio + async def test_chat_runs_tools_sequentially( + self, + sample_active_session: AgentSession, + sample_experiment_report: ExperimentReport, + ) -> None: + """chat() must run the agent under sequential tool execution. + + Regression for issue #172: every tool shares the single AgentDeps.db + AsyncSession, so concurrent tool calls raised SQLAlchemy's + InvalidRequestError. The service must enter PydanticAI's public + ``Agent.parallel_tool_call_execution_mode("sequential")`` context + around the agent run. + """ + service = AgentService() + mock_db = AsyncMock() + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_active_session + mock_db.execute.return_value = mock_result + + run_result = MagicMock() + run_result.output = sample_experiment_report + usage = MagicMock() + usage.total_tokens = 1 + run_result.usage.return_value = usage + run_result.all_messages.return_value = [] + + mock_agent = MagicMock() + mock_agent.run = AsyncMock(return_value=run_result) + + with ( + patch.object(service, "_get_agent", return_value=mock_agent), + patch.object(Agent, "parallel_tool_call_execution_mode") as mock_mode, + ): + await service.chat( + db=mock_db, + session_id=sample_active_session.session_id, + message="Run a backtest", + ) + + mock_mode.assert_called_once_with("sequential") + class TestAgentServiceStreamChat: """Tests for streaming chat functionality.""" @@ -385,6 +430,60 @@ async def __aexit__(self, *exc: object) -> bool: assert events[0].data["error_type"] == "model_behavior_error" assert "exceeded max retries" not in events[0].data["error"] + @pytest.mark.asyncio + async def test_stream_chat_runs_tools_sequentially( + self, + sample_active_session: AgentSession, + ) -> None: + """stream_chat() must also run the agent under sequential tool execution. + + Mirrors test_chat_runs_tools_sequentially for the streaming path so a + future change to only one code path cannot silently reintroduce the + concurrent-session bug from issue #172. + """ + service = AgentService() + mock_db = AsyncMock() + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_active_session + mock_db.execute.return_value = mock_result + + class _StubStream: + """Minimal async-context-manager stand-in for agent.run_stream(...).""" + + async def __aenter__(self) -> MagicMock: + stream = MagicMock() + + async def _stream_text() -> AsyncIterator[str]: + yield "hello" + + stream.stream_text = _stream_text + stream.get_output = AsyncMock(return_value=None) + usage = MagicMock() + usage.total_tokens = 1 + stream.usage.return_value = usage + stream.all_messages.return_value = [] + return stream + + async def __aexit__(self, *exc: object) -> bool: + return False + + mock_agent = MagicMock() + mock_agent.run_stream = MagicMock(return_value=_StubStream()) + + with ( + patch.object(service, "_get_agent", return_value=mock_agent), + patch.object(Agent, "parallel_tool_call_execution_mode") as mock_mode, + ): + async for _event in service.stream_chat( + db=mock_db, + session_id=sample_active_session.session_id, + message="Run a backtest", + ): + pass + + mock_mode.assert_called_once_with("sequential") + class TestAgentServiceApproval: """Tests for approval workflow."""