From a741ced4d3f4a2e2ef746b71a7992704bec7073d Mon Sep 17 00:00:00 2001 From: David Ahmann Date: Sun, 22 Feb 2026 11:35:52 -0500 Subject: [PATCH 1/2] flows: persist stop-streaming cancellation state atomically (#4588) --- src/google/adk/flows/llm_flows/functions.py | 12 +++ .../test_functions_stop_streaming_state.py | 89 +++++++++++++++++++ 2 files changed, 101 insertions(+) create mode 100644 tests/unittests/flows/llm_flows/test_functions_stop_streaming_state.py diff --git a/src/google/adk/flows/llm_flows/functions.py b/src/google/adk/flows/llm_flows/functions.py index 4d045face6..582abb423a 100644 --- a/src/google/adk/flows/llm_flows/functions.py +++ b/src/google/adk/flows/llm_flows/functions.py @@ -55,6 +55,7 @@ REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential' REQUEST_CONFIRMATION_FUNCTION_CALL_NAME = 'adk_request_confirmation' REQUEST_INPUT_FUNCTION_CALL_NAME = 'adk_request_input' +LONG_RUNNING_CANCELLATION_STATE_KEY = '_adk_long_running_tool_cancellations' logger = logging.getLogger('google_adk.' + __name__) @@ -801,6 +802,14 @@ async def _process_function_live_helper( invocation_context, streaming_lock: asyncio.Lock, ): + def _record_cancellation_state(function_name: str, status: str) -> None: + previous = tool_context.state.get(LONG_RUNNING_CANCELLATION_STATE_KEY) + if not isinstance(previous, dict): + previous = {} + updated = dict(previous) + updated[function_name] = status + tool_context.state[LONG_RUNNING_CANCELLATION_STATE_KEY] = updated + function_response = None # Check if this is a stop_streaming function call if ( @@ -840,6 +849,7 @@ async def _process_function_live_helper( function_response = { 'status': f'The task is not cancelled yet for {function_name}.' } + _record_cancellation_state(function_name, 'pending') if not function_response: # Clean up the reference under lock async with streaming_lock: @@ -855,10 +865,12 @@ async def _process_function_live_helper( function_response = { 'status': f'Successfully stopped streaming function {function_name}' } + _record_cancellation_state(function_name, 'cancelled') else: function_response = { 'status': f'No active streaming function named {function_name} found' } + _record_cancellation_state(function_name, 'not_found') elif hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func): # for streaming tool use case # we require the function to be an async generator function diff --git a/tests/unittests/flows/llm_flows/test_functions_stop_streaming_state.py b/tests/unittests/flows/llm_flows/test_functions_stop_streaming_state.py new file mode 100644 index 0000000000..87de8d5753 --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_functions_stop_streaming_state.py @@ -0,0 +1,89 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import asyncio +from types import SimpleNamespace + +from google.adk.agents.active_streaming_tool import ActiveStreamingTool +from google.adk.flows.llm_flows import functions +from google.genai import types +import pytest + + +async def _infinite_stream() -> None: + while True: + await asyncio.sleep(0.1) + + +@pytest.mark.asyncio +async def test_stop_streaming_persists_cancelled_state_atomically(): + task = asyncio.create_task(_infinite_stream()) + invocation_context = SimpleNamespace( + active_streaming_tools={ + 'monitor_stock_price': ActiveStreamingTool(task=task) + } + ) + tool_context = SimpleNamespace(state={}) + streaming_lock = asyncio.Lock() + + function_response = await functions._process_function_live_helper( + tool=SimpleNamespace(name='stop_streaming'), + tool_context=tool_context, + function_call=types.FunctionCall( + name='stop_streaming', + args={'function_name': 'monitor_stock_price'}, + ), + function_args={'function_name': 'monitor_stock_price'}, + invocation_context=invocation_context, + streaming_lock=streaming_lock, + ) + + assert function_response == { + 'status': 'Successfully stopped streaming function monitor_stock_price' + } + assert ( + tool_context.state[functions.LONG_RUNNING_CANCELLATION_STATE_KEY][ + 'monitor_stock_price' + ] + == 'cancelled' + ) + + +@pytest.mark.asyncio +async def test_stop_streaming_persists_not_found_state(): + invocation_context = SimpleNamespace(active_streaming_tools={}) + tool_context = SimpleNamespace(state={}) + streaming_lock = asyncio.Lock() + + function_response = await functions._process_function_live_helper( + tool=SimpleNamespace(name='stop_streaming'), + tool_context=tool_context, + function_call=types.FunctionCall( + name='stop_streaming', + args={'function_name': 'missing_stream'}, + ), + function_args={'function_name': 'missing_stream'}, + invocation_context=invocation_context, + streaming_lock=streaming_lock, + ) + + assert function_response == { + 'status': 'No active streaming function named missing_stream found' + } + assert ( + tool_context.state[functions.LONG_RUNNING_CANCELLATION_STATE_KEY][ + 'missing_stream' + ] + == 'not_found' + ) From aa77920506c87bde6778364264d1918dd5167f24 Mon Sep 17 00:00:00 2001 From: David Ahmann Date: Sun, 22 Feb 2026 11:38:43 -0500 Subject: [PATCH 2/2] tests: cover pending cancellation state persistence (#4588) --- .../test_functions_stop_streaming_state.py | 48 +++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/tests/unittests/flows/llm_flows/test_functions_stop_streaming_state.py b/tests/unittests/flows/llm_flows/test_functions_stop_streaming_state.py index 87de8d5753..37dd2619c8 100644 --- a/tests/unittests/flows/llm_flows/test_functions_stop_streaming_state.py +++ b/tests/unittests/flows/llm_flows/test_functions_stop_streaming_state.py @@ -87,3 +87,51 @@ async def test_stop_streaming_persists_not_found_state(): ] == 'not_found' ) + + +@pytest.mark.asyncio +async def test_stop_streaming_persists_pending_state_on_timeout(monkeypatch): + async def _slow_cancel() -> None: + try: + while True: + await asyncio.sleep(0.1) + except asyncio.CancelledError: + await asyncio.sleep(2.0) + raise + + task = asyncio.create_task(_slow_cancel()) + invocation_context = SimpleNamespace( + active_streaming_tools={'slow_stream': ActiveStreamingTool(task=task)} + ) + tool_context = SimpleNamespace(state={}) + streaming_lock = asyncio.Lock() + + async def _fake_wait_for(awaitable, timeout): + del awaitable, timeout + raise asyncio.TimeoutError + + monkeypatch.setattr(asyncio, 'wait_for', _fake_wait_for) + + function_response = await functions._process_function_live_helper( + tool=SimpleNamespace(name='stop_streaming'), + tool_context=tool_context, + function_call=types.FunctionCall( + name='stop_streaming', + args={'function_name': 'slow_stream'}, + ), + function_args={'function_name': 'slow_stream'}, + invocation_context=invocation_context, + streaming_lock=streaming_lock, + ) + + assert function_response == { + 'status': 'The task is not cancelled yet for slow_stream.' + } + assert ( + tool_context.state[functions.LONG_RUNNING_CANCELLATION_STATE_KEY][ + 'slow_stream' + ] + == 'pending' + ) + + task.cancel()