Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/google/adk/flows/llm_flows/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = 'adk_request_confirmation'
REQUEST_INPUT_FUNCTION_CALL_NAME = 'adk_request_input'
LONG_RUNNING_CANCELLATION_STATE_KEY = '_adk_long_running_tool_cancellations'

logger = logging.getLogger('google_adk.' + __name__)

Expand Down Expand Up @@ -801,6 +802,14 @@ async def _process_function_live_helper(
invocation_context,
streaming_lock: asyncio.Lock,
):
def _record_cancellation_state(function_name: str, status: str) -> None:
previous = tool_context.state.get(LONG_RUNNING_CANCELLATION_STATE_KEY)
if not isinstance(previous, dict):
previous = {}
updated = dict(previous)
updated[function_name] = status
tool_context.state[LONG_RUNNING_CANCELLATION_STATE_KEY] = updated
Comment on lines +806 to +811
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This helper function can be refactored for better clarity. Using .copy() is more idiomatic for creating a shallow copy of a dictionary than using the dict() constructor. I've also renamed previous to cancellations to make the variable's purpose clearer.

Suggested change
previous = tool_context.state.get(LONG_RUNNING_CANCELLATION_STATE_KEY)
if not isinstance(previous, dict):
previous = {}
updated = dict(previous)
updated[function_name] = status
tool_context.state[LONG_RUNNING_CANCELLATION_STATE_KEY] = updated
cancellations = tool_context.state.get(LONG_RUNNING_CANCELLATION_STATE_KEY)
if not isinstance(cancellations, dict):
cancellations = {}
updated = cancellations.copy()
updated[function_name] = status
tool_context.state[LONG_RUNNING_CANCELLATION_STATE_KEY] = updated


function_response = None
# Check if this is a stop_streaming function call
if (
Expand Down Expand Up @@ -840,6 +849,7 @@ async def _process_function_live_helper(
function_response = {
'status': f'The task is not cancelled yet for {function_name}.'
}
_record_cancellation_state(function_name, 'pending')
if not function_response:
# Clean up the reference under lock
async with streaming_lock:
Expand All @@ -855,10 +865,12 @@ async def _process_function_live_helper(
function_response = {
'status': f'Successfully stopped streaming function {function_name}'
}
_record_cancellation_state(function_name, 'cancelled')
else:
function_response = {
'status': f'No active streaming function named {function_name} found'
}
_record_cancellation_state(function_name, 'not_found')
elif hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func):
# for streaming tool use case
# we require the function to be an async generator function
Expand Down
137 changes: 137 additions & 0 deletions tests/unittests/flows/llm_flows/test_functions_stop_streaming_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
from types import SimpleNamespace

from google.adk.agents.active_streaming_tool import ActiveStreamingTool
from google.adk.flows.llm_flows import functions
from google.genai import types
import pytest


async def _infinite_stream() -> None:
while True:
await asyncio.sleep(0.1)


@pytest.mark.asyncio
async def test_stop_streaming_persists_cancelled_state_atomically():
task = asyncio.create_task(_infinite_stream())
invocation_context = SimpleNamespace(
active_streaming_tools={
'monitor_stock_price': ActiveStreamingTool(task=task)
}
)
tool_context = SimpleNamespace(state={})
streaming_lock = asyncio.Lock()

function_response = await functions._process_function_live_helper(
tool=SimpleNamespace(name='stop_streaming'),
tool_context=tool_context,
function_call=types.FunctionCall(
name='stop_streaming',
args={'function_name': 'monitor_stock_price'},
),
function_args={'function_name': 'monitor_stock_price'},
invocation_context=invocation_context,
streaming_lock=streaming_lock,
)

assert function_response == {
'status': 'Successfully stopped streaming function monitor_stock_price'
}
assert (
tool_context.state[functions.LONG_RUNNING_CANCELLATION_STATE_KEY][
'monitor_stock_price'
]
== 'cancelled'
)


@pytest.mark.asyncio
async def test_stop_streaming_persists_not_found_state():
invocation_context = SimpleNamespace(active_streaming_tools={})
tool_context = SimpleNamespace(state={})
streaming_lock = asyncio.Lock()

function_response = await functions._process_function_live_helper(
tool=SimpleNamespace(name='stop_streaming'),
tool_context=tool_context,
function_call=types.FunctionCall(
name='stop_streaming',
args={'function_name': 'missing_stream'},
),
function_args={'function_name': 'missing_stream'},
invocation_context=invocation_context,
streaming_lock=streaming_lock,
)

assert function_response == {
'status': 'No active streaming function named missing_stream found'
}
assert (
tool_context.state[functions.LONG_RUNNING_CANCELLATION_STATE_KEY][
'missing_stream'
]
== 'not_found'
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The new tests cover the 'cancelled' and 'not_found' states, which is great. However, there's a third state, 'pending', that is recorded when a task cancellation times out. It would be beneficial to add a test case for this scenario to ensure complete coverage of the new state persistence logic.

Here is a suggested test case:

@pytest.mark.asyncio
async def test_stop_streaming_persists_pending_state_on_timeout():
    async def slow_cancel_task():
        try:
            while True:
                await asyncio.sleep(0.1)
        except asyncio.CancelledError:
            await asyncio.sleep(2) # Simulate slow cleanup
            raise

    task = asyncio.create_task(slow_cancel_task())
    await asyncio.sleep(0.01) # Give the task a moment to start

    invocation_context = SimpleNamespace(
        active_streaming_tools={
            'slow_tool': ActiveStreamingTool(task=task)
        }
    )
    tool_context = SimpleNamespace(state={})
    streaming_lock = asyncio.Lock()

    function_response = await functions._process_function_live_helper(
        tool=SimpleNamespace(name='stop_streaming'),
        tool_context=tool_context,
        function_call=types.FunctionCall(
            name='stop_streaming',
            args={'function_name': 'slow_tool'},
        ),
        function_args={'function_name': 'slow_tool'},
        invocation_context=invocation_context,
        streaming_lock=streaming_lock,
    )

    assert function_response == {
        'status': 'The task is not cancelled yet for slow_tool.'
    }
    assert (
        tool_context.state[functions.LONG_RUNNING_CANCELLATION_STATE_KEY][
            'slow_tool'
        ]
        == 'pending'
    )
    # Clean up the task to avoid it running forever
    with pytest.raises(asyncio.CancelledError):
        await task



@pytest.mark.asyncio
async def test_stop_streaming_persists_pending_state_on_timeout(monkeypatch):
async def _slow_cancel() -> None:
try:
while True:
await asyncio.sleep(0.1)
except asyncio.CancelledError:
await asyncio.sleep(2.0)
raise

task = asyncio.create_task(_slow_cancel())
invocation_context = SimpleNamespace(
active_streaming_tools={'slow_stream': ActiveStreamingTool(task=task)}
)
tool_context = SimpleNamespace(state={})
streaming_lock = asyncio.Lock()

async def _fake_wait_for(awaitable, timeout):
del awaitable, timeout
raise asyncio.TimeoutError

monkeypatch.setattr(asyncio, 'wait_for', _fake_wait_for)

function_response = await functions._process_function_live_helper(
tool=SimpleNamespace(name='stop_streaming'),
tool_context=tool_context,
function_call=types.FunctionCall(
name='stop_streaming',
args={'function_name': 'slow_stream'},
),
function_args={'function_name': 'slow_stream'},
invocation_context=invocation_context,
streaming_lock=streaming_lock,
)

assert function_response == {
'status': 'The task is not cancelled yet for slow_stream.'
}
assert (
tool_context.state[functions.LONG_RUNNING_CANCELLATION_STATE_KEY][
'slow_stream'
]
== 'pending'
)

task.cancel()