Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions app/features/agents/agents/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import structlog
from pydantic_ai import ModelRetry
from pydantic_ai.models import Model
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.providers.ollama import OllamaProvider

Expand Down Expand Up @@ -120,6 +121,56 @@ def get_fallback_model() -> str:
return settings.agent_fallback_model


def build_agent_model_with_fallback() -> Model | str:
"""Build the PydanticAI ``model`` argument, wrapping primary + fallback.

When the primary model raises a provider error — HTTP 5xx, rate limit,
timeout, i.e. any ``pydantic_ai.exceptions.ModelAPIError`` — PydanticAI's
:class:`FallbackModel` transparently retries the request against
``agent_fallback_model``. This keeps an agent run alive through a transient
provider outage (e.g. a Gemini ``503 UNAVAILABLE``) instead of surfacing a
hard error. ``FallbackModel``'s default ``fallback_on=(ModelAPIError,)``
already covers that case.

The primary model is returned alone (no fallback wrapper) when:

- no fallback is configured, or it equals the primary identifier; or
- the fallback provider has no API key — wrapping it would only move the
failure, so the agent runs primary-only and logs a warning.

Returns:
A :class:`FallbackModel` (primary then fallback) when a usable fallback
is configured, otherwise the primary model argument from
:func:`build_agent_model`.

Raises:
ValueError: If the primary provider's API key is not configured
(fail-fast — an agent with no usable primary cannot run).
"""
primary_id = get_model_identifier()
validate_api_key_for_model(primary_id) # fail-fast on the primary
primary = build_agent_model(primary_id)

fallback_id = get_fallback_model()
if not fallback_id or fallback_id == primary_id:
return primary

try:
validate_api_key_for_model(fallback_id)
except ValueError:
logger.warning(
"agents.fallback_disabled",
reason="missing_api_key",
primary=primary_id,
fallback=fallback_id,
)
return primary

fallback = build_agent_model(fallback_id)
logger.info("agents.fallback_enabled", primary=primary_id, fallback=fallback_id)
return FallbackModel(primary, fallback)


def get_agent_retries() -> int:
"""Get the configured retry budget for agent tool calls and output validation.

Expand Down
10 changes: 4 additions & 6 deletions app/features/agents/agents/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@
SAFETY_INSTRUCTIONS,
SYSTEM_PROMPT_HEADER,
TOOL_USAGE_INSTRUCTIONS,
build_agent_model,
build_agent_model_with_fallback,
get_agent_retries,
get_model_identifier,
get_model_settings,
recoverable,
requires_approval,
validate_api_key_for_model,
)
from app.features.agents.deps import AgentDeps
from app.features.agents.schemas import ExperimentReport
Expand Down Expand Up @@ -83,9 +81,9 @@ def create_experiment_agent() -> Agent[AgentDeps, ExperimentReport]:
Returns:
Configured Agent instance with tools registered.
"""
identifier = get_model_identifier()
validate_api_key_for_model(identifier) # Fail-fast validation
model = build_agent_model(identifier) # str for cloud, Model object for ollama
# Primary model, wrapped in a FallbackModel so a transient provider error
# (HTTP 5xx, rate limit) on the primary transparently retries the fallback.
model = build_agent_model_with_fallback()

retries = get_agent_retries()
agent: Agent[AgentDeps, ExperimentReport] = Agent(
Expand Down
10 changes: 4 additions & 6 deletions app/features/agents/agents/rag_assistant.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@
from app.features.agents.agents.base import (
SAFETY_INSTRUCTIONS,
SYSTEM_PROMPT_HEADER,
build_agent_model,
build_agent_model_with_fallback,
get_agent_retries,
get_model_identifier,
get_model_settings,
recoverable,
validate_api_key_for_model,
)
from app.features.agents.deps import AgentDeps
from app.features.agents.schemas import RAGAnswer
Expand Down Expand Up @@ -78,9 +76,9 @@ def create_rag_assistant_agent() -> Agent[AgentDeps, RAGAnswer]:
Returns:
Configured Agent instance with tools registered.
"""
identifier = get_model_identifier()
validate_api_key_for_model(identifier) # Fail-fast validation
model = build_agent_model(identifier) # str for cloud, Model object for ollama
# Primary model, wrapped in a FallbackModel so a transient provider error
# (HTTP 5xx, rate limit) on the primary transparently retries the fallback.
model = build_agent_model_with_fallback()

retries = get_agent_retries()
agent: Agent[AgentDeps, RAGAnswer] = Agent(
Expand Down
62 changes: 62 additions & 0 deletions app/features/agents/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,15 @@
import pytest
from pydantic_ai import ModelRetry
from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart
from pydantic_ai.models.fallback import FallbackModel
from pydantic_ai.models.function import AgentInfo, FunctionModel
from pydantic_ai.models.openai import OpenAIChatModel

from app.core.config import get_settings
from app.features.agents.agents.base import (
TOOL_USAGE_INSTRUCTIONS,
build_agent_model,
build_agent_model_with_fallback,
get_agent_retries,
recoverable,
validate_api_key_for_model,
Expand Down Expand Up @@ -62,6 +64,66 @@ def test_validate_api_key_for_model_ollama_skips_key_check():
validate_api_key_for_model("ollama:llama3.1")


def test_build_agent_model_with_fallback_wraps_primary_and_fallback():
"""A distinct, key-backed fallback yields a FallbackModel wired primary-then-fallback.

Asserts the *order* via the public ``FallbackModel.models`` list — ``models[0]``
must be the primary (``agent_default_model``) and ``models[1]`` the fallback
(``agent_fallback_model``) — so a swap or misconfiguration is caught, not just
the wrapper type.
"""
settings = get_settings()
settings.agent_default_model = "anthropic:claude-sonnet-4-5"
settings.agent_fallback_model = "openai:gpt-4o"
settings.anthropic_api_key = "test-anthropic-key"
settings.openai_api_key = "test-openai-key"

model = build_agent_model_with_fallback()

Comment on lines +67 to +82
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Strengthen the FallbackModel test by asserting that the primary and fallback are wired as expected, not just that a FallbackModel instance is returned.

This test could still pass if the primary and fallback models were swapped or misconfigured. To make it more robust, assert the internal configuration of the FallbackModel via its public API so that the primary maps to settings.agent_default_model and the fallback to settings.agent_fallback_model. If that isn’t possible, add a behavior-focused test using stub models that confirms the primary is tried first and the fallback is used when the primary raises ModelAPIError.

Suggested implementation:

def test_build_agent_model_with_fallback_wraps_primary_and_fallback():
    """A distinct, key-backed fallback yields a FallbackModel with correctly wired primary and fallback."""
    settings = get_settings()
    settings.agent_default_model = "anthropic:claude-sonnet-4-5"
    settings.agent_fallback_model = "openai:gpt-4o"
    settings.anthropic_api_key = "test-anthropic-key"
    settings.openai_api_key = "test-openai-key"

    model = build_agent_model_with_fallback()

    # Still ensure the wrapper type is correct
    assert isinstance(model, FallbackModel)

    # Verify that the primary model corresponds to the agent_default_model
    primary = model.primary
    # Different model classes may expose their identifier differently; prefer a public attribute if available.
    primary_id = getattr(primary, "model", None) or getattr(primary, "model_name", None) or getattr(primary, "model_id", None)
    assert primary_id == settings.agent_default_model

    # Verify that the fallback model corresponds to the agent_fallback_model
    fallback = model.fallback
    fallback_id = getattr(fallback, "model", None) or getattr(fallback, "model_name", None) or getattr(fallback, "model_id", None)
    assert fallback_id == settings.agent_fallback_model

To make this test pass robustly, ensure that:

  1. FallbackModel exposes the wrapped models via public attributes or properties named primary and fallback. If different names are used, update the test accordingly.
  2. The underlying model instances on primary and fallback expose a public identifier attribute such as model, model_name, or model_id that matches the configured "provider:model" string (e.g., "anthropic:claude-sonnet-4-5" and "openai:gpt-4o"). If the public API uses a different attribute or a method (e.g., primary.name or primary.info.model), adjust the primary_id / fallback_id extraction to match that API.
  3. If no such public attributes exist, consider adding a small, public, read-only property on the model classes (or on FallbackModel itself) that exposes the configured model string so tests can assert wiring without relying on internals.

assert isinstance(model, FallbackModel)
Comment on lines +81 to +83
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

suggestion (testing): Add a test for the fail-fast behavior when the primary model API key is missing in build_agent_model_with_fallback.

The current tests cover a missing fallback key and the primary==fallback case. Since build_agent_model_with_fallback() also calls validate_api_key_for_model(primary_id), please add a test that sets agent_default_model to a non-Ollama provider, clears its API key, and asserts that build_agent_model_with_fallback() raises ValueError. This will lock in the fail-fast behavior and protect against future removal of the primary key validation.

Suggested change
model = build_agent_model_with_fallback()
assert isinstance(model, FallbackModel)
model = build_agent_model_with_fallback()
assert isinstance(model, FallbackModel)
def test_build_agent_model_with_fallback_raises_when_primary_api_key_missing():
"""Fail fast when primary model uses non-Ollama provider and its API key is missing."""
settings = get_settings()
settings.agent_default_model = "anthropic:claude-sonnet-4-5"
settings.agent_fallback_model = "openai:gpt-4o"
settings.anthropic_api_key = ""
settings.openai_api_key = "test-openai-key"
with pytest.raises(ValueError):
build_agent_model_with_fallback()

# Each member model exposes its provider via `.system` and name via
# `.model_name`; recombine to the `provider:model` identifier we configured.
wired = [f"{m.system}:{m.model_name}" for m in model.models]
assert wired == [settings.agent_default_model, settings.agent_fallback_model]


def test_build_agent_model_with_fallback_raises_when_primary_api_key_missing():
"""Fail fast: a non-Ollama primary with no API key raises before any wrapping."""
settings = get_settings()
settings.agent_default_model = "anthropic:claude-sonnet-4-5"
settings.agent_fallback_model = "openai:gpt-4o"
settings.anthropic_api_key = ""
settings.openai_api_key = "test-openai-key"

with pytest.raises(ValueError, match="Anthropic API key not configured"):
build_agent_model_with_fallback()


def test_build_agent_model_with_fallback_primary_only_when_fallback_key_missing():
"""With no API key for the fallback provider, the primary is returned alone."""
settings = get_settings()
settings.agent_default_model = "anthropic:claude-sonnet-4-5"
settings.agent_fallback_model = "openai:gpt-4o"
settings.anthropic_api_key = "test-anthropic-key"
settings.openai_api_key = ""

model = build_agent_model_with_fallback()

assert model == "anthropic:claude-sonnet-4-5"


def test_build_agent_model_with_fallback_primary_only_when_fallback_equals_primary():
"""A fallback identical to the primary adds no resilience — primary returned alone."""
settings = get_settings()
settings.agent_default_model = "anthropic:claude-sonnet-4-5"
settings.agent_fallback_model = "anthropic:claude-sonnet-4-5"
settings.anthropic_api_key = "test-anthropic-key"

model = build_agent_model_with_fallback()

assert model == "anthropic:claude-sonnet-4-5"


def test_prompts_only_reference_registered_tool_names() -> None:
"""Every `tool_*` name in the agent prompts must be an actually-registered tool.

Expand Down