Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/google/adk/flows/llm_flows/base_llm_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def _finalize_model_response_event(
llm_request: LlmRequest,
llm_response: LlmResponse,
model_response_event: Event,
function_call_ids: Optional[dict[tuple[str, int], str]] = None,
) -> Event:
"""Finalize and build the model response event from LLM response.

Expand All @@ -91,6 +92,11 @@ def _finalize_model_response_event(
llm_request: The original LLM request.
llm_response: The LLM response from the model.
model_response_event: The base event to populate.
function_call_ids: Optional mutable dict mapping (function_name, index) to
previously assigned client function call IDs. Used during SSE streaming
to ensure partial and final events for the same function call share the
same ID. When provided, newly generated IDs are stored back into this
dict for reuse by subsequent events in the same streaming sequence.

Returns:
The finalized Event with LLM response data merged in.
Expand All @@ -103,7 +109,23 @@ def _finalize_model_response_event(
if finalized_event.content:
function_calls = finalized_event.get_function_calls()
if function_calls:
# Restore previously assigned IDs before populating new ones so that
# partial and final events in an SSE stream share the same IDs.
if function_call_ids is not None:
for i, fc in enumerate(function_calls):
key = (fc.name, i)
if key in function_call_ids:
fc.id = function_call_ids[key]

functions.populate_client_function_call_id(finalized_event)

# Persist any newly generated IDs for subsequent events.
if function_call_ids is not None:
for i, fc in enumerate(function_calls):
key = (fc.name, i)
if fc.id:
function_call_ids.setdefault(key, fc.id)

finalized_event.long_running_tool_ids = (
functions.get_long_running_function_calls(
function_calls, llm_request.tools_dict
Expand Down Expand Up @@ -821,6 +843,9 @@ async def _run_one_step_async(
author=invocation_context.agent.name,
branch=invocation_context.branch,
)
# Track function call IDs across partial/final events in SSE streaming
# so that the same function call keeps the same client-generated ID.
function_call_ids: dict[tuple[str, int], str] = {}
async with Aclosing(
self._call_llm_async(
invocation_context, llm_request, model_response_event
Expand All @@ -834,6 +859,7 @@ async def _run_one_step_async(
llm_request,
llm_response,
model_response_event,
function_call_ids,
)
) as agen:
async for event in agen:
Expand Down Expand Up @@ -880,6 +906,7 @@ async def _postprocess_async(
llm_request: LlmRequest,
llm_response: LlmResponse,
model_response_event: Event,
function_call_ids: Optional[dict[tuple[str, int], str]] = None,
) -> AsyncGenerator[Event, None]:
"""Postprocess after calling the LLM.

Expand All @@ -888,6 +915,8 @@ async def _postprocess_async(
llm_request: The original LLM request.
llm_response: The LLM response from the LLM call.
model_response_event: A mutable event for the LLM response.
function_call_ids: Optional mutable dict for preserving function call IDs
across partial and final events in an SSE streaming sequence.

Yields:
A generator of events.
Expand All @@ -911,7 +940,7 @@ async def _postprocess_async(

# Builds the event.
model_response_event = self._finalize_model_response_event(
llm_request, llm_response, model_response_event
llm_request, llm_response, model_response_event, function_call_ids
)
yield model_response_event

Expand Down Expand Up @@ -1191,9 +1220,10 @@ def _finalize_model_response_event(
llm_request: LlmRequest,
llm_response: LlmResponse,
model_response_event: Event,
function_call_ids: Optional[dict[tuple[str, int], str]] = None,
) -> Event:
return _finalize_model_response_event(
llm_request, llm_response, model_response_event
llm_request, llm_response, model_response_event, function_call_ids
)

async def _resolve_toolset_auth(
Expand Down
186 changes: 186 additions & 0 deletions tests/unittests/flows/llm_flows/test_streaming_function_call_id.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests that SSE streaming preserves function call IDs across partial/final events.

Regression test for https://github.com/google/adk-python/issues/4609
"""

from google.adk.events.event import Event
from google.adk.flows.llm_flows.base_llm_flow import _finalize_model_response_event
from google.adk.models.llm_request import LlmRequest
from google.adk.models.llm_response import LlmResponse
from google.genai import types
import pytest


def _make_base_event() -> Event:
return Event(
id=Event.new_id(),
invocation_id="test-inv",
author="test-agent",
)


def _make_llm_response(*, partial: bool, fc_name: str = "get_weather") -> LlmResponse:
return LlmResponse(
content=types.Content(
role="model",
parts=[
types.Part(
function_call=types.FunctionCall(
name=fc_name,
args={"location": "NYC"},
)
)
],
),
partial=partial,
)


def _make_llm_request() -> LlmRequest:
req = LlmRequest()
req.tools_dict = {}
return req


class TestStreamingFunctionCallIdConsistency:
"""Ensure partial and final events share the same function call ID."""

def test_partial_and_final_share_same_id(self):
"""The core regression: partial event ID must equal final event ID."""
llm_request = _make_llm_request()
function_call_ids: dict[tuple[str, int], str] = {}

# Simulate partial event
partial_event = _finalize_model_response_event(
llm_request,
_make_llm_response(partial=True),
_make_base_event(),
function_call_ids,
)
partial_fc_id = partial_event.get_function_calls()[0].id
assert partial_fc_id is not None
assert partial_fc_id.startswith("adk-")

# Simulate final event (new Event object, same streaming sequence)
final_event = _finalize_model_response_event(
llm_request,
_make_llm_response(partial=False),
_make_base_event(),
function_call_ids,
)
final_fc_id = final_event.get_function_calls()[0].id

assert final_fc_id == partial_fc_id

def test_without_function_call_ids_dict_generates_different_ids(self):
"""Without the fix dict, each event gets a fresh ID (old behaviour)."""
llm_request = _make_llm_request()

partial_event = _finalize_model_response_event(
llm_request,
_make_llm_response(partial=True),
_make_base_event(),
)
final_event = _finalize_model_response_event(
llm_request,
_make_llm_response(partial=False),
_make_base_event(),
)

# Without the dict, IDs differ (demonstrating the old bug)
assert (
partial_event.get_function_calls()[0].id
!= final_event.get_function_calls()[0].id
)

def test_multiple_function_calls_preserve_ids(self):
"""Each function call in a multi-call response keeps its own stable ID."""
llm_request = _make_llm_request()
function_call_ids: dict[tuple[str, int], str] = {}

def make_multi_fc_response(partial: bool) -> LlmResponse:
return LlmResponse(
content=types.Content(
role="model",
parts=[
types.Part(
function_call=types.FunctionCall(
name="get_weather",
args={"location": "NYC"},
)
),
types.Part(
function_call=types.FunctionCall(
name="get_time",
args={"timezone": "EST"},
)
),
],
),
partial=partial,
)

partial_event = _finalize_model_response_event(
llm_request,
make_multi_fc_response(partial=True),
_make_base_event(),
function_call_ids,
)
partial_ids = [fc.id for fc in partial_event.get_function_calls()]

final_event = _finalize_model_response_event(
llm_request,
make_multi_fc_response(partial=False),
_make_base_event(),
function_call_ids,
)
final_ids = [fc.id for fc in final_event.get_function_calls()]

assert partial_ids == final_ids
# The two function calls should have different IDs from each other
assert partial_ids[0] != partial_ids[1]

def test_server_provided_id_is_preserved(self):
"""If the server already provides an ID, it should not be overwritten."""
llm_request = _make_llm_request()
function_call_ids: dict[tuple[str, int], str] = {}

server_id = "server-provided-id-123"
response = LlmResponse(
content=types.Content(
role="model",
parts=[
types.Part(
function_call=types.FunctionCall(
id=server_id,
name="get_weather",
args={"location": "NYC"},
)
)
],
),
partial=False,
)

event = _finalize_model_response_event(
llm_request,
response,
_make_base_event(),
function_call_ids,
)

assert event.get_function_calls()[0].id == server_id