Skip to content

Commit 3336a78

Browse files
authored
Inject sampled traceparent on workflow execute calls (#459)
* Inject sampled traceparent on workflow execute calls Without a traceparent header, external workers inherit the API's HTTP span context via Temporal. In production the API's spans are frequently unsampled (ParentBasedTraceIdRatio), so workers produce no-op spans and traces never reach the collector — /trace/otel returns WF_1500. Adds a BeforeRequestHook that fires on any /execute path and injects a sampled W3C traceparent: forwarding the active OTEL span if it is already sampled, otherwise generating a fresh sampled one. An explicitly set traceparent header is never overwritten. * Add tests for TraceparentInjectionHook Covers: no-op on non-execute paths, sampled header injection, explicit header preservation, OTEL context propagation, fallback for unsampled or absent spans, and uniqueness of generated IDs. * Address review comments - Shorten TraceparentInjectionHook docstring to one line - Remove module docstring from test file - Drop low-ROI uniqueness test * Fix lint errors in test file Remove unused opentelemetry.trace import (ruff F401) and add isinstance assertions so pyright can narrow Union[Request, Exception] before accessing .headers. * Use operation ID instead of URL path to identify execute calls Matching on request.url.path.endswith("/execute") would affect any future endpoint that happens to share that suffix. Keying on the operation ID is explicit and safe. * Remove unused _EXECUTE_OPERATION_IDS import
1 parent 6087465 commit 3336a78

3 files changed

Lines changed: 152 additions & 0 deletions

File tree

src/mistralai/client/_hooks/registration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from .custom_user_agent import CustomUserAgentHook
22
from .deprecation_warning import DeprecationWarningHook
3+
from .traceparent import TraceparentInjectionHook
34
from .tracing import TracingHook
45
from .types import Hooks
56

@@ -16,6 +17,7 @@ def init_hooks(hooks: Hooks):
1617
"""
1718
tracing_hook = TracingHook()
1819
hooks.register_before_request_hook(CustomUserAgentHook())
20+
hooks.register_before_request_hook(TraceparentInjectionHook())
1921
hooks.register_after_success_hook(DeprecationWarningHook())
2022
hooks.register_after_success_hook(tracing_hook)
2123
hooks.register_before_request_hook(tracing_hook)
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import random
2+
from typing import Dict, Union
3+
4+
import httpx
5+
from opentelemetry.propagate import inject
6+
7+
from .types import BeforeRequestContext, BeforeRequestHook
8+
9+
10+
_EXECUTE_OPERATION_IDS = {
11+
"execute_workflow_v1_workflows__workflow_identifier__execute_post",
12+
"execute_workflow_registration_v1_workflows_registrations__workflow_registration_id__execute_post",
13+
}
14+
15+
16+
class TraceparentInjectionHook(BeforeRequestHook):
17+
"""Inject a sampled traceparent on /execute requests so worker traces are always recorded."""
18+
19+
def before_request(
20+
self, hook_ctx: BeforeRequestContext, request: httpx.Request
21+
) -> Union[httpx.Request, Exception]:
22+
if hook_ctx.operation_id not in _EXECUTE_OPERATION_IDS:
23+
return request
24+
25+
# Don't overwrite an explicitly provided traceparent.
26+
if "traceparent" in request.headers:
27+
return request
28+
29+
carrier: Dict[str, str] = {}
30+
inject(carrier)
31+
traceparent = carrier.get("traceparent", "")
32+
if not traceparent.endswith("-01"):
33+
trace_id = random.getrandbits(128)
34+
span_id = random.getrandbits(64)
35+
traceparent = f"00-{trace_id:032x}-{span_id:016x}-01"
36+
37+
request.headers["traceparent"] = traceparent
38+
return request
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
import re
2+
import unittest
3+
from unittest.mock import MagicMock
4+
5+
import httpx
6+
from opentelemetry.sdk.trace import TracerProvider
7+
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
8+
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
9+
from opentelemetry.sdk.trace.sampling import ALWAYS_OFF
10+
11+
from mistralai.client._hooks.traceparent import TraceparentInjectionHook
12+
from mistralai.client._hooks.types import BeforeRequestContext, HookContext
13+
14+
15+
TRACEPARENT_RE = re.compile(r"^00-[0-9a-f]{32}-[0-9a-f]{16}-01$")
16+
17+
_EXECUTE_OP_ID = "execute_workflow_v1_workflows__workflow_identifier__execute_post"
18+
_EXECUTE_REG_OP_ID = "execute_workflow_registration_v1_workflows_registrations__workflow_registration_id__execute_post"
19+
_OTHER_OP_ID = "list_executions_v1_workflows__workflow_identifier__executions_get"
20+
21+
22+
def _make_hook_ctx(operation_id: str = _EXECUTE_OP_ID) -> BeforeRequestContext:
23+
ctx = HookContext(
24+
config=MagicMock(),
25+
base_url="https://api.mistral.ai",
26+
operation_id=operation_id,
27+
oauth2_scopes=None,
28+
security_source=None,
29+
)
30+
return BeforeRequestContext(ctx)
31+
32+
33+
def _make_request(path: str, traceparent: str | None = None) -> httpx.Request:
34+
headers = {}
35+
if traceparent is not None:
36+
headers["traceparent"] = traceparent
37+
return httpx.Request("POST", f"https://api.mistral.ai{path}", headers=headers)
38+
39+
40+
class TestTraceparentInjectionHook(unittest.TestCase):
41+
def setUp(self):
42+
self.hook = TraceparentInjectionHook()
43+
44+
# --- non-execute operations must NOT be touched ---
45+
46+
def test_other_operation_is_unchanged(self):
47+
req = _make_request("/v1/workflows/my-wf/executions")
48+
result = self.hook.before_request(_make_hook_ctx(_OTHER_OP_ID), req)
49+
assert isinstance(result, httpx.Request)
50+
self.assertNotIn("traceparent", result.headers)
51+
52+
# --- execute operations: header injected ---
53+
54+
def test_execute_gets_sampled_traceparent(self):
55+
req = _make_request("/v1/workflows/my-wf/execute")
56+
result = self.hook.before_request(_make_hook_ctx(_EXECUTE_OP_ID), req)
57+
assert isinstance(result, httpx.Request)
58+
self.assertIn("traceparent", result.headers)
59+
self.assertRegex(result.headers["traceparent"], TRACEPARENT_RE)
60+
61+
def test_execute_registration_gets_sampled_traceparent(self):
62+
req = _make_request("/v1/workflows/registrations/reg-id/execute")
63+
result = self.hook.before_request(_make_hook_ctx(_EXECUTE_REG_OP_ID), req)
64+
assert isinstance(result, httpx.Request)
65+
self.assertIn("traceparent", result.headers)
66+
self.assertRegex(result.headers["traceparent"], TRACEPARENT_RE)
67+
68+
def test_explicit_traceparent_is_not_overwritten(self):
69+
explicit = "00-aabbccddeeff00112233445566778899-0102030405060708-01"
70+
req = _make_request("/v1/workflows/my-wf/execute", traceparent=explicit)
71+
result = self.hook.before_request(_make_hook_ctx(_EXECUTE_OP_ID), req)
72+
assert isinstance(result, httpx.Request)
73+
self.assertEqual(result.headers["traceparent"], explicit)
74+
75+
# --- OTEL context propagation ---
76+
77+
def test_propagates_sampled_active_span(self):
78+
exporter = InMemorySpanExporter()
79+
provider = TracerProvider()
80+
provider.add_span_processor(SimpleSpanProcessor(exporter))
81+
tracer = provider.get_tracer("test")
82+
83+
with tracer.start_as_current_span("parent") as span:
84+
req = _make_request("/v1/workflows/my-wf/execute")
85+
result = self.hook.before_request(_make_hook_ctx(_EXECUTE_OP_ID), req)
86+
87+
assert isinstance(result, httpx.Request)
88+
injected = result.headers["traceparent"]
89+
self.assertTrue(injected.endswith("-01"))
90+
trace_id_hex = f"{span.get_span_context().trace_id:032x}"
91+
self.assertIn(trace_id_hex, injected)
92+
93+
def test_generates_fresh_traceparent_when_no_active_span(self):
94+
req = _make_request("/v1/workflows/my-wf/execute")
95+
result = self.hook.before_request(_make_hook_ctx(_EXECUTE_OP_ID), req)
96+
assert isinstance(result, httpx.Request)
97+
self.assertRegex(result.headers["traceparent"], TRACEPARENT_RE)
98+
99+
def test_generates_fresh_traceparent_when_span_is_unsampled(self):
100+
provider = TracerProvider(sampler=ALWAYS_OFF)
101+
tracer = provider.get_tracer("test")
102+
103+
with tracer.start_as_current_span("unsampled-parent"):
104+
req = _make_request("/v1/workflows/my-wf/execute")
105+
result = self.hook.before_request(_make_hook_ctx(_EXECUTE_OP_ID), req)
106+
107+
assert isinstance(result, httpx.Request)
108+
injected = result.headers["traceparent"]
109+
self.assertTrue(injected.endswith("-01"), f"Expected sampled, got: {injected}")
110+
111+
if __name__ == "__main__":
112+
unittest.main()

0 commit comments

Comments
 (0)