From e8458de827da1272c7e7dd6e9939ac1f4def8a75 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 8 May 2026 08:59:49 -0700 Subject: [PATCH 01/46] Bump ruff to 0.15 and reformat Also bump `[tool.ruff] target-version` from py39 to py310 to match `requires-python`; the old setting caused 0.15 to reject `match` statements in the codebase. --- pyproject.toml | 4 +- scripts/gen_bridge_client.py | 2 +- scripts/gen_protos.py | 12 +- .../worker/workflow_sandbox/_importer.py | 2 +- tests/conftest.py | 18 +- .../aws/lambda_worker/test_lambda_worker.py | 12 +- tests/contrib/aws/s3driver/test_s3driver.py | 18 +- .../langgraph/test_continue_as_new_cached.py | 18 +- .../contrib/langgraph/test_e2e_functional.py | 24 +-- tests/contrib/langsmith/test_integration.py | 72 +++---- .../openai_agents/test_openai_tracing.py | 180 +++++++++--------- .../opentelemetry/test_opentelemetry.py | 30 +-- .../test_opentelemetry_plugin.py | 30 +-- .../workflow_streams/test_workflow_streams.py | 12 +- tests/nexus/test_workflow_caller.py | 24 +-- .../test_workflow_caller_error_chains.py | 12 +- tests/nexus/test_workflow_run_operation.py | 6 +- tests/test_activity.py | 6 +- tests/test_extstore.py | 6 +- tests/test_plugins.py | 26 +-- tests/test_runtime.py | 12 +- tests/test_serialization_context.py | 2 +- tests/test_service.py | 6 +- tests/test_workflow.py | 6 +- tests/worker/test_command_aware_visitor.py | 12 +- uv.lock | 46 ++--- 26 files changed, 300 insertions(+), 298 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d85badb64..b74d2a7ad 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,7 +64,7 @@ dev = [ "pytest~=9.0", "pytest-asyncio>=0.21,<0.22", "pytest-timeout~=2.2", - "ruff>=0.5.0,<0.6", + "ruff>=0.15.12,<0.16", "toml>=0.10.2,<0.11", "twine>=4.0.1,<5", "maturin>=1.8.2", @@ -239,7 +239,7 @@ exclude = [ ] [tool.ruff] -target-version = "py39" +target-version = "py310" [build-system] requires = ["maturin>=1.0,<2.0"] diff --git a/scripts/gen_bridge_client.py b/scripts/gen_bridge_client.py index 07706f1f3..f06dd29e6 100644 --- a/scripts/gen_bridge_client.py +++ b/scripts/gen_bridge_client.py @@ -42,7 +42,7 @@ def generate_python_services( ''') def service_name(s): - return f"import {sanitize_proto_name(s.full_name)[:-len(s.name)-1]}" + return f"import {sanitize_proto_name(s.full_name)[: -len(s.name) - 1]}" service_imports = [ service_name(service_descriptor) diff --git a/scripts/gen_protos.py b/scripts/gen_protos.py index 0047952dc..e2be3975b 100644 --- a/scripts/gen_protos.py +++ b/scripts/gen_protos.py @@ -153,12 +153,12 @@ def check_proto_toolchain_versions(): _, _, proto_version = line.partition("==") elif line.startswith("grpcio-tools"): _, _, grpcio_tools_version = line.partition("==") - assert proto_version.startswith( - "3." - ), f"expected 3.x protobuf, found {proto_version}" - assert grpcio_tools_version.startswith( - "1.48." - ), f"expected 1.48.x grpcio-tools, found {grpcio_tools_version}" + assert proto_version.startswith("3."), ( + f"expected 3.x protobuf, found {proto_version}" + ) + assert grpcio_tools_version.startswith("1.48."), ( + f"expected 1.48.x grpcio-tools, found {grpcio_tools_version}" + ) def generate_protos(output_dir: Path): diff --git a/temporalio/worker/workflow_sandbox/_importer.py b/temporalio/worker/workflow_sandbox/_importer.py index 010c0c082..42f0e06b2 100644 --- a/temporalio/worker/workflow_sandbox/_importer.py +++ b/temporalio/worker/workflow_sandbox/_importer.py @@ -558,7 +558,7 @@ def _calc___package__(globals: Mapping[str, object]) -> str: if package is not None: if spec is not None and package != spec.parent: warnings.warn( - "__package__ != __spec__.parent " f"({package!r} != {spec.parent!r})", + f"__package__ != __spec__.parent ({package!r} != {spec.parent!r})", DeprecationWarning, stacklevel=3, ) diff --git a/tests/conftest.py b/tests/conftest.py index 303af2e3b..48d5f0669 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -19,25 +19,25 @@ # If there is an integration test environment variable set, we must remove the # first path from the sys.path so we can import the wheel instead if os.getenv("TEMPORAL_INTEGRATION_TEST"): - assert ( - sys.path[0] == os.getcwd() - ), "Expected first sys.path to be the current working dir" + assert sys.path[0] == os.getcwd(), ( + "Expected first sys.path to be the current working dir" + ) sys.path.pop(0) # Import temporalio and confirm it is prefixed with virtual env import temporalio - assert temporalio.__file__.startswith( - sys.prefix - ), f"Expected {temporalio.__file__} to be in {sys.prefix}" + assert temporalio.__file__.startswith(sys.prefix), ( + f"Expected {temporalio.__file__} to be in {sys.prefix}" + ) # Unless specifically overridden, we expect tests to run under protobuf 4.x/5.x lib import google.protobuf protobuf_version = google.protobuf.__version__ if os.getenv("TEMPORAL_TEST_PROTO3"): - assert protobuf_version.startswith( - "3." - ), f"Expected protobuf 3.x, got {protobuf_version}" + assert protobuf_version.startswith("3."), ( + f"Expected protobuf 3.x, got {protobuf_version}" + ) else: assert ( protobuf_version.startswith("4.") diff --git a/tests/contrib/aws/lambda_worker/test_lambda_worker.py b/tests/contrib/aws/lambda_worker/test_lambda_worker.py index 178e078ac..cda1cd12f 100644 --- a/tests/contrib/aws/lambda_worker/test_lambda_worker.py +++ b/tests/contrib/aws/lambda_worker/test_lambda_worker.py @@ -247,11 +247,13 @@ def fake_create_worker(_client: Any, **kwargs: Any) -> Any: load_config=lambda: ClientConfigProfile(), getenv={"TEMPORAL_TASK_QUEUE": "test-queue"}.get, # type: ignore[arg-type] extract_lambda_ctx=lambda ctx: ( - ctx.aws_request_id, - ctx.invoked_function_arn, - ) - if hasattr(ctx, "aws_request_id") - else None, + ( + ctx.aws_request_id, + ctx.invoked_function_arn, + ) + if hasattr(ctx, "aws_request_id") + else None + ), ) diff --git a/tests/contrib/aws/s3driver/test_s3driver.py b/tests/contrib/aws/s3driver/test_s3driver.py index 64e0d53ab..19b3419f8 100644 --- a/tests/contrib/aws/s3driver/test_s3driver.py +++ b/tests/contrib/aws/s3driver/test_s3driver.py @@ -430,9 +430,9 @@ async def test_skips_upload_when_key_exists( assert counting_driver_client.put_object_count == 1 await driver.store(make_store_context(), [payload]) - assert ( - counting_driver_client.put_object_count == 1 - ), "put_object should not be called for an existing key" + assert counting_driver_client.put_object_count == 1, ( + "put_object should not be called for an existing key" + ) async def test_skips_upload_preserves_data( self, driver_client: S3StorageDriverClient @@ -812,9 +812,9 @@ async def test_store_cancels_remaining_on_failure( assert isinstance(exc_info.value.__cause__, ConnectionError) assert str(exc_info.value.__cause__) == "S3 connection lost" - assert ( - len(faulty_client.cancelled) == 2 - ), "Expected 2 remaining tasks to be cancelled" + assert len(faulty_client.cancelled) == 2, ( + "Expected 2 remaining tasks to be cancelled" + ) async def test_retrieve_cancels_remaining_on_failure( self, driver_client: S3StorageDriverClient @@ -838,9 +838,9 @@ async def test_retrieve_cancels_remaining_on_failure( assert isinstance(exc_info.value.__cause__, ConnectionError) assert str(exc_info.value.__cause__) == "S3 connection lost" - assert ( - len(faulty_client.cancelled) == 2 - ), "Expected 2 remaining tasks to be cancelled" + assert len(faulty_client.cancelled) == 2, ( + "Expected 2 remaining tasks to be cancelled" + ) # --------------------------------------------------------------------------- diff --git a/tests/contrib/langgraph/test_continue_as_new_cached.py b/tests/contrib/langgraph/test_continue_as_new_cached.py index b19620999..af41f384f 100644 --- a/tests/contrib/langgraph/test_continue_as_new_cached.py +++ b/tests/contrib/langgraph/test_continue_as_new_cached.py @@ -120,12 +120,12 @@ async def test_graph_continue_as_new_cached(client: Client): assert result == {"value": 260} # Each node should execute exactly once — phases 2 and 3 use cached results. - assert ( - _execution_counts.get("multiply", 0) == 1 - ), f"multiply executed {_execution_counts.get('multiply', 0)} times, expected 1" - assert ( - _execution_counts.get("add", 0) == 1 - ), f"add executed {_execution_counts.get('add', 0)} times, expected 1" - assert ( - _execution_counts.get("double", 0) == 1 - ), f"double executed {_execution_counts.get('double', 0)} times, expected 1" + assert _execution_counts.get("multiply", 0) == 1, ( + f"multiply executed {_execution_counts.get('multiply', 0)} times, expected 1" + ) + assert _execution_counts.get("add", 0) == 1, ( + f"add executed {_execution_counts.get('add', 0)} times, expected 1" + ) + assert _execution_counts.get("double", 0) == 1, ( + f"double executed {_execution_counts.get('double', 0)} times, expected 1" + ) diff --git a/tests/contrib/langgraph/test_e2e_functional.py b/tests/contrib/langgraph/test_e2e_functional.py index 7f4ffab88..d10efb483 100644 --- a/tests/contrib/langgraph/test_e2e_functional.py +++ b/tests/contrib/langgraph/test_e2e_functional.py @@ -219,15 +219,15 @@ async def test_continue_as_new_with_checkpoint(self, client: Client) -> None: assert result["result"] == 260 counts = get_task_execution_counts() - assert ( - counts.get("task_a", 0) == 1 - ), f"task_a executed {counts.get('task_a', 0)} times, expected 1" - assert ( - counts.get("task_b", 0) == 1 - ), f"task_b executed {counts.get('task_b', 0)} times, expected 1" - assert ( - counts.get("task_c", 0) == 1 - ), f"task_c executed {counts.get('task_c', 0)} times, expected 1" + assert counts.get("task_a", 0) == 1, ( + f"task_a executed {counts.get('task_a', 0)} times, expected 1" + ) + assert counts.get("task_b", 0) == 1, ( + f"task_b executed {counts.get('task_b', 0)} times, expected 1" + ) + assert counts.get("task_c", 0) == 1, ( + f"task_c executed {counts.get('task_c', 0)} times, expected 1" + ) class TestFunctionalAPIPartialExecution: @@ -266,9 +266,9 @@ async def test_partial_execution_five_tasks(self, client: Client) -> None: counts = get_task_execution_counts() for i in range(1, 6): - assert ( - counts.get(f"step_{i}", 0) == 1 - ), f"step_{i} executed {counts.get(f'step_{i}', 0)} times, expected 1" + assert counts.get(f"step_{i}", 0) == 1, ( + f"step_{i} executed {counts.get(f'step_{i}', 0)} times, expected 1" + ) class TestFunctionalAPIInterruptV2: diff --git a/tests/contrib/langsmith/test_integration.py b/tests/contrib/langsmith/test_integration.py index 78d48c71e..a89d1ea4a 100644 --- a/tests/contrib/langsmith/test_integration.py +++ b/tests/contrib/langsmith/test_integration.py @@ -367,27 +367,27 @@ async def test_workflow_activity_trace_hierarchy( " RunActivity:simple_activity", " simple_activity", ] - assert ( - hierarchy == expected - ), f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + assert hierarchy == expected, ( + f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + ) # Verify run_type: RunActivity is "tool", others are "chain" for run in collector.runs: if run.name == "RunActivity:simple_activity": - assert ( - run.run_type == "tool" - ), f"Expected RunActivity run_type='tool', got '{run.run_type}'" + assert run.run_type == "tool", ( + f"Expected RunActivity run_type='tool', got '{run.run_type}'" + ) else: - assert ( - run.run_type == "chain" - ), f"Expected {run.name} run_type='chain', got '{run.run_type}'" + assert run.run_type == "chain", ( + f"Expected {run.name} run_type='chain', got '{run.run_type}'" + ) # Verify successful runs have outputs == {"status": "ok"} for run in collector.runs: if ":" in run.name: # Interceptor runs use "Type:Name" format - assert run.outputs == { - "status": "ok" - }, f"Expected {run.name} outputs={{'status': 'ok'}}, got {run.outputs}" + assert run.outputs == {"status": "ok"}, ( + f"Expected {run.name} outputs={{'status': 'ok'}}, got {run.outputs}" + ) # --------------------------------------------------------------------------- @@ -475,9 +475,9 @@ async def test_activity_failure_marked( " RunActivity:failing_activity", " failing_activity", ] - assert ( - hierarchy == expected - ), f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + assert hierarchy == expected, ( + f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + ) # Verify the RunActivity run has an error activity_runs = [ r for r in collector.runs if r.name == "RunActivity:failing_activity" @@ -514,9 +514,9 @@ async def test_workflow_failure_marked( "StartWorkflow:FailingWorkflow", "RunWorkflow:FailingWorkflow", ] - assert ( - hierarchy == expected - ), f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + assert hierarchy == expected, ( + f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + ) # Verify the RunWorkflow run has an error wf_runs = [r for r in collector.runs if r.name == "RunWorkflow:FailingWorkflow"] assert len(wf_runs) == 1 @@ -555,9 +555,9 @@ async def test_benign_error_not_marked( " RunActivity:benign_failing_activity", " benign_failing_activity", ] - assert ( - hierarchy == expected - ), f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + assert hierarchy == expected, ( + f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + ) # The RunActivity run for benign error should NOT have error set activity_runs = [ r for r in collector.runs if r.name == "RunActivity:benign_failing_activity" @@ -988,15 +988,15 @@ async def test_factory_traceable_no_external_context( " outer_chain", " inner_llm_call", ] - assert ( - hierarchy == expected - ), f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + assert hierarchy == expected, ( + f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + ) # Verify no duplicate run IDs (replay safety with max_cached_workflows=0) run_ids = [r.id for r in collector.runs] - assert len(run_ids) == len( - set(run_ids) - ), f"Duplicate run IDs found (replay issue): {run_ids}" + assert len(run_ids) == len(set(run_ids)), ( + f"Duplicate run IDs found (replay issue): {run_ids}" + ) async def test_factory_passes_project_name_to_children( self, @@ -1081,15 +1081,15 @@ async def test_mixed_sync_async_traceable_with_temporal_runs( " outer_chain", " inner_llm_call", ] - assert ( - hierarchy == expected - ), f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + assert hierarchy == expected, ( + f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + ) # Verify no duplicate run IDs (replay safety with max_cached_workflows=0) run_ids = [r.id for r in collector.runs] - assert len(run_ids) == len( - set(run_ids) - ), f"Duplicate run IDs found (replay issue): {run_ids}" + assert len(run_ids) == len(set(run_ids)), ( + f"Duplicate run IDs found (replay issue): {run_ids}" + ) # --- Nexus service with direct @traceable call in handler --- @@ -1190,9 +1190,9 @@ async def test_nexus_direct_traceable_without_temporal_runs( "nexus_direct_traceable", " inner_llm_call", ] - assert ( - hierarchy == expected - ), f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + assert hierarchy == expected, ( + f"Hierarchy mismatch.\nExpected:\n{expected}\nActual:\n{hierarchy}" + ) # --------------------------------------------------------------------------- diff --git a/tests/contrib/openai_agents/test_openai_tracing.py b/tests/contrib/openai_agents/test_openai_tracing.py index 7613ae49e..facc3212b 100644 --- a/tests/contrib/openai_agents/test_openai_tracing.py +++ b/tests/contrib/openai_agents/test_openai_tracing.py @@ -369,20 +369,20 @@ async def ready() -> bool: assert workflow_span is not None, "Workflow span should exist" # Verify parenting: External trace should be root, workflow span should be child of external trace - assert ( - external_span.parent is None - ), "External trace should have no parent (be root)" + assert external_span.parent is None, ( + "External trace should have no parent (be root)" + ) assert workflow_span.parent is not None, "Workflow span should have a parent" assert external_span.context is not None, "External span should have context" - assert ( - workflow_span.parent.span_id == external_span.context.span_id - ), "Workflow span should be child of external trace" + assert workflow_span.parent.span_id == external_span.context.span_id, ( + "Workflow span should be child of external trace" + ) # Verify all spans have unique IDs span_ids = [span.context.span_id for span in spans if span.context] - assert len(span_ids) == len( - set(span_ids) - ), f"All spans should have unique IDs, got: {span_ids}" + assert len(span_ids) == len(set(span_ids)), ( + f"All spans should have unique IDs, got: {span_ids}" + ) async def test_external_trace_and_span_to_workflow_spans( @@ -462,27 +462,27 @@ async def ready() -> bool: assert workflow_span is not None, "Workflow span should exist" # Verify parenting: External span should be child of trace, workflow span should be child of external span - assert ( - external_trace_span.parent is None - ), "External trace should have no parent (be root)" + assert external_trace_span.parent is None, ( + "External trace should have no parent (be root)" + ) assert external_span.parent is not None, "External span should have a parent" - assert ( - external_trace_span.context is not None - ), "External trace span should have context" - assert ( - external_span.parent.span_id == external_trace_span.context.span_id - ), "External span should be child of external trace" + assert external_trace_span.context is not None, ( + "External trace span should have context" + ) + assert external_span.parent.span_id == external_trace_span.context.span_id, ( + "External span should be child of external trace" + ) assert workflow_span.parent is not None, "Workflow span should have a parent" assert external_span.context is not None, "External span should have context" - assert ( - workflow_span.parent.span_id == external_span.context.span_id - ), "Workflow span should be child of external span" + assert workflow_span.parent.span_id == external_span.context.span_id, ( + "Workflow span should be child of external span" + ) # Verify all spans have unique IDs span_ids = [span.context.span_id for span in spans if span.context] - assert len(span_ids) == len( - set(span_ids) - ), f"All spans should have unique IDs, got: {span_ids}" + assert len(span_ids) == len(set(span_ids)), ( + f"All spans should have unique IDs, got: {span_ids}" + ) async def test_workflow_only_trace_to_spans( @@ -556,16 +556,16 @@ async def ready() -> bool: assert workflow_span is not None, "Workflow span should exist" # Verify parenting: Workflow trace should be root, workflow span should be child of workflow trace - assert ( - workflow_trace_span.parent is None - ), "Workflow trace should have no parent (be root)" + assert workflow_trace_span.parent is None, ( + "Workflow trace should have no parent (be root)" + ) assert workflow_span.parent is not None, "Workflow span should have a parent" - assert ( - workflow_trace_span.context is not None - ), "Workflow trace span should have context" - assert ( - workflow_span.parent.span_id == workflow_trace_span.context.span_id - ), "Workflow span should be child of workflow trace" + assert workflow_trace_span.context is not None, ( + "Workflow trace span should have context" + ) + assert workflow_span.parent.span_id == workflow_trace_span.context.span_id, ( + "Workflow span should be child of workflow trace" + ) @workflow.defn @@ -611,14 +611,14 @@ async def test_custom_span_without_trace_context( if "Should not appear" in span.name or "Neither should this" in span.name ] - assert ( - len(custom_spans) == 0 - ), f"Expected no custom spans without trace context, but found: {[s.name for s in custom_spans]}" + assert len(custom_spans) == 0, ( + f"Expected no custom spans without trace context, but found: {[s.name for s in custom_spans]}" + ) # Should have no spans at all since no trace was started and spans should be dropped - assert ( - len(spans) == 0 - ), f"Expected no spans without trace context, but found: {[s.name for s in spans]}" + assert len(spans) == 0, ( + f"Expected no spans without trace context, but found: {[s.name for s in spans]}" + ) async def test_otel_tracing_in_runner( @@ -696,35 +696,35 @@ async def test_otel_tracing_in_runner( span_ids = {span.context.span_id for span in spans if span.context} for span in spans: if span.parent: - assert ( - span.parent.span_id in span_ids - ), f"Span '{span.name}' has invalid parent reference - parent span doesn't exist" + assert span.parent.span_id in span_ids, ( + f"Span '{span.name}' has invalid parent reference - parent span doesn't exist" + ) # Validate logical parent-child relationships match user code structure workflow_trace_spans = [span for span in spans if "Research workflow" in span.name] - assert ( - len(workflow_trace_spans) == 1 - ), f"Expected exactly one 'Research workflow' trace, got {len(workflow_trace_spans)}" + assert len(workflow_trace_spans) == 1, ( + f"Expected exactly one 'Research workflow' trace, got {len(workflow_trace_spans)}" + ) workflow_span = workflow_trace_spans[0] assert workflow_span.context is not None # Research manager should be child of workflow trace research_span = research_manager_spans[0] assert research_span.context is not None - assert ( - research_span.parent is not None - ), "Research manager span should have a parent" - assert ( - research_span.parent.span_id == workflow_span.context.span_id - ), "Expected 'Research manager' to be child of 'Research workflow' trace" + assert research_span.parent is not None, ( + "Research manager span should have a parent" + ) + assert research_span.parent.span_id == workflow_span.context.span_id, ( + "Expected 'Research manager' to be child of 'Research workflow' trace" + ) # Search the web should be child of research manager search_span = search_web_spans[0] assert search_span.context is not None assert search_span.parent is not None, "Search the web span should have a parent" - assert ( - search_span.parent.span_id == research_span.context.span_id - ), "Expected 'Search the web' to be child of 'Research manager' span" + assert search_span.parent.span_id == research_span.context.span_id, ( + "Expected 'Search the web' to be child of 'Research manager' span" + ) # All search agent spans should be descendants of "Search the web" # (the SDK now inserts a "task" span between "Search the web" and the agent) @@ -741,12 +741,12 @@ def is_descendant_of(child: ReadableSpan, ancestor_span_id: int) -> bool: return False for search_agent_span in search_agent_spans: - assert ( - search_agent_span.parent is not None - ), f"Search agent span '{search_agent_span.name}' should have a parent" - assert is_descendant_of( - search_agent_span, search_span.context.span_id - ), f"Expected all 'Search agent' spans to be descendants of 'Search the web' span" + assert search_agent_span.parent is not None, ( + f"Search agent span '{search_agent_span.name}' should have a parent" + ) + assert is_descendant_of(search_agent_span, search_span.context.span_id), ( + f"Expected all 'Search agent' spans to be descendants of 'Search the web' span" + ) # PlannerAgent and WriterAgent should be descendants of research manager planner_spans = [span for span in spans if "PlannerAgent" in span.name] @@ -754,15 +754,15 @@ def is_descendant_of(child: ReadableSpan, ancestor_span_id: int) -> bool: for planner_span in planner_spans: assert planner_span.parent is not None, "PlannerAgent span should have a parent" - assert is_descendant_of( - planner_span, research_span.context.span_id - ), "Expected 'PlannerAgent' to be descendant of 'Research manager' span" + assert is_descendant_of(planner_span, research_span.context.span_id), ( + "Expected 'PlannerAgent' to be descendant of 'Research manager' span" + ) for writer_span in writer_spans: assert writer_span.parent is not None, "WriterAgent span should have a parent" - assert is_descendant_of( - writer_span, research_span.context.span_id - ), "Expected 'WriterAgent' to be descendant of 'Research manager' span" + assert is_descendant_of(writer_span, research_span.context.span_id), ( + "Expected 'WriterAgent' to be descendant of 'Research manager' span" + ) @workflow.defn @@ -879,32 +879,32 @@ async def ready() -> bool: assert direct_otel_span is not None, "Direct OTEL span should exist" # Verify parenting chain: Client SDK trace -> Workflow SDK span -> Direct OTEL span - assert ( - client_sdk_trace_span.parent is None - ), "Client SDK trace should have no parent (be root)" + assert client_sdk_trace_span.parent is None, ( + "Client SDK trace should have no parent (be root)" + ) - assert ( - workflow_sdk_span.parent is not None - ), "Workflow SDK span should have a parent" - assert ( - client_sdk_trace_span.context is not None - ), "Client SDK trace span should have context" - assert ( - workflow_sdk_span.parent.span_id == client_sdk_trace_span.context.span_id - ), "Workflow SDK span should be child of Client SDK trace" + assert workflow_sdk_span.parent is not None, ( + "Workflow SDK span should have a parent" + ) + assert client_sdk_trace_span.context is not None, ( + "Client SDK trace span should have context" + ) + assert workflow_sdk_span.parent.span_id == client_sdk_trace_span.context.span_id, ( + "Workflow SDK span should be child of Client SDK trace" + ) assert direct_otel_span.parent is not None, "Direct OTEL span should have a parent" - assert ( - workflow_sdk_span.context is not None - ), "Workflow SDK span should have context" - assert ( - direct_otel_span.parent.span_id == workflow_sdk_span.context.span_id - ), "Direct OTEL span should be child of Workflow SDK span" + assert workflow_sdk_span.context is not None, ( + "Workflow SDK span should have context" + ) + assert direct_otel_span.parent.span_id == workflow_sdk_span.context.span_id, ( + "Direct OTEL span should be child of Workflow SDK span" + ) # Verify all spans belong to the same trace - assert ( - workflow_sdk_span.context is not None - ), "Workflow SDK span should have context" + assert workflow_sdk_span.context is not None, ( + "Workflow SDK span should have context" + ) assert direct_otel_span.context is not None, "Direct OTEL span should have context" assert ( client_sdk_trace_span.context.trace_id @@ -914,6 +914,6 @@ async def ready() -> bool: # Verify all spans have unique IDs span_ids = [span.context.span_id for span in spans if span.context] - assert len(span_ids) == len( - set(span_ids) - ), f"All spans should have unique IDs, got: {span_ids}" + assert len(span_ids) == len(set(span_ids)), ( + f"All spans should have unique IDs, got: {span_ids}" + ) diff --git a/tests/contrib/opentelemetry/test_opentelemetry.py b/tests/contrib/opentelemetry/test_opentelemetry.py index 94bb3fda5..71e2fa41d 100644 --- a/tests/contrib/opentelemetry/test_opentelemetry.py +++ b/tests/contrib/opentelemetry/test_opentelemetry.py @@ -720,12 +720,12 @@ async def test_opentelemetry_baggage_propagation_basic(client_with_tracing: Clie task_queue=task_queue, ) - assert ( - result["user_id"] == "test-user-123" - ), "user.id baggage should propagate to activity" - assert ( - result["tenant_id"] == "some-corp" - ), "tenant.id baggage should propagate to activity" + assert result["user_id"] == "test-user-123", ( + "user.id baggage should propagate to activity" + ) + assert result["tenant_id"] == "some-corp", ( + "tenant.id baggage should propagate to activity" + ) @activity.defn @@ -886,15 +886,15 @@ def tracked_detach(token): # type:ignore[reportMissingParameterType] id=f"workflow_{uuid.uuid4()}", task_queue=task_queue, ) - assert ( - not expect_failure - ), "This test should have raised an exception" + assert not expect_failure, ( + "This test should have raised an exception" + ) except Exception: assert expect_failure, "This test is not expeced to raise" - assert ( - attach_count == detach_count - ), f"Context leak detected: {attach_count} attaches vs {detach_count} detaches. " + assert attach_count == detach_count, ( + f"Context leak detected: {attach_count} attaches vs {detach_count} detaches. " + ) assert attach_count > 0, "Expected at least one context attach/detach" finally: @@ -1030,6 +1030,6 @@ def otel_context_error(record: logging.LogRecord) -> bool: and "Failed to detach context" in record.message ) - assert ( - capturer.find(otel_context_error) is None - ), "Detach from context message should not be logged" + assert capturer.find(otel_context_error) is None, ( + "Detach from context message should not be logged" + ) diff --git a/tests/contrib/opentelemetry/test_opentelemetry_plugin.py b/tests/contrib/opentelemetry/test_opentelemetry_plugin.py index 06ad330ed..3fd50e89b 100644 --- a/tests/contrib/opentelemetry/test_opentelemetry_plugin.py +++ b/tests/contrib/opentelemetry/test_opentelemetry_plugin.py @@ -123,9 +123,9 @@ async def test_otel_tracing_basic(client: Client, reset_otel_tracer_provider: An # Verify the span hierarchy matches expectations actual_hierarchy = dump_spans(spans, with_attributes=False) - assert ( - actual_hierarchy == expected_hierarchy - ), f"Span hierarchy mismatch.\nExpected:\n{expected_hierarchy}\nActual:\n{actual_hierarchy}" + assert actual_hierarchy == expected_hierarchy, ( + f"Span hierarchy mismatch.\nExpected:\n{expected_hierarchy}\nActual:\n{actual_hierarchy}" + ) @workflow.defn @@ -382,9 +382,9 @@ async def test_opentelemetry_comprehensive_tracing( # Verify the span hierarchy matches expectations actual_hierarchy = dump_spans(spans, with_attributes=False) - assert ( - actual_hierarchy == expected_hierarchy - ), f"Span hierarchy mismatch.\nExpected:\n{expected_hierarchy}\nActual:\n{actual_hierarchy}" + assert actual_hierarchy == expected_hierarchy, ( + f"Span hierarchy mismatch.\nExpected:\n{expected_hierarchy}\nActual:\n{actual_hierarchy}" + ) async def test_otel_tracing_with_added_spans( @@ -439,9 +439,9 @@ async def test_otel_tracing_with_added_spans( # Verify the span hierarchy matches expectations actual_hierarchy = dump_spans(spans, with_attributes=False) - assert ( - actual_hierarchy == expected_hierarchy - ), f"Span hierarchy mismatch.\nExpected:\n{expected_hierarchy}\nActual:\n{actual_hierarchy}" + assert actual_hierarchy == expected_hierarchy, ( + f"Span hierarchy mismatch.\nExpected:\n{expected_hierarchy}\nActual:\n{actual_hierarchy}" + ) task_fail_once_workflow_has_failed = False @@ -507,9 +507,9 @@ async def test_otel_tracing_workflow_task_failure( ] actual_hierarchy = dump_spans(spans, with_attributes=False) - assert ( - actual_hierarchy == expected_hierarchy - ), f"Span hierarchy mismatch.\nExpected:\n{expected_hierarchy}\nActual:\n{actual_hierarchy}" + assert actual_hierarchy == expected_hierarchy, ( + f"Span hierarchy mismatch.\nExpected:\n{expected_hierarchy}\nActual:\n{actual_hierarchy}" + ) @workflow.defn @@ -562,9 +562,9 @@ async def test_otel_tracing_workflow_failure( ] actual_hierarchy = dump_spans(spans, with_attributes=False) - assert ( - actual_hierarchy == expected_hierarchy - ), f"Span hierarchy mismatch.\nExpected:\n{expected_hierarchy}\nActual:\n{actual_hierarchy}" + assert actual_hierarchy == expected_hierarchy, ( + f"Span hierarchy mismatch.\nExpected:\n{expected_hierarchy}\nActual:\n{actual_hierarchy}" + ) async def test_otel_standalone_activity_tracing( diff --git a/tests/contrib/workflow_streams/test_workflow_streams.py b/tests/contrib/workflow_streams/test_workflow_streams.py index 203e1313d..7353cbdd5 100644 --- a/tests/contrib/workflow_streams/test_workflow_streams.py +++ b/tests/contrib/workflow_streams/test_workflow_streams.py @@ -2133,9 +2133,9 @@ async def standalone_publish_to_broker(input: StandalonePublishInput) -> None: ``WorkflowStreamClient.create``. ``from_within_activity`` is not usable here because the activity has no parent workflow. """ - assert ( - activity.info().workflow_id is None - ), "test bug: this activity should be standalone" + assert activity.info().workflow_id is None, ( + "test bug: this activity should be standalone" + ) client = WorkflowStreamClient.create( client=activity.client(), workflow_id=input.broker_workflow_id, @@ -2149,9 +2149,9 @@ async def standalone_publish_to_broker(input: StandalonePublishInput) -> None: @activity.defn(name="standalone_subscribe_to_broker") async def standalone_subscribe_to_broker(input: CrossWorkflowInput) -> list[str]: - assert ( - activity.info().workflow_id is None - ), "test bug: this activity should be standalone" + assert activity.info().workflow_id is None, ( + "test bug: this activity should be standalone" + ) client = WorkflowStreamClient.create( client=activity.client(), workflow_id=input.broker_workflow_id, diff --git a/tests/nexus/test_workflow_caller.py b/tests/nexus/test_workflow_caller.py index 38f51cd63..df6ace9fa 100644 --- a/tests/nexus/test_workflow_caller.py +++ b/tests/nexus/test_workflow_caller.py @@ -440,13 +440,13 @@ async def run( self._nexus_operation_start_resolved = True if not input.op_input.response_type.exception_in_operation_start: if isinstance(input.op_input.response_type, SyncResponse): - assert ( - op_handle.operation_token is None - ), "operation_token should be absent after a sync response" + assert op_handle.operation_token is None, ( + "operation_token should be absent after a sync response" + ) else: - assert ( - op_handle.operation_token - ), "operation_token should be present after an async response" + assert op_handle.operation_token, ( + "operation_token should be present after an async response" + ) if request_cancel: # Even for SyncResponse, the op_handle future is not done at this point; that @@ -2302,16 +2302,16 @@ async def test_request_deadline_is_accessible_in_operation( assert len(service_handler.start_deadlines_received) == 1 deadline = service_handler.start_deadlines_received[0] - assert ( - deadline is not None - ), "request_deadline should be set in StartOperationContext" + assert deadline is not None, ( + "request_deadline should be set in StartOperationContext" + ) assert deadline.tzinfo is timezone.utc, "request_deadline should be in utc" await asyncio.wait_for(service_handler.cancel_received.wait(), 1) assert len(service_handler.cancel_deadlines_received) == 1 deadline = service_handler.cancel_deadlines_received[0] - assert ( - deadline is not None - ), "request_deadline should be set in CancelOperationContext" + assert deadline is not None, ( + "request_deadline should be set in CancelOperationContext" + ) assert deadline.tzinfo is timezone.utc, "request_deadline should be in utc" diff --git a/tests/nexus/test_workflow_caller_error_chains.py b/tests/nexus/test_workflow_caller_error_chains.py index 28831e476..9ff84f405 100644 --- a/tests/nexus/test_workflow_caller_error_chains.py +++ b/tests/nexus/test_workflow_caller_error_chains.py @@ -574,15 +574,15 @@ def _validate_exception_chain( # Check remaining expected errors are all optional while expected_idx < len(expected_chain): expected = expected_chain[expected_idx] - assert ( - expected.optional - ), f"Required expected error not found in chain: {expected}" + assert expected.optional, ( + f"Required expected error not found in chain: {expected}" + ) expected_idx += 1 # Check no remaining actual errors - assert actual_idx == len( - actual_chain - ), f"Unexpected errors in chain: {actual_chain[actual_idx:]}" + assert actual_idx == len(actual_chain), ( + f"Unexpected errors in chain: {actual_chain[actual_idx:]}" + ) @workflow.defn(sandboxed=False) diff --git a/tests/nexus/test_workflow_run_operation.py b/tests/nexus/test_workflow_run_operation.py index 489353165..3ba9545fc 100644 --- a/tests/nexus/test_workflow_run_operation.py +++ b/tests/nexus/test_workflow_run_operation.py @@ -182,7 +182,7 @@ async def test_request_deadline_is_accessible_in_workflow_run_operation( assert len(service_handler.start_deadlines_received) == 1 deadline = service_handler.start_deadlines_received[0] - assert ( - deadline is not None - ), "request_deadline should be set in WorkflowRunOperationContext" + assert deadline is not None, ( + "request_deadline should be set in WorkflowRunOperationContext" + ) assert deadline.tzinfo is timezone.utc, "request_deadline should be in utc" diff --git a/tests/test_activity.py b/tests/test_activity.py index 8ed4729ec..172040257 100644 --- a/tests/test_activity.py +++ b/tests/test_activity.py @@ -867,9 +867,9 @@ async def test_id_conflict_policy_fail(client: Client, env: WorkflowEnvironment) id_conflict_policy=ActivityIDConflictPolicy.FAIL, ) assert err.value.activity_id == activity_id - assert "Activity" in str( - err.value - ), f"Expected 'Activity' in error message, got: {err.value}" + assert "Activity" in str(err.value), ( + f"Expected 'Activity' in error message, got: {err.value}" + ) async def test_id_conflict_policy_use_existing( diff --git a/tests/test_extstore.py b/tests/test_extstore.py index 196632042..9a058c582 100644 --- a/tests/test_extstore.py +++ b/tests/test_extstore.py @@ -162,9 +162,9 @@ async def test_extstore_composite_conditional(self): options = ExternalStorage( drivers=[hot_driver, cold_driver], - driver_selector=lambda context, payload: hot_driver - if payload.ByteSize() < 500 - else cold_driver, + driver_selector=lambda context, payload: ( + hot_driver if payload.ByteSize() < 500 else cold_driver + ), payload_size_threshold=100, ) converter = DataConverter(external_storage=options) diff --git a/tests/test_plugins.py b/tests/test_plugins.py index e54f8065f..e8823af27 100644 --- a/tests/test_plugins.py +++ b/tests/test_plugins.py @@ -6,8 +6,8 @@ from typing import cast import pytest - import temporalio.bridge.temporal_sdk_bridge + import temporalio.client import temporalio.converter import temporalio.worker @@ -463,9 +463,9 @@ async def test_simple_plugin_worker_interceptor_only_used_on_worker( # The interceptor should NOT have been used for client interception # since the plugin was not added to the client - assert ( - not interceptor.client_intercepted - ), "Client interceptor should not have been used" + assert not interceptor.client_intercepted, ( + "Client interceptor should not have been used" + ) # The interceptor SHOULD have been used for worker interception # even though it was specified in interceptors @@ -527,9 +527,9 @@ async def test_simple_plugin_interceptor_duplication_when_used_on_client_and_wor assert result == "Hello, test!" # The workflow interceptor should only be called ONCE, not twice - assert ( - interceptor.call_count["execute_workflow"] == 1 - ), f"Expected execute_workflow to be called once, but was called {interceptor.call_count['execute_workflow']} times. This indicates interceptor duplication in execution." + assert interceptor.call_count["execute_workflow"] == 1, ( + f"Expected execute_workflow to be called once, but was called {interceptor.call_count['execute_workflow']} times. This indicates interceptor duplication in execution." + ) async def test_simple_plugin_no_duplication_when_interceptor_in_both_client_and_worker_params( @@ -571,9 +571,9 @@ async def test_simple_plugin_no_duplication_when_interceptor_in_both_client_and_ assert result == "Hello, test!" # The workflow interceptor should only be called ONCE, not twice - assert ( - interceptor.call_count["execute_workflow"] == 1 - ), f"Expected execute_workflow to be called once, but was called {interceptor.call_count['execute_workflow']} times. This indicates interceptor duplication in execution." + assert interceptor.call_count["execute_workflow"] == 1, ( + f"Expected execute_workflow to be called once, but was called {interceptor.call_count['execute_workflow']} times. This indicates interceptor duplication in execution." + ) async def test_simple_plugin_no_duplication_in_interceptor_chain( @@ -612,6 +612,6 @@ async def test_simple_plugin_no_duplication_in_interceptor_chain( assert result == "Hello, test!" # The workflow interceptor should only be called ONCE, not twice - assert ( - interceptor.call_count["execute_workflow"] == 1 - ), f"Expected execute_workflow to be called once, but was called {interceptor.call_count['execute_workflow']} times. This indicates interceptor duplication in the chain." + assert interceptor.call_count["execute_workflow"] == 1, ( + f"Expected execute_workflow to be called once, but was called {interceptor.call_count['execute_workflow']} times. This indicates interceptor duplication in the chain." + ) diff --git a/tests/test_runtime.py b/tests/test_runtime.py index 609df3ff8..c29961c52 100644 --- a/tests/test_runtime.py +++ b/tests/test_runtime.py @@ -255,15 +255,15 @@ async def check_metrics() -> None: metrics_output = f.read().decode("utf-8") for key, buckets in histogram_overrides.items(): - assert ( - key in metrics_output - ), f"Missing {key} in full output: {metrics_output}" + assert key in metrics_output, ( + f"Missing {key} in full output: {metrics_output}" + ) for bucket in buckets: # expect to have {key}_bucket and le={bucket} in the same line with arbitrary strings between them regex = re.compile(f'{key}_bucket.*le="{bucket}"') - assert regex.search( - metrics_output - ), f"Missing bucket for {key} in full output: {metrics_output}" + assert regex.search(metrics_output), ( + f"Missing bucket for {key} in full output: {metrics_output}" + ) # Wait for metrics to appear and match the expected buckets await assert_eventually(check_metrics) diff --git a/tests/test_serialization_context.py b/tests/test_serialization_context.py index 8e8fcf048..0fde2aa96 100644 --- a/tests/test_serialization_context.py +++ b/tests/test_serialization_context.py @@ -1600,7 +1600,7 @@ def __init__(self): @workflow.run async def run(self, _data: str) -> str: await workflow.wait_condition( - lambda: (self.received_signal and self.received_update) + lambda: self.received_signal and self.received_update ) # Run them in parallel to check that data converter operations do not mix up contexts when # there are multiple concurrent payload types. diff --git a/tests/test_service.py b/tests/test_service.py index 9fdcd9fc7..11f993844 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -249,9 +249,9 @@ async def test_method( try: await rpc_call(request, timeout=timedelta(milliseconds=1)) except ValueError as err: - assert ( - "Unknown RPC call" not in str(err) - ), f"Unexpected unknown-RPC error for {target_service_name}.{method_name}: {err}" + assert "Unknown RPC call" not in str(err), ( + f"Unexpected unknown-RPC error for {target_service_name}.{method_name}: {err}" + ) except temporalio.service.RPCError: pass diff --git a/tests/test_workflow.py b/tests/test_workflow.py index 8bd06ad9b..5618c34f5 100644 --- a/tests/test_workflow.py +++ b/tests/test_workflow.py @@ -446,9 +446,9 @@ def test_parameters_identical_up_to_naming(): for f1, f2 in itertools.combinations(fns, 2): name1, name2 = f1.__name__, f2.__name__ expect_equal = name1[0] == name2[0] - assert ( - workflow._parameters_identical_up_to_naming(f1, f2) == (expect_equal) - ), f"expected {name1} and {name2} parameters{' ' if expect_equal else ' not '}to compare equal" + assert workflow._parameters_identical_up_to_naming(f1, f2) == (expect_equal), ( + f"expected {name1} and {name2} parameters{' ' if expect_equal else ' not '}to compare equal" + ) @workflow.defn diff --git a/tests/worker/test_command_aware_visitor.py b/tests/worker/test_command_aware_visitor.py index b8488689e..f354c8614 100644 --- a/tests/worker/test_command_aware_visitor.py +++ b/tests/worker/test_command_aware_visitor.py @@ -65,13 +65,13 @@ def test_command_aware_visitor_has_methods_for_all_seq_protos_with_payloads(): # Sanity check: we should have fewer overrides than total protos with seq # (because some don't have payloads) - assert len(commands_with_payloads) < len( - command_protos - ), "Should have some commands without payloads" + assert len(commands_with_payloads) < len(command_protos), ( + "Should have some commands without payloads" + ) # All activation jobs except FireTimer have payloads - assert ( - len(jobs_with_payloads) == len(job_protos) - 1 - ), "Should have exactly one activation job without payloads (FireTimer)" + assert len(jobs_with_payloads) == len(job_protos) - 1, ( + "Should have exactly one activation job without payloads (FireTimer)" + ) def _get_workflow_command_protos_with_seq() -> Iterator[type[Any]]: diff --git a/uv.lock b/uv.lock index c78ab196c..a0e471bbe 100644 --- a/uv.lock +++ b/uv.lock @@ -9,7 +9,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-04-23T17:46:27.746666Z" +exclude-newer = "2026-05-01T16:01:05.963689Z" exclude-newer-span = "P1W" [options.exclude-newer-package] @@ -4931,27 +4931,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.5.7" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/bf/2b/69e5e412f9d390adbdbcbf4f64d6914fa61b44b08839a6584655014fc524/ruff-0.5.7.tar.gz", hash = "sha256:8dfc0a458797f5d9fb622dd0efc52d796f23f0a1493a9527f4e49a550ae9a7e5", size = 2449817, upload-time = "2024-08-08T15:43:07.467Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/6b/eb/06e06aaf96af30a68e83b357b037008c54a2ddcbad4f989535007c700394/ruff-0.5.7-py3-none-linux_armv6l.whl", hash = "sha256:548992d342fc404ee2e15a242cdbea4f8e39a52f2e7752d0e4cbe88d2d2f416a", size = 9570571, upload-time = "2024-08-08T15:41:56.537Z" }, - { url = "https://files.pythonhosted.org/packages/a4/10/1be32aeaab8728f78f673e7a47dd813222364479b2d6573dbcf0085e83ea/ruff-0.5.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:00cc8872331055ee017c4f1071a8a31ca0809ccc0657da1d154a1d2abac5c0be", size = 8685138, upload-time = "2024-08-08T15:42:02.833Z" }, - { url = "https://files.pythonhosted.org/packages/3d/1d/c218ce83beb4394ba04d05e9aa2ae6ce9fba8405688fe878b0fdb40ce855/ruff-0.5.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:eaf3d86a1fdac1aec8a3417a63587d93f906c678bb9ed0b796da7b59c1114a1e", size = 8266785, upload-time = "2024-08-08T15:42:08.321Z" }, - { url = "https://files.pythonhosted.org/packages/26/79/7f49509bd844476235b40425756def366b227a9714191c91f02fb2178635/ruff-0.5.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a01c34400097b06cf8a6e61b35d6d456d5bd1ae6961542de18ec81eaf33b4cb8", size = 9983964, upload-time = "2024-08-08T15:42:12.419Z" }, - { url = "https://files.pythonhosted.org/packages/bf/b1/939836b70bf9fcd5e5cd3ea67fdb8abb9eac7631351d32f26544034a35e4/ruff-0.5.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fcc8054f1a717e2213500edaddcf1dbb0abad40d98e1bd9d0ad364f75c763eea", size = 9359490, upload-time = "2024-08-08T15:42:16.713Z" }, - { url = "https://files.pythonhosted.org/packages/32/7d/b3db19207de105daad0c8b704b2c6f2a011f9c07017bd58d8d6e7b8eba19/ruff-0.5.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7f70284e73f36558ef51602254451e50dd6cc479f8b6f8413a95fcb5db4a55fc", size = 10170833, upload-time = "2024-08-08T15:42:20.54Z" }, - { url = "https://files.pythonhosted.org/packages/a2/45/eae9da55f3357a1ac04220230b8b07800bf516e6dd7e1ad20a2ff3b03b1b/ruff-0.5.7-py3-none-manylinux_2_17_ppc64.manylinux2014_ppc64.whl", hash = "sha256:a78ad870ae3c460394fc95437d43deb5c04b5c29297815a2a1de028903f19692", size = 10896360, upload-time = "2024-08-08T15:42:25.2Z" }, - { url = "https://files.pythonhosted.org/packages/99/67/4388b36d145675f4c51ebec561fcd4298a0e2550c81e629116f83ce45a39/ruff-0.5.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9ccd078c66a8e419475174bfe60a69adb36ce04f8d4e91b006f1329d5cd44bcf", size = 10477094, upload-time = "2024-08-08T15:42:29.553Z" }, - { url = "https://files.pythonhosted.org/packages/e1/9c/f5e6ed1751dc187a4ecf19a4970dd30a521c0ee66b7941c16e292a4043fb/ruff-0.5.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7e31c9bad4ebf8fdb77b59cae75814440731060a09a0e0077d559a556453acbb", size = 11480896, upload-time = "2024-08-08T15:42:33.772Z" }, - { url = "https://files.pythonhosted.org/packages/c8/3b/2b683be597bbd02046678fc3fc1c199c641512b20212073b58f173822bb3/ruff-0.5.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d796327eed8e168164346b769dd9a27a70e0298d667b4ecee6877ce8095ec8e", size = 10179702, upload-time = "2024-08-08T15:42:38.038Z" }, - { url = "https://files.pythonhosted.org/packages/f1/38/c2d94054dc4b3d1ea4c2ba3439b2a7095f08d1c8184bc41e6abe2a688be7/ruff-0.5.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:4a09ea2c3f7778cc635e7f6edf57d566a8ee8f485f3c4454db7771efb692c499", size = 9982855, upload-time = "2024-08-08T15:42:42.031Z" }, - { url = "https://files.pythonhosted.org/packages/7d/e7/1433db2da505ffa8912dcf5b28a8743012ee780cbc20ad0bf114787385d9/ruff-0.5.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:a36d8dcf55b3a3bc353270d544fb170d75d2dff41eba5df57b4e0b67a95bb64e", size = 9433156, upload-time = "2024-08-08T15:42:45.339Z" }, - { url = "https://files.pythonhosted.org/packages/e0/36/4fa43250e67741edeea3d366f59a1dc993d4d89ad493a36cbaa9889895f2/ruff-0.5.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:9369c218f789eefbd1b8d82a8cf25017b523ac47d96b2f531eba73770971c9e5", size = 9782971, upload-time = "2024-08-08T15:42:49.354Z" }, - { url = "https://files.pythonhosted.org/packages/80/0e/8c276103d518e5cf9202f70630aaa494abf6fc71c04d87c08b6d3cd07a4b/ruff-0.5.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b88ca3db7eb377eb24fb7c82840546fb7acef75af4a74bd36e9ceb37a890257e", size = 10247775, upload-time = "2024-08-08T15:42:53.294Z" }, - { url = "https://files.pythonhosted.org/packages/cb/b9/673096d61276f39291b729dddde23c831a5833d98048349835782688a0ec/ruff-0.5.7-py3-none-win32.whl", hash = "sha256:33d61fc0e902198a3e55719f4be6b375b28f860b09c281e4bdbf783c0566576a", size = 7841772, upload-time = "2024-08-08T15:42:57.488Z" }, - { url = "https://files.pythonhosted.org/packages/67/1c/4520c98bfc06b9c73cd1457686d4d3935d40046b1ddea08403e5a6deff51/ruff-0.5.7-py3-none-win_amd64.whl", hash = "sha256:083bbcbe6fadb93cd86709037acc510f86eed5a314203079df174c40bbbca6b3", size = 8699779, upload-time = "2024-08-08T15:43:00.429Z" }, - { url = "https://files.pythonhosted.org/packages/38/23/b3763a237d2523d40a31fe2d1a301191fe392dd48d3014977d079cf8c0bd/ruff-0.5.7-py3-none-win_arm64.whl", hash = "sha256:2dca26154ff9571995107221d0aeaad0e75a77b5a682d6236cf89a58c70b76f4", size = 8091891, upload-time = "2024-08-08T15:43:04.162Z" }, +version = "0.15.12" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/99/43/3291f1cc9106f4c63bdce7a8d0df5047fe8422a75b091c16b5e9355e0b11/ruff-0.15.12.tar.gz", hash = "sha256:ecea26adb26b4232c0c2ca19ccbc0083a68344180bba2a600605538ce51a40a6", size = 4643852, upload-time = "2026-04-24T18:17:14.305Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c3/6e/e78ffb61d4686f3d96ba3df2c801161843746dcbcbb17a1e927d4829312b/ruff-0.15.12-py3-none-linux_armv6l.whl", hash = "sha256:f86f176e188e94d6bdbc09f09bfd9dc729059ad93d0e7390b5a73efe19f8861c", size = 10640713, upload-time = "2026-04-24T18:17:22.841Z" }, + { url = "https://files.pythonhosted.org/packages/ae/08/a317bc231fb9e7b93e4ef3089501e51922ff88d6936ce5cf870c4fe55419/ruff-0.15.12-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:e3bcd123364c3770b8e1b7baaf343cc99a35f197c5c6e8af79015c666c423a6c", size = 11069267, upload-time = "2026-04-24T18:17:30.105Z" }, + { url = "https://files.pythonhosted.org/packages/aa/a4/f828e9718d3dce1f5f11c39c4f65afd32783c8b2aebb2e3d259e492c47bd/ruff-0.15.12-py3-none-macosx_11_0_arm64.whl", hash = "sha256:fe87510d000220aa1ed530d4448a7c696a0cae1213e5ec30e5874287b66557b5", size = 10397182, upload-time = "2026-04-24T18:17:07.177Z" }, + { url = "https://files.pythonhosted.org/packages/71/e0/3310fc6d1b5e1fdea22bf3b1b807c7e187b581021b0d7d4514cccdb5fb71/ruff-0.15.12-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:84a1630093121375a3e2a95b4a6dc7b59e2b4ee76216e32d81aae550a832d002", size = 10758012, upload-time = "2026-04-24T18:16:55.759Z" }, + { url = "https://files.pythonhosted.org/packages/11/c1/a606911aee04c324ddaa883ae418f3569792fd3c4a10c50e0dd0a2311e1e/ruff-0.15.12-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:fb129f40f114f089ebe0ca56c0d251cf2061b17651d464bb6478dc01e69f11f5", size = 10447479, upload-time = "2026-04-24T18:16:51.677Z" }, + { url = "https://files.pythonhosted.org/packages/9d/68/4201e8444f0894f21ab4aeeaee68aa4f10b51613514a20d80bd628d57e88/ruff-0.15.12-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b0c862b172d695db7598426b8af465e7e9ac00a3ea2a3630ee67eb82e366aaa6", size = 11234040, upload-time = "2026-04-24T18:17:16.529Z" }, + { url = "https://files.pythonhosted.org/packages/34/ff/8a6d6cf4ccc23fd67060874e832c18919d1557a0611ebef03fdb01fff11e/ruff-0.15.12-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2849ea9f3484c3aca43a82f484210370319e7170df4dfe4843395ddf6c57bc33", size = 12087377, upload-time = "2026-04-24T18:17:04.944Z" }, + { url = "https://files.pythonhosted.org/packages/85/f6/c669cf73f5152f623d34e69866a46d5e6185816b19fcd5b6dd8a2d299922/ruff-0.15.12-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9e77c7e51c07fe396826d5969a5b846d9cd4c402535835fb6e21ce8b28fef847", size = 11367784, upload-time = "2026-04-24T18:17:25.409Z" }, + { url = "https://files.pythonhosted.org/packages/e8/39/c61d193b8a1daaa8977f7dea9e8d8ba866e02ea7b65d32f6861693aa4c12/ruff-0.15.12-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:83b2f4f2f3b1026b5fb449b467d9264bf22067b600f7b6f41fc5958909f449d0", size = 11344088, upload-time = "2026-04-24T18:17:12.258Z" }, + { url = "https://files.pythonhosted.org/packages/c2/8d/49afab3645e31e12c590acb6d3b5b69d7aab5b81926dbaf7461f9441f37a/ruff-0.15.12-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:9ba3b8f1afd7e2e43d8943e55f249e13f9682fde09711644a6e7290eb4f3e339", size = 11271770, upload-time = "2026-04-24T18:17:02.457Z" }, + { url = "https://files.pythonhosted.org/packages/46/06/33f41fe94403e2b755481cdfb9b7ef3e4e0ed031c4581124658d935d52b4/ruff-0.15.12-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:e852ba9fdc890655e1d78f2df1499efbe0e54126bd405362154a75e2bde159c5", size = 10719355, upload-time = "2026-04-24T18:17:27.648Z" }, + { url = "https://files.pythonhosted.org/packages/0d/59/18aa4e014debbf559670e4048e39260a85c7fcee84acfd761ac01e7b8d35/ruff-0.15.12-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:dd8aed930da53780d22fc70bdf84452c843cf64f8cb4eb38984319c24c5cd5fd", size = 10462758, upload-time = "2026-04-24T18:17:32.347Z" }, + { url = "https://files.pythonhosted.org/packages/25/e7/cc9f16fd0f3b5fddcbd7ec3d6ae30c8f3fde1047f32a4093a98d633c6570/ruff-0.15.12-py3-none-musllinux_1_2_i686.whl", hash = "sha256:01da3988d225628b709493d7dc67c3b9b12c0210016b08690ef9bd27970b262b", size = 10953498, upload-time = "2026-04-24T18:17:20.674Z" }, + { url = "https://files.pythonhosted.org/packages/72/7a/a9ba7f98c7a575978698f4230c5e8cc54bbc761af34f560818f933dafa0c/ruff-0.15.12-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:9cae0f92bd5700d1213188b31cd3bdd2b315361296d10b96b8e2337d3d11f53e", size = 11447765, upload-time = "2026-04-24T18:17:09.755Z" }, + { url = "https://files.pythonhosted.org/packages/ea/f9/0ae446942c846b8266059ad8a30702a35afae55f5cdc54c5adf8d7afdc27/ruff-0.15.12-py3-none-win32.whl", hash = "sha256:d0185894e038d7043ba8fd6aee7499ece6462dc0ea9f1e260c7451807c714c20", size = 10657277, upload-time = "2026-04-24T18:17:18.591Z" }, + { url = "https://files.pythonhosted.org/packages/33/f1/9614e03e1cdcbf9437570b5400ced8a720b5db22b28d8e0f1bda429f660d/ruff-0.15.12-py3-none-win_amd64.whl", hash = "sha256:c87a162d61ab3adca47c03f7f717c68672edec7d1b5499e652331780fe74950d", size = 11837758, upload-time = "2026-04-24T18:17:00.113Z" }, + { url = "https://files.pythonhosted.org/packages/c0/98/6beb4b351e472e5f4c4613f7c35a5290b8be2497e183825310c4c3a3984b/ruff-0.15.12-py3-none-win_arm64.whl", hash = "sha256:a538f7a82d061cee7be55542aca1d86d1393d55d81d4fcc314370f4340930d4f", size = 11120821, upload-time = "2026-04-24T18:16:57.979Z" }, ] [[package]] @@ -5291,7 +5291,7 @@ dev = [ { name = "pytest-rerunfailures", specifier = ">=16.1" }, { name = "pytest-timeout", specifier = "~=2.2" }, { name = "pytest-xdist", specifier = ">=3.6,<4" }, - { name = "ruff", specifier = ">=0.5.0,<0.6" }, + { name = "ruff", specifier = ">=0.15.12,<0.16" }, { name = "setuptools", specifier = "<82" }, { name = "toml", specifier = ">=0.10.2,<0.11" }, { name = "twine", specifier = ">=4.0.1,<5" }, From cc4319756074d5dc8e7fa1321c8d2f591ef061da Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 12 May 2026 12:46:07 -0700 Subject: [PATCH 02/46] contrib/strands: add Strands Agents plugin --- pyproject.toml | 5 + temporalio/contrib/strands/README.md | 8 + temporalio/contrib/strands/__init__.py | 5 + temporalio/contrib/strands/_plugin.py | 178 +++++++++++ tests/contrib/strands/mock_model.py | 60 ++++ .../contrib/strands/test_structured_output.py | 66 +++++ tests/contrib/strands/test_tool.py | 90 ++++++ uv.lock | 279 +++++++++++++++++- 8 files changed, 686 insertions(+), 5 deletions(-) create mode 100644 temporalio/contrib/strands/README.md create mode 100644 temporalio/contrib/strands/__init__.py create mode 100644 temporalio/contrib/strands/_plugin.py create mode 100644 tests/contrib/strands/mock_model.py create mode 100644 tests/contrib/strands/test_structured_output.py create mode 100644 tests/contrib/strands/test_tool.py diff --git a/pyproject.toml b/pyproject.toml index b74d2a7ad..50216130c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -43,6 +43,9 @@ aioboto3 = [ "aioboto3>=10.4.0", "types-aioboto3[s3]>=10.4.0", ] +strands = [ + "strands-agents>=1.38.0", +] [project.urls] Homepage = "https://github.com/temporalio/sdk-python" @@ -87,6 +90,8 @@ dev = [ "opentelemetry-semantic-conventions>=0.40b0,<1", "opentelemetry-sdk-extension-aws>=2.0.0,<3", "async-timeout>=4.0,<6; python_version < '3.11'", + "strands-agents>=1.38.0", + "strands-agents-tools>=0.5.2", ] [tool.poe.tasks] diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md new file mode 100644 index 000000000..c195a08ce --- /dev/null +++ b/temporalio/contrib/strands/README.md @@ -0,0 +1,8 @@ +# AWS Strands Agents + +# Unsupported features +* File-based tool lookup + +# Migration Guide +* Mark non-deterministic tools with `as_activity()` and include in `StrandsAgent(activity_tools=[...])` +* Use `agent.invoke_async(message)` instead of `agent(message)` which spawns a thread diff --git a/temporalio/contrib/strands/__init__.py b/temporalio/contrib/strands/__init__.py new file mode 100644 index 000000000..ab7f96000 --- /dev/null +++ b/temporalio/contrib/strands/__init__.py @@ -0,0 +1,5 @@ +"""Temporal integration for the Strands Agents SDK.""" + +from ._plugin import StrandsPlugin, as_activity + +__all__ = ["StrandsPlugin", "as_activity"] diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py new file mode 100644 index 000000000..aafaaadb7 --- /dev/null +++ b/temporalio/contrib/strands/_plugin.py @@ -0,0 +1,178 @@ +import asyncio +import functools +import inspect +from dataclasses import replace +from datetime import timedelta +from types import ModuleType +from typing import Any, Callable + +from strands import tool as strands_tool +from strands.tools import ToolProvider +from strands.tools.decorator import DecoratedFunctionTool + +from temporalio import activity, workflow +from temporalio.common import Priority, RetryPolicy +from temporalio.contrib.pydantic import pydantic_data_converter +from temporalio.converter import DataConverter, DefaultPayloadConverter +from temporalio.plugin import SimplePlugin +from temporalio.worker import WorkflowRunner +from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner +from temporalio.workflow import ActivityCancellationType, VersioningIntent + + +class StrandsPlugin(SimplePlugin): + """Temporal Worker plugin for the Strands Agents SDK. + + Marks ``strands`` as a sandbox passthrough module and registers any + nondeterministic tools as Temporal activities so they can be invoked + from inside workflows. + """ + + def __init__( + self, + activity_tools: list[ToolProvider | Any] | None = None, + ): + """Initialize the plugin. + + Args: + activity_tools: Tools to register as Temporal activities. Accepts + the same forms as ``strands.Agent(tools=...)``: plain + functions, ``@tool``-decorated functions, or imported + ``strands_tools`` submodules. + """ + super().__init__( + "aws.StrandsPlugin", + activities=[_build_activity(tool) for tool in activity_tools or []], + workflow_runner=_workflow_runner, + data_converter=_data_converter, + ) + + +def as_activity( + tool: ToolProvider | Any, + *, + task_queue: str | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + retry_policy: RetryPolicy | None = None, + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, + activity_id: str | None = None, + versioning_intent: VersioningIntent | None = None, + summary: str | None = None, + priority: Priority = Priority.default, +) -> DecoratedFunctionTool: + """Wrap a Strands tool to dispatch through a Temporal activity. + + Accepts the same forms as ``strands.Agent(tools=...)``. The returned + tool has the same name and schema as the original, but its body calls + ``workflow.execute_activity`` instead of running the function inline. + The activity itself must be registered on the worker via + ``StrandsPlugin(activity_tools=[fn])``. + + All keyword arguments are forwarded to ``workflow.execute_activity``; + refer to its documentation for details. + """ + fn = _unwrap_tool(tool) + sig = inspect.signature(fn) + activity_name = fn.__name__ + options: dict[str, Any] = { + "task_queue": task_queue, + "schedule_to_close_timeout": schedule_to_close_timeout, + "schedule_to_start_timeout": schedule_to_start_timeout, + "start_to_close_timeout": start_to_close_timeout, + "heartbeat_timeout": heartbeat_timeout, + "retry_policy": retry_policy, + "cancellation_type": cancellation_type, + "activity_id": activity_id, + "versioning_intent": versioning_intent, + "summary": summary, + "priority": priority, + } + + @functools.wraps(fn) + async def proxy(*args: Any, **kwargs: Any) -> Any: + bound = sig.bind(*args, **kwargs) + bound.apply_defaults() + positional = list(bound.arguments.values()) + if not positional: + return await workflow.execute_activity(activity_name, **options) + if len(positional) == 1: + return await workflow.execute_activity( + activity_name, positional[0], **options + ) + return await workflow.execute_activity( + activity_name, args=positional, **options + ) + + proxy.__signature__ = sig # type: ignore[attr-defined] + return strands_tool(proxy) + + +def _build_activity(tool: ToolProvider | Any) -> Callable: + fn = _unwrap_tool(tool) + return activity.defn(name=fn.__name__)(_ensure_async(fn)) + + +def _unwrap_tool(tool: ToolProvider | Any) -> Callable: + if isinstance(tool, ModuleType): + name = tool.__name__.rsplit(".", 1)[-1] + fn = getattr(tool, name) + else: + fn = tool + if isinstance(fn, DecoratedFunctionTool): + return fn._tool_func + if not callable(fn): + raise TypeError(f"Cannot wrap {fn!r} as a Temporal activity") + return fn + + +def _ensure_async(fn: Callable) -> Callable: + """Return an async-compatible version of ``fn``. + + Temporal's Worker rejects sync activity functions unless an + ``activity_executor`` is configured. Most ``strands_tools`` functions + (e.g. ``current_time``, ``calculator``) are sync, so we wrap them in an + ``asyncio.to_thread`` call instead of requiring the user to wire up an + executor on the Worker. + """ + if inspect.iscoroutinefunction(fn): + return fn + + @functools.wraps(fn) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + return await asyncio.to_thread(fn, *args, **kwargs) + + return wrapper + + +def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: + """Add ``strands`` and ``strands_tools`` to the sandbox passthrough list.""" + if not runner: + raise ValueError("No WorkflowRunner provided to the Strands plugin.") + if isinstance(runner, SandboxedWorkflowRunner): + return replace( + runner, + restrictions=runner.restrictions.with_passthrough_modules( + "strands", + "strands_tools", + ), + ) + return runner + + +def _data_converter(converter: DataConverter | None) -> DataConverter: + """Default to ``pydantic_data_converter`` when the user hasn't overridden. + + Strands optionally surfaces pydantic ``BaseModel`` values (e.g. + ``result.structured_output``) that the default Temporal converter can't + serialize. ``pydantic_data_converter`` is a strict superset of the + default, so this is safe even for workflows that never touch pydantic. + """ + if ( + converter is None + or converter.payload_converter_class is DefaultPayloadConverter + ): + return pydantic_data_converter + return converter diff --git a/tests/contrib/strands/mock_model.py b/tests/contrib/strands/mock_model.py new file mode 100644 index 000000000..5cbb0f89f --- /dev/null +++ b/tests/contrib/strands/mock_model.py @@ -0,0 +1,60 @@ +from __future__ import annotations + +import json +from collections.abc import AsyncIterable +from typing import Any + +from strands.models import Model +from strands.types.streaming import StreamEvent + + +class MockModel(Model): + """Scripted Strands ``Model`` for tests. + + Each entry in ``responses`` is consumed by one ``stream()`` call. A ``str`` + yields a text turn; a ``dict`` of ``{name, input}`` yields a tool-use turn. + """ + + def __init__(self, responses: list[str | dict[str, Any]]) -> None: + self._responses = list(responses) + self._tool_call_index = 0 + + def update_config(self, **_model_config: Any) -> None: + return None + + def get_config(self) -> dict[str, Any]: + return {} + + def structured_output(self, *_args: Any, **_kwargs: Any): + raise NotImplementedError + + async def stream(self, *_args: Any, **_kwargs: Any) -> AsyncIterable[StreamEvent]: + if not self._responses: + raise AssertionError("MockModel script exhausted") + response = self._responses.pop(0) + + yield {"messageStart": {"role": "assistant"}} + + if isinstance(response, str): + yield {"contentBlockDelta": {"delta": {"text": response}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + else: + self._tool_call_index += 1 + yield { + "contentBlockStart": { + "start": { + "toolUse": { + "name": response["name"], + "toolUseId": f"mock-tool-{self._tool_call_index}", + }, + }, + }, + } + yield { + "contentBlockDelta": { + "delta": {"toolUse": {"input": json.dumps(response["input"])}}, + }, + } + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "tool_use"}} diff --git a/tests/contrib/strands/test_structured_output.py b/tests/contrib/strands/test_structured_output.py new file mode 100644 index 000000000..29663e7dc --- /dev/null +++ b/tests/contrib/strands/test_structured_output.py @@ -0,0 +1,66 @@ +from uuid import uuid4 + +from pydantic import BaseModel, Field +from strands import Agent + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.strands import StrandsPlugin +from temporalio.worker import Replayer, Worker +from tests.contrib.strands.mock_model import MockModel + + +class PersonInfo(BaseModel): + name: str = Field(description="Name of the person") + age: int = Field(description="Age of the person") + occupation: str = Field(description="Occupation of the person") + + +@workflow.defn +class StructuredOutputWorkflow: + def __init__(self) -> None: + model = MockModel( + [ + { + "name": "PersonInfo", + "input": { + "name": "John Smith", + "age": 30, + "occupation": "software engineer", + }, + }, + ] + ) + self.agent = Agent(model=model, structured_output_model=PersonInfo) + + @workflow.run + async def run(self, prompt: str) -> PersonInfo: + result = await self.agent.invoke_async(prompt) + assert isinstance(result.structured_output, PersonInfo) + return result.structured_output + + +async def test_structured_output(client: Client): + task_queue = "test_structured_output" + plugin = StrandsPlugin() + + async with Worker( + client, + task_queue=task_queue, + workflows=[StructuredOutputWorkflow], + plugins=[plugin], + ): + handle = await client.start_workflow( + StructuredOutputWorkflow.run, + "John Smith is a 30 year-old software engineer", + id=f"test_structured_output_{uuid4()}", + task_queue=task_queue, + ) + assert await handle.result() == PersonInfo( + name="John Smith", age=30, occupation="software engineer" + ) + + await Replayer( + workflows=[StructuredOutputWorkflow], + plugins=[plugin], + ).replay_workflow(await handle.fetch_history()) diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py new file mode 100644 index 000000000..28b37f215 --- /dev/null +++ b/tests/contrib/strands/test_tool.py @@ -0,0 +1,90 @@ +from datetime import timedelta +from uuid import uuid4 + +from strands import Agent, tool +from strands_tools import calculator, current_time + +from temporalio import workflow +from temporalio.api.enums.v1 import EventType +from temporalio.client import Client, WorkflowHistory +from temporalio.contrib.strands import StrandsPlugin, as_activity +from temporalio.worker import Replayer, Worker +from tests.contrib.strands.mock_model import MockModel + + +@tool +def letter_counter(word: str, letter: str) -> int: + return word.lower().count(letter.lower()) + + +@workflow.defn +class ToolWorkflow: + def __init__(self) -> None: + model = MockModel( + [ + {"name": "current_time", "input": {}}, + { + "name": "calculator", + "input": {"expression": "3111696 / 74088"}, + }, + { + "name": "letter_counter", + "input": {"word": "strawberry", "letter": "R"}, + }, + "Done!", + ] + ) + self.agent = Agent( + model=model, + tools=[ + calculator, + as_activity( + current_time, + start_to_close_timeout=timedelta(seconds=15), + ), + letter_counter, + ], + ) + + @workflow.run + async def run(self, prompt: str) -> str: + result = await self.agent.invoke_async(prompt) + return str(result) + + +async def test_tool(client: Client): + task_queue = "test_tool" + plugin = StrandsPlugin(activity_tools=[current_time]) + + async with Worker( + client, + task_queue=task_queue, + workflows=[ToolWorkflow], + plugins=[plugin], + ): + handle = await client.start_workflow( + ToolWorkflow.run, + "I have 4 requests:\n" + "1. What is the time right now?\n" + "2. Calculate 3111696 / 74088\n" + '3. Tell me how many letter R\'s are in the word "strawberry" 🍓', + id=f"test_tool_{uuid4()}", + task_queue=task_queue, + ) + assert await handle.result() == "Done!\n" + + history = await handle.fetch_history() + assert get_activities(history) == ["current_time"] + + await Replayer( + workflows=[ToolWorkflow], + plugins=[plugin], + ).replay_workflow(history) + + +def get_activities(history: WorkflowHistory) -> list[str]: + return [ + event.activity_task_scheduled_event_attributes.activity_type.name + for event in history.events + if event.event_type == EventType.EVENT_TYPE_ACTIVITY_TASK_SCHEDULED + ] diff --git a/uv.lock b/uv.lock index a0e471bbe..e051b4bd5 100644 --- a/uv.lock +++ b/uv.lock @@ -9,7 +9,7 @@ resolution-markers = [ ] [options] -exclude-newer = "2026-05-01T16:01:05.963689Z" +exclude-newer = "0001-01-01T00:00:00Z" # This has no effect and is included for backwards compatibility when using relative exclude-newer values. exclude-newer-span = "P1W" [options.exclude-newer-package] @@ -316,6 +316,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/02/ff/1175b0b7371e46244032d43a56862d0af455823b5280a50c63d99cc50f18/automat-25.4.16-py3-none-any.whl", hash = "sha256:04e9bce696a8d5671ee698005af6e5a9fa15354140a87f4870744604dcdd3ba1", size = 42842, upload-time = "2025-04-16T20:12:14.447Z" }, ] +[[package]] +name = "aws-requests-auth" +version = "0.4.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "requests" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/54/b2/455c0bfcbd772dafd4c9e93c4b713e36790abf9ccbca9b8e661968b29798/aws-requests-auth-0.4.3.tar.gz", hash = "sha256:33593372018b960a31dbbe236f89421678b885c35f0b6a7abfae35bb77e069b2", size = 10096, upload-time = "2020-05-27T23:10:34.742Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/af/11/5dc8be418e1d54bed15eaf3a7461797e5ebb9e6a34869ad750561f35fa5b/aws_requests_auth-0.4.3-py2.py3-none-any.whl", hash = "sha256:646bc37d62140ea1c709d20148f5d43197e6bd2d63909eb36fa4bb2345759977", size = 6838, upload-time = "2020-05-27T23:10:33.658Z" }, +] + [[package]] name = "aws-sam-translator" version = "1.106.0" @@ -374,6 +386,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f4/be/6985abb1011fda8a523cfe21ed9629e397d6e06fb5bae99750402b25c95b/bashlex-0.18-py2.py3-none-any.whl", hash = "sha256:91d73a23a3e51711919c1c899083890cdecffc91d8c088942725ac13e9dcfffa", size = 69539, upload-time = "2023-01-18T15:21:24.167Z" }, ] +[[package]] +name = "beautifulsoup4" +version = "4.14.3" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "soupsieve" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c3/b0/1c6a16426d389813b48d95e26898aff79abbde42ad353958ad95cc8c9b21/beautifulsoup4-4.14.3.tar.gz", hash = "sha256:6292b1c5186d356bba669ef9f7f051757099565ad9ada5dd630bd9de5fa7fb86", size = 627737, upload-time = "2025-11-30T15:08:26.084Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1a/39/47f9197bdd44df24d67ac8893641e16f386c984a0619ef2ee4c51fbbc019/beautifulsoup4-4.14.3-py3-none-any.whl", hash = "sha256:0918bfe44902e6ad8d57732ba310582e98da931428d231a5ecb9e7c703a735bb", size = 107721, upload-time = "2025-11-30T15:08:24.087Z" }, +] + [[package]] name = "blinker" version = "1.9.0" @@ -924,6 +949,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/99/c7/d1ec24fb280caa5a79b6b950db565dab30210a66259d17d5bb2b3a9f878d/dependency_groups-1.3.1-py3-none-any.whl", hash = "sha256:51aeaa0dfad72430fcfb7bcdbefbd75f3792e5919563077f30bc0d73f4493030", size = 8664, upload-time = "2025-05-02T00:34:27.085Z" }, ] +[[package]] +name = "dill" +version = "0.4.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/81/e1/56027a71e31b02ddc53c7d65b01e68edf64dea2932122fe7746a516f75d5/dill-0.4.1.tar.gz", hash = "sha256:423092df4182177d4d8ba8290c8a5b640c66ab35ec7da59ccfa00f6fa3eea5fa", size = 187315, upload-time = "2026-01-19T02:36:56.85Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/1e/77/dc8c558f7593132cf8fefec57c4f60c83b16941c574ac5f619abb3ae7933/dill-0.4.1-py3-none-any.whl", hash = "sha256:1e1ce33e978ae97fcfcff5638477032b801c46c7c65cf717f95fbc2248f79a9d", size = 120019, upload-time = "2026-01-19T02:36:55.663Z" }, +] + [[package]] name = "distro" version = "1.9.0" @@ -2678,6 +2712,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/94/54/e7d793b573f298e1c9013b8c4dade17d481164aa517d1d7148619c2cedbf/markdown_it_py-4.0.0-py3-none-any.whl", hash = "sha256:87327c59b172c5011896038353a81343b6754500a08cd7a4973bb48c6d578147", size = 87321, upload-time = "2025-08-11T12:57:51.923Z" }, ] +[[package]] +name = "markdownify" +version = "1.2.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "beautifulsoup4" }, + { name = "six" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3f/bc/c8c8eea5335341306b0fa7e1cb33c5e1c8d24ef70ddd684da65f41c49c92/markdownify-1.2.2.tar.gz", hash = "sha256:b274f1b5943180b031b699b199cbaeb1e2ac938b75851849a31fd0c3d6603d09", size = 18816, upload-time = "2025-11-16T19:21:18.565Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/43/ce/f1e3e9d959db134cedf06825fae8d5b294bd368aacdd0831a3975b7c4d55/markdownify-1.2.2-py3-none-any.whl", hash = "sha256:3f02d3cc52714084d6e589f70397b6fc9f2f3a8531481bf35e8cc39f975e186a", size = 15724, upload-time = "2025-11-16T19:21:17.622Z" }, +] + [[package]] name = "markupsafe" version = "3.0.3" @@ -3616,6 +3663,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/10/f5/7a40ff3f62bfe715dad2f633d7f1174ba1a7dd74254c15b2558b3401262a/opentelemetry_instrumentation-0.59b0-py3-none-any.whl", hash = "sha256:44082cc8fe56b0186e87ee8f7c17c327c4c2ce93bdbe86496e600985d74368ee", size = 33020, upload-time = "2025-10-16T08:38:31.463Z" }, ] +[[package]] +name = "opentelemetry-instrumentation-threading" +version = "0.59b0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation" }, + { name = "wrapt" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/82/7a/84e97d8992808197006e607ae410c2219bdbbc23d1289ba0c244d3220741/opentelemetry_instrumentation_threading-0.59b0.tar.gz", hash = "sha256:ce5658730b697dcbc0e0d6d13643a69fd8aeb1b32fa8db3bade8ce114c7975f3", size = 8770, upload-time = "2025-10-16T08:40:03.587Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/b8/50/32d29076aaa1c91983cdd3ca8c6bb4d344830cd7d87a7c0fdc2d98c58509/opentelemetry_instrumentation_threading-0.59b0-py3-none-any.whl", hash = "sha256:76da2fc01fe1dccebff6581080cff9e42ac7b27cc61eb563f3c4435c727e8eca", size = 9313, upload-time = "2025-10-16T08:39:15.876Z" }, +] + [[package]] name = "opentelemetry-proto" version = "1.38.0" @@ -3846,6 +3907,104 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/fa/c9/8eed0486f074e9f1ca7f8ce5ad663e65f12fdab344028d658fa1b03d35e0/pathspec-1.1.0-py3-none-any.whl", hash = "sha256:574b128f7456bd899045ccd142dd446af7e6cfd0072d63ad73fbc55fbb4aaa42", size = 56264, upload-time = "2026-04-23T01:46:20.606Z" }, ] +[[package]] +name = "pillow" +version = "12.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/21/c2bcdd5906101a30244eaffc1b6e6ce71a31bd0742a01eb89e660ebfac2d/pillow-12.2.0.tar.gz", hash = "sha256:a830b1a40919539d07806aa58e1b114df53ddd43213d9c8b75847eee6c0182b5", size = 46987819, upload-time = "2026-04-01T14:46:17.687Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3a/aa/d0b28e1c811cd4d5f5c2bfe2e022292bd255ae5744a3b9ac7d6c8f72dd75/pillow-12.2.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:a4e8f36e677d3336f35089648c8955c51c6d386a13cf6ee9c189c5f5bd713a9f", size = 5354355, upload-time = "2026-04-01T14:42:15.402Z" }, + { url = "https://files.pythonhosted.org/packages/27/8e/1d5b39b8ae2bd7650d0c7b6abb9602d16043ead9ebbfef4bc4047454da2a/pillow-12.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2e589959f10d9824d39b350472b92f0ce3b443c0a3442ebf41c40cb8361c5b97", size = 4695871, upload-time = "2026-04-01T14:42:18.234Z" }, + { url = "https://files.pythonhosted.org/packages/f0/c5/dcb7a6ca6b7d3be41a76958e90018d56c8462166b3ef223150360850c8da/pillow-12.2.0-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:a52edc8bfff4429aaabdf4d9ee0daadbbf8562364f940937b941f87a4290f5ff", size = 6269734, upload-time = "2026-04-01T14:42:20.608Z" }, + { url = "https://files.pythonhosted.org/packages/ea/f1/aa1bb13b2f4eba914e9637893c73f2af8e48d7d4023b9d3750d4c5eb2d0c/pillow-12.2.0-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:975385f4776fafde056abb318f612ef6285b10a1f12b8570f3647ad0d74b48ec", size = 8076080, upload-time = "2026-04-01T14:42:23.095Z" }, + { url = "https://files.pythonhosted.org/packages/a1/2a/8c79d6a53169937784604a8ae8d77e45888c41537f7f6f65ed1f407fe66d/pillow-12.2.0-cp310-cp310-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:bd9c0c7a0c681a347b3194c500cb1e6ca9cab053ea4d82a5cf45b6b754560136", size = 6382236, upload-time = "2026-04-01T14:42:25.82Z" }, + { url = "https://files.pythonhosted.org/packages/b5/42/bbcb6051030e1e421d103ce7a8ecadf837aa2f39b8f82ef1a8d37c3d4ebc/pillow-12.2.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:88d387ff40b3ff7c274947ed3125dedf5262ec6919d83946753b5f3d7c67ea4c", size = 7070220, upload-time = "2026-04-01T14:42:28.68Z" }, + { url = "https://files.pythonhosted.org/packages/3f/e1/c2a7d6dd8cfa6b231227da096fd2d58754bab3603b9d73bf609d3c18b64f/pillow-12.2.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:51c4167c34b0d8ba05b547a3bb23578d0ba17b80a5593f93bd8ecb123dd336a3", size = 6493124, upload-time = "2026-04-01T14:42:31.579Z" }, + { url = "https://files.pythonhosted.org/packages/5f/41/7c8617da5d32e1d2f026e509484fdb6f3ad7efaef1749a0c1928adbb099e/pillow-12.2.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:34c0d99ecccea270c04882cb3b86e7b57296079c9a4aff88cb3b33563d95afaa", size = 7194324, upload-time = "2026-04-01T14:42:34.615Z" }, + { url = "https://files.pythonhosted.org/packages/2d/de/a777627e19fd6d62f84070ee1521adde5eeda4855b5cf60fe0b149118bca/pillow-12.2.0-cp310-cp310-win32.whl", hash = "sha256:b85f66ae9eb53e860a873b858b789217ba505e5e405a24b85c0464822fe88032", size = 6376363, upload-time = "2026-04-01T14:42:37.19Z" }, + { url = "https://files.pythonhosted.org/packages/e7/34/fc4cb5204896465842767b96d250c08410f01f2f28afc43b257de842eed5/pillow-12.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:673aa32138f3e7531ccdbca7b3901dba9b70940a19ccecc6a37c77d5fdeb05b5", size = 7083523, upload-time = "2026-04-01T14:42:39.62Z" }, + { url = "https://files.pythonhosted.org/packages/2d/a0/32852d36bc7709f14dc3f64f929a275e958ad8c19a6deba9610d458e28b3/pillow-12.2.0-cp310-cp310-win_arm64.whl", hash = "sha256:3e080565d8d7c671db5802eedfb438e5565ffa40115216eabb8cd52d0ecce024", size = 2463318, upload-time = "2026-04-01T14:42:42.063Z" }, + { url = "https://files.pythonhosted.org/packages/68/e1/748f5663efe6edcfc4e74b2b93edfb9b8b99b67f21a854c3ae416500a2d9/pillow-12.2.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:8be29e59487a79f173507c30ddf57e733a357f67881430449bb32614075a40ab", size = 5354347, upload-time = "2026-04-01T14:42:44.255Z" }, + { url = "https://files.pythonhosted.org/packages/47/a1/d5ff69e747374c33a3b53b9f98cca7889fce1fd03d79cdc4e1bccc6c5a87/pillow-12.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:71cde9a1e1551df7d34a25462fc60325e8a11a82cc2e2f54578e5e9a1e153d65", size = 4695873, upload-time = "2026-04-01T14:42:46.452Z" }, + { url = "https://files.pythonhosted.org/packages/df/21/e3fbdf54408a973c7f7f89a23b2cb97a7ef30c61ab4142af31eee6aebc88/pillow-12.2.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f490f9368b6fc026f021db16d7ec2fbf7d89e2edb42e8ec09d2c60505f5729c7", size = 6280168, upload-time = "2026-04-01T14:42:49.228Z" }, + { url = "https://files.pythonhosted.org/packages/d3/f1/00b7278c7dd52b17ad4329153748f87b6756ec195ff786c2bdf12518337d/pillow-12.2.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:8bd7903a5f2a4545f6fd5935c90058b89d30045568985a71c79f5fd6edf9b91e", size = 8088188, upload-time = "2026-04-01T14:42:51.735Z" }, + { url = "https://files.pythonhosted.org/packages/ad/cf/220a5994ef1b10e70e85748b75649d77d506499352be135a4989c957b701/pillow-12.2.0-cp311-cp311-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:3997232e10d2920a68d25191392e3a4487d8183039e1c74c2297f00ed1c50705", size = 6394401, upload-time = "2026-04-01T14:42:54.343Z" }, + { url = "https://files.pythonhosted.org/packages/e9/bd/e51a61b1054f09437acfbc2ff9106c30d1eb76bc1453d428399946781253/pillow-12.2.0-cp311-cp311-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e74473c875d78b8e9d5da2a70f7099549f9eb37ded4e2f6a463e60125bccd176", size = 7079655, upload-time = "2026-04-01T14:42:56.954Z" }, + { url = "https://files.pythonhosted.org/packages/6b/3d/45132c57d5fb4b5744567c3817026480ac7fc3ce5d4c47902bc0e7f6f853/pillow-12.2.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:56a3f9c60a13133a98ecff6197af34d7824de9b7b38c3654861a725c970c197b", size = 6503105, upload-time = "2026-04-01T14:42:59.847Z" }, + { url = "https://files.pythonhosted.org/packages/7d/2e/9df2fc1e82097b1df3dce58dc43286aa01068e918c07574711fcc53e6fb4/pillow-12.2.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:90e6f81de50ad6b534cab6e5aef77ff6e37722b2f5d908686f4a5c9eba17a909", size = 7203402, upload-time = "2026-04-01T14:43:02.664Z" }, + { url = "https://files.pythonhosted.org/packages/bd/2e/2941e42858ebb67e50ae741473de81c2984e6eff7b397017623c676e2e8d/pillow-12.2.0-cp311-cp311-win32.whl", hash = "sha256:8c984051042858021a54926eb597d6ee3012393ce9c181814115df4c60b9a808", size = 6378149, upload-time = "2026-04-01T14:43:05.274Z" }, + { url = "https://files.pythonhosted.org/packages/69/42/836b6f3cd7f3e5fa10a1f1a5420447c17966044c8fbf589cc0452d5502db/pillow-12.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:6e6b2a0c538fc200b38ff9eb6628228b77908c319a005815f2dde585a0664b60", size = 7082626, upload-time = "2026-04-01T14:43:08.557Z" }, + { url = "https://files.pythonhosted.org/packages/c2/88/549194b5d6f1f494b485e493edc6693c0a16f4ada488e5bd974ed1f42fad/pillow-12.2.0-cp311-cp311-win_arm64.whl", hash = "sha256:9a8a34cc89c67a65ea7437ce257cea81a9dad65b29805f3ecee8c8fe8ff25ffe", size = 2463531, upload-time = "2026-04-01T14:43:10.743Z" }, + { url = "https://files.pythonhosted.org/packages/58/be/7482c8a5ebebbc6470b3eb791812fff7d5e0216c2be3827b30b8bb6603ed/pillow-12.2.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:2d192a155bbcec180f8564f693e6fd9bccff5a7af9b32e2e4bf8c9c69dbad6b5", size = 5308279, upload-time = "2026-04-01T14:43:13.246Z" }, + { url = "https://files.pythonhosted.org/packages/d8/95/0a351b9289c2b5cbde0bacd4a83ebc44023e835490a727b2a3bd60ddc0f4/pillow-12.2.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f3f40b3c5a968281fd507d519e444c35f0ff171237f4fdde090dd60699458421", size = 4695490, upload-time = "2026-04-01T14:43:15.584Z" }, + { url = "https://files.pythonhosted.org/packages/de/af/4e8e6869cbed569d43c416fad3dc4ecb944cb5d9492defaed89ddd6fe871/pillow-12.2.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:03e7e372d5240cc23e9f07deca4d775c0817bffc641b01e9c3af208dbd300987", size = 6284462, upload-time = "2026-04-01T14:43:18.268Z" }, + { url = "https://files.pythonhosted.org/packages/e9/9e/c05e19657fd57841e476be1ab46c4d501bffbadbafdc31a6d665f8b737b6/pillow-12.2.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:b86024e52a1b269467a802258c25521e6d742349d760728092e1bc2d135b4d76", size = 8094744, upload-time = "2026-04-01T14:43:20.716Z" }, + { url = "https://files.pythonhosted.org/packages/2b/54/1789c455ed10176066b6e7e6da1b01e50e36f94ba584dc68d9eebfe9156d/pillow-12.2.0-cp312-cp312-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:7371b48c4fa448d20d2714c9a1f775a81155050d383333e0a6c15b1123dda005", size = 6398371, upload-time = "2026-04-01T14:43:23.443Z" }, + { url = "https://files.pythonhosted.org/packages/43/e3/fdc657359e919462369869f1c9f0e973f353f9a9ee295a39b1fea8ee1a77/pillow-12.2.0-cp312-cp312-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:62f5409336adb0663b7caa0da5c7d9e7bdbaae9ce761d34669420c2a801b2780", size = 7087215, upload-time = "2026-04-01T14:43:26.758Z" }, + { url = "https://files.pythonhosted.org/packages/8b/f8/2f6825e441d5b1959d2ca5adec984210f1ec086435b0ed5f52c19b3b8a6e/pillow-12.2.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:01afa7cf67f74f09523699b4e88c73fb55c13346d212a59a2db1f86b0a63e8c5", size = 6509783, upload-time = "2026-04-01T14:43:29.56Z" }, + { url = "https://files.pythonhosted.org/packages/67/f9/029a27095ad20f854f9dba026b3ea6428548316e057e6fc3545409e86651/pillow-12.2.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:fc3d34d4a8fbec3e88a79b92e5465e0f9b842b628675850d860b8bd300b159f5", size = 7212112, upload-time = "2026-04-01T14:43:32.091Z" }, + { url = "https://files.pythonhosted.org/packages/be/42/025cfe05d1be22dbfdb4f264fe9de1ccda83f66e4fc3aac94748e784af04/pillow-12.2.0-cp312-cp312-win32.whl", hash = "sha256:58f62cc0f00fd29e64b29f4fd923ffdb3859c9f9e6105bfc37ba1d08994e8940", size = 6378489, upload-time = "2026-04-01T14:43:34.601Z" }, + { url = "https://files.pythonhosted.org/packages/5d/7b/25a221d2c761c6a8ae21bfa3874988ff2583e19cf8a27bf2fee358df7942/pillow-12.2.0-cp312-cp312-win_amd64.whl", hash = "sha256:7f84204dee22a783350679a0333981df803dac21a0190d706a50475e361c93f5", size = 7084129, upload-time = "2026-04-01T14:43:37.213Z" }, + { url = "https://files.pythonhosted.org/packages/10/e1/542a474affab20fd4a0f1836cb234e8493519da6b76899e30bcc5d990b8b/pillow-12.2.0-cp312-cp312-win_arm64.whl", hash = "sha256:af73337013e0b3b46f175e79492d96845b16126ddf79c438d7ea7ff27783a414", size = 2463612, upload-time = "2026-04-01T14:43:39.421Z" }, + { url = "https://files.pythonhosted.org/packages/4a/01/53d10cf0dbad820a8db274d259a37ba50b88b24768ddccec07355382d5ad/pillow-12.2.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:8297651f5b5679c19968abefd6bb84d95fe30ef712eb1b2d9b2d31ca61267f4c", size = 4100837, upload-time = "2026-04-01T14:43:41.506Z" }, + { url = "https://files.pythonhosted.org/packages/0f/98/f3a6657ecb698c937f6c76ee564882945f29b79bad496abcba0e84659ec5/pillow-12.2.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:50d8520da2a6ce0af445fa6d648c4273c3eeefbc32d7ce049f22e8b5c3daecc2", size = 4176528, upload-time = "2026-04-01T14:43:43.773Z" }, + { url = "https://files.pythonhosted.org/packages/69/bc/8986948f05e3ea490b8442ea1c1d4d990b24a7e43d8a51b2c7d8b1dced36/pillow-12.2.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:766cef22385fa1091258ad7e6216792b156dc16d8d3fa607e7545b2b72061f1c", size = 3640401, upload-time = "2026-04-01T14:43:45.87Z" }, + { url = "https://files.pythonhosted.org/packages/34/46/6c717baadcd62bc8ed51d238d521ab651eaa74838291bda1f86fe1f864c9/pillow-12.2.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5d2fd0fa6b5d9d1de415060363433f28da8b1526c1c129020435e186794b3795", size = 5308094, upload-time = "2026-04-01T14:43:48.438Z" }, + { url = "https://files.pythonhosted.org/packages/71/43/905a14a8b17fdb1ccb58d282454490662d2cb89a6bfec26af6d3520da5ec/pillow-12.2.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:56b25336f502b6ed02e889f4ece894a72612fe885889a6e8c4c80239ff6e5f5f", size = 4695402, upload-time = "2026-04-01T14:43:51.292Z" }, + { url = "https://files.pythonhosted.org/packages/73/dd/42107efcb777b16fa0393317eac58f5b5cf30e8392e266e76e51cff28c3d/pillow-12.2.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f1c943e96e85df3d3478f7b691f229887e143f81fedab9b20205349ab04d73ed", size = 6280005, upload-time = "2026-04-01T14:43:54.242Z" }, + { url = "https://files.pythonhosted.org/packages/a8/68/b93e09e5e8549019e61acf49f65b1a8530765a7f812c77a7461bca7e4494/pillow-12.2.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:03f6fab9219220f041c74aeaa2939ff0062bd5c364ba9ce037197f4c6d498cd9", size = 8090669, upload-time = "2026-04-01T14:43:57.335Z" }, + { url = "https://files.pythonhosted.org/packages/4b/6e/3ccb54ce8ec4ddd1accd2d89004308b7b0b21c4ac3d20fa70af4760a4330/pillow-12.2.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5cdfebd752ec52bf5bb4e35d9c64b40826bc5b40a13df7c3cda20a2c03a0f5ed", size = 6395194, upload-time = "2026-04-01T14:43:59.864Z" }, + { url = "https://files.pythonhosted.org/packages/67/ee/21d4e8536afd1a328f01b359b4d3997b291ffd35a237c877b331c1c3b71c/pillow-12.2.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eedf4b74eda2b5a4b2b2fb4c006d6295df3bf29e459e198c90ea48e130dc75c3", size = 7082423, upload-time = "2026-04-01T14:44:02.74Z" }, + { url = "https://files.pythonhosted.org/packages/78/5f/e9f86ab0146464e8c133fe85df987ed9e77e08b29d8d35f9f9f4d6f917ba/pillow-12.2.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:00a2865911330191c0b818c59103b58a5e697cae67042366970a6b6f1b20b7f9", size = 6505667, upload-time = "2026-04-01T14:44:05.381Z" }, + { url = "https://files.pythonhosted.org/packages/ed/1e/409007f56a2fdce61584fd3acbc2bbc259857d555196cedcadc68c015c82/pillow-12.2.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1e1757442ed87f4912397c6d35a0db6a7b52592156014706f17658ff58bbf795", size = 7208580, upload-time = "2026-04-01T14:44:08.39Z" }, + { url = "https://files.pythonhosted.org/packages/23/c4/7349421080b12fb35414607b8871e9534546c128a11965fd4a7002ccfbee/pillow-12.2.0-cp313-cp313-win32.whl", hash = "sha256:144748b3af2d1b358d41286056d0003f47cb339b8c43a9ea42f5fea4d8c66b6e", size = 6375896, upload-time = "2026-04-01T14:44:11.197Z" }, + { url = "https://files.pythonhosted.org/packages/3f/82/8a3739a5e470b3c6cbb1d21d315800d8e16bff503d1f16b03a4ec3212786/pillow-12.2.0-cp313-cp313-win_amd64.whl", hash = "sha256:390ede346628ccc626e5730107cde16c42d3836b89662a115a921f28440e6a3b", size = 7081266, upload-time = "2026-04-01T14:44:13.947Z" }, + { url = "https://files.pythonhosted.org/packages/c3/25/f968f618a062574294592f668218f8af564830ccebdd1fa6200f598e65c5/pillow-12.2.0-cp313-cp313-win_arm64.whl", hash = "sha256:8023abc91fba39036dbce14a7d6535632f99c0b857807cbbbf21ecc9f4717f06", size = 2463508, upload-time = "2026-04-01T14:44:16.312Z" }, + { url = "https://files.pythonhosted.org/packages/4d/a4/b342930964e3cb4dce5038ae34b0eab4653334995336cd486c5a8c25a00c/pillow-12.2.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:042db20a421b9bafecc4b84a8b6e444686bd9d836c7fd24542db3e7df7baad9b", size = 5309927, upload-time = "2026-04-01T14:44:18.89Z" }, + { url = "https://files.pythonhosted.org/packages/9f/de/23198e0a65a9cf06123f5435a5d95cea62a635697f8f03d134d3f3a96151/pillow-12.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:dd025009355c926a84a612fecf58bb315a3f6814b17ead51a8e48d3823d9087f", size = 4698624, upload-time = "2026-04-01T14:44:21.115Z" }, + { url = "https://files.pythonhosted.org/packages/01/a6/1265e977f17d93ea37aa28aa81bad4fa597933879fac2520d24e021c8da3/pillow-12.2.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:88ddbc66737e277852913bd1e07c150cc7bb124539f94c4e2df5344494e0a612", size = 6321252, upload-time = "2026-04-01T14:44:23.663Z" }, + { url = "https://files.pythonhosted.org/packages/3c/83/5982eb4a285967baa70340320be9f88e57665a387e3a53a7f0db8231a0cd/pillow-12.2.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d362d1878f00c142b7e1a16e6e5e780f02be8195123f164edf7eddd911eefe7c", size = 8126550, upload-time = "2026-04-01T14:44:26.772Z" }, + { url = "https://files.pythonhosted.org/packages/4e/48/6ffc514adce69f6050d0753b1a18fd920fce8cac87620d5a31231b04bfc5/pillow-12.2.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2c727a6d53cb0018aadd8018c2b938376af27914a68a492f59dfcaca650d5eea", size = 6433114, upload-time = "2026-04-01T14:44:29.615Z" }, + { url = "https://files.pythonhosted.org/packages/36/a3/f9a77144231fb8d40ee27107b4463e205fa4677e2ca2548e14da5cf18dce/pillow-12.2.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:efd8c21c98c5cc60653bcb311bef2ce0401642b7ce9d09e03a7da87c878289d4", size = 7115667, upload-time = "2026-04-01T14:44:32.773Z" }, + { url = "https://files.pythonhosted.org/packages/c1/fc/ac4ee3041e7d5a565e1c4fd72a113f03b6394cc72ab7089d27608f8aaccb/pillow-12.2.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9f08483a632889536b8139663db60f6724bfcb443c96f1b18855860d7d5c0fd4", size = 6538966, upload-time = "2026-04-01T14:44:35.252Z" }, + { url = "https://files.pythonhosted.org/packages/c0/a8/27fb307055087f3668f6d0a8ccb636e7431d56ed0750e07a60547b1e083e/pillow-12.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:dac8d77255a37e81a2efcbd1fc05f1c15ee82200e6c240d7e127e25e365c39ea", size = 7238241, upload-time = "2026-04-01T14:44:37.875Z" }, + { url = "https://files.pythonhosted.org/packages/ad/4b/926ab182c07fccae9fcb120043464e1ff1564775ec8864f21a0ebce6ac25/pillow-12.2.0-cp313-cp313t-win32.whl", hash = "sha256:ee3120ae9dff32f121610bb08e4313be87e03efeadfc6c0d18f89127e24d0c24", size = 6379592, upload-time = "2026-04-01T14:44:40.336Z" }, + { url = "https://files.pythonhosted.org/packages/c2/c4/f9e476451a098181b30050cc4c9a3556b64c02cf6497ea421ac047e89e4b/pillow-12.2.0-cp313-cp313t-win_amd64.whl", hash = "sha256:325ca0528c6788d2a6c3d40e3568639398137346c3d6e66bb61db96b96511c98", size = 7085542, upload-time = "2026-04-01T14:44:43.251Z" }, + { url = "https://files.pythonhosted.org/packages/00/a4/285f12aeacbe2d6dc36c407dfbbe9e96d4a80b0fb710a337f6d2ad978c75/pillow-12.2.0-cp313-cp313t-win_arm64.whl", hash = "sha256:2e5a76d03a6c6dcef67edabda7a52494afa4035021a79c8558e14af25313d453", size = 2465765, upload-time = "2026-04-01T14:44:45.996Z" }, + { url = "https://files.pythonhosted.org/packages/bf/98/4595daa2365416a86cb0d495248a393dfc84e96d62ad080c8546256cb9c0/pillow-12.2.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:3adc9215e8be0448ed6e814966ecf3d9952f0ea40eb14e89a102b87f450660d8", size = 4100848, upload-time = "2026-04-01T14:44:48.48Z" }, + { url = "https://files.pythonhosted.org/packages/0b/79/40184d464cf89f6663e18dfcf7ca21aae2491fff1a16127681bf1fa9b8cf/pillow-12.2.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:6a9adfc6d24b10f89588096364cc726174118c62130c817c2837c60cf08a392b", size = 4176515, upload-time = "2026-04-01T14:44:51.353Z" }, + { url = "https://files.pythonhosted.org/packages/b0/63/703f86fd4c422a9cf722833670f4f71418fb116b2853ff7da722ea43f184/pillow-12.2.0-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:6a6e67ea2e6feda684ed370f9a1c52e7a243631c025ba42149a2cc5934dec295", size = 3640159, upload-time = "2026-04-01T14:44:53.588Z" }, + { url = "https://files.pythonhosted.org/packages/71/e0/fb22f797187d0be2270f83500aab851536101b254bfa1eae10795709d283/pillow-12.2.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:2bb4a8d594eacdfc59d9e5ad972aa8afdd48d584ffd5f13a937a664c3e7db0ed", size = 5312185, upload-time = "2026-04-01T14:44:56.039Z" }, + { url = "https://files.pythonhosted.org/packages/ba/8c/1a9e46228571de18f8e28f16fabdfc20212a5d019f3e3303452b3f0a580d/pillow-12.2.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:80b2da48193b2f33ed0c32c38140f9d3186583ce7d516526d462645fd98660ae", size = 4695386, upload-time = "2026-04-01T14:44:58.663Z" }, + { url = "https://files.pythonhosted.org/packages/70/62/98f6b7f0c88b9addd0e87c217ded307b36be024d4ff8869a812b241d1345/pillow-12.2.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22db17c68434de69d8ecfc2fe821569195c0c373b25cccb9cbdacf2c6e53c601", size = 6280384, upload-time = "2026-04-01T14:45:01.5Z" }, + { url = "https://files.pythonhosted.org/packages/5e/03/688747d2e91cfbe0e64f316cd2e8005698f76ada3130d0194664174fa5de/pillow-12.2.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7b14cc0106cd9aecda615dd6903840a058b4700fcb817687d0ee4fc8b6e389be", size = 8091599, upload-time = "2026-04-01T14:45:04.5Z" }, + { url = "https://files.pythonhosted.org/packages/f6/35/577e22b936fcdd66537329b33af0b4ccfefaeabd8aec04b266528cddb33c/pillow-12.2.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cbeb542b2ebc6fcdacabf8aca8c1a97c9b3ad3927d46b8723f9d4f033288a0f", size = 6396021, upload-time = "2026-04-01T14:45:07.117Z" }, + { url = "https://files.pythonhosted.org/packages/11/8d/d2532ad2a603ca2b93ad9f5135732124e57811d0168155852f37fbce2458/pillow-12.2.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4bfd07bc812fbd20395212969e41931001fd59eb55a60658b0e5710872e95286", size = 7083360, upload-time = "2026-04-01T14:45:09.763Z" }, + { url = "https://files.pythonhosted.org/packages/5e/26/d325f9f56c7e039034897e7380e9cc202b1e368bfd04d4cbe6a441f02885/pillow-12.2.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:9aba9a17b623ef750a4d11b742cbafffeb48a869821252b30ee21b5e91392c50", size = 6507628, upload-time = "2026-04-01T14:45:12.378Z" }, + { url = "https://files.pythonhosted.org/packages/5f/f7/769d5632ffb0988f1c5e7660b3e731e30f7f8ec4318e94d0a5d674eb65a4/pillow-12.2.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:deede7c263feb25dba4e82ea23058a235dcc2fe1f6021025dc71f2b618e26104", size = 7209321, upload-time = "2026-04-01T14:45:15.122Z" }, + { url = "https://files.pythonhosted.org/packages/6a/7a/c253e3c645cd47f1aceea6a8bacdba9991bf45bb7dfe927f7c893e89c93c/pillow-12.2.0-cp314-cp314-win32.whl", hash = "sha256:632ff19b2778e43162304d50da0181ce24ac5bb8180122cbe1bf4673428328c7", size = 6479723, upload-time = "2026-04-01T14:45:17.797Z" }, + { url = "https://files.pythonhosted.org/packages/cd/8b/601e6566b957ca50e28725cb6c355c59c2c8609751efbecd980db44e0349/pillow-12.2.0-cp314-cp314-win_amd64.whl", hash = "sha256:4e6c62e9d237e9b65fac06857d511e90d8461a32adcc1b9065ea0c0fa3a28150", size = 7217400, upload-time = "2026-04-01T14:45:20.529Z" }, + { url = "https://files.pythonhosted.org/packages/d6/94/220e46c73065c3e2951bb91c11a1fb636c8c9ad427ac3ce7d7f3359b9b2f/pillow-12.2.0-cp314-cp314-win_arm64.whl", hash = "sha256:b1c1fbd8a5a1af3412a0810d060a78b5136ec0836c8a4ef9aa11807f2a22f4e1", size = 2554835, upload-time = "2026-04-01T14:45:23.162Z" }, + { url = "https://files.pythonhosted.org/packages/b6/ab/1b426a3974cb0e7da5c29ccff4807871d48110933a57207b5a676cccc155/pillow-12.2.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:57850958fe9c751670e49b2cecf6294acc99e562531f4bd317fa5ddee2068463", size = 5314225, upload-time = "2026-04-01T14:45:25.637Z" }, + { url = "https://files.pythonhosted.org/packages/19/1e/dce46f371be2438eecfee2a1960ee2a243bbe5e961890146d2dee1ff0f12/pillow-12.2.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:d5d38f1411c0ed9f97bcb49b7bd59b6b7c314e0e27420e34d99d844b9ce3b6f3", size = 4698541, upload-time = "2026-04-01T14:45:28.355Z" }, + { url = "https://files.pythonhosted.org/packages/55/c3/7fbecf70adb3a0c33b77a300dc52e424dc22ad8cdc06557a2e49523b703d/pillow-12.2.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c0a9f29ca8e79f09de89293f82fc9b0270bb4af1d58bc98f540cc4aedf03166", size = 6322251, upload-time = "2026-04-01T14:45:30.924Z" }, + { url = "https://files.pythonhosted.org/packages/1c/3c/7fbc17cfb7e4fe0ef1642e0abc17fc6c94c9f7a16be41498e12e2ba60408/pillow-12.2.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1610dd6c61621ae1cf811bef44d77e149ce3f7b95afe66a4512f8c59f25d9ebe", size = 8127807, upload-time = "2026-04-01T14:45:33.908Z" }, + { url = "https://files.pythonhosted.org/packages/ff/c3/a8ae14d6defd2e448493ff512fae903b1e9bd40b72efb6ec55ce0048c8ce/pillow-12.2.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a34329707af4f73cf1782a36cd2289c0368880654a2c11f027bcee9052d35dd", size = 6433935, upload-time = "2026-04-01T14:45:36.623Z" }, + { url = "https://files.pythonhosted.org/packages/6e/32/2880fb3a074847ac159d8f902cb43278a61e85f681661e7419e6596803ed/pillow-12.2.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8e9c4f5b3c546fa3458a29ab22646c1c6c787ea8f5ef51300e5a60300736905e", size = 7116720, upload-time = "2026-04-01T14:45:39.258Z" }, + { url = "https://files.pythonhosted.org/packages/46/87/495cc9c30e0129501643f24d320076f4cc54f718341df18cc70ec94c44e1/pillow-12.2.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:fb043ee2f06b41473269765c2feae53fc2e2fbf96e5e22ca94fb5ad677856f06", size = 6540498, upload-time = "2026-04-01T14:45:41.879Z" }, + { url = "https://files.pythonhosted.org/packages/18/53/773f5edca692009d883a72211b60fdaf8871cbef075eaa9d577f0a2f989e/pillow-12.2.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:f278f034eb75b4e8a13a54a876cc4a5ab39173d2cdd93a638e1b467fc545ac43", size = 7239413, upload-time = "2026-04-01T14:45:44.705Z" }, + { url = "https://files.pythonhosted.org/packages/c9/e4/4b64a97d71b2a83158134abbb2f5bd3f8a2ea691361282f010998f339ec7/pillow-12.2.0-cp314-cp314t-win32.whl", hash = "sha256:6bb77b2dcb06b20f9f4b4a8454caa581cd4dd0643a08bacf821216a16d9c8354", size = 6482084, upload-time = "2026-04-01T14:45:47.568Z" }, + { url = "https://files.pythonhosted.org/packages/ba/13/306d275efd3a3453f72114b7431c877d10b1154014c1ebbedd067770d629/pillow-12.2.0-cp314-cp314t-win_amd64.whl", hash = "sha256:6562ace0d3fb5f20ed7290f1f929cae41b25ae29528f2af1722966a0a02e2aa1", size = 7225152, upload-time = "2026-04-01T14:45:50.032Z" }, + { url = "https://files.pythonhosted.org/packages/ff/6e/cf826fae916b8658848d7b9f38d88da6396895c676e8086fc0988073aaf8/pillow-12.2.0-cp314-cp314t-win_arm64.whl", hash = "sha256:aa88ccfe4e32d362816319ed727a004423aab09c5cea43c01a4b435643fa34eb", size = 2556579, upload-time = "2026-04-01T14:45:52.529Z" }, + { url = "https://files.pythonhosted.org/packages/4e/b7/2437044fb910f499610356d1352e3423753c98e34f915252aafecc64889f/pillow-12.2.0-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:0538bd5e05efec03ae613fd89c4ce0368ecd2ba239cc25b9f9be7ed426b0af1f", size = 5273969, upload-time = "2026-04-01T14:45:55.538Z" }, + { url = "https://files.pythonhosted.org/packages/f6/f4/8316e31de11b780f4ac08ef3654a75555e624a98db1056ecb2122d008d5a/pillow-12.2.0-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:394167b21da716608eac917c60aa9b969421b5dcbbe02ae7f013e7b85811c69d", size = 4659674, upload-time = "2026-04-01T14:45:58.093Z" }, + { url = "https://files.pythonhosted.org/packages/d4/37/664fca7201f8bb2aa1d20e2c3d5564a62e6ae5111741966c8319ca802361/pillow-12.2.0-pp311-pypy311_pp73-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5d04bfa02cc2d23b497d1e90a0f927070043f6cbf303e738300532379a4b4e0f", size = 5288479, upload-time = "2026-04-01T14:46:01.141Z" }, + { url = "https://files.pythonhosted.org/packages/49/62/5b0ed78fce87346be7a5cfcfaaad91f6a1f98c26f86bdbafa2066c647ef6/pillow-12.2.0-pp311-pypy311_pp73-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:0c838a5125cee37e68edec915651521191cef1e6aa336b855f495766e77a366e", size = 7032230, upload-time = "2026-04-01T14:46:03.874Z" }, + { url = "https://files.pythonhosted.org/packages/c3/28/ec0fc38107fc32536908034e990c47914c57cd7c5a3ece4d8d8f7ffd7e27/pillow-12.2.0-pp311-pypy311_pp73-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:4a6c9fa44005fa37a91ebfc95d081e8079757d2e904b27103f4f5fa6f0bf78c0", size = 5355404, upload-time = "2026-04-01T14:46:06.33Z" }, + { url = "https://files.pythonhosted.org/packages/5e/8b/51b0eddcfa2180d60e41f06bd6d0a62202b20b59c68f5a132e615b75aecf/pillow-12.2.0-pp311-pypy311_pp73-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:25373b66e0dd5905ed63fa3cae13c82fbddf3079f2c8bf15c6fb6a35586324c1", size = 6002215, upload-time = "2026-04-01T14:46:08.83Z" }, + { url = "https://files.pythonhosted.org/packages/bc/60/5382c03e1970de634027cee8e1b7d39776b778b81812aaf45b694dfe9e28/pillow-12.2.0-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:bfa9c230d2fe991bed5318a5f119bd6780cda2915cca595393649fc118ab895e", size = 7080946, upload-time = "2026-04-01T14:46:11.734Z" }, +] + [[package]] name = "pkginfo" version = "1.12.1.2" @@ -3873,6 +4032,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/54/20/4d324d65cc6d9205fabedc306948156824eb9f0ee1633355a8f7ec5c66bf/pluggy-1.6.0-py3-none-any.whl", hash = "sha256:e920276dd6813095e9377c0bc5566d94c932c33b27a3e3945d8389c374dd4746", size = 20538, upload-time = "2025-05-15T12:30:06.134Z" }, ] +[[package]] +name = "prompt-toolkit" +version = "3.0.52" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "wcwidth" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/a1/96/06e01a7b38dce6fe1db213e061a4602dd6032a8a97ef6c1a862537732421/prompt_toolkit-3.0.52.tar.gz", hash = "sha256:28cde192929c8e7321de85de1ddbe736f1375148b02f2e17edd840042b1be855", size = 434198, upload-time = "2025-08-27T15:24:02.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, +] + [[package]] name = "propcache" version = "0.4.1" @@ -4796,15 +4967,15 @@ wheels = [ [[package]] name = "rich" -version = "15.0.0" +version = "14.3.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "markdown-it-py" }, { name = "pygments" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c0/8f/0722ca900cc807c13a6a0c696dacf35430f72e0ec571c4275d2371fca3e9/rich-15.0.0.tar.gz", hash = "sha256:edd07a4824c6b40189fb7ac9bc4c52536e9780fbbfbddf6f1e2502c31b068c36", size = 230680, upload-time = "2026-04-12T08:24:00.75Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e9/67/cae617f1351490c25a4b8ac3b8b63a4dda609295d8222bad12242dfdc629/rich-14.3.4.tar.gz", hash = "sha256:817e02727f2b25b40ef56f5aa2217f400c8489f79ca8f46ea2b70dd5e14558a9", size = 230524, upload-time = "2026-04-11T02:57:45.419Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/82/3b/64d4899d73f91ba49a8c18a8ff3f0ea8f1c1d75481760df8c68ef5235bf5/rich-15.0.0-py3-none-any.whl", hash = "sha256:33bd4ef74232fb73fe9279a257718407f169c09b78a87ad3d296f548e27de0bb", size = 310654, upload-time = "2026-04-12T08:24:02.83Z" }, + { url = "https://files.pythonhosted.org/packages/b3/76/6d163cfac87b632216f71879e6b2cf17163f773ff59c00b5ff4900a80fa3/rich-14.3.4-py3-none-any.whl", hash = "sha256:07e7adb4690f68864777b1450859253bed81a99a31ac321ac1817b2313558952", size = 310480, upload-time = "2026-04-11T02:57:47.484Z" }, ] [[package]] @@ -5006,6 +5177,27 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "slack-bolt" +version = "1.28.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "slack-sdk" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/5c/97/a62dde97e84027b252807f2044bed2edcda2d063a5cb0c535fb2be8d9b5d/slack_bolt-1.28.0.tar.gz", hash = "sha256:bfe367d867e8fb157a057248ebd4ac2d7f43acac6d0700fa31381db1e10f3b0f", size = 130768, upload-time = "2026-04-06T23:24:59.936Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/81/a9/697b6a92c728f09d5ef6b8e83dc6c8a87bc6d59499b2933ed067f11b7e30/slack_bolt-1.28.0-py2.py3-none-any.whl", hash = "sha256:738d1ca5e7c7039b6e18103d29267ced6e18c2517053eff18991fdd593acce5c", size = 234819, upload-time = "2026-04-06T23:24:58.278Z" }, +] + +[[package]] +name = "slack-sdk" +version = "3.41.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/22/35/fc009118a13187dd9731657c60138e5a7c2dea88681a7f04dc406af5da7d/slack_sdk-3.41.0.tar.gz", hash = "sha256:eb61eb12a65bebeca9cb5d36b3f799e836ed2be21b456d15df2627cfe34076ca", size = 250568, upload-time = "2026-03-12T16:10:11.381Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a1/df/2e4be347ff98281b505cc0ccf141408cdd25eb5ca9f3830deb361b2472d3/slack_sdk-3.41.0-py2.py3-none-any.whl", hash = "sha256:bb18dcdfff1413ec448e759cf807ec3324090993d8ab9111c74081623b692a89", size = 313885, upload-time = "2026-03-12T16:10:09.811Z" }, +] + [[package]] name = "sniffio" version = "1.3.1" @@ -5024,6 +5216,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c8/78/3565d011c61f5a43488987ee32b6f3f656e7f107ac2782dd57bdd7d91d9a/snowballstemmer-3.0.1-py3-none-any.whl", hash = "sha256:6cd7b3897da8d6c9ffb968a6781fa6532dce9c3618a4b127d920dab764a19064", size = 103274, upload-time = "2025-05-09T16:34:50.371Z" }, ] +[[package]] +name = "soupsieve" +version = "2.8.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/ae/2d9c981590ed9999a0d91755b47fc74f74de286b0f5cee14c9269041e6c4/soupsieve-2.8.3.tar.gz", hash = "sha256:3267f1eeea4251fb42728b6dfb746edc9acaffc4a45b27e19450b676586e8349", size = 118627, upload-time = "2026-01-20T04:27:02.457Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/46/2c/1462b1d0a634697ae9e55b3cecdcb64788e8b7d63f54d923fcd0bb140aed/soupsieve-2.8.3-py3-none-any.whl", hash = "sha256:ed64f2ba4eebeab06cc4962affce381647455978ffc1e36bb79a545b91f45a95", size = 37016, upload-time = "2026-01-20T04:27:01.012Z" }, +] + [[package]] name = "sqlalchemy" version = "2.0.49" @@ -5133,6 +5334,57 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/81/0d/13d1d239a25cbfb19e740db83143e95c772a1fe10202dda4b76792b114dd/starlette-0.52.1-py3-none-any.whl", hash = "sha256:0029d43eb3d273bc4f83a08720b4912ea4b071087a3b48db01b7c839f7954d74", size = 74272, upload-time = "2026-01-18T13:34:09.188Z" }, ] +[[package]] +name = "strands-agents" +version = "1.38.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "boto3" }, + { name = "botocore" }, + { name = "docstring-parser" }, + { name = "jsonschema" }, + { name = "mcp" }, + { name = "opentelemetry-api" }, + { name = "opentelemetry-instrumentation-threading" }, + { name = "opentelemetry-sdk" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "typing-extensions" }, + { name = "watchdog" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/11/89/3e722f4b5bd913531bc32a23bf88aaa77a434774f294bba5bfa88690ec46/strands_agents-1.38.0.tar.gz", hash = "sha256:02a68ec321ad457f9137dfd6a99cf72cf0e86081fee35de85fbe29b9ac0af2b2", size = 858950, upload-time = "2026-04-30T16:57:43.244Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/cc/06/de8d8ab14a2e92dcb0fa82db0a4cb102418a1eda139412bbe5b5725e28df/strands_agents-1.38.0-py3-none-any.whl", hash = "sha256:9dc3de17e25d70e367d37f9151f2a4c7b3ac8fc9f6237e9e1f34d00bfbfd001b", size = 422354, upload-time = "2026-04-30T16:57:41.094Z" }, +] + +[[package]] +name = "strands-agents-tools" +version = "0.5.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "aiohttp" }, + { name = "aws-requests-auth" }, + { name = "botocore" }, + { name = "dill" }, + { name = "markdownify" }, + { name = "pillow" }, + { name = "prompt-toolkit" }, + { name = "pyjwt" }, + { name = "requests" }, + { name = "rich" }, + { name = "slack-bolt" }, + { name = "strands-agents" }, + { name = "sympy" }, + { name = "tenacity" }, + { name = "typing-extensions" }, + { name = "tzdata", marker = "sys_platform == 'win32'" }, + { name = "watchdog" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/63/32/710a49ffd32b0a232ec1731620ee6105c045e9a77ecee1f3ecaa1a80a6cd/strands_agents_tools-0.5.2.tar.gz", hash = "sha256:96763c8ae75933c5dd327cca87561f573aed720c9c0f3d17fd20835910d11381", size = 483164, upload-time = "2026-04-30T17:08:13.151Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/59/ef/fe73b6d25d095784d2e1f6f33419265e796143100fb2f32a6e86f8ae68af/strands_agents_tools-0.5.2-py3-none-any.whl", hash = "sha256:8f85e4cb28d9411e62e1f159aa7e300d3a0f4b1d2b878a7cdfd5d746d9333343", size = 316178, upload-time = "2026-04-30T17:08:11.416Z" }, +] + [[package]] name = "sympy" version = "1.14.0" @@ -5192,6 +5444,9 @@ opentelemetry = [ pydantic = [ { name = "pydantic" }, ] +strands = [ + { name = "strands-agents" }, +] [package.dev-dependencies] dev = [ @@ -5228,6 +5483,8 @@ dev = [ { name = "pytest-xdist" }, { name = "ruff" }, { name = "setuptools" }, + { name = "strands-agents" }, + { name = "strands-agents-tools" }, { name = "toml" }, { name = "twine" }, ] @@ -5252,11 +5509,12 @@ requires-dist = [ { name = "protobuf", specifier = ">=3.20,<7.0.0" }, { name = "pydantic", marker = "extra == 'pydantic'", specifier = ">=2.0.0,<3" }, { name = "python-dateutil", marker = "python_full_version < '3.11'", specifier = ">=2.8.2,<3" }, + { name = "strands-agents", marker = "extra == 'strands'", specifier = ">=1.38.0" }, { name = "types-aioboto3", extras = ["s3"], marker = "extra == 'aioboto3'", specifier = ">=10.4.0" }, { name = "types-protobuf", specifier = ">=3.20,<7.0.0" }, { name = "typing-extensions", specifier = ">=4.2.0,<5" }, ] -provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google-adk", "langgraph", "langsmith", "lambda-worker-otel", "aioboto3"] +provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google-adk", "langgraph", "langsmith", "lambda-worker-otel", "aioboto3", "strands"] [package.metadata.requires-dev] dev = [ @@ -5293,6 +5551,8 @@ dev = [ { name = "pytest-xdist", specifier = ">=3.6,<4" }, { name = "ruff", specifier = ">=0.15.12,<0.16" }, { name = "setuptools", specifier = "<82" }, + { name = "strands-agents", specifier = ">=1.38.0" }, + { name = "strands-agents-tools", specifier = ">=0.5.2" }, { name = "toml", specifier = ">=0.10.2,<0.11" }, { name = "twine", specifier = ">=4.0.1,<5" }, ] @@ -5744,6 +6004,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/33/e8/e40370e6d74ddba47f002a32919d91310d6074130fe4e17dabcafc15cbf1/watchdog-6.0.0-py3-none-win_ia64.whl", hash = "sha256:a1914259fa9e1454315171103c6a30961236f508b9b623eae470268bbcc6a22f", size = 79067, upload-time = "2024-11-01T14:07:11.845Z" }, ] +[[package]] +name = "wcwidth" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/ee/afaf0f85a9a18fe47a67f1e4422ed6cf1fe642f0ae0a2f81166231303c52/wcwidth-0.7.0.tar.gz", hash = "sha256:90e3a7ea092341c44b99562e75d09e4d5160fe7a3974c6fb842a101a95e7eed0", size = 182132, upload-time = "2026-05-02T16:04:12.653Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/41/52/e465037f5375f43533d1a80b6923955201596a99142ed524d77b571a1418/wcwidth-0.7.0-py3-none-any.whl", hash = "sha256:5d69154c429a82910e241c738cd0e2976fac8a2dd47a1a805f4afed1c0f136f2", size = 110825, upload-time = "2026-05-02T16:04:11.033Z" }, +] + [[package]] name = "websockets" version = "15.0.1" From 91ac3fef8d64fdbad0d39c17c4e3c437a5cfc8b1 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 12 May 2026 14:37:19 -0700 Subject: [PATCH 03/46] contrib/strands: split activity helper into _TemporalActivityTool --- temporalio/contrib/strands/README.md | 9 +- temporalio/contrib/strands/__init__.py | 4 +- temporalio/contrib/strands/_plugin.py | 118 +++--------------- .../strands/_temporal_activity_tool.py | 76 +++++++++++ tests/contrib/strands/test_tool.py | 19 ++- 5 files changed, 114 insertions(+), 112 deletions(-) create mode 100644 temporalio/contrib/strands/_temporal_activity_tool.py diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index c195a08ce..315e79c3d 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -4,5 +4,12 @@ * File-based tool lookup # Migration Guide -* Mark non-deterministic tools with `as_activity()` and include in `StrandsAgent(activity_tools=[...])` * Use `agent.invoke_async(message)` instead of `agent(message)` which spawns a thread +* Decorate non-deterministic tools with `@activity.defn` and register them via `Worker(activities=[...])`. Wrap them in the agent with `activity_as_tool()`. For tools from `strands_tools` (or any `@tool`-decorated input), write a thin async wrapper that calls the tool, e.g.: + ```python + from strands_tools.current_time import current_time + + @activity.defn(name="current_time") + async def current_time_activity() -> str: + return current_time() + ``` diff --git a/temporalio/contrib/strands/__init__.py b/temporalio/contrib/strands/__init__.py index ab7f96000..e52adb9c3 100644 --- a/temporalio/contrib/strands/__init__.py +++ b/temporalio/contrib/strands/__init__.py @@ -1,5 +1,5 @@ """Temporal integration for the Strands Agents SDK.""" -from ._plugin import StrandsPlugin, as_activity +from ._plugin import StrandsPlugin, activity_as_tool -__all__ = ["StrandsPlugin", "as_activity"] +__all__ = ["StrandsPlugin", "activity_as_tool"] diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index aafaaadb7..916d7611e 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -1,16 +1,10 @@ -import asyncio -import functools -import inspect +from collections.abc import Callable from dataclasses import replace from datetime import timedelta -from types import ModuleType -from typing import Any, Callable +from typing import Any -from strands import tool as strands_tool -from strands.tools import ToolProvider -from strands.tools.decorator import DecoratedFunctionTool +from strands.types.tools import AgentTool -from temporalio import activity, workflow from temporalio.common import Priority, RetryPolicy from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.converter import DataConverter, DefaultPayloadConverter @@ -19,37 +13,26 @@ from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner from temporalio.workflow import ActivityCancellationType, VersioningIntent +from ._temporal_activity_tool import _TemporalActivityTool + class StrandsPlugin(SimplePlugin): """Temporal Worker plugin for the Strands Agents SDK. - Marks ``strands`` as a sandbox passthrough module and registers any - nondeterministic tools as Temporal activities so they can be invoked - from inside workflows. + Configures sandbox passthrough for ``strands`` and ``strands_tools`` and + swaps in ``pydantic_data_converter`` so structured outputs serialize. """ - def __init__( - self, - activity_tools: list[ToolProvider | Any] | None = None, - ): - """Initialize the plugin. - - Args: - activity_tools: Tools to register as Temporal activities. Accepts - the same forms as ``strands.Agent(tools=...)``: plain - functions, ``@tool``-decorated functions, or imported - ``strands_tools`` submodules. - """ + def __init__(self) -> None: super().__init__( "aws.StrandsPlugin", - activities=[_build_activity(tool) for tool in activity_tools or []], workflow_runner=_workflow_runner, data_converter=_data_converter, ) -def as_activity( - tool: ToolProvider | Any, +def activity_as_tool( + activity_fn: Callable, *, task_queue: str | None = None, schedule_to_close_timeout: timedelta | None = None, @@ -62,21 +45,12 @@ def as_activity( versioning_intent: VersioningIntent | None = None, summary: str | None = None, priority: Priority = Priority.default, -) -> DecoratedFunctionTool: - """Wrap a Strands tool to dispatch through a Temporal activity. +) -> AgentTool: + """Wrap a Temporal activity as a Strands tool. - Accepts the same forms as ``strands.Agent(tools=...)``. The returned - tool has the same name and schema as the original, but its body calls - ``workflow.execute_activity`` instead of running the function inline. - The activity itself must be registered on the worker via - ``StrandsPlugin(activity_tools=[fn])``. - - All keyword arguments are forwarded to ``workflow.execute_activity``; - refer to its documentation for details. + ``activity_fn`` must be decorated by ``@activity.defn``. All keyword + arguments are forwarded to ``workflow.execute_activity``. """ - fn = _unwrap_tool(tool) - sig = inspect.signature(fn) - activity_name = fn.__name__ options: dict[str, Any] = { "task_queue": task_queue, "schedule_to_close_timeout": schedule_to_close_timeout, @@ -90,65 +64,10 @@ def as_activity( "summary": summary, "priority": priority, } - - @functools.wraps(fn) - async def proxy(*args: Any, **kwargs: Any) -> Any: - bound = sig.bind(*args, **kwargs) - bound.apply_defaults() - positional = list(bound.arguments.values()) - if not positional: - return await workflow.execute_activity(activity_name, **options) - if len(positional) == 1: - return await workflow.execute_activity( - activity_name, positional[0], **options - ) - return await workflow.execute_activity( - activity_name, args=positional, **options - ) - - proxy.__signature__ = sig # type: ignore[attr-defined] - return strands_tool(proxy) - - -def _build_activity(tool: ToolProvider | Any) -> Callable: - fn = _unwrap_tool(tool) - return activity.defn(name=fn.__name__)(_ensure_async(fn)) - - -def _unwrap_tool(tool: ToolProvider | Any) -> Callable: - if isinstance(tool, ModuleType): - name = tool.__name__.rsplit(".", 1)[-1] - fn = getattr(tool, name) - else: - fn = tool - if isinstance(fn, DecoratedFunctionTool): - return fn._tool_func - if not callable(fn): - raise TypeError(f"Cannot wrap {fn!r} as a Temporal activity") - return fn - - -def _ensure_async(fn: Callable) -> Callable: - """Return an async-compatible version of ``fn``. - - Temporal's Worker rejects sync activity functions unless an - ``activity_executor`` is configured. Most ``strands_tools`` functions - (e.g. ``current_time``, ``calculator``) are sync, so we wrap them in an - ``asyncio.to_thread`` call instead of requiring the user to wire up an - executor on the Worker. - """ - if inspect.iscoroutinefunction(fn): - return fn - - @functools.wraps(fn) - async def wrapper(*args: Any, **kwargs: Any) -> Any: - return await asyncio.to_thread(fn, *args, **kwargs) - - return wrapper + return _TemporalActivityTool(activity_fn, options) def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: - """Add ``strands`` and ``strands_tools`` to the sandbox passthrough list.""" if not runner: raise ValueError("No WorkflowRunner provided to the Strands plugin.") if isinstance(runner, SandboxedWorkflowRunner): @@ -163,13 +82,6 @@ def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: def _data_converter(converter: DataConverter | None) -> DataConverter: - """Default to ``pydantic_data_converter`` when the user hasn't overridden. - - Strands optionally surfaces pydantic ``BaseModel`` values (e.g. - ``result.structured_output``) that the default Temporal converter can't - serialize. ``pydantic_data_converter`` is a strict superset of the - default, so this is safe even for workflows that never touch pydantic. - """ if ( converter is None or converter.payload_converter_class is DefaultPayloadConverter diff --git a/temporalio/contrib/strands/_temporal_activity_tool.py b/temporalio/contrib/strands/_temporal_activity_tool.py new file mode 100644 index 000000000..1d2ac3ffe --- /dev/null +++ b/temporalio/contrib/strands/_temporal_activity_tool.py @@ -0,0 +1,76 @@ +import inspect +import json +from collections.abc import Callable +from typing import Any + +from strands.tools.decorator import FunctionToolMetadata +from strands.types._events import ToolResultEvent +from strands.types.tools import AgentTool, ToolGenerator, ToolResult, ToolSpec, ToolUse + +from temporalio import activity, workflow + + +class _TemporalActivityTool(AgentTool): + """Strands ``AgentTool`` whose body dispatches a Temporal activity.""" + + def __init__(self, activity_fn: Callable, options: dict[str, Any]) -> None: + super().__init__() + defn = activity._Definition.from_callable(activity_fn) + if not defn or not defn.name: + raise ValueError("activity_fn must be decorated with @activity.defn") + self._activity_name = defn.name + self._options = options + self._signature = inspect.signature(activity_fn) + spec = FunctionToolMetadata(activity_fn).extract_metadata() + spec["name"] = self._activity_name + self._spec: ToolSpec = spec + + @property + def tool_name(self) -> str: + return self._activity_name + + @property + def tool_spec(self) -> ToolSpec: + return self._spec + + @property + def tool_type(self) -> str: + return "temporal_activity" + + async def stream( + self, + tool_use: ToolUse, + invocation_state: dict[str, Any], + **kwargs: Any, + ) -> ToolGenerator: + bound = self._signature.bind(**tool_use["input"]) + bound.apply_defaults() + positional = list(bound.arguments.values()) + if not positional: + result = await workflow.execute_activity( + self._activity_name, **self._options + ) + elif len(positional) == 1: + result = await workflow.execute_activity( + self._activity_name, positional[0], **self._options + ) + else: + result = await workflow.execute_activity( + self._activity_name, args=positional, **self._options + ) + yield ToolResultEvent( + ToolResult( + toolUseId=tool_use["toolUseId"], + status="success", + content=[{"text": _to_text(result)}], + ) + ) + + +def _to_text(result: Any) -> str: + if isinstance(result, str): + return result + try: + return json.dumps(result) + except (TypeError, ValueError): + return str(result) diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index 28b37f215..922c75601 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -2,12 +2,13 @@ from uuid import uuid4 from strands import Agent, tool -from strands_tools import calculator, current_time +from strands_tools import calculator +from strands_tools.current_time import current_time -from temporalio import workflow +from temporalio import activity, workflow from temporalio.api.enums.v1 import EventType from temporalio.client import Client, WorkflowHistory -from temporalio.contrib.strands import StrandsPlugin, as_activity +from temporalio.contrib.strands import StrandsPlugin, activity_as_tool from temporalio.worker import Replayer, Worker from tests.contrib.strands.mock_model import MockModel @@ -17,6 +18,11 @@ def letter_counter(word: str, letter: str) -> int: return word.lower().count(letter.lower()) +@activity.defn(name="current_time") +async def current_time_activity() -> str: + return current_time() + + @workflow.defn class ToolWorkflow: def __init__(self) -> None: @@ -38,8 +44,8 @@ def __init__(self) -> None: model=model, tools=[ calculator, - as_activity( - current_time, + activity_as_tool( + current_time_activity, start_to_close_timeout=timedelta(seconds=15), ), letter_counter, @@ -54,12 +60,13 @@ async def run(self, prompt: str) -> str: async def test_tool(client: Client): task_queue = "test_tool" - plugin = StrandsPlugin(activity_tools=[current_time]) + plugin = StrandsPlugin() async with Worker( client, task_queue=task_queue, workflows=[ToolWorkflow], + activities=[current_time_activity], plugins=[plugin], ): handle = await client.start_workflow( From 1c7ad3ae6fa31f893fe8284f05832ac174a32fa2 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 12 May 2026 14:43:32 -0700 Subject: [PATCH 04/46] contrib/strands: import strands_tools at module level in wrapper --- temporalio/contrib/strands/README.md | 7 ++++--- tests/contrib/strands/test_tool.py | 5 ++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 315e79c3d..8db0f718a 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -5,11 +5,12 @@ # Migration Guide * Use `agent.invoke_async(message)` instead of `agent(message)` which spawns a thread -* Decorate non-deterministic tools with `@activity.defn` and register them via `Worker(activities=[...])`. Wrap them in the agent with `activity_as_tool()`. For tools from `strands_tools` (or any `@tool`-decorated input), write a thin async wrapper that calls the tool, e.g.: +* Decorate non-deterministic tools with `@activity.defn` and register them via `Worker(activities=[...])`. Wrap them in the agent with `activity_as_tool()`. +* For tools imported from `strands_tools`, write a thin async wrapper that calls the tool, e.g.: ```python - from strands_tools.current_time import current_time + from strands_tools import current_time @activity.defn(name="current_time") async def current_time_activity() -> str: - return current_time() + return current_time.current_time() ``` diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index 922c75601..e7bd3233b 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -2,8 +2,7 @@ from uuid import uuid4 from strands import Agent, tool -from strands_tools import calculator -from strands_tools.current_time import current_time +from strands_tools import calculator, current_time from temporalio import activity, workflow from temporalio.api.enums.v1 import EventType @@ -20,7 +19,7 @@ def letter_counter(word: str, letter: str) -> int: @activity.defn(name="current_time") async def current_time_activity() -> str: - return current_time() + return current_time.current_time() @workflow.defn From e8e7810a5a9be52d8424a06b673d2279a8e6300f Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 12 May 2026 16:16:04 -0700 Subject: [PATCH 05/46] contrib/strands: run models as activities via TemporalModel --- temporalio/contrib/strands/README.md | 14 ++ temporalio/contrib/strands/__init__.py | 3 +- temporalio/contrib/strands/_model.py | 169 ++++++++++++++++++ temporalio/contrib/strands/_plugin.py | 14 +- tests/contrib/strands/common.py | 10 ++ tests/contrib/strands/test_model.py | 51 ++++++ tests/contrib/strands/test_model_streaming.py | 81 +++++++++ tests/contrib/strands/test_tool.py | 12 +- 8 files changed, 342 insertions(+), 12 deletions(-) create mode 100644 temporalio/contrib/strands/_model.py create mode 100644 tests/contrib/strands/common.py create mode 100644 tests/contrib/strands/test_model.py create mode 100644 tests/contrib/strands/test_model_streaming.py diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 8db0f718a..c7e861cd2 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -14,3 +14,17 @@ async def current_time_activity() -> str: return current_time.current_time() ``` +* Wrap the model with `TemporalModel` so the LLM call runs as a durable activity. Pass the real model to `StrandsPlugin` on the worker: + ```python + # workflow + from temporalio.contrib.strands import TemporalModel + + agent = Agent(model=TemporalModel(start_to_close_timeout=timedelta(seconds=60))) + + # worker + from strands.models.bedrock import BedrockModel + from temporalio.contrib.strands import StrandsPlugin + + Worker(..., plugins=[StrandsPlugin(model=BedrockModel(model_id="claude-3-5-sonnet"))]) + ``` + To stream chunks to external consumers, pass `streaming_topic="..."` to `TemporalModel` and host a `WorkflowStream` on the workflow. diff --git a/temporalio/contrib/strands/__init__.py b/temporalio/contrib/strands/__init__.py index e52adb9c3..f03045a88 100644 --- a/temporalio/contrib/strands/__init__.py +++ b/temporalio/contrib/strands/__init__.py @@ -1,5 +1,6 @@ """Temporal integration for the Strands Agents SDK.""" +from ._model import TemporalModel from ._plugin import StrandsPlugin, activity_as_tool -__all__ = ["StrandsPlugin", "activity_as_tool"] +__all__ = ["StrandsPlugin", "TemporalModel", "activity_as_tool"] diff --git a/temporalio/contrib/strands/_model.py b/temporalio/contrib/strands/_model.py new file mode 100644 index 000000000..aca057b78 --- /dev/null +++ b/temporalio/contrib/strands/_model.py @@ -0,0 +1,169 @@ +from __future__ import annotations + +from collections.abc import AsyncIterable +from dataclasses import dataclass +from datetime import timedelta +from typing import Any + +from strands.models import Model +from strands.types.content import Messages, SystemContentBlock +from strands.types.streaming import StreamEvent +from strands.types.tools import ToolChoice, ToolSpec + +from temporalio import activity, workflow +from temporalio.common import Priority, RetryPolicy +from temporalio.contrib.workflow_streams import WorkflowStreamClient +from temporalio.workflow import ActivityCancellationType, VersioningIntent + + +@dataclass +class _InvokeModelInput: + messages: Messages + tool_specs: list[ToolSpec] | None = None + system_prompt: str | None = None + tool_choice: ToolChoice | None = None + system_prompt_content: list[SystemContentBlock] | None = None + + +@dataclass +class _StreamingInvokeModelInput(_InvokeModelInput): + streaming_topic: str = "" + streaming_batch_interval_seconds: float = 0.1 + + +class _ModelActivity: + """Holds the user-supplied model and exposes the model activities.""" + + def __init__(self, model: Model) -> None: + self._model = model + + @activity.defn(name="invoke_strands_model") + async def invoke_model(self, input: _InvokeModelInput) -> list[StreamEvent]: + return [event async for event in _stream(self._model, input)] + + @activity.defn(name="invoke_strands_model_streaming") + async def invoke_model_streaming( + self, input: _StreamingInvokeModelInput + ) -> list[StreamEvent]: + events: list[StreamEvent] = [] + stream = WorkflowStreamClient.from_within_activity( + batch_interval=timedelta(seconds=input.streaming_batch_interval_seconds), + ) + topic = stream.topic(input.streaming_topic) + async with stream: + async for event in _stream(self._model, input): + activity.heartbeat() + events.append(event) + topic.publish(event) + return events + + +def _stream(model: Model, input: _InvokeModelInput) -> AsyncIterable[StreamEvent]: + return model.stream( + input.messages, + input.tool_specs, + input.system_prompt, + tool_choice=input.tool_choice, + system_prompt_content=input.system_prompt_content, + ) + + +class TemporalModel(Model): + """Strands :class:`Model` whose ``stream()`` runs as a Temporal activity. + + Construct inside a workflow and pass to ``Agent(model=...)``. The concrete + model is supplied worker-side via the ``model`` argument to + :class:`StrandsPlugin`. + + When ``streaming_topic`` is set, each ``StreamEvent`` is also published to + the named topic on the workflow's + :class:`temporalio.contrib.workflow_streams.WorkflowStream` for external + consumers (UIs, tracing). The workflow must host a ``WorkflowStream`` to + receive the publishes; otherwise the signals are unhandled and dropped. + """ + + def __init__( + self, + *, + task_queue: str | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + retry_policy: RetryPolicy | None = None, + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, + activity_id: str | None = None, + versioning_intent: VersioningIntent | None = None, + summary: str | None = None, + priority: Priority = Priority.default, + streaming_topic: str | None = None, + streaming_batch_interval: timedelta = timedelta(milliseconds=100), + ) -> None: + self._streaming_topic = streaming_topic + self._streaming_batch_interval = streaming_batch_interval + self._options: dict[str, Any] = { + "task_queue": task_queue, + "schedule_to_close_timeout": schedule_to_close_timeout, + "schedule_to_start_timeout": schedule_to_start_timeout, + "start_to_close_timeout": start_to_close_timeout, + "heartbeat_timeout": heartbeat_timeout, + "retry_policy": retry_policy, + "cancellation_type": cancellation_type, + "activity_id": activity_id, + "versioning_intent": versioning_intent, + "summary": summary, + "priority": priority, + } + + def update_config(self, **_model_config: Any) -> None: + return None + + def get_config(self) -> dict[str, Any]: + return {} + + def structured_output(self, *_args: Any, **_kwargs: Any) -> Any: + raise NotImplementedError( + "TemporalModel.structured_output is not supported. Use " + "Agent(structured_output_model=...) which routes structured output " + "through stream() via the structured_output_tool." + ) + + async def stream( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + *, + tool_choice: ToolChoice | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + invocation_state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> AsyncIterable[StreamEvent]: + if self._streaming_topic is not None: + events = await workflow.execute_activity_method( + _ModelActivity.invoke_model_streaming, + _StreamingInvokeModelInput( + messages=messages, + tool_specs=tool_specs, + system_prompt=system_prompt, + tool_choice=tool_choice, + system_prompt_content=system_prompt_content, + streaming_topic=self._streaming_topic, + streaming_batch_interval_seconds=self._streaming_batch_interval.total_seconds(), + ), + **self._options, + ) + else: + events = await workflow.execute_activity_method( + _ModelActivity.invoke_model, + _InvokeModelInput( + messages=messages, + tool_specs=tool_specs, + system_prompt=system_prompt, + tool_choice=tool_choice, + system_prompt_content=system_prompt_content, + ), + **self._options, + ) + for event in events: + yield event diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index 916d7611e..ef548ec87 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -3,6 +3,7 @@ from datetime import timedelta from typing import Any +from strands.models import Model from strands.types.tools import AgentTool from temporalio.common import Priority, RetryPolicy @@ -13,6 +14,7 @@ from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner from temporalio.workflow import ActivityCancellationType, VersioningIntent +from ._model import _ModelActivity from ._temporal_activity_tool import _TemporalActivityTool @@ -21,13 +23,23 @@ class StrandsPlugin(SimplePlugin): Configures sandbox passthrough for ``strands`` and ``strands_tools`` and swaps in ``pydantic_data_converter`` so structured outputs serialize. + + When ``model`` is supplied, registers the activities that back + :class:`temporalio.contrib.strands.TemporalModel` so the model's + ``stream()`` call runs durably as a Temporal activity. The ``model`` + instance lives on the worker and is reused across activity invocations. """ - def __init__(self) -> None: + def __init__(self, model: Model | None = None) -> None: + activities: list[Callable] | None = None + if model is not None: + ma = _ModelActivity(model) + activities = [ma.invoke_model, ma.invoke_model_streaming] super().__init__( "aws.StrandsPlugin", workflow_runner=_workflow_runner, data_converter=_data_converter, + activities=activities, ) diff --git a/tests/contrib/strands/common.py b/tests/contrib/strands/common.py new file mode 100644 index 000000000..5ece12a1e --- /dev/null +++ b/tests/contrib/strands/common.py @@ -0,0 +1,10 @@ +from temporalio.api.enums.v1 import EventType +from temporalio.client import WorkflowHistory + + +def get_activities(history: WorkflowHistory) -> list[str]: + return [ + event.activity_task_scheduled_event_attributes.activity_type.name + for event in history.events + if event.event_type == EventType.EVENT_TYPE_ACTIVITY_TASK_SCHEDULED + ] diff --git a/tests/contrib/strands/test_model.py b/tests/contrib/strands/test_model.py new file mode 100644 index 000000000..6f57a5627 --- /dev/null +++ b/tests/contrib/strands/test_model.py @@ -0,0 +1,51 @@ +from datetime import timedelta +from uuid import uuid4 + +from strands import Agent + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.worker import Replayer, Worker +from tests.contrib.strands.common import get_activities +from tests.contrib.strands.mock_model import MockModel + + +@workflow.defn +class ModelWorkflow: + def __init__(self) -> None: + self.agent = Agent( + model=TemporalModel(start_to_close_timeout=timedelta(seconds=15)), + ) + + @workflow.run + async def run(self, prompt: str) -> str: + result = await self.agent.invoke_async(prompt) + return str(result) + + +async def test_model(client: Client): + task_queue = "test_model" + plugin = StrandsPlugin(model=MockModel(["Done!"])) + + async with Worker( + client, + task_queue=task_queue, + workflows=[ModelWorkflow], + plugins=[plugin], + ): + handle = await client.start_workflow( + ModelWorkflow.run, + "Hello", + id=f"test_model_{uuid4()}", + task_queue=task_queue, + ) + assert await handle.result() == "Done!\n" + + history = await handle.fetch_history() + assert get_activities(history) == ["invoke_strands_model"] + + await Replayer( + workflows=[ModelWorkflow], + plugins=[plugin], + ).replay_workflow(history) diff --git a/tests/contrib/strands/test_model_streaming.py b/tests/contrib/strands/test_model_streaming.py new file mode 100644 index 000000000..f93f944a7 --- /dev/null +++ b/tests/contrib/strands/test_model_streaming.py @@ -0,0 +1,81 @@ +import asyncio +from datetime import timedelta +from uuid import uuid4 + +from strands import Agent +from strands.types.streaming import StreamEvent + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient +from temporalio.worker import Replayer, Worker +from tests.contrib.strands.common import get_activities +from tests.contrib.strands.mock_model import MockModel + + +@workflow.defn +class StreamingModelWorkflow: + def __init__(self) -> None: + self.stream = WorkflowStream() + self.agent = Agent( + model=TemporalModel( + start_to_close_timeout=timedelta(seconds=15), + streaming_topic="events", + ), + ) + + @workflow.run + async def run(self, prompt: str) -> str: + result = await self.agent.invoke_async(prompt) + return str(result) + + +async def test_model_streaming(client: Client): + task_queue = "test_model_streaming" + plugin = StrandsPlugin(model=MockModel(["Done!"])) + workflow_id = f"test_model_streaming_{uuid4()}" + + async with Worker( + client, + task_queue=task_queue, + workflows=[StreamingModelWorkflow], + plugins=[plugin], + ): + handle = await client.start_workflow( + StreamingModelWorkflow.run, + "Hello", + id=workflow_id, + task_queue=task_queue, + ) + + stream = WorkflowStreamClient.create(client, workflow_id) + events: list[StreamEvent] = [] + + async def collect() -> None: + async for item in stream.subscribe( + ["events"], + from_offset=0, + result_type=StreamEvent, + poll_cooldown=timedelta(milliseconds=50), + ): + events.append(item.data) + if len(events) >= 4: + break + + collect_task = asyncio.create_task(collect()) + assert await handle.result() == "Done!\n" + await asyncio.wait_for(collect_task, timeout=10.0) + + history = await handle.fetch_history() + assert get_activities(history) == ["invoke_strands_model_streaming"] + + # The "Done!" response from MockModel produces four StreamEvents: + # messageStart, contentBlockDelta(text), contentBlockStop, messageStop. + assert any("messageStart" in e for e in events) + assert any("messageStop" in e for e in events) + + await Replayer( + workflows=[StreamingModelWorkflow], + plugins=[plugin], + ).replay_workflow(history) diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index e7bd3233b..716b65b20 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -5,10 +5,10 @@ from strands_tools import calculator, current_time from temporalio import activity, workflow -from temporalio.api.enums.v1 import EventType -from temporalio.client import Client, WorkflowHistory +from temporalio.client import Client from temporalio.contrib.strands import StrandsPlugin, activity_as_tool from temporalio.worker import Replayer, Worker +from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel @@ -86,11 +86,3 @@ async def test_tool(client: Client): workflows=[ToolWorkflow], plugins=[plugin], ).replay_workflow(history) - - -def get_activities(history: WorkflowHistory) -> list[str]: - return [ - event.activity_task_scheduled_event_attributes.activity_type.name - for event in history.events - if event.event_type == EventType.EVENT_TYPE_ACTIVITY_TASK_SCHEDULED - ] From b828b316a5f5ea8b3562c6b8665c68b7b4f137f3 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 13 May 2026 09:12:48 -0700 Subject: [PATCH 06/46] contrib/strands: patch Agent.__init__ to route models through activities --- temporalio/contrib/strands/README.md | 13 ++- temporalio/contrib/strands/__init__.py | 3 +- temporalio/contrib/strands/_model.py | 106 +---------------- temporalio/contrib/strands/_patch.py | 109 ++++++++++++++++++ temporalio/contrib/strands/_plugin.py | 57 ++++++++- tests/contrib/strands/test_model.py | 11 +- tests/contrib/strands/test_model_streaming.py | 17 ++- tests/contrib/strands/test_tool.py | 43 ++++--- 8 files changed, 208 insertions(+), 151 deletions(-) create mode 100644 temporalio/contrib/strands/_patch.py diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index c7e861cd2..e14268ff4 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -14,17 +14,18 @@ async def current_time_activity() -> str: return current_time.current_time() ``` -* Wrap the model with `TemporalModel` so the LLM call runs as a durable activity. Pass the real model to `StrandsPlugin` on the worker: +* Pass the real model to `StrandsPlugin` on the worker; the plugin patches `Agent.__init__` so any `Agent(...)` constructed inside a workflow has its model replaced with a stub that routes `stream()` through the registered activity. Workflow code stays vanilla Strands: ```python # workflow - from temporalio.contrib.strands import TemporalModel - - agent = Agent(model=TemporalModel(start_to_close_timeout=timedelta(seconds=60))) + agent = Agent() # model arg is ignored in workflow context # worker from strands.models.bedrock import BedrockModel from temporalio.contrib.strands import StrandsPlugin - Worker(..., plugins=[StrandsPlugin(model=BedrockModel(model_id="claude-3-5-sonnet"))]) + Worker(..., plugins=[StrandsPlugin( + model=BedrockModel(model_id="claude-3-5-sonnet"), + start_to_close_timeout=timedelta(seconds=60), + )]) ``` - To stream chunks to external consumers, pass `streaming_topic="..."` to `TemporalModel` and host a `WorkflowStream` on the workflow. + To stream chunks to external consumers, pass `streaming_topic="..."` to `StrandsPlugin` and host a `WorkflowStream` on the workflow. diff --git a/temporalio/contrib/strands/__init__.py b/temporalio/contrib/strands/__init__.py index f03045a88..e52adb9c3 100644 --- a/temporalio/contrib/strands/__init__.py +++ b/temporalio/contrib/strands/__init__.py @@ -1,6 +1,5 @@ """Temporal integration for the Strands Agents SDK.""" -from ._model import TemporalModel from ._plugin import StrandsPlugin, activity_as_tool -__all__ = ["StrandsPlugin", "TemporalModel", "activity_as_tool"] +__all__ = ["StrandsPlugin", "activity_as_tool"] diff --git a/temporalio/contrib/strands/_model.py b/temporalio/contrib/strands/_model.py index aca057b78..e45372d42 100644 --- a/temporalio/contrib/strands/_model.py +++ b/temporalio/contrib/strands/_model.py @@ -3,17 +3,14 @@ from collections.abc import AsyncIterable from dataclasses import dataclass from datetime import timedelta -from typing import Any from strands.models import Model from strands.types.content import Messages, SystemContentBlock from strands.types.streaming import StreamEvent from strands.types.tools import ToolChoice, ToolSpec -from temporalio import activity, workflow -from temporalio.common import Priority, RetryPolicy +from temporalio import activity from temporalio.contrib.workflow_streams import WorkflowStreamClient -from temporalio.workflow import ActivityCancellationType, VersioningIntent @dataclass @@ -66,104 +63,3 @@ def _stream(model: Model, input: _InvokeModelInput) -> AsyncIterable[StreamEvent tool_choice=input.tool_choice, system_prompt_content=input.system_prompt_content, ) - - -class TemporalModel(Model): - """Strands :class:`Model` whose ``stream()`` runs as a Temporal activity. - - Construct inside a workflow and pass to ``Agent(model=...)``. The concrete - model is supplied worker-side via the ``model`` argument to - :class:`StrandsPlugin`. - - When ``streaming_topic`` is set, each ``StreamEvent`` is also published to - the named topic on the workflow's - :class:`temporalio.contrib.workflow_streams.WorkflowStream` for external - consumers (UIs, tracing). The workflow must host a ``WorkflowStream`` to - receive the publishes; otherwise the signals are unhandled and dropped. - """ - - def __init__( - self, - *, - task_queue: str | None = None, - schedule_to_close_timeout: timedelta | None = None, - schedule_to_start_timeout: timedelta | None = None, - start_to_close_timeout: timedelta | None = None, - heartbeat_timeout: timedelta | None = None, - retry_policy: RetryPolicy | None = None, - cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, - activity_id: str | None = None, - versioning_intent: VersioningIntent | None = None, - summary: str | None = None, - priority: Priority = Priority.default, - streaming_topic: str | None = None, - streaming_batch_interval: timedelta = timedelta(milliseconds=100), - ) -> None: - self._streaming_topic = streaming_topic - self._streaming_batch_interval = streaming_batch_interval - self._options: dict[str, Any] = { - "task_queue": task_queue, - "schedule_to_close_timeout": schedule_to_close_timeout, - "schedule_to_start_timeout": schedule_to_start_timeout, - "start_to_close_timeout": start_to_close_timeout, - "heartbeat_timeout": heartbeat_timeout, - "retry_policy": retry_policy, - "cancellation_type": cancellation_type, - "activity_id": activity_id, - "versioning_intent": versioning_intent, - "summary": summary, - "priority": priority, - } - - def update_config(self, **_model_config: Any) -> None: - return None - - def get_config(self) -> dict[str, Any]: - return {} - - def structured_output(self, *_args: Any, **_kwargs: Any) -> Any: - raise NotImplementedError( - "TemporalModel.structured_output is not supported. Use " - "Agent(structured_output_model=...) which routes structured output " - "through stream() via the structured_output_tool." - ) - - async def stream( - self, - messages: Messages, - tool_specs: list[ToolSpec] | None = None, - system_prompt: str | None = None, - *, - tool_choice: ToolChoice | None = None, - system_prompt_content: list[SystemContentBlock] | None = None, - invocation_state: dict[str, Any] | None = None, - **kwargs: Any, - ) -> AsyncIterable[StreamEvent]: - if self._streaming_topic is not None: - events = await workflow.execute_activity_method( - _ModelActivity.invoke_model_streaming, - _StreamingInvokeModelInput( - messages=messages, - tool_specs=tool_specs, - system_prompt=system_prompt, - tool_choice=tool_choice, - system_prompt_content=system_prompt_content, - streaming_topic=self._streaming_topic, - streaming_batch_interval_seconds=self._streaming_batch_interval.total_seconds(), - ), - **self._options, - ) - else: - events = await workflow.execute_activity_method( - _ModelActivity.invoke_model, - _InvokeModelInput( - messages=messages, - tool_specs=tool_specs, - system_prompt=system_prompt, - tool_choice=tool_choice, - system_prompt_content=system_prompt_content, - ), - **self._options, - ) - for event in events: - yield event diff --git a/temporalio/contrib/strands/_patch.py b/temporalio/contrib/strands/_patch.py new file mode 100644 index 000000000..f5931aca3 --- /dev/null +++ b/temporalio/contrib/strands/_patch.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +from collections.abc import AsyncIterable +from datetime import timedelta +from typing import Any + +from strands.agent.agent import Agent +from strands.models import Model +from strands.types.content import Messages, SystemContentBlock +from strands.types.streaming import StreamEvent +from strands.types.tools import ToolChoice, ToolSpec + +from temporalio import workflow + +from ._model import ( + _InvokeModelInput, + _ModelActivity, + _StreamingInvokeModelInput, +) + +_original_agent_init = Agent.__init__ +_options: dict[str, Any] = {} +_streaming_topic: str | None = None +_streaming_batch_interval: timedelta = timedelta(milliseconds=100) + + +class _ActivityDispatchModel(Model): + """Stub installed by the patch; routes ``stream()`` through an activity.""" + + def update_config(self, **_model_config: Any) -> None: + return None + + def get_config(self) -> dict[str, Any]: + return {} + + def structured_output(self, *_args: Any, **_kwargs: Any) -> Any: + raise NotImplementedError( + "Strands Agent.structured_output_async is not supported in workflow " + "context. Use Agent(structured_output_model=...) which routes through " + "stream() via the structured_output_tool." + ) + + async def stream( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + *, + tool_choice: ToolChoice | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + invocation_state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> AsyncIterable[StreamEvent]: + if _streaming_topic is not None: + events = await workflow.execute_activity_method( + _ModelActivity.invoke_model_streaming, + _StreamingInvokeModelInput( + messages=messages, + tool_specs=tool_specs, + system_prompt=system_prompt, + tool_choice=tool_choice, + system_prompt_content=system_prompt_content, + streaming_topic=_streaming_topic, + streaming_batch_interval_seconds=_streaming_batch_interval.total_seconds(), + ), + **_options, + ) + else: + events = await workflow.execute_activity_method( + _ModelActivity.invoke_model, + _InvokeModelInput( + messages=messages, + tool_specs=tool_specs, + system_prompt=system_prompt, + tool_choice=tool_choice, + system_prompt_content=system_prompt_content, + ), + **_options, + ) + for event in events: + yield event + + +def _patched_agent_init(self: Agent, model: Any = None, *args: Any, **kwargs: Any) -> None: + if workflow.in_workflow(): + if model is not None: + raise ValueError( + "Agent(model=...) must not be set inside a workflow. " + "Pass the model to StrandsPlugin(model=...) instead so it " + "runs as a Temporal activity." + ) + model = _ActivityDispatchModel() + _original_agent_init(self, model, *args, **kwargs) + + +def install_patch( + options: dict[str, Any], + streaming_topic: str | None, + streaming_batch_interval: timedelta, +) -> None: + global _options, _streaming_topic, _streaming_batch_interval + _options = options + _streaming_topic = streaming_topic + _streaming_batch_interval = streaming_batch_interval + setattr(Agent, "__init__", _patched_agent_init) + + +def uninstall_patch() -> None: + setattr(Agent, "__init__", _original_agent_init) diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index ef548ec87..f97cce057 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -1,4 +1,5 @@ -from collections.abc import Callable +from collections.abc import AsyncIterator, Callable +from contextlib import asynccontextmanager from dataclasses import replace from datetime import timedelta from typing import Any @@ -15,6 +16,7 @@ from temporalio.workflow import ActivityCancellationType, VersioningIntent from ._model import _ModelActivity +from ._patch import install_patch, uninstall_patch from ._temporal_activity_tool import _TemporalActivityTool @@ -24,22 +26,65 @@ class StrandsPlugin(SimplePlugin): Configures sandbox passthrough for ``strands`` and ``strands_tools`` and swaps in ``pydantic_data_converter`` so structured outputs serialize. - When ``model`` is supplied, registers the activities that back - :class:`temporalio.contrib.strands.TemporalModel` so the model's - ``stream()`` call runs durably as a Temporal activity. The ``model`` - instance lives on the worker and is reused across activity invocations. + When ``model`` is supplied, registers the model activities and patches + ``strands.agent.agent.Agent.__init__`` so any ``Agent(...)`` constructed + inside a workflow gets its ``model`` replaced with a stub that routes + ``stream()`` through the registered activity. The ``model`` instance lives + on the worker and is reused across activity invocations. Activity options + (``start_to_close_timeout``, ``retry_policy``, etc.) flow from the plugin + to the dispatched activity. """ - def __init__(self, model: Model | None = None) -> None: + def __init__( + self, + *, + model: Model | None = None, + task_queue: str | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + retry_policy: RetryPolicy | None = None, + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, + versioning_intent: VersioningIntent | None = None, + summary: str | None = None, + priority: Priority = Priority.default, + streaming_topic: str | None = None, + streaming_batch_interval: timedelta = timedelta(milliseconds=100), + ) -> None: activities: list[Callable] | None = None if model is not None: ma = _ModelActivity(model) activities = [ma.invoke_model, ma.invoke_model_streaming] + options: dict[str, Any] = { + "task_queue": task_queue, + "schedule_to_close_timeout": schedule_to_close_timeout, + "schedule_to_start_timeout": schedule_to_start_timeout, + "start_to_close_timeout": start_to_close_timeout, + "heartbeat_timeout": heartbeat_timeout, + "retry_policy": retry_policy, + "cancellation_type": cancellation_type, + "versioning_intent": versioning_intent, + "summary": summary, + "priority": priority, + } + + @asynccontextmanager + async def run_context() -> AsyncIterator[None]: + if model is not None: + install_patch(options, streaming_topic, streaming_batch_interval) + try: + yield + finally: + if model is not None: + uninstall_patch() + super().__init__( "aws.StrandsPlugin", workflow_runner=_workflow_runner, data_converter=_data_converter, activities=activities, + run_context=run_context, ) diff --git a/tests/contrib/strands/test_model.py b/tests/contrib/strands/test_model.py index 6f57a5627..f4c50c6b7 100644 --- a/tests/contrib/strands/test_model.py +++ b/tests/contrib/strands/test_model.py @@ -5,7 +5,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.contrib.strands import StrandsPlugin from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel @@ -14,9 +14,7 @@ @workflow.defn class ModelWorkflow: def __init__(self) -> None: - self.agent = Agent( - model=TemporalModel(start_to_close_timeout=timedelta(seconds=15)), - ) + self.agent = Agent() @workflow.run async def run(self, prompt: str) -> str: @@ -26,7 +24,10 @@ async def run(self, prompt: str) -> str: async def test_model(client: Client): task_queue = "test_model" - plugin = StrandsPlugin(model=MockModel(["Done!"])) + plugin = StrandsPlugin( + model=MockModel(["Done!"]), + start_to_close_timeout=timedelta(seconds=15), + ) async with Worker( client, diff --git a/tests/contrib/strands/test_model_streaming.py b/tests/contrib/strands/test_model_streaming.py index f93f944a7..ab60df735 100644 --- a/tests/contrib/strands/test_model_streaming.py +++ b/tests/contrib/strands/test_model_streaming.py @@ -7,7 +7,7 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.contrib.strands import StrandsPlugin from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities @@ -18,12 +18,7 @@ class StreamingModelWorkflow: def __init__(self) -> None: self.stream = WorkflowStream() - self.agent = Agent( - model=TemporalModel( - start_to_close_timeout=timedelta(seconds=15), - streaming_topic="events", - ), - ) + self.agent = Agent() @workflow.run async def run(self, prompt: str) -> str: @@ -33,7 +28,11 @@ async def run(self, prompt: str) -> str: async def test_model_streaming(client: Client): task_queue = "test_model_streaming" - plugin = StrandsPlugin(model=MockModel(["Done!"])) + plugin = StrandsPlugin( + model=MockModel(["Done!"]), + start_to_close_timeout=timedelta(seconds=15), + streaming_topic="events", + ) workflow_id = f"test_model_streaming_{uuid4()}" async with Worker( @@ -70,8 +69,6 @@ async def collect() -> None: history = await handle.fetch_history() assert get_activities(history) == ["invoke_strands_model_streaming"] - # The "Done!" response from MockModel produces four StreamEvents: - # messageStart, contentBlockDelta(text), contentBlockStop, messageStop. assert any("messageStart" in e for e in events) assert any("messageStop" in e for e in events) diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index 716b65b20..44f01a3cc 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -25,22 +25,7 @@ async def current_time_activity() -> str: @workflow.defn class ToolWorkflow: def __init__(self) -> None: - model = MockModel( - [ - {"name": "current_time", "input": {}}, - { - "name": "calculator", - "input": {"expression": "3111696 / 74088"}, - }, - { - "name": "letter_counter", - "input": {"word": "strawberry", "letter": "R"}, - }, - "Done!", - ] - ) self.agent = Agent( - model=model, tools=[ calculator, activity_as_tool( @@ -59,7 +44,23 @@ async def run(self, prompt: str) -> str: async def test_tool(client: Client): task_queue = "test_tool" - plugin = StrandsPlugin() + plugin = StrandsPlugin( + model=MockModel( + [ + {"name": "current_time", "input": {}}, + { + "name": "calculator", + "input": {"expression": "3111696 / 74088"}, + }, + { + "name": "letter_counter", + "input": {"word": "strawberry", "letter": "R"}, + }, + "Done!", + ] + ), + start_to_close_timeout=timedelta(seconds=15), + ) async with Worker( client, @@ -80,7 +81,15 @@ async def test_tool(client: Client): assert await handle.result() == "Done!\n" history = await handle.fetch_history() - assert get_activities(history) == ["current_time"] + assert get_activities(history) == [ + "invoke_strands_model", + "current_time", + "invoke_strands_model", + # calculator (in-workflow) + "invoke_strands_model", + # letter_counter (in-workflow) + "invoke_strands_model", + ] await Replayer( workflows=[ToolWorkflow], From acd767d25e943bf0758faea44ab3973937347b61 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 13 May 2026 13:02:32 -0700 Subject: [PATCH 07/46] contrib/strands: explicit TemporalModel/TemporalMCPClient, drop monkey-patch, add MCP --- temporalio/contrib/strands/README.md | 46 ++++-- temporalio/contrib/strands/__init__.py | 9 +- .../strands/{_model.py => _model_activity.py} | 4 +- temporalio/contrib/strands/_patch.py | 109 -------------- temporalio/contrib/strands/_plugin.py | 84 +++++------ .../strands/_temporal_activity_tool.py | 2 +- .../contrib/strands/_temporal_mcp_client.py | 138 ++++++++++++++++++ .../contrib/strands/_temporal_mcp_tool.py | 60 ++++++++ temporalio/contrib/strands/_temporal_model.py | 126 ++++++++++++++++ tests/contrib/strands/echo_mcp_server.py | 13 ++ tests/contrib/strands/test_mcp.py | 81 ++++++++++ tests/contrib/strands/test_model.py | 14 +- tests/contrib/strands/test_model_streaming.py | 16 +- tests/contrib/strands/test_tool.py | 34 ++--- 14 files changed, 535 insertions(+), 201 deletions(-) rename temporalio/contrib/strands/{_model.py => _model_activity.py} (97%) delete mode 100644 temporalio/contrib/strands/_patch.py create mode 100644 temporalio/contrib/strands/_temporal_mcp_client.py create mode 100644 temporalio/contrib/strands/_temporal_mcp_tool.py create mode 100644 temporalio/contrib/strands/_temporal_model.py create mode 100644 tests/contrib/strands/echo_mcp_server.py create mode 100644 tests/contrib/strands/test_mcp.py diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index e14268ff4..ff70f71d4 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -14,18 +14,46 @@ async def current_time_activity() -> str: return current_time.current_time() ``` -* Pass the real model to `StrandsPlugin` on the worker; the plugin patches `Agent.__init__` so any `Agent(...)` constructed inside a workflow has its model replaced with a stub that routes `stream()` through the registered activity. Workflow code stays vanilla Strands: +* Construct a `TemporalModel` once with a `model_factory` lambda that returns the real model. Pass it both to the workflow's `Agent(model=...)` and to `StrandsPlugin(model=...)`; the plugin calls the factory at worker startup, and the workflow's `stream()` calls dispatch as Temporal activities. ```python + from strands.models.bedrock import BedrockModel + from temporalio.contrib.strands import StrandsPlugin, TemporalModel + + MODEL = TemporalModel( + model_factory=lambda: BedrockModel(model_id="claude-3-5-sonnet"), + start_to_close_timeout=timedelta(seconds=60), + ) + # workflow - agent = Agent() # model arg is ignored in workflow context + @workflow.defn + class MyWorkflow: + def __init__(self) -> None: + self.agent = Agent(model=MODEL) # worker - from strands.models.bedrock import BedrockModel - from temporalio.contrib.strands import StrandsPlugin + Worker(..., plugins=[StrandsPlugin(model=MODEL)]) + ``` + To stream chunks to external consumers, pass `streaming_topic="..."` to `TemporalModel` and host a `WorkflowStream` on the workflow. +* For MCP servers, construct `TemporalMCPClient` once at module level and reference the same instance from both the plugin (registers the per-server `{server}-call-tool` activity and connects at worker startup to discover tools) and the workflow's `Agent(tools=[...])`: + ```python + from mcp import StdioServerParameters, stdio_client + from temporalio.contrib.strands import StrandsPlugin, TemporalMCPClient, TemporalModel - Worker(..., plugins=[StrandsPlugin( - model=BedrockModel(model_id="claude-3-5-sonnet"), - start_to_close_timeout=timedelta(seconds=60), - )]) + ECHO = TemporalMCPClient( + server="echo", + transport_factory=lambda: stdio_client( + StdioServerParameters(command="...", args=[...]), + ), + start_to_close_timeout=timedelta(seconds=30), + ) + + # workflow + @workflow.defn + class MyWorkflow: + def __init__(self) -> None: + self.agent = Agent(model=MODEL, tools=[ECHO]) + + # worker + Worker(..., plugins=[StrandsPlugin(model=MODEL, mcp_clients=[ECHO])]) ``` - To stream chunks to external consumers, pass `streaming_topic="..."` to `StrandsPlugin` and host a `WorkflowStream` on the workflow. + The plugin connects to the MCP server once at worker startup to enumerate tools. The schema is frozen for the worker's lifetime; restart workers to pick up MCP-server changes. If the MCP server is unavailable at startup, the worker fails to start. diff --git a/temporalio/contrib/strands/__init__.py b/temporalio/contrib/strands/__init__.py index e52adb9c3..55a0e0bbb 100644 --- a/temporalio/contrib/strands/__init__.py +++ b/temporalio/contrib/strands/__init__.py @@ -1,5 +1,12 @@ """Temporal integration for the Strands Agents SDK.""" from ._plugin import StrandsPlugin, activity_as_tool +from ._temporal_mcp_client import TemporalMCPClient +from ._temporal_model import TemporalModel -__all__ = ["StrandsPlugin", "activity_as_tool"] +__all__ = [ + "StrandsPlugin", + "TemporalMCPClient", + "TemporalModel", + "activity_as_tool", +] diff --git a/temporalio/contrib/strands/_model.py b/temporalio/contrib/strands/_model_activity.py similarity index 97% rename from temporalio/contrib/strands/_model.py rename to temporalio/contrib/strands/_model_activity.py index e45372d42..9214c4285 100644 --- a/temporalio/contrib/strands/_model.py +++ b/temporalio/contrib/strands/_model_activity.py @@ -1,5 +1,3 @@ -from __future__ import annotations - from collections.abc import AsyncIterable from dataclasses import dataclass from datetime import timedelta @@ -28,7 +26,7 @@ class _StreamingInvokeModelInput(_InvokeModelInput): streaming_batch_interval_seconds: float = 0.1 -class _ModelActivity: +class ModelActivity: """Holds the user-supplied model and exposes the model activities.""" def __init__(self, model: Model) -> None: diff --git a/temporalio/contrib/strands/_patch.py b/temporalio/contrib/strands/_patch.py deleted file mode 100644 index f5931aca3..000000000 --- a/temporalio/contrib/strands/_patch.py +++ /dev/null @@ -1,109 +0,0 @@ -from __future__ import annotations - -from collections.abc import AsyncIterable -from datetime import timedelta -from typing import Any - -from strands.agent.agent import Agent -from strands.models import Model -from strands.types.content import Messages, SystemContentBlock -from strands.types.streaming import StreamEvent -from strands.types.tools import ToolChoice, ToolSpec - -from temporalio import workflow - -from ._model import ( - _InvokeModelInput, - _ModelActivity, - _StreamingInvokeModelInput, -) - -_original_agent_init = Agent.__init__ -_options: dict[str, Any] = {} -_streaming_topic: str | None = None -_streaming_batch_interval: timedelta = timedelta(milliseconds=100) - - -class _ActivityDispatchModel(Model): - """Stub installed by the patch; routes ``stream()`` through an activity.""" - - def update_config(self, **_model_config: Any) -> None: - return None - - def get_config(self) -> dict[str, Any]: - return {} - - def structured_output(self, *_args: Any, **_kwargs: Any) -> Any: - raise NotImplementedError( - "Strands Agent.structured_output_async is not supported in workflow " - "context. Use Agent(structured_output_model=...) which routes through " - "stream() via the structured_output_tool." - ) - - async def stream( - self, - messages: Messages, - tool_specs: list[ToolSpec] | None = None, - system_prompt: str | None = None, - *, - tool_choice: ToolChoice | None = None, - system_prompt_content: list[SystemContentBlock] | None = None, - invocation_state: dict[str, Any] | None = None, - **kwargs: Any, - ) -> AsyncIterable[StreamEvent]: - if _streaming_topic is not None: - events = await workflow.execute_activity_method( - _ModelActivity.invoke_model_streaming, - _StreamingInvokeModelInput( - messages=messages, - tool_specs=tool_specs, - system_prompt=system_prompt, - tool_choice=tool_choice, - system_prompt_content=system_prompt_content, - streaming_topic=_streaming_topic, - streaming_batch_interval_seconds=_streaming_batch_interval.total_seconds(), - ), - **_options, - ) - else: - events = await workflow.execute_activity_method( - _ModelActivity.invoke_model, - _InvokeModelInput( - messages=messages, - tool_specs=tool_specs, - system_prompt=system_prompt, - tool_choice=tool_choice, - system_prompt_content=system_prompt_content, - ), - **_options, - ) - for event in events: - yield event - - -def _patched_agent_init(self: Agent, model: Any = None, *args: Any, **kwargs: Any) -> None: - if workflow.in_workflow(): - if model is not None: - raise ValueError( - "Agent(model=...) must not be set inside a workflow. " - "Pass the model to StrandsPlugin(model=...) instead so it " - "runs as a Temporal activity." - ) - model = _ActivityDispatchModel() - _original_agent_init(self, model, *args, **kwargs) - - -def install_patch( - options: dict[str, Any], - streaming_topic: str | None, - streaming_batch_interval: timedelta, -) -> None: - global _options, _streaming_topic, _streaming_batch_interval - _options = options - _streaming_topic = streaming_topic - _streaming_batch_interval = streaming_batch_interval - setattr(Agent, "__init__", _patched_agent_init) - - -def uninstall_patch() -> None: - setattr(Agent, "__init__", _original_agent_init) diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index f97cce057..43e206046 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -4,7 +4,6 @@ from datetime import timedelta from typing import Any -from strands.models import Model from strands.types.tools import AgentTool from temporalio.common import Priority, RetryPolicy @@ -15,75 +14,64 @@ from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner from temporalio.workflow import ActivityCancellationType, VersioningIntent -from ._model import _ModelActivity -from ._patch import install_patch, uninstall_patch -from ._temporal_activity_tool import _TemporalActivityTool +from ._temporal_activity_tool import TemporalActivityTool +from ._temporal_mcp_client import TemporalMCPClient +from ._temporal_model import TemporalModel class StrandsPlugin(SimplePlugin): """Temporal Worker plugin for the Strands Agents SDK. - Configures sandbox passthrough for ``strands`` and ``strands_tools`` and - swaps in ``pydantic_data_converter`` so structured outputs serialize. + Configures sandbox passthrough for ``strands``, ``strands_tools``, ``mcp``, + and ``temporalio.contrib.strands`` (so the MCP tool cache is visible to + workflow code), and swaps in ``pydantic_data_converter`` so structured + outputs serialize. - When ``model`` is supplied, registers the model activities and patches - ``strands.agent.agent.Agent.__init__`` so any ``Agent(...)`` constructed - inside a workflow gets its ``model`` replaced with a stub that routes - ``stream()`` through the registered activity. The ``model`` instance lives - on the worker and is reused across activity invocations. Activity options - (``start_to_close_timeout``, ``retry_policy``, etc.) flow from the plugin - to the dispatched activity. + When ``model`` is supplied, calls its ``model_factory`` once on the worker + to construct the real model, then registers the model invocation activities + against it. The same :class:`TemporalModel` is also passed to + ``Agent(model=...)`` inside the workflow. + + When ``mcp_clients`` is supplied, registers per-server ``{server}-call-tool`` + activities and, at worker startup, connects to each MCP server and caches + its tool list. Workflow-side ``TemporalMCPClient.load_tools()`` reads from + the cache. The plugin raises if any two clients share the same ``server``. """ def __init__( self, *, - model: Model | None = None, - task_queue: str | None = None, - schedule_to_close_timeout: timedelta | None = None, - schedule_to_start_timeout: timedelta | None = None, - start_to_close_timeout: timedelta | None = None, - heartbeat_timeout: timedelta | None = None, - retry_policy: RetryPolicy | None = None, - cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, - versioning_intent: VersioningIntent | None = None, - summary: str | None = None, - priority: Priority = Priority.default, - streaming_topic: str | None = None, - streaming_batch_interval: timedelta = timedelta(milliseconds=100), + model: TemporalModel | None = None, + mcp_clients: list[TemporalMCPClient] = [], ) -> None: - activities: list[Callable] | None = None + activities: list[Callable] = [] if model is not None: - ma = _ModelActivity(model) - activities = [ma.invoke_model, ma.invoke_model_streaming] - options: dict[str, Any] = { - "task_queue": task_queue, - "schedule_to_close_timeout": schedule_to_close_timeout, - "schedule_to_start_timeout": schedule_to_start_timeout, - "start_to_close_timeout": start_to_close_timeout, - "heartbeat_timeout": heartbeat_timeout, - "retry_policy": retry_policy, - "cancellation_type": cancellation_type, - "versioning_intent": versioning_intent, - "summary": summary, - "priority": priority, - } + ma = model._build_activity() + activities.extend([ma.invoke_model, ma.invoke_model_streaming]) + + names = [c.server for c in mcp_clients] + if len(names) != len(set(names)): + raise ValueError( + "Duplicate MCP server names in mcp_clients; each must be unique." + ) + for c in mcp_clients: + activities.extend(c._get_activities()) @asynccontextmanager async def run_context() -> AsyncIterator[None]: - if model is not None: - install_patch(options, streaming_topic, streaming_batch_interval) + for c in mcp_clients: + await c._populate_cache() try: yield finally: - if model is not None: - uninstall_patch() + for c in mcp_clients: + c._clear_cache() super().__init__( "aws.StrandsPlugin", workflow_runner=_workflow_runner, data_converter=_data_converter, - activities=activities, + activities=activities or None, run_context=run_context, ) @@ -121,7 +109,7 @@ def activity_as_tool( "summary": summary, "priority": priority, } - return _TemporalActivityTool(activity_fn, options) + return TemporalActivityTool(activity_fn, options) def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: @@ -133,6 +121,8 @@ def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: restrictions=runner.restrictions.with_passthrough_modules( "strands", "strands_tools", + "mcp", + "temporalio.contrib.strands", ), ) return runner diff --git a/temporalio/contrib/strands/_temporal_activity_tool.py b/temporalio/contrib/strands/_temporal_activity_tool.py index 1d2ac3ffe..86843ebb1 100644 --- a/temporalio/contrib/strands/_temporal_activity_tool.py +++ b/temporalio/contrib/strands/_temporal_activity_tool.py @@ -10,7 +10,7 @@ from temporalio import activity, workflow -class _TemporalActivityTool(AgentTool): +class TemporalActivityTool(AgentTool): """Strands ``AgentTool`` whose body dispatches a Temporal activity.""" def __init__(self, activity_fn: Callable, options: dict[str, Any]) -> None: diff --git a/temporalio/contrib/strands/_temporal_mcp_client.py b/temporalio/contrib/strands/_temporal_mcp_client.py new file mode 100644 index 000000000..423260b48 --- /dev/null +++ b/temporalio/contrib/strands/_temporal_mcp_client.py @@ -0,0 +1,138 @@ +from collections.abc import Callable, Sequence +from dataclasses import dataclass, field +from datetime import timedelta +from typing import Any + +from strands.tools.mcp.mcp_agent_tool import MCPAgentTool +from strands.tools.mcp.mcp_client import MCPClient +from strands.tools.mcp.mcp_types import MCPToolResult, MCPTransport +from strands.tools.tool_provider import ToolProvider +from strands.types.tools import AgentTool + +from temporalio import activity +from temporalio.common import Priority, RetryPolicy +from temporalio.workflow import ActivityCancellationType, VersioningIntent + + +@dataclass +class _MCPToolInfo: + name: str + description: str + input_schema: dict[str, Any] + output_schema: dict[str, Any] | None = None + + +@dataclass +class _CallToolArgs: + tool_name: str + arguments: dict[str, Any] = field(default_factory=dict) + tool_use_id: str = "" + + +# Server name -> cached tool list. Populated by TemporalMCPClient._populate_cache +# at worker startup and read by TemporalMCPClient.load_tools() inside the +# workflow sandbox. ``StrandsPlugin`` adds ``temporalio.contrib.strands`` to +# sandbox passthrough so this dict is shared between worker process and +# workflow execution. +_TOOL_CACHE: dict[str, list[_MCPToolInfo]] = {} + + +class TemporalMCPClient(ToolProvider): + """An MCP server reference for use in both worker and workflow contexts. + + Construct once at module level. Pass to ``StrandsPlugin(mcp_clients=[...])`` + on the worker (which registers the ``{server}-call-tool`` activity and runs + ``list_tools`` at worker startup), and to ``Agent(tools=[...])`` inside the + workflow (which adds the discovered tools to the agent's registry). + + Construction does no I/O. The actual MCP connection happens worker-side at + plugin startup; each tool call later runs as a Temporal activity. + """ + + def __init__( + self, + server: str, + transport_factory: Callable[[], MCPTransport], + *, + task_queue: str | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + retry_policy: RetryPolicy | None = None, + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, + versioning_intent: VersioningIntent | None = None, + summary: str | None = None, + priority: Priority = Priority.default, + ) -> None: + self._server = server + self._transport_factory = transport_factory + self._options: dict[str, Any] = { + "task_queue": task_queue, + "schedule_to_close_timeout": schedule_to_close_timeout, + "schedule_to_start_timeout": schedule_to_start_timeout, + "start_to_close_timeout": start_to_close_timeout, + "heartbeat_timeout": heartbeat_timeout, + "retry_policy": retry_policy, + "cancellation_type": cancellation_type, + "versioning_intent": versioning_intent, + "summary": summary, + "priority": priority, + } + + @property + def server(self) -> str: + return self._server + + async def load_tools(self, **_kwargs: Any) -> Sequence[AgentTool]: + from ._temporal_mcp_tool import TemporalMCPTool + + infos = _TOOL_CACHE.get(self._server, []) + return [TemporalMCPTool(self._server, info, self._options) for info in infos] + + def add_consumer(self, consumer_id: Any, **_kwargs: Any) -> None: + return None + + def remove_consumer(self, consumer_id: Any, **_kwargs: Any) -> None: + return None + + async def _populate_cache(self) -> None: + """Connect to the MCP server, list tools, fill ``_TOOL_CACHE``.""" + client = MCPClient(self._transport_factory) + try: + infos: list[_MCPToolInfo] = [] + for tool in await client.load_tools(): + if not isinstance(tool, MCPAgentTool): + continue + infos.append( + _MCPToolInfo( + name=tool.mcp_tool.name, + description=tool.mcp_tool.description or "", + input_schema=tool.mcp_tool.inputSchema, + output_schema=tool.mcp_tool.outputSchema, + ) + ) + _TOOL_CACHE[self._server] = infos + finally: + client.stop(None, None, None) + + def _clear_cache(self) -> None: + _TOOL_CACHE.pop(self._server, None) + + def _get_activities(self) -> Sequence[Callable]: + transport_factory = self._transport_factory + + @activity.defn(name=f"{self._server}-call-tool") + async def call_tool(args: _CallToolArgs) -> MCPToolResult: + client = MCPClient(transport_factory) + client.start() + try: + return await client.call_tool_async( + tool_use_id=args.tool_use_id, + name=args.tool_name, + arguments=args.arguments, + ) + finally: + client.stop(None, None, None) + + return [call_tool] diff --git a/temporalio/contrib/strands/_temporal_mcp_tool.py b/temporalio/contrib/strands/_temporal_mcp_tool.py new file mode 100644 index 000000000..620c5925d --- /dev/null +++ b/temporalio/contrib/strands/_temporal_mcp_tool.py @@ -0,0 +1,60 @@ +from typing import Any + +from strands.types._events import ToolResultEvent +from strands.types.tools import AgentTool, ToolGenerator, ToolResult, ToolSpec, ToolUse + +from temporalio import workflow + +from ._temporal_mcp_client import _CallToolArgs, _MCPToolInfo + + +class TemporalMCPTool(AgentTool): + """Workflow-side stub for a single MCP tool; dispatches to an activity.""" + + def __init__( + self, + server: str, + info: _MCPToolInfo, + options: dict[str, Any], + ) -> None: + super().__init__() + self._server = server + self._info = info + self._options = options + + @property + def tool_name(self) -> str: + return self._info.name + + @property + def tool_spec(self) -> ToolSpec: + spec: ToolSpec = { + "name": self._info.name, + "description": self._info.description + or f"Tool which performs {self._info.name}", + "inputSchema": {"json": self._info.input_schema}, + } + if self._info.output_schema: + spec["outputSchema"] = {"json": self._info.output_schema} + return spec + + @property + def tool_type(self) -> str: + return "temporal_mcp" + + async def stream( + self, + tool_use: ToolUse, + invocation_state: dict[str, Any], + **kwargs: Any, + ) -> ToolGenerator: + result: ToolResult = await workflow.execute_activity( + f"{self._server}-call-tool", + _CallToolArgs( + tool_name=self._info.name, + arguments=tool_use["input"], + tool_use_id=tool_use["toolUseId"], + ), + **self._options, + ) + yield ToolResultEvent(result) diff --git a/temporalio/contrib/strands/_temporal_model.py b/temporalio/contrib/strands/_temporal_model.py new file mode 100644 index 000000000..2a111d564 --- /dev/null +++ b/temporalio/contrib/strands/_temporal_model.py @@ -0,0 +1,126 @@ +from collections.abc import AsyncIterable, Callable +from datetime import timedelta +from typing import Any + +from strands.models import Model +from strands.types.content import Messages, SystemContentBlock +from strands.types.streaming import StreamEvent +from strands.types.tools import ToolChoice, ToolSpec + +from temporalio import workflow +from temporalio.common import Priority, RetryPolicy +from temporalio.workflow import ActivityCancellationType, VersioningIntent + +from ._model_activity import ( + ModelActivity, + _InvokeModelInput, + _StreamingInvokeModelInput, +) + + +class TemporalModel(Model): + """A Strands :class:`Model` that runs ``stream()`` as a Temporal activity. + + ``model_factory`` is called once on the worker (when the plugin is + constructed) to produce the real model used inside the activity. + Construction of this :class:`TemporalModel` itself does no I/O, so it is + safe to instantiate at module level — the lambda is just stored. + + Pass the same instance to ``StrandsPlugin(model=...)`` (so the plugin can + register the model's activities) and to ``Agent(model=...)`` inside the + workflow (so the agent dispatches through that activity). + + When ``streaming_topic`` is set, each ``StreamEvent`` is also published to + the named topic on the workflow's + :class:`temporalio.contrib.workflow_streams.WorkflowStream` for external + consumers. + """ + + def __init__( + self, + *, + model_factory: Callable[[], Model], + task_queue: str | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + retry_policy: RetryPolicy | None = None, + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, + versioning_intent: VersioningIntent | None = None, + summary: str | None = None, + priority: Priority = Priority.default, + streaming_topic: str | None = None, + streaming_batch_interval: timedelta = timedelta(milliseconds=100), + ) -> None: + self._model_factory = model_factory + self._streaming_topic = streaming_topic + self._streaming_batch_interval = streaming_batch_interval + self._options: dict[str, Any] = { + "task_queue": task_queue, + "schedule_to_close_timeout": schedule_to_close_timeout, + "schedule_to_start_timeout": schedule_to_start_timeout, + "start_to_close_timeout": start_to_close_timeout, + "heartbeat_timeout": heartbeat_timeout, + "retry_policy": retry_policy, + "cancellation_type": cancellation_type, + "versioning_intent": versioning_intent, + "summary": summary, + "priority": priority, + } + + def _build_activity(self) -> ModelActivity: + return ModelActivity(self._model_factory()) + + def update_config(self, **_model_config: Any) -> None: + return None + + def get_config(self) -> dict[str, Any]: + return {} + + def structured_output(self, *_args: Any, **_kwargs: Any) -> Any: + raise NotImplementedError( + "TemporalModel.structured_output is not supported. Use " + "Agent(structured_output_model=...) which routes structured output " + "through stream() via the structured_output_tool." + ) + + async def stream( + self, + messages: Messages, + tool_specs: list[ToolSpec] | None = None, + system_prompt: str | None = None, + *, + tool_choice: ToolChoice | None = None, + system_prompt_content: list[SystemContentBlock] | None = None, + invocation_state: dict[str, Any] | None = None, + **kwargs: Any, + ) -> AsyncIterable[StreamEvent]: + if self._streaming_topic is not None: + events = await workflow.execute_activity_method( + ModelActivity.invoke_model_streaming, + _StreamingInvokeModelInput( + messages=messages, + tool_specs=tool_specs, + system_prompt=system_prompt, + tool_choice=tool_choice, + system_prompt_content=system_prompt_content, + streaming_topic=self._streaming_topic, + streaming_batch_interval_seconds=self._streaming_batch_interval.total_seconds(), + ), + **self._options, + ) + else: + events = await workflow.execute_activity_method( + ModelActivity.invoke_model, + _InvokeModelInput( + messages=messages, + tool_specs=tool_specs, + system_prompt=system_prompt, + tool_choice=tool_choice, + system_prompt_content=system_prompt_content, + ), + **self._options, + ) + for event in events: + yield event diff --git a/tests/contrib/strands/echo_mcp_server.py b/tests/contrib/strands/echo_mcp_server.py new file mode 100644 index 000000000..9f70075ac --- /dev/null +++ b/tests/contrib/strands/echo_mcp_server.py @@ -0,0 +1,13 @@ +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP("echo-server") + + +@mcp.tool() +def echo(message: str) -> str: + """Return the input message unchanged.""" + return message + + +if __name__ == "__main__": + mcp.run() diff --git a/tests/contrib/strands/test_mcp.py b/tests/contrib/strands/test_mcp.py new file mode 100644 index 000000000..ad1401fa0 --- /dev/null +++ b/tests/contrib/strands/test_mcp.py @@ -0,0 +1,81 @@ +import sys +from datetime import timedelta +from pathlib import Path +from uuid import uuid4 + +from mcp import StdioServerParameters, stdio_client +from strands import Agent + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.strands import ( + StrandsPlugin, + TemporalMCPClient, + TemporalModel, +) +from temporalio.worker import Replayer, Worker +from tests.contrib.strands.common import get_activities +from tests.contrib.strands.mock_model import MockModel + +ECHO = TemporalMCPClient( + server="echo", + transport_factory=lambda: stdio_client( + StdioServerParameters( + command=sys.executable, + args=[str(Path(__file__).parent / "echo_mcp_server.py")], + ) + ), + start_to_close_timeout=timedelta(seconds=30), +) + +MODEL = TemporalModel( + model_factory=lambda: MockModel( + [ + {"name": "echo", "input": {"message": "hello"}}, + "Done!", + ] + ), + start_to_close_timeout=timedelta(seconds=30), +) + + +@workflow.defn +class MCPWorkflow: + def __init__(self) -> None: + self.agent = Agent(model=MODEL, tools=[ECHO]) + + @workflow.run + async def run(self, prompt: str) -> str: + result = await self.agent.invoke_async(prompt) + return str(result) + + +async def test_mcp(client: Client): + task_queue = "test_mcp" + plugin = StrandsPlugin(model=MODEL, mcp_clients=[ECHO]) + + async with Worker( + client, + task_queue=task_queue, + workflows=[MCPWorkflow], + plugins=[plugin], + ): + handle = await client.start_workflow( + MCPWorkflow.run, + "echo hello", + id=f"test_mcp_{uuid4()}", + task_queue=task_queue, + ) + assert await handle.result() == "Done!\n" + + history = await handle.fetch_history() + assert get_activities(history) == [ + "invoke_strands_model", + "echo-call-tool", + "invoke_strands_model", + ] + + await Replayer( + workflows=[MCPWorkflow], + plugins=[plugin], + ).replay_workflow(history) diff --git a/tests/contrib/strands/test_model.py b/tests/contrib/strands/test_model.py index f4c50c6b7..e2caa29f9 100644 --- a/tests/contrib/strands/test_model.py +++ b/tests/contrib/strands/test_model.py @@ -5,16 +5,21 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin +from temporalio.contrib.strands import StrandsPlugin, TemporalModel from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel +MODEL = TemporalModel( + model_factory=lambda: MockModel(["Done!"]), + start_to_close_timeout=timedelta(seconds=15), +) + @workflow.defn class ModelWorkflow: def __init__(self) -> None: - self.agent = Agent() + self.agent = Agent(model=MODEL) @workflow.run async def run(self, prompt: str) -> str: @@ -24,10 +29,7 @@ async def run(self, prompt: str) -> str: async def test_model(client: Client): task_queue = "test_model" - plugin = StrandsPlugin( - model=MockModel(["Done!"]), - start_to_close_timeout=timedelta(seconds=15), - ) + plugin = StrandsPlugin(model=MODEL) async with Worker( client, diff --git a/tests/contrib/strands/test_model_streaming.py b/tests/contrib/strands/test_model_streaming.py index ab60df735..76951447b 100644 --- a/tests/contrib/strands/test_model_streaming.py +++ b/tests/contrib/strands/test_model_streaming.py @@ -7,18 +7,24 @@ from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin +from temporalio.contrib.strands import StrandsPlugin, TemporalModel from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel +MODEL = TemporalModel( + model_factory=lambda: MockModel(["Done!"]), + start_to_close_timeout=timedelta(seconds=15), + streaming_topic="events", +) + @workflow.defn class StreamingModelWorkflow: def __init__(self) -> None: self.stream = WorkflowStream() - self.agent = Agent() + self.agent = Agent(model=MODEL) @workflow.run async def run(self, prompt: str) -> str: @@ -28,11 +34,7 @@ async def run(self, prompt: str) -> str: async def test_model_streaming(client: Client): task_queue = "test_model_streaming" - plugin = StrandsPlugin( - model=MockModel(["Done!"]), - start_to_close_timeout=timedelta(seconds=15), - streaming_topic="events", - ) + plugin = StrandsPlugin(model=MODEL) workflow_id = f"test_model_streaming_{uuid4()}" async with Worker( diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index 44f01a3cc..4211057a7 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -6,7 +6,7 @@ from temporalio import activity, workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin, activity_as_tool +from temporalio.contrib.strands import StrandsPlugin, TemporalModel, activity_as_tool from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel @@ -22,10 +22,24 @@ async def current_time_activity() -> str: return current_time.current_time() +MODEL = TemporalModel( + model_factory=lambda: MockModel( + [ + {"name": "current_time", "input": {}}, + {"name": "calculator", "input": {"expression": "3111696 / 74088"}}, + {"name": "letter_counter", "input": {"word": "strawberry", "letter": "R"}}, + "Done!", + ] + ), + start_to_close_timeout=timedelta(seconds=15), +) + + @workflow.defn class ToolWorkflow: def __init__(self) -> None: self.agent = Agent( + model=MODEL, tools=[ calculator, activity_as_tool( @@ -44,23 +58,7 @@ async def run(self, prompt: str) -> str: async def test_tool(client: Client): task_queue = "test_tool" - plugin = StrandsPlugin( - model=MockModel( - [ - {"name": "current_time", "input": {}}, - { - "name": "calculator", - "input": {"expression": "3111696 / 74088"}, - }, - { - "name": "letter_counter", - "input": {"word": "strawberry", "letter": "R"}, - }, - "Done!", - ] - ), - start_to_close_timeout=timedelta(seconds=15), - ) + plugin = StrandsPlugin(model=MODEL) async with Worker( client, From 160bed052055599c1b3cb3e6334c11ab5b896ab1 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 13 May 2026 13:51:25 -0700 Subject: [PATCH 08/46] contrib/strands: rewrite README, default TemporalModel to BedrockModel Restructure the README into Quickstart + per-feature sections (Model, Structured Output, Streaming, Tools, MCP), add an experimental warning, installation instructions, and a link to strandsagents.com. Also default `TemporalModel.model_factory` to `BedrockModel`, matching the Strands `Agent` default, so the common case doesn't need a factory lambda. --- temporalio/contrib/strands/README.md | 240 +++++++++++++----- temporalio/contrib/strands/_temporal_model.py | 6 +- 2 files changed, 187 insertions(+), 59 deletions(-) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index ff70f71d4..31a4179d9 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -1,59 +1,185 @@ # AWS Strands Agents -# Unsupported features -* File-based tool lookup - -# Migration Guide -* Use `agent.invoke_async(message)` instead of `agent(message)` which spawns a thread -* Decorate non-deterministic tools with `@activity.defn` and register them via `Worker(activities=[...])`. Wrap them in the agent with `activity_as_tool()`. -* For tools imported from `strands_tools`, write a thin async wrapper that calls the tool, e.g.: - ```python - from strands_tools import current_time - - @activity.defn(name="current_time") - async def current_time_activity() -> str: - return current_time.current_time() - ``` -* Construct a `TemporalModel` once with a `model_factory` lambda that returns the real model. Pass it both to the workflow's `Agent(model=...)` and to `StrandsPlugin(model=...)`; the plugin calls the factory at worker startup, and the workflow's `stream()` calls dispatch as Temporal activities. - ```python - from strands.models.bedrock import BedrockModel - from temporalio.contrib.strands import StrandsPlugin, TemporalModel - - MODEL = TemporalModel( - model_factory=lambda: BedrockModel(model_id="claude-3-5-sonnet"), - start_to_close_timeout=timedelta(seconds=60), - ) - - # workflow - @workflow.defn - class MyWorkflow: - def __init__(self) -> None: - self.agent = Agent(model=MODEL) - - # worker - Worker(..., plugins=[StrandsPlugin(model=MODEL)]) - ``` - To stream chunks to external consumers, pass `streaming_topic="..."` to `TemporalModel` and host a `WorkflowStream` on the workflow. -* For MCP servers, construct `TemporalMCPClient` once at module level and reference the same instance from both the plugin (registers the per-server `{server}-call-tool` activity and connects at worker startup to discover tools) and the workflow's `Agent(tools=[...])`: - ```python - from mcp import StdioServerParameters, stdio_client - from temporalio.contrib.strands import StrandsPlugin, TemporalMCPClient, TemporalModel - - ECHO = TemporalMCPClient( - server="echo", - transport_factory=lambda: stdio_client( - StdioServerParameters(command="...", args=[...]), - ), - start_to_close_timeout=timedelta(seconds=30), - ) - - # workflow - @workflow.defn - class MyWorkflow: - def __init__(self) -> None: - self.agent = Agent(model=MODEL, tools=[ECHO]) - - # worker - Worker(..., plugins=[StrandsPlugin(model=MODEL, mcp_clients=[ECHO])]) - ``` - The plugin connects to the MCP server once at worker startup to enumerate tools. The schema is frozen for the worker's lifetime; restart workers to pick up MCP-server changes. If the MCP server is unavailable at startup, the worker fails to start. +⚠️ **This package is currently at an experimental release stage.** ⚠️ + +This Temporal [Plugin](https://docs.temporal.io/develop/plugins-guide) allows you to run [Strands Agents](https://strandsagents.com/) inside Temporal Workflows, routing model invocations, tool calls, and MCP tool calls through Temporal Activities for durable execution, automatic retries, and timeouts. + +## Installation + +```sh +uv add temporalio[strands] +``` + +## Quickstart + +`workflow.py` defines the workflow and runs the worker: + +```python +import asyncio +from datetime import timedelta + +from strands import Agent + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.worker import Worker + +MODEL = TemporalModel(start_to_close_timeout=timedelta(seconds=60)) + + +@workflow.defn +class MyWorkflow: + def __init__(self) -> None: + self.agent = Agent(model=MODEL) + + @workflow.run + async def run(self, prompt: str) -> str: + result = await self.agent.invoke_async(prompt) + return str(result) + + +async def main() -> None: + client = await Client.connect("localhost:7233") + worker = Worker( + client, + task_queue="strands", + workflows=[MyWorkflow], + plugins=[StrandsPlugin(model=MODEL)], + ) + await worker.run() + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +`client.py` starts the workflow: + +```python +import asyncio + +from temporalio.client import Client + +from workflow import MyWorkflow + + +async def main() -> None: + client = await Client.connect("localhost:7233") + result = await client.execute_workflow( + MyWorkflow.run, + "Hello", + id="strands-quickstart", + task_queue="strands", + ) + print(result) + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +Note: Use `agent.invoke_async(message)` instead of `agent(message)`. The synchronous form spawns a worker thread, which the workflow sandbox blocks. + +## Model + +`TemporalModel` defaults to `BedrockModel()`. To use a different model (or a different `BedrockModel` configuration), pass a `model_factory` lambda. The plugin calls it once at worker startup. + +```python +from strands.models.anthropic import AnthropicModel + +MODEL = TemporalModel( + model_factory=lambda: AnthropicModel(client_args={"api_key": "..."}), + start_to_close_timeout=timedelta(seconds=60), +) +``` + +## Structured Output + +Pass a Pydantic model to `Agent(structured_output_model=...)`. Strands routes the call through `stream()` as a synthetic tool, so it dispatches via the model activity like any other invocation. The result is available as `result.structured_output` and can be returned directly from the workflow — `StrandsPlugin` defaults to [`pydantic_data_converter`](../pydantic), so Pydantic types serialize across the activity and workflow boundary. + +```python +from pydantic import BaseModel + +class PersonInfo(BaseModel): + name: str + age: int + +@workflow.defn +class MyWorkflow: + def __init__(self) -> None: + self.agent = Agent(model=MODEL, structured_output_model=PersonInfo) + + @workflow.run + async def run(self, prompt: str) -> PersonInfo: + result = await self.agent.invoke_async(prompt) + return result.structured_output +``` + +`TemporalModel.structured_output()` called directly is not supported — always go through `Agent(structured_output_model=...)`. + +## Streaming + +To forward model chunks to external consumers, pass `streaming_topic="..."` to `TemporalModel` and host a `WorkflowStream` on the workflow. Each `StreamEvent` is published on the named topic from inside the model activity; subscribers read via `WorkflowStreamClient`. Chunks are batched on `streaming_batch_interval` (default 100ms). + +```python +MODEL = TemporalModel(streaming_topic="events") + +# workflow +@workflow.defn +class MyWorkflow: + def __init__(self) -> None: + self.stream = WorkflowStream() + self.agent = Agent(model=MODEL) + +# client +async for item in WorkflowStreamClient.create(client, workflow_id).subscribe( + ["events"], result_type=StreamEvent, +): + print(item.data) +``` + +## Tools + +Decorate non-deterministic tools with `@activity.defn`, or if you're importing tools from `strands_tools`, wrap them in a thin async function. Then, register the activity on the worker via `Worker(activities=[...])` and pass it to the agent with `activity_as_tool(activity, **options)` along with any activity options (e.g. `start_to_close_timeout`): + +```python +from strands_tools import current_time + +@activity.defn +async def fetch_user(user_id: str) -> dict: + ... + +@activity.defn(name="current_time") +async def current_time_activity() -> str: + return current_time.current_time() + +agent = Agent(tools=[ + activity_as_tool(fetch_user, start_to_close_timeout=timedelta(seconds=30)), +]) +``` + +## MCP + +Construct `TemporalMCPClient` once at module level and reference the same instance from both the plugin (which registers a per-server `{server}-call-tool` activity and connects at worker startup to discover tools) and `Agent(tools=[...])`: + +```python +from mcp import StdioServerParameters, stdio_client +from temporalio.contrib.strands import TemporalMCPClient + +ECHO = TemporalMCPClient( + server="echo", + transport_factory=lambda: stdio_client( + StdioServerParameters(command="...", args=[...]), + ), + start_to_close_timeout=timedelta(seconds=30), +) + +# workflow +agent = Agent(tools=[ECHO]) + +# worker +Worker(..., plugins=[StrandsPlugin(mcp_clients=[ECHO])]) +``` + +The plugin connects to the MCP server once at worker startup to enumerate tools. The schema is frozen for the worker's lifetime; restart workers to pick up MCP-server changes. If the MCP server is unavailable at startup, the worker fails to start. diff --git a/temporalio/contrib/strands/_temporal_model.py b/temporalio/contrib/strands/_temporal_model.py index 2a111d564..827ba0f23 100644 --- a/temporalio/contrib/strands/_temporal_model.py +++ b/temporalio/contrib/strands/_temporal_model.py @@ -3,6 +3,7 @@ from typing import Any from strands.models import Model +from strands.models.bedrock import BedrockModel from strands.types.content import Messages, SystemContentBlock from strands.types.streaming import StreamEvent from strands.types.tools import ToolChoice, ToolSpec @@ -22,7 +23,8 @@ class TemporalModel(Model): """A Strands :class:`Model` that runs ``stream()`` as a Temporal activity. ``model_factory`` is called once on the worker (when the plugin is - constructed) to produce the real model used inside the activity. + constructed) to produce the real model used inside the activity. Defaults + to :class:`strands.models.bedrock.BedrockModel`, matching Strands' default. Construction of this :class:`TemporalModel` itself does no I/O, so it is safe to instantiate at module level — the lambda is just stored. @@ -39,7 +41,7 @@ class TemporalModel(Model): def __init__( self, *, - model_factory: Callable[[], Model], + model_factory: Callable[[], Model] = BedrockModel, task_queue: str | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, From d759fa675d2dd464a0a0db6badc4ff7a326fbfc7 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 13 May 2026 15:51:13 -0700 Subject: [PATCH 09/46] contrib/strands: show worker activity registration in Tools snippet --- temporalio/contrib/strands/README.md | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 31a4179d9..d91f4a33c 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -149,14 +149,23 @@ from strands_tools import current_time @activity.defn async def fetch_user(user_id: str) -> dict: ... - + @activity.defn(name="current_time") async def current_time_activity() -> str: return current_time.current_time() +# workflow agent = Agent(tools=[ activity_as_tool(fetch_user, start_to_close_timeout=timedelta(seconds=30)), + activity_as_tool(current_time_activity, start_to_close_timeout=timedelta(seconds=15)), ]) + +# worker +Worker( + ..., + activities=[fetch_user, current_time_activity], + plugins=[StrandsPlugin(model=MODEL)], +) ``` ## MCP From f29728ed4dc4b5262b2076a2a9e0260ff412eeaf Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 13 May 2026 17:49:01 -0700 Subject: [PATCH 10/46] contrib/strands: add activity_as_hook helper, document Hooks Adds activity_as_hook(activity_fn, *, extract, **options): wraps a Temporal activity as a Strands HookCallback so I/O-doing hook callbacks (audit logs, metrics) dispatch off the workflow. Co-locates with activity_as_tool in a new _workflow.py. --- temporalio/contrib/strands/README.md | 43 ++++++++++ temporalio/contrib/strands/__init__.py | 4 +- temporalio/contrib/strands/_plugin.py | 43 ---------- temporalio/contrib/strands/_workflow.py | 103 +++++++++++++++++++++++ tests/contrib/strands/test_hooks.py | 105 ++++++++++++++++++++++++ 5 files changed, 254 insertions(+), 44 deletions(-) create mode 100644 temporalio/contrib/strands/_workflow.py create mode 100644 tests/contrib/strands/test_hooks.py diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index d91f4a33c..d42e36902 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -168,6 +168,49 @@ Worker( ) ``` +## Hooks + +Strands' [hook system](https://strandsagents.com/) (`strands.hooks`) lets you subscribe callbacks to events in the agent lifecycle — invocation start/end, model call before/after, tool call before/after, message added. The native `Agent(hooks=[MyHookProvider()])` API works as-is: every single-agent hook event fires in workflow context, so deterministic callbacks just work. + +```python +from strands.hooks import HookProvider, HookRegistry +from strands.hooks.events import AfterToolCallEvent + +class AuditHook(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(AfterToolCallEvent, self._on_tool_call) + + def _on_tool_call(self, event: AfterToolCallEvent) -> None: + # Pure local state - deterministic across replay. + workflow.logger.info(f"tool {event.tool_use['name']} finished") + +agent = Agent(hooks=[AuditHook()]) +``` + +Callbacks run in workflow context, so they must be deterministic: no `time.time()`, `uuid.uuid4()`, or I/O — same rules as workflow code. For callbacks that need I/O (audit logging, metrics, alerting), use `activity_as_hook()` to dispatch the work as a Temporal activity: + +```python +from temporalio.contrib.strands import activity_as_hook + +@activity.defn +async def persist_tool_call(tool_name: str) -> None: + # I/O safely in an activity. + ... + +class AuditHook(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback( + AfterToolCallEvent, + activity_as_hook( + persist_tool_call, + extract=lambda event: event.tool_use["name"], + start_to_close_timeout=timedelta(seconds=10), + ), + ) +``` + +`extract` pulls a serializable value from the event (the activity input). Events hold references to the `Agent`, `AgentTool` instances, etc., none of which cross the activity boundary. Multi-agent hook events (`graph` / `swarm` / A2A) aren't supported yet — they require multi-agent support, which is still in progress. + ## MCP Construct `TemporalMCPClient` once at module level and reference the same instance from both the plugin (which registers a per-server `{server}-call-tool` activity and connects at worker startup to discover tools) and `Agent(tools=[...])`: diff --git a/temporalio/contrib/strands/__init__.py b/temporalio/contrib/strands/__init__.py index 55a0e0bbb..3b10318fd 100644 --- a/temporalio/contrib/strands/__init__.py +++ b/temporalio/contrib/strands/__init__.py @@ -1,12 +1,14 @@ """Temporal integration for the Strands Agents SDK.""" -from ._plugin import StrandsPlugin, activity_as_tool +from ._plugin import StrandsPlugin from ._temporal_mcp_client import TemporalMCPClient from ._temporal_model import TemporalModel +from ._workflow import activity_as_hook, activity_as_tool __all__ = [ "StrandsPlugin", "TemporalMCPClient", "TemporalModel", + "activity_as_hook", "activity_as_tool", ] diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index 43e206046..3378a7141 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -1,20 +1,13 @@ from collections.abc import AsyncIterator, Callable from contextlib import asynccontextmanager from dataclasses import replace -from datetime import timedelta -from typing import Any -from strands.types.tools import AgentTool - -from temporalio.common import Priority, RetryPolicy from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.converter import DataConverter, DefaultPayloadConverter from temporalio.plugin import SimplePlugin from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner -from temporalio.workflow import ActivityCancellationType, VersioningIntent -from ._temporal_activity_tool import TemporalActivityTool from ._temporal_mcp_client import TemporalMCPClient from ._temporal_model import TemporalModel @@ -76,42 +69,6 @@ async def run_context() -> AsyncIterator[None]: ) -def activity_as_tool( - activity_fn: Callable, - *, - task_queue: str | None = None, - schedule_to_close_timeout: timedelta | None = None, - schedule_to_start_timeout: timedelta | None = None, - start_to_close_timeout: timedelta | None = None, - heartbeat_timeout: timedelta | None = None, - retry_policy: RetryPolicy | None = None, - cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, - activity_id: str | None = None, - versioning_intent: VersioningIntent | None = None, - summary: str | None = None, - priority: Priority = Priority.default, -) -> AgentTool: - """Wrap a Temporal activity as a Strands tool. - - ``activity_fn`` must be decorated by ``@activity.defn``. All keyword - arguments are forwarded to ``workflow.execute_activity``. - """ - options: dict[str, Any] = { - "task_queue": task_queue, - "schedule_to_close_timeout": schedule_to_close_timeout, - "schedule_to_start_timeout": schedule_to_start_timeout, - "start_to_close_timeout": start_to_close_timeout, - "heartbeat_timeout": heartbeat_timeout, - "retry_policy": retry_policy, - "cancellation_type": cancellation_type, - "activity_id": activity_id, - "versioning_intent": versioning_intent, - "summary": summary, - "priority": priority, - } - return TemporalActivityTool(activity_fn, options) - - def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: if not runner: raise ValueError("No WorkflowRunner provided to the Strands plugin.") diff --git a/temporalio/contrib/strands/_workflow.py b/temporalio/contrib/strands/_workflow.py new file mode 100644 index 000000000..b67ca2f1c --- /dev/null +++ b/temporalio/contrib/strands/_workflow.py @@ -0,0 +1,103 @@ +"""Helpers for wiring Temporal activities into Strands' agent and hook surfaces. + +Both ``activity_as_tool`` and ``activity_as_hook`` produce workflow-side objects +that dispatch user activities via :func:`temporalio.workflow.execute_activity`, +so the I/O actually happens off the workflow. +""" + +from collections.abc import Callable +from datetime import timedelta +from typing import Any, TypeVar + +from strands.hooks.registry import BaseHookEvent, HookCallback +from strands.types.tools import AgentTool + +from temporalio import workflow +from temporalio.common import Priority, RetryPolicy +from temporalio.workflow import ActivityCancellationType, VersioningIntent + +from ._temporal_activity_tool import TemporalActivityTool + + +def activity_as_tool( + activity_fn: Callable, + *, + task_queue: str | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + retry_policy: RetryPolicy | None = None, + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, + activity_id: str | None = None, + versioning_intent: VersioningIntent | None = None, + summary: str | None = None, + priority: Priority = Priority.default, +) -> AgentTool: + """Wrap a Temporal activity as a Strands tool. + + ``activity_fn`` must be decorated by ``@activity.defn``. All keyword + arguments are forwarded to ``workflow.execute_activity``. + """ + options: dict[str, Any] = { + "task_queue": task_queue, + "schedule_to_close_timeout": schedule_to_close_timeout, + "schedule_to_start_timeout": schedule_to_start_timeout, + "start_to_close_timeout": start_to_close_timeout, + "heartbeat_timeout": heartbeat_timeout, + "retry_policy": retry_policy, + "cancellation_type": cancellation_type, + "activity_id": activity_id, + "versioning_intent": versioning_intent, + "summary": summary, + "priority": priority, + } + return TemporalActivityTool(activity_fn, options) + + +TEvent = TypeVar("TEvent", bound=BaseHookEvent) + + +def activity_as_hook( + activity_fn: Callable, + *, + extract: Callable[[TEvent], Any], + task_queue: str | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + retry_policy: RetryPolicy | None = None, + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, + activity_id: str | None = None, + versioning_intent: VersioningIntent | None = None, + summary: str | None = None, + priority: Priority = Priority.default, +) -> HookCallback[TEvent]: + """Wrap a Temporal activity as a Strands hook callback. + + The returned coroutine, when registered with ``HookRegistry.add_callback``, + dispatches ``activity_fn`` as a Temporal activity each time the associated + event fires. ``extract`` is called with the event to produce a + serializable activity input — events themselves are not serializable, since + they hold references to the ``Agent`` and other workflow-bound objects. + All other keyword arguments are forwarded to ``workflow.execute_activity``. + """ + options: dict[str, Any] = { + "task_queue": task_queue, + "schedule_to_close_timeout": schedule_to_close_timeout, + "schedule_to_start_timeout": schedule_to_start_timeout, + "start_to_close_timeout": start_to_close_timeout, + "heartbeat_timeout": heartbeat_timeout, + "retry_policy": retry_policy, + "cancellation_type": cancellation_type, + "activity_id": activity_id, + "versioning_intent": versioning_intent, + "summary": summary, + "priority": priority, + } + + async def callback(event: TEvent) -> None: + await workflow.execute_activity(activity_fn, extract(event), **options) + + return callback diff --git a/tests/contrib/strands/test_hooks.py b/tests/contrib/strands/test_hooks.py new file mode 100644 index 000000000..814af5519 --- /dev/null +++ b/tests/contrib/strands/test_hooks.py @@ -0,0 +1,105 @@ +from datetime import timedelta +from uuid import uuid4 + +from strands import Agent, tool +from strands.hooks import HookProvider, HookRegistry +from strands.hooks.events import AfterToolCallEvent + +from temporalio import activity, workflow +from temporalio.client import Client +from temporalio.contrib.strands import ( + StrandsPlugin, + TemporalModel, + activity_as_hook, +) +from temporalio.worker import Replayer, Worker +from tests.contrib.strands.common import get_activities +from tests.contrib.strands.mock_model import MockModel + +# Module-level sink: written by the audit activity, read in assertions. +# Activity bodies run in worker context, not the sandbox, so a plain list is fine. +_AUDIT_LOG: list[str] = [] + + +@activity.defn +async def audit_tool(tool_name: str) -> None: + _AUDIT_LOG.append(tool_name) + + +@tool +def echo(text: str) -> str: + return text + + +class AuditHook(HookProvider): + def __init__(self) -> None: + self.fired_events: list[str] = [] + + def register_hooks(self, registry: HookRegistry, **kwargs: object) -> None: + registry.add_callback(AfterToolCallEvent, self._sync_log) + registry.add_callback( + AfterToolCallEvent, + activity_as_hook( + audit_tool, + extract=lambda event: event.tool_use["name"], + start_to_close_timeout=timedelta(seconds=10), + ), + ) + + def _sync_log(self, event: AfterToolCallEvent) -> None: + # Deterministic in-workflow mutation: appends to per-workflow state. + self.fired_events.append(event.tool_use["name"]) + + +MODEL = TemporalModel( + model_factory=lambda: MockModel( + [ + {"name": "echo", "input": {"text": "hi"}}, + "Done!", + ] + ), + start_to_close_timeout=timedelta(seconds=15), +) + + +@workflow.defn +class HooksWorkflow: + def __init__(self) -> None: + self.hook = AuditHook() + self.agent = Agent(model=MODEL, tools=[echo], hooks=[self.hook]) + + @workflow.run + async def run(self, prompt: str) -> list[str]: + await self.agent.invoke_async(prompt) + return self.hook.fired_events + + +async def test_hooks(client: Client): + _AUDIT_LOG.clear() + task_queue = "test_hooks" + plugin = StrandsPlugin(model=MODEL) + + async with Worker( + client, + task_queue=task_queue, + workflows=[HooksWorkflow], + activities=[audit_tool], + plugins=[plugin], + ): + handle = await client.start_workflow( + HooksWorkflow.run, + "Say hi", + id=f"test_hooks_{uuid4()}", + task_queue=task_queue, + ) + assert await handle.result() == ["echo"] + + assert _AUDIT_LOG == ["echo"] + + history = await handle.fetch_history() + assert "audit_tool" in get_activities(history) + + await Replayer( + workflows=[HooksWorkflow], + plugins=[plugin], + ).replay_workflow(history) From 58cba6dafd9de622eab6790917f822dde3707d27 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 13 May 2026 18:09:09 -0700 Subject: [PATCH 11/46] contrib/strands: rename activity_as_hook extract to activity_input --- temporalio/contrib/strands/README.md | 6 +++--- temporalio/contrib/strands/_workflow.py | 6 +++--- tests/contrib/strands/test_hooks.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index d42e36902..3a2bb8f9f 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -1,4 +1,4 @@ -# AWS Strands Agents +# Strands Agents ⚠️ **This package is currently at an experimental release stage.** ⚠️ @@ -203,13 +203,13 @@ class AuditHook(HookProvider): AfterToolCallEvent, activity_as_hook( persist_tool_call, - extract=lambda event: event.tool_use["name"], + activity_input=lambda event: event.tool_use["name"], start_to_close_timeout=timedelta(seconds=10), ), ) ``` -`extract` pulls a serializable value from the event (the activity input). Events hold references to the `Agent`, `AgentTool` instances, etc., none of which cross the activity boundary. Multi-agent hook events (`graph` / `swarm` / A2A) aren't supported yet — they require multi-agent support, which is still in progress. +`activity_input` extracts serializable values from the event to pass as the activity's input. Use a dataclass or Pydantic model for multiple values. This is needed because events hold references to the `Agent`, `AgentTool` instances, etc., none of which cross the activity boundary. ## MCP diff --git a/temporalio/contrib/strands/_workflow.py b/temporalio/contrib/strands/_workflow.py index b67ca2f1c..e6be1a9df 100644 --- a/temporalio/contrib/strands/_workflow.py +++ b/temporalio/contrib/strands/_workflow.py @@ -61,7 +61,7 @@ def activity_as_tool( def activity_as_hook( activity_fn: Callable, *, - extract: Callable[[TEvent], Any], + activity_input: Callable[[TEvent], Any], task_queue: str | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, @@ -78,7 +78,7 @@ def activity_as_hook( The returned coroutine, when registered with ``HookRegistry.add_callback``, dispatches ``activity_fn`` as a Temporal activity each time the associated - event fires. ``extract`` is called with the event to produce a + event fires. ``activity_input`` is called with the event to produce a serializable activity input — events themselves are not serializable, since they hold references to the ``Agent`` and other workflow-bound objects. All other keyword arguments are forwarded to ``workflow.execute_activity``. @@ -98,6 +98,6 @@ def activity_as_hook( } async def callback(event: TEvent) -> None: - await workflow.execute_activity(activity_fn, extract(event), **options) + await workflow.execute_activity(activity_fn, activity_input(event), **options) return callback diff --git a/tests/contrib/strands/test_hooks.py b/tests/contrib/strands/test_hooks.py index 814af5519..76f123155 100644 --- a/tests/contrib/strands/test_hooks.py +++ b/tests/contrib/strands/test_hooks.py @@ -41,7 +41,7 @@ def register_hooks(self, registry: HookRegistry, **kwargs: object) -> None: AfterToolCallEvent, activity_as_hook( audit_tool, - extract=lambda event: event.tool_use["name"], + activity_input=lambda event: event.tool_use["name"], start_to_close_timeout=timedelta(seconds=10), ), ) From 12cc311c0fa105cb09c6311a3c2b8869343e5635 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 14 May 2026 10:15:25 -0700 Subject: [PATCH 12/46] contrib/strands: document HITL interrupts, add integration test --- temporalio/contrib/strands/README.md | 39 +++++++++ tests/contrib/strands/test_interrupt.py | 103 ++++++++++++++++++++++++ 2 files changed, 142 insertions(+) create mode 100644 tests/contrib/strands/test_interrupt.py diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 3a2bb8f9f..130a762cb 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -211,6 +211,45 @@ class AuditHook(HookProvider): `activity_input` extracts serializable values from the event to pass as the activity's input. Use a dataclass or Pydantic model for multiple values. This is needed because events hold references to the `Agent`, `AgentTool` instances, etc., none of which cross the activity boundary. +## Human-in-the-loop interrupts + +A hook on an interruptible event (e.g. `BeforeToolCallEvent`) can pause the agent by calling `event.interrupt(name, reason=...)`. When this fires, `agent.invoke_async()` returns `AgentResult(stop_reason="interrupt", interrupts=[...])` instead of raising. Pair this with a signal handler that supplies responses, then resume by calling `agent.invoke_async(responses)`: + +```python +from strands.hooks import HookProvider, HookRegistry +from strands.hooks.events import BeforeToolCallEvent + +class ApprovalHook(HookProvider): + def register_hooks(self, registry: HookRegistry) -> None: + registry.add_callback(BeforeToolCallEvent, self._gate) + + def _gate(self, event: BeforeToolCallEvent) -> None: + if event.interrupt("approval", reason="confirm delete") != "approve": + event.cancel_tool = "denied" + +@workflow.defn +class MyWorkflow: + def __init__(self) -> None: + self.agent = Agent(model=MODEL, tools=[delete_thing], hooks=[ApprovalHook()]) + self._approval: str | None = None + + @workflow.signal + def approve(self, response: str) -> None: + self._approval = response + + @workflow.run + async def run(self, prompt: str) -> str: + result = await self.agent.invoke_async(prompt) + if result.stop_reason == "interrupt": + await workflow.wait_condition(lambda: self._approval is not None) + result = await self.agent.invoke_async([ + {"interruptResponse": {"interruptId": result.interrupts[0].id, "response": self._approval}} + ]) + return str(result) +``` + +Interrupt hooks must be deterministic: branch on the activity result and call `event.interrupt(...)` on the workflow side. Tools wrapped via `activity_as_tool` cannot raise interrupts — the activity body has no `Agent` reference — so hooks are the interrupt surface for this plugin. + ## MCP Construct `TemporalMCPClient` once at module level and reference the same instance from both the plugin (which registers a per-server `{server}-call-tool` activity and connects at worker startup to discover tools) and `Agent(tools=[...])`: diff --git a/tests/contrib/strands/test_interrupt.py b/tests/contrib/strands/test_interrupt.py new file mode 100644 index 000000000..a25cab8a5 --- /dev/null +++ b/tests/contrib/strands/test_interrupt.py @@ -0,0 +1,103 @@ +from datetime import timedelta +from uuid import uuid4 + +from strands import Agent, tool +from strands.hooks import HookProvider, HookRegistry +from strands.hooks.events import BeforeToolCallEvent +from strands.types.interrupt import InterruptResponseContent + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.worker import Replayer, Worker +from tests.contrib.strands.common import get_activities +from tests.contrib.strands.mock_model import MockModel + + +@tool +def delete_thing(name: str) -> str: + return f"deleted {name}" + + +class ApprovalHook(HookProvider): + def register_hooks(self, registry: HookRegistry, **kwargs: object) -> None: + registry.add_callback(BeforeToolCallEvent, self._gate) + + def _gate(self, event: BeforeToolCallEvent) -> None: + if event.tool_use["name"] != "delete_thing": + return + approval = event.interrupt( + "approval", + reason=f"approve delete of {event.tool_use['input']['name']}?", + ) + if approval != "approve": + event.cancel_tool = "denied" + + +MODEL = TemporalModel( + model_factory=lambda: MockModel( + [ + {"name": "delete_thing", "input": {"name": "foo"}}, + "Done!", + ] + ), + start_to_close_timeout=timedelta(seconds=15), +) + + +@workflow.defn +class InterruptWorkflow: + def __init__(self) -> None: + self.agent = Agent( + model=MODEL, tools=[delete_thing], hooks=[ApprovalHook()] + ) + self._approval: str | None = None + + @workflow.signal + def approve(self, response: str) -> None: + self._approval = response + + @workflow.run + async def run(self, prompt: str) -> str: + result = await self.agent.invoke_async(prompt) + while result.stop_reason == "interrupt": + await workflow.wait_condition(lambda: self._approval is not None) + response = self._approval + self._approval = None + responses: list[InterruptResponseContent] = [ + {"interruptResponse": {"interruptId": i.id, "response": response}} + for i in (result.interrupts or []) + ] + result = await self.agent.invoke_async(responses) + return str(result) + + +async def test_interrupt(client: Client): + task_queue = "test_interrupt" + plugin = StrandsPlugin(model=MODEL) + + async with Worker( + client, + task_queue=task_queue, + workflows=[InterruptWorkflow], + plugins=[plugin], + ): + handle = await client.start_workflow( + InterruptWorkflow.run, + "delete foo", + id=f"test_interrupt_{uuid4()}", + task_queue=task_queue, + ) + await handle.signal(InterruptWorkflow.approve, "approve") + assert await handle.result() == "Done!\n" + + history = await handle.fetch_history() + assert get_activities(history) == [ + "invoke_strands_model", + "invoke_strands_model", + ] + + await Replayer( + workflows=[InterruptWorkflow], + plugins=[plugin], + ).replay_workflow(history) From 9d6e68af617c6ec366c723f6c12787c4e9b703dc Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 14 May 2026 12:32:24 -0700 Subject: [PATCH 13/46] contrib/strands: document continue-as-new for long chats --- temporalio/contrib/strands/README.md | 38 ++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 130a762cb..305ed2155 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -250,6 +250,44 @@ class MyWorkflow: Interrupt hooks must be deterministic: branch on the activity result and call `event.interrupt(...)` on the workflow side. Tools wrapped via `activity_as_tool` cannot raise interrupts — the activity body has no `Agent` reference — so hooks are the interrupt surface for this plugin. +## Continue-as-new + +A chat-style workflow accumulates history with every turn and will eventually hit Temporal's per-workflow history limit. `workflow.info().is_continue_as_new_suggested()` flips true once the server decides history has grown large enough; check it after each turn and hand off to a fresh run, carrying `agent.messages` as input: + +```python +from dataclasses import dataclass, field +from strands.types.content import Messages + +@dataclass +class ChatInput: + messages: Messages = field(default_factory=list) + +@workflow.defn +class ChatWorkflow: + def __init__(self) -> None: + self._pending: list[str] = [] + self._done = False + + @workflow.signal + def user_says(self, prompt: str) -> None: + self._pending.append(prompt) + + @workflow.signal + def end_chat(self) -> None: + self._done = True + + @workflow.run + async def run(self, input: ChatInput) -> None: + agent = Agent(model=MODEL, messages=list(input.messages)) + while True: + await workflow.wait_condition(lambda: self._pending or self._done) + if self._done: + return + await agent.invoke_async(self._pending.pop(0)) + if workflow.info().is_continue_as_new_suggested(): + workflow.continue_as_new(ChatInput(messages=agent.messages)) +``` + ## MCP Construct `TemporalMCPClient` once at module level and reference the same instance from both the plugin (which registers a per-server `{server}-call-tool` activity and connects at worker startup to discover tools) and `Agent(tools=[...])`: From 570260665904c05770fecac666bfa5014e5be91c Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 14 May 2026 14:08:34 -0700 Subject: [PATCH 14/46] contrib/strands: document OpenTelemetryPlugin compatibility --- temporalio/contrib/strands/README.md | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 305ed2155..3c8f19c50 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -312,3 +312,23 @@ Worker(..., plugins=[StrandsPlugin(mcp_clients=[ECHO])]) ``` The plugin connects to the MCP server once at worker startup to enumerate tools. The schema is frozen for the worker's lifetime; restart workers to pick up MCP-server changes. If the MCP server is unavailable at startup, the worker fails to start. + +## Observability + +`StrandsPlugin` composes cleanly with [`OpenTelemetryPlugin`](../opentelemetry) — add both to the worker to get OTel spans around the model, tool, and MCP activities the plugin schedules, plus any spans Strands itself emits inside `invoke_async`: + +```python +import opentelemetry.trace +from temporalio.contrib.opentelemetry import OpenTelemetryPlugin, create_tracer_provider + +opentelemetry.trace.set_tracer_provider(create_tracer_provider()) + +Worker( + client, + task_queue="strands", + workflows=[MyWorkflow], + plugins=[StrandsPlugin(model=MODEL), OpenTelemetryPlugin()], +) +``` + +Set the tracer provider before connecting the client. See the [OpenTelemetry plugin README](../opentelemetry) for exporter setup. From bba6b6776f84302e10151cb7e0b62bdd206bfdd1 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 14 May 2026 16:24:37 -0700 Subject: [PATCH 15/46] contrib/strands: fix lint warnings and add missing docstrings --- temporalio/contrib/strands/_model_activity.py | 3 +++ temporalio/contrib/strands/_plugin.py | 5 +++-- temporalio/contrib/strands/_temporal_activity_tool.py | 5 +++++ temporalio/contrib/strands/_temporal_mcp_client.py | 5 +++++ temporalio/contrib/strands/_temporal_mcp_tool.py | 5 +++++ temporalio/contrib/strands/_temporal_model.py | 5 +++++ tests/contrib/strands/test_interrupt.py | 4 +--- tests/contrib/strands/test_tool.py | 5 ++++- 8 files changed, 31 insertions(+), 6 deletions(-) diff --git a/temporalio/contrib/strands/_model_activity.py b/temporalio/contrib/strands/_model_activity.py index 9214c4285..10b056991 100644 --- a/temporalio/contrib/strands/_model_activity.py +++ b/temporalio/contrib/strands/_model_activity.py @@ -30,16 +30,19 @@ class ModelActivity: """Holds the user-supplied model and exposes the model activities.""" def __init__(self, model: Model) -> None: + """Store the model that activities will invoke.""" self._model = model @activity.defn(name="invoke_strands_model") async def invoke_model(self, input: _InvokeModelInput) -> list[StreamEvent]: + """Run the model and return its stream events as a list.""" return [event async for event in _stream(self._model, input)] @activity.defn(name="invoke_strands_model_streaming") async def invoke_model_streaming( self, input: _StreamingInvokeModelInput ) -> list[StreamEvent]: + """Run the model and publish each stream event to a WorkflowStream.""" events: list[StreamEvent] = [] stream = WorkflowStreamClient.from_within_activity( batch_interval=timedelta(seconds=input.streaming_batch_interval_seconds), diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index 3378a7141..054c557d0 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -1,4 +1,4 @@ -from collections.abc import AsyncIterator, Callable +from collections.abc import AsyncGenerator, Callable from contextlib import asynccontextmanager from dataclasses import replace @@ -37,6 +37,7 @@ def __init__( model: TemporalModel | None = None, mcp_clients: list[TemporalMCPClient] = [], ) -> None: + """Build the plugin from an optional model and MCP client list.""" activities: list[Callable] = [] if model is not None: ma = model._build_activity() @@ -51,7 +52,7 @@ def __init__( activities.extend(c._get_activities()) @asynccontextmanager - async def run_context() -> AsyncIterator[None]: + async def run_context() -> AsyncGenerator[None, None]: for c in mcp_clients: await c._populate_cache() try: diff --git a/temporalio/contrib/strands/_temporal_activity_tool.py b/temporalio/contrib/strands/_temporal_activity_tool.py index 86843ebb1..c753e0e60 100644 --- a/temporalio/contrib/strands/_temporal_activity_tool.py +++ b/temporalio/contrib/strands/_temporal_activity_tool.py @@ -14,6 +14,7 @@ class TemporalActivityTool(AgentTool): """Strands ``AgentTool`` whose body dispatches a Temporal activity.""" def __init__(self, activity_fn: Callable, options: dict[str, Any]) -> None: + """Capture the target activity and the options to invoke it with.""" super().__init__() defn = activity._Definition.from_callable(activity_fn) if not defn or not defn.name: @@ -27,14 +28,17 @@ def __init__(self, activity_fn: Callable, options: dict[str, Any]) -> None: @property def tool_name(self) -> str: + """Name of the underlying Temporal activity.""" return self._activity_name @property def tool_spec(self) -> ToolSpec: + """Strands ToolSpec derived from the activity's signature.""" return self._spec @property def tool_type(self) -> str: + """Tool kind identifier used by Strands.""" return "temporal_activity" async def stream( @@ -43,6 +47,7 @@ async def stream( invocation_state: dict[str, Any], **kwargs: Any, ) -> ToolGenerator: + """Execute the tool by dispatching to the bound Temporal activity.""" bound = self._signature.bind(**tool_use["input"]) bound.apply_defaults() positional = list(bound.arguments.values()) diff --git a/temporalio/contrib/strands/_temporal_mcp_client.py b/temporalio/contrib/strands/_temporal_mcp_client.py index 423260b48..baaf09a72 100644 --- a/temporalio/contrib/strands/_temporal_mcp_client.py +++ b/temporalio/contrib/strands/_temporal_mcp_client.py @@ -65,6 +65,7 @@ def __init__( summary: str | None = None, priority: Priority = Priority.default, ) -> None: + """Configure the server name, transport factory, and activity options.""" self._server = server self._transport_factory = transport_factory self._options: dict[str, Any] = { @@ -82,18 +83,22 @@ def __init__( @property def server(self) -> str: + """MCP server name used as the activity prefix.""" return self._server async def load_tools(self, **_kwargs: Any) -> Sequence[AgentTool]: + """Return TemporalMCPTool wrappers for tools cached at worker startup.""" from ._temporal_mcp_tool import TemporalMCPTool infos = _TOOL_CACHE.get(self._server, []) return [TemporalMCPTool(self._server, info, self._options) for info in infos] def add_consumer(self, consumer_id: Any, **_kwargs: Any) -> None: + """No-op; consumer tracking is handled by the underlying MCP client.""" return None def remove_consumer(self, consumer_id: Any, **_kwargs: Any) -> None: + """No-op; consumer tracking is handled by the underlying MCP client.""" return None async def _populate_cache(self) -> None: diff --git a/temporalio/contrib/strands/_temporal_mcp_tool.py b/temporalio/contrib/strands/_temporal_mcp_tool.py index 620c5925d..885b1a7e2 100644 --- a/temporalio/contrib/strands/_temporal_mcp_tool.py +++ b/temporalio/contrib/strands/_temporal_mcp_tool.py @@ -17,6 +17,7 @@ def __init__( info: _MCPToolInfo, options: dict[str, Any], ) -> None: + """Bind this tool to a server, its cached info, and activity options.""" super().__init__() self._server = server self._info = info @@ -24,10 +25,12 @@ def __init__( @property def tool_name(self) -> str: + """Name of the underlying MCP tool.""" return self._info.name @property def tool_spec(self) -> ToolSpec: + """Strands ToolSpec built from the cached MCP tool info.""" spec: ToolSpec = { "name": self._info.name, "description": self._info.description @@ -40,6 +43,7 @@ def tool_spec(self) -> ToolSpec: @property def tool_type(self) -> str: + """Tool kind identifier used by Strands.""" return "temporal_mcp" async def stream( @@ -48,6 +52,7 @@ async def stream( invocation_state: dict[str, Any], **kwargs: Any, ) -> ToolGenerator: + """Execute the tool by dispatching to the per-server call-tool activity.""" result: ToolResult = await workflow.execute_activity( f"{self._server}-call-tool", _CallToolArgs( diff --git a/temporalio/contrib/strands/_temporal_model.py b/temporalio/contrib/strands/_temporal_model.py index 827ba0f23..f547a266d 100644 --- a/temporalio/contrib/strands/_temporal_model.py +++ b/temporalio/contrib/strands/_temporal_model.py @@ -55,6 +55,7 @@ def __init__( streaming_topic: str | None = None, streaming_batch_interval: timedelta = timedelta(milliseconds=100), ) -> None: + """Configure the model factory, activity options, and streaming settings.""" self._model_factory = model_factory self._streaming_topic = streaming_topic self._streaming_batch_interval = streaming_batch_interval @@ -75,12 +76,15 @@ def _build_activity(self) -> ModelActivity: return ModelActivity(self._model_factory()) def update_config(self, **_model_config: Any) -> None: + """No-op; the real model is configured worker-side via ``model_factory``.""" return None def get_config(self) -> dict[str, Any]: + """Return an empty config; configuration lives on the worker-side model.""" return {} def structured_output(self, *_args: Any, **_kwargs: Any) -> Any: + """Not supported; use ``Agent(structured_output_model=...)`` instead.""" raise NotImplementedError( "TemporalModel.structured_output is not supported. Use " "Agent(structured_output_model=...) which routes structured output " @@ -98,6 +102,7 @@ async def stream( invocation_state: dict[str, Any] | None = None, **kwargs: Any, ) -> AsyncIterable[StreamEvent]: + """Run the model via the registered Temporal activity and yield events.""" if self._streaming_topic is not None: events = await workflow.execute_activity_method( ModelActivity.invoke_model_streaming, diff --git a/tests/contrib/strands/test_interrupt.py b/tests/contrib/strands/test_interrupt.py index a25cab8a5..c357b5556 100644 --- a/tests/contrib/strands/test_interrupt.py +++ b/tests/contrib/strands/test_interrupt.py @@ -48,9 +48,7 @@ def _gate(self, event: BeforeToolCallEvent) -> None: @workflow.defn class InterruptWorkflow: def __init__(self) -> None: - self.agent = Agent( - model=MODEL, tools=[delete_thing], hooks=[ApprovalHook()] - ) + self.agent = Agent(model=MODEL, tools=[delete_thing], hooks=[ApprovalHook()]) self._approval: str | None = None @workflow.signal diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index 4211057a7..e5911c4a6 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -2,7 +2,10 @@ from uuid import uuid4 from strands import Agent, tool -from strands_tools import calculator, current_time +from strands_tools import ( # pyright: ignore[reportMissingTypeStubs] + calculator, + current_time, +) from temporalio import activity, workflow from temporalio.client import Client From 81870afa28c5eafb462bc9f938724605893d6061 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 15 May 2026 09:30:27 -0700 Subject: [PATCH 16/46] disable tiktoken and other sandbox warnings --- temporalio/contrib/strands/_plugin.py | 12 ++++++++++++ tests/contrib/strands/test_structured_output.py | 5 ++++- 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index 054c557d0..3a6d83eb8 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -2,6 +2,8 @@ from contextlib import asynccontextmanager from dataclasses import replace +import strands.models.model as _strands_model + from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.converter import DataConverter, DefaultPayloadConverter from temporalio.plugin import SimplePlugin @@ -11,6 +13,10 @@ from ._temporal_mcp_client import TemporalMCPClient from ._temporal_model import TemporalModel +# Force Strands' base Model.count_tokens to skip tiktoken (non-deterministic) +# and use its chars-per-token heuristic (deterministic). +setattr(_strands_model, "_get_encoding", lambda: None) + class StrandsPlugin(SimplePlugin): """Temporal Worker plugin for the Strands Agents SDK. @@ -81,6 +87,12 @@ def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: "strands_tools", "mcp", "temporalio.contrib.strands", + # The SDK's default passthrough already includes ``pydantic`` because + # it lazy-imports inside some classes; extend that to its compiled + # validation core and its ``Annotated`` helper for the same reason. + "pydantic", + "pydantic_core", + "annotated_types", ), ) return runner diff --git a/tests/contrib/strands/test_structured_output.py b/tests/contrib/strands/test_structured_output.py index 29663e7dc..4d7dd885b 100644 --- a/tests/contrib/strands/test_structured_output.py +++ b/tests/contrib/strands/test_structured_output.py @@ -44,11 +44,14 @@ async def test_structured_output(client: Client): task_queue = "test_structured_output" plugin = StrandsPlugin() + config = client.config() + config["plugins"] = [*config["plugins"], plugin] + client = Client(**config) + async with Worker( client, task_queue=task_queue, workflows=[StructuredOutputWorkflow], - plugins=[plugin], ): handle = await client.start_workflow( StructuredOutputWorkflow.run, From 5be969f0fd02bad4fb5cd8e6cc29fcd16a18d4ae Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 15 May 2026 10:34:00 -0700 Subject: [PATCH 17/46] contrib/strands: disable Strands retries, route via Temporal RetryPolicy Patch Agent.__init__ at import time to force retry_strategy=None and raise ValueError when a strategy is supplied, so retries happen at the Temporal activity layer (RetryPolicy on activity options) rather than blocking inside the activity body. Documents the behavior in README. --- temporalio/contrib/strands/README.md | 17 ++++++++++++++++- temporalio/contrib/strands/_plugin.py | 24 ++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 3c8f19c50..14020cfd0 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -2,7 +2,7 @@ ⚠️ **This package is currently at an experimental release stage.** ⚠️ -This Temporal [Plugin](https://docs.temporal.io/develop/plugins-guide) allows you to run [Strands Agents](https://strandsagents.com/) inside Temporal Workflows, routing model invocations, tool calls, and MCP tool calls through Temporal Activities for durable execution, automatic retries, and timeouts. +This Temporal [Plugin](https://docs.temporal.io/develop/plugins-guide) allows you to run [Strands Agents](https://strandsagents.com/) inside Temporal Workflows, routing model invocations, tool calls, and MCP tool calls through Temporal Activities for durable execution, Temporal-managed retries, and timeouts. ## Installation @@ -94,6 +94,21 @@ MODEL = TemporalModel( ) ``` +## Retries + +The plugin disables Strands' built-in `ModelRetryStrategy` so retries are handled exclusively by Temporal. Configure retries via `RetryPolicy` on the activity options accepted by `TemporalModel`, `activity_as_tool`, `activity_as_hook`, and `TemporalMCPClient`: + +```python +from temporalio.common import RetryPolicy + +MODEL = TemporalModel( + start_to_close_timeout=timedelta(seconds=60), + retry_policy=RetryPolicy(maximum_attempts=3), +) +``` + +Passing `retry_strategy=...` to `Agent(...)` raises `ValueError`; remove the argument (or pass `retry_strategy=None`) and put the retry config on the activity options instead. + ## Structured Output Pass a Pydantic model to `Agent(structured_output_model=...)`. Strands routes the call through `stream()` as a synthetic tool, so it dispatches via the model activity like any other invocation. The result is available as `result.structured_output` and can be returned directly from the workflow — `StrandsPlugin` defaults to [`pydantic_data_converter`](../pydantic), so Pydantic types serialize across the activity and workflow boundary. diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index 3a6d83eb8..a745a0ef0 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -1,7 +1,9 @@ from collections.abc import AsyncGenerator, Callable from contextlib import asynccontextmanager from dataclasses import replace +from typing import Any +import strands.agent.agent as _strands_agent import strands.models.model as _strands_model from temporalio.contrib.pydantic import pydantic_data_converter @@ -17,6 +19,28 @@ # and use its chars-per-token heuristic (deterministic). setattr(_strands_model, "_get_encoding", lambda: None) +# Temporal handles retries via RetryPolicy on activity options. Disable +# Strands' in-activity ModelRetryStrategy (default max_attempts=6) so retries +# aren't duplicated, and fail fast if the user tries to configure one. +_original_agent_init = _strands_agent.Agent.__init__ +_RETRY_STRATEGY_NOT_PASSED: Any = object() + + +def _patched_agent_init(self: Any, *args: Any, **kwargs: Any) -> None: + retry_strategy = kwargs.get("retry_strategy", _RETRY_STRATEGY_NOT_PASSED) + if retry_strategy is not _RETRY_STRATEGY_NOT_PASSED and retry_strategy is not None: + raise ValueError( + "StrandsPlugin disables Strands retries; configure retries via " + "RetryPolicy on the activity options passed to TemporalModel, " + "activity_as_tool, activity_as_hook, or TemporalMCPClient. " + "Remove retry_strategy from Agent(...) or pass retry_strategy=None." + ) + kwargs["retry_strategy"] = None + _original_agent_init(self, *args, **kwargs) + + +setattr(_strands_agent.Agent, "__init__", _patched_agent_init) + class StrandsPlugin(SimplePlugin): """Temporal Worker plugin for the Strands Agents SDK. From 09ebad768e1002cd276ce559314a3eea10c81f8b Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 15 May 2026 14:53:50 -0700 Subject: [PATCH 18/46] contrib/strands: move activity_as_tool/hook under workflow.* submodule --- temporalio/contrib/strands/README.md | 15 ++++++++------- temporalio/contrib/strands/__init__.py | 5 ++--- temporalio/contrib/strands/_plugin.py | 2 +- .../contrib/strands/{_workflow.py => workflow.py} | 0 tests/contrib/strands/test_hooks.py | 7 ++----- tests/contrib/strands/test_tool.py | 3 ++- 6 files changed, 15 insertions(+), 17 deletions(-) rename temporalio/contrib/strands/{_workflow.py => workflow.py} (100%) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 14020cfd0..4d2918d2d 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -96,7 +96,7 @@ MODEL = TemporalModel( ## Retries -The plugin disables Strands' built-in `ModelRetryStrategy` so retries are handled exclusively by Temporal. Configure retries via `RetryPolicy` on the activity options accepted by `TemporalModel`, `activity_as_tool`, `activity_as_hook`, and `TemporalMCPClient`: +The plugin disables Strands' built-in `ModelRetryStrategy` so retries are handled exclusively by Temporal. Configure retries via `RetryPolicy` on the activity options accepted by `TemporalModel`, `workflow.activity_as_tool`, `workflow.activity_as_hook`, and `TemporalMCPClient`: ```python from temporalio.common import RetryPolicy @@ -156,10 +156,11 @@ async for item in WorkflowStreamClient.create(client, workflow_id).subscribe( ## Tools -Decorate non-deterministic tools with `@activity.defn`, or if you're importing tools from `strands_tools`, wrap them in a thin async function. Then, register the activity on the worker via `Worker(activities=[...])` and pass it to the agent with `activity_as_tool(activity, **options)` along with any activity options (e.g. `start_to_close_timeout`): +Decorate non-deterministic tools with `@activity.defn`, or if you're importing tools from `strands_tools`, wrap them in a thin async function. Then, register the activity on the worker via `Worker(activities=[...])` and pass it to the agent with `workflow.activity_as_tool(activity, **options)` along with any activity options (e.g. `start_to_close_timeout`): ```python from strands_tools import current_time +from temporalio.contrib.strands import workflow as strands_workflow @activity.defn async def fetch_user(user_id: str) -> dict: @@ -171,8 +172,8 @@ async def current_time_activity() -> str: # workflow agent = Agent(tools=[ - activity_as_tool(fetch_user, start_to_close_timeout=timedelta(seconds=30)), - activity_as_tool(current_time_activity, start_to_close_timeout=timedelta(seconds=15)), + strands_workflow.activity_as_tool(fetch_user, start_to_close_timeout=timedelta(seconds=30)), + strands_workflow.activity_as_tool(current_time_activity, start_to_close_timeout=timedelta(seconds=15)), ]) # worker @@ -202,10 +203,10 @@ class AuditHook(HookProvider): agent = Agent(hooks=[AuditHook()]) ``` -Callbacks run in workflow context, so they must be deterministic: no `time.time()`, `uuid.uuid4()`, or I/O — same rules as workflow code. For callbacks that need I/O (audit logging, metrics, alerting), use `activity_as_hook()` to dispatch the work as a Temporal activity: +Callbacks run in workflow context, so they must be deterministic: no `time.time()`, `uuid.uuid4()`, or I/O — same rules as workflow code. For callbacks that need I/O (audit logging, metrics, alerting), use `workflow.activity_as_hook()` to dispatch the work as a Temporal activity: ```python -from temporalio.contrib.strands import activity_as_hook +from temporalio.contrib.strands.workflow import activity_as_hook @activity.defn async def persist_tool_call(tool_name: str) -> None: @@ -263,7 +264,7 @@ class MyWorkflow: return str(result) ``` -Interrupt hooks must be deterministic: branch on the activity result and call `event.interrupt(...)` on the workflow side. Tools wrapped via `activity_as_tool` cannot raise interrupts — the activity body has no `Agent` reference — so hooks are the interrupt surface for this plugin. +Interrupt hooks must be deterministic: branch on the activity result and call `event.interrupt(...)` on the workflow side. Tools wrapped via `workflow.activity_as_tool` cannot raise interrupts — the activity body has no `Agent` reference — so hooks are the interrupt surface for this plugin. ## Continue-as-new diff --git a/temporalio/contrib/strands/__init__.py b/temporalio/contrib/strands/__init__.py index 3b10318fd..857cf46a0 100644 --- a/temporalio/contrib/strands/__init__.py +++ b/temporalio/contrib/strands/__init__.py @@ -1,14 +1,13 @@ """Temporal integration for the Strands Agents SDK.""" +from . import workflow from ._plugin import StrandsPlugin from ._temporal_mcp_client import TemporalMCPClient from ._temporal_model import TemporalModel -from ._workflow import activity_as_hook, activity_as_tool __all__ = [ "StrandsPlugin", "TemporalMCPClient", "TemporalModel", - "activity_as_hook", - "activity_as_tool", + "workflow", ] diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index a745a0ef0..d50c9393b 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -32,7 +32,7 @@ def _patched_agent_init(self: Any, *args: Any, **kwargs: Any) -> None: raise ValueError( "StrandsPlugin disables Strands retries; configure retries via " "RetryPolicy on the activity options passed to TemporalModel, " - "activity_as_tool, activity_as_hook, or TemporalMCPClient. " + "workflow.activity_as_tool, workflow.activity_as_hook, or TemporalMCPClient. " "Remove retry_strategy from Agent(...) or pass retry_strategy=None." ) kwargs["retry_strategy"] = None diff --git a/temporalio/contrib/strands/_workflow.py b/temporalio/contrib/strands/workflow.py similarity index 100% rename from temporalio/contrib/strands/_workflow.py rename to temporalio/contrib/strands/workflow.py diff --git a/tests/contrib/strands/test_hooks.py b/tests/contrib/strands/test_hooks.py index 76f123155..3511cb3d5 100644 --- a/tests/contrib/strands/test_hooks.py +++ b/tests/contrib/strands/test_hooks.py @@ -7,11 +7,8 @@ from temporalio import activity, workflow from temporalio.client import Client -from temporalio.contrib.strands import ( - StrandsPlugin, - TemporalModel, - activity_as_hook, -) +from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.contrib.strands.workflow import activity_as_hook from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index e5911c4a6..efca8d038 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -9,7 +9,8 @@ from temporalio import activity, workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin, TemporalModel, activity_as_tool +from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.contrib.strands.workflow import activity_as_tool from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel From d9ffbd18398dc2973d8b8b373c259b0e9516e9ae Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 18 May 2026 09:52:13 -0700 Subject: [PATCH 19/46] contrib/strands: disable Agent.take_snapshot/load_snapshot --- temporalio/contrib/strands/README.md | 4 ++++ temporalio/contrib/strands/_plugin.py | 17 +++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 4d2918d2d..9a4f1fd7b 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -109,6 +109,10 @@ MODEL = TemporalModel( Passing `retry_strategy=...` to `Agent(...)` raises `ValueError`; remove the argument (or pass `retry_strategy=None`) and put the retry config on the activity options instead. +## Snapshots + +The plugin disables `Agent.take_snapshot()` and `Agent.load_snapshot()`. Temporal's event history already persists workflow state durably at a finer granularity than Strands snapshots, so calling either inside a workflow is redundant. Both methods raise `NotImplementedError`. + ## Structured Output Pass a Pydantic model to `Agent(structured_output_model=...)`. Strands routes the call through `stream()` as a synthetic tool, so it dispatches via the model activity like any other invocation. The result is available as `result.structured_output` and can be returned directly from the workflow — `StrandsPlugin` defaults to [`pydantic_data_converter`](../pydantic), so Pydantic types serialize across the activity and workflow boundary. diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index d50c9393b..869a86c9f 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -42,6 +42,23 @@ def _patched_agent_init(self: Any, *args: Any, **kwargs: Any) -> None: setattr(_strands_agent.Agent, "__init__", _patched_agent_init) +# Temporal workflows already persist agent state durably via the event history at +# a finer granularity than Strands snapshots, so calling either method inside a +# workflow is redundant; fail loudly to steer users to Temporal's durability. +def _snapshots_disabled(*args: Any, **kwargs: Any) -> Any: + del args, kwargs + raise NotImplementedError( + "StrandsPlugin disables Agent.take_snapshot()/load_snapshot(). " + "Temporal workflows already persist agent state durably via the event " + "history at a finer granularity than Strands snapshots. Remove the " + "snapshot call and rely on Temporal's durable execution instead." + ) + + +setattr(_strands_agent.Agent, "take_snapshot", _snapshots_disabled) +setattr(_strands_agent.Agent, "load_snapshot", _snapshots_disabled) + + class StrandsPlugin(SimplePlugin): """Temporal Worker plugin for the Strands Agents SDK. From 28a6e8f31477c34f0e16ae911be2cdf7b30950a7 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 18 May 2026 12:06:05 -0700 Subject: [PATCH 20/46] contrib/strands: add CODEOWNERS entries --- .github/CODEOWNERS | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c718b40c0..9345ed54e 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -13,6 +13,8 @@ /temporalio/contrib/google_adk_agents/ @temporalio/ai-sdk @temporalio/sdk /temporalio/contrib/langsmith/ @temporalio/ai-sdk @temporalio/sdk /temporalio/contrib/openai_agents/ @temporalio/ai-sdk @temporalio/sdk +/temporalio/contrib/strands/ @temporalio/ai-sdk @temporalio/sdk /tests/contrib/google_adk_agents/ @temporalio/ai-sdk @temporalio/sdk /tests/contrib/langsmith/ @temporalio/ai-sdk @temporalio/sdk /tests/contrib/openai_agents/ @temporalio/ai-sdk @temporalio/sdk +/tests/contrib/strands/ @temporalio/ai-sdk @temporalio/sdk From d6fdf3ecb08c2a26f0e6cc29ca040f2efa83db4a Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 18 May 2026 13:15:12 -0700 Subject: [PATCH 21/46] contrib/strands: type _InvokeModelInput fields as Any Python < 3.11's get_type_hints leaks NotRequired[...] through TypedDict fields, which the default JSON converter can't deserialize. Strands Message and ToolSpec use NotRequired, so loosen the activity input fields to Any; values pass through unchanged to Model.stream. --- temporalio/contrib/strands/_model_activity.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/temporalio/contrib/strands/_model_activity.py b/temporalio/contrib/strands/_model_activity.py index 10b056991..986334e45 100644 --- a/temporalio/contrib/strands/_model_activity.py +++ b/temporalio/contrib/strands/_model_activity.py @@ -1,23 +1,26 @@ from collections.abc import AsyncIterable from dataclasses import dataclass from datetime import timedelta +from typing import Any from strands.models import Model -from strands.types.content import Messages, SystemContentBlock from strands.types.streaming import StreamEvent -from strands.types.tools import ToolChoice, ToolSpec from temporalio import activity from temporalio.contrib.workflow_streams import WorkflowStreamClient +# Fields are typed as Any because strands TypedDicts (Message, ToolSpec) use +# NotRequired, which Python < 3.11's get_type_hints leaks through unchanged +# and the default JSON converter then fails to deserialize. Values flow +# through unchanged to ``Model.stream`` which accepts the raw dicts. @dataclass class _InvokeModelInput: - messages: Messages - tool_specs: list[ToolSpec] | None = None + messages: Any + tool_specs: Any = None system_prompt: str | None = None - tool_choice: ToolChoice | None = None - system_prompt_content: list[SystemContentBlock] | None = None + tool_choice: Any = None + system_prompt_content: Any = None @dataclass From 45c6c5ab7500472d5ddb484088d2be1ee6ee87eb Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 18 May 2026 14:09:24 -0700 Subject: [PATCH 22/46] contrib/common: extract _heartbeat_decorator for cross-plugin use --- .github/CODEOWNERS | 5 +++-- temporalio/contrib/common/__init__.py | 0 .../_heartbeat_decorator.py | 14 ++++++-------- .../openai_agents/_invoke_model_activity.py | 2 +- 4 files changed, 10 insertions(+), 11 deletions(-) create mode 100644 temporalio/contrib/common/__init__.py rename temporalio/contrib/{openai_agents => common}/_heartbeat_decorator.py (64%) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 9345ed54e..6c4d36b99 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -6,10 +6,11 @@ # Below are owners for modules in the temporalio/contrib/ -# and tests/contrib/ directories that are owned by teams -# other than the SDK team. For each one, we add the owning team, +# and tests/contrib/ directories that are owned by teams +# other than the SDK team. For each one, we add the owning team, # as well as @temporalio/sdk, so the SDK team can continue to # manage repo-wide concerns. +/temporalio/contrib/common/ @temporalio/ai-sdk @temporalio/sdk /temporalio/contrib/google_adk_agents/ @temporalio/ai-sdk @temporalio/sdk /temporalio/contrib/langsmith/ @temporalio/ai-sdk @temporalio/sdk /temporalio/contrib/openai_agents/ @temporalio/ai-sdk @temporalio/sdk diff --git a/temporalio/contrib/common/__init__.py b/temporalio/contrib/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/temporalio/contrib/openai_agents/_heartbeat_decorator.py b/temporalio/contrib/common/_heartbeat_decorator.py similarity index 64% rename from temporalio/contrib/openai_agents/_heartbeat_decorator.py rename to temporalio/contrib/common/_heartbeat_decorator.py index 4baff6706..085a5d3e5 100644 --- a/temporalio/contrib/openai_agents/_heartbeat_decorator.py +++ b/temporalio/contrib/common/_heartbeat_decorator.py @@ -8,23 +8,22 @@ F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) -def _auto_heartbeater(fn: F) -> F: # type:ignore[reportUnusedClass] - # Propagate type hints from the original callable. +def _auto_heartbeater(fn: F) -> F: + """Decorator that heartbeats at half the activity's heartbeat timeout.""" + @wraps(fn) async def wrapper(*args: Any, **kwargs: Any) -> Any: heartbeat_timeout = activity.info().heartbeat_timeout heartbeat_task = None if heartbeat_timeout: - # Heartbeat twice as often as the timeout heartbeat_task = asyncio.create_task( - heartbeat_every(heartbeat_timeout.total_seconds() / 2) + _heartbeat_every(heartbeat_timeout.total_seconds() / 2) ) try: return await fn(*args, **kwargs) finally: if heartbeat_task: heartbeat_task.cancel() - # Wait for heartbeat cancellation to complete try: await heartbeat_task except asyncio.CancelledError: @@ -33,8 +32,7 @@ async def wrapper(*args: Any, **kwargs: Any) -> Any: return cast(F, wrapper) -async def heartbeat_every(delay: float, *details: Any) -> None: - """Heartbeat every so often while not cancelled""" +async def _heartbeat_every(delay: float) -> None: while True: await asyncio.sleep(delay) - activity.heartbeat(*details) + activity.heartbeat() diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index 1aa836eee..d38ffc500 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -43,7 +43,7 @@ from typing_extensions import Required, TypedDict from temporalio import activity -from temporalio.contrib.openai_agents._heartbeat_decorator import _auto_heartbeater +from temporalio.contrib.common._heartbeat_decorator import _auto_heartbeater from temporalio.contrib.workflow_streams import WorkflowStreamClient from temporalio.exceptions import ApplicationError From 8ab32a7ac4881c46a1590ec330023d2cc2eb8fce Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 18 May 2026 14:09:33 -0700 Subject: [PATCH 23/46] contrib/strands: auto-heartbeat model activities, loosen input types Switch invoke_model/_streaming to the _auto_heartbeater pattern (matches openai_agents) so the heartbeat clock doesn't depend on event cadence and the non-streaming activity is covered too. Type _InvokeModelInput fields as Any: strands Message/ToolSpec use NotRequired, which Python < 3.11's get_type_hints leaks through and the default JSON converter then fails to deserialize. Values pass through unchanged to Model.stream. Drop the explicit activity name= override so the activities use their function names (invoke_model, invoke_model_streaming), matching the naming convention used by other contrib model activities. --- temporalio/contrib/strands/_model_activity.py | 8 +++++--- tests/contrib/strands/test_interrupt.py | 4 ++-- tests/contrib/strands/test_mcp.py | 4 ++-- tests/contrib/strands/test_model.py | 2 +- tests/contrib/strands/test_model_streaming.py | 2 +- tests/contrib/strands/test_tool.py | 8 ++++---- 6 files changed, 15 insertions(+), 13 deletions(-) diff --git a/temporalio/contrib/strands/_model_activity.py b/temporalio/contrib/strands/_model_activity.py index 986334e45..c33266f49 100644 --- a/temporalio/contrib/strands/_model_activity.py +++ b/temporalio/contrib/strands/_model_activity.py @@ -7,6 +7,7 @@ from strands.types.streaming import StreamEvent from temporalio import activity +from temporalio.contrib.common._heartbeat_decorator import _auto_heartbeater from temporalio.contrib.workflow_streams import WorkflowStreamClient @@ -36,12 +37,14 @@ def __init__(self, model: Model) -> None: """Store the model that activities will invoke.""" self._model = model - @activity.defn(name="invoke_strands_model") + @activity.defn + @_auto_heartbeater async def invoke_model(self, input: _InvokeModelInput) -> list[StreamEvent]: """Run the model and return its stream events as a list.""" return [event async for event in _stream(self._model, input)] - @activity.defn(name="invoke_strands_model_streaming") + @activity.defn + @_auto_heartbeater async def invoke_model_streaming( self, input: _StreamingInvokeModelInput ) -> list[StreamEvent]: @@ -53,7 +56,6 @@ async def invoke_model_streaming( topic = stream.topic(input.streaming_topic) async with stream: async for event in _stream(self._model, input): - activity.heartbeat() events.append(event) topic.publish(event) return events diff --git a/tests/contrib/strands/test_interrupt.py b/tests/contrib/strands/test_interrupt.py index c357b5556..06b4f529e 100644 --- a/tests/contrib/strands/test_interrupt.py +++ b/tests/contrib/strands/test_interrupt.py @@ -91,8 +91,8 @@ async def test_interrupt(client: Client): history = await handle.fetch_history() assert get_activities(history) == [ - "invoke_strands_model", - "invoke_strands_model", + "invoke_model", + "invoke_model", ] await Replayer( diff --git a/tests/contrib/strands/test_mcp.py b/tests/contrib/strands/test_mcp.py index ad1401fa0..a0ae7b5ee 100644 --- a/tests/contrib/strands/test_mcp.py +++ b/tests/contrib/strands/test_mcp.py @@ -70,9 +70,9 @@ async def test_mcp(client: Client): history = await handle.fetch_history() assert get_activities(history) == [ - "invoke_strands_model", + "invoke_model", "echo-call-tool", - "invoke_strands_model", + "invoke_model", ] await Replayer( diff --git a/tests/contrib/strands/test_model.py b/tests/contrib/strands/test_model.py index e2caa29f9..1a734c619 100644 --- a/tests/contrib/strands/test_model.py +++ b/tests/contrib/strands/test_model.py @@ -46,7 +46,7 @@ async def test_model(client: Client): assert await handle.result() == "Done!\n" history = await handle.fetch_history() - assert get_activities(history) == ["invoke_strands_model"] + assert get_activities(history) == ["invoke_model"] await Replayer( workflows=[ModelWorkflow], diff --git a/tests/contrib/strands/test_model_streaming.py b/tests/contrib/strands/test_model_streaming.py index 76951447b..e35e30fe1 100644 --- a/tests/contrib/strands/test_model_streaming.py +++ b/tests/contrib/strands/test_model_streaming.py @@ -69,7 +69,7 @@ async def collect() -> None: await asyncio.wait_for(collect_task, timeout=10.0) history = await handle.fetch_history() - assert get_activities(history) == ["invoke_strands_model_streaming"] + assert get_activities(history) == ["invoke_model_streaming"] assert any("messageStart" in e for e in events) assert any("messageStop" in e for e in events) diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index efca8d038..8d01ad580 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -84,13 +84,13 @@ async def test_tool(client: Client): history = await handle.fetch_history() assert get_activities(history) == [ - "invoke_strands_model", + "invoke_model", "current_time", - "invoke_strands_model", + "invoke_model", # calculator (in-workflow) - "invoke_strands_model", + "invoke_model", # letter_counter (in-workflow) - "invoke_strands_model", + "invoke_model", ] await Replayer( From 36ec764d8a878f6c535efcda132d773ad52345c2 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 18 May 2026 14:25:41 -0700 Subject: [PATCH 24/46] contrib/strands: clarify tiktoken comment --- temporalio/contrib/strands/_plugin.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index 869a86c9f..c67b52825 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -15,8 +15,8 @@ from ._temporal_mcp_client import TemporalMCPClient from ._temporal_model import TemporalModel -# Force Strands' base Model.count_tokens to skip tiktoken (non-deterministic) -# and use its chars-per-token heuristic (deterministic). +# Force Strands' base Model.count_tokens to avoid tiktoken, which lazily downloads +# an encoding file. Use the default chars-per-token heuristic instead (deterministic). setattr(_strands_model, "_get_encoding", lambda: None) # Temporal handles retries via RetryPolicy on activity options. Disable From d8c52c1c700465544f9fca2de0791048f799962d Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 18 May 2026 14:38:24 -0700 Subject: [PATCH 25/46] contrib/common: drop leading underscore from auto_heartbeater Cross-module consumers can't be seen by basedpyright when the function is underscore-prefixed, producing a false unused-function warning. Promote the name and add a package docstring. --- temporalio/contrib/common/__init__.py | 1 + temporalio/contrib/common/_heartbeat_decorator.py | 2 +- .../contrib/openai_agents/_invoke_model_activity.py | 8 ++++---- temporalio/contrib/strands/_model_activity.py | 6 +++--- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/temporalio/contrib/common/__init__.py b/temporalio/contrib/common/__init__.py index e69de29bb..a58b7968f 100644 --- a/temporalio/contrib/common/__init__.py +++ b/temporalio/contrib/common/__init__.py @@ -0,0 +1 @@ +"""Shared utilities for temporalio.contrib plugins.""" diff --git a/temporalio/contrib/common/_heartbeat_decorator.py b/temporalio/contrib/common/_heartbeat_decorator.py index 085a5d3e5..7c5b9193d 100644 --- a/temporalio/contrib/common/_heartbeat_decorator.py +++ b/temporalio/contrib/common/_heartbeat_decorator.py @@ -8,7 +8,7 @@ F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) -def _auto_heartbeater(fn: F) -> F: +def auto_heartbeater(fn: F) -> F: """Decorator that heartbeats at half the activity's heartbeat timeout.""" @wraps(fn) diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index d38ffc500..1155d5180 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -43,7 +43,7 @@ from typing_extensions import Required, TypedDict from temporalio import activity -from temporalio.contrib.common._heartbeat_decorator import _auto_heartbeater +from temporalio.contrib.common._heartbeat_decorator import auto_heartbeater from temporalio.contrib.workflow_streams import WorkflowStreamClient from temporalio.exceptions import ApplicationError @@ -314,7 +314,7 @@ def __init__(self, model_provider: ModelProvider | None = None): ) @activity.defn - @_auto_heartbeater + @auto_heartbeater async def invoke_model_activity(self, input: ActivityModelInput) -> ModelResponse: """Activity that invokes a model with the given input.""" model = self._model_provider.get_model(input.get("model_name")) @@ -337,7 +337,7 @@ async def invoke_model_activity(self, input: ActivityModelInput) -> ModelRespons _raise_for_openai_status(e) @activity.defn - @_auto_heartbeater + @auto_heartbeater async def invoke_model_activity_streaming( self, input: StreamingActivityModelInput ) -> list[TResponseStreamEvent]: @@ -357,7 +357,7 @@ async def invoke_model_activity_streaming( ``streaming_topic`` so external consumers (UIs, tracing, etc.) can observe events as they arrive. - Heartbeats run on a background task via ``_auto_heartbeater`` so + Heartbeats run on a background task via ``auto_heartbeater`` so long initial-token latency or long pauses between chunks do not trip ``heartbeat_timeout``. """ diff --git a/temporalio/contrib/strands/_model_activity.py b/temporalio/contrib/strands/_model_activity.py index c33266f49..7bb4e8338 100644 --- a/temporalio/contrib/strands/_model_activity.py +++ b/temporalio/contrib/strands/_model_activity.py @@ -7,7 +7,7 @@ from strands.types.streaming import StreamEvent from temporalio import activity -from temporalio.contrib.common._heartbeat_decorator import _auto_heartbeater +from temporalio.contrib.common._heartbeat_decorator import auto_heartbeater from temporalio.contrib.workflow_streams import WorkflowStreamClient @@ -38,13 +38,13 @@ def __init__(self, model: Model) -> None: self._model = model @activity.defn - @_auto_heartbeater + @auto_heartbeater async def invoke_model(self, input: _InvokeModelInput) -> list[StreamEvent]: """Run the model and return its stream events as a list.""" return [event async for event in _stream(self._model, input)] @activity.defn - @_auto_heartbeater + @auto_heartbeater async def invoke_model_streaming( self, input: _StreamingInvokeModelInput ) -> list[StreamEvent]: From 4c8681d3b4416a4931c4e4e8a039ff5a736cbaa2 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Mon, 18 May 2026 14:48:06 -0700 Subject: [PATCH 26/46] contrib/strands: drop redundant passthrough entries `pydantic` and `temporalio.contrib.strands` are already covered by the SDK's default passthrough (via `pydantic` and `temporalio` in `passthrough_modules_with_temporal`). Update the stale comment in _temporal_mcp_client.py that explained the redundant entry. --- temporalio/contrib/strands/_plugin.py | 7 ++----- temporalio/contrib/strands/_temporal_mcp_client.py | 5 ++--- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index c67b52825..e70172e98 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -127,11 +127,8 @@ def _workflow_runner(runner: WorkflowRunner | None) -> WorkflowRunner: "strands", "strands_tools", "mcp", - "temporalio.contrib.strands", - # The SDK's default passthrough already includes ``pydantic`` because - # it lazy-imports inside some classes; extend that to its compiled - # validation core and its ``Annotated`` helper for the same reason. - "pydantic", + # ``pydantic`` is already in the SDK default passthrough; extend it + # to its compiled validation core and ``Annotated`` helper. "pydantic_core", "annotated_types", ), diff --git a/temporalio/contrib/strands/_temporal_mcp_client.py b/temporalio/contrib/strands/_temporal_mcp_client.py index baaf09a72..0272fd09e 100644 --- a/temporalio/contrib/strands/_temporal_mcp_client.py +++ b/temporalio/contrib/strands/_temporal_mcp_client.py @@ -31,9 +31,8 @@ class _CallToolArgs: # Server name -> cached tool list. Populated by TemporalMCPClient._populate_cache # at worker startup and read by TemporalMCPClient.load_tools() inside the -# workflow sandbox. ``StrandsPlugin`` adds ``temporalio.contrib.strands`` to -# sandbox passthrough so this dict is shared between worker process and -# workflow execution. +# workflow sandbox. ``temporalio`` is in the SDK's default sandbox passthrough, +# so this dict is shared between worker process and workflow execution. _TOOL_CACHE: dict[str, list[_MCPToolInfo]] = {} From 06b60123c8e9c8bf5ad20779c07b03f65a2271c4 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 19 May 2026 12:17:43 -0700 Subject: [PATCH 27/46] contrib/strands: support multiple models and MCP servers per worker Replace the singular model= / mcp_clients=[...] plugin args with name-keyed dicts: StrandsPlugin(models={name: factory}, mcp_clients={name: transport}). TemporalModel and TemporalMCPClient become pure workflow-side handles that reference the worker registration by name and carry only per-call activity options. A single pair of model activities now dispatches to any number of backing models by resolving model_name from the activity input. --- temporalio/contrib/strands/README.md | 88 +++++++++------ temporalio/contrib/strands/_model_activity.py | 32 ++++-- temporalio/contrib/strands/_plugin.py | 55 ++++----- .../contrib/strands/_temporal_mcp_client.py | 104 +++++++++--------- temporalio/contrib/strands/_temporal_model.py | 31 +++--- tests/contrib/strands/test_hooks.py | 28 ++--- tests/contrib/strands/test_interrupt.py | 28 ++--- tests/contrib/strands/test_mcp.py | 51 +++++---- tests/contrib/strands/test_model.py | 14 +-- tests/contrib/strands/test_model_streaming.py | 16 ++- tests/contrib/strands/test_tool.py | 38 ++++--- 11 files changed, 269 insertions(+), 216 deletions(-) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 9a4f1fd7b..edf16d6aa 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -19,19 +19,20 @@ import asyncio from datetime import timedelta from strands import Agent +from strands.models.bedrock import BedrockModel from temporalio import workflow from temporalio.client import Client from temporalio.contrib.strands import StrandsPlugin, TemporalModel from temporalio.worker import Worker -MODEL = TemporalModel(start_to_close_timeout=timedelta(seconds=60)) - - @workflow.defn class MyWorkflow: def __init__(self) -> None: - self.agent = Agent(model=MODEL) + model = TemporalModel( + model_name="bedrock", start_to_close_timeout=timedelta(seconds=60) + ) + self.agent = Agent(model=model) @workflow.run async def run(self, prompt: str) -> str: @@ -45,7 +46,7 @@ async def main() -> None: client, task_queue="strands", workflows=[MyWorkflow], - plugins=[StrandsPlugin(model=MODEL)], + plugins=[StrandsPlugin(models={"bedrock": lambda: BedrockModel()})], ) await worker.run() @@ -81,19 +82,34 @@ if __name__ == "__main__": Note: Use `agent.invoke_async(message)` instead of `agent(message)`. The synchronous form spawns a worker thread, which the workflow sandbox blocks. -## Model +## Models -`TemporalModel` defaults to `BedrockModel()`. To use a different model (or a different `BedrockModel` configuration), pass a `model_factory` lambda. The plugin calls it once at worker startup. +`StrandsPlugin(models=...)` takes a mapping of `name → factory`. Each factory is called lazily on first use (on the worker, outside the workflow sandbox) and the constructed model is cached for the worker's lifetime. Each `TemporalModel(model_name=...)` selects which factory to invoke. ```python from strands.models.anthropic import AnthropicModel +from strands.models.bedrock import BedrockModel -MODEL = TemporalModel( - model_factory=lambda: AnthropicModel(client_args={"api_key": "..."}), - start_to_close_timeout=timedelta(seconds=60), -) +MODELS = { + "claude": lambda: AnthropicModel(client_args={"api_key": "..."}), + "bedrock": lambda: BedrockModel(), +} + +# workflow +@workflow.defn +class MultiModelWorkflow: + def __init__(self) -> None: + claude = TemporalModel(model_name="claude", start_to_close_timeout=timedelta(seconds=60)) + bedrock = TemporalModel(model_name="bedrock", start_to_close_timeout=timedelta(seconds=60)) + self.agent_a = Agent(model=claude) + self.agent_b = Agent(model=bedrock) + +# worker +Worker(..., plugins=[StrandsPlugin(models=MODELS)]) ``` +`TemporalModel` is the per-call handle: each instance carries its own activity options (timeouts, retry policy, task queue, streaming topic) but dispatches to the shared model activity, which resolves `model_name` against the registered factories at runtime. A `model_name` not present in `models` raises `ValueError` inside the activity. + ## Retries The plugin disables Strands' built-in `ModelRetryStrategy` so retries are handled exclusively by Temporal. Configure retries via `RetryPolicy` on the activity options accepted by `TemporalModel`, `workflow.activity_as_tool`, `workflow.activity_as_hook`, and `TemporalMCPClient`: @@ -101,7 +117,8 @@ The plugin disables Strands' built-in `ModelRetryStrategy` so retries are handle ```python from temporalio.common import RetryPolicy -MODEL = TemporalModel( +TemporalModel( + model_name="bedrock", start_to_close_timeout=timedelta(seconds=60), retry_policy=RetryPolicy(maximum_attempts=3), ) @@ -127,7 +144,8 @@ class PersonInfo(BaseModel): @workflow.defn class MyWorkflow: def __init__(self) -> None: - self.agent = Agent(model=MODEL, structured_output_model=PersonInfo) + model = TemporalModel(model_name="bedrock", start_to_close_timeout=timedelta(seconds=60)) + self.agent = Agent(model=model, structured_output_model=PersonInfo) @workflow.run async def run(self, prompt: str) -> PersonInfo: @@ -142,14 +160,13 @@ class MyWorkflow: To forward model chunks to external consumers, pass `streaming_topic="..."` to `TemporalModel` and host a `WorkflowStream` on the workflow. Each `StreamEvent` is published on the named topic from inside the model activity; subscribers read via `WorkflowStreamClient`. Chunks are batched on `streaming_batch_interval` (default 100ms). ```python -MODEL = TemporalModel(streaming_topic="events") - # workflow @workflow.defn class MyWorkflow: def __init__(self) -> None: self.stream = WorkflowStream() - self.agent = Agent(model=MODEL) + model = TemporalModel(model_name="bedrock", streaming_topic="events") + self.agent = Agent(model=model) # client async for item in WorkflowStreamClient.create(client, workflow_id).subscribe( @@ -184,7 +201,7 @@ agent = Agent(tools=[ Worker( ..., activities=[fetch_user, current_time_activity], - plugins=[StrandsPlugin(model=MODEL)], + plugins=[StrandsPlugin(models=MODELS)], ) ``` @@ -250,7 +267,8 @@ class ApprovalHook(HookProvider): @workflow.defn class MyWorkflow: def __init__(self) -> None: - self.agent = Agent(model=MODEL, tools=[delete_thing], hooks=[ApprovalHook()]) + model = TemporalModel(model_name="bedrock", start_to_close_timeout=timedelta(seconds=60)) + self.agent = Agent(model=model, tools=[delete_thing], hooks=[ApprovalHook()]) self._approval: str | None = None @workflow.signal @@ -285,6 +303,7 @@ class ChatInput: @workflow.defn class ChatWorkflow: def __init__(self) -> None: + self._model = TemporalModel(model_name="bedrock", start_to_close_timeout=timedelta(seconds=60)) self._pending: list[str] = [] self._done = False @@ -298,7 +317,7 @@ class ChatWorkflow: @workflow.run async def run(self, input: ChatInput) -> None: - agent = Agent(model=MODEL, messages=list(input.messages)) + agent = Agent(model=self._model, messages=list(input.messages)) while True: await workflow.wait_condition(lambda: self._pending or self._done) if self._done: @@ -310,28 +329,33 @@ class ChatWorkflow: ## MCP -Construct `TemporalMCPClient` once at module level and reference the same instance from both the plugin (which registers a per-server `{server}-call-tool` activity and connects at worker startup to discover tools) and `Agent(tools=[...])`: +`StrandsPlugin(mcp_clients=...)` takes a mapping of `name → transport factory`, mirroring the `models=` pattern. The plugin registers a per-server `{name}-call-tool` activity and connects at worker startup to enumerate tools. Workflow-side, `TemporalMCPClient(server="name")` is a pure handle: it references the server by name and carries the per-call activity options. ```python from mcp import StdioServerParameters, stdio_client from temporalio.contrib.strands import TemporalMCPClient -ECHO = TemporalMCPClient( - server="echo", - transport_factory=lambda: stdio_client( - StdioServerParameters(command="...", args=[...]), - ), - start_to_close_timeout=timedelta(seconds=30), -) - # workflow -agent = Agent(tools=[ECHO]) +@workflow.defn +class MyWorkflow: + def __init__(self) -> None: + echo = TemporalMCPClient(server="echo", start_to_close_timeout=timedelta(seconds=30)) + self.agent = Agent(tools=[echo]) # worker -Worker(..., plugins=[StrandsPlugin(mcp_clients=[ECHO])]) +Worker( + ..., + plugins=[StrandsPlugin( + mcp_clients={ + "echo": lambda: stdio_client( + StdioServerParameters(command="...", args=[...]), + ), + }, + )], +) ``` -The plugin connects to the MCP server once at worker startup to enumerate tools. The schema is frozen for the worker's lifetime; restart workers to pick up MCP-server changes. If the MCP server is unavailable at startup, the worker fails to start. +The plugin connects to each MCP server once at worker startup to enumerate tools. The schema is frozen for the worker's lifetime; restart workers to pick up MCP-server changes. If a server is unavailable at startup, the worker fails to start. ## Observability @@ -347,7 +371,7 @@ Worker( client, task_queue="strands", workflows=[MyWorkflow], - plugins=[StrandsPlugin(model=MODEL), OpenTelemetryPlugin()], + plugins=[StrandsPlugin(models=MODELS), OpenTelemetryPlugin()], ) ``` diff --git a/temporalio/contrib/strands/_model_activity.py b/temporalio/contrib/strands/_model_activity.py index 7bb4e8338..f0626dce9 100644 --- a/temporalio/contrib/strands/_model_activity.py +++ b/temporalio/contrib/strands/_model_activity.py @@ -1,4 +1,4 @@ -from collections.abc import AsyncIterable +from collections.abc import AsyncIterable, Callable from dataclasses import dataclass from datetime import timedelta from typing import Any @@ -17,6 +17,7 @@ # through unchanged to ``Model.stream`` which accepts the raw dicts. @dataclass class _InvokeModelInput: + model_name: str messages: Any tool_specs: Any = None system_prompt: str | None = None @@ -31,31 +32,44 @@ class _StreamingInvokeModelInput(_InvokeModelInput): class ModelActivity: - """Holds the user-supplied model and exposes the model activities.""" + """Holds the registered model factories and exposes the model activities.""" - def __init__(self, model: Model) -> None: - """Store the model that activities will invoke.""" - self._model = model + def __init__(self, factories: dict[str, Callable[[], Model]]) -> None: + """Store the factories; models are constructed lazily on first use.""" + self._factories = factories + self._models: dict[str, Model] = {} + + def _get_model(self, name: str) -> Model: + if name not in self._models: + if name not in self._factories: + raise ValueError( + f"Unknown model name {name!r}. " + f"Known: {sorted(self._factories)}" + ) + self._models[name] = self._factories[name]() + return self._models[name] @activity.defn @auto_heartbeater async def invoke_model(self, input: _InvokeModelInput) -> list[StreamEvent]: - """Run the model and return its stream events as a list.""" - return [event async for event in _stream(self._model, input)] + """Run the named model and return its stream events as a list.""" + model = self._get_model(input.model_name) + return [event async for event in _stream(model, input)] @activity.defn @auto_heartbeater async def invoke_model_streaming( self, input: _StreamingInvokeModelInput ) -> list[StreamEvent]: - """Run the model and publish each stream event to a WorkflowStream.""" + """Run the named model and publish each stream event to a WorkflowStream.""" + model = self._get_model(input.model_name) events: list[StreamEvent] = [] stream = WorkflowStreamClient.from_within_activity( batch_interval=timedelta(seconds=input.streaming_batch_interval_seconds), ) topic = stream.topic(input.streaming_topic) async with stream: - async for event in _stream(self._model, input): + async for event in _stream(model, input): events.append(event) topic.publish(event) return events diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index e70172e98..83d421b2f 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -5,6 +5,8 @@ import strands.agent.agent as _strands_agent import strands.models.model as _strands_model +from strands.models import Model +from strands.tools.mcp.mcp_types import MCPTransport from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.converter import DataConverter, DefaultPayloadConverter @@ -12,8 +14,12 @@ from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner -from ._temporal_mcp_client import TemporalMCPClient -from ._temporal_model import TemporalModel +from ._model_activity import ModelActivity +from ._temporal_mcp_client import ( + _build_call_tool_activity, + _clear_cache, + _populate_cache, +) # Force Strands' base Model.count_tokens to avoid tiktoken, which lazily downloads # an encoding file. Use the default chars-per-token heuristic instead (deterministic). @@ -67,46 +73,43 @@ class StrandsPlugin(SimplePlugin): workflow code), and swaps in ``pydantic_data_converter`` so structured outputs serialize. - When ``model`` is supplied, calls its ``model_factory`` once on the worker - to construct the real model, then registers the model invocation activities - against it. The same :class:`TemporalModel` is also passed to - ``Agent(model=...)`` inside the workflow. + When ``models`` is supplied, registers a single pair of model invocation + activities; each call carries the chosen ``model_name`` in its input and + the worker resolves it against the factories. Factories are called lazily + on first use, then cached for the worker's lifetime. Use the same name in + ``TemporalModel(model_name=...)`` inside the workflow. - When ``mcp_clients`` is supplied, registers per-server ``{server}-call-tool`` - activities and, at worker startup, connects to each MCP server and caches - its tool list. Workflow-side ``TemporalMCPClient.load_tools()`` reads from - the cache. The plugin raises if any two clients share the same ``server``. + When ``mcp_clients`` is supplied, registers a per-server + ``{server}-call-tool`` activity for each entry and, at worker startup, + connects to each MCP server to cache its tool list. Workflow-side + ``TemporalMCPClient(server="...").load_tools()`` reads from the cache. """ def __init__( self, *, - model: TemporalModel | None = None, - mcp_clients: list[TemporalMCPClient] = [], + models: dict[str, Callable[[], Model]] | None = None, + mcp_clients: dict[str, Callable[[], MCPTransport]] | None = None, ) -> None: - """Build the plugin from an optional model and MCP client list.""" + """Build the plugin from optional model and MCP transport factories.""" activities: list[Callable] = [] - if model is not None: - ma = model._build_activity() + if models: + ma = ModelActivity(models) activities.extend([ma.invoke_model, ma.invoke_model_streaming]) - names = [c.server for c in mcp_clients] - if len(names) != len(set(names)): - raise ValueError( - "Duplicate MCP server names in mcp_clients; each must be unique." - ) - for c in mcp_clients: - activities.extend(c._get_activities()) + mcp_clients = mcp_clients or {} + for server, transport_factory in mcp_clients.items(): + activities.append(_build_call_tool_activity(server, transport_factory)) @asynccontextmanager async def run_context() -> AsyncGenerator[None, None]: - for c in mcp_clients: - await c._populate_cache() + for server, transport_factory in mcp_clients.items(): + await _populate_cache(server, transport_factory) try: yield finally: - for c in mcp_clients: - c._clear_cache() + for server in mcp_clients: + _clear_cache(server) super().__init__( "aws.StrandsPlugin", diff --git a/temporalio/contrib/strands/_temporal_mcp_client.py b/temporalio/contrib/strands/_temporal_mcp_client.py index 0272fd09e..0ec37200f 100644 --- a/temporalio/contrib/strands/_temporal_mcp_client.py +++ b/temporalio/contrib/strands/_temporal_mcp_client.py @@ -29,29 +29,29 @@ class _CallToolArgs: tool_use_id: str = "" -# Server name -> cached tool list. Populated by TemporalMCPClient._populate_cache -# at worker startup and read by TemporalMCPClient.load_tools() inside the -# workflow sandbox. ``temporalio`` is in the SDK's default sandbox passthrough, -# so this dict is shared between worker process and workflow execution. +# Server name -> cached tool list. Populated by ``_populate_cache`` at worker +# startup and read by ``TemporalMCPClient.load_tools()`` inside the workflow +# sandbox. ``temporalio`` is in the SDK's default sandbox passthrough, so this +# dict is shared between worker process and workflow execution. _TOOL_CACHE: dict[str, list[_MCPToolInfo]] = {} class TemporalMCPClient(ToolProvider): - """An MCP server reference for use in both worker and workflow contexts. + """Workflow-side handle to an MCP server registered on the worker. - Construct once at module level. Pass to ``StrandsPlugin(mcp_clients=[...])`` - on the worker (which registers the ``{server}-call-tool`` activity and runs - ``list_tools`` at worker startup), and to ``Agent(tools=[...])`` inside the - workflow (which adds the discovered tools to the agent's registry). + The transport factory and tool discovery live worker-side via + ``StrandsPlugin(mcp_clients={"server": lambda: ...})``. This handle only + carries the server name (which selects the registered factory) and the + per-call activity options. - Construction does no I/O. The actual MCP connection happens worker-side at - plugin startup; each tool call later runs as a Temporal activity. + Construct once at module level and pass to ``Agent(tools=[...])`` inside + the workflow. Multiple handles may reference the same server name with + different activity options. """ def __init__( self, server: str, - transport_factory: Callable[[], MCPTransport], *, task_queue: str | None = None, schedule_to_close_timeout: timedelta | None = None, @@ -64,9 +64,8 @@ def __init__( summary: str | None = None, priority: Priority = Priority.default, ) -> None: - """Configure the server name, transport factory, and activity options.""" + """Configure the server name and activity options.""" self._server = server - self._transport_factory = transport_factory self._options: dict[str, Any] = { "task_queue": task_queue, "schedule_to_close_timeout": schedule_to_close_timeout, @@ -100,43 +99,48 @@ def remove_consumer(self, consumer_id: Any, **_kwargs: Any) -> None: """No-op; consumer tracking is handled by the underlying MCP client.""" return None - async def _populate_cache(self) -> None: - """Connect to the MCP server, list tools, fill ``_TOOL_CACHE``.""" - client = MCPClient(self._transport_factory) - try: - infos: list[_MCPToolInfo] = [] - for tool in await client.load_tools(): - if not isinstance(tool, MCPAgentTool): - continue - infos.append( - _MCPToolInfo( - name=tool.mcp_tool.name, - description=tool.mcp_tool.description or "", - input_schema=tool.mcp_tool.inputSchema, - output_schema=tool.mcp_tool.outputSchema, - ) + +async def _populate_cache( + server: str, transport_factory: Callable[[], MCPTransport] +) -> None: + """Connect to the MCP server, list tools, fill ``_TOOL_CACHE``.""" + client = MCPClient(transport_factory) + try: + infos: list[_MCPToolInfo] = [] + for tool in await client.load_tools(): + if not isinstance(tool, MCPAgentTool): + continue + infos.append( + _MCPToolInfo( + name=tool.mcp_tool.name, + description=tool.mcp_tool.description or "", + input_schema=tool.mcp_tool.inputSchema, + output_schema=tool.mcp_tool.outputSchema, ) - _TOOL_CACHE[self._server] = infos + ) + _TOOL_CACHE[server] = infos + finally: + client.stop(None, None, None) + + +def _clear_cache(server: str) -> None: + _TOOL_CACHE.pop(server, None) + + +def _build_call_tool_activity( + server: str, transport_factory: Callable[[], MCPTransport] +) -> Callable: + @activity.defn(name=f"{server}-call-tool") + async def call_tool(args: _CallToolArgs) -> MCPToolResult: + client = MCPClient(transport_factory) + client.start() + try: + return await client.call_tool_async( + tool_use_id=args.tool_use_id, + name=args.tool_name, + arguments=args.arguments, + ) finally: client.stop(None, None, None) - def _clear_cache(self) -> None: - _TOOL_CACHE.pop(self._server, None) - - def _get_activities(self) -> Sequence[Callable]: - transport_factory = self._transport_factory - - @activity.defn(name=f"{self._server}-call-tool") - async def call_tool(args: _CallToolArgs) -> MCPToolResult: - client = MCPClient(transport_factory) - client.start() - try: - return await client.call_tool_async( - tool_use_id=args.tool_use_id, - name=args.tool_name, - arguments=args.arguments, - ) - finally: - client.stop(None, None, None) - - return [call_tool] + return call_tool diff --git a/temporalio/contrib/strands/_temporal_model.py b/temporalio/contrib/strands/_temporal_model.py index f547a266d..f77e68ad8 100644 --- a/temporalio/contrib/strands/_temporal_model.py +++ b/temporalio/contrib/strands/_temporal_model.py @@ -1,9 +1,8 @@ -from collections.abc import AsyncIterable, Callable +from collections.abc import AsyncIterable from datetime import timedelta from typing import Any from strands.models import Model -from strands.models.bedrock import BedrockModel from strands.types.content import Messages, SystemContentBlock from strands.types.streaming import StreamEvent from strands.types.tools import ToolChoice, ToolSpec @@ -22,15 +21,14 @@ class TemporalModel(Model): """A Strands :class:`Model` that runs ``stream()`` as a Temporal activity. - ``model_factory`` is called once on the worker (when the plugin is - constructed) to produce the real model used inside the activity. Defaults - to :class:`strands.models.bedrock.BedrockModel`, matching Strands' default. - Construction of this :class:`TemporalModel` itself does no I/O, so it is - safe to instantiate at module level — the lambda is just stored. + ``model_name`` selects which factory the plugin will invoke worker-side; it + must match a key in ``StrandsPlugin(models={...})``. Construction of this + :class:`TemporalModel` itself does no I/O, so it is safe to instantiate at + module level. - Pass the same instance to ``StrandsPlugin(model=...)`` (so the plugin can - register the model's activities) and to ``Agent(model=...)`` inside the - workflow (so the agent dispatches through that activity). + Pass this instance to ``Agent(model=...)`` inside the workflow; each call + dispatches through the registered model activity with ``model_name`` in + the input, and the worker resolves it against the plugin's factories. When ``streaming_topic`` is set, each ``StreamEvent`` is also published to the named topic on the workflow's @@ -40,8 +38,8 @@ class TemporalModel(Model): def __init__( self, + model_name: str, *, - model_factory: Callable[[], Model] = BedrockModel, task_queue: str | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, @@ -55,8 +53,8 @@ def __init__( streaming_topic: str | None = None, streaming_batch_interval: timedelta = timedelta(milliseconds=100), ) -> None: - """Configure the model factory, activity options, and streaming settings.""" - self._model_factory = model_factory + """Configure the model name, activity options, and streaming settings.""" + self._model_name = model_name self._streaming_topic = streaming_topic self._streaming_batch_interval = streaming_batch_interval self._options: dict[str, Any] = { @@ -72,11 +70,8 @@ def __init__( "priority": priority, } - def _build_activity(self) -> ModelActivity: - return ModelActivity(self._model_factory()) - def update_config(self, **_model_config: Any) -> None: - """No-op; the real model is configured worker-side via ``model_factory``.""" + """No-op; the real model is configured worker-side via the plugin's factories.""" return None def get_config(self) -> dict[str, Any]: @@ -107,6 +102,7 @@ async def stream( events = await workflow.execute_activity_method( ModelActivity.invoke_model_streaming, _StreamingInvokeModelInput( + model_name=self._model_name, messages=messages, tool_specs=tool_specs, system_prompt=system_prompt, @@ -121,6 +117,7 @@ async def stream( events = await workflow.execute_activity_method( ModelActivity.invoke_model, _InvokeModelInput( + model_name=self._model_name, messages=messages, tool_specs=tool_specs, system_prompt=system_prompt, diff --git a/tests/contrib/strands/test_hooks.py b/tests/contrib/strands/test_hooks.py index 3511cb3d5..12a197ebc 100644 --- a/tests/contrib/strands/test_hooks.py +++ b/tests/contrib/strands/test_hooks.py @@ -48,22 +48,15 @@ def _sync_log(self, event: AfterToolCallEvent) -> None: self.fired_events.append(event.tool_use["name"]) -MODEL = TemporalModel( - model_factory=lambda: MockModel( - [ - {"name": "echo", "input": {"text": "hi"}}, - "Done!", - ] - ), - start_to_close_timeout=timedelta(seconds=15), -) - - @workflow.defn class HooksWorkflow: def __init__(self) -> None: + model = TemporalModel( + model_name="mock", + start_to_close_timeout=timedelta(seconds=15), + ) self.hook = AuditHook() - self.agent = Agent(model=MODEL, tools=[echo], hooks=[self.hook]) + self.agent = Agent(model=model, tools=[echo], hooks=[self.hook]) @workflow.run async def run(self, prompt: str) -> list[str]: @@ -74,7 +67,16 @@ async def run(self, prompt: str) -> list[str]: async def test_hooks(client: Client): _AUDIT_LOG.clear() task_queue = "test_hooks" - plugin = StrandsPlugin(model=MODEL) + plugin = StrandsPlugin( + models={ + "mock": lambda: MockModel( + [ + {"name": "echo", "input": {"text": "hi"}}, + "Done!", + ] + ) + } + ) async with Worker( client, diff --git a/tests/contrib/strands/test_interrupt.py b/tests/contrib/strands/test_interrupt.py index 06b4f529e..b98e701b7 100644 --- a/tests/contrib/strands/test_interrupt.py +++ b/tests/contrib/strands/test_interrupt.py @@ -34,21 +34,14 @@ def _gate(self, event: BeforeToolCallEvent) -> None: event.cancel_tool = "denied" -MODEL = TemporalModel( - model_factory=lambda: MockModel( - [ - {"name": "delete_thing", "input": {"name": "foo"}}, - "Done!", - ] - ), - start_to_close_timeout=timedelta(seconds=15), -) - - @workflow.defn class InterruptWorkflow: def __init__(self) -> None: - self.agent = Agent(model=MODEL, tools=[delete_thing], hooks=[ApprovalHook()]) + model = TemporalModel( + model_name="mock", + start_to_close_timeout=timedelta(seconds=15), + ) + self.agent = Agent(model=model, tools=[delete_thing], hooks=[ApprovalHook()]) self._approval: str | None = None @workflow.signal @@ -72,7 +65,16 @@ async def run(self, prompt: str) -> str: async def test_interrupt(client: Client): task_queue = "test_interrupt" - plugin = StrandsPlugin(model=MODEL) + plugin = StrandsPlugin( + models={ + "mock": lambda: MockModel( + [ + {"name": "delete_thing", "input": {"name": "foo"}}, + "Done!", + ] + ) + } + ) async with Worker( client, diff --git a/tests/contrib/strands/test_mcp.py b/tests/contrib/strands/test_mcp.py index a0ae7b5ee..bfa3e1cff 100644 --- a/tests/contrib/strands/test_mcp.py +++ b/tests/contrib/strands/test_mcp.py @@ -17,32 +17,18 @@ from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel -ECHO = TemporalMCPClient( - server="echo", - transport_factory=lambda: stdio_client( - StdioServerParameters( - command=sys.executable, - args=[str(Path(__file__).parent / "echo_mcp_server.py")], - ) - ), - start_to_close_timeout=timedelta(seconds=30), -) - -MODEL = TemporalModel( - model_factory=lambda: MockModel( - [ - {"name": "echo", "input": {"message": "hello"}}, - "Done!", - ] - ), - start_to_close_timeout=timedelta(seconds=30), -) - - @workflow.defn class MCPWorkflow: def __init__(self) -> None: - self.agent = Agent(model=MODEL, tools=[ECHO]) + model = TemporalModel( + model_name="mock", + start_to_close_timeout=timedelta(seconds=30), + ) + echo = TemporalMCPClient( + server="echo", + start_to_close_timeout=timedelta(seconds=30), + ) + self.agent = Agent(model=model, tools=[echo]) @workflow.run async def run(self, prompt: str) -> str: @@ -52,7 +38,24 @@ async def run(self, prompt: str) -> str: async def test_mcp(client: Client): task_queue = "test_mcp" - plugin = StrandsPlugin(model=MODEL, mcp_clients=[ECHO]) + plugin = StrandsPlugin( + models={ + "mock": lambda: MockModel( + [ + {"name": "echo", "input": {"message": "hello"}}, + "Done!", + ] + ) + }, + mcp_clients={ + "echo": lambda: stdio_client( + StdioServerParameters( + command=sys.executable, + args=[str(Path(__file__).parent / "echo_mcp_server.py")], + ) + ), + }, + ) async with Worker( client, diff --git a/tests/contrib/strands/test_model.py b/tests/contrib/strands/test_model.py index 1a734c619..9ade66f0e 100644 --- a/tests/contrib/strands/test_model.py +++ b/tests/contrib/strands/test_model.py @@ -10,16 +10,14 @@ from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel -MODEL = TemporalModel( - model_factory=lambda: MockModel(["Done!"]), - start_to_close_timeout=timedelta(seconds=15), -) - - @workflow.defn class ModelWorkflow: def __init__(self) -> None: - self.agent = Agent(model=MODEL) + model = TemporalModel( + model_name="mock", + start_to_close_timeout=timedelta(seconds=15), + ) + self.agent = Agent(model=model) @workflow.run async def run(self, prompt: str) -> str: @@ -29,7 +27,7 @@ async def run(self, prompt: str) -> str: async def test_model(client: Client): task_queue = "test_model" - plugin = StrandsPlugin(model=MODEL) + plugin = StrandsPlugin(models={"mock": lambda: MockModel(["Done!"])}) async with Worker( client, diff --git a/tests/contrib/strands/test_model_streaming.py b/tests/contrib/strands/test_model_streaming.py index e35e30fe1..4b531e8cc 100644 --- a/tests/contrib/strands/test_model_streaming.py +++ b/tests/contrib/strands/test_model_streaming.py @@ -13,18 +13,16 @@ from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel -MODEL = TemporalModel( - model_factory=lambda: MockModel(["Done!"]), - start_to_close_timeout=timedelta(seconds=15), - streaming_topic="events", -) - - @workflow.defn class StreamingModelWorkflow: def __init__(self) -> None: self.stream = WorkflowStream() - self.agent = Agent(model=MODEL) + model = TemporalModel( + model_name="mock", + start_to_close_timeout=timedelta(seconds=15), + streaming_topic="events", + ) + self.agent = Agent(model=model) @workflow.run async def run(self, prompt: str) -> str: @@ -34,7 +32,7 @@ async def run(self, prompt: str) -> str: async def test_model_streaming(client: Client): task_queue = "test_model_streaming" - plugin = StrandsPlugin(model=MODEL) + plugin = StrandsPlugin(models={"mock": lambda: MockModel(["Done!"])}) workflow_id = f"test_model_streaming_{uuid4()}" async with Worker( diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index 8d01ad580..04ad95e65 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -26,24 +26,15 @@ async def current_time_activity() -> str: return current_time.current_time() -MODEL = TemporalModel( - model_factory=lambda: MockModel( - [ - {"name": "current_time", "input": {}}, - {"name": "calculator", "input": {"expression": "3111696 / 74088"}}, - {"name": "letter_counter", "input": {"word": "strawberry", "letter": "R"}}, - "Done!", - ] - ), - start_to_close_timeout=timedelta(seconds=15), -) - - @workflow.defn class ToolWorkflow: def __init__(self) -> None: + model = TemporalModel( + model_name="mock", + start_to_close_timeout=timedelta(seconds=15), + ) self.agent = Agent( - model=MODEL, + model=model, tools=[ calculator, activity_as_tool( @@ -62,7 +53,24 @@ async def run(self, prompt: str) -> str: async def test_tool(client: Client): task_queue = "test_tool" - plugin = StrandsPlugin(model=MODEL) + plugin = StrandsPlugin( + models={ + "mock": lambda: MockModel( + [ + {"name": "current_time", "input": {}}, + { + "name": "calculator", + "input": {"expression": "3111696 / 74088"}, + }, + { + "name": "letter_counter", + "input": {"word": "strawberry", "letter": "R"}, + }, + "Done!", + ] + ) + } + ) async with Worker( client, From f078cd1cc67e74e157e9385e7b4a43cc544b2417 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 19 May 2026 13:08:23 -0700 Subject: [PATCH 28/46] contrib/strands: introduce TemporalAgent, drop Agent monkey-patches TemporalAgent(Agent) is the primary user-facing class: it takes model="name" to select a factory registered with StrandsPlugin(models=...), accepts the per-call activity options, and forwards all other kwargs to Strands' Agent. Construction-time validation of retry_strategy and overrides of take_snapshot/load_snapshot replace the previous Agent.__init__ and snapshot monkey-patches in StrandsPlugin. TemporalModel is no longer exported; it remains as internal plumbing for TemporalAgent. --- temporalio/contrib/strands/README.md | 84 ++++++++++-------- temporalio/contrib/strands/__init__.py | 4 +- temporalio/contrib/strands/_plugin.py | 43 +--------- temporalio/contrib/strands/_temporal_agent.py | 85 +++++++++++++++++++ tests/contrib/strands/test_hooks.py | 13 +-- tests/contrib/strands/test_interrupt.py | 11 +-- tests/contrib/strands/test_mcp.py | 14 +-- tests/contrib/strands/test_model.py | 10 +-- tests/contrib/strands/test_model_streaming.py | 9 +- tests/contrib/strands/test_tool.py | 11 +-- 10 files changed, 168 insertions(+), 116 deletions(-) create mode 100644 temporalio/contrib/strands/_temporal_agent.py diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index edf16d6aa..98b88a1c3 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -18,21 +18,20 @@ uv add temporalio[strands] import asyncio from datetime import timedelta -from strands import Agent from strands.models.bedrock import BedrockModel from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.contrib.strands import StrandsPlugin, TemporalAgent from temporalio.worker import Worker + @workflow.defn class MyWorkflow: def __init__(self) -> None: - model = TemporalModel( - model_name="bedrock", start_to_close_timeout=timedelta(seconds=60) + self.agent = TemporalAgent( + model="bedrock", start_to_close_timeout=timedelta(seconds=60) ) - self.agent = Agent(model=model) @workflow.run async def run(self, prompt: str) -> str: @@ -84,7 +83,7 @@ Note: Use `agent.invoke_async(message)` instead of `agent(message)`. The synchro ## Models -`StrandsPlugin(models=...)` takes a mapping of `name → factory`. Each factory is called lazily on first use (on the worker, outside the workflow sandbox) and the constructed model is cached for the worker's lifetime. Each `TemporalModel(model_name=...)` selects which factory to invoke. +`StrandsPlugin(models=...)` takes a mapping of `name → factory`. Each factory is called lazily on first use (on the worker, outside the workflow sandbox) and the constructed model is cached for the worker's lifetime. `TemporalAgent(model="name", ...)` selects which factory to invoke and carries the activity options for that agent's model calls. ```python from strands.models.anthropic import AnthropicModel @@ -99,40 +98,38 @@ MODELS = { @workflow.defn class MultiModelWorkflow: def __init__(self) -> None: - claude = TemporalModel(model_name="claude", start_to_close_timeout=timedelta(seconds=60)) - bedrock = TemporalModel(model_name="bedrock", start_to_close_timeout=timedelta(seconds=60)) - self.agent_a = Agent(model=claude) - self.agent_b = Agent(model=bedrock) + self.agent_a = TemporalAgent(model="claude", start_to_close_timeout=timedelta(seconds=60)) + self.agent_b = TemporalAgent(model="bedrock", start_to_close_timeout=timedelta(seconds=60)) # worker Worker(..., plugins=[StrandsPlugin(models=MODELS)]) ``` -`TemporalModel` is the per-call handle: each instance carries its own activity options (timeouts, retry policy, task queue, streaming topic) but dispatches to the shared model activity, which resolves `model_name` against the registered factories at runtime. A `model_name` not present in `models` raises `ValueError` inside the activity. +Each `TemporalAgent` carries its own activity options (timeouts, retry policy, task queue, streaming topic) and dispatches to the shared model activity, which resolves the model name against the registered factories at runtime. A name not present in `models` raises `ValueError` inside the activity. ## Retries -The plugin disables Strands' built-in `ModelRetryStrategy` so retries are handled exclusively by Temporal. Configure retries via `RetryPolicy` on the activity options accepted by `TemporalModel`, `workflow.activity_as_tool`, `workflow.activity_as_hook`, and `TemporalMCPClient`: +`TemporalAgent` disables Strands' built-in `ModelRetryStrategy` so retries are handled exclusively by Temporal. Configure retries via `retry_policy` on `TemporalAgent`, and on the activity options accepted by `workflow.activity_as_tool`, `workflow.activity_as_hook`, and `TemporalMCPClient`: ```python from temporalio.common import RetryPolicy -TemporalModel( - model_name="bedrock", +TemporalAgent( + model="bedrock", start_to_close_timeout=timedelta(seconds=60), retry_policy=RetryPolicy(maximum_attempts=3), ) ``` -Passing `retry_strategy=...` to `Agent(...)` raises `ValueError`; remove the argument (or pass `retry_strategy=None`) and put the retry config on the activity options instead. +Passing `retry_strategy=...` to `TemporalAgent(...)` raises `ValueError`; remove the argument (or pass `retry_strategy=None`) and put the retry config on the activity options instead. ## Snapshots -The plugin disables `Agent.take_snapshot()` and `Agent.load_snapshot()`. Temporal's event history already persists workflow state durably at a finer granularity than Strands snapshots, so calling either inside a workflow is redundant. Both methods raise `NotImplementedError`. +`TemporalAgent.take_snapshot()` and `TemporalAgent.load_snapshot()` raise `NotImplementedError`. Temporal's event history already persists workflow state durably at a finer granularity than Strands snapshots, so calling either inside a workflow is redundant. ## Structured Output -Pass a Pydantic model to `Agent(structured_output_model=...)`. Strands routes the call through `stream()` as a synthetic tool, so it dispatches via the model activity like any other invocation. The result is available as `result.structured_output` and can be returned directly from the workflow — `StrandsPlugin` defaults to [`pydantic_data_converter`](../pydantic), so Pydantic types serialize across the activity and workflow boundary. +Pass a Pydantic model as `structured_output_model=`. Strands routes the call through `stream()` as a synthetic tool, so it dispatches via the model activity like any other invocation. The result is available as `result.structured_output` and can be returned directly from the workflow — `StrandsPlugin` defaults to [`pydantic_data_converter`](../pydantic), so Pydantic types serialize across the activity and workflow boundary. ```python from pydantic import BaseModel @@ -144,8 +141,11 @@ class PersonInfo(BaseModel): @workflow.defn class MyWorkflow: def __init__(self) -> None: - model = TemporalModel(model_name="bedrock", start_to_close_timeout=timedelta(seconds=60)) - self.agent = Agent(model=model, structured_output_model=PersonInfo) + self.agent = TemporalAgent( + model="bedrock", + start_to_close_timeout=timedelta(seconds=60), + structured_output_model=PersonInfo, + ) @workflow.run async def run(self, prompt: str) -> PersonInfo: @@ -153,11 +153,9 @@ class MyWorkflow: return result.structured_output ``` -`TemporalModel.structured_output()` called directly is not supported — always go through `Agent(structured_output_model=...)`. - ## Streaming -To forward model chunks to external consumers, pass `streaming_topic="..."` to `TemporalModel` and host a `WorkflowStream` on the workflow. Each `StreamEvent` is published on the named topic from inside the model activity; subscribers read via `WorkflowStreamClient`. Chunks are batched on `streaming_batch_interval` (default 100ms). +To forward model chunks to external consumers, pass `streaming_topic="..."` to `TemporalAgent` and host a `WorkflowStream` on the workflow. Each `StreamEvent` is published on the named topic from inside the model activity; subscribers read via `WorkflowStreamClient`. Chunks are batched on `streaming_batch_interval` (default 100ms). ```python # workflow @@ -165,8 +163,7 @@ To forward model chunks to external consumers, pass `streaming_topic="..."` to ` class MyWorkflow: def __init__(self) -> None: self.stream = WorkflowStream() - model = TemporalModel(model_name="bedrock", streaming_topic="events") - self.agent = Agent(model=model) + self.agent = TemporalAgent(model="bedrock", streaming_topic="events") # client async for item in WorkflowStreamClient.create(client, workflow_id).subscribe( @@ -192,10 +189,14 @@ async def current_time_activity() -> str: return current_time.current_time() # workflow -agent = Agent(tools=[ - strands_workflow.activity_as_tool(fetch_user, start_to_close_timeout=timedelta(seconds=30)), - strands_workflow.activity_as_tool(current_time_activity, start_to_close_timeout=timedelta(seconds=15)), -]) +agent = TemporalAgent( + model="bedrock", + start_to_close_timeout=timedelta(seconds=60), + tools=[ + strands_workflow.activity_as_tool(fetch_user, start_to_close_timeout=timedelta(seconds=30)), + strands_workflow.activity_as_tool(current_time_activity, start_to_close_timeout=timedelta(seconds=15)), + ], +) # worker Worker( @@ -207,7 +208,7 @@ Worker( ## Hooks -Strands' [hook system](https://strandsagents.com/) (`strands.hooks`) lets you subscribe callbacks to events in the agent lifecycle — invocation start/end, model call before/after, tool call before/after, message added. The native `Agent(hooks=[MyHookProvider()])` API works as-is: every single-agent hook event fires in workflow context, so deterministic callbacks just work. +Strands' [hook system](https://strandsagents.com/) (`strands.hooks`) lets you subscribe callbacks to events in the agent lifecycle — invocation start/end, model call before/after, tool call before/after, message added. Pass `hooks=[MyHookProvider()]` to `TemporalAgent`: every single-agent hook event fires in workflow context, so deterministic callbacks just work. ```python from strands.hooks import HookProvider, HookRegistry @@ -221,7 +222,7 @@ class AuditHook(HookProvider): # Pure local state - deterministic across replay. workflow.logger.info(f"tool {event.tool_use['name']} finished") -agent = Agent(hooks=[AuditHook()]) +agent = TemporalAgent(model="bedrock", start_to_close_timeout=..., hooks=[AuditHook()]) ``` Callbacks run in workflow context, so they must be deterministic: no `time.time()`, `uuid.uuid4()`, or I/O — same rules as workflow code. For callbacks that need I/O (audit logging, metrics, alerting), use `workflow.activity_as_hook()` to dispatch the work as a Temporal activity: @@ -267,8 +268,12 @@ class ApprovalHook(HookProvider): @workflow.defn class MyWorkflow: def __init__(self) -> None: - model = TemporalModel(model_name="bedrock", start_to_close_timeout=timedelta(seconds=60)) - self.agent = Agent(model=model, tools=[delete_thing], hooks=[ApprovalHook()]) + self.agent = TemporalAgent( + model="bedrock", + start_to_close_timeout=timedelta(seconds=60), + tools=[delete_thing], + hooks=[ApprovalHook()], + ) self._approval: str | None = None @workflow.signal @@ -286,7 +291,7 @@ class MyWorkflow: return str(result) ``` -Interrupt hooks must be deterministic: branch on the activity result and call `event.interrupt(...)` on the workflow side. Tools wrapped via `workflow.activity_as_tool` cannot raise interrupts — the activity body has no `Agent` reference — so hooks are the interrupt surface for this plugin. +Interrupt hooks must be deterministic: branch on the activity result and call `event.interrupt(...)` on the workflow side. Tools wrapped via `workflow.activity_as_tool` cannot raise interrupts — the activity body has no agent reference — so hooks are the interrupt surface for this plugin. ## Continue-as-new @@ -303,7 +308,6 @@ class ChatInput: @workflow.defn class ChatWorkflow: def __init__(self) -> None: - self._model = TemporalModel(model_name="bedrock", start_to_close_timeout=timedelta(seconds=60)) self._pending: list[str] = [] self._done = False @@ -317,7 +321,11 @@ class ChatWorkflow: @workflow.run async def run(self, input: ChatInput) -> None: - agent = Agent(model=self._model, messages=list(input.messages)) + agent = TemporalAgent( + model="bedrock", + start_to_close_timeout=timedelta(seconds=60), + messages=list(input.messages), + ) while True: await workflow.wait_condition(lambda: self._pending or self._done) if self._done: @@ -340,7 +348,11 @@ from temporalio.contrib.strands import TemporalMCPClient class MyWorkflow: def __init__(self) -> None: echo = TemporalMCPClient(server="echo", start_to_close_timeout=timedelta(seconds=30)) - self.agent = Agent(tools=[echo]) + self.agent = TemporalAgent( + model="bedrock", + start_to_close_timeout=timedelta(seconds=60), + tools=[echo], + ) # worker Worker( diff --git a/temporalio/contrib/strands/__init__.py b/temporalio/contrib/strands/__init__.py index 857cf46a0..39a8e7401 100644 --- a/temporalio/contrib/strands/__init__.py +++ b/temporalio/contrib/strands/__init__.py @@ -2,12 +2,12 @@ from . import workflow from ._plugin import StrandsPlugin +from ._temporal_agent import TemporalAgent from ._temporal_mcp_client import TemporalMCPClient -from ._temporal_model import TemporalModel __all__ = [ "StrandsPlugin", + "TemporalAgent", "TemporalMCPClient", - "TemporalModel", "workflow", ] diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index 83d421b2f..b4adc1a95 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -1,9 +1,7 @@ from collections.abc import AsyncGenerator, Callable from contextlib import asynccontextmanager from dataclasses import replace -from typing import Any -import strands.agent.agent as _strands_agent import strands.models.model as _strands_model from strands.models import Model from strands.tools.mcp.mcp_types import MCPTransport @@ -25,45 +23,6 @@ # an encoding file. Use the default chars-per-token heuristic instead (deterministic). setattr(_strands_model, "_get_encoding", lambda: None) -# Temporal handles retries via RetryPolicy on activity options. Disable -# Strands' in-activity ModelRetryStrategy (default max_attempts=6) so retries -# aren't duplicated, and fail fast if the user tries to configure one. -_original_agent_init = _strands_agent.Agent.__init__ -_RETRY_STRATEGY_NOT_PASSED: Any = object() - - -def _patched_agent_init(self: Any, *args: Any, **kwargs: Any) -> None: - retry_strategy = kwargs.get("retry_strategy", _RETRY_STRATEGY_NOT_PASSED) - if retry_strategy is not _RETRY_STRATEGY_NOT_PASSED and retry_strategy is not None: - raise ValueError( - "StrandsPlugin disables Strands retries; configure retries via " - "RetryPolicy on the activity options passed to TemporalModel, " - "workflow.activity_as_tool, workflow.activity_as_hook, or TemporalMCPClient. " - "Remove retry_strategy from Agent(...) or pass retry_strategy=None." - ) - kwargs["retry_strategy"] = None - _original_agent_init(self, *args, **kwargs) - - -setattr(_strands_agent.Agent, "__init__", _patched_agent_init) - - -# Temporal workflows already persist agent state durably via the event history at -# a finer granularity than Strands snapshots, so calling either method inside a -# workflow is redundant; fail loudly to steer users to Temporal's durability. -def _snapshots_disabled(*args: Any, **kwargs: Any) -> Any: - del args, kwargs - raise NotImplementedError( - "StrandsPlugin disables Agent.take_snapshot()/load_snapshot(). " - "Temporal workflows already persist agent state durably via the event " - "history at a finer granularity than Strands snapshots. Remove the " - "snapshot call and rely on Temporal's durable execution instead." - ) - - -setattr(_strands_agent.Agent, "take_snapshot", _snapshots_disabled) -setattr(_strands_agent.Agent, "load_snapshot", _snapshots_disabled) - class StrandsPlugin(SimplePlugin): """Temporal Worker plugin for the Strands Agents SDK. @@ -77,7 +36,7 @@ class StrandsPlugin(SimplePlugin): activities; each call carries the chosen ``model_name`` in its input and the worker resolves it against the factories. Factories are called lazily on first use, then cached for the worker's lifetime. Use the same name in - ``TemporalModel(model_name=...)`` inside the workflow. + ``TemporalAgent(model=...)`` inside the workflow. When ``mcp_clients`` is supplied, registers a per-server ``{server}-call-tool`` activity for each entry and, at worker startup, diff --git a/temporalio/contrib/strands/_temporal_agent.py b/temporalio/contrib/strands/_temporal_agent.py new file mode 100644 index 000000000..989ae0cce --- /dev/null +++ b/temporalio/contrib/strands/_temporal_agent.py @@ -0,0 +1,85 @@ +from datetime import timedelta +from typing import Any + +from strands import Agent + +from temporalio.common import Priority, RetryPolicy +from temporalio.workflow import ActivityCancellationType, VersioningIntent + +from ._temporal_model import TemporalModel + +_SNAPSHOT_DISABLED = ( + "TemporalAgent disables take_snapshot()/load_snapshot(). Temporal " + "workflows already persist agent state durably via the event history at " + "a finer granularity than Strands snapshots. Remove the snapshot call " + "and rely on Temporal's durable execution instead." +) + + +class TemporalAgent(Agent): + """A Strands :class:`Agent` that routes model calls through a Temporal activity. + + ``model`` is the name of a factory registered in + ``StrandsPlugin(models={...})``. The activity options apply to every model + invocation this agent makes. All other keyword arguments are forwarded to + Strands' :class:`Agent` (``tools``, ``hooks``, ``system_prompt``, + ``structured_output_model``, ``messages``, etc.). + + Strands' ``retry_strategy`` is disabled; configure retries via + ``retry_policy`` here and on the activity options accepted by + ``activity_as_tool``, ``activity_as_hook``, and ``TemporalMCPClient``. + """ + + def __init__( + self, + *, + model: str, + task_queue: str | None = None, + schedule_to_close_timeout: timedelta | None = None, + schedule_to_start_timeout: timedelta | None = None, + start_to_close_timeout: timedelta | None = None, + heartbeat_timeout: timedelta | None = None, + retry_policy: RetryPolicy | None = None, + cancellation_type: ActivityCancellationType = ActivityCancellationType.TRY_CANCEL, + versioning_intent: VersioningIntent | None = None, + summary: str | None = None, + priority: Priority = Priority.default, + streaming_topic: str | None = None, + streaming_batch_interval: timedelta = timedelta(milliseconds=100), + **agent_kwargs: Any, + ) -> None: + """Build a TemporalAgent from a registered model name and activity options.""" + if agent_kwargs.get("retry_strategy") is not None: + raise ValueError( + "TemporalAgent disables Strands retries; configure retries via " + "retry_policy on TemporalAgent and on the activity options " + "passed to workflow.activity_as_tool, workflow.activity_as_hook, " + "or TemporalMCPClient. Remove retry_strategy from " + "TemporalAgent(...) or pass retry_strategy=None." + ) + agent_kwargs["retry_strategy"] = None + + temporal_model = TemporalModel( + model_name=model, + task_queue=task_queue, + schedule_to_close_timeout=schedule_to_close_timeout, + schedule_to_start_timeout=schedule_to_start_timeout, + start_to_close_timeout=start_to_close_timeout, + heartbeat_timeout=heartbeat_timeout, + retry_policy=retry_policy, + cancellation_type=cancellation_type, + versioning_intent=versioning_intent, + summary=summary, + priority=priority, + streaming_topic=streaming_topic, + streaming_batch_interval=streaming_batch_interval, + ) + super().__init__(model=temporal_model, **agent_kwargs) + + def take_snapshot(self, *_args: Any, **_kwargs: Any) -> Any: + """Disabled; Temporal's event history is the source of truth.""" + raise NotImplementedError(_SNAPSHOT_DISABLED) + + def load_snapshot(self, *_args: Any, **_kwargs: Any) -> Any: + """Disabled; Temporal's event history is the source of truth.""" + raise NotImplementedError(_SNAPSHOT_DISABLED) diff --git a/tests/contrib/strands/test_hooks.py b/tests/contrib/strands/test_hooks.py index 12a197ebc..1fc4ff282 100644 --- a/tests/contrib/strands/test_hooks.py +++ b/tests/contrib/strands/test_hooks.py @@ -1,13 +1,13 @@ from datetime import timedelta from uuid import uuid4 -from strands import Agent, tool +from strands import tool from strands.hooks import HookProvider, HookRegistry from strands.hooks.events import AfterToolCallEvent from temporalio import activity, workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.contrib.strands import StrandsPlugin, TemporalAgent from temporalio.contrib.strands.workflow import activity_as_hook from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities @@ -51,12 +51,13 @@ def _sync_log(self, event: AfterToolCallEvent) -> None: @workflow.defn class HooksWorkflow: def __init__(self) -> None: - model = TemporalModel( - model_name="mock", + self.hook = AuditHook() + self.agent = TemporalAgent( + model="mock", start_to_close_timeout=timedelta(seconds=15), + tools=[echo], + hooks=[self.hook], ) - self.hook = AuditHook() - self.agent = Agent(model=model, tools=[echo], hooks=[self.hook]) @workflow.run async def run(self, prompt: str) -> list[str]: diff --git a/tests/contrib/strands/test_interrupt.py b/tests/contrib/strands/test_interrupt.py index b98e701b7..8a04d6c5e 100644 --- a/tests/contrib/strands/test_interrupt.py +++ b/tests/contrib/strands/test_interrupt.py @@ -1,14 +1,14 @@ from datetime import timedelta from uuid import uuid4 -from strands import Agent, tool +from strands import tool from strands.hooks import HookProvider, HookRegistry from strands.hooks.events import BeforeToolCallEvent from strands.types.interrupt import InterruptResponseContent from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.contrib.strands import StrandsPlugin, TemporalAgent from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel @@ -37,11 +37,12 @@ def _gate(self, event: BeforeToolCallEvent) -> None: @workflow.defn class InterruptWorkflow: def __init__(self) -> None: - model = TemporalModel( - model_name="mock", + self.agent = TemporalAgent( + model="mock", start_to_close_timeout=timedelta(seconds=15), + tools=[delete_thing], + hooks=[ApprovalHook()], ) - self.agent = Agent(model=model, tools=[delete_thing], hooks=[ApprovalHook()]) self._approval: str | None = None @workflow.signal diff --git a/tests/contrib/strands/test_mcp.py b/tests/contrib/strands/test_mcp.py index bfa3e1cff..98af52ba8 100644 --- a/tests/contrib/strands/test_mcp.py +++ b/tests/contrib/strands/test_mcp.py @@ -4,31 +4,31 @@ from uuid import uuid4 from mcp import StdioServerParameters, stdio_client -from strands import Agent from temporalio import workflow from temporalio.client import Client from temporalio.contrib.strands import ( StrandsPlugin, + TemporalAgent, TemporalMCPClient, - TemporalModel, ) from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel + @workflow.defn class MCPWorkflow: def __init__(self) -> None: - model = TemporalModel( - model_name="mock", - start_to_close_timeout=timedelta(seconds=30), - ) echo = TemporalMCPClient( server="echo", start_to_close_timeout=timedelta(seconds=30), ) - self.agent = Agent(model=model, tools=[echo]) + self.agent = TemporalAgent( + model="mock", + start_to_close_timeout=timedelta(seconds=30), + tools=[echo], + ) @workflow.run async def run(self, prompt: str) -> str: diff --git a/tests/contrib/strands/test_model.py b/tests/contrib/strands/test_model.py index 9ade66f0e..d467638d0 100644 --- a/tests/contrib/strands/test_model.py +++ b/tests/contrib/strands/test_model.py @@ -1,23 +1,21 @@ from datetime import timedelta from uuid import uuid4 -from strands import Agent - from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.contrib.strands import StrandsPlugin, TemporalAgent from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel + @workflow.defn class ModelWorkflow: def __init__(self) -> None: - model = TemporalModel( - model_name="mock", + self.agent = TemporalAgent( + model="mock", start_to_close_timeout=timedelta(seconds=15), ) - self.agent = Agent(model=model) @workflow.run async def run(self, prompt: str) -> str: diff --git a/tests/contrib/strands/test_model_streaming.py b/tests/contrib/strands/test_model_streaming.py index 4b531e8cc..d711e5e97 100644 --- a/tests/contrib/strands/test_model_streaming.py +++ b/tests/contrib/strands/test_model_streaming.py @@ -2,27 +2,26 @@ from datetime import timedelta from uuid import uuid4 -from strands import Agent from strands.types.streaming import StreamEvent from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.contrib.strands import StrandsPlugin, TemporalAgent from temporalio.contrib.workflow_streams import WorkflowStream, WorkflowStreamClient from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities from tests.contrib.strands.mock_model import MockModel + @workflow.defn class StreamingModelWorkflow: def __init__(self) -> None: self.stream = WorkflowStream() - model = TemporalModel( - model_name="mock", + self.agent = TemporalAgent( + model="mock", start_to_close_timeout=timedelta(seconds=15), streaming_topic="events", ) - self.agent = Agent(model=model) @workflow.run async def run(self, prompt: str) -> str: diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index 04ad95e65..904f3eb88 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -1,7 +1,7 @@ from datetime import timedelta from uuid import uuid4 -from strands import Agent, tool +from strands import tool from strands_tools import ( # pyright: ignore[reportMissingTypeStubs] calculator, current_time, @@ -9,7 +9,7 @@ from temporalio import activity, workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin, TemporalModel +from temporalio.contrib.strands import StrandsPlugin, TemporalAgent from temporalio.contrib.strands.workflow import activity_as_tool from temporalio.worker import Replayer, Worker from tests.contrib.strands.common import get_activities @@ -29,12 +29,9 @@ async def current_time_activity() -> str: @workflow.defn class ToolWorkflow: def __init__(self) -> None: - model = TemporalModel( - model_name="mock", + self.agent = TemporalAgent( + model="mock", start_to_close_timeout=timedelta(seconds=15), - ) - self.agent = Agent( - model=model, tools=[ calculator, activity_as_tool( From 34b4017d9475de98a1a81474225a83d898dfdbfb Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 19 May 2026 14:11:32 -0700 Subject: [PATCH 29/46] contrib/strands: register OpenTelemetryPlugin on the client in the README Per the OpenTelemetry plugin's own guidance, plugins register on the client so workers built from that client pick them up automatically. Update the Observability section accordingly, plus minor wording polish in the Models and Structured Output sections. --- temporalio/contrib/strands/README.md | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 98b88a1c3..60aa14197 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -89,20 +89,24 @@ Note: Use `agent.invoke_async(message)` instead of `agent(message)`. The synchro from strands.models.anthropic import AnthropicModel from strands.models.bedrock import BedrockModel -MODELS = { - "claude": lambda: AnthropicModel(client_args={"api_key": "..."}), - "bedrock": lambda: BedrockModel(), -} - # workflow @workflow.defn class MultiModelWorkflow: def __init__(self) -> None: - self.agent_a = TemporalAgent(model="claude", start_to_close_timeout=timedelta(seconds=60)) - self.agent_b = TemporalAgent(model="bedrock", start_to_close_timeout=timedelta(seconds=60)) + self.agent_a = TemporalAgent( + model="claude", + start_to_close_timeout=timedelta(seconds=60), + ) + self.agent_b = TemporalAgent( + model="bedrock", + start_to_close_timeout=timedelta(seconds=60), + ) # worker -Worker(..., plugins=[StrandsPlugin(models=MODELS)]) +Worker(..., plugins=[StrandsPlugin(models={ + "claude": lambda: AnthropicModel(client_args={"api_key": "..."}), + "bedrock": lambda: BedrockModel(), +})]) ``` Each `TemporalAgent` carries its own activity options (timeouts, retry policy, task queue, streaming topic) and dispatches to the shared model activity, which resolves the model name against the registered factories at runtime. A name not present in `models` raises `ValueError` inside the activity. @@ -129,7 +133,7 @@ Passing `retry_strategy=...` to `TemporalAgent(...)` raises `ValueError`; remove ## Structured Output -Pass a Pydantic model as `structured_output_model=`. Strands routes the call through `stream()` as a synthetic tool, so it dispatches via the model activity like any other invocation. The result is available as `result.structured_output` and can be returned directly from the workflow — `StrandsPlugin` defaults to [`pydantic_data_converter`](../pydantic), so Pydantic types serialize across the activity and workflow boundary. +Like Strands `Agent`, `TemporalAgent` supports structured output with `structured_output_model`. The plugin defaults to [`pydantic_data_converter`](../pydantic), so Pydantic types easily serialize across the activity and workflow boundary. ```python from pydantic import BaseModel @@ -371,7 +375,7 @@ The plugin connects to each MCP server once at worker startup to enumerate tools ## Observability -`StrandsPlugin` composes cleanly with [`OpenTelemetryPlugin`](../opentelemetry) — add both to the worker to get OTel spans around the model, tool, and MCP activities the plugin schedules, plus any spans Strands itself emits inside `invoke_async`: +`StrandsPlugin` composes cleanly with [`OpenTelemetryPlugin`](../opentelemetry). Register `OpenTelemetryPlugin` on the client (workers built from that client pick it up automatically) and `StrandsPlugin` on the worker. You'll get OTel spans around the model, tool, and MCP activities the plugin schedules, plus any spans Strands itself emits inside `invoke_async`: ```python import opentelemetry.trace @@ -379,11 +383,13 @@ from temporalio.contrib.opentelemetry import OpenTelemetryPlugin, create_tracer_ opentelemetry.trace.set_tracer_provider(create_tracer_provider()) +client = await Client.connect("localhost:7233", plugins=[OpenTelemetryPlugin()]) + Worker( client, task_queue="strands", workflows=[MyWorkflow], - plugins=[StrandsPlugin(models=MODELS), OpenTelemetryPlugin()], + plugins=[StrandsPlugin(models=MODELS)], ) ``` From 95159e2f4458d8f1ad72d49b7587b0c12a3e879f Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 19 May 2026 14:17:44 -0700 Subject: [PATCH 30/46] contrib/strands: set max_cached_workflows=0 in tests Force every workflow task to replay from full history so the strands tests double as a continuous determinism check on the plugin and TemporalAgent. All 7 tests pass under the stricter setting. Also trims a redundant paragraph from StrandsPlugin's docstring. --- temporalio/contrib/strands/_plugin.py | 5 ----- tests/contrib/strands/test_hooks.py | 1 + tests/contrib/strands/test_interrupt.py | 1 + tests/contrib/strands/test_mcp.py | 1 + tests/contrib/strands/test_model.py | 1 + tests/contrib/strands/test_model_streaming.py | 1 + tests/contrib/strands/test_structured_output.py | 1 + tests/contrib/strands/test_tool.py | 1 + 8 files changed, 7 insertions(+), 5 deletions(-) diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index b4adc1a95..23b7abba4 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -27,11 +27,6 @@ class StrandsPlugin(SimplePlugin): """Temporal Worker plugin for the Strands Agents SDK. - Configures sandbox passthrough for ``strands``, ``strands_tools``, ``mcp``, - and ``temporalio.contrib.strands`` (so the MCP tool cache is visible to - workflow code), and swaps in ``pydantic_data_converter`` so structured - outputs serialize. - When ``models`` is supplied, registers a single pair of model invocation activities; each call carries the chosen ``model_name`` in its input and the worker resolves it against the factories. Factories are called lazily diff --git a/tests/contrib/strands/test_hooks.py b/tests/contrib/strands/test_hooks.py index 1fc4ff282..19976cb44 100644 --- a/tests/contrib/strands/test_hooks.py +++ b/tests/contrib/strands/test_hooks.py @@ -85,6 +85,7 @@ async def test_hooks(client: Client): workflows=[HooksWorkflow], activities=[audit_tool], plugins=[plugin], + max_cached_workflows=0, ): handle = await client.start_workflow( HooksWorkflow.run, diff --git a/tests/contrib/strands/test_interrupt.py b/tests/contrib/strands/test_interrupt.py index 8a04d6c5e..64f72bc07 100644 --- a/tests/contrib/strands/test_interrupt.py +++ b/tests/contrib/strands/test_interrupt.py @@ -82,6 +82,7 @@ async def test_interrupt(client: Client): task_queue=task_queue, workflows=[InterruptWorkflow], plugins=[plugin], + max_cached_workflows=0, ): handle = await client.start_workflow( InterruptWorkflow.run, diff --git a/tests/contrib/strands/test_mcp.py b/tests/contrib/strands/test_mcp.py index 98af52ba8..d03a3d7b4 100644 --- a/tests/contrib/strands/test_mcp.py +++ b/tests/contrib/strands/test_mcp.py @@ -62,6 +62,7 @@ async def test_mcp(client: Client): task_queue=task_queue, workflows=[MCPWorkflow], plugins=[plugin], + max_cached_workflows=0, ): handle = await client.start_workflow( MCPWorkflow.run, diff --git a/tests/contrib/strands/test_model.py b/tests/contrib/strands/test_model.py index d467638d0..68d578e5c 100644 --- a/tests/contrib/strands/test_model.py +++ b/tests/contrib/strands/test_model.py @@ -32,6 +32,7 @@ async def test_model(client: Client): task_queue=task_queue, workflows=[ModelWorkflow], plugins=[plugin], + max_cached_workflows=0, ): handle = await client.start_workflow( ModelWorkflow.run, diff --git a/tests/contrib/strands/test_model_streaming.py b/tests/contrib/strands/test_model_streaming.py index d711e5e97..41f3ff2f1 100644 --- a/tests/contrib/strands/test_model_streaming.py +++ b/tests/contrib/strands/test_model_streaming.py @@ -39,6 +39,7 @@ async def test_model_streaming(client: Client): task_queue=task_queue, workflows=[StreamingModelWorkflow], plugins=[plugin], + max_cached_workflows=0, ): handle = await client.start_workflow( StreamingModelWorkflow.run, diff --git a/tests/contrib/strands/test_structured_output.py b/tests/contrib/strands/test_structured_output.py index 4d7dd885b..8018a9d16 100644 --- a/tests/contrib/strands/test_structured_output.py +++ b/tests/contrib/strands/test_structured_output.py @@ -52,6 +52,7 @@ async def test_structured_output(client: Client): client, task_queue=task_queue, workflows=[StructuredOutputWorkflow], + max_cached_workflows=0, ): handle = await client.start_workflow( StructuredOutputWorkflow.run, diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index 904f3eb88..e89f6f9b7 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -75,6 +75,7 @@ async def test_tool(client: Client): workflows=[ToolWorkflow], activities=[current_time_activity], plugins=[plugin], + max_cached_workflows=0, ): handle = await client.start_workflow( ToolWorkflow.run, From 71c4f68e4893c9b356d3a7a4861d34ab2d192e3f Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 19 May 2026 14:22:12 -0700 Subject: [PATCH 31/46] contrib/strands: appease poe lint MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Drop the leading underscore from populate_cache / clear_cache / build_call_tool_activity in _temporal_mcp_client.py — basedpyright flagged them as unused because it doesn't follow cross-module imports for underscore-prefixed names. Add docstrings since pydocstyle now treats them as public. Also pick up a one-line ruff format fix in _model_activity.py. --- temporalio/contrib/strands/_model_activity.py | 3 +-- temporalio/contrib/strands/_plugin.py | 12 ++++++------ temporalio/contrib/strands/_temporal_mcp_client.py | 9 ++++++--- 3 files changed, 13 insertions(+), 11 deletions(-) diff --git a/temporalio/contrib/strands/_model_activity.py b/temporalio/contrib/strands/_model_activity.py index f0626dce9..87009f0f7 100644 --- a/temporalio/contrib/strands/_model_activity.py +++ b/temporalio/contrib/strands/_model_activity.py @@ -43,8 +43,7 @@ def _get_model(self, name: str) -> Model: if name not in self._models: if name not in self._factories: raise ValueError( - f"Unknown model name {name!r}. " - f"Known: {sorted(self._factories)}" + f"Unknown model name {name!r}. Known: {sorted(self._factories)}" ) self._models[name] = self._factories[name]() return self._models[name] diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index 23b7abba4..f1de2710c 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -14,9 +14,9 @@ from ._model_activity import ModelActivity from ._temporal_mcp_client import ( - _build_call_tool_activity, - _clear_cache, - _populate_cache, + build_call_tool_activity, + clear_cache, + populate_cache, ) # Force Strands' base Model.count_tokens to avoid tiktoken, which lazily downloads @@ -53,17 +53,17 @@ def __init__( mcp_clients = mcp_clients or {} for server, transport_factory in mcp_clients.items(): - activities.append(_build_call_tool_activity(server, transport_factory)) + activities.append(build_call_tool_activity(server, transport_factory)) @asynccontextmanager async def run_context() -> AsyncGenerator[None, None]: for server, transport_factory in mcp_clients.items(): - await _populate_cache(server, transport_factory) + await populate_cache(server, transport_factory) try: yield finally: for server in mcp_clients: - _clear_cache(server) + clear_cache(server) super().__init__( "aws.StrandsPlugin", diff --git a/temporalio/contrib/strands/_temporal_mcp_client.py b/temporalio/contrib/strands/_temporal_mcp_client.py index 0ec37200f..542782594 100644 --- a/temporalio/contrib/strands/_temporal_mcp_client.py +++ b/temporalio/contrib/strands/_temporal_mcp_client.py @@ -100,7 +100,7 @@ def remove_consumer(self, consumer_id: Any, **_kwargs: Any) -> None: return None -async def _populate_cache( +async def populate_cache( server: str, transport_factory: Callable[[], MCPTransport] ) -> None: """Connect to the MCP server, list tools, fill ``_TOOL_CACHE``.""" @@ -123,13 +123,16 @@ async def _populate_cache( client.stop(None, None, None) -def _clear_cache(server: str) -> None: +def clear_cache(server: str) -> None: + """Drop the cached tool list for ``server``.""" _TOOL_CACHE.pop(server, None) -def _build_call_tool_activity( +def build_call_tool_activity( server: str, transport_factory: Callable[[], MCPTransport] ) -> Callable: + """Return the per-server ``{server}-call-tool`` activity for registration.""" + @activity.defn(name=f"{server}-call-tool") async def call_tool(args: _CallToolArgs) -> MCPToolResult: client = MCPClient(transport_factory) From 300fdc5acf3d390134ee866c9d2b98f141c1fa4d Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 19 May 2026 15:21:00 -0700 Subject: [PATCH 32/46] contrib/strands: propagate InterruptException across activity boundary Install a failure converter on the plugin's data converter that translates strands InterruptException into an ApplicationError carrying the Interrupt payload in details. TemporalActivityTool.stream() catches the matching ApplicationError, reconstructs the Interrupt, and yields ToolInterruptEvent so AgentResult.interrupts is populated just like the in-workflow case. The path requires StrandsPlugin on the client (not just the worker), since _ActivityWorker reads the data converter from client_config. README HITL section is restructured to cover both hook-based and tool-body surfaces, with a note on the client-attachment requirement. New test_interrupt_exception.py exercises both surfaces end-to-end with signal-driven resume. --- temporalio/contrib/strands/README.md | 53 ++++- .../contrib/strands/_failure_converter.py | 40 ++++ temporalio/contrib/strands/_plugin.py | 6 +- .../strands/_temporal_activity_tool.py | 40 ++-- .../strands/test_interrupt_exception.py | 191 ++++++++++++++++++ 5 files changed, 314 insertions(+), 16 deletions(-) create mode 100644 temporalio/contrib/strands/_failure_converter.py create mode 100644 tests/contrib/strands/test_interrupt_exception.py diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 60aa14197..460a43363 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -255,7 +255,11 @@ class AuditHook(HookProvider): ## Human-in-the-loop interrupts -A hook on an interruptible event (e.g. `BeforeToolCallEvent`) can pause the agent by calling `event.interrupt(name, reason=...)`. When this fires, `agent.invoke_async()` returns `AgentResult(stop_reason="interrupt", interrupts=[...])` instead of raising. Pair this with a signal handler that supplies responses, then resume by calling `agent.invoke_async(responses)`: +Strands offers two HITL surfaces; both work with the plugin. In each case, `agent.invoke_async()` returns `AgentResult(stop_reason="interrupt", interrupts=[...])` instead of raising. Pair this with a signal handler that supplies responses, then resume by calling `agent.invoke_async(responses)`. + +### Hook-based interrupts + +A hook on an interruptible event (e.g. `BeforeToolCallEvent`) can pause the agent by calling `event.interrupt(name, reason=...)`. The hook runs in workflow context, so it must be deterministic — no I/O. ```python from strands.hooks import HookProvider, HookRegistry @@ -295,7 +299,52 @@ class MyWorkflow: return str(result) ``` -Interrupt hooks must be deterministic: branch on the activity result and call `event.interrupt(...)` on the workflow side. Tools wrapped via `workflow.activity_as_tool` cannot raise interrupts — the activity body has no agent reference — so hooks are the interrupt surface for this plugin. +### Tool-body interrupts + +A `@strands.tool` function can raise `InterruptException(Interrupt(...))` directly. The agent stops with the interrupt, the workflow handles the resume the same way as for hooks. + +```python +from strands import tool +from strands.interrupt import Interrupt, InterruptException + +@tool +def delete_thing(name: str) -> str: + raise InterruptException( + Interrupt(id=f"delete:{name}", name="approval", reason=f"delete {name}?") + ) +``` + +The same works from an `activity_as_tool`-wrapped activity. The plugin's failure converter preserves the `Interrupt` payload across the activity boundary, so `AgentResult.interrupts` is populated just like the in-workflow case: + +```python +from strands.interrupt import Interrupt, InterruptException +from temporalio.contrib.strands.workflow import activity_as_tool + +@activity.defn +async def delete_thing(name: str) -> str: + if not await policy.is_authorized(name): + raise InterruptException( + Interrupt(id=f"delete:{name}", name="approval", reason=f"delete {name}?") + ) + await storage.delete(name) + return f"deleted {name}" + +@workflow.defn +class MyWorkflow: + def __init__(self) -> None: + self.agent = TemporalAgent( + model="bedrock", + start_to_close_timeout=timedelta(seconds=60), + tools=[activity_as_tool(delete_thing, start_to_close_timeout=timedelta(seconds=10))], + ) +``` + +This relies on the plugin's failure converter, which is installed via the client's data converter. **Attach `StrandsPlugin` to the client** (not just the worker) for activity-tool interrupts to work — workers built from that client pick up the plugin automatically. + +```python +client = await Client.connect("localhost:7233", plugins=[StrandsPlugin(models=MODELS)]) +Worker(client, task_queue="strands", workflows=[MyWorkflow], activities=[delete_thing]) +``` ## Continue-as-new diff --git a/temporalio/contrib/strands/_failure_converter.py b/temporalio/contrib/strands/_failure_converter.py new file mode 100644 index 000000000..407316723 --- /dev/null +++ b/temporalio/contrib/strands/_failure_converter.py @@ -0,0 +1,40 @@ +"""Failure converter that preserves Strands ``InterruptException`` payloads.""" + +from strands.interrupt import InterruptException + +import temporalio.api.failure.v1 +from temporalio.converter import DefaultFailureConverter, PayloadConverter +from temporalio.exceptions import ApplicationError + +# Activity-side: when a Strands ``InterruptException`` would otherwise be +# serialized by the default converter, the ``Interrupt`` payload on +# ``exc.interrupt`` is dropped (it lives on the instance, not in the +# serialized ApplicationError). We translate to a typed ApplicationError so +# the interrupt data survives the activity boundary and the workflow side +# can rebuild a real ``Interrupt``. +STRANDS_INTERRUPT_TYPE = "StrandsInterrupt" + + +class StrandsFailureConverter(DefaultFailureConverter): + """Failure converter that preserves Strands ``InterruptException`` payloads.""" + + def to_failure( + self, + exception: BaseException, + payload_converter: PayloadConverter, + failure: temporalio.api.failure.v1.Failure, + ) -> None: + """Translate ``InterruptException`` to a typed ``ApplicationError``.""" + if isinstance(exception, InterruptException): + super().to_failure( + ApplicationError( + f"interrupt:{exception.interrupt.name}", + exception.interrupt.to_dict(), + type=STRANDS_INTERRUPT_TYPE, + non_retryable=True, + ), + payload_converter, + failure, + ) + return + super().to_failure(exception, payload_converter, failure) diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index f1de2710c..63ceb01ed 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -12,6 +12,7 @@ from temporalio.worker import WorkflowRunner from temporalio.worker.workflow_sandbox import SandboxedWorkflowRunner +from ._failure_converter import StrandsFailureConverter from ._model_activity import ModelActivity from ._temporal_mcp_client import ( build_call_tool_activity, @@ -98,5 +99,8 @@ def _data_converter(converter: DataConverter | None) -> DataConverter: converter is None or converter.payload_converter_class is DefaultPayloadConverter ): - return pydantic_data_converter + return replace( + pydantic_data_converter, + failure_converter_class=StrandsFailureConverter, + ) return converter diff --git a/temporalio/contrib/strands/_temporal_activity_tool.py b/temporalio/contrib/strands/_temporal_activity_tool.py index c753e0e60..bb838834e 100644 --- a/temporalio/contrib/strands/_temporal_activity_tool.py +++ b/temporalio/contrib/strands/_temporal_activity_tool.py @@ -3,11 +3,15 @@ from collections.abc import Callable from typing import Any +from strands.interrupt import Interrupt from strands.tools.decorator import FunctionToolMetadata -from strands.types._events import ToolResultEvent +from strands.types._events import ToolInterruptEvent, ToolResultEvent from strands.types.tools import AgentTool, ToolGenerator, ToolResult, ToolSpec, ToolUse from temporalio import activity, workflow +from temporalio.exceptions import ActivityError, ApplicationError + +from ._failure_converter import STRANDS_INTERRUPT_TYPE class TemporalActivityTool(AgentTool): @@ -51,18 +55,28 @@ async def stream( bound = self._signature.bind(**tool_use["input"]) bound.apply_defaults() positional = list(bound.arguments.values()) - if not positional: - result = await workflow.execute_activity( - self._activity_name, **self._options - ) - elif len(positional) == 1: - result = await workflow.execute_activity( - self._activity_name, positional[0], **self._options - ) - else: - result = await workflow.execute_activity( - self._activity_name, args=positional, **self._options - ) + try: + if not positional: + result = await workflow.execute_activity( + self._activity_name, **self._options + ) + elif len(positional) == 1: + result = await workflow.execute_activity( + self._activity_name, positional[0], **self._options + ) + else: + result = await workflow.execute_activity( + self._activity_name, args=positional, **self._options + ) + except ActivityError as e: + cause = e.__cause__ + if ( + isinstance(cause, ApplicationError) + and cause.type == STRANDS_INTERRUPT_TYPE + ): + yield ToolInterruptEvent(tool_use, [Interrupt(**cause.details[0])]) + return + raise yield ToolResultEvent( ToolResult( toolUseId=tool_use["toolUseId"], diff --git a/tests/contrib/strands/test_interrupt_exception.py b/tests/contrib/strands/test_interrupt_exception.py new file mode 100644 index 000000000..ed858b32b --- /dev/null +++ b/tests/contrib/strands/test_interrupt_exception.py @@ -0,0 +1,191 @@ +from datetime import timedelta +from uuid import uuid4 + +from strands import tool +from strands.interrupt import Interrupt, InterruptException +from strands.types.interrupt import InterruptResponseContent + +from temporalio import activity, workflow +from temporalio.client import Client +from temporalio.contrib.strands import StrandsPlugin, TemporalAgent +from temporalio.contrib.strands.workflow import activity_as_tool +from temporalio.worker import Replayer, Worker +from tests.contrib.strands.common import get_activities +from tests.contrib.strands.mock_model import MockModel + + +@tool +def in_workflow_delete(name: str) -> str: + raise InterruptException( + Interrupt(id=f"delete:{name}", name="approval", reason=f"delete {name}?") + ) + + +# Counts attempts so the activity raises on the first invocation and succeeds on +# the second — modeling a real "approval flipped an external flag" check. +_activity_delete_calls = 0 + + +@activity.defn +async def activity_delete(name: str) -> str: + global _activity_delete_calls + _activity_delete_calls += 1 + if _activity_delete_calls == 1: + raise InterruptException( + Interrupt(id=f"delete:{name}", name="approval", reason=f"delete {name}?") + ) + return f"deleted {name}" + + +@workflow.defn +class InWorkflowToolInterruptWorkflow: + def __init__(self) -> None: + self.agent = TemporalAgent( + model="mock", + start_to_close_timeout=timedelta(seconds=15), + tools=[in_workflow_delete], + ) + self._approval: str | None = None + + @workflow.signal + def approve(self, response: str) -> None: + self._approval = response + + @workflow.run + async def run(self, prompt: str) -> str: + result = await self.agent.invoke_async(prompt) + while result.stop_reason == "interrupt": + await workflow.wait_condition(lambda: self._approval is not None) + response, self._approval = self._approval, None + responses: list[InterruptResponseContent] = [ + {"interruptResponse": {"interruptId": i.id, "response": response}} + for i in (result.interrupts or []) + ] + result = await self.agent.invoke_async(responses) + return str(result) + + +@workflow.defn +class ActivityToolInterruptWorkflow: + def __init__(self) -> None: + self.agent = TemporalAgent( + model="mock", + start_to_close_timeout=timedelta(seconds=15), + tools=[ + activity_as_tool( + activity_delete, + start_to_close_timeout=timedelta(seconds=15), + ) + ], + ) + self._approval: str | None = None + + @workflow.signal + def approve(self, response: str) -> None: + self._approval = response + + @workflow.run + async def run(self, prompt: str) -> str: + result = await self.agent.invoke_async(prompt) + while result.stop_reason == "interrupt": + await workflow.wait_condition(lambda: self._approval is not None) + response, self._approval = self._approval, None + responses: list[InterruptResponseContent] = [ + {"interruptResponse": {"interruptId": i.id, "response": response}} + for i in (result.interrupts or []) + ] + result = await self.agent.invoke_async(responses) + return str(result) + + +async def test_in_workflow_tool_interrupt(client: Client): + task_queue = "test_in_workflow_tool_interrupt" + plugin = StrandsPlugin( + models={ + "mock": lambda: MockModel( + [ + {"name": "in_workflow_delete", "input": {"name": "foo"}}, + "Done!", + ] + ) + } + ) + + async with Worker( + client, + task_queue=task_queue, + workflows=[InWorkflowToolInterruptWorkflow], + plugins=[plugin], + max_cached_workflows=0, + ): + handle = await client.start_workflow( + InWorkflowToolInterruptWorkflow.run, + "delete foo", + id=f"test_in_workflow_tool_interrupt_{uuid4()}", + task_queue=task_queue, + ) + await handle.signal(InWorkflowToolInterruptWorkflow.approve, "approve") + assert await handle.result() == "Done!\n" + + history = await handle.fetch_history() + # No activity call for the in-workflow @tool — only model calls. + assert get_activities(history) == ["invoke_model", "invoke_model"] + + await Replayer( + workflows=[InWorkflowToolInterruptWorkflow], + plugins=[plugin], + ).replay_workflow(history) + + +async def test_activity_tool_interrupt(client: Client): + global _activity_delete_calls + _activity_delete_calls = 0 + task_queue = "test_activity_tool_interrupt" + plugin = StrandsPlugin( + models={ + "mock": lambda: MockModel( + [ + {"name": "activity_delete", "input": {"name": "foo"}}, + "Done!", + ] + ) + } + ) + + # Activity-side InterruptException relies on the failure converter installed + # via the data converter, which _ActivityWorker reads from the client config. + # Re-create the client with the plugin attached so that converter takes effect. + config = client.config() + config["plugins"] = [*config["plugins"], plugin] + client = Client(**config) + + async with Worker( + client, + task_queue=task_queue, + workflows=[ActivityToolInterruptWorkflow], + activities=[activity_delete], + max_cached_workflows=0, + ): + handle = await client.start_workflow( + ActivityToolInterruptWorkflow.run, + "delete foo", + id=f"test_activity_tool_interrupt_{uuid4()}", + task_queue=task_queue, + ) + await handle.signal(ActivityToolInterruptWorkflow.approve, "approve") + assert await handle.result() == "Done!\n" + + history = await handle.fetch_history() + # activity_delete appears twice: once for the call that raised + # InterruptException, once for the resume call that returned successfully. + assert get_activities(history) == [ + "invoke_model", + "activity_delete", + "activity_delete", + "invoke_model", + ] + + await Replayer( + workflows=[ActivityToolInterruptWorkflow], + plugins=[plugin], + ).replay_workflow(history) From e6f7b023ad2bfa3d765d2c160dc8b7dd113069f8 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 19 May 2026 15:48:43 -0700 Subject: [PATCH 33/46] contrib/strands: forward invocation_state; default StrandsPlugin to BedrockModel Forward agent invocation_state across the model activity boundary so the worker-side model receives it via model.stream(invocation_state=...). Entries that aren't JSON-serializable are dropped before dispatch with a debug log naming the dropped keys. Make model selection optional. StrandsPlugin() with no args registers a single BedrockModel() factory under the name "bedrock" (matching Strands' own implicit default in agent.py:221), and TemporalAgent() with no model resolves to the sole registered factory at activity time. Multi-model setups continue to require an explicit model= on TemporalAgent. README quickstart shrinks accordingly: no BedrockModel import, no models= argument on StrandsPlugin, no model= on TemporalAgent. Model= remains in the multi-model example where it's load-bearing. --- temporalio/contrib/strands/README.md | 23 ++---- temporalio/contrib/strands/_model_activity.py | 17 +++- temporalio/contrib/strands/_plugin.py | 9 +- temporalio/contrib/strands/_temporal_agent.py | 2 +- temporalio/contrib/strands/_temporal_model.py | 24 +++++- .../contrib/strands/test_invocation_state.py | 82 +++++++++++++++++++ 6 files changed, 136 insertions(+), 21 deletions(-) create mode 100644 tests/contrib/strands/test_invocation_state.py diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 460a43363..deb3c9a5c 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -18,8 +18,6 @@ uv add temporalio[strands] import asyncio from datetime import timedelta -from strands.models.bedrock import BedrockModel - from temporalio import workflow from temporalio.client import Client from temporalio.contrib.strands import StrandsPlugin, TemporalAgent @@ -29,9 +27,7 @@ from temporalio.worker import Worker @workflow.defn class MyWorkflow: def __init__(self) -> None: - self.agent = TemporalAgent( - model="bedrock", start_to_close_timeout=timedelta(seconds=60) - ) + self.agent = TemporalAgent(start_to_close_timeout=timedelta(seconds=60)) @workflow.run async def run(self, prompt: str) -> str: @@ -45,7 +41,7 @@ async def main() -> None: client, task_queue="strands", workflows=[MyWorkflow], - plugins=[StrandsPlugin(models={"bedrock": lambda: BedrockModel()})], + plugins=[StrandsPlugin()], ) await worker.run() @@ -85,6 +81,8 @@ Note: Use `agent.invoke_async(message)` instead of `agent(message)`. The synchro `StrandsPlugin(models=...)` takes a mapping of `name → factory`. Each factory is called lazily on first use (on the worker, outside the workflow sandbox) and the constructed model is cached for the worker's lifetime. `TemporalAgent(model="name", ...)` selects which factory to invoke and carries the activity options for that agent's model calls. +If `models` is omitted, the plugin registers a single `BedrockModel()` factory under the name `"bedrock"`, matching Strands' own implicit default. That's why the quickstart can drop `models=` entirely. + ```python from strands.models.anthropic import AnthropicModel from strands.models.bedrock import BedrockModel @@ -111,6 +109,8 @@ Worker(..., plugins=[StrandsPlugin(models={ Each `TemporalAgent` carries its own activity options (timeouts, retry policy, task queue, streaming topic) and dispatches to the shared model activity, which resolves the model name against the registered factories at runtime. A name not present in `models` raises `ValueError` inside the activity. +If `models` has a single entry, `TemporalAgent` may be constructed without an explicit `model=` and the sole factory is used automatically. + ## Retries `TemporalAgent` disables Strands' built-in `ModelRetryStrategy` so retries are handled exclusively by Temporal. Configure retries via `retry_policy` on `TemporalAgent`, and on the activity options accepted by `workflow.activity_as_tool`, `workflow.activity_as_hook`, and `TemporalMCPClient`: @@ -119,7 +119,6 @@ Each `TemporalAgent` carries its own activity options (timeouts, retry policy, t from temporalio.common import RetryPolicy TemporalAgent( - model="bedrock", start_to_close_timeout=timedelta(seconds=60), retry_policy=RetryPolicy(maximum_attempts=3), ) @@ -146,7 +145,6 @@ class PersonInfo(BaseModel): class MyWorkflow: def __init__(self) -> None: self.agent = TemporalAgent( - model="bedrock", start_to_close_timeout=timedelta(seconds=60), structured_output_model=PersonInfo, ) @@ -167,7 +165,7 @@ To forward model chunks to external consumers, pass `streaming_topic="..."` to ` class MyWorkflow: def __init__(self) -> None: self.stream = WorkflowStream() - self.agent = TemporalAgent(model="bedrock", streaming_topic="events") + self.agent = TemporalAgent(streaming_topic="events") # client async for item in WorkflowStreamClient.create(client, workflow_id).subscribe( @@ -194,7 +192,6 @@ async def current_time_activity() -> str: # workflow agent = TemporalAgent( - model="bedrock", start_to_close_timeout=timedelta(seconds=60), tools=[ strands_workflow.activity_as_tool(fetch_user, start_to_close_timeout=timedelta(seconds=30)), @@ -226,7 +223,7 @@ class AuditHook(HookProvider): # Pure local state - deterministic across replay. workflow.logger.info(f"tool {event.tool_use['name']} finished") -agent = TemporalAgent(model="bedrock", start_to_close_timeout=..., hooks=[AuditHook()]) +agent = TemporalAgent(start_to_close_timeout=..., hooks=[AuditHook()]) ``` Callbacks run in workflow context, so they must be deterministic: no `time.time()`, `uuid.uuid4()`, or I/O — same rules as workflow code. For callbacks that need I/O (audit logging, metrics, alerting), use `workflow.activity_as_hook()` to dispatch the work as a Temporal activity: @@ -277,7 +274,6 @@ class ApprovalHook(HookProvider): class MyWorkflow: def __init__(self) -> None: self.agent = TemporalAgent( - model="bedrock", start_to_close_timeout=timedelta(seconds=60), tools=[delete_thing], hooks=[ApprovalHook()], @@ -333,7 +329,6 @@ async def delete_thing(name: str) -> str: class MyWorkflow: def __init__(self) -> None: self.agent = TemporalAgent( - model="bedrock", start_to_close_timeout=timedelta(seconds=60), tools=[activity_as_tool(delete_thing, start_to_close_timeout=timedelta(seconds=10))], ) @@ -375,7 +370,6 @@ class ChatWorkflow: @workflow.run async def run(self, input: ChatInput) -> None: agent = TemporalAgent( - model="bedrock", start_to_close_timeout=timedelta(seconds=60), messages=list(input.messages), ) @@ -402,7 +396,6 @@ class MyWorkflow: def __init__(self) -> None: echo = TemporalMCPClient(server="echo", start_to_close_timeout=timedelta(seconds=30)) self.agent = TemporalAgent( - model="bedrock", start_to_close_timeout=timedelta(seconds=60), tools=[echo], ) diff --git a/temporalio/contrib/strands/_model_activity.py b/temporalio/contrib/strands/_model_activity.py index 87009f0f7..d19756ec5 100644 --- a/temporalio/contrib/strands/_model_activity.py +++ b/temporalio/contrib/strands/_model_activity.py @@ -1,5 +1,5 @@ from collections.abc import AsyncIterable, Callable -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import timedelta from typing import Any @@ -17,8 +17,9 @@ # through unchanged to ``Model.stream`` which accepts the raw dicts. @dataclass class _InvokeModelInput: - model_name: str + model_name: str | None messages: Any + invocation_state: dict[str, Any] = field(default_factory=dict) tool_specs: Any = None system_prompt: str | None = None tool_choice: Any = None @@ -39,7 +40,16 @@ def __init__(self, factories: dict[str, Callable[[], Model]]) -> None: self._factories = factories self._models: dict[str, Model] = {} - def _get_model(self, name: str) -> Model: + def _get_model(self, name: str | None) -> Model: + if name is None: + if len(self._factories) != 1: + raise ValueError( + f"TemporalAgent constructed without an explicit `model`, " + f"but the plugin has {len(self._factories)} models registered. " + f"Pass model='...' to disambiguate. " + f"Known: {sorted(self._factories)}" + ) + name = next(iter(self._factories)) if name not in self._models: if name not in self._factories: raise ValueError( @@ -81,4 +91,5 @@ def _stream(model: Model, input: _InvokeModelInput) -> AsyncIterable[StreamEvent input.system_prompt, tool_choice=input.tool_choice, system_prompt_content=input.system_prompt_content, + invocation_state=input.invocation_state, ) diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index 63ceb01ed..94c887c93 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -4,6 +4,7 @@ import strands.models.model as _strands_model from strands.models import Model +from strands.models.bedrock import BedrockModel from strands.tools.mcp.mcp_types import MCPTransport from temporalio.contrib.pydantic import pydantic_data_converter @@ -46,7 +47,13 @@ def __init__( models: dict[str, Callable[[], Model]] | None = None, mcp_clients: dict[str, Callable[[], MCPTransport]] | None = None, ) -> None: - """Build the plugin from optional model and MCP transport factories.""" + """Build the plugin from optional model and MCP transport factories. + + If ``models`` is omitted, registers a single ``BedrockModel()`` factory + under the name ``"bedrock"``, matching Strands' own implicit default. + """ + if models is None: + models = {"bedrock": lambda: BedrockModel()} activities: list[Callable] = [] if models: ma = ModelActivity(models) diff --git a/temporalio/contrib/strands/_temporal_agent.py b/temporalio/contrib/strands/_temporal_agent.py index 989ae0cce..01a1dd225 100644 --- a/temporalio/contrib/strands/_temporal_agent.py +++ b/temporalio/contrib/strands/_temporal_agent.py @@ -33,7 +33,7 @@ class TemporalAgent(Agent): def __init__( self, *, - model: str, + model: str | None = None, task_queue: str | None = None, schedule_to_close_timeout: timedelta | None = None, schedule_to_start_timeout: timedelta | None = None, diff --git a/temporalio/contrib/strands/_temporal_model.py b/temporalio/contrib/strands/_temporal_model.py index f77e68ad8..8874dbc9a 100644 --- a/temporalio/contrib/strands/_temporal_model.py +++ b/temporalio/contrib/strands/_temporal_model.py @@ -1,3 +1,4 @@ +import json from collections.abc import AsyncIterable from datetime import timedelta from typing import Any @@ -18,6 +19,24 @@ ) +def _filter_serializable(state: dict[str, Any]) -> dict[str, Any]: + """Keep invocation_state entries that JSON-serialize; drop the rest with a debug log.""" + clean: dict[str, Any] = {} + dropped: list[str] = [] + for key, value in state.items(): + try: + json.dumps(value) + except (TypeError, ValueError): + dropped.append(key) + continue + clean[key] = value + if dropped: + workflow.logger.debug( + f"Dropping non-serializable invocation_state keys: {dropped}" + ) + return clean + + class TemporalModel(Model): """A Strands :class:`Model` that runs ``stream()`` as a Temporal activity. @@ -38,7 +57,7 @@ class TemporalModel(Model): def __init__( self, - model_name: str, + model_name: str | None = None, *, task_queue: str | None = None, schedule_to_close_timeout: timedelta | None = None, @@ -98,12 +117,14 @@ async def stream( **kwargs: Any, ) -> AsyncIterable[StreamEvent]: """Run the model via the registered Temporal activity and yield events.""" + clean_state = _filter_serializable(invocation_state) if invocation_state else {} if self._streaming_topic is not None: events = await workflow.execute_activity_method( ModelActivity.invoke_model_streaming, _StreamingInvokeModelInput( model_name=self._model_name, messages=messages, + invocation_state=clean_state, tool_specs=tool_specs, system_prompt=system_prompt, tool_choice=tool_choice, @@ -119,6 +140,7 @@ async def stream( _InvokeModelInput( model_name=self._model_name, messages=messages, + invocation_state=clean_state, tool_specs=tool_specs, system_prompt=system_prompt, tool_choice=tool_choice, diff --git a/tests/contrib/strands/test_invocation_state.py b/tests/contrib/strands/test_invocation_state.py new file mode 100644 index 000000000..01fd4e004 --- /dev/null +++ b/tests/contrib/strands/test_invocation_state.py @@ -0,0 +1,82 @@ +from collections.abc import AsyncIterable +from datetime import timedelta +from typing import Any +from uuid import uuid4 + +from strands.models import Model +from strands.types.streaming import StreamEvent + +from temporalio import workflow +from temporalio.client import Client +from temporalio.contrib.strands import StrandsPlugin, TemporalAgent +from temporalio.worker import Worker + +# Worker-side sink: the recording model writes the invocation_state it +# received here so the test body can inspect it after the workflow completes. +_RECEIVED: list[dict[str, Any]] = [] + + +class _RecordingModel(Model): + def update_config(self, **_model_config: Any) -> None: + return None + + def get_config(self) -> dict[str, Any]: + return {} + + def structured_output(self, *_args: Any, **_kwargs: Any) -> Any: + raise NotImplementedError + + async def stream( + self, + *_args: Any, + invocation_state: dict[str, Any] | None = None, + **_kwargs: Any, + ) -> AsyncIterable[StreamEvent]: + _RECEIVED.append(invocation_state or {}) + yield {"messageStart": {"role": "assistant"}} + yield {"contentBlockDelta": {"delta": {"text": "ok"}}} + yield {"contentBlockStop": {}} + yield {"messageStop": {"stopReason": "end_turn"}} + + +@workflow.defn +class _InvocationStateWorkflow: + def __init__(self) -> None: + self.agent = TemporalAgent( + model="recording", + start_to_close_timeout=timedelta(seconds=15), + ) + + @workflow.run + async def run(self, prompt: str) -> str: + result = await self.agent.invoke_async( + prompt, + invocation_state={"user_key": "user_value", "non_json": object()}, + ) + return str(result) + + +async def test_invocation_state_round_trip(client: Client): + _RECEIVED.clear() + plugin = StrandsPlugin(models={"recording": lambda: _RecordingModel()}) + + async with Worker( + client, + task_queue="test_invocation_state", + workflows=[_InvocationStateWorkflow], + plugins=[plugin], + max_cached_workflows=0, + ): + await client.execute_workflow( + _InvocationStateWorkflow.run, + "hi", + id=f"test_invocation_state_{uuid4()}", + task_queue="test_invocation_state", + ) + + # The serializable key crosses the activity boundary; the non-serializable + # one is dropped before dispatch (with a debug log). + assert _RECEIVED, "model.stream() was not called" + received = _RECEIVED[0] + assert received.get("user_key") == "user_value" + assert "non_json" not in received From efc6fe7c49e47e401115cfebf114c06ede95348b Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 19 May 2026 15:59:24 -0700 Subject: [PATCH 34/46] contrib/strands: gate implicit model resolution behind the plugin default Drop the single-entry guess for TemporalAgent(model=None). Implicit resolution is now valid only when StrandsPlugin auto-registers its own BedrockModel default; any user-supplied models= forces every TemporalAgent to pass model= explicitly. Track the gate via a default_name field on ModelActivity that the plugin sets only on the auto-registered path. --- temporalio/contrib/strands/README.md | 6 +---- temporalio/contrib/strands/_model_activity.py | 25 +++++++++++++------ temporalio/contrib/strands/_plugin.py | 4 ++- 3 files changed, 22 insertions(+), 13 deletions(-) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index deb3c9a5c..b34564261 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -79,9 +79,7 @@ Note: Use `agent.invoke_async(message)` instead of `agent(message)`. The synchro ## Models -`StrandsPlugin(models=...)` takes a mapping of `name → factory`. Each factory is called lazily on first use (on the worker, outside the workflow sandbox) and the constructed model is cached for the worker's lifetime. `TemporalAgent(model="name", ...)` selects which factory to invoke and carries the activity options for that agent's model calls. - -If `models` is omitted, the plugin registers a single `BedrockModel()` factory under the name `"bedrock"`, matching Strands' own implicit default. That's why the quickstart can drop `models=` entirely. +`StrandsPlugin(models=...)` takes a mapping of `name → factory`. Each factory is called lazily on first use (on the worker, outside the workflow sandbox) and the constructed model is cached for the worker's lifetime. `TemporalAgent(model="name", ...)` selects which factory to invoke and carries the activity options for that agent's model calls. If `models` is omitted, the plugin registers a single `BedrockModel()` factory under the name `"bedrock"`, matching Strands' own implicit default. ```python from strands.models.anthropic import AnthropicModel @@ -109,8 +107,6 @@ Worker(..., plugins=[StrandsPlugin(models={ Each `TemporalAgent` carries its own activity options (timeouts, retry policy, task queue, streaming topic) and dispatches to the shared model activity, which resolves the model name against the registered factories at runtime. A name not present in `models` raises `ValueError` inside the activity. -If `models` has a single entry, `TemporalAgent` may be constructed without an explicit `model=` and the sole factory is used automatically. - ## Retries `TemporalAgent` disables Strands' built-in `ModelRetryStrategy` so retries are handled exclusively by Temporal. Configure retries via `retry_policy` on `TemporalAgent`, and on the activity options accepted by `workflow.activity_as_tool`, `workflow.activity_as_hook`, and `TemporalMCPClient`: diff --git a/temporalio/contrib/strands/_model_activity.py b/temporalio/contrib/strands/_model_activity.py index d19756ec5..8dd4aedff 100644 --- a/temporalio/contrib/strands/_model_activity.py +++ b/temporalio/contrib/strands/_model_activity.py @@ -35,21 +35,32 @@ class _StreamingInvokeModelInput(_InvokeModelInput): class ModelActivity: """Holds the registered model factories and exposes the model activities.""" - def __init__(self, factories: dict[str, Callable[[], Model]]) -> None: - """Store the factories; models are constructed lazily on first use.""" + def __init__( + self, + factories: dict[str, Callable[[], Model]], + *, + default_name: str | None = None, + ) -> None: + """Store the factories; models are constructed lazily on first use. + + ``default_name`` is set only by the plugin's own auto-registered + ``BedrockModel`` default. User-supplied ``models`` leave it ``None``, + which forces every ``TemporalAgent`` to specify ``model=`` explicitly. + """ self._factories = factories + self._default_name = default_name self._models: dict[str, Model] = {} def _get_model(self, name: str | None) -> Model: if name is None: - if len(self._factories) != 1: + if self._default_name is None: raise ValueError( - f"TemporalAgent constructed without an explicit `model`, " - f"but the plugin has {len(self._factories)} models registered. " - f"Pass model='...' to disambiguate. " + f"TemporalAgent was constructed without an explicit `model`, " + f"but the plugin was configured with user-supplied `models=`. " + f"Pass model='...' to TemporalAgent. " f"Known: {sorted(self._factories)}" ) - name = next(iter(self._factories)) + name = self._default_name if name not in self._models: if name not in self._factories: raise ValueError( diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index 94c887c93..f3cc762ce 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -52,11 +52,13 @@ def __init__( If ``models`` is omitted, registers a single ``BedrockModel()`` factory under the name ``"bedrock"``, matching Strands' own implicit default. """ + default_name: str | None = None if models is None: models = {"bedrock": lambda: BedrockModel()} + default_name = "bedrock" activities: list[Callable] = [] if models: - ma = ModelActivity(models) + ma = ModelActivity(models, default_name=default_name) activities.extend([ma.invoke_model, ma.invoke_model_streaming]) mcp_clients = mcp_clients or {} From c6502658e89050bdd6225083f9d5a03397823fe9 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 19 May 2026 16:08:32 -0700 Subject: [PATCH 35/46] contrib/strands: mark terminal Strands exceptions non-retryable Extend StrandsFailureConverter.to_failure to translate Strands' terminal model/session exceptions into ApplicationError(non_retryable=True, type=...): MaxTokensReachedException, ContextWindowOverflowException, StructuredOutputException, SessionException. These deterministic failures won't succeed on retry, so the typed annotation stops Temporal's retry policy from churning on them. ModelThrottledException stays retryable. --- .../contrib/strands/_failure_converter.py | 35 +++++++++++++++++-- 1 file changed, 32 insertions(+), 3 deletions(-) diff --git a/temporalio/contrib/strands/_failure_converter.py b/temporalio/contrib/strands/_failure_converter.py index 407316723..d173091e7 100644 --- a/temporalio/contrib/strands/_failure_converter.py +++ b/temporalio/contrib/strands/_failure_converter.py @@ -1,6 +1,12 @@ -"""Failure converter that preserves Strands ``InterruptException`` payloads.""" +"""Failure converter for Strands-specific exceptions.""" from strands.interrupt import InterruptException +from strands.types.exceptions import ( + ContextWindowOverflowException, + MaxTokensReachedException, + SessionException, + StructuredOutputException, +) import temporalio.api.failure.v1 from temporalio.converter import DefaultFailureConverter, PayloadConverter @@ -14,9 +20,21 @@ # can rebuild a real ``Interrupt``. STRANDS_INTERRUPT_TYPE = "StrandsInterrupt" +# Strands' model/session exceptions that are deterministic failures (token +# limits, context overflow, structured-output validation, session I/O). They +# won't succeed on retry, so they cross the boundary as non-retryable typed +# ApplicationErrors. TemporalAgent.invoke_async rewraps these as +# StrandsWorkflowError on the workflow side so users can `except` cleanly. +_TERMINAL_EXCEPTIONS: tuple[type[BaseException], ...] = ( + MaxTokensReachedException, + ContextWindowOverflowException, + StructuredOutputException, + SessionException, +) + class StrandsFailureConverter(DefaultFailureConverter): - """Failure converter that preserves Strands ``InterruptException`` payloads.""" + """Failure converter that preserves Strands exception payloads and retryability.""" def to_failure( self, @@ -24,7 +42,7 @@ def to_failure( payload_converter: PayloadConverter, failure: temporalio.api.failure.v1.Failure, ) -> None: - """Translate ``InterruptException`` to a typed ``ApplicationError``.""" + """Translate Strands exceptions to typed ``ApplicationError``s.""" if isinstance(exception, InterruptException): super().to_failure( ApplicationError( @@ -37,4 +55,15 @@ def to_failure( failure, ) return + if isinstance(exception, _TERMINAL_EXCEPTIONS): + super().to_failure( + ApplicationError( + str(exception), + type=type(exception).__name__, + non_retryable=True, + ), + payload_converter, + failure, + ) + return super().to_failure(exception, payload_converter, failure) From 15a9ab3a32549667344816a90dc3c38e98fa748d Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Tue, 19 May 2026 21:59:07 -0700 Subject: [PATCH 36/46] contrib/strands: accept MCPClient factories in mcp_clients --- temporalio/contrib/strands/README.md | 11 +++++++---- temporalio/contrib/strands/_plugin.py | 12 ++++++------ temporalio/contrib/strands/_temporal_mcp_client.py | 10 +++++----- tests/contrib/strands/test_mcp.py | 11 +++++++---- 4 files changed, 25 insertions(+), 19 deletions(-) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index b34564261..2b913bcba 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -380,10 +380,11 @@ class ChatWorkflow: ## MCP -`StrandsPlugin(mcp_clients=...)` takes a mapping of `name → transport factory`, mirroring the `models=` pattern. The plugin registers a per-server `{name}-call-tool` activity and connects at worker startup to enumerate tools. Workflow-side, `TemporalMCPClient(server="name")` is a pure handle: it references the server by name and carries the per-call activity options. +`StrandsPlugin(mcp_clients=...)` takes a mapping of `name → MCPClient factory`, mirroring the `models=` pattern. The plugin registers a per-server `{name}-call-tool` activity and connects at worker startup to enumerate tools. Workflow-side, `TemporalMCPClient(server="name")` is a pure handle: it references the server by name and carries the per-call activity options. ```python from mcp import StdioServerParameters, stdio_client +from strands.tools.mcp.mcp_client import MCPClient from temporalio.contrib.strands import TemporalMCPClient # workflow @@ -401,15 +402,17 @@ Worker( ..., plugins=[StrandsPlugin( mcp_clients={ - "echo": lambda: stdio_client( - StdioServerParameters(command="...", args=[...]), + "echo": lambda: MCPClient( + lambda: stdio_client( + StdioServerParameters(command="...", args=[...]), + ), ), }, )], ) ``` -The plugin connects to each MCP server once at worker startup to enumerate tools. The schema is frozen for the worker's lifetime; restart workers to pick up MCP-server changes. If a server is unavailable at startup, the worker fails to start. +Each factory returns a fully configured `MCPClient`, so you can pass options like `tool_filters`, `prefix`, `elicitation_callback`, or `tasks_config` to it. The plugin connects to each MCP server once at worker startup to enumerate tools. The schema is frozen for the worker's lifetime; restart workers to pick up MCP-server changes. If a server is unavailable at startup, the worker fails to start. ## Observability diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index f3cc762ce..7bb57d3a4 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -5,7 +5,7 @@ import strands.models.model as _strands_model from strands.models import Model from strands.models.bedrock import BedrockModel -from strands.tools.mcp.mcp_types import MCPTransport +from strands.tools.mcp.mcp_client import MCPClient from temporalio.contrib.pydantic import pydantic_data_converter from temporalio.converter import DataConverter, DefaultPayloadConverter @@ -45,7 +45,7 @@ def __init__( self, *, models: dict[str, Callable[[], Model]] | None = None, - mcp_clients: dict[str, Callable[[], MCPTransport]] | None = None, + mcp_clients: dict[str, Callable[[], MCPClient]] | None = None, ) -> None: """Build the plugin from optional model and MCP transport factories. @@ -62,13 +62,13 @@ def __init__( activities.extend([ma.invoke_model, ma.invoke_model_streaming]) mcp_clients = mcp_clients or {} - for server, transport_factory in mcp_clients.items(): - activities.append(build_call_tool_activity(server, transport_factory)) + for server, client_factory in mcp_clients.items(): + activities.append(build_call_tool_activity(server, client_factory)) @asynccontextmanager async def run_context() -> AsyncGenerator[None, None]: - for server, transport_factory in mcp_clients.items(): - await populate_cache(server, transport_factory) + for server, client_factory in mcp_clients.items(): + await populate_cache(server, client_factory) try: yield finally: diff --git a/temporalio/contrib/strands/_temporal_mcp_client.py b/temporalio/contrib/strands/_temporal_mcp_client.py index 542782594..4c61f39f5 100644 --- a/temporalio/contrib/strands/_temporal_mcp_client.py +++ b/temporalio/contrib/strands/_temporal_mcp_client.py @@ -5,7 +5,7 @@ from strands.tools.mcp.mcp_agent_tool import MCPAgentTool from strands.tools.mcp.mcp_client import MCPClient -from strands.tools.mcp.mcp_types import MCPToolResult, MCPTransport +from strands.tools.mcp.mcp_types import MCPToolResult from strands.tools.tool_provider import ToolProvider from strands.types.tools import AgentTool @@ -101,10 +101,10 @@ def remove_consumer(self, consumer_id: Any, **_kwargs: Any) -> None: async def populate_cache( - server: str, transport_factory: Callable[[], MCPTransport] + server: str, client_factory: Callable[[], MCPClient] ) -> None: """Connect to the MCP server, list tools, fill ``_TOOL_CACHE``.""" - client = MCPClient(transport_factory) + client = client_factory() try: infos: list[_MCPToolInfo] = [] for tool in await client.load_tools(): @@ -129,13 +129,13 @@ def clear_cache(server: str) -> None: def build_call_tool_activity( - server: str, transport_factory: Callable[[], MCPTransport] + server: str, client_factory: Callable[[], MCPClient] ) -> Callable: """Return the per-server ``{server}-call-tool`` activity for registration.""" @activity.defn(name=f"{server}-call-tool") async def call_tool(args: _CallToolArgs) -> MCPToolResult: - client = MCPClient(transport_factory) + client = client_factory() client.start() try: return await client.call_tool_async( diff --git a/tests/contrib/strands/test_mcp.py b/tests/contrib/strands/test_mcp.py index d03a3d7b4..ab8fffcbc 100644 --- a/tests/contrib/strands/test_mcp.py +++ b/tests/contrib/strands/test_mcp.py @@ -4,6 +4,7 @@ from uuid import uuid4 from mcp import StdioServerParameters, stdio_client +from strands.tools.mcp.mcp_client import MCPClient from temporalio import workflow from temporalio.client import Client @@ -48,10 +49,12 @@ async def test_mcp(client: Client): ) }, mcp_clients={ - "echo": lambda: stdio_client( - StdioServerParameters( - command=sys.executable, - args=[str(Path(__file__).parent / "echo_mcp_server.py")], + "echo": lambda: MCPClient( + lambda: stdio_client( + StdioServerParameters( + command=sys.executable, + args=[str(Path(__file__).parent / "echo_mcp_server.py")], + ) ) ), }, From ff60a7909474c1849baa92cccd664f9f815afc53 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 20 May 2026 09:00:13 -0700 Subject: [PATCH 37/46] contrib/strands: drop _get_encoding monkey patch strands-agents 1.39.0 removed _get_encoding and routes count_tokens straight to the chars-per-token heuristic, so the patch is a no-op. Bump the floor pin to 1.39.0 to keep that assumption true. --- pyproject.toml | 4 ++-- temporalio/contrib/strands/_plugin.py | 5 ----- uv.lock | 4 ++-- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 580c7611b..ceac8840e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ aioboto3 = [ "types-aioboto3[s3]>=10.4.0", ] strands = [ - "strands-agents>=1.38.0", + "strands-agents>=1.39.0", ] [project.urls] @@ -90,7 +90,7 @@ dev = [ "opentelemetry-semantic-conventions>=0.40b0,<1", "opentelemetry-sdk-extension-aws>=2.0.0,<3", "async-timeout>=4.0,<6; python_version < '3.11'", - "strands-agents>=1.38.0", + "strands-agents>=1.39.0", "strands-agents-tools>=0.5.2", ] diff --git a/temporalio/contrib/strands/_plugin.py b/temporalio/contrib/strands/_plugin.py index 7bb57d3a4..918883e07 100644 --- a/temporalio/contrib/strands/_plugin.py +++ b/temporalio/contrib/strands/_plugin.py @@ -2,7 +2,6 @@ from contextlib import asynccontextmanager from dataclasses import replace -import strands.models.model as _strands_model from strands.models import Model from strands.models.bedrock import BedrockModel from strands.tools.mcp.mcp_client import MCPClient @@ -21,10 +20,6 @@ populate_cache, ) -# Force Strands' base Model.count_tokens to avoid tiktoken, which lazily downloads -# an encoding file. Use the default chars-per-token heuristic instead (deterministic). -setattr(_strands_model, "_get_encoding", lambda: None) - class StrandsPlugin(SimplePlugin): """Temporal Worker plugin for the Strands Agents SDK. diff --git a/uv.lock b/uv.lock index eb1e42bc4..91d131594 100644 --- a/uv.lock +++ b/uv.lock @@ -5510,7 +5510,7 @@ requires-dist = [ { name = "protobuf", specifier = ">=3.20,<7.0.0" }, { name = "pydantic", marker = "extra == 'pydantic'", specifier = ">=2.0.0,<3" }, { name = "python-dateutil", marker = "python_full_version < '3.11'", specifier = ">=2.8.2,<3" }, - { name = "strands-agents", marker = "extra == 'strands'", specifier = ">=1.38.0" }, + { name = "strands-agents", marker = "extra == 'strands'", specifier = ">=1.39.0" }, { name = "types-aioboto3", extras = ["s3"], marker = "extra == 'aioboto3'", specifier = ">=10.4.0" }, { name = "types-protobuf", specifier = ">=3.20,<7.0.0" }, { name = "typing-extensions", specifier = ">=4.2.0,<5" }, @@ -5552,7 +5552,7 @@ dev = [ { name = "pytest-xdist", specifier = ">=3.6,<4" }, { name = "ruff", specifier = ">=0.15.12,<0.16" }, { name = "setuptools", specifier = "<82" }, - { name = "strands-agents", specifier = ">=1.38.0" }, + { name = "strands-agents", specifier = ">=1.39.0" }, { name = "strands-agents-tools", specifier = ">=0.5.2" }, { name = "toml", specifier = ">=0.10.2,<0.11" }, { name = "twine", specifier = ">=4.0.1,<5" }, From 67ecb642d788b0dbac787228cfe415ff273c8295 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 20 May 2026 09:14:00 -0700 Subject: [PATCH 38/46] contrib/strands: swap current_time for shell in demos and tests --- temporalio/contrib/strands/README.md | 12 ++++++------ tests/contrib/strands/test_tool.py | 20 ++++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 2b913bcba..4007ad2a9 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -175,30 +175,30 @@ async for item in WorkflowStreamClient.create(client, workflow_id).subscribe( Decorate non-deterministic tools with `@activity.defn`, or if you're importing tools from `strands_tools`, wrap them in a thin async function. Then, register the activity on the worker via `Worker(activities=[...])` and pass it to the agent with `workflow.activity_as_tool(activity, **options)` along with any activity options (e.g. `start_to_close_timeout`): ```python -from strands_tools import current_time +from strands_tools import shell from temporalio.contrib.strands import workflow as strands_workflow @activity.defn async def fetch_user(user_id: str) -> dict: ... -@activity.defn(name="current_time") -async def current_time_activity() -> str: - return current_time.current_time() +@activity.defn(name="shell") +async def shell_activity(command: str) -> dict: + return shell.shell(command=command, non_interactive=True) # workflow agent = TemporalAgent( start_to_close_timeout=timedelta(seconds=60), tools=[ strands_workflow.activity_as_tool(fetch_user, start_to_close_timeout=timedelta(seconds=30)), - strands_workflow.activity_as_tool(current_time_activity, start_to_close_timeout=timedelta(seconds=15)), + strands_workflow.activity_as_tool(shell_activity, start_to_close_timeout=timedelta(seconds=15)), ], ) # worker Worker( ..., - activities=[fetch_user, current_time_activity], + activities=[fetch_user, shell_activity], plugins=[StrandsPlugin(models=MODELS)], ) ``` diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index e89f6f9b7..7548dd3de 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -4,7 +4,7 @@ from strands import tool from strands_tools import ( # pyright: ignore[reportMissingTypeStubs] calculator, - current_time, + shell, ) from temporalio import activity, workflow @@ -21,9 +21,9 @@ def letter_counter(word: str, letter: str) -> int: return word.lower().count(letter.lower()) -@activity.defn(name="current_time") -async def current_time_activity() -> str: - return current_time.current_time() +@activity.defn(name="shell") +async def shell_activity(command: str) -> dict: + return shell.shell(command=command, non_interactive=True) @workflow.defn @@ -35,7 +35,7 @@ def __init__(self) -> None: tools=[ calculator, activity_as_tool( - current_time_activity, + shell_activity, start_to_close_timeout=timedelta(seconds=15), ), letter_counter, @@ -54,7 +54,7 @@ async def test_tool(client: Client): models={ "mock": lambda: MockModel( [ - {"name": "current_time", "input": {}}, + {"name": "shell", "input": {"command": "echo hello"}}, { "name": "calculator", "input": {"expression": "3111696 / 74088"}, @@ -73,14 +73,14 @@ async def test_tool(client: Client): client, task_queue=task_queue, workflows=[ToolWorkflow], - activities=[current_time_activity], + activities=[shell_activity], plugins=[plugin], max_cached_workflows=0, ): handle = await client.start_workflow( ToolWorkflow.run, - "I have 4 requests:\n" - "1. What is the time right now?\n" + "I have 3 requests:\n" + "1. Run `echo hello` in a shell\n" "2. Calculate 3111696 / 74088\n" '3. Tell me how many letter R\'s are in the word "strawberry" 🍓', id=f"test_tool_{uuid4()}", @@ -91,7 +91,7 @@ async def test_tool(client: Client): history = await handle.fetch_history() assert get_activities(history) == [ "invoke_model", - "current_time", + "shell", "invoke_model", # calculator (in-workflow) "invoke_model", From 17cc14f35bf8468fa3b4b31940a80826942ead3f Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 20 May 2026 09:14:00 -0700 Subject: [PATCH 39/46] contrib/strands: ruff format populate_cache signature --- temporalio/contrib/strands/_temporal_mcp_client.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/temporalio/contrib/strands/_temporal_mcp_client.py b/temporalio/contrib/strands/_temporal_mcp_client.py index 4c61f39f5..110b121ed 100644 --- a/temporalio/contrib/strands/_temporal_mcp_client.py +++ b/temporalio/contrib/strands/_temporal_mcp_client.py @@ -100,9 +100,7 @@ def remove_consumer(self, consumer_id: Any, **_kwargs: Any) -> None: return None -async def populate_cache( - server: str, client_factory: Callable[[], MCPClient] -) -> None: +async def populate_cache(server: str, client_factory: Callable[[], MCPClient]) -> None: """Connect to the MCP server, list tools, fill ``_TOOL_CACHE``.""" client = client_factory() try: From bb8d7c3f01e9a5bf0332212097eab8377defbc2d Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Wed, 20 May 2026 16:27:12 -0700 Subject: [PATCH 40/46] contrib/strands: switch test_structured_output to TemporalAgent Also refresh stale Agent(...) references in _temporal_mcp_client and _temporal_model docstrings to point at TemporalAgent(...). --- .../contrib/strands/_temporal_mcp_client.py | 6 +-- temporalio/contrib/strands/_temporal_model.py | 10 ++--- .../contrib/strands/test_structured_output.py | 42 ++++++++++--------- 3 files changed, 29 insertions(+), 29 deletions(-) diff --git a/temporalio/contrib/strands/_temporal_mcp_client.py b/temporalio/contrib/strands/_temporal_mcp_client.py index 110b121ed..1dcad70a2 100644 --- a/temporalio/contrib/strands/_temporal_mcp_client.py +++ b/temporalio/contrib/strands/_temporal_mcp_client.py @@ -44,9 +44,9 @@ class TemporalMCPClient(ToolProvider): carries the server name (which selects the registered factory) and the per-call activity options. - Construct once at module level and pass to ``Agent(tools=[...])`` inside - the workflow. Multiple handles may reference the same server name with - different activity options. + Construct once at module level and pass to ``TemporalAgent(tools=[...])`` + inside the workflow. Multiple handles may reference the same server name + with different activity options. """ def __init__( diff --git a/temporalio/contrib/strands/_temporal_model.py b/temporalio/contrib/strands/_temporal_model.py index 8874dbc9a..e9e85b1ee 100644 --- a/temporalio/contrib/strands/_temporal_model.py +++ b/temporalio/contrib/strands/_temporal_model.py @@ -45,10 +45,6 @@ class TemporalModel(Model): :class:`TemporalModel` itself does no I/O, so it is safe to instantiate at module level. - Pass this instance to ``Agent(model=...)`` inside the workflow; each call - dispatches through the registered model activity with ``model_name`` in - the input, and the worker resolves it against the plugin's factories. - When ``streaming_topic`` is set, each ``StreamEvent`` is also published to the named topic on the workflow's :class:`temporalio.contrib.workflow_streams.WorkflowStream` for external @@ -98,11 +94,11 @@ def get_config(self) -> dict[str, Any]: return {} def structured_output(self, *_args: Any, **_kwargs: Any) -> Any: - """Not supported; use ``Agent(structured_output_model=...)`` instead.""" + """Not supported; use ``TemporalAgent(structured_output_model=...)`` instead.""" raise NotImplementedError( "TemporalModel.structured_output is not supported. Use " - "Agent(structured_output_model=...) which routes structured output " - "through stream() via the structured_output_tool." + "TemporalAgent(structured_output_model=...) which routes structured " + "output through stream() via the structured_output_tool." ) async def stream( diff --git a/tests/contrib/strands/test_structured_output.py b/tests/contrib/strands/test_structured_output.py index 8018a9d16..18c77c553 100644 --- a/tests/contrib/strands/test_structured_output.py +++ b/tests/contrib/strands/test_structured_output.py @@ -1,11 +1,11 @@ +from datetime import timedelta from uuid import uuid4 from pydantic import BaseModel, Field -from strands import Agent from temporalio import workflow from temporalio.client import Client -from temporalio.contrib.strands import StrandsPlugin +from temporalio.contrib.strands import StrandsPlugin, TemporalAgent from temporalio.worker import Replayer, Worker from tests.contrib.strands.mock_model import MockModel @@ -19,19 +19,11 @@ class PersonInfo(BaseModel): @workflow.defn class StructuredOutputWorkflow: def __init__(self) -> None: - model = MockModel( - [ - { - "name": "PersonInfo", - "input": { - "name": "John Smith", - "age": 30, - "occupation": "software engineer", - }, - }, - ] + self.agent = TemporalAgent( + model="mock", + start_to_close_timeout=timedelta(seconds=15), + structured_output_model=PersonInfo, ) - self.agent = Agent(model=model, structured_output_model=PersonInfo) @workflow.run async def run(self, prompt: str) -> PersonInfo: @@ -42,16 +34,28 @@ async def run(self, prompt: str) -> PersonInfo: async def test_structured_output(client: Client): task_queue = "test_structured_output" - plugin = StrandsPlugin() - - config = client.config() - config["plugins"] = [*config["plugins"], plugin] - client = Client(**config) + plugin = StrandsPlugin( + models={ + "mock": lambda: MockModel( + [ + { + "name": "PersonInfo", + "input": { + "name": "John Smith", + "age": 30, + "occupation": "software engineer", + }, + }, + ] + ) + } + ) async with Worker( client, task_queue=task_queue, workflows=[StructuredOutputWorkflow], + plugins=[plugin], max_cached_workflows=0, ): handle = await client.start_workflow( From 2b54db975290b9dc8d01599d19dc5ff591f2dd0a Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 21 May 2026 13:35:51 -0700 Subject: [PATCH 41/46] contrib/strands: rename optional dependency to strands-agents --- pyproject.toml | 2 +- temporalio/contrib/strands/README.md | 2 +- uv.lock | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 3911aa054..b8cff8a31 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ lambda-worker-otel = [ "opentelemetry-sdk-extension-aws>=2.0.0,<3", ] aioboto3 = ["aioboto3>=10.4.0", "types-aioboto3[s3]>=10.4.0"] -strands = ["strands-agents>=1.39.0"] +strands-agents = ["strands-agents>=1.39.0"] [project.urls] Homepage = "https://github.com/temporalio/sdk-python" diff --git a/temporalio/contrib/strands/README.md b/temporalio/contrib/strands/README.md index 4007ad2a9..0758b93b5 100644 --- a/temporalio/contrib/strands/README.md +++ b/temporalio/contrib/strands/README.md @@ -7,7 +7,7 @@ This Temporal [Plugin](https://docs.temporal.io/develop/plugins-guide) allows yo ## Installation ```sh -uv add temporalio[strands] +uv add temporalio[strands-agents] ``` ## Quickstart diff --git a/uv.lock b/uv.lock index 520488cdf..976749795 100644 --- a/uv.lock +++ b/uv.lock @@ -5453,7 +5453,7 @@ opentelemetry = [ pydantic = [ { name = "pydantic" }, ] -strands = [ +strands-agents = [ { name = "strands-agents" }, ] @@ -5519,12 +5519,12 @@ requires-dist = [ { name = "protobuf", specifier = ">=3.20,<7.0.0" }, { name = "pydantic", marker = "extra == 'pydantic'", specifier = ">=2.0.0,<3" }, { name = "python-dateutil", marker = "python_full_version < '3.11'", specifier = ">=2.8.2,<3" }, - { name = "strands-agents", marker = "extra == 'strands'", specifier = ">=1.39.0" }, + { name = "strands-agents", marker = "extra == 'strands-agents'", specifier = ">=1.39.0" }, { name = "types-aioboto3", extras = ["s3"], marker = "extra == 'aioboto3'", specifier = ">=10.4.0" }, { name = "types-protobuf", specifier = ">=3.20,<7.0.0" }, { name = "typing-extensions", specifier = ">=4.2.0,<5" }, ] -provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google-adk", "langgraph", "langsmith", "lambda-worker-otel", "aioboto3", "strands"] +provides-extras = ["grpc", "opentelemetry", "pydantic", "openai-agents", "google-adk", "langgraph", "langsmith", "lambda-worker-otel", "aioboto3", "strands-agents"] [package.metadata.requires-dev] dev = [ From e083d8e1efc81faba9535e5620286003648da19c Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 21 May 2026 16:08:50 -0700 Subject: [PATCH 42/46] contrib/strands: fix pydoctor docstring errors `warnings-as-errors = true` was failing CI's gen-docs step on invalid RST inline literals and unresolvable cross-references to the optional strands package. --- temporalio/contrib/strands/_failure_converter.py | 2 +- temporalio/contrib/strands/_temporal_agent.py | 4 ++-- temporalio/contrib/strands/_temporal_model.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/temporalio/contrib/strands/_failure_converter.py b/temporalio/contrib/strands/_failure_converter.py index d173091e7..c387f47f4 100644 --- a/temporalio/contrib/strands/_failure_converter.py +++ b/temporalio/contrib/strands/_failure_converter.py @@ -42,7 +42,7 @@ def to_failure( payload_converter: PayloadConverter, failure: temporalio.api.failure.v1.Failure, ) -> None: - """Translate Strands exceptions to typed ``ApplicationError``s.""" + """Translate Strands exceptions to typed ApplicationErrors.""" if isinstance(exception, InterruptException): super().to_failure( ApplicationError( diff --git a/temporalio/contrib/strands/_temporal_agent.py b/temporalio/contrib/strands/_temporal_agent.py index 01a1dd225..9bc1beb31 100644 --- a/temporalio/contrib/strands/_temporal_agent.py +++ b/temporalio/contrib/strands/_temporal_agent.py @@ -17,12 +17,12 @@ class TemporalAgent(Agent): - """A Strands :class:`Agent` that routes model calls through a Temporal activity. + """A Strands ``Agent`` that routes model calls through a Temporal activity. ``model`` is the name of a factory registered in ``StrandsPlugin(models={...})``. The activity options apply to every model invocation this agent makes. All other keyword arguments are forwarded to - Strands' :class:`Agent` (``tools``, ``hooks``, ``system_prompt``, + Strands' ``Agent`` (``tools``, ``hooks``, ``system_prompt``, ``structured_output_model``, ``messages``, etc.). Strands' ``retry_strategy`` is disabled; configure retries via diff --git a/temporalio/contrib/strands/_temporal_model.py b/temporalio/contrib/strands/_temporal_model.py index e9e85b1ee..29e5c63a2 100644 --- a/temporalio/contrib/strands/_temporal_model.py +++ b/temporalio/contrib/strands/_temporal_model.py @@ -38,11 +38,11 @@ def _filter_serializable(state: dict[str, Any]) -> dict[str, Any]: class TemporalModel(Model): - """A Strands :class:`Model` that runs ``stream()`` as a Temporal activity. + """A Strands ``Model`` that runs ``stream()`` as a Temporal activity. ``model_name`` selects which factory the plugin will invoke worker-side; it must match a key in ``StrandsPlugin(models={...})``. Construction of this - :class:`TemporalModel` itself does no I/O, so it is safe to instantiate at + ``TemporalModel`` itself does no I/O, so it is safe to instantiate at module level. When ``streaming_topic`` is set, each ``StreamEvent`` is also published to From cc724b293da2ee3ab9c7b77c5f35601ca91b5ef8 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 21 May 2026 16:31:37 -0700 Subject: [PATCH 43/46] tests: fix Windows test collection failures * test_type_errors.py: open test files with encoding="utf-8" so the rglob scan doesn't choke on UTF-8 characters (e.g. the strawberry emoji in test_tool.py) when the host's default codec is cp1252. * test_tool.py: skip the module on Windows; strands_tools.shell pulls in pty -> tty -> termios at import time, which is Unix-only. --- tests/contrib/strands/test_tool.py | 9 +++++++++ tests/test_type_errors.py | 4 ++-- 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index 7548dd3de..e8dc6f318 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -1,6 +1,15 @@ +import sys from datetime import timedelta from uuid import uuid4 +import pytest + +if sys.platform == "win32": + pytest.skip( + "strands_tools.shell uses Unix-only pty/termios", + allow_module_level=True, + ) + from strands import tool from strands_tools import ( # pyright: ignore[reportMissingTypeStubs] calculator, diff --git a/tests/test_type_errors.py b/tests/test_type_errors.py index d8e6e2afb..2b70d7f63 100644 --- a/tests/test_type_errors.py +++ b/tests/test_type_errors.py @@ -86,7 +86,7 @@ def _test_type_errors( def _has_type_error_assertions(test_file: Path) -> bool: """Check if a file contains any type error assertions.""" - with open(test_file) as f: + with open(test_file, encoding="utf-8") as f: return any(re.search(r"# assert-type-error-\w+:", line) for line in f) @@ -94,7 +94,7 @@ def _get_expected_errors(test_file: Path, type_checker: str) -> dict[int, str]: """Parse expected type errors from comments in a file for the specified type checker.""" expected_errors = {} - with open(test_file) as f: + with open(test_file, encoding="utf-8") as f: lines = zip(itertools.count(1), f) for line_num, line in lines: if match := re.search( From ca173ee0da99982290cd9e489f45c27630129510 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 21 May 2026 16:45:41 -0700 Subject: [PATCH 44/46] tests/contrib/strands: swap shell tool for file_read strands_tools.shell imports pty/tty/termios at module load, which is Unix-only and broke Windows test collection. file_read on a tmp_path fixture is also non-deterministic (depends on filesystem state), has no in-workflow equivalent, and imports cleanly on every platform. --- tests/contrib/strands/test_tool.py | 44 ++++++++++++++++-------------- 1 file changed, 24 insertions(+), 20 deletions(-) diff --git a/tests/contrib/strands/test_tool.py b/tests/contrib/strands/test_tool.py index e8dc6f318..39985e2df 100644 --- a/tests/contrib/strands/test_tool.py +++ b/tests/contrib/strands/test_tool.py @@ -1,19 +1,11 @@ -import sys from datetime import timedelta +from pathlib import Path from uuid import uuid4 -import pytest - -if sys.platform == "win32": - pytest.skip( - "strands_tools.shell uses Unix-only pty/termios", - allow_module_level=True, - ) - from strands import tool from strands_tools import ( # pyright: ignore[reportMissingTypeStubs] calculator, - shell, + file_read, ) from temporalio import activity, workflow @@ -30,9 +22,18 @@ def letter_counter(word: str, letter: str) -> int: return word.lower().count(letter.lower()) -@activity.defn(name="shell") -async def shell_activity(command: str) -> dict: - return shell.shell(command=command, non_interactive=True) +@activity.defn(name="read_file") +async def read_file_activity(path: str) -> str: + result = file_read.file_read( + { + "toolUseId": "read_file", + "name": "file_read", + "input": {"path": path, "mode": "view"}, + } + ) + text = result["content"][0].get("text") + assert text is not None + return text @workflow.defn @@ -44,7 +45,7 @@ def __init__(self) -> None: tools=[ calculator, activity_as_tool( - shell_activity, + read_file_activity, start_to_close_timeout=timedelta(seconds=15), ), letter_counter, @@ -57,13 +58,16 @@ async def run(self, prompt: str) -> str: return str(result) -async def test_tool(client: Client): +async def test_tool(client: Client, tmp_path: Path): task_queue = "test_tool" + fixture = tmp_path / "greeting.txt" + fixture.write_text("hello\n") + plugin = StrandsPlugin( models={ "mock": lambda: MockModel( [ - {"name": "shell", "input": {"command": "echo hello"}}, + {"name": "read_file", "input": {"path": str(fixture)}}, { "name": "calculator", "input": {"expression": "3111696 / 74088"}, @@ -82,16 +86,16 @@ async def test_tool(client: Client): client, task_queue=task_queue, workflows=[ToolWorkflow], - activities=[shell_activity], + activities=[read_file_activity], plugins=[plugin], max_cached_workflows=0, ): handle = await client.start_workflow( ToolWorkflow.run, "I have 3 requests:\n" - "1. Run `echo hello` in a shell\n" + f"1. Read the file at {fixture}\n" "2. Calculate 3111696 / 74088\n" - '3. Tell me how many letter R\'s are in the word "strawberry" 🍓', + '3. Tell me how many letter R\'s are in the word "strawberry"', id=f"test_tool_{uuid4()}", task_queue=task_queue, ) @@ -100,7 +104,7 @@ async def test_tool(client: Client): history = await handle.fetch_history() assert get_activities(history) == [ "invoke_model", - "shell", + "read_file", "invoke_model", # calculator (in-workflow) "invoke_model", From c84320c66a55b9e3ad5284342426f302046c2272 Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Thu, 21 May 2026 22:14:15 -0700 Subject: [PATCH 45/46] Fix Strands MCP client on Python 3.10 --- .../contrib/strands/_temporal_mcp_client.py | 82 +++++++++++++------ 1 file changed, 59 insertions(+), 23 deletions(-) diff --git a/temporalio/contrib/strands/_temporal_mcp_client.py b/temporalio/contrib/strands/_temporal_mcp_client.py index 1dcad70a2..aa8e25c15 100644 --- a/temporalio/contrib/strands/_temporal_mcp_client.py +++ b/temporalio/contrib/strands/_temporal_mcp_client.py @@ -3,6 +3,8 @@ from datetime import timedelta from typing import Any +from mcp import ClientSession +from mcp.types import PaginatedRequestParams, Tool from strands.tools.mcp.mcp_agent_tool import MCPAgentTool from strands.tools.mcp.mcp_client import MCPClient from strands.tools.mcp.mcp_types import MCPToolResult @@ -100,25 +102,56 @@ def remove_consumer(self, consumer_id: Any, **_kwargs: Any) -> None: return None +# Use MCP sessions directly instead of MCPClient's background-thread helpers. +# Those helpers route calls through cross-loop futures that are unreliable on +# Python 3.10 when invoked from Temporal's async worker/activity event loops. +async def _list_mcp_tools(client: MCPClient) -> Sequence[Tool]: + async with client._transport_callable() as (read_stream, write_stream, *_): + async with ClientSession( + read_stream, + write_stream, + elicitation_callback=client._elicitation_callback, + ) as session: + await session.initialize() + tools: list[Tool] = [] + pagination_token = None + while True: + page = await session.list_tools( + params=PaginatedRequestParams(cursor=pagination_token) + if pagination_token is not None + else None + ) + tools.extend(page.tools) + pagination_token = page.nextCursor + if pagination_token is None: + return tools + + +def _agent_tool_for_filtering(client: MCPClient, tool: Tool) -> MCPAgentTool: + if client._prefix: + return MCPAgentTool(tool, client, name_override=f"{client._prefix}_{tool.name}") + return MCPAgentTool(tool, client) + + async def populate_cache(server: str, client_factory: Callable[[], MCPClient]) -> None: """Connect to the MCP server, list tools, fill ``_TOOL_CACHE``.""" client = client_factory() - try: - infos: list[_MCPToolInfo] = [] - for tool in await client.load_tools(): - if not isinstance(tool, MCPAgentTool): - continue - infos.append( - _MCPToolInfo( - name=tool.mcp_tool.name, - description=tool.mcp_tool.description or "", - input_schema=tool.mcp_tool.inputSchema, - output_schema=tool.mcp_tool.outputSchema, - ) + infos: list[_MCPToolInfo] = [] + for tool in await _list_mcp_tools(client): + if not client._should_include_tool_with_filters( + _agent_tool_for_filtering(client, tool), + client._tool_filters, + ): + continue + infos.append( + _MCPToolInfo( + name=tool.name, + description=tool.description or "", + input_schema=tool.inputSchema, + output_schema=tool.outputSchema, ) - _TOOL_CACHE[server] = infos - finally: - client.stop(None, None, None) + ) + _TOOL_CACHE[server] = infos def clear_cache(server: str) -> None: @@ -134,14 +167,17 @@ def build_call_tool_activity( @activity.defn(name=f"{server}-call-tool") async def call_tool(args: _CallToolArgs) -> MCPToolResult: client = client_factory() - client.start() try: - return await client.call_tool_async( - tool_use_id=args.tool_use_id, - name=args.tool_name, - arguments=args.arguments, - ) - finally: - client.stop(None, None, None) + async with client._transport_callable() as (read_stream, write_stream, *_): + async with ClientSession( + read_stream, + write_stream, + elicitation_callback=client._elicitation_callback, + ) as session: + await session.initialize() + result = await session.call_tool(args.tool_name, args.arguments) + return client._handle_tool_result(args.tool_use_id, result) + except Exception as err: + return client._handle_tool_execution_error(args.tool_use_id, err) return call_tool From a4636e46efd5535d1e9cf71f6e1db5d2ba7da05d Mon Sep 17 00:00:00 2001 From: Brian Strauch Date: Fri, 22 May 2026 11:39:23 -0700 Subject: [PATCH 46/46] contrib: inline _heartbeat_decorator into strands and openai_agents Removes the temporalio/contrib/common shared module by duplicating _heartbeat_decorator.py into each plugin and updating the two import sites. --- temporalio/contrib/common/__init__.py | 1 - .../_heartbeat_decorator.py | 0 .../openai_agents/_invoke_model_activity.py | 2 +- .../contrib/strands/_heartbeat_decorator.py | 38 +++++++++++++++++++ temporalio/contrib/strands/_model_activity.py | 2 +- 5 files changed, 40 insertions(+), 3 deletions(-) delete mode 100644 temporalio/contrib/common/__init__.py rename temporalio/contrib/{common => openai_agents}/_heartbeat_decorator.py (100%) create mode 100644 temporalio/contrib/strands/_heartbeat_decorator.py diff --git a/temporalio/contrib/common/__init__.py b/temporalio/contrib/common/__init__.py deleted file mode 100644 index a58b7968f..000000000 --- a/temporalio/contrib/common/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Shared utilities for temporalio.contrib plugins.""" diff --git a/temporalio/contrib/common/_heartbeat_decorator.py b/temporalio/contrib/openai_agents/_heartbeat_decorator.py similarity index 100% rename from temporalio/contrib/common/_heartbeat_decorator.py rename to temporalio/contrib/openai_agents/_heartbeat_decorator.py diff --git a/temporalio/contrib/openai_agents/_invoke_model_activity.py b/temporalio/contrib/openai_agents/_invoke_model_activity.py index 1155d5180..a43f9aeaf 100644 --- a/temporalio/contrib/openai_agents/_invoke_model_activity.py +++ b/temporalio/contrib/openai_agents/_invoke_model_activity.py @@ -43,7 +43,7 @@ from typing_extensions import Required, TypedDict from temporalio import activity -from temporalio.contrib.common._heartbeat_decorator import auto_heartbeater +from temporalio.contrib.openai_agents._heartbeat_decorator import auto_heartbeater from temporalio.contrib.workflow_streams import WorkflowStreamClient from temporalio.exceptions import ApplicationError diff --git a/temporalio/contrib/strands/_heartbeat_decorator.py b/temporalio/contrib/strands/_heartbeat_decorator.py new file mode 100644 index 000000000..7c5b9193d --- /dev/null +++ b/temporalio/contrib/strands/_heartbeat_decorator.py @@ -0,0 +1,38 @@ +import asyncio +from collections.abc import Awaitable, Callable +from functools import wraps +from typing import Any, TypeVar, cast + +from temporalio import activity + +F = TypeVar("F", bound=Callable[..., Awaitable[Any]]) + + +def auto_heartbeater(fn: F) -> F: + """Decorator that heartbeats at half the activity's heartbeat timeout.""" + + @wraps(fn) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + heartbeat_timeout = activity.info().heartbeat_timeout + heartbeat_task = None + if heartbeat_timeout: + heartbeat_task = asyncio.create_task( + _heartbeat_every(heartbeat_timeout.total_seconds() / 2) + ) + try: + return await fn(*args, **kwargs) + finally: + if heartbeat_task: + heartbeat_task.cancel() + try: + await heartbeat_task + except asyncio.CancelledError: + pass + + return cast(F, wrapper) + + +async def _heartbeat_every(delay: float) -> None: + while True: + await asyncio.sleep(delay) + activity.heartbeat() diff --git a/temporalio/contrib/strands/_model_activity.py b/temporalio/contrib/strands/_model_activity.py index 8dd4aedff..fba30658d 100644 --- a/temporalio/contrib/strands/_model_activity.py +++ b/temporalio/contrib/strands/_model_activity.py @@ -7,7 +7,7 @@ from strands.types.streaming import StreamEvent from temporalio import activity -from temporalio.contrib.common._heartbeat_decorator import auto_heartbeater +from temporalio.contrib.strands._heartbeat_decorator import auto_heartbeater from temporalio.contrib.workflow_streams import WorkflowStreamClient