Skip to content

Commit 0b86253

Browse files
xumapleclaude
andcommitted
Fix context propagation bugs and remove handler suppression
Bug 1: Replace stale _current_run snapshot with ambient context in outbound interceptor. Add _get_current_run_for_propagation() helper that filters _ContextBridgeRunTree from ambient context. Outbound methods now read get_current_run_tree() for @Traceable nesting instead of a frozen reference from workflow entry. Bug 2: Add tracing_context() to Nexus inbound interceptor for both execute_nexus_operation_start and execute_nexus_operation_cancel, matching the activity inbound pattern. Ensures @Traceable functions in Nexus handlers have a LangSmith client even with add_temporal_runs=False. Remove handler suppression (is_handler check, _workflow_is_active flag) to align with OTel interceptor which creates spans for all handlers unconditionally. Add dump_traces() to test infrastructure for per-root-trace assertions. Restructure comprehensive tests so user_pipeline only wraps start_workflow, with polling/signals/queries as independent root traces. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent d89ae66 commit 0b86253

5 files changed

Lines changed: 302 additions & 244 deletions

File tree

temporalio/contrib/langsmith/_interceptor.py

Lines changed: 67 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,18 @@ def _extract_nexus_context(
123123
return _ReplaySafeRunTree(run, executor=executor) if run else None
124124

125125

126+
def _get_current_run_for_propagation() -> RunTree | None:
127+
"""Get the current ambient run for context propagation.
128+
129+
Filters out _ContextBridgeRunTree, which is internal scaffolding
130+
that should never be serialized into headers or used as parent runs.
131+
"""
132+
run = get_current_run_tree()
133+
if isinstance(run, _ContextBridgeRunTree):
134+
return None
135+
return run
136+
137+
126138
# ---------------------------------------------------------------------------
127139
# Sandbox safety: patch @traceable's aio_to_thread
128140
# ---------------------------------------------------------------------------
@@ -742,46 +754,26 @@ class _LangSmithWorkflowInboundInterceptor(
742754
"""Instruments workflow execution with LangSmith runs."""
743755

744756
_config: ClassVar[LangSmithInterceptor]
745-
_current_run: _ReplaySafeRunTree | None = None
746757

747758
def init(self, outbound: temporalio.worker.WorkflowOutboundInterceptor) -> None:
748-
super().init(
749-
_LangSmithWorkflowOutboundInterceptor(outbound, self._config, self)
750-
)
759+
super().init(_LangSmithWorkflowOutboundInterceptor(outbound, self._config))
751760

752761
@contextmanager
753762
def _workflow_maybe_run(
754763
self,
755764
name: str,
756765
headers: Mapping[str, Payload] | None = None,
757-
*,
758-
is_handler: bool = False,
759766
) -> Iterator[_ReplaySafeRunTree | None]:
760767
"""Workflow-specific run creation with metadata.
761768
762-
Extracts parent from headers (if provided) and stores the run (or parent
763-
fallback) as ``_current_run`` so the outbound interceptor can propagate
764-
context even when ``add_temporal_runs=False``.
765-
766-
Always sets up ``tracing_context`` so ``@traceable`` functions called
767-
from workflow code can discover the parent and LangSmith client,
768-
independent of the ``add_temporal_runs`` toggle.
769-
770-
When ``is_handler`` is True and no LangSmith context is found in
771-
headers, skips trace creation if a workflow run is already active
772-
(``_current_run`` is set). This suppresses orphan traces from
773-
uninstrumented client operations (e.g. query polling) while still
774-
allowing handler traces when invoked with propagated context.
769+
Extracts parent from headers (if provided) and sets up
770+
``tracing_context`` so ``@traceable`` functions called from workflow
771+
code can discover the parent and LangSmith client, independent of the
772+
``add_temporal_runs`` toggle.
775773
"""
776774
parent = _extract_context(headers, self._config._executor) if headers else None
777775
if parent is not None:
778776
parent.ls_client = self._config._client
779-
# Handler from an uninstrumented client during workflow execution:
780-
# no LangSmith headers but _current_run is set. Skip trace creation
781-
# to avoid orphan/duplicate handler traces (e.g. query polling).
782-
if is_handler and parent is None and self._current_run is not None:
783-
yield None
784-
return
785777
info = temporalio.workflow.info()
786778
extra_metadata = {
787779
"temporalWorkflowID": info.workflow_id,
@@ -815,12 +807,7 @@ def _workflow_maybe_run(
815807
parent=parent,
816808
extra_metadata=extra_metadata,
817809
) as run:
818-
prev_run = self._current_run
819-
self._current_run = run or parent
820-
try:
821-
yield run
822-
finally:
823-
self._current_run = prev_run
810+
yield run
824811

825812
async def execute_workflow(
826813
self, input: temporalio.worker.ExecuteWorkflowInput
@@ -833,31 +820,23 @@ async def execute_workflow(
833820
return await super().execute_workflow(input)
834821

835822
async def handle_signal(self, input: temporalio.worker.HandleSignalInput) -> None:
836-
with self._workflow_maybe_run(
837-
f"HandleSignal:{input.signal}", input.headers, is_handler=True
838-
):
823+
with self._workflow_maybe_run(f"HandleSignal:{input.signal}", input.headers):
839824
return await super().handle_signal(input)
840825

841826
async def handle_query(self, input: temporalio.worker.HandleQueryInput) -> Any:
842-
with self._workflow_maybe_run(
843-
f"HandleQuery:{input.query}", input.headers, is_handler=True
844-
):
827+
with self._workflow_maybe_run(f"HandleQuery:{input.query}", input.headers):
845828
return await super().handle_query(input)
846829

847830
def handle_update_validator(
848831
self, input: temporalio.worker.HandleUpdateInput
849832
) -> None:
850-
with self._workflow_maybe_run(
851-
f"ValidateUpdate:{input.update}", input.headers, is_handler=True
852-
):
833+
with self._workflow_maybe_run(f"ValidateUpdate:{input.update}", input.headers):
853834
return super().handle_update_validator(input)
854835

855836
async def handle_update_handler(
856837
self, input: temporalio.worker.HandleUpdateInput
857838
) -> Any:
858-
with self._workflow_maybe_run(
859-
f"HandleUpdate:{input.update}", input.headers, is_handler=True
860-
):
839+
with self._workflow_maybe_run(f"HandleUpdate:{input.update}", input.headers):
861840
return await super().handle_update_handler(input)
862841

863842

@@ -875,19 +854,22 @@ def __init__(
875854
self,
876855
next: temporalio.worker.WorkflowOutboundInterceptor,
877856
config: LangSmithInterceptor,
878-
inbound: _LangSmithWorkflowInboundInterceptor,
879857
) -> None:
880858
super().__init__(next)
881859
self._config = config
882-
self._inbound = inbound
883860

884861
@contextmanager
885862
def _traced_outbound(
886863
self, name: str, input: _InputWithHeaders
887864
) -> Iterator[_ReplaySafeRunTree | None]:
888-
"""Outbound workflow run creation with context injection into input.headers."""
889-
with self._config.maybe_run(name, parent=self._inbound._current_run) as run:
890-
context_source = run or self._inbound._current_run
865+
"""Outbound workflow run creation with context injection into input.headers.
866+
867+
Uses ambient context (``get_current_run_tree()``) instead of a cached
868+
snapshot, so ``@traceable`` step functions that wrap outbound calls
869+
correctly parent the outbound run under themselves.
870+
"""
871+
with self._config.maybe_run(name) as run:
872+
context_source = run or _get_current_run_for_propagation()
891873
if context_source:
892874
input.headers = _inject_context(input.headers, context_source)
893875
yield run
@@ -923,8 +905,8 @@ async def signal_external_workflow(
923905
return await super().signal_external_workflow(input)
924906

925907
def continue_as_new(self, input: temporalio.worker.ContinueAsNewInput) -> NoReturn:
926-
# No trace created, but inject context from inbound's current run
927-
current_run = getattr(self._inbound, "_current_run", None)
908+
# No trace created, but inject context from ambient run
909+
current_run = _get_current_run_for_propagation()
928910
if current_run:
929911
input.headers = _inject_context(input.headers, current_run)
930912
super().continue_as_new(input)
@@ -934,9 +916,8 @@ async def start_nexus_operation(
934916
) -> temporalio.workflow.NexusOperationHandle[Any]:
935917
with self._config.maybe_run(
936918
f"StartNexusOperation:{input.service}/{input.operation_name}",
937-
parent=self._inbound._current_run,
938919
) as run:
939-
context_source = run or self._inbound._current_run
920+
context_source = run or _get_current_run_for_propagation()
940921
if context_source:
941922
input.headers = _inject_nexus_context(
942923
input.headers or {}, context_source
@@ -969,20 +950,42 @@ async def execute_nexus_operation_start(
969950
| nexusrpc.handler.StartOperationResultAsync
970951
):
971952
parent = _extract_nexus_context(input.ctx.headers, self._config._executor)
972-
with self._config.maybe_run(
973-
f"RunStartNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}",
974-
run_type="tool",
975-
parent=parent,
976-
):
977-
return await self.next.execute_nexus_operation_start(input)
953+
if parent is not None and hasattr(parent, "ls_client"):
954+
parent.ls_client = self._config._client
955+
ctx_kwargs: dict[str, Any] = {
956+
"client": self._config._client,
957+
"enabled": True,
958+
}
959+
if self._config._project_name:
960+
ctx_kwargs["project_name"] = self._config._project_name
961+
if parent:
962+
ctx_kwargs["parent"] = parent
963+
with tracing_context(**ctx_kwargs):
964+
with self._config.maybe_run(
965+
f"RunStartNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}",
966+
run_type="tool",
967+
parent=parent,
968+
):
969+
return await self.next.execute_nexus_operation_start(input)
978970

979971
async def execute_nexus_operation_cancel(
980972
self, input: temporalio.worker.ExecuteNexusOperationCancelInput
981973
) -> None:
982974
parent = _extract_nexus_context(input.ctx.headers, self._config._executor)
983-
with self._config.maybe_run(
984-
f"RunCancelNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}",
985-
run_type="tool",
986-
parent=parent,
987-
):
988-
return await self.next.execute_nexus_operation_cancel(input)
975+
if parent is not None and hasattr(parent, "ls_client"):
976+
parent.ls_client = self._config._client
977+
ctx_kwargs: dict[str, Any] = {
978+
"client": self._config._client,
979+
"enabled": True,
980+
}
981+
if self._config._project_name:
982+
ctx_kwargs["project_name"] = self._config._project_name
983+
if parent:
984+
ctx_kwargs["parent"] = parent
985+
with tracing_context(**ctx_kwargs):
986+
with self._config.maybe_run(
987+
f"RunCancelNexusOperationHandler:{input.ctx.service}/{input.ctx.operation}",
988+
run_type="tool",
989+
parent=parent,
990+
):
991+
return await self.next.execute_nexus_operation_cancel(input)

tests/contrib/langsmith/conftest.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -58,24 +58,18 @@ def clear(self) -> None:
5858
self._by_id.clear()
5959

6060

61-
def dump_runs(collector: InMemoryRunCollector) -> list[str]:
62-
"""Reconstruct parent-child hierarchy from collected runs.
61+
def dump_traces(collector: InMemoryRunCollector) -> list[list[str]]:
62+
"""Reconstruct parent-child hierarchy grouped by root trace.
6363
64-
Returns a list of indented strings, e.g.:
65-
["StartWorkflow:MyWf", " RunWorkflow:MyWf", " StartActivity:do_thing"]
64+
Returns a list of traces, where each trace is a list of indented
65+
strings (same format as dump_runs). Each trace starts from a
66+
different root run.
6667
"""
6768
runs = collector.runs
6869
children: dict[str | None, list[_RunRecord]] = {}
6970
for r in runs:
7071
children.setdefault(r.parent_run_id, []).append(r)
7172

72-
result: list[str] = []
73-
74-
def _walk(parent_id: str | None, depth: int) -> None:
75-
for child in children.get(parent_id, []):
76-
result.append(" " * depth + child.name)
77-
_walk(child.id, depth + 1)
78-
7973
# Strict: reject dangling parent references
8074
known_ids = {r.id for r in runs}
8175
for r in runs:
@@ -84,10 +78,26 @@ def _walk(parent_id: str | None, depth: int) -> None:
8478
f"Run {r.name!r} (id={r.id}) has parent_run_id={r.parent_run_id} "
8579
f"which is not in the collected runs — dangling parent reference"
8680
)
87-
# Only walk true roots (parent_run_id is None)
88-
_walk(None, 0)
8981

90-
return result
82+
traces: list[list[str]] = []
83+
for root in children.get(None, []):
84+
trace: list[str] = []
85+
86+
def _walk(parent_id: str | None, depth: int) -> None:
87+
for child in children.get(parent_id, []):
88+
trace.append(" " * depth + child.name)
89+
_walk(child.id, depth + 1)
90+
91+
trace.append(root.name)
92+
_walk(root.id, 1)
93+
traces.append(trace)
94+
95+
return traces
96+
97+
98+
def dump_runs(collector: InMemoryRunCollector) -> list[str]:
99+
"""Flat list of all runs across all traces."""
100+
return [run for trace in dump_traces(collector) for run in trace]
91101

92102

93103
def make_mock_ls_client(collector: InMemoryRunCollector) -> MagicMock:

0 commit comments

Comments
 (0)