From fba8180a0533467fe615aa4e343a50c00567357c Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 22 Jan 2026 22:32:00 +0000 Subject: [PATCH 1/2] google-adk: Add Google ADK integration for Databricks AI features This adds a new integration package `databricks-google-adk` that provides Databricks AI tools for Google Agent Development Kit (ADK) agents: - VectorSearchRetrieverTool: Search Databricks Vector Search indexes - GenieTool: Query Databricks Genie AI/BI spaces - DatabricksToolset: Bundle multiple Databricks tools together The integration follows the existing patterns from other framework integrations (langchain, llamaindex, openai) and reuses the core VectorSearchRetrieverToolMixin for consistent behavior. --- integrations/google-adk/README.md | 237 +++++++++++++++++ integrations/google-adk/pyproject.toml | 64 +++++ .../src/databricks_google_adk/__init__.py | 44 ++++ .../src/databricks_google_adk/genie.py | 244 ++++++++++++++++++ .../src/databricks_google_adk/toolset.py | 182 +++++++++++++ .../vector_search_retriever_tool.py | 217 ++++++++++++++++ integrations/google-adk/tests/__init__.py | 0 .../google-adk/tests/unit_tests/__init__.py | 0 .../google-adk/tests/unit_tests/test_genie.py | 151 +++++++++++ .../tests/unit_tests/test_toolset.py | 149 +++++++++++ .../test_vector_search_retriever_tool.py | 204 +++++++++++++++ 11 files changed, 1492 insertions(+) create mode 100644 integrations/google-adk/README.md create mode 100644 integrations/google-adk/pyproject.toml create mode 100644 integrations/google-adk/src/databricks_google_adk/__init__.py create mode 100644 integrations/google-adk/src/databricks_google_adk/genie.py create mode 100644 integrations/google-adk/src/databricks_google_adk/toolset.py create mode 100644 integrations/google-adk/src/databricks_google_adk/vector_search_retriever_tool.py create mode 100644 integrations/google-adk/tests/__init__.py create mode 100644 integrations/google-adk/tests/unit_tests/__init__.py create mode 100644 integrations/google-adk/tests/unit_tests/test_genie.py create mode 100644 integrations/google-adk/tests/unit_tests/test_toolset.py create mode 100644 integrations/google-adk/tests/unit_tests/test_vector_search_retriever_tool.py diff --git a/integrations/google-adk/README.md b/integrations/google-adk/README.md new file mode 100644 index 000000000..c607afdbc --- /dev/null +++ b/integrations/google-adk/README.md @@ -0,0 +1,237 @@ +# Databricks AI Bridge for Google ADK + +This package provides Databricks AI integration for [Google Agent Development Kit (ADK)](https://github.com/google/adk-python), enabling you to use Databricks Vector Search and Genie in your ADK agents. + +## Installation + +```bash +pip install databricks-google-adk +``` + +## Features + +- **VectorSearchRetrieverTool**: Search Databricks Vector Search indexes from ADK agents +- **GenieTool**: Query Databricks Genie AI/BI spaces using natural language +- **DatabricksToolset**: Bundle multiple Databricks tools together for easy agent configuration + +## Quick Start + +### Vector Search Tool + +Use Databricks Vector Search in your ADK agent: + +```python +from databricks_google_adk import VectorSearchRetrieverTool +from google.adk.agents import Agent +from google.adk.runners import Runner + +# Create the Vector Search tool +vector_search = VectorSearchRetrieverTool( + index_name="catalog.schema.my_index", + num_results=5, +) + +# Create an ADK agent with the tool +agent = Agent( + name="search_assistant", + model="gemini-2.0-flash", + instruction="You are a helpful assistant that searches documents to answer questions.", + tools=[vector_search.as_tool()], +) + +# Run the agent +runner = Runner(agent=agent, app_name="search_app") +``` + +### Genie Tool + +Query Databricks Genie for data insights: + +```python +from databricks_google_adk import GenieTool +from google.adk.agents import Agent + +# Create the Genie tool +genie = GenieTool( + space_id="your-genie-space-id", + tool_description="Ask questions about sales data", +) + +# Create an ADK agent with the tool +agent = Agent( + name="data_analyst", + model="gemini-2.0-flash", + instruction="You are a data analyst. Use the genie tool to answer questions about data.", + tools=[genie.as_tool()], +) +``` + +### Using DatabricksToolset + +Bundle multiple Databricks tools together: + +```python +from databricks_google_adk import DatabricksToolset +from google.adk.agents import Agent + +# Create a toolset with multiple tools +toolset = DatabricksToolset( + vector_search_indexes=[ + "catalog.schema.products_index", + "catalog.schema.docs_index", + ], + genie_space_ids=["genie-space-123"], +) + +# Or build incrementally with method chaining +toolset = ( + DatabricksToolset() + .add_vector_search_tool( + index_name="catalog.schema.my_index", + tool_name="search_products", + tool_description="Search product catalog", + ) + .add_genie_tool( + space_id="genie-space-123", + tool_name="ask_sales_data", + tool_description="Query sales data", + ) +) + +# Use with an ADK agent +agent = Agent( + name="data_assistant", + model="gemini-2.0-flash", + instruction="You help users find products and analyze sales data.", + tools=[toolset], +) +``` + +## Advanced Usage + +### Self-Managed Embeddings + +For Vector Search indexes with self-managed embeddings, provide an embedding function: + +```python +from databricks_google_adk import VectorSearchRetrieverTool + +def my_embedding_fn(text: str) -> list[float]: + # Your embedding logic here + return embeddings + +vector_search = VectorSearchRetrieverTool( + index_name="catalog.schema.self_managed_index", + text_column="content", + embedding_fn=my_embedding_fn, +) +``` + +### Dynamic Filters + +Enable LLM-generated filters for more flexible querying: + +```python +vector_search = VectorSearchRetrieverTool( + index_name="catalog.schema.my_index", + dynamic_filter=True, # Allows LLM to generate filters +) +``` + +### Multi-turn Conversations with Genie + +The GenieTool maintains conversation state for follow-up questions: + +```python +from databricks_google_adk import GenieTool + +genie = GenieTool(space_id="your-space-id") + +# First question +result1 = genie.ask("What were total sales last month?") + +# Follow-up question (uses same conversation) +result2 = genie.ask("Break that down by region") + +# Start a new conversation +result3 = genie.ask("Show me top customers", new_conversation=True) + +# Reset conversation state +genie.reset_conversation() +``` + +### Custom Authentication + +Pass a custom WorkspaceClient for authentication: + +```python +from databricks.sdk import WorkspaceClient +from databricks_google_adk import VectorSearchRetrieverTool, GenieTool + +# Create a workspace client with custom configuration +client = WorkspaceClient( + host="https://your-workspace.databricks.com", + token="your-token", +) + +# Use with Vector Search +vector_search = VectorSearchRetrieverTool( + index_name="catalog.schema.my_index", + workspace_client=client, +) + +# Use with Genie +genie = GenieTool( + space_id="your-space-id", + client=client, +) +``` + +## API Reference + +### VectorSearchRetrieverTool + +| Parameter | Type | Description | +|-----------|------|-------------| +| `index_name` | `str` | Name of the Vector Search index (format: `catalog.schema.index`) | +| `num_results` | `int` | Number of results to return (default: 5) | +| `columns` | `list[str]` | Columns to return in results | +| `filters` | `dict` | Static filters to apply to all searches | +| `dynamic_filter` | `bool` | Enable LLM-generated filters (default: False) | +| `tool_name` | `str` | Custom name for the tool | +| `tool_description` | `str` | Custom description for the tool | +| `text_column` | `str` | Text column name (required for self-managed embeddings) | +| `embedding_fn` | `Callable` | Embedding function (required for self-managed embeddings) | +| `workspace_client` | `WorkspaceClient` | Custom Databricks client | + +### GenieTool + +| Parameter | Type | Description | +|-----------|------|-------------| +| `space_id` | `str` | Genie space ID | +| `tool_name` | `str` | Custom name for the tool (default: "ask_genie") | +| `tool_description` | `str` | Custom description for the tool | +| `client` | `WorkspaceClient` | Custom Databricks client | +| `return_pandas` | `bool` | Return results as pandas DataFrames (default: False) | + +### DatabricksToolset + +| Parameter | Type | Description | +|-----------|------|-------------| +| `vector_search_indexes` | `list[str]` | List of Vector Search index names | +| `genie_space_ids` | `list[str]` | List of Genie space IDs | +| `workspace_client` | `WorkspaceClient` | Custom Databricks client | +| `embedding_fn` | `Callable` | Embedding function for self-managed indexes | +| `tool_filter` | `list[str]` | Filter tools by name | +| `tool_name_prefix` | `str` | Prefix to add to all tool names | + +## Requirements + +- Python >= 3.10 +- google-adk >= 1.0.0 +- databricks-ai-bridge >= 0.4.0 +- databricks-vectorsearch >= 0.40 + +## License + +Apache-2.0 diff --git a/integrations/google-adk/pyproject.toml b/integrations/google-adk/pyproject.toml new file mode 100644 index 000000000..501404918 --- /dev/null +++ b/integrations/google-adk/pyproject.toml @@ -0,0 +1,64 @@ +[project] +name = "databricks-google-adk" +version = "0.1.0.dev0" +description = "Databricks AI support for Google Agent Development Kit (ADK)" +authors = [ + { name="Databricks", email="feedback@databricks.com" }, +] +readme = "README.md" +license = { text="Apache-2.0" } +requires-python = ">=3.10" +dependencies = [ + "databricks-vectorsearch>=0.40", + "databricks-ai-bridge>=0.4.0", + "google-adk>=1.0.0", + "pydantic>=2.10.0", +] + +[dependency-groups] +dev = [ + "typing_extensions>=4.15.0", + "databricks-sdk>=0.34.0", + "ruff==0.14.10", + "ty>=0.0.11", + { include-group = "tests" }, +] + +tests = [ + "pytest>=9.0.0", + "pytest-timeout>=2.3.1", + "pytest-asyncio>=0.24.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.uv.sources] +databricks-ai-bridge = { path = "../../", editable = true } + +[tool.hatch.build] +include = [ + "src/databricks_google_adk/*" +] + +[tool.hatch.build.targets.wheel] +packages = ["src/databricks_google_adk"] + +[tool.ruff] +include = ["pyproject.toml", "src/**/*.py", "tests/**/*.py"] +extend = "../../pyproject.toml" + +[tool.pytest.ini_options] +filterwarnings = [ + "ignore::Warning", + "default::Warning:databricks_google_adk", + "default::Warning:tests", +] +asyncio_mode = "auto" + +[tool.ty.environment] +root = ["./src", "./tests"] + +[tool.ty.src] +include = ["./src", "./tests"] diff --git a/integrations/google-adk/src/databricks_google_adk/__init__.py b/integrations/google-adk/src/databricks_google_adk/__init__.py new file mode 100644 index 000000000..61a4b748d --- /dev/null +++ b/integrations/google-adk/src/databricks_google_adk/__init__.py @@ -0,0 +1,44 @@ +""" +Databricks AI support for Google Agent Development Kit (ADK). + +This package provides tools and utilities for integrating Databricks AI features +with Google ADK agents. + +Available classes and functions: + +- :class:`VectorSearchRetrieverTool` - Search Databricks Vector Search indexes +- :class:`GenieTool` - Query Databricks Genie AI/BI spaces +- :func:`create_genie_tool` - Factory function to create Genie tools +- :class:`DatabricksToolset` - Bundle multiple Databricks tools together + +Example: + ```python + from databricks_google_adk import VectorSearchRetrieverTool, GenieTool + from google.adk.agents import Agent + + # Create tools + vector_search = VectorSearchRetrieverTool( + index_name="catalog.schema.my_index", + ) + genie = GenieTool(space_id="your-genie-space-id") + + # Use with an ADK agent + agent = Agent( + name="data_assistant", + model="gemini-2.0-flash", + instruction="You help users find and analyze data.", + tools=[vector_search.as_tool(), genie.as_tool()], + ) + ``` +""" + +from databricks_google_adk.genie import GenieTool, create_genie_tool +from databricks_google_adk.toolset import DatabricksToolset +from databricks_google_adk.vector_search_retriever_tool import VectorSearchRetrieverTool + +__all__ = [ + "VectorSearchRetrieverTool", + "GenieTool", + "create_genie_tool", + "DatabricksToolset", +] diff --git a/integrations/google-adk/src/databricks_google_adk/genie.py b/integrations/google-adk/src/databricks_google_adk/genie.py new file mode 100644 index 000000000..297f0a022 --- /dev/null +++ b/integrations/google-adk/src/databricks_google_adk/genie.py @@ -0,0 +1,244 @@ +from typing import Optional + +from databricks.sdk import WorkspaceClient +from databricks_ai_bridge.genie import Genie +from google.adk.tools import FunctionTool + + +def create_genie_tool( + space_id: str, + tool_name: str = "ask_genie", + tool_description: str | None = None, + client: Optional[WorkspaceClient] = None, + return_pandas: bool = False, +) -> FunctionTool: + """ + Create a Google ADK tool that queries a Databricks Genie space. + + Genie is Databricks' AI/BI assistant that can answer natural language questions + about your data by generating and executing SQL queries. + + Args: + space_id: The ID of the Genie space to query. + tool_name: Name for the tool (default: "ask_genie"). + tool_description: Custom description for the tool. If not provided, + uses the Genie space's description. + client: Optional WorkspaceClient instance for authentication. + return_pandas: Whether to return results as pandas DataFrames + (if False, returns markdown strings). + + Returns: + A FunctionTool that can be used with Google ADK agents. + + Example: + ```python + from databricks_google_adk import create_genie_tool + from google.adk.agents import Agent + + # Create the Genie tool + genie_tool = create_genie_tool( + space_id="your-genie-space-id", + tool_description="Ask questions about sales data", + ) + + # Use with an ADK agent + agent = Agent( + name="data_analyst", + model="gemini-2.0-flash", + instruction="You are a data analyst. Use the genie tool to answer data questions.", + tools=[genie_tool], + ) + ``` + """ + import mlflow + + genie = Genie( + space_id=space_id, + client=client, + return_pandas=return_pandas, + ) + + # Use space description if no custom description provided + description = tool_description or genie.description or ( + "Ask questions about data in natural language. " + "This tool queries a Databricks Genie space that can generate and execute SQL queries." + ) + + # Track conversation state for multi-turn conversations + conversation_state = {"conversation_id": None} + + def ask_genie(question: str, new_conversation: bool = False) -> dict: + """ + Ask a question to the Databricks Genie AI/BI assistant. + + Args: + question: The natural language question to ask about the data. + new_conversation: If True, starts a new conversation. If False, + continues the previous conversation if one exists. + + Returns: + A dictionary containing: + - result: The query result (markdown table or text) + - query: The generated SQL query (if applicable) + - description: Explanation of the query logic + - conversation_id: ID to continue this conversation + """ + with mlflow.start_span(name="ask_genie", span_type="TOOL"): + # Determine conversation_id + conv_id = None if new_conversation else conversation_state.get("conversation_id") + + # Query Genie + response = genie.ask_question(question, conversation_id=conv_id) + + # Update conversation state + if response.conversation_id: + conversation_state["conversation_id"] = response.conversation_id + + # Format result + result = response.result + if hasattr(result, "to_markdown"): + # pandas DataFrame + result = result.to_markdown(index=False) + + return { + "result": result, + "query": response.query or "", + "description": response.description or "", + "conversation_id": response.conversation_id or "", + } + + # Set function metadata + ask_genie.__name__ = tool_name + ask_genie.__doc__ = description + + return FunctionTool(ask_genie) + + +class GenieTool: + """ + A wrapper class for Databricks Genie that provides a Google ADK compatible tool. + + This class maintains conversation state and provides both synchronous and + tool-based access to Genie. + + Example: + ```python + from databricks_google_adk import GenieTool + from google.adk.agents import Agent + + # Create the Genie tool wrapper + genie = GenieTool(space_id="your-genie-space-id") + + # Use with an ADK agent + agent = Agent( + name="data_analyst", + model="gemini-2.0-flash", + instruction="You are a data analyst.", + tools=[genie.as_tool()], + ) + + # Or call directly + result = genie.ask("What were total sales last month?") + ``` + """ + + def __init__( + self, + space_id: str, + tool_name: str = "ask_genie", + tool_description: str | None = None, + client: Optional[WorkspaceClient] = None, + return_pandas: bool = False, + ): + """ + Initialize the GenieTool. + + Args: + space_id: The ID of the Genie space to query. + tool_name: Name for the tool (default: "ask_genie"). + tool_description: Custom description for the tool. + client: Optional WorkspaceClient instance. + return_pandas: Whether to return results as pandas DataFrames. + """ + self.space_id = space_id + self.tool_name = tool_name + self._client = client + self._return_pandas = return_pandas + self._genie = Genie( + space_id=space_id, + client=client, + return_pandas=return_pandas, + ) + self._tool_description = tool_description or self._genie.description + self._conversation_id: str | None = None + self._adk_tool: FunctionTool | None = None + + @property + def description(self) -> str: + """Get the Genie space description.""" + return self._genie.description or "" + + @property + def conversation_id(self) -> str | None: + """Get the current conversation ID.""" + return self._conversation_id + + def reset_conversation(self) -> None: + """Reset the conversation state to start fresh.""" + self._conversation_id = None + + def ask(self, question: str, new_conversation: bool = False) -> dict: + """ + Ask a question to Genie. + + Args: + question: The natural language question to ask. + new_conversation: If True, starts a new conversation. + + Returns: + A dictionary with result, query, description, and conversation_id. + """ + import mlflow + + with mlflow.start_span(name="GenieTool.ask", span_type="TOOL"): + conv_id = None if new_conversation else self._conversation_id + + response = self._genie.ask_question(question, conversation_id=conv_id) + + if response.conversation_id: + self._conversation_id = response.conversation_id + + result = response.result + if hasattr(result, "to_markdown"): + result = result.to_markdown(index=False) + + return { + "result": result, + "query": response.query or "", + "description": response.description or "", + "conversation_id": response.conversation_id or "", + } + + def as_tool(self) -> FunctionTool: + """ + Convert this GenieTool to a Google ADK FunctionTool. + + Returns: + A FunctionTool that can be used with Google ADK agents. + """ + if self._adk_tool is not None: + return self._adk_tool + + # Create a closure that references self + def ask_genie(question: str, new_conversation: bool = False) -> dict: + """Ask a question to the Databricks Genie AI/BI assistant.""" + return self.ask(question, new_conversation) + + ask_genie.__name__ = self.tool_name + ask_genie.__doc__ = self._tool_description or ( + "Ask questions about data in natural language. " + "This tool queries a Databricks Genie space." + ) + + self._adk_tool = FunctionTool(ask_genie) + return self._adk_tool diff --git a/integrations/google-adk/src/databricks_google_adk/toolset.py b/integrations/google-adk/src/databricks_google_adk/toolset.py new file mode 100644 index 000000000..6212d8ca2 --- /dev/null +++ b/integrations/google-adk/src/databricks_google_adk/toolset.py @@ -0,0 +1,182 @@ +from typing import Callable, Optional, Union + +from databricks.sdk import WorkspaceClient +from google.adk.tools import BaseTool, FunctionTool +from google.adk.tools.base_toolset import BaseToolset, ToolPredicate + +from databricks_google_adk.genie import GenieTool +from databricks_google_adk.vector_search_retriever_tool import VectorSearchRetrieverTool + + +class DatabricksToolset(BaseToolset): + """ + A Google ADK toolset that bundles Databricks AI tools together. + + This toolset provides convenient access to Databricks Vector Search and Genie + tools for use with Google ADK agents. + + Example: + ```python + from databricks_google_adk import DatabricksToolset + from google.adk.agents import Agent + + # Create a toolset with Vector Search and Genie + toolset = DatabricksToolset( + vector_search_indexes=["catalog.schema.my_index"], + genie_space_ids=["genie-space-123"], + ) + + # Use with an ADK agent + agent = Agent( + name="data_assistant", + model="gemini-2.0-flash", + instruction="You help users find and analyze data.", + tools=[toolset], + ) + ``` + """ + + def __init__( + self, + *, + vector_search_indexes: list[str] | None = None, + genie_space_ids: list[str] | None = None, + workspace_client: Optional[WorkspaceClient] = None, + embedding_fn: Callable[[str], list[float]] | None = None, + tool_filter: Optional[Union[ToolPredicate, list[str]]] = None, + tool_name_prefix: Optional[str] = None, + ): + """ + Initialize the DatabricksToolset. + + Args: + vector_search_indexes: List of Vector Search index names to include. + Each index name should be in the format "catalog.schema.index". + genie_space_ids: List of Genie space IDs to include. + workspace_client: Optional WorkspaceClient for authentication. + embedding_fn: Optional embedding function for self-managed embeddings + in Vector Search indexes. + tool_filter: Optional filter to select specific tools by name or predicate. + tool_name_prefix: Optional prefix to add to all tool names. + """ + super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix) + + self._workspace_client = workspace_client + self._embedding_fn = embedding_fn + self._tools: list[BaseTool] = [] + + # Create Vector Search tools + for index_name in vector_search_indexes or []: + vs_tool = VectorSearchRetrieverTool( + index_name=index_name, + workspace_client=workspace_client, + embedding_fn=embedding_fn, + ) + self._tools.append(vs_tool.as_tool()) + + # Create Genie tools + for space_id in genie_space_ids or []: + genie_tool = GenieTool( + space_id=space_id, + tool_name=f"genie_{space_id.replace('-', '_')}", + client=workspace_client, + ) + self._tools.append(genie_tool.as_tool()) + + async def get_tools(self, readonly_context=None) -> list[BaseTool]: + """ + Return all tools in the toolset. + + Args: + readonly_context: Optional context for filtering tools. + + Returns: + List of BaseTool instances. + """ + # Apply filtering if tool_filter is set + if self.tool_filter is not None: + return [ + tool for tool in self._tools + if self._is_tool_selected(tool.name if hasattr(tool, 'name') else str(tool)) + ] + return self._tools + + def add_vector_search_tool( + self, + index_name: str, + tool_name: str | None = None, + tool_description: str | None = None, + num_results: int = 5, + **kwargs, + ) -> "DatabricksToolset": + """ + Add a Vector Search tool to the toolset. + + Args: + index_name: The name of the Vector Search index. + tool_name: Optional custom name for the tool. + tool_description: Optional custom description. + num_results: Number of results to return (default: 5). + **kwargs: Additional arguments passed to VectorSearchRetrieverTool. + + Returns: + Self for method chaining. + """ + vs_tool = VectorSearchRetrieverTool( + index_name=index_name, + tool_name=tool_name, + tool_description=tool_description, + num_results=num_results, + workspace_client=self._workspace_client, + embedding_fn=self._embedding_fn, + **kwargs, + ) + self._tools.append(vs_tool.as_tool()) + return self + + def add_genie_tool( + self, + space_id: str, + tool_name: str | None = None, + tool_description: str | None = None, + **kwargs, + ) -> "DatabricksToolset": + """ + Add a Genie tool to the toolset. + + Args: + space_id: The ID of the Genie space. + tool_name: Optional custom name for the tool. + tool_description: Optional custom description. + **kwargs: Additional arguments passed to GenieTool. + + Returns: + Self for method chaining. + """ + genie = GenieTool( + space_id=space_id, + tool_name=tool_name or f"genie_{space_id.replace('-', '_')}", + tool_description=tool_description, + client=self._workspace_client, + **kwargs, + ) + self._tools.append(genie.as_tool()) + return self + + def add_custom_tool(self, tool: FunctionTool | BaseTool) -> "DatabricksToolset": + """ + Add a custom tool to the toolset. + + Args: + tool: A FunctionTool or BaseTool instance. + + Returns: + Self for method chaining. + """ + self._tools.append(tool) + return self + + async def close(self) -> None: + """Clean up resources.""" + # No persistent resources to clean up + pass diff --git a/integrations/google-adk/src/databricks_google_adk/vector_search_retriever_tool.py b/integrations/google-adk/src/databricks_google_adk/vector_search_retriever_tool.py new file mode 100644 index 000000000..57086aba7 --- /dev/null +++ b/integrations/google-adk/src/databricks_google_adk/vector_search_retriever_tool.py @@ -0,0 +1,217 @@ +import inspect +from typing import Any, Callable + +from databricks_ai_bridge.utils.vector_search import ( + IndexDetails, + RetrieverSchema, + parse_vector_search_response, + validate_and_get_return_columns, + validate_and_get_text_column, +) +from databricks_ai_bridge.vector_search_retriever_tool import ( + FilterItem, + VectorSearchRetrieverToolInput, + VectorSearchRetrieverToolMixin, + vector_search_retriever_tool_trace, +) +from google.adk.tools import FunctionTool +from pydantic import Field, PrivateAttr + + +class VectorSearchRetrieverTool(VectorSearchRetrieverToolMixin): + """ + Databricks Vector Search retriever tool for Google ADK. + + This tool allows Google ADK agents to search and retrieve documents from + Databricks Vector Search indexes. + + Example: + ```python + from databricks_google_adk import VectorSearchRetrieverTool + from google.adk.agents import Agent + + # Create the tool + vector_search_tool = VectorSearchRetrieverTool( + index_name="catalog.schema.my_index", + num_results=5, + ) + + # Use with an ADK agent + agent = Agent( + name="search_assistant", + model="gemini-2.0-flash", + instruction="You are a helpful assistant that searches documents.", + tools=[vector_search_tool.as_tool()], + ) + ``` + """ + + text_column: str | None = Field( + None, + description="The name of the text column to use for the embeddings. " + "Required for direct-access index or delta-sync index with " + "self-managed embeddings.", + ) + embedding_fn: Callable[[str], list[float]] | None = Field( + None, + description="Embedding function for self-managed embeddings. " + "Should accept a string and return a list of floats.", + ) + + _index = PrivateAttr() + _index_details = PrivateAttr() + _retriever_schema = PrivateAttr() + _adk_tool = PrivateAttr(default=None) + + def model_post_init(self, __context: Any) -> None: + """Initialize the vector search client and index after model creation.""" + from databricks.vector_search.client import VectorSearchClient + from databricks.vector_search.utils import CredentialStrategy + + credential_strategy = None + if ( + self.workspace_client is not None + and self.workspace_client.config.auth_type == "model_serving_user_credentials" + ): + credential_strategy = CredentialStrategy.MODEL_SERVING_USER_CREDENTIALS + + self._index = VectorSearchClient( + disable_notice=True, credential_strategy=credential_strategy + ).get_index(index_name=self.index_name) + self._index_details = IndexDetails(self._index) + + # Validate columns + self.text_column = validate_and_get_text_column(self.text_column, self._index_details) + self.columns = validate_and_get_return_columns( + self.columns or [], + self.text_column, + self._index_details, + self.doc_uri, + self.primary_key, + ) + self._retriever_schema = RetrieverSchema( + text_column=self.text_column, + doc_uri=self.doc_uri, + primary_key=self.primary_key, + other_columns=self.columns, + ) + + def _get_query_text_vector(self, query: str) -> tuple[str | None, list[float] | None]: + """Get the query text and vector based on the index configuration.""" + if self._index_details.is_databricks_managed_embeddings(): + if self.embedding_fn: + raise ValueError( + f"The index '{self._index_details.name}' uses Databricks-managed embeddings. " + "Do not pass the `embedding_fn` parameter when executing retriever calls." + ) + return query, None + + if not self.embedding_fn: + raise ValueError( + "The embedding_fn is required for non-Databricks-managed " + "embeddings Vector Search indexes in order to generate embeddings for retrieval queries." + ) + + text = query if self.query_type and self.query_type.upper() == "HYBRID" else None + vector = self.embedding_fn(query) + if ( + index_embedding_dimension := self._index_details.embedding_vector_column.get( + "embedding_dimension" + ) + ) and len(vector) != index_embedding_dimension: + raise ValueError( + f"Expected embedding dimension {index_embedding_dimension} but got {len(vector)}" + ) + return text, vector + + @vector_search_retriever_tool_trace + def _search( + self, query: str, filters: list[FilterItem] | None = None, **kwargs: Any + ) -> list[dict[str, Any]]: + """ + Execute a similarity search against the vector index. + + Args: + query: The search query string. + filters: Optional list of filters to apply to the search. + **kwargs: Additional keyword arguments passed to the search. + + Returns: + A list of dictionaries containing the search results. + """ + query_text, query_vector = self._get_query_text_vector(query) + + # Since LLM can generate either a dict or FilterItem, convert to dict always + filters_dict = {dict(item)["key"]: dict(item)["value"] for item in (filters or [])} + combined_filters = {**filters_dict, **(self.filters or {})} + + signature = inspect.signature(self._index.similarity_search) + kwargs = {**kwargs, **(self.model_extra or {})} + kwargs = {k: v for k, v in kwargs.items() if k in signature.parameters} + + # Allow kwargs to override the default values upon invocation + num_results = kwargs.pop("num_results", self.num_results) + query_type = kwargs.pop("query_type", self.query_type) + reranker = kwargs.pop("reranker", self.reranker) + + # Ensure that we don't have duplicate keys + kwargs.update( + { + "query_text": query_text, + "query_vector": query_vector, + "columns": self.columns, + "filters": combined_filters, + "num_results": num_results, + "query_type": query_type, + "reranker": reranker, + } + ) + search_resp = self._index.similarity_search(**kwargs) + return parse_vector_search_response( + search_resp, + retriever_schema=self._retriever_schema, + include_score=self.include_score or False, + ) + + def as_tool(self) -> FunctionTool: + """ + Convert this retriever to a Google ADK FunctionTool. + + Returns: + A FunctionTool that can be used with Google ADK agents. + """ + if self._adk_tool is not None: + return self._adk_tool + + tool_name = self._get_tool_name() + tool_description = self.tool_description or self._get_default_tool_description( + self._index_details + ) + + if self.dynamic_filter: + # Create a function with filter parameter for LLM-generated filters + + def search_with_filters( + query: str, filters: list[dict[str, Any]] | None = None + ) -> list[dict[str, Any]]: + """Search the vector index with optional filters.""" + filter_items = None + if filters: + filter_items = [FilterItem(**f) for f in filters] + return self._search(query, filter_items) + + search_with_filters.__name__ = tool_name + search_with_filters.__doc__ = tool_description + self._adk_tool = FunctionTool(search_with_filters) + else: + # Create a simple function without filter parameter + + def search(query: str) -> list[dict[str, Any]]: + """Search the vector index.""" + return self._search(query) + + search.__name__ = tool_name + search.__doc__ = tool_description + self._adk_tool = FunctionTool(search) + + return self._adk_tool diff --git a/integrations/google-adk/tests/__init__.py b/integrations/google-adk/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/integrations/google-adk/tests/unit_tests/__init__.py b/integrations/google-adk/tests/unit_tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/integrations/google-adk/tests/unit_tests/test_genie.py b/integrations/google-adk/tests/unit_tests/test_genie.py new file mode 100644 index 000000000..79beddd11 --- /dev/null +++ b/integrations/google-adk/tests/unit_tests/test_genie.py @@ -0,0 +1,151 @@ +from unittest.mock import MagicMock, patch + +import pytest +from databricks_ai_bridge.genie import GenieResponse +from google.adk.tools import FunctionTool + +from databricks_google_adk import GenieTool, create_genie_tool + + +@pytest.fixture +def mock_genie(): + """Mock the Genie class.""" + with patch("databricks_google_adk.genie.Genie") as mock: + mock_instance = MagicMock() + mock_instance.description = "Test Genie Space" + mock_instance.ask_question.return_value = GenieResponse( + result="| Column1 | Column2 |\n|---------|---------|\\n| Value1 | Value2 |", + query="SELECT * FROM table", + description="Query description", + conversation_id="conv-123", + ) + mock.return_value = mock_instance + yield mock_instance + + +class TestCreateGenieTool: + """Tests for create_genie_tool function.""" + + def test_create_genie_tool_returns_function_tool(self, mock_genie): + """Test that create_genie_tool returns a FunctionTool.""" + tool = create_genie_tool(space_id="test-space") + assert isinstance(tool, FunctionTool) + + def test_create_genie_tool_custom_name(self, mock_genie): + """Test that custom tool name is respected.""" + tool = create_genie_tool(space_id="test-space", tool_name="my_genie") + assert tool.func.__name__ == "my_genie" + + def test_create_genie_tool_custom_description(self, mock_genie): + """Test that custom description is respected.""" + custom_desc = "Custom description" + tool = create_genie_tool(space_id="test-space", tool_description=custom_desc) + assert tool.func.__doc__ == custom_desc + + def test_create_genie_tool_uses_space_description(self, mock_genie): + """Test that space description is used when no custom description provided.""" + tool = create_genie_tool(space_id="test-space") + # Should use the mock's description + assert tool.func.__doc__ == "Test Genie Space" + + +class TestGenieTool: + """Tests for GenieTool class.""" + + def test_genie_tool_init(self, mock_genie): + """Test GenieTool initialization.""" + genie = GenieTool(space_id="test-space") + assert genie.space_id == "test-space" + assert genie.conversation_id is None + + def test_genie_tool_description_property(self, mock_genie): + """Test description property.""" + genie = GenieTool(space_id="test-space") + assert genie.description == "Test Genie Space" + + def test_genie_tool_as_tool(self, mock_genie): + """Test as_tool method returns FunctionTool.""" + genie = GenieTool(space_id="test-space") + tool = genie.as_tool() + assert isinstance(tool, FunctionTool) + + def test_genie_tool_as_tool_caching(self, mock_genie): + """Test that as_tool returns the same instance on repeated calls.""" + genie = GenieTool(space_id="test-space") + tool1 = genie.as_tool() + tool2 = genie.as_tool() + assert tool1 is tool2 + + def test_genie_tool_ask(self, mock_genie): + """Test ask method returns expected format.""" + genie = GenieTool(space_id="test-space") + result = genie.ask("What is the total?") + + assert "result" in result + assert "query" in result + assert "description" in result + assert "conversation_id" in result + assert result["conversation_id"] == "conv-123" + + def test_genie_tool_ask_updates_conversation_id(self, mock_genie): + """Test that ask updates conversation_id.""" + genie = GenieTool(space_id="test-space") + assert genie.conversation_id is None + + genie.ask("First question") + assert genie.conversation_id == "conv-123" + + def test_genie_tool_reset_conversation(self, mock_genie): + """Test reset_conversation clears conversation_id.""" + genie = GenieTool(space_id="test-space") + genie.ask("First question") + assert genie.conversation_id is not None + + genie.reset_conversation() + assert genie.conversation_id is None + + def test_genie_tool_new_conversation_flag(self, mock_genie): + """Test that new_conversation=True starts fresh conversation.""" + genie = GenieTool(space_id="test-space") + genie._conversation_id = "old-conv" + + # With new_conversation=True, should pass None to ask_question + genie.ask("New question", new_conversation=True) + + mock_genie.ask_question.assert_called_with("New question", conversation_id=None) + + def test_genie_tool_continues_conversation(self, mock_genie): + """Test that conversation continues by default.""" + genie = GenieTool(space_id="test-space") + genie._conversation_id = "existing-conv" + + genie.ask("Follow-up question", new_conversation=False) + + mock_genie.ask_question.assert_called_with( + "Follow-up question", conversation_id="existing-conv" + ) + + def test_genie_tool_custom_tool_name(self, mock_genie): + """Test custom tool name.""" + genie = GenieTool(space_id="test-space", tool_name="custom_genie") + tool = genie.as_tool() + assert tool.func.__name__ == "custom_genie" + + def test_genie_tool_handles_dataframe_result(self, mock_genie): + """Test that DataFrame results are converted to markdown.""" + import pandas as pd + + mock_genie.ask_question.return_value = GenieResponse( + result=pd.DataFrame({"col1": [1, 2], "col2": ["a", "b"]}), + query="SELECT * FROM table", + description="Query description", + conversation_id="conv-123", + ) + + genie = GenieTool(space_id="test-space") + result = genie.ask("Query with DataFrame") + + # Result should be markdown string, not DataFrame + assert isinstance(result["result"], str) + assert "col1" in result["result"] + assert "col2" in result["result"] diff --git a/integrations/google-adk/tests/unit_tests/test_toolset.py b/integrations/google-adk/tests/unit_tests/test_toolset.py new file mode 100644 index 000000000..1b9dcab5d --- /dev/null +++ b/integrations/google-adk/tests/unit_tests/test_toolset.py @@ -0,0 +1,149 @@ +from unittest.mock import MagicMock, patch + +import pytest +from databricks_ai_bridge.genie import GenieResponse +from databricks_ai_bridge.test_utils.vector_search import ( + DELTA_SYNC_INDEX, + mock_vs_client, # noqa: F401 + mock_workspace_client, # noqa: F401 +) +from google.adk.tools import BaseTool, FunctionTool + +from databricks_google_adk import DatabricksToolset + + +@pytest.fixture +def mock_genie(): + """Mock the Genie class for toolset tests.""" + with patch("databricks_google_adk.genie.Genie") as mock: + mock_instance = MagicMock() + mock_instance.description = "Test Genie Space" + mock_instance.ask_question.return_value = GenieResponse( + result="Test result", + query="SELECT * FROM table", + description="Query description", + conversation_id="conv-123", + ) + mock.return_value = mock_instance + yield mock_instance + + +class TestDatabricksToolset: + """Tests for DatabricksToolset class.""" + + def test_empty_toolset(self, mock_genie): + """Test creating an empty toolset.""" + toolset = DatabricksToolset() + assert toolset._tools == [] + + def test_toolset_with_vector_search(self, mock_genie): + """Test creating a toolset with Vector Search indexes.""" + toolset = DatabricksToolset( + vector_search_indexes=[DELTA_SYNC_INDEX], + ) + assert len(toolset._tools) == 1 + assert isinstance(toolset._tools[0], FunctionTool) + + def test_toolset_with_genie(self, mock_genie): + """Test creating a toolset with Genie spaces.""" + toolset = DatabricksToolset( + genie_space_ids=["test-space-1"], + ) + assert len(toolset._tools) == 1 + assert isinstance(toolset._tools[0], FunctionTool) + + def test_toolset_with_multiple_tools(self, mock_genie): + """Test creating a toolset with multiple tools.""" + toolset = DatabricksToolset( + vector_search_indexes=[DELTA_SYNC_INDEX], + genie_space_ids=["test-space-1", "test-space-2"], + ) + assert len(toolset._tools) == 3 # 1 VS + 2 Genie + + @pytest.mark.asyncio + async def test_get_tools(self, mock_genie): + """Test get_tools returns all tools.""" + toolset = DatabricksToolset( + vector_search_indexes=[DELTA_SYNC_INDEX], + genie_space_ids=["test-space-1"], + ) + tools = await toolset.get_tools() + assert len(tools) == 2 + + def test_add_vector_search_tool(self, mock_genie): + """Test add_vector_search_tool method.""" + toolset = DatabricksToolset() + result = toolset.add_vector_search_tool( + index_name=DELTA_SYNC_INDEX, + tool_name="custom_search", + ) + + # Should return self for chaining + assert result is toolset + assert len(toolset._tools) == 1 + + def test_add_genie_tool(self, mock_genie): + """Test add_genie_tool method.""" + toolset = DatabricksToolset() + result = toolset.add_genie_tool( + space_id="test-space", + tool_name="custom_genie", + ) + + # Should return self for chaining + assert result is toolset + assert len(toolset._tools) == 1 + + def test_add_custom_tool(self, mock_genie): + """Test add_custom_tool method.""" + + def custom_func(x: str) -> str: + return x + + custom_tool = FunctionTool(custom_func) + + toolset = DatabricksToolset() + result = toolset.add_custom_tool(custom_tool) + + assert result is toolset + assert len(toolset._tools) == 1 + assert toolset._tools[0] is custom_tool + + def test_method_chaining(self, mock_genie): + """Test that builder methods can be chained.""" + toolset = ( + DatabricksToolset() + .add_vector_search_tool(index_name=DELTA_SYNC_INDEX) + .add_genie_tool(space_id="test-space") + ) + + assert len(toolset._tools) == 2 + + @pytest.mark.asyncio + async def test_get_tools_with_filter(self, mock_genie): + """Test get_tools with tool_filter.""" + toolset = DatabricksToolset( + vector_search_indexes=[DELTA_SYNC_INDEX], + genie_space_ids=["test-space"], + tool_filter=["genie_test_space"], # Only include genie tool + ) + + tools = await toolset.get_tools() + # Only the genie tool should be returned (filtered by name) + assert len(tools) == 1 + + @pytest.mark.asyncio + async def test_close(self, mock_genie): + """Test close method doesn't raise.""" + toolset = DatabricksToolset() + await toolset.close() # Should not raise + + def test_genie_tool_name_formatting(self, mock_genie): + """Test that Genie tool names handle dashes correctly.""" + toolset = DatabricksToolset( + genie_space_ids=["space-with-dashes"], + ) + + # The tool name should have dashes replaced with underscores + tool = toolset._tools[0] + assert tool.func.__name__ == "genie_space_with_dashes" diff --git a/integrations/google-adk/tests/unit_tests/test_vector_search_retriever_tool.py b/integrations/google-adk/tests/unit_tests/test_vector_search_retriever_tool.py new file mode 100644 index 000000000..82cd82595 --- /dev/null +++ b/integrations/google-adk/tests/unit_tests/test_vector_search_retriever_tool.py @@ -0,0 +1,204 @@ +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock, create_autospec, patch + +import pytest +from databricks.vector_search.client import VectorSearchIndex +from databricks_ai_bridge.test_utils.vector_search import ( + ALL_INDEX_NAMES, + DEFAULT_VECTOR_DIMENSION, + DELTA_SYNC_INDEX, + EXAMPLE_SEARCH_RESPONSE, + mock_vs_client, # noqa: F401 + mock_workspace_client, # noqa: F401 +) +from databricks_ai_bridge.vector_search_retriever_tool import FilterItem +from google.adk.tools import FunctionTool + +from databricks_google_adk import VectorSearchRetrieverTool + + +def fake_embedding_fn(text: str) -> list[float]: + """Fake embedding function for testing.""" + return [1.0] * (DEFAULT_VECTOR_DIMENSION - 1) + [0.0] + + +def init_vector_search_tool( + index_name: str, + columns: Optional[List[str]] = None, + tool_name: Optional[str] = None, + tool_description: Optional[str] = None, + embedding_fn=None, + text_column: Optional[str] = None, + filters: Optional[Dict[str, Any]] = None, + **kwargs: Any, +) -> VectorSearchRetrieverTool: + kwargs.update( + { + "index_name": index_name, + "columns": columns, + "tool_name": tool_name, + "tool_description": tool_description, + "embedding_fn": embedding_fn, + "text_column": text_column, + "filters": filters, + } + ) + if index_name != DELTA_SYNC_INDEX: + kwargs.update( + { + "embedding_fn": fake_embedding_fn, + "text_column": "text", + } + ) + return VectorSearchRetrieverTool(**kwargs) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_init(index_name: str) -> None: + """Test that VectorSearchRetrieverTool initializes correctly.""" + vector_search_tool = init_vector_search_tool(index_name) + assert isinstance(vector_search_tool, VectorSearchRetrieverTool) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_as_tool_returns_function_tool(index_name: str) -> None: + """Test that as_tool() returns a FunctionTool.""" + vector_search_tool = init_vector_search_tool(index_name) + adk_tool = vector_search_tool.as_tool() + assert isinstance(adk_tool, FunctionTool) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_tool_name_generation(index_name: str) -> None: + """Test that tool names are generated correctly.""" + vector_search_tool = init_vector_search_tool(index_name) + expected_name = index_name.replace(".", "__") + assert vector_search_tool._get_tool_name() == expected_name + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_custom_tool_name(index_name: str) -> None: + """Test that custom tool names are respected.""" + custom_name = "my_custom_tool" + vector_search_tool = init_vector_search_tool(index_name, tool_name=custom_name) + assert vector_search_tool._get_tool_name() == custom_name + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_tool_description_generation(index_name: str) -> None: + """Test that tool descriptions are generated correctly.""" + vector_search_tool = init_vector_search_tool(index_name) + adk_tool = vector_search_tool.as_tool() + # The function's docstring becomes the tool description + assert adk_tool.func.__doc__ is not None + assert "vector search" in adk_tool.func.__doc__.lower() or "search" in adk_tool.func.__doc__.lower() + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_custom_tool_description(index_name: str) -> None: + """Test that custom tool descriptions are respected.""" + custom_description = "My custom tool description" + vector_search_tool = init_vector_search_tool(index_name, tool_description=custom_description) + adk_tool = vector_search_tool.as_tool() + assert adk_tool.func.__doc__ == custom_description + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_search_execution(index_name: str) -> None: + """Test that search can be executed.""" + vector_search_tool = init_vector_search_tool(index_name) + results = vector_search_tool._search("test query") + assert isinstance(results, list) + + +@pytest.mark.parametrize("index_name", ALL_INDEX_NAMES) +def test_as_tool_caching(index_name: str) -> None: + """Test that as_tool() returns the same instance on repeated calls.""" + vector_search_tool = init_vector_search_tool(index_name) + tool1 = vector_search_tool.as_tool() + tool2 = vector_search_tool.as_tool() + assert tool1 is tool2 + + +def test_filters_are_passed_through() -> None: + """Test that filters are correctly passed to the search.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX) + vector_search_tool._index = create_autospec(VectorSearchIndex, instance=True) + vector_search_tool._index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE + + vector_search_tool._search( + query="what cities are in Germany", + filters=[FilterItem(key="country", value="Germany")], + ) + vector_search_tool._index.similarity_search.assert_called_once() + call_kwargs = vector_search_tool._index.similarity_search.call_args.kwargs + assert call_kwargs["filters"] == {"country": "Germany"} + + +def test_filters_are_combined() -> None: + """Test that runtime filters are combined with predefined filters.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, filters={"city LIKE": "Berlin"}) + vector_search_tool._index = create_autospec(VectorSearchIndex, instance=True) + vector_search_tool._index.similarity_search.return_value = EXAMPLE_SEARCH_RESPONSE + + vector_search_tool._search( + query="what cities are in Germany", + filters=[FilterItem(key="country", value="Germany")], + ) + call_kwargs = vector_search_tool._index.similarity_search.call_args.kwargs + assert call_kwargs["filters"] == {"city LIKE": "Berlin", "country": "Germany"} + + +def test_dynamic_filter_creates_function_with_filters_param() -> None: + """Test that dynamic_filter=True creates a tool with filters parameter.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=True) + adk_tool = vector_search_tool.as_tool() + + # Check that the function accepts filters parameter + import inspect + sig = inspect.signature(adk_tool.func) + assert "filters" in sig.parameters + + +def test_static_filter_creates_function_without_filters_param() -> None: + """Test that dynamic_filter=False creates a tool without filters parameter.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, dynamic_filter=False) + adk_tool = vector_search_tool.as_tool() + + # Check that the function does not accept filters parameter + import inspect + sig = inspect.signature(adk_tool.func) + assert "filters" not in sig.parameters + + +def test_num_results_configuration() -> None: + """Test that num_results is configurable.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, num_results=10) + assert vector_search_tool.num_results == 10 + + +def test_query_type_configuration() -> None: + """Test that query_type is configurable.""" + vector_search_tool = init_vector_search_tool(DELTA_SYNC_INDEX, query_type="HYBRID") + assert vector_search_tool.query_type == "HYBRID" + + +def test_embedding_fn_required_for_self_managed() -> None: + """Test that embedding_fn is required for self-managed embeddings indexes.""" + # For non-delta-sync indexes, embedding_fn is required + with pytest.raises(ValueError, match="embedding_fn is required"): + tool = VectorSearchRetrieverTool( + index_name="test.direct_access.index", + text_column="text", + # No embedding_fn provided + ) + tool._search("test query") + + +def test_embedding_fn_not_allowed_for_databricks_managed() -> None: + """Test that embedding_fn is not allowed for Databricks-managed embeddings.""" + tool = init_vector_search_tool(DELTA_SYNC_INDEX) + tool.embedding_fn = fake_embedding_fn # Try to set embedding_fn + + with pytest.raises(ValueError, match="Databricks-managed embeddings"): + tool._search("test query") From ebae4cb6f79371cb5c7b97a8cf16b6481fdec76d Mon Sep 17 00:00:00 2001 From: Claude Date: Fri, 23 Jan 2026 01:34:49 +0000 Subject: [PATCH 2/2] google-adk: Add MCP toolset and Agent Engine deployment helpers This adds two major features to the Google ADK integration: 1. DatabricksMcpToolset - Connect to Databricks MCP servers: - UC Functions MCP via for_uc_functions() - Vector Search MCP via for_vector_search() - Genie MCP via for_genie() 2. Vertex AI Agent Engine deployment: - DatabricksAgentEngineApp for wrapping agents - deploy_to_agent_engine() one-step deployment - create_agent_engine_config() for custom configs - Automatic handling of Databricks credentials via Secret Manager --- integrations/google-adk/README.md | 145 ++++++++ integrations/google-adk/pyproject.toml | 6 + .../src/databricks_google_adk/__init__.py | 28 ++ .../src/databricks_google_adk/deployment.py | 340 ++++++++++++++++++ .../src/databricks_google_adk/mcp.py | 329 +++++++++++++++++ .../tests/unit_tests/test_deployment.py | 193 ++++++++++ .../google-adk/tests/unit_tests/test_mcp.py | 190 ++++++++++ 7 files changed, 1231 insertions(+) create mode 100644 integrations/google-adk/src/databricks_google_adk/deployment.py create mode 100644 integrations/google-adk/src/databricks_google_adk/mcp.py create mode 100644 integrations/google-adk/tests/unit_tests/test_deployment.py create mode 100644 integrations/google-adk/tests/unit_tests/test_mcp.py diff --git a/integrations/google-adk/README.md b/integrations/google-adk/README.md index c607afdbc..311dda269 100644 --- a/integrations/google-adk/README.md +++ b/integrations/google-adk/README.md @@ -10,9 +10,17 @@ pip install databricks-google-adk ## Features +**Tools:** - **VectorSearchRetrieverTool**: Search Databricks Vector Search indexes from ADK agents - **GenieTool**: Query Databricks Genie AI/BI spaces using natural language + +**Toolsets:** - **DatabricksToolset**: Bundle multiple Databricks tools together for easy agent configuration +- **DatabricksMcpToolset**: Connect to Databricks MCP servers (UC Functions, Vector Search, Genie) + +**Deployment:** +- **DatabricksAgentEngineApp**: Deploy Databricks-powered agents to Vertex AI Agent Engine +- **deploy_to_agent_engine**: One-step deployment helper function ## Quick Start @@ -107,6 +115,111 @@ agent = Agent( ) ``` +### Databricks MCP Toolset + +Connect to Databricks MCP servers for UC Functions, Vector Search, or Genie: + +```python +from databricks_google_adk import DatabricksMcpToolset +from google.adk.agents import Agent + +# Connect to UC Functions MCP server +toolset = DatabricksMcpToolset.for_uc_functions( + catalog="my_catalog", + schema="my_schema", +) + +# Or connect to Vector Search MCP +toolset = DatabricksMcpToolset.for_vector_search( + catalog="my_catalog", + schema="my_schema", +) + +# Or connect to Genie MCP +toolset = DatabricksMcpToolset.for_genie( + space_id="my-genie-space-id", +) + +# Use with an ADK agent +agent = Agent( + name="function_caller", + model="gemini-2.0-flash", + instruction="You help users by calling Databricks functions.", + tools=[toolset], +) +``` + +## Deployment to Vertex AI Agent Engine + +Deploy your Databricks-powered ADK agents to Google Cloud's Vertex AI Agent Engine. + +### Installation + +```bash +pip install databricks-google-adk[deployment] +``` + +### Quick Deployment + +```python +from databricks_google_adk import VectorSearchRetrieverTool +from databricks_google_adk.deployment import deploy_to_agent_engine +from google.adk.agents import Agent + +# Create your agent +vector_search = VectorSearchRetrieverTool(index_name="catalog.schema.index") +agent = Agent( + name="search_agent", + model="gemini-2.0-flash", + instruction="You help users search documents.", + tools=[vector_search.as_tool()], +) + +# Deploy to Agent Engine +remote_agent = deploy_to_agent_engine( + agent=agent, + project="my-gcp-project", + location="us-central1", + staging_bucket="gs://my-staging-bucket", + databricks_host="https://my-workspace.databricks.com", + databricks_token_secret="projects/my-project/secrets/databricks-token/versions/latest", +) + +print(f"Deployed: {remote_agent.resource_name}") +``` + +### Step-by-Step Deployment + +For more control, use the `DatabricksAgentEngineApp` class: + +```python +from databricks_google_adk import DatabricksAgentEngineApp +import vertexai + +# Create the deployable app +app = DatabricksAgentEngineApp(agent=agent) + +# Test locally first +async for event in app.test_locally("Search for AI documents"): + print(event) + +# Initialize Vertex AI +client = vertexai.Client(project="my-project", location="us-central1") + +# Get deployment config +config = app.get_deployment_config( + staging_bucket="gs://my-bucket", + databricks_host="https://workspace.databricks.com", + databricks_token_secret="projects/my-project/secrets/db-token/versions/latest", +) + +# Deploy +remote_agent = client.agent_engines.create( + agent=app.adk_app, + config=config, +) +``` + ## Advanced Usage ### Self-Managed Embeddings @@ -225,12 +338,44 @@ genie = GenieTool( | `tool_filter` | `list[str]` | Filter tools by name | | `tool_name_prefix` | `str` | Prefix to add to all tool names | +### DatabricksMcpToolset + +| Parameter | Type | Description | +|-----------|------|-------------| +| `server_url` | `str` | URL of the Databricks MCP server | +| `workspace_client` | `WorkspaceClient` | Custom Databricks client | +| `tool_filter` | `list[str]` | Filter tools by name | +| `tool_name_prefix` | `str` | Prefix to add to all tool names | + +Factory methods: +- `DatabricksMcpToolset.for_uc_functions(catalog, schema)` - UC Functions MCP +- `DatabricksMcpToolset.for_vector_search(catalog, schema)` - Vector Search MCP +- `DatabricksMcpToolset.for_genie(space_id)` - Genie MCP + +### deploy_to_agent_engine + +| Parameter | Type | Description | +|-----------|------|-------------| +| `agent` | `Agent` | The ADK Agent to deploy | +| `project` | `str` | Google Cloud project ID | +| `location` | `str` | Google Cloud region | +| `staging_bucket` | `str` | GCS bucket for staging (format: `gs://bucket`) | +| `databricks_host` | `str` | Databricks workspace URL | +| `databricks_token_secret` | `str` | Secret Manager secret for Databricks token | +| `display_name` | `str` | Display name for the deployed agent | +| `description` | `str` | Description for the deployed agent | + ## Requirements +Core: - Python >= 3.10 - google-adk >= 1.0.0 - databricks-ai-bridge >= 0.4.0 - databricks-vectorsearch >= 0.40 +- databricks-mcp >= 0.5.0 + +For deployment to Agent Engine: +- google-cloud-aiplatform[agent_engines,adk] >= 1.112 ## License diff --git a/integrations/google-adk/pyproject.toml b/integrations/google-adk/pyproject.toml index 501404918..d0d8f26f2 100644 --- a/integrations/google-adk/pyproject.toml +++ b/integrations/google-adk/pyproject.toml @@ -11,10 +11,16 @@ requires-python = ">=3.10" dependencies = [ "databricks-vectorsearch>=0.40", "databricks-ai-bridge>=0.4.0", + "databricks-mcp>=0.5.0", "google-adk>=1.0.0", "pydantic>=2.10.0", ] +[project.optional-dependencies] +deployment = [ + "google-cloud-aiplatform[agent_engines,adk]>=1.112", +] + [dependency-groups] dev = [ "typing_extensions>=4.15.0", diff --git a/integrations/google-adk/src/databricks_google_adk/__init__.py b/integrations/google-adk/src/databricks_google_adk/__init__.py index 61a4b748d..0276248e5 100644 --- a/integrations/google-adk/src/databricks_google_adk/__init__.py +++ b/integrations/google-adk/src/databricks_google_adk/__init__.py @@ -6,10 +6,19 @@ Available classes and functions: +Tools: - :class:`VectorSearchRetrieverTool` - Search Databricks Vector Search indexes - :class:`GenieTool` - Query Databricks Genie AI/BI spaces - :func:`create_genie_tool` - Factory function to create Genie tools + +Toolsets: - :class:`DatabricksToolset` - Bundle multiple Databricks tools together +- :class:`DatabricksMcpToolset` - Connect to Databricks MCP servers + +Deployment: +- :class:`DatabricksAgentEngineApp` - Deploy agents to Vertex AI Agent Engine +- :func:`deploy_to_agent_engine` - One-step deployment helper +- :func:`create_agent_engine_config` - Create deployment configuration Example: ```python @@ -32,13 +41,32 @@ ``` """ +from databricks_google_adk.deployment import ( + DatabricksAgentEngineApp, + create_agent_engine_config, + deploy_to_agent_engine, + get_databricks_requirements, +) from databricks_google_adk.genie import GenieTool, create_genie_tool +from databricks_google_adk.mcp import ( + DatabricksMcpToolset, + create_databricks_mcp_toolset, +) from databricks_google_adk.toolset import DatabricksToolset from databricks_google_adk.vector_search_retriever_tool import VectorSearchRetrieverTool __all__ = [ + # Tools "VectorSearchRetrieverTool", "GenieTool", "create_genie_tool", + # Toolsets "DatabricksToolset", + "DatabricksMcpToolset", + "create_databricks_mcp_toolset", + # Deployment + "DatabricksAgentEngineApp", + "deploy_to_agent_engine", + "create_agent_engine_config", + "get_databricks_requirements", ] diff --git a/integrations/google-adk/src/databricks_google_adk/deployment.py b/integrations/google-adk/src/databricks_google_adk/deployment.py new file mode 100644 index 000000000..24749c40f --- /dev/null +++ b/integrations/google-adk/src/databricks_google_adk/deployment.py @@ -0,0 +1,340 @@ +""" +Vertex AI Agent Engine deployment helpers for Databricks-powered ADK agents. + +This module provides utilities to deploy ADK agents that use Databricks tools +to Google Cloud's Vertex AI Agent Engine. +""" + +from typing import Any, Optional + +from google.adk.agents import Agent + + +def get_databricks_requirements() -> list[str]: + """ + Get the pip requirements needed for Databricks tools in Agent Engine. + + Returns: + List of pip package requirements. + + Example: + ```python + from databricks_google_adk.deployment import get_databricks_requirements + + requirements = get_databricks_requirements() + # ['databricks-google-adk', 'databricks-sdk', ...] + ``` + """ + return [ + "databricks-google-adk", + "databricks-sdk", + "databricks-vectorsearch", + "databricks-ai-bridge", + "databricks-mcp", + ] + + +def create_agent_engine_config( + staging_bucket: str, + requirements: Optional[list[str]] = None, + env_vars: Optional[dict[str, str]] = None, + databricks_host: Optional[str] = None, + databricks_token_secret: Optional[str] = None, + **kwargs, +) -> dict[str, Any]: + """ + Create configuration for deploying to Vertex AI Agent Engine. + + Args: + staging_bucket: GCS bucket for staging (format: "gs://bucket-name"). + requirements: Additional pip requirements. Databricks requirements + are automatically included. + env_vars: Environment variables to set in the deployed agent. + databricks_host: Databricks workspace URL. If provided, will be set + as DATABRICKS_HOST environment variable. + databricks_token_secret: Secret Manager secret name for Databricks token. + Format: "projects/PROJECT/secrets/SECRET/versions/VERSION". + If provided, will be configured for DATABRICKS_TOKEN. + **kwargs: Additional configuration options passed to Agent Engine. + + Returns: + Configuration dictionary for agent_engines.create(). + + Example: + ```python + from databricks_google_adk.deployment import create_agent_engine_config + + config = create_agent_engine_config( + staging_bucket="gs://my-bucket", + databricks_host="https://my-workspace.databricks.com", + databricks_token_secret="projects/my-project/secrets/databricks-token/versions/latest", + ) + + # Use with Vertex AI + remote_agent = client.agent_engines.create( + agent=app, + config=config, + ) + ``` + """ + # Combine requirements + all_requirements = get_databricks_requirements() + if requirements: + all_requirements.extend(requirements) + # Remove duplicates while preserving order + all_requirements = list(dict.fromkeys(all_requirements)) + + # Build environment variables + all_env_vars = env_vars.copy() if env_vars else {} + if databricks_host: + all_env_vars["DATABRICKS_HOST"] = databricks_host + + config = { + "requirements": all_requirements, + "staging_bucket": staging_bucket, + **kwargs, + } + + if all_env_vars: + config["env_vars"] = all_env_vars + + # Handle secret for Databricks token + if databricks_token_secret: + config["secrets"] = config.get("secrets", {}) + config["secrets"]["DATABRICKS_TOKEN"] = databricks_token_secret + + return config + + +class DatabricksAgentEngineApp: + """ + A wrapper for deploying Databricks-powered ADK agents to Vertex AI Agent Engine. + + This class provides a simplified interface for creating and deploying + ADK agents that use Databricks tools to Vertex AI Agent Engine. + + Example: + ```python + from databricks_google_adk import VectorSearchRetrieverTool, DatabricksAgentEngineApp + from google.adk.agents import Agent + + # Create an agent with Databricks tools + vector_search = VectorSearchRetrieverTool(index_name="catalog.schema.index") + agent = Agent( + name="search_agent", + model="gemini-2.0-flash", + tools=[vector_search.as_tool()], + ) + + # Create the deployable app + app = DatabricksAgentEngineApp(agent=agent) + + # Deploy to Agent Engine + import vertexai + client = vertexai.Client(project="my-project", location="us-central1") + + remote_agent = client.agent_engines.create( + agent=app.adk_app, + config=app.get_deployment_config( + staging_bucket="gs://my-bucket", + databricks_host="https://my-workspace.databricks.com", + databricks_token_secret="projects/my-project/secrets/db-token/versions/latest", + ), + ) + ``` + """ + + def __init__( + self, + agent: Agent, + additional_requirements: Optional[list[str]] = None, + ): + """ + Initialize the DatabricksAgentEngineApp. + + Args: + agent: The ADK Agent to deploy. + additional_requirements: Additional pip requirements beyond + the standard Databricks packages. + """ + self._agent = agent + self._additional_requirements = additional_requirements or [] + self._adk_app = None + + @property + def agent(self) -> Agent: + """Get the underlying ADK agent.""" + return self._agent + + @property + def adk_app(self): + """ + Get the AdkApp for deployment. + + Returns: + An AdkApp instance wrapping the agent. + + Note: + Requires vertexai package: pip install google-cloud-aiplatform[agent_engines,adk] + """ + if self._adk_app is None: + try: + from vertexai.agent_engines import AdkApp + except ImportError: + raise ImportError( + "vertexai package is required for Agent Engine deployment. " + "Install with: pip install google-cloud-aiplatform[agent_engines,adk]" + ) + self._adk_app = AdkApp(agent=self._agent) + return self._adk_app + + def get_deployment_config( + self, + staging_bucket: str, + databricks_host: Optional[str] = None, + databricks_token_secret: Optional[str] = None, + env_vars: Optional[dict[str, str]] = None, + **kwargs, + ) -> dict[str, Any]: + """ + Get the deployment configuration for Agent Engine. + + Args: + staging_bucket: GCS bucket for staging (format: "gs://bucket-name"). + databricks_host: Databricks workspace URL. + databricks_token_secret: Secret Manager secret for Databricks token. + env_vars: Additional environment variables. + **kwargs: Additional configuration options. + + Returns: + Configuration dictionary for agent_engines.create(). + """ + return create_agent_engine_config( + staging_bucket=staging_bucket, + requirements=self._additional_requirements, + env_vars=env_vars, + databricks_host=databricks_host, + databricks_token_secret=databricks_token_secret, + **kwargs, + ) + + async def test_locally( + self, + message: str, + user_id: str = "test-user", + ): + """ + Test the agent locally before deployment. + + Args: + message: The message to send to the agent. + user_id: User ID for the session. + + Yields: + Events from the agent's response stream. + + Example: + ```python + app = DatabricksAgentEngineApp(agent=my_agent) + + async for event in app.test_locally("What documents match 'AI'?"): + print(event) + ``` + """ + async for event in self.adk_app.async_stream_query( + user_id=user_id, + message=message, + ): + yield event + + +def deploy_to_agent_engine( + agent: Agent, + project: str, + location: str, + staging_bucket: str, + databricks_host: Optional[str] = None, + databricks_token_secret: Optional[str] = None, + display_name: Optional[str] = None, + description: Optional[str] = None, + **kwargs, +): + """ + Deploy an ADK agent with Databricks tools to Vertex AI Agent Engine. + + This is a convenience function that handles the full deployment flow. + + Args: + agent: The ADK Agent to deploy. + project: Google Cloud project ID. + location: Google Cloud region (e.g., "us-central1"). + staging_bucket: GCS bucket for staging (format: "gs://bucket-name"). + databricks_host: Databricks workspace URL. + databricks_token_secret: Secret Manager secret for Databricks token. + display_name: Display name for the deployed agent. + description: Description for the deployed agent. + **kwargs: Additional configuration options. + + Returns: + The deployed remote agent resource. + + Example: + ```python + from databricks_google_adk import VectorSearchRetrieverTool + from databricks_google_adk.deployment import deploy_to_agent_engine + from google.adk.agents import Agent + + # Create agent + vector_search = VectorSearchRetrieverTool(index_name="catalog.schema.index") + agent = Agent( + name="search_agent", + model="gemini-2.0-flash", + tools=[vector_search.as_tool()], + ) + + # Deploy + remote_agent = deploy_to_agent_engine( + agent=agent, + project="my-gcp-project", + location="us-central1", + staging_bucket="gs://my-staging-bucket", + databricks_host="https://my-workspace.databricks.com", + databricks_token_secret="projects/my-project/secrets/db-token/versions/latest", + ) + + print(f"Deployed agent: {remote_agent.resource_name}") + ``` + + Note: + Requires vertexai package: pip install google-cloud-aiplatform[agent_engines,adk] + """ + try: + import vertexai + from vertexai.agent_engines import AdkApp + except ImportError: + raise ImportError( + "vertexai package is required for Agent Engine deployment. " + "Install with: pip install google-cloud-aiplatform[agent_engines,adk]" + ) + + # Initialize Vertex AI client + client = vertexai.Client(project=project, location=location) + + # Create the app and config + app = DatabricksAgentEngineApp(agent=agent) + config = app.get_deployment_config( + staging_bucket=staging_bucket, + databricks_host=databricks_host, + databricks_token_secret=databricks_token_secret, + **kwargs, + ) + + # Add display name and description if provided + create_kwargs = {"agent": app.adk_app, "config": config} + if display_name: + create_kwargs["display_name"] = display_name + if description: + create_kwargs["description"] = description + + # Deploy + return client.agent_engines.create(**create_kwargs) diff --git a/integrations/google-adk/src/databricks_google_adk/mcp.py b/integrations/google-adk/src/databricks_google_adk/mcp.py new file mode 100644 index 000000000..b521db0e8 --- /dev/null +++ b/integrations/google-adk/src/databricks_google_adk/mcp.py @@ -0,0 +1,329 @@ +""" +Databricks MCP integration for Google ADK. + +This module provides a toolset that connects Databricks MCP servers to Google ADK agents. +""" + +from typing import Any, Optional, Union + +from databricks.sdk import WorkspaceClient +from databricks_mcp import DatabricksMCPClient +from google.adk.tools import BaseTool, FunctionTool +from google.adk.tools.base_toolset import BaseToolset, ToolPredicate + + +class DatabricksMcpToolset(BaseToolset): + """ + A Google ADK toolset that connects to Databricks MCP servers. + + This toolset wraps the DatabricksMCPClient to expose Databricks MCP tools + (UC Functions, Vector Search, Genie) to Google ADK agents. + + Supported Databricks MCP server types: + - UC Functions: `/api/2.0/mcp/functions//` + - Vector Search: `/api/2.0/mcp/vector-search//` + - Genie: `/api/2.0/mcp/genie/` + + Example: + ```python + from databricks_google_adk import DatabricksMcpToolset + from google.adk.agents import Agent + + # Connect to a Databricks UC Functions MCP server + toolset = DatabricksMcpToolset( + server_url="https://your-workspace.databricks.com/api/2.0/mcp/functions/catalog/schema" + ) + + # Use with an ADK agent + agent = Agent( + name="function_caller", + model="gemini-2.0-flash", + instruction="You help users by calling Databricks functions.", + tools=[toolset], + ) + ``` + """ + + def __init__( + self, + server_url: str, + workspace_client: Optional[WorkspaceClient] = None, + tool_filter: Optional[Union[ToolPredicate, list[str]]] = None, + tool_name_prefix: Optional[str] = None, + ): + """ + Initialize the DatabricksMcpToolset. + + Args: + server_url: URL of the Databricks MCP server. Supported formats: + - UC Functions: `https:///api/2.0/mcp/functions//` + - Vector Search: `https:///api/2.0/mcp/vector-search//` + - Genie: `https:///api/2.0/mcp/genie/` + workspace_client: Optional WorkspaceClient for authentication. + If not provided, will be created automatically. + tool_filter: Optional filter to select specific tools by name or predicate. + tool_name_prefix: Optional prefix to add to all tool names. + """ + super().__init__(tool_filter=tool_filter, tool_name_prefix=tool_name_prefix) + + self._server_url = server_url + self._workspace_client = workspace_client + self._mcp_client = DatabricksMCPClient( + server_url=server_url, + workspace_client=workspace_client, + ) + self._tools: list[BaseTool] | None = None + + @classmethod + def for_uc_functions( + cls, + catalog: str, + schema: str, + workspace_client: Optional[WorkspaceClient] = None, + workspace_url: Optional[str] = None, + **kwargs, + ) -> "DatabricksMcpToolset": + """ + Create a toolset for Unity Catalog functions. + + Args: + catalog: The catalog name. + schema: The schema name. + workspace_client: Optional WorkspaceClient for authentication. + workspace_url: Optional workspace URL. If not provided, will be + inferred from workspace_client. + **kwargs: Additional arguments passed to DatabricksMcpToolset. + + Returns: + A DatabricksMcpToolset configured for UC functions. + + Example: + ```python + toolset = DatabricksMcpToolset.for_uc_functions( + catalog="my_catalog", + schema="my_schema", + ) + ``` + """ + client = workspace_client or WorkspaceClient() + base_url = workspace_url or client.config.host + server_url = f"{base_url}/api/2.0/mcp/functions/{catalog}/{schema}" + return cls(server_url=server_url, workspace_client=client, **kwargs) + + @classmethod + def for_vector_search( + cls, + catalog: str, + schema: str, + workspace_client: Optional[WorkspaceClient] = None, + workspace_url: Optional[str] = None, + **kwargs, + ) -> "DatabricksMcpToolset": + """ + Create a toolset for Vector Search. + + Args: + catalog: The catalog name. + schema: The schema name. + workspace_client: Optional WorkspaceClient for authentication. + workspace_url: Optional workspace URL. + **kwargs: Additional arguments passed to DatabricksMcpToolset. + + Returns: + A DatabricksMcpToolset configured for Vector Search. + + Example: + ```python + toolset = DatabricksMcpToolset.for_vector_search( + catalog="my_catalog", + schema="my_schema", + ) + ``` + """ + client = workspace_client or WorkspaceClient() + base_url = workspace_url or client.config.host + server_url = f"{base_url}/api/2.0/mcp/vector-search/{catalog}/{schema}" + return cls(server_url=server_url, workspace_client=client, **kwargs) + + @classmethod + def for_genie( + cls, + space_id: str, + workspace_client: Optional[WorkspaceClient] = None, + workspace_url: Optional[str] = None, + **kwargs, + ) -> "DatabricksMcpToolset": + """ + Create a toolset for Genie. + + Args: + space_id: The Genie space ID. + workspace_client: Optional WorkspaceClient for authentication. + workspace_url: Optional workspace URL. + **kwargs: Additional arguments passed to DatabricksMcpToolset. + + Returns: + A DatabricksMcpToolset configured for Genie. + + Example: + ```python + toolset = DatabricksMcpToolset.for_genie( + space_id="my-genie-space-id", + ) + ``` + """ + client = workspace_client or WorkspaceClient() + base_url = workspace_url or client.config.host + server_url = f"{base_url}/api/2.0/mcp/genie/{space_id}" + return cls(server_url=server_url, workspace_client=client, **kwargs) + + def _load_tools(self) -> list[BaseTool]: + """Load tools from the MCP server.""" + mcp_tools = self._mcp_client.list_tools() + adk_tools = [] + + for mcp_tool in mcp_tools: + # Create a closure to capture the tool name + def make_tool_fn(tool_name: str, tool_desc: str, input_schema: dict): + def tool_fn(**kwargs) -> dict[str, Any]: + """Execute the MCP tool.""" + result = self._mcp_client.call_tool(tool_name, kwargs) + # Extract content from CallToolResult + if hasattr(result, "content"): + # MCP returns content as a list of content items + contents = [] + for item in result.content: + if hasattr(item, "text"): + contents.append(item.text) + elif hasattr(item, "data"): + contents.append(item.data) + else: + contents.append(str(item)) + return {"result": "\n".join(contents) if contents else "Success"} + return {"result": str(result)} + + tool_fn.__name__ = tool_name.replace(".", "__") + tool_fn.__doc__ = tool_desc + return tool_fn + + tool_name = mcp_tool.name + tool_description = mcp_tool.description or f"Call {tool_name}" + input_schema = mcp_tool.inputSchema if hasattr(mcp_tool, "inputSchema") else {} + + fn = make_tool_fn(tool_name, tool_description, input_schema) + adk_tools.append(FunctionTool(fn)) + + return adk_tools + + async def get_tools(self, readonly_context=None) -> list[BaseTool]: + """ + Return all tools from the Databricks MCP server. + + Args: + readonly_context: Optional context for filtering tools. + + Returns: + List of BaseTool instances. + """ + if self._tools is None: + self._tools = self._load_tools() + + # Apply filtering if tool_filter is set + if self.tool_filter is not None: + return [ + tool + for tool in self._tools + if self._is_tool_selected( + tool.func.__name__ if hasattr(tool, "func") else str(tool) + ) + ] + return self._tools + + def get_databricks_resources(self) -> list: + """ + Get Databricks resources for MLflow model logging. + + This is useful when deploying agents that use Databricks MCP tools + to ensure proper authorization in Model Serving. + + Returns: + List of Databricks resource objects for MLflow. + """ + return self._mcp_client.get_databricks_resources() + + async def close(self) -> None: + """Clean up resources.""" + # DatabricksMCPClient doesn't have a close method currently + pass + + +def create_databricks_mcp_toolset( + server_type: str, + *, + catalog: Optional[str] = None, + schema: Optional[str] = None, + space_id: Optional[str] = None, + workspace_client: Optional[WorkspaceClient] = None, + **kwargs, +) -> DatabricksMcpToolset: + """ + Factory function to create a DatabricksMcpToolset. + + Args: + server_type: Type of MCP server: "uc_functions", "vector_search", or "genie". + catalog: Catalog name (required for uc_functions and vector_search). + schema: Schema name (required for uc_functions and vector_search). + space_id: Genie space ID (required for genie). + workspace_client: Optional WorkspaceClient for authentication. + **kwargs: Additional arguments passed to DatabricksMcpToolset. + + Returns: + A configured DatabricksMcpToolset. + + Example: + ```python + # For UC Functions + toolset = create_databricks_mcp_toolset( + "uc_functions", + catalog="my_catalog", + schema="my_schema", + ) + + # For Genie + toolset = create_databricks_mcp_toolset( + "genie", + space_id="my-genie-space-id", + ) + ``` + """ + if server_type == "uc_functions": + if not catalog or not schema: + raise ValueError("catalog and schema are required for uc_functions") + return DatabricksMcpToolset.for_uc_functions( + catalog=catalog, + schema=schema, + workspace_client=workspace_client, + **kwargs, + ) + elif server_type == "vector_search": + if not catalog or not schema: + raise ValueError("catalog and schema are required for vector_search") + return DatabricksMcpToolset.for_vector_search( + catalog=catalog, + schema=schema, + workspace_client=workspace_client, + **kwargs, + ) + elif server_type == "genie": + if not space_id: + raise ValueError("space_id is required for genie") + return DatabricksMcpToolset.for_genie( + space_id=space_id, + workspace_client=workspace_client, + **kwargs, + ) + else: + raise ValueError( + f"Unknown server_type: {server_type}. " + "Must be one of: uc_functions, vector_search, genie" + ) diff --git a/integrations/google-adk/tests/unit_tests/test_deployment.py b/integrations/google-adk/tests/unit_tests/test_deployment.py new file mode 100644 index 000000000..fe4d15cec --- /dev/null +++ b/integrations/google-adk/tests/unit_tests/test_deployment.py @@ -0,0 +1,193 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from databricks_google_adk.deployment import ( + DatabricksAgentEngineApp, + create_agent_engine_config, + get_databricks_requirements, +) + + +class TestGetDatabricksRequirements: + """Tests for get_databricks_requirements function.""" + + def test_returns_list(self): + """Test that function returns a list.""" + result = get_databricks_requirements() + assert isinstance(result, list) + + def test_includes_core_packages(self): + """Test that core packages are included.""" + result = get_databricks_requirements() + assert "databricks-google-adk" in result + assert "databricks-sdk" in result + assert "databricks-ai-bridge" in result + assert "databricks-mcp" in result + + +class TestCreateAgentEngineConfig: + """Tests for create_agent_engine_config function.""" + + def test_basic_config(self): + """Test basic configuration creation.""" + config = create_agent_engine_config( + staging_bucket="gs://my-bucket", + ) + + assert config["staging_bucket"] == "gs://my-bucket" + assert "requirements" in config + assert "databricks-google-adk" in config["requirements"] + + def test_with_databricks_host(self): + """Test configuration with Databricks host.""" + config = create_agent_engine_config( + staging_bucket="gs://my-bucket", + databricks_host="https://my-workspace.databricks.com", + ) + + assert config["env_vars"]["DATABRICKS_HOST"] == "https://my-workspace.databricks.com" + + def test_with_databricks_token_secret(self): + """Test configuration with Databricks token secret.""" + secret = "projects/my-project/secrets/db-token/versions/latest" + config = create_agent_engine_config( + staging_bucket="gs://my-bucket", + databricks_token_secret=secret, + ) + + assert config["secrets"]["DATABRICKS_TOKEN"] == secret + + def test_with_additional_requirements(self): + """Test configuration with additional requirements.""" + config = create_agent_engine_config( + staging_bucket="gs://my-bucket", + requirements=["pandas", "numpy"], + ) + + assert "pandas" in config["requirements"] + assert "numpy" in config["requirements"] + # Core packages should still be included + assert "databricks-google-adk" in config["requirements"] + + def test_with_env_vars(self): + """Test configuration with custom environment variables.""" + config = create_agent_engine_config( + staging_bucket="gs://my-bucket", + env_vars={"MY_VAR": "my_value"}, + ) + + assert config["env_vars"]["MY_VAR"] == "my_value" + + def test_env_vars_combined_with_databricks_host(self): + """Test that env_vars are combined with databricks_host.""" + config = create_agent_engine_config( + staging_bucket="gs://my-bucket", + databricks_host="https://workspace.databricks.com", + env_vars={"OTHER_VAR": "other_value"}, + ) + + assert config["env_vars"]["DATABRICKS_HOST"] == "https://workspace.databricks.com" + assert config["env_vars"]["OTHER_VAR"] == "other_value" + + def test_additional_kwargs(self): + """Test that additional kwargs are passed through.""" + config = create_agent_engine_config( + staging_bucket="gs://my-bucket", + custom_option="custom_value", + ) + + assert config["custom_option"] == "custom_value" + + def test_requirements_deduplication(self): + """Test that duplicate requirements are removed.""" + config = create_agent_engine_config( + staging_bucket="gs://my-bucket", + requirements=["databricks-google-adk", "pandas"], # databricks-google-adk is duplicate + ) + + # Count occurrences of databricks-google-adk + count = config["requirements"].count("databricks-google-adk") + assert count == 1 + + +class TestDatabricksAgentEngineApp: + """Tests for DatabricksAgentEngineApp class.""" + + @pytest.fixture + def mock_agent(self): + """Create a mock ADK agent.""" + agent = MagicMock() + agent.name = "test_agent" + return agent + + def test_init(self, mock_agent): + """Test DatabricksAgentEngineApp initialization.""" + app = DatabricksAgentEngineApp(agent=mock_agent) + + assert app.agent is mock_agent + assert app._additional_requirements == [] + + def test_init_with_requirements(self, mock_agent): + """Test initialization with additional requirements.""" + app = DatabricksAgentEngineApp( + agent=mock_agent, + additional_requirements=["pandas", "numpy"], + ) + + assert app._additional_requirements == ["pandas", "numpy"] + + def test_agent_property(self, mock_agent): + """Test agent property.""" + app = DatabricksAgentEngineApp(agent=mock_agent) + assert app.agent is mock_agent + + def test_adk_app_requires_vertexai(self, mock_agent): + """Test that adk_app raises ImportError without vertexai.""" + app = DatabricksAgentEngineApp(agent=mock_agent) + + with patch.dict("sys.modules", {"vertexai": None, "vertexai.agent_engines": None}): + # Force reimport to trigger ImportError + with pytest.raises(ImportError, match="vertexai package is required"): + # Clear cached app + app._adk_app = None + _ = app.adk_app + + def test_adk_app_with_mock_vertexai(self, mock_agent): + """Test adk_app with mocked vertexai.""" + mock_adk_app = MagicMock() + + with patch("databricks_google_adk.deployment.AdkApp", return_value=mock_adk_app): + # Need to patch the import + with patch.dict("sys.modules", {"vertexai": MagicMock(), "vertexai.agent_engines": MagicMock()}): + app = DatabricksAgentEngineApp(agent=mock_agent) + # Can't actually test this without vertexai installed + # Just verify the app object exists + assert app._adk_app is None + + def test_get_deployment_config(self, mock_agent): + """Test get_deployment_config method.""" + app = DatabricksAgentEngineApp( + agent=mock_agent, + additional_requirements=["extra-package"], + ) + + config = app.get_deployment_config( + staging_bucket="gs://bucket", + databricks_host="https://workspace.databricks.com", + ) + + assert config["staging_bucket"] == "gs://bucket" + assert "extra-package" in config["requirements"] + assert config["env_vars"]["DATABRICKS_HOST"] == "https://workspace.databricks.com" + + def test_get_deployment_config_with_secret(self, mock_agent): + """Test get_deployment_config with token secret.""" + app = DatabricksAgentEngineApp(agent=mock_agent) + + config = app.get_deployment_config( + staging_bucket="gs://bucket", + databricks_token_secret="projects/p/secrets/s/versions/v", + ) + + assert config["secrets"]["DATABRICKS_TOKEN"] == "projects/p/secrets/s/versions/v" diff --git a/integrations/google-adk/tests/unit_tests/test_mcp.py b/integrations/google-adk/tests/unit_tests/test_mcp.py new file mode 100644 index 000000000..c5f4b26bf --- /dev/null +++ b/integrations/google-adk/tests/unit_tests/test_mcp.py @@ -0,0 +1,190 @@ +from unittest.mock import MagicMock, patch + +import pytest + +from databricks_google_adk import DatabricksMcpToolset, create_databricks_mcp_toolset + + +@pytest.fixture +def mock_mcp_client(): + """Mock the DatabricksMCPClient.""" + with patch("databricks_google_adk.mcp.DatabricksMCPClient") as mock: + mock_instance = MagicMock() + + # Create mock tools + mock_tool1 = MagicMock() + mock_tool1.name = "tool1" + mock_tool1.description = "First tool" + mock_tool1.inputSchema = {"type": "object", "properties": {}} + + mock_tool2 = MagicMock() + mock_tool2.name = "tool2" + mock_tool2.description = "Second tool" + mock_tool2.inputSchema = {"type": "object", "properties": {}} + + mock_instance.list_tools.return_value = [mock_tool1, mock_tool2] + + # Mock call_tool result + mock_result = MagicMock() + mock_content = MagicMock() + mock_content.text = "Tool result" + mock_result.content = [mock_content] + mock_instance.call_tool.return_value = mock_result + + mock_instance.get_databricks_resources.return_value = [] + + mock.return_value = mock_instance + yield mock_instance + + +@pytest.fixture +def mock_workspace_client(): + """Mock the WorkspaceClient.""" + with patch("databricks_google_adk.mcp.WorkspaceClient") as mock: + mock_instance = MagicMock() + mock_instance.config.host = "https://test-workspace.databricks.com" + mock.return_value = mock_instance + yield mock_instance + + +class TestDatabricksMcpToolset: + """Tests for DatabricksMcpToolset class.""" + + def test_init(self, mock_mcp_client): + """Test DatabricksMcpToolset initialization.""" + toolset = DatabricksMcpToolset( + server_url="https://workspace.databricks.com/api/2.0/mcp/functions/cat/schema" + ) + assert toolset._server_url == "https://workspace.databricks.com/api/2.0/mcp/functions/cat/schema" + + @pytest.mark.asyncio + async def test_get_tools(self, mock_mcp_client): + """Test get_tools returns tools from MCP server.""" + toolset = DatabricksMcpToolset( + server_url="https://workspace.databricks.com/api/2.0/mcp/functions/cat/schema" + ) + tools = await toolset.get_tools() + + assert len(tools) == 2 + mock_mcp_client.list_tools.assert_called_once() + + @pytest.mark.asyncio + async def test_get_tools_caching(self, mock_mcp_client): + """Test that tools are cached after first load.""" + toolset = DatabricksMcpToolset( + server_url="https://workspace.databricks.com/api/2.0/mcp/functions/cat/schema" + ) + + await toolset.get_tools() + await toolset.get_tools() + + # list_tools should only be called once due to caching + assert mock_mcp_client.list_tools.call_count == 1 + + @pytest.mark.asyncio + async def test_get_tools_with_filter(self, mock_mcp_client): + """Test get_tools with tool_filter.""" + toolset = DatabricksMcpToolset( + server_url="https://workspace.databricks.com/api/2.0/mcp/functions/cat/schema", + tool_filter=["tool1"], + ) + tools = await toolset.get_tools() + + # Only tool1 should be returned + assert len(tools) == 1 + + def test_get_databricks_resources(self, mock_mcp_client): + """Test get_databricks_resources delegates to MCP client.""" + toolset = DatabricksMcpToolset( + server_url="https://workspace.databricks.com/api/2.0/mcp/functions/cat/schema" + ) + resources = toolset.get_databricks_resources() + + mock_mcp_client.get_databricks_resources.assert_called_once() + assert resources == [] + + @pytest.mark.asyncio + async def test_close(self, mock_mcp_client): + """Test close method doesn't raise.""" + toolset = DatabricksMcpToolset( + server_url="https://workspace.databricks.com/api/2.0/mcp/functions/cat/schema" + ) + await toolset.close() # Should not raise + + +class TestDatabricksMcpToolsetFactories: + """Tests for DatabricksMcpToolset factory methods.""" + + def test_for_uc_functions(self, mock_mcp_client, mock_workspace_client): + """Test for_uc_functions factory method.""" + toolset = DatabricksMcpToolset.for_uc_functions( + catalog="my_catalog", + schema="my_schema", + ) + + expected_url = "https://test-workspace.databricks.com/api/2.0/mcp/functions/my_catalog/my_schema" + assert toolset._server_url == expected_url + + def test_for_vector_search(self, mock_mcp_client, mock_workspace_client): + """Test for_vector_search factory method.""" + toolset = DatabricksMcpToolset.for_vector_search( + catalog="my_catalog", + schema="my_schema", + ) + + expected_url = "https://test-workspace.databricks.com/api/2.0/mcp/vector-search/my_catalog/my_schema" + assert toolset._server_url == expected_url + + def test_for_genie(self, mock_mcp_client, mock_workspace_client): + """Test for_genie factory method.""" + toolset = DatabricksMcpToolset.for_genie( + space_id="my-genie-space", + ) + + expected_url = "https://test-workspace.databricks.com/api/2.0/mcp/genie/my-genie-space" + assert toolset._server_url == expected_url + + +class TestCreateDatabricksMcpToolset: + """Tests for create_databricks_mcp_toolset factory function.""" + + def test_create_uc_functions(self, mock_mcp_client, mock_workspace_client): + """Test creating UC functions toolset.""" + toolset = create_databricks_mcp_toolset( + "uc_functions", + catalog="cat", + schema="sch", + ) + assert "mcp/functions/cat/sch" in toolset._server_url + + def test_create_vector_search(self, mock_mcp_client, mock_workspace_client): + """Test creating Vector Search toolset.""" + toolset = create_databricks_mcp_toolset( + "vector_search", + catalog="cat", + schema="sch", + ) + assert "mcp/vector-search/cat/sch" in toolset._server_url + + def test_create_genie(self, mock_mcp_client, mock_workspace_client): + """Test creating Genie toolset.""" + toolset = create_databricks_mcp_toolset( + "genie", + space_id="space-123", + ) + assert "mcp/genie/space-123" in toolset._server_url + + def test_create_uc_functions_missing_params(self, mock_mcp_client): + """Test that missing catalog/schema raises ValueError.""" + with pytest.raises(ValueError, match="catalog and schema are required"): + create_databricks_mcp_toolset("uc_functions") + + def test_create_genie_missing_space_id(self, mock_mcp_client): + """Test that missing space_id raises ValueError.""" + with pytest.raises(ValueError, match="space_id is required"): + create_databricks_mcp_toolset("genie") + + def test_create_unknown_type(self, mock_mcp_client): + """Test that unknown server_type raises ValueError.""" + with pytest.raises(ValueError, match="Unknown server_type"): + create_databricks_mcp_toolset("unknown_type")