diff --git a/.env.example b/.env.example index 1c6fb544..0545e3b1 100644 --- a/.env.example +++ b/.env.example @@ -64,6 +64,10 @@ RAG_HNSW_EF_CONSTRUCTION=64 # - openai: GPT models (gpt-4o, gpt-4o-mini, etc.) # - google-gla: Gemini models via Google AI Studio (gemini-2-5-flash, gemini-3-flash, gemini-3-pro) # - google-vertex: Gemini models via Vertex AI (gemini-*) [requires GCP auth] +# - ollama: local models via Ollama's OpenAI-compatible endpoint (no API key) +# e.g. AGENT_DEFAULT_MODEL=ollama:llama3.1 (requires `ollama serve` + `ollama pull llama3.1`) +# Runtime-editable: the /admin "AI Models" tab persists overrides in the +# app_config table and applies them live — no .env edit or restart needed. AGENT_DEFAULT_MODEL=anthropic:claude-sonnet-4-5 AGENT_FALLBACK_MODEL=openai:gpt-4o diff --git a/PRPs/PRP-18-ai-model-admin-console.md b/PRPs/PRP-18-ai-model-admin-console.md new file mode 100644 index 00000000..9fed6b8f --- /dev/null +++ b/PRPs/PRP-18-ai-model-admin-console.md @@ -0,0 +1,674 @@ +name: "PRP-18 — AI Model Admin Console (runtime-editable, Ollama-capable)" +description: | + Add an "AI Models" admin surface that lets an operator view and **change** + the agent LLM model, the RAG embedding model, and provider API keys at + runtime — including running the chat agent fully local via Ollama. Changes + persist in a new `app_config` DB table and take effect live (no restart) by + mutating the in-process `Settings` singleton and invalidating the agent / + embedding caches. + +## Purpose + +One-pass implementation of a full-stack feature: a new `config` vertical slice +(backend) + an "AI Models" tab on the existing `/admin` page (frontend). The PRP +carries every file path, code snippet, library URL, and gotcha needed. + +## Core Principles + +1. **Context is King** — all referenced files/snippets are in this document. +2. **Validation Loops** — executable gates in the Validation section. +3. **Follow CLAUDE.md** — vertical slices, RFC 7807, Pydantic v2, strict typing, + `type(scope): description (#issue)` commits, branch off `dev`. +4. **Progressive Success** — backend slice → cache wiring → Ollama agent → routes + → frontend. Each step has a gate. + +--- + +## Goal + +Ship `http://localhost:5173/admin` → **AI Models** tab that can: + +- Display the **effective** AI-model configuration (agent LLM + RAG embeddings). +- **Edit and persist** those values; edits apply live without a backend restart. +- Add **Ollama as a first-class agent LLM provider** (`ollama:`), not just + a RAG embedding option. +- Show **provider connectivity** (Ollama reachable + its local model list; cloud + key presence) and let the operator **set/replace API keys** from the UI. + +End state: an operator can switch the chat agent from `anthropic:claude-sonnet-4-5` +to `ollama:llama3.1`, hit Save, open `/chat`, and the next message runs locally — +no process restart, no `.env` edit. + +## Why + +- **User value** — the system currently requires hand-editing `.env` + restarting + uvicorn to change a model. An operator-facing console removes that friction. +- **Portfolio value** — "swap any AI model, including a fully-local Ollama path, + from a dashboard" is a strong demo of the agentic + RAG layers. +- **Integration** — extends the existing `/admin` page (RAG / Aliases / Seeder + tabs) and the existing Ollama embedding support (`app/features/rag/embeddings.py`). + +## What + +### User-visible behavior + +- `/admin` gains a 4th tab **AI Models** with four cards: Agent LLM, RAG + Embeddings, API Keys, Provider Health. +- Editing a field + Save persists it and applies it immediately. +- The Agent LLM provider dropdown includes `ollama`; when `ollama` is picked the + model field is a dropdown populated from the host's pulled Ollama models. +- Provider Health shows Ollama reachability + local models, and cloud key presence. + +### Technical requirements + +- New `config` vertical slice: `app/features/config/{models,schemas,service,routes,tests}.py`. +- New `app_config` table (Alembic migration) — key/value override store. +- A startup hook applies persisted overrides onto the `Settings` singleton. +- A save path validates → upserts DB → mutates `Settings` → resets agent + + embedding caches. +- `ollama` added to the agent provider allow-list; PydanticAI agent built via an + Ollama-aware model factory. +- All errors RFC 7807; request bodies Pydantic v2; `mypy --strict` + `pyright --strict` clean. + +### Success Criteria + +- [ ] `GET /config/ai` returns the effective AI config; API keys are **masked**, never raw. +- [ ] `PATCH /config/ai` persists changes to `app_config`, mutates `Settings`, resets caches. +- [ ] `GET /config/providers/health` reports Ollama + cloud provider status. +- [ ] `GET /config/ollama/models` lists the host's pulled Ollama models. +- [ ] Setting `agent_default_model` to `ollama:` makes `/chat` run via Ollama with **no restart**. +- [ ] Persisted overrides survive a backend restart (re-applied on startup). +- [ ] `/admin` → AI Models tab edits and saves all four cards. +- [ ] All validation gates green (ruff, mypy, pyright, pytest unit + integration, frontend tsc/lint/test, `alembic upgrade head`). + +## All Needed Context + +### Documentation & References + +```yaml +- url: https://ai.pydantic.dev/models/openai/#ollama + why: Canonical pattern for running a PydanticAI Agent against Ollama via the + OpenAI-compatible endpoint. Shows OpenAIChatModel + OllamaProvider usage. + critical: Ollama's OpenAI-compatible base is "/v1". The agent + model object is passed to Agent(model=...) instead of a "provider:model" string. + +- url: https://ai.pydantic.dev/api/providers/ + why: OllamaProvider constructor signature (base_url kwarg). Verify with context7 + (resolve-library-id "pydantic-ai" → query-docs) if the import path differs. + +- url: https://ai.pydantic.dev/api/models/openai/ + why: OpenAIChatModel signature. In pydantic-ai 1.96 the class is OpenAIChatModel + (OpenAIModel is a deprecated alias). Import: from pydantic_ai.models.openai import OpenAIChatModel + +- url: https://github.com/ollama/ollama/blob/main/docs/api.md#list-local-models + why: GET /api/tags response shape — {"models":[{"name","model","size","details":{...}}]}. + Used for both the health check and the model picker. + +- url: https://github.com/ollama/ollama/blob/main/docs/openai.md + why: Confirms Ollama's /v1/embeddings + /v1/chat/completions OpenAI-compat surface. + +- url: https://tanstack.com/query/latest/docs/framework/react/guides/mutations + why: useMutation + queryClient.invalidateQueries pattern (mirror use-seeder.ts). + +- file: app/features/seeder/routes.py + why: THE slice to mirror — a control-plane slice with no DB models of its own, + APIRouter(prefix=..., tags=...), HTTPException for RFC 7807, get_db dependency. + +- file: app/features/seeder/schemas.py + why: Pydantic v2 request/response schema patterns, field_validator usage. + +- file: app/features/rag/embeddings.py + why: Ollama HTTP client pattern (httpx.AsyncClient base_url + timeout), the + module-global _embedding_provider singleton + get_embedding_service(). + +- file: app/features/agents/agents/base.py + why: get_model_identifier / get_model_settings / validate_api_key_for_model — + the exact functions to extend for Ollama. + +- file: app/features/agents/agents/experiment.py + why: create_experiment_agent() builds Agent(model=...); _experiment_agent global + singleton + get_experiment_agent(). rag_assistant.py mirrors this. + +- file: app/core/config.py + why: Settings model + the agent_default_model/agent_fallback_model field_validator + whose provider allow-list must gain "ollama". + +- file: alembic/versions/d6e0f2g3h456_create_agent_session_table.py + why: Migration template — revision id style, upgrade/downgrade, JSONB column. + +- file: frontend/src/pages/admin.tsx + why: The existing Tabs structure to extend (add a 4th TabsTrigger/TabsContent). + +- file: frontend/src/hooks/use-seeder.ts + why: TanStack Query hook pattern to mirror exactly (useQuery/useMutation/invalidate). + +- file: frontend/src/lib/api.ts + why: api() fetch wrapper + ApiError — how every endpoint is called. + +- file: frontend/src/lib/constants.ts + why: ROUTES / NAV_ITEMS — no new route needed (/admin exists), no edit expected here. + +- file: app/features/seeder/tests/test_routes.py + why: Route-test pattern — TestClient, patch get_settings/get_db, patch service fns. + +- file: app/features/rag/tests/test_embeddings.py + why: How to unit-test httpx/Ollama calls with patched get_settings (no live calls). + +- file: app/core/tests/test_strict_mode_policy.py + why: AST invariant — a ConfigDict(strict=True) request model may not have a bare + date/datetime/time/UUID/Decimal field. Our config schemas use only + str/int/float/bool, so they pass — keep it that way. +``` + +### Current Codebase tree (relevant subset) + +```bash +app/ + core/ + config.py # Settings + get_settings() @lru_cache ← EDIT (allow-list, validator extraction) + database.py # get_db / engine / async_session + problem_details.py # RFC 7807 envelope (HTTPException auto-converts) + main.py # router wiring + lifespan ← EDIT (wire config router + startup hook) + features/ + agents/ + agents/base.py # model helpers ← EDIT (build_agent_model, ollama in validate) + agents/experiment.py # _experiment_agent singleton ← EDIT (use build_agent_model + reset hook) + agents/rag_assistant.py# _rag_assistant_agent singleton ← EDIT (same) + service.py + rag/ + embeddings.py # _embedding_provider singleton ← EDIT (add reset_embedding_service()) + seeder/ # control-plane slice — MIRROR THIS +alembic/versions/ # 8 migrations ← ADD one +frontend/src/ + pages/admin.tsx # Tabs(rag|aliases|seeder) ← EDIT (+models tab) + hooks/use-seeder.ts # hook pattern to mirror + lib/api.ts + types/api.ts # ← EDIT (+config types) + components/ # demo/ chat/ ... (no admin/ dir yet) +``` + +### Desired Codebase tree (files added) + +```bash +app/features/config/ + __init__.py + models.py # AppConfig ORM (app_config table) + schemas.py # AIModelConfig, AIModelConfigUpdate, ProviderHealth, OllamaModel, ApiKeyStatus + service.py # load/save overrides, apply to Settings, cache resets, connectivity tests + routes.py # APIRouter(prefix="/config", tags=["config"]) + tests/ + __init__.py + conftest.py + test_schemas.py + test_service.py + test_routes.py +alembic/versions/_create_app_config_table.py +frontend/src/ + hooks/use-config.ts # useAIConfig, useUpdateAIConfig, useProviderHealth, useOllamaModels + components/admin/ai-models-panel.tsx # the "AI Models" tab UI +``` + +### Known Gotchas & Library Quirks + +```python +# CRITICAL: get_settings() is @lru_cache'd — ONE Settings object lives for the +# process. Mutating its attributes (settings.agent_default_model = "x") is the +# intended override mechanism here. BaseSettings is NOT frozen and +# validate_assignment is False — assignment will NOT re-run field validators, +# so the config SERVICE must validate BEFORE setattr. + +# CRITICAL: Agents are module-global singletons: +# experiment.py:_experiment_agent, rag_assistant.py:_rag_assistant_agent. +# embeddings.py:_embedding_provider is likewise a module global. +# After any model/key change you MUST null these out or the change is invisible. +# Add reset functions; the config service calls them on every successful save. + +# CRITICAL: app/core/config.py has a field_validator on agent_default_model / +# agent_fallback_model with valid_providers = ["anthropic","openai", +# "google-gla","google-vertex"]. Ollama as an AGENT provider REQUIRES adding +# "ollama" to that list, else Settings() / setattr-validation rejects it. + +# CRITICAL: PydanticAI cloud providers accept a plain "provider:model" STRING as +# Agent(model=...). Ollama does NOT — you must pass a model OBJECT: +# OpenAIChatModel(model_name, provider=OllamaProvider(base_url=f"{url}/v1")). +# So introduce build_agent_model(identifier) -> str | Model. + +# GOTCHA: validate_api_key_for_model() exports keys to os.environ ONLY IF the var +# is absent ("if 'OPENAI_API_KEY' not in os.environ"). When the operator +# REPLACES a key via the UI, the save path must overwrite os.environ[...] +# unconditionally, or PydanticAI keeps using the stale key. +# Ollama needs NO key — skip validate_api_key_for_model for provider == "ollama". + +# GOTCHA: Ollama's OpenAI-compatible endpoints live under /v1 +# (/v1/chat/completions, /v1/embeddings). The model-list endpoint is the NATIVE +# /api/tags (NOT under /v1). ollama_base_url default is http://localhost:11434. + +# GOTCHA: rag_embedding_dimension is load-bearing — rag_chunk.embedding is a fixed +# pgvector dimension (see migration c5d9e1f2g345). Changing provider/model/ +# dimension while rag_chunk rows exist breaks retrieval. The PATCH handler MUST +# guard: if the change alters rag_embedding_dimension AND chunks exist, return +# 409 application/problem+json unless an explicit force=true is passed. + +# GOTCHA: NEVER return raw API keys in any GET response and NEVER log key values +# (.claude/rules/security-patterns.md). Return ApiKeyStatus { is_set: bool, +# masked: "sk-ant-…" + last 4 }. Storing keys in app_config is the documented +# tradeoff the operator chose (Q4) — keep them out of logs and out of GETs. + +# GOTCHA: Pydantic v2 strict mode — request bodies KEEP ConfigDict(strict=True). +# All config fields are str/int/float/bool (JSON-native) so NO Field(strict=False) +# override is needed and test_strict_mode_policy.py stays green. Do not add a +# date/UUID/Decimal field to a strict request model. + +# GOTCHA: The startup hook reads app_config via the DB. On a brand-new DB the +# table may not exist yet — wrap the load in try/except, log a warning, and let +# the app boot on env defaults. Never let a missing table crash startup. +``` + +## Implementation Blueprint + +### Data models and structure + +```python +# app/features/config/models.py — mirror agents/models.py ORM style (Mapped[], mapped_column) +from datetime import datetime +from sqlalchemy import String, func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column +from app.core.database import Base # confirm Base import path from agents/models.py + +class AppConfig(Base): + """Key/value override store for runtime-editable settings.""" + __tablename__ = "app_config" + key: Mapped[str] = mapped_column(String(100), primary_key=True) + value: Mapped[dict] = mapped_column(JSONB, nullable=False) # {"v": } + updated_at: Mapped[datetime] = mapped_column( + server_default=func.now(), onupdate=func.now() + ) + +# app/features/config/schemas.py — Pydantic v2, ConfigDict(strict=True) on request body +ALLOWED_OVERRIDE_KEYS = { + "agent_default_model", "agent_fallback_model", "agent_temperature", + "agent_max_tokens", "agent_thinking_budget", + "rag_embedding_provider", "rag_embedding_model", "rag_embedding_dimension", + "ollama_base_url", "ollama_embedding_model", + "openai_api_key", "anthropic_api_key", "google_api_key", # secret keys +} +SECRET_KEYS = {"openai_api_key", "anthropic_api_key", "google_api_key"} + +class ApiKeyStatus(BaseModel): # response only + provider: str # "openai" | "anthropic" | "google" + is_set: bool + masked: str | None # e.g. "sk-ant-…3f9a" (never the full value) + +class AIModelConfig(BaseModel): # GET /config/ai response + agent_default_model: str + agent_fallback_model: str + agent_temperature: float + agent_max_tokens: int + agent_thinking_budget: int | None + rag_embedding_provider: str + rag_embedding_model: str + rag_embedding_dimension: int + ollama_base_url: str + ollama_embedding_model: str + api_keys: list[ApiKeyStatus] + overridden_keys: list[str] # which keys currently come from app_config (not env) + +class AIModelConfigUpdate(BaseModel): # PATCH /config/ai request + model_config = ConfigDict(strict=True) # repo policy; all fields JSON-native + agent_default_model: str | None = None + agent_fallback_model: str | None = None + agent_temperature: float | None = Field(default=None, ge=0.0, le=2.0) + agent_max_tokens: int | None = Field(default=None, ge=1) + agent_thinking_budget: int | None = None + rag_embedding_provider: Literal["openai", "ollama"] | None = None + rag_embedding_model: str | None = None + rag_embedding_dimension: int | None = Field(default=None, ge=1) + ollama_base_url: str | None = None + ollama_embedding_model: str | None = None + openai_api_key: str | None = None + anthropic_api_key: str | None = None + google_api_key: str | None = None + force: bool = False # bypass the dimension-change guard + +class OllamaModel(BaseModel): + name: str + size_bytes: int | None = None + family: str | None = None + +class ProviderHealth(BaseModel): + provider: str # "ollama" | "openai" | "anthropic" | "google" + reachable: bool + detail: str + models: list[str] = [] # populated for ollama +``` + +### List of tasks (in order) + +```yaml +Task 0 — Issue + branch: + - Create a GitHub issue: feat(api,ui): AI model admin console with Ollama support. + - git switch -c feat/config-ai-model-admin-console (off up-to-date dev) + - Every commit references that issue number. + +Task 1 — MODIFY app/core/config.py: + - EXTRACT the body of the agent_default_model/agent_fallback_model validator into + a module-level function `validate_model_identifier(v: str) -> str` so the + config service can reuse it. The @field_validator delegates to it. + - ADD "ollama" to valid_providers list. + - KEEP the existing behavior identical for cloud providers. + +Task 2 — MODIFY .env.example: + - Document that AGENT_DEFAULT_MODEL now also accepts ollama: + (e.g. AGENT_DEFAULT_MODEL=ollama:llama3.1). Add a commented example. + - No new env vars are required (app_config is the override store). + +Task 3 — CREATE alembic/versions/_create_app_config_table.py: + - MIRROR alembic/versions/d6e0f2g3h456_create_agent_session_table.py. + - upgrade(): create table app_config(key VARCHAR(100) PK, value JSONB NOT NULL, + updated_at TIMESTAMP server_default now()). + - downgrade(): drop table app_config. + - down_revision = current head (run `uv run alembic heads` to find it). + +Task 4 — CREATE app/features/config/__init__.py + models.py: + - AppConfig ORM as in the blueprint. Confirm Base import path against + app/features/agents/models.py. + +Task 5 — CREATE app/features/config/schemas.py: + - All schemas from the blueprint. ConfigDict(strict=True) on AIModelConfigUpdate. + +Task 6 — MODIFY app/features/rag/embeddings.py: + - ADD `def reset_embedding_service() -> None:` that sets the module global + `_embedding_provider = None` (so the next get_embedding_service() rebuilds). + +Task 7 — MODIFY app/features/agents/agents/experiment.py + rag_assistant.py: + - ADD `def reset_experiment_agent()` / `reset_rag_assistant_agent()` that null + the module-global singleton. + - CHANGE `model = get_model_identifier()` → `model = build_agent_model(get_model_identifier())`. + +Task 8 — MODIFY app/features/agents/agents/base.py: + - ADD build_agent_model(identifier: str) -> "str | Model": + * provider, name = identifier.split(":", 1) + * if provider == "ollama": return OpenAIChatModel(name, + provider=OllamaProvider(base_url=settings.ollama_base_url.rstrip("/") + "/v1")) + * else: return identifier (string, unchanged cloud path) + - MODIFY validate_api_key_for_model: if provider == "ollama" → return early + (no key needed). + - ADD reset_agent_caches() convenience that calls reset_experiment_agent() + + reset_rag_assistant_agent() (import locally to avoid cycles). + - Imports: from pydantic_ai.models.openai import OpenAIChatModel + from pydantic_ai.providers.ollama import OllamaProvider + (both verified available in pydantic-ai 1.96.0; if an import fails, confirm + the path via context7 query-docs for "pydantic-ai"). + +Task 9 — CREATE app/features/config/service.py: + - get_effective_config(db) -> AIModelConfig: read Settings + which keys are in + app_config (overridden_keys); mask secrets. + - apply_overrides_on_startup(db) -> None: load all app_config rows, setattr onto + get_settings(), export secret keys to os.environ. Try/except-safe. + - update_config(db, payload: AIModelConfigUpdate) -> AIModelConfig: + 1. Validate model identifiers via validate_model_identifier (incl. ollama). + 2. If rag_embedding_dimension changes AND rag_chunk rows exist AND not force + → raise HTTPException(409, "embedding dimension change requires re-index"). + 3. Upsert each provided non-None field into app_config (ON CONFLICT DO UPDATE). + 4. setattr onto get_settings() singleton; for secret keys also set + os.environ[KEY] unconditionally. + 5. reset_agent_caches() + reset_embedding_service(). + 6. Return get_effective_config(db). + - get_provider_health() -> list[ProviderHealth]: ollama → httpx GET /api/tags + (reachable + model names); openai/anthropic/google → key-presence (+ optional + cheap GET /v1/models behind a short timeout, swallow errors → reachable=False). + - list_ollama_models() -> list[OllamaModel]: httpx GET {ollama_base_url}/api/tags, + parse {"models":[...]}. Raise HTTPException(502) on connection failure. + - NEVER log key values; log key NAMES + bool only. + +Task 10 — CREATE app/features/config/routes.py: + - router = APIRouter(prefix="/config", tags=["config"]) (mirror seeder/routes.py) + - GET /config/ai -> AIModelConfig + - PATCH /config/ai -> AIModelConfig (body AIModelConfigUpdate) + - GET /config/providers/health -> list[ProviderHealth] + - GET /config/ollama/models -> list[OllamaModel] + - db: AsyncSession = Depends(get_db). HTTPException for errors (RFC 7807 auto). + +Task 11 — MODIFY app/main.py: + - import config router; app.include_router(config_router). + - In lifespan() AFTER configure_logging(), open an async session and call + config.service.apply_overrides_on_startup(db) inside try/except (warn-and-continue). + +Task 12 — CREATE app/features/config/tests/ (conftest, test_schemas, test_service, test_routes): + - Mirror seeder/tests/test_routes.py + rag/tests/test_embeddings.py patterns. + - Unit: mask logic, validate_model_identifier accepts "ollama:llama3.1" rejects + "ollama:" and "bad:x"; build_agent_model returns OpenAIChatModel for ollama, + str for cloud; update_config resets caches (assert globals nulled); httpx + Ollama /api/tags mocked. Mark DB-touching tests @pytest.mark.integration. + +Task 13 — CREATE frontend/src/types/api.ts additions: + - AIModelConfig, AIModelConfigUpdate, ProviderHealth, OllamaModel, ApiKeyStatus + mirroring the Pydantic schemas. + +Task 14 — CREATE frontend/src/hooks/use-config.ts: + - useAIConfig() -> useQuery(['config','ai'], () => api('/config/ai')) + - useProviderHealth() -> useQuery(['config','health'], ...) + - useOllamaModels() -> useQuery(['config','ollama-models'], ...) enabled on demand + - useUpdateAIConfig() -> useMutation(PATCH /config/ai) onSuccess invalidate + ['config',*]. MIRROR use-seeder.ts exactly. + +Task 15 — CREATE frontend/src/components/admin/ai-models-panel.tsx: + - with 4 cards (Agent LLM, RAG Embeddings, API Keys, Provider + Health). Use existing shadcn ui: Card, Select, Input, Button, Badge, Slider + via Input type=range (matches SeederPanel), toast on save. Provider select + includes ollama; when provider==ollama, model field is a Select fed by + useOllamaModels(). API key inputs are type="password"; show ApiKeyStatus badge. + - Follow webapp-testing / ui-design rule: verify in a real browser before + declaring done (see Validation Level 4). + +Task 16 — MODIFY frontend/src/pages/admin.tsx: + - Add a 4th (icon: Bot or Cpu from lucide-react) + and . + +Task 17 — MODIFY docs: + - docs/_base/API_CONTRACTS.md: add the 4 /config/* endpoints. + - README.md: one line on the AI Models admin tab + Ollama-as-agent option. + - docs/rag-ollama-setup.md: note Ollama can now also back the chat agent. +``` + +### Per-task pseudocode (critical details only) + +```python +# Task 8 — base.py +from pydantic_ai.models import Model +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.ollama import OllamaProvider + +def build_agent_model(identifier: str) -> str | Model: + # PATTERN: cloud providers keep the plain-string path (unchanged behavior) + provider = identifier.split(":", 1)[0] + if provider != "ollama": + return identifier + settings = get_settings() + _, model_name = identifier.split(":", 1) + # CRITICAL: Ollama OpenAI-compat base ends in /v1 + base = settings.ollama_base_url.rstrip("/") + "/v1" + return OpenAIChatModel(model_name, provider=OllamaProvider(base_url=base)) + +# Task 9 — service.update_config (the load-bearing path) +async def update_config(db: AsyncSession, payload: AIModelConfigUpdate) -> AIModelConfig: + settings = get_settings() + changes = payload.model_dump(exclude_none=True, exclude={"force"}) + + # 1. validate model identifiers (reuses config.validate_model_identifier — ollama OK) + for k in ("agent_default_model", "agent_fallback_model"): + if k in changes: + validate_model_identifier(changes[k]) # raises ValueError → 422 via handler + + # 2. dimension-change guard (GOTCHA above) + if "rag_embedding_dimension" in changes and not payload.force: + if changes["rag_embedding_dimension"] != settings.rag_embedding_dimension: + chunk_count = await _count_rag_chunks(db) + if chunk_count > 0: + raise HTTPException(409, detail=( + f"Changing embedding dimension with {chunk_count} indexed " + "chunks breaks retrieval. Delete RAG sources first or pass force=true.")) + + # 3. persist (ON CONFLICT DO UPDATE — SQLAlchemy pg insert, parameter-bound) + for key, val in changes.items(): + await _upsert_app_config(db, key, val) + await db.commit() + + # 4. apply to the live Settings singleton (+ os.environ for secrets) + for key, val in changes.items(): + setattr(settings, key, val) + if key in SECRET_KEYS: + os.environ[_ENV_NAME[key]] = val # unconditional overwrite + + # 5. CRITICAL: invalidate caches so the change is visible immediately + reset_agent_caches() # nulls _experiment_agent + _rag_assistant_agent + reset_embedding_service() # nulls _embedding_provider + + logger.info("config.updated", keys=sorted(changes), secrets=sorted( + k for k in changes if k in SECRET_KEYS)) # NAMES only, never values + return await get_effective_config(db) +``` + +```typescript +// Task 14 — use-config.ts (mirror use-seeder.ts) +export function useAIConfig() { + return useQuery({ queryKey: ['config', 'ai'], queryFn: () => api('/config/ai') }) +} +export function useUpdateAIConfig() { + const qc = useQueryClient() + return useMutation({ + mutationFn: (body: AIModelConfigUpdate) => + api('/config/ai', { method: 'PATCH', body }), + onSuccess: () => void qc.invalidateQueries({ queryKey: ['config'] }), + }) +} +``` + +### Integration Points + +```yaml +DATABASE: + - migration: "create table app_config(key PK, value JSONB, updated_at)" + +CONFIG (app/core/config.py): + - extract validate_model_identifier() module-level fn; add "ollama" to providers. + +ROUTES (app/main.py): + - from app.features.config.routes import router as config_router + - app.include_router(config_router) + +STARTUP (app/main.py lifespan): + - try: async with async_session() as db: await apply_overrides_on_startup(db) + except Exception: logger.warning("config.overrides_skipped", ...) + +CACHE RESETS: + - rag/embeddings.py:reset_embedding_service() + - agents/agents/experiment.py:reset_experiment_agent() + - agents/agents/rag_assistant.py:reset_rag_assistant_agent() + +FRONTEND: + - admin.tsx: +1 Tab → AIModelsPanel + - no ROUTES/NAV_ITEMS change (/admin already exists) +``` + +## Validation Loop + +### Level 1: Syntax & Style + +```bash +uv run ruff check . --fix +uv run ruff format . +uv run mypy app/ +uv run pyright app/ +# Expected: zero errors. Both type checkers gate merge. +``` + +### Level 2: Unit Tests + +```bash +uv run pytest -v -m "not integration" \ + app/features/config/tests/ \ + app/core/tests/test_strict_mode_policy.py \ + app/features/agents/tests/ +``` + +Required cases (mirror seeder/rag test style): +- `test_validate_model_identifier_accepts_ollama` — `"ollama:llama3.1"` passes. +- `test_validate_model_identifier_rejects_blank_ollama` — `"ollama:"` raises. +- `test_build_agent_model_ollama_returns_model_object` — type is OpenAIChatModel. +- `test_build_agent_model_cloud_returns_string` — `"anthropic:…"` unchanged. +- `test_get_effective_config_masks_secrets` — no raw key in the response. +- `test_update_config_resets_caches` — globals nulled after update. +- `test_update_config_dimension_guard` — 409 when chunks exist and dimension changes. +- `test_list_ollama_models_parses_tags` — httpx `/api/tags` response mocked. + +### Level 3: Integration + Migration + +```bash +docker compose up -d +uv run alembic upgrade head # app_config table created cleanly +uv run pytest -v -m integration app/features/config/tests/ + +# Manual API smoke: +uv run uvicorn app.main:app --port 8123 & +curl -s localhost:8123/config/ai | python -m json.tool # masked keys +curl -s localhost:8123/config/providers/health | python -m json.tool +curl -s -X PATCH localhost:8123/config/ai \ + -H 'Content-Type: application/json' \ + -d '{"agent_temperature": 0.3}' # 200, value applied +# Ollama path (requires `ollama serve` + a pulled model): +curl -s localhost:8123/config/ollama/models | python -m json.tool +``` + +### Level 4: Frontend + Browser Dogfood (ui-design rule — mandatory) + +```bash +cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run +``` + +Then exercise the running UI via the **webapp-testing** / **agent-browser** skill: +navigate to `http://localhost:5173/admin`, open the **AI Models** tab, change the +agent temperature, Save, confirm the toast + that `GET /config/ai` reflects it. +Type-check passing ≠ UI works. + +## Final Validation Checklist + +- [ ] `uv run ruff check . && uv run ruff format --check .` clean +- [ ] `uv run mypy app/ && uv run pyright app/` clean +- [ ] `uv run pytest -v -m "not integration"` green +- [ ] `docker compose up -d && uv run alembic upgrade head` applies `app_config` +- [ ] `uv run pytest -v -m integration` green +- [ ] `cd frontend && pnpm tsc --noEmit && pnpm lint && pnpm test --run` green +- [ ] `GET /config/ai` returns masked keys (no raw secret anywhere) +- [ ] `PATCH /config/ai` persists + applies live; survives a uvicorn restart +- [ ] Agent set to `ollama:` answers in `/chat` with no restart +- [ ] AI Models tab verified in a real browser (webapp-testing / agent-browser) +- [ ] `docs/_base/API_CONTRACTS.md` + `README.md` + `.env.example` updated +- [ ] Commits: `feat(api,ui): … (#)`; branch `feat/config-ai-model-admin-console` + +## Anti-Patterns to Avoid + +- ❌ Don't write a parallel config system — extend `Settings` + the `app_config` table. +- ❌ Don't return or log raw API keys — masked/presence only. +- ❌ Don't forget the cache resets — a saved change that doesn't take effect is the + #1 failure mode of this feature. +- ❌ Don't pass an `ollama:` string to `Agent(model=...)` — build a model object. +- ❌ Don't let a missing `app_config` table crash startup — warn and continue. +- ❌ Don't change `rag_embedding_dimension` silently when chunks exist. +- ❌ Don't hand-roll the UI — use shadcn components + the ui-design skills, and + dogfood in a browser. +- ❌ Don't weaken `test_strict_mode_policy.py` — keep config request fields JSON-native. + +--- + +## Confidence Score: 8/10 + +High context density: every file, snippet, and gotcha for a one-pass build is +here, and the slice-to-mirror (`seeder`) is an exact structural match. The two +points of residual risk: (1) the precise PydanticAI 1.96 import/constructor for +`OpenAIChatModel` + `OllamaProvider` (mitigated — both imports were verified +available; context7 fallback noted), and (2) the live-mutation of the cached +`Settings` singleton + cache-reset wiring, which is novel for this codebase +(mitigated with explicit reset functions and a dedicated unit test). Deduct two +points for those; everything else is a well-trodden path. diff --git a/README.md b/README.md index ff7b7c46..3dbd92dc 100644 --- a/README.md +++ b/README.md @@ -12,6 +12,7 @@ Portfolio-grade end-to-end retail demand forecasting system. - **RAG Knowledge Base**: Postgres pgvector embeddings + evidence-grounded answers with citations - **Agentic Layer**: PydanticAI agents for autonomous experimentation and evidence-grounded Q&A with human-in-the-loop approval - **Data Seeder (The Forge)**: Reproducible synthetic data generator with realistic time-series patterns, scenario presets, and retail effects +- **AI Models Console**: `/admin` → AI Models tab — swap the agent LLM (incl. fully-local Ollama), the RAG embedding model, and provider API keys at runtime; changes apply live with no restart ## Quick Start diff --git a/alembic/env.py b/alembic/env.py index 38f68518..6abccc57 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -13,6 +13,7 @@ # Import all models for Alembic autogenerate detection from app.features.agents import models as agents_models # noqa: F401 +from app.features.config import models as config_models # noqa: F401 from app.features.data_platform import models as data_platform_models # noqa: F401 from app.features.jobs import models as jobs_models # noqa: F401 from app.features.rag import models as rag_models # noqa: F401 diff --git a/alembic/versions/378c112e4b32_create_app_config_table.py b/alembic/versions/378c112e4b32_create_app_config_table.py new file mode 100644 index 00000000..fce01a1b --- /dev/null +++ b/alembic/versions/378c112e4b32_create_app_config_table.py @@ -0,0 +1,46 @@ +"""create app_config table + +Revision ID: 378c112e4b32 +Revises: a8b9c0d1e234 +Create Date: 2026-05-18 12:38:56.878929 + +""" + +from __future__ import annotations + +from collections.abc import Sequence + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision: str = "378c112e4b32" +down_revision: str | None = "a8b9c0d1e234" +branch_labels: str | Sequence[str] | None = None +depends_on: str | Sequence[str] | None = None + + +def upgrade() -> None: + """Apply migration - create app_config key/value override store.""" + op.create_table( + "app_config", + sa.Column("key", sa.String(length=100), nullable=False), + sa.Column( + "value", + postgresql.JSONB(astext_type=sa.Text()), + nullable=False, + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.PrimaryKeyConstraint("key"), + ) + + +def downgrade() -> None: + """Revert migration - drop app_config table.""" + op.drop_table("app_config") diff --git a/app/core/config.py b/app/core/config.py index 1370fdb5..cdbe5cf1 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -6,6 +6,58 @@ from pydantic import field_validator from pydantic_settings import BaseSettings, SettingsConfigDict +# Valid agent LLM provider prefixes for a "provider:model-name" identifier. +# "ollama" runs the agent fully local via Ollama's OpenAI-compatible endpoint. +VALID_MODEL_PROVIDERS: tuple[str, ...] = ( + "anthropic", + "openai", + "google-gla", + "google-vertex", + "ollama", +) + + +def validate_model_identifier(v: str) -> str: + """Validate an agent model identifier of the form ``provider:model-name``. + + Shared by the ``Settings`` field validators and the runtime config service + (``app/features/config``) so a UI-driven model change is checked the same + way an env-var-driven one is. + + Args: + v: Model identifier string (e.g. ``anthropic:claude-sonnet-4-5``, + ``ollama:llama3.1``). + + Returns: + The validated model identifier, unchanged. + + Raises: + ValueError: If the format is invalid, the model name is blank, or the + provider is not in :data:`VALID_MODEL_PROVIDERS`. + """ + if ":" not in v: + raise ValueError( + f"Invalid model identifier '{v}'. " + "Expected format: 'provider:model-name' " + "(e.g., 'anthropic:claude-sonnet-4-5', 'ollama:llama3.1')" + ) + provider, model_name = v.split(":", 1) + + # Validate model name is non-empty and not just whitespace + if not model_name or not model_name.strip(): + raise ValueError( + f"Invalid model identifier '{v}'. " + "Model name after ':' cannot be empty or blank. " + "Expected format: 'provider:model-name' " + "(e.g., 'anthropic:claude-sonnet-4-5', 'ollama:llama3.1')" + ) + + if provider not in VALID_MODEL_PROVIDERS: + raise ValueError( + f"Unknown provider '{provider}'. Valid providers: {list(VALID_MODEL_PROVIDERS)}" + ) + return v + class Settings(BaseSettings): """Application settings loaded from environment variables.""" @@ -130,39 +182,9 @@ class Settings(BaseSettings): @field_validator("agent_default_model", "agent_fallback_model") @classmethod - def validate_model_identifier(cls, v: str) -> str: - """Validate model identifier format (provider:model-name). - - Args: - v: Model identifier string. - - Returns: - Validated model identifier. - - Raises: - ValueError: If format is invalid or model name is missing. - """ - if ":" not in v: - raise ValueError( - f"Invalid model identifier '{v}'. " - "Expected format: 'provider:model-name' " - "(e.g., 'anthropic:claude-sonnet-4-5', 'google-gla:gemini-3-flash')" - ) - provider, model_name = v.split(":", 1) - - # Validate model name is non-empty and not just whitespace - if not model_name or not model_name.strip(): - raise ValueError( - f"Invalid model identifier '{v}'. " - "Model name after ':' cannot be empty or blank. " - "Expected format: 'provider:model-name' " - "(e.g., 'anthropic:claude-sonnet-4-5', 'google-gla:gemini-3-flash')" - ) - - valid_providers = ["anthropic", "openai", "google-gla", "google-vertex"] - if provider not in valid_providers: - raise ValueError(f"Unknown provider '{provider}'. Valid providers: {valid_providers}") - return v + def _validate_agent_model(cls, v: str) -> str: + """Validate agent model identifiers via :func:`validate_model_identifier`.""" + return validate_model_identifier(v) @property def is_development(self) -> bool: diff --git a/app/features/agents/agents/base.py b/app/features/agents/agents/base.py index bfc4fbcf..abade6c2 100644 --- a/app/features/agents/agents/base.py +++ b/app/features/agents/agents/base.py @@ -5,16 +5,101 @@ from __future__ import annotations +import functools +import inspect import os +from collections.abc import Awaitable, Callable from typing import Any import structlog +from pydantic_ai import ModelRetry +from pydantic_ai.models import Model +from pydantic_ai.models.openai import OpenAIChatModel +from pydantic_ai.providers.ollama import OllamaProvider from app.core.config import get_settings logger = structlog.get_logger() +def recoverable[**P, ToolReturnT]( + func: Callable[P, Awaitable[ToolReturnT]], +) -> Callable[P, Awaitable[ToolReturnT]]: + """Wrap an async agent tool so an expected ``ValueError`` becomes a ``ModelRetry``. + + Input-driven failures (no data for a store, an unknown run id, a malformed + date) should let the model correct its arguments on the next turn instead of + crashing the whole run (issue #176). Other exception types still propagate + as genuine errors. + + Args: + func: The async tool function to wrap. + + Returns: + The wrapped tool function, signature preserved for PydanticAI schema + extraction. + + Raises: + TypeError: If ``func`` is not a coroutine function. The wrapper + ``await``s ``func``, so wrapping a sync callable would only fail + (with an opaque "not awaitable" error) when the tool is first + called — this guard surfaces the mistake at decoration time. + """ + if not inspect.iscoroutinefunction(func): + raise TypeError( + f"@recoverable wraps async tool functions only; " + f"{getattr(func, '__qualname__', func)!r} is not a coroutine function." + ) + + @functools.wraps(func) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> ToolReturnT: + try: + return await func(*args, **kwargs) + except ValueError as exc: + raise ModelRetry(str(exc)) from exc + + return wrapper + + +def build_agent_model(identifier: str) -> str | Model: + """Build the PydanticAI ``model`` argument for an agent identifier. + + Cloud providers accept a plain ``provider:model-name`` string. Ollama does + not — it needs an :class:`OpenAIChatModel` bound to an :class:`OllamaProvider` + pointed at the host's OpenAI-compatible ``/v1`` endpoint. + + Args: + identifier: Model identifier (e.g. ``anthropic:claude-sonnet-4-5``, + ``ollama:llama3.1``). + + Returns: + The identifier string unchanged for cloud providers, or a configured + :class:`OpenAIChatModel` for the ``ollama`` provider. + """ + provider = identifier.split(":", 1)[0] + if provider != "ollama": + return identifier + + settings = get_settings() + model_name = identifier.split(":", 1)[1] + # CRITICAL: Ollama's OpenAI-compatible base ends in /v1. + base_url = settings.ollama_base_url.rstrip("/") + "/v1" + return OpenAIChatModel(model_name, provider=OllamaProvider(base_url=base_url)) + + +def reset_agent_caches() -> None: + """Drop the cached agent singletons so the next build picks up new config. + + Called by the config service after a successful model/key change. Imports + are local to avoid an import cycle (the agent modules import from here). + """ + from app.features.agents.agents.experiment import reset_experiment_agent + from app.features.agents.agents.rag_assistant import reset_rag_assistant_agent + + reset_experiment_agent() + reset_rag_assistant_agent() + + def get_model_identifier() -> str: """Get the configured model identifier for agents. @@ -35,6 +120,19 @@ def get_fallback_model() -> str: return settings.agent_fallback_model +def get_agent_retries() -> int: + """Get the configured retry budget for agent tool calls and output validation. + + PydanticAI defaults to 1 retry; without this the configured + ``agent_retry_attempts`` setting is silently ignored. + + Returns: + Number of retry attempts for tool calls and structured-output validation. + """ + settings = get_settings() + return settings.agent_retry_attempts + + def get_model_settings() -> dict[str, Any]: """Get model settings from configuration for PydanticAI Agent. @@ -68,6 +166,11 @@ def validate_api_key_for_model(model: str) -> None: settings = get_settings() provider = model.split(":")[0] + if provider == "ollama": + # Local Ollama runs without an API key — nothing to validate or export. + logger.debug("agents.api_key_validated", provider=provider, model=model) + return + if provider == "anthropic": if not settings.anthropic_api_key: raise ValueError( @@ -123,13 +226,13 @@ def requires_approval(action_name: str) -> bool: """ TOOL_USAGE_INSTRUCTIONS = """ -TOOL USAGE: -- Use list_runs to find existing experiments -- Use run_backtest to evaluate model performance -- Use compare_runs to analyze differences between runs -- Use create_alias to deploy successful models (requires approval) -- Use archive_run to clean up old experiments (requires approval) -- Use retrieve_context to find documentation +TOOL USAGE (call tools by these EXACT names): +- Use tool_list_runs to find existing experiments +- Use tool_run_backtest to evaluate model performance +- Use tool_compare_backtest_results to compare two backtest results +- Use tool_compare_runs to analyze differences between registered runs +- Use tool_create_alias to deploy successful models (requires approval) +- Use tool_archive_run to clean up old experiments (requires approval) """ SAFETY_INSTRUCTIONS = """ diff --git a/app/features/agents/agents/experiment.py b/app/features/agents/agents/experiment.py index 38d1c6f6..22311139 100644 --- a/app/features/agents/agents/experiment.py +++ b/app/features/agents/agents/experiment.py @@ -13,14 +13,17 @@ from typing import Any, Literal import structlog -from pydantic_ai import Agent, RunContext +from pydantic_ai import Agent, PromptedOutput, RunContext from app.features.agents.agents.base import ( SAFETY_INSTRUCTIONS, SYSTEM_PROMPT_HEADER, TOOL_USAGE_INSTRUCTIONS, + build_agent_model, + get_agent_retries, get_model_identifier, get_model_settings, + recoverable, requires_approval, validate_api_key_for_model, ) @@ -53,12 +56,18 @@ WORKFLOW: 1. Parse the objective to understand what the user wants -2. Check existing runs with list_runs to avoid duplicates -3. Run backtests for candidate models -4. Compare results using compare_backtest_results +2. Check existing runs with tool_list_runs to avoid duplicates +3. Run backtests for candidate models with tool_run_backtest +4. Compare results using tool_compare_backtest_results 5. Formulate recommendation with clear metrics 6. If auto_deploy requested and model beats baselines, propose deployment +CONVERSATIONAL BEHAVIOR: +- If the user greets you or has not yet described a concrete forecasting + objective, reply conversationally in the `summary` field and ask what they + would like to experiment on. Do NOT call any tools until you have a specific + objective (a store and product plus a date range, or an explicit request). + {TOOL_USAGE_INSTRUCTIONS} {SAFETY_INSTRUCTIONS} @@ -74,19 +83,29 @@ def create_experiment_agent() -> Agent[AgentDeps, ExperimentReport]: Returns: Configured Agent instance with tools registered. """ - model = get_model_identifier() - validate_api_key_for_model(model) # Fail-fast validation + identifier = get_model_identifier() + validate_api_key_for_model(identifier) # Fail-fast validation + model = build_agent_model(identifier) # str for cloud, Model object for ollama + retries = get_agent_retries() agent: Agent[AgentDeps, ExperimentReport] = Agent( model=model, deps_type=AgentDeps, - output_type=ExperimentReport, + # PromptedOutput puts the JSON schema in the prompt and parses the + # model's text reply, instead of the default ToolOutput mode which + # weaker/local models fail to satisfy (issue #173). + output_type=PromptedOutput(ExperimentReport), system_prompt=EXPERIMENT_SYSTEM_PROMPT, + # Apply the configured agent_retry_attempts. Without this PydanticAI + # defaults to 1, and weaker models fail structured-output validation. + output_retries=retries, + tool_retries=retries, **get_model_settings(), ) # Register tools with the agent @agent.tool + @recoverable async def tool_list_runs( ctx: RunContext[AgentDeps], page: int = 1, @@ -128,6 +147,7 @@ async def tool_list_runs( ) @agent.tool + @recoverable async def tool_get_run( ctx: RunContext[AgentDeps], run_id: str, @@ -149,6 +169,7 @@ async def tool_get_run( return await get_run(db=ctx.deps.db, run_id=run_id) @agent.tool + @recoverable async def tool_run_backtest( ctx: RunContext[AgentDeps], store_id: int, @@ -210,23 +231,37 @@ async def tool_run_backtest( @agent.tool_plain def tool_compare_backtest_results( - result_a: dict[str, Any], - result_b: dict[str, Any], + result_a: dict[str, Any] | None = None, + result_b: dict[str, Any] | None = None, ) -> dict[str, Any]: """Compare two backtest results. - Use this to analyze which model performs better. + Use this to analyze which model performs better. Both arguments must be + full backtest-result dicts as returned by tool_run_backtest. Args: - result_a: First backtest result. - result_b: Second backtest result. + result_a: First backtest result (from tool_run_backtest). + result_b: Second backtest result (from tool_run_backtest). Returns: - Comparison with metric differences and recommendation. + Comparison with metric differences and recommendation, or an + informative error dict if either result is missing. """ + # Tolerate missing/empty args: return a self-correcting hint instead of + # failing schema validation, which would burn the tool's retry budget + # and crash the whole run with UnexpectedModelBehavior. + if not result_a or not result_b: + return { + "error": "compare_backtest_results needs two backtest results.", + "hint": ( + "Call tool_run_backtest twice first, then pass both result " + "dicts as result_a and result_b." + ), + } return compare_backtest_results(result_a, result_b) @agent.tool + @recoverable async def tool_compare_runs( ctx: RunContext[AgentDeps], run_id_a: str, @@ -255,6 +290,7 @@ async def tool_compare_runs( ) @agent.tool + @recoverable async def tool_create_alias( ctx: RunContext[AgentDeps], alias_name: str, @@ -303,6 +339,7 @@ async def tool_create_alias( ) @agent.tool + @recoverable async def tool_archive_run( ctx: RunContext[AgentDeps], run_id: str, @@ -351,3 +388,13 @@ def get_experiment_agent() -> Agent[AgentDeps, ExperimentReport]: if _experiment_agent is None: _experiment_agent = create_experiment_agent() return _experiment_agent + + +def reset_experiment_agent() -> None: + """Drop the cached experiment agent so the next get_* call rebuilds it. + + Used after a runtime model/key change so the new configuration takes + effect without a process restart. + """ + global _experiment_agent + _experiment_agent = None diff --git a/app/features/agents/agents/rag_assistant.py b/app/features/agents/agents/rag_assistant.py index 97348434..c2288409 100644 --- a/app/features/agents/agents/rag_assistant.py +++ b/app/features/agents/agents/rag_assistant.py @@ -12,14 +12,17 @@ from typing import Any import structlog -from pydantic_ai import Agent, RunContext +from pydantic_ai import Agent, PromptedOutput, RunContext from app.core.config import get_settings from app.features.agents.agents.base import ( SAFETY_INSTRUCTIONS, SYSTEM_PROMPT_HEADER, + build_agent_model, + get_agent_retries, get_model_identifier, get_model_settings, + recoverable, validate_api_key_for_model, ) from app.features.agents.deps import AgentDeps @@ -75,14 +78,23 @@ def create_rag_assistant_agent() -> Agent[AgentDeps, RAGAnswer]: Returns: Configured Agent instance with tools registered. """ - model = get_model_identifier() - validate_api_key_for_model(model) # Fail-fast validation + identifier = get_model_identifier() + validate_api_key_for_model(identifier) # Fail-fast validation + model = build_agent_model(identifier) # str for cloud, Model object for ollama + retries = get_agent_retries() agent: Agent[AgentDeps, RAGAnswer] = Agent( model=model, deps_type=AgentDeps, - output_type=RAGAnswer, + # PromptedOutput puts the JSON schema in the prompt and parses the + # model's text reply, instead of the default ToolOutput mode which + # weaker/local models fail to satisfy (issue #173). + output_type=PromptedOutput(RAGAnswer), system_prompt=RAG_SYSTEM_PROMPT, + # Apply the configured agent_retry_attempts. Without this PydanticAI + # defaults to 1, and weaker models fail structured-output validation. + output_retries=retries, + tool_retries=retries, **get_model_settings(), ) @@ -92,6 +104,7 @@ def create_rag_assistant_agent() -> Agent[AgentDeps, RAGAnswer]: # Register tools with the agent @agent.tool + @recoverable async def tool_retrieve_context( ctx: RunContext[AgentDeps], query: str, @@ -176,6 +189,7 @@ def tool_check_evidence( ) @agent.tool + @recoverable async def tool_list_sources( ctx: RunContext[AgentDeps], ) -> dict[str, Any]: @@ -209,3 +223,13 @@ def get_rag_assistant_agent() -> Agent[AgentDeps, RAGAnswer]: if _rag_assistant_agent is None: _rag_assistant_agent = create_rag_assistant_agent() return _rag_assistant_agent + + +def reset_rag_assistant_agent() -> None: + """Drop the cached RAG assistant agent so the next get_* call rebuilds it. + + Used after a runtime model/key change so the new configuration takes + effect without a process restart. + """ + global _rag_assistant_agent + _rag_assistant_agent = None diff --git a/app/features/agents/service.py b/app/features/agents/service.py index c510befc..751c984c 100644 --- a/app/features/agents/service.py +++ b/app/features/agents/service.py @@ -15,12 +15,14 @@ import asyncio import uuid from collections.abc import AsyncIterator +from contextlib import AbstractContextManager from datetime import UTC, datetime, timedelta from typing import Any, Literal, cast import structlog from pydantic_ai import Agent -from pydantic_ai.messages import ModelMessage +from pydantic_ai.exceptions import UnexpectedModelBehavior +from pydantic_ai.messages import ModelMessage, ModelMessagesTypeAdapter from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession @@ -57,6 +59,22 @@ class NoApprovalPendingError(ValueError): pass +def _sequential_tool_execution() -> AbstractContextManager[None]: + """Run an agent turn's tool calls one at a time, never concurrently. + + Every tool in a run shares the single ``AgentDeps.db`` ``AsyncSession``, + and SQLAlchemy forbids concurrent operations on one session. PydanticAI's + default parallel tool execution therefore raises ``InvalidRequestError`` + whenever a model emits more than one DB-touching tool call in a turn + (issue #172). + + Both :meth:`AgentService.chat` and :meth:`AgentService.stream_chat` wrap + their agent run in this context, so the execution-mode policy lives in + exactly one place. + """ + return Agent.parallel_tool_call_execution_mode("sequential") + + class AgentService: """Service for managing agent sessions and interactions. @@ -238,7 +256,7 @@ async def chat( ) # Run agent with message history - message_history = self._deserialize_messages(session.message_history) + message_history = self._deserialize_messages(session.message_history, session_id) logger.info( "agents.chat_started", @@ -249,18 +267,39 @@ async def chat( ) try: - result = await asyncio.wait_for( - agent.run( - message, - deps=deps, - message_history=message_history, - ), - timeout=self.settings.agent_timeout_seconds, - ) + with _sequential_tool_execution(): + result = await asyncio.wait_for( + agent.run( + message, + deps=deps, + message_history=message_history, + ), + timeout=self.settings.agent_timeout_seconds, + ) except TimeoutError as e: raise TimeoutError( f"Agent response timed out after {self.settings.agent_timeout_seconds} seconds" ) from e + except UnexpectedModelBehavior as e: + # The model misbehaved (e.g. a tool call exceeded its retry budget). + # This is recoverable from the user's perspective — surface a clean + # message instead of leaking the raw PydanticAI exception string. + logger.warning( + "agents.chat_model_misbehavior", + session_id=session_id, + error=str(e), + error_type=type(e).__name__, + ) + session.last_activity = datetime.now(UTC) + await db.flush() + return ChatResponse( + session_id=session_id, + message=( + "I couldn't complete that request — the model produced an " + "invalid tool call. Please try rephrasing, or give me a " + "specific forecasting objective to work on." + ), + ) # Extract tool calls from result tool_calls: list[ToolCallResult] = [] @@ -409,7 +448,7 @@ async def stream_chat( request_id=request_id, ) - message_history = self._deserialize_messages(session.message_history) + message_history = self._deserialize_messages(session.message_history, session_id) logger.info( "agents.stream_chat_started", @@ -419,148 +458,178 @@ async def stream_chat( # Stream the response try: - async with asyncio.timeout(self.settings.agent_timeout_seconds): - async with agent.run_stream( - message, - deps=deps, - message_history=message_history, - ) as result: - try: - async for text in result.stream_text(): - yield StreamEvent( - event_type="text_delta", - data={"delta": text}, - timestamp=datetime.now(UTC), + with _sequential_tool_execution(): + async with asyncio.timeout(self.settings.agent_timeout_seconds): + async with agent.run_stream( + message, + deps=deps, + message_history=message_history, + ) as result: + try: + async for text in result.stream_text(): + yield StreamEvent( + event_type="text_delta", + data={"delta": text}, + timestamp=datetime.now(UTC), + ) + except Exception as e: + # Structured output agents (output_type=...) cannot stream raw text deltas. + # In that case we skip delta streaming and only emit the final complete event. + logger.info( + "agents.stream_chat_text_delta_unavailable", + session_id=session_id, + error=str(e), + error_type=type(e).__name__, ) - except Exception as e: - # Structured output agents (output_type=...) cannot stream raw text deltas. - # In that case we skip delta streaming and only emit the final complete event. - logger.info( - "agents.stream_chat_text_delta_unavailable", - session_id=session_id, - error=str(e), - error_type=type(e).__name__, + + # Get final result and update session + # NOTE: PydanticAI v1.48 exposes get_output() on StreamedRunResult. + final_result: Any = await result.get_output() + usage = result.usage() + + session.message_history = self._serialize_messages(result.all_messages()) + session.total_tokens_used += usage.total_tokens or 0 + session.tool_calls_count += deps.tool_call_count + session.last_activity = datetime.now(UTC) + session.expires_at = session.last_activity + timedelta( + minutes=self.settings.agent_session_ttl_minutes ) - # Get final result and update session - # NOTE: PydanticAI v1.48 exposes get_output() on StreamedRunResult. - final_result: Any = await result.get_output() - usage = result.usage() - - session.message_history = self._serialize_messages(result.all_messages()) - session.total_tokens_used += usage.total_tokens or 0 - session.tool_calls_count += deps.tool_call_count - session.last_activity = datetime.now(UTC) - session.expires_at = session.last_activity + timedelta( - minutes=self.settings.agent_session_ttl_minutes - ) - - await db.flush() - - # Check for pending approval actions (mirror chat() logic) - pending_action = None - pending_approval = False - stream_now = datetime.now(UTC) - - # Check for pending_action in result data (primary trigger) - if hasattr(final_result, "pending_action") and final_result.pending_action: - pending_approval = True - pending_action_data = final_result.pending_action - # Extract action details - support both dict and object with attributes - if isinstance(pending_action_data, dict): - action_type = pending_action_data.get("action_type", "unknown") - arguments = pending_action_data.get("arguments", {}) - description = pending_action_data.get( - "description", f"Agent requested approval for {action_type}" - ) - else: - action_type = getattr(pending_action_data, "action_type", "unknown") - arguments = getattr(pending_action_data, "arguments", {}) - description = getattr( - pending_action_data, - "description", - f"Agent requested approval for {action_type}", + await db.flush() + + # Check for pending approval actions (mirror chat() logic) + pending_action = None + pending_approval = False + stream_now = datetime.now(UTC) + + # Check for pending_action in result data (primary trigger) + if hasattr(final_result, "pending_action") and final_result.pending_action: + pending_approval = True + pending_action_data = final_result.pending_action + # Extract action details - support both dict and object with attributes + if isinstance(pending_action_data, dict): + action_type = pending_action_data.get("action_type", "unknown") + arguments = pending_action_data.get("arguments", {}) + description = pending_action_data.get( + "description", f"Agent requested approval for {action_type}" + ) + else: + action_type = getattr(pending_action_data, "action_type", "unknown") + arguments = getattr(pending_action_data, "arguments", {}) + description = getattr( + pending_action_data, + "description", + f"Agent requested approval for {action_type}", + ) + + session.pending_action = { + "action_id": uuid.uuid4().hex[:16], + "action_type": action_type, + "description": description, + "arguments": arguments, + "created_at": stream_now.isoformat(), + "expires_at": ( + stream_now + + timedelta( + minutes=self.settings.agent_approval_timeout_minutes + ) + ).isoformat(), + } + session.status = SessionStatus.AWAITING_APPROVAL.value + pending_action = self._format_pending_action(session.pending_action) + # Fallback: check approval_required flag (legacy trigger) + elif ( + hasattr(final_result, "approval_required") + and final_result.approval_required + ): + pending_approval = True + session.pending_action = { + "action_id": uuid.uuid4().hex[:16], + "action_type": "unknown", + "description": "Agent requested approval for an action", + "arguments": {}, + "created_at": stream_now.isoformat(), + "expires_at": ( + stream_now + + timedelta( + minutes=self.settings.agent_approval_timeout_minutes + ) + ).isoformat(), + } + session.status = SessionStatus.AWAITING_APPROVAL.value + pending_action = self._format_pending_action(session.pending_action) + + await db.flush() + + # If approval is required, emit approval_required event + if pending_approval and pending_action: + yield StreamEvent( + event_type="approval_required", + data={ + "action": pending_action, + "message": "Human approval required before proceeding.", + }, + timestamp=stream_now, ) - session.pending_action = { - "action_id": uuid.uuid4().hex[:16], - "action_type": action_type, - "description": description, - "arguments": arguments, - "created_at": stream_now.isoformat(), - "expires_at": ( - stream_now - + timedelta(minutes=self.settings.agent_approval_timeout_minutes) - ).isoformat(), - } - session.status = SessionStatus.AWAITING_APPROVAL.value - pending_action = self._format_pending_action(session.pending_action) - # Fallback: check approval_required flag (legacy trigger) - elif ( - hasattr(final_result, "approval_required") - and final_result.approval_required - ): - pending_approval = True - session.pending_action = { - "action_id": uuid.uuid4().hex[:16], - "action_type": "unknown", - "description": "Agent requested approval for an action", - "arguments": {}, - "created_at": stream_now.isoformat(), - "expires_at": ( - stream_now - + timedelta(minutes=self.settings.agent_approval_timeout_minutes) - ).isoformat(), - } - session.status = SessionStatus.AWAITING_APPROVAL.value - pending_action = self._format_pending_action(session.pending_action) - - await db.flush() - - # If approval is required, emit approval_required event - if pending_approval and pending_action: + # Yield completion event + response_message: str = "No response generated." + if final_result: + if hasattr(final_result, "answer") and final_result.answer: + response_message = str(final_result.answer) + elif hasattr(final_result, "summary") and final_result.summary: + response_message = str(final_result.summary) + elif ( + hasattr(final_result, "recommendations") + and final_result.recommendations + ): + recommendations = final_result.recommendations + if isinstance(recommendations, list) and recommendations: + response_message = "\n".join( + str(item) for item in recommendations + ) + else: + response_message = str(final_result) + else: + response_message = str(final_result) + yield StreamEvent( - event_type="approval_required", + event_type="complete", data={ - "action": pending_action, - "message": "Human approval required before proceeding.", + "message": response_message, + "tokens_used": usage.total_tokens or 0, + "tool_calls_count": deps.tool_call_count, + "pending_approval": pending_approval, }, - timestamp=stream_now, + timestamp=datetime.now(UTC), ) - - # Yield completion event - response_message: str = "No response generated." - if final_result: - if hasattr(final_result, "answer") and final_result.answer: - response_message = str(final_result.answer) - elif hasattr(final_result, "summary") and final_result.summary: - response_message = str(final_result.summary) - elif ( - hasattr(final_result, "recommendations") - and final_result.recommendations - ): - recommendations = final_result.recommendations - if isinstance(recommendations, list) and recommendations: - response_message = "\n".join(str(item) for item in recommendations) - else: - response_message = str(final_result) - else: - response_message = str(final_result) - - yield StreamEvent( - event_type="complete", - data={ - "message": response_message, - "tokens_used": usage.total_tokens or 0, - "tool_calls_count": deps.tool_call_count, - "pending_approval": pending_approval, - }, - timestamp=datetime.now(UTC), - ) except TimeoutError as e: raise TimeoutError( f"Agent response timed out after {self.settings.agent_timeout_seconds} seconds" ) from e + except UnexpectedModelBehavior as e: + # The model misbehaved (e.g. a tool call exceeded its retry budget). + # Emit a clean, recoverable `error` event rather than letting the raw + # PydanticAI exception bubble to the WebSocket handler. + logger.warning( + "agents.stream_chat_model_misbehavior", + session_id=session_id, + error=str(e), + error_type=type(e).__name__, + ) + yield StreamEvent( + event_type="error", + data={ + "error": ( + "The assistant produced an invalid tool call and couldn't " + "complete the request. Please try rephrasing your message." + ), + "error_type": "model_behavior_error", + "recoverable": True, + }, + timestamp=datetime.now(UTC), + ) + return logger.info( "agents.stream_chat_completed", @@ -701,75 +770,60 @@ def _serialize_messages( self, messages: list[ModelMessage], ) -> list[dict[str, Any]]: - """Serialize PydanticAI messages for storage. + """Serialize PydanticAI messages to JSON-safe dicts for JSONB storage. - PydanticAI messages (ModelRequest, ModelResponse) are dataclasses, - so we use dataclasses.asdict() for serialization. + Uses PydanticAI's own ``ModelMessagesTypeAdapter`` so the output can be + round-tripped back into real ``ModelMessage`` objects by + :meth:`_deserialize_messages`. Args: - messages: List of ModelMessage objects. + messages: List of ModelMessage objects (e.g. ``result.all_messages()``). Returns: - List of serializable dictionaries. + List of JSON-serializable dictionaries. """ - import dataclasses - from datetime import datetime - - def json_safe(obj: object) -> object: - """Convert non-JSON-serializable objects to JSON-safe types.""" - if isinstance(obj, datetime): - return obj.isoformat() - if isinstance(obj, dict): - return {k: json_safe(v) for k, v in obj.items()} - if isinstance(obj, list): - return [json_safe(item) for item in obj] - # Primitive JSON types pass through - if isinstance(obj, (str, int, float, bool, type(None))): - return obj - # Fallback: convert unknown types to string representation - return str(obj) - - serialized: list[dict[str, Any]] = [] - for msg in messages: - if dataclasses.is_dataclass(msg) and not isinstance(msg, type): - # Convert dataclass to dict, handling nested types - try: - msg_dict = dataclasses.asdict(msg) - # Convert datetime objects to ISO strings - # Cast is safe: json_safe preserves dict structure - msg_dict = cast(dict[str, Any], json_safe(msg_dict)) - # Add kind discriminator for deserialization - if hasattr(msg, "kind"): - msg_dict["kind"] = msg.kind - serialized.append(msg_dict) - except (TypeError, ValueError): - # Fallback for types that can't be converted - serialized.append({"type": type(msg).__name__, "data": str(msg)}) - else: - # Fallback for non-dataclass types - serialized.append({"type": type(msg).__name__, "data": str(msg)}) - return serialized + return cast( + list[dict[str, Any]], + ModelMessagesTypeAdapter.dump_python(messages, mode="json"), + ) def _deserialize_messages( self, data: list[dict[str, Any]], + session_id: str, ) -> list[ModelMessage]: - """Deserialize messages from storage. + """Reconstruct PydanticAI ModelMessage objects from stored dicts. + + PydanticAI's ``run()`` / ``run_stream()`` require ``message_history`` to + be real ``ModelMessage`` instances — passing raw dicts fails when the + framework accesses fields such as ``conversation_id``. Args: data: List of serialized message dictionaries. + session_id: Owning session, logged so a failure can be correlated + with the specific stored record. Returns: - List of ModelMessage objects. - - Note: - PydanticAI handles message reconstruction internally. - We return the raw data for now - the agent.run() method - accepts message history in various formats. + List of ModelMessage objects. Returns an empty list if the stored + data cannot be validated (e.g. it predates this serialization + format) — a lost history is recoverable; a crash is not. """ - # PydanticAI's run() method can accept message history as dicts - # Cast to list[ModelMessage] for type checking - return data # type: ignore[return-value] + if not data: + return [] + try: + return ModelMessagesTypeAdapter.validate_python(data) + except Exception: + # Degrade to an empty history on ANY deserialization failure, not + # just ValidationError: a malformed stored record (wrong shape, + # type errors) must never crash an otherwise-valid agent run. + # exc_info preserves the full exception type, message, and traceback. + logger.warning( + "agents.message_history_deserialize_failed", + session_id=session_id, + message_count=len(data), + exc_info=True, + ) + return [] def _format_pending_action( self, diff --git a/app/features/agents/tests/test_base.py b/app/features/agents/tests/test_base.py new file mode 100644 index 00000000..83193aed --- /dev/null +++ b/app/features/agents/tests/test_base.py @@ -0,0 +1,262 @@ +"""Unit tests for agent base helpers (Ollama-aware model factory).""" + +import re +from collections.abc import Iterator +from typing import Any, cast +from unittest.mock import AsyncMock + +import pytest +from pydantic_ai import ModelRetry +from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart +from pydantic_ai.models.function import AgentInfo, FunctionModel +from pydantic_ai.models.openai import OpenAIChatModel + +from app.core.config import get_settings +from app.features.agents.agents.base import ( + TOOL_USAGE_INSTRUCTIONS, + build_agent_model, + get_agent_retries, + recoverable, + validate_api_key_for_model, +) +from app.features.agents.agents.experiment import ( + EXPERIMENT_SYSTEM_PROMPT, + create_experiment_agent, +) +from app.features.agents.agents.rag_assistant import create_rag_assistant_agent +from app.features.agents.deps import AgentDeps +from app.features.agents.schemas import ExperimentReport, RAGAnswer + + +@pytest.fixture(autouse=True) +def _reset_settings() -> Iterator[None]: + """Reset the settings cache so key mutations do not leak across tests.""" + get_settings.cache_clear() + yield + get_settings.cache_clear() + + +def test_build_agent_model_cloud_returns_string(): + """A cloud identifier is returned unchanged (plain-string Agent path).""" + assert build_agent_model("anthropic:claude-sonnet-4-5") == "anthropic:claude-sonnet-4-5" + + +def test_build_agent_model_openai_returns_string(): + """An openai identifier is also returned unchanged.""" + assert build_agent_model("openai:gpt-4o") == "openai:gpt-4o" + + +def test_build_agent_model_ollama_returns_model_object(): + """An ollama identifier becomes a configured OpenAIChatModel object.""" + model = build_agent_model("ollama:llama3.1") + assert isinstance(model, OpenAIChatModel) + + +def test_validate_api_key_for_model_ollama_skips_key_check(): + """The ollama provider needs no API key — validation must not raise.""" + settings = get_settings() + settings.anthropic_api_key = "" + settings.openai_api_key = "" + settings.google_api_key = "" + # Should return without raising even though no cloud key is configured. + validate_api_key_for_model("ollama:llama3.1") + + +def test_prompts_only_reference_registered_tool_names() -> None: + """Every `tool_*` name in the agent prompts must be an actually-registered tool. + + Regression for issue #175: the prompts named tools as `run_backtest`, + `list_runs`, … but the registered tools are `tool_`-prefixed, so weaker + models called unknown tool names. This test is the single source of truth + for that invariant — the registered set is read off the built agent (not a + hardcoded list), so drift in either direction (a renamed tool or an edited + prompt) fails CI. + """ + settings = get_settings() + settings.agent_default_model = "ollama:llama3.1" + agent = create_experiment_agent() + + captured: dict[str, set[str]] = {} + + def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + captured["registered"] = {tool.name for tool in info.function_tools} + # End the run immediately with a PromptedOutput-parseable text reply. + return ModelResponse(parts=[TextPart(content='{"summary": "noop"}')]) + + agent.run_sync( + "noop", + model=FunctionModel(respond), + deps=AgentDeps(db=AsyncMock(), session_id="test-tool-names"), + ) + registered = captured["registered"] + + # Tool names the prompts instruct the model to call. EXPERIMENT_SYSTEM_PROMPT + # already embeds TOOL_USAGE_INSTRUCTIONS; both are scanned to stay correct + # even if that embedding changes. + prompt_text = TOOL_USAGE_INSTRUCTIONS + EXPERIMENT_SYSTEM_PROMPT + referenced = set(re.findall(r"\btool_[a-z_]+\b", prompt_text)) + + assert referenced, "expected the prompts to name at least one tool" + unknown = referenced - registered + assert not unknown, f"prompts reference unregistered tools: {sorted(unknown)}" + + +async def test_recoverable_converts_valueerror_to_model_retry(): + """A ValueError from a tool becomes a ModelRetry the model can recover from (#176).""" + + @recoverable + async def tool() -> str: + raise ValueError("No data found for store=1") + + with pytest.raises(ModelRetry, match="No data found for store=1"): + await tool() + + +async def test_recoverable_passes_through_other_exceptions(): + """Non-ValueError exceptions are genuine bugs — they must still propagate.""" + + @recoverable + async def tool() -> str: + raise RuntimeError("a real bug") + + with pytest.raises(RuntimeError, match="a real bug"): + await tool() + + +async def test_recoverable_returns_value_on_success(): + """The decorator is transparent when the tool succeeds.""" + + @recoverable + async def tool() -> str: + return "ok" + + assert await tool() == "ok" + + +def test_recoverable_rejects_sync_function() -> None: + """@recoverable is async-only — applying it to a sync function fails fast. + + Without the guard a sync function would be wrapped and then ``await``ed, + surfacing a confusing ``TypeError: ... is not awaitable`` only at call + time. The decorator rejects it at decoration time instead. + """ + + def sync_tool() -> str: + return "nope" + + # recoverable is async-only by type; cast bypasses the static check so the + # runtime guard itself can be exercised. + with pytest.raises(TypeError, match="async tool functions only"): + recoverable(cast(Any, sync_tool)) + + +def test_get_agent_retries_returns_configured_value(): + """get_agent_retries reflects the agent_retry_attempts setting.""" + settings = get_settings() + settings.agent_retry_attempts = 5 + assert get_agent_retries() == 5 + + +def test_experiment_agent_applies_retry_attempts(): + """The experiment agent is built with the configured retry budget. + + Regression for issue #170: agent_retry_attempts was never passed to + Agent(), so PydanticAI silently used its default of 1. + """ + settings = get_settings() + settings.agent_default_model = "ollama:llama3.1" + settings.agent_retry_attempts = 4 + + agent = create_experiment_agent() + + assert agent._max_output_retries == 4 + assert agent._max_tool_retries == 4 + + +def test_rag_assistant_agent_applies_retry_attempts(): + """The RAG assistant agent is built with the configured retry budget.""" + settings = get_settings() + settings.agent_default_model = "ollama:llama3.1" + settings.agent_retry_attempts = 4 + + agent = create_rag_assistant_agent() + + assert agent._max_output_retries == 4 + assert agent._max_tool_retries == 4 + + +def test_experiment_agent_uses_prompted_output() -> None: + """The experiment agent runs in PromptedOutput mode, not default ToolOutput. + + Regression for issue #173: weaker/local models answer in prose and cannot + satisfy the tool-call output contract that the default ToolOutput mode + requires. PromptedOutput puts the JSON schema in the prompt and parses the + model's text reply instead. + + This is asserted behaviorally via the public ``FunctionModel`` test double: + PromptedOutput mode registers no ``final_result`` output tool, and a + plain-text JSON reply is still parsed into a valid ``ExperimentReport``. + """ + settings = get_settings() + settings.agent_default_model = "ollama:llama3.1" + agent = create_experiment_agent() + + report_json = ExperimentReport( + run_id="run-1", + status="success", + summary="seasonal_naive wins", + metrics={"mae": 8.9}, + recommendations=["deploy seasonal_naive"], + ).model_dump_json() + + captured: dict[str, list[str]] = {} + + def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + captured["output_tools"] = [tool.name for tool in info.output_tools] + return ModelResponse(parts=[TextPart(content=report_json)]) + + result = agent.run_sync( + "Run an experiment", + model=FunctionModel(respond), + deps=AgentDeps(db=AsyncMock(), session_id="test-prompted-output"), + ) + + # PromptedOutput mode registers no structured-output tool... + assert captured["output_tools"] == [] + # ...and the plain-text JSON reply is parsed into the structured type. + assert isinstance(result.output, ExperimentReport) + assert result.output.summary == "seasonal_naive wins" + + +def test_rag_assistant_agent_uses_prompted_output() -> None: + """The RAG assistant agent runs in PromptedOutput mode (issue #173). + + Mirrors test_experiment_agent_uses_prompted_output: no ``final_result`` + output tool is registered, and a plain-text JSON reply is parsed into a + valid ``RAGAnswer``. + """ + settings = get_settings() + settings.agent_default_model = "ollama:llama3.1" + agent = create_rag_assistant_agent() + + answer_json = RAGAnswer( + answer="The forecast API supports naive and seasonal_naive models.", + confidence="high", + sources=[{"source_path": "docs/api.md", "relevance": 0.9}], + ).model_dump_json() + + captured: dict[str, list[str]] = {} + + def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + captured["output_tools"] = [tool.name for tool in info.output_tools] + return ModelResponse(parts=[TextPart(content=answer_json)]) + + result = agent.run_sync( + "What models does the forecast API support?", + model=FunctionModel(respond), + deps=AgentDeps(db=AsyncMock(), session_id="test-prompted-output"), + ) + + assert captured["output_tools"] == [] + assert isinstance(result.output, RAGAnswer) + assert result.output.confidence == "high" diff --git a/app/features/agents/tests/test_service.py b/app/features/agents/tests/test_service.py index 4b831c46..08064495 100644 --- a/app/features/agents/tests/test_service.py +++ b/app/features/agents/tests/test_service.py @@ -1,10 +1,21 @@ """Unit tests for agent service.""" +import json +from collections.abc import AsyncIterator from datetime import UTC, datetime, timedelta from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest +from pydantic_ai import Agent +from pydantic_ai.exceptions import UnexpectedModelBehavior +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + TextPart, + UserPromptPart, +) from app.features.agents.deps import AgentDeps from app.features.agents.models import AgentSession, AgentType, SessionStatus @@ -287,6 +298,192 @@ async def test_chat_success( assert response.tokens_used == 100 mock_agent.run.assert_called_once() + @pytest.mark.asyncio + async def test_chat_model_misbehavior_returns_friendly_message( + self, + sample_active_session: AgentSession, + ) -> None: + """A misbehaving model should yield a clean message, not crash. + + Regression for issue #164: a tool call exceeding its retry budget + raised PydanticAI's `UnexpectedModelBehavior`, whose raw string + ("Tool '...' exceeded max retries count of 1") leaked to the user. + """ + service = AgentService() + mock_db = AsyncMock() + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_active_session + mock_db.execute.return_value = mock_result + + mock_agent = MagicMock() + mock_agent.run = AsyncMock( + side_effect=UnexpectedModelBehavior( + "Tool 'tool_compare_backtest_results' exceeded max retries count of 1" + ) + ) + + with patch.object(service, "_get_agent", return_value=mock_agent): + response = await service.chat( + db=mock_db, + session_id=sample_active_session.session_id, + message="Hello", + ) + + assert response.session_id == sample_active_session.session_id + assert response.pending_approval is False + assert "invalid tool call" in response.message + assert "exceeded max retries" not in response.message + + @pytest.mark.asyncio + async def test_chat_runs_tools_sequentially( + self, + sample_active_session: AgentSession, + sample_experiment_report: ExperimentReport, + ) -> None: + """chat() must run the agent under sequential tool execution. + + Regression for issue #172: every tool shares the single AgentDeps.db + AsyncSession, so concurrent tool calls raised SQLAlchemy's + InvalidRequestError. The service must enter PydanticAI's public + ``Agent.parallel_tool_call_execution_mode("sequential")`` context + around the agent run. + """ + service = AgentService() + mock_db = AsyncMock() + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_active_session + mock_db.execute.return_value = mock_result + + run_result = MagicMock() + run_result.output = sample_experiment_report + usage = MagicMock() + usage.total_tokens = 1 + run_result.usage.return_value = usage + run_result.all_messages.return_value = [] + + mock_agent = MagicMock() + mock_agent.run = AsyncMock(return_value=run_result) + + with ( + patch.object(service, "_get_agent", return_value=mock_agent), + patch.object(Agent, "parallel_tool_call_execution_mode") as mock_mode, + ): + await service.chat( + db=mock_db, + session_id=sample_active_session.session_id, + message="Run a backtest", + ) + + mock_mode.assert_called_once_with("sequential") + + +class TestAgentServiceStreamChat: + """Tests for streaming chat functionality.""" + + @pytest.mark.asyncio + async def test_stream_chat_model_misbehavior_yields_error_event( + self, + sample_active_session: AgentSession, + ) -> None: + """A misbehaving model should yield a recoverable `error` event, not crash. + + Regression for issue #164: `UnexpectedModelBehavior` raised inside + `agent.run_stream` bubbled to the WebSocket handler, which echoed the + raw exception string to the client. + """ + service = AgentService() + mock_db = AsyncMock() + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_active_session + mock_db.execute.return_value = mock_result + + class _RaisingStream: + """Async context manager that fails on entry like a misbehaving run.""" + + async def __aenter__(self) -> Any: + raise UnexpectedModelBehavior( + "Tool 'tool_compare_backtest_results' exceeded max retries count of 1" + ) + + async def __aexit__(self, *exc: object) -> bool: + return False + + mock_agent = MagicMock() + mock_agent.run_stream = MagicMock(return_value=_RaisingStream()) + + with patch.object(service, "_get_agent", return_value=mock_agent): + events = [ + event + async for event in service.stream_chat( + db=mock_db, + session_id=sample_active_session.session_id, + message="Hello", + ) + ] + + assert len(events) == 1 + assert events[0].event_type == "error" + assert events[0].data["recoverable"] is True + assert events[0].data["error_type"] == "model_behavior_error" + assert "exceeded max retries" not in events[0].data["error"] + + @pytest.mark.asyncio + async def test_stream_chat_runs_tools_sequentially( + self, + sample_active_session: AgentSession, + ) -> None: + """stream_chat() must also run the agent under sequential tool execution. + + Mirrors test_chat_runs_tools_sequentially for the streaming path so a + future change to only one code path cannot silently reintroduce the + concurrent-session bug from issue #172. + """ + service = AgentService() + mock_db = AsyncMock() + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = sample_active_session + mock_db.execute.return_value = mock_result + + class _StubStream: + """Minimal async-context-manager stand-in for agent.run_stream(...).""" + + async def __aenter__(self) -> MagicMock: + stream = MagicMock() + + async def _stream_text() -> AsyncIterator[str]: + yield "hello" + + stream.stream_text = _stream_text + stream.get_output = AsyncMock(return_value=None) + usage = MagicMock() + usage.total_tokens = 1 + stream.usage.return_value = usage + stream.all_messages.return_value = [] + return stream + + async def __aexit__(self, *exc: object) -> bool: + return False + + mock_agent = MagicMock() + mock_agent.run_stream = MagicMock(return_value=_StubStream()) + + with ( + patch.object(service, "_get_agent", return_value=mock_agent), + patch.object(Agent, "parallel_tool_call_execution_mode") as mock_mode, + ): + async for _event in service.stream_chat( + db=mock_db, + session_id=sample_active_session.session_id, + message="Run a backtest", + ): + pass + + mock_mode.assert_called_once_with("sequential") + class TestAgentServiceApproval: """Tests for approval workflow.""" @@ -467,16 +664,49 @@ def test_serialize_empty_messages(self) -> None: def test_deserialize_empty_messages(self) -> None: """Should handle empty message data.""" service = AgentService() - result = service._deserialize_messages([]) + result = service._deserialize_messages([], "test-session") + assert result == [] + + def test_serialize_deserialize_roundtrip(self) -> None: + """Messages should round-trip back into real ModelMessage objects. + + Regression for issue #166: _deserialize_messages used to return raw + dicts, which crashed PydanticAI 1.96 when it accessed `conversation_id` + on the history items. + """ + service = AgentService() + messages: list[ModelMessage] = [ + ModelRequest(parts=[UserPromptPart(content="run a backtest")]), + ModelResponse(parts=[TextPart(content="done")]), + ] + + serialized = service._serialize_messages(messages) + # Serialized form must survive a JSONB write (pure JSON types only). + json.dumps(serialized) + + restored = service._deserialize_messages(serialized, "test-session") + + assert [type(m).__name__ for m in restored] == ["ModelRequest", "ModelResponse"] + # The attribute whose absence on a dict caused the original crash. + assert restored[0].conversation_id is None + + def test_deserialize_legacy_format_returns_empty(self) -> None: + """Unparseable (pre-#166) stored history degrades to empty, not a crash.""" + service = AgentService() + legacy: list[dict[str, Any]] = [{"type": "ModelRequest", "data": ""}] + result = service._deserialize_messages(legacy, "test-session") assert result == [] - def test_deserialize_returns_raw_data(self) -> None: - """Should return raw data for PydanticAI compatibility.""" + def test_deserialize_non_validation_error_returns_empty(self) -> None: + """Any deserialization failure degrades to empty, not only ValidationError.""" service = AgentService() data: list[dict[str, Any]] = [{"kind": "request", "parts": []}] - result = service._deserialize_messages(data) - # _deserialize_messages returns raw dicts for PydanticAI - assert len(result) == 1 + with patch( + "app.features.agents.service.ModelMessagesTypeAdapter.validate_python", + side_effect=TypeError("unexpected adapter failure"), + ): + result = service._deserialize_messages(data, "test-session") + assert result == [] class TestAgentServicePendingActionFormat: diff --git a/app/features/config/__init__.py b/app/features/config/__init__.py new file mode 100644 index 00000000..9f2c08a8 --- /dev/null +++ b/app/features/config/__init__.py @@ -0,0 +1,6 @@ +"""Runtime-editable application configuration slice. + +Exposes the ``app_config`` key/value override store, the ``/config`` REST +surface, and the service that applies persisted overrides onto the live +``Settings`` singleton (agent LLM model, RAG embedding model, provider keys). +""" diff --git a/app/features/config/models.py b/app/features/config/models.py new file mode 100644 index 00000000..451847a6 --- /dev/null +++ b/app/features/config/models.py @@ -0,0 +1,38 @@ +"""ORM model for the runtime configuration override store. + +The ``app_config`` table is a small key/value store. Each row overrides one +``Settings`` field; the value is wrapped as ``{"v": }`` so the JSONB +column holds a consistent object shape regardless of the scalar type. +""" + +from __future__ import annotations + +import datetime +from typing import Any + +from sqlalchemy import DateTime, String, func +from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy.orm import Mapped, mapped_column + +from app.core.database import Base + + +class AppConfig(Base): + """Key/value override store for runtime-editable settings. + + Attributes: + key: The ``Settings`` field name being overridden (primary key). + value: JSONB envelope ``{"v": }`` carrying the override value. + updated_at: Timestamp of the last write. + """ + + __tablename__ = "app_config" + + key: Mapped[str] = mapped_column(String(100), primary_key=True) + value: Mapped[dict[str, Any]] = mapped_column(JSONB, nullable=False) + updated_at: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + onupdate=func.now(), + nullable=False, + ) diff --git a/app/features/config/routes.py b/app/features/config/routes.py new file mode 100644 index 00000000..a94d1dbc --- /dev/null +++ b/app/features/config/routes.py @@ -0,0 +1,84 @@ +"""FastAPI routes for runtime AI-model configuration. + +Provides the ``/config`` surface backing the dashboard "AI Models" admin tab: +read the effective config, edit it (applied live), and probe provider health. +""" + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_db +from app.core.logging import get_logger +from app.features.config import schemas, service + +router = APIRouter(prefix="/config", tags=["config"]) +logger = get_logger(__name__) + + +@router.get( + "/ai", + response_model=schemas.AIModelConfig, + summary="Get effective AI-model configuration", + description=( + "Return the live agent LLM + RAG embedding configuration. API keys are " + "masked — the raw value is never returned." + ), +) +async def get_ai_config( + db: AsyncSession = Depends(get_db), +) -> schemas.AIModelConfig: + """Return the effective AI-model configuration with masked secrets.""" + return await service.get_effective_config(db) + + +@router.patch( + "/ai", + response_model=schemas.AIModelConfig, + summary="Update AI-model configuration", + description=( + "Persist and immediately apply changes to the agent LLM, RAG embedding " + "model, or provider API keys. Changes take effect with no restart. " + "Returns 409 if an embedding-dimension change would break indexed RAG " + "chunks (resend with force=true to override)." + ), +) +async def update_ai_config( + payload: schemas.AIModelConfigUpdate, + db: AsyncSession = Depends(get_db), +) -> schemas.AIModelConfig: + """Persist + apply an AI-model configuration change. + + Raises: + HTTPException: 400 (no fields), 409 (dimension-change guard), or 422 + (invalid model identifier — raised at the schema boundary). + """ + return await service.update_config(db, payload) + + +@router.get( + "/providers/health", + response_model=list[schemas.ProviderHealth], + summary="Check AI provider connectivity", + description=( + "Report connectivity for each provider: Ollama is probed live, cloud " + "providers report API-key presence." + ), +) +async def get_providers_health() -> list[schemas.ProviderHealth]: + """Return connectivity status for every AI provider.""" + return await service.get_provider_health() + + +@router.get( + "/ollama/models", + response_model=list[schemas.OllamaModel], + summary="List local Ollama models", + description="List the models pulled on the configured Ollama host.", +) +async def get_ollama_models() -> list[schemas.OllamaModel]: + """List the Ollama host's pulled models. + + Raises: + HTTPException: 502 if the Ollama host is unreachable. + """ + return await service.list_ollama_models() diff --git a/app/features/config/schemas.py b/app/features/config/schemas.py new file mode 100644 index 00000000..cca1ae0d --- /dev/null +++ b/app/features/config/schemas.py @@ -0,0 +1,138 @@ +"""Pydantic schemas for the runtime configuration slice. + +Request bodies keep ``ConfigDict(strict=True)`` per repo policy +(``.claude/rules/security-patterns.md``). Every field here is a JSON-native +scalar (str/int/float/bool), so no ``Field(strict=False)`` override is needed +and ``app/core/tests/test_strict_mode_policy.py`` stays green. +""" + +from __future__ import annotations + +from typing import Literal + +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from app.core.config import validate_model_identifier + +# ``Settings`` fields the operator may override at runtime via PATCH /config/ai. +ALLOWED_OVERRIDE_KEYS: frozenset[str] = frozenset( + { + "agent_default_model", + "agent_fallback_model", + "agent_temperature", + "agent_max_tokens", + "agent_thinking_budget", + "rag_embedding_provider", + "rag_embedding_model", + "rag_embedding_dimension", + "ollama_base_url", + "ollama_embedding_model", + "openai_api_key", + "anthropic_api_key", + "google_api_key", + } +) + +# Subset of ALLOWED_OVERRIDE_KEYS holding secrets: masked in responses, +# mirrored into os.environ on save, and NEVER logged or returned raw. +SECRET_KEYS: frozenset[str] = frozenset( + { + "openai_api_key", + "anthropic_api_key", + "google_api_key", + } +) + +# Maps a secret Settings field to its environment-variable name. +SECRET_ENV_NAMES: dict[str, str] = { + "openai_api_key": "OPENAI_API_KEY", + "anthropic_api_key": "ANTHROPIC_API_KEY", + "google_api_key": "GOOGLE_API_KEY", +} + + +class ApiKeyStatus(BaseModel): + """Presence + masked preview of one provider API key (response only).""" + + provider: str = Field(description="Provider name: 'openai' | 'anthropic' | 'google'") + is_set: bool = Field(description="True when a non-empty key is configured") + masked: str | None = Field( + default=None, + description="Masked preview (e.g. 'sk-ant…3f9a'); None when no key is set", + ) + + +class AIModelConfig(BaseModel): + """Effective AI-model configuration — GET /config/ai response.""" + + agent_default_model: str = Field(description="Active agent LLM identifier") + agent_fallback_model: str = Field(description="Fallback agent LLM identifier") + agent_temperature: float = Field(description="Agent sampling temperature") + agent_max_tokens: int = Field(description="Agent response token cap") + agent_thinking_budget: int | None = Field( + description="Extended-reasoning token budget (Gemini 2.5+); None disables it" + ) + rag_embedding_provider: str = Field(description="RAG embedding provider: 'openai' | 'ollama'") + rag_embedding_model: str = Field(description="OpenAI embedding model name") + rag_embedding_dimension: int = Field(description="Embedding vector dimension") + ollama_base_url: str = Field(description="Ollama server base URL") + ollama_embedding_model: str = Field(description="Ollama embedding model name") + api_keys: list[ApiKeyStatus] = Field(description="Per-provider key presence (masked)") + overridden_keys: list[str] = Field( + description="Keys currently sourced from app_config rather than the environment" + ) + + +class AIModelConfigUpdate(BaseModel): + """Partial update for the AI-model configuration — PATCH /config/ai body. + + All fields are optional; only the non-null ones are persisted and applied. + """ + + model_config = ConfigDict(strict=True) + + agent_default_model: str | None = None + agent_fallback_model: str | None = None + agent_temperature: float | None = Field(default=None, ge=0.0, le=2.0) + agent_max_tokens: int | None = Field(default=None, ge=1) + agent_thinking_budget: int | None = Field(default=None, ge=1) + rag_embedding_provider: Literal["openai", "ollama"] | None = None + rag_embedding_model: str | None = None + rag_embedding_dimension: int | None = Field(default=None, ge=1) + ollama_base_url: str | None = None + ollama_embedding_model: str | None = None + openai_api_key: str | None = None + anthropic_api_key: str | None = None + google_api_key: str | None = None + force: bool = Field( + default=False, + description="Bypass the embedding-dimension-change guard (re-index required)", + ) + + @field_validator("agent_default_model", "agent_fallback_model") + @classmethod + def _check_model_identifier(cls, v: str | None) -> str | None: + """Validate agent model identifiers (incl. the ``ollama:`` provider).""" + if v is None: + return None + return validate_model_identifier(v) + + +class OllamaModel(BaseModel): + """One model pulled on the Ollama host (from GET /api/tags).""" + + name: str = Field(description="Model name, e.g. 'llama3.1:latest'") + size_bytes: int | None = Field(default=None, description="On-disk size in bytes") + family: str | None = Field(default=None, description="Model family, e.g. 'llama'") + + +class ProviderHealth(BaseModel): + """Connectivity status for one AI provider — GET /config/providers/health.""" + + provider: str = Field(description="'ollama' | 'openai' | 'anthropic' | 'google'") + reachable: bool = Field(description="True when the provider is usable") + detail: str = Field(description="Human-readable status detail") + models: list[str] = Field( + default_factory=list, + description="Local model names (populated for the 'ollama' provider)", + ) diff --git a/app/features/config/service.py b/app/features/config/service.py new file mode 100644 index 00000000..a9c5da6e --- /dev/null +++ b/app/features/config/service.py @@ -0,0 +1,333 @@ +"""Service layer for the runtime configuration slice. + +Responsibilities: +- Read the *effective* AI-model configuration (Settings + which keys come from + the ``app_config`` override store), with secrets masked. +- Persist operator edits to ``app_config``, apply them onto the cached + ``Settings`` singleton, and invalidate the agent / embedding caches so the + change takes effect live (no process restart). +- Re-apply persisted overrides on startup. +- Probe provider connectivity (Ollama reachability + cloud key presence). + +CRITICAL: API key values are NEVER returned in a GET response and NEVER logged +(only key names + booleans). See ``.claude/rules/security-patterns.md``. +""" + +from __future__ import annotations + +import os +from typing import Any + +import httpx +from fastapi import HTTPException, status +from sqlalchemy import func, select +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.config import get_settings +from app.core.logging import get_logger +from app.features.agents.agents.base import reset_agent_caches +from app.features.config.models import AppConfig +from app.features.config.schemas import ( + ALLOWED_OVERRIDE_KEYS, + SECRET_ENV_NAMES, + SECRET_KEYS, + AIModelConfig, + AIModelConfigUpdate, + ApiKeyStatus, + OllamaModel, + ProviderHealth, +) +from app.features.rag.embeddings import reset_embedding_service +from app.features.rag.models import DocumentChunk + +logger = get_logger(__name__) + +# Scalar types an ``app_config`` override value may hold. +OverrideValue = str | int | float | bool + + +# ============================================================================= +# Helpers +# ============================================================================= + + +def _mask_secret(value: str) -> str | None: + """Return a masked preview of a secret, or None when it is unset. + + Shows a short prefix and the last 4 characters so an operator can confirm + *which* key is configured without the value ever leaving the server intact. + """ + if not value: + return None + if len(value) <= 11: + return f"…{value[-4:]}" + return f"{value[:7]}…{value[-4:]}" + + +def _key_status(provider: str, value: str) -> ApiKeyStatus: + """Build an :class:`ApiKeyStatus` for one provider key.""" + return ApiKeyStatus(provider=provider, is_set=bool(value), masked=_mask_secret(value)) + + +async def _load_overrides(db: AsyncSession) -> dict[str, Any]: + """Load all persisted overrides as a ``{key: scalar}`` mapping.""" + result = await db.execute(select(AppConfig)) + return {row.key: row.value.get("v") for row in result.scalars().all()} + + +async def _count_rag_chunks(db: AsyncSession) -> int: + """Count indexed RAG chunks (used by the embedding-dimension-change guard).""" + result = await db.execute(select(func.count()).select_from(DocumentChunk)) + return int(result.scalar_one()) + + +async def _upsert_app_config(db: AsyncSession, key: str, value: OverrideValue) -> None: + """Insert-or-update one override row (parameter-bound ``ON CONFLICT``).""" + stmt = pg_insert(AppConfig).values(key=key, value={"v": value}) + stmt = stmt.on_conflict_do_update( + index_elements=["key"], + set_={"value": stmt.excluded.value, "updated_at": func.now()}, + ) + await db.execute(stmt) + + +async def _fetch_ollama_models() -> list[OllamaModel]: + """Query the Ollama host's native ``/api/tags`` endpoint. + + Returns: + The host's pulled models. + + Raises: + httpx.HTTPError: If the host is unreachable or returns an error. + """ + settings = get_settings() + base_url = settings.ollama_base_url.rstrip("/") + async with httpx.AsyncClient(timeout=httpx.Timeout(10.0, connect=5.0)) as client: + response = await client.get(f"{base_url}/api/tags") + response.raise_for_status() + data = response.json() + + models: list[OllamaModel] = [] + for entry in data.get("models", []): + details = entry.get("details") or {} + models.append( + OllamaModel( + name=entry.get("name") or entry.get("model") or "unknown", + size_bytes=entry.get("size"), + family=details.get("family"), + ) + ) + return models + + +# ============================================================================= +# Public service functions +# ============================================================================= + + +async def get_effective_config(db: AsyncSession) -> AIModelConfig: + """Return the effective AI-model configuration with secrets masked. + + Args: + db: Async database session. + + Returns: + The live configuration plus which keys are sourced from ``app_config``. + """ + settings = get_settings() + overrides = await _load_overrides(db) + + return AIModelConfig( + agent_default_model=settings.agent_default_model, + agent_fallback_model=settings.agent_fallback_model, + agent_temperature=settings.agent_temperature, + agent_max_tokens=settings.agent_max_tokens, + agent_thinking_budget=settings.agent_thinking_budget, + rag_embedding_provider=settings.rag_embedding_provider, + rag_embedding_model=settings.rag_embedding_model, + rag_embedding_dimension=settings.rag_embedding_dimension, + ollama_base_url=settings.ollama_base_url, + ollama_embedding_model=settings.ollama_embedding_model, + api_keys=[ + _key_status("openai", settings.openai_api_key), + _key_status("anthropic", settings.anthropic_api_key), + _key_status("google", settings.google_api_key), + ], + overridden_keys=sorted(k for k in overrides if k in ALLOWED_OVERRIDE_KEYS), + ) + + +async def apply_overrides_on_startup(db: AsyncSession) -> None: + """Re-apply persisted overrides onto the ``Settings`` singleton at startup. + + Safe by design: if the ``app_config`` table does not exist yet (brand-new + database) the load is caught and the app boots on environment defaults. + + Args: + db: Async database session. + """ + try: + overrides = await _load_overrides(db) + except Exception as exc: # never let config crash startup + logger.warning( + "config.overrides_load_failed", + error=str(exc), + error_type=type(exc).__name__, + ) + return + + if not overrides: + return + + settings = get_settings() + applied: list[str] = [] + for key, value in overrides.items(): + if key not in ALLOWED_OVERRIDE_KEYS: + continue + setattr(settings, key, value) + if key in SECRET_KEYS and isinstance(value, str): + os.environ[SECRET_ENV_NAMES[key]] = value + applied.append(key) + + # Drop cached singletons so the first agent / embedding build uses overrides. + reset_agent_caches() + reset_embedding_service() + logger.info("config.overrides_applied", keys=sorted(applied)) + + +async def update_config(db: AsyncSession, payload: AIModelConfigUpdate) -> AIModelConfig: + """Persist + apply an AI-model configuration change. + + Validates (model identifiers are checked at the schema boundary), guards + against breaking RAG retrieval, persists to ``app_config``, mutates the + live ``Settings`` singleton, and invalidates the agent / embedding caches. + + Args: + db: Async database session. + payload: The partial update; only non-null fields are applied. + + Returns: + The effective configuration after the change. + + Raises: + HTTPException: 400 if no fields were supplied; 409 if the embedding + dimension change would break existing RAG chunks (without force). + """ + settings = get_settings() + changes: dict[str, Any] = payload.model_dump(exclude_none=True, exclude={"force"}) + + if not changes: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="No configuration fields provided to update.", + ) + + # Guard: changing the embedding dimension orphans every existing chunk. + new_dim = changes.get("rag_embedding_dimension") + if new_dim is not None and new_dim != settings.rag_embedding_dimension and not payload.force: + chunk_count = await _count_rag_chunks(db) + if chunk_count > 0: + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail=( + f"Changing the embedding dimension with {chunk_count} indexed " + "RAG chunk(s) breaks retrieval. Delete the RAG sources first, " + "or resend the request with force=true." + ), + ) + + # Persist every override (parameter-bound upsert). + for key, value in changes.items(): + await _upsert_app_config(db, key, value) + await db.commit() + + # Apply onto the live (cached) Settings singleton; mirror secrets to env. + for key, value in changes.items(): + setattr(settings, key, value) + if key in SECRET_KEYS and isinstance(value, str): + # Unconditional overwrite — a replaced key must evict the stale one. + os.environ[SECRET_ENV_NAMES[key]] = value + + # CRITICAL: invalidate caches so the change is visible on the next request. + reset_agent_caches() + reset_embedding_service() + + logger.info( + "config.updated", + keys=sorted(changes), + secrets=sorted(k for k in changes if k in SECRET_KEYS), + ) + return await get_effective_config(db) + + +async def get_provider_health() -> list[ProviderHealth]: + """Report connectivity for every AI provider. + + Ollama is probed live (``/api/tags``); cloud providers report API-key + presence (a cheap, offline proxy for usability). + + Returns: + One :class:`ProviderHealth` per provider. + """ + settings = get_settings() + health: list[ProviderHealth] = [] + + # Ollama — live probe. + try: + models = await _fetch_ollama_models() + health.append( + ProviderHealth( + provider="ollama", + reachable=True, + detail=(f"Reachable at {settings.ollama_base_url} ({len(models)} model(s) pulled)"), + models=[m.name for m in models], + ) + ) + except httpx.HTTPError as exc: + health.append( + ProviderHealth( + provider="ollama", + reachable=False, + detail=f"Not reachable at {settings.ollama_base_url}: {exc}", + ) + ) + + # Cloud providers — key presence. + for provider, key in ( + ("openai", settings.openai_api_key), + ("anthropic", settings.anthropic_api_key), + ("google", settings.google_api_key), + ): + is_set = bool(key) + health.append( + ProviderHealth( + provider=provider, + reachable=is_set, + detail="API key configured" if is_set else "API key not set", + ) + ) + + return health + + +async def list_ollama_models() -> list[OllamaModel]: + """List the Ollama host's pulled models. + + Returns: + The host's pulled models. + + Raises: + HTTPException: 502 if the Ollama host is unreachable. + """ + try: + return await _fetch_ollama_models() + except httpx.HTTPError as exc: + settings = get_settings() + raise HTTPException( + status_code=status.HTTP_502_BAD_GATEWAY, + detail=( + f"Could not reach Ollama at {settings.ollama_base_url}: {exc}. " + "Ensure 'ollama serve' is running." + ), + ) from exc diff --git a/app/features/config/tests/__init__.py b/app/features/config/tests/__init__.py new file mode 100644 index 00000000..f80f6582 --- /dev/null +++ b/app/features/config/tests/__init__.py @@ -0,0 +1 @@ +"""Tests for the runtime configuration slice.""" diff --git a/app/features/config/tests/conftest.py b/app/features/config/tests/conftest.py new file mode 100644 index 00000000..557ba461 --- /dev/null +++ b/app/features/config/tests/conftest.py @@ -0,0 +1,50 @@ +"""Test fixtures for the config feature.""" + +from collections.abc import AsyncGenerator, Iterator + +import pytest +from sqlalchemy import delete +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from app.core.config import get_settings +from app.features.config.models import AppConfig + + +@pytest.fixture(autouse=True) +def reset_caches() -> Iterator[None]: + """Reset the settings cache + agent/embedding singletons between tests. + + Config tests mutate the cached ``Settings`` singleton and the agent / + embedding module globals; isolating them keeps tests order-independent. + """ + from app.features.agents.agents import experiment, rag_assistant + from app.features.rag import embeddings + + get_settings.cache_clear() + yield + get_settings.cache_clear() + experiment._experiment_agent = None + rag_assistant._rag_assistant_agent = None + embeddings._embedding_provider = None + + +@pytest.fixture +async def db_session() -> AsyncGenerator[AsyncSession, None]: + """Async DB session for integration tests; wipes app_config on teardown. + + Requires PostgreSQL to be running (docker-compose up -d) and migrations + applied. + """ + settings = get_settings() + engine = create_async_engine(settings.database_url, echo=False) + async_session_maker = async_sessionmaker(engine, class_=AsyncSession, expire_on_commit=False) + + async with async_session_maker() as session: + try: + yield session + finally: + await session.rollback() + await session.execute(delete(AppConfig)) + await session.commit() + + await engine.dispose() diff --git a/app/features/config/tests/test_routes.py b/app/features/config/tests/test_routes.py new file mode 100644 index 00000000..2f064afe --- /dev/null +++ b/app/features/config/tests/test_routes.py @@ -0,0 +1,156 @@ +"""Unit tests for config slice routes. + +The DB dependency is overridden with a stub session; the service layer is +patched so routes are exercised in isolation. +""" + +from collections.abc import AsyncGenerator, Iterator +from unittest.mock import AsyncMock, patch + +import pytest +from fastapi import HTTPException +from fastapi.testclient import TestClient + +from app.core.database import get_db +from app.features.config import schemas +from app.main import app + + +def _sample_config( + agent_default_model: str = "anthropic:claude-sonnet-4-5", + agent_temperature: float = 0.1, +) -> schemas.AIModelConfig: + """Build a representative AIModelConfig response.""" + return schemas.AIModelConfig( + agent_default_model=agent_default_model, + agent_fallback_model="openai:gpt-4o", + agent_temperature=agent_temperature, + agent_max_tokens=4096, + agent_thinking_budget=None, + rag_embedding_provider="openai", + rag_embedding_model="text-embedding-3-small", + rag_embedding_dimension=1536, + ollama_base_url="http://localhost:11434", + ollama_embedding_model="nomic-embed-text", + api_keys=[schemas.ApiKeyStatus(provider="anthropic", is_set=True, masked="sk-ant-…1234")], + overridden_keys=[], + ) + + +@pytest.fixture +def client() -> Iterator[TestClient]: + """Test client with the DB dependency stubbed (no lifespan, no real DB).""" + + async def _override_get_db() -> AsyncGenerator[AsyncMock, None]: + yield AsyncMock() + + app.dependency_overrides[get_db] = _override_get_db + yield TestClient(app) + app.dependency_overrides.clear() + + +class TestGetAIConfig: + """Tests for GET /config/ai.""" + + def test_returns_effective_config(self, client): + """The endpoint returns the effective config with masked keys.""" + with patch( + "app.features.config.routes.service.get_effective_config", + new=AsyncMock(return_value=_sample_config()), + ): + response = client.get("/config/ai") + + assert response.status_code == 200 + data = response.json() + assert data["agent_default_model"] == "anthropic:claude-sonnet-4-5" + assert data["api_keys"][0]["masked"] == "sk-ant-…1234" + + +class TestUpdateAIConfig: + """Tests for PATCH /config/ai.""" + + def test_patch_applies_change(self, client): + """A valid update returns the new effective config.""" + with patch( + "app.features.config.routes.service.update_config", + new=AsyncMock(return_value=_sample_config(agent_temperature=0.3)), + ): + response = client.patch("/config/ai", json={"agent_temperature": 0.3}) + + assert response.status_code == 200 + assert response.json()["agent_temperature"] == 0.3 + + def test_patch_accepts_ollama_model(self, client): + """An ollama agent model passes schema validation and reaches the service.""" + with patch( + "app.features.config.routes.service.update_config", + new=AsyncMock(return_value=_sample_config(agent_default_model="ollama:llama3.1")), + ): + response = client.patch("/config/ai", json={"agent_default_model": "ollama:llama3.1"}) + + assert response.status_code == 200 + assert response.json()["agent_default_model"] == "ollama:llama3.1" + + def test_patch_rejects_invalid_model(self, client): + """An invalid model identifier is rejected at the schema boundary (422).""" + response = client.patch("/config/ai", json={"agent_default_model": "nope"}) + assert response.status_code == 422 + + def test_patch_surfaces_dimension_conflict(self, client): + """A 409 from the dimension guard propagates to the caller.""" + with patch( + "app.features.config.routes.service.update_config", + new=AsyncMock(side_effect=HTTPException(status_code=409, detail="dimension change")), + ): + response = client.patch("/config/ai", json={"rag_embedding_dimension": 768}) + + assert response.status_code == 409 + + +class TestProviderHealthRoute: + """Tests for GET /config/providers/health.""" + + def test_returns_health(self, client): + """The endpoint returns one entry per provider.""" + with patch( + "app.features.config.routes.service.get_provider_health", + new=AsyncMock( + return_value=[ + schemas.ProviderHealth( + provider="ollama", + reachable=True, + detail="ok", + models=["llama3.1"], + ), + ] + ), + ): + response = client.get("/config/providers/health") + + assert response.status_code == 200 + assert response.json()[0]["provider"] == "ollama" + + +class TestOllamaModelsRoute: + """Tests for GET /config/ollama/models.""" + + def test_returns_models(self, client): + """The endpoint returns the host's pulled models.""" + with patch( + "app.features.config.routes.service.list_ollama_models", + new=AsyncMock(return_value=[schemas.OllamaModel(name="llama3.1:latest")]), + ): + response = client.get("/config/ollama/models") + + assert response.status_code == 200 + assert response.json()[0]["name"] == "llama3.1:latest" + + def test_unreachable_returns_502(self, client): + """An unreachable Ollama host surfaces as a 502.""" + with patch( + "app.features.config.routes.service.list_ollama_models", + new=AsyncMock(side_effect=HTTPException(status_code=502, detail="unreachable")), + ): + response = client.get("/config/ollama/models") + + assert response.status_code == 502 diff --git a/app/features/config/tests/test_schemas.py b/app/features/config/tests/test_schemas.py new file mode 100644 index 00000000..c11232df --- /dev/null +++ b/app/features/config/tests/test_schemas.py @@ -0,0 +1,133 @@ +"""Unit tests for config slice schemas.""" + +import pytest +from pydantic import ValidationError + +from app.core.config import validate_model_identifier +from app.features.config.schemas import ( + AIModelConfig, + AIModelConfigUpdate, + ApiKeyStatus, + OllamaModel, + ProviderHealth, +) + + +class TestValidateModelIdentifier: + """Tests for the shared validate_model_identifier helper.""" + + def test_validate_model_identifier_accepts_ollama(self): + """An 'ollama:' identifier is accepted.""" + assert validate_model_identifier("ollama:llama3.1") == "ollama:llama3.1" + + def test_validate_model_identifier_accepts_cloud(self): + """Cloud-provider identifiers keep working unchanged.""" + assert ( + validate_model_identifier("anthropic:claude-sonnet-4-5") + == "anthropic:claude-sonnet-4-5" + ) + + def test_validate_model_identifier_rejects_blank_ollama(self): + """'ollama:' with no model name is rejected.""" + with pytest.raises(ValueError, match="empty or blank"): + validate_model_identifier("ollama:") + + def test_validate_model_identifier_rejects_missing_colon(self): + """An identifier without a provider prefix is rejected.""" + with pytest.raises(ValueError, match="provider:model-name"): + validate_model_identifier("llama3.1") + + def test_validate_model_identifier_rejects_unknown_provider(self): + """An unknown provider is rejected.""" + with pytest.raises(ValueError, match="Unknown provider"): + validate_model_identifier("pinecone:model") + + +class TestAIModelConfigUpdate: + """Tests for the PATCH /config/ai request body.""" + + def test_accepts_ollama_model(self): + """The update body accepts an ollama agent model.""" + upd = AIModelConfigUpdate(agent_default_model="ollama:llama3.1") + assert upd.agent_default_model == "ollama:llama3.1" + + def test_rejects_blank_ollama_model(self): + """A blank ollama model identifier fails validation.""" + with pytest.raises(ValidationError): + AIModelConfigUpdate(agent_default_model="ollama:") + + def test_rejects_unknown_provider(self): + """An unknown provider on the fallback model fails validation.""" + with pytest.raises(ValidationError): + AIModelConfigUpdate(agent_fallback_model="bad:model") + + def test_force_defaults_false(self): + """The dimension-guard bypass flag defaults to False.""" + assert AIModelConfigUpdate().force is False + + def test_temperature_range_enforced(self): + """Temperature above the allowed range fails validation.""" + with pytest.raises(ValidationError): + AIModelConfigUpdate(agent_temperature=3.0) + + def test_embedding_provider_literal_enforced(self): + """Only 'openai' and 'ollama' are valid embedding providers.""" + with pytest.raises(ValidationError): + AIModelConfigUpdate.model_validate({"rag_embedding_provider": "pinecone"}) + + def test_all_fields_optional(self): + """An empty update body is structurally valid (rejected later in service).""" + upd = AIModelConfigUpdate() + assert upd.agent_default_model is None + assert upd.rag_embedding_dimension is None + + def test_model_validate_json_path(self): + """Exercise FastAPI's validate_python path on the strict request body. + + Mirrors the strict-mode policy requirement: every strict request body + gets at least one model_validate({...}) case. + """ + upd = AIModelConfigUpdate.model_validate( + {"agent_temperature": 0.5, "agent_default_model": "ollama:llama3.1"} + ) + assert upd.agent_temperature == 0.5 + assert upd.agent_default_model == "ollama:llama3.1" + + +class TestResponseSchemas: + """Tests for the response-only schemas.""" + + def test_ai_model_config_constructs(self): + """AIModelConfig accepts a full effective-config payload.""" + cfg = AIModelConfig( + agent_default_model="anthropic:claude-sonnet-4-5", + agent_fallback_model="openai:gpt-4o", + agent_temperature=0.1, + agent_max_tokens=4096, + agent_thinking_budget=None, + rag_embedding_provider="openai", + rag_embedding_model="text-embedding-3-small", + rag_embedding_dimension=1536, + ollama_base_url="http://localhost:11434", + ollama_embedding_model="nomic-embed-text", + api_keys=[], + overridden_keys=[], + ) + assert cfg.agent_thinking_budget is None + + def test_api_key_status(self): + """ApiKeyStatus carries presence + a masked preview.""" + status = ApiKeyStatus(provider="anthropic", is_set=True, masked="sk-ant-…3f9a") + assert status.is_set is True + assert status.masked == "sk-ant-…3f9a" + + def test_ollama_model_optional_fields(self): + """OllamaModel size/family default to None.""" + model = OllamaModel(name="llama3.1:latest") + assert model.size_bytes is None + assert model.family is None + + def test_provider_health_defaults_models_empty(self): + """ProviderHealth.models defaults to an empty list.""" + health = ProviderHealth(provider="openai", reachable=True, detail="ok") + assert health.models == [] diff --git a/app/features/config/tests/test_service.py b/app/features/config/tests/test_service.py new file mode 100644 index 00000000..edaf5866 --- /dev/null +++ b/app/features/config/tests/test_service.py @@ -0,0 +1,280 @@ +"""Tests for the config service layer. + +Unit tests mock the DB session and httpx; integration tests (marked +``integration``) run against a real Postgres via the ``db_session`` fixture. +""" + +from collections.abc import Iterator +from contextlib import contextmanager +from typing import Any +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest +from fastapi import HTTPException + +from app.core.config import get_settings +from app.features.config import service +from app.features.config.schemas import AIModelConfigUpdate, OllamaModel + + +def _mock_db(chunk_count: int = 0, override_rows: list[Any] | None = None) -> MagicMock: + """Build an AsyncSession mock covering execute()/commit() for the service.""" + result = MagicMock() + result.scalars.return_value.all.return_value = override_rows or [] + result.scalar_one.return_value = chunk_count + db = MagicMock() + db.execute = AsyncMock(return_value=result) + db.commit = AsyncMock() + return db + + +@contextmanager +def _patch_ollama_get( + json_payload: dict[str, Any] | None = None, + side_effect: Exception | None = None, +) -> Iterator[None]: + """Patch httpx.AsyncClient so /api/tags calls are served from a fixture.""" + mock_response = MagicMock() + mock_response.json.return_value = json_payload or {} + mock_response.raise_for_status = MagicMock() + + mock_client = MagicMock() + if side_effect is not None: + mock_client.get = AsyncMock(side_effect=side_effect) + else: + mock_client.get = AsyncMock(return_value=mock_response) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + + with patch("app.features.config.service.httpx.AsyncClient", return_value=mock_client): + yield + + +# ============================================================================= +# Unit tests — masking +# ============================================================================= + + +class TestMaskSecret: + """Tests for the secret-masking helper.""" + + def test_empty_value_returns_none(self): + """An unset key masks to None.""" + assert service._mask_secret("") is None + + def test_short_value_masked(self): + """A short key shows only the trailing characters.""" + assert service._mask_secret("abcd1234") == "…1234" + + def test_long_value_never_leaks_raw(self): + """A long key is masked and never contains its raw value.""" + raw = "sk-ant-api03-supersecretvalue9999" + masked = service._mask_secret(raw) + assert masked is not None + assert raw not in masked + assert masked.endswith("9999") + + +# ============================================================================= +# Unit tests — get_effective_config +# ============================================================================= + + +class TestGetEffectiveConfig: + """Tests for get_effective_config.""" + + @pytest.mark.asyncio + async def test_get_effective_config_masks_secrets(self): + """API keys appear masked — the raw value never reaches the response.""" + settings = get_settings() + settings.anthropic_api_key = "sk-ant-supersecretvalue-0001" + + config = await service.get_effective_config(_mock_db()) + + anthropic = next(k for k in config.api_keys if k.provider == "anthropic") + assert anthropic.is_set is True + assert anthropic.masked is not None + assert "supersecretvalue" not in config.model_dump_json() + + +# ============================================================================= +# Unit tests — update_config +# ============================================================================= + + +class TestUpdateConfig: + """Tests for update_config.""" + + @pytest.mark.asyncio + async def test_update_config_resets_caches(self): + """A successful update nulls the agent + embedding singletons.""" + from app.features.agents.agents import experiment, rag_assistant + from app.features.rag import embeddings + + # Seed non-None sentinels so we can prove they were cleared. + experiment._experiment_agent = MagicMock() + rag_assistant._rag_assistant_agent = MagicMock() + embeddings._embedding_provider = MagicMock() + + db = _mock_db() + result = await service.update_config(db, AIModelConfigUpdate(agent_temperature=0.42)) + + assert experiment._experiment_agent is None + assert rag_assistant._rag_assistant_agent is None + assert embeddings._embedding_provider is None + assert result.agent_temperature == 0.42 + db.commit.assert_awaited() + + @pytest.mark.asyncio + async def test_update_config_empty_payload_rejected(self): + """An update with no fields is a 400.""" + with pytest.raises(HTTPException) as exc: + await service.update_config(_mock_db(), AIModelConfigUpdate()) + assert exc.value.status_code == 400 + + @pytest.mark.asyncio + async def test_update_config_dimension_guard(self): + """Changing the embedding dimension with chunks present is a 409.""" + db = _mock_db(chunk_count=5) + with pytest.raises(HTTPException) as exc: + await service.update_config(db, AIModelConfigUpdate(rag_embedding_dimension=768)) + assert exc.value.status_code == 409 + + @pytest.mark.asyncio + async def test_update_config_dimension_guard_bypassed_by_force(self): + """force=true allows the embedding-dimension change despite chunks.""" + db = _mock_db(chunk_count=5) + result = await service.update_config( + db, AIModelConfigUpdate(rag_embedding_dimension=768, force=True) + ) + assert result.rag_embedding_dimension == 768 + + @pytest.mark.asyncio + async def test_update_config_secret_mirrored_to_environment(self): + """Saving an API key mirrors it into os.environ unconditionally.""" + import os + + db = _mock_db() + await service.update_config(db, AIModelConfigUpdate(openai_api_key="sk-new-openai-key-123")) + assert os.environ["OPENAI_API_KEY"] == "sk-new-openai-key-123" + + +# ============================================================================= +# Unit tests — provider health + ollama models +# ============================================================================= + + +class TestProviderHealth: + """Tests for get_provider_health.""" + + @pytest.mark.asyncio + async def test_health_reports_ollama_reachable(self): + """A reachable Ollama host reports its pulled models.""" + with patch( + "app.features.config.service._fetch_ollama_models", + new=AsyncMock(return_value=[OllamaModel(name="llama3.1:latest")]), + ): + health = await service.get_provider_health() + ollama = next(h for h in health if h.provider == "ollama") + assert ollama.reachable is True + assert "llama3.1:latest" in ollama.models + + @pytest.mark.asyncio + async def test_health_reports_ollama_unreachable(self): + """An unreachable Ollama host reports reachable=False.""" + with patch( + "app.features.config.service._fetch_ollama_models", + new=AsyncMock(side_effect=httpx.ConnectError("refused")), + ): + health = await service.get_provider_health() + ollama = next(h for h in health if h.provider == "ollama") + assert ollama.reachable is False + + @pytest.mark.asyncio + async def test_health_reports_cloud_key_presence(self): + """Cloud providers report reachable iff a key is configured.""" + settings = get_settings() + settings.openai_api_key = "sk-test-openai" + settings.anthropic_api_key = "" + + with patch( + "app.features.config.service._fetch_ollama_models", + new=AsyncMock(side_effect=httpx.ConnectError("x")), + ): + health = await service.get_provider_health() + + openai = next(h for h in health if h.provider == "openai") + anthropic = next(h for h in health if h.provider == "anthropic") + assert openai.reachable is True + assert anthropic.reachable is False + + +class TestListOllamaModels: + """Tests for list_ollama_models.""" + + @pytest.mark.asyncio + async def test_list_ollama_models_parses_tags(self): + """A /api/tags response is parsed into OllamaModel objects.""" + payload = { + "models": [ + { + "name": "llama3.1:latest", + "model": "llama3.1", + "size": 4661224676, + "details": {"family": "llama"}, + }, + {"name": "nomic-embed-text:latest", "size": 274302450, "details": {}}, + ] + } + with _patch_ollama_get(json_payload=payload): + models = await service.list_ollama_models() + + assert [m.name for m in models] == [ + "llama3.1:latest", + "nomic-embed-text:latest", + ] + assert models[0].family == "llama" + assert models[0].size_bytes == 4661224676 + + @pytest.mark.asyncio + async def test_list_ollama_models_unreachable_raises_502(self): + """An unreachable Ollama host surfaces as a 502.""" + with _patch_ollama_get(side_effect=httpx.ConnectError("refused")): + with pytest.raises(HTTPException) as exc: + await service.list_ollama_models() + assert exc.value.status_code == 502 + + +# ============================================================================= +# Integration tests — real Postgres round-trips +# ============================================================================= + + +@pytest.mark.integration +class TestConfigServiceIntegration: + """Integration tests requiring a real database.""" + + @pytest.mark.asyncio + async def test_update_and_read_round_trip(self, db_session): + """An override persists and is reported by get_effective_config.""" + await service.update_config(db_session, AIModelConfigUpdate(agent_temperature=0.77)) + config = await service.get_effective_config(db_session) + assert config.agent_temperature == 0.77 + assert "agent_temperature" in config.overridden_keys + + @pytest.mark.asyncio + async def test_apply_overrides_on_startup_reapplies(self, db_session): + """Persisted overrides are re-applied onto a fresh Settings singleton.""" + await service.update_config(db_session, AIModelConfigUpdate(agent_max_tokens=2048)) + # Simulate a process restart. + get_settings.cache_clear() + await service.apply_overrides_on_startup(db_session) + assert get_settings().agent_max_tokens == 2048 + + @pytest.mark.asyncio + async def test_apply_overrides_on_startup_empty_is_noop(self, db_session): + """With no persisted overrides, startup application is a clean no-op.""" + await service.apply_overrides_on_startup(db_session) + config = await service.get_effective_config(db_session) + assert config.overridden_keys == [] diff --git a/app/main.py b/app/main.py index 9cca6275..f473127d 100644 --- a/app/main.py +++ b/app/main.py @@ -7,6 +7,7 @@ from fastapi.middleware.cors import CORSMiddleware from app.core.config import get_settings +from app.core.database import get_session_maker from app.core.exceptions import register_exception_handlers from app.core.health import router as health_router from app.core.logging import configure_logging, get_logger @@ -15,6 +16,8 @@ from app.features.agents.websocket import router as agents_ws_router from app.features.analytics.routes import router as analytics_router from app.features.backtesting.routes import router as backtesting_router +from app.features.config.routes import router as config_router +from app.features.config.service import apply_overrides_on_startup from app.features.demo.routes import router as demo_router from app.features.dimensions.routes import router as dimensions_router from app.features.featuresets.routes import router as featuresets_router @@ -49,6 +52,19 @@ async def lifespan(_app: FastAPI) -> AsyncGenerator[None, None]: debug=settings.debug, ) + # Re-apply persisted runtime config overrides onto the Settings singleton. + # Warn-and-continue: a missing app_config table must never block startup. + try: + session_maker = get_session_maker() + async with session_maker() as db: + await apply_overrides_on_startup(db) + except Exception as exc: # config must never block startup + logger.warning( + "config.overrides_skipped", + error=str(exc), + error_type=type(exc).__name__, + ) + yield # Shutdown @@ -126,6 +142,7 @@ def create_app() -> FastAPI: app.include_router(agents_ws_router) app.include_router(seeder_router) app.include_router(demo_router) + app.include_router(config_router) return app diff --git a/docs/_base/API_CONTRACTS.md b/docs/_base/API_CONTRACTS.md index ec9c8907..b8eea632 100644 --- a/docs/_base/API_CONTRACTS.md +++ b/docs/_base/API_CONTRACTS.md @@ -47,6 +47,10 @@ All endpoints serve JSON; error responses use `application/problem+json` (RFC 78 | seeder | (see `app/features/seeder/routes.py`) | `/seeder/*` | Trigger scenarios, status, customization | | demo | POST | `/demo/run` | Run the end-to-end demo pipeline in-process; returns a `DemoRunResult`. `409 application/problem+json` if a run is already active | | demo | WS | `/demo/stream` | Stream one `StepEvent` per pipeline step for the live Showcase page | +| config | GET | `/config/ai` | Effective AI-model config (agent LLM + RAG embeddings); API keys masked, never raw | +| config | PATCH | `/config/ai` | Persist + apply AI-model changes live (no restart). `409` if an embedding-dimension change would orphan indexed RAG chunks (resend with `force=true`) | +| config | GET | `/config/providers/health` | Per-provider connectivity — Ollama probed live, cloud providers by API-key presence | +| config | GET | `/config/ollama/models` | Models pulled on the configured Ollama host. `502` if the host is unreachable | ## WebSocket Events (`/agents/stream`) diff --git a/docs/rag-ollama-setup.md b/docs/rag-ollama-setup.md index e6a10840..e889df07 100644 --- a/docs/rag-ollama-setup.md +++ b/docs/rag-ollama-setup.md @@ -32,6 +32,11 @@ **HU:** Ez az útmutató dokumentálja, hogyan kell konfigurálni a ForecastLabAI RAG (Retrieval-Augmented Generation) rendszert, hogy Ollama-t használjon embedding provider-ként az OpenAI helyett. Ez lehetővé teszi a teljesen lokális/LAN-alapú embedding generálást külső API függőségek nélkül. +> **Note / Megjegyzés:** Ollama can now also back the **chat agent**, not just RAG +> embeddings. Set `AGENT_DEFAULT_MODEL=ollama:` (e.g. `ollama:llama3.1`), +> or switch it live from the `/admin` → **AI Models** tab — no `.env` edit or +> restart needed. See `PRPs/PRP-18-ai-model-admin-console.md`. + ### Architecture / Architektúra ``` diff --git a/frontend/src/components/admin/ai-models-panel.tsx b/frontend/src/components/admin/ai-models-panel.tsx new file mode 100644 index 00000000..910bcf2c --- /dev/null +++ b/frontend/src/components/admin/ai-models-panel.tsx @@ -0,0 +1,444 @@ +import { useState } from 'react' +import { Loader2, RefreshCw, Save, Cpu, Database, KeyRound, Activity } from 'lucide-react' +import { + useAIConfig, + useProviderHealth, + useOllamaModels, + useUpdateAIConfig, +} from '@/hooks/use-config' +import { ErrorDisplay } from '@/components/common/error-display' +import { LoadingState } from '@/components/common/loading-state' +import { Button } from '@/components/ui/button' +import { Card, CardContent, CardDescription, CardHeader, CardTitle } from '@/components/ui/card' +import { Input } from '@/components/ui/input' +import { + Select, + SelectContent, + SelectItem, + SelectTrigger, + SelectValue, +} from '@/components/ui/select' +import { Checkbox } from '@/components/ui/checkbox' +import { Badge } from '@/components/ui/badge' +import { toast } from 'sonner' +import type { AIModelConfig, AIModelConfigUpdate } from '@/types/api' + +const AGENT_PROVIDERS = ['anthropic', 'openai', 'google-gla', 'google-vertex', 'ollama'] as const +const EMBEDDING_PROVIDERS = ['openai', 'ollama'] as const + +interface FormState { + agentProvider: string + agentModel: string + agentFallback: string + agentTemperature: string + agentMaxTokens: string + agentThinkingBudget: string + ragProvider: string + ragModel: string + ragDimension: string + ollamaBaseUrl: string + ollamaEmbeddingModel: string +} + +function deriveForm(cfg: AIModelConfig): FormState { + const [agentProvider, ...rest] = cfg.agent_default_model.split(':') + return { + agentProvider, + agentModel: rest.join(':'), + agentFallback: cfg.agent_fallback_model, + agentTemperature: String(cfg.agent_temperature), + agentMaxTokens: String(cfg.agent_max_tokens), + agentThinkingBudget: cfg.agent_thinking_budget == null ? '' : String(cfg.agent_thinking_budget), + ragProvider: cfg.rag_embedding_provider, + ragModel: cfg.rag_embedding_model, + ragDimension: String(cfg.rag_embedding_dimension), + ollamaBaseUrl: cfg.ollama_base_url, + ollamaEmbeddingModel: cfg.ollama_embedding_model, + } +} + +function Field({ label, children }: { label: string; children: React.ReactNode }) { + return ( +
+ + {children} +
+ ) +} + +export function AIModelsPanel() { + const { data: config, isLoading, error, refetch } = useAIConfig() + const updateConfig = useUpdateAIConfig() + + // `form` holds operator edits; it is null until the first edit. The + // displayed form is `form ?? deriveForm(config)` — no state-seeding effect. + const [form, setForm] = useState(null) + const [keys, setKeys] = useState({ openai: '', anthropic: '', google: '' }) + const [forceDimension, setForceDimension] = useState(false) + + const agentProvider = + form?.agentProvider ?? config?.agent_default_model.split(':')[0] ?? '' + const ollamaModels = useOllamaModels(agentProvider === 'ollama') + + if (error) return + if (isLoading || !config) return + + const f = form ?? deriveForm(config) + const update = (patch: Partial) => + setForm((prev) => ({ ...(prev ?? deriveForm(config)), ...patch })) + + const save = async (body: AIModelConfigUpdate, label: string) => { + try { + await updateConfig.mutateAsync(body) + toast.success(`${label} saved — applied live, no restart needed`) + } catch (err) { + toast.error(err instanceof Error ? err.message : `${label} update failed`) + } + } + + const saveAgent = () => + save( + { + agent_default_model: `${f.agentProvider}:${f.agentModel.trim()}`, + agent_fallback_model: f.agentFallback.trim(), + agent_temperature: Number(f.agentTemperature), + agent_max_tokens: Number(f.agentMaxTokens), + ...(f.agentThinkingBudget.trim() + ? { agent_thinking_budget: Number(f.agentThinkingBudget) } + : {}), + }, + 'Agent LLM' + ) + + const saveEmbeddings = () => + save( + { + rag_embedding_provider: f.ragProvider as 'openai' | 'ollama', + rag_embedding_model: f.ragModel.trim(), + rag_embedding_dimension: Number(f.ragDimension), + ollama_base_url: f.ollamaBaseUrl.trim(), + ollama_embedding_model: f.ollamaEmbeddingModel.trim(), + force: forceDimension, + }, + 'RAG embeddings' + ) + + const saveKeys = async () => { + const body: AIModelConfigUpdate = {} + if (keys.openai.trim()) body.openai_api_key = keys.openai.trim() + if (keys.anthropic.trim()) body.anthropic_api_key = keys.anthropic.trim() + if (keys.google.trim()) body.google_api_key = keys.google.trim() + if (Object.keys(body).length === 0) { + toast.warning('Enter at least one API key to save') + return + } + await save(body, 'API keys') + setKeys({ openai: '', anthropic: '', google: '' }) + } + + const busy = updateConfig.isPending + + return ( +
+ {/* Agent LLM */} + + + + Agent LLM + + + The model backing the chat agent. Pick ollama to run fully local. + + + +
+ + + + + + {f.agentProvider === 'ollama' ? ( + + ) : ( + update({ agentModel: e.target.value })} + placeholder="claude-sonnet-4-5" + /> + )} + + + + update({ agentFallback: e.target.value })} + placeholder="openai:gpt-4o" + /> + + + + update({ agentTemperature: e.target.value })} + /> + + + + update({ agentMaxTokens: e.target.value })} + /> + + + + update({ agentThinkingBudget: e.target.value })} + /> + +
+
+ +
+
+
+ + {/* RAG Embeddings */} + + + + RAG Embeddings + + + Embedding provider for the knowledge base. Changing the dimension with + indexed chunks requires a re-index. + + + +
+ + + + + + update({ ragModel: e.target.value })} + placeholder="text-embedding-3-small" + /> + + + + update({ ragDimension: e.target.value })} + /> + + + + update({ ollamaEmbeddingModel: e.target.value })} + placeholder="nomic-embed-text" + /> + + + + update({ ollamaBaseUrl: e.target.value })} + placeholder="http://localhost:11434" + /> + +
+
+ + +
+
+
+ + {/* API Keys */} + + + + Provider API Keys + + + Set or replace cloud provider keys. Stored values are never displayed — + only a masked preview. Ollama needs no key. + + + + {(['openai', 'anthropic', 'google'] as const).map((provider) => { + const status = config.api_keys.find((k) => k.provider === provider) + return ( + +
+ + setKeys((k) => ({ ...k, [provider]: e.target.value })) + } + /> + + {status?.is_set ? 'Set' : 'Unset'} + +
+
+ ) + })} +
+ +
+
+
+ + {/* Provider Health */} + +
+ ) +} + +function ProviderHealthCard() { + const { data: health, isLoading, error, refetch, isFetching } = useProviderHealth() + + return ( + + +
+ + Provider Health + + + Ollama is probed live; cloud providers report API-key presence. + +
+ +
+ + {error ? ( + + ) : isLoading ? ( + + ) : ( +
+ {health?.map((h) => ( +
+
+

{h.provider}

+

+ {h.detail} + {h.models.length > 0 && ` • models: ${h.models.join(', ')}`} +

+
+ + {h.reachable ? 'Reachable' : 'Unreachable'} + +
+ ))} +
+ )} +
+
+ ) +} diff --git a/frontend/src/components/admin/index.ts b/frontend/src/components/admin/index.ts new file mode 100644 index 00000000..b1afa9ef --- /dev/null +++ b/frontend/src/components/admin/index.ts @@ -0,0 +1 @@ +export { AIModelsPanel } from './ai-models-panel' diff --git a/frontend/src/hooks/use-config.ts b/frontend/src/hooks/use-config.ts new file mode 100644 index 00000000..2d2a32df --- /dev/null +++ b/frontend/src/hooks/use-config.ts @@ -0,0 +1,48 @@ +import { useQuery, useMutation, useQueryClient } from '@tanstack/react-query' +import { api } from '@/lib/api' +import type { + AIModelConfig, + AIModelConfigUpdate, + OllamaModel, + ProviderHealth, +} from '@/types/api' + +// Query: effective AI-model configuration (agent LLM + RAG embeddings). +export function useAIConfig() { + return useQuery({ + queryKey: ['config', 'ai'], + queryFn: () => api('/config/ai'), + }) +} + +// Query: per-provider connectivity (Ollama probed live, cloud keys by presence). +export function useProviderHealth() { + return useQuery({ + queryKey: ['config', 'health'], + queryFn: () => api('/config/providers/health'), + }) +} + +// Query: models pulled on the Ollama host. Opt-in via `enabled` so it only +// runs when the operator actually needs the picker (e.g. provider === ollama). +export function useOllamaModels(enabled: boolean) { + return useQuery({ + queryKey: ['config', 'ollama-models'], + queryFn: () => api('/config/ollama/models'), + enabled, + retry: false, + }) +} + +// Mutation: persist + apply an AI-model configuration change (no restart). +export function useUpdateAIConfig() { + const queryClient = useQueryClient() + return useMutation({ + mutationFn: (body: AIModelConfigUpdate) => + api('/config/ai', { method: 'PATCH', body }), + onSuccess: () => { + // Refresh every config-derived view (effective config + provider health). + void queryClient.invalidateQueries({ queryKey: ['config'] }) + }, + }) +} diff --git a/frontend/src/pages/admin.tsx b/frontend/src/pages/admin.tsx index e3087f57..b933238c 100644 --- a/frontend/src/pages/admin.tsx +++ b/frontend/src/pages/admin.tsx @@ -17,6 +17,7 @@ import { Warehouse, History, Percent, + Bot, } from 'lucide-react' import { useRagSources, useDeleteRagSource, useIndexDocument } from '@/hooks/use-rag-sources' import { useAliases, useDeleteAlias, useCreateAlias } from '@/hooks/use-runs' @@ -63,6 +64,7 @@ import { import { Tabs, TabsContent, TabsList, TabsTrigger } from '@/components/ui/tabs' import { Badge } from '@/components/ui/badge' import { Skeleton } from '@/components/ui/skeleton' +import { AIModelsPanel } from '@/components/admin' import { toast } from 'sonner' import type { ScenarioInfo, VerifyCheck, VerifyCheckStatus } from '@/types/api' @@ -85,6 +87,10 @@ export default function AdminPage() { Data Seeder
+ + + AI Models + @@ -98,6 +104,10 @@ export default function AdminPage() { + + + + ) diff --git a/frontend/src/types/api.ts b/frontend/src/types/api.ts index 55859a9c..c356c6ab 100644 --- a/frontend/src/types/api.ts +++ b/frontend/src/types/api.ts @@ -356,3 +356,61 @@ export interface DemoRunResult { alias: string | null wall_clock_s: number } + +// === AI Model Configuration (/config) === + +// Presence + masked preview of one provider API key (never the raw value). +export interface ApiKeyStatus { + provider: string + is_set: boolean + masked: string | null +} + +// Effective AI-model configuration — GET /config/ai response. +export interface AIModelConfig { + agent_default_model: string + agent_fallback_model: string + agent_temperature: number + agent_max_tokens: number + agent_thinking_budget: number | null + rag_embedding_provider: string + rag_embedding_model: string + rag_embedding_dimension: number + ollama_base_url: string + ollama_embedding_model: string + api_keys: ApiKeyStatus[] + overridden_keys: string[] +} + +// Partial update for PATCH /config/ai — only non-null fields are applied. +export interface AIModelConfigUpdate { + agent_default_model?: string + agent_fallback_model?: string + agent_temperature?: number + agent_max_tokens?: number + agent_thinking_budget?: number | null + rag_embedding_provider?: 'openai' | 'ollama' + rag_embedding_model?: string + rag_embedding_dimension?: number + ollama_base_url?: string + ollama_embedding_model?: string + openai_api_key?: string + anthropic_api_key?: string + google_api_key?: string + force?: boolean +} + +// One model pulled on the Ollama host. +export interface OllamaModel { + name: string + size_bytes: number | null + family: string | null +} + +// Connectivity status for one AI provider — GET /config/providers/health. +export interface ProviderHealth { + provider: string + reachable: boolean + detail: string + models: string[] +}