From 0e5b93d240dd89cbf1cfc77e7e80ca44cf57107b Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Sat, 21 Mar 2026 01:43:46 -0700 Subject: [PATCH 1/3] init changes --- .../agent_server/agent.py | 120 +++++++----------- .../agent_server/utils_agent_memory.py | 47 +++++++ agent-openai-agents-sdk/agent_server/agent.py | 110 +++++++--------- 3 files changed, 139 insertions(+), 138 deletions(-) create mode 100644 agent-langgraph-long-term-memory/agent_server/utils_agent_memory.py diff --git a/agent-langgraph-long-term-memory/agent_server/agent.py b/agent-langgraph-long-term-memory/agent_server/agent.py index 6ccf7e85..9a8146ea 100644 --- a/agent-langgraph-long-term-memory/agent_server/agent.py +++ b/agent-langgraph-long-term-memory/agent_server/agent.py @@ -1,6 +1,5 @@ import logging import os -from datetime import datetime from typing import Any, AsyncGenerator, Optional import mlflow @@ -13,7 +12,6 @@ ) from fastapi import HTTPException from langchain.agents import create_agent -from langchain_core.tools import tool from langgraph.store.base import BaseStore from mlflow.genai.agent_server import invoke, stream from mlflow.types.responses import ( @@ -23,84 +21,54 @@ to_chat_completions_input, ) -from agent_server.utils import ( - get_databricks_host_from_env, - get_session_id, - get_user_workspace_client, - process_agent_astream_events, -) from agent_server.utils_memory import ( get_lakebase_access_error_message, get_user_id, memory_tools, resolve_lakebase_instance_name, ) +from agent_server.utils_agent_memory import agent_memory_tools, read_agent_instructions +from agent_server.utils import ( + get_databricks_host_from_env, + get_user_workspace_client, + process_agent_astream_events, +) logger = logging.getLogger(__name__) -logging.getLogger("mlflow.utils.autologging_utils").setLevel(logging.ERROR) -logging.getLogger("LiteLLM").setLevel(logging.WARNING) mlflow.langchain.autolog() sp_workspace_client = WorkspaceClient() - -@tool -def get_current_time() -> str: - """Get the current date and time.""" - return datetime.now().isoformat() - - ############################################ # Configuration ############################################ LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4-5" -_LAKEBASE_INSTANCE_NAME_RAW = os.getenv("LAKEBASE_INSTANCE_NAME") or None +_LAKEBASE_INSTANCE_NAME_RAW = os.getenv("LAKEBASE_INSTANCE_NAME", "") EMBEDDING_ENDPOINT = "databricks-gte-large-en" EMBEDDING_DIMS = 1024 -LAKEBASE_AUTOSCALING_PROJECT = os.getenv("LAKEBASE_AUTOSCALING_PROJECT") or None -LAKEBASE_AUTOSCALING_BRANCH = os.getenv("LAKEBASE_AUTOSCALING_BRANCH") or None +UC_VOLUME = os.getenv("UC_VOLUME", "") -############################################ - -_has_autoscaling = LAKEBASE_AUTOSCALING_PROJECT and LAKEBASE_AUTOSCALING_BRANCH -if not _LAKEBASE_INSTANCE_NAME_RAW and not _has_autoscaling: +if not _LAKEBASE_INSTANCE_NAME_RAW: raise ValueError( - "Lakebase configuration is required but not set. " - "Please set one of the following in your environment:\n" - " Option 1 (provisioned): LAKEBASE_INSTANCE_NAME=\n" - " Option 2 (autoscaling): LAKEBASE_AUTOSCALING_PROJECT= and LAKEBASE_AUTOSCALING_BRANCH=\n" + "LAKEBASE_INSTANCE_NAME environment variable is required but not set. " + "Please set it in your environment:\n" + " LAKEBASE_INSTANCE_NAME=\n" ) # Resolve hostname to instance name if needed (if given hostname of lakebase instead of name) -LAKEBASE_INSTANCE_NAME = resolve_lakebase_instance_name(_LAKEBASE_INSTANCE_NAME_RAW) if _LAKEBASE_INSTANCE_NAME_RAW else None +LAKEBASE_INSTANCE_NAME = resolve_lakebase_instance_name(_LAKEBASE_INSTANCE_NAME_RAW) SYSTEM_PROMPT = """You are a helpful assistant. Use the available tools to answer questions. -You have access to memory tools that allow you to remember information about users: +## User Memory (per-user, private) - Use get_user_memory to search for previously saved information about the user - Use save_user_memory to remember important facts, preferences, or details the user shares - Use delete_user_memory to forget specific information when asked -Always check for relevant memories at the start of a conversation to provide personalized responses. - -## When to save memories - -**Always save** when the user explicitly asks you to remember something. Trigger phrases include: -"remember that…", "store this", "add to memory", "note that…", "from now on…" - -**Proactively save** when the user shares information that is likely to remain true for months or years \ -and would meaningfully improve future responses. This includes: -- Preferences (e.g., language, framework, formatting style) -- Role, responsibilities, or expertise -- Ongoing projects or long-term goals -- Recurring constraints (e.g., accessibility needs, dietary restrictions) - -## When NOT to save memories +## Agent Memory (shared across all users) +- Use save_agent_instruction to save learnings that apply to ALL users: team preferences, process rules, best practices +- Use get_agent_instructions to read the current shared instructions -- Temporary or short-lived facts (e.g., "I'm tired today") -- Trivial or one-off details (e.g., what they ate for lunch, a single troubleshooting step) -- Highly sensitive personal information (health conditions, political affiliation, sexual orientation, \ -religion, criminal history) — unless the user explicitly asks you to store it -- Information that could feel intrusive or overly personal to store""" +Always check for relevant user memories at the start of a conversation.""" def init_mcp_client(workspace_client: WorkspaceClient) -> DatabricksMultiServerMCPClient: @@ -110,49 +78,44 @@ def init_mcp_client(workspace_client: WorkspaceClient) -> DatabricksMultiServerM DatabricksMCPServer( name="system-ai", url=f"{host_name}/api/2.0/mcp/functions/system/ai", - workspace_client=workspace_client, ), ] ) -async def init_agent(store: BaseStore, workspace_client: Optional[WorkspaceClient] = None): - tools = [get_current_time] + memory_tools() - # To use MCP server tools instead, replace the line above with: - # mcp_client = init_mcp_client(workspace_client or sp_workspace_client) - # try: - # tools.extend(await mcp_client.get_tools()) - # except Exception: - # logger.warning("Failed to fetch MCP tools. Continuing without MCP tools.", exc_info=True) +async def init_agent(store: BaseStore, system_prompt: str, workspace_client: Optional[WorkspaceClient] = None): + ws = workspace_client or sp_workspace_client + mcp_client = init_mcp_client(ws) + tools = await mcp_client.get_tools() + memory_tools() + if UC_VOLUME: + tools += agent_memory_tools(ws, UC_VOLUME) return create_agent( model=ChatDatabricks(endpoint=LLM_ENDPOINT_NAME), tools=tools, - system_prompt=SYSTEM_PROMPT, + system_prompt=system_prompt, store=store, ) @invoke() -async def invoke_handler(request: ResponsesAgentRequest) -> ResponsesAgentResponse: +async def non_streaming(request: ResponsesAgentRequest) -> ResponsesAgentResponse: + user_id = get_user_id(request) + outputs = [ event.item - async for event in stream_handler(request) + async for event in streaming(request) if event.type == "response.output_item.done" ] - user_id = get_user_id(request) custom_outputs = {"user_id": user_id} if user_id else None return ResponsesAgentResponse(output=outputs, custom_outputs=custom_outputs) @stream() -async def stream_handler( +async def streaming( request: ResponsesAgentRequest, ) -> AsyncGenerator[ResponsesAgentStreamEvent, None]: - if session_id := get_session_id(request): - mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) - user_id = get_user_id(request) if not user_id: @@ -165,8 +128,6 @@ async def stream_handler( try: async with AsyncDatabricksStore( instance_name=LAKEBASE_INSTANCE_NAME, - project=LAKEBASE_AUTOSCALING_PROJECT, - branch=LAKEBASE_AUTOSCALING_BRANCH, embedding_endpoint=EMBEDDING_ENDPOINT, embedding_dims=EMBEDDING_DIMS, ) as store: @@ -175,9 +136,14 @@ async def stream_handler( if user_id: config["configurable"]["user_id"] = user_id - # By default, uses service principal credentials (sp_workspace_client). - # For on-behalf-of user authentication, use get_user_workspace_client() instead. - agent = await init_agent(workspace_client=sp_workspace_client, store=store) + full_prompt = SYSTEM_PROMPT + if UC_VOLUME: + instructions = read_agent_instructions(sp_workspace_client, UC_VOLUME) + if instructions.strip(): + full_prompt += f"\n\n## Current Agent Instructions\n{instructions}" + + agent = await init_agent(store=store, system_prompt=full_prompt) + async for event in process_agent_astream_events( agent.astream(messages, config, stream_mode=["updates", "messages"]) ): @@ -185,10 +151,10 @@ async def stream_handler( except Exception as e: error_msg = str(e).lower() # Check for Lakebase access/connection errors - if any(keyword in error_msg for keyword in ["lakebase", "pg_hba", "postgres", "database instance"]): + if any( + keyword in error_msg + for keyword in ["permission"] + ): logger.error(f"Lakebase access error: {e}") - lakebase_desc = LAKEBASE_INSTANCE_NAME or f"{LAKEBASE_AUTOSCALING_PROJECT}/{LAKEBASE_AUTOSCALING_BRANCH}" - raise HTTPException( - status_code=503, detail=get_lakebase_access_error_message(lakebase_desc) - ) from e + raise HTTPException(status_code=503, detail=get_lakebase_access_error_message(LAKEBASE_INSTANCE_NAME)) from e raise diff --git a/agent-langgraph-long-term-memory/agent_server/utils_agent_memory.py b/agent-langgraph-long-term-memory/agent_server/utils_agent_memory.py new file mode 100644 index 00000000..0c607a70 --- /dev/null +++ b/agent-langgraph-long-term-memory/agent_server/utils_agent_memory.py @@ -0,0 +1,47 @@ +import logging +from io import BytesIO + +from databricks.sdk import WorkspaceClient +from langchain_core.runnables import RunnableConfig +from langchain_core.tools import tool + +logger = logging.getLogger(__name__) + +MAX_INSTRUCTION_LINES = 50 + + +def _volume_base_path(volume: str) -> str: + return f"/Volumes/{volume.replace('.', '/')}" + + +def read_agent_instructions(w: WorkspaceClient, volume: str) -> str: + """Read instructions.md from a UC Volume. Returns empty string if not found.""" + path = f"{_volume_base_path(volume)}/instructions.md" + try: + resp = w.files.download(path) + return resp.contents.read().decode("utf-8") + except Exception: + return "" + + +def agent_memory_tools(workspace_client: WorkspaceClient, volume: str): + @tool + def save_agent_instruction(instruction: str, config: RunnableConfig) -> str: + """Save a new instruction to the shared agent memory. Use for learnings that + apply to ALL users: team preferences, process rules, best practices.""" + current = read_agent_instructions(workspace_client, volume) + lines = [l for l in current.strip().split("\n") if l.strip()] if current.strip() else [] + if sum(1 for l in lines if l.startswith("- ")) >= MAX_INSTRUCTION_LINES: + return f"Cannot save — already at {MAX_INSTRUCTION_LINES} instructions." + lines.append(f"- {instruction}") + path = f"{_volume_base_path(volume)}/instructions.md" + workspace_client.files.upload(path, BytesIO(("\n".join(lines) + "\n").encode("utf-8")), overwrite=True) + return f"Saved agent instruction: {instruction}" + + @tool + def get_agent_instructions(config: RunnableConfig) -> str: + """Read the current shared agent instructions.""" + content = read_agent_instructions(workspace_client, volume) + return content if content.strip() else "No agent instructions saved yet." + + return [save_agent_instruction, get_agent_instructions] diff --git a/agent-openai-agents-sdk/agent_server/agent.py b/agent-openai-agents-sdk/agent_server/agent.py index b336ffdc..88a56892 100644 --- a/agent-openai-agents-sdk/agent_server/agent.py +++ b/agent-openai-agents-sdk/agent_server/agent.py @@ -1,10 +1,7 @@ -import logging -from datetime import datetime from typing import AsyncGenerator -import litellm import mlflow -from agents import Agent, Runner, function_tool, set_default_openai_api, set_default_openai_client +from agents import Agent, Runner, set_default_openai_api, set_default_openai_client from agents.tracing import set_trace_processors from databricks.sdk import WorkspaceClient from databricks_openai import AsyncDatabricksOpenAI @@ -17,84 +14,75 @@ ) from agent_server.utils import ( - build_mcp_url, - get_session_id, get_user_workspace_client, - process_agent_stream_events, + process_agent_stream_events, build_mcp_url, ) -logger = logging.getLogger(__name__) - # NOTE: this will work for all databricks models OTHER than GPT-OSS, which uses a slightly different API set_default_openai_client(AsyncDatabricksOpenAI()) set_default_openai_api("chat_completions") set_trace_processors([]) # only use mlflow for trace processing mlflow.openai.autolog() -logging.getLogger("mlflow.utils.autologging_utils").setLevel(logging.ERROR) -litellm.suppress_debug_info = True -@function_tool -def get_current_time() -> str: - """Get the current date and time.""" - return datetime.now().isoformat() +MEMORY_MCP_HOST = "https://eng-ml-agent-platform.staging.cloud.databricks.com" +memory_ws_client = WorkspaceClient(host=MEMORY_MCP_HOST, profile="agent-platform") -async def init_mcp_server(workspace_client: WorkspaceClient): +async def init_memory_mcp_server(): return McpServer( - url=build_mcp_url("/api/2.0/mcp/functions/system/ai", workspace_client=workspace_client), - name="system.ai UC function MCP server", - workspace_client=workspace_client, + url=f"{MEMORY_MCP_HOST}/api/2.0/mcp/sql", + name="memory-mcp", + workspace_client=memory_ws_client, + params={ + "headers": {"x-databricks-traffic-id": "testenv://liteswap/jenny_memory"}, + }, ) -def create_agent(mcp_servers: list[McpServer] | None = None) -> Agent: +MEMORY_STORE = "test-embed" + +SYSTEM_PROMPT = f"""You are a helpful assistant with long-term memory. + +## Important: Check agent memory before every response +Before responding to the user, ALWAYS call search_memory with memory_store="{MEMORY_STORE}", scope="agent", query="response preferences and procedures" to load any shared instructions that affect how you should respond. + +Also call search_memory with scope="user" to check for personal context about the current user. + +## Memory Tools +- write_memory: Save info. scope="user" for personal facts, scope="agent" for shared rules/procedures. +- search_memory: Search past memories. scope="user"/"agent"/"both". +- Always use memory_store="{MEMORY_STORE}" for all memory operations.""" + + +def create_agent(mcp_server: McpServer) -> Agent: return Agent( - name="Agent", - instructions="You are a helpful assistant.", - model="databricks-gpt-5-2", - tools=[get_current_time], - mcp_servers=mcp_servers or [], + name="Memory agent", + instructions=SYSTEM_PROMPT, + model="databricks-claude-sonnet-4-5", + mcp_servers=[mcp_server], ) @invoke() -async def invoke_handler(request: ResponsesAgentRequest) -> ResponsesAgentResponse: - if session_id := get_session_id(request): - mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) - # To use MCP server tools, wrap the code below with this async context manager. - # By default, uses service principal credentials via WorkspaceClient(). - # For on-behalf-of user authentication, use get_user_workspace_client() instead. - # try: - # async with await init_mcp_server(WorkspaceClient()) as mcp_server: - # agent = create_agent(mcp_servers=[mcp_server]) - # except Exception: - # logger.warning("MCP server unavailable. Continuing without MCP tools.", exc_info=True) - # agent = create_agent() - agent = create_agent() - messages = [i.model_dump() for i in request.input] - result = await Runner.run(agent, messages) - return ResponsesAgentResponse(output=[item.to_input_item() for item in result.new_items]) +async def invoke(request: ResponsesAgentRequest) -> ResponsesAgentResponse: + # Optionally use the user's workspace client for on-behalf-of authentication + # user_workspace_client = get_user_workspace_client() + async with await init_memory_mcp_server() as mcp_server: + agent = create_agent(mcp_server) + messages = [i.model_dump() for i in request.input] + result = await Runner.run(agent, messages) + return ResponsesAgentResponse(output=[item.to_input_item() for item in result.new_items]) @stream() -async def stream_handler( - request: ResponsesAgentRequest, -) -> AsyncGenerator[ResponsesAgentStreamEvent, None]: - if session_id := get_session_id(request): - mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) - # To use MCP server tools, wrap the code below with this async context manager. - # By default, uses service principal credentials via WorkspaceClient(). - # For on-behalf-of user authentication, use get_user_workspace_client() instead. - # try: - # async with await init_mcp_server(WorkspaceClient()) as mcp_server: - # agent = create_agent(mcp_servers=[mcp_server]) - # except Exception: - # logger.warning("MCP server unavailable. Continuing without MCP tools.", exc_info=True) - # agent = create_agent() - agent = create_agent() - messages = [i.model_dump() for i in request.input] - result = Runner.run_streamed(agent, input=messages) - - async for event in process_agent_stream_events(result.stream_events()): - yield event +async def stream(request: ResponsesAgentRequest) -> AsyncGenerator[ResponsesAgentStreamEvent, None]: + # Optionally use the user's workspace client for on-behalf-of authentication + # user_workspace_client = get_user_workspace_client() + async with await init_memory_mcp_server() as mcp_server: + agent = create_agent(mcp_server) + messages = [i.model_dump() for i in request.input] + result = Runner.run_streamed(agent, input=messages) + + async for event in process_agent_stream_events(result.stream_events()): + yield event From a68daa308f70ecd5fd7fd9d2724ac8f39dc6d2a3 Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Sat, 21 Mar 2026 02:02:49 -0700 Subject: [PATCH 2/3] rebase --- .../agent_server/agent.py | 115 +++++++++++++----- agent-openai-agents-sdk/agent_server/agent.py | 73 ++++++----- 2 files changed, 126 insertions(+), 62 deletions(-) diff --git a/agent-langgraph-long-term-memory/agent_server/agent.py b/agent-langgraph-long-term-memory/agent_server/agent.py index 9a8146ea..faed7923 100644 --- a/agent-langgraph-long-term-memory/agent_server/agent.py +++ b/agent-langgraph-long-term-memory/agent_server/agent.py @@ -1,5 +1,6 @@ import logging import os +from datetime import datetime from typing import Any, AsyncGenerator, Optional import mlflow @@ -12,6 +13,7 @@ ) from fastapi import HTTPException from langchain.agents import create_agent +from langchain_core.tools import tool from langgraph.store.base import BaseStore from mlflow.genai.agent_server import invoke, stream from mlflow.types.responses import ( @@ -21,54 +23,90 @@ to_chat_completions_input, ) +from agent_server.utils import ( + get_databricks_host_from_env, + get_session_id, + get_user_workspace_client, + process_agent_astream_events, +) +from agent_server.utils_agent_memory import agent_memory_tools, read_agent_instructions from agent_server.utils_memory import ( get_lakebase_access_error_message, get_user_id, memory_tools, resolve_lakebase_instance_name, ) -from agent_server.utils_agent_memory import agent_memory_tools, read_agent_instructions -from agent_server.utils import ( - get_databricks_host_from_env, - get_user_workspace_client, - process_agent_astream_events, -) logger = logging.getLogger(__name__) +logging.getLogger("mlflow.utils.autologging_utils").setLevel(logging.ERROR) +logging.getLogger("LiteLLM").setLevel(logging.WARNING) mlflow.langchain.autolog() sp_workspace_client = WorkspaceClient() + +@tool +def get_current_time() -> str: + """Get the current date and time.""" + return datetime.now().isoformat() + + ############################################ # Configuration ############################################ LLM_ENDPOINT_NAME = "databricks-claude-sonnet-4-5" -_LAKEBASE_INSTANCE_NAME_RAW = os.getenv("LAKEBASE_INSTANCE_NAME", "") +_LAKEBASE_INSTANCE_NAME_RAW = os.getenv("LAKEBASE_INSTANCE_NAME") or None EMBEDDING_ENDPOINT = "databricks-gte-large-en" EMBEDDING_DIMS = 1024 UC_VOLUME = os.getenv("UC_VOLUME", "") +LAKEBASE_AUTOSCALING_PROJECT = os.getenv("LAKEBASE_AUTOSCALING_PROJECT") or None +LAKEBASE_AUTOSCALING_BRANCH = os.getenv("LAKEBASE_AUTOSCALING_BRANCH") or None -if not _LAKEBASE_INSTANCE_NAME_RAW: +############################################ + +_has_autoscaling = LAKEBASE_AUTOSCALING_PROJECT and LAKEBASE_AUTOSCALING_BRANCH +if not _LAKEBASE_INSTANCE_NAME_RAW and not _has_autoscaling: raise ValueError( - "LAKEBASE_INSTANCE_NAME environment variable is required but not set. " - "Please set it in your environment:\n" - " LAKEBASE_INSTANCE_NAME=\n" + "Lakebase configuration is required but not set. " + "Please set one of the following in your environment:\n" + " Option 1 (provisioned): LAKEBASE_INSTANCE_NAME=\n" + " Option 2 (autoscaling): LAKEBASE_AUTOSCALING_PROJECT= and LAKEBASE_AUTOSCALING_BRANCH=\n" ) # Resolve hostname to instance name if needed (if given hostname of lakebase instead of name) -LAKEBASE_INSTANCE_NAME = resolve_lakebase_instance_name(_LAKEBASE_INSTANCE_NAME_RAW) +LAKEBASE_INSTANCE_NAME = resolve_lakebase_instance_name(_LAKEBASE_INSTANCE_NAME_RAW) if _LAKEBASE_INSTANCE_NAME_RAW else None SYSTEM_PROMPT = """You are a helpful assistant. Use the available tools to answer questions. -## User Memory (per-user, private) +You have access to memory tools that allow you to remember information about users: - Use get_user_memory to search for previously saved information about the user - Use save_user_memory to remember important facts, preferences, or details the user shares - Use delete_user_memory to forget specific information when asked +Always check for relevant memories at the start of a conversation to provide personalized responses. + +## When to save memories + +**Always save** when the user explicitly asks you to remember something. Trigger phrases include: +"remember that…", "store this", "add to memory", "note that…", "from now on…" + +**Proactively save** when the user shares information that is likely to remain true for months or years \ +and would meaningfully improve future responses. This includes: +- Preferences (e.g., language, framework, formatting style) +- Role, responsibilities, or expertise +- Ongoing projects or long-term goals +- Recurring constraints (e.g., accessibility needs, dietary restrictions) + +## When NOT to save memories + +- Temporary or short-lived facts (e.g., "I'm tired today") +- Trivial or one-off details (e.g., what they ate for lunch, a single troubleshooting step) +- Highly sensitive personal information (health conditions, political affiliation, sexual orientation, \ +religion, criminal history) — unless the user explicitly asks you to store it +- Information that could feel intrusive or overly personal to store + ## Agent Memory (shared across all users) - Use save_agent_instruction to save learnings that apply to ALL users: team preferences, process rules, best practices -- Use get_agent_instructions to read the current shared instructions - -Always check for relevant user memories at the start of a conversation.""" +- Use get_agent_instructions to read the current shared instructions""" def init_mcp_client(workspace_client: WorkspaceClient) -> DatabricksMultiServerMCPClient: @@ -78,17 +116,22 @@ def init_mcp_client(workspace_client: WorkspaceClient) -> DatabricksMultiServerM DatabricksMCPServer( name="system-ai", url=f"{host_name}/api/2.0/mcp/functions/system/ai", + workspace_client=workspace_client, ), ] ) -async def init_agent(store: BaseStore, system_prompt: str, workspace_client: Optional[WorkspaceClient] = None): - ws = workspace_client or sp_workspace_client - mcp_client = init_mcp_client(ws) - tools = await mcp_client.get_tools() + memory_tools() +async def init_agent(store: BaseStore, workspace_client: Optional[WorkspaceClient] = None, system_prompt: str = SYSTEM_PROMPT): + tools = [get_current_time] + memory_tools() if UC_VOLUME: - tools += agent_memory_tools(ws, UC_VOLUME) + tools += agent_memory_tools(workspace_client or sp_workspace_client, UC_VOLUME) + # To use MCP server tools instead, replace the line above with: + # mcp_client = init_mcp_client(workspace_client or sp_workspace_client) + # try: + # tools.extend(await mcp_client.get_tools()) + # except Exception: + # logger.warning("Failed to fetch MCP tools. Continuing without MCP tools.", exc_info=True) return create_agent( model=ChatDatabricks(endpoint=LLM_ENDPOINT_NAME), @@ -99,23 +142,25 @@ async def init_agent(store: BaseStore, system_prompt: str, workspace_client: Opt @invoke() -async def non_streaming(request: ResponsesAgentRequest) -> ResponsesAgentResponse: - user_id = get_user_id(request) - +async def invoke_handler(request: ResponsesAgentRequest) -> ResponsesAgentResponse: outputs = [ event.item - async for event in streaming(request) + async for event in stream_handler(request) if event.type == "response.output_item.done" ] + user_id = get_user_id(request) custom_outputs = {"user_id": user_id} if user_id else None return ResponsesAgentResponse(output=outputs, custom_outputs=custom_outputs) @stream() -async def streaming( +async def stream_handler( request: ResponsesAgentRequest, ) -> AsyncGenerator[ResponsesAgentStreamEvent, None]: + if session_id := get_session_id(request): + mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) + user_id = get_user_id(request) if not user_id: @@ -128,6 +173,8 @@ async def streaming( try: async with AsyncDatabricksStore( instance_name=LAKEBASE_INSTANCE_NAME, + project=LAKEBASE_AUTOSCALING_PROJECT, + branch=LAKEBASE_AUTOSCALING_BRANCH, embedding_endpoint=EMBEDDING_ENDPOINT, embedding_dims=EMBEDDING_DIMS, ) as store: @@ -136,14 +183,16 @@ async def streaming( if user_id: config["configurable"]["user_id"] = user_id + # Inject agent-scoped instructions from UC Volume into system prompt full_prompt = SYSTEM_PROMPT if UC_VOLUME: instructions = read_agent_instructions(sp_workspace_client, UC_VOLUME) if instructions.strip(): full_prompt += f"\n\n## Current Agent Instructions\n{instructions}" - agent = await init_agent(store=store, system_prompt=full_prompt) - + # By default, uses service principal credentials (sp_workspace_client). + # For on-behalf-of user authentication, use get_user_workspace_client() instead. + agent = await init_agent(workspace_client=sp_workspace_client, store=store, system_prompt=full_prompt) async for event in process_agent_astream_events( agent.astream(messages, config, stream_mode=["updates", "messages"]) ): @@ -151,10 +200,10 @@ async def streaming( except Exception as e: error_msg = str(e).lower() # Check for Lakebase access/connection errors - if any( - keyword in error_msg - for keyword in ["permission"] - ): + if any(keyword in error_msg for keyword in ["lakebase", "pg_hba", "postgres", "database instance"]): logger.error(f"Lakebase access error: {e}") - raise HTTPException(status_code=503, detail=get_lakebase_access_error_message(LAKEBASE_INSTANCE_NAME)) from e + lakebase_desc = LAKEBASE_INSTANCE_NAME or f"{LAKEBASE_AUTOSCALING_PROJECT}/{LAKEBASE_AUTOSCALING_BRANCH}" + raise HTTPException( + status_code=503, detail=get_lakebase_access_error_message(lakebase_desc) + ) from e raise diff --git a/agent-openai-agents-sdk/agent_server/agent.py b/agent-openai-agents-sdk/agent_server/agent.py index 88a56892..0e32ee26 100644 --- a/agent-openai-agents-sdk/agent_server/agent.py +++ b/agent-openai-agents-sdk/agent_server/agent.py @@ -1,7 +1,10 @@ +import logging +from datetime import datetime from typing import AsyncGenerator +import litellm import mlflow -from agents import Agent, Runner, set_default_openai_api, set_default_openai_client +from agents import Agent, Runner, function_tool, set_default_openai_api, set_default_openai_client from agents.tracing import set_trace_processors from databricks.sdk import WorkspaceClient from databricks_openai import AsyncDatabricksOpenAI @@ -14,35 +17,34 @@ ) from agent_server.utils import ( + build_mcp_url, + get_session_id, get_user_workspace_client, - process_agent_stream_events, build_mcp_url, + process_agent_stream_events, ) +logger = logging.getLogger(__name__) + # NOTE: this will work for all databricks models OTHER than GPT-OSS, which uses a slightly different API set_default_openai_client(AsyncDatabricksOpenAI()) set_default_openai_api("chat_completions") set_trace_processors([]) # only use mlflow for trace processing mlflow.openai.autolog() +logging.getLogger("mlflow.utils.autologging_utils").setLevel(logging.ERROR) +litellm.suppress_debug_info = True -MEMORY_MCP_HOST = "https://eng-ml-agent-platform.staging.cloud.databricks.com" -memory_ws_client = WorkspaceClient(host=MEMORY_MCP_HOST, profile="agent-platform") - - -async def init_memory_mcp_server(): - return McpServer( - url=f"{MEMORY_MCP_HOST}/api/2.0/mcp/sql", - name="memory-mcp", - workspace_client=memory_ws_client, - params={ - "headers": {"x-databricks-traffic-id": "testenv://liteswap/jenny_memory"}, - }, - ) +@function_tool +def get_current_time() -> str: + """Get the current date and time.""" + return datetime.now().isoformat() +MEMORY_MCP_HOST = "https://eng-ml-agent-platform.staging.cloud.databricks.com" +memory_ws_client = WorkspaceClient(host=MEMORY_MCP_HOST, profile="agent-platform") MEMORY_STORE = "test-embed" -SYSTEM_PROMPT = f"""You are a helpful assistant with long-term memory. +MEMORY_SYSTEM_PROMPT = f"""You are a helpful assistant with long-term memory. ## Important: Check agent memory before every response Before responding to the user, ALWAYS call search_memory with memory_store="{MEMORY_STORE}", scope="agent", query="response preferences and procedures" to load any shared instructions that affect how you should respond. @@ -55,32 +57,45 @@ async def init_memory_mcp_server(): - Always use memory_store="{MEMORY_STORE}" for all memory operations.""" -def create_agent(mcp_server: McpServer) -> Agent: +async def init_mcp_server(workspace_client: WorkspaceClient = None): + return McpServer( + url=f"{MEMORY_MCP_HOST}/api/2.0/mcp/sql", + name="memory-mcp", + workspace_client=memory_ws_client, + params={ + "headers": {"x-databricks-traffic-id": "testenv://liteswap/jenny_memory"}, + }, + ) + + +def create_agent(mcp_servers: list[McpServer] | None = None) -> Agent: return Agent( name="Memory agent", - instructions=SYSTEM_PROMPT, + instructions=MEMORY_SYSTEM_PROMPT, model="databricks-claude-sonnet-4-5", - mcp_servers=[mcp_server], + mcp_servers=mcp_servers or [], ) @invoke() -async def invoke(request: ResponsesAgentRequest) -> ResponsesAgentResponse: - # Optionally use the user's workspace client for on-behalf-of authentication - # user_workspace_client = get_user_workspace_client() - async with await init_memory_mcp_server() as mcp_server: - agent = create_agent(mcp_server) +async def invoke_handler(request: ResponsesAgentRequest) -> ResponsesAgentResponse: + if session_id := get_session_id(request): + mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) + async with await init_mcp_server() as mcp_server: + agent = create_agent(mcp_servers=[mcp_server]) messages = [i.model_dump() for i in request.input] result = await Runner.run(agent, messages) return ResponsesAgentResponse(output=[item.to_input_item() for item in result.new_items]) @stream() -async def stream(request: ResponsesAgentRequest) -> AsyncGenerator[ResponsesAgentStreamEvent, None]: - # Optionally use the user's workspace client for on-behalf-of authentication - # user_workspace_client = get_user_workspace_client() - async with await init_memory_mcp_server() as mcp_server: - agent = create_agent(mcp_server) +async def stream_handler( + request: ResponsesAgentRequest, +) -> AsyncGenerator[ResponsesAgentStreamEvent, None]: + if session_id := get_session_id(request): + mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) + async with await init_mcp_server() as mcp_server: + agent = create_agent(mcp_servers=[mcp_server]) messages = [i.model_dump() for i in request.input] result = Runner.run_streamed(agent, input=messages) From dec632dc9917e27a2b2797b87430dbba32c2ed6e Mon Sep 17 00:00:00 2001 From: Nisha Balaji Date: Sun, 22 Mar 2026 16:47:17 -0700 Subject: [PATCH 3/3] github --- agent-openai-agents-sdk/agent_server/agent.py | 47 ++++++++++++------- 1 file changed, 29 insertions(+), 18 deletions(-) diff --git a/agent-openai-agents-sdk/agent_server/agent.py b/agent-openai-agents-sdk/agent_server/agent.py index 0e32ee26..2116617e 100644 --- a/agent-openai-agents-sdk/agent_server/agent.py +++ b/agent-openai-agents-sdk/agent_server/agent.py @@ -4,7 +4,7 @@ import litellm import mlflow -from agents import Agent, Runner, function_tool, set_default_openai_api, set_default_openai_client +from agents import Agent, ModelSettings, Runner, function_tool, set_default_openai_api, set_default_openai_client from agents.tracing import set_trace_processors from databricks.sdk import WorkspaceClient from databricks_openai import AsyncDatabricksOpenAI @@ -44,36 +44,45 @@ def get_current_time() -> str: memory_ws_client = WorkspaceClient(host=MEMORY_MCP_HOST, profile="agent-platform") MEMORY_STORE = "test-embed" -MEMORY_SYSTEM_PROMPT = f"""You are a helpful assistant with long-term memory. +MEMORY_SYSTEM_PROMPT = f"""You are a helpful assistant with long-term memory. You proactively remember things about users. -## Important: Check agent memory before every response -Before responding to the user, ALWAYS call search_memory with memory_store="{MEMORY_STORE}", scope="agent", query="response preferences and procedures" to load any shared instructions that affect how you should respond. +Always use memory_store="{MEMORY_STORE}" for all memory operations. -Also call search_memory with scope="user" to check for personal context about the current user. +## Before every response +1. Call search_memory scope="agent", query="response preferences and procedures" to load shared instructions. +2. Call search_memory scope="user" to check for personal context about the current user. -## Memory Tools -- write_memory: Save info. scope="user" for personal facts, scope="agent" for shared rules/procedures. -- search_memory: Search past memories. scope="user"/"agent"/"both". -- Always use memory_store="{MEMORY_STORE}" for all memory operations.""" +## Saving memories +Proactively save anything the user shares about themselves (location, role, preferences, interests, etc.) using write_memory. Use scope="user" for personal facts, scope="agent" for shared rules that apply to all users. +## Conversation history +Refer to the current chat history for questions about this session. Only search memory for info from previous sessions.""" -async def init_mcp_server(workspace_client: WorkspaceClient = None): - return McpServer( + +async def init_mcp_servers(): + memory = McpServer( url=f"{MEMORY_MCP_HOST}/api/2.0/mcp/sql", name="memory-mcp", workspace_client=memory_ws_client, params={ - "headers": {"x-databricks-traffic-id": "testenv://liteswap/jenny_memory"}, + "headers": {"x-databricks-traffic-id": "testenv://liteswap/jennymemorysa"}, }, ) + github = McpServer( + url=f"{MEMORY_MCP_HOST}/api/2.0/mcp/external/github_demo", + name="github-mcp", + workspace_client=memory_ws_client, + ) + return memory, github def create_agent(mcp_servers: list[McpServer] | None = None) -> Agent: return Agent( - name="Memory agent", + name="Code review agent", instructions=MEMORY_SYSTEM_PROMPT, - model="databricks-claude-sonnet-4-5", + model="databricks-gpt-5-2", mcp_servers=mcp_servers or [], + model_settings=ModelSettings(parallel_tool_calls=False), ) @@ -81,8 +90,9 @@ def create_agent(mcp_servers: list[McpServer] | None = None) -> Agent: async def invoke_handler(request: ResponsesAgentRequest) -> ResponsesAgentResponse: if session_id := get_session_id(request): mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) - async with await init_mcp_server() as mcp_server: - agent = create_agent(mcp_servers=[mcp_server]) + memory_srv, github_srv = await init_mcp_servers() + async with memory_srv as mem, github_srv as gh: + agent = create_agent(mcp_servers=[mem, gh]) messages = [i.model_dump() for i in request.input] result = await Runner.run(agent, messages) return ResponsesAgentResponse(output=[item.to_input_item() for item in result.new_items]) @@ -94,8 +104,9 @@ async def stream_handler( ) -> AsyncGenerator[ResponsesAgentStreamEvent, None]: if session_id := get_session_id(request): mlflow.update_current_trace(metadata={"mlflow.trace.session": session_id}) - async with await init_mcp_server() as mcp_server: - agent = create_agent(mcp_servers=[mcp_server]) + memory_srv, github_srv = await init_mcp_servers() + async with memory_srv as mem, github_srv as gh: + agent = create_agent(mcp_servers=[mem, gh]) messages = [i.model_dump() for i in request.input] result = Runner.run_streamed(agent, input=messages)