Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion tests/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,20 @@
import os
import sys
import threading
from collections.abc import Generator
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path

import pytest

from taskito import Queue

# Public type alias used by workflow test files for the ``workflow_worker``
# fixture parameter (mypy requires annotated test parameters under
# ``disallow_untyped_defs``). Importing this from conftest keeps the type
# definition in one place.
WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]


@pytest.fixture
def queue(tmp_path: Path) -> Queue:
Expand All @@ -28,6 +35,29 @@ def run_worker(queue: Queue) -> Generator[threading.Thread]:
thread.join(timeout=5)


@pytest.fixture
def workflow_worker(queue: Queue) -> WorkflowWorkerFactory:
"""Context-manager factory that starts and stops a worker thread.

Workflow tests typically run several short worker sessions per test
(start, submit workflow, wait, stop — repeated). Returning a context
manager from one fixture replaces the per-file ``_start_worker`` /
``_stop_worker`` helpers without changing test semantics.
"""

@contextmanager
def _ctx() -> Generator[threading.Thread]:
thread = threading.Thread(target=queue.run_worker, daemon=True)
thread.start()
try:
yield thread
finally:
queue._inner.request_shutdown()
thread.join(timeout=5)

return _ctx


_PYTEST_EXIT_STATUS: int = 0


Expand Down
68 changes: 17 additions & 51 deletions tests/python/test_workflows_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,16 @@
from __future__ import annotations

import threading
from collections.abc import Callable
from contextlib import AbstractContextManager

from taskito import Queue
from taskito.workflows import NodeStatus, Workflow, WorkflowState

WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]

def _start_worker(queue: Queue) -> threading.Thread:
thread = threading.Thread(target=queue.run_worker, daemon=True)
thread.start()
return thread


def _stop_worker(queue: Queue, thread: threading.Thread) -> None:
queue._inner.request_shutdown()
thread.join(timeout=5)


def test_result_hash_stored(queue: Queue) -> None:
def test_result_hash_stored(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
"""Completed nodes have a non-None result_hash."""

@queue.task()
Expand All @@ -29,12 +22,9 @@ def ok_task() -> str:
wf = Workflow(name="hash_stored")
wf.step("a", ok_task)

worker = _start_worker(queue)
try:
with workflow_worker():
run = queue.submit_workflow(wf)
run.wait(timeout=15)
finally:
_stop_worker(queue, worker)

# Check the node data via base_run_node_data
nodes = queue._inner.get_base_run_node_data(run.id)
Expand All @@ -46,7 +36,7 @@ def ok_task() -> str:
# In practice it's usually populated.


def test_incremental_skips_completed(queue: Queue) -> None:
def test_incremental_skips_completed(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
"""Incremental run marks base-completed nodes as CACHE_HIT."""

executed: list[str] = []
Expand All @@ -66,23 +56,17 @@ def step_b() -> str:
wf.step("b", step_b, after="a")

# First run: everything executes.
worker = _start_worker(queue)
try:
with workflow_worker():
run1 = queue.submit_workflow(wf)
run1.wait(timeout=15)
finally:
_stop_worker(queue, worker)

assert run1.status().state == WorkflowState.COMPLETED
executed.clear()

# Second run: incremental.
worker = _start_worker(queue)
try:
with workflow_worker():
run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id)
run2.wait(timeout=15)
finally:
_stop_worker(queue, worker)

final = run2.status()
assert final.state == WorkflowState.COMPLETED
Expand All @@ -94,7 +78,7 @@ def step_b() -> str:
assert executed == [] # nothing re-executed


def test_incremental_reruns_failed(queue: Queue) -> None:
def test_incremental_reruns_failed(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
"""Failed nodes in the base run get re-executed."""

call_count = {"n": 0}
Expand All @@ -110,22 +94,16 @@ def flaky() -> str:
wf.step("a", flaky)

# First run: fails.
worker = _start_worker(queue)
try:
with workflow_worker():
run1 = queue.submit_workflow(wf)
run1.wait(timeout=15)
finally:
_stop_worker(queue, worker)

assert run1.status().state == WorkflowState.FAILED

# Second run incremental: failed node re-executes.
worker = _start_worker(queue)
try:
with workflow_worker():
run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id)
run2.wait(timeout=15)
finally:
_stop_worker(queue, worker)

assert run2.status().state == WorkflowState.COMPLETED
assert call_count["n"] == 2
Expand Down Expand Up @@ -157,7 +135,7 @@ def test_dirty_propagation(queue: Queue) -> None:
assert not cached


def test_cache_hit_is_terminal(queue: Queue) -> None:
def test_cache_hit_is_terminal(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
"""CACHE_HIT nodes are terminal and don't block the workflow."""

@queue.task()
Expand All @@ -169,26 +147,20 @@ def ok_task() -> str:
wf.step("b", ok_task, after="a")

# First run.
worker = _start_worker(queue)
try:
with workflow_worker():
run1 = queue.submit_workflow(wf)
run1.wait(timeout=15)
finally:
_stop_worker(queue, worker)

# Second incremental run.
worker = _start_worker(queue)
try:
with workflow_worker():
run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id)
final = run2.wait(timeout=15)
finally:
_stop_worker(queue, worker)

# Workflow should complete (CACHE_HIT is terminal).
assert final.state == WorkflowState.COMPLETED


def test_full_refresh_ignores_cache(queue: Queue) -> None:
def test_full_refresh_ignores_cache(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
"""incremental=False always re-runs everything."""

executed: list[str] = []
Expand All @@ -202,22 +174,16 @@ def step_a() -> str:
wf.step("a", step_a)

# First run.
worker = _start_worker(queue)
try:
with workflow_worker():
run1 = queue.submit_workflow(wf)
run1.wait(timeout=15)
finally:
_stop_worker(queue, worker)

executed.clear()

# Second run without incremental — should re-execute.
worker = _start_worker(queue)
try:
with workflow_worker():
run2 = queue.submit_workflow(wf)
run2.wait(timeout=15)
finally:
_stop_worker(queue, worker)

assert executed == ["a"]

Expand Down
Loading
Loading