From 757a0fcb02319b60d5fb5ae58c8d2999972c3a48 Mon Sep 17 00:00:00 2001 From: stakeswky Date: Wed, 25 Feb 2026 07:49:19 +0800 Subject: [PATCH 1/2] fix: preserve function call IDs across SSE streaming partial/final events In SSE streaming mode, _finalize_model_response_event creates a brand-new Event object for every LlmResponse chunk (partial and final). Because each new Event's function calls start with an empty .id, populate_client_function_call_id generates a fresh adk-{uuid} every time. This means the partial event and the final event for the same function call end up with different IDs, breaking LongRunningFunctionTool / HITL workflows that match responses by ID. Fix: add an optional function_call_ids dict parameter to _finalize_model_response_event. The dict maps (function_name, index) to the ID that was assigned the first time that function call was seen. Before populate_client_function_call_id runs, any previously stored ID is restored onto the function call so the guard 'if not function_call.id' keeps it. After population, newly generated IDs are written back into the dict. _run_one_step_async creates one such dict per LLM call and threads it through _postprocess_async for the lifetime of the streaming sequence, so all partial and final events share the same stable IDs. Fixes #4609 --- .../adk/flows/llm_flows/base_llm_flow.py | 34 +++- .../test_streaming_function_call_id.py | 186 ++++++++++++++++++ 2 files changed, 218 insertions(+), 2 deletions(-) create mode 100644 tests/unittests/flows/llm_flows/test_streaming_function_call_id.py diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 424bb580e1..803b8abb27 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -81,6 +81,7 @@ def _finalize_model_response_event( llm_request: LlmRequest, llm_response: LlmResponse, model_response_event: Event, + function_call_ids: Optional[dict[tuple[str, int], str]] = None, ) -> Event: """Finalize and build the model response event from LLM response. @@ -91,6 +92,11 @@ def _finalize_model_response_event( llm_request: The original LLM request. llm_response: The LLM response from the model. model_response_event: The base event to populate. + function_call_ids: Optional mutable dict mapping (function_name, index) to + previously assigned client function call IDs. Used during SSE streaming + to ensure partial and final events for the same function call share the + same ID. When provided, newly generated IDs are stored back into this + dict for reuse by subsequent events in the same streaming sequence. Returns: The finalized Event with LLM response data merged in. @@ -103,7 +109,23 @@ def _finalize_model_response_event( if finalized_event.content: function_calls = finalized_event.get_function_calls() if function_calls: + # Restore previously assigned IDs before populating new ones so that + # partial and final events in an SSE stream share the same IDs. + if function_call_ids is not None: + for i, fc in enumerate(function_calls): + key = (fc.name, i) + if key in function_call_ids: + fc.id = function_call_ids[key] + functions.populate_client_function_call_id(finalized_event) + + # Persist any newly generated IDs for subsequent events. + if function_call_ids is not None: + for i, fc in enumerate(function_calls): + key = (fc.name, i) + if fc.id and key not in function_call_ids: + function_call_ids[key] = fc.id + finalized_event.long_running_tool_ids = ( functions.get_long_running_function_calls( function_calls, llm_request.tools_dict @@ -821,6 +843,9 @@ async def _run_one_step_async( author=invocation_context.agent.name, branch=invocation_context.branch, ) + # Track function call IDs across partial/final events in SSE streaming + # so that the same function call keeps the same client-generated ID. + function_call_ids: dict[tuple[str, int], str] = {} async with Aclosing( self._call_llm_async( invocation_context, llm_request, model_response_event @@ -834,6 +859,7 @@ async def _run_one_step_async( llm_request, llm_response, model_response_event, + function_call_ids, ) ) as agen: async for event in agen: @@ -880,6 +906,7 @@ async def _postprocess_async( llm_request: LlmRequest, llm_response: LlmResponse, model_response_event: Event, + function_call_ids: Optional[dict[tuple[str, int], str]] = None, ) -> AsyncGenerator[Event, None]: """Postprocess after calling the LLM. @@ -888,6 +915,8 @@ async def _postprocess_async( llm_request: The original LLM request. llm_response: The LLM response from the LLM call. model_response_event: A mutable event for the LLM response. + function_call_ids: Optional mutable dict for preserving function call IDs + across partial and final events in an SSE streaming sequence. Yields: A generator of events. @@ -911,7 +940,7 @@ async def _postprocess_async( # Builds the event. model_response_event = self._finalize_model_response_event( - llm_request, llm_response, model_response_event + llm_request, llm_response, model_response_event, function_call_ids ) yield model_response_event @@ -1191,9 +1220,10 @@ def _finalize_model_response_event( llm_request: LlmRequest, llm_response: LlmResponse, model_response_event: Event, + function_call_ids: Optional[dict[tuple[str, int], str]] = None, ) -> Event: return _finalize_model_response_event( - llm_request, llm_response, model_response_event + llm_request, llm_response, model_response_event, function_call_ids ) async def _resolve_toolset_auth( diff --git a/tests/unittests/flows/llm_flows/test_streaming_function_call_id.py b/tests/unittests/flows/llm_flows/test_streaming_function_call_id.py new file mode 100644 index 0000000000..849553848d --- /dev/null +++ b/tests/unittests/flows/llm_flows/test_streaming_function_call_id.py @@ -0,0 +1,186 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests that SSE streaming preserves function call IDs across partial/final events. + +Regression test for https://github.com/google/adk-python/issues/4609 +""" + +from google.adk.events.event import Event +from google.adk.flows.llm_flows.base_llm_flow import _finalize_model_response_event +from google.adk.models.llm_request import LlmRequest +from google.adk.models.llm_response import LlmResponse +from google.genai import types +import pytest + + +def _make_base_event() -> Event: + return Event( + id=Event.new_id(), + invocation_id="test-inv", + author="test-agent", + ) + + +def _make_llm_response(*, partial: bool, fc_name: str = "get_weather") -> LlmResponse: + return LlmResponse( + content=types.Content( + role="model", + parts=[ + types.Part( + function_call=types.FunctionCall( + name=fc_name, + args={"location": "NYC"}, + ) + ) + ], + ), + partial=partial, + ) + + +def _make_llm_request() -> LlmRequest: + req = LlmRequest() + req.tools_dict = {} + return req + + +class TestStreamingFunctionCallIdConsistency: + """Ensure partial and final events share the same function call ID.""" + + def test_partial_and_final_share_same_id(self): + """The core regression: partial event ID must equal final event ID.""" + llm_request = _make_llm_request() + function_call_ids: dict[tuple[str, int], str] = {} + + # Simulate partial event + partial_event = _finalize_model_response_event( + llm_request, + _make_llm_response(partial=True), + _make_base_event(), + function_call_ids, + ) + partial_fc_id = partial_event.get_function_calls()[0].id + assert partial_fc_id is not None + assert partial_fc_id.startswith("adk-") + + # Simulate final event (new Event object, same streaming sequence) + final_event = _finalize_model_response_event( + llm_request, + _make_llm_response(partial=False), + _make_base_event(), + function_call_ids, + ) + final_fc_id = final_event.get_function_calls()[0].id + + assert final_fc_id == partial_fc_id + + def test_without_function_call_ids_dict_generates_different_ids(self): + """Without the fix dict, each event gets a fresh ID (old behaviour).""" + llm_request = _make_llm_request() + + partial_event = _finalize_model_response_event( + llm_request, + _make_llm_response(partial=True), + _make_base_event(), + ) + final_event = _finalize_model_response_event( + llm_request, + _make_llm_response(partial=False), + _make_base_event(), + ) + + # Without the dict, IDs differ (demonstrating the old bug) + assert ( + partial_event.get_function_calls()[0].id + != final_event.get_function_calls()[0].id + ) + + def test_multiple_function_calls_preserve_ids(self): + """Each function call in a multi-call response keeps its own stable ID.""" + llm_request = _make_llm_request() + function_call_ids: dict[tuple[str, int], str] = {} + + def make_multi_fc_response(partial: bool) -> LlmResponse: + return LlmResponse( + content=types.Content( + role="model", + parts=[ + types.Part( + function_call=types.FunctionCall( + name="get_weather", + args={"location": "NYC"}, + ) + ), + types.Part( + function_call=types.FunctionCall( + name="get_time", + args={"timezone": "EST"}, + ) + ), + ], + ), + partial=partial, + ) + + partial_event = _finalize_model_response_event( + llm_request, + make_multi_fc_response(partial=True), + _make_base_event(), + function_call_ids, + ) + partial_ids = [fc.id for fc in partial_event.get_function_calls()] + + final_event = _finalize_model_response_event( + llm_request, + make_multi_fc_response(partial=False), + _make_base_event(), + function_call_ids, + ) + final_ids = [fc.id for fc in final_event.get_function_calls()] + + assert partial_ids == final_ids + # The two function calls should have different IDs from each other + assert partial_ids[0] != partial_ids[1] + + def test_server_provided_id_is_preserved(self): + """If the server already provides an ID, it should not be overwritten.""" + llm_request = _make_llm_request() + function_call_ids: dict[tuple[str, int], str] = {} + + server_id = "server-provided-id-123" + response = LlmResponse( + content=types.Content( + role="model", + parts=[ + types.Part( + function_call=types.FunctionCall( + id=server_id, + name="get_weather", + args={"location": "NYC"}, + ) + ) + ], + ), + partial=False, + ) + + event = _finalize_model_response_event( + llm_request, + response, + _make_base_event(), + function_call_ids, + ) + + assert event.get_function_calls()[0].id == server_id From 94ba5946142ae50a150f296379f0a7139c705d56 Mon Sep 17 00:00:00 2001 From: stakeswky Date: Wed, 25 Feb 2026 08:36:50 +0800 Subject: [PATCH 2/2] refactor: use dict.setdefault() for ID persistence per review suggestion --- src/google/adk/flows/llm_flows/base_llm_flow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/google/adk/flows/llm_flows/base_llm_flow.py b/src/google/adk/flows/llm_flows/base_llm_flow.py index 803b8abb27..ac114dd209 100644 --- a/src/google/adk/flows/llm_flows/base_llm_flow.py +++ b/src/google/adk/flows/llm_flows/base_llm_flow.py @@ -123,8 +123,8 @@ def _finalize_model_response_event( if function_call_ids is not None: for i, fc in enumerate(function_calls): key = (fc.name, i) - if fc.id and key not in function_call_ids: - function_call_ids[key] = fc.id + if fc.id: + function_call_ids.setdefault(key, fc.id) finalized_event.long_running_tool_ids = ( functions.get_long_running_function_calls(