From 3d0f06112fcbe3b479733ff364226fe801db0b73 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Fri, 8 May 2026 06:39:36 +0530 Subject: [PATCH] test(workflows): hoist worker setup into conftest fixture --- tests/python/conftest.py | 32 +++++- tests/python/test_workflows_caching.py | 68 +++-------- tests/python/test_workflows_conditions.py | 114 +++++++------------ tests/python/test_workflows_fan_out.py | 87 +++++--------- tests/python/test_workflows_gates.py | 55 +++------ tests/python/test_workflows_linear.py | 61 ++++------ tests/python/test_workflows_subworkflow.py | 50 +++----- tests/python/test_workflows_visualization.py | 22 +--- 8 files changed, 171 insertions(+), 318 deletions(-) diff --git a/tests/python/conftest.py b/tests/python/conftest.py index 2d82fd7..556e658 100644 --- a/tests/python/conftest.py +++ b/tests/python/conftest.py @@ -3,13 +3,20 @@ import os import sys import threading -from collections.abc import Generator +from collections.abc import Callable, Generator +from contextlib import AbstractContextManager, contextmanager from pathlib import Path import pytest from taskito import Queue +# Public type alias used by workflow test files for the ``workflow_worker`` +# fixture parameter (mypy requires annotated test parameters under +# ``disallow_untyped_defs``). Importing this from conftest keeps the type +# definition in one place. +WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] + @pytest.fixture def queue(tmp_path: Path) -> Queue: @@ -28,6 +35,29 @@ def run_worker(queue: Queue) -> Generator[threading.Thread]: thread.join(timeout=5) +@pytest.fixture +def workflow_worker(queue: Queue) -> WorkflowWorkerFactory: + """Context-manager factory that starts and stops a worker thread. + + Workflow tests typically run several short worker sessions per test + (start, submit workflow, wait, stop — repeated). Returning a context + manager from one fixture replaces the per-file ``_start_worker`` / + ``_stop_worker`` helpers without changing test semantics. + """ + + @contextmanager + def _ctx() -> Generator[threading.Thread]: + thread = threading.Thread(target=queue.run_worker, daemon=True) + thread.start() + try: + yield thread + finally: + queue._inner.request_shutdown() + thread.join(timeout=5) + + return _ctx + + _PYTEST_EXIT_STATUS: int = 0 diff --git a/tests/python/test_workflows_caching.py b/tests/python/test_workflows_caching.py index e434112..cfba4c8 100644 --- a/tests/python/test_workflows_caching.py +++ b/tests/python/test_workflows_caching.py @@ -3,23 +3,16 @@ from __future__ import annotations import threading +from collections.abc import Callable +from contextlib import AbstractContextManager from taskito import Queue from taskito.workflows import NodeStatus, Workflow, WorkflowState +WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] -def _start_worker(queue: Queue) -> threading.Thread: - thread = threading.Thread(target=queue.run_worker, daemon=True) - thread.start() - return thread - -def _stop_worker(queue: Queue, thread: threading.Thread) -> None: - queue._inner.request_shutdown() - thread.join(timeout=5) - - -def test_result_hash_stored(queue: Queue) -> None: +def test_result_hash_stored(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Completed nodes have a non-None result_hash.""" @queue.task() @@ -29,12 +22,9 @@ def ok_task() -> str: wf = Workflow(name="hash_stored") wf.step("a", ok_task) - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) run.wait(timeout=15) - finally: - _stop_worker(queue, worker) # Check the node data via base_run_node_data nodes = queue._inner.get_base_run_node_data(run.id) @@ -46,7 +36,7 @@ def ok_task() -> str: # In practice it's usually populated. -def test_incremental_skips_completed(queue: Queue) -> None: +def test_incremental_skips_completed(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Incremental run marks base-completed nodes as CACHE_HIT.""" executed: list[str] = [] @@ -66,23 +56,17 @@ def step_b() -> str: wf.step("b", step_b, after="a") # First run: everything executes. - worker = _start_worker(queue) - try: + with workflow_worker(): run1 = queue.submit_workflow(wf) run1.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert run1.status().state == WorkflowState.COMPLETED executed.clear() # Second run: incremental. - worker = _start_worker(queue) - try: + with workflow_worker(): run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) run2.wait(timeout=15) - finally: - _stop_worker(queue, worker) final = run2.status() assert final.state == WorkflowState.COMPLETED @@ -94,7 +78,7 @@ def step_b() -> str: assert executed == [] # nothing re-executed -def test_incremental_reruns_failed(queue: Queue) -> None: +def test_incremental_reruns_failed(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Failed nodes in the base run get re-executed.""" call_count = {"n": 0} @@ -110,22 +94,16 @@ def flaky() -> str: wf.step("a", flaky) # First run: fails. - worker = _start_worker(queue) - try: + with workflow_worker(): run1 = queue.submit_workflow(wf) run1.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert run1.status().state == WorkflowState.FAILED # Second run incremental: failed node re-executes. - worker = _start_worker(queue) - try: + with workflow_worker(): run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) run2.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert run2.status().state == WorkflowState.COMPLETED assert call_count["n"] == 2 @@ -157,7 +135,7 @@ def test_dirty_propagation(queue: Queue) -> None: assert not cached -def test_cache_hit_is_terminal(queue: Queue) -> None: +def test_cache_hit_is_terminal(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """CACHE_HIT nodes are terminal and don't block the workflow.""" @queue.task() @@ -169,26 +147,20 @@ def ok_task() -> str: wf.step("b", ok_task, after="a") # First run. - worker = _start_worker(queue) - try: + with workflow_worker(): run1 = queue.submit_workflow(wf) run1.wait(timeout=15) - finally: - _stop_worker(queue, worker) # Second incremental run. - worker = _start_worker(queue) - try: + with workflow_worker(): run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) final = run2.wait(timeout=15) - finally: - _stop_worker(queue, worker) # Workflow should complete (CACHE_HIT is terminal). assert final.state == WorkflowState.COMPLETED -def test_full_refresh_ignores_cache(queue: Queue) -> None: +def test_full_refresh_ignores_cache(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """incremental=False always re-runs everything.""" executed: list[str] = [] @@ -202,22 +174,16 @@ def step_a() -> str: wf.step("a", step_a) # First run. - worker = _start_worker(queue) - try: + with workflow_worker(): run1 = queue.submit_workflow(wf) run1.wait(timeout=15) - finally: - _stop_worker(queue, worker) executed.clear() # Second run without incremental — should re-execute. - worker = _start_worker(queue) - try: + with workflow_worker(): run2 = queue.submit_workflow(wf) run2.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert executed == ["a"] diff --git a/tests/python/test_workflows_conditions.py b/tests/python/test_workflows_conditions.py index 96f5b95..6c48468 100644 --- a/tests/python/test_workflows_conditions.py +++ b/tests/python/test_workflows_conditions.py @@ -3,23 +3,16 @@ from __future__ import annotations import threading +from collections.abc import Callable +from contextlib import AbstractContextManager from taskito import Queue from taskito.workflows import NodeStatus, Workflow, WorkflowContext, WorkflowState +WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] -def _start_worker(queue: Queue) -> threading.Thread: - thread = threading.Thread(target=queue.run_worker, daemon=True) - thread.start() - return thread - -def _stop_worker(queue: Queue, thread: threading.Thread) -> None: - queue._inner.request_shutdown() - thread.join(timeout=5) - - -def test_on_failure_step_runs(queue: Queue) -> None: +def test_on_failure_step_runs(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A step with condition='on_failure' runs when predecessor fails.""" @queue.task(max_retries=0) @@ -37,19 +30,18 @@ def cleanup() -> str: wf.step("a", fail_task) wf.step("b", cleanup, after="a", condition="on_failure") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.nodes["a"].status == NodeStatus.FAILED assert final.nodes["b"].status == NodeStatus.COMPLETED assert collected == ["cleanup ran"] -def test_on_failure_step_skipped_on_success(queue: Queue) -> None: +def test_on_failure_step_skipped_on_success( + queue: Queue, workflow_worker: WorkflowWorkerFactory +) -> None: """A step with condition='on_failure' is SKIPPED when predecessor succeeds.""" @queue.task() @@ -64,18 +56,15 @@ def rollback() -> str: wf.step("a", ok_task) wf.step("b", rollback, after="a", condition="on_failure") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.nodes["a"].status == NodeStatus.COMPLETED assert final.nodes["b"].status == NodeStatus.SKIPPED -def test_always_step_runs_on_success(queue: Queue) -> None: +def test_always_step_runs_on_success(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A step with condition='always' runs when predecessor succeeds.""" collected: list[str] = [] @@ -93,18 +82,15 @@ def always_task() -> str: wf.step("a", ok_task) wf.step("b", always_task, after="a", condition="always") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.nodes["b"].status == NodeStatus.COMPLETED assert collected == ["always ran"] -def test_always_step_runs_on_failure(queue: Queue) -> None: +def test_always_step_runs_on_failure(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A step with condition='always' runs even when predecessor fails.""" collected: list[str] = [] @@ -122,19 +108,16 @@ def always_task() -> str: wf.step("a", fail_task) wf.step("b", always_task, after="a", condition="always") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.nodes["a"].status == NodeStatus.FAILED assert final.nodes["b"].status == NodeStatus.COMPLETED assert collected == ["always ran"] -def test_on_success_default(queue: Queue) -> None: +def test_on_success_default(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Default condition (on_success) skips the step when predecessor fails.""" @queue.task(max_retries=0) @@ -149,19 +132,18 @@ def next_task() -> str: wf.step("a", fail_task) wf.step("b", next_task, after="a") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.FAILED assert final.nodes["a"].status == NodeStatus.FAILED assert final.nodes["b"].status == NodeStatus.SKIPPED -def test_continue_mode_independent_branches(queue: Queue) -> None: +def test_continue_mode_independent_branches( + queue: Queue, workflow_worker: WorkflowWorkerFactory +) -> None: """on_failure='continue' lets independent branches keep running.""" order: list[str] = [] @@ -194,12 +176,9 @@ def after_ok() -> str: wf.step("after_fail", after_fail, after="fail_branch") wf.step("after_ok", after_ok, after="ok_branch") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) # fail_branch failed → after_fail skipped (condition=on_success, pred failed) # ok_branch succeeded → after_ok ran @@ -211,7 +190,9 @@ def after_ok() -> str: assert "after_ok" in order -def test_continue_mode_skips_downstream(queue: Queue) -> None: +def test_continue_mode_skips_downstream( + queue: Queue, workflow_worker: WorkflowWorkerFactory +) -> None: """In continue mode, failure skips on_success downstream in the chain.""" @queue.task(max_retries=0) @@ -228,12 +209,9 @@ def ok_task() -> str: wf.step("c", ok_task, after="b") wf.step("d", ok_task, after="c") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.FAILED assert final.nodes["a"].status == NodeStatus.COMPLETED @@ -242,7 +220,7 @@ def ok_task() -> str: assert final.nodes["d"].status == NodeStatus.SKIPPED -def test_callable_condition_true(queue: Queue) -> None: +def test_callable_condition_true(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A callable condition that returns True lets the step run.""" @queue.task() @@ -260,18 +238,15 @@ def guarded() -> str: wf.step("a", ok_task) wf.step("b", guarded, after="a", condition=lambda ctx: True) - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.nodes["b"].status == NodeStatus.COMPLETED assert collected == ["ran"] -def test_callable_condition_false(queue: Queue) -> None: +def test_callable_condition_false(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A callable condition that returns False skips the step.""" @queue.task() @@ -286,17 +261,14 @@ def guarded() -> str: wf.step("a", ok_task) wf.step("b", guarded, after="a", condition=lambda ctx: False) - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.nodes["b"].status == NodeStatus.SKIPPED -def test_callable_accesses_results(queue: Queue) -> None: +def test_callable_accesses_results(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A callable condition can access predecessor results via ctx.results.""" @queue.task() @@ -322,18 +294,15 @@ def high_score(ctx: WorkflowContext) -> bool: wf.step("validate", score_task) wf.step("deploy", deploy, after="validate", condition=high_score) - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.nodes["deploy"].status == NodeStatus.COMPLETED assert collected == ["deployed"] -def test_fail_fast_backward_compat(queue: Queue) -> None: +def test_fail_fast_backward_compat(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Phase 2 regression: fail_fast (default) cascades all pending nodes.""" @queue.task(max_retries=0) @@ -349,12 +318,9 @@ def ok_task() -> str: wf.step("b", ok_task, after="a") wf.step("c", ok_task, after="b") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.FAILED assert final.nodes["a"].status == NodeStatus.FAILED @@ -362,7 +328,9 @@ def ok_task() -> str: assert final.nodes["c"].status == NodeStatus.SKIPPED -def test_skip_propagation_respects_always(queue: Queue) -> None: +def test_skip_propagation_respects_always( + queue: Queue, workflow_worker: WorkflowWorkerFactory +) -> None: """A→B→C: A fails, B(on_success) skipped, C(always) still runs.""" @queue.task(max_retries=0) @@ -385,12 +353,9 @@ def always_task() -> str: wf.step("b", ok_task, after="a") wf.step("c", always_task, after="b", condition="always") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.nodes["a"].status == NodeStatus.FAILED assert final.nodes["b"].status == NodeStatus.SKIPPED @@ -398,7 +363,9 @@ def always_task() -> str: assert collected == ["always ran"] -def test_fan_out_with_on_failure_downstream(queue: Queue) -> None: +def test_fan_out_with_on_failure_downstream( + queue: Queue, workflow_worker: WorkflowWorkerFactory +) -> None: """Fan-out child fails, downstream on_failure step runs.""" @queue.task() @@ -428,12 +395,9 @@ def on_error() -> str: wf.step("collect", aggregate, after="process", fan_in="all") wf.step("handle_error", on_error, after="process", condition="on_failure") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.nodes["process"].status == NodeStatus.FAILED assert final.nodes["collect"].status == NodeStatus.SKIPPED diff --git a/tests/python/test_workflows_fan_out.py b/tests/python/test_workflows_fan_out.py index 8455af7..c7f92e6 100644 --- a/tests/python/test_workflows_fan_out.py +++ b/tests/python/test_workflows_fan_out.py @@ -4,6 +4,8 @@ import threading import time +from collections.abc import Callable +from contextlib import AbstractContextManager from typing import Any import pytest @@ -11,19 +13,10 @@ from taskito import Queue from taskito.workflows import NodeStatus, Workflow, WorkflowState +WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] -def _start_worker(queue: Queue) -> threading.Thread: - thread = threading.Thread(target=queue.run_worker, daemon=True) - thread.start() - return thread - -def _stop_worker(queue: Queue, thread: threading.Thread) -> None: - queue._inner.request_shutdown() - thread.join(timeout=5) - - -def test_fan_out_each(queue: Queue) -> None: +def test_fan_out_each(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """fan_out='each' splits a list into N parallel jobs, fan_in='all' collects.""" @queue.task() @@ -46,18 +39,15 @@ def aggregate(results: list[int]) -> str: wf.step("process", double, after="fetch", fan_out="each") wf.step("collect", aggregate, after="process", fan_in="all") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED assert sorted(collected) == [20, 40, 60] -def test_fan_out_empty_list(queue: Queue) -> None: +def test_fan_out_empty_list(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Fan-out over an empty list → fan-in receives [].""" @queue.task() @@ -80,18 +70,15 @@ def aggregate(results: list) -> str: wf.step("process", process, after="fetch", fan_out="each") wf.step("collect", aggregate, after="process", fan_in="all") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED assert collected == [] -def test_fan_out_single_item(queue: Queue) -> None: +def test_fan_out_single_item(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Fan-out with a single-element list → 1 child → fan-in gets [result].""" @queue.task() @@ -114,18 +101,15 @@ def aggregate(results: list[int]) -> str: wf.step("process", add_one, after="fetch", fan_out="each") wf.step("collect", aggregate, after="process", fan_in="all") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED assert collected == [43] -def test_fan_out_with_downstream(queue: Queue) -> None: +def test_fan_out_with_downstream(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Full pipeline: source → fan_out → fan_in → downstream static step.""" order: list[str] = [] @@ -156,12 +140,9 @@ def report() -> str: wf.step("agg", aggregate, after="process", fan_in="all") wf.step("report", report, after="agg") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED assert "source" in order @@ -172,7 +153,7 @@ def report() -> str: assert order.index("aggregate") < order.index("report") -def test_fan_out_child_failure(queue: Queue) -> None: +def test_fan_out_child_failure(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A failing fan-out child triggers fail-fast, workflow fails.""" @queue.task() @@ -194,19 +175,16 @@ def aggregate(results: list[int]) -> str: wf.step("process", maybe_fail, after="fetch", fan_out="each") wf.step("collect", aggregate, after="process", fan_in="all") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.FAILED # The aggregate node should be skipped assert final.nodes["collect"].status == NodeStatus.SKIPPED -def test_fan_out_source_failure(queue: Queue) -> None: +def test_fan_out_source_failure(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """If the fan-out source step fails, all deferred nodes are SKIPPED.""" @queue.task(max_retries=0) @@ -226,12 +204,9 @@ def aggregate(results: list[int]) -> str: wf.step("process", process, after="fetch", fan_out="each") wf.step("collect", aggregate, after="process", fan_in="all") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.FAILED assert final.nodes["fetch"].status == NodeStatus.FAILED @@ -239,7 +214,7 @@ def aggregate(results: list[int]) -> str: assert final.nodes["collect"].status == NodeStatus.SKIPPED -def test_fan_out_cancellation(queue: Queue) -> None: +def test_fan_out_cancellation(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Cancelling a workflow mid-fan-out skips pending children.""" @queue.task() @@ -260,20 +235,19 @@ def aggregate(results: list[int]) -> str: wf.step("process", slow_process, after="fetch", fan_out="each") wf.step("collect", aggregate, after="process", fan_in="all") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) # Wait for source to complete and fan-out to expand time.sleep(2) run.cancel() snapshot = run.status() - finally: - _stop_worker(queue, worker) assert snapshot.state == WorkflowState.CANCELLED -def test_fan_out_status_shows_children(queue: Queue) -> None: +def test_fan_out_status_shows_children( + queue: Queue, workflow_worker: WorkflowWorkerFactory +) -> None: """status() returns child node snapshots like process[0], process[1].""" @queue.task() @@ -293,12 +267,9 @@ def aggregate(results: list[int]) -> str: wf.step("process", process, after="fetch", fan_out="each") wf.step("collect", aggregate, after="process", fan_in="all") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED # Children should appear in the node map @@ -310,7 +281,9 @@ def aggregate(results: list[int]) -> str: assert final.nodes[f"process[{i}]"].job_id is not None -def test_fan_out_preserves_result_order(queue: Queue) -> None: +def test_fan_out_preserves_result_order( + queue: Queue, workflow_worker: WorkflowWorkerFactory +) -> None: """Fan-in results maintain the order of child indices.""" @queue.task() @@ -333,12 +306,9 @@ def aggregate(results: list[str]) -> str: wf.step("process", identity, after="fetch", fan_out="each") wf.step("collect", aggregate, after="process", fan_in="all") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED # Results should be in child index order (same as input order) @@ -356,7 +326,7 @@ class _FakeTask: wf.step("bad[0]", _FakeTask()) -def test_linear_workflow_still_works(queue: Queue) -> None: +def test_linear_workflow_still_works(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Phase 2 regression: a linear workflow without fan-out still works.""" order: list[str] = [] @@ -381,12 +351,9 @@ def step_c() -> str: wf.step("b", step_b, after="a") wf.step("c", step_c, after="b") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED assert order == ["a", "b", "c"] diff --git a/tests/python/test_workflows_gates.py b/tests/python/test_workflows_gates.py index c1074fa..bcbf41b 100644 --- a/tests/python/test_workflows_gates.py +++ b/tests/python/test_workflows_gates.py @@ -4,23 +4,16 @@ import threading import time +from collections.abc import Callable +from contextlib import AbstractContextManager from taskito import Queue from taskito.workflows import NodeStatus, Workflow, WorkflowState +WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] -def _start_worker(queue: Queue) -> threading.Thread: - thread = threading.Thread(target=queue.run_worker, daemon=True) - thread.start() - return thread - -def _stop_worker(queue: Queue, thread: threading.Thread) -> None: - queue._inner.request_shutdown() - thread.join(timeout=5) - - -def test_gate_pauses_workflow(queue: Queue) -> None: +def test_gate_pauses_workflow(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A gate node enters WAITING_APPROVAL and blocks downstream.""" @queue.task() @@ -32,13 +25,10 @@ def ok_task() -> str: wf.gate("approve", after="a") wf.step("b", ok_task, after="approve") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) time.sleep(3) # Let "a" complete snapshot = run.status() - finally: - _stop_worker(queue, worker) assert snapshot.state == WorkflowState.RUNNING assert snapshot.nodes["a"].status == NodeStatus.COMPLETED @@ -46,7 +36,7 @@ def ok_task() -> str: assert snapshot.nodes["b"].status == NodeStatus.PENDING -def test_approve_gate_resumes(queue: Queue) -> None: +def test_approve_gate_resumes(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Approving a gate lets downstream steps run to completion.""" collected: list[str] = [] @@ -61,14 +51,11 @@ def ok_task() -> str: wf.gate("approve", after="a") wf.step("b", ok_task, after="approve") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) time.sleep(2) # Let "a" complete and gate enter WAITING_APPROVAL queue.approve_gate(run.id, "approve") final = run.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED assert final.nodes["approve"].status == NodeStatus.COMPLETED @@ -76,7 +63,7 @@ def ok_task() -> str: assert len(collected) == 2 # "a" and "b" -def test_reject_gate_fails(queue: Queue) -> None: +def test_reject_gate_fails(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Rejecting a gate fails it and skips downstream.""" @queue.task() @@ -88,21 +75,18 @@ def ok_task() -> str: wf.gate("approve", after="a") wf.step("b", ok_task, after="approve") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) time.sleep(2) queue.reject_gate(run.id, "approve", error="not approved") final = run.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.FAILED assert final.nodes["approve"].status == NodeStatus.FAILED assert final.nodes["b"].status == NodeStatus.SKIPPED -def test_gate_timeout_reject(queue: Queue) -> None: +def test_gate_timeout_reject(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Gate with timeout and on_timeout='reject' auto-rejects.""" @queue.task() @@ -114,12 +98,9 @@ def ok_task() -> str: wf.gate("approve", after="a", timeout=1.0, on_timeout="reject") wf.step("b", ok_task, after="approve") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.FAILED assert final.nodes["approve"].status == NodeStatus.FAILED @@ -127,7 +108,7 @@ def ok_task() -> str: assert "timeout" in (final.nodes["approve"].error or "").lower() -def test_gate_timeout_approve(queue: Queue) -> None: +def test_gate_timeout_approve(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Gate with on_timeout='approve' auto-approves and continues.""" collected: list[str] = [] @@ -142,19 +123,16 @@ def ok_task() -> str: wf.gate("approve", after="a", timeout=1.0, on_timeout="approve") wf.step("b", ok_task, after="approve") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED assert final.nodes["approve"].status == NodeStatus.COMPLETED assert final.nodes["b"].status == NodeStatus.COMPLETED -def test_gate_with_condition(queue: Queue) -> None: +def test_gate_with_condition(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A gate with condition='on_success' respects predecessor state.""" @queue.task(max_retries=0) @@ -170,12 +148,9 @@ def ok_task() -> str: wf.gate("approve", after="a", condition="on_success") wf.step("b", ok_task, after="approve") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.FAILED # Gate should be skipped because predecessor failed (condition=on_success) diff --git a/tests/python/test_workflows_linear.py b/tests/python/test_workflows_linear.py index dddca4d..27716a0 100644 --- a/tests/python/test_workflows_linear.py +++ b/tests/python/test_workflows_linear.py @@ -4,6 +4,8 @@ import threading import time +from collections.abc import Callable +from contextlib import AbstractContextManager import pytest @@ -11,19 +13,10 @@ from taskito.workflows import NodeStatus, Workflow, WorkflowState from taskito.workflows.run import WorkflowTimeoutError +WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] -def _start_worker(queue: Queue) -> threading.Thread: - thread = threading.Thread(target=queue.run_worker, daemon=True) - thread.start() - return thread - -def _stop_worker(queue: Queue, thread: threading.Thread) -> None: - queue._inner.request_shutdown() - thread.join(timeout=5) - - -def test_linear_three_step_workflow(queue: Queue) -> None: +def test_linear_three_step_workflow(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A→B→C runs in order and the workflow reaches COMPLETED.""" order: list[str] = [] @@ -48,12 +41,9 @@ def step_c() -> str: wf.step("b", step_b, after="a") wf.step("c", step_c, after="b") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED assert order == ["a", "b", "c"] @@ -61,7 +51,9 @@ def step_c() -> str: assert set(final.nodes.keys()) == {"a", "b", "c"} -def test_workflow_with_args_and_kwargs(queue: Queue) -> None: +def test_workflow_with_args_and_kwargs( + queue: Queue, workflow_worker: WorkflowWorkerFactory +) -> None: """Step args and kwargs round-trip through the queue serializer.""" received: list[tuple] = [] @@ -75,19 +67,18 @@ def collect(x: int, y: int, *, label: str) -> int: wf.step("first", collect, args=(2, 3), kwargs={"label": "a"}) wf.step("second", collect, args=(10, 20), kwargs={"label": "b"}, after="first") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED assert (2, 3, "a") in received assert (10, 20, "b") in received -def test_workflow_decorator_registration(queue: Queue) -> None: +def test_workflow_decorator_registration( + queue: Queue, workflow_worker: WorkflowWorkerFactory +) -> None: """@queue.workflow() stores a proxy that can build and submit.""" @queue.task() @@ -105,12 +96,9 @@ def build() -> Workflow: assert built.name == "nightly" assert built.step_names == ["x"] - worker = _start_worker(queue) - try: + with workflow_worker(): run = build.submit() final = run.wait(timeout=10) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED @@ -170,7 +158,7 @@ def noop() -> None: assert node.status == NodeStatus.SKIPPED -def test_workflow_failing_step(queue: Queue) -> None: +def test_workflow_failing_step(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A failing step fails the workflow and skips downstream steps.""" @queue.task(max_retries=0) @@ -186,12 +174,9 @@ def boom() -> str: wf.step("b", boom, after="a") wf.step("c", good, after="b") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.FAILED assert final.nodes["a"].status == NodeStatus.COMPLETED @@ -281,7 +266,9 @@ def noop() -> None: # (verified indirectly: second submit succeeds without duplicate-key errors) -def test_workflow_emits_completed_event(queue: Queue) -> None: +def test_workflow_emits_completed_event( + queue: Queue, workflow_worker: WorkflowWorkerFactory +) -> None: """WORKFLOW_COMPLETED event fires on successful run completion.""" from taskito.events import EventType @@ -301,14 +288,11 @@ def listener(_event_type: EventType, payload: dict) -> None: wf = Workflow(name="event_pipe") wf.step("x", noop) - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) run.wait(timeout=10) # Give event bus thread a moment to dispatch event_received.wait(timeout=5) - finally: - _stop_worker(queue, worker) assert any(e.get("run_id") == run.id for e in events) assert any(e.get("state") == "completed" for e in events) @@ -331,7 +315,7 @@ def noop() -> None: @pytest.mark.asyncio -async def test_workflow_async_step(queue: Queue) -> None: +async def test_workflow_async_step(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """An async @queue.task() step works inside a workflow.""" @queue.task() @@ -346,8 +330,7 @@ def sync_step() -> str: wf.step("a", async_step) wf.step("b", sync_step, after="a") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) # Poll instead of blocking wait to play nicely with asyncio deadline = time.monotonic() + 15 @@ -355,8 +338,6 @@ def sync_step() -> str: while not final.state.is_terminal() and time.monotonic() < deadline: await _async_sleep(0.1) final = run.status() - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED diff --git a/tests/python/test_workflows_subworkflow.py b/tests/python/test_workflows_subworkflow.py index 4d6b2fc..7629a38 100644 --- a/tests/python/test_workflows_subworkflow.py +++ b/tests/python/test_workflows_subworkflow.py @@ -3,23 +3,16 @@ from __future__ import annotations import threading +from collections.abc import Callable +from contextlib import AbstractContextManager from taskito import Queue from taskito.workflows import NodeStatus, Workflow, WorkflowState +WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] -def _start_worker(queue: Queue) -> threading.Thread: - thread = threading.Thread(target=queue.run_worker, daemon=True) - thread.start() - return thread - -def _stop_worker(queue: Queue, thread: threading.Thread) -> None: - queue._inner.request_shutdown() - thread.join(timeout=5) - - -def test_sub_workflow_executes(queue: Queue) -> None: +def test_sub_workflow_executes(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A sub-workflow step runs the child workflow, parent continues after.""" order: list[str] = [] @@ -50,12 +43,9 @@ def etl_pipeline() -> Workflow: wf.step("etl", etl_pipeline.as_step()) wf.step("report", report, after="etl") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED assert "extract" in order @@ -64,7 +54,7 @@ def etl_pipeline() -> Workflow: assert order.index("load") < order.index("report") -def test_sub_workflow_failure(queue: Queue) -> None: +def test_sub_workflow_failure(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """A failing sub-workflow fails the parent node.""" @queue.task(max_retries=0) @@ -85,19 +75,18 @@ def failing_sub() -> Workflow: wf.step("sub", failing_sub.as_step()) wf.step("after", ok_task, after="sub") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.FAILED assert final.nodes["sub"].status == NodeStatus.FAILED assert final.nodes["after"].status == NodeStatus.SKIPPED -def test_sub_workflow_compile_failure_marks_parent_failed(queue: Queue) -> None: +def test_sub_workflow_compile_failure_marks_parent_failed( + queue: Queue, workflow_worker: WorkflowWorkerFactory +) -> None: """Regression: a factory that raises during `build()` must not leave the parent node Skipped forever — it must be marked Failed so the outer run can finalize. Before the fix, the tracker called `skip_workflow_node` @@ -116,12 +105,9 @@ def broken_sub() -> Workflow: wf.step("sub", broken_sub.as_step()) wf.step("after", downstream, after="sub") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=15) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.FAILED, ( f"outer run must finalize as FAILED, got {final.state}" @@ -133,7 +119,7 @@ def broken_sub() -> Workflow: assert final.nodes["after"].status == NodeStatus.SKIPPED -def test_cancel_parent_cascades(queue: Queue) -> None: +def test_cancel_parent_cascades(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Cancelling a parent workflow cancels the child sub-workflow too.""" import time @@ -152,19 +138,16 @@ def slow_sub() -> Workflow: wf = Workflow(name="parent_cancel") wf.step("sub", slow_sub.as_step()) - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) time.sleep(2) # Let sub-workflow submit run.cancel() snapshot = run.status() - finally: - _stop_worker(queue, worker) assert snapshot.state == WorkflowState.CANCELLED -def test_parallel_sub_workflows(queue: Queue) -> None: +def test_parallel_sub_workflows(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """Two sub-workflows can run concurrently.""" order: list[str] = [] @@ -201,12 +184,9 @@ def sub_b() -> Workflow: wf.step("sb", sub_b.as_step()) wf.step("reconcile", reconcile, after=["sa", "sb"]) - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=20) - finally: - _stop_worker(queue, worker) assert final.state == WorkflowState.COMPLETED assert "a" in order diff --git a/tests/python/test_workflows_visualization.py b/tests/python/test_workflows_visualization.py index f9186c1..ebce1c3 100644 --- a/tests/python/test_workflows_visualization.py +++ b/tests/python/test_workflows_visualization.py @@ -3,26 +3,19 @@ from __future__ import annotations import threading +from collections.abc import Callable +from contextlib import AbstractContextManager from taskito import Queue from taskito.workflows import Workflow +WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] + class _FakeTask: _task_name = "fake" -def _start_worker(queue: Queue) -> threading.Thread: - thread = threading.Thread(target=queue.run_worker, daemon=True) - thread.start() - return thread - - -def _stop_worker(queue: Queue, thread: threading.Thread) -> None: - queue._inner.request_shutdown() - thread.join(timeout=5) - - def test_mermaid_linear() -> None: """Linear DAG renders correct Mermaid graph.""" wf = Workflow(name="linear") @@ -80,7 +73,7 @@ def test_dot_linear() -> None: assert "a -> b" in output -def test_visualize_live_run(queue: Queue) -> None: +def test_visualize_live_run(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: """WorkflowRun.visualize() shows live statuses.""" @queue.task() @@ -91,13 +84,10 @@ def ok_task() -> str: wf.step("a", ok_task) wf.step("b", ok_task, after="a") - worker = _start_worker(queue) - try: + with workflow_worker(): run = queue.submit_workflow(wf) final = run.wait(timeout=15) output = run.visualize("mermaid") - finally: - _stop_worker(queue, worker) assert final.state.value == "completed" assert "graph LR" in output