diff --git a/src/google/adk/tools/function_tool.py b/src/google/adk/tools/function_tool.py index 6b8496dc58..6c0657c743 100644 --- a/src/google/adk/tools/function_tool.py +++ b/src/google/adk/tools/function_tool.py @@ -47,6 +47,7 @@ def __init__( func: Callable[..., Any], *, require_confirmation: Union[bool, Callable[..., bool]] = False, + is_high_risk: bool = False, ): """Initializes the FunctionTool. Extracts metadata from a callable object. @@ -56,6 +57,9 @@ def __init__( a callable that takes the function's arguments and returns a boolean. If the callable returns True, the tool will require confirmation from the user. + is_high_risk: Whether the tool performs high-impact operations. High-risk + tools fail closed unless an explicit confirmation policy resolves to + `True`. """ name = '' doc = '' @@ -82,6 +86,7 @@ def __init__( self.func = func self._ignore_params = ['tool_context', 'input_stream'] self._require_confirmation = require_confirmation + self._is_high_risk = is_high_risk @override def _get_declaration(self) -> Optional[types.FunctionDeclaration]: @@ -192,6 +197,15 @@ async def run_async( else: require_confirmation = bool(self._require_confirmation) + if self._is_high_risk and not require_confirmation: + return { + 'error': ( + 'This high-risk tool requires an explicit confirmation policy.' + ' Set require_confirmation=True or provide a callable policy' + ' that returns True.' + ) + } + if require_confirmation: if not tool_context.tool_confirmation: args_to_show = args_to_call.copy() diff --git a/src/google/adk/tools/mcp_tool/mcp_tool.py b/src/google/adk/tools/mcp_tool/mcp_tool.py index f31768a01e..213b900a50 100644 --- a/src/google/adk/tools/mcp_tool/mcp_tool.py +++ b/src/google/adk/tools/mcp_tool/mcp_tool.py @@ -130,6 +130,7 @@ def __init__( auth_scheme: Optional[AuthScheme] = None, auth_credential: Optional[AuthCredential] = None, require_confirmation: Union[bool, Callable[..., bool]] = False, + is_high_risk: bool = False, header_provider: Optional[ Callable[[ReadonlyContext], Dict[str, str]] ] = None, @@ -151,6 +152,8 @@ def __init__( or a callable that takes the function's arguments and returns a boolean. If the callable returns True, the tool will require confirmation from the user. + is_high_risk: Whether this tool is high-risk. High-risk tools fail + closed unless an explicit confirmation policy resolves to `True`. header_provider: Optional function to provide dynamic headers. progress_callback: Optional callback to receive progress notifications from MCP server during long-running tool execution. Can be either: @@ -178,6 +181,7 @@ def __init__( self._mcp_tool = mcp_tool self._mcp_session_manager = mcp_session_manager self._require_confirmation = require_confirmation + self._is_high_risk = is_high_risk self._header_provider = header_provider self._progress_callback = progress_callback @@ -262,6 +266,15 @@ async def run_async( else: require_confirmation = bool(self._require_confirmation) + if self._is_high_risk and not require_confirmation: + return { + "error": ( + "This high-risk tool requires an explicit confirmation policy." + " Set require_confirmation=True or provide a callable policy" + " that returns True." + ) + } + if require_confirmation: if not tool_context.tool_confirmation: args_to_show = args.copy() diff --git a/src/google/adk/tools/mcp_tool/mcp_toolset.py b/src/google/adk/tools/mcp_tool/mcp_toolset.py index fb4e992dfd..8390d53251 100644 --- a/src/google/adk/tools/mcp_tool/mcp_toolset.py +++ b/src/google/adk/tools/mcp_tool/mcp_toolset.py @@ -107,6 +107,7 @@ def __init__( auth_scheme: Optional[AuthScheme] = None, auth_credential: Optional[AuthCredential] = None, require_confirmation: Union[bool, Callable[..., bool]] = False, + is_high_risk: bool = False, header_provider: Optional[ Callable[[ReadonlyContext], Dict[str, str]] ] = None, @@ -136,6 +137,9 @@ def __init__( auth_credential: The auth credential of the tool for tool calling require_confirmation: Whether tools in this toolset require confirmation. Can be a single boolean or a callable to apply to all tools. + is_high_risk: Whether tools from this toolset are high-risk. High-risk + tools fail closed unless an explicit confirmation policy resolves to + `True`. header_provider: A callable that takes a ReadonlyContext and returns a dictionary of headers to be used for the MCP session. progress_callback: Optional callback to receive progress notifications @@ -170,6 +174,7 @@ def __init__( self._auth_scheme = auth_scheme self._auth_credential = auth_credential self._require_confirmation = require_confirmation + self._is_high_risk = is_high_risk # Store auth config as instance variable so ADK can populate # exchanged_auth_credential in-place before calling get_tools() self._auth_config: Optional[AuthConfig] = ( @@ -316,6 +321,7 @@ async def get_tools( auth_scheme=self._auth_scheme, auth_credential=self._auth_credential, require_confirmation=self._require_confirmation, + is_high_risk=self._is_high_risk, header_provider=self._header_provider, progress_callback=self._progress_callback if hasattr(self, "_progress_callback") diff --git a/tests/unittests/tools/mcp_tool/test_mcp_tool.py b/tests/unittests/tools/mcp_tool/test_mcp_tool.py index f38a8bbc7a..bb2ce535cc 100644 --- a/tests/unittests/tools/mcp_tool/test_mcp_tool.py +++ b/tests/unittests/tools/mcp_tool/test_mcp_tool.py @@ -685,6 +685,30 @@ async def test_run_async_require_confirmation_true_no_confirmation(self): } tool_context.request_confirmation.assert_called_once() + @pytest.mark.asyncio + async def test_run_async_high_risk_without_confirmation_policy_fails_closed( + self, + ): + """Test that high-risk MCP tools fail closed without explicit policy.""" + tool = MCPTool( + mcp_tool=self.mock_mcp_tool, + mcp_session_manager=self.mock_session_manager, + is_high_risk=True, + ) + tool_context = Mock(spec=ToolContext) + tool_context.tool_confirmation = None + args = {"param1": "test_value"} + + result = await tool.run_async(args=args, tool_context=tool_context) + + assert result == { + "error": ( + "This high-risk tool requires an explicit confirmation policy. Set" + " require_confirmation=True or provide a callable policy that" + " returns True." + ) + } + @pytest.mark.asyncio async def test_run_async_require_confirmation_true_rejected(self): """Test require_confirmation=True with rejection in context.""" diff --git a/tests/unittests/tools/test_function_tool.py b/tests/unittests/tools/test_function_tool.py index 9b1d1abd11..58a5870517 100644 --- a/tests/unittests/tools/test_function_tool.py +++ b/tests/unittests/tools/test_function_tool.py @@ -417,6 +417,32 @@ def sample_func(arg1: str): assert result == {"received_arg": "hello"} +@pytest.mark.asyncio +async def test_run_async_high_risk_without_confirmation_policy_fails_closed(): + """Test that high-risk tools fail closed without explicit confirmation policy.""" + + def sample_func(arg1: str): + return {"received_arg": arg1} + + tool = FunctionTool(sample_func, is_high_risk=True) + mock_invocation_context = MagicMock(spec=InvocationContext) + mock_invocation_context.session = MagicMock(spec=Session) + mock_invocation_context.session.state = MagicMock() + tool_context_mock = ToolContext(invocation_context=mock_invocation_context) + + result = await tool.run_async( + args={"arg1": "hello"}, + tool_context=tool_context_mock, + ) + assert result == { + "error": ( + "This high-risk tool requires an explicit confirmation policy. Set" + " require_confirmation=True or provide a callable policy that returns" + " True." + ) + } + + @pytest.mark.asyncio async def test_run_async_parameter_filtering(mock_tool_context): """Test that parameter filtering works correctly for functions with explicit parameters."""