Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,24 @@

from __future__ import annotations

import contextvars
import importlib
import json
from types import SimpleNamespace
from typing import Any

# Per-call state stored in a ContextVar so that concurrent agent invocations
# (e.g. TUI multi-tab, parallel subagents) do not share / overwrite each
# other's react-step / token / response metadata. ``state(instance)`` keeps
# its previous signature for backward compatibility; the ``instance``
# argument is intentionally unused now.
_HERMES_STATE: contextvars.ContextVar["dict[str, Any] | None"] = (
contextvars.ContextVar(
"opentelemetry_hermes_state",
default=None,
)
)

from opentelemetry.semconv._incubating.attributes import (
gen_ai_attributes as GenAIAttributes,
)
Expand Down Expand Up @@ -750,7 +763,18 @@ def apply_skill_attributes(


def state(instance: Any) -> dict[str, Any]:
current = getattr(instance, "_otel_hermes_state", None)
"""Return the per-call telemetry state for the current execution context.

The state dict is stored in ``_HERMES_STATE`` (a ``ContextVar``) instead
of being attached to ``instance``, so that concurrent invocations on the
same agent instance (e.g. TUI multi-tab, parallel subagents, asyncio
``gather``) get isolated state and do not corrupt each other's react-step
span / token counters / response metadata.

The ``instance`` parameter is kept for backward compatibility with all
existing call sites; it is intentionally unused.
"""
current = _HERMES_STATE.get()
if current is None:
current = {
"handler": None,
Expand All @@ -767,12 +791,18 @@ def state(instance: Any) -> dict[str, Any]:
"first_token_monotonic_s": None,
"active_llm_depth": 0,
}
setattr(instance, "_otel_hermes_state", current)
_HERMES_STATE.set(current)
return current


def clear_state(instance: Any) -> None:
setattr(instance, "_otel_hermes_state", None)
"""Drop the per-call state for the current execution context.

``instance`` is kept for backward compatibility; the state is now
isolated per ``ContextVar`` and clearing it only affects the current
asyncio task / thread, never sibling concurrent invocations.
"""
_HERMES_STATE.set(None)


def start_step(handler, instance: Any, finish_step) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1612,3 +1612,165 @@ def _streaming_api_response(api_kwargs, *, on_first_delta=None):
if on_first_delta is not None:
on_first_delta()
return _response(content="1 2 3", finish_reason="stop")


def test_state_is_isolated_across_concurrent_invocations(
instrumentation_module,
tracer_provider,
meter_provider,
span_exporter,
):
"""Concurrent calls on the same AIAgent must produce isolated react-step
and token state.

Regression test for the silent trace corruption that occurred when
``helpers.state(instance)`` stored its per-call state on the instance
itself: two concurrent invocations on a shared agent overwrote each
other's ``current_step_invocation`` / token counters, leaking step spans
and mixing token usage across traces. The fix moves the state into a
``ContextVar`` so each thread / asyncio Task sees an independent dict.
"""
import threading

runtime = _runtime(instrumentation_module, tracer_provider, meter_provider)
shared_agent = _FakeAgent(session_id="session-shared")
barrier = threading.Barrier(2)
errors: list[BaseException] = []

def run_one_conversation(thread_id: int) -> None:
try:
tool_call = _tool_call(call_id=f"call-t{thread_id}")

def wrapped_run(user_message):
# First LLM round: returns a tool_call -> opens a step span
# that must be closed by the tool batch wrapper.
runtime.llm_wrapper(
lambda api_kwargs: _response(
content=None,
finish_reason="tool_calls",
tool_calls=[tool_call],
response_id=f"resp-t{thread_id}-1",
prompt_tokens=11,
completion_tokens=3,
),
shared_agent,
(
{
"model": shared_agent.model,
"messages": [
{"role": "user", "content": user_message}
],
},
),
{},
)
# Synchronise both threads after they have opened their
# first step span. With the buggy implementation, the second
# thread's ``state(instance)`` call would clobber the first
# thread's ``current_step_invocation`` here, leaking a span.
barrier.wait(timeout=5)
runtime.tool_batch_wrapper(
lambda: runtime.tool_wrapper(
lambda *args, **kwargs: f"tool_ok_t{thread_id}",
shared_agent,
(
"read_file",
{"path": f"/tmp/t{thread_id}.txt"},
f"task-t{thread_id}",
tool_call.id,
),
{},
),
shared_agent,
(),
{},
)
# Second LLM round: closes the second step with finish=stop.
runtime.llm_wrapper(
lambda api_kwargs: _response(
content=f"final-t{thread_id}",
finish_reason="stop",
response_id=f"resp-t{thread_id}-2",
prompt_tokens=13,
completion_tokens=2,
),
shared_agent,
(
{
"model": shared_agent.model,
"messages": [
{"role": "user", "content": user_message},
{
"role": "tool",
"content": f"tool_ok_t{thread_id}",
},
],
},
),
{},
)
return {"final_response": f"final-t{thread_id}"}

runtime.run_wrapper(
wrapped_run, shared_agent, (f"hello-from-t{thread_id}",), {}
)
except BaseException as exc: # pragma: no cover - defensive
errors.append(exc)

threads = [
threading.Thread(target=run_one_conversation, args=(i,))
for i in (1, 2)
]
for thread in threads:
thread.start()
for thread in threads:
thread.join(timeout=10)

assert not errors, f"worker thread raised: {errors}"

agent_spans = _spans_by_kind(span_exporter, "AGENT")
step_spans = _spans_by_kind(span_exporter, "STEP")
llm_spans = _spans_by_kind(span_exporter, "LLM")
tool_spans = _spans_by_kind(span_exporter, "TOOL")

# Two concurrent agent invocations must produce two complete trees.
assert len(agent_spans) == 2
# Without the fix, one of the four step spans is leaked / never closed
# because state["current_step_invocation"] is overwritten by the sibling
# thread before ``finish_step`` gets to close it.
assert len(step_spans) == 4, (
"expected 4 react step spans (2 per concurrent call), got "
f"{len(step_spans)} - state(instance) is not concurrency-safe"
)
assert len(llm_spans) == 4
assert len(tool_spans) == 2

# Each agent span must root its own complete sub-tree: 2 steps,
# 2 LLM spans, 1 tool span. Without the fix, spans from sibling
# invocations are mixed under the wrong agent root.
for agent_span in agent_spans:
descendant_steps = [
span
for span in step_spans
if span.parent
and span.parent.span_id == agent_span.context.span_id
]
assert len(descendant_steps) == 2, (
"each agent span should own exactly 2 react step spans, "
f"got {len(descendant_steps)}"
)
descendant_step_ids = {
span.context.span_id for span in descendant_steps
}
descendant_llms = [
span
for span in llm_spans
if span.parent and span.parent.span_id in descendant_step_ids
]
descendant_tools = [
span
for span in tool_spans
if span.parent and span.parent.span_id in descendant_step_ids
]
assert len(descendant_llms) == 2
assert len(descendant_tools) == 1
Loading