diff --git a/integrations/openai/pyproject.toml b/integrations/openai/pyproject.toml index b3eea7f83..52893cc14 100644 --- a/integrations/openai/pyproject.toml +++ b/integrations/openai/pyproject.toml @@ -20,6 +20,11 @@ dependencies = [ "openai-agents>=0.5.0" ] +[project.optional-dependencies] +memory = [ + "databricks-ai-bridge[memory]>=0.12.0", +] + [dependency-groups] dev = [ "typing_extensions>=4.15.0", diff --git a/integrations/openai/src/databricks_openai/agents/session.py b/integrations/openai/src/databricks_openai/agents/session.py new file mode 100644 index 000000000..3c6e7549b --- /dev/null +++ b/integrations/openai/src/databricks_openai/agents/session.py @@ -0,0 +1,647 @@ +from __future__ import annotations + +import asyncio +import json +import logging +from datetime import datetime, timezone +from threading import Lock +from typing import Any, Dict, Optional, Tuple, Union, cast +from uuid import UUID + +from databricks.sdk import WorkspaceClient + +try: + from agents.items import TResponseInputItem + from agents.memory.session import SessionABC + from databricks_ai_bridge.lakebase import AsyncLakebasePool, LakebasePool + from psycopg import sql + from psycopg.sql import Composed +except ImportError as e: + raise ImportError( + "MemorySession requires databricks-openai[memory]. " + "Please install with: pip install databricks-openai[memory]" + ) from e + +logger = logging.getLogger(__name__) + +# Module-level pool cache: instance_name -> LakebasePool +_pool_cache: Dict[str, LakebasePool] = {} +_pool_cache_lock = Lock() + +# Module-level async pool cache: instance_name -> AsyncLakebasePool +_async_pool_cache: Dict[str, AsyncLakebasePool] = {} +_async_pool_cache_lock = asyncio.Lock() + + +def _get_or_create_pool( + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + **pool_kwargs, +) -> LakebasePool: + """Get cached pool or create new one for this instance.""" + cache_key = instance_name + + with _pool_cache_lock: + if cache_key not in _pool_cache: + logger.info(f"Creating new LakebasePool for {cache_key}") + _pool_cache[cache_key] = LakebasePool( + instance_name=instance_name, + workspace_client=workspace_client, + **pool_kwargs, + ) + return _pool_cache[cache_key] + + +async def _get_or_create_async_pool( + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + **pool_kwargs, +) -> AsyncLakebasePool: + """Get cached async pool or create new one for this instance.""" + cache_key = instance_name + + async with _async_pool_cache_lock: + if cache_key not in _async_pool_cache: + logger.info(f"Creating new AsyncLakebasePool for {cache_key}") + pool = AsyncLakebasePool( + instance_name=instance_name, + workspace_client=workspace_client, + **pool_kwargs, + ) + await pool.open() + _async_pool_cache[cache_key] = pool + return _async_pool_cache[cache_key] + + +class _MemorySessionBase(SessionABC): + """ + Base class with shared SQL, configuration, and helper methods for memory sessions. + + Subclasses implement sync or async pool initialization and database operations. + """ + + # Table names + SESSIONS_TABLE = "agent_sessions" + MESSAGES_TABLE = "agent_messages" + + CREATE_SESSIONS_TABLE_SQL = """ + CREATE TABLE IF NOT EXISTS {sessions_table} ( + session_id UUID PRIMARY KEY, + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP + ); + """ + + CREATE_MESSAGES_TABLE_SQL = """ + CREATE TABLE IF NOT EXISTS {messages_table} ( + id BIGSERIAL PRIMARY KEY, + session_id UUID NOT NULL REFERENCES {sessions_table}(session_id) ON DELETE CASCADE, + message_data JSONB NOT NULL, + created_at TIMESTAMPTZ DEFAULT CURRENT_TIMESTAMP + ); + CREATE INDEX IF NOT EXISTS {idx_session_id} + ON {messages_table}(session_id); + CREATE INDEX IF NOT EXISTS {idx_session_order} + ON {messages_table}(session_id, id); + """ + + session_id: UUID + + def __init__( + self, + session_id: UUID, + *, + sessions_table: str = SESSIONS_TABLE, + messages_table: str = MESSAGES_TABLE, + ) -> None: + """ + Initialize base session attributes. + + Args: + session_id: UUID identifier for this conversation session. + sessions_table: Name of the sessions table. Defaults to "agent_sessions". + messages_table: Name of the messages table. Defaults to "agent_messages". + """ + self.session_id = session_id + self.sessions_table = sessions_table + self.messages_table = messages_table + + # --- SQL Building Helpers --- + + def _build_create_sessions_sql(self) -> Composed: + """Build SQL to create the sessions table.""" + return sql.SQL(self.CREATE_SESSIONS_TABLE_SQL).format( + sessions_table=sql.Identifier(self.sessions_table) + ) + + def _build_create_messages_sql(self) -> Composed: + """Build SQL to create the messages table.""" + return sql.SQL(self.CREATE_MESSAGES_TABLE_SQL).format( + sessions_table=sql.Identifier(self.sessions_table), + messages_table=sql.Identifier(self.messages_table), + idx_session_id=sql.Identifier(f"idx_{self.messages_table}_session_id"), + idx_session_order=sql.Identifier(f"idx_{self.messages_table}_session_order"), + ) + + def _build_ensure_session_sql(self) -> Composed: + """Build SQL to insert session record if not exists.""" + return sql.SQL( + """ + INSERT INTO {} (session_id, created_at, updated_at) + VALUES (%s, %s, %s) + ON CONFLICT (session_id) DO NOTHING + """ + ).format(sql.Identifier(self.sessions_table)) + + def _build_get_items_query( + self, limit: int | None + ) -> Tuple[Composed, Tuple[UUID, ...] | Tuple[UUID, int]]: + """Build SQL query and params to get items.""" + if limit is not None: + query = sql.SQL( + """ + SELECT message_data FROM ( + SELECT message_data, id + FROM {} + WHERE session_id = %s + ORDER BY id DESC + LIMIT %s + ) sub + ORDER BY id ASC + """ + ).format(sql.Identifier(self.messages_table)) + params: Tuple[UUID, ...] | Tuple[UUID, int] = (self.session_id, limit) + else: + query = sql.SQL( + """ + SELECT message_data + FROM {} + WHERE session_id = %s + ORDER BY id ASC + """ + ).format(sql.Identifier(self.messages_table)) + params = (self.session_id,) + return query, params + + def _build_add_items_sql(self) -> Composed: + """Build SQL to insert message items.""" + return sql.SQL( + """ + INSERT INTO {} (session_id, message_data) + VALUES (%s, %s) + """ + ).format(sql.Identifier(self.messages_table)) + + def _build_update_session_timestamp_sql(self) -> Composed: + """Build SQL to update session timestamp.""" + return sql.SQL( + """ + UPDATE {} + SET updated_at = CURRENT_TIMESTAMP + WHERE session_id = %s + """ + ).format(sql.Identifier(self.sessions_table)) + + def _build_update_session_timestamp_with_value_sql(self) -> Composed: + """Build SQL to update session timestamp with explicit value.""" + return sql.SQL( + """ + UPDATE {} + SET updated_at = %s + WHERE session_id = %s + """ + ).format(sql.Identifier(self.sessions_table)) + + def _build_pop_item_sql(self) -> Composed: + """Build SQL to delete and return most recent item.""" + messages_table_id = sql.Identifier(self.messages_table) + return sql.SQL( + """ + DELETE FROM {messages_table} + WHERE id = ( + SELECT id + FROM {messages_table} + WHERE session_id = %s + ORDER BY id DESC + LIMIT 1 + ) + RETURNING message_data + """ + ).format(messages_table=messages_table_id) + + def _build_clear_session_sql(self) -> Composed: + """Build SQL to delete all messages for session.""" + return sql.SQL("DELETE FROM {} WHERE session_id = %s").format( + sql.Identifier(self.messages_table) + ) + + def _prepare_items_for_insert(self, items: list[TResponseInputItem]) -> list[Tuple[UUID, str]]: + """Prepare items for database insertion.""" + return [(self.session_id, json.dumps(item)) for item in items] + + def _parse_message_data(self, message_data: Union[str, dict[str, Any]]) -> TResponseInputItem: + """Parse message_data from database (may be JSON string or dict).""" + if isinstance(message_data, str): + return cast(TResponseInputItem, json.loads(message_data)) + return cast(TResponseInputItem, message_data) + + def _parse_rows_to_items(self, rows: list) -> list[TResponseInputItem]: + """Parse database rows to list of items.""" + return [self._parse_message_data(row["message_data"]) for row in rows] + + +class MemorySession(_MemorySessionBase): + """ + OpenAI Agents SDK Session implementation using Lakebase for persistent storage. + + This class follows the Session protocol for conversation memory, + storing session data in two tables: + - agent_sessions: Tracks session metadata (session_id, created_at, updated_at) + - agent_messages: Stores conversation items (id, session_id, message_data, created_at) + + SessionABC: https://openai.github.io/openai-agents-python/ref/memory/session/#agents.memory.session.SessionABC + + Example: + ```python + from uuid import UUID + from databricks_openai.agents.session import MemorySession + from agents import Agent, Runner + + async def run_agent(thread_id: UUID | None, message: str): + # Use uuid7 for time-ordered UUIDs (better for database indexing) + session_id = thread_id + session = MemorySession( + session_id=session_id, + instance_name="my-lakebase-instance" + ) + agent = Agent(name="Assistant") + return await Runner.run(agent, message, session=session) + ``` + """ + + def __init__( + self, + session_id: UUID, + *, + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + sessions_table: str = _MemorySessionBase.SESSIONS_TABLE, + messages_table: str = _MemorySessionBase.MESSAGES_TABLE, + **pool_kwargs, + ) -> None: + """ + Initialize a MemorySession. + + On first initialization for a given Lakebase instance, this will automatically + create the required tables if they don't exist. + + Args: + session_id: UUID identifier for this conversation session. + instance_name: Name of the Lakebase instance. + workspace_client: Optional WorkspaceClient for authentication. + sessions_table: Name of the sessions table. Defaults to "agent_sessions". + messages_table: Name of the messages table. Defaults to "agent_messages". + **pool_kwargs: Additional arguments passed to LakebasePool. + """ + super().__init__( + session_id=session_id, + sessions_table=sessions_table, + messages_table=messages_table, + ) + + self._pool = _get_or_create_pool( + instance_name=instance_name, + workspace_client=workspace_client, + **pool_kwargs, + ) + + if not self._tables_exist(): + self._create_tables() + + self._ensure_session() + + def _tables_exist(self) -> bool: + """Check if both session tables already exist.""" + with self._pool.connection() as conn: + result = conn.execute( + """ + SELECT COUNT(*) as cnt FROM information_schema.tables + WHERE table_schema = current_schema() + AND table_name IN (%s, %s) + """, + (self.sessions_table, self.messages_table), + ) + row = result.fetchone() + return row["cnt"] == 2 + + def _create_tables(self) -> None: + """Create the required tables.""" + with self._pool.connection() as conn: + conn.execute(self._build_create_sessions_sql()) + conn.execute(self._build_create_messages_sql()) + logger.info(f"Created tables {self.sessions_table}, {self.messages_table}") + + def _ensure_session(self) -> None: + """Ensure the session record exists in agent_sessions table.""" + now = datetime.now(timezone.utc) + with self._pool.connection() as conn: + conn.execute( + self._build_ensure_session_sql(), + (self.session_id, now, now), + ) + logger.debug(f"Ensured session {self.session_id} exists") + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """ + Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history. + """ + query, params = self._build_get_items_query(limit) + with self._pool.connection() as conn: + result = conn.execute(query, params) + rows = result.fetchall() + return self._parse_rows_to_items(rows) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """ + Add new items to the conversation history. + + Args: + items: List of input items to add to the history. + """ + if not items: + return + + with self._pool.connection() as conn: + with conn.cursor() as cur: + cur.executemany( + self._build_add_items_sql(), + self._prepare_items_for_insert(items), + ) + conn.execute( + self._build_update_session_timestamp_sql(), + (self.session_id,), + ) + logger.debug(f"Added {len(items)} items to session {self.session_id}") + + async def pop_item(self) -> TResponseInputItem | None: + """ + Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty. + """ + with self._pool.connection() as conn: + result = conn.execute( + self._build_pop_item_sql(), + (self.session_id,), + ) + row = result.fetchone() + + if row: + now = datetime.now(timezone.utc) + conn.execute( + self._build_update_session_timestamp_with_value_sql(), + (now, self.session_id), + ) + + if row: + logger.debug(f"Popped item from session {self.session_id}") + return self._parse_message_data(row["message_data"]) + return None + + async def clear_session(self) -> None: + """Clear all items for this session.""" + with self._pool.connection() as conn: + result = conn.execute( + self._build_clear_session_sql(), + (self.session_id,), + ) + count = result.rowcount + + now = datetime.now(timezone.utc) + conn.execute( + self._build_update_session_timestamp_with_value_sql(), + (now, self.session_id), + ) + logger.info(f"Cleared {count} items from session {self.session_id}") + + +class AsyncMemorySession(_MemorySessionBase): + """ + OpenAI Agents SDK Session implementation using Lakebase for persistent storage (async version). + + This class follows the Session protocol for conversation memory, + storing session data in two tables: + - agent_sessions: Tracks session metadata (session_id, created_at, updated_at) + - agent_messages: Stores conversation items (id, session_id, message_data, created_at) + + SessionABC: https://openai.github.io/openai-agents-python/ref/memory/session/#agents.memory.session.SessionABC + + Example: + ```python + from uuid import UUID + from databricks_openai.agents.session import AsyncMemorySession + from agents import Agent, Runner + + async def run_agent(thread_id: UUID | None, message: str): + session_id = thread_id + session = AsyncMemorySession( + session_id=session_id, + instance_name="my-lakebase-instance" + ) + agent = Agent(name="Assistant") + return await Runner.run(agent, message, session=session) + ``` + """ + + def __init__( + self, + session_id: UUID, + *, + instance_name: str, + workspace_client: Optional[WorkspaceClient] = None, + sessions_table: str = _MemorySessionBase.SESSIONS_TABLE, + messages_table: str = _MemorySessionBase.MESSAGES_TABLE, + **pool_kwargs, + ) -> None: + """ + Initialize an AsyncMemorySession. + + Note: The async pool and tables are initialized lazily on first use. + + Args: + session_id: UUID identifier for this conversation session. + instance_name: Name of the Lakebase instance. + workspace_client: Optional WorkspaceClient for authentication. + sessions_table: Name of the sessions table. Defaults to "agent_sessions". + messages_table: Name of the messages table. Defaults to "agent_messages". + **pool_kwargs: Additional arguments passed to AsyncLakebasePool. + """ + super().__init__( + session_id=session_id, + sessions_table=sessions_table, + messages_table=messages_table, + ) + + self._instance_name = instance_name + self._workspace_client = workspace_client + self._pool_kwargs = pool_kwargs + + self._pool: Optional[AsyncLakebasePool] = None + self._initialized = False + self._init_lock = asyncio.Lock() + + async def _ensure_initialized(self) -> None: + """Ensure the pool is created and tables exist (lazy initialization).""" + if self._initialized: + return + + async with self._init_lock: + if self._initialized: + return + + self._pool = await _get_or_create_async_pool( + instance_name=self._instance_name, + workspace_client=self._workspace_client, + **self._pool_kwargs, + ) + + if not await self._tables_exist(): + await self._create_tables() + + await self._ensure_session() + self._initialized = True + + async def _tables_exist(self) -> bool: + """Check if both session tables already exist.""" + assert self._pool is not None + async with self._pool.connection() as conn: + result = await conn.execute( + """ + SELECT COUNT(*) as cnt FROM information_schema.tables + WHERE table_schema = current_schema() + AND table_name IN (%s, %s) + """, + (self.sessions_table, self.messages_table), + ) + row = await result.fetchone() + return row["cnt"] == 2 + + async def _create_tables(self) -> None: + """Create the required tables.""" + assert self._pool is not None + async with self._pool.connection() as conn: + await conn.execute(self._build_create_sessions_sql()) + await conn.execute(self._build_create_messages_sql()) + logger.info(f"Created tables {self.sessions_table}, {self.messages_table}") + + async def _ensure_session(self) -> None: + """Ensure the session record exists in agent_sessions table.""" + assert self._pool is not None + now = datetime.now(timezone.utc) + async with self._pool.connection() as conn: + await conn.execute( + self._build_ensure_session_sql(), + (self.session_id, now, now), + ) + logger.debug(f"Ensured session {self.session_id} exists") + + async def get_items(self, limit: int | None = None) -> list[TResponseInputItem]: + """ + Retrieve the conversation history for this session. + + Args: + limit: Maximum number of items to retrieve. If None, retrieves all items. + When specified, returns the latest N items in chronological order. + + Returns: + List of input items representing the conversation history. + """ + await self._ensure_initialized() + assert self._pool is not None + + query, params = self._build_get_items_query(limit) + async with self._pool.connection() as conn: + result = await conn.execute(query, params) + rows = await result.fetchall() + return self._parse_rows_to_items(rows) + + async def add_items(self, items: list[TResponseInputItem]) -> None: + """ + Add new items to the conversation history. + + Args: + items: List of input items to add to the history. + """ + if not items: + return + + await self._ensure_initialized() + assert self._pool is not None + + async with self._pool.connection() as conn: + async with conn.cursor() as cur: + await cur.executemany( + self._build_add_items_sql(), + self._prepare_items_for_insert(items), + ) + await conn.execute( + self._build_update_session_timestamp_sql(), + (self.session_id,), + ) + logger.debug(f"Added {len(items)} items to session {self.session_id}") + + async def pop_item(self) -> TResponseInputItem | None: + """ + Remove and return the most recent item from the session. + + Returns: + The most recent item if it exists, None if the session is empty. + """ + await self._ensure_initialized() + assert self._pool is not None + + async with self._pool.connection() as conn: + result = await conn.execute( + self._build_pop_item_sql(), + (self.session_id,), + ) + row = await result.fetchone() + + if row: + now = datetime.now(timezone.utc) + await conn.execute( + self._build_update_session_timestamp_with_value_sql(), + (now, self.session_id), + ) + + if row: + logger.debug(f"Popped item from session {self.session_id}") + return self._parse_message_data(row["message_data"]) + return None + + async def clear_session(self) -> None: + """Clear all items for this session.""" + await self._ensure_initialized() + assert self._pool is not None + + async with self._pool.connection() as conn: + result = await conn.execute( + self._build_clear_session_sql(), + (self.session_id,), + ) + count = result.rowcount + + now = datetime.now(timezone.utc) + await conn.execute( + self._build_update_session_timestamp_with_value_sql(), + (now, self.session_id), + ) + logger.info(f"Cleared {count} items from session {self.session_id}") diff --git a/integrations/openai/src/databricks_openai/utils/clients.py b/integrations/openai/src/databricks_openai/utils/clients.py index 99bb9d08c..fca933824 100644 --- a/integrations/openai/src/databricks_openai/utils/clients.py +++ b/integrations/openai/src/databricks_openai/utils/clients.py @@ -18,14 +18,15 @@ def auth_flow(self, request: Request) -> Generator[Request, Response, None]: yield request -def _strip_strict_from_tools(tools: list | None) -> list | None: +def _strip_strict_from_tools(tools) -> list | None: """Remove 'strict' field from tool function definitions. Databricks model endpoints (except GPT) don't support the 'strict' field in tool schemas, but openai-agents SDK v0.6.4+ includes it. """ - if tools is None: - return None + # Handle None or OpenAI's Omit sentinel (non-iterable placeholder) + if tools is None or not isinstance(tools, list): + return tools for tool in tools: if isinstance(tool, dict) and "function" in tool: tool.get("function", {}).pop("strict", None) diff --git a/integrations/openai/tests/integration_tests/test_memory_session.py b/integrations/openai/tests/integration_tests/test_memory_session.py new file mode 100644 index 000000000..183266d5c --- /dev/null +++ b/integrations/openai/tests/integration_tests/test_memory_session.py @@ -0,0 +1,413 @@ +"""Integration tests for MemorySession and AsyncMemorySession. + +These tests require: +1. A Lakebase instance to be available +2. Valid Databricks authentication (DATABRICKS_HOST + DATABRICKS_TOKEN as env variables) + +Set the environment variable: + LAKEBASE_INSTANCE_NAME: Name or hostname of the Lakebase instance + +Example: + LAKEBASE_INSTANCE_NAME=lakebase pytest tests/integration_tests/test_memory_session.py -v +""" + +from __future__ import annotations + +import asyncio +import os +import uuid +from typing import Any, cast + +import pytest + +# Skip all tests if LAKEBASE_INSTANCE_NAME is not set +pytestmark = pytest.mark.skipif( + not os.environ.get("LAKEBASE_INSTANCE_NAME"), + reason="LAKEBASE_INSTANCE_NAME environment variable not set", +) + + +def get_instance_name() -> str: + """Get the Lakebase instance name from environment.""" + return os.environ["LAKEBASE_INSTANCE_NAME"] + + +def get_unique_table_names() -> tuple[str, str]: + """Generate unique table names for test isolation.""" + suffix = uuid.uuid4().hex[:8] + return f"test_sessions_{suffix}", f"test_messages_{suffix}" + + +@pytest.fixture(scope="session", autouse=True) +def cleanup_pool_cache(): + """Session-scoped fixture to close cached pools after all tests complete.""" + yield + + # Close sync pool cache + from databricks_openai.agents import session as session_module + + for pool in session_module._pool_cache.values(): + try: + pool.close() + except Exception: + pass + session_module._pool_cache.clear() + + # Close async pool cache - need to handle event loop carefully + for pool in list(session_module._async_pool_cache.values()): + try: + # Access the underlying pool and close it synchronously if possible + # The pool's _pool attribute is the actual AsyncConnectionPool + if hasattr(pool, "_pool") and pool._pool is not None: + # Use wait=False to avoid blocking on workers + pool._pool.close(timeout=0) + except Exception: + pass + session_module._async_pool_cache.clear() + + +@pytest.fixture +def cleanup_tables(): + """Fixture to track and clean up test tables after tests.""" + tables_to_cleanup: list[tuple[str, str]] = [] + + yield tables_to_cleanup + + # Cleanup after test + if tables_to_cleanup: + from databricks_ai_bridge.lakebase import LakebasePool + + pool = LakebasePool(instance_name=get_instance_name()) + with pool.connection() as conn: + for sessions_table, messages_table in tables_to_cleanup: + # Drop messages first (foreign key constraint) + conn.execute(f"DROP TABLE IF EXISTS {messages_table}") + conn.execute(f"DROP TABLE IF EXISTS {sessions_table}") + + +# ============================================================================= +# Sync MemorySession Tests +# ============================================================================= + + +def test_memory_session_crud_operations(cleanup_tables): + """ + Comprehensive CRUD test for sync MemorySession. + + Tests the full lifecycle: + - clear_session() on fresh session + - get_items() returns empty list for new session + - add_items() stores messages + - get_items() retrieves stored messages + - get_items(limit=N) returns latest N items in order + - pop_item() removes and returns most recent item + - clear_session() removes all items + """ + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id = uuid.uuid4() + session = MemorySession( + session_id=session_id, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Clear any existing data (should be no-op for new session) + asyncio.run(session.clear_session()) + + # Test get_items on empty session + items = cast(list[Any], asyncio.run(session.get_items())) + assert items == [], f"Expected empty list, got {items}" + + # Test add_items + test_items: list[Any] = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you!"}, + ] + asyncio.run(session.add_items(test_items)) + + # Test get_items returns what we added + items = cast(list[Any], asyncio.run(session.get_items())) + assert len(items) == 2, f"Expected 2 items, got {len(items)}" + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello, how are you?" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "I'm doing well, thank you!" + + # Test get_items with limit - should return latest N items in chronological order + items = cast(list[Any], asyncio.run(session.get_items(limit=1))) + assert len(items) == 1, f"Expected 1 item with limit, got {len(items)}" + assert items[0]["role"] == "assistant" # Latest item + + # Test pop_item - removes and returns the last item + popped = cast(Any, asyncio.run(session.pop_item())) + assert popped is not None + assert popped["role"] == "assistant" # Should be the last item + + # Verify only 1 item remains + items = cast(list[Any], asyncio.run(session.get_items())) + assert len(items) == 1, f"Expected 1 item after pop, got {len(items)}" + assert items[0]["role"] == "user" + + # Test clear_session + asyncio.run(session.clear_session()) + items = cast(list[Any], asyncio.run(session.get_items())) + assert items == [], f"Expected empty after clear, got {items}" + + +def test_memory_session_multiple_sessions_isolated(cleanup_tables): + """Test that different session_ids have isolated data.""" + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id_1 = uuid.uuid4() + session_id_2 = uuid.uuid4() + + session_1 = MemorySession( + session_id=session_id_1, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + session_2 = MemorySession( + session_id=session_id_2, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add different items to each session + items_1_data: list[Any] = [{"role": "user", "content": "Session 1 message"}] + items_2_data: list[Any] = [{"role": "user", "content": "Session 2 message"}] + asyncio.run(session_1.add_items(items_1_data)) + asyncio.run(session_2.add_items(items_2_data)) + + # Verify isolation + items_1 = cast(list[Any], asyncio.run(session_1.get_items())) + items_2 = cast(list[Any], asyncio.run(session_2.get_items())) + + assert len(items_1) == 1 + assert len(items_2) == 1 + assert items_1[0]["content"] == "Session 1 message" + assert items_2[0]["content"] == "Session 2 message" + + # Clear one session shouldn't affect the other + asyncio.run(session_1.clear_session()) + items_1 = cast(list[Any], asyncio.run(session_1.get_items())) + items_2 = cast(list[Any], asyncio.run(session_2.get_items())) + assert len(items_1) == 0 + assert len(items_2) == 1 + + +def test_memory_session_pop_empty_returns_none(cleanup_tables): + """Test that pop_item returns None on empty session.""" + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = MemorySession( + session_id=uuid.uuid4(), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Pop on empty session should return None + popped = asyncio.run(session.pop_item()) + assert popped is None + + +def test_memory_session_add_empty_items_noop(cleanup_tables): + """Test that add_items with empty list is a no-op.""" + from databricks_openai.agents.session import MemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = MemorySession( + session_id=uuid.uuid4(), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add empty list - should not raise + asyncio.run(session.add_items([])) + + # Session should still be empty + items = cast(list[Any], asyncio.run(session.get_items())) + assert items == [] + + +# ============================================================================= +# Async AsyncMemorySession Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_async_memory_session_crud_operations(cleanup_tables): + """ + Comprehensive CRUD test for AsyncMemorySession. + + Tests the full lifecycle: + - clear_session() on fresh session + - get_items() returns empty list for new session + - add_items() stores messages + - get_items() retrieves stored messages + - get_items(limit=N) returns latest N items in order + - pop_item() removes and returns most recent item + - clear_session() removes all items + """ + from databricks_openai.agents.session import AsyncMemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id = uuid.uuid4() + session = AsyncMemorySession( + session_id=session_id, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Clear any existing data (should be no-op for new session) + await session.clear_session() + + # Test get_items on empty session + items = cast(list[Any], await session.get_items()) + assert items == [], f"Expected empty list, got {items}" + + # Test add_items + test_items: list[Any] = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you!"}, + ] + await session.add_items(test_items) + + # Test get_items returns what we added + items = cast(list[Any], await session.get_items()) + assert len(items) == 2, f"Expected 2 items, got {len(items)}" + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello, how are you?" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "I'm doing well, thank you!" + + # Test get_items with limit - should return latest N items in chronological order + items = cast(list[Any], await session.get_items(limit=1)) + assert len(items) == 1, f"Expected 1 item with limit, got {len(items)}" + assert items[0]["role"] == "assistant" # Latest item + + # Test pop_item - removes and returns the last item + popped = cast(Any, await session.pop_item()) + assert popped is not None + assert popped["role"] == "assistant" # Should be the last item + + # Verify only 1 item remains + items = cast(list[Any], await session.get_items()) + assert len(items) == 1, f"Expected 1 item after pop, got {len(items)}" + assert items[0]["role"] == "user" + + # Test clear_session + await session.clear_session() + items = cast(list[Any], await session.get_items()) + assert items == [], f"Expected empty after clear, got {items}" + + +@pytest.mark.asyncio +async def test_async_memory_session_multiple_sessions_isolated(cleanup_tables): + """Test that different session_ids have isolated data (async version).""" + from databricks_openai.agents.session import AsyncMemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session_id_1 = uuid.uuid4() + session_id_2 = uuid.uuid4() + + session_1 = AsyncMemorySession( + session_id=session_id_1, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + session_2 = AsyncMemorySession( + session_id=session_id_2, + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add different items to each session + items_1_data: list[Any] = [{"role": "user", "content": "Async Session 1 message"}] + items_2_data: list[Any] = [{"role": "user", "content": "Async Session 2 message"}] + await session_1.add_items(items_1_data) + await session_2.add_items(items_2_data) + + # Verify isolation + items_1 = cast(list[Any], await session_1.get_items()) + items_2 = cast(list[Any], await session_2.get_items()) + + assert len(items_1) == 1 + assert len(items_2) == 1 + assert items_1[0]["content"] == "Async Session 1 message" + assert items_2[0]["content"] == "Async Session 2 message" + + # Clear one session shouldn't affect the other + await session_1.clear_session() + items_1 = cast(list[Any], await session_1.get_items()) + items_2 = cast(list[Any], await session_2.get_items()) + assert len(items_1) == 0 + assert len(items_2) == 1 + + +@pytest.mark.asyncio +async def test_async_memory_session_pop_empty_returns_none(cleanup_tables): + """Test that pop_item returns None on empty session (async version).""" + from databricks_openai.agents.session import AsyncMemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = AsyncMemorySession( + session_id=uuid.uuid4(), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Pop on empty session should return None + popped = await session.pop_item() + assert popped is None + + +@pytest.mark.asyncio +async def test_async_memory_session_add_empty_items_noop(cleanup_tables): + """Test that add_items with empty list is a no-op (async version).""" + from databricks_openai.agents.session import AsyncMemorySession + + sessions_table, messages_table = get_unique_table_names() + cleanup_tables.append((sessions_table, messages_table)) + + session = AsyncMemorySession( + session_id=uuid.uuid4(), + instance_name=get_instance_name(), + sessions_table=sessions_table, + messages_table=messages_table, + ) + + # Add empty list - should not raise + await session.add_items([]) + + # Session should still be empty + items = cast(list[Any], await session.get_items()) + assert items == [] diff --git a/integrations/openai/tests/unit_tests/test_clients.py b/integrations/openai/tests/unit_tests/test_clients.py index df44f276e..6f30d5b6a 100644 --- a/integrations/openai/tests/unit_tests/test_clients.py +++ b/integrations/openai/tests/unit_tests/test_clients.py @@ -126,6 +126,22 @@ def test_strip_strict_from_tools_handles_none(self): assert _strip_strict_from_tools(None) is None + def test_strip_strict_from_tools_handles_non_list(self): + """Test that non-list values (like OpenAI's Omit sentinel) are passed through.""" + from databricks_openai.utils.clients import _strip_strict_from_tools + + # Simulate OpenAI's Omit sentinel (a non-iterable placeholder) + class Omit: + pass + + omit_sentinel = Omit() + result = _strip_strict_from_tools(omit_sentinel) + assert result is omit_sentinel # Should return unchanged + + # Also test with other non-list types + assert _strip_strict_from_tools("not a list") == "not a list" + assert _strip_strict_from_tools(123) == 123 + def test_strip_strict_from_tools_handles_empty_list(self): from databricks_openai.utils.clients import _strip_strict_from_tools diff --git a/integrations/openai/tests/unit_tests/test_session.py b/integrations/openai/tests/unit_tests/test_session.py new file mode 100644 index 000000000..dede5e8fd --- /dev/null +++ b/integrations/openai/tests/unit_tests/test_session.py @@ -0,0 +1,1040 @@ +from __future__ import annotations + +import json +from typing import TYPE_CHECKING, Any, cast +from unittest.mock import MagicMock +from uuid import UUID + +import pytest + +pytest.importorskip("psycopg") +pytest.importorskip("psycopg_pool") +pytest.importorskip("agents.memory.session") + +from databricks_ai_bridge import lakebase + +if TYPE_CHECKING: + from psycopg import sql +else: + from psycopg import sql + +from databricks_openai.agents.session import ( + AsyncMemorySession, + MemorySession, + _async_pool_cache, + _pool_cache, +) + +# Use UUID (V7) for performance +TEST_SESSION_ID = UUID("12345678-1234-5678-1234-567812345678") +TEST_SESSION_ID_2 = UUID("22345678-1234-5678-1234-567812345678") +TEST_SESSION_ID_3 = UUID("32345678-1234-5678-1234-567812345678") + + +def query_to_string(query): + """Convert a query (string or sql.Composed) to a string for testing.""" + if isinstance(query, str): + return query + if isinstance(query, (sql.Composed, sql.SQL, sql.Identifier)): + return query.as_string(None) + return str(query) + + +class MockCursor: + """Mock cursor for executemany operations.""" + + def __init__(self): + self.executed_queries = [] + + def executemany(self, query, params): + self.executed_queries.append((query, params)) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + pass + + +class MockAsyncCursor: + """Mock async cursor for executemany operations.""" + + def __init__(self): + self.executed_queries = [] + + async def executemany(self, query, params): + self.executed_queries.append((query, params)) + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + pass + + +class MockResult: + """Mock result object for database queries.""" + + def __init__(self, rows=None, rowcount=0): + self._rows = rows or [] + self.rowcount = rowcount + self._index = 0 + + def fetchall(self): + return self._rows + + def fetchone(self): + if self._index < len(self._rows): + row = self._rows[self._index] + self._index += 1 + return row + return None + + +class MockAsyncResult: + """Mock async result object for database queries.""" + + def __init__(self, rows=None, rowcount=0): + self._rows = rows or [] + self.rowcount = rowcount + self._index = 0 + + async def fetchall(self): + return self._rows + + async def fetchone(self): + if self._index < len(self._rows): + row = self._rows[self._index] + self._index += 1 + return row + return None + + +class MockConnection: + """Mock database connection.""" + + def __init__(self): + self.executed_queries = [] + self._cursor = MockCursor() + self._next_result = MockResult() + self._results_queue = [] + + def execute(self, query, params=None): + self.executed_queries.append((query, params)) + if self._results_queue: + return self._results_queue.pop(0) + return self._next_result + + def cursor(self): + return self._cursor + + def set_next_result(self, result): + self._next_result = result + + def queue_result(self, result): + self._results_queue.append(result) + + +class MockAsyncConnection: + """Mock async database connection.""" + + def __init__(self): + self.executed_queries = [] + self._cursor = MockAsyncCursor() + self._next_result = MockAsyncResult() + self._results_queue = [] + + async def execute(self, query, params=None): + self.executed_queries.append((query, params)) + if self._results_queue: + return self._results_queue.pop(0) + return self._next_result + + def cursor(self): + return self._cursor + + def set_next_result(self, result): + self._next_result = result + + def queue_result(self, result): + self._results_queue.append(result) + + +class MockConnectionPool: + """Mock connection pool for testing.""" + + def __init__(self, connection_value=None): + self.connection_value = connection_value or MockConnection() + self.conninfo = "" + + def __call__(self, *, conninfo, connection_class=None, **kwargs): + self.conninfo = conninfo + return self + + def connection(self): + class _Ctx: + def __init__(self, outer): + self.outer = outer + + def __enter__(self): + return self.outer.connection_value + + def __exit__(self, exc_type, exc, tb): + pass + + return _Ctx(self) + + +class MockAsyncConnectionPool: + """Mock async connection pool for testing.""" + + def __init__(self, connection_value=None): + self.connection_value = connection_value or MockAsyncConnection() + self.conninfo = "" + self._opened = False + self._closed = False + + def __call__(self, *, conninfo, connection_class=None, **kwargs): + self.conninfo = conninfo + return self + + async def open(self): + self._opened = True + + async def close(self): + self._closed = True + + def connection(self): + class _AsyncCtx: + def __init__(self, outer): + self.outer = outer + + async def __aenter__(self): + return self.outer.connection_value + + async def __aexit__(self, exc_type, exc, tb): + pass + + return _AsyncCtx(self) + + +@pytest.fixture(autouse=True) +def clear_pool_cache(): + """Clear the pool cache before each test.""" + _pool_cache.clear() + _async_pool_cache.clear() + yield + _pool_cache.clear() + _async_pool_cache.clear() + + +@pytest.fixture +def mock_workspace(): + """Create a mock workspace client.""" + workspace = MagicMock() + workspace.database.generate_database_credential.return_value = MagicMock(token="stub-token") + workspace.database.get_database_instance.return_value.read_write_dns = "db-host" + workspace.current_service_principal.me.side_effect = RuntimeError("no sp") + workspace.current_user.me.return_value = MagicMock(user_name="test@databricks.com") + return workspace + + +@pytest.fixture +def mock_connection(): + """Create a mock connection.""" + return MockConnection() + + +@pytest.fixture +def mock_async_connection(): + """Create a mock async connection.""" + return MockAsyncConnection() + + +@pytest.fixture +def mock_pool(mock_connection): + """Create a mock connection pool.""" + return MockConnectionPool(connection_value=mock_connection) + + +@pytest.fixture +def mock_async_pool(mock_async_connection): + """Create a mock async connection pool.""" + return MockAsyncConnectionPool(connection_value=mock_async_connection) + + +# ============================================================================= +# MemorySession Tests (Sync) +# ============================================================================= + + +def test_session_configures_lakebase(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test that MemorySession correctly configures the Lakebase pool.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables already exist + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + mock_connection.queue_result(MockResult()) # INSERT session + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + assert ( + mock_pool.conninfo + == "dbname=databricks_postgres user=test@databricks.com host=db-host port=5432 sslmode=require" + ) + assert session.session_id == TEST_SESSION_ID + assert session.sessions_table == "agent_sessions" + assert session.messages_table == "agent_messages" + + +def test_session_creates_tables_on_init_when_not_exist( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test that MemorySession creates tables when they don't exist.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables don't exist (count=0) + mock_connection.queue_result(MockResult(rows=[{"cnt": 0}])) + # CREATE sessions table + mock_connection.queue_result(MockResult()) + # CREATE messages table + mock_connection.queue_result(MockResult()) + # INSERT session + mock_connection.queue_result(MockResult()) + + MemorySession( + session_id=TEST_SESSION_ID_2, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + # Should have executed CREATE TABLE statements + queries = [query_to_string(q) for q, _ in mock_connection.executed_queries] + create_sessions_found = any( + "CREATE TABLE IF NOT EXISTS" in q and "agent_sessions" in q for q in queries + ) + create_messages_found = any( + "CREATE TABLE IF NOT EXISTS" in q and "agent_messages" in q for q in queries + ) + + assert create_sessions_found, "Should create sessions table" + assert create_messages_found, "Should create messages table" + + +def test_session_skips_table_creation_when_tables_exist( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test that MemorySession skips table creation when tables already exist.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: both tables exist (count=2) + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + # INSERT session (no CREATE TABLE calls) + mock_connection.queue_result(MockResult()) + + MemorySession( + session_id=TEST_SESSION_ID_3, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + # Should NOT have executed CREATE TABLE statements + queries = [query_to_string(q) for q, _ in mock_connection.executed_queries] + create_table_found = any("CREATE TABLE" in q for q in queries) + + assert not create_table_found, "Should not create tables when they already exist" + + +def test_session_ensures_session_record(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test that MemorySession ensures the session record exists.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables already exist + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + mock_connection.queue_result(MockResult()) # INSERT session + + MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + # Find the INSERT INTO agent_sessions query + insert_queries = [ + (q, p) + for q, p in mock_connection.executed_queries + if "INSERT INTO" in query_to_string(q) and "agent_sessions" in query_to_string(q) + ] + + assert len(insert_queries) > 0, "Should insert session record" + query, params = insert_queries[0] + assert params[0] == TEST_SESSION_ID, "Should use correct session_id" + + +@pytest.mark.asyncio +async def test_get_items_empty_session(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test get_items returns empty list for new session.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables exist, then INSERT session, then SELECT messages + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rows=[])) # SELECT messages + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + items = await session.get_items() + assert items == [] + + +@pytest.mark.asyncio +async def test_get_items_returns_parsed_json( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test get_items correctly parses JSON data.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + test_messages = [ + {"message_data": json.dumps({"role": "user", "content": "Hello"})}, + {"message_data": json.dumps({"role": "assistant", "content": "Hi there!"})}, + ] + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rows=test_messages)) # SELECT messages + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + items = cast(list[dict[str, Any]], await session.get_items()) + + assert len(items) == 2 + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "Hi there!" + + +@pytest.mark.asyncio +async def test_get_items_with_limit(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test get_items respects limit parameter.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result( + MockResult(rows=[{"message_data": json.dumps({"role": "user", "content": "Latest"})}]) + ) + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + items = await session.get_items(limit=1) + + assert len(items) == 1 + + # Verify the query used LIMIT + select_queries = [ + query_to_string(q) + for q, p in mock_connection.executed_queries + if "SELECT message_data" in query_to_string(q) + ] + assert any("LIMIT" in q for q in select_queries), "Should use LIMIT in query" + + +@pytest.mark.asyncio +async def test_add_items(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test add_items inserts messages correctly.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables exist + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + mock_connection.queue_result(MockResult()) # INSERT session + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + test_items: list[dict[str, Any]] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + + await session.add_items(cast(Any, test_items)) + + # Check that executemany was called on cursor + assert len(mock_connection._cursor.executed_queries) > 0 + query, params = mock_connection._cursor.executed_queries[-1] + query_str = query_to_string(query) + assert "INSERT INTO" in query_str and "agent_messages" in query_str + assert len(params) == 2 # Two items + + +@pytest.mark.asyncio +async def test_add_items_empty_list(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test add_items handles empty list gracefully.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables exist + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) + mock_connection.queue_result(MockResult()) # INSERT session + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + initial_query_count = len(mock_connection.executed_queries) + + await session.add_items([]) + + # Should not execute any additional queries for empty list + # (only the queries from init should be present) + assert len(mock_connection.executed_queries) == initial_query_count + + +@pytest.mark.asyncio +async def test_pop_item_returns_last_item(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test pop_item removes and returns the most recent item.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + # DELETE RETURNING result + mock_connection.queue_result( + MockResult( + rows=[{"message_data": json.dumps({"role": "assistant", "content": "Last msg"})}] + ) + ) + mock_connection.queue_result(MockResult()) # UPDATE session timestamp + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + popped = cast(dict[str, Any], await session.pop_item()) + + assert popped is not None + assert popped["role"] == "assistant" + assert popped["content"] == "Last msg" + + +@pytest.mark.asyncio +async def test_pop_item_returns_none_when_empty( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test pop_item returns None for empty session.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rows=[])) # DELETE RETURNING - empty + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + popped = await session.pop_item() + + assert popped is None + + +@pytest.mark.asyncio +async def test_clear_session(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test clear_session deletes all messages for the session.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rowcount=5)) # DELETE messages + mock_connection.queue_result(MockResult()) # UPDATE session timestamp + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + await session.clear_session() + + # Find the DELETE query + delete_queries = [ + (q, p) + for q, p in mock_connection.executed_queries + if "DELETE FROM" in query_to_string(q) + and "agent_messages" in query_to_string(q) + and "WHERE session_id" in query_to_string(q) + ] + + assert len(delete_queries) > 0, "Should execute DELETE query" + query, params = delete_queries[0] + assert params == (TEST_SESSION_ID,), "Should use correct session_id" + + +def test_custom_table_names(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test that custom table names are used correctly.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock: tables don't exist (custom names), so they will be created + mock_connection.queue_result(MockResult(rows=[{"cnt": 0}])) # tables don't exist + mock_connection.queue_result(MockResult()) # CREATE sessions + mock_connection.queue_result(MockResult()) # CREATE messages + mock_connection.queue_result(MockResult()) # INSERT session + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + sessions_table="custom_sessions", + messages_table="custom_messages", + ) + + assert session.sessions_table == "custom_sessions" + assert session.messages_table == "custom_messages" + + # Check that CREATE TABLE uses custom names + queries = [query_to_string(q) for q, _ in mock_connection.executed_queries] + assert any("custom_sessions" in q for q in queries) + assert any("custom_messages" in q for q in queries) + + +def test_pool_caching(monkeypatch, mock_workspace, mock_pool, mock_connection): + """Test that pools are cached and reused.""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Mock for both session creations + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist + mock_connection.queue_result(MockResult()) # INSERT session 1 + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist + mock_connection.queue_result(MockResult()) # INSERT session 2 + + session1 = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="shared-instance", + workspace_client=mock_workspace, + ) + + session2 = MemorySession( + session_id=TEST_SESSION_ID_2, + instance_name="shared-instance", + workspace_client=mock_workspace, + ) + + # Both sessions should share the same pool + assert session1._pool is session2._pool + + +@pytest.mark.asyncio +async def test_get_items_handles_dict_message_data( + monkeypatch, mock_workspace, mock_pool, mock_connection +): + """Test get_items handles message_data that's already a dict (not JSON string).""" + monkeypatch.setattr(lakebase, "ConnectionPool", mock_pool) + + # Some database drivers return JSONB as dict directly + test_messages = [ + {"message_data": {"role": "user", "content": "Already parsed"}}, + ] + + mock_connection.queue_result(MockResult(rows=[{"cnt": 2}])) # tables exist check + mock_connection.queue_result(MockResult()) # INSERT session + mock_connection.queue_result(MockResult(rows=test_messages)) + + session = MemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + items = cast(list[dict[str, Any]], await session.get_items()) + + assert len(items) == 1 + assert items[0]["role"] == "user" + assert items[0]["content"] == "Already parsed" + + +# ============================================================================= +# AsyncMemorySession Tests +# ============================================================================= + + +@pytest.mark.asyncio +async def test_async_session_lazy_initialization( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that AsyncMemorySession initializes lazily on first use.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + # Create session - should NOT trigger any DB operations yet + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + # No queries should have been executed yet + assert len(mock_async_connection.executed_queries) == 0 + assert not session._initialized + + # Mock: tables exist, then INSERT session, then SELECT messages + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT messages + + # First use triggers initialization + await session.get_items() + + # Now it should be initialized + assert session._initialized + assert mock_async_pool._opened + + +@pytest.mark.asyncio +async def test_async_session_configures_lakebase( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that AsyncMemorySession correctly configures the Lakebase pool.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-lakebase-instance", + workspace_client=mock_workspace, + ) + + # Mock: tables already exist + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT for get_items + + # Trigger initialization + await session.get_items() + + assert ( + mock_async_pool.conninfo + == "dbname=databricks_postgres user=test@databricks.com host=db-host port=5432 sslmode=require" + ) + assert session.session_id == TEST_SESSION_ID + assert session.sessions_table == "agent_sessions" + assert session.messages_table == "agent_messages" + + +@pytest.mark.asyncio +async def test_async_session_creates_tables_when_not_exist( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that AsyncMemorySession creates tables when they don't exist.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + # Mock: tables don't exist (count=0) + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 0}])) + mock_async_connection.queue_result(MockAsyncResult()) # CREATE sessions table + mock_async_connection.queue_result(MockAsyncResult()) # CREATE messages table + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT for get_items + + await session.get_items() + + # Should have executed CREATE TABLE statements + queries = [query_to_string(q) for q, _ in mock_async_connection.executed_queries] + create_sessions_found = any( + "CREATE TABLE IF NOT EXISTS" in q and "agent_sessions" in q for q in queries + ) + create_messages_found = any( + "CREATE TABLE IF NOT EXISTS" in q and "agent_messages" in q for q in queries + ) + + assert create_sessions_found, "Should create sessions table" + assert create_messages_found, "Should create messages table" + + +@pytest.mark.asyncio +async def test_async_session_skips_table_creation_when_tables_exist( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that AsyncMemorySession skips table creation when tables already exist.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + # Mock: both tables exist (count=2) + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT for get_items + + await session.get_items() + + # Should NOT have executed CREATE TABLE statements + queries = [query_to_string(q) for q, _ in mock_async_connection.executed_queries] + create_table_found = any("CREATE TABLE" in q for q in queries) + + assert not create_table_found, "Should not create tables when they already exist" + + +@pytest.mark.asyncio +async def test_async_get_items_empty_session( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async get_items returns empty list for new session.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT messages + + items = await session.get_items() + assert items == [] + + +@pytest.mark.asyncio +async def test_async_get_items_returns_parsed_json( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async get_items correctly parses JSON data.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + test_messages = [ + {"message_data": json.dumps({"role": "user", "content": "Hello"})}, + {"message_data": json.dumps({"role": "assistant", "content": "Hi there!"})}, + ] + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=test_messages)) # SELECT messages + + items = cast(list[dict[str, Any]], await session.get_items()) + + assert len(items) == 2 + assert items[0]["role"] == "user" + assert items[0]["content"] == "Hello" + assert items[1]["role"] == "assistant" + assert items[1]["content"] == "Hi there!" + + +@pytest.mark.asyncio +async def test_async_add_items(monkeypatch, mock_workspace, mock_async_pool, mock_async_connection): + """Test async add_items inserts messages correctly.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + # Mock: tables exist + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + + test_items: list[dict[str, Any]] = [ + {"role": "user", "content": "Hello"}, + {"role": "assistant", "content": "Hi!"}, + ] + + await session.add_items(cast(Any, test_items)) + + # Check that executemany was called on cursor + assert len(mock_async_connection._cursor.executed_queries) > 0 + query, params = mock_async_connection._cursor.executed_queries[-1] + query_str = query_to_string(query) + assert "INSERT INTO" in query_str and "agent_messages" in query_str + assert len(params) == 2 # Two items + + +@pytest.mark.asyncio +async def test_async_pop_item_returns_last_item( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async pop_item removes and returns the most recent item.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + # DELETE RETURNING result + mock_async_connection.queue_result( + MockAsyncResult( + rows=[{"message_data": json.dumps({"role": "assistant", "content": "Last msg"})}] + ) + ) + mock_async_connection.queue_result(MockAsyncResult()) # UPDATE session timestamp + + popped = cast(dict[str, Any], await session.pop_item()) + + assert popped is not None + assert popped["role"] == "assistant" + assert popped["content"] == "Last msg" + + +@pytest.mark.asyncio +async def test_async_pop_item_returns_none_when_empty( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async pop_item returns None for empty session.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # DELETE RETURNING - empty + + popped = await session.pop_item() + + assert popped is None + + +@pytest.mark.asyncio +async def test_async_clear_session( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test async clear_session deletes all messages for the session.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + ) + + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist check + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rowcount=5)) # DELETE messages + mock_async_connection.queue_result(MockAsyncResult()) # UPDATE session timestamp + + await session.clear_session() + + # Find the DELETE query + delete_queries = [ + (q, p) + for q, p in mock_async_connection.executed_queries + if "DELETE FROM" in query_to_string(q) + and "agent_messages" in query_to_string(q) + and "WHERE session_id" in query_to_string(q) + ] + + assert len(delete_queries) > 0, "Should execute DELETE query" + query, params = delete_queries[0] + assert params == (TEST_SESSION_ID,), "Should use correct session_id" + + +@pytest.mark.asyncio +async def test_async_custom_table_names( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that async session uses custom table names correctly.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="test-instance", + workspace_client=mock_workspace, + sessions_table="custom_sessions", + messages_table="custom_messages", + ) + + # Mock: tables don't exist (custom names), so they will be created + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 0}])) # tables don't exist + mock_async_connection.queue_result(MockAsyncResult()) # CREATE sessions + mock_async_connection.queue_result(MockAsyncResult()) # CREATE messages + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT for get_items + + await session.get_items() + + assert session.sessions_table == "custom_sessions" + assert session.messages_table == "custom_messages" + + # Check that CREATE TABLE uses custom names + queries = [query_to_string(q) for q, _ in mock_async_connection.executed_queries] + assert any("custom_sessions" in q for q in queries) + assert any("custom_messages" in q for q in queries) + + +@pytest.mark.asyncio +async def test_async_pool_caching( + monkeypatch, mock_workspace, mock_async_pool, mock_async_connection +): + """Test that async pools are cached and reused.""" + monkeypatch.setattr(lakebase, "AsyncConnectionPool", mock_async_pool) + + session1 = AsyncMemorySession( + session_id=TEST_SESSION_ID, + instance_name="shared-async-instance", + workspace_client=mock_workspace, + ) + + session2 = AsyncMemorySession( + session_id=TEST_SESSION_ID_2, + instance_name="shared-async-instance", + workspace_client=mock_workspace, + ) + + # Mock for initialization of first session + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session 1 + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT + + await session1.get_items() + + # Mock for second session (tables already exist check, but pool should be reused) + mock_async_connection.queue_result(MockAsyncResult(rows=[{"cnt": 2}])) # tables exist + mock_async_connection.queue_result(MockAsyncResult()) # INSERT session 2 + mock_async_connection.queue_result(MockAsyncResult(rows=[])) # SELECT + + await session2.get_items() + + # Both sessions should share the same pool + assert session1._pool is session2._pool