From 74097cfc7831f28af7c3272d7389cb62d3290756 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 17 Feb 2026 09:02:51 -0800 Subject: [PATCH 01/49] Unreviewed agent output: make chat and embed interfaces provider-agnostic using pydantic_ai --- src/typeagent/aitools/embeddings.py | 60 +++++++++++++++++++ src/typeagent/aitools/utils.py | 60 +++++++++++++++++++ src/typeagent/aitools/vectorbase.py | 17 ++++-- src/typeagent/knowpro/convknowledge.py | 53 +--------------- src/typeagent/knowpro/convsettings.py | 6 +- src/typeagent/mcp/server.py | 2 +- src/typeagent/storage/sqlite/provider.py | 4 +- tests/conftest.py | 12 ++-- tests/test_conversation_metadata.py | 26 ++++---- tests/test_demo.py | 4 +- .../test_message_text_index_serialization.py | 6 +- tests/test_podcasts.py | 4 +- tests/test_reltermsindex.py | 4 +- tests/test_semrefindex.py | 6 +- tests/test_sqlite_indexes.py | 4 +- tests/test_sqlitestore.py | 4 +- tests/test_storage_providers_unified.py | 8 +-- tests/test_transcripts.py | 6 +- tests/test_vectorbase.py | 8 +-- tools/ingest_vtt.py | 4 +- tools/query.py | 2 +- 21 files changed, 193 insertions(+), 107 deletions(-) diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index 819993f4..c3403af2 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -3,6 +3,7 @@ import asyncio import os +from typing import Protocol, runtime_checkable import numpy as np from numpy.typing import NDArray @@ -20,6 +21,39 @@ type NormalizedEmbeddings = NDArray[np.float32] # An array of embeddings +@runtime_checkable +class IEmbeddingModel(Protocol): + """Provider-agnostic interface for embedding models. + + Implement this protocol to add support for a new embedding provider + (e.g. Anthropic, Gemini, local models). The existing AsyncEmbeddingModel + implements it for OpenAI and Azure OpenAI. + """ + + model_name: str + embedding_size: int + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + """Cache an already-computed embedding under the given key.""" + ... + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + """Compute a single embedding without caching.""" + ... + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + """Compute embeddings for a batch of strings without caching.""" + ... + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + """Retrieve a single embedding, using cache if available.""" + ... + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + """Retrieve embeddings for multiple keys, using cache if available.""" + ... + + DEFAULT_MODEL_NAME = "text-embedding-ada-002" DEFAULT_EMBEDDING_SIZE = 1536 # Default embedding size (required for ada-002) DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING" # We support OpenAI and Azure OpenAI @@ -311,3 +345,29 @@ async def truncate_input(self, input: str) -> tuple[str, int]: return self.encoding.decode(truncated_tokens), self.max_chunk_size else: return input, len(tokens) + + +def create_embedding_model( + embedding_size: int | None = None, + model_name: str | None = None, + **kwargs, +) -> IEmbeddingModel: + """Create an embedding model using OpenAI/Azure OpenAI. + + This is the default factory. To use a different provider, create an + instance of a class that implements ``IEmbeddingModel`` and pass it + directly to ``TextEmbeddingIndexSettings`` or ``ConversationSettings``. + + Args: + embedding_size: Requested embedding dimensionality (provider-specific). + model_name: Model identifier (e.g. "text-embedding-ada-002"). + **kwargs: Extra keyword arguments forwarded to ``AsyncEmbeddingModel``. + + Returns: + An ``IEmbeddingModel`` instance backed by OpenAI / Azure OpenAI. + """ + return AsyncEmbeddingModel( + embedding_size=embedding_size, + model_name=model_name, + **kwargs, + ) diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index b9150334..2459473c 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -3,6 +3,7 @@ """Utilities that are hard to fit in any specific module.""" +import asyncio from contextlib import contextmanager import difflib import os @@ -16,6 +17,8 @@ import typechat +from .auth import AzureTokenProvider, get_shared_token_provider + @contextmanager def timelog(label: str, verbose: bool = True): @@ -87,6 +90,63 @@ def create_translator[T]( return translator +# TODO: Make these parameters that can be configured (e.g. from command line). +DEFAULT_MAX_RETRY_ATTEMPTS = 0 +DEFAULT_TIMEOUT_SECONDS = 25 + + +class ModelWrapper(typechat.TypeChatLanguageModel): + """Wraps a TypeChat model to handle Azure token refresh.""" + + def __init__( + self, + base_model: typechat.TypeChatLanguageModel, + token_provider: AzureTokenProvider, + ): + self.base_model = base_model + self.token_provider = token_provider + + async def complete( + self, prompt: str | list[typechat.PromptSection] + ) -> typechat.Result[str]: + if self.token_provider.needs_refresh(): + loop = asyncio.get_running_loop() + api_key = await loop.run_in_executor( + None, self.token_provider.refresh_token + ) + env: dict[str, str | None] = dict(os.environ) + key_name = "AZURE_OPENAI_API_KEY" + env[key_name] = api_key + self.base_model = typechat.create_language_model(env) + self.base_model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS + return await self.base_model.complete(prompt) + + +def create_typechat_model() -> typechat.TypeChatLanguageModel: + """Create a TypeChat language model using OpenAI or Azure OpenAI. + + Reads ``OPENAI_API_KEY``, ``AZURE_OPENAI_API_KEY`` and related env vars. + Handles Azure ``identity`` token provider for Microsoft internal usage. + + To use a different provider (e.g. Anthropic, Gemini), implement + ``typechat.TypeChatLanguageModel`` directly and pass it to + ``KnowledgeExtractor`` or ``create_translator()``. + """ + env: dict[str, str | None] = dict(os.environ) + key_name = "AZURE_OPENAI_API_KEY" + key = env.get(key_name) + shared_token_provider: AzureTokenProvider | None = None + if key is not None and key.lower() == "identity": + shared_token_provider = get_shared_token_provider() + env[key_name] = shared_token_provider.get_token() + model = typechat.create_language_model(env) + model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS + model.max_retry_attempts = DEFAULT_MAX_RETRY_ATTEMPTS + if shared_token_provider is not None: + model = ModelWrapper(model, shared_token_provider) + return model + + # Vibe-coded by o4-mini-high def list_diff(label_a, a, label_b, b, max_items): """Print colorized diff between two sorted list of numbers.""" diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 3bbc5729..0a52f838 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -6,9 +6,14 @@ import numpy as np -from openai import DEFAULT_MAX_RETRIES +from .embeddings import ( + create_embedding_model, + IEmbeddingModel, + NormalizedEmbedding, + NormalizedEmbeddings, +) -from .embeddings import AsyncEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings +DEFAULT_MAX_RETRIES = 2 @dataclass @@ -19,7 +24,7 @@ class ScoredInt: @dataclass class TextEmbeddingIndexSettings: - embedding_model: AsyncEmbeddingModel + embedding_model: IEmbeddingModel embedding_size: int # Set to embedding_model.embedding_size min_score: float # Between 0.0 and 1.0 max_matches: int | None # >= 1; None means no limit @@ -28,7 +33,7 @@ class TextEmbeddingIndexSettings: def __init__( self, - embedding_model: AsyncEmbeddingModel | None = None, + embedding_model: IEmbeddingModel | None = None, embedding_size: int | None = None, min_score: float | None = None, max_matches: int | None = None, @@ -41,7 +46,7 @@ def __init__( self.max_retries = ( max_retries if max_retries is not None else DEFAULT_MAX_RETRIES ) - self.embedding_model = embedding_model or AsyncEmbeddingModel( + self.embedding_model = embedding_model or create_embedding_model( embedding_size, max_retries=self.max_retries ) self.embedding_size = self.embedding_model.embedding_size @@ -53,7 +58,7 @@ def __init__( class VectorBase: settings: TextEmbeddingIndexSettings _vectors: NormalizedEmbeddings - _model: AsyncEmbeddingModel + _model: IEmbeddingModel _embedding_size: int def __init__(self, settings: TextEmbeddingIndexSettings): diff --git a/src/typeagent/knowpro/convknowledge.py b/src/typeagent/knowpro/convknowledge.py index 4bea97d4..53a10b25 100644 --- a/src/typeagent/knowpro/convknowledge.py +++ b/src/typeagent/knowpro/convknowledge.py @@ -1,62 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import asyncio from dataclasses import dataclass, field -import os import typechat from . import kplib -from ..aitools import auth +from ..aitools.utils import create_typechat_model # Re-export for backward compat -# TODO: Move ModelWrapper and create_typechat_model() to aitools package. - - -# TODO: Make these parameters that can be configured (e.g. from command line). -DEFAULT_MAX_RETRY_ATTEMPTS = 0 -DEFAULT_TIMEOUT_SECONDS = 25 - - -class ModelWrapper(typechat.TypeChatLanguageModel): - def __init__( - self, - base_model: typechat.TypeChatLanguageModel, - token_provider: auth.AzureTokenProvider, - ): - self.base_model = base_model - self.token_provider = token_provider - - async def complete( - self, prompt: str | list[typechat.PromptSection] - ) -> typechat.Result[str]: - if self.token_provider.needs_refresh(): - loop = asyncio.get_running_loop() - api_key = await loop.run_in_executor( - None, self.token_provider.refresh_token - ) - env: dict[str, str | None] = dict(os.environ) - key_name = "AZURE_OPENAI_API_KEY" - env[key_name] = api_key - self.base_model = typechat.create_language_model(env) - self.base_model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS - return await self.base_model.complete(prompt) - - -def create_typechat_model() -> typechat.TypeChatLanguageModel: - env: dict[str, str | None] = dict(os.environ) - key_name = "AZURE_OPENAI_API_KEY" - key = env.get(key_name) - shared_token_provider: auth.AzureTokenProvider | None = None - if key is not None and key.lower() == "identity": - shared_token_provider = auth.get_shared_token_provider() - env[key_name] = shared_token_provider.get_token() - model = typechat.create_language_model(env) - model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS - model.max_retry_attempts = DEFAULT_MAX_RETRY_ATTEMPTS - if shared_token_provider is not None: - model = ModelWrapper(model, shared_token_provider) - return model +# Re-export: callers may still do ``convknowledge.create_typechat_model()``. +__all__ = ["create_typechat_model", "KnowledgeExtractor"] @dataclass diff --git a/src/typeagent/knowpro/convsettings.py b/src/typeagent/knowpro/convsettings.py index 627546ed..7e25cd93 100644 --- a/src/typeagent/knowpro/convsettings.py +++ b/src/typeagent/knowpro/convsettings.py @@ -5,7 +5,7 @@ from dataclasses import dataclass -from ..aitools.embeddings import AsyncEmbeddingModel +from ..aitools.embeddings import create_embedding_model, IEmbeddingModel from ..aitools.vectorbase import TextEmbeddingIndexSettings from .interfaces import IKnowledgeExtractor, IStorageProvider @@ -38,11 +38,11 @@ class ConversationSettings: def __init__( self, - model: AsyncEmbeddingModel | None = None, + model: IEmbeddingModel | None = None, storage_provider: IStorageProvider | None = None, ): # All settings share the same model, so they share the embedding cache. - model = model or AsyncEmbeddingModel() + model = model or create_embedding_model() self.embedding_model = model min_score = 0.85 self.related_term_index_settings = RelatedTermIndexSettings( diff --git a/src/typeagent/mcp/server.py b/src/typeagent/mcp/server.py index 19919c92..dcd4a3cf 100644 --- a/src/typeagent/mcp/server.py +++ b/src/typeagent/mcp/server.py @@ -102,7 +102,7 @@ class ProcessingContext: query_context: query.QueryEvalContext[ podcast.PodcastMessage, TermToSemanticRefIndex ] - embedding_model: embeddings.AsyncEmbeddingModel + embedding_model: embeddings.IEmbeddingModel query_translator: typechat.TypeChatJsonTranslator[SearchQuery] answer_translator: typechat.TypeChatJsonTranslator[AnswerResponse] diff --git a/src/typeagent/storage/sqlite/provider.py b/src/typeagent/storage/sqlite/provider.py index 8fae1b2b..975d6a70 100644 --- a/src/typeagent/storage/sqlite/provider.py +++ b/src/typeagent/storage/sqlite/provider.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone import sqlite3 -from ...aitools.embeddings import AsyncEmbeddingModel +from ...aitools.embeddings import create_embedding_model from ...aitools.vectorbase import TextEmbeddingIndexSettings from ...knowpro import interfaces from ...knowpro.convsettings import MessageTextIndexSettings, RelatedTermIndexSettings @@ -125,7 +125,7 @@ def _resolve_embedding_settings( if provided_message_settings is None: if stored_size is not None or stored_name is not None: - embedding_model = AsyncEmbeddingModel( + embedding_model = create_embedding_model( embedding_size=stored_size, model_name=stored_name, ) diff --git a/tests/conftest.py b/tests/conftest.py index 40aee885..7f7ce210 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -15,7 +15,11 @@ from openai.types.embedding import Embedding import tiktoken -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.embeddings import ( + AsyncEmbeddingModel, + IEmbeddingModel, + TEST_MODEL_NAME, +) from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -90,7 +94,7 @@ def really_needs_auth() -> None: @pytest.fixture(scope="session") -def embedding_model() -> AsyncEmbeddingModel: +def embedding_model() -> IEmbeddingModel: """Fixture to create a test embedding model with small embedding size for faster tests.""" return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) @@ -130,7 +134,7 @@ def temp_db_path() -> Iterator[str]: @pytest.fixture def memory_storage( - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, ) -> MemoryStorageProvider: """Create a memory storage provider with settings.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model=embedding_model) @@ -188,7 +192,7 @@ def get_text_location(self) -> TextLocation: @pytest_asyncio.fixture async def sqlite_storage( - temp_db_path: str, embedding_model: AsyncEmbeddingModel + temp_db_path: str, embedding_model: IEmbeddingModel ) -> AsyncGenerator[SqliteStorageProvider[FakeMessage], None]: """Create a SqliteStorageProvider for testing.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) diff --git a/tests/test_conversation_metadata.py b/tests/test_conversation_metadata.py index eadc125a..69d80dab 100644 --- a/tests/test_conversation_metadata.py +++ b/tests/test_conversation_metadata.py @@ -16,7 +16,11 @@ from pydantic.dataclasses import dataclass -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.embeddings import ( + AsyncEmbeddingModel, + IEmbeddingModel, + TEST_MODEL_NAME, +) from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -54,7 +58,7 @@ def get_knowledge(self) -> KnowledgeResponse: @pytest_asyncio.fixture async def storage_provider( - temp_db_path: str, embedding_model: AsyncEmbeddingModel + temp_db_path: str, embedding_model: IEmbeddingModel ) -> AsyncGenerator[SqliteStorageProvider[DummyMessage], None]: """Create a SqliteStorageProvider for testing conversation metadata.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -270,7 +274,7 @@ def test_get_db_version_with_metadata( @pytest.mark.asyncio async def test_multiple_conversations_different_dbs( - self, embedding_model: AsyncEmbeddingModel + self, embedding_model: IEmbeddingModel ): """Test multiple conversations in different database files.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -343,7 +347,7 @@ async def test_multiple_conversations_different_dbs( @pytest.mark.asyncio async def test_conversation_metadata_single_per_db( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test that only one conversation metadata can exist per database.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -418,7 +422,7 @@ async def test_conversation_metadata_with_special_characters( @pytest.mark.asyncio async def test_conversation_metadata_persistence( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test that conversation metadata persists across provider instances.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -486,7 +490,7 @@ async def test_empty_string_timestamps( @pytest.mark.asyncio async def test_very_long_name_tag( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test conversation metadata with very long name_tag.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -517,7 +521,7 @@ async def test_very_long_name_tag( @pytest.mark.asyncio async def test_unicode_name_tag( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test conversation metadata with Unicode name_tag.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -548,7 +552,7 @@ async def test_unicode_name_tag( @pytest.mark.asyncio async def test_conversation_metadata_shared_access( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test shared access to metadata using the same database file.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -599,7 +603,7 @@ async def test_conversation_metadata_shared_access( @pytest.mark.asyncio async def test_embedding_metadata_mismatch_raises( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Ensure a mismatch between stored metadata and provided settings raises.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -643,7 +647,7 @@ async def test_embedding_metadata_mismatch_raises( @pytest.mark.asyncio async def test_embedding_model_mismatch_raises( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Ensure providing a different embedding model name raises.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) @@ -683,7 +687,7 @@ async def test_embedding_model_mismatch_raises( @pytest.mark.asyncio async def test_updated_at_changes_on_add_messages( - self, temp_db_path: str, embedding_model: AsyncEmbeddingModel + self, temp_db_path: str, embedding_model: IEmbeddingModel ): """Test that updated_at timestamp is updated when messages are added.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) diff --git a/tests/test_demo.py b/tests/test_demo.py index 599f006e..39f2f061 100644 --- a/tests/test_demo.py +++ b/tests/test_demo.py @@ -6,7 +6,7 @@ import textwrap import time -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import ScoredSemanticRefOrdinal from typeagent.podcasts import podcast @@ -33,7 +33,7 @@ async def main(filename_prefix: str): settings = ConversationSettings() model = settings.embedding_model assert model is not None - assert isinstance(model, AsyncEmbeddingModel), f"model is {model!r}" + assert isinstance(model, IEmbeddingModel), f"model is {model!r}" assert settings.thread_settings.embedding_model is model assert ( settings.message_text_index_settings.embedding_index_settings.embedding_model diff --git a/tests/test_message_text_index_serialization.py b/tests/test_message_text_index_serialization.py index 9b14fbe2..4504bf42 100644 --- a/tests/test_message_text_index_serialization.py +++ b/tests/test_message_text_index_serialization.py @@ -8,7 +8,7 @@ import numpy as np import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, TextEmbeddingIndexSettings, @@ -44,7 +44,7 @@ def sqlite_db(self) -> sqlite3.Connection: async def test_message_text_index_serialize_not_empty( self, sqlite_db: sqlite3.Connection, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, needs_auth: None, ): """Test that MessageTextIndex serialization produces non-empty data when populated.""" @@ -111,7 +111,7 @@ async def test_message_text_index_serialize_not_empty( async def test_message_text_index_deserialize_restores_data( self, sqlite_db: sqlite3.Connection, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, needs_auth: None, ): """Test that MessageTextIndex deserialization actually restores data.""" diff --git a/tests/test_podcasts.py b/tests/test_podcasts.py index 6f901a75..97be7c3e 100644 --- a/tests/test_podcasts.py +++ b/tests/test_podcasts.py @@ -6,7 +6,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import Datetime from typeagent.knowpro.serialization import DATA_FILE_SUFFIX, EMBEDDING_FILE_SUFFIX @@ -18,7 +18,7 @@ @pytest.mark.asyncio async def test_ingest_podcast( - really_needs_auth: None, temp_dir: str, embedding_model: AsyncEmbeddingModel + really_needs_auth: None, temp_dir: str, embedding_model: IEmbeddingModel ): # Import the podcast settings = ConversationSettings(embedding_model) diff --git a/tests/test_reltermsindex.py b/tests/test_reltermsindex.py index 57afa913..47d21d59 100644 --- a/tests/test_reltermsindex.py +++ b/tests/test_reltermsindex.py @@ -8,7 +8,7 @@ import pytest_asyncio # TypeAgent imports -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, @@ -30,7 +30,7 @@ @pytest_asyncio.fixture(params=["memory", "sqlite"]) async def related_terms_index( request: pytest.FixtureRequest, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, temp_db_path: str, ) -> AsyncGenerator[ITermToRelatedTermsIndex, None]: class DummyTestMessage(IMessage): diff --git a/tests/test_semrefindex.py b/tests/test_semrefindex.py index 5fbff6cc..20dad6be 100644 --- a/tests/test_semrefindex.py +++ b/tests/test_semrefindex.py @@ -8,7 +8,7 @@ import pytest_asyncio # TypeAgent imports -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, @@ -37,7 +37,7 @@ @pytest_asyncio.fixture(params=["memory", "sqlite"]) async def semantic_ref_index( request: pytest.FixtureRequest, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, temp_db_path: str, ) -> AsyncGenerator[ITermToSemanticRefIndex, None]: """Unified fixture to create a semantic ref index for both memory and SQLite providers.""" @@ -97,7 +97,7 @@ def get_knowledge(self): @pytest_asyncio.fixture(params=["memory", "sqlite"]) async def semantic_ref_setup( request: pytest.FixtureRequest, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, temp_db_path: str, ) -> AsyncGenerator[Dict[str, ITermToSemanticRefIndex | ISemanticRefCollection], None]: """Unified fixture that provides both semantic ref index and collection for testing helper functions.""" diff --git a/tests/test_sqlite_indexes.py b/tests/test_sqlite_indexes.py index 8639dab3..825f57d8 100644 --- a/tests/test_sqlite_indexes.py +++ b/tests/test_sqlite_indexes.py @@ -10,7 +10,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro import interfaces from typeagent.knowpro.convsettings import MessageTextIndexSettings @@ -35,7 +35,7 @@ @pytest.fixture def embedding_settings( - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, ) -> TextEmbeddingIndexSettings: """Create TextEmbeddingIndexSettings for testing.""" return TextEmbeddingIndexSettings(embedding_model) diff --git a/tests/test_sqlitestore.py b/tests/test_sqlitestore.py index 7bd1f98d..ad784221 100644 --- a/tests/test_sqlitestore.py +++ b/tests/test_sqlitestore.py @@ -9,7 +9,7 @@ from pydantic.dataclasses import dataclass -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, @@ -39,7 +39,7 @@ def get_knowledge(self) -> KnowledgeResponse: @pytest_asyncio.fixture async def dummy_sqlite_storage_provider( - temp_db_path: str, embedding_model: AsyncEmbeddingModel + temp_db_path: str, embedding_model: IEmbeddingModel ) -> AsyncGenerator[SqliteStorageProvider[DummyMessage], None]: """Create a SqliteStorageProvider for testing.""" embedding_settings = TextEmbeddingIndexSettings(embedding_model) diff --git a/tests/test_storage_providers_unified.py b/tests/test_storage_providers_unified.py index 67ae9d7c..14829f03 100644 --- a/tests/test_storage_providers_unified.py +++ b/tests/test_storage_providers_unified.py @@ -16,7 +16,7 @@ from pydantic.dataclasses import dataclass -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro import kplib from typeagent.knowpro.convsettings import ( @@ -52,7 +52,7 @@ def get_knowledge(self) -> KnowledgeResponse: @pytest_asyncio.fixture(params=["memory", "sqlite"]) async def storage_provider_type( request: pytest.FixtureRequest, - embedding_model: AsyncEmbeddingModel, + embedding_model: IEmbeddingModel, temp_db_path: str, ) -> AsyncGenerator[tuple[IStorageProvider, str], None]: """Parameterized fixture that provides both memory and sqlite storage providers.""" @@ -328,7 +328,7 @@ async def test_conversation_threads_interface_parity( # Cross-provider validation tests @pytest.mark.asyncio async def test_cross_provider_message_collection_equivalence( - embedding_model: AsyncEmbeddingModel, temp_db_path: str, needs_auth: None + embedding_model: IEmbeddingModel, temp_db_path: str, needs_auth: None ): """Test that both providers handle message collections equivalently.""" # Create both providers with identical settings @@ -586,7 +586,7 @@ async def test_timestamp_index_with_data( @pytest.mark.asyncio async def test_storage_provider_independence( - embedding_model: AsyncEmbeddingModel, temp_db_path: str, needs_auth: None + embedding_model: IEmbeddingModel, temp_db_path: str, needs_auth: None ): """Test that different storage provider instances work independently.""" # Create settings shared between providers diff --git a/tests/test_transcripts.py b/tests/test_transcripts.py index 9d930034..0d0bfb57 100644 --- a/tests/test_transcripts.py +++ b/tests/test_transcripts.py @@ -6,7 +6,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import AsyncEmbeddingModel, IEmbeddingModel from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH from typeagent.transcripts.transcript import ( @@ -88,7 +88,7 @@ def test_get_transcript_info(): @pytest.fixture def conversation_settings( - needs_auth: None, embedding_model: AsyncEmbeddingModel + needs_auth: None, embedding_model: IEmbeddingModel ) -> ConversationSettings: """Create conversation settings for testing.""" return ConversationSettings(embedding_model) @@ -242,7 +242,7 @@ async def test_transcript_creation(): @pytest.mark.asyncio async def test_transcript_knowledge_extraction_slow( - really_needs_auth: None, embedding_model: AsyncEmbeddingModel + really_needs_auth: None, embedding_model: IEmbeddingModel ): """ Test that knowledge extraction works during transcript ingestion. diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index 62abd392..be8d2c07 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -61,8 +61,8 @@ def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples): assert len(bulk_vector_base) == len(vector_base) np.testing.assert_array_equal(bulk_vector_base.serialize(), vector_base.serialize()) - sequential_cache = vector_base._model._embedding_cache - bulk_cache = bulk_vector_base._model._embedding_cache + sequential_cache = vector_base._model._embedding_cache # type: ignore[attr-defined] + bulk_cache = bulk_vector_base._model._embedding_cache # type: ignore[attr-defined] assert set(sequential_cache.keys()) == set(bulk_cache.keys()) for key in keys: np.testing.assert_array_equal(bulk_cache[key], sequential_cache[key]) @@ -85,7 +85,7 @@ async def test_add_key_no_cache(vector_base: VectorBase, sample_embeddings: Samp assert len(vector_base) == len(sample_embeddings) assert ( - vector_base._model._embedding_cache == {} + vector_base._model._embedding_cache == {} # type: ignore[attr-defined] ), "Cache should remain empty when cache=False" @@ -106,7 +106,7 @@ async def test_add_keys_no_cache(vector_base: VectorBase, sample_embeddings: Sam assert len(vector_base) == len(sample_embeddings) assert ( - vector_base._model._embedding_cache == {} + vector_base._model._embedding_cache == {} # type: ignore[attr-defined] ), "Cache should remain empty when cache=False" diff --git a/tools/ingest_vtt.py b/tools/ingest_vtt.py index ad567756..8bcef6d7 100644 --- a/tools/ingest_vtt.py +++ b/tools/ingest_vtt.py @@ -24,7 +24,7 @@ from dotenv import load_dotenv import webvtt -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import create_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import ConversationMetadata from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH @@ -203,7 +203,7 @@ async def ingest_vtt_files( if verbose: print("Setting up conversation settings...") try: - embedding_model = AsyncEmbeddingModel(model_name=embedding_name) + embedding_model = create_embedding_model(model_name=embedding_name) settings = ConversationSettings(embedding_model) # Create metadata with the conversation name diff --git a/tools/query.py b/tools/query.py index 3d28a89e..24b1f4c9 100644 --- a/tools/query.py +++ b/tools/query.py @@ -150,7 +150,7 @@ class ProcessingContext: debug2: typing.Literal["none", "diff", "full", "skip"] debug3: typing.Literal["none", "diff", "full", "nice"] debug4: typing.Literal["none", "diff", "full", "nice"] - embedding_model: embeddings.AsyncEmbeddingModel + embedding_model: embeddings.IEmbeddingModel query_translator: typechat.TypeChatJsonTranslator[search_query_schema.SearchQuery] answer_translator: typechat.TypeChatJsonTranslator[ answer_response_schema.AnswerResponse From d59d7b6909484b9d8ac3b99cf455c6cf9bdc394c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 17 Feb 2026 10:05:17 -0800 Subject: [PATCH 02/49] Agent step 2 -- unreviewed --- src/typeagent/aitools/embeddings.py | 9 +- src/typeagent/aitools/model_registry.py | 214 ++++++++++++++++++++++++ src/typeagent/aitools/utils.py | 9 +- tests/test_model_registry.py | 152 +++++++++++++++++ 4 files changed, 377 insertions(+), 7 deletions(-) create mode 100644 src/typeagent/aitools/model_registry.py create mode 100644 tests/test_model_registry.py diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index c3403af2..ce8ea062 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -94,6 +94,7 @@ def __init__( model_name: str | None = None, endpoint_envvar: str | None = None, max_retries: int = DEFAULT_MAX_RETRIES, + use_azure: bool | None = None, ): if model_name is None: model_name = DEFAULT_MODEL_NAME @@ -122,8 +123,12 @@ def __init__( openai_api_key = os.getenv("OPENAI_API_KEY") azure_api_key = os.getenv("AZURE_OPENAI_API_KEY") - # Prefer OpenAI if both are set, use Azure if only Azure is set - self.use_azure = bool(azure_api_key) and not bool(openai_api_key) + # Determine provider: explicit use_azure overrides auto-detection. + if use_azure is not None: + self.use_azure = use_azure + else: + # Prefer OpenAI if both are set, use Azure if only Azure is set + self.use_azure = bool(azure_api_key) and not bool(openai_api_key) if endpoint_envvar is None: # Check if OpenAI credentials are available, prefer OpenAI over Azure diff --git a/src/typeagent/aitools/model_registry.py b/src/typeagent/aitools/model_registry.py new file mode 100644 index 00000000..7cc03c4b --- /dev/null +++ b/src/typeagent/aitools/model_registry.py @@ -0,0 +1,214 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Provider-agnostic model configuration. + +Create chat and embedding models from ``provider/model`` spec strings:: + + from typeagent.aitools.model_registry import configure_models + + chat, embedder = configure_models( + "openai/gpt-4o", + "openai/text-embedding-3-small", + ) + +Supported built-in providers +----------------------------- + +* ``openai/`` — requires ``OPENAI_API_KEY`` env var. +* ``azure/`` — requires ``AZURE_OPENAI_API_KEY`` (and + ``AZURE_OPENAI_ENDPOINT``) env vars. For Azure, the *model* part of the + spec is the **deployment name**. + +Extending with new providers +---------------------------- + +Implement ``typechat.TypeChatLanguageModel`` for chat, or +``IEmbeddingModel`` for embeddings, then register a factory:: + + from typeagent.aitools.model_registry import ( + register_chat_provider, + register_embedding_provider, + ) + + register_chat_provider("anthropic", my_anthropic_chat_factory) + register_embedding_provider("gemini", my_gemini_embedding_factory) + +Each factory is a callable ``(model_name: str) -> Model``. +""" + +from collections.abc import Callable +import os + +import typechat + +from .embeddings import AsyncEmbeddingModel, IEmbeddingModel + +# --------------------------------------------------------------------------- +# Spec parsing +# --------------------------------------------------------------------------- + +type ChatModelFactory = Callable[[str], typechat.TypeChatLanguageModel] +type EmbeddingModelFactory = Callable[[str], IEmbeddingModel] + + +def _parse_model_spec(spec: str) -> tuple[str, str]: + """Parse ``'provider/model'`` into ``(provider, model_name)``. + + Raises ``ValueError`` on malformed specs. + """ + parts = spec.split("/", 1) + if len(parts) != 2 or not parts[0] or not parts[1]: + raise ValueError( + f"Invalid model spec {spec!r}. " + f"Expected 'provider/model', e.g. 'openai/gpt-4o'." + ) + return parts[0], parts[1] + + +# --------------------------------------------------------------------------- +# Chat model registry +# --------------------------------------------------------------------------- + +_chat_providers: dict[str, ChatModelFactory] = {} + + +def register_chat_provider(provider: str, factory: ChatModelFactory) -> None: + """Register a factory that creates chat models for *provider*.""" + _chat_providers[provider] = factory + + +def _openai_chat(model_name: str) -> typechat.TypeChatLanguageModel: + env: dict[str, str | None] = dict(os.environ) + if not env.get("OPENAI_API_KEY"): + raise RuntimeError("OPENAI_API_KEY required for openai/ chat models.") + env["OPENAI_MODEL"] = model_name + # Force the OpenAI path even when Azure env vars are also present. + env.pop("AZURE_OPENAI_API_KEY", None) + return typechat.create_language_model(env) + + +def _azure_chat(model_name: str) -> typechat.TypeChatLanguageModel: + from .auth import AzureTokenProvider, get_shared_token_provider + from .utils import DEFAULT_MAX_RETRY_ATTEMPTS, DEFAULT_TIMEOUT_SECONDS, ModelWrapper + + env: dict[str, str | None] = dict(os.environ) + key = env.get("AZURE_OPENAI_API_KEY") + if not key: + raise RuntimeError("AZURE_OPENAI_API_KEY required for azure/ chat models.") + env["OPENAI_MODEL"] = model_name + # Force the Azure path even when OPENAI_API_KEY is also present. + env.pop("OPENAI_API_KEY", None) + + shared_token_provider: AzureTokenProvider | None = None + if isinstance(key, str) and key.lower() == "identity": + shared_token_provider = get_shared_token_provider() + env["AZURE_OPENAI_API_KEY"] = shared_token_provider.get_token() + + model = typechat.create_language_model(env) + model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS + model.max_retry_attempts = DEFAULT_MAX_RETRY_ATTEMPTS + if shared_token_provider is not None: + model = ModelWrapper(model, shared_token_provider) + return model + + +register_chat_provider("openai", _openai_chat) +register_chat_provider("azure", _azure_chat) + + +# --------------------------------------------------------------------------- +# Embedding model registry +# --------------------------------------------------------------------------- + +_embedding_providers: dict[str, EmbeddingModelFactory] = {} + + +def register_embedding_provider(provider: str, factory: EmbeddingModelFactory) -> None: + """Register a factory that creates embedding models for *provider*.""" + _embedding_providers[provider] = factory + + +def _openai_embedding(model_name: str) -> IEmbeddingModel: + return AsyncEmbeddingModel(model_name=model_name, use_azure=False) + + +def _azure_embedding(model_name: str) -> IEmbeddingModel: + return AsyncEmbeddingModel(model_name=model_name, use_azure=True) + + +register_embedding_provider("openai", _openai_embedding) +register_embedding_provider("azure", _azure_embedding) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def create_chat_model( + model_spec: str, +) -> typechat.TypeChatLanguageModel: + """Create a chat model from a ``provider/model`` spec. + + Examples:: + + model = create_chat_model("openai/gpt-4o") + model = create_chat_model("azure/my-gpt4o-deployment") + + For Azure, *model* is the **deployment name**, not the underlying + model name. + """ + provider, model_name = _parse_model_spec(model_spec) + factory = _chat_providers.get(provider) + if factory is None: + avail = ", ".join(sorted(_chat_providers)) or "(none)" + raise ValueError( + f"Unknown chat provider {provider!r}. " + f"Available: {avail}. " + f"Use register_chat_provider() to add support." + ) + return factory(model_name) + + +def create_embedding_model( + model_spec: str, +) -> IEmbeddingModel: + """Create an embedding model from a ``provider/model`` spec. + + Examples:: + + model = create_embedding_model("openai/text-embedding-3-small") + model = create_embedding_model("azure/text-embedding-3-small") + """ + provider, model_name = _parse_model_spec(model_spec) + factory = _embedding_providers.get(provider) + if factory is None: + avail = ", ".join(sorted(_embedding_providers)) or "(none)" + raise ValueError( + f"Unknown embedding provider {provider!r}. " + f"Available: {avail}. " + f"Use register_embedding_provider() to add support." + ) + return factory(model_name) + + +def configure_models( + chat_model_spec: str, + embedding_model_spec: str, +) -> tuple[typechat.TypeChatLanguageModel, IEmbeddingModel]: + """Configure both a chat model and an embedding model at once. + + Example:: + + chat, embedder = configure_models( + "openai/gpt-4o", + "openai/text-embedding-3-small", + ) + + settings = ConversationSettings(model=embedder) + extractor = KnowledgeExtractor(model=chat) + """ + return create_chat_model(chat_model_spec), create_embedding_model( + embedding_model_spec + ) diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index 2459473c..b475c728 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -125,12 +125,11 @@ async def complete( def create_typechat_model() -> typechat.TypeChatLanguageModel: """Create a TypeChat language model using OpenAI or Azure OpenAI. - Reads ``OPENAI_API_KEY``, ``AZURE_OPENAI_API_KEY`` and related env vars. - Handles Azure ``identity`` token provider for Microsoft internal usage. + Auto-detects the provider from ``OPENAI_API_KEY`` / ``AZURE_OPENAI_API_KEY`` + environment variables. - To use a different provider (e.g. Anthropic, Gemini), implement - ``typechat.TypeChatLanguageModel`` directly and pass it to - ``KnowledgeExtractor`` or ``create_translator()``. + For explicit provider selection, use :func:`model_registry.create_chat_model` + with a spec string like ``"openai/gpt-4o"`` or ``"azure/my-deployment"``. """ env: dict[str, str | None] = dict(os.environ) key_name = "AZURE_OPENAI_API_KEY" diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py new file mode 100644 index 00000000..11b4f71e --- /dev/null +++ b/tests/test_model_registry.py @@ -0,0 +1,152 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import pytest + +import typechat + +from typeagent.aitools.embeddings import ( + AsyncEmbeddingModel, + IEmbeddingModel, + TEST_MODEL_NAME, +) +from typeagent.aitools.model_registry import ( + _chat_providers, + _embedding_providers, + _parse_model_spec, + configure_models, + create_chat_model, + create_embedding_model, + register_chat_provider, + register_embedding_provider, +) + +# --------------------------------------------------------------------------- +# Spec parsing +# --------------------------------------------------------------------------- + + +def test_parse_valid_specs() -> None: + assert _parse_model_spec("openai/gpt-4o") == ("openai", "gpt-4o") + assert _parse_model_spec("azure/my-deployment") == ("azure", "my-deployment") + assert _parse_model_spec("anthropic/claude-3.5-sonnet") == ( + "anthropic", + "claude-3.5-sonnet", + ) + + +def test_parse_spec_preserves_slashes() -> None: + """Only the first '/' is a separator; the rest belong to the model name.""" + assert _parse_model_spec("provider/model/variant") == ( + "provider", + "model/variant", + ) + + +def test_parse_invalid_specs() -> None: + with pytest.raises(ValueError, match="Invalid model spec"): + _parse_model_spec("noslash") + with pytest.raises(ValueError, match="Invalid model spec"): + _parse_model_spec("/model") + with pytest.raises(ValueError, match="Invalid model spec"): + _parse_model_spec("provider/") + with pytest.raises(ValueError, match="Invalid model spec"): + _parse_model_spec("") + + +# --------------------------------------------------------------------------- +# Built-in registration +# --------------------------------------------------------------------------- + + +def test_builtin_providers_registered() -> None: + assert "openai" in _chat_providers + assert "azure" in _chat_providers + assert "openai" in _embedding_providers + assert "azure" in _embedding_providers + + +# --------------------------------------------------------------------------- +# Unknown provider errors +# --------------------------------------------------------------------------- + + +def test_unknown_chat_provider() -> None: + with pytest.raises(ValueError, match="Unknown chat provider"): + create_chat_model("magical/unicorn") + + +def test_unknown_embedding_provider() -> None: + with pytest.raises(ValueError, match="Unknown embedding provider"): + create_embedding_model("magical/unicorn") + + +# --------------------------------------------------------------------------- +# Custom provider registration +# --------------------------------------------------------------------------- + + +class FakeChatModel(typechat.TypeChatLanguageModel): + """Minimal chat model for registry tests.""" + + async def complete( + self, prompt: str | list[typechat.PromptSection] + ) -> typechat.Result[str]: + return typechat.Success("fake") + + +def test_register_and_use_custom_chat_provider() -> None: + instance = FakeChatModel() + register_chat_provider("_test_chat", lambda name: instance) + try: + result = create_chat_model("_test_chat/any-model") + assert result is instance + finally: + _chat_providers.pop("_test_chat", None) + + +def test_register_and_use_custom_embedding_provider() -> None: + instance = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + register_embedding_provider("_test_embed", lambda name: instance) + try: + result = create_embedding_model("_test_embed/any-model") + assert result is instance + assert isinstance(result, IEmbeddingModel) + finally: + _embedding_providers.pop("_test_embed", None) + + +def test_model_name_forwarded_to_factory() -> None: + """The model portion of the spec is passed to the factory.""" + received: list[str] = [] + + def capture_factory(model_name: str) -> IEmbeddingModel: + received.append(model_name) + return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + + register_embedding_provider("_test_fwd", capture_factory) + try: + create_embedding_model("_test_fwd/text-embedding-3-small") + assert received == ["text-embedding-3-small"] + finally: + _embedding_providers.pop("_test_fwd", None) + + +# --------------------------------------------------------------------------- +# configure_models +# --------------------------------------------------------------------------- + + +def test_configure_models() -> None: + chat_instance = FakeChatModel() + embed_instance = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + + register_chat_provider("_test_cm", lambda name: chat_instance) + register_embedding_provider("_test_cm", lambda name: embed_instance) + try: + chat, embedder = configure_models("_test_cm/chat", "_test_cm/embed") + assert chat is chat_instance + assert embedder is embed_instance + finally: + _chat_providers.pop("_test_cm", None) + _embedding_providers.pop("_test_cm", None) From ff733f50b2d5b41f3457d6f34d0ed8d66b7183a6 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 17 Feb 2026 10:08:06 -0800 Subject: [PATCH 03/49] Agent step 3 -- unreviewed -- use Pydantic's model registry --- src/typeagent/aitools/model_registry.py | 294 +++++++++++++----------- tests/test_model_registry.py | 271 ++++++++++++++-------- 2 files changed, 334 insertions(+), 231 deletions(-) diff --git a/src/typeagent/aitools/model_registry.py b/src/typeagent/aitools/model_registry.py index 7cc03c4b..3b02e461 100644 --- a/src/typeagent/aitools/model_registry.py +++ b/src/typeagent/aitools/model_registry.py @@ -1,144 +1,165 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Provider-agnostic model configuration. +"""Provider-agnostic model configuration backed by pydantic_ai. -Create chat and embedding models from ``provider/model`` spec strings:: +Create chat and embedding models from ``provider:model`` spec strings:: from typeagent.aitools.model_registry import configure_models chat, embedder = configure_models( - "openai/gpt-4o", - "openai/text-embedding-3-small", + "openai:gpt-4o", + "openai:text-embedding-3-small", ) -Supported built-in providers ------------------------------ +The spec format is ``provider:model``, matching pydantic_ai conventions. +Provider wiring (API keys, endpoints, etc.) is handled by pydantic_ai's +model registry, which supports 25+ providers including ``openai``, +``azure``, ``anthropic``, ``google``, ``bedrock``, ``groq``, ``mistral``, +``ollama``, ``cohere``, and many more. -* ``openai/`` — requires ``OPENAI_API_KEY`` env var. -* ``azure/`` — requires ``AZURE_OPENAI_API_KEY`` (and - ``AZURE_OPENAI_ENDPOINT``) env vars. For Azure, the *model* part of the - spec is the **deployment name**. - -Extending with new providers ----------------------------- - -Implement ``typechat.TypeChatLanguageModel`` for chat, or -``IEmbeddingModel`` for embeddings, then register a factory:: - - from typeagent.aitools.model_registry import ( - register_chat_provider, - register_embedding_provider, - ) - - register_chat_provider("anthropic", my_anthropic_chat_factory) - register_embedding_provider("gemini", my_gemini_embedding_factory) - -Each factory is a callable ``(model_name: str) -> Model``. +See https://ai.pydantic.dev/models/ for all supported providers and their +required environment variables. """ -from collections.abc import Callable -import os - +import numpy as np +from numpy.typing import NDArray + +from pydantic_ai import Embedder as _PydanticAIEmbedder +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + SystemPromptPart, + TextPart, + UserPromptPart, +) +from pydantic_ai.models import infer_model, Model, ModelRequestParameters import typechat -from .embeddings import AsyncEmbeddingModel, IEmbeddingModel +from .embeddings import IEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings # --------------------------------------------------------------------------- -# Spec parsing +# Known embedding sizes for common models # --------------------------------------------------------------------------- -type ChatModelFactory = Callable[[str], typechat.TypeChatLanguageModel] -type EmbeddingModelFactory = Callable[[str], IEmbeddingModel] - - -def _parse_model_spec(spec: str) -> tuple[str, str]: - """Parse ``'provider/model'`` into ``(provider, model_name)``. - - Raises ``ValueError`` on malformed specs. - """ - parts = spec.split("/", 1) - if len(parts) != 2 or not parts[0] or not parts[1]: - raise ValueError( - f"Invalid model spec {spec!r}. " - f"Expected 'provider/model', e.g. 'openai/gpt-4o'." - ) - return parts[0], parts[1] +_KNOWN_EMBEDDING_SIZES: dict[str, int] = { + "text-embedding-ada-002": 1536, + "text-embedding-3-small": 1536, + "text-embedding-3-large": 3072, + "embed-english-v3.0": 1024, + "embed-multilingual-v3.0": 1024, + "embed-english-light-v3.0": 384, + "embed-multilingual-light-v3.0": 384, + "text-embedding-004": 768, + "embedding-001": 768, +} # --------------------------------------------------------------------------- -# Chat model registry +# Chat model adapter # --------------------------------------------------------------------------- -_chat_providers: dict[str, ChatModelFactory] = {} - - -def register_chat_provider(provider: str, factory: ChatModelFactory) -> None: - """Register a factory that creates chat models for *provider*.""" - _chat_providers[provider] = factory - - -def _openai_chat(model_name: str) -> typechat.TypeChatLanguageModel: - env: dict[str, str | None] = dict(os.environ) - if not env.get("OPENAI_API_KEY"): - raise RuntimeError("OPENAI_API_KEY required for openai/ chat models.") - env["OPENAI_MODEL"] = model_name - # Force the OpenAI path even when Azure env vars are also present. - env.pop("AZURE_OPENAI_API_KEY", None) - return typechat.create_language_model(env) +class PydanticAIChatModel: + """Adapter from :class:`pydantic_ai.models.Model` to TypeChat's + :class:`~typechat.TypeChatLanguageModel`. -def _azure_chat(model_name: str) -> typechat.TypeChatLanguageModel: - from .auth import AzureTokenProvider, get_shared_token_provider - from .utils import DEFAULT_MAX_RETRY_ATTEMPTS, DEFAULT_TIMEOUT_SECONDS, ModelWrapper - - env: dict[str, str | None] = dict(os.environ) - key = env.get("AZURE_OPENAI_API_KEY") - if not key: - raise RuntimeError("AZURE_OPENAI_API_KEY required for azure/ chat models.") - env["OPENAI_MODEL"] = model_name - # Force the Azure path even when OPENAI_API_KEY is also present. - env.pop("OPENAI_API_KEY", None) + This lets any pydantic_ai chat model (OpenAI, Anthropic, Google, …) be + used wherever TypeChat expects a ``TypeChatLanguageModel``. + """ - shared_token_provider: AzureTokenProvider | None = None - if isinstance(key, str) and key.lower() == "identity": - shared_token_provider = get_shared_token_provider() - env["AZURE_OPENAI_API_KEY"] = shared_token_provider.get_token() + def __init__(self, model: Model) -> None: + self._model = model - model = typechat.create_language_model(env) - model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS - model.max_retry_attempts = DEFAULT_MAX_RETRY_ATTEMPTS - if shared_token_provider is not None: - model = ModelWrapper(model, shared_token_provider) - return model + async def complete( + self, prompt: str | list[typechat.PromptSection] + ) -> typechat.Result[str]: + parts: list[SystemPromptPart | UserPromptPart] = [] + if isinstance(prompt, str): + parts.append(UserPromptPart(content=prompt)) + else: + for section in prompt: + if section["role"] == "system": + parts.append(SystemPromptPart(content=section["content"])) + else: + parts.append(UserPromptPart(content=section["content"])) + messages: list[ModelMessage] = [ModelRequest(parts=parts)] + params = ModelRequestParameters() -register_chat_provider("openai", _openai_chat) -register_chat_provider("azure", _azure_chat) + response = await self._model.request(messages, None, params) + text_parts = [p.content for p in response.parts if isinstance(p, TextPart)] + if text_parts: + return typechat.Success("".join(text_parts)) + return typechat.Failure("No text content in model response") # --------------------------------------------------------------------------- -# Embedding model registry +# Embedding model adapter # --------------------------------------------------------------------------- -_embedding_providers: dict[str, EmbeddingModelFactory] = {} - - -def register_embedding_provider(provider: str, factory: EmbeddingModelFactory) -> None: - """Register a factory that creates embedding models for *provider*.""" - _embedding_providers[provider] = factory +class PydanticAIEmbeddingModel: + """Adapter from :class:`pydantic_ai.Embedder` to :class:`IEmbeddingModel`. -def _openai_embedding(model_name: str) -> IEmbeddingModel: - return AsyncEmbeddingModel(model_name=model_name, use_azure=False) - - -def _azure_embedding(model_name: str) -> IEmbeddingModel: - return AsyncEmbeddingModel(model_name=model_name, use_azure=True) - + This lets any pydantic_ai embedding provider (OpenAI, Cohere, Google, …) + be used wherever the codebase expects an ``IEmbeddingModel``, including + :class:`~typeagent.aitools.vectorbase.VectorBase` and + :class:`~typeagent.knowpro.convsettings.ConversationSettings`. + """ -register_embedding_provider("openai", _openai_embedding) -register_embedding_provider("azure", _azure_embedding) + model_name: str + embedding_size: int + + def __init__( + self, + embedder: _PydanticAIEmbedder, + model_name: str, + embedding_size: int, + ) -> None: + self._embedder = embedder + self.model_name = model_name + self.embedding_size = embedding_size + self._cache: dict[str, NormalizedEmbedding] = {} + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + self._cache[key] = embedding + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + result = await self._embedder.embed([input], input_type="document") + embedding: NDArray[np.float32] = np.array( + result.embeddings[0], dtype=np.float32 + ) + norm = float(np.linalg.norm(embedding)) + if norm > 0: + embedding = (embedding / norm).astype(np.float32) + return embedding + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + if not input: + return np.empty((0, self.embedding_size), dtype=np.float32) + result = await self._embedder.embed(input, input_type="document") + embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32) + norms = np.linalg.norm(embeddings, axis=1, keepdims=True).astype(np.float32) + norms = np.where(norms > 0, norms, np.float32(1.0)) + embeddings = (embeddings / norms).astype(np.float32) + return embeddings + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + cached = self._cache.get(key) + if cached is not None: + return cached + embedding = await self.get_embedding_nocache(key) + self._cache[key] = embedding + return embedding + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + missing_keys = [k for k in keys if k not in self._cache] + if missing_keys: + fresh = await self.get_embeddings_nocache(missing_keys) + for i, k in enumerate(missing_keys): + self._cache[k] = fresh[i] + return np.array([self._cache[k] for k in keys], dtype=np.float32) # --------------------------------------------------------------------------- @@ -149,66 +170,71 @@ def _azure_embedding(model_name: str) -> IEmbeddingModel: def create_chat_model( model_spec: str, ) -> typechat.TypeChatLanguageModel: - """Create a chat model from a ``provider/model`` spec. + """Create a chat model from a ``provider:model`` spec. - Examples:: + Delegates to :func:`pydantic_ai.models.infer_model` for provider wiring. - model = create_chat_model("openai/gpt-4o") - model = create_chat_model("azure/my-gpt4o-deployment") + Examples:: - For Azure, *model* is the **deployment name**, not the underlying - model name. + model = create_chat_model("openai:gpt-4o") + model = create_chat_model("anthropic:claude-sonnet-4-20250514") + model = create_chat_model("google:gemini-2.0-flash") """ - provider, model_name = _parse_model_spec(model_spec) - factory = _chat_providers.get(provider) - if factory is None: - avail = ", ".join(sorted(_chat_providers)) or "(none)" - raise ValueError( - f"Unknown chat provider {provider!r}. " - f"Available: {avail}. " - f"Use register_chat_provider() to add support." - ) - return factory(model_name) + model = infer_model(model_spec) + return PydanticAIChatModel(model) def create_embedding_model( model_spec: str, + *, + embedding_size: int | None = None, ) -> IEmbeddingModel: - """Create an embedding model from a ``provider/model`` spec. + """Create an embedding model from a ``provider:model`` spec. + + Delegates to :class:`pydantic_ai.Embedder` for provider wiring. + + If *embedding_size* is not given, it is looked up in a table of common + models. For unknown models, pass it explicitly. Examples:: - model = create_embedding_model("openai/text-embedding-3-small") - model = create_embedding_model("azure/text-embedding-3-small") + model = create_embedding_model("openai:text-embedding-3-small") + model = create_embedding_model("cohere:embed-english-v3.0") + model = create_embedding_model("google:text-embedding-004") """ - provider, model_name = _parse_model_spec(model_spec) - factory = _embedding_providers.get(provider) - if factory is None: - avail = ", ".join(sorted(_embedding_providers)) or "(none)" + model_name = model_spec.split(":")[-1] if ":" in model_spec else model_spec + if embedding_size is None: + embedding_size = _KNOWN_EMBEDDING_SIZES.get(model_name) + if embedding_size is None: raise ValueError( - f"Unknown embedding provider {provider!r}. " - f"Available: {avail}. " - f"Use register_embedding_provider() to add support." + f"Unknown embedding size for model {model_name!r}. " + f"Pass embedding_size= explicitly." ) - return factory(model_name) + embedder = _PydanticAIEmbedder(model_spec) + return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) def configure_models( chat_model_spec: str, embedding_model_spec: str, + *, + embedding_size: int | None = None, ) -> tuple[typechat.TypeChatLanguageModel, IEmbeddingModel]: """Configure both a chat model and an embedding model at once. + Delegates to pydantic_ai's model registry for provider wiring. + Example:: chat, embedder = configure_models( - "openai/gpt-4o", - "openai/text-embedding-3-small", + "openai:gpt-4o", + "openai:text-embedding-3-small", ) settings = ConversationSettings(model=embedder) extractor = KnowledgeExtractor(model=chat) """ - return create_chat_model(chat_model_spec), create_embedding_model( - embedding_model_spec + return ( + create_chat_model(chat_model_spec), + create_embedding_model(embedding_model_spec, embedding_size=embedding_size), ) diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index 11b4f71e..d18e071d 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -1,135 +1,218 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import numpy as np import pytest import typechat -from typeagent.aitools.embeddings import ( - AsyncEmbeddingModel, - IEmbeddingModel, - TEST_MODEL_NAME, -) +from typeagent.aitools.embeddings import IEmbeddingModel, NormalizedEmbedding from typeagent.aitools.model_registry import ( - _chat_providers, - _embedding_providers, - _parse_model_spec, + _KNOWN_EMBEDDING_SIZES, configure_models, create_chat_model, create_embedding_model, - register_chat_provider, - register_embedding_provider, + PydanticAIChatModel, + PydanticAIEmbeddingModel, ) # --------------------------------------------------------------------------- -# Spec parsing +# Spec format # --------------------------------------------------------------------------- -def test_parse_valid_specs() -> None: - assert _parse_model_spec("openai/gpt-4o") == ("openai", "gpt-4o") - assert _parse_model_spec("azure/my-deployment") == ("azure", "my-deployment") - assert _parse_model_spec("anthropic/claude-3.5-sonnet") == ( - "anthropic", - "claude-3.5-sonnet", - ) +def test_spec_uses_colon_separator() -> None: + """Specs use ``provider:model`` format matching pydantic_ai conventions.""" + with pytest.raises(Exception): + # A nonsense provider should fail + create_chat_model("nonexistent_provider_xyz:fake-model") -def test_parse_spec_preserves_slashes() -> None: - """Only the first '/' is a separator; the rest belong to the model name.""" - assert _parse_model_spec("provider/model/variant") == ( - "provider", - "model/variant", - ) +# --------------------------------------------------------------------------- +# Known embedding sizes +# --------------------------------------------------------------------------- -def test_parse_invalid_specs() -> None: - with pytest.raises(ValueError, match="Invalid model spec"): - _parse_model_spec("noslash") - with pytest.raises(ValueError, match="Invalid model spec"): - _parse_model_spec("/model") - with pytest.raises(ValueError, match="Invalid model spec"): - _parse_model_spec("provider/") - with pytest.raises(ValueError, match="Invalid model spec"): - _parse_model_spec("") +def test_known_embedding_sizes() -> None: + assert _KNOWN_EMBEDDING_SIZES["text-embedding-3-small"] == 1536 + assert _KNOWN_EMBEDDING_SIZES["text-embedding-3-large"] == 3072 + assert _KNOWN_EMBEDDING_SIZES["text-embedding-ada-002"] == 1536 -# --------------------------------------------------------------------------- -# Built-in registration -# --------------------------------------------------------------------------- +def test_unknown_embedding_size_raises() -> None: + with pytest.raises(ValueError, match="Unknown embedding size"): + create_embedding_model("openai:completely-unknown-model-xyz") -def test_builtin_providers_registered() -> None: - assert "openai" in _chat_providers - assert "azure" in _chat_providers - assert "openai" in _embedding_providers - assert "azure" in _embedding_providers +def test_explicit_embedding_size() -> None: + """Passing embedding_size= bypasses the lookup table.""" + # This should not raise even though the model name is unknown + model = create_embedding_model( + "openai:completely-unknown-model-xyz", embedding_size=42 + ) + assert model.embedding_size == 42 # --------------------------------------------------------------------------- -# Unknown provider errors +# PydanticAIChatModel adapter # --------------------------------------------------------------------------- -def test_unknown_chat_provider() -> None: - with pytest.raises(ValueError, match="Unknown chat provider"): - create_chat_model("magical/unicorn") +@pytest.mark.asyncio +async def test_chat_adapter_complete() -> None: + """PydanticAIChatModel wraps a pydantic_ai Model.""" + from unittest.mock import AsyncMock + + from pydantic_ai.messages import ModelResponse, TextPart + from pydantic_ai.models import Model + + mock_model = AsyncMock(spec=Model) + mock_model.request.return_value = ModelResponse(parts=[TextPart(content="hello")]) + + adapter = PydanticAIChatModel(mock_model) + result = await adapter.complete("test prompt") + assert isinstance(result, typechat.Success) + assert result.value == "hello" + + +@pytest.mark.asyncio +async def test_chat_adapter_prompt_sections() -> None: + """PydanticAIChatModel handles list[PromptSection] prompts.""" + from unittest.mock import AsyncMock + + from pydantic_ai.messages import ModelResponse, TextPart + from pydantic_ai.models import Model + + mock_model = AsyncMock(spec=Model) + mock_model.request.return_value = ModelResponse( + parts=[TextPart(content="response")] + ) + + adapter = PydanticAIChatModel(mock_model) + sections: list[typechat.PromptSection] = [ + {"role": "system", "content": "You are helpful."}, + {"role": "user", "content": "Hello"}, + ] + result = await adapter.complete(sections) + assert isinstance(result, typechat.Success) + assert result.value == "response" + # Verify the request was called with proper message structure + call_args = mock_model.request.call_args + messages = call_args[0][0] + assert len(messages) == 1 + request = messages[0] + from pydantic_ai.messages import SystemPromptPart, UserPromptPart -def test_unknown_embedding_provider() -> None: - with pytest.raises(ValueError, match="Unknown embedding provider"): - create_embedding_model("magical/unicorn") + assert isinstance(request.parts[0], SystemPromptPart) + assert isinstance(request.parts[1], UserPromptPart) # --------------------------------------------------------------------------- -# Custom provider registration +# PydanticAIEmbeddingModel adapter # --------------------------------------------------------------------------- -class FakeChatModel(typechat.TypeChatLanguageModel): - """Minimal chat model for registry tests.""" +@pytest.mark.asyncio +async def test_embedding_adapter_single() -> None: + """PydanticAIEmbeddingModel computes a single normalized embedding.""" + from unittest.mock import AsyncMock - async def complete( - self, prompt: str | list[typechat.PromptSection] - ) -> typechat.Result[str]: - return typechat.Success("fake") + from pydantic_ai import Embedder + from pydantic_ai.embeddings import EmbeddingResult + mock_embedder = AsyncMock(spec=Embedder) + raw_vec = [3.0, 4.0, 0.0] + mock_embedder.embed.return_value = EmbeddingResult( + embeddings=[raw_vec], + inputs=["test"], + input_type="document", + model_name="test-model", + provider_name="test", + ) + + adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) + result = await adapter.get_embedding_nocache("test") + assert result.shape == (3,) + norm = float(np.linalg.norm(result)) + assert abs(norm - 1.0) < 1e-6 + + +@pytest.mark.asyncio +async def test_embedding_adapter_batch() -> None: + """PydanticAIEmbeddingModel computes batch embeddings.""" + from unittest.mock import AsyncMock + + from pydantic_ai import Embedder + from pydantic_ai.embeddings import EmbeddingResult + + mock_embedder = AsyncMock(spec=Embedder) + mock_embedder.embed.return_value = EmbeddingResult( + embeddings=[[1.0, 0.0], [0.0, 1.0]], + inputs=["a", "b"], + input_type="document", + model_name="test-model", + provider_name="test", + ) + + adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 2) + result = await adapter.get_embeddings_nocache(["a", "b"]) + assert result.shape == (2, 2) -def test_register_and_use_custom_chat_provider() -> None: - instance = FakeChatModel() - register_chat_provider("_test_chat", lambda name: instance) - try: - result = create_chat_model("_test_chat/any-model") - assert result is instance - finally: - _chat_providers.pop("_test_chat", None) +@pytest.mark.asyncio +async def test_embedding_adapter_caching() -> None: + """Caching avoids re-computing embeddings.""" + from unittest.mock import AsyncMock -def test_register_and_use_custom_embedding_provider() -> None: - instance = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) - register_embedding_provider("_test_embed", lambda name: instance) - try: - result = create_embedding_model("_test_embed/any-model") - assert result is instance - assert isinstance(result, IEmbeddingModel) - finally: - _embedding_providers.pop("_test_embed", None) + from pydantic_ai import Embedder + from pydantic_ai.embeddings import EmbeddingResult + mock_embedder = AsyncMock(spec=Embedder) + mock_embedder.embed.return_value = EmbeddingResult( + embeddings=[[1.0, 0.0, 0.0]], + inputs=["cached"], + input_type="document", + model_name="test-model", + provider_name="test", + ) + + adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) + first = await adapter.get_embedding("cached") + second = await adapter.get_embedding("cached") + np.testing.assert_array_equal(first, second) + # embed() should only be called once + assert mock_embedder.embed.call_count == 1 + + +@pytest.mark.asyncio +async def test_embedding_adapter_add_embedding() -> None: + """add_embedding() populates the cache.""" + from unittest.mock import AsyncMock -def test_model_name_forwarded_to_factory() -> None: - """The model portion of the spec is passed to the factory.""" - received: list[str] = [] + from pydantic_ai import Embedder - def capture_factory(model_name: str) -> IEmbeddingModel: - received.append(model_name) - return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + mock_embedder = AsyncMock(spec=Embedder) + adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) + vec: NormalizedEmbedding = np.array([1.0, 0.0, 0.0], dtype=np.float32) + adapter.add_embedding("key", vec) + result = await adapter.get_embedding("key") + np.testing.assert_array_equal(result, vec) + # No embed() call needed + mock_embedder.embed.assert_not_called() - register_embedding_provider("_test_fwd", capture_factory) - try: - create_embedding_model("_test_fwd/text-embedding-3-small") - assert received == ["text-embedding-3-small"] - finally: - _embedding_providers.pop("_test_fwd", None) + +@pytest.mark.asyncio +async def test_embedding_adapter_empty_batch() -> None: + """Empty batch returns empty array.""" + from unittest.mock import AsyncMock + + from pydantic_ai import Embedder + + mock_embedder = AsyncMock(spec=Embedder) + adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 4) + result = await adapter.get_embeddings_nocache([]) + assert result.shape == (0, 4) # --------------------------------------------------------------------------- @@ -137,16 +220,10 @@ def capture_factory(model_name: str) -> IEmbeddingModel: # --------------------------------------------------------------------------- -def test_configure_models() -> None: - chat_instance = FakeChatModel() - embed_instance = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) - - register_chat_provider("_test_cm", lambda name: chat_instance) - register_embedding_provider("_test_cm", lambda name: embed_instance) - try: - chat, embedder = configure_models("_test_cm/chat", "_test_cm/embed") - assert chat is chat_instance - assert embedder is embed_instance - finally: - _chat_providers.pop("_test_cm", None) - _embedding_providers.pop("_test_cm", None) +def test_configure_models_returns_correct_types() -> None: + """configure_models creates both adapters.""" + chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") + assert isinstance(chat, PydanticAIChatModel) + assert isinstance(embedder, PydanticAIEmbeddingModel) + assert isinstance(embedder, IEmbeddingModel) + assert embedder.embedding_size == 1536 From 1015737b8eb6242609cc2f06350b3834d45e6617 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 18 Feb 2026 08:21:32 -0800 Subject: [PATCH 04/49] Don't hardcode an incomplete table of embedding sizes --- src/typeagent/aitools/model_registry.py | 56 +++++++++------------- tests/test_model_registry.py | 62 ++++++++++++++++++------- 2 files changed, 67 insertions(+), 51 deletions(-) diff --git a/src/typeagent/aitools/model_registry.py b/src/typeagent/aitools/model_registry.py index 3b02e461..b83db9a7 100644 --- a/src/typeagent/aitools/model_registry.py +++ b/src/typeagent/aitools/model_registry.py @@ -38,29 +38,12 @@ from .embeddings import IEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings -# --------------------------------------------------------------------------- -# Known embedding sizes for common models -# --------------------------------------------------------------------------- - -_KNOWN_EMBEDDING_SIZES: dict[str, int] = { - "text-embedding-ada-002": 1536, - "text-embedding-3-small": 1536, - "text-embedding-3-large": 3072, - "embed-english-v3.0": 1024, - "embed-multilingual-v3.0": 1024, - "embed-english-light-v3.0": 384, - "embed-multilingual-light-v3.0": 384, - "text-embedding-004": 768, - "embedding-001": 768, -} - - # --------------------------------------------------------------------------- # Chat model adapter # --------------------------------------------------------------------------- -class PydanticAIChatModel: +class PydanticAIChatModel(typechat.TypeChatLanguageModel): """Adapter from :class:`pydantic_ai.models.Model` to TypeChat's :class:`~typechat.TypeChatLanguageModel`. @@ -99,13 +82,16 @@ async def complete( # --------------------------------------------------------------------------- -class PydanticAIEmbeddingModel: +class PydanticAIEmbeddingModel(IEmbeddingModel): """Adapter from :class:`pydantic_ai.Embedder` to :class:`IEmbeddingModel`. This lets any pydantic_ai embedding provider (OpenAI, Cohere, Google, …) be used wherever the codebase expects an ``IEmbeddingModel``, including :class:`~typeagent.aitools.vectorbase.VectorBase` and :class:`~typeagent.knowpro.convsettings.ConversationSettings`. + + If *embedding_size* is not given, it is probed automatically by making a + single embedding call. """ model_name: str @@ -115,7 +101,7 @@ def __init__( self, embedder: _PydanticAIEmbedder, model_name: str, - embedding_size: int, + embedding_size: int = 0, ) -> None: self._embedder = embedder self.model_name = model_name @@ -125,11 +111,18 @@ def __init__( def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: self._cache[key] = embedding + async def _probe_embedding_size(self) -> None: + """Discover embedding_size by making a single API call.""" + result = await self._embedder.embed(["probe"], input_type="document") + self.embedding_size = len(result.embeddings[0]) + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: result = await self._embedder.embed([input], input_type="document") embedding: NDArray[np.float32] = np.array( result.embeddings[0], dtype=np.float32 ) + if self.embedding_size == 0: + self.embedding_size = len(embedding) norm = float(np.linalg.norm(embedding)) if norm > 0: embedding = (embedding / norm).astype(np.float32) @@ -137,9 +130,13 @@ async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: if not input: + if self.embedding_size == 0: + await self._probe_embedding_size() return np.empty((0, self.embedding_size), dtype=np.float32) result = await self._embedder.embed(input, input_type="document") embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32) + if self.embedding_size == 0: + self.embedding_size = embeddings.shape[1] norms = np.linalg.norm(embeddings, axis=1, keepdims=True).astype(np.float32) norms = np.where(norms > 0, norms, np.float32(1.0)) embeddings = (embeddings / norms).astype(np.float32) @@ -187,14 +184,14 @@ def create_chat_model( def create_embedding_model( model_spec: str, *, - embedding_size: int | None = None, -) -> IEmbeddingModel: + embedding_size: int = 0, +) -> PydanticAIEmbeddingModel: """Create an embedding model from a ``provider:model`` spec. Delegates to :class:`pydantic_ai.Embedder` for provider wiring. - If *embedding_size* is not given, it is looked up in a table of common - models. For unknown models, pass it explicitly. + If *embedding_size* is not given, it will be probed automatically + on the first embedding call. Examples:: @@ -203,13 +200,6 @@ def create_embedding_model( model = create_embedding_model("google:text-embedding-004") """ model_name = model_spec.split(":")[-1] if ":" in model_spec else model_spec - if embedding_size is None: - embedding_size = _KNOWN_EMBEDDING_SIZES.get(model_name) - if embedding_size is None: - raise ValueError( - f"Unknown embedding size for model {model_name!r}. " - f"Pass embedding_size= explicitly." - ) embedder = _PydanticAIEmbedder(model_spec) return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) @@ -218,8 +208,8 @@ def configure_models( chat_model_spec: str, embedding_model_spec: str, *, - embedding_size: int | None = None, -) -> tuple[typechat.TypeChatLanguageModel, IEmbeddingModel]: + embedding_size: int = 0, +) -> tuple[PydanticAIChatModel, PydanticAIEmbeddingModel]: """Configure both a chat model and an embedding model at once. Delegates to pydantic_ai's model registry for provider wiring. diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index d18e071d..d6bcf5ae 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -8,7 +8,6 @@ from typeagent.aitools.embeddings import IEmbeddingModel, NormalizedEmbedding from typeagent.aitools.model_registry import ( - _KNOWN_EMBEDDING_SIZES, configure_models, create_chat_model, create_embedding_model, @@ -29,35 +28,34 @@ def test_spec_uses_colon_separator() -> None: # --------------------------------------------------------------------------- -# Known embedding sizes +# Embedding size # --------------------------------------------------------------------------- -def test_known_embedding_sizes() -> None: - assert _KNOWN_EMBEDDING_SIZES["text-embedding-3-small"] == 1536 - assert _KNOWN_EMBEDDING_SIZES["text-embedding-3-large"] == 3072 - assert _KNOWN_EMBEDDING_SIZES["text-embedding-ada-002"] == 1536 - - -def test_unknown_embedding_size_raises() -> None: - with pytest.raises(ValueError, match="Unknown embedding size"): - create_embedding_model("openai:completely-unknown-model-xyz") - - def test_explicit_embedding_size() -> None: - """Passing embedding_size= bypasses the lookup table.""" - # This should not raise even though the model name is unknown + """Passing embedding_size= sets it immediately.""" model = create_embedding_model( - "openai:completely-unknown-model-xyz", embedding_size=42 + "openai:text-embedding-3-small", embedding_size=42 ) assert model.embedding_size == 42 +def test_default_embedding_size_is_zero() -> None: + """Without embedding_size=, it defaults to 0 (probed on first call).""" + model = create_embedding_model("openai:text-embedding-3-small") + assert model.embedding_size == 0 + + # --------------------------------------------------------------------------- # PydanticAIChatModel adapter # --------------------------------------------------------------------------- +def test_chat_model_is_typechat_model() -> None: + """PydanticAIChatModel inherits from TypeChatLanguageModel.""" + assert issubclass(PydanticAIChatModel, typechat.TypeChatLanguageModel) + + @pytest.mark.asyncio async def test_chat_adapter_complete() -> None: """PydanticAIChatModel wraps a pydantic_ai Model.""" @@ -113,6 +111,11 @@ async def test_chat_adapter_prompt_sections() -> None: # --------------------------------------------------------------------------- +def test_embedding_model_is_iembedding_model() -> None: + """PydanticAIEmbeddingModel inherits from IEmbeddingModel.""" + assert issubclass(PydanticAIEmbeddingModel, IEmbeddingModel) + + @pytest.mark.asyncio async def test_embedding_adapter_single() -> None: """PydanticAIEmbeddingModel computes a single normalized embedding.""" @@ -138,6 +141,29 @@ async def test_embedding_adapter_single() -> None: assert abs(norm - 1.0) < 1e-6 +@pytest.mark.asyncio +async def test_embedding_adapter_probes_size() -> None: + """embedding_size is discovered from the first embedding call.""" + from unittest.mock import AsyncMock + + from pydantic_ai import Embedder + from pydantic_ai.embeddings import EmbeddingResult + + mock_embedder = AsyncMock(spec=Embedder) + mock_embedder.embed.return_value = EmbeddingResult( + embeddings=[[1.0, 0.0, 0.0]], + inputs=["probe"], + input_type="document", + model_name="test-model", + provider_name="test", + ) + + adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model") + assert adapter.embedding_size == 0 + await adapter.get_embedding_nocache("probe") + assert adapter.embedding_size == 3 + + @pytest.mark.asyncio async def test_embedding_adapter_batch() -> None: """PydanticAIEmbeddingModel computes batch embeddings.""" @@ -204,7 +230,7 @@ async def test_embedding_adapter_add_embedding() -> None: @pytest.mark.asyncio async def test_embedding_adapter_empty_batch() -> None: - """Empty batch returns empty array.""" + """Empty batch returns empty array with known size.""" from unittest.mock import AsyncMock from pydantic_ai import Embedder @@ -226,4 +252,4 @@ def test_configure_models_returns_correct_types() -> None: assert isinstance(chat, PydanticAIChatModel) assert isinstance(embedder, PydanticAIEmbeddingModel) assert isinstance(embedder, IEmbeddingModel) - assert embedder.embedding_size == 1536 + assert isinstance(chat, typechat.TypeChatLanguageModel) From 4bd1387fd884831b4b2dfd0a363ac3471d23cf01 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 18 Feb 2026 08:36:03 -0800 Subject: [PATCH 05/49] Fix test failures --- src/typeagent/aitools/model_registry.py | 2 +- tests/test_model_registry.py | 11 ++++------- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/typeagent/aitools/model_registry.py b/src/typeagent/aitools/model_registry.py index b83db9a7..fcb55480 100644 --- a/src/typeagent/aitools/model_registry.py +++ b/src/typeagent/aitools/model_registry.py @@ -166,7 +166,7 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: def create_chat_model( model_spec: str, -) -> typechat.TypeChatLanguageModel: +) -> PydanticAIChatModel: """Create a chat model from a ``provider:model`` spec. Delegates to :func:`pydantic_ai.models.infer_model` for provider wiring. diff --git a/tests/test_model_registry.py b/tests/test_model_registry.py index d6bcf5ae..73217f7a 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_registry.py @@ -34,9 +34,7 @@ def test_spec_uses_colon_separator() -> None: def test_explicit_embedding_size() -> None: """Passing embedding_size= sets it immediately.""" - model = create_embedding_model( - "openai:text-embedding-3-small", embedding_size=42 - ) + model = create_embedding_model("openai:text-embedding-3-small", embedding_size=42) assert model.embedding_size == 42 @@ -53,7 +51,7 @@ def test_default_embedding_size_is_zero() -> None: def test_chat_model_is_typechat_model() -> None: """PydanticAIChatModel inherits from TypeChatLanguageModel.""" - assert issubclass(PydanticAIChatModel, typechat.TypeChatLanguageModel) + assert typechat.TypeChatLanguageModel in PydanticAIChatModel.__mro__ @pytest.mark.asyncio @@ -113,7 +111,7 @@ async def test_chat_adapter_prompt_sections() -> None: def test_embedding_model_is_iembedding_model() -> None: """PydanticAIEmbeddingModel inherits from IEmbeddingModel.""" - assert issubclass(PydanticAIEmbeddingModel, IEmbeddingModel) + assert IEmbeddingModel in PydanticAIEmbeddingModel.__mro__ @pytest.mark.asyncio @@ -251,5 +249,4 @@ def test_configure_models_returns_correct_types() -> None: chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") assert isinstance(chat, PydanticAIChatModel) assert isinstance(embedder, PydanticAIEmbeddingModel) - assert isinstance(embedder, IEmbeddingModel) - assert isinstance(chat, typechat.TypeChatLanguageModel) + assert typechat.TypeChatLanguageModel in type(chat).__mro__ From 60aa4031902276fa9678bd109debd94e03b6818c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 18 Feb 2026 08:41:11 -0800 Subject: [PATCH 06/49] Rename model_registry -> model_adapters --- .../aitools/{model_registry.py => model_adapters.py} | 2 +- src/typeagent/aitools/utils.py | 4 ++-- tests/{test_model_registry.py => test_model_adapters.py} | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) rename src/typeagent/aitools/{model_registry.py => model_adapters.py} (99%) rename tests/{test_model_registry.py => test_model_adapters.py} (99%) diff --git a/src/typeagent/aitools/model_registry.py b/src/typeagent/aitools/model_adapters.py similarity index 99% rename from src/typeagent/aitools/model_registry.py rename to src/typeagent/aitools/model_adapters.py index fcb55480..6d8f06a9 100644 --- a/src/typeagent/aitools/model_registry.py +++ b/src/typeagent/aitools/model_adapters.py @@ -5,7 +5,7 @@ Create chat and embedding models from ``provider:model`` spec strings:: - from typeagent.aitools.model_registry import configure_models + from typeagent.aitools.model_adapters import configure_models chat, embedder = configure_models( "openai:gpt-4o", diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index b475c728..9d57b854 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -128,8 +128,8 @@ def create_typechat_model() -> typechat.TypeChatLanguageModel: Auto-detects the provider from ``OPENAI_API_KEY`` / ``AZURE_OPENAI_API_KEY`` environment variables. - For explicit provider selection, use :func:`model_registry.create_chat_model` - with a spec string like ``"openai/gpt-4o"`` or ``"azure/my-deployment"``. + For explicit provider selection, use :func:`model_adapters.create_chat_model` + with a spec string like ``"openai:gpt-4o"`` or ``"azure:my-deployment"``. """ env: dict[str, str | None] = dict(os.environ) key_name = "AZURE_OPENAI_API_KEY" diff --git a/tests/test_model_registry.py b/tests/test_model_adapters.py similarity index 99% rename from tests/test_model_registry.py rename to tests/test_model_adapters.py index 73217f7a..33bd78c7 100644 --- a/tests/test_model_registry.py +++ b/tests/test_model_adapters.py @@ -7,7 +7,7 @@ import typechat from typeagent.aitools.embeddings import IEmbeddingModel, NormalizedEmbedding -from typeagent.aitools.model_registry import ( +from typeagent.aitools.model_adapters import ( configure_models, create_chat_model, create_embedding_model, From 067f3b92fd2c0e7ef8bd8aaee701498da58de265 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 24 Feb 2026 13:25:38 -0800 Subject: [PATCH 07/49] Move pydantic-ai to main deps --- pyproject.toml | 2 +- uv.lock | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 848a497d..94aca756 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "numpy>=2.2.6", "openai>=1.81.0", "pydantic>=2.11.4", + "pydantic-ai-slim[openai]>=1.39.0", "pyreadline3>=3.5.4 ; sys_platform == 'win32'", "python-dotenv>=1.1.0", "tiktoken>=0.12.0", @@ -87,7 +88,6 @@ dev = [ "isort>=7.0.0", "logfire>=4.1.0", # So 'make check' passes "opentelemetry-instrumentation-httpx>=0.57b0", - "pydantic-ai-slim[openai]>=1.39.0", "pyright>=1.1.408", # 407 has a regression "pytest>=8.3.5", "pytest-asyncio>=0.26.0", diff --git a/uv.lock b/uv.lock index ce23abdc..a901acd1 100644 --- a/uv.lock +++ b/uv.lock @@ -1821,6 +1821,7 @@ dependencies = [ { name = "numpy" }, { name = "openai" }, { name = "pydantic" }, + { name = "pydantic-ai-slim", extra = ["openai"] }, { name = "pyreadline3", marker = "sys_platform == 'win32'" }, { name = "python-dotenv" }, { name = "tiktoken" }, @@ -1845,7 +1846,6 @@ dev = [ { name = "isort" }, { name = "logfire" }, { name = "opentelemetry-instrumentation-httpx" }, - { name = "pydantic-ai-slim", extra = ["openai"] }, { name = "pyright" }, { name = "pytest" }, { name = "pytest-asyncio" }, @@ -1863,6 +1863,7 @@ requires-dist = [ { name = "openai", specifier = ">=1.81.0" }, { name = "opentelemetry-instrumentation-httpx", marker = "extra == 'logfire'", specifier = ">=0.57b0" }, { name = "pydantic", specifier = ">=2.11.4" }, + { name = "pydantic-ai-slim", extras = ["openai"], specifier = ">=1.39.0" }, { name = "pyreadline3", marker = "sys_platform == 'win32'", specifier = ">=3.5.4" }, { name = "python-dotenv", specifier = ">=1.1.0" }, { name = "tiktoken", specifier = ">=0.12.0" }, @@ -1882,7 +1883,6 @@ dev = [ { name = "isort", specifier = ">=7.0.0" }, { name = "logfire", specifier = ">=4.1.0" }, { name = "opentelemetry-instrumentation-httpx", specifier = ">=0.57b0" }, - { name = "pydantic-ai-slim", extras = ["openai"], specifier = ">=1.39.0" }, { name = "pyright", specifier = ">=1.1.408" }, { name = "pytest", specifier = ">=8.3.5" }, { name = "pytest-asyncio", specifier = ">=0.26.0" }, From 17b959fbffea9fdd9d679dc489e9b822993ed9b2 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 24 Feb 2026 16:12:24 -0800 Subject: [PATCH 08/49] Remove obsolete create_embedding_model -- wasn't easy --- src/typeagent/aitools/embeddings.py | 26 --------------------- src/typeagent/aitools/model_adapters.py | 8 ++++++- src/typeagent/aitools/vectorbase.py | 29 ++++++++++++++++++++---- src/typeagent/knowpro/convsettings.py | 3 ++- src/typeagent/knowpro/fuzzyindex.py | 6 ++--- src/typeagent/knowpro/serialization.py | 15 ++++++++++++ src/typeagent/podcasts/podcast.py | 15 +++++++----- src/typeagent/storage/sqlite/provider.py | 9 +++++--- src/typeagent/transcripts/transcript.py | 15 +++++++----- tests/test_serialization.py | 2 +- tools/ingest_vtt.py | 7 ++++-- 11 files changed, 82 insertions(+), 53 deletions(-) diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index ce8ea062..5aacaafa 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -350,29 +350,3 @@ async def truncate_input(self, input: str) -> tuple[str, int]: return self.encoding.decode(truncated_tokens), self.max_chunk_size else: return input, len(tokens) - - -def create_embedding_model( - embedding_size: int | None = None, - model_name: str | None = None, - **kwargs, -) -> IEmbeddingModel: - """Create an embedding model using OpenAI/Azure OpenAI. - - This is the default factory. To use a different provider, create an - instance of a class that implements ``IEmbeddingModel`` and pass it - directly to ``TextEmbeddingIndexSettings`` or ``ConversationSettings``. - - Args: - embedding_size: Requested embedding dimensionality (provider-specific). - model_name: Model identifier (e.g. "text-embedding-ada-002"). - **kwargs: Extra keyword arguments forwarded to ``AsyncEmbeddingModel``. - - Returns: - An ``IEmbeddingModel`` instance backed by OpenAI / Azure OpenAI. - """ - return AsyncEmbeddingModel( - embedding_size=embedding_size, - model_name=model_name, - **kwargs, - ) diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 6d8f06a9..c558111b 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -181,8 +181,11 @@ def create_chat_model( return PydanticAIChatModel(model) +DEFAULT_EMBEDDING_SPEC = "openai:text-embedding-3-small" + + def create_embedding_model( - model_spec: str, + model_spec: str | None = None, *, embedding_size: int = 0, ) -> PydanticAIEmbeddingModel: @@ -190,6 +193,7 @@ def create_embedding_model( Delegates to :class:`pydantic_ai.Embedder` for provider wiring. + If *model_spec* is ``None``, :data:`DEFAULT_EMBEDDING_SPEC` is used. If *embedding_size* is not given, it will be probed automatically on the first embedding call. @@ -199,6 +203,8 @@ def create_embedding_model( model = create_embedding_model("cohere:embed-english-v3.0") model = create_embedding_model("google:text-embedding-004") """ + if model_spec is None: + model_spec = DEFAULT_EMBEDDING_SPEC model_name = model_spec.split(":")[-1] if ":" in model_spec else model_spec embedder = _PydanticAIEmbedder(model_spec) return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 0a52f838..df93997f 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -7,11 +7,11 @@ import numpy as np from .embeddings import ( - create_embedding_model, IEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings, ) +from .model_adapters import create_embedding_model DEFAULT_MAX_RETRIES = 2 @@ -47,7 +47,7 @@ def __init__( max_retries if max_retries is not None else DEFAULT_MAX_RETRIES ) self.embedding_model = embedding_model or create_embedding_model( - embedding_size, max_retries=self.max_retries + embedding_size=embedding_size or 0, ) self.embedding_size = self.embedding_model.embedding_size assert ( @@ -93,6 +93,9 @@ def add_embedding( ) -> None: if isinstance(embedding, list): embedding = np.array(embedding, dtype=np.float32) + if self._embedding_size == 0: + self._set_embedding_size(len(embedding)) + self._vectors.shape = (0, self._embedding_size) embeddings = embedding.reshape(1, -1) # Make it 2D: 1xN self._vectors = np.append(self._vectors, embeddings, axis=0) if key is not None: @@ -102,6 +105,9 @@ def add_embeddings( self, keys: None | list[str], embeddings: NormalizedEmbeddings ) -> None: assert embeddings.ndim == 2 + if self._embedding_size == 0: + self._set_embedding_size(embeddings.shape[1]) + self._vectors.shape = (0, self._embedding_size) assert embeddings.shape[1] == self._embedding_size self._vectors = np.concatenate((self._vectors, embeddings), axis=0) if keys is not None: @@ -165,9 +171,17 @@ async def fuzzy_lookup( embedding, max_hits=max_hits, min_score=min_score, predicate=predicate ) + def _set_embedding_size(self, size: int) -> None: + """Adopt *size* when it was not known at construction time.""" + assert size > 0 + self._embedding_size = size + self._model.embedding_size = size + self.settings.embedding_size = size + def clear(self) -> None: self._vectors = np.array([], dtype=np.float32) - self._vectors.shape = (0, self._embedding_size) + if self._embedding_size > 0: + self._vectors.shape = (0, self._embedding_size) def get_embedding_at(self, pos: int) -> NormalizedEmbedding: if 0 <= pos < len(self._vectors): @@ -180,13 +194,20 @@ def serialize_embedding_at(self, pos: int) -> NormalizedEmbedding | None: return self._vectors[pos] if 0 <= pos < len(self._vectors) else None def serialize(self) -> NormalizedEmbeddings: - assert self._vectors.shape == (len(self._vectors), self._embedding_size) + if self._embedding_size > 0: + assert self._vectors.shape == (len(self._vectors), self._embedding_size) return self._vectors # TODO: Should we make a copy? def deserialize(self, data: NormalizedEmbeddings | None) -> None: if data is None: self.clear() return + if self._embedding_size == 0: + if data.ndim < 2 or data.shape[0] == 0: + # Empty data — can't determine size; just clear. + self.clear() + return + self._set_embedding_size(data.shape[1]) assert data.shape == (len(data), self._embedding_size), [ data.shape, self._embedding_size, diff --git a/src/typeagent/knowpro/convsettings.py b/src/typeagent/knowpro/convsettings.py index 7e25cd93..9dbf1214 100644 --- a/src/typeagent/knowpro/convsettings.py +++ b/src/typeagent/knowpro/convsettings.py @@ -5,7 +5,8 @@ from dataclasses import dataclass -from ..aitools.embeddings import create_embedding_model, IEmbeddingModel +from ..aitools.embeddings import IEmbeddingModel +from ..aitools.model_adapters import create_embedding_model from ..aitools.vectorbase import TextEmbeddingIndexSettings from .interfaces import IKnowledgeExtractor, IStorageProvider diff --git a/src/typeagent/knowpro/fuzzyindex.py b/src/typeagent/knowpro/fuzzyindex.py index 6ace1b34..97138e6c 100644 --- a/src/typeagent/knowpro/fuzzyindex.py +++ b/src/typeagent/knowpro/fuzzyindex.py @@ -137,7 +137,7 @@ def deserialize(self, embeddings: NormalizedEmbedding) -> None: assert embeddings.dtype == np.float32, embeddings.dtype assert embeddings.ndim == 2, embeddings.shape assert ( - embeddings.shape[1] == self._vector_base._embedding_size + self._vector_base._embedding_size == 0 + or embeddings.shape[1] == self._vector_base._embedding_size ), embeddings.shape - self.clear() - self.push(embeddings) + self._vector_base.deserialize(embeddings) diff --git a/src/typeagent/knowpro/serialization.py b/src/typeagent/knowpro/serialization.py index 1e48b68c..cbbe7b71 100644 --- a/src/typeagent/knowpro/serialization.py +++ b/src/typeagent/knowpro/serialization.py @@ -46,9 +46,14 @@ def create_file_header() -> FileHeader: return FileHeader(version="0.1") +class ModelMetadata(TypedDict): + embeddingSize: int + + class EmbeddingFileHeader(TypedDict): relatedCount: NotRequired[int | None] messageCount: NotRequired[int | None] + modelMetadata: NotRequired[ModelMetadata | None] class EmbeddingData(TypedDict): @@ -104,6 +109,7 @@ def to_conversation_file_data[TMessageData]( embedding_file_header = EmbeddingFileHeader() embeddings_list: list[NormalizedEmbeddings] = [] + embedding_size = 0 related_terms_index_data = conversation_data.get("relatedTermsIndexData") if related_terms_index_data is not None: @@ -114,6 +120,8 @@ def to_conversation_file_data[TMessageData]( embeddings_list.append(embeddings) text_embedding_data["embeddings"] = None embedding_file_header["relatedCount"] = len(embeddings) + if embedding_size == 0 and embeddings.ndim == 2: + embedding_size = embeddings.shape[1] message_index_data = conversation_data.get("messageIndexData") if message_index_data is not None: @@ -124,6 +132,13 @@ def to_conversation_file_data[TMessageData]( embeddings_list.append(embeddings) text_embedding_data["embeddings"] = None embedding_file_header["messageCount"] = len(embeddings) + if embedding_size == 0 and embeddings.ndim == 2: + embedding_size = embeddings.shape[1] + + if embedding_size > 0: + embedding_file_header["modelMetadata"] = ModelMetadata( + embeddingSize=embedding_size + ) binary_data = ConversationBinaryData(embeddingsList=embeddings_list) json_data = ConversationJsonData( diff --git a/src/typeagent/podcasts/podcast.py b/src/typeagent/podcasts/podcast.py index 3ed2639a..5376d20e 100644 --- a/src/typeagent/podcasts/podcast.py +++ b/src/typeagent/podcasts/podcast.py @@ -143,13 +143,19 @@ async def deserialize( @staticmethod def _read_conversation_data_from_file( - filename_prefix: str, embedding_size: int + filename_prefix: str, ) -> ConversationDataWithIndexes[Any]: """Read podcast conversation data from files. No exceptions are caught; they just bubble out.""" with open(filename_prefix + "_data.json", "r", encoding="utf-8") as f: json_data: serialization.ConversationJsonData[PodcastMessageData] = ( json.load(f) ) + embedding_file_header = json_data.get("embeddingFileHeader") + embedding_size = 0 + if embedding_file_header: + model_metadata = embedding_file_header.get("modelMetadata") + if model_metadata: + embedding_size = model_metadata.get("embeddingSize", 0) embeddings_list: list[NormalizedEmbeddings] | None = None if embedding_size: with open(filename_prefix + "_embeddings.bin", "rb") as f: @@ -159,7 +165,7 @@ def _read_conversation_data_from_file( embeddings_list = [embeddings] else: print( - "Warning: not reading embeddings file because size is {embedding_size}" + f"Warning: not reading embeddings file because size is {embedding_size}" ) embeddings_list = None file_data = serialization.ConversationFileData( @@ -178,10 +184,7 @@ async def read_from_file( settings: ConversationSettings, dbname: str | None = None, ) -> "Podcast": - embedding_size = settings.embedding_model.embedding_size - data = Podcast._read_conversation_data_from_file( - filename_prefix, embedding_size - ) + data = Podcast._read_conversation_data_from_file(filename_prefix) provider = await settings.get_storage_provider() msgs = await provider.get_message_collection() diff --git a/src/typeagent/storage/sqlite/provider.py b/src/typeagent/storage/sqlite/provider.py index 975d6a70..515c9cf0 100644 --- a/src/typeagent/storage/sqlite/provider.py +++ b/src/typeagent/storage/sqlite/provider.py @@ -6,7 +6,7 @@ from datetime import datetime, timezone import sqlite3 -from ...aitools.embeddings import create_embedding_model +from ...aitools.model_adapters import create_embedding_model from ...aitools.vectorbase import TextEmbeddingIndexSettings from ...knowpro import interfaces from ...knowpro.convsettings import MessageTextIndexSettings, RelatedTermIndexSettings @@ -125,9 +125,12 @@ def _resolve_embedding_settings( if provided_message_settings is None: if stored_size is not None or stored_name is not None: + spec = stored_name or "" + if spec and ":" not in spec: + spec = f"openai:{spec}" embedding_model = create_embedding_model( - embedding_size=stored_size, - model_name=stored_name, + spec, + embedding_size=stored_size or 0, ) base_embedding_settings = TextEmbeddingIndexSettings( embedding_model=embedding_model, diff --git a/src/typeagent/transcripts/transcript.py b/src/typeagent/transcripts/transcript.py index 494166ba..5033e293 100644 --- a/src/typeagent/transcripts/transcript.py +++ b/src/typeagent/transcripts/transcript.py @@ -143,13 +143,19 @@ async def deserialize( @staticmethod def _read_conversation_data_from_file( - filename_prefix: str, embedding_size: int + filename_prefix: str, ) -> ConversationDataWithIndexes[Any]: """Read transcript conversation data from files. No exceptions are caught; they just bubble out.""" with open(filename_prefix + "_data.json", "r", encoding="utf-8") as f: json_data: serialization.ConversationJsonData[TranscriptMessageData] = ( json.load(f) ) + embedding_file_header = json_data.get("embeddingFileHeader") + embedding_size = 0 + if embedding_file_header: + model_metadata = embedding_file_header.get("modelMetadata") + if model_metadata: + embedding_size = model_metadata.get("embeddingSize", 0) embeddings_list: list[NormalizedEmbeddings] | None = None if embedding_size: with open(filename_prefix + "_embeddings.bin", "rb") as f: @@ -159,7 +165,7 @@ def _read_conversation_data_from_file( embeddings_list = [embeddings] else: print( - "Warning: not reading embeddings file because size is {embedding_size}" + f"Warning: not reading embeddings file because size is {embedding_size}" ) embeddings_list = None file_data = serialization.ConversationFileData( @@ -178,10 +184,7 @@ async def read_from_file( settings: ConversationSettings, dbname: str | None = None, ) -> "Transcript": - embedding_size = settings.embedding_model.embedding_size - data = Transcript._read_conversation_data_from_file( - filename_prefix, embedding_size - ) + data = Transcript._read_conversation_data_from_file(filename_prefix) provider = await settings.get_storage_provider() msgs = await provider.get_message_collection() diff --git a/tests/test_serialization.py b/tests/test_serialization.py index 92aa71d9..adb46dd3 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -113,7 +113,7 @@ def test_write_and_read_conversation_data( # Read back the data read_data = Podcast._read_conversation_data_from_file( - str(filename), embedding_size=2 + str(filename), ) assert read_data is not None assert read_data.get("relatedTermsIndexData") is not None diff --git a/tools/ingest_vtt.py b/tools/ingest_vtt.py index 8bcef6d7..7fbf38fc 100644 --- a/tools/ingest_vtt.py +++ b/tools/ingest_vtt.py @@ -24,7 +24,7 @@ from dotenv import load_dotenv import webvtt -from typeagent.aitools.embeddings import create_embedding_model +from typeagent.aitools.model_adapters import create_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import ConversationMetadata from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH @@ -203,7 +203,10 @@ async def ingest_vtt_files( if verbose: print("Setting up conversation settings...") try: - embedding_model = create_embedding_model(model_name=embedding_name) + spec = embedding_name + if spec and ":" not in spec: + spec = f"openai:{spec}" + embedding_model = create_embedding_model(spec) settings = ConversationSettings(embedding_model) # Create metadata with the conversation name From 6f1286ffab1a07f31d072d0782ffc66878b92196 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 24 Feb 2026 19:01:06 -0800 Subject: [PATCH 09/49] Fix test_configure_models_returns_correct_types --- tests/test_model_adapters.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index 33bd78c7..f4c910bf 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -244,8 +244,9 @@ async def test_embedding_adapter_empty_batch() -> None: # --------------------------------------------------------------------------- -def test_configure_models_returns_correct_types() -> None: +def test_configure_models_returns_correct_types(monkeypatch: pytest.MonkeyPatch) -> None: """configure_models creates both adapters.""" + monkeypatch.setenv("OPENAI_API_KEY", "test-key") chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") assert isinstance(chat, PydanticAIChatModel) assert isinstance(embedder, PydanticAIEmbeddingModel) From 83d6f0a0b6b9dae4b0947cc417633a3a5dee0f4d Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Tue, 24 Feb 2026 19:15:32 -0800 Subject: [PATCH 10/49] Fall back on Azure for OpenAI models if only Azure key is present --- src/typeagent/aitools/model_adapters.py | 62 +++++++++++++++++++++++-- tests/test_model_adapters.py | 4 +- 2 files changed, 62 insertions(+), 4 deletions(-) diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index c558111b..fea2ed57 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -18,10 +18,16 @@ ``azure``, ``anthropic``, ``google``, ``bedrock``, ``groq``, ``mistral``, ``ollama``, ``cohere``, and many more. +When a spec uses ``openai:`` as the provider and ``OPENAI_API_KEY`` is not +set, but ``AZURE_OPENAI_API_KEY`` is available, the provider is +automatically switched to Azure OpenAI. + See https://ai.pydantic.dev/models/ for all supported providers and their required environment variables. """ +import os + import numpy as np from numpy.typing import NDArray @@ -159,6 +165,36 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: return np.array([self._cache[k] for k in keys], dtype=np.float32) +# --------------------------------------------------------------------------- +# Provider auto-detection +# --------------------------------------------------------------------------- + + +def _needs_azure_fallback(provider: str) -> bool: + """Return True if *provider* is ``openai`` but only Azure credentials exist.""" + return ( + provider == "openai" + and not os.getenv("OPENAI_API_KEY") + and bool(os.getenv("AZURE_OPENAI_API_KEY")) + ) + + +def _make_azure_provider(): + """Create a :class:`pydantic_ai.providers.azure.AzureProvider`.""" + from pydantic_ai.providers.azure import AzureProvider + + from .utils import get_azure_api_key, parse_azure_endpoint + + raw_key = os.environ["AZURE_OPENAI_API_KEY"] + api_key = get_azure_api_key(raw_key) + azure_endpoint, api_version = parse_azure_endpoint("AZURE_OPENAI_ENDPOINT") + return AzureProvider( + azure_endpoint=azure_endpoint, + api_version=api_version, + api_key=api_key, + ) + + # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- @@ -170,6 +206,8 @@ def create_chat_model( """Create a chat model from a ``provider:model`` spec. Delegates to :func:`pydantic_ai.models.infer_model` for provider wiring. + If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but + ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. Examples:: @@ -177,7 +215,13 @@ def create_chat_model( model = create_chat_model("anthropic:claude-sonnet-4-20250514") model = create_chat_model("google:gemini-2.0-flash") """ - model = infer_model(model_spec) + provider, _, model_name = model_spec.partition(":") + if _needs_azure_fallback(provider): + from pydantic_ai.models.openai import OpenAIChatModel + + model = OpenAIChatModel(model_name, provider=_make_azure_provider()) + else: + model = infer_model(model_spec) return PydanticAIChatModel(model) @@ -192,6 +236,8 @@ def create_embedding_model( """Create an embedding model from a ``provider:model`` spec. Delegates to :class:`pydantic_ai.Embedder` for provider wiring. + If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but + ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. If *model_spec* is ``None``, :data:`DEFAULT_EMBEDDING_SPEC` is used. If *embedding_size* is not given, it will be probed automatically @@ -205,8 +251,18 @@ def create_embedding_model( """ if model_spec is None: model_spec = DEFAULT_EMBEDDING_SPEC - model_name = model_spec.split(":")[-1] if ":" in model_spec else model_spec - embedder = _PydanticAIEmbedder(model_spec) + provider, _, model_name = model_spec.partition(":") + if not model_name: + model_name = provider # No colon in spec + if _needs_azure_fallback(provider): + from pydantic_ai.embeddings.openai import OpenAIEmbeddingModel + + embedding_model = OpenAIEmbeddingModel( + model_name, provider=_make_azure_provider() + ) + embedder = _PydanticAIEmbedder(embedding_model) + else: + embedder = _PydanticAIEmbedder(model_spec) return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index f4c910bf..998ddd97 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -244,7 +244,9 @@ async def test_embedding_adapter_empty_batch() -> None: # --------------------------------------------------------------------------- -def test_configure_models_returns_correct_types(monkeypatch: pytest.MonkeyPatch) -> None: +def test_configure_models_returns_correct_types( + monkeypatch: pytest.MonkeyPatch, +) -> None: """configure_models creates both adapters.""" monkeypatch.setenv("OPENAI_API_KEY", "test-key") chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") From 2659f3015675808ed876adbbba7365fa34881efd Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 09:47:51 -0800 Subject: [PATCH 11/49] Use embed_documents() instead of embed(input_type=["document"]) --- src/typeagent/aitools/model_adapters.py | 6 +++--- tests/test_model_adapters.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index fea2ed57..472eae8e 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -119,11 +119,11 @@ def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: async def _probe_embedding_size(self) -> None: """Discover embedding_size by making a single API call.""" - result = await self._embedder.embed(["probe"], input_type="document") + result = await self._embedder.embed_documents(["probe"]) self.embedding_size = len(result.embeddings[0]) async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: - result = await self._embedder.embed([input], input_type="document") + result = await self._embedder.embed_documents([input]) embedding: NDArray[np.float32] = np.array( result.embeddings[0], dtype=np.float32 ) @@ -139,7 +139,7 @@ async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings if self.embedding_size == 0: await self._probe_embedding_size() return np.empty((0, self.embedding_size), dtype=np.float32) - result = await self._embedder.embed(input, input_type="document") + result = await self._embedder.embed_documents(input) embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32) if self.embedding_size == 0: self.embedding_size = embeddings.shape[1] diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index 998ddd97..5e773430 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -124,7 +124,7 @@ async def test_embedding_adapter_single() -> None: mock_embedder = AsyncMock(spec=Embedder) raw_vec = [3.0, 4.0, 0.0] - mock_embedder.embed.return_value = EmbeddingResult( + mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[raw_vec], inputs=["test"], input_type="document", @@ -148,7 +148,7 @@ async def test_embedding_adapter_probes_size() -> None: from pydantic_ai.embeddings import EmbeddingResult mock_embedder = AsyncMock(spec=Embedder) - mock_embedder.embed.return_value = EmbeddingResult( + mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0, 0.0]], inputs=["probe"], input_type="document", @@ -171,7 +171,7 @@ async def test_embedding_adapter_batch() -> None: from pydantic_ai.embeddings import EmbeddingResult mock_embedder = AsyncMock(spec=Embedder) - mock_embedder.embed.return_value = EmbeddingResult( + mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0], [0.0, 1.0]], inputs=["a", "b"], input_type="document", @@ -193,7 +193,7 @@ async def test_embedding_adapter_caching() -> None: from pydantic_ai.embeddings import EmbeddingResult mock_embedder = AsyncMock(spec=Embedder) - mock_embedder.embed.return_value = EmbeddingResult( + mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0, 0.0]], inputs=["cached"], input_type="document", @@ -205,8 +205,8 @@ async def test_embedding_adapter_caching() -> None: first = await adapter.get_embedding("cached") second = await adapter.get_embedding("cached") np.testing.assert_array_equal(first, second) - # embed() should only be called once - assert mock_embedder.embed.call_count == 1 + # embed_documents() should only be called once + assert mock_embedder.embed_documents.call_count == 1 @pytest.mark.asyncio @@ -222,8 +222,8 @@ async def test_embedding_adapter_add_embedding() -> None: adapter.add_embedding("key", vec) result = await adapter.get_embedding("key") np.testing.assert_array_equal(result, vec) - # No embed() call needed - mock_embedder.embed.assert_not_called() + # No embed_documents() call needed + mock_embedder.embed_documents.assert_not_called() @pytest.mark.asyncio From 2b8735b9c9e82d6eaac12ba206e0e0213c904727 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 09:59:24 -0800 Subject: [PATCH 12/49] Fix the mcp test. We now do the right thing with azure endpoint env vars (I hope) --- src/typeagent/aitools/model_adapters.py | 40 ++++++++++++++++++++----- 1 file changed, 33 insertions(+), 7 deletions(-) diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 472eae8e..ab376034 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -179,20 +179,32 @@ def _needs_azure_fallback(provider: str) -> bool: ) -def _make_azure_provider(): - """Create a :class:`pydantic_ai.providers.azure.AzureProvider`.""" +def _make_azure_provider( + endpoint_envvar: str = "AZURE_OPENAI_ENDPOINT", + api_key_envvar: str = "AZURE_OPENAI_API_KEY", +): + """Create a :class:`pydantic_ai.providers.azure.AzureProvider`. + + Constructs an ``AsyncAzureOpenAI`` client from the given environment + variables and wraps it in an ``AzureProvider``. The endpoint env-var + may contain a full Azure deployment URL (including path and + ``api-version`` query parameter) — the same format used throughout + this codebase. + """ + from openai import AsyncAzureOpenAI from pydantic_ai.providers.azure import AzureProvider from .utils import get_azure_api_key, parse_azure_endpoint - raw_key = os.environ["AZURE_OPENAI_API_KEY"] + raw_key = os.environ[api_key_envvar] api_key = get_azure_api_key(raw_key) - azure_endpoint, api_version = parse_azure_endpoint("AZURE_OPENAI_ENDPOINT") - return AzureProvider( + azure_endpoint, api_version = parse_azure_endpoint(endpoint_envvar) + client = AsyncAzureOpenAI( azure_endpoint=azure_endpoint, api_version=api_version, api_key=api_key, ) + return AzureProvider(openai_client=client) # --------------------------------------------------------------------------- @@ -257,9 +269,23 @@ def create_embedding_model( if _needs_azure_fallback(provider): from pydantic_ai.embeddings.openai import OpenAIEmbeddingModel - embedding_model = OpenAIEmbeddingModel( - model_name, provider=_make_azure_provider() + from .embeddings import model_to_embedding_size_and_envvar + + # Look up model-specific Azure endpoint, falling back to the generic one. + _, suggested_envvar = model_to_embedding_size_and_envvar.get( + model_name, (None, None) ) + if suggested_envvar and os.getenv(suggested_envvar): + endpoint_envvar = suggested_envvar + else: + endpoint_envvar = "AZURE_OPENAI_ENDPOINT_EMBEDDING" + # Allow a model-specific API key, falling back to the generic one. + api_key_envvar = "AZURE_OPENAI_API_KEY_EMBEDDING" + if not os.getenv(api_key_envvar): + api_key_envvar = "AZURE_OPENAI_API_KEY" + + azure_provider = _make_azure_provider(endpoint_envvar, api_key_envvar) + embedding_model = OpenAIEmbeddingModel(model_name, provider=azure_provider) embedder = _PydanticAIEmbedder(embedding_model) else: embedder = _PydanticAIEmbedder(model_spec) From 05183d956a06ea5314e11d32144dbd3360df3625 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 11:31:13 -0800 Subject: [PATCH 13/49] Remove AsyncEmbeddingModel; migrate all tests to PydanticAIEmbeddingModel --- src/typeagent/aitools/embeddings.py | 299 +------------------ src/typeagent/aitools/model_adapters.py | 89 ++++++ tests/conftest.py | 93 +----- tests/test_add_messages_with_indexing.py | 8 +- tests/test_conversation_metadata.py | 12 +- tests/test_embedding_consistency.py | 14 +- tests/test_embeddings.py | 251 +++------------- tests/test_factory.py | 8 +- tests/test_incremental_index.py | 12 +- tests/test_message_text_index_population.py | 4 +- tests/test_messageindex.py | 16 +- tests/test_podcast_incremental.py | 6 +- tests/test_property_index_population.py | 17 +- tests/test_query_method.py | 6 +- tests/test_related_terms_fast.py | 4 +- tests/test_related_terms_index_population.py | 4 +- tests/test_secindex.py | 8 +- tests/test_secindex_storage_integration.py | 4 +- tests/test_transcripts.py | 7 +- tests/test_vectorbase.py | 15 +- 20 files changed, 200 insertions(+), 677 deletions(-) diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index 5aacaafa..f56db25d 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -1,22 +1,11 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -import asyncio -import os from typing import Protocol, runtime_checkable import numpy as np from numpy.typing import NDArray -from openai import AsyncAzureOpenAI, AsyncOpenAI, DEFAULT_MAX_RETRIES, OpenAIError -from openai.types import Embedding -import tiktoken -from tiktoken import model as tiktoken_model -from tiktoken.core import Encoding - -from .auth import AzureTokenProvider, get_shared_token_provider -from .utils import timelog - type NormalizedEmbedding = NDArray[np.float32] # A single embedding type NormalizedEmbeddings = NDArray[np.float32] # An array of embeddings @@ -26,8 +15,8 @@ class IEmbeddingModel(Protocol): """Provider-agnostic interface for embedding models. Implement this protocol to add support for a new embedding provider - (e.g. Anthropic, Gemini, local models). The existing AsyncEmbeddingModel - implements it for OpenAI and Azure OpenAI. + (e.g. Anthropic, Gemini, local models). The production implementation + is :class:`~typeagent.aitools.model_adapters.PydanticAIEmbeddingModel`. """ model_name: str @@ -58,11 +47,6 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: DEFAULT_EMBEDDING_SIZE = 1536 # Default embedding size (required for ada-002) DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING" # We support OpenAI and Azure OpenAI TEST_MODEL_NAME = "test" -MAX_BATCH_SIZE = 2048 -MAX_TOKEN_SIZE = 4096 -MAX_TOKENS_PER_BATCH = 300_000 -MAX_CHAR_SIZE = MAX_TOKEN_SIZE * 3 -MAX_CHARS_PER_BATCH = MAX_TOKENS_PER_BATCH * 3 model_to_embedding_size_and_envvar: dict[str, tuple[int | None, str]] = { DEFAULT_MODEL_NAME: (DEFAULT_EMBEDDING_SIZE, DEFAULT_ENVVAR), @@ -71,282 +55,3 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: # For testing only, not a real model (insert real embeddings above) TEST_MODEL_NAME: (3, "SIR_NOT_APPEARING_IN_THIS_FILM"), } - - -class AsyncEmbeddingModel: - model_name: str - embedding_size: int - endpoint_envvar: str - use_azure: bool - azure_token_provider: AzureTokenProvider | None - async_client: AsyncOpenAI | None - azure_endpoint: str - azure_api_version: str - encoding: Encoding | None - max_chunk_size: int - max_size_per_batch: int - - _embedding_cache: dict[str, NormalizedEmbedding] - - def __init__( - self, - embedding_size: int | None = None, - model_name: str | None = None, - endpoint_envvar: str | None = None, - max_retries: int = DEFAULT_MAX_RETRIES, - use_azure: bool | None = None, - ): - if model_name is None: - model_name = DEFAULT_MODEL_NAME - self.model_name = model_name - - suggested_embedding_size, suggested_endpoint_envvar = ( - model_to_embedding_size_and_envvar.get(model_name, (None, None)) - ) - - if embedding_size is None: - if suggested_embedding_size is not None: - embedding_size = suggested_embedding_size - else: - embedding_size = DEFAULT_EMBEDDING_SIZE - self.embedding_size = embedding_size - - if ( - model_name == DEFAULT_MODEL_NAME - and embedding_size != DEFAULT_EMBEDDING_SIZE - ): - raise ValueError( - f"Cannot customize embedding_size for default model {DEFAULT_MODEL_NAME}" - ) - - # Read API keys once - openai_api_key = os.getenv("OPENAI_API_KEY") - azure_api_key = os.getenv("AZURE_OPENAI_API_KEY") - - # Determine provider: explicit use_azure overrides auto-detection. - if use_azure is not None: - self.use_azure = use_azure - else: - # Prefer OpenAI if both are set, use Azure if only Azure is set - self.use_azure = bool(azure_api_key) and not bool(openai_api_key) - - if endpoint_envvar is None: - # Check if OpenAI credentials are available, prefer OpenAI over Azure - if openai_api_key: - endpoint_envvar = "OPENAI_BASE_URL" # Use OpenAI - elif suggested_endpoint_envvar is not None: - endpoint_envvar = suggested_endpoint_envvar - else: - endpoint_envvar = DEFAULT_ENVVAR - - self.endpoint_envvar = endpoint_envvar - self.azure_token_provider = None - - if self.model_name == TEST_MODEL_NAME: - self.async_client = None - elif self.use_azure: - if not azure_api_key: - raise ValueError("AZURE_OPENAI_API_KEY not found in environment.") - with timelog("Using Azure OpenAI"): - self._setup_azure(azure_api_key) - else: - if not openai_api_key: - raise ValueError("OPENAI_API_KEY not found in environment.") - endpoint = os.getenv(self.endpoint_envvar) - with timelog("Using OpenAI"): - self.async_client = AsyncOpenAI( - base_url=endpoint, api_key=openai_api_key, max_retries=max_retries - ) - - if self.model_name in tiktoken_model.MODEL_TO_ENCODING: - encoding_name = tiktoken.encoding_name_for_model(self.model_name) - self.encoding = tiktoken.get_encoding(encoding_name) - self.max_chunk_size = MAX_TOKEN_SIZE - self.max_size_per_batch = MAX_TOKENS_PER_BATCH - else: - self.encoding = None - self.max_chunk_size = MAX_CHAR_SIZE - self.max_size_per_batch = MAX_CHARS_PER_BATCH - - self._embedding_cache = {} - - def _setup_azure(self, azure_api_key: str) -> None: - from .utils import get_azure_api_key, parse_azure_endpoint - - azure_api_key = get_azure_api_key(azure_api_key) - self.azure_endpoint, self.azure_api_version = parse_azure_endpoint( - self.endpoint_envvar - ) - - if azure_api_key != os.getenv("AZURE_OPENAI_API_KEY"): - # If we got a token from identity, store the provider for refresh - self.azure_token_provider = get_shared_token_provider() - - self.async_client = AsyncAzureOpenAI( - api_version=self.azure_api_version, - azure_endpoint=self.azure_endpoint, - api_key=azure_api_key, - ) - - async def refresh_auth(self): - """Update client when using a token provider and it's nearly expired.""" - # refresh_token is synchronous and slow -- run it in a separate thread - assert self.azure_token_provider - refresh_token = self.azure_token_provider.refresh_token - loop = asyncio.get_running_loop() - azure_api_key = await loop.run_in_executor(None, refresh_token) - assert self.azure_api_version - assert self.azure_endpoint - self.async_client = AsyncAzureOpenAI( - api_version=self.azure_api_version, - azure_endpoint=self.azure_endpoint, - api_key=azure_api_key, - ) - - def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: - existing = self._embedding_cache.get(key) - if existing is not None: - assert np.array_equal(existing, embedding) - else: - self._embedding_cache[key] = embedding - - async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: - embeddings = await self.get_embeddings_nocache([input]) - return embeddings[0] - - async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: - if not input: - empty = np.array([], dtype=np.float32) - empty.shape = (0, self.embedding_size) - return empty - if self.azure_token_provider and self.azure_token_provider.needs_refresh(): - await self.refresh_auth() - extra_args = {} - if self.model_name != DEFAULT_MODEL_NAME: - extra_args["dimensions"] = self.embedding_size - if self.async_client is None: - # Compute a random embedding for testing purposes. - - def hashish(s: str) -> int: - # Primitive deterministic hash function (hash() varies per run) - h = 0 - for ch in s: - h = (h * 31 + ord(ch)) & 0xFFFFFFFF - return h - - prime = 1961 - fake_data: list[NormalizedEmbedding] = [] - for item in input: - if not item: - raise OpenAIError - length = len(item) - floats = [] - for i in range(self.embedding_size): - cut = i % length - scrambled = item[cut:] + item[:cut] - hashed = hashish(scrambled) - reduced = (hashed % prime) / prime - floats.append(reduced) - array = np.array(floats, dtype=np.float64) - normalized = array / np.sqrt(np.dot(array, array)) - dot = np.dot(normalized, normalized) - assert ( - abs(dot - 1.0) < 1e-15 - ), f"Embedding {normalized} is not normalized: {dot}" - fake_data.append(normalized) - assert len(fake_data) == len(input), (len(fake_data), "!=", len(input)) - result = np.array(fake_data, dtype=np.float32) - return result - else: - batches: list[list[str]] = [] - batch: list[str] = [] - batch_sum: int = 0 - for sentence in input: - truncated_input, truncated_input_size = await self.truncate_input( - sentence - ) - if ( - len(batch) >= MAX_BATCH_SIZE - or batch_sum + truncated_input_size > self.max_size_per_batch - ): - batches.append(batch) - batch = [] - batch_sum = 0 - batch.append(truncated_input) - batch_sum += truncated_input_size - if batch: - batches.append(batch) - - data: list[Embedding] = [] - for batch in batches: - embeddings_data = ( - await self.async_client.embeddings.create( - input=batch, - model=self.model_name, - encoding_format="float", - **extra_args, - ) - ).data - data.extend(embeddings_data) - - assert len(data) == len(input), (len(data), "!=", len(input)) - return np.array([d.embedding for d in data], dtype=np.float32) - - async def get_embedding(self, key: str) -> NormalizedEmbedding: - """Retrieve an embedding, using the cache.""" - if key in self._embedding_cache: - return self._embedding_cache[key] - embedding = await self.get_embedding_nocache(key) - self._embedding_cache[key] = embedding - return embedding - - async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: - """Retrieve embeddings for multiple keys, using the cache.""" - embeddings: list[NormalizedEmbedding | None] = [] - missing_keys: list[str] = [] - - # Collect cached embeddings and identify missing keys - for key in keys: - if key in self._embedding_cache: - embeddings.append(self._embedding_cache[key]) - else: - embeddings.append(None) # Placeholder for missing keys - missing_keys.append(key) - - # Retrieve embeddings for missing keys - if missing_keys: - new_embeddings = await self.get_embeddings_nocache(missing_keys) - for key, embedding in zip(missing_keys, new_embeddings): - self._embedding_cache[key] = embedding - - # Replace placeholders with retrieved embeddings - for i, key in enumerate(keys): - if embeddings[i] is None: - embeddings[i] = self._embedding_cache[key] - return np.array(embeddings, dtype=np.float32).reshape( - (len(keys), self.embedding_size) - ) - - async def truncate_input(self, input: str) -> tuple[str, int]: - """Truncate input strings to fit within model limits. - - args: - input: The input string to truncate. - - returns: - A tuple of (truncated string, size after truncation). - """ - if self.encoding is None: - # Non-token-aware truncation - if len(input) > self.max_chunk_size: - return input[: self.max_chunk_size], self.max_chunk_size - else: - return input, len(input) - else: - # Token-aware truncation - tokens = self.encoding.encode(input) - if len(tokens) > self.max_chunk_size: - truncated_tokens = tokens[: self.max_chunk_size] - return self.encoding.decode(truncated_tokens), self.max_chunk_size - else: - return input, len(tokens) diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index ab376034..4acc0ff8 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -26,12 +26,16 @@ required environment variables. """ +from collections.abc import Sequence import os import numpy as np from numpy.typing import NDArray from pydantic_ai import Embedder as _PydanticAIEmbedder +from pydantic_ai.embeddings.base import EmbeddingModel as _PydanticAIEmbeddingModelBase +from pydantic_ai.embeddings.result import EmbeddingResult, EmbedInputType +from pydantic_ai.embeddings.settings import EmbeddingSettings from pydantic_ai.messages import ( ModelMessage, ModelRequest, @@ -157,6 +161,10 @@ async def get_embedding(self, key: str) -> NormalizedEmbedding: return embedding async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + if not keys: + if self.embedding_size == 0: + await self._probe_embedding_size() + return np.empty((0, self.embedding_size), dtype=np.float32) missing_keys = [k for k in keys if k not in self._cache] if missing_keys: fresh = await self.get_embeddings_nocache(missing_keys) @@ -292,6 +300,87 @@ def create_embedding_model( return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) +# --------------------------------------------------------------------------- +# Test helpers +# --------------------------------------------------------------------------- + + +def _hashish(s: str) -> int: + """Deterministic hash function for fake embeddings (hash() varies per run).""" + h = 0 + for ch in s: + h = (h * 31 + ord(ch)) & 0xFFFFFFFF + return h + + +def _compute_fake_embeddings( + input_texts: list[str], embedding_size: int +) -> list[list[float]]: + """Generate deterministic fake embeddings for testing (unnormalized). + + Raises :class:`ValueError` on empty input strings. + """ + prime = 1961 + result: list[list[float]] = [] + for item in input_texts: + if not item: + raise ValueError("Empty input text") + length = len(item) + floats: list[float] = [] + for i in range(embedding_size): + cut = i % length + scrambled = item[cut:] + item[:cut] + hashed = _hashish(scrambled) + reduced = (hashed % prime) / prime + floats.append(reduced) + result.append(floats) + return result + + +class _FakePydanticAIEmbeddingModel(_PydanticAIEmbeddingModelBase): + """A pydantic_ai :class:`EmbeddingModel` that returns deterministic fake + embeddings. Used only for testing — no network calls are made.""" + + def __init__(self, embedding_size: int = 3) -> None: + super().__init__() + self._embedding_size = embedding_size + + @property + def model_name(self) -> str: + return "test" + + @property + def system(self) -> str: + return "test" + + async def embed( + self, + inputs: str | Sequence[str], + *, + input_type: EmbedInputType, + settings: EmbeddingSettings | None = None, + ) -> EmbeddingResult: + inputs_list, settings = self.prepare_embed(inputs, settings) + embeddings = _compute_fake_embeddings(inputs_list, self._embedding_size) + return EmbeddingResult( + embeddings=embeddings, + inputs=inputs_list, + input_type=input_type, + model_name="test", + provider_name="test", + ) + + +def create_test_embedding_model( + embedding_size: int = 3, +) -> PydanticAIEmbeddingModel: + """Create a :class:`PydanticAIEmbeddingModel` with deterministic fake + embeddings for testing. No API keys or network access required.""" + fake_model = _FakePydanticAIEmbeddingModel(embedding_size) + embedder = _PydanticAIEmbedder(fake_model) + return PydanticAIEmbeddingModel(embedder, "test", embedding_size) + + def configure_models( chat_model_spec: str, embedding_model_spec: str, diff --git a/tests/conftest.py b/tests/conftest.py index 7f7ce210..c4de6d47 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,15 +11,8 @@ import pytest import pytest_asyncio -from openai.types.create_embedding_response import CreateEmbeddingResponse, Usage -from openai.types.embedding import Embedding -import tiktoken - -from typeagent.aitools.embeddings import ( - AsyncEmbeddingModel, - IEmbeddingModel, - TEST_MODEL_NAME, -) +from typeagent.aitools.embeddings import IEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -96,7 +89,7 @@ def really_needs_auth() -> None: @pytest.fixture(scope="session") def embedding_model() -> IEmbeddingModel: """Fixture to create a test embedding model with small embedding size for faster tests.""" - return AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + return create_test_embedding_model() @pytest.fixture(scope="session") @@ -303,7 +296,7 @@ def __init__( self._has_secondary_indexes = has_secondary_indexes else: # Create test model for settings - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() self.settings = ConversationSettings(test_model, storage_provider) self._needs_async_init = False self._storage_provider = storage_provider @@ -323,7 +316,7 @@ def __init__( async def ensure_initialized(self): """Ensure async initialization is complete.""" if self._needs_async_init: - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() self.settings = ConversationSettings(test_model) storage_provider = await self.settings.get_storage_provider() self._storage_provider = storage_provider @@ -355,79 +348,3 @@ async def fake_conversation_with_storage( ) -> FakeConversation: """Fixture to create a FakeConversation instance with storage provider.""" return FakeConversation(storage_provider=memory_storage) - - -class FakeEmbeddings: - - def __init__( - self, - max_batch_size: int = 2048, - max_chunk_size: int = 4096, - max_elements_per_batch: int = 300_000, - use_tiktoken: bool = False, - ): - self.model_name = "text-embedding-ada-002" - self.call_count = 0 - self.max_batch_size = max_batch_size - self.max_chunk_size = max_chunk_size - self.max_elements_per_batch = max_elements_per_batch - self.use_tiktoken = use_tiktoken - - def reset_counter(self): - self.call_count = 0 - - async def create(self, **kwargs): - self.call_count += 1 - input = kwargs["input"] - len_input = len(input) - if len_input > self.max_batch_size: - raise ValueError("Embedding model received batch larger 2048") - dimensions = 1536 - if "dimensions" in kwargs: - dimensions = kwargs["dimensions"] - - embedding_result = [] - total_elements = 0 - for index in range(len_input): - entity = input[index] - if self.use_tiktoken: - enc_name = tiktoken.encoding_name_for_model(self.model_name) - enc = tiktoken.get_encoding(enc_name) - entity = enc.encode(entity) - total_elements += len(entity) - if len(entity) > self.max_chunk_size: - raise ValueError( - f"Chunk size {len(entity)} larger than max size {self.max_chunk_size}" - ) - value = index % 2 - embedding_result.append( - Embedding( - embedding=[value] * dimensions, index=index, object="embedding" - ) - ) - - if total_elements > self.max_elements_per_batch: - raise ValueError( - f"Batch size {total_elements} larger than max tokens/chars per batch {self.max_elements_per_batch}" - ) - - response = CreateEmbeddingResponse( - data=embedding_result, - model="test_model", - object="list", - usage=Usage(prompt_tokens=0, total_tokens=0), - ) - - return response - - -@pytest.fixture -def fake_embeddings() -> FakeEmbeddings: - """Fixture to create a FaceEmbedding instance""" - return FakeEmbeddings(max_batch_size=2048, max_chunk_size=4096 * 3) - - -@pytest.fixture -def fake_embeddings_tiktoken() -> FakeEmbeddings: - """Fixture to create a FaceEmbedding instance""" - return FakeEmbeddings(max_batch_size=2048, max_chunk_size=4096, use_tiktoken=True) diff --git a/tests/test_add_messages_with_indexing.py b/tests/test_add_messages_with_indexing.py index 4f00cfb1..d3df2c4d 100644 --- a/tests/test_add_messages_with_indexing.py +++ b/tests/test_add_messages_with_indexing.py @@ -8,7 +8,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.storage.sqlite.provider import SqliteStorageProvider from typeagent.transcripts.transcript import ( @@ -24,7 +24,7 @@ async def test_add_messages_with_indexing_basic(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -65,7 +65,7 @@ async def test_add_messages_with_indexing_batched(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -122,7 +122,7 @@ async def test_transaction_rollback_on_error(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False diff --git a/tests/test_conversation_metadata.py b/tests/test_conversation_metadata.py index 69d80dab..887c50b2 100644 --- a/tests/test_conversation_metadata.py +++ b/tests/test_conversation_metadata.py @@ -16,11 +16,8 @@ from pydantic.dataclasses import dataclass -from typeagent.aitools.embeddings import ( - AsyncEmbeddingModel, - IEmbeddingModel, - TEST_MODEL_NAME, -) +from typeagent.aitools.embeddings import IEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -80,7 +77,7 @@ async def storage_provider_memory() -> ( AsyncGenerator[SqliteStorageProvider[DummyMessage], None] ): """Create an in-memory SqliteStorageProvider for testing conversation metadata.""" - embedding_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(embedding_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) @@ -624,9 +621,8 @@ async def test_embedding_metadata_mismatch_raises( provider.db.commit() await provider.close() - mismatched_model = AsyncEmbeddingModel( + mismatched_model = create_test_embedding_model( embedding_size=embedding_settings.embedding_size + 1, - model_name=embedding_model.model_name, ) mismatched_settings = TextEmbeddingIndexSettings( embedding_model=mismatched_model, diff --git a/tests/test_embedding_consistency.py b/tests/test_embedding_consistency.py index 906c2b52..f032c856 100644 --- a/tests/test_embedding_consistency.py +++ b/tests/test_embedding_consistency.py @@ -9,7 +9,7 @@ import pytest from typeagent import create_conversation -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.storage.sqlite import SqliteStorageProvider from typeagent.transcripts.transcript import TranscriptMessage, TranscriptMessageMeta @@ -25,7 +25,7 @@ async def test_embedding_size_mismatch_in_message_index(): try: # Create a conversation with test model (embedding size 3) settings1 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=3, model_name="test") + model=create_test_embedding_model(embedding_size=3) ) # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False @@ -46,7 +46,7 @@ async def test_embedding_size_mismatch_in_message_index(): # Now try to open the same database with a different embedding size # This should raise an error settings2 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=5, model_name="test") + model=create_test_embedding_model(embedding_size=5) ) with pytest.raises(ValueError, match="embedding_size"): @@ -74,7 +74,7 @@ async def test_embedding_size_mismatch_in_related_terms(): try: # Create a conversation with default embedding size settings1 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=3, model_name="test") + model=create_test_embedding_model(embedding_size=3) ) # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False @@ -95,7 +95,7 @@ async def test_embedding_size_mismatch_in_related_terms(): # Now try to open the same database with a different embedding size # This should raise an error settings2 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=5, model_name="test") + model=create_test_embedding_model(embedding_size=5) ) with pytest.raises(ValueError, match="embedding_size"): @@ -123,7 +123,7 @@ async def test_empty_db_no_error(): try: # Create an empty database settings1 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=3, model_name="test") + model=create_test_embedding_model(embedding_size=3) ) # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False @@ -134,7 +134,7 @@ async def test_empty_db_no_error(): # Open with different embedding size should work since DB is empty settings2 = ConversationSettings( - model=AsyncEmbeddingModel(embedding_size=5, model_name="test") + model=create_test_embedding_model(embedding_size=5) ) provider = SqliteStorageProvider( db_path=db_path, diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index ac766f1b..2ae09845 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -3,23 +3,18 @@ import numpy as np import pytest -from pytest import MonkeyPatch from pytest_mock import MockerFixture -import openai - -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel +from typeagent.aitools.model_adapters import PydanticAIEmbeddingModel from conftest import ( embedding_model, # type: ignore # Magic, prevents side effects of mocking ) -from conftest import ( - FakeEmbeddings, -) @pytest.mark.asyncio -async def test_get_embedding_nocache(embedding_model: AsyncEmbeddingModel): +async def test_get_embedding_nocache(embedding_model: PydanticAIEmbeddingModel): """Test retrieving an embedding without using the cache.""" input_text = "Hello, world" embedding = await embedding_model.get_embedding_nocache(input_text) @@ -30,7 +25,7 @@ async def test_get_embedding_nocache(embedding_model: AsyncEmbeddingModel): @pytest.mark.asyncio -async def test_get_embeddings_nocache(embedding_model: AsyncEmbeddingModel): +async def test_get_embeddings_nocache(embedding_model: PydanticAIEmbeddingModel): """Test retrieving multiple embeddings without using the cache.""" inputs = ["Hello, world", "Foo bar baz"] embeddings = await embedding_model.get_embeddings_nocache(inputs) @@ -42,14 +37,14 @@ async def test_get_embeddings_nocache(embedding_model: AsyncEmbeddingModel): @pytest.mark.asyncio async def test_get_embedding_with_cache( - embedding_model: AsyncEmbeddingModel, mocker: MockerFixture + embedding_model: PydanticAIEmbeddingModel, mocker: MockerFixture ): """Test retrieving an embedding with caching.""" input_text = "Hello, world" # First call should populate the cache embedding1 = await embedding_model.get_embedding(input_text) - assert input_text in embedding_model._embedding_cache + assert input_text in embedding_model._cache # Mock the nocache method to ensure it's not called mock_get_embedding_nocache = mocker.patch.object( @@ -66,7 +61,7 @@ async def test_get_embedding_with_cache( @pytest.mark.asyncio async def test_get_embeddings_with_cache( - embedding_model: AsyncEmbeddingModel, mocker: MockerFixture + embedding_model: PydanticAIEmbeddingModel, mocker: MockerFixture ): """Test retrieving multiple embeddings with caching.""" inputs = ["Hello, world", "Foo bar baz"] @@ -74,7 +69,7 @@ async def test_get_embeddings_with_cache( # First call should populate the cache embeddings1 = await embedding_model.get_embeddings(inputs) for input_text in inputs: - assert input_text in embedding_model._embedding_cache + assert input_text in embedding_model._cache # Mock the nocache method to ensure it's not called mock_get_embeddings_nocache = mocker.patch.object( @@ -90,9 +85,9 @@ async def test_get_embeddings_with_cache( @pytest.mark.asyncio -async def test_get_embeddings_empty_input(embedding_model: AsyncEmbeddingModel): +async def test_get_embeddings_empty_input(embedding_model: PydanticAIEmbeddingModel): """Test retrieving embeddings for an empty input list.""" - inputs = [] + inputs: list[str] = [] embeddings = await embedding_model.get_embeddings(inputs) assert isinstance(embeddings, np.ndarray) @@ -101,222 +96,60 @@ async def test_get_embeddings_empty_input(embedding_model: AsyncEmbeddingModel): @pytest.mark.asyncio -async def test_add_embedding_to_cache(embedding_model: AsyncEmbeddingModel): +async def test_add_embedding_to_cache(embedding_model: PydanticAIEmbeddingModel): """Test adding an embedding to the cache.""" key = "test_key" embedding = np.array([0.1, 0.2, 0.3], dtype=np.float32) embedding_model.add_embedding(key, embedding) - assert key in embedding_model._embedding_cache - assert np.array_equal(embedding_model._embedding_cache[key], embedding) + assert key in embedding_model._cache + assert np.array_equal(embedding_model._cache[key], embedding) @pytest.mark.asyncio -async def test_get_embedding_nocache_empty_input(embedding_model: AsyncEmbeddingModel): +async def test_get_embedding_nocache_empty_input( + embedding_model: PydanticAIEmbeddingModel, +): """Test retrieving an embedding with no cache for an empty input.""" - with pytest.raises(openai.OpenAIError): + with pytest.raises(ValueError, match="Empty input text"): await embedding_model.get_embedding_nocache("") @pytest.mark.asyncio -async def test_refresh_auth( - embedding_model: AsyncEmbeddingModel, mocker: MockerFixture -): - """Test refreshing authentication when using Azure.""" - # Note that pyright doesn't understand mocking, hence the `# type: ignore` below - mocker.patch.object(embedding_model, "azure_token_provider", autospec=True) - mocker.patch.object(embedding_model, "_setup_azure", autospec=True) - - embedding_model.azure_token_provider.needs_refresh.return_value = True # type: ignore - embedding_model.azure_token_provider.refresh_token.return_value = "new_token" # type: ignore - embedding_model.azure_api_version = "2023-05-15" - embedding_model.azure_endpoint = "https://example.azure.com" - - await embedding_model.refresh_auth() +async def test_embeddings_are_normalized(embedding_model: PydanticAIEmbeddingModel): + """Test that returned embeddings are unit-normalized.""" + inputs = ["Hello, world", "Foo bar baz", "Testing normalization"] + embeddings = await embedding_model.get_embeddings_nocache(inputs) - embedding_model.azure_token_provider.refresh_token.assert_called_once() # type: ignore - assert embedding_model.async_client is not None + for i in range(len(inputs)): + norm = float(np.linalg.norm(embeddings[i])) + assert abs(norm - 1.0) < 1e-6, f"Embedding {i} not normalized: norm={norm}" @pytest.mark.asyncio -async def test_set_endpoint(monkeypatch: MonkeyPatch): - """Test creating of model with custom endpoint.""" - - monkeypatch.setenv("AZURE_OPENAI_API_KEY", "does-not-matter") - monkeypatch.delenv("OPENAI_API_KEY", raising=False) # Ensure Azure path is used - - # Default - monkeypatch.setenv( - "AZURE_OPENAI_ENDPOINT_EMBEDDING", - "http://localhost:7997?api-version=2024-06-01", - ) - embedding_model = AsyncEmbeddingModel() - assert embedding_model.embedding_size == 1536 - assert embedding_model.model_name == "text-embedding-ada-002" - assert embedding_model.endpoint_envvar == "AZURE_OPENAI_ENDPOINT_EMBEDDING" - - # 3-large - monkeypatch.setenv( - "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE", - "http://localhost:7997?api-version=2024-06-01", - ) - embedding_model = AsyncEmbeddingModel(model_name="text-embedding-3-large") - assert embedding_model.embedding_size == 3072 - assert embedding_model.model_name == "text-embedding-3-large" - assert embedding_model.endpoint_envvar == "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE" - - # 3-small - monkeypatch.setenv( - "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL", - "http://localhost:7998?api-version=2024-06-01", - ) - embedding_model = AsyncEmbeddingModel(model_name="text-embedding-3-small") - assert embedding_model.embedding_size == 1536 - assert embedding_model.model_name == "text-embedding-3-small" - assert embedding_model.endpoint_envvar == "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL" - - # Fully custom with OpenAI - monkeypatch.setenv("OPENAI_API_KEY", "does-not-matter") - monkeypatch.setenv("INFINITY_EMBEDDING_URL", "http://localhost:7997") - embedding_model = AsyncEmbeddingModel( - 1024, "custom_model", endpoint_envvar="INFINITY_EMBEDDING_URL" - ) - assert embedding_model.embedding_size == 1024 - assert embedding_model.model_name == "custom_model" - # NOTE: checking openai.AsyncOpenAI internals - assert embedding_model.async_client is not None - assert embedding_model.async_client.base_url == "http://localhost:7997" - assert embedding_model.async_client.api_key == "does-not-matter" - assert embedding_model.endpoint_envvar == "INFINITY_EMBEDDING_URL" - - # Customized 3-small with Azure (endpoint_envvar must contain "AZURE") - monkeypatch.delenv("OPENAI_API_KEY") # Force Azure path - monkeypatch.setenv( - "AZURE_ALTERNATE_ENDPOINT", - "http://localhost:7999?api-version=2024-06-01", - ) - embedding_model = AsyncEmbeddingModel( - 2000, "text-embedding-3-small", endpoint_envvar="AZURE_ALTERNATE_ENDPOINT" - ) - assert embedding_model.embedding_size == 2000 - assert embedding_model.model_name == "text-embedding-3-small" - assert embedding_model.endpoint_envvar == "AZURE_ALTERNATE_ENDPOINT" - - # Allow explicitly setting default embedding size - AsyncEmbeddingModel(1536) - - # Can't customize embedding_size for default model - with pytest.raises(ValueError): - AsyncEmbeddingModel(1024) - - # Not even when default model name specified explicitly - with pytest.raises(ValueError): - AsyncEmbeddingModel(1024, "text-embedding-ada-002") +async def test_embeddings_are_deterministic( + embedding_model: PydanticAIEmbeddingModel, +): + """Test that the same input always produces the same embedding.""" + input_text = "Deterministic test" + e1 = await embedding_model.get_embedding_nocache(input_text) + e2 = await embedding_model.get_embedding_nocache(input_text) + assert np.array_equal(e1, e2) @pytest.mark.asyncio -async def test_embeddings_batching_tiktoken( - fake_embeddings_tiktoken: FakeEmbeddings, monkeypatch: MonkeyPatch +async def test_different_inputs_produce_different_embeddings( + embedding_model: PydanticAIEmbeddingModel, ): - monkeypatch.setenv("OPENAI_API_KEY", "test_key") - - embedding_model = AsyncEmbeddingModel() - assert embedding_model.max_chunk_size == 4096 - - embedding_model.async_client.embeddings = fake_embeddings_tiktoken # type: ignore - - # Check max batch size - inputs = ["a"] * 2049 - embeddings = await embedding_model.get_embeddings(inputs) - assert len(embeddings) == 2049 - assert fake_embeddings_tiktoken.call_count == 2 - - # Check max token size - inputs = ["Very long input longer than 4096 tokens will be truncated" * 500] - embeddings = await embedding_model.get_embeddings(inputs) - assert len(embeddings) == 1 - - fake_embeddings_tiktoken.reset_counter() - - TEST_MAX_TOKEN_SIZE = 10 - TEST_MAX_TOKENS_PER_BATCH = 20 - embedding_model.max_chunk_size = TEST_MAX_TOKEN_SIZE - embedding_model.max_size_per_batch = TEST_MAX_TOKENS_PER_BATCH - fake_embeddings_tiktoken.max_elements_per_batch = TEST_MAX_TOKENS_PER_BATCH - - assert embedding_model.encoding is not None - - token = [500] * 20 # --> 20 tokens - input = [embedding_model.encoding.decode(token)] * 4 - embeddings = await embedding_model.get_embeddings_nocache(input) # type: ignore - - # each input gets truncated to 10 tokens, so 4 inputs fit in 2 batches of 20 tokens - assert fake_embeddings_tiktoken.call_count == 2 - assert len(embeddings) == 4 - - fake_embeddings_tiktoken.reset_counter() - - TEST_MAX_TOKEN_SIZE = 7 - embedding_model.max_chunk_size = TEST_MAX_TOKEN_SIZE - - token = [500] * 20 # --> 20 tokens - input = [embedding_model.encoding.decode(token)] * 5 - embeddings = await embedding_model.get_embeddings_nocache(input) # type: ignore - - # each input gets truncated to 7 tokens, so each batch can hold 2 inputs (14 tokens) - # 5 inputs require 3 batches - assert fake_embeddings_tiktoken.call_count == 3 - assert len(embeddings) == 5 + """Test that different inputs produce different embeddings.""" + e1 = await embedding_model.get_embedding_nocache("Hello") + e2 = await embedding_model.get_embedding_nocache("World") + assert not np.array_equal(e1, e2) @pytest.mark.asyncio -async def test_embeddings_batching( - fake_embeddings: FakeEmbeddings, monkeypatch: MonkeyPatch +async def test_implements_iembedding_model( + embedding_model: PydanticAIEmbeddingModel, ): - monkeypatch.setenv("OPENAI_API_KEY", "test_key") - - embedding_model = AsyncEmbeddingModel(1024, "custom_model") - embedding_model.async_client.embeddings = fake_embeddings # type: ignore - - # Check max batch size - inputs = ["a"] * 2049 - embeddings = await embedding_model.get_embeddings(inputs) - assert len(embeddings) == 2049 - assert fake_embeddings.call_count == 2 - - TEST_MAX_CHAR_SIZE = 10 - TEST_MAX_CHARS_PER_BATCH = 20 - embedding_model.max_chunk_size = TEST_MAX_CHAR_SIZE - embedding_model.max_size_per_batch = TEST_MAX_CHARS_PER_BATCH - fake_embeddings.max_elements_per_batch = TEST_MAX_CHARS_PER_BATCH - - # Check max token size - inputs = ["a" * TEST_MAX_CHAR_SIZE] - embeddings = await embedding_model.get_embeddings_nocache(inputs) - assert len(embeddings) == 1 - assert np.all(embeddings[0] == 0) - - fake_embeddings.reset_counter() - - # Check one over max token size - inputs = ["a" * (TEST_MAX_CHAR_SIZE + 1)] - embeddings = await embedding_model.get_embeddings_nocache(inputs) - assert len(embeddings) == 1 - assert fake_embeddings.call_count == 1 - - fake_embeddings.reset_counter() - - # Check input as large as max_size_per_batch - inputs = ["a" * 10, "a" * 5, "a" * 5] - embeddings = await embedding_model.get_embeddings_nocache(inputs) # type: ignore - assert fake_embeddings.call_count == 1 - assert len(embeddings) == 3 - - fake_embeddings.reset_counter() - - # Check input larger than max_size_per_batch - # max chars per batch is 20, so 10*10 chars requires 5 batches - inputs = ["a" * 10] * 10 - embeddings = await embedding_model.get_embeddings_nocache(inputs) # type: ignore - assert fake_embeddings.call_count == 5 - assert len(embeddings) == 10 + """Test that PydanticAIEmbeddingModel satisfies the IEmbeddingModel protocol.""" + assert isinstance(embedding_model, IEmbeddingModel) diff --git a/tests/test_factory.py b/tests/test_factory.py index 0f62220f..44c45e5f 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -6,7 +6,7 @@ import pytest from typeagent import create_conversation -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.transcripts.transcript import TranscriptMessage, TranscriptMessageMeta @@ -15,7 +15,7 @@ async def test_create_conversation_minimal(): """Test creating a conversation with minimal parameters.""" # Create empty conversation with test model - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, @@ -36,7 +36,7 @@ async def test_create_conversation_minimal(): @pytest.mark.asyncio async def test_create_conversation_with_tags(): """Test creating a conversation with tags.""" - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, @@ -54,7 +54,7 @@ async def test_create_conversation_with_tags(): async def test_create_conversation_and_add_messages(really_needs_auth): """Test the complete workflow: create conversation and add messages.""" # 1. Create empty conversation - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, diff --git a/tests/test_incremental_index.py b/tests/test_incremental_index.py index 12f706a6..ced12f19 100644 --- a/tests/test_incremental_index.py +++ b/tests/test_incremental_index.py @@ -8,7 +8,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.storage.sqlite.provider import SqliteStorageProvider from typeagent.transcripts.transcript import ( @@ -30,7 +30,7 @@ async def test_incremental_index_building(): db_path = os.path.join(tmpdir, "test.db") # Create settings with test model (no API keys needed) - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -74,7 +74,7 @@ async def test_incremental_index_building(): # Second ingestion - add more messages and rebuild index print("\n=== Second ingestion ===") - test_model2 = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model2 = create_test_embedding_model() settings2 = ConversationSettings(model=test_model2) settings2.semantic_ref_index_settings.auto_extract_knowledge = False storage2 = SqliteStorageProvider( @@ -136,7 +136,7 @@ async def test_incremental_index_with_vtt_files(): db_path = os.path.join(tmpdir, "test.db") # Create settings with test model (no API keys needed) - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -161,9 +161,7 @@ async def test_incremental_index_with_vtt_files(): # Second VTT file ingestion into same database print("\n=== Import second VTT file ===") - settings2 = ConversationSettings( - model=AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) - ) + settings2 = ConversationSettings(model=create_test_embedding_model()) settings2.semantic_ref_index_settings.auto_extract_knowledge = False # Ingest the second transcript diff --git a/tests/test_message_text_index_population.py b/tests/test_message_text_index_population.py index 384aaa9c..13d53c00 100644 --- a/tests/test_message_text_index_population.py +++ b/tests/test_message_text_index_population.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, @@ -30,7 +30,7 @@ async def test_message_text_index_population_from_database(): try: # Use the test model that's already configured in the system - embedding_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(embedding_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) diff --git a/tests/test_messageindex.py b/tests/test_messageindex.py index b4e40dd6..ec80498f 100644 --- a/tests/test_messageindex.py +++ b/tests/test_messageindex.py @@ -42,10 +42,10 @@ def message_text_index( mock_text_location_index: MagicMock, ) -> IMessageTextEmbeddingIndex: """Fixture to create a MessageTextIndex instance with a mocked TextToTextLocationIndex.""" - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) settings = MessageTextIndexSettings(embedding_settings) index = MessageTextIndex(settings) @@ -55,10 +55,10 @@ def message_text_index( def test_message_text_index_init(needs_auth: None): """Test initialization of MessageTextIndex.""" - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) settings = MessageTextIndexSettings(embedding_settings) index = MessageTextIndex(settings) @@ -147,11 +147,11 @@ async def test_generate_embedding(needs_auth: None): """Test generating an embedding for a message without mocking.""" import numpy as np - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings # Create real MessageTextIndex with test model - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) settings = MessageTextIndexSettings(embedding_settings) index = MessageTextIndex(settings) @@ -205,14 +205,14 @@ async def test_build_message_index(needs_auth: None): ] # Create storage provider asynchronously - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( MessageTextIndexSettings, RelatedTermIndexSettings, ) - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) diff --git a/tests/test_podcast_incremental.py b/tests/test_podcast_incremental.py index 92d5ad32..4b1732d6 100644 --- a/tests/test_podcast_incremental.py +++ b/tests/test_podcast_incremental.py @@ -8,7 +8,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.podcasts.podcast import Podcast, PodcastMessage, PodcastMessageMeta from typeagent.storage.sqlite.provider import SqliteStorageProvider @@ -20,7 +20,7 @@ async def test_podcast_add_messages_with_indexing(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False @@ -57,7 +57,7 @@ async def test_podcast_add_messages_batched(): with tempfile.TemporaryDirectory() as tmpdir: db_path = os.path.join(tmpdir, "test.db") - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) settings.semantic_ref_index_settings.auto_extract_knowledge = False diff --git a/tests/test_property_index_population.py b/tests/test_property_index_population.py index 8b751bb7..f6cc3edb 100644 --- a/tests/test_property_index_population.py +++ b/tests/test_property_index_population.py @@ -8,10 +8,9 @@ import tempfile from dotenv import load_dotenv -import numpy as np import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro import kplib from typeagent.knowpro.convsettings import ( @@ -23,16 +22,6 @@ from typeagent.storage import SqliteStorageProvider -class MockEmbeddingModel(AsyncEmbeddingModel): - def __init__(self): - super().__init__(embedding_size=3, model_name="test") - - async def get_embeddings(self, keys: list[str]) -> np.ndarray: - result = np.random.rand(len(keys), 3).astype(np.float32) - norms = np.linalg.norm(result, axis=1, keepdims=True) - return result / norms - - @pytest.mark.asyncio async def test_property_index_population_from_database(really_needs_auth): """Test that property index is correctly populated when reopening a database.""" @@ -42,7 +31,7 @@ async def test_property_index_population_from_database(really_needs_auth): temp_db_file.close() try: - embedding_model = MockEmbeddingModel() + embedding_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(embedding_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) @@ -98,7 +87,7 @@ async def test_property_index_population_from_database(really_needs_auth): # Reopen database and verify property index # Use the same embedding settings to avoid dimension mismatch - embedding_model2 = MockEmbeddingModel() + embedding_model2 = create_test_embedding_model() embedding_settings2 = TextEmbeddingIndexSettings(embedding_model2) message_text_settings2 = MessageTextIndexSettings(embedding_settings2) related_terms_settings2 = RelatedTermIndexSettings(embedding_settings2) diff --git a/tests/test_query_method.py b/tests/test_query_method.py index ac605823..bcbf2e00 100644 --- a/tests/test_query_method.py +++ b/tests/test_query_method.py @@ -6,7 +6,7 @@ import pytest from typeagent import create_conversation -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.transcripts.transcript import TranscriptMessage, TranscriptMessageMeta @@ -15,7 +15,7 @@ async def test_query_method_basic(really_needs_auth: None): """Test the basic query method workflow.""" # Create a conversation with some test data - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, @@ -60,7 +60,7 @@ async def test_query_method_basic(really_needs_auth: None): @pytest.mark.asyncio async def test_query_method_empty_conversation(really_needs_auth: None): """Test query method on an empty conversation.""" - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) conversation = await create_conversation( None, diff --git a/tests/test_related_terms_fast.py b/tests/test_related_terms_fast.py index 3919666f..fbcf60c5 100644 --- a/tests/test_related_terms_fast.py +++ b/tests/test_related_terms_fast.py @@ -9,7 +9,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.interfaces import SemanticRef, TextLocation, TextRange from typeagent.knowpro.kplib import ConcreteEntity @@ -26,7 +26,7 @@ async def test_related_terms_index_minimal(): try: # Create minimal test data with test embedding model (no API keys needed) - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() settings = ConversationSettings(model=test_model) # Use a simple storage provider without AI embeddings diff --git a/tests/test_related_terms_index_population.py b/tests/test_related_terms_index_population.py index 9d16936f..9de6f015 100644 --- a/tests/test_related_terms_index_population.py +++ b/tests/test_related_terms_index_population.py @@ -10,7 +10,7 @@ from dotenv import load_dotenv import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro import kplib from typeagent.knowpro.convsettings import ( @@ -32,7 +32,7 @@ async def test_related_terms_index_population_from_database(really_needs_auth): try: # Use the test model that's already configured in the system - embedding_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(embedding_model) message_text_settings = MessageTextIndexSettings(embedding_settings) related_terms_settings = RelatedTermIndexSettings(embedding_settings) diff --git a/tests/test_secindex.py b/tests/test_secindex.py index a9008aa3..39665b05 100644 --- a/tests/test_secindex.py +++ b/tests/test_secindex.py @@ -3,7 +3,7 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import ( ConversationSettings, @@ -29,9 +29,9 @@ def simple_conversation() -> FakeConversation: @pytest.fixture def conversation_settings(needs_auth: None) -> ConversationSettings: - from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.model_adapters import create_test_embedding_model - model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + model = create_test_embedding_model() return ConversationSettings(model) @@ -41,7 +41,7 @@ def test_conversation_secondary_indexes_initialization( """Test initialization of ConversationSecondaryIndexes.""" storage_provider = memory_storage # Create proper settings for testing - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) settings = RelatedTermIndexSettings(embedding_settings) indexes = ConversationSecondaryIndexes(storage_provider, settings) diff --git a/tests/test_secindex_storage_integration.py b/tests/test_secindex_storage_integration.py index a050771b..15738bb6 100644 --- a/tests/test_secindex_storage_integration.py +++ b/tests/test_secindex_storage_integration.py @@ -4,7 +4,7 @@ # Test that ConversationSecondaryIndexes now uses storage provider properly import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings from typeagent.knowpro.convsettings import RelatedTermIndexSettings from typeagent.knowpro.secindex import ConversationSecondaryIndexes @@ -19,7 +19,7 @@ async def test_secondary_indexes_use_storage_provider( storage_provider = memory_storage # Create test settings - test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + test_model = create_test_embedding_model() embedding_settings = TextEmbeddingIndexSettings(test_model) related_terms_settings = RelatedTermIndexSettings(embedding_settings) diff --git a/tests/test_transcripts.py b/tests/test_transcripts.py index 0d0bfb57..0286f5bc 100644 --- a/tests/test_transcripts.py +++ b/tests/test_transcripts.py @@ -6,7 +6,8 @@ import pytest -from typeagent.aitools.embeddings import AsyncEmbeddingModel, IEmbeddingModel +from typeagent.aitools.embeddings import IEmbeddingModel +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.knowpro.universal_message import format_timestamp_utc, UNIX_EPOCH from typeagent.transcripts.transcript import ( @@ -224,10 +225,8 @@ def test_transcript_message_creation(): @pytest.mark.asyncio async def test_transcript_creation(): """Test creating an empty transcript.""" - from typeagent.aitools.embeddings import TEST_MODEL_NAME - # Create a minimal transcript for testing structure - embedding_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_model = create_test_embedding_model() settings = ConversationSettings(embedding_model) transcript = await Transcript.create( diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index be8d2c07..4a92b23e 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -5,10 +5,9 @@ import pytest from typeagent.aitools.embeddings import ( - AsyncEmbeddingModel, NormalizedEmbedding, - TEST_MODEL_NAME, ) +from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase @@ -19,9 +18,7 @@ def vector_base() -> VectorBase: def make_vector_base() -> VectorBase: - settings = TextEmbeddingIndexSettings( - AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) - ) + settings = TextEmbeddingIndexSettings(create_test_embedding_model()) return VectorBase(settings) @@ -61,8 +58,8 @@ def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples): assert len(bulk_vector_base) == len(vector_base) np.testing.assert_array_equal(bulk_vector_base.serialize(), vector_base.serialize()) - sequential_cache = vector_base._model._embedding_cache # type: ignore[attr-defined] - bulk_cache = bulk_vector_base._model._embedding_cache # type: ignore[attr-defined] + sequential_cache = vector_base._model._cache # type: ignore[attr-defined] + bulk_cache = bulk_vector_base._model._cache # type: ignore[attr-defined] assert set(sequential_cache.keys()) == set(bulk_cache.keys()) for key in keys: np.testing.assert_array_equal(bulk_cache[key], sequential_cache[key]) @@ -85,7 +82,7 @@ async def test_add_key_no_cache(vector_base: VectorBase, sample_embeddings: Samp assert len(vector_base) == len(sample_embeddings) assert ( - vector_base._model._embedding_cache == {} # type: ignore[attr-defined] + vector_base._model._cache == {} # type: ignore[attr-defined] ), "Cache should remain empty when cache=False" @@ -106,7 +103,7 @@ async def test_add_keys_no_cache(vector_base: VectorBase, sample_embeddings: Sam assert len(vector_base) == len(sample_embeddings) assert ( - vector_base._model._embedding_cache == {} # type: ignore[attr-defined] + vector_base._model._cache == {} # type: ignore[attr-defined] ), "Cache should remain empty when cache=False" From dec2e6fe5c2c3f313a35f9b54ce0a0ae63cdeb0b Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 11:46:53 -0800 Subject: [PATCH 14/49] Move in-function imports to module level in tests/ --- tests/test_mcp_server.py | 22 +++-------- tests/test_messageindex.py | 3 +- tests/test_model_adapters.py | 51 ++++++------------------- tests/test_sqlitestore.py | 3 +- tests/test_storage_providers_unified.py | 5 +-- tests/test_transcripts.py | 5 +-- tests/test_utils.py | 7 ++-- 7 files changed, 25 insertions(+), 71 deletions(-) diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 03fd0e69..24933ca2 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -3,16 +3,21 @@ """End-to-end tests for the MCP server.""" +import json import os import sys from typing import Any import pytest -from mcp import StdioServerParameters +from mcp import ClientSession, StdioServerParameters from mcp.client.session import ClientSession as ClientSessionType +from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent +from openai.types.chat import ChatCompletionMessageParam + +from typeagent.aitools.utils import create_async_openai_client from conftest import EPISODE_53_INDEX @@ -38,11 +43,6 @@ async def sampling_callback( params: CreateMessageRequestParams, ) -> CreateMessageResult: """Sampling callback that uses OpenAI to generate responses.""" - # Use OpenAI to generate a response - from openai.types.chat import ChatCompletionMessageParam - - from typeagent.aitools.utils import create_async_openai_client - client = create_async_openai_client() # Convert MCP SamplingMessage to OpenAI format @@ -91,9 +91,6 @@ async def test_mcp_server_query_conversation_slow( really_needs_auth, server_params: StdioServerParameters ): """Test the query_conversation tool end-to-end using MCP client.""" - from mcp import ClientSession - from mcp.client.stdio import stdio_client - # Pass through environment variables needed for authentication # otherwise this test will fail in the CI on Windows only if not (server_params.env) is None: @@ -135,8 +132,6 @@ async def test_mcp_server_query_conversation_slow( response_text = content_item.text # Parse response (it should be JSON with success, answer, time_used) - import json - try: response_data = json.loads(response_text) except json.JSONDecodeError as e: @@ -158,9 +153,6 @@ async def test_mcp_server_query_conversation_slow( @pytest.mark.asyncio async def test_mcp_server_empty_question(server_params: StdioServerParameters): """Test the query_conversation tool with an empty question.""" - from mcp import ClientSession - from mcp.client.stdio import stdio_client - # Create client session and connect to server async with stdio_client(server_params) as (read, write): async with ClientSession( @@ -183,8 +175,6 @@ async def test_mcp_server_empty_question(server_params: StdioServerParameters): assert isinstance(content_item, TextContent) response_text = content_item.text - import json - response_data = json.loads(response_text) assert response_data["success"] is False assert "No question provided" in response_data["answer"] diff --git a/tests/test_messageindex.py b/tests/test_messageindex.py index ec80498f..f91ac933 100644 --- a/tests/test_messageindex.py +++ b/tests/test_messageindex.py @@ -4,6 +4,7 @@ from typing import cast from unittest.mock import AsyncMock, MagicMock +import numpy as np import pytest from typeagent.knowpro.convsettings import MessageTextIndexSettings @@ -145,8 +146,6 @@ async def test_lookup_messages_in_subset( @pytest.mark.asyncio async def test_generate_embedding(needs_auth: None): """Test generating an embedding for a message without mocking.""" - import numpy as np - from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index 5e773430..dbff97cd 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -1,9 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +from unittest.mock import AsyncMock + import numpy as np import pytest +from pydantic_ai import Embedder +from pydantic_ai.embeddings import EmbeddingResult +from pydantic_ai.messages import ( + ModelResponse, + SystemPromptPart, + TextPart, + UserPromptPart, +) +from pydantic_ai.models import Model import typechat from typeagent.aitools.embeddings import IEmbeddingModel, NormalizedEmbedding @@ -57,11 +68,6 @@ def test_chat_model_is_typechat_model() -> None: @pytest.mark.asyncio async def test_chat_adapter_complete() -> None: """PydanticAIChatModel wraps a pydantic_ai Model.""" - from unittest.mock import AsyncMock - - from pydantic_ai.messages import ModelResponse, TextPart - from pydantic_ai.models import Model - mock_model = AsyncMock(spec=Model) mock_model.request.return_value = ModelResponse(parts=[TextPart(content="hello")]) @@ -74,11 +80,6 @@ async def test_chat_adapter_complete() -> None: @pytest.mark.asyncio async def test_chat_adapter_prompt_sections() -> None: """PydanticAIChatModel handles list[PromptSection] prompts.""" - from unittest.mock import AsyncMock - - from pydantic_ai.messages import ModelResponse, TextPart - from pydantic_ai.models import Model - mock_model = AsyncMock(spec=Model) mock_model.request.return_value = ModelResponse( parts=[TextPart(content="response")] @@ -98,8 +99,6 @@ async def test_chat_adapter_prompt_sections() -> None: messages = call_args[0][0] assert len(messages) == 1 request = messages[0] - from pydantic_ai.messages import SystemPromptPart, UserPromptPart - assert isinstance(request.parts[0], SystemPromptPart) assert isinstance(request.parts[1], UserPromptPart) @@ -117,11 +116,6 @@ def test_embedding_model_is_iembedding_model() -> None: @pytest.mark.asyncio async def test_embedding_adapter_single() -> None: """PydanticAIEmbeddingModel computes a single normalized embedding.""" - from unittest.mock import AsyncMock - - from pydantic_ai import Embedder - from pydantic_ai.embeddings import EmbeddingResult - mock_embedder = AsyncMock(spec=Embedder) raw_vec = [3.0, 4.0, 0.0] mock_embedder.embed_documents.return_value = EmbeddingResult( @@ -142,11 +136,6 @@ async def test_embedding_adapter_single() -> None: @pytest.mark.asyncio async def test_embedding_adapter_probes_size() -> None: """embedding_size is discovered from the first embedding call.""" - from unittest.mock import AsyncMock - - from pydantic_ai import Embedder - from pydantic_ai.embeddings import EmbeddingResult - mock_embedder = AsyncMock(spec=Embedder) mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0, 0.0]], @@ -165,11 +154,6 @@ async def test_embedding_adapter_probes_size() -> None: @pytest.mark.asyncio async def test_embedding_adapter_batch() -> None: """PydanticAIEmbeddingModel computes batch embeddings.""" - from unittest.mock import AsyncMock - - from pydantic_ai import Embedder - from pydantic_ai.embeddings import EmbeddingResult - mock_embedder = AsyncMock(spec=Embedder) mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0], [0.0, 1.0]], @@ -187,11 +171,6 @@ async def test_embedding_adapter_batch() -> None: @pytest.mark.asyncio async def test_embedding_adapter_caching() -> None: """Caching avoids re-computing embeddings.""" - from unittest.mock import AsyncMock - - from pydantic_ai import Embedder - from pydantic_ai.embeddings import EmbeddingResult - mock_embedder = AsyncMock(spec=Embedder) mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0, 0.0]], @@ -212,10 +191,6 @@ async def test_embedding_adapter_caching() -> None: @pytest.mark.asyncio async def test_embedding_adapter_add_embedding() -> None: """add_embedding() populates the cache.""" - from unittest.mock import AsyncMock - - from pydantic_ai import Embedder - mock_embedder = AsyncMock(spec=Embedder) adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) vec: NormalizedEmbedding = np.array([1.0, 0.0, 0.0], dtype=np.float32) @@ -229,10 +204,6 @@ async def test_embedding_adapter_add_embedding() -> None: @pytest.mark.asyncio async def test_embedding_adapter_empty_batch() -> None: """Empty batch returns empty array with known size.""" - from unittest.mock import AsyncMock - - from pydantic_ai import Embedder - mock_embedder = AsyncMock(spec=Embedder) adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 4) result = await adapter.get_embeddings_nocache([]) diff --git a/tests/test_sqlitestore.py b/tests/test_sqlitestore.py index ad784221..704ab0bb 100644 --- a/tests/test_sqlitestore.py +++ b/tests/test_sqlitestore.py @@ -3,6 +3,7 @@ from collections.abc import AsyncGenerator from dataclasses import field +from datetime import datetime import pytest import pytest_asyncio @@ -128,8 +129,6 @@ async def test_sqlite_timestamp_index( dummy_sqlite_storage_provider: SqliteStorageProvider[DummyMessage], ): """Test SqliteTimestampToTextRangeIndex functionality.""" - from datetime import datetime - from typeagent.knowpro.interfaces import DateRange # Set up database with some messages diff --git a/tests/test_storage_providers_unified.py b/tests/test_storage_providers_unified.py index 14829f03..d0ecb9c5 100644 --- a/tests/test_storage_providers_unified.py +++ b/tests/test_storage_providers_unified.py @@ -9,6 +9,8 @@ """ from dataclasses import field +import os +import tempfile from typing import assert_never, AsyncGenerator import pytest @@ -605,9 +607,6 @@ async def test_storage_provider_independence( ) # Create two sqlite providers (with different temp files) - import os - import tempfile - temp_file1 = tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) temp_path1 = temp_file1.name temp_file1.close() diff --git a/tests/test_transcripts.py b/tests/test_transcripts.py index 0286f5bc..9d98ae88 100644 --- a/tests/test_transcripts.py +++ b/tests/test_transcripts.py @@ -5,6 +5,7 @@ import os import pytest +import webvtt from typeagent.aitools.embeddings import IEmbeddingModel from typeagent.aitools.model_adapters import create_test_embedding_model @@ -102,8 +103,6 @@ def conversation_settings( @pytest.mark.asyncio async def test_ingest_vtt_transcript(conversation_settings: ConversationSettings): """Test importing a VTT file into a Transcript object.""" - import webvtt - from typeagent.storage.memory.collections import ( MemoryMessageCollection, MemorySemanticRefCollection, @@ -253,8 +252,6 @@ async def test_transcript_knowledge_extraction_slow( 4. Verifies both mechanical extraction (entities/actions from metadata) and LLM extraction (topics from content) work correctly """ - import webvtt - from typeagent.storage.memory.collections import ( MemoryMessageCollection, MemorySemanticRefCollection, diff --git a/tests/test_utils.py b/tests/test_utils.py index cb95c93a..ceea367d 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -7,6 +7,9 @@ from dotenv import load_dotenv +import pydantic.dataclasses +import typechat + import typeagent.aitools.utils as utils @@ -37,14 +40,10 @@ def test_load_dotenv(really_needs_auth): def test_create_translator(): - import typechat - class DummyModel(typechat.TypeChatLanguageModel): async def complete(self, *args, **kwargs) -> typechat.Result: return typechat.Failure("dummy response") - import pydantic.dataclasses - @pydantic.dataclasses.dataclass class DummySchema: pass From 909247d7c86da51019ec87d2def57188022f2844 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 12:01:56 -0800 Subject: [PATCH 15/49] Don't re-export create_typechat_model from convknowledge.py --- src/typeagent/emails/email_memory.py | 3 +-- src/typeagent/knowpro/conversation_base.py | 4 ++-- src/typeagent/knowpro/convknowledge.py | 5 +---- src/typeagent/knowpro/knowledge.py | 3 ++- tools/query.py | 3 +-- 5 files changed, 7 insertions(+), 11 deletions(-) diff --git a/src/typeagent/emails/email_memory.py b/src/typeagent/emails/email_memory.py index d6cf06cb..3da149df 100644 --- a/src/typeagent/emails/email_memory.py +++ b/src/typeagent/emails/email_memory.py @@ -12,7 +12,6 @@ from ..knowpro import ( answer_response_schema, answers, - convknowledge, search_query_schema, searchlang, ) @@ -24,7 +23,7 @@ class EmailMemorySettings: def __init__(self, conversation_settings: ConversationSettings) -> None: - self.language_model = convknowledge.create_typechat_model() + self.language_model = utils.create_typechat_model() self.query_translator = utils.create_translator( self.language_model, search_query_schema.SearchQuery ) diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 732e239a..74d34b19 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -352,12 +352,12 @@ async def query( """ # Create translators lazily (once per conversation instance) if self._query_translator is None: - model = convknowledge.create_typechat_model() + model = utils.create_typechat_model() self._query_translator = utils.create_translator( model, search_query_schema.SearchQuery ) if self._answer_translator is None: - model = convknowledge.create_typechat_model() + model = utils.create_typechat_model() self._answer_translator = utils.create_translator( model, answer_response_schema.AnswerResponse ) diff --git a/src/typeagent/knowpro/convknowledge.py b/src/typeagent/knowpro/convknowledge.py index 53a10b25..6f9000d3 100644 --- a/src/typeagent/knowpro/convknowledge.py +++ b/src/typeagent/knowpro/convknowledge.py @@ -6,10 +6,7 @@ import typechat from . import kplib -from ..aitools.utils import create_typechat_model # Re-export for backward compat - -# Re-export: callers may still do ``convknowledge.create_typechat_model()``. -__all__ = ["create_typechat_model", "KnowledgeExtractor"] +from ..aitools.utils import create_typechat_model @dataclass diff --git a/src/typeagent/knowpro/knowledge.py b/src/typeagent/knowpro/knowledge.py index bda7397f..48d5b8fa 100644 --- a/src/typeagent/knowpro/knowledge.py +++ b/src/typeagent/knowpro/knowledge.py @@ -8,6 +8,7 @@ from typechat import Result, TypeChatLanguageModel from . import convknowledge, kplib +from ..aitools import utils from .interfaces import IKnowledgeExtractor @@ -15,7 +16,7 @@ def create_knowledge_extractor( chat_model: TypeChatLanguageModel | None = None, ) -> convknowledge.KnowledgeExtractor: """Create a knowledge extractor using the given Chat Model.""" - chat_model = chat_model or convknowledge.create_typechat_model() + chat_model = chat_model or utils.create_typechat_model() extractor = convknowledge.KnowledgeExtractor( chat_model, max_chars_per_chunk=4096, merge_action_knowledge=False ) diff --git a/tools/query.py b/tools/query.py index 24b1f4c9..c9817803 100644 --- a/tools/query.py +++ b/tools/query.py @@ -36,7 +36,6 @@ from typeagent.knowpro import ( answer_response_schema, answers, - convknowledge, kplib, query, search, @@ -576,7 +575,7 @@ async def main(): "Error: non-empty --search-results required for batch mode." ) - model = convknowledge.create_typechat_model() + model = utils.create_typechat_model() query_translator = utils.create_translator(model, search_query_schema.SearchQuery) if args.alt_schema: if args.verbose: From 8807cc57e417f64983944cded8cac391adaaebaa Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 12:12:19 -0800 Subject: [PATCH 16/49] Remove redundant tests that Chat/Embedding models subclass their protocols --- tests/test_model_adapters.py | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index dbff97cd..c7ccf8f9 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -17,7 +17,7 @@ from pydantic_ai.models import Model import typechat -from typeagent.aitools.embeddings import IEmbeddingModel, NormalizedEmbedding +from typeagent.aitools.embeddings import NormalizedEmbedding from typeagent.aitools.model_adapters import ( configure_models, create_chat_model, @@ -60,11 +60,6 @@ def test_default_embedding_size_is_zero() -> None: # --------------------------------------------------------------------------- -def test_chat_model_is_typechat_model() -> None: - """PydanticAIChatModel inherits from TypeChatLanguageModel.""" - assert typechat.TypeChatLanguageModel in PydanticAIChatModel.__mro__ - - @pytest.mark.asyncio async def test_chat_adapter_complete() -> None: """PydanticAIChatModel wraps a pydantic_ai Model.""" @@ -108,11 +103,6 @@ async def test_chat_adapter_prompt_sections() -> None: # --------------------------------------------------------------------------- -def test_embedding_model_is_iembedding_model() -> None: - """PydanticAIEmbeddingModel inherits from IEmbeddingModel.""" - assert IEmbeddingModel in PydanticAIEmbeddingModel.__mro__ - - @pytest.mark.asyncio async def test_embedding_adapter_single() -> None: """PydanticAIEmbeddingModel computes a single normalized embedding.""" @@ -223,4 +213,3 @@ def test_configure_models_returns_correct_types( chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") assert isinstance(chat, PydanticAIChatModel) assert isinstance(embedder, PydanticAIEmbeddingModel) - assert typechat.TypeChatLanguageModel in type(chat).__mro__ From 68d3082faafaa388bf2df2da3c47ddc1ee5fa730 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 12:28:45 -0800 Subject: [PATCH 17/49] Avoid type-ignore in favor of isinstance --- tests/test_vectorbase.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index 4a92b23e..04c53c36 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -7,7 +7,10 @@ from typeagent.aitools.embeddings import ( NormalizedEmbedding, ) -from typeagent.aitools.model_adapters import create_test_embedding_model +from typeagent.aitools.model_adapters import ( + create_test_embedding_model, + PydanticAIEmbeddingModel, +) from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase @@ -58,8 +61,10 @@ def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples): assert len(bulk_vector_base) == len(vector_base) np.testing.assert_array_equal(bulk_vector_base.serialize(), vector_base.serialize()) - sequential_cache = vector_base._model._cache # type: ignore[attr-defined] - bulk_cache = bulk_vector_base._model._cache # type: ignore[attr-defined] + assert isinstance(vector_base._model, PydanticAIEmbeddingModel) + assert isinstance(bulk_vector_base._model, PydanticAIEmbeddingModel) + sequential_cache = vector_base._model._cache + bulk_cache = bulk_vector_base._model._cache assert set(sequential_cache.keys()) == set(bulk_cache.keys()) for key in keys: np.testing.assert_array_equal(bulk_cache[key], sequential_cache[key]) @@ -81,9 +86,8 @@ async def test_add_key_no_cache(vector_base: VectorBase, sample_embeddings: Samp await vector_base.add_key(key, cache=False) assert len(vector_base) == len(sample_embeddings) - assert ( - vector_base._model._cache == {} # type: ignore[attr-defined] - ), "Cache should remain empty when cache=False" + assert isinstance(vector_base._model, PydanticAIEmbeddingModel) + assert vector_base._model._cache == {}, "Cache should remain empty when cache=False" @pytest.mark.asyncio @@ -102,9 +106,8 @@ async def test_add_keys_no_cache(vector_base: VectorBase, sample_embeddings: Sam await vector_base.add_keys(keys, cache=False) assert len(vector_base) == len(sample_embeddings) - assert ( - vector_base._model._cache == {} # type: ignore[attr-defined] - ), "Cache should remain empty when cache=False" + assert isinstance(vector_base._model, PydanticAIEmbeddingModel) + assert vector_base._model._cache == {}, "Cache should remain empty when cache=False" @pytest.mark.asyncio From 3697f89a043ddb5bd28b04d44435cf0513ad9179 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 15:39:37 -0800 Subject: [PATCH 18/49] Remove ModelWrapper, create_typechat_model; use create_chat_model everywhere --- AGENTS.md | 3 ++ src/typeagent/aitools/model_adapters.py | 43 +++++++++++++--- src/typeagent/aitools/utils.py | 59 ---------------------- src/typeagent/emails/email_memory.py | 4 +- src/typeagent/knowpro/conversation_base.py | 6 +-- src/typeagent/knowpro/convknowledge.py | 4 +- src/typeagent/knowpro/knowledge.py | 4 +- tools/query.py | 4 +- 8 files changed, 49 insertions(+), 78 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index f08200b2..dafd9da6 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -93,3 +93,6 @@ please follow these guidelines: * **Code Validation**: Don't use `py_compile` for syntax checking. Use `pyright` or `make check` instead for proper type checking and validation. + +* **Deprecations**: Don't deprecate things -- just delete them and fix the usage sites. + Don't create backward compatibility APIs or exports or whatever. Fix the usage sites. diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 4acc0ff8..7f73d2a5 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -198,20 +198,34 @@ def _make_azure_provider( may contain a full Azure deployment URL (including path and ``api-version`` query parameter) — the same format used throughout this codebase. + + When ``AZURE_OPENAI_API_KEY`` is set to ``"identity"``, the client + uses Azure Managed Identity via a token provider callback, which + refreshes tokens automatically before each request. """ from openai import AsyncAzureOpenAI from pydantic_ai.providers.azure import AzureProvider - from .utils import get_azure_api_key, parse_azure_endpoint + from .utils import parse_azure_endpoint raw_key = os.environ[api_key_envvar] - api_key = get_azure_api_key(raw_key) azure_endpoint, api_version = parse_azure_endpoint(endpoint_envvar) - client = AsyncAzureOpenAI( - azure_endpoint=azure_endpoint, - api_version=api_version, - api_key=api_key, - ) + + if raw_key.lower() == "identity": + from .auth import get_shared_token_provider + + token_provider = get_shared_token_provider() + client = AsyncAzureOpenAI( + azure_endpoint=azure_endpoint, + api_version=api_version, + azure_ad_token_provider=token_provider.get_token, + ) + else: + client = AsyncAzureOpenAI( + azure_endpoint=azure_endpoint, + api_version=api_version, + api_key=raw_key, + ) return AzureProvider(openai_client=client) @@ -220,8 +234,11 @@ def _make_azure_provider( # --------------------------------------------------------------------------- +DEFAULT_CHAT_SPEC = "openai:gpt-4o" + + def create_chat_model( - model_spec: str, + model_spec: str | None = None, ) -> PydanticAIChatModel: """Create a chat model from a ``provider:model`` spec. @@ -229,12 +246,22 @@ def create_chat_model( If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. + If *model_spec* is ``None``, it is constructed from the ``OPENAI_MODEL`` + environment variable (falling back to :data:`DEFAULT_CHAT_SPEC`). + Examples:: + model = create_chat_model() # uses OPENAI_MODEL or gpt-4o model = create_chat_model("openai:gpt-4o") model = create_chat_model("anthropic:claude-sonnet-4-20250514") model = create_chat_model("google:gemini-2.0-flash") """ + if model_spec is None: + openai_model = os.getenv("OPENAI_MODEL") + if openai_model: + model_spec = f"openai:{openai_model}" + else: + model_spec = DEFAULT_CHAT_SPEC provider, _, model_name = model_spec.partition(":") if _needs_azure_fallback(provider): from pydantic_ai.models.openai import OpenAIChatModel diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index 9d57b854..b9150334 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -3,7 +3,6 @@ """Utilities that are hard to fit in any specific module.""" -import asyncio from contextlib import contextmanager import difflib import os @@ -17,8 +16,6 @@ import typechat -from .auth import AzureTokenProvider, get_shared_token_provider - @contextmanager def timelog(label: str, verbose: bool = True): @@ -90,62 +87,6 @@ def create_translator[T]( return translator -# TODO: Make these parameters that can be configured (e.g. from command line). -DEFAULT_MAX_RETRY_ATTEMPTS = 0 -DEFAULT_TIMEOUT_SECONDS = 25 - - -class ModelWrapper(typechat.TypeChatLanguageModel): - """Wraps a TypeChat model to handle Azure token refresh.""" - - def __init__( - self, - base_model: typechat.TypeChatLanguageModel, - token_provider: AzureTokenProvider, - ): - self.base_model = base_model - self.token_provider = token_provider - - async def complete( - self, prompt: str | list[typechat.PromptSection] - ) -> typechat.Result[str]: - if self.token_provider.needs_refresh(): - loop = asyncio.get_running_loop() - api_key = await loop.run_in_executor( - None, self.token_provider.refresh_token - ) - env: dict[str, str | None] = dict(os.environ) - key_name = "AZURE_OPENAI_API_KEY" - env[key_name] = api_key - self.base_model = typechat.create_language_model(env) - self.base_model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS - return await self.base_model.complete(prompt) - - -def create_typechat_model() -> typechat.TypeChatLanguageModel: - """Create a TypeChat language model using OpenAI or Azure OpenAI. - - Auto-detects the provider from ``OPENAI_API_KEY`` / ``AZURE_OPENAI_API_KEY`` - environment variables. - - For explicit provider selection, use :func:`model_adapters.create_chat_model` - with a spec string like ``"openai:gpt-4o"`` or ``"azure:my-deployment"``. - """ - env: dict[str, str | None] = dict(os.environ) - key_name = "AZURE_OPENAI_API_KEY" - key = env.get(key_name) - shared_token_provider: AzureTokenProvider | None = None - if key is not None and key.lower() == "identity": - shared_token_provider = get_shared_token_provider() - env[key_name] = shared_token_provider.get_token() - model = typechat.create_language_model(env) - model.timeout_seconds = DEFAULT_TIMEOUT_SECONDS - model.max_retry_attempts = DEFAULT_MAX_RETRY_ATTEMPTS - if shared_token_provider is not None: - model = ModelWrapper(model, shared_token_provider) - return model - - # Vibe-coded by o4-mini-high def list_diff(label_a, a, label_b, b, max_items): """Print colorized diff between two sorted list of numbers.""" diff --git a/src/typeagent/emails/email_memory.py b/src/typeagent/emails/email_memory.py index 3da149df..6dd50cc4 100644 --- a/src/typeagent/emails/email_memory.py +++ b/src/typeagent/emails/email_memory.py @@ -8,7 +8,7 @@ import typechat -from ..aitools import utils +from ..aitools import model_adapters, utils from ..knowpro import ( answer_response_schema, answers, @@ -23,7 +23,7 @@ class EmailMemorySettings: def __init__(self, conversation_settings: ConversationSettings) -> None: - self.language_model = utils.create_typechat_model() + self.language_model = model_adapters.create_chat_model() self.query_translator = utils.create_translator( self.language_model, search_query_schema.SearchQuery ) diff --git a/src/typeagent/knowpro/conversation_base.py b/src/typeagent/knowpro/conversation_base.py index 74d34b19..07ea1553 100644 --- a/src/typeagent/knowpro/conversation_base.py +++ b/src/typeagent/knowpro/conversation_base.py @@ -18,7 +18,7 @@ searchlang, secindex, ) -from ..aitools import utils +from ..aitools import model_adapters, utils from ..storage.memory import semrefindex from .convsettings import ConversationSettings from .interfaces import ( @@ -352,12 +352,12 @@ async def query( """ # Create translators lazily (once per conversation instance) if self._query_translator is None: - model = utils.create_typechat_model() + model = model_adapters.create_chat_model() self._query_translator = utils.create_translator( model, search_query_schema.SearchQuery ) if self._answer_translator is None: - model = utils.create_typechat_model() + model = model_adapters.create_chat_model() self._answer_translator = utils.create_translator( model, answer_response_schema.AnswerResponse ) diff --git a/src/typeagent/knowpro/convknowledge.py b/src/typeagent/knowpro/convknowledge.py index 6f9000d3..fe1d5f5c 100644 --- a/src/typeagent/knowpro/convknowledge.py +++ b/src/typeagent/knowpro/convknowledge.py @@ -6,12 +6,12 @@ import typechat from . import kplib -from ..aitools.utils import create_typechat_model +from ..aitools.model_adapters import create_chat_model @dataclass class KnowledgeExtractor: - model: typechat.TypeChatLanguageModel = field(default_factory=create_typechat_model) + model: typechat.TypeChatLanguageModel = field(default_factory=create_chat_model) max_chars_per_chunk: int = 2048 merge_action_knowledge: bool = ( False # TODO: Implement merge_action_knowledge_into_response diff --git a/src/typeagent/knowpro/knowledge.py b/src/typeagent/knowpro/knowledge.py index 48d5b8fa..dbcfe206 100644 --- a/src/typeagent/knowpro/knowledge.py +++ b/src/typeagent/knowpro/knowledge.py @@ -8,7 +8,7 @@ from typechat import Result, TypeChatLanguageModel from . import convknowledge, kplib -from ..aitools import utils +from ..aitools import model_adapters from .interfaces import IKnowledgeExtractor @@ -16,7 +16,7 @@ def create_knowledge_extractor( chat_model: TypeChatLanguageModel | None = None, ) -> convknowledge.KnowledgeExtractor: """Create a knowledge extractor using the given Chat Model.""" - chat_model = chat_model or utils.create_typechat_model() + chat_model = chat_model or model_adapters.create_chat_model() extractor = convknowledge.KnowledgeExtractor( chat_model, max_chars_per_chunk=4096, merge_action_knowledge=False ) diff --git a/tools/query.py b/tools/query.py index c9817803..b3a35091 100644 --- a/tools/query.py +++ b/tools/query.py @@ -32,7 +32,7 @@ import typechat -from typeagent.aitools import embeddings, utils +from typeagent.aitools import embeddings, model_adapters, utils from typeagent.knowpro import ( answer_response_schema, answers, @@ -575,7 +575,7 @@ async def main(): "Error: non-empty --search-results required for batch mode." ) - model = utils.create_typechat_model() + model = model_adapters.create_chat_model() query_translator = utils.create_translator(model, search_query_schema.SearchQuery) if args.alt_schema: if args.verbose: From 087b7a3f6727d0be8f411ab35458360e1cedcfab Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 16:30:06 -0800 Subject: [PATCH 19/49] Split up *EmbeddingModel into IEmbedder and CachingEmbeddingModel --- src/typeagent/aitools/embeddings.py | 90 +++++++++++++++++++++++-- src/typeagent/aitools/model_adapters.py | 61 +++++++---------- src/typeagent/aitools/vectorbase.py | 1 - tests/test_embeddings.py | 35 +++++----- tests/test_model_adapters.py | 29 ++++---- tests/test_vectorbase.py | 10 +-- 6 files changed, 146 insertions(+), 80 deletions(-) diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index f56db25d..e0e04b7a 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -11,16 +11,46 @@ @runtime_checkable -class IEmbeddingModel(Protocol): - """Provider-agnostic interface for embedding models. +class IEmbedder(Protocol): + """Minimal provider interface for embedding models. Implement this protocol to add support for a new embedding provider - (e.g. Anthropic, Gemini, local models). The production implementation - is :class:`~typeagent.aitools.model_adapters.PydanticAIEmbeddingModel`. + (e.g. Anthropic, Gemini, local models). Only raw embedding computation + is required; caching is handled by :class:`CachingEmbeddingModel`. + + The production implementation is + :class:`~typeagent.aitools.model_adapters.PydanticAIEmbedder`. + """ + + @property + def model_name(self) -> str: ... + + @property + def embedding_size(self) -> int: ... + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + """Compute a single embedding without caching.""" + ... + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + """Compute embeddings for a batch of strings without caching.""" + ... + + +@runtime_checkable +class IEmbeddingModel(Protocol): + """Consumer-facing interface for embedding models with caching. + + This extends the provider interface (:class:`IEmbedder`) with caching + methods. Use :class:`CachingEmbeddingModel` to wrap an :class:`IEmbedder` + and get a ready-to-use ``IEmbeddingModel``. """ - model_name: str - embedding_size: int + @property + def model_name(self) -> str: ... + + @property + def embedding_size(self) -> int: ... def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: """Cache an already-computed embedding under the given key.""" @@ -43,6 +73,54 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: ... +class CachingEmbeddingModel: + """Wraps an :class:`IEmbedder` with an in-memory embedding cache. + + This shared base class implements the caching logic once, so individual + embedding providers only need to implement the minimal :class:`IEmbedder` + protocol (``get_embedding_nocache`` / ``get_embeddings_nocache``). + """ + + def __init__(self, embedder: IEmbedder) -> None: + self._embedder = embedder + self._cache: dict[str, NormalizedEmbedding] = {} + + @property + def model_name(self) -> str: + return self._embedder.model_name + + @property + def embedding_size(self) -> int: + return self._embedder.embedding_size + + def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: + self._cache[key] = embedding + + async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: + return await self._embedder.get_embedding_nocache(input) + + async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: + return await self._embedder.get_embeddings_nocache(input) + + async def get_embedding(self, key: str) -> NormalizedEmbedding: + cached = self._cache.get(key) + if cached is not None: + return cached + embedding = await self._embedder.get_embedding_nocache(key) + self._cache[key] = embedding + return embedding + + async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: + if not keys: + return await self._embedder.get_embeddings_nocache([]) + missing_keys = [k for k in keys if k not in self._cache] + if missing_keys: + fresh = await self._embedder.get_embeddings_nocache(missing_keys) + for i, k in enumerate(missing_keys): + self._cache[k] = fresh[i] + return np.array([self._cache[k] for k in keys], dtype=np.float32) + + DEFAULT_MODEL_NAME = "text-embedding-ada-002" DEFAULT_EMBEDDING_SIZE = 1536 # Default embedding size (required for ada-002) DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING" # We support OpenAI and Azure OpenAI diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 7f73d2a5..6ccf8292 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -46,7 +46,11 @@ from pydantic_ai.models import infer_model, Model, ModelRequestParameters import typechat -from .embeddings import IEmbeddingModel, NormalizedEmbedding, NormalizedEmbeddings +from .embeddings import ( + CachingEmbeddingModel, + NormalizedEmbedding, + NormalizedEmbeddings, +) # --------------------------------------------------------------------------- # Chat model adapter @@ -92,13 +96,13 @@ async def complete( # --------------------------------------------------------------------------- -class PydanticAIEmbeddingModel(IEmbeddingModel): - """Adapter from :class:`pydantic_ai.Embedder` to :class:`IEmbeddingModel`. +class PydanticAIEmbedder: + """Adapter from :class:`pydantic_ai.Embedder` to :class:`IEmbedder`. This lets any pydantic_ai embedding provider (OpenAI, Cohere, Google, …) - be used wherever the codebase expects an ``IEmbeddingModel``, including - :class:`~typeagent.aitools.vectorbase.VectorBase` and - :class:`~typeagent.knowpro.convsettings.ConversationSettings`. + be used wherever the codebase expects an ``IEmbedder``. Wrap in + :class:`~typeagent.aitools.embeddings.CachingEmbeddingModel` to get a + ready-to-use ``IEmbeddingModel`` with caching. If *embedding_size* is not given, it is probed automatically by making a single embedding call. @@ -116,10 +120,6 @@ def __init__( self._embedder = embedder self.model_name = model_name self.embedding_size = embedding_size - self._cache: dict[str, NormalizedEmbedding] = {} - - def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: - self._cache[key] = embedding async def _probe_embedding_size(self) -> None: """Discover embedding_size by making a single API call.""" @@ -152,26 +152,6 @@ async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings embeddings = (embeddings / norms).astype(np.float32) return embeddings - async def get_embedding(self, key: str) -> NormalizedEmbedding: - cached = self._cache.get(key) - if cached is not None: - return cached - embedding = await self.get_embedding_nocache(key) - self._cache[key] = embedding - return embedding - - async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: - if not keys: - if self.embedding_size == 0: - await self._probe_embedding_size() - return np.empty((0, self.embedding_size), dtype=np.float32) - missing_keys = [k for k in keys if k not in self._cache] - if missing_keys: - fresh = await self.get_embeddings_nocache(missing_keys) - for i, k in enumerate(missing_keys): - self._cache[k] = fresh[i] - return np.array([self._cache[k] for k in keys], dtype=np.float32) - # --------------------------------------------------------------------------- # Provider auto-detection @@ -279,7 +259,7 @@ def create_embedding_model( model_spec: str | None = None, *, embedding_size: int = 0, -) -> PydanticAIEmbeddingModel: +) -> CachingEmbeddingModel: """Create an embedding model from a ``provider:model`` spec. Delegates to :class:`pydantic_ai.Embedder` for provider wiring. @@ -290,6 +270,9 @@ def create_embedding_model( If *embedding_size* is not given, it will be probed automatically on the first embedding call. + Returns a :class:`~typeagent.aitools.embeddings.CachingEmbeddingModel` + wrapping a :class:`PydanticAIEmbedder`. + Examples:: model = create_embedding_model("openai:text-embedding-3-small") @@ -324,7 +307,9 @@ def create_embedding_model( embedder = _PydanticAIEmbedder(embedding_model) else: embedder = _PydanticAIEmbedder(model_spec) - return PydanticAIEmbeddingModel(embedder, model_name, embedding_size) + return CachingEmbeddingModel( + PydanticAIEmbedder(embedder, model_name, embedding_size) + ) # --------------------------------------------------------------------------- @@ -400,12 +385,14 @@ async def embed( def create_test_embedding_model( embedding_size: int = 3, -) -> PydanticAIEmbeddingModel: - """Create a :class:`PydanticAIEmbeddingModel` with deterministic fake +) -> CachingEmbeddingModel: + """Create a :class:`CachingEmbeddingModel` with deterministic fake embeddings for testing. No API keys or network access required.""" fake_model = _FakePydanticAIEmbeddingModel(embedding_size) - embedder = _PydanticAIEmbedder(fake_model) - return PydanticAIEmbeddingModel(embedder, "test", embedding_size) + pydantic_embedder = _PydanticAIEmbedder(fake_model) + return CachingEmbeddingModel( + PydanticAIEmbedder(pydantic_embedder, "test", embedding_size) + ) def configure_models( @@ -413,7 +400,7 @@ def configure_models( embedding_model_spec: str, *, embedding_size: int = 0, -) -> tuple[PydanticAIChatModel, PydanticAIEmbeddingModel]: +) -> tuple[PydanticAIChatModel, CachingEmbeddingModel]: """Configure both a chat model and an embedding model at once. Delegates to pydantic_ai's model registry for provider wiring. diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index df93997f..076a3476 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -175,7 +175,6 @@ def _set_embedding_size(self, size: int) -> None: """Adopt *size* when it was not known at construction time.""" assert size > 0 self._embedding_size = size - self._model.embedding_size = size self.settings.embedding_size = size def clear(self) -> None: diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 2ae09845..94f17e7a 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -5,8 +5,7 @@ import pytest from pytest_mock import MockerFixture -from typeagent.aitools.embeddings import IEmbeddingModel -from typeagent.aitools.model_adapters import PydanticAIEmbeddingModel +from typeagent.aitools.embeddings import CachingEmbeddingModel, IEmbeddingModel from conftest import ( embedding_model, # type: ignore # Magic, prevents side effects of mocking @@ -14,7 +13,7 @@ @pytest.mark.asyncio -async def test_get_embedding_nocache(embedding_model: PydanticAIEmbeddingModel): +async def test_get_embedding_nocache(embedding_model: CachingEmbeddingModel): """Test retrieving an embedding without using the cache.""" input_text = "Hello, world" embedding = await embedding_model.get_embedding_nocache(input_text) @@ -25,7 +24,7 @@ async def test_get_embedding_nocache(embedding_model: PydanticAIEmbeddingModel): @pytest.mark.asyncio -async def test_get_embeddings_nocache(embedding_model: PydanticAIEmbeddingModel): +async def test_get_embeddings_nocache(embedding_model: CachingEmbeddingModel): """Test retrieving multiple embeddings without using the cache.""" inputs = ["Hello, world", "Foo bar baz"] embeddings = await embedding_model.get_embeddings_nocache(inputs) @@ -37,7 +36,7 @@ async def test_get_embeddings_nocache(embedding_model: PydanticAIEmbeddingModel) @pytest.mark.asyncio async def test_get_embedding_with_cache( - embedding_model: PydanticAIEmbeddingModel, mocker: MockerFixture + embedding_model: CachingEmbeddingModel, mocker: MockerFixture ): """Test retrieving an embedding with caching.""" input_text = "Hello, world" @@ -46,9 +45,9 @@ async def test_get_embedding_with_cache( embedding1 = await embedding_model.get_embedding(input_text) assert input_text in embedding_model._cache - # Mock the nocache method to ensure it's not called + # Mock the nocache method on the underlying embedder to ensure it's not called mock_get_embedding_nocache = mocker.patch.object( - embedding_model, "get_embedding_nocache", autospec=True + embedding_model._embedder, "get_embedding_nocache", autospec=True ) # Second call should retrieve from the cache @@ -61,7 +60,7 @@ async def test_get_embedding_with_cache( @pytest.mark.asyncio async def test_get_embeddings_with_cache( - embedding_model: PydanticAIEmbeddingModel, mocker: MockerFixture + embedding_model: CachingEmbeddingModel, mocker: MockerFixture ): """Test retrieving multiple embeddings with caching.""" inputs = ["Hello, world", "Foo bar baz"] @@ -71,9 +70,9 @@ async def test_get_embeddings_with_cache( for input_text in inputs: assert input_text in embedding_model._cache - # Mock the nocache method to ensure it's not called + # Mock the nocache method on the underlying embedder to ensure it's not called mock_get_embeddings_nocache = mocker.patch.object( - embedding_model, "get_embeddings_nocache", autospec=True + embedding_model._embedder, "get_embeddings_nocache", autospec=True ) # Second call should retrieve from the cache @@ -85,7 +84,7 @@ async def test_get_embeddings_with_cache( @pytest.mark.asyncio -async def test_get_embeddings_empty_input(embedding_model: PydanticAIEmbeddingModel): +async def test_get_embeddings_empty_input(embedding_model: CachingEmbeddingModel): """Test retrieving embeddings for an empty input list.""" inputs: list[str] = [] embeddings = await embedding_model.get_embeddings(inputs) @@ -96,7 +95,7 @@ async def test_get_embeddings_empty_input(embedding_model: PydanticAIEmbeddingMo @pytest.mark.asyncio -async def test_add_embedding_to_cache(embedding_model: PydanticAIEmbeddingModel): +async def test_add_embedding_to_cache(embedding_model: CachingEmbeddingModel): """Test adding an embedding to the cache.""" key = "test_key" embedding = np.array([0.1, 0.2, 0.3], dtype=np.float32) @@ -108,7 +107,7 @@ async def test_add_embedding_to_cache(embedding_model: PydanticAIEmbeddingModel) @pytest.mark.asyncio async def test_get_embedding_nocache_empty_input( - embedding_model: PydanticAIEmbeddingModel, + embedding_model: CachingEmbeddingModel, ): """Test retrieving an embedding with no cache for an empty input.""" with pytest.raises(ValueError, match="Empty input text"): @@ -116,7 +115,7 @@ async def test_get_embedding_nocache_empty_input( @pytest.mark.asyncio -async def test_embeddings_are_normalized(embedding_model: PydanticAIEmbeddingModel): +async def test_embeddings_are_normalized(embedding_model: CachingEmbeddingModel): """Test that returned embeddings are unit-normalized.""" inputs = ["Hello, world", "Foo bar baz", "Testing normalization"] embeddings = await embedding_model.get_embeddings_nocache(inputs) @@ -128,7 +127,7 @@ async def test_embeddings_are_normalized(embedding_model: PydanticAIEmbeddingMod @pytest.mark.asyncio async def test_embeddings_are_deterministic( - embedding_model: PydanticAIEmbeddingModel, + embedding_model: CachingEmbeddingModel, ): """Test that the same input always produces the same embedding.""" input_text = "Deterministic test" @@ -139,7 +138,7 @@ async def test_embeddings_are_deterministic( @pytest.mark.asyncio async def test_different_inputs_produce_different_embeddings( - embedding_model: PydanticAIEmbeddingModel, + embedding_model: CachingEmbeddingModel, ): """Test that different inputs produce different embeddings.""" e1 = await embedding_model.get_embedding_nocache("Hello") @@ -149,7 +148,7 @@ async def test_different_inputs_produce_different_embeddings( @pytest.mark.asyncio async def test_implements_iembedding_model( - embedding_model: PydanticAIEmbeddingModel, + embedding_model: CachingEmbeddingModel, ): - """Test that PydanticAIEmbeddingModel satisfies the IEmbeddingModel protocol.""" + """Test that CachingEmbeddingModel satisfies the IEmbeddingModel protocol.""" assert isinstance(embedding_model, IEmbeddingModel) diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index c7ccf8f9..f089c4ba 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -17,13 +17,13 @@ from pydantic_ai.models import Model import typechat -from typeagent.aitools.embeddings import NormalizedEmbedding +from typeagent.aitools.embeddings import CachingEmbeddingModel, NormalizedEmbedding from typeagent.aitools.model_adapters import ( configure_models, create_chat_model, create_embedding_model, PydanticAIChatModel, - PydanticAIEmbeddingModel, + PydanticAIEmbedder, ) # --------------------------------------------------------------------------- @@ -99,13 +99,13 @@ async def test_chat_adapter_prompt_sections() -> None: # --------------------------------------------------------------------------- -# PydanticAIEmbeddingModel adapter +# PydanticAIEmbedder adapter # --------------------------------------------------------------------------- @pytest.mark.asyncio async def test_embedding_adapter_single() -> None: - """PydanticAIEmbeddingModel computes a single normalized embedding.""" + """PydanticAIEmbedder computes a single normalized embedding.""" mock_embedder = AsyncMock(spec=Embedder) raw_vec = [3.0, 4.0, 0.0] mock_embedder.embed_documents.return_value = EmbeddingResult( @@ -116,7 +116,7 @@ async def test_embedding_adapter_single() -> None: provider_name="test", ) - adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) + adapter = PydanticAIEmbedder(mock_embedder, "test-model", 3) result = await adapter.get_embedding_nocache("test") assert result.shape == (3,) norm = float(np.linalg.norm(result)) @@ -135,7 +135,7 @@ async def test_embedding_adapter_probes_size() -> None: provider_name="test", ) - adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model") + adapter = PydanticAIEmbedder(mock_embedder, "test-model") assert adapter.embedding_size == 0 await adapter.get_embedding_nocache("probe") assert adapter.embedding_size == 3 @@ -143,7 +143,7 @@ async def test_embedding_adapter_probes_size() -> None: @pytest.mark.asyncio async def test_embedding_adapter_batch() -> None: - """PydanticAIEmbeddingModel computes batch embeddings.""" + """PydanticAIEmbedder computes batch embeddings.""" mock_embedder = AsyncMock(spec=Embedder) mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0], [0.0, 1.0]], @@ -153,14 +153,14 @@ async def test_embedding_adapter_batch() -> None: provider_name="test", ) - adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 2) + adapter = PydanticAIEmbedder(mock_embedder, "test-model", 2) result = await adapter.get_embeddings_nocache(["a", "b"]) assert result.shape == (2, 2) @pytest.mark.asyncio async def test_embedding_adapter_caching() -> None: - """Caching avoids re-computing embeddings.""" + """CachingEmbeddingModel avoids re-computing embeddings.""" mock_embedder = AsyncMock(spec=Embedder) mock_embedder.embed_documents.return_value = EmbeddingResult( embeddings=[[1.0, 0.0, 0.0]], @@ -170,7 +170,8 @@ async def test_embedding_adapter_caching() -> None: provider_name="test", ) - adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) + embedder = PydanticAIEmbedder(mock_embedder, "test-model", 3) + adapter = CachingEmbeddingModel(embedder) first = await adapter.get_embedding("cached") second = await adapter.get_embedding("cached") np.testing.assert_array_equal(first, second) @@ -182,7 +183,8 @@ async def test_embedding_adapter_caching() -> None: async def test_embedding_adapter_add_embedding() -> None: """add_embedding() populates the cache.""" mock_embedder = AsyncMock(spec=Embedder) - adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 3) + embedder = PydanticAIEmbedder(mock_embedder, "test-model", 3) + adapter = CachingEmbeddingModel(embedder) vec: NormalizedEmbedding = np.array([1.0, 0.0, 0.0], dtype=np.float32) adapter.add_embedding("key", vec) result = await adapter.get_embedding("key") @@ -195,7 +197,8 @@ async def test_embedding_adapter_add_embedding() -> None: async def test_embedding_adapter_empty_batch() -> None: """Empty batch returns empty array with known size.""" mock_embedder = AsyncMock(spec=Embedder) - adapter = PydanticAIEmbeddingModel(mock_embedder, "test-model", 4) + embedder = PydanticAIEmbedder(mock_embedder, "test-model", 4) + adapter = CachingEmbeddingModel(embedder) result = await adapter.get_embeddings_nocache([]) assert result.shape == (0, 4) @@ -212,4 +215,4 @@ def test_configure_models_returns_correct_types( monkeypatch.setenv("OPENAI_API_KEY", "test-key") chat, embedder = configure_models("openai:gpt-4o", "openai:text-embedding-3-small") assert isinstance(chat, PydanticAIChatModel) - assert isinstance(embedder, PydanticAIEmbeddingModel) + assert isinstance(embedder, CachingEmbeddingModel) diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index 04c53c36..416ed3eb 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -5,11 +5,11 @@ import pytest from typeagent.aitools.embeddings import ( + CachingEmbeddingModel, NormalizedEmbedding, ) from typeagent.aitools.model_adapters import ( create_test_embedding_model, - PydanticAIEmbeddingModel, ) from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings, VectorBase @@ -61,8 +61,8 @@ def test_add_embeddings(vector_base: VectorBase, sample_embeddings: Samples): assert len(bulk_vector_base) == len(vector_base) np.testing.assert_array_equal(bulk_vector_base.serialize(), vector_base.serialize()) - assert isinstance(vector_base._model, PydanticAIEmbeddingModel) - assert isinstance(bulk_vector_base._model, PydanticAIEmbeddingModel) + assert isinstance(vector_base._model, CachingEmbeddingModel) + assert isinstance(bulk_vector_base._model, CachingEmbeddingModel) sequential_cache = vector_base._model._cache bulk_cache = bulk_vector_base._model._cache assert set(sequential_cache.keys()) == set(bulk_cache.keys()) @@ -86,7 +86,7 @@ async def test_add_key_no_cache(vector_base: VectorBase, sample_embeddings: Samp await vector_base.add_key(key, cache=False) assert len(vector_base) == len(sample_embeddings) - assert isinstance(vector_base._model, PydanticAIEmbeddingModel) + assert isinstance(vector_base._model, CachingEmbeddingModel) assert vector_base._model._cache == {}, "Cache should remain empty when cache=False" @@ -106,7 +106,7 @@ async def test_add_keys_no_cache(vector_base: VectorBase, sample_embeddings: Sam await vector_base.add_keys(keys, cache=False) assert len(vector_base) == len(sample_embeddings) - assert isinstance(vector_base._model, PydanticAIEmbeddingModel) + assert isinstance(vector_base._model, CachingEmbeddingModel) assert vector_base._model._cache == {}, "Cache should remain empty when cache=False" From 43894bd7adaae5a980a15ffa831bb88415aeafea Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 16:56:33 -0800 Subject: [PATCH 20/49] Remove max_retries everywhere -- this is now under Pydantic control --- src/typeagent/aitools/vectorbase.py | 7 ------- src/typeagent/knowpro/knowledge.py | 12 +++--------- tests/test_knowledge.py | 6 +++--- 3 files changed, 6 insertions(+), 19 deletions(-) diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 076a3476..107c7a21 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -13,8 +13,6 @@ ) from .model_adapters import create_embedding_model -DEFAULT_MAX_RETRIES = 2 - @dataclass class ScoredInt: @@ -29,7 +27,6 @@ class TextEmbeddingIndexSettings: min_score: float # Between 0.0 and 1.0 max_matches: int | None # >= 1; None means no limit batch_size: int # >= 1 - max_retries: int def __init__( self, @@ -38,14 +35,10 @@ def __init__( min_score: float | None = None, max_matches: int | None = None, batch_size: int | None = None, - max_retries: int | None = None, ): self.min_score = min_score if min_score is not None else 0.85 self.max_matches = max_matches if max_matches and max_matches >= 1 else None self.batch_size = batch_size if batch_size and batch_size >= 1 else 8 - self.max_retries = ( - max_retries if max_retries is not None else DEFAULT_MAX_RETRIES - ) self.embedding_model = embedding_model or create_embedding_model( embedding_size=embedding_size or 0, ) diff --git a/src/typeagent/knowpro/knowledge.py b/src/typeagent/knowpro/knowledge.py index dbcfe206..60ce302d 100644 --- a/src/typeagent/knowpro/knowledge.py +++ b/src/typeagent/knowpro/knowledge.py @@ -26,7 +26,6 @@ def create_knowledge_extractor( async def extract_knowledge_from_text( knowledge_extractor: IKnowledgeExtractor, text: str, - max_retries: int, ) -> Result[kplib.KnowledgeResponse]: """Extract knowledge from a single text input with retries.""" # TODO: Add a retry mechanism to handle transient errors. @@ -37,13 +36,10 @@ async def batch_worker( q: asyncio.Queue[tuple[int, str] | None], knowledge_extractor: IKnowledgeExtractor, results: dict[int, Result[kplib.KnowledgeResponse]], - max_retries: int, ) -> None: while item := await q.get(): index, text = item - result = await extract_knowledge_from_text( - knowledge_extractor, text, max_retries - ) + result = await extract_knowledge_from_text(knowledge_extractor, text) results[index] = result @@ -51,7 +47,6 @@ async def extract_knowledge_from_text_batch( knowledge_extractor: IKnowledgeExtractor, text_batch: list[str], concurrency: int = 2, - max_retries: int = 3, ) -> list[Result[kplib.KnowledgeResponse]]: """Extract knowledge from a batch of text inputs concurrently.""" if not text_batch: @@ -64,7 +59,7 @@ async def extract_knowledge_from_text_batch( async with asyncio.TaskGroup() as tg: for _ in range(concurrency): - tg.create_task(batch_worker(q, knowledge_extractor, results, max_retries)) + tg.create_task(batch_worker(q, knowledge_extractor, results)) for index, text in enumerate(text_batch): await q.put((index, text)) @@ -203,7 +198,6 @@ async def extract_knowledge_for_text_batch_q( knowledge_extractor: convknowledge.KnowledgeExtractor, text_batch: list[str], concurrency: int = 2, - max_retries: int = 3, ) -> list[Result[kplib.KnowledgeResponse]]: """Extract knowledge for a batch of text inputs using a task queue.""" raise NotImplementedError("TODO") @@ -212,7 +206,7 @@ async def extract_knowledge_for_text_batch_q( # await run_in_batches( # task_batch, - # lambda text: extract_knowledge_from_text(knowledge_extractor, text, max_retries), + # lambda text: extract_knowledge_from_text(knowledge_extractor, text), # concurrency, # ) diff --git a/tests/test_knowledge.py b/tests/test_knowledge.py index e20ff1f2..d4f46fd1 100644 --- a/tests/test_knowledge.py +++ b/tests/test_knowledge.py @@ -44,12 +44,12 @@ async def test_extract_knowledge_from_text( mock_knowledge_extractor: convknowledge.KnowledgeExtractor, ): """Test extracting knowledge from a single text input.""" - result = await extract_knowledge_from_text(mock_knowledge_extractor, "test text", 3) + result = await extract_knowledge_from_text(mock_knowledge_extractor, "test text") assert isinstance(result, Success) assert result.value.topics[0] == "test text" failure_result = await extract_knowledge_from_text( - mock_knowledge_extractor, "error", 3 + mock_knowledge_extractor, "error" ) assert isinstance(failure_result, Failure) assert failure_result.message == "Extraction failed" @@ -62,7 +62,7 @@ async def test_extract_knowledge_from_text_batch( """Test extracting knowledge from a batch of text inputs.""" text_batch = ["text 1", "text 2", "error"] results = await extract_knowledge_from_text_batch( - mock_knowledge_extractor, text_batch, 2, 3 + mock_knowledge_extractor, text_batch, 2 ) assert len(results) == 3 From 910a99b94f1a78f5e8ceb5e92b69265f364e9567 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 19:40:57 -0800 Subject: [PATCH 21/49] Remove embedding_size argument everywhere. Handle it internally --- src/typeagent/aitools/embeddings.py | 28 ++-- src/typeagent/aitools/model_adapters.py | 41 +---- src/typeagent/aitools/vectorbase.py | 37 +++-- src/typeagent/knowpro/interfaces_storage.py | 1 - src/typeagent/storage/sqlite/messageindex.py | 3 + src/typeagent/storage/sqlite/provider.py | 104 ++++--------- tests/test_conversation_metadata.py | 48 ------ tests/test_embedding_consistency.py | 148 ++++++++++++------- tests/test_embeddings.py | 13 +- tests/test_messageindex.py | 2 +- tests/test_model_adapters.py | 53 ++----- tests/test_vectorbase.py | 25 ++++ 12 files changed, 205 insertions(+), 298 deletions(-) diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index e0e04b7a..5d0e172c 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -25,15 +25,15 @@ class IEmbedder(Protocol): @property def model_name(self) -> str: ... - @property - def embedding_size(self) -> int: ... - async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: """Compute a single embedding without caching.""" ... async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: - """Compute embeddings for a batch of strings without caching.""" + """Compute embeddings for a batch of strings without caching. + + Raises :class:`ValueError` if *input* is empty. + """ ... @@ -49,9 +49,6 @@ class IEmbeddingModel(Protocol): @property def model_name(self) -> str: ... - @property - def embedding_size(self) -> int: ... - def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: """Cache an already-computed embedding under the given key.""" ... @@ -89,10 +86,6 @@ def __init__(self, embedder: IEmbedder) -> None: def model_name(self) -> str: return self._embedder.model_name - @property - def embedding_size(self) -> int: - return self._embedder.embedding_size - def add_embedding(self, key: str, embedding: NormalizedEmbedding) -> None: self._cache[key] = embedding @@ -112,7 +105,7 @@ async def get_embedding(self, key: str) -> NormalizedEmbedding: async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: if not keys: - return await self._embedder.get_embeddings_nocache([]) + raise ValueError("Cannot embed an empty list") missing_keys = [k for k in keys if k not in self._cache] if missing_keys: fresh = await self._embedder.get_embeddings_nocache(missing_keys) @@ -122,14 +115,11 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: DEFAULT_MODEL_NAME = "text-embedding-ada-002" -DEFAULT_EMBEDDING_SIZE = 1536 # Default embedding size (required for ada-002) DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING" # We support OpenAI and Azure OpenAI TEST_MODEL_NAME = "test" -model_to_embedding_size_and_envvar: dict[str, tuple[int | None, str]] = { - DEFAULT_MODEL_NAME: (DEFAULT_EMBEDDING_SIZE, DEFAULT_ENVVAR), - "text-embedding-3-small": (1536, "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL"), - "text-embedding-3-large": (3072, "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE"), - # For testing only, not a real model (insert real embeddings above) - TEST_MODEL_NAME: (3, "SIR_NOT_APPEARING_IN_THIS_FILM"), +model_to_envvar: dict[str, str] = { + DEFAULT_MODEL_NAME: DEFAULT_ENVVAR, + "text-embedding-3-small": "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL", + "text-embedding-3-large": "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE", } diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 6ccf8292..86dca394 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -103,36 +103,23 @@ class PydanticAIEmbedder: be used wherever the codebase expects an ``IEmbedder``. Wrap in :class:`~typeagent.aitools.embeddings.CachingEmbeddingModel` to get a ready-to-use ``IEmbeddingModel`` with caching. - - If *embedding_size* is not given, it is probed automatically by making a - single embedding call. """ model_name: str - embedding_size: int def __init__( self, embedder: _PydanticAIEmbedder, model_name: str, - embedding_size: int = 0, ) -> None: self._embedder = embedder self.model_name = model_name - self.embedding_size = embedding_size - - async def _probe_embedding_size(self) -> None: - """Discover embedding_size by making a single API call.""" - result = await self._embedder.embed_documents(["probe"]) - self.embedding_size = len(result.embeddings[0]) async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: result = await self._embedder.embed_documents([input]) embedding: NDArray[np.float32] = np.array( result.embeddings[0], dtype=np.float32 ) - if self.embedding_size == 0: - self.embedding_size = len(embedding) norm = float(np.linalg.norm(embedding)) if norm > 0: embedding = (embedding / norm).astype(np.float32) @@ -140,13 +127,9 @@ async def get_embedding_nocache(self, input: str) -> NormalizedEmbedding: async def get_embeddings_nocache(self, input: list[str]) -> NormalizedEmbeddings: if not input: - if self.embedding_size == 0: - await self._probe_embedding_size() - return np.empty((0, self.embedding_size), dtype=np.float32) + raise ValueError("Cannot embed an empty list") result = await self._embedder.embed_documents(input) embeddings: NDArray[np.float32] = np.array(result.embeddings, dtype=np.float32) - if self.embedding_size == 0: - self.embedding_size = embeddings.shape[1] norms = np.linalg.norm(embeddings, axis=1, keepdims=True).astype(np.float32) norms = np.where(norms > 0, norms, np.float32(1.0)) embeddings = (embeddings / norms).astype(np.float32) @@ -257,8 +240,6 @@ def create_chat_model( def create_embedding_model( model_spec: str | None = None, - *, - embedding_size: int = 0, ) -> CachingEmbeddingModel: """Create an embedding model from a ``provider:model`` spec. @@ -267,8 +248,6 @@ def create_embedding_model( ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. If *model_spec* is ``None``, :data:`DEFAULT_EMBEDDING_SPEC` is used. - If *embedding_size* is not given, it will be probed automatically - on the first embedding call. Returns a :class:`~typeagent.aitools.embeddings.CachingEmbeddingModel` wrapping a :class:`PydanticAIEmbedder`. @@ -287,12 +266,10 @@ def create_embedding_model( if _needs_azure_fallback(provider): from pydantic_ai.embeddings.openai import OpenAIEmbeddingModel - from .embeddings import model_to_embedding_size_and_envvar + from .embeddings import model_to_envvar # Look up model-specific Azure endpoint, falling back to the generic one. - _, suggested_envvar = model_to_embedding_size_and_envvar.get( - model_name, (None, None) - ) + suggested_envvar = model_to_envvar.get(model_name) if suggested_envvar and os.getenv(suggested_envvar): endpoint_envvar = suggested_envvar else: @@ -307,9 +284,7 @@ def create_embedding_model( embedder = _PydanticAIEmbedder(embedding_model) else: embedder = _PydanticAIEmbedder(model_spec) - return CachingEmbeddingModel( - PydanticAIEmbedder(embedder, model_name, embedding_size) - ) + return CachingEmbeddingModel(PydanticAIEmbedder(embedder, model_name)) # --------------------------------------------------------------------------- @@ -390,16 +365,12 @@ def create_test_embedding_model( embeddings for testing. No API keys or network access required.""" fake_model = _FakePydanticAIEmbeddingModel(embedding_size) pydantic_embedder = _PydanticAIEmbedder(fake_model) - return CachingEmbeddingModel( - PydanticAIEmbedder(pydantic_embedder, "test", embedding_size) - ) + return CachingEmbeddingModel(PydanticAIEmbedder(pydantic_embedder, "test")) def configure_models( chat_model_spec: str, embedding_model_spec: str, - *, - embedding_size: int = 0, ) -> tuple[PydanticAIChatModel, CachingEmbeddingModel]: """Configure both a chat model and an embedding model at once. @@ -417,5 +388,5 @@ def configure_models( """ return ( create_chat_model(chat_model_spec), - create_embedding_model(embedding_model_spec, embedding_size=embedding_size), + create_embedding_model(embedding_model_spec), ) diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 107c7a21..46c4dbae 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -23,7 +23,6 @@ class ScoredInt: @dataclass class TextEmbeddingIndexSettings: embedding_model: IEmbeddingModel - embedding_size: int # Set to embedding_model.embedding_size min_score: float # Between 0.0 and 1.0 max_matches: int | None # >= 1; None means no limit batch_size: int # >= 1 @@ -31,7 +30,6 @@ class TextEmbeddingIndexSettings: def __init__( self, embedding_model: IEmbeddingModel | None = None, - embedding_size: int | None = None, min_score: float | None = None, max_matches: int | None = None, batch_size: int | None = None, @@ -39,13 +37,7 @@ def __init__( self.min_score = min_score if min_score is not None else 0.85 self.max_matches = max_matches if max_matches and max_matches >= 1 else None self.batch_size = batch_size if batch_size and batch_size >= 1 else 8 - self.embedding_model = embedding_model or create_embedding_model( - embedding_size=embedding_size or 0, - ) - self.embedding_size = self.embedding_model.embedding_size - assert ( - embedding_size is None or self.embedding_size == embedding_size - ), f"Given embedding size {embedding_size} doesn't match model's embedding size {self.embedding_size}" + self.embedding_model = embedding_model or create_embedding_model() class VectorBase: @@ -57,7 +49,7 @@ class VectorBase: def __init__(self, settings: TextEmbeddingIndexSettings): self.settings = settings self._model = settings.embedding_model - self._embedding_size = self._model.embedding_size + self._embedding_size = 0 self.clear() async def get_embedding(self, key: str, cache: bool = True) -> NormalizedEmbedding: @@ -89,6 +81,11 @@ def add_embedding( if self._embedding_size == 0: self._set_embedding_size(len(embedding)) self._vectors.shape = (0, self._embedding_size) + if len(embedding) != self._embedding_size: + raise ValueError( + f"Embedding size mismatch: expected {self._embedding_size}, " + f"got {len(embedding)}" + ) embeddings = embedding.reshape(1, -1) # Make it 2D: 1xN self._vectors = np.append(self._vectors, embeddings, axis=0) if key is not None: @@ -97,23 +94,30 @@ def add_embedding( def add_embeddings( self, keys: None | list[str], embeddings: NormalizedEmbeddings ) -> None: - assert embeddings.ndim == 2 + if embeddings.ndim != 2: + raise ValueError(f"Expected 2D embeddings array, got {embeddings.ndim}D") if self._embedding_size == 0: self._set_embedding_size(embeddings.shape[1]) self._vectors.shape = (0, self._embedding_size) - assert embeddings.shape[1] == self._embedding_size + if embeddings.shape[1] != self._embedding_size: + raise ValueError( + f"Embedding size mismatch: expected {self._embedding_size}, " + f"got {embeddings.shape[1]}" + ) self._vectors = np.concatenate((self._vectors, embeddings), axis=0) if keys is not None: for key, embedding in zip(keys, embeddings): self._model.add_embedding(key, embedding) async def add_key(self, key: str, cache: bool = True) -> None: - embeddings = (await self.get_embedding(key, cache=cache)).reshape(1, -1) - self._vectors = np.append(self._vectors, embeddings, axis=0) + embedding = await self.get_embedding(key, cache=cache) + self.add_embedding(key if cache else None, embedding) async def add_keys(self, keys: list[str], cache: bool = True) -> None: + if not keys: + return embeddings = await self.get_embeddings(keys, cache=cache) - self._vectors = np.concatenate((self._vectors, embeddings), axis=0) + self.add_embeddings(keys if cache else None, embeddings) def fuzzy_lookup_embedding( self, @@ -126,6 +130,8 @@ def fuzzy_lookup_embedding( max_hits = 10 if min_score is None: min_score = 0.0 + if len(self._vectors) == 0: + return [] # This line does most of the work: scores: Iterable[float] = np.dot(self._vectors, embedding) scored_ordinals = [ @@ -168,7 +174,6 @@ def _set_embedding_size(self, size: int) -> None: """Adopt *size* when it was not known at construction time.""" assert size > 0 self._embedding_size = size - self.settings.embedding_size = size def clear(self) -> None: self._vectors = np.array([], dtype=np.float32) diff --git a/src/typeagent/knowpro/interfaces_storage.py b/src/typeagent/knowpro/interfaces_storage.py index 877e3ba9..a82fe7ad 100644 --- a/src/typeagent/knowpro/interfaces_storage.py +++ b/src/typeagent/knowpro/interfaces_storage.py @@ -52,7 +52,6 @@ class ConversationMetadata: schema_version: int | None = None created_at: Datetime | None = None updated_at: Datetime | None = None - embedding_size: int | None = None embedding_model: str | None = None tags: list[str] | None = None extra: dict[str, str] | None = None diff --git a/src/typeagent/storage/sqlite/messageindex.py b/src/typeagent/storage/sqlite/messageindex.py index 877cbd6d..d48a9761 100644 --- a/src/typeagent/storage/sqlite/messageindex.py +++ b/src/typeagent/storage/sqlite/messageindex.py @@ -63,6 +63,9 @@ async def add_messages_starting_at( for chunk_ord, chunk in enumerate(message.text_chunks): chunks_to_embed.append((msg_ord, chunk_ord, chunk)) + if not chunks_to_embed: + return + embeddings = await self._vectorbase.get_embeddings( [chunk for _, _, chunk in chunks_to_embed], cache=False ) diff --git a/src/typeagent/storage/sqlite/provider.py b/src/typeagent/storage/sqlite/provider.py index 515c9cf0..3d5a3185 100644 --- a/src/typeagent/storage/sqlite/provider.py +++ b/src/typeagent/storage/sqlite/provider.py @@ -31,7 +31,7 @@ class SqliteStorageProvider[TMessage: interfaces.IMessage]( """SQLite-backed storage provider implementation. This provider performs consistency checks on database initialization to ensure - that existing embeddings match the configured embedding_size. If a mismatch is + that existing embeddings match the configured embedding model. If a mismatch is detected, a ValueError is raised with a descriptive error message. """ @@ -119,22 +119,16 @@ def _resolve_embedding_settings( provided_related_settings: RelatedTermIndexSettings | None, ) -> tuple[MessageTextIndexSettings, RelatedTermIndexSettings]: metadata_exists = self._conversation_metadata_exists() - stored_size_str = self._get_single_metadata_value("embedding_size") stored_name = self._get_single_metadata_value("embedding_name") - stored_size = int(stored_size_str) if stored_size_str else None if provided_message_settings is None: - if stored_size is not None or stored_name is not None: - spec = stored_name or "" + if stored_name is not None: + spec = stored_name if spec and ":" not in spec: spec = f"openai:{spec}" - embedding_model = create_embedding_model( - spec, - embedding_size=stored_size or 0, - ) + embedding_model = create_embedding_model(spec) base_embedding_settings = TextEmbeddingIndexSettings( embedding_model=embedding_model, - embedding_size=stored_size, ) else: base_embedding_settings = TextEmbeddingIndexSettings() @@ -142,13 +136,7 @@ def _resolve_embedding_settings( else: message_settings = provided_message_settings base_embedding_settings = message_settings.embedding_index_settings - provided_size = base_embedding_settings.embedding_size provided_name = base_embedding_settings.embedding_model.model_name - if stored_size is not None and stored_size != provided_size: - raise ValueError( - f"Conversation metadata embedding_size " - f"({stored_size}) does not match provided embedding size ({provided_size})." - ) if stored_name is not None and stored_name != provided_name: raise ValueError( f"Conversation metadata embedding_model " @@ -160,12 +148,7 @@ def _resolve_embedding_settings( else: related_settings = provided_related_settings related_embedding_settings = related_settings.embedding_index_settings - related_size = related_embedding_settings.embedding_size related_name = related_embedding_settings.embedding_model.model_name - if related_size != base_embedding_settings.embedding_size: - raise ValueError( - "Related term index embedding_size does not match message text index embedding_size" - ) if related_name != base_embedding_settings.embedding_model.model_name: raise ValueError( "Related term index embedding_model does not match message text index embedding_model" @@ -173,17 +156,9 @@ def _resolve_embedding_settings( if related_settings.embedding_index_settings is not base_embedding_settings: related_settings.embedding_index_settings = base_embedding_settings - actual_size = base_embedding_settings.embedding_size actual_name = base_embedding_settings.embedding_model.model_name if self._metadata is not None: - if self._metadata.embedding_size is None: - self._metadata.embedding_size = actual_size - elif self._metadata.embedding_size != actual_size: - raise ValueError( - "Conversation metadata embedding_size does not match provider settings" - ) - if self._metadata.embedding_model is None: self._metadata.embedding_model = actual_name elif self._metadata.embedding_model != actual_name: @@ -193,8 +168,6 @@ def _resolve_embedding_settings( if metadata_exists: metadata_updates: dict[str, str] = {} - if stored_size is None: - metadata_updates["embedding_size"] = str(actual_size) if stored_name is None: metadata_updates["embedding_name"] = actual_name if metadata_updates: @@ -203,51 +176,47 @@ def _resolve_embedding_settings( return message_settings, related_settings def _check_embedding_consistency(self) -> None: - """Check that existing embeddings in the database match the expected embedding size. + """Check that existing embeddings in the database are consistent. - This method is called during initialization to ensure that embeddings stored in the - database match the embedding_size specified in ConversationSettings. This prevents - runtime errors when trying to use embeddings of incompatible sizes. + This method is called during initialization to ensure that embeddings + stored in the message text index and related terms index have the same + size. This prevents runtime errors when trying to use embeddings of + incompatible sizes. Raises: - ValueError: If embeddings in the database don't match the expected size. + ValueError: If embeddings in the database have inconsistent sizes. """ from .schema import deserialize_embedding cursor = self.db.cursor() - expected_size = ( - self.message_text_index_settings.embedding_index_settings.embedding_size - ) - # Check message text index embeddings + # Get size from message text index embeddings + message_size: int | None = None cursor.execute("SELECT embedding FROM MessageTextIndex LIMIT 1") row = cursor.fetchone() if row and row[0]: embedding = deserialize_embedding(row[0]) - actual_size = len(embedding) - if actual_size != expected_size: - raise ValueError( - f"Message text index embedding size mismatch: " - f"database contains embeddings of size {actual_size}, " - f"but ConversationSettings specifies embedding_size={expected_size}. " - f"The database was likely created with a different embedding model. " - f"Please use the same embedding model or create a new database." - ) + message_size = len(embedding) - # Check related terms fuzzy index embeddings + # Get size from related terms fuzzy index embeddings + related_size: int | None = None cursor.execute("SELECT term_embedding FROM RelatedTermsFuzzy LIMIT 1") row = cursor.fetchone() if row and row[0]: embedding = deserialize_embedding(row[0]) - actual_size = len(embedding) - if actual_size != expected_size: - raise ValueError( - f"Related terms index embedding size mismatch: " - f"database contains embeddings of size {actual_size}, " - f"but ConversationSettings specifies embedding_size={expected_size}. " - f"The database was likely created with a different embedding model. " - f"Please use the same embedding model or create a new database." - ) + related_size = len(embedding) + + if ( + message_size is not None + and related_size is not None + and message_size != related_size + ): + raise ValueError( + f"Embedding size mismatch: " + f"message text index has size {message_size}, " + f"but related terms index has size {related_size}. " + f"The database may be corrupted." + ) def _init_conversation_metadata_if_needed(self) -> None: """Initialize conversation metadata if the database is new (empty metadata table). @@ -276,18 +245,10 @@ def _init_conversation_metadata_if_needed(self) -> None: tags = None extras = {} - actual_embedding_size = ( - self.message_text_index_settings.embedding_index_settings.embedding_size - ) actual_embedding_name = ( self.message_text_index_settings.embedding_index_settings.embedding_model.model_name ) - metadata_embedding_size = ( - self._metadata.embedding_size - if self._metadata and self._metadata.embedding_size is not None - else actual_embedding_size - ) metadata_embedding_name = ( self._metadata.embedding_model if self._metadata and self._metadata.embedding_model is not None @@ -309,7 +270,6 @@ def _init_conversation_metadata_if_needed(self) -> None: created_at=format_timestamp_utc(current_time), updated_at=format_timestamp_utc(current_time), tag=tags, # None or list of tags - embedding_size=str(metadata_embedding_size), embedding_name=metadata_embedding_name, **extras, ) @@ -516,9 +476,6 @@ def parse_datetime(value_str: str) -> datetime: updated_at_str = get_single("updated_at") updated_at = parse_datetime(updated_at_str) if updated_at_str else None - embedding_size_str = get_single("embedding_size") - embedding_size = int(embedding_size_str) if embedding_size_str else None - embedding_model = get_single("embedding_name") # Handle tags (multiple values allowed, None if key doesn't exist) @@ -545,7 +502,6 @@ def parse_datetime(value_str: str) -> datetime: schema_version=schema_version, created_at=created_at, updated_at=updated_at, - embedding_size=embedding_size, embedding_model=embedding_model, tags=tags, extra=extra if extra else None, @@ -592,9 +548,6 @@ async def update_conversation_timestamps( # Insert default values if no metadata exists name_tag = self._metadata.name_tag if self._metadata else "conversation" schema_version = str(CONVERSATION_SCHEMA_VERSION) - actual_embedding_size = ( - self.message_text_index_settings.embedding_index_settings.embedding_size - ) actual_embedding_name = ( self.message_text_index_settings.embedding_index_settings.embedding_model.model_name ) @@ -602,7 +555,6 @@ async def update_conversation_timestamps( metadata_kwds: dict[str, str | None] = { "name_tag": name_tag or "conversation", "schema_version": schema_version, - "embedding_size": str(actual_embedding_size), "embedding_name": actual_embedding_name, } if created_at is not None: diff --git a/tests/test_conversation_metadata.py b/tests/test_conversation_metadata.py index 887c50b2..37a194a2 100644 --- a/tests/test_conversation_metadata.py +++ b/tests/test_conversation_metadata.py @@ -106,7 +106,6 @@ async def test_get_conversation_metadata_nonexistent( assert metadata.schema_version is None assert metadata.created_at is None assert metadata.updated_at is None - assert metadata.embedding_size is None assert metadata.embedding_model is None assert metadata.tags is None assert metadata.extra is None @@ -131,9 +130,7 @@ async def test_update_conversation_timestamps_new( assert metadata.created_at == created_at assert metadata.updated_at == updated_at settings = storage_provider.message_text_index_settings.embedding_index_settings - expected_size = settings.embedding_size expected_model = settings.embedding_model.model_name - assert metadata.embedding_size == expected_size assert metadata.embedding_model == expected_model assert metadata.tags is None assert metadata.extra is None @@ -458,9 +455,7 @@ async def test_conversation_metadata_persistence( assert metadata.name_tag == "conversation_persistent_test" assert metadata.created_at == created_at assert metadata.updated_at == updated_at - expected_size = embedding_settings.embedding_size expected_model = embedding_settings.embedding_model.model_name - assert metadata.embedding_size == expected_size assert metadata.embedding_model == expected_model finally: await provider2.close() @@ -598,49 +593,6 @@ async def test_conversation_metadata_shared_access( await provider1.close() await provider2.close() - @pytest.mark.asyncio - async def test_embedding_metadata_mismatch_raises( - self, temp_db_path: str, embedding_model: IEmbeddingModel - ): - """Ensure a mismatch between stored metadata and provided settings raises.""" - embedding_settings = TextEmbeddingIndexSettings(embedding_model) - message_text_settings = MessageTextIndexSettings(embedding_settings) - related_terms_settings = RelatedTermIndexSettings(embedding_settings) - - provider = SqliteStorageProvider( - db_path=temp_db_path, - message_type=DummyMessage, - message_text_index_settings=message_text_settings, - related_term_index_settings=related_terms_settings, - ) - - await provider.update_conversation_timestamps( - created_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), - updated_at=datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc), - ) - provider.db.commit() - await provider.close() - - mismatched_model = create_test_embedding_model( - embedding_size=embedding_settings.embedding_size + 1, - ) - mismatched_settings = TextEmbeddingIndexSettings( - embedding_model=mismatched_model, - embedding_size=mismatched_model.embedding_size, - ) - - with pytest.raises(ValueError, match="embedding_size"): - SqliteStorageProvider( - db_path=temp_db_path, - message_type=DummyMessage, - message_text_index_settings=MessageTextIndexSettings( - mismatched_settings - ), - related_term_index_settings=RelatedTermIndexSettings( - mismatched_settings - ), - ) - @pytest.mark.asyncio async def test_embedding_model_mismatch_raises( self, temp_db_path: str, embedding_model: IEmbeddingModel diff --git a/tests/test_embedding_consistency.py b/tests/test_embedding_consistency.py index f032c856..619c9210 100644 --- a/tests/test_embedding_consistency.py +++ b/tests/test_embedding_consistency.py @@ -1,39 +1,38 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""Test embedding consistency checks between database and settings.""" +"""Test embedding consistency checks between database indexes.""" import os +import sqlite3 import tempfile +import numpy as np import pytest from typeagent import create_conversation from typeagent.aitools.model_adapters import create_test_embedding_model from typeagent.knowpro.convsettings import ConversationSettings from typeagent.storage.sqlite import SqliteStorageProvider +from typeagent.storage.sqlite.schema import serialize_embedding from typeagent.transcripts.transcript import TranscriptMessage, TranscriptMessageMeta @pytest.mark.asyncio -async def test_embedding_size_mismatch_in_message_index(): - """Test that opening a DB with mismatched embedding size raises an error.""" - # Create a temporary database file +async def test_same_embedding_size_no_error(): + """Test that opening a DB with the same model works fine.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name try: - # Create a conversation with test model (embedding size 3) settings1 = ConversationSettings( model=create_test_embedding_model(embedding_size=3) ) - # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False conv1 = await create_conversation( db_path, TranscriptMessage, settings=settings1 ) - # Add some messages to populate the index messages = [ TranscriptMessage( text_chunks=["Hello world"], @@ -43,108 +42,151 @@ async def test_embedding_size_mismatch_in_message_index(): await conv1.add_messages_with_indexing(messages) await conv1.storage_provider.close() - # Now try to open the same database with a different embedding size - # This should raise an error + # Reopen with same settings — should work settings2 = ConversationSettings( - model=create_test_embedding_model(embedding_size=5) + model=create_test_embedding_model(embedding_size=3) ) - - with pytest.raises(ValueError, match="embedding_size"): - provider = SqliteStorageProvider( - db_path=db_path, - message_type=TranscriptMessage, - message_text_index_settings=settings2.message_text_index_settings, - related_term_index_settings=settings2.related_term_index_settings, - ) - await provider.close() + provider = SqliteStorageProvider( + db_path=db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings2.message_text_index_settings, + related_term_index_settings=settings2.related_term_index_settings, + ) + await provider.close() finally: - # Clean up the temporary database if os.path.exists(db_path): os.unlink(db_path) @pytest.mark.asyncio -async def test_embedding_size_mismatch_in_related_terms(): - """Test that opening a DB with mismatched embedding size in related terms raises an error.""" - # Create a temporary database file +async def test_empty_db_no_error(): + """Test that opening an empty DB doesn't raise an error regardless of embedding size.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name try: - # Create a conversation with default embedding size settings1 = ConversationSettings( model=create_test_embedding_model(embedding_size=3) ) - # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False conv1 = await create_conversation( db_path, TranscriptMessage, settings=settings1 ) - - # Add some messages to populate the related terms index - messages = [ - TranscriptMessage( - text_chunks=["Apple is a fruit"], - metadata=TranscriptMessageMeta(speaker="Alice"), - ) - ] - await conv1.add_messages_with_indexing(messages) await conv1.storage_provider.close() - # Now try to open the same database with a different embedding size - # This should raise an error + # Open with different embedding size should work since DB is empty settings2 = ConversationSettings( model=create_test_embedding_model(embedding_size=5) ) + provider = SqliteStorageProvider( + db_path=db_path, + message_type=TranscriptMessage, + message_text_index_settings=settings2.message_text_index_settings, + related_term_index_settings=settings2.related_term_index_settings, + ) + await provider.close() + + finally: + if os.path.exists(db_path): + os.unlink(db_path) + + +@pytest.mark.asyncio +async def test_embedding_size_mismatch_raises(): + """Test that mismatched embedding sizes between indexes raises ValueError.""" + with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: + db_path = tmp.name - with pytest.raises(ValueError, match="embedding_size"): - provider = SqliteStorageProvider( + try: + # Create a conversation so that the schema is set up + settings = ConversationSettings( + model=create_test_embedding_model(embedding_size=3) + ) + settings.semantic_ref_index_settings.auto_extract_knowledge = False + conv = await create_conversation(db_path, TranscriptMessage, settings=settings) + await conv.storage_provider.close() + + # Manually insert embeddings of different sizes into the two tables + conn = sqlite3.connect(db_path) + msg_emb = serialize_embedding(np.array([0.1, 0.2, 0.3], dtype=np.float32)) + term_emb = serialize_embedding( + np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.float32) + ) + conn.execute( + "INSERT INTO MessageTextIndex " + "(msg_id, chunk_ordinal, embedding, index_position) " + "VALUES (0, 0, ?, 0)", + (msg_emb,), + ) + conn.execute( + "INSERT INTO RelatedTermsFuzzy (term, term_embedding) VALUES (?, ?)", + ("hello", term_emb), + ) + conn.commit() + conn.close() + + # Reopening should detect the mismatch + settings2 = ConversationSettings( + model=create_test_embedding_model(embedding_size=3) + ) + with pytest.raises(ValueError, match="Embedding size mismatch"): + SqliteStorageProvider( db_path=db_path, message_type=TranscriptMessage, message_text_index_settings=settings2.message_text_index_settings, related_term_index_settings=settings2.related_term_index_settings, ) - await provider.close() finally: - # Clean up the temporary database if os.path.exists(db_path): os.unlink(db_path) @pytest.mark.asyncio -async def test_empty_db_no_error(): - """Test that opening an empty DB doesn't raise an error regardless of embedding size.""" - # Create a temporary database file +async def test_adding_mismatched_embeddings_raises(): + """Test that adding messages with a different embedding size raises ValueError.""" with tempfile.NamedTemporaryFile(suffix=".db", delete=False) as tmp: db_path = tmp.name try: - # Create an empty database + # Create and populate with size-3 embeddings settings1 = ConversationSettings( model=create_test_embedding_model(embedding_size=3) ) - # Disable LLM knowledge extraction to avoid API key requirement settings1.semantic_ref_index_settings.auto_extract_knowledge = False conv1 = await create_conversation( db_path, TranscriptMessage, settings=settings1 ) + await conv1.add_messages_with_indexing( + [ + TranscriptMessage( + text_chunks=["Hello world"], + metadata=TranscriptMessageMeta(speaker="Alice"), + ) + ] + ) await conv1.storage_provider.close() - # Open with different embedding size should work since DB is empty + # Reopen with size-5 embeddings and try to add more messages settings2 = ConversationSettings( model=create_test_embedding_model(embedding_size=5) ) - provider = SqliteStorageProvider( - db_path=db_path, - message_type=TranscriptMessage, - message_text_index_settings=settings2.message_text_index_settings, - related_term_index_settings=settings2.related_term_index_settings, + settings2.semantic_ref_index_settings.auto_extract_knowledge = False + conv2 = await create_conversation( + db_path, TranscriptMessage, settings=settings2 ) - await provider.close() + with pytest.raises(ValueError, match="Embedding size mismatch"): + await conv2.add_messages_with_indexing( + [ + TranscriptMessage( + text_chunks=["Goodbye world"], + metadata=TranscriptMessageMeta(speaker="Bob"), + ) + ] + ) + await conv2.storage_provider.close() finally: - # Clean up the temporary database if os.path.exists(db_path): os.unlink(db_path) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 94f17e7a..24a4ff69 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -19,7 +19,6 @@ async def test_get_embedding_nocache(embedding_model: CachingEmbeddingModel): embedding = await embedding_model.get_embedding_nocache(input_text) assert isinstance(embedding, np.ndarray) - assert embedding.shape == (embedding_model.embedding_size,) assert embedding.dtype == np.float32 @@ -30,7 +29,7 @@ async def test_get_embeddings_nocache(embedding_model: CachingEmbeddingModel): embeddings = await embedding_model.get_embeddings_nocache(inputs) assert isinstance(embeddings, np.ndarray) - assert embeddings.shape == (len(inputs), embedding_model.embedding_size) + assert embeddings.shape[0] == len(inputs) assert embeddings.dtype == np.float32 @@ -85,13 +84,9 @@ async def test_get_embeddings_with_cache( @pytest.mark.asyncio async def test_get_embeddings_empty_input(embedding_model: CachingEmbeddingModel): - """Test retrieving embeddings for an empty input list.""" - inputs: list[str] = [] - embeddings = await embedding_model.get_embeddings(inputs) - - assert isinstance(embeddings, np.ndarray) - assert embeddings.shape == (0, embedding_model.embedding_size) - assert embeddings.dtype == np.float32 + """Test retrieving embeddings for an empty input list raises ValueError.""" + with pytest.raises(ValueError, match="Cannot embed an empty list"): + await embedding_model.get_embeddings([]) @pytest.mark.asyncio diff --git a/tests/test_messageindex.py b/tests/test_messageindex.py index f91ac933..7e00cc45 100644 --- a/tests/test_messageindex.py +++ b/tests/test_messageindex.py @@ -158,7 +158,7 @@ async def test_generate_embedding(needs_auth: None): embedding = await index.generate_embedding("test message") assert embedding is not None - assert len(embedding) == test_model.embedding_size # 3 for test model + assert len(embedding) == 3 # test model uses embedding size 3 dot = float(np.dot(embedding, embedding)) assert abs(dot - 1.0) < 1e-6, f"Embedding not normalized: {dot}" diff --git a/tests/test_model_adapters.py b/tests/test_model_adapters.py index f089c4ba..11907bd9 100644 --- a/tests/test_model_adapters.py +++ b/tests/test_model_adapters.py @@ -21,7 +21,6 @@ from typeagent.aitools.model_adapters import ( configure_models, create_chat_model, - create_embedding_model, PydanticAIChatModel, PydanticAIEmbedder, ) @@ -38,23 +37,6 @@ def test_spec_uses_colon_separator() -> None: create_chat_model("nonexistent_provider_xyz:fake-model") -# --------------------------------------------------------------------------- -# Embedding size -# --------------------------------------------------------------------------- - - -def test_explicit_embedding_size() -> None: - """Passing embedding_size= sets it immediately.""" - model = create_embedding_model("openai:text-embedding-3-small", embedding_size=42) - assert model.embedding_size == 42 - - -def test_default_embedding_size_is_zero() -> None: - """Without embedding_size=, it defaults to 0 (probed on first call).""" - model = create_embedding_model("openai:text-embedding-3-small") - assert model.embedding_size == 0 - - # --------------------------------------------------------------------------- # PydanticAIChatModel adapter # --------------------------------------------------------------------------- @@ -116,7 +98,7 @@ async def test_embedding_adapter_single() -> None: provider_name="test", ) - adapter = PydanticAIEmbedder(mock_embedder, "test-model", 3) + adapter = PydanticAIEmbedder(mock_embedder, "test-model") result = await adapter.get_embedding_nocache("test") assert result.shape == (3,) norm = float(np.linalg.norm(result)) @@ -124,21 +106,12 @@ async def test_embedding_adapter_single() -> None: @pytest.mark.asyncio -async def test_embedding_adapter_probes_size() -> None: - """embedding_size is discovered from the first embedding call.""" +async def test_embedding_adapter_empty_batch_raises() -> None: + """Empty batch raises ValueError.""" mock_embedder = AsyncMock(spec=Embedder) - mock_embedder.embed_documents.return_value = EmbeddingResult( - embeddings=[[1.0, 0.0, 0.0]], - inputs=["probe"], - input_type="document", - model_name="test-model", - provider_name="test", - ) - adapter = PydanticAIEmbedder(mock_embedder, "test-model") - assert adapter.embedding_size == 0 - await adapter.get_embedding_nocache("probe") - assert adapter.embedding_size == 3 + with pytest.raises(ValueError, match="Cannot embed an empty list"): + await adapter.get_embeddings_nocache([]) @pytest.mark.asyncio @@ -153,7 +126,7 @@ async def test_embedding_adapter_batch() -> None: provider_name="test", ) - adapter = PydanticAIEmbedder(mock_embedder, "test-model", 2) + adapter = PydanticAIEmbedder(mock_embedder, "test-model") result = await adapter.get_embeddings_nocache(["a", "b"]) assert result.shape == (2, 2) @@ -170,7 +143,7 @@ async def test_embedding_adapter_caching() -> None: provider_name="test", ) - embedder = PydanticAIEmbedder(mock_embedder, "test-model", 3) + embedder = PydanticAIEmbedder(mock_embedder, "test-model") adapter = CachingEmbeddingModel(embedder) first = await adapter.get_embedding("cached") second = await adapter.get_embedding("cached") @@ -183,7 +156,7 @@ async def test_embedding_adapter_caching() -> None: async def test_embedding_adapter_add_embedding() -> None: """add_embedding() populates the cache.""" mock_embedder = AsyncMock(spec=Embedder) - embedder = PydanticAIEmbedder(mock_embedder, "test-model", 3) + embedder = PydanticAIEmbedder(mock_embedder, "test-model") adapter = CachingEmbeddingModel(embedder) vec: NormalizedEmbedding = np.array([1.0, 0.0, 0.0], dtype=np.float32) adapter.add_embedding("key", vec) @@ -194,13 +167,13 @@ async def test_embedding_adapter_add_embedding() -> None: @pytest.mark.asyncio -async def test_embedding_adapter_empty_batch() -> None: - """Empty batch returns empty array with known size.""" +async def test_embedding_adapter_empty_batch_returns_empty() -> None: + """Empty batch via CachingEmbeddingModel raises ValueError.""" mock_embedder = AsyncMock(spec=Embedder) - embedder = PydanticAIEmbedder(mock_embedder, "test-model", 4) + embedder = PydanticAIEmbedder(mock_embedder, "test-model") adapter = CachingEmbeddingModel(embedder) - result = await adapter.get_embeddings_nocache([]) - assert result.shape == (0, 4) + with pytest.raises(ValueError, match="Cannot embed an empty list"): + await adapter.get_embeddings([]) # --------------------------------------------------------------------------- diff --git a/tests/test_vectorbase.py b/tests/test_vectorbase.py index 416ed3eb..81ccecc6 100644 --- a/tests/test_vectorbase.py +++ b/tests/test_vectorbase.py @@ -195,3 +195,28 @@ def test_fuzzy_lookup_embedding_in_subset( # Empty subset returns empty list result = vector_base.fuzzy_lookup_embedding_in_subset(query, []) assert result == [] + + +def test_add_embedding_size_mismatch(vector_base: VectorBase) -> None: + """Adding an embedding of wrong size raises ValueError.""" + emb3 = np.array([0.1, 0.2, 0.3], dtype=np.float32) + emb5 = np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.float32) + vector_base.add_embedding(None, emb3) + with pytest.raises(ValueError, match="Embedding size mismatch"): + vector_base.add_embedding(None, emb5) + + +def test_add_embeddings_size_mismatch(vector_base: VectorBase) -> None: + """Adding a batch of embeddings of wrong size raises ValueError.""" + batch3 = np.array([[0.1, 0.2, 0.3]], dtype=np.float32) + batch5 = np.array([[0.1, 0.2, 0.3, 0.4, 0.5]], dtype=np.float32) + vector_base.add_embeddings(None, batch3) + with pytest.raises(ValueError, match="Embedding size mismatch"): + vector_base.add_embeddings(None, batch5) + + +def test_add_embeddings_wrong_ndim(vector_base: VectorBase) -> None: + """Adding a 1D array via add_embeddings raises ValueError.""" + emb1d = np.array([0.1, 0.2, 0.3], dtype=np.float32) + with pytest.raises(ValueError, match="Expected 2D"): + vector_base.add_embeddings(None, emb1d) From 091bd58564629d5346230af97aaf5e9e778a0c22 Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 19:52:27 -0800 Subject: [PATCH 22/49] Change default embedding back to ada-002 for backwards compatibility --- src/typeagent/aitools/embeddings.py | 4 +--- src/typeagent/aitools/model_adapters.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/typeagent/aitools/embeddings.py b/src/typeagent/aitools/embeddings.py index 5d0e172c..8b579df2 100644 --- a/src/typeagent/aitools/embeddings.py +++ b/src/typeagent/aitools/embeddings.py @@ -114,12 +114,10 @@ async def get_embeddings(self, keys: list[str]) -> NormalizedEmbeddings: return np.array([self._cache[k] for k in keys], dtype=np.float32) -DEFAULT_MODEL_NAME = "text-embedding-ada-002" -DEFAULT_ENVVAR = "AZURE_OPENAI_ENDPOINT_EMBEDDING" # We support OpenAI and Azure OpenAI TEST_MODEL_NAME = "test" model_to_envvar: dict[str, str] = { - DEFAULT_MODEL_NAME: DEFAULT_ENVVAR, + "text-embedding-ada-002": "AZURE_OPENAI_ENDPOINT_EMBEDDING", "text-embedding-3-small": "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_SMALL", "text-embedding-3-large": "AZURE_OPENAI_ENDPOINT_EMBEDDING_3_LARGE", } diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 86dca394..164cfdca 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -235,7 +235,7 @@ def create_chat_model( return PydanticAIChatModel(model) -DEFAULT_EMBEDDING_SPEC = "openai:text-embedding-3-small" +DEFAULT_EMBEDDING_SPEC = "openai:text-embedding-ada-002" def create_embedding_model( From 7c1c5272e58069cf35964544eb6d40846659447c Mon Sep 17 00:00:00 2001 From: Guido van Rossum Date: Wed, 25 Feb 2026 20:57:01 -0800 Subject: [PATCH 23/49] Add OPENAI_EMBEDDING_MODEL envvar to set the text embedding (e.g. text-embedding-3-small) --- src/typeagent/aitools/model_adapters.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/src/typeagent/aitools/model_adapters.py b/src/typeagent/aitools/model_adapters.py index 164cfdca..34d5ac84 100644 --- a/src/typeagent/aitools/model_adapters.py +++ b/src/typeagent/aitools/model_adapters.py @@ -229,6 +229,11 @@ def create_chat_model( if _needs_azure_fallback(provider): from pydantic_ai.models.openai import OpenAIChatModel + if os.getenv("OPENAI_MODEL"): + print( + f"OPENAI_MODEL={os.getenv('OPENAI_MODEL')!r} ignored; " + f"Azure deployment is determined by AZURE_OPENAI_ENDPOINT" + ) model = OpenAIChatModel(model_name, provider=_make_azure_provider()) else: model = infer_model(model_spec) @@ -247,19 +252,26 @@ def create_embedding_model( If the spec uses ``openai:`` and ``OPENAI_API_KEY`` is not set but ``AZURE_OPENAI_API_KEY`` is, Azure OpenAI is used automatically. - If *model_spec* is ``None``, :data:`DEFAULT_EMBEDDING_SPEC` is used. + If *model_spec* is ``None``, it is constructed from the + ``OPENAI_EMBEDDING_MODEL`` environment variable (falling back to + :data:`DEFAULT_EMBEDDING_SPEC`). Returns a :class:`~typeagent.aitools.embeddings.CachingEmbeddingModel` wrapping a :class:`PydanticAIEmbedder`. Examples:: + model = create_embedding_model() # uses OPENAI_EMBEDDING_MODEL or ada-002 model = create_embedding_model("openai:text-embedding-3-small") model = create_embedding_model("cohere:embed-english-v3.0") model = create_embedding_model("google:text-embedding-004") """ if model_spec is None: - model_spec = DEFAULT_EMBEDDING_SPEC + openai_embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL") + if openai_embedding_model: + model_spec = f"openai:{openai_embedding_model}" + else: + model_spec = DEFAULT_EMBEDDING_SPEC provider, _, model_name = model_spec.partition(":") if not model_name: model_name = provider # No colon in spec From 400f86932ce1dbd037fef73b5bdae3c915efbbc3 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:24 +0100 Subject: [PATCH 24/49] speed optimization: reduce lookup to cache the ordinals_of_subset Total Complexity reduces from O(n*k) to O(n+k) --- src/typeagent/aitools/vectorbase.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/typeagent/aitools/vectorbase.py b/src/typeagent/aitools/vectorbase.py index 46c4dbae..63e2e77a 100644 --- a/src/typeagent/aitools/vectorbase.py +++ b/src/typeagent/aitools/vectorbase.py @@ -150,8 +150,9 @@ def fuzzy_lookup_embedding_in_subset( max_hits: int | None = None, min_score: float | None = None, ) -> list[ScoredInt]: + ordinals_set = set(ordinals_of_subset) return self.fuzzy_lookup_embedding( - embedding, max_hits, min_score, lambda i: i in ordinals_of_subset + embedding, max_hits, min_score, lambda i: i in ordinals_set ) async def fuzzy_lookup( From 1cf46b45dd1dafbcd5b11ccca6d91ccd0ed39058 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:24 +0100 Subject: [PATCH 25/49] parse_azure_endpoint regex missed & separator [?,] e.g. didn't match &api-version=... in multi-parameter URLs. Fixed to [?&,]. --- src/typeagent/aitools/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index b9150334..e556633b 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -191,7 +191,7 @@ def parse_azure_endpoint( if not azure_endpoint: raise RuntimeError(f"Environment variable {endpoint_envvar} not found") - m = re.search(r"[?,]api-version=([\d-]+(?:preview)?)", azure_endpoint) + m = re.search(r"[?&,]api-version=([\d-]+(?:preview)?)", azure_endpoint) if not m: raise RuntimeError( f"{endpoint_envvar}={azure_endpoint} doesn't contain valid api-version field" From aab0eee0a3beb66bd9b6a58f7828173f3242036f Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:24 +0100 Subject: [PATCH 26/49] added test for parsing azure endpoint urls --- tests/test_utils.py | 50 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index ceea367d..6ac075c9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -6,6 +6,7 @@ import os from dotenv import load_dotenv +import pytest import pydantic.dataclasses import typechat @@ -51,3 +52,52 @@ class DummySchema: # This will raise if the environment or typechat is not set up correctly translator = utils.create_translator(DummyModel(), DummySchema) assert hasattr(translator, "model") + + +class TestParseAzureEndpoint: + """Tests for parse_azure_endpoint regex matching.""" + + def test_api_version_after_question_mark( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """api-version as the first (and only) query parameter.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com/openai/deployments/gpt-4?api-version=2025-01-01-preview", + ) + endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + assert version == "2025-01-01-preview" + assert endpoint.startswith("https://") + + def test_api_version_after_ampersand(self, monkeypatch: pytest.MonkeyPatch) -> None: + """api-version preceded by & (not the first query parameter).""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com/openai/deployments/gpt-4?foo=bar&api-version=2025-01-01-preview", + ) + endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + assert version == "2025-01-01-preview" + + def test_api_version_after_comma(self, monkeypatch: pytest.MonkeyPatch) -> None: + """api-version preceded by comma (alternate separator).""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com/openai/deployments/gpt-4?foo=bar,api-version=2024-06-01", + ) + endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + assert version == "2024-06-01" + + def test_missing_env_var_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + """RuntimeError when the environment variable is not set.""" + monkeypatch.delenv("NONEXISTENT_ENDPOINT", raising=False) + with pytest.raises(RuntimeError, match="not found"): + utils.parse_azure_endpoint("NONEXISTENT_ENDPOINT") + + def test_no_api_version_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: + """RuntimeError when the endpoint has no api-version field.""" + monkeypatch.setenv( + "TEST_ENDPOINT", + "https://myhost.openai.azure.com/openai/deployments/gpt-4", + ) + with pytest.raises(RuntimeError, match="doesn't contain valid api-version"): + utils.parse_azure_endpoint("TEST_ENDPOINT") From 4e9127a42452106d1d23be6a4d9d1612943f388d Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:24 +0100 Subject: [PATCH 27/49] fixed parse_azure_endpoint tests --- tests/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 6ac075c9..bc4ec425 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -75,7 +75,7 @@ def test_api_version_after_ampersand(self, monkeypatch: pytest.MonkeyPatch) -> N "TEST_ENDPOINT", "https://myhost.openai.azure.com/openai/deployments/gpt-4?foo=bar&api-version=2025-01-01-preview", ) - endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + _, version = utils.parse_azure_endpoint("TEST_ENDPOINT") assert version == "2025-01-01-preview" def test_api_version_after_comma(self, monkeypatch: pytest.MonkeyPatch) -> None: @@ -84,7 +84,7 @@ def test_api_version_after_comma(self, monkeypatch: pytest.MonkeyPatch) -> None: "TEST_ENDPOINT", "https://myhost.openai.azure.com/openai/deployments/gpt-4?foo=bar,api-version=2024-06-01", ) - endpoint, version = utils.parse_azure_endpoint("TEST_ENDPOINT") + _, version = utils.parse_azure_endpoint("TEST_ENDPOINT") assert version == "2024-06-01" def test_missing_env_var_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: From f6ae405cd1fe386e2660cdf0c58181c6cc90876d Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:24 +0100 Subject: [PATCH 28/49] =?UTF-8?q?1st=20change:=20There=20was=20a=20mutable?= =?UTF-8?q?=20default=20argument=20in=20collections.py=20=E2=80=94=20Seman?= =?UTF-8?q?ticRefAccumulator.=5F=5Finit=5F=5F=20used=20set()=20as=20defaul?= =?UTF-8?q?t=20parameter,=20shared=20across=20all=20instances.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The code has been rewritten to still track which search terms produced hits in the accumulator but in a way to avoid mutable default arguments. Now: __init__ always creates a fresh set() — no parameter Add a with_term_matches factory that creates a new accumulator with a copy of term matches (makes the copy-vs-share explicit) Update group_matches_by_type to use a copy instead of sharing the same set object Update get_matches_in_scope and WhereSemanticRefExpr to use the factory 2nd change: hit_count=1 for non-exact matches in collections.py MatchAccumulator.add — inflated hit counts for related-only matches. Fixed to hit_count=0. added comments to the else branches added additional tests --- src/typeagent/knowpro/collections.py | 42 +++++++++---- tests/test_collections.py | 90 +++++++++++++++++++++++++++- 2 files changed, 120 insertions(+), 12 deletions(-) diff --git a/src/typeagent/knowpro/collections.py b/src/typeagent/knowpro/collections.py index 1e5205f3..c73e83a1 100644 --- a/src/typeagent/knowpro/collections.py +++ b/src/typeagent/knowpro/collections.py @@ -77,10 +77,12 @@ def add(self, value: T, score: float, is_exact_match: bool = True) -> None: existing_match.hit_count += 1 existing_match.score += score else: + # Related (non-exact) match: only accumulate related counters. existing_match.related_hit_count += 1 existing_match.related_score += score else: if is_exact_match: + # New exact match: starts with hit_count=1 and the given score. self.set_match( Match( value, @@ -91,10 +93,14 @@ def add(self, value: T, score: float, is_exact_match: bool = True) -> None: ) ) else: + # New related-only match: hit_count stays 0 because + # only exact matches count as direct hits. This matters + # for select_with_hit_count / _matches_with_min_hit_count + # which filter on hit_count to weed out noise. self.set_match( Match( value, - hit_count=1, + hit_count=0, score=0.0, related_hit_count=1, related_score=score, @@ -250,9 +256,25 @@ def smooth_match_score[T](match: Match[T]) -> None: class SemanticRefAccumulator(MatchAccumulator[SemanticRefOrdinal]): - def __init__(self, search_term_matches: set[str] = set()): + """Accumulates scored semantic reference matches. + + ``search_term_matches`` tracks which search terms produced hits (provenance). + Use ``with_term_matches`` to create a derived accumulator that inherits a + *copy* of the parent's provenance. + """ + + def __init__(self) -> None: super().__init__() - self.search_term_matches = search_term_matches + self.search_term_matches: set[str] = set() + + @classmethod + def with_term_matches( + cls, source: "SemanticRefAccumulator" + ) -> "SemanticRefAccumulator": + """Create a new accumulator inheriting a copy of *source*'s term-match provenance.""" + acc = cls() + acc.search_term_matches = set(source.search_term_matches) + return acc def add_term_matches( self, @@ -330,8 +352,7 @@ async def group_matches_by_type( semantic_ref = await semantic_refs.get_item(match.value) group = groups.get(semantic_ref.knowledge.knowledge_type) if group is None: - group = SemanticRefAccumulator() - group.search_term_matches = self.search_term_matches + group = SemanticRefAccumulator.with_term_matches(self) groups[semantic_ref.knowledge.knowledge_type] = group group.set_match(match) return groups @@ -341,7 +362,7 @@ async def get_matches_in_scope( semantic_refs: ISemanticRefCollection, ranges_in_scope: "TextRangesInScope", ) -> "SemanticRefAccumulator": - accumulator = SemanticRefAccumulator(self.search_term_matches) + accumulator = SemanticRefAccumulator.with_term_matches(self) for match in self: if ranges_in_scope.is_range_in_scope( (await semantic_refs.get_item(match.value)).range @@ -516,12 +537,13 @@ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> No def is_in_range(self, inner_range: TextRange) -> bool: if len(self._ranges) == 0: return False - i = bisect.bisect_left(self._ranges, inner_range) - for outer_range in self._ranges[i:]: - if outer_range.start > inner_range.start: - break + for outer_range in self._ranges: if inner_range in outer_range: return True + # Since ranges are sorted by start, once we pass inner_range's start + # no further range can contain it. + if outer_range.start > inner_range.start: + break return False diff --git a/tests/test_collections.py b/tests/test_collections.py index f35c3fe9..9f913d0d 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -112,8 +112,10 @@ def test_text_range_collection_add_and_check(): assert collection.is_in_range(range1) is True assert collection.is_in_range(range2) is True - assert collection.is_in_range(range3) is False - assert collection.is_in_range(range4) is False + assert ( + collection.is_in_range(range3) is True + ) # range3 [5,10) is inside range1 [0,10) + assert collection.is_in_range(range4) is False # range4 [5,25) spans across ranges assert collection.is_in_range(range5) is False @@ -406,6 +408,90 @@ def test_match_accumulator_select_top_n_scoring(): assert matches[1].value == "medium" +def test_match_accumulator_add_non_exact_match(): + """Non-exact (related) matches must start with hit_count=0.""" + accumulator = MatchAccumulator[str]() + accumulator.add("related_term", score=0.7, is_exact_match=False) + + match = accumulator.get_match("related_term") + assert match is not None + assert match.hit_count == 0 + assert match.score == 0.0 + assert match.related_hit_count == 1 + assert match.related_score == 0.7 + + +def test_match_accumulator_non_exact_filtered_by_min_hit_count(): + """Related-only matches should be excluded by min_hit_count=1 filter.""" + accumulator = MatchAccumulator[str]() + accumulator.add("exact_term", score=1.0, is_exact_match=True) + accumulator.add("related_term", score=0.9, is_exact_match=False) + + matches = list(accumulator._matches_with_min_hit_count(min_hit_count=1)) # type: ignore + assert len(matches) == 1 + assert matches[0].value == "exact_term" + + +def test_match_accumulator_related_then_exact_same_value(): + """Adding a related match then an exact match for the same value.""" + accumulator = MatchAccumulator[str]() + accumulator.add("term", score=0.5, is_exact_match=False) + accumulator.add("term", score=1.0, is_exact_match=True) + + match = accumulator.get_match("term") + assert match is not None + assert match.hit_count == 1 + assert match.score == 1.0 + assert match.related_hit_count == 1 + assert match.related_score == 0.5 + + +def test_match_accumulator_exact_then_related_same_value(): + """Adding an exact match then a related match for the same value.""" + accumulator = MatchAccumulator[str]() + accumulator.add("term", score=1.0, is_exact_match=True) + accumulator.add("term", score=0.3, is_exact_match=False) + + match = accumulator.get_match("term") + assert match is not None + assert match.hit_count == 1 + assert match.score == 1.0 + assert match.related_hit_count == 1 + assert match.related_score == 0.3 + + +def test_match_accumulator_multiple_related_accumulate(): + """Multiple related matches for the same value accumulate correctly.""" + accumulator = MatchAccumulator[str]() + accumulator.add("term", score=0.4, is_exact_match=False) + accumulator.add("term", score=0.6, is_exact_match=False) + + match = accumulator.get_match("term") + assert match is not None + assert match.hit_count == 0 + assert match.score == 0.0 + assert match.related_hit_count == 2 + assert match.related_score == pytest.approx(1.0) + + +def test_match_accumulator_total_score_includes_related(): + """calculate_total_score adds smoothed related score to the main score.""" + accumulator = MatchAccumulator[str]() + accumulator.add("exact_only", score=2.0, is_exact_match=True) + accumulator.add("mixed", score=1.0, is_exact_match=True) + accumulator.add("mixed", score=0.5, is_exact_match=False) + + accumulator.calculate_total_score() + + exact_only = accumulator.get_match("exact_only") + mixed = accumulator.get_match("mixed") + assert exact_only is not None + assert mixed is not None + # "mixed" should have a higher score than its raw 1.0 + # because the related_score of 0.5 is added (smoothed). + assert mixed.score > 1.0 + + def test_get_smooth_score(): """Test calculating smooth scores.""" assert get_smooth_score(10.0, 1) == 10.0 # Single hit count, no smoothing From 9c7d37ab348ab07e6d802022ad445789dea21946 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:24 +0100 Subject: [PATCH 29/49] =?UTF-8?q?The=20original=20=5F=5Frepr=5F=5F=20used?= =?UTF-8?q?=20dir(self),=20which=20returns=20all=20attributes=20on=20the?= =?UTF-8?q?=20object=20=E2=80=94=20including=20inherited=20methods,=20dund?= =?UTF-8?q?er=20methods=20(=5F=5Finit=5F=5F,=20=5F=5Feq=5F=5F,=20=E2=80=A6?= =?UTF-8?q?),=20and=20class-level=20descriptors.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The fix switches to vars(self), which returns only the instance's __dict__ — i.e. the actual dataclass field values. added testcases removed unused method from class --- src/typeagent/knowpro/searchlang.py | 25 ++------- tests/test_searchlang.py | 79 +++++++++++++++++++++++++++++ 2 files changed, 84 insertions(+), 20 deletions(-) create mode 100644 tests/test_searchlang.py diff --git a/src/typeagent/knowpro/searchlang.py b/src/typeagent/knowpro/searchlang.py index dbb8092a..54b7eb7b 100644 --- a/src/typeagent/knowpro/searchlang.py +++ b/src/typeagent/knowpro/searchlang.py @@ -83,11 +83,9 @@ class LanguageSearchOptions(SearchOptions): def __repr__(self): parts = [] - for key in dir(self): - if not key.startswith("_"): - value = getattr(self, key) - if value is not None: - parts.append(f"{key}={value!r}") + for key, value in vars(self).items(): + if not key.startswith("_") and value is not None: + parts.append(f"{key}={value!r}") return f"{self.__class__.__name__}({', '.join(parts)})" @@ -371,6 +369,8 @@ def compile_action_term_as_search_terms( self.compile_entity_terms_as_search_terms( action_term.additional_entities, action_group ) + if use_or_max and action_group.terms: + term_group.terms.append(action_group) return term_group def compile_search_terms( @@ -609,21 +609,6 @@ def add_entity_name_to_group( exact_match_value, ) - def add_search_term_to_groupadd_entity_name_to_group( - self, - entity_term: EntityTerm, - property_name: PropertyNames, - term_group: SearchTermGroup, - exact_match_value: bool = False, - ) -> None: - if not entity_term.is_name_pronoun: - self.add_property_term_to_group( - property_name.value, - entity_term.name, - term_group, - exact_match_value, - ) - def add_property_term_to_group( self, property_name: str, diff --git a/tests/test_searchlang.py b/tests/test_searchlang.py new file mode 100644 index 00000000..46cdae75 --- /dev/null +++ b/tests/test_searchlang.py @@ -0,0 +1,79 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typeagent.knowpro.search import SearchOptions +from typeagent.knowpro.searchlang import ( + LanguageQueryCompileOptions, + LanguageSearchOptions, +) + + +class TestSearchOptionsRepr: + """Tests for the custom __repr__ on SearchOptions and LanguageSearchOptions.""" + + def test_all_defaults_shows_non_none_fields(self) -> None: + """Default fields that are not None (like exact_match=False) appear.""" + opts = SearchOptions() + r = repr(opts) + assert r.startswith("SearchOptions(") + # exact_match defaults to False, which is not None, so it shows up: + assert "exact_match=False" in r + # None-valued fields are omitted: + assert "max_knowledge_matches" not in r + + def test_non_none_fields_shown(self) -> None: + opts = SearchOptions(max_knowledge_matches=10, threshold_score=0.5) + r = repr(opts) + assert "max_knowledge_matches=10" in r + assert "threshold_score=0.5" in r + # Fields left at None are omitted: + assert "max_message_matches" not in r + assert "max_chars_in_budget" not in r + + def test_false_field_shown(self) -> None: + """False is not None, so it should appear.""" + opts = SearchOptions(exact_match=False) + assert "exact_match=False" in repr(opts) + + def test_true_field_shown(self) -> None: + opts = SearchOptions(exact_match=True) + assert "exact_match=True" in repr(opts) + + +class TestLanguageSearchOptionsRepr: + """Tests for LanguageSearchOptions.__repr__ (subclass of SearchOptions).""" + + def test_all_defaults_shows_class_name(self) -> None: + opts = LanguageSearchOptions() + r = repr(opts) + # Subclass name, not parent name: + assert r.startswith("LanguageSearchOptions(") + + def test_inherited_and_own_fields(self) -> None: + opts = LanguageSearchOptions( + max_knowledge_matches=5, + compile_options=LanguageQueryCompileOptions(exact_scope=True), + ) + r = repr(opts) + assert "LanguageSearchOptions(" in r + assert "max_knowledge_matches=5" in r + assert "compile_options=" in r + assert "exact_scope=True" in r + + def test_none_fields_omitted(self) -> None: + opts = LanguageSearchOptions() + r = repr(opts) + assert "compile_options" not in r + assert "model_instructions" not in r + assert "max_knowledge_matches" not in r + + def test_no_private_fields(self) -> None: + """Fields starting with _ should never appear in repr.""" + opts = LanguageSearchOptions(max_knowledge_matches=3) + r = repr(opts) + # No key=value pair where the key starts with underscore: + inside = r.split("(", 1)[1].rstrip(")") + for part in inside.split(", "): + if "=" in part: + key = part.split("=", 1)[0] + assert not key.startswith("_"), f"private field {key!r} in repr" From 6e6972ef2dc4ccc224d1907fed4dafbf6611e3a3 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:24 +0100 Subject: [PATCH 30/49] =?UTF-8?q?facets=5Fto=5Fmerged=5Ffacets=20used=20st?= =?UTF-8?q?r(facet)=20instead=20of=20str(facet.value)=20=E2=80=94=20produc?= =?UTF-8?q?ed=20dataclass=20repr=20strings=20instead=20of=20actual=20facet?= =?UTF-8?q?=20values.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Off-by-one in get_enclosing_date_range_for_text_range — used exclusive end ordinal directly, potentially indexing past the last message. Fixed to use message_ordinal - 1. added testcases --- src/typeagent/knowpro/answers.py | 4 +- tests/test_answers.py | 180 +++++++++++++++++++++++++++++++ 2 files changed, 182 insertions(+), 2 deletions(-) create mode 100644 tests/test_answers.py diff --git a/src/typeagent/knowpro/answers.py b/src/typeagent/knowpro/answers.py index ae6fad98..c6832efb 100644 --- a/src/typeagent/knowpro/answers.py +++ b/src/typeagent/knowpro/answers.py @@ -405,7 +405,7 @@ async def get_enclosing_date_range_for_text_range( if not start_timestamp: return None end_timestamp = ( - (await messages.get_item(range.end.message_ordinal)).timestamp + (await messages.get_item(range.end.message_ordinal - 1)).timestamp if range.end else None ) @@ -535,7 +535,7 @@ def facets_to_merged_facets(facets: list[Facet]) -> MergedFacets: merged_facets: MergedFacets = {} for facet in facets: name = facet.name.lower() - value = str(facet).lower() + value = str(facet.value).lower() merged_facets.setdefault(name, []).append(value) return merged_facets diff --git a/tests/test_answers.py b/tests/test_answers.py new file mode 100644 index 00000000..2669ea78 --- /dev/null +++ b/tests/test_answers.py @@ -0,0 +1,180 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from collections.abc import AsyncGenerator + +import pytest +import pytest_asyncio + +from typeagent.knowpro.answers import ( + facets_to_merged_facets, + get_enclosing_date_range_for_text_range, + get_enclosing_text_range, + merged_facets_to_facets, + text_range_from_message_range, +) +from typeagent.knowpro.interfaces import TextLocation, TextRange +from typeagent.knowpro.kplib import Facet + +from conftest import FakeMessage, FakeMessageCollection + +# --------------------------------------------------------------------------- +# Change 1: facets_to_merged_facets uses str(facet.value), not str(facet) +# --------------------------------------------------------------------------- + + +class TestFacetsToMergedFacets: + """Verify that facet *values* (not the whole Facet object) are stringified.""" + + def test_string_value(self) -> None: + facets = [Facet(name="colour", value="red")] + merged = facets_to_merged_facets(facets) + assert merged == {"colour": ["red"]} + + def test_numeric_value(self) -> None: + facets = [Facet(name="age", value=30.0)] + merged = facets_to_merged_facets(facets) + # Should be "30.0", NOT "Facet('age', 30.0)" + assert merged == {"age": ["30.0"]} + assert "Facet" not in merged["age"][0] + + def test_bool_value(self) -> None: + facets = [Facet(name="active", value=True)] + merged = facets_to_merged_facets(facets) + assert merged == {"active": ["true"]} + + def test_multiple_facets_same_name(self) -> None: + facets = [ + Facet(name="tag", value="a"), + Facet(name="tag", value="b"), + ] + merged = facets_to_merged_facets(facets) + assert merged == {"tag": ["a", "b"]} + + def test_lowercases_names_and_values(self) -> None: + facets = [Facet(name="Colour", value="RED")] + merged = facets_to_merged_facets(facets) + assert "colour" in merged + assert merged["colour"] == ["red"] + + def test_roundtrip_through_merged(self) -> None: + """facets_to_merged_facets -> merged_facets_to_facets preserves semantics.""" + original = [ + Facet(name="colour", value="red"), + Facet(name="colour", value="blue"), + Facet(name="size", value="large"), + ] + merged = facets_to_merged_facets(original) + restored = merged_facets_to_facets(merged) + restored_by_name = {f.name: f.value for f in restored} + assert restored_by_name["colour"] == "red; blue" + assert restored_by_name["size"] == "large" + + +# --------------------------------------------------------------------------- +# Change 2: get_enclosing_date_range_for_text_range uses ordinal-1 for end +# --------------------------------------------------------------------------- + + +class TestGetEnclosingDateRangeForTextRange: + """Verify the off-by-one fix: end is exclusive, so we subtract 1.""" + + @pytest_asyncio.fixture() + async def messages(self) -> AsyncGenerator[FakeMessageCollection, None]: + """Three messages with ordinals 0, 1, 2 and timestamps derived from them.""" + coll = FakeMessageCollection() + for i in range(3): + msg = FakeMessage("text", message_ordinal=i) + await coll.append(msg) + yield coll + + @pytest.mark.asyncio + async def test_single_message_range(self, messages: FakeMessageCollection) -> None: + """Point range (end=None) should use only the start message's timestamp.""" + tr = TextRange(start=TextLocation(1)) + dr = await get_enclosing_date_range_for_text_range(messages, tr) + assert dr is not None + assert dr.start.hour == 1 + assert dr.end is None + + @pytest.mark.asyncio + async def test_multi_message_range_uses_last_included( + self, messages: FakeMessageCollection + ) -> None: + """Range [0, 2) should use message 1 for end (ordinal 2-1=1), not message 2.""" + tr = TextRange( + start=TextLocation(0), + end=TextLocation(2), # exclusive end + ) + dr = await get_enclosing_date_range_for_text_range(messages, tr) + assert dr is not None + assert dr.start.hour == 0 + # End timestamp comes from message ordinal 1 (= 2-1), NOT ordinal 2: + assert dr.end is not None + assert dr.end.hour == 1 + + @pytest.mark.asyncio + async def test_adjacent_messages(self, messages: FakeMessageCollection) -> None: + """Range [1, 2) covers only message 1.""" + tr = TextRange( + start=TextLocation(1), + end=TextLocation(2), + ) + dr = await get_enclosing_date_range_for_text_range(messages, tr) + assert dr is not None + assert dr.start.hour == 1 + assert dr.end is not None + assert dr.end.hour == 1 # same message: end-1 == start + + @pytest.mark.asyncio + async def test_no_timestamp_returns_none(self) -> None: + """If start message has no timestamp, return None.""" + coll = FakeMessageCollection() + msg = FakeMessage("text") # no message_ordinal → no timestamp + await coll.append(msg) + tr = TextRange(start=TextLocation(0)) + dr = await get_enclosing_date_range_for_text_range(coll, tr) + assert dr is None + + +# --------------------------------------------------------------------------- +# Helper functions (also exercised for completeness) +# --------------------------------------------------------------------------- + + +class TestGetEnclosingTextRange: + def test_single_ordinal(self) -> None: + tr = get_enclosing_text_range([5]) + assert tr is not None + assert tr.start.message_ordinal == 5 + assert tr.end is None # point range + + def test_multiple_ordinals(self) -> None: + tr = get_enclosing_text_range([3, 1, 7]) + assert tr is not None + assert tr.start.message_ordinal == 1 + assert tr.end is not None + assert tr.end.message_ordinal == 7 + + def test_empty_ordinals(self) -> None: + tr = get_enclosing_text_range([]) + assert tr is None + + +class TestTextRangeFromMessageRange: + def test_point(self) -> None: + tr = text_range_from_message_range(3, 3) + assert tr is not None + assert tr.start.message_ordinal == 3 + assert tr.end is None + + def test_range(self) -> None: + tr = text_range_from_message_range(2, 5) + assert tr is not None + assert tr.start.message_ordinal == 2 + assert tr.end is not None + assert tr.end.message_ordinal == 5 + + def test_invalid_raises(self) -> None: + with pytest.raises(ValueError, match="Expect message ordinal range"): + text_range_from_message_range(5, 2) From ccb96f25a684b41fb80b6928a211439bc387b8dd Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:24 +0100 Subject: [PATCH 31/49] Remove duplicate __all__ in interfaces_search.py --- src/typeagent/knowpro/interfaces_search.py | 17 +++-------------- 1 file changed, 3 insertions(+), 14 deletions(-) diff --git a/src/typeagent/knowpro/interfaces_search.py b/src/typeagent/knowpro/interfaces_search.py index 4ff20c7f..79539e58 100644 --- a/src/typeagent/knowpro/interfaces_search.py +++ b/src/typeagent/knowpro/interfaces_search.py @@ -18,16 +18,15 @@ ) __all__ = [ - "SearchTerm", "KnowledgePropertyName", "PropertySearchTerm", + "SearchSelectExpr", + "SearchTerm", "SearchTermGroup", "SearchTermGroupTypes", + "SemanticRefSearchResult", "WhenFilter", - "SearchSelectExpr", ] - - @dataclass class SearchTerm: """Represents a term being searched for. @@ -144,13 +143,3 @@ class SemanticRefSearchResult: semantic_ref_matches: list[ScoredSemanticRefOrdinal] -__all__ = [ - "KnowledgePropertyName", - "PropertySearchTerm", - "SearchSelectExpr", - "SearchTerm", - "SearchTermGroup", - "SearchTermGroupTypes", - "SemanticRefSearchResult", - "WhenFilter", -] From b4d4a4036fcecdf2eceb9b929fb42286afe023b2 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:25 +0100 Subject: [PATCH 32/49] when cur_chunk was empty (i.e. at the very start, or right after yielding a full chunk), the merged result would begin with a spurious separator like "\n\n". The fix adds a guard added testcases --- src/typeagent/emails/email_import.py | 3 +- tests/test_email_import.py | 102 +++++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 tests/test_email_import.py diff --git a/src/typeagent/emails/email_import.py b/src/typeagent/emails/email_import.py index ce61d6bb..88f13b94 100644 --- a/src/typeagent/emails/email_import.py +++ b/src/typeagent/emails/email_import.py @@ -263,7 +263,8 @@ def _merge_chunks( yield cur_chunk cur_chunk = new_chunk else: - cur_chunk += separator + if cur_chunk: + cur_chunk += separator cur_chunk += new_chunk if (len(cur_chunk)) > 0: diff --git a/tests/test_email_import.py b/tests/test_email_import.py new file mode 100644 index 00000000..371136bc --- /dev/null +++ b/tests/test_email_import.py @@ -0,0 +1,102 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +from typeagent.emails.email_import import ( + _merge_chunks, + _split_into_paragraphs, + _text_to_chunks, +) + + +class TestMergeChunks: + """Tests for _merge_chunks, specifically the separator-on-empty-chunk fix.""" + + def test_no_leading_separator(self) -> None: + """First chunk must NOT start with the separator.""" + result = list(_merge_chunks(["hello", "world"], "\n\n", 100)) + assert len(result) == 1 + assert result[0] == "hello\n\nworld" + assert not result[0].startswith("\n") + + def test_no_leading_separator_after_yield(self) -> None: + """After yielding a full chunk, the next chunk must not start with separator.""" + # Each piece is 5 chars; max_chunk_length=8 forces a split after each. + pieces = ["aaaaa", "bbbbb", "ccccc"] + result = list(_merge_chunks(pieces, "--", 8)) + for chunk in result: + assert not chunk.startswith("--"), f"chunk {chunk!r} starts with separator" + + def test_single_chunk(self) -> None: + result = list(_merge_chunks(["only"], "\n\n", 100)) + assert result == ["only"] + + def test_empty_input(self) -> None: + result = list(_merge_chunks([], "\n\n", 100)) + assert result == [] + + def test_exact_fit(self) -> None: + """Two chunks that fit exactly within max_chunk_length.""" + # "ab" + "\n\n" + "cd" = 6 chars + result = list(_merge_chunks(["ab", "cd"], "\n\n", 6)) + assert result == ["ab\n\ncd"] + + def test_overflow_splits(self) -> None: + """Chunks that don't fit together should be yielded separately.""" + # "ab" + "\n\n" + "cd" = 6 chars, max is 5 -> must split + result = list(_merge_chunks(["ab", "cd"], "\n\n", 5)) + assert result == ["ab", "cd"] + + def test_truncation_of_oversized_chunk(self) -> None: + """A single chunk longer than max_chunk_length is truncated.""" + result = list(_merge_chunks(["abcdefghij"], "\n\n", 5)) + assert result == ["abcde"] + + def test_multiple_merges_and_splits(self) -> None: + pieces = ["aa", "bb", "cccccc", "dd"] + # "aa" + "--" + "bb" = 6, fits in 8 + # "cccccc" alone = 6, can't merge with previous (6+2+6=14>8), yield "aa--bb" + # "cccccc" + "--" + "dd" = 10 > 8, yield "cccccc" + # "dd" yielded at end + result = list(_merge_chunks(pieces, "--", 8)) + assert result == ["aa--bb", "cccccc", "dd"] + + +class TestSplitIntoParagraphs: + def test_basic_split(self) -> None: + text = "para1\n\npara2\n\npara3" + assert _split_into_paragraphs(text) == ["para1", "para2", "para3"] + + def test_multiple_newlines(self) -> None: + text = "a\n\n\n\nb" + assert _split_into_paragraphs(text) == ["a", "b"] + + def test_no_split(self) -> None: + assert _split_into_paragraphs("single paragraph") == ["single paragraph"] + + def test_leading_trailing_newlines(self) -> None: + text = "\n\nfoo\n\n" + result = _split_into_paragraphs(text) + assert "foo" in result + assert "" not in result + + +class TestTextToChunks: + def test_short_text_single_chunk(self) -> None: + result = _text_to_chunks("short text", max_chunk_length=100) + assert result == ["short text"] + + def test_long_text_splits(self) -> None: + text = "para one\n\npara two\n\npara three" + result = _text_to_chunks(text, max_chunk_length=15) + assert len(result) > 1 + for chunk in result: + assert not chunk.startswith("\n"), f"chunk {chunk!r} has leading newline" + + def test_no_leading_separator_in_any_chunk(self) -> None: + """Regression: no chunk should start with the paragraph separator.""" + text = "A" * 50 + "\n\n" + "B" * 50 + "\n\n" + "C" * 50 + result = _text_to_chunks(text, max_chunk_length=60) + for chunk in result: + assert not chunk.startswith( + "\n\n" + ), f"chunk {chunk!r} has leading separator" From b52910bad83b69943b9ba085af69ab082a0aac13 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:25 +0100 Subject: [PATCH 33/49] format --- src/typeagent/knowpro/interfaces_search.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/typeagent/knowpro/interfaces_search.py b/src/typeagent/knowpro/interfaces_search.py index 79539e58..c3727d21 100644 --- a/src/typeagent/knowpro/interfaces_search.py +++ b/src/typeagent/knowpro/interfaces_search.py @@ -27,6 +27,8 @@ "SemanticRefSearchResult", "WhenFilter", ] + + @dataclass class SearchTerm: """Represents a term being searched for. @@ -141,5 +143,3 @@ class SemanticRefSearchResult: term_matches: set[str] semantic_ref_matches: list[ScoredSemanticRefOrdinal] - - From b7fd4a9b5e09c3de5fcf601ecdf1d833dd530f48 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:25 +0100 Subject: [PATCH 34/49] removed bare import coverage at the top level. The coverage package is a dev/test dependency. MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit "user" role correctly but defaulted everything else to "assistant" — including "system" prompts. MCP's SamplingMessage only accepts "user" or "assistant" roles. A "system" prompt section from TypeChat would silently be sent as "assistant", which changes LLM behavior. added testcases --- src/typeagent/mcp/server.py | 14 ++- tests/test_mcp_server_unit.py | 167 ++++++++++++++++++++++++++++++++++ 2 files changed, 179 insertions(+), 2 deletions(-) create mode 100644 tests/test_mcp_server_unit.py diff --git a/src/typeagent/mcp/server.py b/src/typeagent/mcp/server.py index dcd4a3cf..8fb4d03e 100644 --- a/src/typeagent/mcp/server.py +++ b/src/typeagent/mcp/server.py @@ -9,7 +9,10 @@ import time from typing import Any -import coverage +try: + import coverage +except ImportError: + coverage = None # type: ignore[assignment] from dotenv import load_dotenv from mcp.server.fastmcp import Context, FastMCP @@ -18,7 +21,8 @@ import typechat # Enable coverage.py before local imports (a no-op unless COVERAGE_PROCESS_START is set). -coverage.process_startup() +if coverage is not None: + coverage.process_startup() from typeagent.aitools import embeddings, utils from typeagent.knowpro import answers, query, searchlang @@ -246,6 +250,12 @@ async def query_conversation( return QuestionResponse( success=True, answer=combined_answer.answer or "", time_used=dt ) + case _: + return QuestionResponse( + success=False, + answer=f"Unexpected answer type: {combined_answer.type}", + time_used=dt, + ) # Run the MCP server diff --git a/tests/test_mcp_server_unit.py b/tests/test_mcp_server_unit.py new file mode 100644 index 00000000..cfc0b918 --- /dev/null +++ b/tests/test_mcp_server_unit.py @@ -0,0 +1,167 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +"""Unit tests for server.py changes (coverage guard, role mapping, match default).""" + +from unittest.mock import AsyncMock + +import pytest + +from mcp.types import SamplingMessage, TextContent +import typechat + +from typeagent.mcp.server import MCPTypeChatModel, QuestionResponse + +# --------------------------------------------------------------------------- +# Change 1: coverage import guard — tested implicitly (the module loads at all +# without `coverage` installed). We just verify the guard didn't break the +# import. +# --------------------------------------------------------------------------- + + +def test_server_module_imports() -> None: + """Importing the server module should not raise even without coverage.""" + import typeagent.mcp.server as mod + + assert hasattr(mod, "mcp") # The FastMCP instance exists + + +# --------------------------------------------------------------------------- +# Change 2: PromptSection role mapping ("system" → "assistant") +# --------------------------------------------------------------------------- + + +class TestMCPTypeChatModelRoleMapping: + """Verify that PromptSection roles are mapped correctly to MCP roles.""" + + @staticmethod + def _make_model() -> tuple[MCPTypeChatModel, AsyncMock]: + session = AsyncMock() + # create_message returns a result with TextContent + session.create_message.return_value = AsyncMock( + content=TextContent(type="text", text="response") + ) + model = MCPTypeChatModel(session) + return model, session + + @pytest.mark.asyncio + async def test_string_prompt_becomes_user_message(self) -> None: + model, session = self._make_model() + await model.complete("hello") + + call_args = session.create_message.call_args + messages: list[SamplingMessage] = call_args.kwargs["messages"] + assert len(messages) == 1 + assert messages[0].role == "user" + assert isinstance(messages[0].content, TextContent) + assert messages[0].content.text == "hello" + + @pytest.mark.asyncio + async def test_user_role_preserved(self) -> None: + model, session = self._make_model() + sections: list[typechat.PromptSection] = [ + {"role": "user", "content": "question"}, + ] + await model.complete(sections) + + messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ + "messages" + ] + assert messages[0].role == "user" + + @pytest.mark.asyncio + async def test_assistant_role_preserved(self) -> None: + model, session = self._make_model() + sections: list[typechat.PromptSection] = [ + {"role": "assistant", "content": "context"}, + ] + await model.complete(sections) + + messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ + "messages" + ] + assert messages[0].role == "assistant" + + @pytest.mark.asyncio + async def test_system_role_mapped_to_assistant(self) -> None: + """System role doesn't exist in MCP SamplingMessage; it must be mapped.""" + model, session = self._make_model() + sections: list[typechat.PromptSection] = [ + {"role": "system", "content": "instructions"}, + {"role": "user", "content": "question"}, + ] + await model.complete(sections) + + messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ + "messages" + ] + assert messages[0].role == "assistant" # "system" → "assistant" + assert messages[1].role == "user" + + @pytest.mark.asyncio + async def test_mixed_roles_order(self) -> None: + model, session = self._make_model() + sections: list[typechat.PromptSection] = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "usr"}, + {"role": "assistant", "content": "asst"}, + ] + await model.complete(sections) + + messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ + "messages" + ] + assert [m.role for m in messages] == ["assistant", "user", "assistant"] + + @pytest.mark.asyncio + async def test_exception_returns_failure(self) -> None: + model, session = self._make_model() + session.create_message.side_effect = RuntimeError("boom") + result = await model.complete("test") + assert isinstance(result, typechat.Failure) + assert "boom" in result.message + + @pytest.mark.asyncio + async def test_text_content_returns_success(self) -> None: + model, _ = self._make_model() + result = await model.complete("test") + assert isinstance(result, typechat.Success) + assert result.value == "response" + + @pytest.mark.asyncio + async def test_list_content_returns_joined(self) -> None: + model, session = self._make_model() + session.create_message.return_value = AsyncMock( + content=[ + TextContent(type="text", text="part1"), + TextContent(type="text", text="part2"), + ] + ) + result = await model.complete("test") + assert isinstance(result, typechat.Success) + assert result.value == "part1\npart2" + + +# --------------------------------------------------------------------------- +# Change 3: match statement default case in query_conversation +# --------------------------------------------------------------------------- + + +class TestQuestionResponseMatchDefault: + """The match on combined_answer.type must handle unexpected types.""" + + def test_known_types(self) -> None: + """QuestionResponse can represent success and failure.""" + ok = QuestionResponse(success=True, answer="yes", time_used=42) + assert ok.success is True + fail = QuestionResponse(success=False, answer="no", time_used=0) + assert fail.success is False + + def test_answer_type_coverage(self) -> None: + """AnswerResponse.type should only be 'Answered' or 'NoAnswer'.""" + from typeagent.knowpro.answer_response_schema import AnswerResponse + + answered = AnswerResponse(type="Answered", answer="yes") + assert answered.type == "Answered" + no_answer = AnswerResponse(type="NoAnswer", why_no_answer="dunno") + assert no_answer.type == "NoAnswer" From 0f413fa2df59e766a9576688f64d35641f18a4a4 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:25 +0100 Subject: [PATCH 35/49] =?UTF-8?q?dir(self)lists=20all=20attributes=20on=20?= =?UTF-8?q?the=20object=20including=20inherited=20methods,=20dunder=20meth?= =?UTF-8?q?ods=20(=5F=5Finit=5F=5F,=20=5F=5Feq=5F=5F,=20=5F=5Fhash=5F=5F,?= =?UTF-8?q?=20=E2=80=A6),=20and=20descriptors.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The fix switches to vars(self), which returns only the instance's __dict__ — the actual dataclass field values. Combined with the not key.startswith("_") and is not None filters, the repr is now clean. added testcases --- src/typeagent/knowpro/search.py | 8 +++----- tests/test_searchlang.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 34 insertions(+), 5 deletions(-) diff --git a/src/typeagent/knowpro/search.py b/src/typeagent/knowpro/search.py index b7c37ac7..9641590e 100644 --- a/src/typeagent/knowpro/search.py +++ b/src/typeagent/knowpro/search.py @@ -90,11 +90,9 @@ class SearchOptions: def __repr__(self): parts = [] - for key in dir(self): - if not key.startswith("_"): - value = getattr(self, key) - if value is not None: - parts.append(f"{key}={value!r}") + for key, value in vars(self).items(): + if not key.startswith("_") and value is not None: + parts.append(f"{key}={value!r}") return f"{self.__class__.__name__}({', '.join(parts)})" diff --git a/tests/test_searchlang.py b/tests/test_searchlang.py index 46cdae75..15c80aaf 100644 --- a/tests/test_searchlang.py +++ b/tests/test_searchlang.py @@ -39,6 +39,37 @@ def test_true_field_shown(self) -> None: opts = SearchOptions(exact_match=True) assert "exact_match=True" in repr(opts) + def test_all_fields_set(self) -> None: + """When every field is non-None, all appear in repr.""" + opts = SearchOptions( + max_knowledge_matches=10, + exact_match=True, + max_message_matches=20, + max_chars_in_budget=5000, + threshold_score=0.75, + ) + r = repr(opts) + assert "max_knowledge_matches=10" in r + assert "exact_match=True" in r + assert "max_message_matches=20" in r + assert "max_chars_in_budget=5000" in r + assert "threshold_score=0.75" in r + + def test_zero_values_shown(self) -> None: + """Zero is not None, so numeric zeros should appear.""" + opts = SearchOptions(max_knowledge_matches=0, threshold_score=0.0) + r = repr(opts) + assert "max_knowledge_matches=0" in r + assert "threshold_score=0.0" in r + + def test_no_dunder_or_method_names(self) -> None: + """The repr must not contain dunder names or method objects.""" + opts = SearchOptions(max_knowledge_matches=5) + r = repr(opts) + assert "__init__" not in r + assert "__eq__" not in r + assert "bound method" not in r + class TestLanguageSearchOptionsRepr: """Tests for LanguageSearchOptions.__repr__ (subclass of SearchOptions).""" From 996ebf0459ef443f65dbed699bcbaccf125d66ca Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Fri, 27 Feb 2026 16:17:25 +0100 Subject: [PATCH 36/49] for objects that don't have .ordinal, the code would crash with AttributeError. The fix introduces a local ordinal = 0 counter incremented on each iteration, which works regardless of the message implementation. added message.timestamp null guard fix related to the changes in SemanticRefAccumulator use with_term_matches classmethod factory, which copies the search-term provenance set, so mutations to the filtered accumulator's set don't affect the source. added testcases --- src/typeagent/knowpro/query.py | 13 +++- tests/test_query.py | 130 +++++++++++++++++++++++++++++++++ 2 files changed, 139 insertions(+), 4 deletions(-) diff --git a/src/typeagent/knowpro/query.py b/src/typeagent/knowpro/query.py index 5151054c..dafeef84 100644 --- a/src/typeagent/knowpro/query.py +++ b/src/typeagent/knowpro/query.py @@ -101,15 +101,20 @@ async def get_text_range_for_date_range( messages = conversation.messages range_start_ordinal: MessageOrdinal = -1 range_end_ordinal = range_start_ordinal + ordinal = 0 async for message in messages: - if Datetime.fromisoformat(message.timestamp) in date_range: + if ( + message.timestamp + and Datetime.fromisoformat(message.timestamp) in date_range + ): if range_start_ordinal < 0: - range_start_ordinal = message.ordinal - range_end_ordinal = message.ordinal + range_start_ordinal = ordinal + range_end_ordinal = ordinal else: if range_start_ordinal >= 0: # We have a range, so break. break + ordinal += 1 if range_start_ordinal >= 0: return TextRange( start=TextLocation(range_start_ordinal), @@ -696,7 +701,7 @@ class WhereSemanticRefExpr(QueryOpExpr[SemanticRefAccumulator]): async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator: accumulator = await self.source_expr.eval(context) - filtered = SemanticRefAccumulator(accumulator.search_term_matches) + filtered = SemanticRefAccumulator.with_term_matches(accumulator) # Filter matches asynchronously filtered_matches = [] diff --git a/tests/test_query.py b/tests/test_query.py index 5a58df3d..e65eea6c 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -680,3 +680,133 @@ async def test_lookup_knowledge_type(): assert {r.semantic_ref_ordinal for r in result} == {0, 2} # Should return empty list if no match assert await lookup_knowledge_type(collection, "action") == [] + + +# --------------------------------------------------------------------------- +# Change 1: get_text_range_for_date_range uses manual ordinal counter +# (not message.ordinal) and guards against None timestamp. +# --------------------------------------------------------------------------- + + +class TestGetTextRangeForDateRange: + """Tests for the ordinal counter fix and timestamp None guard.""" + + @pytest.mark.asyncio + async def test_messages_without_ordinal_attribute(self) -> None: + """Messages that lack .ordinal should still work (manual counter).""" + + class BareMessage(FakeMessage): + """A message subclass that explicitly lacks .ordinal.""" + + def __init__(self, ts: str) -> None: + super().__init__("text") + self.timestamp = ts + if hasattr(self, "ordinal"): + del self.ordinal + + conv = FakeConversation( + messages=[ + BareMessage("2020-01-01T01:00:00"), + BareMessage("2020-01-01T02:00:00"), + ], + ) + date_range = DateRange( + start=Datetime(2020, 1, 1, 0, 0, 0), + end=Datetime(2020, 1, 2, 0, 0, 0), + ) + result = await get_text_range_for_date_range(conv, date_range) + assert result is not None + assert result.start.message_ordinal == 0 + assert result.end is not None + assert result.end.message_ordinal == 2 # exclusive end + + @pytest.mark.asyncio + async def test_none_timestamp_skipped(self) -> None: + """Messages with None timestamp should be skipped, not crash.""" + conv = FakeConversation( + messages=[ + FakeMessage("no timestamp"), # timestamp=None + FakeMessage("has timestamp", message_ordinal=1), + ], + ) + date_range = DateRange( + start=Datetime(2020, 1, 1, 0, 0, 0), + end=Datetime(2020, 1, 2, 0, 0, 0), + ) + result = await get_text_range_for_date_range(conv, date_range) + # Only message at ordinal 1 matches: + assert result is not None + assert result.start.message_ordinal == 1 + assert result.end is not None + assert result.end.message_ordinal == 2 + + @pytest.mark.asyncio + async def test_all_none_timestamps_returns_none(self) -> None: + """If every message has None timestamp, result should be None.""" + conv = FakeConversation( + messages=[FakeMessage("a"), FakeMessage("b")], + ) + date_range = DateRange( + start=Datetime(2020, 1, 1, 0, 0, 0), + end=Datetime(2020, 1, 2, 0, 0, 0), + ) + assert await get_text_range_for_date_range(conv, date_range) is None + + @pytest.mark.asyncio + async def test_single_message_in_range(self) -> None: + conv = FakeConversation( + messages=[FakeMessage("msg", message_ordinal=0)], + ) + date_range = DateRange( + start=Datetime(2020, 1, 1, 0, 0, 0), + end=Datetime(2020, 1, 2, 0, 0, 0), + ) + result = await get_text_range_for_date_range(conv, date_range) + assert result is not None + assert result.start.message_ordinal == 0 + assert result.end is not None + assert result.end.message_ordinal == 1 + + +# --------------------------------------------------------------------------- +# Change 3: WhereSemanticRefExpr uses with_term_matches (provenance copy) +# --------------------------------------------------------------------------- + + +class TestWhereSemanticRefExprProvenance: + """Verify that WhereSemanticRefExpr copies (not shares) search_term_matches.""" + + @pytest.mark.asyncio + async def test_filtered_accumulator_has_copied_provenance( + self, searchable_conversation: FakeConversation + ) -> None: + """The filtered accumulator's search_term_matches is a copy.""" + from typeagent.knowpro.query import WhereSemanticRefExpr + + # Build a source accumulator with some provenance + src = SemanticRefAccumulator() + src.search_term_matches.add("term_a") + src.add_term_matches( + Term("test"), + [ScoredSemanticRefOrdinal(0, 1.0)], + is_exact_match=True, + weight=1.0, + ) + + # Create a trivial source expression that returns the above accumulator + class ConstExpr(QueryOpExpr[SemanticRefAccumulator]): + async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator: + return src + + # WhereSemanticRefExpr with no predicates (all matches pass) + expr = WhereSemanticRefExpr( + source_expr=ConstExpr(), + predicates=[], + ) + ctx = QueryEvalContext(searchable_conversation) + filtered = await expr.eval(ctx) + + # Provenance was copied, not shared: + assert "term_a" in filtered.search_term_matches + filtered.search_term_matches.add("new_term") + assert "new_term" not in src.search_term_matches From f0749e1e82e5163f8daf71a34a94d6f7c96f0b6f Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Sun, 1 Mar 2026 16:02:40 +0100 Subject: [PATCH 37/49] Clean up import statements in test_utils.py Removed redundant import statements for pydantic.dataclasses and typechat. --- tests/test_utils.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index f8f77d3e..bc4ec425 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -11,9 +11,6 @@ import pydantic.dataclasses import typechat -import pydantic.dataclasses -import typechat - import typeagent.aitools.utils as utils From 4611d844fac6f6453bd3774fca6b6fc3e31136f8 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Sun, 1 Mar 2026 14:48:54 +0100 Subject: [PATCH 38/49] removed invalid , (comma) in URL specification and updated test --- src/typeagent/aitools/utils.py | 2 +- tests/test_utils.py | 9 --------- 2 files changed, 1 insertion(+), 10 deletions(-) diff --git a/src/typeagent/aitools/utils.py b/src/typeagent/aitools/utils.py index e556633b..adba401a 100644 --- a/src/typeagent/aitools/utils.py +++ b/src/typeagent/aitools/utils.py @@ -191,7 +191,7 @@ def parse_azure_endpoint( if not azure_endpoint: raise RuntimeError(f"Environment variable {endpoint_envvar} not found") - m = re.search(r"[?&,]api-version=([\d-]+(?:preview)?)", azure_endpoint) + m = re.search(r"[?&]api-version=([\d-]+(?:preview)?)", azure_endpoint) if not m: raise RuntimeError( f"{endpoint_envvar}={azure_endpoint} doesn't contain valid api-version field" diff --git a/tests/test_utils.py b/tests/test_utils.py index bc4ec425..5966af61 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -78,15 +78,6 @@ def test_api_version_after_ampersand(self, monkeypatch: pytest.MonkeyPatch) -> N _, version = utils.parse_azure_endpoint("TEST_ENDPOINT") assert version == "2025-01-01-preview" - def test_api_version_after_comma(self, monkeypatch: pytest.MonkeyPatch) -> None: - """api-version preceded by comma (alternate separator).""" - monkeypatch.setenv( - "TEST_ENDPOINT", - "https://myhost.openai.azure.com/openai/deployments/gpt-4?foo=bar,api-version=2024-06-01", - ) - _, version = utils.parse_azure_endpoint("TEST_ENDPOINT") - assert version == "2024-06-01" - def test_missing_env_var_raises(self, monkeypatch: pytest.MonkeyPatch) -> None: """RuntimeError when the environment variable is not set.""" monkeypatch.delenv("NONEXISTENT_ENDPOINT", raising=False) From 77d06be840890dde08b2f81ca3727b44b781ecf9 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Sun, 1 Mar 2026 14:59:42 +0100 Subject: [PATCH 39/49] - fixed implementation to match the corrected half-open interval semantics - Removed the incorrect - 1 from range.end.message_ordinal. Now fetches the timestamp of the message at the exclusive end ordinal (the first message after the range). Added a bounds check: if the ordinal is past the end of the collection, end is left as None. - fixed test cases - added test for the edge case where the range end is past the last message --- src/typeagent/knowpro/answers.py | 11 ++++++----- tests/test_answers.py | 24 ++++++++++++++++++------ 2 files changed, 24 insertions(+), 11 deletions(-) diff --git a/src/typeagent/knowpro/answers.py b/src/typeagent/knowpro/answers.py index c6832efb..d1b87a35 100644 --- a/src/typeagent/knowpro/answers.py +++ b/src/typeagent/knowpro/answers.py @@ -404,11 +404,12 @@ async def get_enclosing_date_range_for_text_range( start_timestamp = (await messages.get_item(range.start.message_ordinal)).timestamp if not start_timestamp: return None - end_timestamp = ( - (await messages.get_item(range.end.message_ordinal - 1)).timestamp - if range.end - else None - ) + end_timestamp: str | None = None + if range.end: + end_ordinal = range.end.message_ordinal + if end_ordinal < await messages.size(): + end_timestamp = (await messages.get_item(end_ordinal)).timestamp + # else: range extends to the end of the conversation; leave as None. return DateRange( start=Datetime.fromisoformat(start_timestamp), end=Datetime.fromisoformat(end_timestamp) if end_timestamp else None, diff --git a/tests/test_answers.py b/tests/test_answers.py index 2669ea78..a8fc0ce7 100644 --- a/tests/test_answers.py +++ b/tests/test_answers.py @@ -98,10 +98,10 @@ async def test_single_message_range(self, messages: FakeMessageCollection) -> No assert dr.end is None @pytest.mark.asyncio - async def test_multi_message_range_uses_last_included( + async def test_multi_message_range_uses_exclusive_end( self, messages: FakeMessageCollection ) -> None: - """Range [0, 2) should use message 1 for end (ordinal 2-1=1), not message 2.""" + """Range [0, 2) should use message 2 (the exclusive end) for end timestamp.""" tr = TextRange( start=TextLocation(0), end=TextLocation(2), # exclusive end @@ -109,13 +109,13 @@ async def test_multi_message_range_uses_last_included( dr = await get_enclosing_date_range_for_text_range(messages, tr) assert dr is not None assert dr.start.hour == 0 - # End timestamp comes from message ordinal 1 (= 2-1), NOT ordinal 2: + # End timestamp comes from the message at the exclusive end ordinal: assert dr.end is not None - assert dr.end.hour == 1 + assert dr.end.hour == 2 @pytest.mark.asyncio async def test_adjacent_messages(self, messages: FakeMessageCollection) -> None: - """Range [1, 2) covers only message 1.""" + """Range [1, 2) covers only message 1; end timestamp is message 2.""" tr = TextRange( start=TextLocation(1), end=TextLocation(2), @@ -124,7 +124,19 @@ async def test_adjacent_messages(self, messages: FakeMessageCollection) -> None: assert dr is not None assert dr.start.hour == 1 assert dr.end is not None - assert dr.end.hour == 1 # same message: end-1 == start + assert dr.end.hour == 2 # exclusive end: timestamp of the next message + + @pytest.mark.asyncio + async def test_end_past_last_message(self, messages: FakeMessageCollection) -> None: + """If the exclusive end ordinal is past the last message, end is None.""" + tr = TextRange( + start=TextLocation(0), + end=TextLocation(3), # messages only have ordinals 0, 1, 2 + ) + dr = await get_enclosing_date_range_for_text_range(messages, tr) + assert dr is not None + assert dr.start.hour == 0 + assert dr.end is None @pytest.mark.asyncio async def test_no_timestamp_returns_none(self) -> None: From 3a69313f8bb8a6b12df3504c9361534450519312 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Sun, 1 Mar 2026 15:16:28 +0100 Subject: [PATCH 40/49] - used aenumerate in the for loop - as Python 3.12 does not have aenumerate in its stdlib, added a small help into utils.py --- src/typeagent/knowpro/query.py | 5 ++--- src/typeagent/knowpro/utils.py | 9 +++++++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/src/typeagent/knowpro/query.py b/src/typeagent/knowpro/query.py index dafeef84..d3a4640c 100644 --- a/src/typeagent/knowpro/query.py +++ b/src/typeagent/knowpro/query.py @@ -45,6 +45,7 @@ Thread, ) from .kplib import ConcreteEntity +from .utils import aenumerate # TODO: Move to compilelib.py type BooleanOp = Literal["and", "or", "or_max"] @@ -101,8 +102,7 @@ async def get_text_range_for_date_range( messages = conversation.messages range_start_ordinal: MessageOrdinal = -1 range_end_ordinal = range_start_ordinal - ordinal = 0 - async for message in messages: + async for ordinal, message in aenumerate(messages): if ( message.timestamp and Datetime.fromisoformat(message.timestamp) in date_range @@ -114,7 +114,6 @@ async def get_text_range_for_date_range( if range_start_ordinal >= 0: # We have a range, so break. break - ordinal += 1 if range_start_ordinal >= 0: return TextRange( start=TextLocation(range_start_ordinal), diff --git a/src/typeagent/knowpro/utils.py b/src/typeagent/knowpro/utils.py index 298c09db..92eedacc 100644 --- a/src/typeagent/knowpro/utils.py +++ b/src/typeagent/knowpro/utils.py @@ -3,9 +3,18 @@ """Utility functions for the knowpro package.""" +from collections.abc import AsyncIterable + from .interfaces import MessageOrdinal, TextLocation, TextRange +async def aenumerate[T](aiterable: AsyncIterable[T], start: int = 0): + i = start + async for item in aiterable: + yield i, item + i += 1 + + def text_range_from_message_chunk( message_ordinal: MessageOrdinal, chunk_ordinal: int = 0, From 70bc27dc4dac3840976d6f9d4b69b94892c9f924 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Sun, 1 Mar 2026 15:23:43 +0100 Subject: [PATCH 41/49] Merged all content from test_mcp_server_unit.py into test_mcp_server.py --- tests/test_mcp_server.py | 169 +++++++++++++++++++++++++++++++++- tests/test_mcp_server_unit.py | 167 --------------------------------- 2 files changed, 167 insertions(+), 169 deletions(-) delete mode 100644 tests/test_mcp_server_unit.py diff --git a/tests/test_mcp_server.py b/tests/test_mcp_server.py index 24933ca2..d70c2275 100644 --- a/tests/test_mcp_server.py +++ b/tests/test_mcp_server.py @@ -1,12 +1,13 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. -"""End-to-end tests for the MCP server.""" +"""End-to-end and unit tests for the MCP server.""" import json import os import sys from typing import Any +from unittest.mock import AsyncMock import pytest @@ -14,10 +15,17 @@ from mcp.client.session import ClientSession as ClientSessionType from mcp.client.stdio import stdio_client from mcp.shared.context import RequestContext -from mcp.types import CreateMessageRequestParams, CreateMessageResult, TextContent +from mcp.types import ( + CreateMessageRequestParams, + CreateMessageResult, + SamplingMessage, + TextContent, +) from openai.types.chat import ChatCompletionMessageParam +import typechat from typeagent.aitools.utils import create_async_openai_client +from typeagent.mcp.server import MCPTypeChatModel, QuestionResponse from conftest import EPISODE_53_INDEX @@ -178,3 +186,160 @@ async def test_mcp_server_empty_question(server_params: StdioServerParameters): response_data = json.loads(response_text) assert response_data["success"] is False assert "No question provided" in response_data["answer"] + + +# --------------------------------------------------------------------------- +# Unit tests (formerly in test_mcp_server_unit.py) +# --------------------------------------------------------------------------- + +# Coverage import guard — tested implicitly (the module loads at all +# without `coverage` installed). We just verify the guard didn't break the +# import. + + +def test_server_module_imports() -> None: + """Importing the server module should not raise even without coverage.""" + import typeagent.mcp.server as mod + + assert hasattr(mod, "mcp") # The FastMCP instance exists + + +# --------------------------------------------------------------------------- +# PromptSection role mapping ("system" → "assistant") +# --------------------------------------------------------------------------- + + +class TestMCPTypeChatModelRoleMapping: + """Verify that PromptSection roles are mapped correctly to MCP roles.""" + + @staticmethod + def _make_model() -> tuple[MCPTypeChatModel, AsyncMock]: + session = AsyncMock() + # create_message returns a result with TextContent + session.create_message.return_value = AsyncMock( + content=TextContent(type="text", text="response") + ) + model = MCPTypeChatModel(session) + return model, session + + @pytest.mark.asyncio + async def test_string_prompt_becomes_user_message(self) -> None: + model, session = self._make_model() + await model.complete("hello") + + call_args = session.create_message.call_args + messages: list[SamplingMessage] = call_args.kwargs["messages"] + assert len(messages) == 1 + assert messages[0].role == "user" + assert isinstance(messages[0].content, TextContent) + assert messages[0].content.text == "hello" + + @pytest.mark.asyncio + async def test_user_role_preserved(self) -> None: + model, session = self._make_model() + sections: list[typechat.PromptSection] = [ + {"role": "user", "content": "question"}, + ] + await model.complete(sections) + + messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ + "messages" + ] + assert messages[0].role == "user" + + @pytest.mark.asyncio + async def test_assistant_role_preserved(self) -> None: + model, session = self._make_model() + sections: list[typechat.PromptSection] = [ + {"role": "assistant", "content": "context"}, + ] + await model.complete(sections) + + messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ + "messages" + ] + assert messages[0].role == "assistant" + + @pytest.mark.asyncio + async def test_system_role_mapped_to_assistant(self) -> None: + """System role doesn't exist in MCP SamplingMessage; it must be mapped.""" + model, session = self._make_model() + sections: list[typechat.PromptSection] = [ + {"role": "system", "content": "instructions"}, + {"role": "user", "content": "question"}, + ] + await model.complete(sections) + + messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ + "messages" + ] + assert messages[0].role == "assistant" # "system" → "assistant" + assert messages[1].role == "user" + + @pytest.mark.asyncio + async def test_mixed_roles_order(self) -> None: + model, session = self._make_model() + sections: list[typechat.PromptSection] = [ + {"role": "system", "content": "sys"}, + {"role": "user", "content": "usr"}, + {"role": "assistant", "content": "asst"}, + ] + await model.complete(sections) + + messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ + "messages" + ] + assert [m.role for m in messages] == ["assistant", "user", "assistant"] + + @pytest.mark.asyncio + async def test_exception_returns_failure(self) -> None: + model, session = self._make_model() + session.create_message.side_effect = RuntimeError("boom") + result = await model.complete("test") + assert isinstance(result, typechat.Failure) + assert "boom" in result.message + + @pytest.mark.asyncio + async def test_text_content_returns_success(self) -> None: + model, _ = self._make_model() + result = await model.complete("test") + assert isinstance(result, typechat.Success) + assert result.value == "response" + + @pytest.mark.asyncio + async def test_list_content_returns_joined(self) -> None: + model, session = self._make_model() + session.create_message.return_value = AsyncMock( + content=[ + TextContent(type="text", text="part1"), + TextContent(type="text", text="part2"), + ] + ) + result = await model.complete("test") + assert isinstance(result, typechat.Success) + assert result.value == "part1\npart2" + + +# --------------------------------------------------------------------------- +# match statement default case in query_conversation +# --------------------------------------------------------------------------- + + +class TestQuestionResponseMatchDefault: + """The match on combined_answer.type must handle unexpected types.""" + + def test_known_types(self) -> None: + """QuestionResponse can represent success and failure.""" + ok = QuestionResponse(success=True, answer="yes", time_used=42) + assert ok.success is True + fail = QuestionResponse(success=False, answer="no", time_used=0) + assert fail.success is False + + def test_answer_type_coverage(self) -> None: + """AnswerResponse.type should only be 'Answered' or 'NoAnswer'.""" + from typeagent.knowpro.answer_response_schema import AnswerResponse + + answered = AnswerResponse(type="Answered", answer="yes") + assert answered.type == "Answered" + no_answer = AnswerResponse(type="NoAnswer", why_no_answer="dunno") + assert no_answer.type == "NoAnswer" diff --git a/tests/test_mcp_server_unit.py b/tests/test_mcp_server_unit.py deleted file mode 100644 index cfc0b918..00000000 --- a/tests/test_mcp_server_unit.py +++ /dev/null @@ -1,167 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. - -"""Unit tests for server.py changes (coverage guard, role mapping, match default).""" - -from unittest.mock import AsyncMock - -import pytest - -from mcp.types import SamplingMessage, TextContent -import typechat - -from typeagent.mcp.server import MCPTypeChatModel, QuestionResponse - -# --------------------------------------------------------------------------- -# Change 1: coverage import guard — tested implicitly (the module loads at all -# without `coverage` installed). We just verify the guard didn't break the -# import. -# --------------------------------------------------------------------------- - - -def test_server_module_imports() -> None: - """Importing the server module should not raise even without coverage.""" - import typeagent.mcp.server as mod - - assert hasattr(mod, "mcp") # The FastMCP instance exists - - -# --------------------------------------------------------------------------- -# Change 2: PromptSection role mapping ("system" → "assistant") -# --------------------------------------------------------------------------- - - -class TestMCPTypeChatModelRoleMapping: - """Verify that PromptSection roles are mapped correctly to MCP roles.""" - - @staticmethod - def _make_model() -> tuple[MCPTypeChatModel, AsyncMock]: - session = AsyncMock() - # create_message returns a result with TextContent - session.create_message.return_value = AsyncMock( - content=TextContent(type="text", text="response") - ) - model = MCPTypeChatModel(session) - return model, session - - @pytest.mark.asyncio - async def test_string_prompt_becomes_user_message(self) -> None: - model, session = self._make_model() - await model.complete("hello") - - call_args = session.create_message.call_args - messages: list[SamplingMessage] = call_args.kwargs["messages"] - assert len(messages) == 1 - assert messages[0].role == "user" - assert isinstance(messages[0].content, TextContent) - assert messages[0].content.text == "hello" - - @pytest.mark.asyncio - async def test_user_role_preserved(self) -> None: - model, session = self._make_model() - sections: list[typechat.PromptSection] = [ - {"role": "user", "content": "question"}, - ] - await model.complete(sections) - - messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ - "messages" - ] - assert messages[0].role == "user" - - @pytest.mark.asyncio - async def test_assistant_role_preserved(self) -> None: - model, session = self._make_model() - sections: list[typechat.PromptSection] = [ - {"role": "assistant", "content": "context"}, - ] - await model.complete(sections) - - messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ - "messages" - ] - assert messages[0].role == "assistant" - - @pytest.mark.asyncio - async def test_system_role_mapped_to_assistant(self) -> None: - """System role doesn't exist in MCP SamplingMessage; it must be mapped.""" - model, session = self._make_model() - sections: list[typechat.PromptSection] = [ - {"role": "system", "content": "instructions"}, - {"role": "user", "content": "question"}, - ] - await model.complete(sections) - - messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ - "messages" - ] - assert messages[0].role == "assistant" # "system" → "assistant" - assert messages[1].role == "user" - - @pytest.mark.asyncio - async def test_mixed_roles_order(self) -> None: - model, session = self._make_model() - sections: list[typechat.PromptSection] = [ - {"role": "system", "content": "sys"}, - {"role": "user", "content": "usr"}, - {"role": "assistant", "content": "asst"}, - ] - await model.complete(sections) - - messages: list[SamplingMessage] = session.create_message.call_args.kwargs[ - "messages" - ] - assert [m.role for m in messages] == ["assistant", "user", "assistant"] - - @pytest.mark.asyncio - async def test_exception_returns_failure(self) -> None: - model, session = self._make_model() - session.create_message.side_effect = RuntimeError("boom") - result = await model.complete("test") - assert isinstance(result, typechat.Failure) - assert "boom" in result.message - - @pytest.mark.asyncio - async def test_text_content_returns_success(self) -> None: - model, _ = self._make_model() - result = await model.complete("test") - assert isinstance(result, typechat.Success) - assert result.value == "response" - - @pytest.mark.asyncio - async def test_list_content_returns_joined(self) -> None: - model, session = self._make_model() - session.create_message.return_value = AsyncMock( - content=[ - TextContent(type="text", text="part1"), - TextContent(type="text", text="part2"), - ] - ) - result = await model.complete("test") - assert isinstance(result, typechat.Success) - assert result.value == "part1\npart2" - - -# --------------------------------------------------------------------------- -# Change 3: match statement default case in query_conversation -# --------------------------------------------------------------------------- - - -class TestQuestionResponseMatchDefault: - """The match on combined_answer.type must handle unexpected types.""" - - def test_known_types(self) -> None: - """QuestionResponse can represent success and failure.""" - ok = QuestionResponse(success=True, answer="yes", time_used=42) - assert ok.success is True - fail = QuestionResponse(success=False, answer="no", time_used=0) - assert fail.success is False - - def test_answer_type_coverage(self) -> None: - """AnswerResponse.type should only be 'Answered' or 'NoAnswer'.""" - from typeagent.knowpro.answer_response_schema import AnswerResponse - - answered = AnswerResponse(type="Answered", answer="yes") - assert answered.type == "Answered" - no_answer = AnswerResponse(type="NoAnswer", why_no_answer="dunno") - assert no_answer.type == "NoAnswer" From d4014cf8fce5a0923ff9bc76c65f387e44ea7f82 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Sun, 1 Mar 2026 15:29:10 +0100 Subject: [PATCH 42/49] explained the additiona if guard --- src/typeagent/knowpro/searchlang.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/typeagent/knowpro/searchlang.py b/src/typeagent/knowpro/searchlang.py index 54b7eb7b..e2e990de 100644 --- a/src/typeagent/knowpro/searchlang.py +++ b/src/typeagent/knowpro/searchlang.py @@ -369,6 +369,7 @@ def compile_action_term_as_search_terms( self.compile_entity_terms_as_search_terms( action_term.additional_entities, action_group ) + # only append the nested or_max wrapper when created one (use_or_max) and it's non-empty. if use_or_max and action_group.terms: term_group.terms.append(action_group) return term_group From c351f1690d71295489e5940dbe60908ab696dee8 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Sun, 1 Mar 2026 15:41:26 +0100 Subject: [PATCH 43/49] renamed is_in_range() to contains_range() --- src/typeagent/knowpro/collections.py | 4 ++-- tests/test_collections.py | 12 +++++++----- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/src/typeagent/knowpro/collections.py b/src/typeagent/knowpro/collections.py index c73e83a1..059bc3bf 100644 --- a/src/typeagent/knowpro/collections.py +++ b/src/typeagent/knowpro/collections.py @@ -534,7 +534,7 @@ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> No for text_range in text_ranges._ranges: self.add_range(text_range) - def is_in_range(self, inner_range: TextRange) -> bool: + def contains_range(self, inner_range: TextRange) -> bool: if len(self._ranges) == 0: return False for outer_range in self._ranges: @@ -566,7 +566,7 @@ def is_range_in_scope(self, inner_range: TextRange) -> bool: # We have a very simple impl: we don't intersect/union ranges yet. # Instead, we ensure that the inner range is not rejected by any outer ranges. for outer_ranges in self.text_ranges: - if not outer_ranges.is_in_range(inner_range): + if not outer_ranges.contains_range(inner_range): return False return True diff --git a/tests/test_collections.py b/tests/test_collections.py index 9f913d0d..ca6e14ea 100644 --- a/tests/test_collections.py +++ b/tests/test_collections.py @@ -110,13 +110,15 @@ def test_text_range_collection_add_and_check(): assert len(collection) == 2 - assert collection.is_in_range(range1) is True - assert collection.is_in_range(range2) is True + assert collection.contains_range(range1) is True + assert collection.contains_range(range2) is True assert ( - collection.is_in_range(range3) is True + collection.contains_range(range3) is True ) # range3 [5,10) is inside range1 [0,10) - assert collection.is_in_range(range4) is False # range4 [5,25) spans across ranges - assert collection.is_in_range(range5) is False + assert ( + collection.contains_range(range4) is False + ) # range4 [5,25) spans across ranges + assert collection.contains_range(range5) is False def test_text_ranges_in_scope(): From 5300f099ddd97c9c576232375fe8068c3d0ad123 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Sun, 1 Mar 2026 21:00:44 +0100 Subject: [PATCH 44/49] changed method to clone() --- src/typeagent/knowpro/collections.py | 17 +++++++---------- src/typeagent/knowpro/query.py | 2 +- tests/test_query.py | 2 +- 3 files changed, 9 insertions(+), 12 deletions(-) diff --git a/src/typeagent/knowpro/collections.py b/src/typeagent/knowpro/collections.py index 059bc3bf..1c729697 100644 --- a/src/typeagent/knowpro/collections.py +++ b/src/typeagent/knowpro/collections.py @@ -259,7 +259,7 @@ class SemanticRefAccumulator(MatchAccumulator[SemanticRefOrdinal]): """Accumulates scored semantic reference matches. ``search_term_matches`` tracks which search terms produced hits (provenance). - Use ``with_term_matches`` to create a derived accumulator that inherits a + Use ``clone`` to create a derived accumulator that inherits a *copy* of the parent's provenance. """ @@ -267,13 +267,10 @@ def __init__(self) -> None: super().__init__() self.search_term_matches: set[str] = set() - @classmethod - def with_term_matches( - cls, source: "SemanticRefAccumulator" - ) -> "SemanticRefAccumulator": - """Create a new accumulator inheriting a copy of *source*'s term-match provenance.""" - acc = cls() - acc.search_term_matches = set(source.search_term_matches) + def clone(self) -> "SemanticRefAccumulator": + """Create a new empty accumulator inheriting a copy of this one's term-match provenance.""" + acc = self.__class__() + acc.search_term_matches = set(self.search_term_matches) return acc def add_term_matches( @@ -352,7 +349,7 @@ async def group_matches_by_type( semantic_ref = await semantic_refs.get_item(match.value) group = groups.get(semantic_ref.knowledge.knowledge_type) if group is None: - group = SemanticRefAccumulator.with_term_matches(self) + group = self.clone() groups[semantic_ref.knowledge.knowledge_type] = group group.set_match(match) return groups @@ -362,7 +359,7 @@ async def get_matches_in_scope( semantic_refs: ISemanticRefCollection, ranges_in_scope: "TextRangesInScope", ) -> "SemanticRefAccumulator": - accumulator = SemanticRefAccumulator.with_term_matches(self) + accumulator = self.clone() for match in self: if ranges_in_scope.is_range_in_scope( (await semantic_refs.get_item(match.value)).range diff --git a/src/typeagent/knowpro/query.py b/src/typeagent/knowpro/query.py index d3a4640c..b0967a6e 100644 --- a/src/typeagent/knowpro/query.py +++ b/src/typeagent/knowpro/query.py @@ -700,7 +700,7 @@ class WhereSemanticRefExpr(QueryOpExpr[SemanticRefAccumulator]): async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator: accumulator = await self.source_expr.eval(context) - filtered = SemanticRefAccumulator.with_term_matches(accumulator) + filtered = accumulator.clone() # Filter matches asynchronously filtered_matches = [] diff --git a/tests/test_query.py b/tests/test_query.py index e65eea6c..c736d655 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -769,7 +769,7 @@ async def test_single_message_in_range(self) -> None: # --------------------------------------------------------------------------- -# Change 3: WhereSemanticRefExpr uses with_term_matches (provenance copy) +# Change 3: WhereSemanticRefExpr uses clone() (provenance copy) # --------------------------------------------------------------------------- From bd0d3298d69f64c72bc9dbd197b2cad91cd77113 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Mon, 2 Mar 2026 23:36:53 +0100 Subject: [PATCH 45/49] restored commit 582175ce42dd487e779f7da0eee85cb0e2fbe6de --- src/typeagent/knowpro/collections.py | 43 ++++++++-------------------- 1 file changed, 12 insertions(+), 31 deletions(-) diff --git a/src/typeagent/knowpro/collections.py b/src/typeagent/knowpro/collections.py index 1c729697..1e5205f3 100644 --- a/src/typeagent/knowpro/collections.py +++ b/src/typeagent/knowpro/collections.py @@ -77,12 +77,10 @@ def add(self, value: T, score: float, is_exact_match: bool = True) -> None: existing_match.hit_count += 1 existing_match.score += score else: - # Related (non-exact) match: only accumulate related counters. existing_match.related_hit_count += 1 existing_match.related_score += score else: if is_exact_match: - # New exact match: starts with hit_count=1 and the given score. self.set_match( Match( value, @@ -93,14 +91,10 @@ def add(self, value: T, score: float, is_exact_match: bool = True) -> None: ) ) else: - # New related-only match: hit_count stays 0 because - # only exact matches count as direct hits. This matters - # for select_with_hit_count / _matches_with_min_hit_count - # which filter on hit_count to weed out noise. self.set_match( Match( value, - hit_count=0, + hit_count=1, score=0.0, related_hit_count=1, related_score=score, @@ -256,22 +250,9 @@ def smooth_match_score[T](match: Match[T]) -> None: class SemanticRefAccumulator(MatchAccumulator[SemanticRefOrdinal]): - """Accumulates scored semantic reference matches. - - ``search_term_matches`` tracks which search terms produced hits (provenance). - Use ``clone`` to create a derived accumulator that inherits a - *copy* of the parent's provenance. - """ - - def __init__(self) -> None: + def __init__(self, search_term_matches: set[str] = set()): super().__init__() - self.search_term_matches: set[str] = set() - - def clone(self) -> "SemanticRefAccumulator": - """Create a new empty accumulator inheriting a copy of this one's term-match provenance.""" - acc = self.__class__() - acc.search_term_matches = set(self.search_term_matches) - return acc + self.search_term_matches = search_term_matches def add_term_matches( self, @@ -349,7 +330,8 @@ async def group_matches_by_type( semantic_ref = await semantic_refs.get_item(match.value) group = groups.get(semantic_ref.knowledge.knowledge_type) if group is None: - group = self.clone() + group = SemanticRefAccumulator() + group.search_term_matches = self.search_term_matches groups[semantic_ref.knowledge.knowledge_type] = group group.set_match(match) return groups @@ -359,7 +341,7 @@ async def get_matches_in_scope( semantic_refs: ISemanticRefCollection, ranges_in_scope: "TextRangesInScope", ) -> "SemanticRefAccumulator": - accumulator = self.clone() + accumulator = SemanticRefAccumulator(self.search_term_matches) for match in self: if ranges_in_scope.is_range_in_scope( (await semantic_refs.get_item(match.value)).range @@ -531,16 +513,15 @@ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> No for text_range in text_ranges._ranges: self.add_range(text_range) - def contains_range(self, inner_range: TextRange) -> bool: + def is_in_range(self, inner_range: TextRange) -> bool: if len(self._ranges) == 0: return False - for outer_range in self._ranges: - if inner_range in outer_range: - return True - # Since ranges are sorted by start, once we pass inner_range's start - # no further range can contain it. + i = bisect.bisect_left(self._ranges, inner_range) + for outer_range in self._ranges[i:]: if outer_range.start > inner_range.start: break + if inner_range in outer_range: + return True return False @@ -563,7 +544,7 @@ def is_range_in_scope(self, inner_range: TextRange) -> bool: # We have a very simple impl: we don't intersect/union ranges yet. # Instead, we ensure that the inner range is not rejected by any outer ranges. for outer_ranges in self.text_ranges: - if not outer_ranges.contains_range(inner_range): + if not outer_ranges.is_in_range(inner_range): return False return True From b32e8a2b251b2b908a8801356e767b5f0af57cf9 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Tue, 3 Mar 2026 00:01:11 +0100 Subject: [PATCH 46/49] - each instance gets its own set - clone() method added - fix for hit_count = 0 aaplied again - renamed is_in_range to contains_range --- src/typeagent/knowpro/collections.py | 29 +++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/src/typeagent/knowpro/collections.py b/src/typeagent/knowpro/collections.py index 1e5205f3..97ccf384 100644 --- a/src/typeagent/knowpro/collections.py +++ b/src/typeagent/knowpro/collections.py @@ -91,10 +91,14 @@ def add(self, value: T, score: float, is_exact_match: bool = True) -> None: ) ) else: + # New related-only match: hit_count stays 0 because + # only exact matches count as direct hits. This matters + # for select_with_hit_count / _matches_with_min_hit_count + # which filter on hit_count to weed out noise. self.set_match( Match( value, - hit_count=1, + hit_count=0, score=0.0, related_hit_count=1, related_score=score, @@ -250,9 +254,17 @@ def smooth_match_score[T](match: Match[T]) -> None: class SemanticRefAccumulator(MatchAccumulator[SemanticRefOrdinal]): - def __init__(self, search_term_matches: set[str] = set()): + def __init__(self, search_term_matches: set[str] | None = None): super().__init__() - self.search_term_matches = search_term_matches + self.search_term_matches = ( + search_term_matches if search_term_matches is not None else set() + ) + + def clone(self) -> "SemanticRefAccumulator": + """Create a new empty accumulator inheriting a copy of this one's term-match provenance.""" + acc = self.__class__() + acc.search_term_matches = set(self.search_term_matches) + return acc def add_term_matches( self, @@ -513,11 +525,10 @@ def add_ranges(self, text_ranges: "list[TextRange] | TextRangeCollection") -> No for text_range in text_ranges._ranges: self.add_range(text_range) - def is_in_range(self, inner_range: TextRange) -> bool: - if len(self._ranges) == 0: - return False - i = bisect.bisect_left(self._ranges, inner_range) - for outer_range in self._ranges[i:]: + def contains_range(self, inner_range: TextRange) -> bool: + # Since ranges are sorted by start, once we pass inner_range's start + # no further range can contain it. + for outer_range in self._ranges: if outer_range.start > inner_range.start: break if inner_range in outer_range: @@ -544,7 +555,7 @@ def is_range_in_scope(self, inner_range: TextRange) -> bool: # We have a very simple impl: we don't intersect/union ranges yet. # Instead, we ensure that the inner range is not rejected by any outer ranges. for outer_ranges in self.text_ranges: - if not outer_ranges.is_in_range(inner_range): + if not outer_ranges.contains_range(inner_range): return False return True From c602cf7fed3e0a73f19db1e7bf3f8884f71a94a4 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Tue, 3 Mar 2026 00:09:41 +0100 Subject: [PATCH 47/49] removed comments --- tests/test_query.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/tests/test_query.py b/tests/test_query.py index c736d655..4546e646 100644 --- a/tests/test_query.py +++ b/tests/test_query.py @@ -682,12 +682,6 @@ async def test_lookup_knowledge_type(): assert await lookup_knowledge_type(collection, "action") == [] -# --------------------------------------------------------------------------- -# Change 1: get_text_range_for_date_range uses manual ordinal counter -# (not message.ordinal) and guards against None timestamp. -# --------------------------------------------------------------------------- - - class TestGetTextRangeForDateRange: """Tests for the ordinal counter fix and timestamp None guard.""" @@ -768,11 +762,6 @@ async def test_single_message_in_range(self) -> None: assert result.end.message_ordinal == 1 -# --------------------------------------------------------------------------- -# Change 3: WhereSemanticRefExpr uses clone() (provenance copy) -# --------------------------------------------------------------------------- - - class TestWhereSemanticRefExprProvenance: """Verify that WhereSemanticRefExpr copies (not shares) search_term_matches.""" From 0bf1904a313863092d6b083e85de8c3ad15736b0 Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Tue, 3 Mar 2026 08:10:03 +0100 Subject: [PATCH 48/49] - small refactoring - removed clone --- src/typeagent/knowpro/collections.py | 9 +-------- src/typeagent/knowpro/query.py | 2 +- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/src/typeagent/knowpro/collections.py b/src/typeagent/knowpro/collections.py index 97ccf384..a2716577 100644 --- a/src/typeagent/knowpro/collections.py +++ b/src/typeagent/knowpro/collections.py @@ -260,12 +260,6 @@ def __init__(self, search_term_matches: set[str] | None = None): search_term_matches if search_term_matches is not None else set() ) - def clone(self) -> "SemanticRefAccumulator": - """Create a new empty accumulator inheriting a copy of this one's term-match provenance.""" - acc = self.__class__() - acc.search_term_matches = set(self.search_term_matches) - return acc - def add_term_matches( self, search_term: Term, @@ -342,8 +336,7 @@ async def group_matches_by_type( semantic_ref = await semantic_refs.get_item(match.value) group = groups.get(semantic_ref.knowledge.knowledge_type) if group is None: - group = SemanticRefAccumulator() - group.search_term_matches = self.search_term_matches + group = SemanticRefAccumulator(self.search_term_matches) groups[semantic_ref.knowledge.knowledge_type] = group group.set_match(match) return groups diff --git a/src/typeagent/knowpro/query.py b/src/typeagent/knowpro/query.py index b0967a6e..c30fb531 100644 --- a/src/typeagent/knowpro/query.py +++ b/src/typeagent/knowpro/query.py @@ -700,7 +700,7 @@ class WhereSemanticRefExpr(QueryOpExpr[SemanticRefAccumulator]): async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator: accumulator = await self.source_expr.eval(context) - filtered = accumulator.clone() + filtered = SemanticRefAccumulator(self.search_term_matches) # Filter matches asynchronously filtered_matches = [] From 93828915cdbdd011297a2764d54d22d2fcd1c14f Mon Sep 17 00:00:00 2001 From: Bernhard Merkle Date: Tue, 3 Mar 2026 14:24:56 +0100 Subject: [PATCH 49/49] fix for filtered --- src/typeagent/knowpro/query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/typeagent/knowpro/query.py b/src/typeagent/knowpro/query.py index c30fb531..0bedf958 100644 --- a/src/typeagent/knowpro/query.py +++ b/src/typeagent/knowpro/query.py @@ -700,7 +700,7 @@ class WhereSemanticRefExpr(QueryOpExpr[SemanticRefAccumulator]): async def eval(self, context: QueryEvalContext) -> SemanticRefAccumulator: accumulator = await self.source_expr.eval(context) - filtered = SemanticRefAccumulator(self.search_term_matches) + filtered = SemanticRefAccumulator(set(accumulator.search_term_matches)) # Filter matches asynchronously filtered_matches = []