diff --git a/app/features/agents/agents/base.py b/app/features/agents/agents/base.py index eef89e07..abade6c2 100644 --- a/app/features/agents/agents/base.py +++ b/app/features/agents/agents/base.py @@ -5,10 +5,14 @@ from __future__ import annotations +import functools +import inspect import os +from collections.abc import Awaitable, Callable from typing import Any import structlog +from pydantic_ai import ModelRetry from pydantic_ai.models import Model from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.providers.ollama import OllamaProvider @@ -18,6 +22,45 @@ logger = structlog.get_logger() +def recoverable[**P, ToolReturnT]( + func: Callable[P, Awaitable[ToolReturnT]], +) -> Callable[P, Awaitable[ToolReturnT]]: + """Wrap an async agent tool so an expected ``ValueError`` becomes a ``ModelRetry``. + + Input-driven failures (no data for a store, an unknown run id, a malformed + date) should let the model correct its arguments on the next turn instead of + crashing the whole run (issue #176). Other exception types still propagate + as genuine errors. + + Args: + func: The async tool function to wrap. + + Returns: + The wrapped tool function, signature preserved for PydanticAI schema + extraction. + + Raises: + TypeError: If ``func`` is not a coroutine function. The wrapper + ``await``s ``func``, so wrapping a sync callable would only fail + (with an opaque "not awaitable" error) when the tool is first + called — this guard surfaces the mistake at decoration time. + """ + if not inspect.iscoroutinefunction(func): + raise TypeError( + f"@recoverable wraps async tool functions only; " + f"{getattr(func, '__qualname__', func)!r} is not a coroutine function." + ) + + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> ToolReturnT: + try: + return await func(*args, **kwargs) + except ValueError as exc: + raise ModelRetry(str(exc)) from exc + + return wrapper + + def build_agent_model(identifier: str) -> str | Model: """Build the PydanticAI ``model`` argument for an agent identifier. @@ -183,13 +226,13 @@ def requires_approval(action_name: str) -> bool: """ TOOL_USAGE_INSTRUCTIONS = """ -TOOL USAGE: -- Use list_runs to find existing experiments -- Use run_backtest to evaluate model performance -- Use compare_runs to analyze differences between runs -- Use create_alias to deploy successful models (requires approval) -- Use archive_run to clean up old experiments (requires approval) -- Use retrieve_context to find documentation +TOOL USAGE (call tools by these EXACT names): +- Use tool_list_runs to find existing experiments +- Use tool_run_backtest to evaluate model performance +- Use tool_compare_backtest_results to compare two backtest results +- Use tool_compare_runs to analyze differences between registered runs +- Use tool_create_alias to deploy successful models (requires approval) +- Use tool_archive_run to clean up old experiments (requires approval) """ SAFETY_INSTRUCTIONS = """ diff --git a/app/features/agents/agents/experiment.py b/app/features/agents/agents/experiment.py index 47d594fd..22311139 100644 --- a/app/features/agents/agents/experiment.py +++ b/app/features/agents/agents/experiment.py @@ -23,6 +23,7 @@ get_agent_retries, get_model_identifier, get_model_settings, + recoverable, requires_approval, validate_api_key_for_model, ) @@ -55,9 +56,9 @@ WORKFLOW: 1. Parse the objective to understand what the user wants -2. Check existing runs with list_runs to avoid duplicates -3. Run backtests for candidate models -4. Compare results using compare_backtest_results +2. Check existing runs with tool_list_runs to avoid duplicates +3. Run backtests for candidate models with tool_run_backtest +4. Compare results using tool_compare_backtest_results 5. Formulate recommendation with clear metrics 6. If auto_deploy requested and model beats baselines, propose deployment @@ -104,6 +105,7 @@ def create_experiment_agent() -> Agent[AgentDeps, ExperimentReport]: # Register tools with the agent @agent.tool + @recoverable async def tool_list_runs( ctx: RunContext[AgentDeps], page: int = 1, @@ -145,6 +147,7 @@ async def tool_list_runs( ) @agent.tool + @recoverable async def tool_get_run( ctx: RunContext[AgentDeps], run_id: str, @@ -166,6 +169,7 @@ async def tool_get_run( return await get_run(db=ctx.deps.db, run_id=run_id) @agent.tool + @recoverable async def tool_run_backtest( ctx: RunContext[AgentDeps], store_id: int, @@ -257,6 +261,7 @@ def tool_compare_backtest_results( return compare_backtest_results(result_a, result_b) @agent.tool + @recoverable async def tool_compare_runs( ctx: RunContext[AgentDeps], run_id_a: str, @@ -285,6 +290,7 @@ async def tool_compare_runs( ) @agent.tool + @recoverable async def tool_create_alias( ctx: RunContext[AgentDeps], alias_name: str, @@ -333,6 +339,7 @@ async def tool_create_alias( ) @agent.tool + @recoverable async def tool_archive_run( ctx: RunContext[AgentDeps], run_id: str, diff --git a/app/features/agents/agents/rag_assistant.py b/app/features/agents/agents/rag_assistant.py index 5e448510..c2288409 100644 --- a/app/features/agents/agents/rag_assistant.py +++ b/app/features/agents/agents/rag_assistant.py @@ -22,6 +22,7 @@ get_agent_retries, get_model_identifier, get_model_settings, + recoverable, validate_api_key_for_model, ) from app.features.agents.deps import AgentDeps @@ -103,6 +104,7 @@ def create_rag_assistant_agent() -> Agent[AgentDeps, RAGAnswer]: # Register tools with the agent @agent.tool + @recoverable async def tool_retrieve_context( ctx: RunContext[AgentDeps], query: str, @@ -187,6 +189,7 @@ def tool_check_evidence( ) @agent.tool + @recoverable async def tool_list_sources( ctx: RunContext[AgentDeps], ) -> dict[str, Any]: diff --git a/app/features/agents/tests/test_base.py b/app/features/agents/tests/test_base.py index 4415eb11..83193aed 100644 --- a/app/features/agents/tests/test_base.py +++ b/app/features/agents/tests/test_base.py @@ -1,20 +1,28 @@ """Unit tests for agent base helpers (Ollama-aware model factory).""" +import re from collections.abc import Iterator +from typing import Any, cast from unittest.mock import AsyncMock import pytest +from pydantic_ai import ModelRetry from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.openai import OpenAIChatModel from app.core.config import get_settings from app.features.agents.agents.base import ( + TOOL_USAGE_INSTRUCTIONS, build_agent_model, get_agent_retries, + recoverable, validate_api_key_for_model, ) -from app.features.agents.agents.experiment import create_experiment_agent +from app.features.agents.agents.experiment import ( + EXPERIMENT_SYSTEM_PROMPT, + create_experiment_agent, +) from app.features.agents.agents.rag_assistant import create_rag_assistant_agent from app.features.agents.deps import AgentDeps from app.features.agents.schemas import ExperimentReport, RAGAnswer @@ -54,6 +62,94 @@ def test_validate_api_key_for_model_ollama_skips_key_check(): validate_api_key_for_model("ollama:llama3.1") +def test_prompts_only_reference_registered_tool_names() -> None: + """Every `tool_*` name in the agent prompts must be an actually-registered tool. + + Regression for issue #175: the prompts named tools as `run_backtest`, + `list_runs`, … but the registered tools are `tool_`-prefixed, so weaker + models called unknown tool names. This test is the single source of truth + for that invariant — the registered set is read off the built agent (not a + hardcoded list), so drift in either direction (a renamed tool or an edited + prompt) fails CI. + """ + settings = get_settings() + settings.agent_default_model = "ollama:llama3.1" + agent = create_experiment_agent() + + captured: dict[str, set[str]] = {} + + def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + captured["registered"] = {tool.name for tool in info.function_tools} + # End the run immediately with a PromptedOutput-parseable text reply. + return ModelResponse(parts=[TextPart(content='{"summary": "noop"}')]) + + agent.run_sync( + "noop", + model=FunctionModel(respond), + deps=AgentDeps(db=AsyncMock(), session_id="test-tool-names"), + ) + registered = captured["registered"] + + # Tool names the prompts instruct the model to call. EXPERIMENT_SYSTEM_PROMPT + # already embeds TOOL_USAGE_INSTRUCTIONS; both are scanned to stay correct + # even if that embedding changes. + prompt_text = TOOL_USAGE_INSTRUCTIONS + EXPERIMENT_SYSTEM_PROMPT + referenced = set(re.findall(r"\btool_[a-z_]+\b", prompt_text)) + + assert referenced, "expected the prompts to name at least one tool" + unknown = referenced - registered + assert not unknown, f"prompts reference unregistered tools: {sorted(unknown)}" + + +async def test_recoverable_converts_valueerror_to_model_retry(): + """A ValueError from a tool becomes a ModelRetry the model can recover from (#176).""" + + @recoverable + async def tool() -> str: + raise ValueError("No data found for store=1") + + with pytest.raises(ModelRetry, match="No data found for store=1"): + await tool() + + +async def test_recoverable_passes_through_other_exceptions(): + """Non-ValueError exceptions are genuine bugs — they must still propagate.""" + + @recoverable + async def tool() -> str: + raise RuntimeError("a real bug") + + with pytest.raises(RuntimeError, match="a real bug"): + await tool() + + +async def test_recoverable_returns_value_on_success(): + """The decorator is transparent when the tool succeeds.""" + + @recoverable + async def tool() -> str: + return "ok" + + assert await tool() == "ok" + + +def test_recoverable_rejects_sync_function() -> None: + """@recoverable is async-only — applying it to a sync function fails fast. + + Without the guard a sync function would be wrapped and then ``await``ed, + surfacing a confusing ``TypeError: ... is not awaitable`` only at call + time. The decorator rejects it at decoration time instead. + """ + + def sync_tool() -> str: + return "nope" + + # recoverable is async-only by type; cast bypasses the static check so the + # runtime guard itself can be exercised. + with pytest.raises(TypeError, match="async tool functions only"): + recoverable(cast(Any, sync_tool)) + + def test_get_agent_retries_returns_configured_value(): """get_agent_retries reflects the agent_retry_attempts setting.""" settings = get_settings()