diff --git a/examples/adk_streaming_thinking_usage.py b/examples/adk_streaming_thinking_usage.py new file mode 100644 index 00000000..975d39ec --- /dev/null +++ b/examples/adk_streaming_thinking_usage.py @@ -0,0 +1,186 @@ +"""ADK adapter example usage: streaming, thinking / reasoning, multi-provider. + +Demonstrates the PortkeyAdk adapter across multiple providers (OpenAI, +Anthropic, Vertex/Gemini) in four modes: + + 1. Non-streaming (basic) + 2. Non-streaming with thinking / reasoning + 3. Streaming (basic) + 4. Streaming with thinking / reasoning + +Requires a valid Portkey API key. + +Usage: + PORTKEY_API_KEY= python examples/adk_streaming_thinking_usage.py +""" + +from __future__ import annotations + +import asyncio +import os +import sys +import traceback +from typing import Any + +_REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) +if _REPO_ROOT not in sys.path: + sys.path.insert(0, _REPO_ROOT) + +from portkey_ai.integrations.adk import PortkeyAdk # noqa: E402 + +try: + from google.adk.models.llm_request import LlmRequest # type: ignore + from google.genai import types as genai_types # type: ignore +except ImportError: + print("google-adk is required: pip install 'portkey-ai[adk]'") + sys.exit(1) + + +PORTKEY_API_KEY = os.environ.get("PORTKEY_API_KEY", "") + +# (display_name, model_slug, supports_thinking) +MODELS: list[tuple[str, str, bool]] = [ + ("OpenAI GPT-5.3", "@openai/gpt-5.3-chat-latest", True), + ("Anthropic Claude Sonnet 4.6", "@anthropic/claude-sonnet-4-6", True), + ("Vertex Gemini 3 Flash", "@vertex-ai/gemini-3-flash-preview", True), +] + +SIMPLE_PROMPT = "What is 2 + 2? Reply with just the number." +REASONING_PROMPT = ( + "A farmer has 17 sheep. All but 9 run away. How many sheep does the " + "farmer have left? Think step by step, then give just the number." +) + + +def _build_request( + model: str, + prompt: str, + *, + enable_thinking: bool = False, + thinking_budget: int = 4096, +) -> LlmRequest: + """Build an LlmRequest with optional thinking config.""" + kwargs: dict[str, Any] = { + "model": model, + "contents": [ + genai_types.Content( + role="user", + parts=[genai_types.Part.from_text(text=prompt)], + ) + ], + } + if enable_thinking: + kwargs["config"] = genai_types.GenerateContentConfig( + thinking_config=genai_types.ThinkingConfig( + include_thoughts=True, + thinking_budget=thinking_budget, + ), + ) + return LlmRequest(**kwargs) + + +def _print_parts(parts: Any, indent: str = " ") -> None: + """Print the parts of an LlmResponse in a readable format.""" + if not parts: + print(f"{indent}(no parts)") + return + for i, p in enumerate(parts): + thought = getattr(p, "thought", False) + text = getattr(p, "text", None) + fc = getattr(p, "function_call", None) + thought_sig = getattr(p, "thought_signature", None) + + if thought and text: + sig = f" (signature: {thought_sig!r})" if thought_sig else "" + print(f"{indent}[thought]{sig}") + for line in text.splitlines(): + print(f"{indent} {line}") + elif text: + print(f"{indent}[text]") + for line in text.splitlines(): + print(f"{indent} {line}") + elif fc: + print( + f"{indent}[function_call] {getattr(fc, 'name', '?')}" + f"({getattr(fc, 'args', {})})" + ) + else: + print(f"{indent}[unknown part] {type(p).__name__}") + + +async def run_non_streaming( + llm: PortkeyAdk, prompt: str, *, enable_thinking: bool = False +) -> None: + req = _build_request(llm.model, prompt, enable_thinking=enable_thinking) + async for resp in llm.generate_content_async(req, stream=False): + parts = getattr(resp.content, "parts", None) if resp.content else None + _print_parts(parts) + + +async def run_streaming( + llm: PortkeyAdk, prompt: str, *, enable_thinking: bool = False +) -> None: + req = _build_request(llm.model, prompt, enable_thinking=enable_thinking) + partial_count = 0 + async for resp in llm.generate_content_async(req, stream=True): + parts = getattr(resp.content, "parts", None) if resp.content else None + if getattr(resp, "partial", False): + partial_count += 1 + for p in parts or []: + text = getattr(p, "text", None) + thought = getattr(p, "thought", False) + if thought and text: + print(f" [thought delta] {text[:120]}", end="") + elif text: + print(text, end="") + else: + print() + print(f" --- final response ({partial_count} streaming chunks) ---") + _print_parts(parts) + + +async def main() -> None: + if not PORTKEY_API_KEY: + print("Set PORTKEY_API_KEY environment variable.") + sys.exit(1) + + for display_name, model_slug, supports_thinking in MODELS: + print(f"\n{'=' * 60}") + print(f" {display_name} ({model_slug})") + print(f"{'=' * 60}") + + llm = PortkeyAdk(model=model_slug, api_key=PORTKEY_API_KEY) + + print(f"\n Non-streaming | prompt: {SIMPLE_PROMPT!r}") + try: + await run_non_streaming(llm, SIMPLE_PROMPT) + except Exception: + traceback.print_exc() + + if supports_thinking: + print(f"\n Non-streaming + thinking | prompt: {REASONING_PROMPT!r}") + try: + await run_non_streaming(llm, REASONING_PROMPT, enable_thinking=True) + except Exception: + traceback.print_exc() + + print(f"\n Streaming | prompt: {SIMPLE_PROMPT!r}") + print(" ", end="") + try: + await run_streaming(llm, SIMPLE_PROMPT) + except Exception: + traceback.print_exc() + + if supports_thinking: + print(f"\n Streaming + thinking | prompt: {REASONING_PROMPT!r}") + print(" ", end="") + try: + await run_streaming(llm, REASONING_PROMPT, enable_thinking=True) + except Exception: + traceback.print_exc() + + print() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/examples/hello_world_portkey_adk.py b/examples/hello_world_portkey_adk.py index 1ebe6fea..56104553 100644 --- a/examples/hello_world_portkey_adk.py +++ b/examples/hello_world_portkey_adk.py @@ -42,10 +42,12 @@ async def main() -> None: req = build_request(model) final_text: List[str] = [] async for resp in llm.generate_content_async(req, stream=False): - if resp.content and getattr(resp.content, "parts", None): - for p in resp.content.parts: - if getattr(p, "text", None): - final_text.append(p.text) + parts = getattr(resp.content, "parts", None) if resp.content else None + if parts: + for p in parts: + text = getattr(p, "text", None) + if text: + final_text.append(text) print("".join(final_text)) diff --git a/portkey_ai/integrations/adk.py b/portkey_ai/integrations/adk.py index fdf9fa22..466fa160 100644 --- a/portkey_ai/integrations/adk.py +++ b/portkey_ai/integrations/adk.py @@ -9,220 +9,78 @@ pip install 'portkey-ai[adk]' Design: -- Keep this adapter tiny. Heavy lifting (OpenAI-compatible streaming, etc.) +- Keep this adapter tiny. Heavy lifting (Responses API transport, etc.) is already provided by Portkey's SDK. We only: - - Map ADK `google.genai.types.Content`/tools -> OpenAI-compatible request. - - Translate OpenAI-style stream chunks -> ADK `LlmResponse` objects. + - Map ADK `google.genai.types.Content`/tools -> Responses API input. + - Translate Responses output and streaming events -> ADK `LlmResponse` objects. - Expose a class compatible with ADK `BaseLlm` interface. """ from __future__ import annotations -from typing import ( - Any, - AsyncGenerator, - AsyncIterator, - Optional, - TYPE_CHECKING, - Iterable, - Union, - Tuple, - Generator, - Literal, - cast, -) -import json +from typing import TYPE_CHECKING, Any, AsyncGenerator, Iterable, Optional import base64 -import logging +import json from portkey_ai import AsyncPortkey -logger = logging.getLogger("portkey_ai.integrations.adk") - -if TYPE_CHECKING: # Only for static typing, never imported at runtime automatically - from google.adk.models.base_llm import BaseLlm as _AdkBaseLlm # type: ignore - from google.adk.models.llm_request import LlmRequest # type: ignore - from google.adk.models.llm_response import LlmResponse # type: ignore +if TYPE_CHECKING: + from google.adk.models.llm_request import LlmRequest # type: ignore # no py.typed + from google.adk.models.llm_response import LlmResponse # type: ignore[import-untyped] # no py.typed try: - # Attempt runtime import so we can subclass the ADK BaseLlm and construct ADK types. - from google.adk.models.base_llm import BaseLlm as _AdkBaseLlm # type: ignore + from google.adk.models.base_llm import BaseLlm as _AdkBaseLlm # type: ignore # no py.typed _HAS_ADK = True -except Exception: # pragma: no cover - when ADK not installed +except Exception: # pragma: no cover _HAS_ADK = False - class _AdkBaseLlm: # type: ignore - """Fallback to allow import of this module without ADK installed.""" - + class _AdkBaseLlm: # type: ignore[no-redef] # stub used when google-adk is absent pass -# ------------------------------- helpers ------------------------------------ - - -class _FunctionChunk: - def __init__( - self, - id: Optional[str], - name: Optional[str], - args: Optional[str], - index: Optional[int] = 0, - ) -> None: - self.id = id - self.name = name - self.args = args - self.index = index or 0 - - -class _TextChunk: - def __init__(self, text: str) -> None: - self.text = text - - -class _UsageMetadataChunk: - def __init__( - self, prompt_tokens: int, completion_tokens: int, total_tokens: int - ) -> None: - self.prompt_tokens = prompt_tokens - self.completion_tokens = completion_tokens - self.total_tokens = total_tokens - - def _safe_json_serialize(obj: Any) -> str: try: return json.dumps(obj, ensure_ascii=False) - except (TypeError, OverflowError): + except (TypeError, ValueError, OverflowError): return str(obj) -def _to_portkey_role(role: Optional[str]) -> Literal["user", "assistant"]: - if role in ["model", "assistant"]: +def _to_input_role(role: Optional[str], system_role: str) -> str: + if role in ("model", "assistant"): return "assistant" + if role in ("system", "developer"): + return system_role if system_role in ("system", "developer") else "developer" return "user" -def _get_content(parts: Iterable[Any]) -> Union[list[dict], str]: - """Convert ADK parts to Portkey/OpenAI-compatible content. +def _normalize_thought_signature(value: Any) -> Optional[str]: + """Normalize thought_signature to ``str``. - Note: we import google.genai.types lazily to avoid runtime import when ADK isn't installed. + Gemini returns raw ``bytes``; Responses API returns a base-64 ``str``. """ - content_objects: list[dict] = [] - # We treat `parts` as an iterable of objects with attributes: text, inline_data(data,mime_type) - for part in parts: - text = getattr(part, "text", None) - inline_data = getattr(part, "inline_data", None) - if text: - # Return simple string when it's a single text part - if isinstance(parts, list) and len(parts) == 1: - return text - content_objects.append({"type": "text", "text": text}) - elif ( - inline_data - and getattr(inline_data, "data", None) - and getattr(inline_data, "mime_type", None) - ): - b64 = base64.b64encode(inline_data.data).decode("utf-8") - data_uri = f"data:{inline_data.mime_type};base64,{b64}" - if inline_data.mime_type.startswith("image"): - content_objects.append( - {"type": "image_url", "image_url": {"url": data_uri}} - ) - elif inline_data.mime_type.startswith("video"): - content_objects.append( - {"type": "video_url", "video_url": {"url": data_uri}} - ) - elif inline_data.mime_type.startswith("audio"): - content_objects.append( - {"type": "audio_url", "audio_url": {"url": data_uri}} - ) - elif inline_data.mime_type == "application/pdf": - content_objects.append( - { - "type": "file", - "file": { - "file_data": data_uri, - "format": inline_data.mime_type, - }, - } - ) - else: - raise ValueError("Portkey(ADK) does not support this content part.") - return content_objects - - -def _content_to_message_param(content: Any) -> Union[dict, list[dict]]: - """Convert ADK `types.Content` to OpenAI-compatible message dict(s).""" - tool_messages: list[dict] = [] - for part in getattr(content, "parts", []) or []: - function_response = getattr(part, "function_response", None) - if function_response: - tool_messages.append( - { - "role": "tool", - "tool_call_id": getattr(function_response, "id", None), - "content": _safe_json_serialize( - getattr(function_response, "response", None) - ), - } - ) - if tool_messages: - return tool_messages if len(tool_messages) > 1 else tool_messages[0] - - role = _to_portkey_role(getattr(content, "role", None)) - message_content = _get_content(getattr(content, "parts", []) or []) or None - - if role == "user": - return {"role": "user", "content": message_content} + if value is None: + return None + if isinstance(value, bytes): + return base64.b64encode(value).decode("utf-8") + if isinstance(value, str): + return value + return None - # assistant/model - tool_calls: list[dict] = [] - content_present = False - for part in getattr(content, "parts", []) or []: - function_call = getattr(part, "function_call", None) - if function_call: - tool_calls.append( - { - "type": "function", - "id": getattr(function_call, "id", None), - "function": { - "name": getattr(function_call, "name", None), - "arguments": _safe_json_serialize( - getattr(function_call, "args", None) - ), - }, - } - ) - elif getattr(part, "text", None) or getattr(part, "inline_data", None): - content_present = True - final_content = message_content if content_present else None - if ( - isinstance(final_content, list) - and final_content - and final_content[0].get("type") == "text" - ): - # Some providers require a plain string when only a single text block - final_content = final_content[0].get("text", "") - - msg: dict[str, Any] = {"role": role, "content": final_content} - if tool_calls: - msg["tool_calls"] = tool_calls - return msg - - -def _schema_to_dict(schema: Any) -> dict: +def _schema_to_dict(schema: Any) -> dict[str, Any]: """Recursively convert ADK Schema to a plain JSON schema dict.""" schema_dict = schema.model_dump(exclude_none=True) if "type" in schema_dict: - t = schema_dict["type"] - schema_dict["type"] = (t.value if hasattr(t, "value") else t).lower() + schema_type = schema_dict["type"] + schema_dict["type"] = ( + schema_type.value if hasattr(schema_type, "value") else schema_type + ).lower() if "items" in schema_dict: items = schema_dict["items"] if isinstance(items, dict): - # Rebuild a Schema object then recurse try: - from google.genai import types as genai_types # type: ignore + from google.genai import types as genai_types # type: ignore[import-untyped] # no py.typed schema_dict["items"] = _schema_to_dict( genai_types.Schema.model_validate(items) @@ -246,8 +104,147 @@ def _schema_to_dict(schema: Any) -> dict: return schema_dict -def _function_declaration_to_tool_param(function_declaration: Any) -> dict: - """Convert ADK FunctionDeclaration to OpenAI tool param dict.""" +def _build_input_content(parts: Iterable[Any]) -> list[dict[str, Any]] | str: + content_objects: list[dict[str, Any]] = [] + for part in parts: + text = getattr(part, "text", None) + inline_data = getattr(part, "inline_data", None) + if text: + if isinstance(parts, list) and len(parts) == 1: + return text + content_objects.append({"type": "input_text", "text": text}) + continue + if not ( + inline_data + and getattr(inline_data, "data", None) + and getattr(inline_data, "mime_type", None) + ): + continue + + b64 = base64.b64encode(inline_data.data).decode("utf-8") + mime_type = inline_data.mime_type + data_uri = f"data:{mime_type};base64,{b64}" + if mime_type.startswith("image"): + content_objects.append( + {"type": "input_image", "image_url": data_uri, "detail": "auto"} + ) + elif mime_type.startswith("audio"): + audio_format = mime_type.split("/")[-1].lower() + if audio_format not in ("mp3", "wav"): + raise ValueError( + "Portkey(ADK) Responses adapter only supports mp3/wav audio inputs." + ) + content_objects.append( + { + "type": "input_audio", + "input_audio": {"data": b64, "format": audio_format}, + } + ) + elif mime_type.startswith("video"): + raise ValueError( + "Portkey(ADK) Responses adapter does not support video input." + ) + else: + content_objects.append( + { + "type": "input_file", + "file_data": data_uri, + "filename": "attachment", + } + ) + return content_objects + + +def _content_to_input_items(content: Any, system_role: str) -> list[dict[str, Any]]: + items: list[dict[str, Any]] = [] + + tool_outputs = [] + for part in getattr(content, "parts", None) or []: + function_response = getattr(part, "function_response", None) + if not function_response: + continue + call_id = getattr(function_response, "id", None) + if not call_id: + raise ValueError( + "FunctionResponse is missing 'id' — cannot build call_id " + "for function_call_output item in Responses API request." + ) + tool_outputs.append( + { + "type": "function_call_output", + "call_id": call_id, + "output": _safe_json_serialize( + getattr(function_response, "response", None) + ), + } + ) + if tool_outputs: + return tool_outputs + + role = _to_input_role(getattr(content, "role", None), system_role) + input_parts = getattr(content, "parts", None) or [] + + text_or_media_present = False + for part in input_parts: + function_call = getattr(part, "function_call", None) + if function_call: + fc_name = getattr(function_call, "name", None) + fc_id = getattr(function_call, "id", None) + if not fc_name: + raise ValueError( + "FunctionCall is missing 'name' — cannot build " + "function_call item for Responses API request." + ) + if not fc_id: + raise ValueError( + "FunctionCall is missing 'id' — cannot build " + "call_id for function_call item in Responses API request." + ) + items.append( + { + "type": "function_call", + "id": fc_id, + "call_id": fc_id, + "name": fc_name, + "arguments": _safe_json_serialize( + getattr(function_call, "args", None) + ), + } + ) + continue + + text = getattr(part, "text", None) + inline_data = getattr(part, "inline_data", None) + thought = getattr(part, "thought", False) + thought_signature = _normalize_thought_signature( + getattr(part, "thought_signature", None) + ) + + if thought and text: + reasoning_item: dict[str, Any] = { + "id": thought_signature or f"reasoning_{len(items)}", + "type": "reasoning", + "summary": [], + "content": [{"type": "reasoning_text", "text": text}], + } + if thought_signature: + reasoning_item["encrypted_content"] = thought_signature + items.append(reasoning_item) + continue + + if text or inline_data: + text_or_media_present = True + + if text_or_media_present: + message_parts_or_text = _build_input_content(input_parts) + items.append( + {"type": "message", "role": role, "content": message_parts_or_text} + ) + + return items + + +def _function_declaration_to_tool_param(function_declaration: Any) -> dict[str, Any]: name = getattr(function_declaration, "name", None) assert name @@ -257,167 +254,37 @@ def _function_declaration_to_tool_param(function_declaration: Any) -> dict: for key, value in params.properties.items(): properties[key] = _schema_to_dict(value) - tool = { + tool: dict[str, Any] = { "type": "function", - "function": { - "name": name, - "description": getattr(function_declaration, "description", "") or "", - "parameters": { - "type": "object", - "properties": properties, - }, - }, + "name": name, + "description": getattr(function_declaration, "description", "") or "", + "parameters": {"type": "object", "properties": properties}, + "strict": True, } - if params and getattr(params, "required", None): - # Help mypy understand nested dict mutation - params_dict = cast(dict, tool["function"]["parameters"]) # type: ignore[index] - params_dict["required"] = params.required - + tool["parameters"]["required"] = params.required return tool -def _model_response_to_chunk( - response: Any, -) -> Generator[ - Tuple[ - Optional[Union[_TextChunk, _FunctionChunk, _UsageMetadataChunk]], Optional[str] - ], - None, - None, +def _get_response_inputs( + llm_request: "LlmRequest", system_role: str = "developer" +) -> tuple[ + list[dict[str, Any]], + Optional[list[dict[str, Any]]], + Optional[dict[str, Any]], + Optional[str], ]: - """Convert Portkey ChatCompletion response/chunk to ADK-friendly chunks.""" - message = None - finish_reason = None - if getattr(response, "choices", None): - choice0 = response.choices[0] - finish_reason = getattr(choice0, "finish_reason", None) - if getattr(choice0, "delta", None): - message = choice0.delta - elif getattr(choice0, "message", None): - message = choice0.message - - if message: - if getattr(message, "content", None): - yield _TextChunk(text=message.content), finish_reason - - tool_calls = getattr(message, "tool_calls", None) - if tool_calls: - for tool_call in tool_calls: - if getattr(tool_call, "type", None) == "function": - yield _FunctionChunk( - id=getattr(tool_call, "id", None), - name=getattr( - getattr(tool_call, "function", None), "name", None - ), - args=getattr( - getattr(tool_call, "function", None), "arguments", None - ), - index=getattr(tool_call, "index", 0), - ), finish_reason - - if finish_reason and not ( - (getattr(message, "content", None)) - or (getattr(message, "tool_calls", None)) - ): - yield None, finish_reason - - if not message: - yield None, None - - usage = getattr(response, "usage", None) - if usage: - yield _UsageMetadataChunk( - prompt_tokens=getattr(usage, "prompt_tokens", 0), - completion_tokens=getattr(usage, "completion_tokens", 0), - total_tokens=getattr(usage, "total_tokens", 0), - ), None - - -def _message_to_generate_content_response(message: Any, is_partial: bool = False) -> "LlmResponse": # type: ignore[name-defined] - """Convert a Portkey-style message object to ADK LlmResponse.""" - from google.genai import types as genai_types # type: ignore - from google.adk.models.llm_response import LlmResponse # type: ignore - - parts: list[Any] = [] - if getattr(message, "content", None): - parts.append(genai_types.Part.from_text(text=message.content)) - - if getattr(message, "tool_calls", None): - for tool_call in message.tool_calls: - if getattr(tool_call, "type", None) == "function": - part = genai_types.Part.from_function_call( - name=getattr(getattr(tool_call, "function", None), "name", None), # type: ignore[arg-type] - args=json.loads( - getattr(getattr(tool_call, "function", None), "arguments", "{}") - or "{}" - ), - ) - # Attach tool_call id if present - try: - part.function_call.id = getattr(tool_call, "id", None) # type: ignore[union-attr] - except Exception: - pass - parts.append(part) - - return LlmResponse( - content=genai_types.Content(role="model", parts=parts), partial=is_partial - ) - - -def _model_response_to_generate_content_response(response: Any) -> "LlmResponse": # type: ignore[name-defined] - from google.genai import types as genai_types # type: ignore - - message = None - if getattr(response, "choices", None): - message = response.choices[0].message - - if not message: - raise ValueError("No message in response") - - llm_response = _message_to_generate_content_response(message) - usage = getattr(response, "usage", None) - if usage: - llm_response.usage_metadata = genai_types.GenerateContentResponseUsageMetadata( - prompt_token_count=getattr(usage, "prompt_tokens", 0), - candidates_token_count=getattr(usage, "completion_tokens", 0), - total_token_count=getattr(usage, "total_tokens", 0), - ) - return llm_response - + input_items: list[dict[str, Any]] = [] + for content in getattr(llm_request, "contents", None) or []: + input_items.extend(_content_to_input_items(content, system_role)) -def _get_completion_inputs( - llm_request: "LlmRequest", system_role: str = "developer" -) -> tuple[list[dict], Optional[list[dict]], Optional[dict]]: # type: ignore[name-defined] - """Convert ADK LlmRequest into OpenAI-compatible inputs for Portkey. - - Args: - llm_request: The ADK request object. - system_role: Which role to use for the system instruction. One of - "developer" (default, aligned with ADK/LiteLLM) or "system" - (for providers that strictly expect a system role). - """ - # 1) Messages - messages: list[dict] = [] - for content in getattr(llm_request, "contents", []) or []: - msg_or_list = _content_to_message_param(content) - if isinstance(msg_or_list, list): - messages.extend(msg_or_list) - elif msg_or_list: - messages.append(msg_or_list) - - # Insert system/developer instruction config = getattr(llm_request, "config", None) system_instruction = getattr(config, "system_instruction", None) if config else None - if system_instruction: - role = system_role if system_role in ("developer", "system") else "developer" - messages.insert(0, {"role": role, "content": system_instruction}) + instructions = str(system_instruction) if system_instruction else None - # 2) Tools - tools: Optional[list[dict]] = None + tools: Optional[list[dict[str, Any]]] = None if config and getattr(config, "tools", None): - # Avoid indexing a Collection directly; use next(iter(...)) for mypy compatibility - tool0 = next(iter(getattr(config, "tools", []) or []), None) + tool0 = next(iter(getattr(config, "tools", None) or []), None) function_declarations = ( getattr(tool0, "function_declarations", None) if tool0 else None ) @@ -426,196 +293,290 @@ def _get_completion_inputs( _function_declaration_to_tool_param(fd) for fd in function_declarations ] - # 3) Response format (convert ADK schema to OpenAI json_schema where possible) - response_format: Optional[dict] = None + text_config: Optional[dict[str, Any]] = None response_schema = getattr(config, "response_schema", None) if config else None if response_schema: try: json_schema = _schema_to_dict(response_schema) - response_format = { - "type": "json_schema", - "json_schema": {"name": "adk_response", "schema": json_schema}, + text_config = { + "format": { + "type": "json_schema", + "name": "adk_response", + "schema": json_schema, + "strict": True, + } } except Exception: - # Best effort: ignore if schema cannot be converted - response_format = None + text_config = None + + return input_items, tools, text_config, instructions + + +def _get_reasoning_config(llm_request: "LlmRequest") -> Optional[dict[str, Any]]: + """Map ADK thinking_config to Portkey Responses ``reasoning`` dict. + + The Portkey Responses API accepts ``reasoning.effort`` as a string + ("low" | "medium" | "high") for **all** providers — Portkey translates + this to each provider's native reasoning mechanism server-side (e.g. + Anthropic's ``budget_tokens``, Google's ``thinking_budget``). + + ADK exposes an integer ``thinking_budget`` (token count), so we map it + to the closest effort level. The thresholds below are best-effort + heuristics; some models may not support every effort level (e.g. + GPT-5.3 only supports "medium"). In those cases the provider returns + a clear error indicating the supported values. + """ + config = getattr(llm_request, "config", None) + thinking_config = getattr(config, "thinking_config", None) if config else None + if not thinking_config or not getattr(thinking_config, "include_thoughts", None): + return None + + budget = getattr(thinking_config, "thinking_budget", None) + effort = "medium" + if budget is not None: + if budget <= 1024: + effort = "low" + elif budget <= 4096: + effort = "medium" + else: + effort = "high" + return {"effort": effort, "summary": "auto"} + + +def _usage_to_metadata(usage: Any) -> Any: + from google.genai import types as genai_types # type: ignore[import-untyped] # no py.typed + + if not usage: + return None + return genai_types.GenerateContentResponseUsageMetadata( + prompt_token_count=getattr(usage, "input_tokens", 0), + candidates_token_count=getattr(usage, "output_tokens", 0), + total_token_count=getattr(usage, "total_tokens", 0), + ) - return messages, tools, response_format +def _build_text_part( + text: str, thought: bool = False, thought_signature: Optional[str] = None +) -> Any: + from google.genai import types as genai_types # type: ignore[import-untyped] # no py.typed -# ----------------------------- main adapter --------------------------------- + part = genai_types.Part.from_text(text=text) + if thought: + part.thought = True + if thought_signature: + part.thought_signature = thought_signature # type: ignore[assignment] # Part.thought_signature typed as bytes but Responses API uses str + return part + + +def _function_call_to_part(item: Any) -> Any: + from google.genai import types as genai_types # type: ignore[import-untyped] # no py.typed + + args_str = getattr(item, "arguments", "{}") or "{}" + name = getattr(item, "name", None) or "function_call" + try: + parsed_args = json.loads(args_str) + except json.JSONDecodeError: + parsed_args = {} + part = genai_types.Part.from_function_call( + name=name, + args=parsed_args, + ) + function_call = getattr(part, "function_call", None) + if function_call is not None: + function_call.id = getattr(item, "call_id", None) or getattr(item, "id", None) + return part -class PortkeyAdk(_AdkBaseLlm): # type: ignore[misc] +def _response_output_to_parts(output: Iterable[Any]) -> list[Any]: + parts: list[Any] = [] + for item in output: + item_type = getattr(item, "type", None) + if item_type == "reasoning": + thought_signature = _normalize_thought_signature( + getattr(item, "encrypted_content", None) + ) + for content in getattr(item, "content", None) or []: + if getattr(content, "type", None) == "reasoning_text": + parts.append( + _build_text_part( + getattr(content, "text", ""), + thought=True, + thought_signature=thought_signature, + ) + ) + if not getattr(item, "content", None): + for summary in getattr(item, "summary", None) or []: + parts.append( + _build_text_part(getattr(summary, "text", ""), thought=True) + ) + elif item_type == "message": + for content in getattr(item, "content", None) or []: + content_type = getattr(content, "type", None) + if content_type == "output_text": + parts.append(_build_text_part(getattr(content, "text", ""))) + elif content_type == "refusal": + parts.append(_build_text_part(getattr(content, "refusal", ""))) + elif item_type == "function_call": + parts.append(_function_call_to_part(item)) + return parts + + +def _response_to_llm_response(response: Any, is_partial: bool = False) -> "LlmResponse": + from google.adk.models.llm_response import LlmResponse # type: ignore[import-untyped] # no py.typed + from google.genai import types as genai_types # type: ignore[import-untyped] # no py.typed + + llm_response = LlmResponse( + content=genai_types.Content( + role="model", + parts=_response_output_to_parts(getattr(response, "output", None) or []), + ), + partial=is_partial, + ) + usage_metadata = _usage_to_metadata(getattr(response, "usage", None)) + if usage_metadata is not None: + llm_response.usage_metadata = usage_metadata + return llm_response + + +def _parts_to_llm_response( + parts: list[Any], is_partial: bool = False, usage: Any = None +) -> "LlmResponse": + from google.adk.models.llm_response import LlmResponse # type: ignore[import-untyped] # no py.typed + from google.genai import types as genai_types # type: ignore[import-untyped] # no py.typed + + llm_response = LlmResponse( + content=genai_types.Content(role="model", parts=parts), + partial=is_partial, + ) + usage_metadata = _usage_to_metadata(usage) + if usage_metadata is not None: + llm_response.usage_metadata = usage_metadata + return llm_response + + +class PortkeyAdk(_AdkBaseLlm): # type: ignore[misc] # _AdkBaseLlm may be a stub when google-adk is absent """ADK `BaseLlm` adapter backed by Portkey Async client.""" - def __init__(self, model: str, api_key: Optional[str] = None, **kwargs: Any) -> None: # type: ignore[override] + def __init__( + self, model: str, api_key: Optional[str] = None, **kwargs: Any + ) -> None: # type: ignore[override] # BaseLlm.__init__ has a different signature if not _HAS_ADK: raise ImportError( "google-adk is not installed. Install with: pip install 'portkey-ai[adk]'" ) - # Initialize ADK BaseLlm (Pydantic BaseModel) so `model` is recorded - # Extract system instruction role preference before BaseLlm init - sys_role = str(kwargs.pop("system_role", "developer")).lower() - self._system_role = ( - sys_role if sys_role in ("developer", "system") else "developer" - ) - super().__init__(model=model, **{k: v for k, v in kwargs.items() if k != "model"}) # type: ignore[misc] + sys_role = str(kwargs.pop("system_role", "developer")).lower() - # Set up Portkey client + # Extract Portkey client kwargs before passing the rest to BaseLlm. client_args: dict[str, Any] = {} if api_key: client_args["api_key"] = api_key - # Common options: virtual key / base_url / provider + Authorization, etc. - if "virtual_key" in kwargs: - client_args["virtual_key"] = kwargs.pop("virtual_key") - if "base_url" in kwargs: - client_args["base_url"] = kwargs.pop("base_url") - if "config" in kwargs: - client_args["config"] = kwargs.pop("config") - if "provider" in kwargs: - client_args["provider"] = kwargs.pop("provider") - if "Authorization" in kwargs: - client_args["Authorization"] = kwargs.pop("Authorization") - - self._client = AsyncPortkey(**client_args) # type: ignore[arg-type] - - # Remaining args are passed through to completion calls (temperature, top_p, etc.) - self._additional_args: dict[str, Any] = dict(kwargs) - # Guard against reserved keys managed by us - self._additional_args.pop("messages", None) - self._additional_args.pop("tools", None) - self._additional_args.pop("stream", None) - - async def generate_content_async(self, llm_request: "LlmRequest", stream: bool = False) -> AsyncGenerator["LlmResponse", None]: # type: ignore[override,name-defined] - """Generate ADK LlmResponse objects using Portkey Chat Completions.""" - # Use ADK BaseLlm helper to ensure a user message exists so model can respond - self._maybe_append_user_content(llm_request) - - messages, tools, response_format = _get_completion_inputs( - llm_request, getattr(self, "_system_role", "developer") + _CLIENT_KEYS = ( + "virtual_key", + "base_url", + "config", + "provider", + "Authorization", + ) + for key in _CLIENT_KEYS: + if key in kwargs: + client_args[key] = kwargs.pop(key) + client_args["strict_open_ai_compliance"] = kwargs.pop( + "strict_open_ai_compliance", False ) - completion_args: dict[str, Any] = { - "model": getattr(self, "model", None), - "messages": messages, - "tools": tools, - # Only include response_format if we successfully converted it - **({"response_format": response_format} if response_format else {}), + base_kwargs = {k: v for k, v in kwargs.items() if k != "model"} + try: + super().__init__(model=model, **base_kwargs) # type: ignore[misc,call-arg] # BaseLlm may be the stub; real BaseLlm signature varies by adk version + except TypeError: + super().__init__() # type: ignore[misc,call-arg] # fallback for older BaseLlm that takes no args + self.model: str = model # type: ignore[assignment] # BaseLlm declares model as a Pydantic field with different type + + # Must be set AFTER super().__init__() -- Pydantic resets __dict__. + self._system_role: str = ( + sys_role if sys_role in ("developer", "system") else "developer" + ) + + self._client = AsyncPortkey(**client_args) # type: ignore[arg-type] # client_args values are Any; AsyncPortkey expects specific types + + _RESERVED_ARG_KEYS = {"input", "tools", "stream", "text", "reasoning"} + self._additional_args: dict[str, Any] = { + k: v for k, v in kwargs.items() if k not in _RESERVED_ARG_KEYS } - completion_args.update(self._additional_args) - if tools and "tool_choice" not in completion_args: - # Encourage tool use when functions are provided, mirroring Strands behavior - completion_args["tool_choice"] = "auto" - - if stream: - # Aggregate streaming text and tool calls to yield ADK LlmResponse objects - text_accum = "" - function_calls: dict[int, dict[str, Any]] = {} - fallback_index = 0 - usage_metadata = None - aggregated_llm_response = None - aggregated_llm_response_with_tool_call = None - - # Await the creation to obtain an async iterator for streaming - stream_obj = await self._client.chat.completions.create(stream=True, **completion_args) # type: ignore[arg-type] - stream_iter = cast(AsyncIterator[Any], stream_obj) - async for part in stream_iter: - for chunk, finish_reason in _model_response_to_chunk(part): - if isinstance(chunk, _FunctionChunk): - idx = chunk.index or fallback_index - if idx not in function_calls: - function_calls[idx] = {"name": "", "args": "", "id": None} - if chunk.name: - function_calls[idx]["name"] += chunk.name - if chunk.args: - function_calls[idx]["args"] += chunk.args - # If args parses as JSON, advance fallback index (handles providers that omit indices) - try: - json.loads(function_calls[idx]["args"]) - fallback_index += 1 - except json.JSONDecodeError: - pass - function_calls[idx]["id"] = ( - chunk.id or function_calls[idx]["id"] or str(idx) - ) - elif isinstance(chunk, _TextChunk): - text_accum += chunk.text - # Yield partials for better interactivity - yield _message_to_generate_content_response( - type( - "Msg", (), {"content": chunk.text, "tool_calls": None} - )(), - is_partial=True, - ) - elif isinstance(chunk, _UsageMetadataChunk): - from google.genai import types as genai_types # type: ignore - - usage_metadata = ( - genai_types.GenerateContentResponseUsageMetadata( - prompt_token_count=chunk.prompt_tokens, - candidates_token_count=chunk.completion_tokens, - total_token_count=chunk.total_tokens, - ) - ) - if (finish_reason in ("tool_calls", "stop")) and function_calls: - # Flush tool calls as a single LlmResponse - tool_calls = [] - for idx, data in function_calls.items(): - if data.get("id"): - tool_calls.append( - type( - "ToolCall", - (), - { - "type": "function", - "id": data["id"], - "function": type( - "Function", - (), - { - "name": data["name"], - "arguments": data["args"], - }, - )(), - "index": idx, - }, - )() - ) - aggregated_llm_response_with_tool_call = ( - _message_to_generate_content_response( - type( - "Msg", (), {"content": "", "tool_calls": tool_calls} - )() - ) - ) - function_calls.clear() - elif finish_reason == "stop" and text_accum: - aggregated_llm_response = _message_to_generate_content_response( - type( - "Msg", (), {"content": text_accum, "tool_calls": None} - )() - ) - text_accum = "" - - # End of stream: yield aggregated responses (attach usage if available) - if aggregated_llm_response: - if usage_metadata is not None: - aggregated_llm_response.usage_metadata = usage_metadata - usage_metadata = None - yield aggregated_llm_response - - if aggregated_llm_response_with_tool_call: - if usage_metadata is not None: - aggregated_llm_response_with_tool_call.usage_metadata = ( - usage_metadata + async def generate_content_async( + self, llm_request: "LlmRequest", stream: bool = False + ) -> AsyncGenerator["LlmResponse", None]: # type: ignore[override,name-defined] # return type differs from BaseLlm; LlmResponse is a forward ref + maybe_append_user_content = getattr(self, "_maybe_append_user_content", None) + if callable(maybe_append_user_content): + maybe_append_user_content(llm_request) + + input_items, tools, text_config, instructions = _get_response_inputs( + llm_request, self._system_role + ) + reasoning = _get_reasoning_config(llm_request) + + response_args: dict[str, Any] = { + "model": self.model, + "input": input_items, + **({"tools": tools} if tools else {}), + **({"text": text_config} if text_config else {}), + **({"instructions": instructions} if instructions else {}), + } + if reasoning: + response_args["reasoning"] = reasoning + response_args.setdefault("include", ["reasoning.encrypted_content"]) + + response_args.update(self._additional_args) + + # Ensure adapter-required include entries aren't dropped by user overrides. + if reasoning: + include = response_args.get("include") or [] + if not isinstance(include, list): + include = [include] + if "reasoning.encrypted_content" not in include: + include.append("reasoning.encrypted_content") + response_args["include"] = include + if tools and "tool_choice" not in response_args: + response_args["tool_choice"] = "auto" + + if not stream: + response = await self._client.responses.create(**response_args) + yield _response_to_llm_response(response) + return + + stream_response = await self._client.responses.create( + stream=True, **response_args + ) + final_response: Any = None + + async for event in stream_response: # type: ignore[union-attr] # create() returns union of Response|AsyncStream; we know it's AsyncStream here + event_type: str | None = getattr(event, "type", None) + + if event_type == "response.reasoning_text.delta": + delta: str = getattr(event, "delta", "") + if delta: + yield _parts_to_llm_response( + [_build_text_part(delta, thought=True)], + is_partial=True, ) - yield aggregated_llm_response_with_tool_call - else: - response = await self._client.chat.completions.create(**completion_args) - yield _model_response_to_generate_content_response(response) + + elif event_type == "response.output_text.delta": + delta = getattr(event, "delta", "") + if delta: + yield _parts_to_llm_response( + [_build_text_part(delta)], + is_partial=True, + ) + + # Function-call deltas are not yielded as partials; the full + # function_call item is emitted via the response.completed event. + + elif event_type == "response.completed": + final_response = getattr(event, "response", None) + + if final_response is not None: + yield _response_to_llm_response(final_response) __all__ = ["PortkeyAdk"] diff --git a/tests/integrations/test_adk_adapter.py b/tests/integrations/test_adk_adapter.py index f244392b..df826434 100644 --- a/tests/integrations/test_adk_adapter.py +++ b/tests/integrations/test_adk_adapter.py @@ -1,48 +1,144 @@ -import asyncio from typing import Any, AsyncIterator, Optional import pytest -# Skip these tests unless google-adk is installed pytest.importorskip("google.adk", reason="google-adk extra not installed") from google.adk.models.llm_request import LlmRequest # type: ignore from google.genai import types as genai_types # type: ignore -from portkey_ai.integrations.adk import PortkeyAdk +from portkey_ai.integrations.adk import ( + PortkeyAdk, + _normalize_thought_signature, + _get_reasoning_config, +) -class _FakeDelta: - def __init__(self, content: Optional[str] = None): - self.content = content +class _FakeReasoningContent: + def __init__(self, text: str): + self.type = "reasoning_text" + self.text = text -class _FakeMessage: - def __init__(self, content: Optional[str] = None): - self.content = content - self.tool_calls = None +class _FakeReasoningSummary: + def __init__(self, text: str): + self.type = "summary_text" + self.text = text -class _FakeChoice: +class _FakeReasoningItem: def __init__( self, - delta: Optional[_FakeDelta] = None, - message: Optional[_FakeMessage] = None, - finish_reason: Optional[str] = None, + text: Optional[str] = None, + encrypted_content: Optional[str] = None, + summary_text: Optional[str] = None, + item_id: str = "rs_1", ): - self.delta = delta - self.message = message - self.finish_reason = finish_reason + self.type = "reasoning" + self.id = item_id + self.encrypted_content = encrypted_content + self.content = [_FakeReasoningContent(text)] if text is not None else None + self.summary = [_FakeReasoningSummary(summary_text)] if summary_text else [] + + +class _FakeOutputText: + def __init__(self, text: str): + self.type = "output_text" + self.text = text + + +class _FakeRefusal: + def __init__(self, refusal: str): + self.type = "refusal" + self.refusal = refusal + + +class _FakeMessage: + def __init__(self, text: Optional[str] = None, refusal: Optional[str] = None): + self.type = "message" + self.id = "msg_1" + self.role = "assistant" + self.status = "completed" + self.content = [] + if text is not None: + self.content.append(_FakeOutputText(text)) + if refusal is not None: + self.content.append(_FakeRefusal(refusal)) + + +class _FakeFunctionCall: + def __init__(self, name: str, arguments: str, call_id: str = "call_1"): + self.type = "function_call" + self.id = call_id + self.call_id = call_id + self.name = name + self.arguments = arguments + self.status = "completed" + + +class _FakeUsage: + def __init__( + self, input_tokens: int = 1, output_tokens: int = 2, total_tokens: int = 3 + ): + self.input_tokens = input_tokens + self.output_tokens = output_tokens + self.total_tokens = total_tokens class _FakeResponse: - def __init__(self, message_text: str): - self.choices = [ - _FakeChoice(message=_FakeMessage(message_text), finish_reason="stop") - ] - self.usage = type( - "Usage", (), {"prompt_tokens": 1, "completion_tokens": 2, "total_tokens": 3} - )() + def __init__(self, output: list[Any], usage: Optional[_FakeUsage] = None): + self.output = output + self.usage = usage or _FakeUsage() + + +class _FakeReasoningDeltaEvent: + def __init__(self, delta: str, item_id: str = "rs_1"): + self.type = "response.reasoning_text.delta" + self.delta = delta + self.item_id = item_id + + +class _FakeTextDeltaEvent: + def __init__(self, delta: str, item_id: str = "msg_1", content_index: int = 0): + self.type = "response.output_text.delta" + self.delta = delta + self.item_id = item_id + self.content_index = content_index + + +class _FakeFunctionArgsDeltaEvent: + def __init__(self, delta: str, item_id: str = "call_1"): + self.type = "response.function_call_arguments.delta" + self.delta = delta + self.item_id = item_id + + +class _FakeFunctionArgsDoneEvent: + def __init__(self, arguments: str, name: str, item_id: str = "call_1"): + self.type = "response.function_call_arguments.done" + self.arguments = arguments + self.name = name + self.item_id = item_id + + +class _FakeOutputItemAddedEvent: + def __init__(self, item: Any, output_index: int = 0): + self.type = "response.output_item.added" + self.item = item + self.output_index = output_index + + +class _FakeOutputItemDoneEvent: + def __init__(self, item: Any, output_index: int = 0): + self.type = "response.output_item.done" + self.item = item + self.output_index = output_index + + +class _FakeCompletedEvent: + def __init__(self, response: _FakeResponse): + self.type = "response.completed" + self.response = response def _build_request(model: str, text: str = "Hello") -> LlmRequest: @@ -62,60 +158,223 @@ async def test_non_streaming_simple(monkeypatch: pytest.MonkeyPatch) -> None: llm = PortkeyAdk(model="@openai/gpt-4o-mini", api_key="test") async def fake_create(**kwargs: Any) -> _FakeResponse: - assert not kwargs.get("stream"), "Non-streaming path should not request stream" - return _FakeResponse(message_text="Hello world!") + assert kwargs["input"][0]["content"] == "test" + return _FakeResponse(output=[_FakeMessage(text="Hello world!")]) - # Patch the underlying client create method - monkeypatch.setattr(llm._client.chat.completions, "create", fake_create) # type: ignore[attr-defined] + monkeypatch.setattr(llm._client.responses, "create", fake_create) # type: ignore[attr-defined] req = _build_request(model="@openai/gpt-4o-mini", text="test") outputs: list[str] = [] async for resp in llm.generate_content_async(req, stream=False): assert not getattr(resp, "partial", False) - assert resp.content and resp.content.parts - for p in resp.content.parts: - if getattr(p, "text", None): - outputs.append(p.text) - assert "".join(outputs).strip() == "Hello world!" + parts = getattr(getattr(resp, "content", None), "parts", []) or [] + for p in parts: + text = getattr(p, "text", None) + if text: + outputs.append(text) + assert "".join(outputs) == "Hello world!" + + +@pytest.mark.asyncio +async def test_non_streaming_with_reasoning_and_function_call( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llm = PortkeyAdk(model="@openai/gpt-4o-mini", api_key="test") + + async def fake_create(**kwargs: Any) -> _FakeResponse: + assert kwargs["reasoning"]["summary"] == "auto" + assert kwargs["include"] == ["reasoning.encrypted_content"] + return _FakeResponse( + output=[ + _FakeReasoningItem( + text="Let me think.", + encrypted_content="sig123", + ), + _FakeMessage(text="Final answer."), + _FakeFunctionCall("lookup_weather", '{"city":"SF"}'), + ] + ) + + monkeypatch.setattr(llm._client.responses, "create", fake_create) # type: ignore[attr-defined] + + req = LlmRequest( + model="@openai/gpt-4o-mini", + contents=[ + genai_types.Content( + role="user", + parts=[genai_types.Part.from_text(text="test")], + ) + ], + config=genai_types.GenerateContentConfig( + thinking_config=genai_types.ThinkingConfig(include_thoughts=True) + ), + ) + + thoughts: list[str] = [] + thought_signatures: list[str] = [] + texts: list[str] = [] + function_names: list[str] = [] + async for resp in llm.generate_content_async(req, stream=False): + parts = getattr(getattr(resp, "content", None), "parts", []) or [] + for p in parts: + text = getattr(p, "text", None) + thought_signature = getattr(p, "thought_signature", None) + function_call = getattr(p, "function_call", None) + if getattr(p, "thought", False) and text: + thoughts.append(text) + if isinstance(thought_signature, str): + thought_signatures.append(thought_signature) + elif function_call is not None and getattr(function_call, "name", None): + function_names.append(function_call.name) + elif text: + texts.append(text) + + assert "".join(thoughts) == "Let me think." + assert thought_signatures == ["sig123"] + assert "".join(texts) == "Final answer." + assert function_names == ["lookup_weather"] @pytest.mark.asyncio -async def test_streaming_accumulates_and_final(monkeypatch: pytest.MonkeyPatch) -> None: +async def test_streaming_yields_partials_and_final( + monkeypatch: pytest.MonkeyPatch, +) -> None: llm = PortkeyAdk(model="@openai/gpt-4o-mini", api_key="test") - part1 = type("Chunk", (), {"choices": [_FakeChoice(delta=_FakeDelta("Hello "))]})() - part2 = type("Chunk", (), {"choices": [_FakeChoice(delta=_FakeDelta("world!"))]})() - part3 = type( - "Chunk", - (), - {"choices": [_FakeChoice(message=_FakeMessage(None), finish_reason="stop")]}, - )() + final_response = _FakeResponse( + output=[ + _FakeReasoningItem(text="Thinking...", item_id="rs_1"), + _FakeMessage(text="Answer: 42"), + ] + ) async def fake_stream_gen() -> AsyncIterator[Any]: - yield part1 - yield part2 - yield part3 + yield _FakeOutputItemAddedEvent(_FakeReasoningItem(text=None, item_id="rs_1")) + yield _FakeReasoningDeltaEvent("Thinking...") + yield _FakeTextDeltaEvent("Answer: ") + yield _FakeTextDeltaEvent("42") + yield _FakeCompletedEvent(final_response) - async def fake_create(**kwargs: Any) -> AsyncIterator[Any] | _FakeResponse: - if kwargs.get("stream"): - return fake_stream_gen() - return _FakeResponse(message_text="unused") + async def fake_create(**kwargs: Any) -> AsyncIterator[Any]: + assert kwargs["stream"] is True + return fake_stream_gen() - monkeypatch.setattr(llm._client.chat.completions, "create", fake_create) # type: ignore[attr-defined] + monkeypatch.setattr(llm._client.responses, "create", fake_create) # type: ignore[attr-defined] req = _build_request(model="@openai/gpt-4o-mini", text="test") - partial_text = [] - final_text = [] + partial_thoughts: list[str] = [] + partial_text: list[str] = [] + final_thoughts: list[str] = [] + final_text: list[str] = [] async for resp in llm.generate_content_async(req, stream=True): - assert resp.content and resp.content.parts - text_parts = [p.text for p in resp.content.parts if getattr(p, "text", None)] - if getattr(resp, "partial", False): - partial_text.extend(text_parts) - else: - final_text.extend(text_parts) - - # Partial updates should reflect the stream pieces - assert "".join(partial_text) == "Hello world!" - # Final message should be aggregated once - assert "".join(final_text) == "Hello world!" + parts = getattr(getattr(resp, "content", None), "parts", []) or [] + for p in parts: + text = getattr(p, "text", None) + if getattr(p, "thought", False) and text: + if getattr(resp, "partial", False): + partial_thoughts.append(text) + else: + final_thoughts.append(text) + elif text: + if getattr(resp, "partial", False): + partial_text.append(text) + else: + final_text.append(text) + + assert "".join(partial_thoughts) == "Thinking..." + assert "".join(partial_text) == "Answer: 42" + assert "".join(final_thoughts) == "Thinking..." + assert "".join(final_text) == "Answer: 42" + + +@pytest.mark.asyncio +async def test_streaming_function_call_yields_tool_response( + monkeypatch: pytest.MonkeyPatch, +) -> None: + llm = PortkeyAdk(model="@openai/gpt-4o-mini", api_key="test") + function_item = _FakeFunctionCall("lookup_weather", '{"city":"SF"}') + final_response = _FakeResponse(output=[function_item]) + + async def fake_stream_gen() -> AsyncIterator[Any]: + yield _FakeOutputItemAddedEvent(function_item) + yield _FakeFunctionArgsDeltaEvent('{"city":') + yield _FakeFunctionArgsDoneEvent('{"city":"SF"}', "lookup_weather") + yield _FakeOutputItemDoneEvent(function_item) + yield _FakeCompletedEvent(final_response) + + async def fake_create(**kwargs: Any) -> AsyncIterator[Any]: + return fake_stream_gen() + + monkeypatch.setattr(llm._client.responses, "create", fake_create) # type: ignore[attr-defined] + + req = _build_request(model="@openai/gpt-4o-mini", text="test") + function_names: list[str] = [] + async for resp in llm.generate_content_async(req, stream=True): + parts = getattr(getattr(resp, "content", None), "parts", []) or [] + for p in parts: + function_call = getattr(p, "function_call", None) + if function_call is not None and getattr(function_call, "name", None): + function_names.append(function_call.name) + + assert function_names == ["lookup_weather"] + + +def test_normalize_thought_signature_string() -> None: + assert _normalize_thought_signature("signature_string") == "signature_string" + + +def test_normalize_thought_signature_bytes() -> None: + assert _normalize_thought_signature(b"binary_sig") == "YmluYXJ5X3NpZw==" + + +def test_normalize_thought_signature_none() -> None: + assert _normalize_thought_signature(None) is None + + +def test_get_reasoning_config_with_budget() -> None: + req = LlmRequest( + model="test", + contents=[], + config=genai_types.GenerateContentConfig( + thinking_config=genai_types.ThinkingConfig( + include_thoughts=True, + thinking_budget=2048, + ), + ), + ) + result = _get_reasoning_config(req) + assert result is not None + assert result["effort"] == "medium" + assert result["summary"] == "auto" + + +def test_get_reasoning_config_without_budget() -> None: + req = LlmRequest( + model="test", + contents=[], + config=genai_types.GenerateContentConfig( + thinking_config=genai_types.ThinkingConfig( + include_thoughts=True, + ), + ), + ) + result = _get_reasoning_config(req) + assert result is not None + assert result["effort"] == "medium" + + +def test_get_reasoning_config_none_when_not_configured() -> None: + req = LlmRequest(model="test", contents=[]) + assert _get_reasoning_config(req) is None + + +def test_get_reasoning_config_none_when_empty() -> None: + req = LlmRequest( + model="test", + contents=[], + config=genai_types.GenerateContentConfig( + thinking_config=genai_types.ThinkingConfig(), + ), + ) + assert _get_reasoning_config(req) is None