Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions src/google/adk/tools/function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
func: Callable[..., Any],
*,
require_confirmation: Union[bool, Callable[..., bool]] = False,
is_high_risk: bool = False,
):
"""Initializes the FunctionTool. Extracts metadata from a callable object.

Expand All @@ -56,6 +57,9 @@ def __init__(
a callable that takes the function's arguments and returns a boolean. If
the callable returns True, the tool will require confirmation from the
user.
is_high_risk: Whether the tool performs high-impact operations. High-risk
tools fail closed unless an explicit confirmation policy resolves to
`True`.
"""
name = ''
doc = ''
Expand All @@ -82,6 +86,7 @@ def __init__(
self.func = func
self._ignore_params = ['tool_context', 'input_stream']
self._require_confirmation = require_confirmation
self._is_high_risk = is_high_risk

@override
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
Expand Down Expand Up @@ -192,6 +197,15 @@ async def run_async(
else:
require_confirmation = bool(self._require_confirmation)

if self._is_high_risk and not require_confirmation:
return {
'error': (
'This high-risk tool requires an explicit confirmation policy.'
' Set require_confirmation=True or provide a callable policy'
' that returns True.'
)
Comment on lines +202 to +206
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This error message is duplicated in src/google/adk/tools/mcp_tool/mcp_tool.py and also hardcoded in the corresponding test files (tests/unittests/tools/test_function_tool.py and tests/unittests/tools/mcp_tool/test_mcp_tool.py). To improve maintainability and avoid inconsistencies, it's best to define this string as a constant in a shared module (e.g., a new src/google/adk/tools/constants.py) and import it in all places where it's used. This ensures that any future changes to the message only need to be made in one place.

}

if require_confirmation:
if not tool_context.tool_confirmation:
args_to_show = args_to_call.copy()
Expand Down
13 changes: 13 additions & 0 deletions src/google/adk/tools/mcp_tool/mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def __init__(
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
require_confirmation: Union[bool, Callable[..., bool]] = False,
is_high_risk: bool = False,
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
Expand All @@ -151,6 +152,8 @@ def __init__(
or a callable that takes the function's arguments and returns a
boolean. If the callable returns True, the tool will require
confirmation from the user.
is_high_risk: Whether this tool is high-risk. High-risk tools fail
closed unless an explicit confirmation policy resolves to `True`.
header_provider: Optional function to provide dynamic headers.
progress_callback: Optional callback to receive progress notifications
from MCP server during long-running tool execution. Can be either:
Expand Down Expand Up @@ -178,6 +181,7 @@ def __init__(
self._mcp_tool = mcp_tool
self._mcp_session_manager = mcp_session_manager
self._require_confirmation = require_confirmation
self._is_high_risk = is_high_risk
self._header_provider = header_provider
self._progress_callback = progress_callback

Expand Down Expand Up @@ -262,6 +266,15 @@ async def run_async(
else:
require_confirmation = bool(self._require_confirmation)

if self._is_high_risk and not require_confirmation:
return {
"error": (
"This high-risk tool requires an explicit confirmation policy."
" Set require_confirmation=True or provide a callable policy"
" that returns True."
)
}

if require_confirmation:
if not tool_context.tool_confirmation:
args_to_show = args.copy()
Expand Down
6 changes: 6 additions & 0 deletions src/google/adk/tools/mcp_tool/mcp_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(
auth_scheme: Optional[AuthScheme] = None,
auth_credential: Optional[AuthCredential] = None,
require_confirmation: Union[bool, Callable[..., bool]] = False,
is_high_risk: bool = False,
header_provider: Optional[
Callable[[ReadonlyContext], Dict[str, str]]
] = None,
Expand Down Expand Up @@ -136,6 +137,9 @@ def __init__(
auth_credential: The auth credential of the tool for tool calling
require_confirmation: Whether tools in this toolset require confirmation.
Can be a single boolean or a callable to apply to all tools.
is_high_risk: Whether tools from this toolset are high-risk. High-risk
tools fail closed unless an explicit confirmation policy resolves to
`True`.
header_provider: A callable that takes a ReadonlyContext and returns a
dictionary of headers to be used for the MCP session.
progress_callback: Optional callback to receive progress notifications
Expand Down Expand Up @@ -170,6 +174,7 @@ def __init__(
self._auth_scheme = auth_scheme
self._auth_credential = auth_credential
self._require_confirmation = require_confirmation
self._is_high_risk = is_high_risk
# Store auth config as instance variable so ADK can populate
# exchanged_auth_credential in-place before calling get_tools()
self._auth_config: Optional[AuthConfig] = (
Expand Down Expand Up @@ -316,6 +321,7 @@ async def get_tools(
auth_scheme=self._auth_scheme,
auth_credential=self._auth_credential,
require_confirmation=self._require_confirmation,
is_high_risk=self._is_high_risk,
header_provider=self._header_provider,
progress_callback=self._progress_callback
if hasattr(self, "_progress_callback")
Expand Down
24 changes: 24 additions & 0 deletions tests/unittests/tools/mcp_tool/test_mcp_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,30 @@ async def test_run_async_require_confirmation_true_no_confirmation(self):
}
tool_context.request_confirmation.assert_called_once()

@pytest.mark.asyncio
async def test_run_async_high_risk_without_confirmation_policy_fails_closed(
self,
):
"""Test that high-risk MCP tools fail closed without explicit policy."""
tool = MCPTool(
mcp_tool=self.mock_mcp_tool,
mcp_session_manager=self.mock_session_manager,
is_high_risk=True,
)
tool_context = Mock(spec=ToolContext)
tool_context.tool_confirmation = None
args = {"param1": "test_value"}

result = await tool.run_async(args=args, tool_context=tool_context)

assert result == {
"error": (
"This high-risk tool requires an explicit confirmation policy. Set"
" require_confirmation=True or provide a callable policy that"
" returns True."
)
}

@pytest.mark.asyncio
async def test_run_async_require_confirmation_true_rejected(self):
"""Test require_confirmation=True with rejection in context."""
Expand Down
26 changes: 26 additions & 0 deletions tests/unittests/tools/test_function_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -417,6 +417,32 @@ def sample_func(arg1: str):
assert result == {"received_arg": "hello"}


@pytest.mark.asyncio
async def test_run_async_high_risk_without_confirmation_policy_fails_closed():
"""Test that high-risk tools fail closed without explicit confirmation policy."""

def sample_func(arg1: str):
return {"received_arg": arg1}

tool = FunctionTool(sample_func, is_high_risk=True)
mock_invocation_context = MagicMock(spec=InvocationContext)
mock_invocation_context.session = MagicMock(spec=Session)
mock_invocation_context.session.state = MagicMock()
tool_context_mock = ToolContext(invocation_context=mock_invocation_context)

result = await tool.run_async(
args={"arg1": "hello"},
tool_context=tool_context_mock,
)
assert result == {
"error": (
"This high-risk tool requires an explicit confirmation policy. Set"
" require_confirmation=True or provide a callable policy that returns"
" True."
)
}


@pytest.mark.asyncio
async def test_run_async_parameter_filtering(mock_tool_context):
"""Test that parameter filtering works correctly for functions with explicit parameters."""
Expand Down