diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/README.md b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/README.md new file mode 100644 index 000000000..4b4aac443 --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/README.md @@ -0,0 +1,17 @@ +# LoongSuite WideSearch Instrumentation + +OpenTelemetry instrumentation for the [WideSearch](https://github.com/ByteDance-Seed/WideSearch) multi-agent search framework. + +## Installation + +```bash +pip install loongsuite-instrumentation-widesearch +``` + +## Usage + +```python +from opentelemetry.instrumentation.widesearch import WideSearchInstrumentor + +WideSearchInstrumentor().instrument() +``` diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/pyproject.toml b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/pyproject.toml new file mode 100644 index 000000000..9a819d25a --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/pyproject.toml @@ -0,0 +1,57 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "loongsuite-instrumentation-widesearch" +dynamic = ["version"] +description = "LoongSuite WideSearch Instrumentation" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.11" +authors = [ + { name = "LoongSuite Python Agent Authors", email = "caishipeng.csp@alibaba-inc.com" }, + { name = "OpenTelemetry Authors", email = "cncf-opentelemetry-contributors@lists.cncf.io" }, +] +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", +] +dependencies = [ + "opentelemetry-api ~= 1.37", + "opentelemetry-instrumentation >= 0.58b0", + "opentelemetry-semantic-conventions >= 0.58b0", + "opentelemetry-util-genai", + "wrapt >= 1.17.3, < 2.0.0", +] + +[project.optional-dependencies] +instruments = [ + "widesearch >= 0.1.0", +] +test = [ + "pytest ~= 8.0", + "pytest-cov ~= 4.1.0", +] + +[project.entry-points.opentelemetry_instrumentor] +widesearch = "opentelemetry.instrumentation.widesearch:WideSearchInstrumentor" + +[project.urls] +Homepage = "https://github.com/alibaba/loongsuite-python-agent/tree/main/instrumentation-loongsuite/loongsuite-instrumentation-widesearch" +Repository = "https://github.com/alibaba/loongsuite-python-agent" + +[tool.hatch.version] +path = "src/opentelemetry/instrumentation/widesearch/version.py" + +[tool.hatch.build.targets.sdist] +include = ["/src", "/tests"] + +[tool.hatch.build.targets.wheel] +packages = ["src/opentelemetry"] diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/__init__.py b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/__init__.py new file mode 100644 index 000000000..9c441d18f --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/__init__.py @@ -0,0 +1,164 @@ +""" +WideSearch instrumentation supporting `widesearch >= 0.1.0`. + +Usage +----- +.. code:: python + + from opentelemetry.instrumentation.widesearch import WideSearchInstrumentor + + WideSearchInstrumentor().instrument() + +API +--- +""" + +from __future__ import annotations + +import logging +from typing import Any, Collection + +from wrapt import wrap_function_wrapper + +from opentelemetry.instrumentation.instrumentor import BaseInstrumentor +from opentelemetry.instrumentation.utils import unwrap +from opentelemetry.instrumentation.widesearch.package import _instruments +from opentelemetry.instrumentation.widesearch.patch import ( + wrap_create_sub_agents_factory, + wrap_invoke_tool_call, + wrap_run_single_query, + wrap_runner_run, + wrap_runner_step, +) +from opentelemetry.instrumentation.widesearch.version import __version__ +from opentelemetry.util.genai.extended_handler import ExtendedTelemetryHandler + +logger = logging.getLogger(__name__) + +_RUN_MODULE = "src.agent.run" +_MULTI_AGENT_MODULE = "src.agent.multi_agent_tools" + +__all__ = ["WideSearchInstrumentor", "__version__"] + + +class WideSearchInstrumentor(BaseInstrumentor): + """OpenTelemetry instrumentor for WideSearch framework. + + Instruments the following components: + - run_single_query(): ENTRY span + - Runner.run(): AGENT span (async generator) + - Runner._step(): STEP span + - Runner._invoke_tool_call(): TOOL spans + - create_sub_agents_wrap(): TASK span + """ + + def __init__(self): + super().__init__() + self._handler = None + + def instrumentation_dependencies(self) -> Collection[str]: + return _instruments + + def _instrument(self, **kwargs: Any) -> None: + tracer_provider = kwargs.get("tracer_provider") + meter_provider = kwargs.get("meter_provider") + logger_provider = kwargs.get("logger_provider") + + self._handler = ExtendedTelemetryHandler( + tracer_provider=tracer_provider, + meter_provider=meter_provider, + logger_provider=logger_provider, + ) + + # H1: ENTRY span + try: + wrap_function_wrapper( + module=_RUN_MODULE, + name="run_single_query", + wrapper=lambda w, i, a, k: wrap_run_single_query( + w, i, a, k, handler=self._handler + ), + ) + logger.debug("Instrumented run_single_query") + except Exception as e: + logger.warning(f"Failed to instrument run_single_query: {e}") + + # H2: AGENT span + try: + wrap_function_wrapper( + module=_RUN_MODULE, + name="Runner.run", + wrapper=lambda w, i, a, k: wrap_runner_run( + w, i, a, k, handler=self._handler + ), + ) + logger.debug("Instrumented Runner.run") + except Exception as e: + logger.warning(f"Failed to instrument Runner.run: {e}") + + # H3: STEP span + try: + wrap_function_wrapper( + module=_RUN_MODULE, + name="Runner._step", + wrapper=lambda w, i, a, k: wrap_runner_step( + w, i, a, k, handler=self._handler + ), + ) + logger.debug("Instrumented Runner._step") + except Exception as e: + logger.warning(f"Failed to instrument Runner._step: {e}") + + # H4: TOOL spans + try: + wrap_function_wrapper( + module=_RUN_MODULE, + name="Runner._invoke_tool_call", + wrapper=lambda w, i, a, k: wrap_invoke_tool_call( + w, i, a, k, handler=self._handler + ), + ) + logger.debug("Instrumented Runner._invoke_tool_call") + except Exception as e: + logger.warning( + f"Failed to instrument Runner._invoke_tool_call: {e}" + ) + + # H5: TASK span (wrap factory) + try: + wrap_function_wrapper( + module=_MULTI_AGENT_MODULE, + name="create_sub_agents_wrap", + wrapper=lambda w, i, a, k: wrap_create_sub_agents_factory( + w, i, a, k, handler=self._handler + ), + ) + logger.debug("Instrumented create_sub_agents_wrap") + except Exception as e: + logger.warning( + f"Failed to instrument create_sub_agents_wrap: {e}" + ) + + def _uninstrument(self, **kwargs: Any) -> None: + try: + import src.agent.run # noqa: PLC0415 + + unwrap(src.agent.run, "run_single_query") + unwrap(src.agent.run.Runner, "run") + unwrap(src.agent.run.Runner, "_step") + unwrap(src.agent.run.Runner, "_invoke_tool_call") + logger.debug("Uninstrumented src.agent.run") + except Exception as e: + logger.warning(f"Failed to uninstrument src.agent.run: {e}") + + try: + import src.agent.multi_agent_tools # noqa: PLC0415 + + unwrap(src.agent.multi_agent_tools, "create_sub_agents_wrap") + logger.debug("Uninstrumented src.agent.multi_agent_tools") + except Exception as e: + logger.warning( + f"Failed to uninstrument src.agent.multi_agent_tools: {e}" + ) + + self._handler = None diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/package.py b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/package.py new file mode 100644 index 000000000..bd0572292 --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/package.py @@ -0,0 +1,2 @@ +_instruments = ("widesearch >= 0.1.0",) +_supports_metrics = False diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/patch.py b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/patch.py new file mode 100644 index 000000000..32ac6287b --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/patch.py @@ -0,0 +1,338 @@ +"""Patch functions for WideSearch instrumentation. + +Wraps key WideSearch methods to generate OpenTelemetry spans: +- run_single_query -> ENTRY span +- Runner.run -> AGENT span (async generator) +- Runner._step -> STEP span +- Runner._invoke_tool_call -> TOOL spans (one per tool_call) +- create_sub_agents_wrap -> TASK span (on returned closure) +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from contextvars import ContextVar + +from opentelemetry.trace import SpanKind, StatusCode +from opentelemetry.trace.status import Status +from opentelemetry.util.genai.extended_handler import ExtendedTelemetryHandler +from opentelemetry.util.genai.extended_types import ReactStepInvocation +from opentelemetry.util.genai.types import Error + +from .utils import ( + _create_agent_invocation, + _create_entry_invocation, + _create_tool_invocation, + _extract_output_messages, + _step_to_output_messages, +) + +logger = logging.getLogger(__name__) + +_step_counter: ContextVar[int] = ContextVar("ws_step_counter", default=0) +_in_run_single_query: ContextVar[bool] = ContextVar("ws_in_rsq", default=False) + + +async def wrap_run_single_query( + wrapped, instance, args, kwargs, *, handler: ExtendedTelemetryHandler +): + """H1: ENTRY span for run_single_query.""" + if _in_run_single_query.get(): + return await wrapped(*args, **kwargs) + token = _in_run_single_query.set(True) + + query = args[0] if args else kwargs.get("query", "") + try: + invocation = _create_entry_invocation(query) + except Exception as e: + logger.debug(f"Failed to create entry invocation: {e}") + _in_run_single_query.reset(token) + return await wrapped(*args, **kwargs) + + handler.start_entry(invocation) + + try: + result = await wrapped(*args, **kwargs) + invocation.output_messages = _extract_output_messages(result) + handler.stop_entry(invocation) + return result + except Exception as e: + handler.fail_entry(invocation, Error(message=str(e), type=type(e))) + raise + finally: + _in_run_single_query.reset(token) + + +async def wrap_runner_run( + wrapped, instance, args, kwargs, *, handler: ExtendedTelemetryHandler +): + """H2: AGENT span for Runner.run (async generator).""" + starting_agent = args[0] if args else kwargs.get("starting_agent") + user_input = args[1] if len(args) > 1 else kwargs.get("user_input", "") + + try: + invocation = _create_agent_invocation(starting_agent, user_input) + except Exception as e: + logger.debug(f"Failed to create agent invocation: {e}") + async for step in wrapped(*args, **kwargs): + yield step + return + + counter_token = _step_counter.set(0) + handler.start_invoke_agent(invocation) + + try: + last_step = None + async for step in wrapped(*args, **kwargs): + last_step = step + yield step + + if last_step: + invocation.output_messages = _step_to_output_messages(last_step) + handler.stop_invoke_agent(invocation) + except GeneratorExit as e: + handler.fail_invoke_agent( + invocation, Error(message="GeneratorExit", type=GeneratorExit) + ) + raise + except Exception as e: + handler.fail_invoke_agent( + invocation, Error(message=str(e), type=type(e)) + ) + raise + finally: + _step_counter.reset(counter_token) + + +async def wrap_runner_step( + wrapped, instance, args, kwargs, *, handler: ExtendedTelemetryHandler +): + """H3: STEP span for Runner._step.""" + step_num = _step_counter.get() + 1 + _step_counter.set(step_num) + + invocation = ReactStepInvocation(round=step_num) + invocation.attributes["gen_ai.framework"] = "widesearch" + + try: + handler.start_react_step(invocation) + except Exception as e: + logger.debug(f"Failed to start react step: {e}") + return await wrapped(*args, **kwargs) + + try: + result = await wrapped(*args, **kwargs) + + from src.agent.memory import ActionStep, ActionStepError, StepStatus + + if isinstance(result, ActionStepError): + invocation.finish_reason = "error" + handler.fail_react_step( + invocation, + Error(message=result.message, type=type(result)), + ) + else: + if result.step_status == StepStatus.FINISHED: + invocation.finish_reason = "finished" + elif result.error_marker is not None: + invocation.finish_reason = "error" + else: + invocation.finish_reason = "continue" + handler.stop_react_step(invocation) + + return result + except Exception as e: + invocation.finish_reason = "error" + handler.fail_react_step( + invocation, Error(message=str(e), type=type(e)) + ) + raise + + +async def wrap_invoke_tool_call( + wrapped, instance, args, kwargs, *, handler: ExtendedTelemetryHandler +): + """H4: TOOL span for each tool_call inside Runner._invoke_tool_call.""" + agent = args[0] if args else kwargs.get("agent") + model_response = args[1] if len(args) > 1 else kwargs.get("model_response") + + if not model_response.outputs: + return await wrapped(*args, **kwargs) + + resp = model_response.outputs[0] + if not resp.tool_calls: + return await wrapped(*args, **kwargs) + + from src.agent.schema import ErrorMarker, ToolCallResult + + async def _call_with_span(tool_call): + try: + invocation = _create_tool_invocation(tool_call, agent) + except Exception as e: + logger.debug(f"Failed to create tool invocation: {e}") + return await _call_original(tool_call, agent) + + handler.start_execute_tool(invocation) + + tool_name = tool_call.tool_name + tool = agent.get_tool_by_name(tool_name) + if tool is None: + invocation.tool_call_result = f"Tool {tool_name} not found" + handler.fail_execute_tool( + invocation, + Error( + message=f"Tool {tool_name} not found", + type=ValueError, + ), + ) + return ToolCallResult( + tool_call_id=tool_call.tool_call_id, + error_marker=ErrorMarker(message=f"Tool {tool_name} not found"), + ) + + arguments = tool_call.arguments + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {} + + try: + response = await tool(**arguments) + except Exception as e: + invocation.tool_call_result = str(e) + handler.fail_execute_tool( + invocation, Error(message=str(e), type=type(e)) + ) + return ToolCallResult( + tool_call_id=tool_call.tool_call_id, + error_marker=ErrorMarker(message=str(e)), + ) + + error_marker = ( + ErrorMarker(message=response.error) if response.error else None + ) + system_error_marker = ( + ErrorMarker(message=response.system_error) + if response.system_error + else None + ) + + result_content = response.data + invocation.tool_call_result = ( + str(result_content) if result_content else None + ) + + if error_marker or system_error_marker: + msg = (error_marker or system_error_marker)["message"] + handler.fail_execute_tool( + invocation, Error(message=msg, type=RuntimeError) + ) + else: + handler.stop_execute_tool(invocation) + + return ToolCallResult( + tool_call_id=tool_call.tool_call_id, + content=result_content, + error_marker=error_marker, + system_error_marker=system_error_marker, + extra=response.extra if response.extra else {}, + ) + + async def _call_original(tool_call, agent): + """Fallback: execute tool without span.""" + tool_name = tool_call.tool_name + tool = agent.get_tool_by_name(tool_name) + if tool is None: + return ToolCallResult( + tool_call_id=tool_call.tool_call_id, + error_marker=ErrorMarker(message=f"Tool {tool_name} not found"), + ) + arguments = tool_call.arguments + if isinstance(arguments, str): + try: + arguments = json.loads(arguments) + except json.JSONDecodeError: + arguments = {} + try: + response = await tool(**arguments) + except Exception as e: + return ToolCallResult( + tool_call_id=tool_call.tool_call_id, + error_marker=ErrorMarker(message=str(e)), + ) + return ToolCallResult( + tool_call_id=tool_call.tool_call_id, + content=response.data, + error_marker=( + ErrorMarker(message=response.error) if response.error else None + ), + system_error_marker=( + ErrorMarker(message=response.system_error) + if response.system_error + else None + ), + extra=response.extra if response.extra else {}, + ) + + tasks = [_call_with_span(tc) for tc in resp.tool_calls] + results = await asyncio.gather(*tasks) + return [r for r in results if r is not None] + + +def wrap_create_sub_agents_factory( + wrapped, instance, args, kwargs, *, handler: ExtendedTelemetryHandler +): + """H5: TASK span wrapping the closure returned by create_sub_agents_wrap.""" + original_closure = wrapped(*args, **kwargs) + + async def closure_with_task_span(sub_agents): + tracer = handler._tracer + span_name = "run_task create_sub_agents" + + with tracer.start_as_current_span( + name=span_name, + kind=SpanKind.INTERNAL, + ) as span: + span.set_attribute("gen_ai.span.kind", "TASK") + span.set_attribute("gen_ai.operation.name", "run_task") + span.set_attribute("gen_ai.framework", "widesearch") + + try: + safe_input = json.dumps( + [ + { + "index": sa.get("index"), + "prompt": sa.get("prompt", "")[:200], + } + for sa in sub_agents + ], + ensure_ascii=False, + ) + span.set_attribute("input.value", safe_input) + except Exception: + pass + + try: + result = await original_closure(sub_agents) + + if result and hasattr(result, "data") and result.data: + output_str = ( + result.data + if isinstance(result.data, str) + else json.dumps(result.data, ensure_ascii=False) + ) + if len(output_str) > 4096: + output_str = output_str[:4096] + "...(truncated)" + span.set_attribute("output.value", output_str) + + span.set_status(Status(StatusCode.OK)) + return result + except Exception as e: + span.record_exception(e) + span.set_status(Status(StatusCode.ERROR, str(e))) + raise + + return closure_with_task_span diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/utils.py b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/utils.py new file mode 100644 index 000000000..0a8f751f7 --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/utils.py @@ -0,0 +1,155 @@ +"""Utility functions for WideSearch instrumentation.""" + +from __future__ import annotations + +import json +import logging +from typing import Any, List, Optional + +from opentelemetry.util.genai.extended_types import ( + EntryInvocation, + ExecuteToolInvocation, + InvokeAgentInvocation, + ReactStepInvocation, +) +from opentelemetry.util.genai.types import ( + FunctionToolDefinition, + InputMessage, + OutputMessage, + Text, +) + +logger = logging.getLogger(__name__) + + +_FRAMEWORK = "widesearch" + + +def _create_entry_invocation(query: str) -> EntryInvocation: + invocation = EntryInvocation() + invocation.input_messages = [ + InputMessage(role="user", parts=[Text(content=query)]) + ] + invocation.attributes["gen_ai.framework"] = _FRAMEWORK + return invocation + + +def _create_agent_invocation( + agent: Any, user_input: str +) -> InvokeAgentInvocation: + agent_name = getattr(agent, "name", None) or "widesearch-agent" + + request_model = None + model_config_name = getattr(agent, "model_config_name", None) + if model_config_name: + try: + from src.utils.config import model_config + + request_model = model_config.get(model_config_name, {}).get( + "model_name" + ) + except Exception: + pass + request_model = request_model or model_config_name + + instructions = getattr(agent, "instructions", None) or "" + + invocation = InvokeAgentInvocation( + provider="widesearch", + agent_name=agent_name, + agent_description=instructions[:200] if instructions else "", + request_model=request_model, + input_messages=[ + InputMessage(role="user", parts=[Text(content=user_input)]) + ], + ) + invocation.attributes["gen_ai.framework"] = _FRAMEWORK + + if instructions: + invocation.system_instruction = [Text(content=instructions)] + + tools_desc = getattr(agent, "tools_desc", None) + if tools_desc: + invocation.tool_definitions = _convert_tools_desc(tools_desc) + + return invocation + + +def _create_tool_invocation( + tool_call: Any, agent: Any +) -> ExecuteToolInvocation: + args = tool_call.arguments + if isinstance(args, str): + try: + args = json.loads(args) + except (json.JSONDecodeError, ValueError): + args = {"raw": args} + + description = None + if hasattr(agent, "tools_desc"): + for td in agent.tools_desc: + func = td.get("function", {}) + if func.get("name") == tool_call.tool_name: + description = func.get("description") + break + + invocation = ExecuteToolInvocation( + tool_name=tool_call.tool_name, + tool_call_id=getattr(tool_call, "tool_call_id", None), + tool_call_arguments=args, + tool_description=description, + tool_type="function", + ) + invocation.attributes["gen_ai.framework"] = _FRAMEWORK + return invocation + + +def _extract_output_messages(messages: Any) -> List[OutputMessage]: + """Extract output messages from run_single_query return value.""" + if not messages: + return [] + last_msg = messages[-1] + content = "" + if isinstance(last_msg, dict): + c = last_msg.get("content", {}) + if isinstance(c, dict): + content = c.get("content", "") + elif isinstance(c, str): + content = c + return [ + OutputMessage( + role="assistant", + parts=[Text(content=content)], + finish_reason="stop", + ) + ] + + +def _step_to_output_messages(step: Any) -> List[OutputMessage]: + """Extract output messages from an ActionStep.""" + content = getattr(step, "content", None) or "" + return [ + OutputMessage( + role="assistant", + parts=[Text(content=content)], + finish_reason="stop", + ) + ] + + +def _convert_tools_desc( + tools_desc: List[dict], +) -> Optional[List[FunctionToolDefinition]]: + """Convert WideSearch tools_desc to FunctionToolDefinition list.""" + result = [] + for td in tools_desc: + if td.get("type") == "function": + func = td.get("function", {}) + result.append( + FunctionToolDefinition( + name=func.get("name", ""), + description=func.get("description"), + parameters=func.get("parameters"), + ) + ) + return result if result else None diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/version.py b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/version.py new file mode 100644 index 000000000..26056b5d8 --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/src/opentelemetry/instrumentation/widesearch/version.py @@ -0,0 +1 @@ +__version__ = "0.5.0.dev" diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/tests/__init__.py b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/tests/conftest.py b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/tests/conftest.py new file mode 100644 index 000000000..fa827987c --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/tests/conftest.py @@ -0,0 +1,386 @@ +"""Test configuration for WideSearch instrumentation tests. + +Injects lightweight stub modules for `src.agent.*` into sys.modules +so that wrap_function_wrapper can find them without installing WideSearch. +""" + +from __future__ import annotations + +import os +import sys +import types +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable, List, Literal + +import pytest + +# --------------------------------------------------------------------------- +# Stub modules for WideSearch (src.agent.*) +# --------------------------------------------------------------------------- + + +class StepStatus(str, Enum): + USER = "USER" + FINISHED = "FINISHED" + CONTINUE = "CONTINUE" + ERROR = "ERROR" + + +@dataclass +class ActionStepError: + message: str + source: Literal["llm"] = "llm" + + +@dataclass +class ToolCall: + tool_name: str + arguments: Any + tool_call_id: str + + +@dataclass +class ErrorMarker: + message: str + + def __getitem__(self, key): + if key == "message": + return self.message + raise KeyError(key) + + +@dataclass +class ToolCallResult: + tool_call_id: str + content: str | None = None + error_marker: Any = None + system_error_marker: Any = None + extra: dict = field(default_factory=dict) + + +@dataclass +class LLMOutputItem: + role: str = "assistant" + content: str | None = None + reasoning_content: str | None = None + signature: str | None = None + tool_calls: list = field(default_factory=list) + + +@dataclass +class ModelResponse: + outputs: list = field(default_factory=list) + session_id: str | None = None + error_marker: Any = None + + +@dataclass +class ActionStep: + step_status: StepStatus = StepStatus.CONTINUE + content: str | None = None + reasoning_content: str | None = None + signature: str | None = None + tool_calls: list = field(default_factory=list) + tool_call_results: list = field(default_factory=list) + error_marker: Any = None + + +@dataclass +class UserInputStep: + user_input: str + step_status: StepStatus = StepStatus.USER + + +@dataclass +class MemoryTurn: + steps: list = field(default_factory=list) + + @property + def step_number(self): + return sum(1 for s in self.steps if isinstance(s, ActionStep)) + + def is_finished(self) -> bool: + if not self.steps: + return False + return self.steps[-1].step_status == StepStatus.FINISHED + + +@dataclass +class MemoryAgent: + system_instructions: str | None = None + turns: list = field(default_factory=list) + + def insert_user_input(self, user_input: str): + turn = MemoryTurn() + turn.steps.append(UserInputStep(user_input=user_input)) + self.turns.append(turn) + return turn + + def insert_action_step(self, action_step): + last_turn = self.turns[-1] + last_turn.steps.append(action_step) + return last_turn + + def to_message(self, **kwargs): + return [] + + +@dataclass +class InternalResponse: + data: Any = None + error: str | None = None + system_error: str | None = None + extra: dict | None = None + + +@dataclass +class Agent: + name: str = "test-agent" + instructions: str | None = "You are a helpful agent." + tools: dict = field(default_factory=dict) + tools_desc: list = field(default_factory=list) + model_config_name: str = "gpt-4o" + + def get_tool_by_name(self, tool_name: str): + return self.tools.get(tool_name) + + +DEFAULT_MAX_STEPS = 50 +DEFAULT_MAX_ERROR_COUNT = 3 + + +class Runner: + _step_override = None # Set to a callable to override _step behavior + + @classmethod + async def run( + cls, + starting_agent, + user_input: str, + memory=None, + *, + max_steps: int = DEFAULT_MAX_STEPS, + llm_error_strategy: str = "retry", + ): + if memory is None: + memory = MemoryAgent( + system_instructions=starting_agent.instructions + ) + last_turn = memory.insert_user_input(user_input) + step_result = await cls._step(agent=starting_agent, memory=memory) + if not isinstance(step_result, ActionStepError): + yield step_result + + @classmethod + async def _step(cls, *, agent, memory) -> ActionStep | ActionStepError: + if cls._step_override is not None: + return await cls._step_override(agent=agent, memory=memory) + return ActionStep(step_status=StepStatus.FINISHED, content="Done") + + @classmethod + async def _invoke_tool_call( + cls, agent, model_response + ) -> list: + return [] + + +async def run_single_query( + query: str, + agent_name: str = "", + model_config_name: str = "", + tools: dict = None, + tools_desc: list = None, + system_prompt: str = "", +): + agent = Agent( + name=agent_name, + tools=tools or {}, + tools_desc=tools_desc or [], + model_config_name=model_config_name, + ) + memory = MemoryAgent(system_instructions=system_prompt) + + # Mirrors real implementation: calls Runner.run as async generator + async for step in Runner.run(agent, query, memory): + pass + + last_content = "final answer" + if memory.turns: + last_turn = memory.turns[-1] + for s in reversed(last_turn.steps): + if isinstance(s, ActionStep) and s.content: + last_content = s.content + break + + return [ + {"role": "user", "content": query}, + {"role": "assistant", "content": {"content": last_content}}, + ] + + +def _default_tools(): + return {} + + +def get_system_prompt(language="zh"): + return "You are a helpful assistant." + + +def create_sub_agents_wrap( + agent_name, model_config_name, tools, tools_desc, system_prompt +): + async def create_sub_agents(sub_agents: list) -> InternalResponse: + import json + + results = [] + for sa in sub_agents: + results.append( + {"index": sa.get("index"), "prompt": sa.get("prompt", ""), "response": "sub result"} + ) + return InternalResponse( + data=json.dumps(results, ensure_ascii=False) + ) + + return create_sub_agents + + +def _inject_stub_modules(): + """Inject stub modules into sys.modules so that wrapt can resolve them.""" + # Create module hierarchy: src -> src.agent -> src.agent.run, etc. + src_mod = types.ModuleType("src") + src_agent_mod = types.ModuleType("src.agent") + src_agent_run_mod = types.ModuleType("src.agent.run") + src_agent_multi_agent_tools_mod = types.ModuleType("src.agent.multi_agent_tools") + src_agent_memory_mod = types.ModuleType("src.agent.memory") + src_agent_schema_mod = types.ModuleType("src.agent.schema") + src_agent_tools_mod = types.ModuleType("src.agent.tools") + src_agent_prompt_mod = types.ModuleType("src.agent.prompt") + src_utils_mod = types.ModuleType("src.utils") + src_utils_config_mod = types.ModuleType("src.utils.config") + + # Populate src.agent.run + src_agent_run_mod.Runner = Runner + src_agent_run_mod.run_single_query = run_single_query + src_agent_run_mod.run_turn = None + src_agent_run_mod.extract_messages_from_memory = None + + # Populate src.agent.multi_agent_tools + src_agent_multi_agent_tools_mod.create_sub_agents_wrap = create_sub_agents_wrap + + # Populate src.agent.memory + src_agent_memory_mod.ActionStep = ActionStep + src_agent_memory_mod.ActionStepError = ActionStepError + src_agent_memory_mod.MemoryAgent = MemoryAgent + src_agent_memory_mod.StepStatus = StepStatus + src_agent_memory_mod.UserInputStep = UserInputStep + + # Populate src.agent.schema + src_agent_schema_mod.ToolCall = ToolCall + src_agent_schema_mod.ToolCallResult = ToolCallResult + src_agent_schema_mod.ModelResponse = ModelResponse + src_agent_schema_mod.ErrorMarker = ErrorMarker + src_agent_schema_mod.LLMOutputItem = LLMOutputItem + + # Populate src.agent.tools + src_agent_tools_mod.InternalResponse = InternalResponse + src_agent_tools_mod._default_tools = {} + + # Populate src.agent.prompt + src_agent_prompt_mod.get_system_prompt = get_system_prompt + + # Populate src.agent.agent + src_agent_agent_mod = types.ModuleType("src.agent.agent") + src_agent_agent_mod.Agent = Agent + src_agent_agent_mod.DEFAULT_MAX_STEPS = DEFAULT_MAX_STEPS + src_agent_agent_mod.DEFAULT_MAX_ERROR_COUNT = DEFAULT_MAX_ERROR_COUNT + + # Populate src.utils.config + src_utils_config_mod.model_config = { + "gpt-4o": {"model_name": "gpt-4o-2024-05-13"}, + } + + # Wire up parent references + src_mod.agent = src_agent_mod + src_mod.utils = src_utils_mod + src_agent_mod.run = src_agent_run_mod + src_agent_mod.multi_agent_tools = src_agent_multi_agent_tools_mod + src_agent_mod.memory = src_agent_memory_mod + src_agent_mod.schema = src_agent_schema_mod + src_agent_mod.tools = src_agent_tools_mod + src_agent_mod.prompt = src_agent_prompt_mod + src_agent_mod.agent = src_agent_agent_mod + + # Register in sys.modules + sys.modules["src"] = src_mod + sys.modules["src.agent"] = src_agent_mod + sys.modules["src.agent.run"] = src_agent_run_mod + sys.modules["src.agent.multi_agent_tools"] = src_agent_multi_agent_tools_mod + sys.modules["src.agent.memory"] = src_agent_memory_mod + sys.modules["src.agent.schema"] = src_agent_schema_mod + sys.modules["src.agent.tools"] = src_agent_tools_mod + sys.modules["src.agent.prompt"] = src_agent_prompt_mod + sys.modules["src.agent.agent"] = src_agent_agent_mod + sys.modules["src.utils"] = src_utils_mod + sys.modules["src.utils.config"] = src_utils_config_mod + + +# Inject stubs before any test imports the instrumentation module +_inject_stub_modules() + + +# --------------------------------------------------------------------------- +# OTel test fixtures +# --------------------------------------------------------------------------- + + +def pytest_configure(config: pytest.Config): + os.environ["OTEL_SEMCONV_STABILITY_OPT_IN"] = "gen_ai_latest_experimental" + os.environ["OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT"] = "span_only" + + +from opentelemetry.instrumentation.widesearch import WideSearchInstrumentor +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import SimpleSpanProcessor +from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, +) +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import InMemoryMetricReader + + +@pytest.fixture(scope="function", name="span_exporter") +def fixture_span_exporter(): + exporter = InMemorySpanExporter() + yield exporter + + +@pytest.fixture(scope="function", name="metric_reader") +def fixture_metric_reader(): + reader = InMemoryMetricReader() + yield reader + + +@pytest.fixture(scope="function", name="tracer_provider") +def fixture_tracer_provider(span_exporter): + provider = TracerProvider() + provider.add_span_processor(SimpleSpanProcessor(span_exporter)) + return provider + + +@pytest.fixture(scope="function", name="meter_provider") +def fixture_meter_provider(metric_reader): + meter_provider = MeterProvider(metric_readers=[metric_reader]) + return meter_provider + + +@pytest.fixture(scope="function") +def instrument(tracer_provider, meter_provider): + instrumentor = WideSearchInstrumentor() + instrumentor.instrument( + tracer_provider=tracer_provider, + meter_provider=meter_provider, + skip_dep_check=True, + ) + yield instrumentor + instrumentor.uninstrument() diff --git a/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/tests/test_widesearch.py b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/tests/test_widesearch.py new file mode 100644 index 000000000..7ddc04a2a --- /dev/null +++ b/instrumentation-loongsuite/loongsuite-instrumentation-widesearch/tests/test_widesearch.py @@ -0,0 +1,715 @@ +"""Tests for WideSearch instrumentation. + +Covers: +- Instrumentor lifecycle (instrument/uninstrument idempotency) +- 5 span types: ENTRY, AGENT, STEP, TOOL, TASK +- Parent-child relationships +- Key attributes +- Error paths +""" + +from __future__ import annotations + +import asyncio +import json +import sys +from dataclasses import field +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from opentelemetry.trace import StatusCode + +from .conftest import ( + ActionStep, + ActionStepError, + Agent, + ErrorMarker, + InternalResponse, + LLMOutputItem, + MemoryAgent, + ModelResponse, + Runner, + StepStatus, + ToolCall, + ToolCallResult, +) + + +def _run_async(coro): + """Helper to run async coroutines in tests.""" + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + +def _run_async_gen(async_gen): + """Helper to consume an async generator.""" + async def _consume(): + results = [] + async for item in async_gen: + results.append(item) + return results + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(_consume()) + finally: + loop.close() + + +# --------------------------------------------------------------------------- +# Instrumentor Lifecycle Tests +# --------------------------------------------------------------------------- + + +class TestInstrumentorLifecycle: + def test_instrument_and_uninstrument(self, tracer_provider, meter_provider): + from opentelemetry.instrumentation.widesearch import WideSearchInstrumentor + + instrumentor = WideSearchInstrumentor() + instrumentor.instrument( + tracer_provider=tracer_provider, + meter_provider=meter_provider, + skip_dep_check=True, + ) + assert instrumentor._handler is not None + instrumentor.uninstrument() + assert instrumentor._handler is None + + def test_double_instrument_uninstrument(self, tracer_provider, meter_provider): + from opentelemetry.instrumentation.widesearch import WideSearchInstrumentor + + instrumentor = WideSearchInstrumentor() + instrumentor.instrument( + tracer_provider=tracer_provider, + meter_provider=meter_provider, + skip_dep_check=True, + ) + instrumentor.uninstrument() + + instrumentor2 = WideSearchInstrumentor() + instrumentor2.instrument( + tracer_provider=tracer_provider, + meter_provider=meter_provider, + skip_dep_check=True, + ) + assert instrumentor2._handler is not None + instrumentor2.uninstrument() + + def test_instrumentation_dependencies(self): + from opentelemetry.instrumentation.widesearch import WideSearchInstrumentor + + instrumentor = WideSearchInstrumentor() + deps = instrumentor.instrumentation_dependencies() + assert ("widesearch >= 0.1.0",) == deps + + +# --------------------------------------------------------------------------- +# ENTRY Span Tests (H1: run_single_query) +# --------------------------------------------------------------------------- + + +class TestEntrySpan: + def test_entry_span_created(self, span_exporter, instrument): + """run_single_query should produce an ENTRY span.""" + from src.agent.run import run_single_query + + _run_async(run_single_query("What is AI?", agent_name="searcher")) + + spans = span_exporter.get_finished_spans() + entry_spans = [ + s for s in spans if s.name == "enter_ai_application_system" + ] + assert len(entry_spans) == 1 + + entry = entry_spans[0] + attrs = dict(entry.attributes) + assert attrs.get("gen_ai.span.kind") == "ENTRY" + assert attrs.get("gen_ai.operation.name") == "enter" + assert attrs.get("gen_ai.framework") == "widesearch" + + def test_entry_span_error(self, span_exporter, instrument): + """ENTRY span should record ERROR on exception.""" + from src.agent.run import Runner, run_single_query + + async def failing_step(*, agent, memory): + raise RuntimeError("LLM connection failed") + + Runner._step_override = failing_step + + try: + with pytest.raises(RuntimeError, match="LLM connection failed"): + _run_async(run_single_query("test")) + finally: + Runner._step_override = None + + spans = span_exporter.get_finished_spans() + entry_spans = [ + s for s in spans if s.name == "enter_ai_application_system" + ] + assert len(entry_spans) == 1 + assert entry_spans[0].status.status_code == StatusCode.ERROR + + +# --------------------------------------------------------------------------- +# AGENT Span Tests (H2: Runner.run) +# --------------------------------------------------------------------------- + + +class TestAgentSpan: + def test_agent_span_created(self, span_exporter, instrument): + """Runner.run should produce an AGENT span.""" + from src.agent.run import Runner + + agent = Agent(name="search-agent", model_config_name="gpt-4o") + + async def _run(): + results = [] + async for step in Runner.run(agent, "Hello"): + results.append(step) + return results + + _run_async(_run()) + + spans = span_exporter.get_finished_spans() + agent_spans = [ + s for s in spans if "invoke_agent" in s.name + ] + assert len(agent_spans) == 1 + + span = agent_spans[0] + attrs = dict(span.attributes) + assert attrs.get("gen_ai.span.kind") == "AGENT" + assert attrs.get("gen_ai.operation.name") == "invoke_agent" + assert attrs.get("gen_ai.agent.name") == "search-agent" + assert attrs.get("gen_ai.framework") == "widesearch" + + def test_agent_span_is_child_of_entry(self, span_exporter, instrument): + """AGENT span should be a child of ENTRY span.""" + from src.agent.run import run_single_query + + _run_async(run_single_query("test query", agent_name="test")) + + spans = span_exporter.get_finished_spans() + entry_spans = [ + s for s in spans if s.name == "enter_ai_application_system" + ] + agent_spans = [s for s in spans if "invoke_agent" in s.name] + + assert len(entry_spans) == 1 + assert len(agent_spans) == 1 + + entry = entry_spans[0] + agent = agent_spans[0] + assert agent.parent.span_id == entry.context.span_id + + def test_agent_span_error(self, span_exporter, instrument): + """AGENT span should record ERROR when _step raises.""" + from src.agent.run import Runner + + async def failing_step(*, agent, memory): + raise ValueError("Step failure") + + Runner._step_override = failing_step + agent = Agent(name="fail-agent") + + async def _run(): + async for _ in Runner.run(agent, "Hello"): + pass + + try: + with pytest.raises(ValueError): + _run_async(_run()) + finally: + Runner._step_override = None + + spans = span_exporter.get_finished_spans() + agent_spans = [s for s in spans if "invoke_agent" in s.name] + assert len(agent_spans) == 1 + assert agent_spans[0].status.status_code == StatusCode.ERROR + + +# --------------------------------------------------------------------------- +# STEP Span Tests (H3: Runner._step) +# --------------------------------------------------------------------------- + + +class TestStepSpan: + def test_step_span_created(self, span_exporter, instrument): + """Runner._step should produce a STEP span.""" + from src.agent.run import Runner + + agent = Agent(name="stepper") + + async def _run(): + async for _ in Runner.run(agent, "test"): + pass + + _run_async(_run()) + + spans = span_exporter.get_finished_spans() + step_spans = [s for s in spans if s.name == "react step"] + assert len(step_spans) >= 1 + + step = step_spans[0] + attrs = dict(step.attributes) + assert attrs.get("gen_ai.span.kind") == "STEP" + assert attrs.get("gen_ai.operation.name") == "react" + assert attrs.get("gen_ai.react.round") == 1 + + def test_step_span_is_child_of_agent(self, span_exporter, instrument): + """STEP span should be child of AGENT span.""" + from src.agent.run import Runner + + agent = Agent(name="stepper") + + async def _run(): + async for _ in Runner.run(agent, "test"): + pass + + _run_async(_run()) + + spans = span_exporter.get_finished_spans() + agent_spans = [s for s in spans if "invoke_agent" in s.name] + step_spans = [s for s in spans if s.name == "react step"] + + assert len(agent_spans) == 1 + assert len(step_spans) >= 1 + + agent_span = agent_spans[0] + step_span = step_spans[0] + assert step_span.parent.span_id == agent_span.context.span_id + + def test_step_span_finish_reason_finished(self, span_exporter, instrument): + """STEP span should have finish_reason='finished' when step finishes.""" + from src.agent.run import Runner + + agent = Agent(name="stepper") + + async def _run(): + async for _ in Runner.run(agent, "test"): + pass + + _run_async(_run()) + + spans = span_exporter.get_finished_spans() + step_spans = [s for s in spans if s.name == "react step"] + assert len(step_spans) >= 1 + attrs = dict(step_spans[0].attributes) + assert attrs.get("gen_ai.react.finish_reason") == "finished" + + def test_step_span_error_on_action_step_error( + self, span_exporter, instrument + ): + """STEP span should record ERROR when _step returns ActionStepError.""" + from src.agent.run import Runner + + async def error_step(*, agent, memory): + return ActionStepError(message="LLM timeout") + + Runner._step_override = error_step + agent = Agent(name="error-agent") + + try: + async def _run(): + async for _ in Runner.run(agent, "test"): + pass + + _run_async(_run()) + finally: + Runner._step_override = None + + spans = span_exporter.get_finished_spans() + step_spans = [s for s in spans if s.name == "react step"] + assert len(step_spans) >= 1 + assert step_spans[0].status.status_code == StatusCode.ERROR + attrs = dict(step_spans[0].attributes) + assert attrs.get("gen_ai.react.finish_reason") == "error" + + +# --------------------------------------------------------------------------- +# TOOL Span Tests (H4: Runner._invoke_tool_call) +# --------------------------------------------------------------------------- + + +class TestToolSpan: + def test_tool_span_created(self, span_exporter, instrument): + """_invoke_tool_call should produce TOOL spans.""" + from src.agent.run import Runner + + async def mock_tool(**kwargs): + return InternalResponse(data="search results") + + agent = Agent( + name="tool-agent", + tools={"search_global": mock_tool}, + tools_desc=[ + { + "type": "function", + "function": { + "name": "search_global", + "description": "Search the web", + "parameters": {}, + }, + } + ], + ) + + tc = ToolCall( + tool_name="search_global", + arguments='{"q": "AI"}', + tool_call_id="call_123", + ) + model_resp = ModelResponse( + outputs=[LLMOutputItem(tool_calls=[tc])] + ) + + _run_async(Runner._invoke_tool_call(agent, model_resp)) + + spans = span_exporter.get_finished_spans() + tool_spans = [s for s in spans if "execute_tool" in s.name] + assert len(tool_spans) == 1 + + span = tool_spans[0] + attrs = dict(span.attributes) + assert attrs.get("gen_ai.span.kind") == "TOOL" + assert attrs.get("gen_ai.operation.name") == "execute_tool" + assert attrs.get("gen_ai.tool.name") == "search_global" + assert attrs.get("gen_ai.tool.call.id") == "call_123" + assert attrs.get("gen_ai.framework") == "widesearch" + + def test_tool_span_records_arguments_and_result( + self, span_exporter, instrument + ): + """TOOL span should record arguments and result.""" + from src.agent.run import Runner + + async def mock_tool(q=""): + return InternalResponse(data=f"results for: {q}") + + agent = Agent( + name="tool-agent", + tools={"search_global": mock_tool}, + ) + + tc = ToolCall( + tool_name="search_global", + arguments=json.dumps({"q": "OpenTelemetry"}), + tool_call_id="call_456", + ) + model_resp = ModelResponse( + outputs=[LLMOutputItem(tool_calls=[tc])] + ) + + results = _run_async(Runner._invoke_tool_call(agent, model_resp)) + assert len(results) == 1 + assert results[0].content == "results for: OpenTelemetry" + + spans = span_exporter.get_finished_spans() + tool_spans = [s for s in spans if "execute_tool" in s.name] + assert len(tool_spans) == 1 + attrs = dict(tool_spans[0].attributes) + assert "gen_ai.tool.call.arguments" in attrs + assert "gen_ai.tool.call.result" in attrs + + def test_tool_span_error_on_missing_tool(self, span_exporter, instrument): + """TOOL span should record ERROR when tool not found.""" + from src.agent.run import Runner + + agent = Agent(name="tool-agent", tools={}) + + tc = ToolCall( + tool_name="nonexistent_tool", + arguments="{}", + tool_call_id="call_789", + ) + model_resp = ModelResponse( + outputs=[LLMOutputItem(tool_calls=[tc])] + ) + + results = _run_async(Runner._invoke_tool_call(agent, model_resp)) + assert len(results) == 1 + assert results[0].error_marker is not None + + spans = span_exporter.get_finished_spans() + tool_spans = [s for s in spans if "execute_tool" in s.name] + assert len(tool_spans) == 1 + assert tool_spans[0].status.status_code == StatusCode.ERROR + + def test_tool_span_error_on_exception(self, span_exporter, instrument): + """TOOL span should record ERROR when tool raises exception.""" + from src.agent.run import Runner + + async def failing_tool(**kwargs): + raise ConnectionError("Network error") + + agent = Agent( + name="tool-agent", + tools={"flaky_tool": failing_tool}, + ) + + tc = ToolCall( + tool_name="flaky_tool", + arguments="{}", + tool_call_id="call_err", + ) + model_resp = ModelResponse( + outputs=[LLMOutputItem(tool_calls=[tc])] + ) + + results = _run_async(Runner._invoke_tool_call(agent, model_resp)) + assert len(results) == 1 + assert results[0].error_marker is not None + assert "Network error" in results[0].error_marker.message + + spans = span_exporter.get_finished_spans() + tool_spans = [s for s in spans if "execute_tool" in s.name] + assert len(tool_spans) == 1 + assert tool_spans[0].status.status_code == StatusCode.ERROR + + def test_multiple_tool_spans(self, span_exporter, instrument): + """Multiple tool_calls should produce multiple TOOL spans.""" + from src.agent.run import Runner + + async def mock_search(**kwargs): + return InternalResponse(data="search result") + + async def mock_browse(**kwargs): + return InternalResponse(data="page content") + + agent = Agent( + name="multi-tool", + tools={ + "search_global": mock_search, + "text_browser_view": mock_browse, + }, + ) + + tc1 = ToolCall( + tool_name="search_global", + arguments='{"q": "test"}', + tool_call_id="call_1", + ) + tc2 = ToolCall( + tool_name="text_browser_view", + arguments='{"url": "http://example.com"}', + tool_call_id="call_2", + ) + model_resp = ModelResponse( + outputs=[LLMOutputItem(tool_calls=[tc1, tc2])] + ) + + results = _run_async(Runner._invoke_tool_call(agent, model_resp)) + assert len(results) == 2 + + spans = span_exporter.get_finished_spans() + tool_spans = [s for s in spans if "execute_tool" in s.name] + assert len(tool_spans) == 2 + + +# --------------------------------------------------------------------------- +# TASK Span Tests (H5: create_sub_agents_wrap) +# --------------------------------------------------------------------------- + + +class TestTaskSpan: + def test_task_span_created(self, span_exporter, instrument): + """create_sub_agents closure should produce a TASK span.""" + from src.agent.multi_agent_tools import create_sub_agents_wrap + + closure = create_sub_agents_wrap( + "main-agent", "gpt-4o", {}, [], "system prompt" + ) + + sub_agents = [ + {"index": 0, "prompt": "Search for X"}, + {"index": 1, "prompt": "Search for Y"}, + ] + + result = _run_async(closure(sub_agents)) + assert result is not None + + spans = span_exporter.get_finished_spans() + task_spans = [ + s for s in spans if s.name == "run_task create_sub_agents" + ] + assert len(task_spans) == 1 + + span = task_spans[0] + attrs = dict(span.attributes) + assert attrs.get("gen_ai.span.kind") == "TASK" + assert attrs.get("gen_ai.operation.name") == "run_task" + assert attrs.get("gen_ai.framework") == "widesearch" + assert "input.value" in attrs + + def test_task_span_records_output(self, span_exporter, instrument): + """TASK span should record output.value.""" + from src.agent.multi_agent_tools import create_sub_agents_wrap + + closure = create_sub_agents_wrap( + "agent", "gpt-4o", {}, [], "prompt" + ) + + sub_agents = [{"index": 0, "prompt": "find info"}] + result = _run_async(closure(sub_agents)) + + spans = span_exporter.get_finished_spans() + task_spans = [ + s for s in spans if s.name == "run_task create_sub_agents" + ] + assert len(task_spans) == 1 + attrs = dict(task_spans[0].attributes) + assert "output.value" in attrs + + def test_task_span_error(self, span_exporter, instrument): + """TASK span should record ERROR when closure raises.""" + from src.agent.multi_agent_tools import create_sub_agents_wrap + + # Temporarily replace create_sub_agents_wrap's inner closure behavior + import src.agent.multi_agent_tools as mat + + original = mat.create_sub_agents_wrap + + def error_factory(*args, **kwargs): + original_closure = original(*args, **kwargs) + + async def error_closure(sub_agents): + raise RuntimeError("Sub-agent execution failed") + + return error_closure + + mat.create_sub_agents_wrap = error_factory + + # Re-instrument to pick up the new function + from opentelemetry.instrumentation.widesearch import WideSearchInstrumentor + + instrument.uninstrument() + instrument.instrument( + tracer_provider=span_exporter._tracer_provider + if hasattr(span_exporter, "_tracer_provider") + else None, + skip_dep_check=True, + ) + + # Since re-instrumentation is complex, let's just test the wrapper directly + # by calling the instrumented version + instrument.uninstrument() + + # Simpler approach: directly test the wrap function + from opentelemetry.instrumentation.widesearch.patch import ( + wrap_create_sub_agents_factory, + ) + from opentelemetry.util.genai.extended_handler import ( + ExtendedTelemetryHandler, + ) + from opentelemetry.sdk.trace import TracerProvider + from opentelemetry.sdk.trace.export import SimpleSpanProcessor + from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, + ) + + exporter = InMemorySpanExporter() + tp = TracerProvider() + tp.add_span_processor(SimpleSpanProcessor(exporter)) + handler = ExtendedTelemetryHandler(tracer_provider=tp) + + def failing_factory(*args, **kwargs): + async def failing_closure(sub_agents): + raise RuntimeError("Boom") + + return failing_closure + + wrapped_factory = wrap_create_sub_agents_factory( + failing_factory, None, (), {}, handler=handler + ) + + with pytest.raises(RuntimeError, match="Boom"): + _run_async(wrapped_factory([{"index": 0, "prompt": "x"}])) + + spans = exporter.get_finished_spans() + task_spans = [ + s for s in spans if s.name == "run_task create_sub_agents" + ] + assert len(task_spans) == 1 + assert task_spans[0].status.status_code == StatusCode.ERROR + + +# --------------------------------------------------------------------------- +# Parent-Child Relationship Tests +# --------------------------------------------------------------------------- + + +class TestParentChildRelationships: + def test_full_hierarchy_entry_agent_step(self, span_exporter, instrument): + """Full call through run_single_query should produce ENTRY > AGENT > STEP.""" + from src.agent.run import run_single_query + + _run_async(run_single_query("hierarchy test", agent_name="root")) + + spans = span_exporter.get_finished_spans() + entry_spans = [ + s for s in spans if s.name == "enter_ai_application_system" + ] + agent_spans = [s for s in spans if "invoke_agent" in s.name] + step_spans = [s for s in spans if s.name == "react step"] + + assert len(entry_spans) == 1 + assert len(agent_spans) == 1 + assert len(step_spans) >= 1 + + entry = entry_spans[0] + agent = agent_spans[0] + step = step_spans[0] + + # AGENT is child of ENTRY + assert agent.parent.span_id == entry.context.span_id + # STEP is child of AGENT + assert step.parent.span_id == agent.context.span_id + + def test_tool_span_is_child_of_step(self, span_exporter, instrument): + """TOOL span should be child of the STEP span when invoked during a step.""" + from src.agent.run import Runner + + async def mock_tool(**kwargs): + return InternalResponse(data="result") + + agent = Agent( + name="hierarchy-agent", + tools={"my_tool": mock_tool}, + ) + + async def custom_step(*, agent, memory): + tc = ToolCall( + tool_name="my_tool", + arguments="{}", + tool_call_id="tc_hier", + ) + model_resp = ModelResponse( + outputs=[LLMOutputItem(tool_calls=[tc])] + ) + await Runner._invoke_tool_call(agent, model_resp) + return ActionStep(step_status=StepStatus.FINISHED, content="done") + + Runner._step_override = custom_step + + try: + async def _run(): + async for _ in Runner.run(agent, "test"): + pass + + _run_async(_run()) + finally: + Runner._step_override = None + + spans = span_exporter.get_finished_spans() + step_spans = [s for s in spans if s.name == "react step"] + tool_spans = [s for s in spans if "execute_tool" in s.name] + + assert len(step_spans) >= 1 + assert len(tool_spans) >= 1 + + step_span = step_spans[0] + tool_span = tool_spans[0] + assert tool_span.parent.span_id == step_span.context.span_id