From a221edf0d29829456d03f0bd8f2548309bdf8215 Mon Sep 17 00:00:00 2001
From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com>
Date: Fri, 8 May 2026 11:01:21 +0530
Subject: [PATCH] test: drop duplicate top-level test files
Squash-merge of #150 added the subdir copies without removing the top-level originals, leaving every test file present in two places. Removes the 50 stale top-level files and the duplicate prefork_apps/ directory.
---
tests/prefork_apps/__init__.py | 0
tests/prefork_apps/cancel_app.py | 33 --
tests/prefork_apps/timeout_app.py | 41 --
tests/test_basic.py | 140 ------
tests/test_batch.py | 146 ------
tests/test_cancel.py | 62 ---
tests/test_chain.py | 82 ----
tests/test_cli.py | 41 --
tests/test_context.py | 65 ---
tests/test_contrib.py | 325 -------------
tests/test_customizability.py | 297 ------------
tests/test_dashboard.py | 403 ---------------
tests/test_dashboard_settings.py | 199 --------
tests/test_dashboard_static.py | 142 ------
tests/test_dependencies.py | 117 -----
tests/test_dlq.py | 48 --
tests/test_events.py | 103 ----
tests/test_fastapi.py | 190 --------
tests/test_hooks.py | 83 ----
tests/test_idempotent.py | 123 -----
tests/test_interception.py | 446 -----------------
tests/test_keda.py | 153 ------
tests/test_namespace.py | 118 -----
tests/test_native_async.py | 494 -------------------
tests/test_observability.py | 309 ------------
tests/test_periodic.py | 59 ---
tests/test_prefork.py | 323 ------------
tests/test_priority.py | 50 --
tests/test_progress.py | 44 --
tests/test_proxies.py | 429 ----------------
tests/test_rate_limit.py | 37 --
tests/test_resource_system_full.py | 675 --------------------------
tests/test_resources.py | 367 --------------
tests/test_result_race.py | 62 ---
tests/test_retry.py | 87 ----
tests/test_retry_history.py | 58 ---
tests/test_run_maybe_async.py | 59 ---
tests/test_serializers.py | 129 -----
tests/test_shutdown.py | 62 ---
tests/test_streaming.py | 154 ------
tests/test_unique.py | 65 ---
tests/test_webhooks.py | 138 ------
tests/test_worker.py | 114 -----
tests/test_worker_resources.py | 181 -------
tests/test_workflows_analysis.py | 137 ------
tests/test_workflows_caching.py | 215 --------
tests/test_workflows_conditions.py | 405 ----------------
tests/test_workflows_cron.py | 30 --
tests/test_workflows_fan_out.py | 359 --------------
tests/test_workflows_gates.py | 158 ------
tests/test_workflows_linear.py | 348 -------------
tests/test_workflows_subworkflow.py | 194 --------
tests/test_workflows_visualization.py | 96 ----
53 files changed, 9195 deletions(-)
delete mode 100644 tests/prefork_apps/__init__.py
delete mode 100644 tests/prefork_apps/cancel_app.py
delete mode 100644 tests/prefork_apps/timeout_app.py
delete mode 100644 tests/test_basic.py
delete mode 100644 tests/test_batch.py
delete mode 100644 tests/test_cancel.py
delete mode 100644 tests/test_chain.py
delete mode 100644 tests/test_cli.py
delete mode 100644 tests/test_context.py
delete mode 100644 tests/test_contrib.py
delete mode 100644 tests/test_customizability.py
delete mode 100644 tests/test_dashboard.py
delete mode 100644 tests/test_dashboard_settings.py
delete mode 100644 tests/test_dashboard_static.py
delete mode 100644 tests/test_dependencies.py
delete mode 100644 tests/test_dlq.py
delete mode 100644 tests/test_events.py
delete mode 100644 tests/test_fastapi.py
delete mode 100644 tests/test_hooks.py
delete mode 100644 tests/test_idempotent.py
delete mode 100644 tests/test_interception.py
delete mode 100644 tests/test_keda.py
delete mode 100644 tests/test_namespace.py
delete mode 100644 tests/test_native_async.py
delete mode 100644 tests/test_observability.py
delete mode 100644 tests/test_periodic.py
delete mode 100644 tests/test_prefork.py
delete mode 100644 tests/test_priority.py
delete mode 100644 tests/test_progress.py
delete mode 100644 tests/test_proxies.py
delete mode 100644 tests/test_rate_limit.py
delete mode 100644 tests/test_resource_system_full.py
delete mode 100644 tests/test_resources.py
delete mode 100644 tests/test_result_race.py
delete mode 100644 tests/test_retry.py
delete mode 100644 tests/test_retry_history.py
delete mode 100644 tests/test_run_maybe_async.py
delete mode 100644 tests/test_serializers.py
delete mode 100644 tests/test_shutdown.py
delete mode 100644 tests/test_streaming.py
delete mode 100644 tests/test_unique.py
delete mode 100644 tests/test_webhooks.py
delete mode 100644 tests/test_worker.py
delete mode 100644 tests/test_worker_resources.py
delete mode 100644 tests/test_workflows_analysis.py
delete mode 100644 tests/test_workflows_caching.py
delete mode 100644 tests/test_workflows_conditions.py
delete mode 100644 tests/test_workflows_cron.py
delete mode 100644 tests/test_workflows_fan_out.py
delete mode 100644 tests/test_workflows_gates.py
delete mode 100644 tests/test_workflows_linear.py
delete mode 100644 tests/test_workflows_subworkflow.py
delete mode 100644 tests/test_workflows_visualization.py
diff --git a/tests/prefork_apps/__init__.py b/tests/prefork_apps/__init__.py
deleted file mode 100644
index e69de29..0000000
diff --git a/tests/prefork_apps/cancel_app.py b/tests/prefork_apps/cancel_app.py
deleted file mode 100644
index 20bf770..0000000
--- a/tests/prefork_apps/cancel_app.py
+++ /dev/null
@@ -1,33 +0,0 @@
-"""Module-level Queue + tasks used by the prefork cancel regression tests.
-
-The Queue inside this module must be importable both in the parent test
-process and inside each prefork child interpreter. The DB path comes from
-``TASKITO_CANCEL_TEST_DB`` so each test run can use its own tmp file while
-still letting the parent and child build identical Queue instances from
-the same module path.
-"""
-
-from __future__ import annotations
-
-import os
-import time
-
-from taskito import Queue
-from taskito.context import current_job
-
-queue = Queue(db_path=os.environ.get("TASKITO_CANCEL_TEST_DB", "/tmp/taskito-cancel.db"))
-
-
-@queue.task(timeout=30, max_retries=0)
-def cooperative_loop(max_iters: int = 600) -> int:
- """Loop calling ``check_cancelled()`` so cancel can stop the task quickly."""
- for _ in range(max_iters):
- current_job.check_cancelled()
- time.sleep(0.05)
- return max_iters
-
-
-@queue.task(max_retries=0)
-def quick(x: int) -> int:
- """Returns immediately — used to verify the child still serves jobs after a cancel."""
- return x * 2
diff --git a/tests/prefork_apps/timeout_app.py b/tests/prefork_apps/timeout_app.py
deleted file mode 100644
index b546e34..0000000
--- a/tests/prefork_apps/timeout_app.py
+++ /dev/null
@@ -1,41 +0,0 @@
-"""Module-level Queue + tasks used by the prefork timeout regression tests.
-
-Prefork children import the app module independently, so the task registry
-must live at a module path resolvable inside the child interpreter — that's
-why this module exists as a sibling of ``test_prefork.py`` rather than being
-defined inline in the test.
-
-The DB path comes from ``TASKITO_TIMEOUT_TEST_DB`` so each test run can use a
-unique tmp file while still letting the parent and child build identical
-Queue instances from this same module.
-"""
-
-from __future__ import annotations
-
-import os
-import time
-
-from taskito import Queue
-
-queue = Queue(db_path=os.environ.get("TASKITO_TIMEOUT_TEST_DB", "/tmp/taskito-timeout.db"))
-
-
-@queue.task(timeout=2, max_retries=0)
-def hang() -> None:
- """Spin forever — used to trigger the watchdog's SIGKILL path."""
- while True:
- pass
-
-
-@queue.task()
-def quick(x: int) -> int:
- """Returns immediately — used to verify timeout=0 (no timeout) is unaffected."""
- return x * 2
-
-
-@queue.task(timeout=2, max_retries=0)
-def sleep_then_finish(seconds: float) -> str:
- """Sleeps for `seconds`, then finishes — used to verify the watchdog only
- fires when the deadline is actually exceeded."""
- time.sleep(seconds)
- return "done"
diff --git a/tests/test_basic.py b/tests/test_basic.py
deleted file mode 100644
index 9a42053..0000000
--- a/tests/test_basic.py
+++ /dev/null
@@ -1,140 +0,0 @@
-"""Basic tests for taskito — enqueue, dequeue, result retrieval."""
-
-import threading
-
-from taskito import Queue
-
-
-def test_task_registration(queue: Queue) -> None:
- """Tasks can be registered with the decorator."""
-
- @queue.task()
- def add(a: int, b: int) -> int:
- return a + b
-
- assert add.name.endswith("add")
- assert add.name in queue._task_registry
-
-
-def test_enqueue_returns_job_result(queue: Queue) -> None:
- """Enqueueing a task returns a JobResult handle."""
-
- @queue.task()
- def noop() -> None:
- pass
-
- job = noop.delay()
- assert job.id is not None
- assert len(job.id) > 0
-
-
-def test_task_direct_call(queue: Queue) -> None:
- """Decorated tasks can still be called directly."""
-
- @queue.task()
- def multiply(a: int, b: int) -> int:
- return a * b
-
- assert multiply(3, 4) == 12
-
-
-def test_apply_async_with_delay(queue: Queue) -> None:
- """apply_async accepts a delay parameter."""
-
- @queue.task()
- def slow_task() -> None:
- pass
-
- job = slow_task.apply_async(delay=60)
- assert job.id is not None
-
-
-def test_apply_async_with_overrides(queue: Queue) -> None:
- """apply_async can override default task settings."""
-
- @queue.task(priority=1, queue="default")
- def configurable_task(x: int) -> int:
- return x
-
- job = configurable_task.apply_async(
- args=(42,),
- priority=10,
- queue="urgent",
- max_retries=5,
- timeout=600,
- )
- assert job.id is not None
-
-
-def test_queue_stats(queue: Queue) -> None:
- """stats() returns counts by status."""
-
- @queue.task()
- def sample_task() -> None:
- pass
-
- sample_task.delay()
- sample_task.delay()
-
- stats = queue.stats()
- assert stats["pending"] == 2
- assert stats["running"] == 0
-
-
-def test_worker_executes_task(queue: Queue) -> None:
- """Worker processes tasks and stores results."""
-
- @queue.task()
- def add(a: int, b: int) -> int:
- return a + b
-
- job = add.delay(2, 3)
-
- # Run worker in a background thread
- worker_thread = threading.Thread(
- target=queue.run_worker,
- daemon=True,
- )
- worker_thread.start()
-
- # Wait for result
- result = job.result(timeout=10)
- assert result == 5
-
-
-def test_worker_handles_kwargs(queue: Queue) -> None:
- """Worker correctly passes keyword arguments."""
-
- @queue.task()
- def greet(name: str, greeting: str = "Hello") -> str:
- return f"{greeting}, {name}!"
-
- job = greet.delay("World", greeting="Hi")
-
- worker_thread = threading.Thread(
- target=queue.run_worker,
- daemon=True,
- )
- worker_thread.start()
-
- result = job.result(timeout=10)
- assert result == "Hi, World!"
-
-
-def test_worker_none_result(queue: Queue) -> None:
- """Tasks returning None work correctly."""
-
- @queue.task()
- def void_task() -> None:
- pass
-
- job = void_task.delay()
-
- worker_thread = threading.Thread(
- target=queue.run_worker,
- daemon=True,
- )
- worker_thread.start()
-
- result = job.result(timeout=10)
- assert result is None
diff --git a/tests/test_batch.py b/tests/test_batch.py
deleted file mode 100644
index 70cec6c..0000000
--- a/tests/test_batch.py
+++ /dev/null
@@ -1,146 +0,0 @@
-"""Tests for batch enqueue (enqueue_many / task.map)."""
-
-import threading
-from typing import Any
-
-from taskito import Queue
-from taskito.middleware import TaskMiddleware
-
-
-def test_enqueue_many(queue: Queue) -> None:
- """enqueue_many enqueues all items in a single batch."""
-
- @queue.task()
- def double(x: int) -> int:
- return x * 2
-
- jobs = queue.enqueue_many(
- task_name=double.name,
- args_list=[(i,) for i in range(10)],
- )
- assert len(jobs) == 10
-
- stats = queue.stats()
- assert stats["pending"] == 10
-
-
-def test_task_map(queue: Queue) -> None:
- """TaskWrapper.map() enqueues and returns results."""
-
- @queue.task()
- def add(a: int, b: int) -> int:
- return a + b
-
- jobs = add.map([(1, 2), (3, 4), (5, 6)])
- assert len(jobs) == 3
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- results = [j.result(timeout=10) for j in jobs]
- assert sorted(results) == [3, 7, 11]
-
-
-def test_batch_stats(queue: Queue) -> None:
- """Batch enqueue of 50 items shows correct pending count."""
-
- @queue.task()
- def noop() -> None:
- pass
-
- queue.enqueue_many(
- task_name=noop.name,
- args_list=[() for _ in range(50)],
- )
-
- stats = queue.stats()
- assert stats["pending"] == 50
-
-
-def test_enqueue_many_invokes_on_enqueue_per_job(tmp_path: Any) -> None:
- """`on_enqueue` middleware must receive each job's own args/kwargs.
-
- Regression: the previous implementation always passed `args_list[0]`
- and a fresh empty options dict to every middleware call, so middleware
- could not distinguish jobs in the batch.
- """
-
- class RecordingMiddleware(TaskMiddleware):
- def __init__(self) -> None:
- self.calls: list[tuple[tuple, dict]] = []
-
- def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None:
- self.calls.append((args, dict(kwargs)))
-
- mw = RecordingMiddleware()
- q = Queue(db_path=str(tmp_path / "test.db"), middleware=[mw])
-
- @q.task()
- def add(a: int, b: int) -> int:
- return a + b
-
- q.enqueue_many(
- task_name=add.name,
- args_list=[(1, 2), (3, 4), (5, 6)],
- kwargs_list=[{"trace": "alpha"}, {"trace": "beta"}, {"trace": "gamma"}],
- )
-
- assert mw.calls == [
- ((1, 2), {"trace": "alpha"}),
- ((3, 4), {"trace": "beta"}),
- ((5, 6), {"trace": "gamma"}),
- ]
-
-
-def test_enqueue_many_applies_option_mutations(tmp_path: Any) -> None:
- """Mutations to the options dict inside on_enqueue must propagate to the
- enqueued jobs — matching the documented behaviour of single-enqueue.
- Regression: the previous implementation discarded mutations because the
- hook ran *after* `enqueue_batch` and against a fresh empty dict.
- """
-
- class PerJobBoosterMiddleware(TaskMiddleware):
- def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None:
- # Bump priority based on the first argument so each job sees a
- # distinct mutation.
- options["priority"] = int(args[0]) * 10
-
- q = Queue(db_path=str(tmp_path / "test.db"), middleware=[PerJobBoosterMiddleware()])
-
- @q.task()
- def task_one(n: int) -> int:
- return n
-
- results = q.enqueue_many(task_name=task_one.name, args_list=[(1,), (2,), (3,)])
-
- priorities = [q.get_job(r.id).to_dict()["priority"] for r in results] # type: ignore[union-attr]
- assert priorities == [10, 20, 30]
-
-
-def test_enqueue_many_logs_middleware_exceptions(tmp_path: Any, caplog: Any) -> None:
- """Middleware exceptions must be logged, not silently swallowed.
-
- Regression: the previous implementation used a bare `except: pass`,
- making misbehaving middleware effectively invisible.
- """
- import logging
-
- class ExplodingMiddleware(TaskMiddleware):
- def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None:
- raise RuntimeError("middleware boom")
-
- q = Queue(db_path=str(tmp_path / "test.db"), middleware=[ExplodingMiddleware()])
-
- @q.task()
- def my_task(n: int) -> int:
- return n
-
- with caplog.at_level(logging.ERROR, logger="taskito.app"):
- results = q.enqueue_many(task_name=my_task.name, args_list=[(1,), (2,)])
-
- # Jobs are still enqueued — middleware errors must not block enqueue
- assert len(results) == 2
-
- # And the error was surfaced via the logger
- assert any("middleware on_enqueue() error" in rec.message for rec in caplog.records)
- assert any("middleware boom" in (rec.exc_text or "") for rec in caplog.records)
diff --git a/tests/test_cancel.py b/tests/test_cancel.py
deleted file mode 100644
index 42c09e3..0000000
--- a/tests/test_cancel.py
+++ /dev/null
@@ -1,62 +0,0 @@
-"""Tests for job cancellation."""
-
-from __future__ import annotations
-
-import threading
-
-from taskito import Queue
-
-
-def test_cancel_pending_job(queue: Queue) -> None:
- """A pending job can be cancelled."""
-
- @queue.task()
- def slow_task() -> str:
- return "done"
-
- job = slow_task.delay()
- assert queue.cancel_job(job.id) is True
-
- refreshed = queue.get_job(job.id)
- assert refreshed is not None
- assert refreshed.status == "cancelled"
-
-
-def test_cancel_nonexistent_job(queue: Queue) -> None:
- """Cancelling a nonexistent job returns False."""
-
- @queue.task()
- def dummy() -> None:
- pass
-
- assert queue.cancel_job("nonexistent-id") is False
-
-
-def test_cancel_completed_job(queue: Queue) -> None:
- """Cancelling a completed job returns False (only pending can be cancelled)."""
-
- @queue.task()
- def quick_task() -> int:
- return 42
-
- job = quick_task.delay()
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- job.result(timeout=10)
- assert queue.cancel_job(job.id) is False
-
-
-def test_cancelled_in_stats(queue: Queue) -> None:
- """Cancelled jobs show up in stats."""
-
- @queue.task()
- def task_a() -> None:
- pass
-
- job = task_a.delay()
- queue.cancel_job(job.id)
-
- stats = queue.stats()
- assert stats["cancelled"] == 1
diff --git a/tests/test_chain.py b/tests/test_chain.py
deleted file mode 100644
index a99a862..0000000
--- a/tests/test_chain.py
+++ /dev/null
@@ -1,82 +0,0 @@
-"""Tests for task chaining (chain, group, chord)."""
-
-from __future__ import annotations
-
-import threading
-from pathlib import Path
-
-import pytest
-
-from taskito import Queue, chain, chord, group
-
-
-@pytest.fixture
-def queue(tmp_path: Path) -> Queue:
- db_path = str(tmp_path / "test_chain.db")
- q = Queue(db_path=db_path, workers=4)
-
- # Start worker in background
- t = threading.Thread(target=q.run_worker, daemon=True)
- t.start()
-
- return q
-
-
-def test_chain_executes_in_order(queue: Queue) -> None:
- """chain pipes results through signatures in order."""
-
- @queue.task()
- def add(a: int, b: int) -> int:
- return a + b
-
- @queue.task()
- def double(x: int) -> int:
- return x * 2
-
- result = chain(add.s(2, 3), double.s())
- last_job = result.apply(queue)
- assert last_job.result(timeout=30) == 10 # (2+3) * 2 = 10
-
-
-def test_chain_with_immutable(queue: Queue) -> None:
- """si() signatures ignore previous results."""
-
- @queue.task()
- def add(a: int, b: int) -> int:
- return a + b
-
- @queue.task()
- def constant() -> int:
- return 99
-
- result = chain(add.s(1, 2), constant.si())
- last_job = result.apply(queue)
- assert last_job.result(timeout=30) == 99
-
-
-def test_group_parallel(queue: Queue) -> None:
- """group enqueues tasks in parallel."""
-
- @queue.task()
- def square(x: int) -> int:
- return x * x
-
- jobs = group(square.s(2), square.s(3), square.s(4)).apply(queue)
- results = [j.result(timeout=30) for j in jobs]
- assert sorted(results) == [4, 9, 16]
-
-
-def test_chord_callback(queue: Queue) -> None:
- """chord runs group, then callback with collected results."""
-
- @queue.task()
- def add(a: int, b: int) -> int:
- return a + b
-
- @queue.task()
- def total(results: list[int]) -> int:
- return sum(results)
-
- grp = group(add.s(1, 2), add.s(3, 4), add.s(5, 6))
- result_job = chord(grp, total.s()).apply(queue)
- assert result_job.result(timeout=30) == 21 # 3 + 7 + 11 = 21
diff --git a/tests/test_cli.py b/tests/test_cli.py
deleted file mode 100644
index 174e079..0000000
--- a/tests/test_cli.py
+++ /dev/null
@@ -1,41 +0,0 @@
-"""Tests for CLI info command."""
-
-from pathlib import Path
-
-import pytest
-
-from taskito.cli import _load_queue, _print_stats
-
-
-def test_load_queue_invalid_format() -> None:
- """_load_queue rejects paths without a colon."""
- with pytest.raises(SystemExit):
- _load_queue("no_colon_here")
-
-
-def test_load_queue_missing_module() -> None:
- """_load_queue exits on missing module."""
- with pytest.raises(SystemExit):
- _load_queue("nonexistent.module:queue")
-
-
-def test_print_stats_format(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None:
- """_print_stats prints a formatted stats table."""
- from taskito import Queue
-
- db_path = str(tmp_path / "test_cli_stats.db")
- queue = Queue(db_path=db_path)
-
- @queue.task()
- def noop() -> None:
- pass
-
- noop.delay()
- noop.delay()
-
- _print_stats(queue)
- output = capsys.readouterr().out
-
- assert "taskito queue statistics" in output
- assert "pending" in output
- assert "total" in output
diff --git a/tests/test_context.py b/tests/test_context.py
deleted file mode 100644
index b76f284..0000000
--- a/tests/test_context.py
+++ /dev/null
@@ -1,65 +0,0 @@
-"""Tests for job context — current_job inside running tasks."""
-
-import threading
-from pathlib import Path
-from typing import Any
-
-import pytest
-
-from taskito import Queue
-from taskito.context import current_job
-
-
-@pytest.fixture
-def queue(tmp_path: Path) -> Queue:
- db_path = str(tmp_path / "test_context.db")
- return Queue(db_path=db_path, workers=1)
-
-
-def test_current_job_raises_outside_task() -> None:
- """current_job properties raise RuntimeError outside a task."""
- with pytest.raises(RuntimeError, match="No active job context"):
- _ = current_job.id
-
-
-def test_current_job_id_available_in_task(queue: Queue) -> None:
- """current_job.id is accessible inside a running task."""
- captured: dict[str, Any] = {}
-
- @queue.task()
- def capture_context() -> str:
- captured["id"] = current_job.id
- captured["task_name"] = current_job.task_name
- captured["retry_count"] = current_job.retry_count
- captured["queue_name"] = current_job.queue_name
- return "ok"
-
- job = capture_context.delay()
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- result = job.result(timeout=10)
- assert result == "ok"
- assert captured["id"] == job.id
- assert captured["task_name"].endswith("capture_context")
- assert captured["retry_count"] == 0
- assert captured["queue_name"] == "default"
-
-
-def test_current_job_update_progress(queue: Queue) -> None:
- """current_job.update_progress() works inside a running task."""
-
- @queue.task()
- def task_with_progress() -> str:
- current_job.update_progress(50)
- current_job.update_progress(100)
- return "done"
-
- job = task_with_progress.delay()
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- result = job.result(timeout=10)
- assert result == "done"
diff --git a/tests/test_contrib.py b/tests/test_contrib.py
deleted file mode 100644
index 5df8773..0000000
--- a/tests/test_contrib.py
+++ /dev/null
@@ -1,325 +0,0 @@
-"""Tests for contrib middleware modules using mocks (no hard dependencies)."""
-
-from __future__ import annotations
-
-import types
-from typing import Any
-from unittest.mock import MagicMock, patch
-
-# ── Helpers ──────────────────────────────────────────────────────────
-
-
-def _make_ctx(
- job_id: str = "job-1",
- task_name: str = "my_task",
- queue_name: str = "default",
- retry_count: int = 0,
-) -> MagicMock:
- ctx = MagicMock()
- ctx.id = job_id
- ctx.task_name = task_name
- ctx.queue_name = queue_name
- ctx.retry_count = retry_count
- return ctx
-
-
-# ── OpenTelemetry ────────────────────────────────────────────────────
-
-
-class TestOpenTelemetryMiddleware:
- def test_before_starts_span(self) -> None:
- otel = _try_import_otel()
- if otel is None:
- return
-
- mock_tracer = MagicMock()
- mock_span = MagicMock()
- mock_tracer.start_span.return_value = mock_span
-
- mw = otel.OpenTelemetryMiddleware.__new__(otel.OpenTelemetryMiddleware)
- import threading
-
- mw._tracer = mock_tracer
- mw._span_name_fn = None
- mw._attr_prefix = "taskito"
- mw._extra_attributes_fn = None
- mw._task_filter = None
- mw._spans = {}
- mw._lock = threading.Lock()
-
- ctx = _make_ctx()
- mw.before(ctx)
-
- mock_tracer.start_span.assert_called_once()
- assert "job-1" in mw._spans
-
- def test_after_ends_span_success(self) -> None:
- otel = _try_import_otel()
- if otel is None:
- return
-
- mock_span = MagicMock()
- mw = otel.OpenTelemetryMiddleware.__new__(otel.OpenTelemetryMiddleware)
- import threading
-
- mw._tracer = MagicMock()
- mw._span_name_fn = None
- mw._attr_prefix = "taskito"
- mw._extra_attributes_fn = None
- mw._task_filter = None
- mw._spans = {"job-1": mock_span}
- mw._lock = threading.Lock()
-
- ctx = _make_ctx()
- mw.after(ctx, result="ok", error=None)
-
- mock_span.set_status.assert_called_once()
- mock_span.end.assert_called_once()
- mock_span.record_exception.assert_not_called()
-
- def test_after_records_exception_on_error(self) -> None:
- otel = _try_import_otel()
- if otel is None:
- return
-
- mock_span = MagicMock()
- mw = otel.OpenTelemetryMiddleware.__new__(otel.OpenTelemetryMiddleware)
- import threading
-
- mw._tracer = MagicMock()
- mw._span_name_fn = None
- mw._attr_prefix = "taskito"
- mw._extra_attributes_fn = None
- mw._task_filter = None
- mw._spans = {"job-1": mock_span}
- mw._lock = threading.Lock()
-
- ctx = _make_ctx()
- exc = ValueError("boom")
- mw.after(ctx, result=None, error=exc)
-
- mock_span.record_exception.assert_called_once_with(exc)
- mock_span.end.assert_called_once()
-
-
-def _try_import_otel() -> types.ModuleType | None:
- """Import otel module with mocked opentelemetry if not installed."""
- try:
- import sys
-
- # Provide mock opentelemetry modules if not installed
- mock_trace = MagicMock()
- mock_trace.StatusCode.OK = "OK"
- mock_trace.StatusCode.ERROR = "ERROR"
-
- with patch.dict(
- sys.modules,
- {
- "opentelemetry": MagicMock(),
- "opentelemetry.trace": mock_trace,
- },
- ):
- if "taskito.contrib.otel" in sys.modules:
- del sys.modules["taskito.contrib.otel"]
- from taskito.contrib import otel
-
- # Patch module-level references
- otel.trace = mock_trace
- otel.StatusCode = mock_trace.StatusCode
- return otel
- except Exception:
- return None
-
-
-# ── Sentry ───────────────────────────────────────────────────────────
-
-
-class TestSentryMiddleware:
- def test_before_pushes_scope(self) -> None:
- sentry_mod = _try_import_sentry()
- if sentry_mod is None:
- return
-
- mock_sdk = sentry_mod.sentry_sdk
- ctx = _make_ctx()
-
- mw = sentry_mod.SentryMiddleware.__new__(sentry_mod.SentryMiddleware)
- mw._tag_prefix = "taskito"
- mw._transaction_name_fn = None
- mw._task_filter = None
- mw._extra_tags_fn = None
- mw.before(ctx)
-
- mock_sdk.push_scope.assert_called_once()
-
- def test_after_pops_scope(self) -> None:
- sentry_mod = _try_import_sentry()
- if sentry_mod is None:
- return
-
- mock_sdk = sentry_mod.sentry_sdk
- ctx = _make_ctx()
-
- mw = sentry_mod.SentryMiddleware.__new__(sentry_mod.SentryMiddleware)
- mw._tag_prefix = "taskito"
- mw._transaction_name_fn = None
- mw._task_filter = None
- mw._extra_tags_fn = None
- mw.after(ctx, result="ok", error=None)
-
- mock_sdk.pop_scope_unsafe.assert_called_once()
-
- def test_after_captures_exception_on_error(self) -> None:
- sentry_mod = _try_import_sentry()
- if sentry_mod is None:
- return
-
- mock_sdk = sentry_mod.sentry_sdk
- ctx = _make_ctx()
- exc = RuntimeError("oops")
-
- mw = sentry_mod.SentryMiddleware.__new__(sentry_mod.SentryMiddleware)
- mw._tag_prefix = "taskito"
- mw._transaction_name_fn = None
- mw._task_filter = None
- mw._extra_tags_fn = None
- mw.after(ctx, result=None, error=exc)
-
- mock_sdk.capture_exception.assert_called_once_with(exc)
-
-
-def _try_import_sentry() -> types.ModuleType | None:
- try:
- import sys
-
- mock_sdk = MagicMock()
- with patch.dict(sys.modules, {"sentry_sdk": mock_sdk}):
- if "taskito.contrib.sentry" in sys.modules:
- del sys.modules["taskito.contrib.sentry"]
- from taskito.contrib import sentry
-
- sentry.sentry_sdk = mock_sdk
- return sentry
- except Exception:
- return None
-
-
-# ── Prometheus ───────────────────────────────────────────────────────
-
-
-def _make_mock_metrics() -> dict[str, Any]:
- """Create a mock metrics dict matching the instance-based store format."""
- return {
- "jobs_total": MagicMock(),
- "job_duration": MagicMock(),
- "active_workers": MagicMock(),
- "retries_total": MagicMock(),
- "queue_depth": MagicMock(),
- "dlq_size": MagicMock(),
- "worker_utilization": MagicMock(),
- "resource_health": MagicMock(),
- "resource_recreations": MagicMock(),
- "resource_init_duration": MagicMock(),
- "proxy_reconstruct_duration": MagicMock(),
- "proxy_reconstruct_total": MagicMock(),
- "proxy_reconstruct_errors": MagicMock(),
- "intercept_duration": MagicMock(),
- "intercept_strategy_total": MagicMock(),
- "pool_size": MagicMock(),
- "pool_active": MagicMock(),
- "pool_idle": MagicMock(),
- "pool_timeouts": MagicMock(),
- }
-
-
-class TestPrometheusMiddleware:
- def test_before_increments_active_workers(self) -> None:
- prom = _try_import_prometheus()
- if prom is None:
- return
-
- import threading
-
- metrics = _make_mock_metrics()
- mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware)
- mw._metrics = metrics
- mw._extra_labels_fn = None
- mw._task_filter = None
- mw._start_times = {}
- mw._lock = threading.Lock()
-
- ctx = _make_ctx()
- mw.before(ctx)
-
- metrics["active_workers"].inc.assert_called()
-
- def test_after_tracks_counter_and_histogram(self) -> None:
- prom = _try_import_prometheus()
- if prom is None:
- return
-
- import threading
-
- metrics = _make_mock_metrics()
- mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware)
- mw._metrics = metrics
- mw._extra_labels_fn = None
- mw._task_filter = None
- mw._start_times = {"job-1": 0.0}
- mw._lock = threading.Lock()
-
- ctx = _make_ctx()
- mw.after(ctx, result="ok", error=None)
-
- metrics["active_workers"].dec.assert_called()
- metrics["jobs_total"].labels.assert_called_with(task="my_task", status="completed")
- metrics["jobs_total"].labels().inc.assert_called()
-
- def test_after_tracks_failure(self) -> None:
- prom = _try_import_prometheus()
- if prom is None:
- return
-
- import threading
-
- metrics = _make_mock_metrics()
- mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware)
- mw._metrics = metrics
- mw._extra_labels_fn = None
- mw._task_filter = None
- mw._start_times = {"job-1": 0.0}
- mw._lock = threading.Lock()
-
- ctx = _make_ctx()
- exc = ValueError("fail")
- mw.after(ctx, result=None, error=exc)
-
- metrics["jobs_total"].labels.assert_called_with(task="my_task", status="failed")
-
-
-def _try_import_prometheus() -> types.ModuleType | None:
- try:
- import sys
-
- mock_counter = MagicMock()
- mock_gauge = MagicMock()
- mock_histogram = MagicMock()
-
- with patch.dict(
- sys.modules,
- {
- "prometheus_client": MagicMock(
- Counter=mock_counter,
- Gauge=mock_gauge,
- Histogram=mock_histogram,
- start_http_server=MagicMock(),
- ),
- },
- ):
- if "taskito.contrib.prometheus" in sys.modules:
- del sys.modules["taskito.contrib.prometheus"]
- from taskito.contrib import prometheus
-
- return prometheus
- except Exception:
- return None
diff --git a/tests/test_customizability.py b/tests/test_customizability.py
deleted file mode 100644
index b33f3c0..0000000
--- a/tests/test_customizability.py
+++ /dev/null
@@ -1,297 +0,0 @@
-"""Tests for customizability configuration options."""
-
-from __future__ import annotations
-
-from typing import Any
-from unittest.mock import MagicMock
-
-from taskito.app import Queue
-from taskito.events import EventType
-from taskito.middleware import TaskMiddleware
-from taskito.webhooks import WebhookManager
-
-# ── Middleware Hooks ──────────────────────────────────────────────────
-
-
-class RecordingMiddleware(TaskMiddleware):
- """Middleware that records all hook calls for testing."""
-
- def __init__(self) -> None:
- self.calls: list[tuple[str, Any]] = []
-
- def before(self, ctx: Any) -> None:
- self.calls.append(("before", ctx.task_name))
-
- def after(self, ctx: Any, result: Any, error: Any) -> None:
- self.calls.append(("after", ctx.task_name))
-
- def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None:
- self.calls.append(("on_enqueue", task_name))
-
- def on_dead_letter(self, ctx: Any, error: Exception) -> None:
- self.calls.append(("on_dead_letter", ctx.task_name))
-
- def on_timeout(self, ctx: Any) -> None:
- self.calls.append(("on_timeout", ctx.task_name))
-
- def on_cancel(self, ctx: Any) -> None:
- self.calls.append(("on_cancel", ctx.task_name))
-
-
-class TestMiddlewareHooks:
- def test_on_enqueue_called(self, tmp_path: Any) -> None:
- mw = RecordingMiddleware()
- q = Queue(db_path=str(tmp_path / "test.db"), middleware=[mw])
-
- @q.task()
- def my_task() -> None:
- pass
-
- my_task.delay()
- assert ("on_enqueue", my_task.name) in mw.calls
-
- def test_on_enqueue_can_mutate_options(self, tmp_path: Any) -> None:
- """on_enqueue can modify the options dict to change enqueue params."""
-
- class PriorityBoostMiddleware(TaskMiddleware):
- def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None:
- options["priority"] = 99
-
- mw = PriorityBoostMiddleware()
- q = Queue(db_path=str(tmp_path / "test.db"), middleware=[mw])
-
- @q.task()
- def my_task() -> None:
- pass
-
- result = my_task.delay()
- job = q.get_job(result.id)
- assert job is not None
- assert job.to_dict()["priority"] == 99
-
- def test_default_hooks_are_noop(self) -> None:
- """Base TaskMiddleware hooks should not raise."""
- mw = TaskMiddleware()
- mw.on_enqueue("test", (), {}, {})
- mw.on_dead_letter(MagicMock(), Exception("test"))
- mw.on_timeout(MagicMock())
- mw.on_cancel(MagicMock())
-
-
-# ── Event System ─────────────────────────────────────────────────────
-
-
-class TestEventSystem:
- def test_new_event_types_exist(self) -> None:
- assert EventType.WORKER_STARTED.value == "worker.started"
- assert EventType.WORKER_STOPPED.value == "worker.stopped"
- assert EventType.QUEUE_PAUSED.value == "queue.paused"
- assert EventType.QUEUE_RESUMED.value == "queue.resumed"
-
- def test_event_workers_param(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"), event_workers=2)
- assert q._event_bus._executor._max_workers == 2
-
- def test_on_event_public_api(self, tmp_path: Any, poll_until: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
- received: list[Any] = []
-
- def callback(event_type: EventType, payload: dict) -> None:
- received.append((event_type, payload))
-
- q.on_event(EventType.JOB_ENQUEUED, callback)
-
- @q.task()
- def my_task() -> None:
- pass
-
- my_task.delay()
- poll_until(lambda: len(received) >= 1, message="JOB_ENQUEUED event not delivered")
- assert len(received) == 1
- assert received[0][0] == EventType.JOB_ENQUEUED
-
-
-# ── Webhook Configuration ────────────────────────────────────────────
-
-
-class TestWebhookConfig:
- def test_add_webhook_with_retry_params(self) -> None:
- mgr = WebhookManager()
- mgr.add_webhook(
- "https://example.com/hook",
- max_retries=5,
- timeout=30.0,
- retry_backoff=3.0,
- )
- wh = mgr._webhooks[0]
- assert wh["max_retries"] == 5
- assert wh["timeout"] == 30.0
- assert wh["retry_backoff"] == 3.0
-
- def test_add_webhook_defaults(self) -> None:
- mgr = WebhookManager()
- mgr.add_webhook("https://example.com/hook")
- wh = mgr._webhooks[0]
- assert wh["max_retries"] == 3
- assert wh["timeout"] == 10.0
- assert wh["retry_backoff"] == 2.0
-
- def test_queue_add_webhook_passes_params(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
- q.add_webhook(
- "https://example.com/hook",
- max_retries=1,
- timeout=5.0,
- retry_backoff=1.5,
- )
- wh = q._webhook_manager._webhooks[0]
- assert wh["max_retries"] == 1
- assert wh["timeout"] == 5.0
-
-
-# ── Queue Configuration ──────────────────────────────────────────────
-
-
-class TestQueueConfig:
- def test_scheduler_timing_params(self, tmp_path: Any) -> None:
- q = Queue(
- db_path=str(tmp_path / "test.db"),
- scheduler_poll_interval_ms=100,
- scheduler_reap_interval=50,
- scheduler_cleanup_interval=600,
- )
- # These are passed to the Rust side — just verify they don't error
- assert q._inner is not None
-
- def test_scheduler_timing_defaults(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
- # Defaults should work fine
- assert q._inner is not None
-
-
-# ── Per-Task Configuration ───────────────────────────────────────────
-
-
-class TestPerTaskConfig:
- def test_max_retry_delay_param(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
-
- @q.task(max_retry_delay=60)
- def my_task() -> None:
- pass
-
- config = q._task_configs[-1]
- assert config.max_retry_delay == 60
-
- def test_max_retry_delay_default(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
-
- @q.task()
- def my_task() -> None:
- pass
-
- config = q._task_configs[-1]
- assert config.max_retry_delay is None
-
- def test_max_concurrent_param(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
-
- @q.task(max_concurrent=5)
- def my_task() -> None:
- pass
-
- config = q._task_configs[-1]
- assert config.max_concurrent == 5
-
- def test_max_concurrent_default(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
-
- @q.task()
- def my_task() -> None:
- pass
-
- config = q._task_configs[-1]
- assert config.max_concurrent is None
-
-
-# ── Per-Task Serializer ──────────────────────────────────────────────
-
-
-class TestPerTaskSerializer:
- def test_task_level_serializer_used_for_enqueue(self, tmp_path: Any) -> None:
- """Per-task serializer is used instead of queue-level serializer."""
- mock_serializer = MagicMock()
- mock_serializer.dumps.return_value = b"\x80\x04\x95"
-
- q = Queue(db_path=str(tmp_path / "test.db"))
-
- @q.task(serializer=mock_serializer)
- def my_task(x: int) -> None:
- pass
-
- my_task.delay(42)
- mock_serializer.dumps.assert_called_once()
-
- def test_queue_serializer_used_when_no_task_serializer(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
-
- @q.task()
- def my_task() -> None:
- pass
-
- # Should use the default CloudpickleSerializer without error
- my_task.delay()
- assert my_task.name not in q._task_serializers
-
-
-# ── Flask CLI ─────────────────────────────────────────────────────────
-
-
-# ── Queue-Level Limits ────────────────────────────────────────────────
-
-
-class TestQueueLevelLimits:
- def test_set_queue_rate_limit(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
- q.set_queue_rate_limit("default", "100/m")
- assert q._queue_configs["default"]["rate_limit"] == "100/m"
-
- def test_set_queue_concurrency(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
- q.set_queue_concurrency("default", 10)
- assert q._queue_configs["default"]["max_concurrent"] == 10
-
- def test_set_both_on_same_queue(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
- q.set_queue_rate_limit("emails", "50/m")
- q.set_queue_concurrency("emails", 5)
- assert q._queue_configs["emails"]["rate_limit"] == "50/m"
- assert q._queue_configs["emails"]["max_concurrent"] == 5
-
- def test_queue_configs_serialized_to_json(self, tmp_path: Any) -> None:
- import json
-
- q = Queue(db_path=str(tmp_path / "test.db"))
- q.set_queue_rate_limit("default", "10/s")
- q.set_queue_concurrency("default", 3)
- serialized = json.dumps(q._queue_configs)
- parsed = json.loads(serialized)
- assert parsed["default"]["rate_limit"] == "10/s"
- assert parsed["default"]["max_concurrent"] == 3
-
-
-# ── Flask CLI ─────────────────────────────────────────────────────────
-
-
-class TestFlaskConfig:
- def test_cli_group_param(self) -> None:
- from taskito.contrib.flask import Taskito
-
- ext = Taskito(cli_group="jobs")
- assert ext._cli_group == "jobs"
-
- def test_cli_group_default(self) -> None:
- from taskito.contrib.flask import Taskito
-
- ext = Taskito()
- assert ext._cli_group == "taskito"
diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py
deleted file mode 100644
index dc281e1..0000000
--- a/tests/test_dashboard.py
+++ /dev/null
@@ -1,403 +0,0 @@
-"""Tests for dashboard API endpoints and list_jobs/to_dict."""
-
-import json
-import threading
-import urllib.error
-import urllib.request
-from collections.abc import Generator
-from pathlib import Path
-from typing import Any
-
-import pytest
-
-from taskito import Queue
-
-
-@pytest.fixture
-def queue(tmp_path: Path) -> Queue:
- """Create a fresh queue with some test data pre-registered."""
- db_path = str(tmp_path / "test_dashboard.db")
- q = Queue(db_path=db_path, workers=2)
-
- @q.task(queue="default")
- def task_a(x: int) -> int:
- return x * 2
-
- @q.task(queue="email")
- def task_b(x: int) -> int:
- return x + 1
-
- return q
-
-
-@pytest.fixture
-def populated_queue(queue: Queue) -> tuple[Queue, list[Any]]:
- """Queue with several jobs enqueued."""
- task_a_name: str = ""
- task_b_name: str = ""
- for name, _fn in queue._task_registry.items():
- if "task_a" in name:
- task_a_name = name
- elif "task_b" in name:
- task_b_name = name
-
- jobs: list[Any] = []
- for i in range(5):
- jobs.append(queue.enqueue(task_a_name, args=(i,)))
- for i in range(3):
- jobs.append(queue.enqueue(task_b_name, args=(i,), queue="email"))
- return queue, jobs
-
-
-# ── list_jobs tests ──────────────────────────────────────
-
-
-def test_list_jobs_returns_all(populated_queue: tuple[Queue, list[Any]]) -> None:
- """list_jobs() with no filters returns all jobs."""
- queue, _ = populated_queue
- result = queue.list_jobs()
- assert len(result) == 8
-
-
-def test_list_jobs_filter_by_queue(populated_queue: tuple[Queue, list[Any]]) -> None:
- """list_jobs() can filter by queue name."""
- queue, _ = populated_queue
- result = queue.list_jobs(queue="email")
- assert len(result) == 3
- for j in result:
- d = j.to_dict()
- assert d["queue"] == "email"
-
-
-def test_list_jobs_filter_by_status(populated_queue: tuple[Queue, list[Any]]) -> None:
- """list_jobs() can filter by status."""
- queue, _ = populated_queue
- result = queue.list_jobs(status="pending")
- assert len(result) == 8 # all are pending
-
- result = queue.list_jobs(status="running")
- assert len(result) == 0
-
-
-def test_list_jobs_filter_by_task_name(populated_queue: tuple[Queue, list[Any]]) -> None:
- """list_jobs() can filter by task name."""
- queue, _ = populated_queue
- # Find the task_a name
- task_a_name = None
- for name in queue._task_registry:
- if "task_a" in name:
- task_a_name = name
- break
-
- result = queue.list_jobs(task_name=task_a_name)
- assert len(result) == 5
-
-
-def test_list_jobs_pagination(populated_queue: tuple[Queue, list[Any]]) -> None:
- """list_jobs() respects limit and offset."""
- queue, _ = populated_queue
- page1 = queue.list_jobs(limit=3, offset=0)
- page2 = queue.list_jobs(limit=3, offset=3)
-
- assert len(page1) == 3
- assert len(page2) == 3
-
- ids1 = {j.id for j in page1}
- ids2 = {j.id for j in page2}
- assert ids1.isdisjoint(ids2)
-
-
-def test_list_jobs_invalid_status(queue: Queue) -> None:
- """list_jobs() raises on invalid status string."""
- with pytest.raises(ValueError):
- queue.list_jobs(status="bogus")
-
-
-# ── to_dict tests ────────────────────────────────────────
-
-
-def test_to_dict_fields(queue: Queue) -> None:
- """to_dict() returns all expected fields."""
-
- @queue.task()
- def dummy() -> None:
- pass
-
- job = dummy.delay()
- d = job.to_dict()
-
- expected_keys = {
- "id",
- "queue",
- "task_name",
- "status",
- "priority",
- "progress",
- "retry_count",
- "max_retries",
- "created_at",
- "scheduled_at",
- "started_at",
- "completed_at",
- "error",
- "timeout_ms",
- "unique_key",
- "metadata",
- "namespace",
- }
- assert set(d.keys()) == expected_keys
- assert d["status"] == "pending"
- assert d["id"] == job.id
-
-
-def test_to_dict_is_json_serializable(queue: Queue) -> None:
- """to_dict() output can be serialized to JSON."""
-
- @queue.task()
- def dummy() -> None:
- pass
-
- job = dummy.delay()
- d = job.to_dict()
- serialized = json.dumps(d)
- assert isinstance(serialized, str)
-
-
-# ── Dashboard HTTP tests ─────────────────────────────────
-
-
-def _start_dashboard(queue: Queue, *, static_assets: Any = None) -> tuple[str, Any]:
- """Boot a dashboard server bound to a random port.
-
- Returns the (url, server) pair so callers can shut it down explicitly.
- Optionally takes a ``StaticAssets`` instance — production code uses
- the package-bundled default; tests inject their own.
- """
- from http.server import ThreadingHTTPServer
-
- from taskito.dashboard import _make_handler
-
- handler = _make_handler(queue, static_assets=static_assets)
- server = ThreadingHTTPServer(("127.0.0.1", 0), handler)
- port = server.server_address[1]
- thread = threading.Thread(target=server.serve_forever, daemon=True)
- thread.start()
- return f"http://127.0.0.1:{port}", server
-
-
-@pytest.fixture
-def dashboard_server(
- populated_queue: tuple[Queue, list[Any]],
-) -> Generator[tuple[str, Queue, list[Any]]]:
- """Start a dashboard server on a random port."""
- queue, jobs = populated_queue
- url, server = _start_dashboard(queue)
- try:
- yield url, queue, jobs
- finally:
- server.shutdown()
-
-
-def _get(url: str) -> Any:
- """GET request and parse JSON."""
- with urllib.request.urlopen(url) as resp:
- return json.loads(resp.read())
-
-
-def _post(url: str) -> Any:
- """POST request and parse JSON."""
- req = urllib.request.Request(url, method="POST", data=b"")
- with urllib.request.urlopen(req) as resp:
- return json.loads(resp.read())
-
-
-def test_api_stats(dashboard_server: tuple[str, Queue, list[Any]]) -> None:
- """GET /api/stats returns valid stats dict."""
- base, _, __ = dashboard_server
- data = _get(f"{base}/api/stats")
- assert "pending" in data
- assert data["pending"] == 8
-
-
-def test_api_jobs_list(dashboard_server: tuple[str, Queue, list[Any]]) -> None:
- """GET /api/jobs returns job list."""
- base, _, __ = dashboard_server
- data = _get(f"{base}/api/jobs")
- assert isinstance(data, list)
- assert len(data) == 8
-
-
-def test_api_jobs_filter_status(dashboard_server: tuple[str, Queue, list[Any]]) -> None:
- """GET /api/jobs?status=pending filters correctly."""
- base, _, __ = dashboard_server
- data = _get(f"{base}/api/jobs?status=pending")
- assert len(data) == 8
-
- data = _get(f"{base}/api/jobs?status=running")
- assert len(data) == 0
-
-
-def test_api_jobs_filter_queue(dashboard_server: tuple[str, Queue, list[Any]]) -> None:
- """GET /api/jobs?queue=email filters correctly."""
- base, _, __ = dashboard_server
- data = _get(f"{base}/api/jobs?queue=email")
- assert len(data) == 3
-
-
-def test_api_jobs_pagination(dashboard_server: tuple[str, Queue, list[Any]]) -> None:
- """GET /api/jobs?limit=3&offset=0 paginates."""
- base, _, __ = dashboard_server
- data = _get(f"{base}/api/jobs?limit=3&offset=0")
- assert len(data) == 3
-
-
-def test_api_job_detail(dashboard_server: tuple[str, Queue, list[Any]]) -> None:
- """GET /api/jobs/{id} returns job dict."""
- base, _, jobs = dashboard_server
- job_id = jobs[0].id
- data = _get(f"{base}/api/jobs/{job_id}")
- assert data["id"] == job_id
- assert "status" in data
-
-
-def test_api_job_not_found(dashboard_server: tuple[str, Queue, list[Any]]) -> None:
- """GET /api/jobs/nonexistent returns 404."""
- base, _, __ = dashboard_server
- try:
- _get(f"{base}/api/jobs/nonexistent-id")
- raise AssertionError("Expected 404")
- except urllib.error.HTTPError as e:
- assert e.code == 404
-
-
-def test_api_cancel_job(dashboard_server: tuple[str, Queue, list[Any]]) -> None:
- """POST /api/jobs/{id}/cancel cancels a pending job."""
- base, _, jobs = dashboard_server
- job_id = jobs[0].id
- data = _post(f"{base}/api/jobs/{job_id}/cancel")
- assert data["cancelled"] is True
-
-
-def test_api_dead_letters_empty(dashboard_server: tuple[str, Queue, list[Any]]) -> None:
- """GET /api/dead-letters returns empty list initially."""
- base, _, __ = dashboard_server
- data = _get(f"{base}/api/dead-letters")
- assert data == []
-
-
-def test_spa_html_served(
- populated_queue: tuple[Queue, list[Any]],
- tmp_path: Path,
-) -> None:
- """GET / serves the SPA index.html when assets are bundled.
-
- Tests inject a ``StaticAssets`` rooted at a tmp directory so the test
- is self-contained — no dependency on a prior frontend build.
- """
- from taskito.dashboard import StaticAssets
-
- tmp_path.joinpath("index.html").write_text(
- '
',
- encoding="utf-8",
- )
- queue, _ = populated_queue
- url, server = _start_dashboard(queue, static_assets=StaticAssets(tmp_path))
- try:
- with urllib.request.urlopen(url) as resp:
- assert resp.status == 200
- assert resp.headers.get("Content-Type", "").startswith("text/html")
- html = resp.read().decode()
- assert "" in html.lower()
- assert 'id="app"' in html
- finally:
- server.shutdown()
-
-
-def test_spa_assets_resolved_under_root(
- populated_queue: tuple[Queue, list[Any]],
- tmp_path: Path,
-) -> None:
- """Hashed asset paths resolve under the bundle root and get long
- immutable cache headers."""
- from taskito.dashboard import StaticAssets
-
- tmp_path.joinpath("index.html").write_text("", encoding="utf-8")
- assets_dir = tmp_path / "assets"
- assets_dir.mkdir()
- assets_dir.joinpath("index-abc.js").write_text("export {};", encoding="utf-8")
-
- queue, _ = populated_queue
- url, server = _start_dashboard(queue, static_assets=StaticAssets(tmp_path))
- try:
- with urllib.request.urlopen(f"{url}/assets/index-abc.js") as resp:
- assert resp.status == 200
- assert resp.headers.get("Content-Type", "").startswith("application/javascript")
- assert "immutable" in resp.headers.get("Cache-Control", "")
- assert resp.read().decode() == "export {};"
- finally:
- server.shutdown()
-
-
-def test_spa_unknown_route_falls_back_to_index(
- populated_queue: tuple[Queue, list[Any]],
- tmp_path: Path,
-) -> None:
- """Client-side routes (anything that's not /assets/* or a real file)
- resolve to index.html so deep links keep working."""
- from taskito.dashboard import StaticAssets
-
- tmp_path.joinpath("index.html").write_text(
- '',
- encoding="utf-8",
- )
- queue, _ = populated_queue
- url, server = _start_dashboard(queue, static_assets=StaticAssets(tmp_path))
- try:
- with urllib.request.urlopen(f"{url}/jobs/some-id") as resp:
- assert resp.status == 200
- assert 'id="app"' in resp.read().decode()
- finally:
- server.shutdown()
-
-
-def test_spa_missing_asset_under_assets_returns_404(
- populated_queue: tuple[Queue, list[Any]],
- tmp_path: Path,
-) -> None:
- """A miss inside ``/assets/`` returns 404 — never the index fallback,
- so a stale page can't confuse the browser into running old chunks."""
- from taskito.dashboard import StaticAssets
-
- tmp_path.joinpath("index.html").write_text("", encoding="utf-8")
- queue, _ = populated_queue
- url, server = _start_dashboard(queue, static_assets=StaticAssets(tmp_path))
- try:
- try:
- urllib.request.urlopen(f"{url}/assets/missing.js")
- pytest.fail("Expected 404")
- except urllib.error.HTTPError as exc:
- assert exc.code == 404
- finally:
- server.shutdown()
-
-
-def test_spa_missing_assets_returns_503(
- populated_queue: tuple[Queue, list[Any]],
-) -> None:
- """When the frontend build wasn't run, the dashboard returns 503 with
- actionable rebuild instructions rather than silently 404-ing."""
- from taskito.dashboard import StaticAssets
-
- queue, _ = populated_queue
- url, server = _start_dashboard(queue, static_assets=StaticAssets(None))
- try:
- try:
- urllib.request.urlopen(url)
- pytest.fail("Expected 503")
- except urllib.error.HTTPError as exc:
- assert exc.code == 503
- body = exc.read().decode()
- assert "not bundled" in body.lower()
- assert "pnpm" in body.lower()
- finally:
- server.shutdown()
diff --git a/tests/test_dashboard_settings.py b/tests/test_dashboard_settings.py
deleted file mode 100644
index ee9888d..0000000
--- a/tests/test_dashboard_settings.py
+++ /dev/null
@@ -1,199 +0,0 @@
-"""Tests for the dashboard settings key/value store.
-
-Covers:
-- ``Queue.get_setting`` / ``set_setting`` / ``delete_setting`` / ``list_settings``
-- HTTP endpoints under ``/api/settings``
-"""
-
-from __future__ import annotations
-
-import json
-import threading
-import urllib.error
-import urllib.request
-from collections.abc import Generator
-from pathlib import Path
-from typing import Any
-
-import pytest
-
-from taskito import Queue
-
-
-@pytest.fixture
-def queue(tmp_path: Path) -> Queue:
- return Queue(db_path=str(tmp_path / "settings.db"))
-
-
-def _put(url: str, body: dict) -> Any:
- req = urllib.request.Request(
- url,
- method="PUT",
- data=json.dumps(body).encode(),
- headers={"Content-Type": "application/json"},
- )
- with urllib.request.urlopen(req) as resp:
- return json.loads(resp.read())
-
-
-def _delete(url: str) -> Any:
- req = urllib.request.Request(url, method="DELETE")
- with urllib.request.urlopen(req) as resp:
- return json.loads(resp.read())
-
-
-def _get(url: str) -> Any:
- with urllib.request.urlopen(url) as resp:
- return json.loads(resp.read())
-
-
-# ── Python API ──────────────────────────────────────────
-
-
-def test_get_setting_returns_none_when_unset(queue: Queue) -> None:
- assert queue.get_setting("missing") is None
-
-
-def test_set_and_get_setting(queue: Queue) -> None:
- queue.set_setting("dashboard.title", "My Queue")
- assert queue.get_setting("dashboard.title") == "My Queue"
-
-
-def test_set_setting_overwrites(queue: Queue) -> None:
- queue.set_setting("k", "v1")
- queue.set_setting("k", "v2")
- assert queue.get_setting("k") == "v2"
-
-
-def test_delete_setting(queue: Queue) -> None:
- queue.set_setting("k", "v")
- assert queue.delete_setting("k") is True
- assert queue.get_setting("k") is None
- # Delete on missing key is a no-op returning False.
- assert queue.delete_setting("k") is False
-
-
-def test_list_settings_returns_all(queue: Queue) -> None:
- queue.set_setting("a", "1")
- queue.set_setting("b", "2")
- snapshot = queue.list_settings()
- assert snapshot == {"a": "1", "b": "2"}
-
-
-def test_setting_preserves_unicode(queue: Queue) -> None:
- queue.set_setting("greeting", "안녕하세요 🌏")
- assert queue.get_setting("greeting") == "안녕하세요 🌏"
-
-
-def test_setting_preserves_json(queue: Queue) -> None:
- payload = json.dumps({"label": "Grafana", "url": "https://example/dash"})
- queue.set_setting("dashboard.links.0", payload)
- assert json.loads(queue.get_setting("dashboard.links.0") or "") == {
- "label": "Grafana",
- "url": "https://example/dash",
- }
-
-
-# ── HTTP endpoints ──────────────────────────────────────
-
-
-@pytest.fixture
-def dashboard_server(queue: Queue) -> Generator[tuple[str, Queue]]:
- from http.server import ThreadingHTTPServer
-
- from taskito.dashboard import _make_handler
-
- handler = _make_handler(queue)
- server = ThreadingHTTPServer(("127.0.0.1", 0), handler)
- port = server.server_address[1]
- thread = threading.Thread(target=server.serve_forever, daemon=True)
- thread.start()
- try:
- yield f"http://127.0.0.1:{port}", queue
- finally:
- server.shutdown()
-
-
-def test_get_settings_returns_empty_dict(dashboard_server: tuple[str, Queue]) -> None:
- base, _ = dashboard_server
- assert _get(f"{base}/api/settings") == {}
-
-
-def test_put_then_get_setting(dashboard_server: tuple[str, Queue]) -> None:
- base, _ = dashboard_server
- _put(f"{base}/api/settings/dashboard.title", {"value": "My Queue"})
-
- data = _get(f"{base}/api/settings/dashboard.title")
- assert data == {"key": "dashboard.title", "value": "My Queue"}
-
- snapshot = _get(f"{base}/api/settings")
- assert snapshot == {"dashboard.title": "My Queue"}
-
-
-def test_put_setting_with_json_value(dashboard_server: tuple[str, Queue]) -> None:
- """Non-string ``value`` is JSON-encoded before persistence."""
- base, queue = dashboard_server
- payload = [
- {"label": "Grafana", "url": "https://grafana.example/d/abc"},
- {"label": "Sentry", "url": "https://sentry.example/issues"},
- ]
- _put(f"{base}/api/settings/dashboard.external_links", {"value": payload})
-
- stored = queue.get_setting("dashboard.external_links")
- assert stored is not None
- assert json.loads(stored) == payload
-
-
-def test_get_unknown_setting_returns_404(dashboard_server: tuple[str, Queue]) -> None:
- base, _ = dashboard_server
- with pytest.raises(urllib.error.HTTPError) as exc_info:
- _get(f"{base}/api/settings/missing.key")
- assert exc_info.value.code == 404
-
-
-def test_put_setting_with_missing_value_field_returns_400(
- dashboard_server: tuple[str, Queue],
-) -> None:
- base, _ = dashboard_server
- with pytest.raises(urllib.error.HTTPError) as exc_info:
- _put(f"{base}/api/settings/k", {"not_value": 1})
- assert exc_info.value.code == 400
-
-
-def test_put_setting_rejects_invalid_json_body(dashboard_server: tuple[str, Queue]) -> None:
- base, _ = dashboard_server
- req = urllib.request.Request(
- f"{base}/api/settings/k",
- method="PUT",
- data=b"{not json",
- headers={"Content-Type": "application/json"},
- )
- with pytest.raises(urllib.error.HTTPError) as exc_info:
- urllib.request.urlopen(req)
- assert exc_info.value.code == 400
-
-
-def test_delete_setting_returns_true_when_exists(
- dashboard_server: tuple[str, Queue],
-) -> None:
- base, queue = dashboard_server
- queue.set_setting("k", "v")
- assert _delete(f"{base}/api/settings/k") == {"deleted": True}
- assert queue.get_setting("k") is None
-
-
-def test_delete_missing_setting_returns_false(
- dashboard_server: tuple[str, Queue],
-) -> None:
- base, _ = dashboard_server
- assert _delete(f"{base}/api/settings/missing") == {"deleted": False}
-
-
-def test_settings_persist_across_queue_instances(tmp_path: Path) -> None:
- """A fresh Queue instance pointed at the same DB sees prior writes."""
- db = str(tmp_path / "persist.db")
- q1 = Queue(db_path=db)
- q1.set_setting("k", "v")
-
- q2 = Queue(db_path=db)
- assert q2.get_setting("k") == "v"
diff --git a/tests/test_dashboard_static.py b/tests/test_dashboard_static.py
deleted file mode 100644
index 4fc61c2..0000000
--- a/tests/test_dashboard_static.py
+++ /dev/null
@@ -1,142 +0,0 @@
-"""Tests for dashboard static asset resolution and Content-Type mapping."""
-
-from __future__ import annotations
-
-from pathlib import Path
-
-import pytest
-
-from taskito.dashboard import _content_type_for, _resolve_static_node
-
-
-@pytest.fixture
-def static_root(tmp_path: Path) -> Path:
- """Layout mirroring a built Vite SPA tree."""
- (tmp_path / "index.html").write_text("")
- assets = tmp_path / "assets"
- assets.mkdir()
- (assets / "index-abc.js").write_text("// js")
- (assets / "index-abc.css").write_text("/* css */")
- (assets / "nested").mkdir()
- (assets / "nested" / "deep.png").write_bytes(b"\x89PNG")
- return tmp_path
-
-
-# ── _resolve_static_node ────────────────────────────────────────────
-
-
-def test_resolve_index_html(static_root: Path) -> None:
- node = _resolve_static_node(static_root, "/index.html")
- assert node is not None
- assert node.read_text() == ""
-
-
-def test_resolve_hashed_asset(static_root: Path) -> None:
- node = _resolve_static_node(static_root, "/assets/index-abc.js")
- assert node is not None
- assert node.read_text() == "// js"
-
-
-def test_resolve_nested_asset(static_root: Path) -> None:
- node = _resolve_static_node(static_root, "/assets/nested/deep.png")
- assert node is not None
- assert node.read_bytes() == b"\x89PNG"
-
-
-def test_resolve_missing_file_returns_none(static_root: Path) -> None:
- assert _resolve_static_node(static_root, "/assets/missing.js") is None
-
-
-def test_resolve_directory_returns_none(static_root: Path) -> None:
- # A directory matches joinpath but is_file() is False
- assert _resolve_static_node(static_root, "/assets") is None
-
-
-def test_resolve_empty_path_returns_none(static_root: Path) -> None:
- assert _resolve_static_node(static_root, "") is None
- assert _resolve_static_node(static_root, "/") is None
-
-
-def test_resolve_rejects_parent_traversal(static_root: Path) -> None:
- assert _resolve_static_node(static_root, "/../secret") is None
- assert _resolve_static_node(static_root, "/assets/../../secret") is None
-
-
-def test_resolve_rejects_current_directory(static_root: Path) -> None:
- assert _resolve_static_node(static_root, "/./index.html") is None
-
-
-def test_resolve_rejects_null_byte(static_root: Path) -> None:
- assert _resolve_static_node(static_root, "/index.html\x00.png") is None
-
-
-def test_resolve_rejects_backslash(static_root: Path) -> None:
- # Windows-style separators should be rejected to avoid ambiguity
- assert _resolve_static_node(static_root, "/assets\\index-abc.js") is None
-
-
-def test_resolve_rejects_double_slash(static_root: Path) -> None:
- # Empty segments from double slashes are rejected
- assert _resolve_static_node(static_root, "/assets//index-abc.js") is None
-
-
-# ── _content_type_for ──────────────────────────────────────────────
-
-
-@pytest.mark.parametrize(
- ("path", "expected"),
- [
- ("/index.html", "text/html; charset=utf-8"),
- ("/assets/index-abc.js", "application/javascript; charset=utf-8"),
- ("/assets/index-abc.mjs", "application/javascript; charset=utf-8"),
- ("/assets/index-abc.css", "text/css; charset=utf-8"),
- ("/icon.svg", "image/svg+xml"),
- ("/icon.png", "image/png"),
- ("/favicon.ico", "image/x-icon"),
- ("/fonts/inter.woff2", "font/woff2"),
- ("/fonts/inter.woff", "font/woff"),
- ("/app.webmanifest", "application/manifest+json"),
- ("/data.json", "application/json; charset=utf-8"),
- ("/unknown.bin", "application/octet-stream"),
- ("/no-extension", "application/octet-stream"),
- ],
-)
-def test_content_type_for(path: str, expected: str) -> None:
- assert _content_type_for(path) == expected
-
-
-def test_content_type_case_insensitive() -> None:
- # Uppercase extensions should still match
- assert _content_type_for("/IMAGE.PNG") == "image/png"
- assert _content_type_for("/script.JS") == "application/javascript; charset=utf-8"
-
-
-# ── _safe_path (log injection guard) ─────────────────────────────────
-
-
-def test_safe_path_strips_crlf() -> None:
- from taskito.dashboard.server import _safe_path
-
- assert _safe_path("/api/jobs\r\nFAKE LOG ENTRY") == "/api/jobsFAKE LOG ENTRY"
-
-
-def test_safe_path_strips_null_byte() -> None:
- from taskito.dashboard.server import _safe_path
-
- assert _safe_path("/api/jobs\x00admin") == "/api/jobsadmin"
-
-
-def test_safe_path_strips_all_control_chars_except_tab() -> None:
- from taskito.dashboard.server import _safe_path
-
- raw = "/api\x01\x02\x1f\x7fpath\twith-tab"
- assert _safe_path(raw) == "/apipath\twith-tab"
-
-
-def test_safe_path_truncates_long_input() -> None:
- from taskito.dashboard.server import _safe_path
-
- raw = "/api/" + "x" * 1000
- out = _safe_path(raw)
- assert len(out) == 256
- assert out.startswith("/api/x")
diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py
deleted file mode 100644
index d5fa75d..0000000
--- a/tests/test_dependencies.py
+++ /dev/null
@@ -1,117 +0,0 @@
-"""Tests for task dependency feature."""
-
-import threading
-
-import pytest
-
-from taskito import Queue
-
-
-def test_enqueue_with_depends_on(queue: Queue) -> None:
- """Jobs can declare dependencies on other jobs."""
-
- @queue.task()
- def step(x: int) -> int:
- return x
-
- job_a = step.delay(1)
- job_b = step.apply_async(args=(2,), depends_on=job_a.id)
-
- assert job_b.dependencies == [job_a.id]
- assert job_a.dependents == [job_b.id]
-
-
-def test_enqueue_with_multiple_deps(queue: Queue) -> None:
- """Jobs can depend on multiple other jobs."""
-
- @queue.task()
- def step(x: int) -> int:
- return x
-
- job_a = step.delay(1)
- job_b = step.delay(2)
- job_c = step.apply_async(args=(3,), depends_on=[job_a.id, job_b.id])
-
- deps = set(job_c.dependencies)
- assert deps == {job_a.id, job_b.id}
-
-
-def test_depends_on_string_coercion(queue: Queue) -> None:
- """depends_on accepts a single string ID."""
-
- @queue.task()
- def step(x: int) -> int:
- return x
-
- job_a = step.delay(1)
- # Pass as single string, not list
- job_b = queue.enqueue(
- task_name=step.name,
- args=(2,),
- depends_on=job_a.id,
- )
-
- assert job_b.dependencies == [job_a.id]
-
-
-def test_dependency_blocks_execution(queue: Queue) -> None:
- """Dependent job waits until dependency completes."""
-
- @queue.task()
- def step(x: int) -> int:
- return x * 10
-
- job_a = step.delay(1)
- job_b = step.apply_async(args=(2,), depends_on=job_a.id)
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- # Both should complete — job_a first, then job_b
- result_a = job_a.result(timeout=10)
- result_b = job_b.result(timeout=10)
-
- assert result_a == 10
- assert result_b == 20
-
-
-def test_cascade_cancel_on_job_cancel(queue: Queue) -> None:
- """Cancelling a job cascades to its dependents."""
-
- @queue.task()
- def step(x: int) -> int:
- return x
-
- job_a = step.delay(1)
- job_b = step.apply_async(args=(2,), depends_on=job_a.id)
- job_c = step.apply_async(args=(3,), depends_on=job_b.id)
-
- queue.cancel_job(job_a.id)
-
- job_b.refresh()
- job_c.refresh()
- assert job_b.status == "cancelled"
- assert job_c.status == "cancelled"
-
-
-def test_no_dependencies_property_when_none(queue: Queue) -> None:
- """Jobs without dependencies return empty list."""
-
- @queue.task()
- def step(x: int) -> int:
- return x
-
- job = step.delay(1)
- assert job.dependencies == []
- assert job.dependents == []
-
-
-def test_enqueue_rejects_missing_dependency(queue: Queue) -> None:
- """Enqueuing with a nonexistent dependency raises an error."""
-
- @queue.task()
- def step(x: int) -> int:
- return x
-
- with pytest.raises(RuntimeError):
- step.apply_async(args=(1,), depends_on="nonexistent-id")
diff --git a/tests/test_dlq.py b/tests/test_dlq.py
deleted file mode 100644
index 14b484d..0000000
--- a/tests/test_dlq.py
+++ /dev/null
@@ -1,48 +0,0 @@
-"""Tests for dead letter queue management."""
-
-import threading
-from typing import Any
-
-from taskito import Queue
-
-PollUntil = Any # the conftest fixture's runtime type
-
-
-def test_dead_letters_empty(queue: Queue) -> None:
- """Empty DLQ returns empty list."""
- dead = queue.dead_letters()
- assert dead == []
-
-
-def test_purge_dead(queue: Queue, poll_until: PollUntil) -> None:
- """Purging removes old dead letter entries."""
-
- @queue.task(max_retries=0, retry_backoff=0.1)
- def instant_fail() -> None:
- raise RuntimeError("fail")
-
- instant_fail.delay()
-
- worker_thread = threading.Thread(
- target=queue.run_worker,
- daemon=True,
- )
- worker_thread.start()
-
- poll_until(
- lambda: len(queue.dead_letters()) >= 1,
- timeout=10,
- message="failed job did not reach DLQ",
- )
- dead = queue.dead_letters()
- assert len(dead) >= 1
-
- # Purge entries older than 0 seconds (purge everything)
- purged = queue.purge_dead(older_than=0)
- # Note: purge uses "older_than" seconds ago, so older_than=0 means
- # cutoff is now, which purges everything before now
- # The entry was just created so it might not be purged with 0
- # Use a large value instead
- queue.dead_letters()
- # Just verify the API works without error
- assert isinstance(purged, int)
diff --git a/tests/test_events.py b/tests/test_events.py
deleted file mode 100644
index 6d6e5c6..0000000
--- a/tests/test_events.py
+++ /dev/null
@@ -1,103 +0,0 @@
-"""Tests for EventBus event dispatch."""
-
-import time
-from typing import Any
-
-from taskito.events import EventBus, EventType
-
-PollUntil = Any # the conftest fixture's runtime type
-
-
-def test_callback_receives_event(poll_until: PollUntil) -> None:
- """Registered callbacks receive emitted events."""
- received: list[tuple[EventType, dict[str, Any]]] = []
- bus = EventBus()
- bus.on(EventType.JOB_COMPLETED, lambda et, p: received.append((et, p)))
-
- bus.emit(EventType.JOB_COMPLETED, {"job_id": "123"})
- poll_until(lambda: len(received) >= 1, message="event was not delivered")
-
- assert len(received) == 1
- assert received[0][0] == EventType.JOB_COMPLETED
- assert received[0][1]["job_id"] == "123"
-
-
-def test_multiple_callbacks(poll_until: PollUntil) -> None:
- """Multiple callbacks for the same event type all fire."""
- counts = {"a": 0, "b": 0}
- bus = EventBus()
- bus.on(EventType.JOB_FAILED, lambda et, p: counts.__setitem__("a", counts["a"] + 1))
- bus.on(EventType.JOB_FAILED, lambda et, p: counts.__setitem__("b", counts["b"] + 1))
-
- bus.emit(EventType.JOB_FAILED, {"error": "boom"})
- poll_until(
- lambda: counts["a"] == 1 and counts["b"] == 1,
- message="not all callbacks fired",
- )
-
- assert counts["a"] == 1
- assert counts["b"] == 1
-
-
-def test_event_filtering() -> None:
- """Callbacks only fire for their registered event type."""
- received: list[str] = []
- bus = EventBus()
- bus.on(EventType.JOB_COMPLETED, lambda et, p: received.append("completed"))
-
- bus.emit(EventType.JOB_FAILED, {"error": "boom"})
- # Brief settle so a (would-be incorrect) cross-type dispatch could land.
- time.sleep(0.2)
-
- assert received == []
-
-
-def test_exception_in_callback_does_not_crash(poll_until: PollUntil) -> None:
- """A raising callback doesn't prevent other events from processing."""
- results: list[str] = []
- bus = EventBus()
-
- def bad_callback(et: EventType, p: dict[str, Any]) -> None:
- raise RuntimeError("callback error")
-
- def good_callback(et: EventType, p: dict[str, Any]) -> None:
- results.append("ok")
-
- bus.on(EventType.JOB_ENQUEUED, bad_callback)
- bus.on(EventType.JOB_ENQUEUED, good_callback)
-
- bus.emit(EventType.JOB_ENQUEUED, {})
- poll_until(lambda: results == ["ok"], message="good_callback did not run")
-
- assert results == ["ok"]
-
-
-def test_emit_with_no_listeners() -> None:
- """Emitting an event with no listeners doesn't raise."""
- bus = EventBus()
- bus.emit(EventType.JOB_DEAD, {"job_id": "456"})
-
-
-def test_all_event_types_exist() -> None:
- """All expected event types are defined."""
- expected = {
- "job.enqueued",
- "job.completed",
- "job.failed",
- "job.retrying",
- "job.dead",
- "job.cancelled",
- "worker.started",
- "worker.stopped",
- "worker.online",
- "worker.offline",
- "worker.unhealthy",
- "queue.paused",
- "queue.resumed",
- "workflow.submitted",
- "workflow.completed",
- "workflow.failed",
- "workflow.cancelled",
- "workflow.gate_reached",
- }
- assert {e.value for e in EventType} == expected
diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py
deleted file mode 100644
index fdbcc1f..0000000
--- a/tests/test_fastapi.py
+++ /dev/null
@@ -1,190 +0,0 @@
-"""Tests for FastAPI integration."""
-
-import threading
-from typing import Any
-
-import pytest
-
-# Skip entire module if fastapi is not installed
-fastapi = pytest.importorskip("fastapi")
-httpx = pytest.importorskip("httpx")
-
-from fastapi import FastAPI # noqa: E402
-from fastapi.testclient import TestClient # noqa: E402
-
-from taskito import Queue # noqa: E402
-from taskito.contrib.fastapi import TaskitoRouter # noqa: E402
-
-
-@pytest.fixture
-def app(queue: Queue) -> FastAPI:
- """Create a FastAPI app with TaskitoRouter."""
- app = FastAPI()
- app.include_router(TaskitoRouter(queue), prefix="/tasks")
- return app
-
-
-@pytest.fixture
-def client(app: FastAPI) -> TestClient:
- """Create a TestClient."""
- return TestClient(app)
-
-
-@pytest.fixture
-def populated(queue: Queue, client: TestClient) -> tuple[Queue, TestClient, list[Any], Any]:
- """Queue with a task and some jobs."""
-
- @queue.task()
- def add(a: int, b: int) -> int:
- return a + b
-
- jobs = [add.delay(i, i + 1) for i in range(5)]
- return queue, client, jobs, add
-
-
-# ── Stats ────────────────────────────────────────────────
-
-
-def test_stats(populated: tuple[Queue, TestClient, list[Any], Any]) -> None:
- _queue, client, _jobs, _add = populated
- resp = client.get("/tasks/stats")
- assert resp.status_code == 200
- data = resp.json()
- assert data["pending"] == 5
- assert data["running"] == 0
-
-
-# ── Job detail ───────────────────────────────────────────
-
-
-def test_get_job(populated: tuple[Queue, TestClient, list[Any], Any]) -> None:
- _queue, client, jobs, _add = populated
- job_id = jobs[0].id
- resp = client.get(f"/tasks/jobs/{job_id}")
- assert resp.status_code == 200
- data = resp.json()
- assert data["id"] == job_id
- assert data["status"] == "pending"
-
-
-def test_get_job_not_found(client: TestClient) -> None:
- resp = client.get("/tasks/jobs/nonexistent")
- assert resp.status_code == 404
-
-
-# ── Job errors ───────────────────────────────────────────
-
-
-def test_get_job_errors_empty(populated: tuple[Queue, TestClient, list[Any], Any]) -> None:
- _queue, client, jobs, _add = populated
- job_id = jobs[0].id
- resp = client.get(f"/tasks/jobs/{job_id}/errors")
- assert resp.status_code == 200
- assert resp.json() == []
-
-
-# ── Job result ───────────────────────────────────────────
-
-
-def test_get_job_result_pending(populated: tuple[Queue, TestClient, list[Any], Any]) -> None:
- _queue, client, jobs, _add = populated
- job_id = jobs[0].id
- resp = client.get(f"/tasks/jobs/{job_id}/result")
- assert resp.status_code == 200
- data = resp.json()
- assert data["status"] == "pending"
- assert data["result"] is None
-
-
-def test_get_job_result_completed(populated: tuple[Queue, TestClient, list[Any], Any]) -> None:
- queue, client, jobs, _add = populated
-
- worker = threading.Thread(target=queue.run_worker, daemon=True)
- worker.start()
-
- # Wait for first job to complete
- jobs[0].result(timeout=10)
-
- resp = client.get(f"/tasks/jobs/{jobs[0].id}/result")
- assert resp.status_code == 200
- data = resp.json()
- assert data["status"] == "complete"
- assert data["result"] == 1 # add(0, 1) = 1
-
-
-# ── Cancel ───────────────────────────────────────────────
-
-
-def test_cancel_job(populated: tuple[Queue, TestClient, list[Any], Any]) -> None:
- _queue, client, jobs, _add = populated
- job_id = jobs[0].id
- resp = client.post(f"/tasks/jobs/{job_id}/cancel")
- assert resp.status_code == 200
- assert resp.json()["cancelled"] is True
-
- # Cancel again — should be false
- resp = client.post(f"/tasks/jobs/{job_id}/cancel")
- assert resp.json()["cancelled"] is False
-
-
-# ── Dead letters ─────────────────────────────────────────
-
-
-def test_dead_letters_empty(client: TestClient) -> None:
- resp = client.get("/tasks/dead-letters")
- assert resp.status_code == 200
- assert resp.json() == []
-
-
-# ── Progress SSE ─────────────────────────────────────────
-
-
-def test_progress_stream(populated: tuple[Queue, TestClient, list[Any], Any]) -> None:
- queue, client, jobs, _add = populated
-
- # Start worker so the job completes
- worker = threading.Thread(target=queue.run_worker, daemon=True)
- worker.start()
-
- job_id = jobs[0].id
- # Wait for job to finish first
- jobs[0].result(timeout=10)
-
- with client.stream("GET", f"/tasks/jobs/{job_id}/progress") as resp:
- assert resp.status_code == 200
- lines: list[str] = []
- for line in resp.iter_lines():
- if line.startswith("data:"):
- lines.append(line)
- break # Just check first event
-
- assert len(lines) >= 1
- # The first (and only) event should show complete status
- import json
-
- data = json.loads(lines[0].replace("data: ", ""))
- assert data["status"] == "complete"
-
-
-def test_progress_stream_not_found(client: TestClient) -> None:
- resp = client.get("/tasks/jobs/nonexistent/progress")
- assert resp.status_code == 404
-
-
-# ── Router config ────────────────────────────────────────
-
-
-def test_router_custom_tags(queue: Queue) -> None:
- """TaskitoRouter accepts standard APIRouter kwargs."""
- router = TaskitoRouter(queue, tags=["my-tasks"])
- assert "my-tasks" in router.tags
-
-
-def test_router_custom_prefix(queue: Queue) -> None:
- """Router can be mounted with a custom prefix."""
- app = FastAPI()
- app.include_router(TaskitoRouter(queue), prefix="/api/v1/queue")
- client = TestClient(app)
-
- resp = client.get("/api/v1/queue/stats")
- assert resp.status_code == 200
diff --git a/tests/test_hooks.py b/tests/test_hooks.py
deleted file mode 100644
index e96d4f5..0000000
--- a/tests/test_hooks.py
+++ /dev/null
@@ -1,83 +0,0 @@
-"""Tests for the hooks/middleware system."""
-
-from __future__ import annotations
-
-import threading
-from typing import Any
-
-from taskito import Queue
-
-
-def test_before_and_after_hooks(queue: Queue) -> None:
- """before_task and after_task hooks fire around task execution."""
- events: list[tuple[Any, ...]] = []
-
- @queue.before_task
- def on_before(task_name: str, args: tuple, kwargs: dict) -> None:
- events.append(("before", task_name))
-
- @queue.after_task
- def on_after(task_name: str, args: tuple, kwargs: dict, result: Any, error: Any) -> None:
- events.append(("after", task_name, result, error))
-
- @queue.task()
- def add(a: int, b: int) -> int:
- return a + b
-
- job = add.delay(1, 2)
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- result = job.result(timeout=10)
- assert result == 3
-
- # Verify hooks fired
- assert any(e[0] == "before" for e in events)
- assert any(e[0] == "after" and len(e) > 2 and e[2] == 3 and e[3] is None for e in events)
-
-
-def test_on_success_hook(queue: Queue) -> None:
- """on_success hook fires when task succeeds."""
- success_results: list[Any] = []
-
- @queue.on_success
- def on_success(task_name: str, args: tuple, kwargs: dict, result: Any) -> None:
- success_results.append(result)
-
- @queue.task()
- def multiply(a: int, b: int) -> int:
- return a * b
-
- job = multiply.delay(3, 4)
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- result = job.result(timeout=10)
- assert result == 12
- assert 12 in success_results
-
-
-def test_on_failure_hook(queue: Queue) -> None:
- """on_failure hook fires when task raises."""
- failure_errors: list[str] = []
-
- @queue.on_failure
- def on_failure(task_name: str, args: tuple, kwargs: dict, error: Exception) -> None:
- failure_errors.append(str(error))
-
- @queue.task(max_retries=1, retry_backoff=0.1)
- def always_fails() -> None:
- raise ValueError("boom")
-
- always_fails.delay()
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- import time
-
- time.sleep(3)
-
- assert any("boom" in e for e in failure_errors)
diff --git a/tests/test_idempotent.py b/tests/test_idempotent.py
deleted file mode 100644
index fefe3ef..0000000
--- a/tests/test_idempotent.py
+++ /dev/null
@@ -1,123 +0,0 @@
-"""Tests for auto-derived idempotency on @queue.task(idempotent=True)."""
-
-from __future__ import annotations
-
-import threading
-
-from taskito import Queue
-
-
-def test_idempotent_task_dedupes_by_args(queue: Queue) -> None:
- """Two enqueues with identical args under idempotent=True share a job."""
-
- @queue.task(idempotent=True)
- def charge(customer_id: int, amount: int) -> int:
- return amount
-
- job1 = charge.delay(42, 1000)
- job2 = charge.delay(42, 1000)
-
- assert job1.id == job2.id
-
-
-def test_idempotent_task_distinct_args_distinct_jobs(queue: Queue) -> None:
- """Different arguments produce different auto-keys → different jobs."""
-
- @queue.task(idempotent=True)
- def charge(customer_id: int, amount: int) -> int:
- return amount
-
- job_a = charge.delay(1, 100)
- job_b = charge.delay(2, 100)
- job_c = charge.delay(1, 200)
-
- assert len({job_a.id, job_b.id, job_c.id}) == 3
-
-
-def test_idempotent_per_call_key_overrides_auto(queue: Queue) -> None:
- """An explicit idempotency_key wins over the auto-derived key."""
-
- @queue.task(idempotent=True)
- def process(data: str) -> str:
- return data
-
- auto_job = process.delay("payload")
- explicit_job = process.apply_async(args=("payload",), idempotency_key="custom-key")
-
- # Auto-key and explicit-key collide on the same args only by coincidence —
- # since the explicit key is "custom-key", they should be different jobs.
- assert auto_job.id != explicit_job.id
-
-
-def test_idempotent_per_call_disable_creates_new_job(queue: Queue) -> None:
- """idempotent=False on apply_async overrides the per-task default."""
-
- @queue.task(idempotent=True)
- def process(data: str) -> str:
- return data
-
- auto_job = process.delay("same-args")
- forced_new = process.apply_async(args=("same-args",), idempotent=False)
-
- assert auto_job.id != forced_new.id
-
-
-def test_non_idempotent_task_allows_duplicates(queue: Queue) -> None:
- """Without idempotent=True, identical calls produce distinct jobs."""
-
- @queue.task()
- def process(data: str) -> str:
- return data
-
- job1 = process.delay("payload")
- job2 = process.delay("payload")
-
- assert job1.id != job2.id
-
-
-def test_idempotent_clears_after_completion(queue: Queue) -> None:
- """After an idempotent job finishes, the next call creates a new job."""
-
- @queue.task(idempotent=True)
- def fast() -> str:
- return "done"
-
- job1 = fast.delay()
-
- worker = threading.Thread(target=queue.run_worker, daemon=True)
- worker.start()
-
- job1.result(timeout=10)
-
- job2 = fast.delay()
- assert job2.id != job1.id
-
-
-def test_idempotent_via_enqueue_kwarg(queue: Queue) -> None:
- """Per-call idempotent=True works without a registered task default."""
-
- @queue.task()
- def process(data: str) -> str:
- return data
-
- job1 = process.apply_async(args=("x",), idempotent=True)
- job2 = process.apply_async(args=("x",), idempotent=True)
- job3 = process.apply_async(args=("y",), idempotent=True)
-
- assert job1.id == job2.id
- assert job1.id != job3.id
-
-
-def test_idempotent_unique_key_takes_precedence(queue: Queue) -> None:
- """An explicit unique_key beats both auto-derivation and idempotency_key."""
-
- @queue.task(idempotent=True)
- def process(data: str) -> str:
- return data
-
- explicit = process.apply_async(args=("payload",), unique_key="explicit-uk")
- auto = process.delay("payload")
-
- # The auto key is "auto:" and the explicit one is "explicit-uk",
- # so they must differ.
- assert explicit.id != auto.id
diff --git a/tests/test_interception.py b/tests/test_interception.py
deleted file mode 100644
index 4f7db15..0000000
--- a/tests/test_interception.py
+++ /dev/null
@@ -1,446 +0,0 @@
-"""Tests for the argument interception system (Layer 1)."""
-
-from __future__ import annotations
-
-import datetime
-import decimal
-import enum
-import pathlib
-import re
-import socket
-import threading
-import uuid
-from dataclasses import dataclass
-from typing import Any
-
-import pytest
-
-from taskito import Queue
-from taskito.interception import (
- ArgumentInterceptor,
- InterceptionError,
- InterceptionReport,
-)
-from taskito.interception.built_in import build_default_registry
-from taskito.interception.converters import reconstruct_converted
-from taskito.interception.reconstruct import reconstruct_args
-from taskito.interception.strategy import Strategy
-
-# -- Fixtures --
-
-
-@pytest.fixture
-def registry() -> Any:
- return build_default_registry()
-
-
-@pytest.fixture
-def strict(registry: Any) -> ArgumentInterceptor:
- return ArgumentInterceptor(registry, mode="strict")
-
-
-@pytest.fixture
-def lenient(registry: Any) -> ArgumentInterceptor:
- return ArgumentInterceptor(registry, mode="lenient")
-
-
-# -- PASS strategy --
-
-
-class TestPassStrategy:
- def test_int_passes_through(self, strict: ArgumentInterceptor) -> None:
- args, _kw = strict.intercept((42,), {})
- assert args == (42,)
-
- def test_str_passes_through(self, strict: ArgumentInterceptor) -> None:
- args, _kw = strict.intercept(("hello",), {})
- assert args == ("hello",)
-
- def test_float_passes_through(self, strict: ArgumentInterceptor) -> None:
- args, _kw = strict.intercept((3.14,), {})
- assert args == (3.14,)
-
- def test_bool_passes_through(self, strict: ArgumentInterceptor) -> None:
- args, _kw = strict.intercept((True, False), {})
- assert args == (True, False)
-
- def test_none_passes_through(self, strict: ArgumentInterceptor) -> None:
- args, _kw = strict.intercept((None,), {})
- assert args == (None,)
-
- def test_bytes_passes_through(self, strict: ArgumentInterceptor) -> None:
- args, _kw = strict.intercept((b"data",), {})
- assert args == (b"data",)
-
- def test_mixed_primitives(self, strict: ArgumentInterceptor) -> None:
- args, kwargs = strict.intercept(
- (1, "two", 3.0, True, None, b"six"),
- {"key": "val"},
- )
- assert args == (1, "two", 3.0, True, None, b"six")
- assert kwargs == {"key": "val"}
-
-
-# -- CONVERT strategy --
-
-
-class TestConvertStrategy:
- def test_uuid_round_trip(self, strict: ArgumentInterceptor) -> None:
- original = uuid.UUID("12345678-1234-5678-1234-567812345678")
- args, _kw = strict.intercept((original,), {})
- assert args[0]["__taskito_convert__"] is True
- assert args[0]["type_key"] == "uuid"
- # Reconstruct
- restored = reconstruct_converted(args[0])
- assert restored == original
-
- def test_datetime_round_trip(self, strict: ArgumentInterceptor) -> None:
- original = datetime.datetime(2025, 3, 10, 12, 0, 0)
- args, _ = strict.intercept((original,), {})
- assert args[0]["type_key"] == "datetime"
- restored = reconstruct_converted(args[0])
- assert restored == original
-
- def test_date_round_trip(self, strict: ArgumentInterceptor) -> None:
- original = datetime.date(2025, 3, 10)
- args, _ = strict.intercept((original,), {})
- assert args[0]["type_key"] == "date"
- restored = reconstruct_converted(args[0])
- assert restored == original
-
- def test_time_round_trip(self, strict: ArgumentInterceptor) -> None:
- original = datetime.time(14, 30, 0)
- args, _ = strict.intercept((original,), {})
- assert args[0]["type_key"] == "time"
- restored = reconstruct_converted(args[0])
- assert restored == original
-
- def test_timedelta_round_trip(self, strict: ArgumentInterceptor) -> None:
- original = datetime.timedelta(hours=1, minutes=30)
- args, _ = strict.intercept((original,), {})
- assert args[0]["type_key"] == "timedelta"
- restored = reconstruct_converted(args[0])
- assert restored == original
-
- def test_decimal_round_trip(self, strict: ArgumentInterceptor) -> None:
- original = decimal.Decimal("3.14159")
- args, _ = strict.intercept((original,), {})
- assert args[0]["type_key"] == "decimal"
- restored = reconstruct_converted(args[0])
- assert restored == original
-
- def test_path_round_trip(self, strict: ArgumentInterceptor) -> None:
- original = pathlib.Path("/tmp/data.csv")
- args, _ = strict.intercept((original,), {})
- assert args[0]["type_key"] == "path"
- restored = reconstruct_converted(args[0])
- assert restored == original
-
- def test_pattern_round_trip(self, strict: ArgumentInterceptor) -> None:
- original = re.compile(r"\d+", re.IGNORECASE)
- args, _ = strict.intercept((original,), {})
- assert args[0]["type_key"] == "pattern"
- restored = reconstruct_converted(args[0])
- assert restored.pattern == original.pattern
- assert restored.flags == original.flags
-
- def test_enum_converts(self, strict: ArgumentInterceptor) -> None:
- class Color(enum.Enum):
- RED = "red"
- GREEN = "green"
-
- args, _ = strict.intercept((Color.RED,), {})
- assert args[0]["type_key"] == "enum"
- assert args[0]["value"] == "red"
-
- def test_dataclass_converts(self, strict: ArgumentInterceptor) -> None:
- @dataclass
- class Point:
- x: int
- y: int
-
- original = Point(x=1, y=2)
- args, _ = strict.intercept((original,), {})
- assert args[0]["__taskito_convert__"] is True
- assert args[0]["type_key"] == "dataclass"
- assert args[0]["value"] == {"x": 1, "y": 2}
-
- def test_datetime_before_date(self, strict: ArgumentInterceptor) -> None:
- """datetime is a subclass of date — datetime must match first."""
- dt = datetime.datetime(2025, 1, 1, 12, 0, 0)
- args, _ = strict.intercept((dt,), {})
- assert args[0]["type_key"] == "datetime"
-
-
-# -- REDIRECT strategy --
-
-
-class TestRedirectStrategy:
- def test_redirect_produces_marker(self, strict: ArgumentInterceptor) -> None:
- """Test that redirect types produce markers (if sqlalchemy is installed)."""
- try:
- from sqlalchemy.orm import Session # type: ignore[import-not-found] # noqa: F401
-
- sqlalchemy_available = True
- except ImportError:
- sqlalchemy_available = False
-
- if not sqlalchemy_available:
- pytest.skip("sqlalchemy not installed")
-
-
-# -- REJECT strategy --
-
-
-class TestRejectStrategy:
- def test_threading_lock_rejected(self, strict: ArgumentInterceptor) -> None:
- lock = threading.Lock()
- with pytest.raises(InterceptionError) as exc_info:
- strict.intercept((lock,), {})
- assert len(exc_info.value.failures) == 1
- assert "args[0]" in exc_info.value.failures[0].path
- assert "lock" in exc_info.value.failures[0].type_name.lower()
-
- def test_socket_rejected(self, strict: ArgumentInterceptor) -> None:
- sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- try:
- with pytest.raises(InterceptionError):
- strict.intercept((sock,), {})
- finally:
- sock.close()
-
- def test_generator_rejected(self, strict: ArgumentInterceptor) -> None:
- gen = (x for x in range(10))
- with pytest.raises(InterceptionError):
- strict.intercept((gen,), {})
-
- def test_reject_error_includes_path(self, strict: ArgumentInterceptor) -> None:
- lock = threading.Lock()
- with pytest.raises(InterceptionError) as exc_info:
- strict.intercept((), {"session": lock})
- assert "kwargs.session" in exc_info.value.failures[0].path
-
- def test_reject_error_has_suggestions(self, strict: ArgumentInterceptor) -> None:
- lock = threading.Lock()
- with pytest.raises(InterceptionError) as exc_info:
- strict.intercept((lock,), {})
- assert len(exc_info.value.failures[0].suggestions) > 0
-
- def test_multiple_rejections_collected(self, strict: ArgumentInterceptor) -> None:
- lock = threading.Lock()
- event = threading.Event()
- with pytest.raises(InterceptionError) as exc_info:
- strict.intercept((lock, event), {})
- assert len(exc_info.value.failures) == 2
-
-
-# -- Lenient mode --
-
-
-class TestLenientMode:
- def test_rejected_arg_dropped(self, lenient: ArgumentInterceptor) -> None:
- lock = threading.Lock()
- args, _kw = lenient.intercept((42, lock), {})
- assert args[0] == 42
- assert args[1] is None # dropped to None
-
- def test_rejected_kwarg_dropped(self, lenient: ArgumentInterceptor) -> None:
- lock = threading.Lock()
- _args, kwargs = lenient.intercept((), {"x": 1, "lock": lock})
- assert kwargs == {"x": 1}
-
-
-# -- Off mode --
-
-
-class TestOffMode:
- def test_passthrough_no_interception(self, registry: Any) -> None:
- interceptor = ArgumentInterceptor(registry, mode="off")
- lock = threading.Lock()
- args, _kw = interceptor.intercept((lock,), {})
- assert args[0] is lock
-
-
-# -- Recursive walking --
-
-
-class TestRecursiveWalking:
- def test_nested_uuid_in_dict(self, strict: ArgumentInterceptor) -> None:
- uid = uuid.uuid4()
- _, kwargs = strict.intercept((), {"config": {"user_id": uid}})
- assert kwargs["config"]["user_id"]["__taskito_convert__"] is True
-
- def test_uuid_in_list(self, strict: ArgumentInterceptor) -> None:
- uid = uuid.uuid4()
- args, _ = strict.intercept(([uid],), {})
- assert args[0][0]["__taskito_convert__"] is True
-
- def test_depth_limit(self, registry: Any) -> None:
- interceptor = ArgumentInterceptor(registry, mode="strict", max_depth=2)
- uid = uuid.uuid4()
- # Depth 3 — beyond limit, should pass through
- deep = {"a": {"b": {"c": uid}}}
- _, kwargs = interceptor.intercept((), {"x": deep})
- # uid at depth 3 should pass through as-is (beyond max_depth=2)
- assert kwargs["x"]["a"]["b"]["c"] is uid
-
- def test_circular_reference_handled(self, strict: ArgumentInterceptor) -> None:
- d: dict[str, Any] = {"value": 42}
- d["self"] = d # circular!
- args, _ = strict.intercept((d,), {})
- # Should not infinite loop — circular ref is detected and passed through
- assert args[0]["value"] == 42
-
-
-# -- Full round trip (intercept + reconstruct) --
-
-
-class TestRoundTrip:
- def test_uuid_full_round_trip(self, strict: ArgumentInterceptor) -> None:
- uid = uuid.UUID("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee")
- intercepted_args, intercepted_kwargs = strict.intercept((uid,), {"id": uid})
- args, kwargs, _redirects = reconstruct_args(intercepted_args, intercepted_kwargs)
- assert args[0] == uid
- assert kwargs["id"] == uid
-
- def test_mixed_types_round_trip(self, strict: ArgumentInterceptor) -> None:
- dt = datetime.datetime(2025, 6, 15, 10, 30)
- path = pathlib.Path("/data/file.txt")
- intercepted_args, intercepted_kwargs = strict.intercept((42, "hello", dt), {"path": path})
- args, kwargs, _redirects = reconstruct_args(intercepted_args, intercepted_kwargs)
- assert args[0] == 42
- assert args[1] == "hello"
- assert args[2] == dt
- assert kwargs["path"] == path
-
- def test_nested_convert_round_trip(self, strict: ArgumentInterceptor) -> None:
- uid = uuid.uuid4()
- intercepted_args, _ = strict.intercept(({"ids": [uid]},), {})
- args, _, _ = reconstruct_args(intercepted_args, {})
- assert args[0]["ids"][0] == uid
-
-
-# -- Analyze / Report --
-
-
-class TestAnalyze:
- def test_analyze_returns_report(self, strict: ArgumentInterceptor) -> None:
- report = strict.analyze((42, "hello"), {"uid": uuid.uuid4()})
- assert isinstance(report, InterceptionReport)
- assert len(report.entries) == 3
-
- def test_analyze_shows_strategies(self, strict: ArgumentInterceptor) -> None:
- uid = uuid.uuid4()
- report = strict.analyze((42, uid), {})
- strategies = [e.strategy for e in report.entries]
- assert Strategy.PASS in strategies
- assert Strategy.CONVERT in strategies
-
- def test_analyze_on_off_mode_returns_empty(self, registry: Any) -> None:
- interceptor = ArgumentInterceptor(registry, mode="off")
- report = interceptor.analyze((42,), {})
- assert len(report.entries) == 0
-
- def test_report_str_format(self, strict: ArgumentInterceptor) -> None:
- report = strict.analyze((42,), {})
- text = str(report)
- assert "Argument Analysis:" in text
-
-
-# -- Queue integration --
-
-
-class TestQueueIntegration:
- def test_queue_default_interception_off(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
- assert q._interceptor is None
-
- def test_queue_strict_mode(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"), interception="strict")
- assert q._interceptor is not None
- assert q._interceptor.mode == "strict"
-
- def test_queue_enqueue_with_interception(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"), interception="strict")
-
- @q.task()
- def add(a: int, b: int) -> int:
- return a + b
-
- # Simple args should work fine
- result = add.delay(1, 2)
- assert result.id is not None
-
- def test_queue_enqueue_rejects_lock(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"), interception="strict")
-
- @q.task()
- def bad_task(lock: Any) -> None:
- pass
-
- with pytest.raises(InterceptionError):
- bad_task.delay(threading.Lock())
-
- def test_task_analyze(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"), interception="strict")
-
- @q.task()
- def my_task(user_id: int, created_at: datetime.datetime) -> None:
- pass
-
- report = my_task.analyze(42, datetime.datetime.now())
- assert len(report.entries) == 2
-
- def test_task_analyze_off_mode(self, tmp_path: Any) -> None:
- q = Queue(db_path=str(tmp_path / "test.db"))
-
- @q.task()
- def my_task(x: int) -> None:
- pass
-
- report = my_task.analyze(42)
- assert len(report.entries) == 0
-
-
-# -- Custom type registration --
-
-
-class TestCustomRegistration:
- def test_register_custom_reject(self, registry: Any, strict: ArgumentInterceptor) -> None:
- class MyLock:
- pass
-
- registry.register(
- MyLock,
- Strategy.REJECT,
- priority=40,
- reject_reason="Use distributed locking instead.",
- reject_suggestions=["Use queue.lock()"],
- )
- with pytest.raises(InterceptionError) as exc_info:
- strict.intercept((MyLock(),), {})
- assert "distributed locking" in str(exc_info.value)
-
- def test_register_custom_convert(self, registry: Any, strict: ArgumentInterceptor) -> None:
- class Money:
- def __init__(self, amount: int, currency: str) -> None:
- self.amount = amount
- self.currency = currency
-
- def convert_money(obj: Money) -> dict[str, Any]:
- return {
- "__taskito_convert__": True,
- "type_key": "money",
- "value": {"amount": obj.amount, "currency": obj.currency},
- }
-
- registry.register(
- Money,
- Strategy.CONVERT,
- priority=15,
- converter=convert_money,
- type_key="money",
- )
- args, _ = strict.intercept((Money(100, "USD"),), {})
- assert args[0]["__taskito_convert__"] is True
- assert args[0]["value"]["amount"] == 100
diff --git a/tests/test_keda.py b/tests/test_keda.py
deleted file mode 100644
index a298e34..0000000
--- a/tests/test_keda.py
+++ /dev/null
@@ -1,153 +0,0 @@
-"""Tests for KEDA-compatible /api/scaler endpoint contract."""
-
-import json
-import threading
-import urllib.request
-from collections.abc import Generator
-from http.server import ThreadingHTTPServer
-from pathlib import Path
-
-import pytest
-
-from taskito import Queue
-from taskito.dashboard import build_scaler_response
-
-
-@pytest.fixture()
-def empty_queue(tmp_path: Path) -> Queue:
- """Queue with no pending jobs."""
- return Queue(db_path=str(tmp_path / "keda.db"), workers=4)
-
-
-@pytest.fixture()
-def populated_queue(tmp_path: Path) -> Queue:
- """Queue with pending jobs across two queues."""
- q = Queue(db_path=str(tmp_path / "keda.db"), workers=4)
-
- @q.task(queue="emails")
- def send_email(to: str) -> None:
- pass
-
- @q.task(queue="reports")
- def generate_report(name: str) -> None:
- pass
-
- for i in range(5):
- send_email.delay(f"user{i}@example.com")
- for i in range(3):
- generate_report.delay(f"report_{i}")
-
- return q
-
-
-@pytest.fixture()
-def scaler_server(empty_queue: Queue) -> Generator[str]:
- """Start a scaler HTTP server on a random port, yield the base URL."""
- from taskito.scaler import _make_scaler_handler
-
- handler = _make_scaler_handler(empty_queue, target_queue_depth=10)
- server = ThreadingHTTPServer(("127.0.0.1", 0), handler)
- port = server.server_address[1]
- thread = threading.Thread(target=server.serve_forever, daemon=True)
- thread.start()
- yield f"http://127.0.0.1:{port}"
- server.shutdown()
-
-
-# ─── Unit tests: build_scaler_response() ───
-
-
-class TestScalerResponseShape:
- def test_required_fields(self, empty_queue: Queue) -> None:
- resp = build_scaler_response(empty_queue)
- assert "metricName" in resp
- assert "metricValue" in resp
- assert "isActive" in resp
- assert "liveWorkers" in resp
- assert "totalCapacity" in resp
- assert "targetQueueDepth" in resp
-
- def test_metric_value_is_pending_count(self, populated_queue: Queue) -> None:
- resp = build_scaler_response(populated_queue)
- assert resp["metricValue"] == 8 # 5 emails + 3 reports
-
- def test_is_active_false_when_empty(self, empty_queue: Queue) -> None:
- resp = build_scaler_response(empty_queue)
- assert resp["isActive"] is False
- assert resp["metricValue"] == 0
-
- def test_is_active_true_when_pending(self, populated_queue: Queue) -> None:
- resp = build_scaler_response(populated_queue)
- assert resp["isActive"] is True
-
- def test_per_queue_filter(self, populated_queue: Queue) -> None:
- resp = build_scaler_response(populated_queue, queue_name="emails")
- assert resp["metricValue"] == 5
- assert resp["isActive"] is True
-
- def test_metric_name_namespaced_for_queue(self, populated_queue: Queue) -> None:
- resp = build_scaler_response(populated_queue, queue_name="reports")
- assert resp["metricName"] == "taskito_queue_depth_reports"
-
- def test_default_metric_name(self, empty_queue: Queue) -> None:
- resp = build_scaler_response(empty_queue)
- assert resp["metricName"] == "taskito_queue_depth"
-
- def test_worker_utilization_present(self, empty_queue: Queue) -> None:
- resp = build_scaler_response(empty_queue)
- # workers=4, 0 running → utilization = 0.0
- assert "workerUtilization" in resp
- assert resp["workerUtilization"] == 0.0
-
- def test_worker_utilization_is_zero_when_idle(self, empty_queue: Queue) -> None:
- resp = build_scaler_response(empty_queue)
- assert resp["workerUtilization"] == 0.0
-
- def test_per_queue_stats_present(self, populated_queue: Queue) -> None:
- resp = build_scaler_response(populated_queue)
- assert "perQueue" in resp
- assert "emails" in resp["perQueue"]
- assert "reports" in resp["perQueue"]
- assert resp["perQueue"]["emails"]["pending"] == 5
- assert resp["perQueue"]["reports"]["pending"] == 3
-
- def test_target_queue_depth_default(self, empty_queue: Queue) -> None:
- resp = build_scaler_response(empty_queue)
- assert resp["targetQueueDepth"] == 10
-
- def test_target_queue_depth_custom(self, empty_queue: Queue) -> None:
- resp = build_scaler_response(empty_queue, target_queue_depth=25)
- assert resp["targetQueueDepth"] == 25
-
- def test_live_workers_zero_before_start(self, empty_queue: Queue) -> None:
- resp = build_scaler_response(empty_queue)
- assert resp["liveWorkers"] == 0
-
- def test_total_capacity_matches_config(self, empty_queue: Queue) -> None:
- resp = build_scaler_response(empty_queue)
- assert resp["totalCapacity"] == 4
-
-
-# ─── Integration tests: HTTP server ───
-
-
-class TestScalerHTTP:
- def test_scaler_endpoint_returns_json(self, scaler_server: str) -> None:
- resp = urllib.request.urlopen(f"{scaler_server}/api/scaler")
- assert resp.status == 200
- data = json.loads(resp.read())
- assert "metricName" in data
- assert "metricValue" in data
-
- def test_health_endpoint(self, scaler_server: str) -> None:
- resp = urllib.request.urlopen(f"{scaler_server}/health")
- assert resp.status == 200
- data = json.loads(resp.read())
- assert data["status"] == "ok"
-
- def test_unknown_path_returns_404(self, scaler_server: str) -> None:
- try:
- urllib.request.urlopen(f"{scaler_server}/unknown")
- pytest.fail("Expected HTTPError")
- except urllib.error.HTTPError as e:
- assert e.code == 404
diff --git a/tests/test_namespace.py b/tests/test_namespace.py
deleted file mode 100644
index 704b343..0000000
--- a/tests/test_namespace.py
+++ /dev/null
@@ -1,118 +0,0 @@
-"""Tests for namespace-based routing and isolation."""
-
-import threading
-from pathlib import Path
-
-from taskito import Queue
-
-
-def test_namespace_enqueue_sets_namespace(tmp_path: Path) -> None:
- """Jobs enqueued on a namespaced Queue carry the namespace."""
- queue = Queue(db_path=str(tmp_path / "test.db"), namespace="team-a")
-
- @queue.task()
- def add(x: int, y: int) -> int:
- return x + y
-
- job = add.delay(1, 2)
- py_job = queue._inner.get_job(job.id)
- assert py_job is not None
- assert py_job.namespace == "team-a"
-
-
-def test_no_namespace_jobs_have_none(tmp_path: Path) -> None:
- """Jobs enqueued without a namespace have namespace=None."""
- queue = Queue(db_path=str(tmp_path / "test.db"))
-
- @queue.task()
- def noop() -> None:
- pass
-
- job = noop.delay()
- py_job = queue._inner.get_job(job.id)
- assert py_job is not None
- assert py_job.namespace is None
-
-
-def test_namespace_isolation_worker(tmp_path: Path) -> None:
- """A namespaced worker only processes jobs from its namespace."""
- db = str(tmp_path / "test.db")
-
- # Create two queues sharing the same DB but different namespaces
- q_a = Queue(db_path=db, namespace="team-a")
- q_b = Queue(db_path=db, namespace="team-b")
-
- results: list[str] = []
-
- @q_a.task()
- def task_a() -> str:
- results.append("a")
- return "a"
-
- @q_b.task()
- def task_b() -> str:
- results.append("b")
- return "b"
-
- # Enqueue one job on each namespace
- job_a = task_a.delay()
- job_b = task_b.delay()
-
- # Run worker for team-a only
- worker = threading.Thread(target=q_a.run_worker, daemon=True)
- worker.start()
-
- # Wait for team-a's job to complete
- job_a.result(timeout=10)
-
- # team-b's job should still be pending
- job_b.refresh()
- assert job_b.status == "pending"
-
- # Shut down team-a worker (daemon thread exits on its own)
- q_a._inner.request_shutdown()
-
-
-def test_namespace_list_jobs_scoped(tmp_path: Path) -> None:
- """list_jobs defaults to the queue's namespace."""
- db = str(tmp_path / "test.db")
- q_a = Queue(db_path=db, namespace="ns-a")
- q_b = Queue(db_path=db, namespace="ns-b")
-
- @q_a.task()
- def task_x() -> None:
- pass
-
- @q_b.task()
- def task_y() -> None:
- pass
-
- task_x.delay()
- task_x.delay()
- task_y.delay()
-
- # Each queue sees only its own jobs by default
- assert len(q_a.list_jobs()) == 2
- assert len(q_b.list_jobs()) == 1
-
- # Passing namespace=None shows all
- assert len(q_a.list_jobs(namespace=None)) == 3
-
-
-def test_namespace_preserved_in_job_result(tmp_path: Path) -> None:
- """JobResult.to_dict() includes the namespace."""
- queue = Queue(db_path=str(tmp_path / "test.db"), namespace="my-ns")
-
- @queue.task()
- def greet(name: str) -> str:
- return f"hi {name}"
-
- job = greet.delay("world")
-
- worker = threading.Thread(target=queue.run_worker, daemon=True)
- worker.start()
-
- result = job.result(timeout=10)
- assert result == "hi world"
-
- queue._inner.request_shutdown()
diff --git a/tests/test_native_async.py b/tests/test_native_async.py
deleted file mode 100644
index f6fa7d1..0000000
--- a/tests/test_native_async.py
+++ /dev/null
@@ -1,494 +0,0 @@
-"""Tests for native async task support."""
-
-from __future__ import annotations
-
-import asyncio
-import threading
-from pathlib import Path
-from typing import Any
-from unittest.mock import MagicMock
-
-from taskito import Queue, TaskCancelledError, current_job
-from taskito.async_support.context import (
- clear_async_context,
- get_async_context,
- set_async_context,
-)
-from taskito.middleware import TaskMiddleware
-
-PollUntil = Any # the conftest fixture's runtime type
-
-# ── Async detection ──────────────────────────────────────────────
-
-
-def test_async_task_detected(tmp_path: Path) -> None:
- """_taskito_is_async is True for async functions."""
- queue = Queue(db_path=str(tmp_path / "test.db"))
-
- @queue.task()
- async def my_async_task() -> None:
- pass
-
- assert my_async_task._taskito_is_async is True
- assert hasattr(my_async_task, "_taskito_async_fn")
-
-
-def test_sync_task_not_async(tmp_path: Path) -> None:
- """_taskito_is_async is False for sync functions."""
- queue = Queue(db_path=str(tmp_path / "test.db"))
-
- @queue.task()
- def my_sync_task() -> None:
- pass
-
- assert my_sync_task._taskito_is_async is False
- assert my_sync_task._taskito_async_fn is None
-
-
-# ── Async context (contextvars) ──────────────────────────────────
-
-
-def test_async_context_var() -> None:
- """set/get/clear async context via contextvars."""
- token = set_async_context("job-1", "my_task", 0, "default")
- ctx = get_async_context()
- assert ctx is not None
- assert ctx.job_id == "job-1"
- assert ctx.task_name == "my_task"
- assert ctx.retry_count == 0
- assert ctx.queue_name == "default"
- clear_async_context(token)
- assert get_async_context() is None
-
-
-def test_async_context_isolated_between_tasks() -> None:
- """Each async task gets its own contextvar context (no cross-contamination)."""
- results: list[str | None] = []
-
- async def coro(job_id: str) -> None:
- token = set_async_context(job_id, "task", 0, "q")
- await asyncio.sleep(0.01)
- ctx = get_async_context()
- results.append(ctx.job_id if ctx else None)
- clear_async_context(token)
-
- async def run_both() -> None:
- await asyncio.gather(coro("a"), coro("b"))
-
- asyncio.run(run_both())
-
- assert sorted(r for r in results if r is not None) == ["a", "b"]
-
-
-def test_sync_context_unchanged(tmp_path: Path) -> None:
- """current_job still works via threading.local for sync tasks."""
- from taskito.context import _clear_context, _set_context
-
- _set_context("sync-job", "sync_task", 2, "high")
- assert current_job.id == "sync-job"
- assert current_job.task_name == "sync_task"
- assert current_job.retry_count == 2
- assert current_job.queue_name == "high"
- _clear_context()
-
-
-def test_async_context_fallback_to_sync() -> None:
- """_require_context falls back to threading.local when no async context."""
- from taskito.context import _clear_context, _set_context
-
- # No async context set
- assert get_async_context() is None
- # Set sync context
- _set_context("sync-id", "t", 0, "q")
- assert current_job.id == "sync-id"
- _clear_context()
-
-
-def test_async_context_preferred_over_sync() -> None:
- """When both async and sync contexts exist, async wins."""
- from taskito.context import _clear_context, _set_context
-
- _set_context("sync-id", "t", 0, "q")
- token = set_async_context("async-id", "t", 0, "q")
- assert current_job.id == "async-id"
- clear_async_context(token)
- assert current_job.id == "sync-id"
- _clear_context()
-
-
-# ── AsyncTaskExecutor unit tests ─────────────────────────────────
-
-
-def test_async_executor_lifecycle() -> None:
- """Start/stop executor without errors."""
- from taskito.async_support.executor import AsyncTaskExecutor
-
- sender = MagicMock()
- registry: dict[str, Any] = {}
- queue_ref = MagicMock()
- executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10)
- executor.start()
- assert executor._loop is not None
- assert executor._loop.is_running()
- executor.stop()
-
-
-def test_async_executor_submit_and_execute(poll_until: PollUntil) -> None:
- """Basic async task produces correct result via executor."""
- import cloudpickle
-
- from taskito.async_support.executor import AsyncTaskExecutor
-
- sender = MagicMock()
-
- async def my_task(x: int, y: int) -> int:
- return x + y
-
- # Build a minimal wrapper that the executor expects
- class FakeWrapper:
- _taskito_async_fn = staticmethod(my_task)
-
- registry: dict[str, Any] = {"test_mod.my_task": FakeWrapper()}
-
- queue_ref = MagicMock()
- queue_ref._interceptor = None
- queue_ref._proxy_registry = None
- queue_ref._test_mode_active = False
- queue_ref._resource_runtime = None
- queue_ref._task_inject_map = {}
- queue_ref._task_retry_filters = {}
- queue_ref._get_middleware_chain.return_value = []
- queue_ref._proxy_metrics = None
-
- executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10)
- executor.start()
-
- payload = cloudpickle.dumps(((2, 3), {}))
- executor.submit_job("job-1", "test_mod.my_task", payload, 0, 3, "default")
- poll_until(lambda: sender.report_success.called, message="job-1 result not reported")
- executor.stop()
-
- sender.report_success.assert_called_once()
- call_args = sender.report_success.call_args
- assert call_args[0][0] == "job-1"
- assert call_args[0][1] == "test_mod.my_task"
- result = cloudpickle.loads(call_args[0][2])
- assert result == 5
-
-
-def test_async_exception_reported(poll_until: PollUntil) -> None:
- """Exception in async task → failure result with traceback."""
- import cloudpickle
-
- from taskito.async_support.executor import AsyncTaskExecutor
-
- sender = MagicMock()
-
- async def failing_task() -> None:
- raise ValueError("boom")
-
- class FakeWrapper:
- _taskito_async_fn = staticmethod(failing_task)
-
- registry: dict[str, Any] = {"mod.failing_task": FakeWrapper()}
-
- queue_ref = MagicMock()
- queue_ref._interceptor = None
- queue_ref._proxy_registry = None
- queue_ref._test_mode_active = False
- queue_ref._resource_runtime = None
- queue_ref._task_inject_map = {}
- queue_ref._task_retry_filters = {}
- queue_ref._get_middleware_chain.return_value = []
- queue_ref._proxy_metrics = None
-
- executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10)
- executor.start()
-
- payload = cloudpickle.dumps(((), {}))
- executor.submit_job("job-2", "mod.failing_task", payload, 0, 3, "default")
- poll_until(lambda: sender.report_failure.called, message="job-2 failure not reported")
- executor.stop()
-
- sender.report_failure.assert_called_once()
- call_args = sender.report_failure.call_args
- assert call_args[0][0] == "job-2"
- assert "boom" in call_args[0][2]
- assert call_args[0][6] is True # should_retry
-
-
-def test_async_cancellation(poll_until: PollUntil) -> None:
- """TaskCancelledError → cancelled result."""
- import cloudpickle
-
- from taskito.async_support.executor import AsyncTaskExecutor
-
- sender = MagicMock()
-
- async def cancelling_task() -> None:
- raise TaskCancelledError("cancelled")
-
- class FakeWrapper:
- _taskito_async_fn = staticmethod(cancelling_task)
-
- registry: dict[str, Any] = {"mod.cancelling_task": FakeWrapper()}
-
- queue_ref = MagicMock()
- queue_ref._interceptor = None
- queue_ref._proxy_registry = None
- queue_ref._test_mode_active = False
- queue_ref._resource_runtime = None
- queue_ref._task_inject_map = {}
- queue_ref._task_retry_filters = {}
- queue_ref._get_middleware_chain.return_value = []
- queue_ref._proxy_metrics = None
-
- executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10)
- executor.start()
-
- payload = cloudpickle.dumps(((), {}))
- executor.submit_job("job-3", "mod.cancelling_task", payload, 0, 3, "default")
- poll_until(lambda: sender.report_cancelled.called, message="job-3 cancellation not reported")
- executor.stop()
-
- sender.report_cancelled.assert_called_once()
- assert sender.report_cancelled.call_args[0][0] == "job-3"
-
-
-def test_async_retry_filter(poll_until: PollUntil) -> None:
- """Failed async task respects retry_on filter."""
- import cloudpickle
-
- from taskito.async_support.executor import AsyncTaskExecutor
-
- sender = MagicMock()
-
- async def flaky_task() -> None:
- raise TypeError("wrong type")
-
- class FakeWrapper:
- _taskito_async_fn = staticmethod(flaky_task)
-
- registry: dict[str, Any] = {"mod.flaky_task": FakeWrapper()}
-
- queue_ref = MagicMock()
- queue_ref._interceptor = None
- queue_ref._proxy_registry = None
- queue_ref._test_mode_active = False
- queue_ref._resource_runtime = None
- queue_ref._task_inject_map = {}
- # Only retry on ValueError, not TypeError
- queue_ref._task_retry_filters = {
- "mod.flaky_task": {"retry_on": [ValueError], "dont_retry_on": []},
- }
- queue_ref._get_middleware_chain.return_value = []
- queue_ref._proxy_metrics = None
-
- executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10)
- executor.start()
-
- payload = cloudpickle.dumps(((), {}))
- executor.submit_job("job-4", "mod.flaky_task", payload, 0, 3, "default")
- poll_until(lambda: sender.report_failure.called, message="job-4 failure not reported")
- executor.stop()
-
- sender.report_failure.assert_called_once()
- assert sender.report_failure.call_args[0][6] is False # should_retry = False
-
-
-def test_async_concurrency_limit(poll_until: PollUntil) -> None:
- """Semaphore bounds concurrent async tasks."""
- import cloudpickle
-
- from taskito.async_support.executor import AsyncTaskExecutor
-
- sender = MagicMock()
- max_concurrent = 0
- current = 0
- lock = threading.Lock()
-
- async def slow_task() -> None:
- nonlocal max_concurrent, current
- with lock:
- current += 1
- max_concurrent = max(max_concurrent, current)
- await asyncio.sleep(0.1)
- with lock:
- current -= 1
-
- class FakeWrapper:
- _taskito_async_fn = staticmethod(slow_task)
-
- registry: dict[str, Any] = {"mod.slow_task": FakeWrapper()}
-
- queue_ref = MagicMock()
- queue_ref._interceptor = None
- queue_ref._proxy_registry = None
- queue_ref._test_mode_active = False
- queue_ref._resource_runtime = None
- queue_ref._task_inject_map = {}
- queue_ref._task_retry_filters = {}
- queue_ref._get_middleware_chain.return_value = []
- queue_ref._proxy_metrics = None
-
- # Set concurrency to 2
- executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=2)
- executor.start()
-
- payload = cloudpickle.dumps(((), {}))
- for i in range(5):
- executor.submit_job(f"job-{i}", "mod.slow_task", payload, 0, 3, "default")
-
- poll_until(
- lambda: sender.report_success.call_count >= 5,
- timeout=10,
- message="not all 5 slow_task jobs reported success",
- )
- executor.stop()
-
- assert max_concurrent <= 2
- assert sender.report_success.call_count == 5
-
-
-def test_async_middleware_hooks(poll_until: PollUntil) -> None:
- """Middleware before/after called for async tasks."""
- import cloudpickle
-
- from taskito.async_support.executor import AsyncTaskExecutor
-
- before_called: list[str] = []
- after_called: list[str] = []
-
- class TestMiddleware(TaskMiddleware):
- def before(self, job_context: Any) -> None:
- before_called.append(job_context.id)
-
- def after(self, job_context: Any, result: Any, error: Any) -> None:
- after_called.append(job_context.id)
-
- sender = MagicMock()
-
- async def simple_task() -> int:
- return 42
-
- class FakeWrapper:
- _taskito_async_fn = staticmethod(simple_task)
-
- registry: dict[str, Any] = {"mod.simple_task": FakeWrapper()}
-
- queue_ref = MagicMock()
- queue_ref._interceptor = None
- queue_ref._proxy_registry = None
- queue_ref._test_mode_active = False
- queue_ref._resource_runtime = None
- queue_ref._task_inject_map = {}
- queue_ref._task_retry_filters = {}
- queue_ref._get_middleware_chain.return_value = [TestMiddleware()]
- queue_ref._proxy_metrics = None
-
- executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10)
- executor.start()
-
- payload = cloudpickle.dumps(((), {}))
- executor.submit_job("mw-job", "mod.simple_task", payload, 0, 3, "default")
- poll_until(lambda: "mw-job" in after_called, message="middleware after hook not called")
- executor.stop()
-
- assert "mw-job" in before_called
- assert "mw-job" in after_called
-
-
-def test_async_task_with_injection(poll_until: PollUntil) -> None:
- """inject=["db"] works for async tasks via executor."""
- import cloudpickle
-
- from taskito.async_support.executor import AsyncTaskExecutor
-
- sender = MagicMock()
-
- async def db_task(db: Any = None) -> str:
- return f"got-{db}"
-
- class FakeWrapper:
- _taskito_async_fn = staticmethod(db_task)
-
- registry: dict[str, Any] = {"mod.db_task": FakeWrapper()}
-
- fake_db = "fake-conn"
-
- queue_ref = MagicMock()
- queue_ref._interceptor = None
- queue_ref._proxy_registry = None
- queue_ref._test_mode_active = False
- queue_ref._task_inject_map = {"mod.db_task": ["db"]}
- queue_ref._task_retry_filters = {}
- queue_ref._get_middleware_chain.return_value = []
- queue_ref._proxy_metrics = None
-
- # Mock resource runtime
- runtime = MagicMock()
- runtime.acquire_for_task.return_value = (fake_db, None)
- queue_ref._resource_runtime = runtime
-
- executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10)
- executor.start()
-
- payload = cloudpickle.dumps(((), {}))
- executor.submit_job("inj-job", "mod.db_task", payload, 0, 3, "default")
- poll_until(lambda: sender.report_success.called, message="inj-job result not reported")
- executor.stop()
-
- sender.report_success.assert_called_once()
- result = cloudpickle.loads(sender.report_success.call_args[0][2])
- assert result == "got-fake-conn"
-
-
-def test_async_context_available_inside_task(poll_until: PollUntil) -> None:
- """current_job.id works inside an async task via contextvars."""
- import cloudpickle
-
- from taskito.async_support.executor import AsyncTaskExecutor
-
- sender = MagicMock()
- captured_id: list[str] = []
-
- async def ctx_task() -> str:
- captured_id.append(current_job.id)
- return "ok"
-
- class FakeWrapper:
- _taskito_async_fn = staticmethod(ctx_task)
-
- registry: dict[str, Any] = {"mod.ctx_task": FakeWrapper()}
-
- queue_ref = MagicMock()
- queue_ref._interceptor = None
- queue_ref._proxy_registry = None
- queue_ref._test_mode_active = False
- queue_ref._resource_runtime = None
- queue_ref._task_inject_map = {}
- queue_ref._task_retry_filters = {}
- queue_ref._get_middleware_chain.return_value = []
- queue_ref._proxy_metrics = None
-
- executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10)
- executor.start()
-
- payload = cloudpickle.dumps(((), {}))
- executor.submit_job("ctx-job", "mod.ctx_task", payload, 0, 3, "default")
- poll_until(lambda: captured_id == ["ctx-job"], message="ctx-job context not captured")
- executor.stop()
-
- assert captured_id == ["ctx-job"]
-
-
-def test_async_concurrency_parameter(tmp_path: Path) -> None:
- """Queue accepts async_concurrency parameter."""
- queue = Queue(db_path=str(tmp_path / "test.db"), async_concurrency=50)
- assert queue._async_concurrency == 50
-
-
-def test_async_concurrency_default(tmp_path: Path) -> None:
- """Default async_concurrency is 100."""
- queue = Queue(db_path=str(tmp_path / "test.db"))
- assert queue._async_concurrency == 100
diff --git a/tests/test_observability.py b/tests/test_observability.py
deleted file mode 100644
index 7e0f17a..0000000
--- a/tests/test_observability.py
+++ /dev/null
@@ -1,309 +0,0 @@
-"""Tests for resource observability (Phase 6 — status API, CLI, health checks)."""
-
-from __future__ import annotations
-
-import time
-from typing import Any
-
-import pytest
-
-from taskito import Queue
-from taskito.resources import ResourceDefinition, ResourceRuntime
-
-# ---------------------------------------------------------------------------
-# ResourceRuntime.status()
-# ---------------------------------------------------------------------------
-
-
-class TestResourceRuntimeStatus:
- def test_status_returns_all_resources(self) -> None:
- """status() returns an entry for every initialized resource."""
- defs = {
- "config": ResourceDefinition(name="config", factory=lambda: {}),
- "db": ResourceDefinition(
- name="db", factory=lambda config: "conn", depends_on=["config"]
- ),
- }
- rt = ResourceRuntime(defs)
- rt.initialize()
-
- entries = rt.status()
- assert len(entries) == 2
- names = [e["name"] for e in entries]
- assert "config" in names
- assert "db" in names
- rt.teardown()
-
- def test_status_healthy_resource(self) -> None:
- """A successfully initialized resource has health='healthy'."""
- defs = {"svc": ResourceDefinition(name="svc", factory=lambda: "ok")}
- rt = ResourceRuntime(defs)
- rt.initialize()
-
- entry = rt.status()[0]
- assert entry["health"] == "healthy"
- assert entry["name"] == "svc"
- assert entry["scope"] == "worker"
- assert entry["recreations"] == 0
- assert entry["init_duration_ms"] >= 0
- assert entry["depends_on"] == []
- rt.teardown()
-
- def test_status_unhealthy_resource(self) -> None:
- """An unhealthy resource is reported as such."""
- defs = {"svc": ResourceDefinition(name="svc", factory=lambda: "ok")}
- rt = ResourceRuntime(defs)
- rt.initialize()
- rt._unhealthy.add("svc")
-
- entry = rt.status()[0]
- assert entry["health"] == "unhealthy"
- rt.teardown()
-
- def test_status_tracks_init_duration(self) -> None:
- """init_duration_ms is populated after initialize."""
-
- def slow_factory() -> str:
- time.sleep(0.05)
- return "result"
-
- defs = {"slow": ResourceDefinition(name="slow", factory=slow_factory)}
- rt = ResourceRuntime(defs)
- rt.initialize()
-
- entry = rt.status()[0]
- assert entry["init_duration_ms"] >= 40 # at least ~50ms minus timing jitter
- rt.teardown()
-
- def test_status_tracks_recreations(self) -> None:
- """recreation count is incremented on successful recreate."""
- call_count = 0
-
- def make_svc() -> str:
- nonlocal call_count
- call_count += 1
- return f"v{call_count}"
-
- defs = {
- "svc": ResourceDefinition(
- name="svc",
- factory=make_svc,
- health_check=lambda inst: True,
- health_check_interval=1.0,
- ),
- }
- rt = ResourceRuntime(defs)
- rt.initialize()
- assert rt._recreation_count.get("svc", 0) == 0
-
- # Manually trigger recreation
- rt.recreate("svc")
- assert rt._recreation_count["svc"] == 1
-
- entry = rt.status()[0]
- assert entry["recreations"] == 1
- rt.teardown()
-
- def test_status_includes_depends_on(self) -> None:
- """depends_on list is included in status output."""
- defs = {
- "config": ResourceDefinition(name="config", factory=lambda: {}),
- "db": ResourceDefinition(
- name="db", factory=lambda config: "conn", depends_on=["config"]
- ),
- }
- rt = ResourceRuntime(defs)
- rt.initialize()
-
- entries = {e["name"]: e for e in rt.status()}
- assert entries["config"]["depends_on"] == []
- assert entries["db"]["depends_on"] == ["config"]
- rt.teardown()
-
- def test_from_test_overrides_status(self) -> None:
- """Test-override runtime reports healthy status."""
- rt = ResourceRuntime.from_test_overrides({"db": "mock"})
- entries = rt.status()
- assert len(entries) == 1
- assert entries[0]["name"] == "db"
- assert entries[0]["health"] == "healthy"
-
-
-# ---------------------------------------------------------------------------
-# Queue.resource_status()
-# ---------------------------------------------------------------------------
-
-
-class TestQueueResourceStatus:
- def test_resource_status_with_runtime(self, tmp_path: Any) -> None:
- """resource_status() delegates to runtime when initialized."""
- queue = Queue(db_path=str(tmp_path / "q.db"))
-
- @queue.worker_resource("db")
- def create_db() -> str:
- return "conn"
-
- # Manually initialize runtime
- rt = ResourceRuntime(queue._resource_definitions)
- rt.initialize()
- queue._resource_runtime = rt
-
- status = queue.resource_status()
- assert len(status) == 1
- assert status[0]["name"] == "db"
- assert status[0]["health"] == "healthy"
- rt.teardown()
-
- def test_resource_status_without_runtime(self, tmp_path: Any) -> None:
- """resource_status() returns definitions with not_initialized health."""
- queue = Queue(db_path=str(tmp_path / "q.db"))
-
- @queue.worker_resource("db")
- def create_db() -> str:
- return "conn"
-
- status = queue.resource_status()
- assert len(status) == 1
- assert status[0]["name"] == "db"
- assert status[0]["health"] == "not_initialized"
-
- def test_resource_status_empty(self, tmp_path: Any) -> None:
- """resource_status() returns empty list with no resources."""
- queue = Queue(db_path=str(tmp_path / "q.db"))
- assert queue.resource_status() == []
-
-
-# ---------------------------------------------------------------------------
-# Health check integration
-# ---------------------------------------------------------------------------
-
-
-class TestHealthCheckIntegration:
- def test_readiness_reports_healthy_resources(self, tmp_path: Any) -> None:
- """check_readiness includes resource status when all healthy."""
- from taskito.health import check_readiness
-
- queue = Queue(db_path=str(tmp_path / "q.db"))
-
- @queue.worker_resource("db")
- def create_db() -> str:
- return "conn"
-
- rt = ResourceRuntime(queue._resource_definitions)
- rt.initialize()
- queue._resource_runtime = rt
-
- result = check_readiness(queue)
- assert "resources" in result["checks"]
- res_check = result["checks"]["resources"]
- assert res_check["status"] == "ok"
- assert res_check["count"] == 1
- assert res_check["unhealthy"] == []
- rt.teardown()
-
- def test_readiness_reports_unhealthy_resources(self, tmp_path: Any) -> None:
- """check_readiness marks status as degraded for unhealthy resources."""
- from taskito.health import check_readiness
-
- queue = Queue(db_path=str(tmp_path / "q.db"))
-
- @queue.worker_resource("db")
- def create_db() -> str:
- return "conn"
-
- rt = ResourceRuntime(queue._resource_definitions)
- rt.initialize()
- rt._unhealthy.add("db")
- queue._resource_runtime = rt
-
- result = check_readiness(queue)
- assert result["status"] == "degraded"
- res_check = result["checks"]["resources"]
- assert res_check["status"] == "degraded"
- assert "db" in res_check["unhealthy"]
- rt.teardown()
-
- def test_readiness_no_resources(self, tmp_path: Any) -> None:
- """check_readiness works fine without any resources."""
- from taskito.health import check_readiness
-
- queue = Queue(db_path=str(tmp_path / "q.db"))
- result = check_readiness(queue)
- # No resources section when empty
- assert (
- "resources" not in result["checks"]
- or result["checks"].get("resources", {}).get("count", 0) == 0
- )
-
- def test_health_check_always_ok(self) -> None:
- """check_health is always ok regardless of resources."""
- from taskito.health import check_health
-
- assert check_health() == {"status": "ok"}
-
-
-# ---------------------------------------------------------------------------
-# CLI resources subcommand
-# ---------------------------------------------------------------------------
-
-
-class TestCLIResources:
- def test_run_resources_no_resources(self, tmp_path: Any) -> None:
- """resource_status returns empty list when no resources registered."""
- queue = Queue(db_path=str(tmp_path / "q.db"))
- assert queue.resource_status() == []
-
- def test_resource_status_table_format(
- self, tmp_path: Any, capsys: pytest.CaptureFixture[str]
- ) -> None:
- """Verify table output format from CLI helper."""
- queue = Queue(db_path=str(tmp_path / "q.db"))
-
- @queue.worker_resource("config")
- def create_config() -> dict[str, str]:
- return {}
-
- @queue.worker_resource("db", depends_on=["config"])
- def create_db(config: Any) -> str:
- return "conn"
-
- rt = ResourceRuntime(queue._resource_definitions)
- rt.initialize()
- queue._resource_runtime = rt
-
- # Simulate what run_resources prints
- resources = queue.resource_status()
- assert len(resources) == 2
- names = {r["name"] for r in resources}
- assert names == {"config", "db"}
-
- # Verify structure matches CLI expectations
- for r in resources:
- assert "name" in r
- assert "scope" in r
- assert "health" in r
- assert "init_duration_ms" in r
- assert "recreations" in r
- assert "depends_on" in r
-
- rt.teardown()
-
-
-# ---------------------------------------------------------------------------
-# Prometheus resource metrics
-# ---------------------------------------------------------------------------
-
-
-class TestPrometheusResourceMetrics:
- def test_prometheus_middleware_has_resource_metrics(self) -> None:
- """Verify resource metric singletons are initialized."""
- pytest.importorskip("prometheus_client")
- from taskito.contrib.prometheus import _init_metrics # type: ignore[attr-defined]
-
- _init_metrics()
-
- from taskito.contrib import prometheus as pmod
-
- assert pmod._resource_health is not None # type: ignore[attr-defined]
- assert pmod._resource_recreations is not None # type: ignore[attr-defined]
- assert pmod._resource_init_duration is not None # type: ignore[attr-defined]
diff --git a/tests/test_periodic.py b/tests/test_periodic.py
deleted file mode 100644
index 3cabf77..0000000
--- a/tests/test_periodic.py
+++ /dev/null
@@ -1,59 +0,0 @@
-"""Tests for periodic (cron-scheduled) tasks."""
-
-import threading
-from pathlib import Path
-from typing import Any
-
-import pytest
-
-from taskito import Queue
-
-
-@pytest.fixture
-def queue(tmp_path: Path) -> Queue:
- db_path = str(tmp_path / "test_periodic.db")
- return Queue(db_path=db_path, workers=1)
-
-
-def test_periodic_task_registration(queue: Queue) -> None:
- """Periodic tasks are registered as both regular tasks and periodic configs."""
-
- @queue.periodic(cron="0 * * * * *")
- def every_minute() -> str:
- return "tick"
-
- assert every_minute.name.endswith("every_minute")
- assert every_minute.name in queue._task_registry
- assert len(queue._periodic_configs) == 1
- assert queue._periodic_configs[0]["cron_expr"] == "0 * * * * *"
-
-
-def test_periodic_task_direct_call(queue: Queue) -> None:
- """Periodic tasks can still be called directly like regular tasks."""
-
- @queue.periodic(cron="0 * * * * *")
- def add(a: int, b: int) -> int:
- return a + b
-
- assert add(3, 4) == 7
-
-
-def test_periodic_task_triggers(queue: Queue, poll_until: Any) -> None:
- """Periodic task gets enqueued by the scheduler when due."""
- results: list[int] = []
-
- @queue.periodic(cron="* * * * * *") # every second
- def frequent_task() -> str:
- results.append(1)
- return "done"
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- poll_until(
- lambda: queue.stats()["completed"] >= 1,
- timeout=15,
- message="periodic task never triggered",
- )
-
- assert queue.stats()["completed"] >= 1
diff --git a/tests/test_prefork.py b/tests/test_prefork.py
deleted file mode 100644
index fa1a366..0000000
--- a/tests/test_prefork.py
+++ /dev/null
@@ -1,323 +0,0 @@
-"""Tests for the prefork (multi-process) worker pool."""
-
-from __future__ import annotations
-
-import contextlib
-import importlib
-import os
-import sys
-import threading
-import time
-from collections.abc import Iterator
-from pathlib import Path
-from typing import Any
-
-import pytest
-
-from taskito import Queue
-from taskito.context import JobContext
-from taskito.middleware import TaskMiddleware
-
-
-@pytest.mark.skipif(
- sys.platform == "win32",
- reason="prefork is rejected on Windows before app= is checked",
-)
-def test_prefork_requires_app_path(tmp_path: Path) -> None:
- """pool='prefork' without app= raises ValueError."""
- queue = Queue(db_path=str(tmp_path / "test.db"))
-
- @queue.task()
- def noop() -> None:
- pass
-
- with pytest.raises(ValueError, match="app= is required"):
- queue.run_worker(pool="prefork")
-
-
-def test_prefork_rejected_on_windows(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
- """pool='prefork' on Windows fails fast with NotImplementedError."""
- monkeypatch.setattr(sys, "platform", "win32")
- queue = Queue(db_path=str(tmp_path / "test.db"))
- with pytest.raises(NotImplementedError, match="not supported on Windows"):
- queue.run_worker(pool="prefork", app="x:y")
-
-
-def test_prefork_basic_execution(tmp_path: Path) -> None:
- """A task enqueued and processed by a prefork worker returns the correct result.
-
- NOTE: Prefork children import the app module independently, so the task name
- must resolve to the same module path in both parent and child. Tasks defined
- inside test functions can't be imported by children — use module-level tasks.
- This test is currently skipped; see test_prefork_module_level_queue for the
- working version.
- """
- pytest.skip("Tasks defined inside functions can't be imported by prefork children")
-
-
-def test_prefork_thread_pool_unchanged(tmp_path: Path) -> None:
- """pool='thread' (default) still works normally."""
- q = Queue(db_path=str(tmp_path / "test.db"))
-
- @q.task()
- def multiply(x: int, y: int) -> int:
- return x * y
-
- job = multiply.delay(3, 7)
-
- worker = threading.Thread(target=q.run_worker, daemon=True)
- worker.start()
-
- result = job.result(timeout=10)
- assert result == 21
-
- q._inner.request_shutdown()
-
-
-# ---------------------------------------------------------------------------
-# Per-job timeout enforcement (issue #81)
-# ---------------------------------------------------------------------------
-
-# The prefork pool is Unix-oriented: child processes communicate over anonymous
-# stdio pipes, which on Windows have different blocking semantics that make
-# parent-side reader threads hang after `TerminateProcess`. Per-job timeout
-# behaviour itself is identical, but the surrounding pool plumbing isn't
-# Windows-ready, so these end-to-end tests are skipped there.
-prefork_unix_only = pytest.mark.skipif(
- sys.platform == "win32",
- reason="prefork pool is Unix-only — child stdio pipe semantics differ on Windows",
-)
-
-PREFORK_APP_PATH = "prefork_apps.timeout_app:queue"
-PREFORK_APP_DIR = str(Path(__file__).parent)
-
-
-@pytest.fixture
-def timeout_app(tmp_path: Path) -> Iterator[object]:
- """Set up the module-level timeout-test app with a per-test DB path.
-
- The Queue inside ``prefork_apps.timeout_app`` is constructed at import time
- from ``$TASKITO_TIMEOUT_TEST_DB``, and the prefork child re-imports the
- same module fresh in its own interpreter — so the env var must be set in
- the parent process before that import happens, and propagates to the
- child via inherited env.
- """
- db_path = str(tmp_path / "timeout.db")
- prev_db = os.environ.get("TASKITO_TIMEOUT_TEST_DB")
- prev_pythonpath = os.environ.get("PYTHONPATH")
-
- os.environ["TASKITO_TIMEOUT_TEST_DB"] = db_path
- # Make `prefork_apps.timeout_app` importable in both parent and (inherited)
- # child without depending on pytest's rootdir manipulation.
- os.environ["PYTHONPATH"] = (
- f"{PREFORK_APP_DIR}{os.pathsep}{prev_pythonpath}" if prev_pythonpath else PREFORK_APP_DIR
- )
- if PREFORK_APP_DIR not in sys.path:
- sys.path.insert(0, PREFORK_APP_DIR)
-
- # Force a fresh module import so the Queue picks up the per-test DB path.
- sys.modules.pop("prefork_apps.timeout_app", None)
- sys.modules.pop("prefork_apps", None)
- module = importlib.import_module("prefork_apps.timeout_app")
-
- try:
- yield module
- finally:
- with contextlib.suppress(Exception):
- module.queue._inner.request_shutdown()
- if prev_db is None:
- os.environ.pop("TASKITO_TIMEOUT_TEST_DB", None)
- else:
- os.environ["TASKITO_TIMEOUT_TEST_DB"] = prev_db
- if prev_pythonpath is None:
- os.environ.pop("PYTHONPATH", None)
- else:
- os.environ["PYTHONPATH"] = prev_pythonpath
-
-
-def _start_prefork_worker(queue: Queue) -> threading.Thread:
- """Start a prefork worker for ``queue`` in a daemon thread."""
- thread = threading.Thread(
- target=queue.run_worker,
- kwargs={"pool": "prefork", "app": PREFORK_APP_PATH},
- daemon=True,
- )
- thread.start()
- return thread
-
-
-def _wait_for_terminal(job: Any, timeout: float) -> str:
- """Poll a JobResult.refresh() until the status is terminal or `timeout` elapses."""
- deadline = time.monotonic() + timeout
- while time.monotonic() < deadline:
- job.refresh()
- status: str = job.status
- if status in {"complete", "failed", "dead", "cancelled"}:
- return status
- time.sleep(0.1)
- job.refresh()
- final_status: str = job.status
- return final_status
-
-
-@prefork_unix_only
-def test_prefork_kills_hung_task(timeout_app: object) -> None:
- """A task that hangs past its `timeout=` is SIGKILLed by the watchdog and
- reported as a timeout failure within the timeout + watchdog tick budget."""
- timeouts_seen: list[str] = []
-
- class TimeoutSpy(TaskMiddleware):
- def on_timeout(self, ctx: JobContext) -> None:
- timeouts_seen.append(ctx.id)
-
- queue: Queue = timeout_app.queue # type: ignore[attr-defined]
- queue._global_middleware.append(TimeoutSpy())
-
- started = time.monotonic()
- job = timeout_app.hang.delay() # type: ignore[attr-defined]
- _start_prefork_worker(queue)
-
- # timeout=2s, watchdog tick=250ms → kill within ~2.25s; allow generous
- # headroom for child spawn and CI noise.
- status = _wait_for_terminal(job, timeout=15)
- elapsed = time.monotonic() - started
-
- assert status == "dead", f"expected 'dead', got {status!r} (error={job.error!r})"
- assert "timed out" in (job.error or "").lower()
- assert elapsed < 12, f"hung task took {elapsed:.1f}s to be killed (expected < 12s)"
- assert job.id in timeouts_seen, "on_timeout middleware did not fire"
-
-
-@prefork_unix_only
-def test_prefork_no_timeout_unaffected(timeout_app: object) -> None:
- """A task with no timeout (timeout=0) runs to completion — the watchdog
- must not kill jobs that have no deadline configured."""
- queue: Queue = timeout_app.queue # type: ignore[attr-defined]
-
- job = timeout_app.quick.delay(21) # type: ignore[attr-defined]
- _start_prefork_worker(queue)
-
- result = job.result(timeout=15)
- assert result == 42
-
-
-@prefork_unix_only
-def test_prefork_finishes_before_deadline(timeout_app: object) -> None:
- """A task that completes well before its deadline returns normally — the
- watchdog only fires when the deadline is actually crossed."""
- queue: Queue = timeout_app.queue # type: ignore[attr-defined]
-
- # timeout=2s, sleep 0.5s — should finish cleanly.
- job = timeout_app.sleep_then_finish.delay(0.5) # type: ignore[attr-defined]
- _start_prefork_worker(queue)
-
- result = job.result(timeout=15)
- assert result == "done"
-
-
-# ---------------------------------------------------------------------------
-# Cooperative cancellation propagation (issue #82)
-# ---------------------------------------------------------------------------
-
-CANCEL_APP_PATH = "prefork_apps.cancel_app:queue"
-
-
-@pytest.fixture
-def cancel_app(tmp_path: Path) -> Iterator[object]:
- """Set up the module-level cancel-test app with a per-test DB path.
-
- Mirrors ``timeout_app`` — must set the env var before the import so the
- parent's Queue construction and the child's re-import see the same DB.
- """
- db_path = str(tmp_path / "cancel.db")
- prev_db = os.environ.get("TASKITO_CANCEL_TEST_DB")
- prev_pythonpath = os.environ.get("PYTHONPATH")
-
- os.environ["TASKITO_CANCEL_TEST_DB"] = db_path
- os.environ["PYTHONPATH"] = (
- f"{PREFORK_APP_DIR}{os.pathsep}{prev_pythonpath}" if prev_pythonpath else PREFORK_APP_DIR
- )
- if PREFORK_APP_DIR not in sys.path:
- sys.path.insert(0, PREFORK_APP_DIR)
-
- sys.modules.pop("prefork_apps.cancel_app", None)
- sys.modules.pop("prefork_apps", None)
- module = importlib.import_module("prefork_apps.cancel_app")
-
- try:
- yield module
- finally:
- with contextlib.suppress(Exception):
- module.queue._inner.request_shutdown()
- if prev_db is None:
- os.environ.pop("TASKITO_CANCEL_TEST_DB", None)
- else:
- os.environ["TASKITO_CANCEL_TEST_DB"] = prev_db
- if prev_pythonpath is None:
- os.environ.pop("PYTHONPATH", None)
- else:
- os.environ["PYTHONPATH"] = prev_pythonpath
-
-
-def _start_cancel_worker(queue: Queue) -> threading.Thread:
- thread = threading.Thread(
- target=queue.run_worker,
- kwargs={"pool": "prefork", "app": CANCEL_APP_PATH},
- daemon=True,
- )
- thread.start()
- return thread
-
-
-@prefork_unix_only
-def test_prefork_cancel_running_job_stops_quickly(cancel_app: object, poll_until: Any) -> None:
- """``cancel_running_job`` propagates to the prefork child and stops a
- cooperative task within a small budget — the regression test for #82."""
- queue: Queue = cancel_app.queue # type: ignore[attr-defined]
-
- cancels_seen: list[str] = []
-
- class CancelSpy(TaskMiddleware):
- def on_cancel(self, ctx: JobContext) -> None:
- cancels_seen.append(ctx.id)
-
- queue._global_middleware.append(CancelSpy())
-
- job = cancel_app.cooperative_loop.delay(600) # type: ignore[attr-defined]
- _start_cancel_worker(queue)
-
- # Wait until the job is actually running on a child before cancelling.
- def _running() -> bool:
- job.refresh()
- return bool(job.status == "running")
-
- poll_until(_running, timeout=10, message="job never reached running state")
-
- assert queue.cancel_running_job(job.id) is True
-
- status = _wait_for_terminal(job, timeout=10)
- assert status == "cancelled", f"expected 'cancelled', got {status!r} (error={job.error!r})"
- assert job.id in cancels_seen, "on_cancel middleware did not fire"
-
-
-@prefork_unix_only
-def test_prefork_cancel_does_not_kill_child(cancel_app: object, poll_until: Any) -> None:
- """A cancel must stop the running task without killing the child — the
- next job dispatched to the same pool should still complete normally."""
- queue: Queue = cancel_app.queue # type: ignore[attr-defined]
-
- long_job = cancel_app.cooperative_loop.delay(600) # type: ignore[attr-defined]
- _start_cancel_worker(queue)
-
- def _running() -> bool:
- long_job.refresh()
- return bool(long_job.status == "running")
-
- poll_until(_running, timeout=10, message="long_job never reached running state")
- assert queue.cancel_running_job(long_job.id) is True
- status = _wait_for_terminal(long_job, timeout=10)
- assert status == "cancelled"
-
- follow_up = cancel_app.quick.delay(21) # type: ignore[attr-defined]
- result = follow_up.result(timeout=15)
- assert result == 42
diff --git a/tests/test_priority.py b/tests/test_priority.py
deleted file mode 100644
index 13c8034..0000000
--- a/tests/test_priority.py
+++ /dev/null
@@ -1,50 +0,0 @@
-"""Tests for priority scheduling."""
-
-import threading
-from pathlib import Path
-from typing import Any
-
-import pytest
-
-from taskito import Queue
-
-
-@pytest.fixture
-def queue(tmp_path: Path) -> Queue:
- db_path = str(tmp_path / "test_priority.db")
- return Queue(db_path=db_path, workers=1) # 1 worker for ordering
-
-
-def test_priority_ordering(queue: Queue, poll_until: Any) -> None:
- """Higher priority jobs should be processed first."""
- results: list[str] = []
-
- @queue.task()
- def record_task(label: str) -> str:
- results.append(label)
- return label
-
- # Enqueue low priority first, then high priority. All three jobs are
- # synchronously visible after `apply_async` returns, so no pre-worker
- # delay is needed.
- record_task.apply_async(args=("low",), priority=1)
- record_task.apply_async(args=("medium",), priority=5)
- record_task.apply_async(args=("high",), priority=10)
-
- worker_thread = threading.Thread(
- target=queue.run_worker,
- daemon=True,
- )
- worker_thread.start()
-
- poll_until(
- lambda: len(results) >= 3,
- timeout=10,
- message="not all priority jobs completed",
- )
-
- # High priority should have been processed first
- assert len(results) == 3
- assert results[0] == "high"
- assert results[1] == "medium"
- assert results[2] == "low"
diff --git a/tests/test_progress.py b/tests/test_progress.py
deleted file mode 100644
index 58cd7c6..0000000
--- a/tests/test_progress.py
+++ /dev/null
@@ -1,44 +0,0 @@
-"""Tests for progress tracking."""
-
-from __future__ import annotations
-
-import time
-
-from taskito import Queue
-
-
-def test_update_progress(queue: Queue) -> None:
- """Progress can be updated and read back."""
-
- @queue.task()
- def slow_task() -> str:
-
- time.sleep(0.5)
- return "done"
-
- job = slow_task.delay()
-
- # Update progress directly via the queue API
- queue.update_progress(job.id, 50)
-
- refreshed = queue.get_job(job.id)
- assert refreshed is not None
- assert refreshed.progress == 50
-
- queue.update_progress(job.id, 100)
- refreshed = queue.get_job(job.id)
- assert refreshed is not None
- assert refreshed.progress == 100
-
-
-def test_progress_starts_none(queue: Queue) -> None:
- """Progress is None by default."""
-
- @queue.task()
- def task_a() -> int:
- return 1
-
- job = task_a.delay()
- refreshed = queue.get_job(job.id)
- assert refreshed is not None
- assert refreshed.progress is None
diff --git a/tests/test_proxies.py b/tests/test_proxies.py
deleted file mode 100644
index b6c960f..0000000
--- a/tests/test_proxies.py
+++ /dev/null
@@ -1,429 +0,0 @@
-"""Tests for the proxy system (Layer 3 — transparent reconstruction)."""
-
-from __future__ import annotations
-
-import logging
-from typing import Any
-
-import pytest
-
-from taskito import ProxyReconstructionError, Queue
-from taskito.proxies import ProxyRegistry, cleanup_proxies, reconstruct_proxies
-from taskito.proxies.built_in import register_builtin_handlers
-from taskito.proxies.handlers.file import FileHandler
-from taskito.proxies.handlers.logger import LoggerHandler
-
-# ---------------------------------------------------------------------------
-# FileHandler
-# ---------------------------------------------------------------------------
-
-
-class TestFileHandler:
- def test_detect_open_file(self, tmp_path: Any) -> None:
- f = open(tmp_path / "test.txt", "w") # noqa: SIM115
- try:
- handler = FileHandler()
- assert handler.detect(f) is True
- finally:
- f.close()
-
- def test_detect_closed_file(self, tmp_path: Any) -> None:
- f = open(tmp_path / "test.txt", "w") # noqa: SIM115
- f.close()
- handler = FileHandler()
- assert handler.detect(f) is False
-
- def test_detect_stdin(self) -> None:
- import sys
-
- handler = FileHandler()
- assert handler.detect(sys.stdin) is False
-
- def test_deconstruct_text_file(self, tmp_path: Any) -> None:
- path = tmp_path / "data.txt"
- path.write_text("hello world")
- with open(path) as f:
- handler = FileHandler()
- recipe = handler.deconstruct(f)
- assert recipe["path"] == str(path)
- assert recipe["mode"] == "r"
- assert recipe["encoding"] is not None
- assert recipe["position"] == 0
-
- def test_deconstruct_binary_file(self, tmp_path: Any) -> None:
- path = tmp_path / "data.bin"
- path.write_bytes(b"\x00\x01\x02")
- with open(path, "rb") as f:
- handler = FileHandler()
- recipe = handler.deconstruct(f)
- assert recipe["mode"] == "rb"
- assert recipe["encoding"] is None
-
- def test_reconstruct_text_file(self, tmp_path: Any) -> None:
- path = tmp_path / "data.txt"
- path.write_text("hello world")
- handler = FileHandler()
- recipe = {"path": str(path), "mode": "r", "encoding": "utf-8", "position": 0}
- f = handler.reconstruct(recipe, version=1)
- try:
- assert f.read() == "hello world"
- finally:
- f.close()
-
- def test_reconstruct_at_position(self, tmp_path: Any) -> None:
- path = tmp_path / "data.txt"
- path.write_text("hello world")
- handler = FileHandler()
- recipe = {"path": str(path), "mode": "r", "encoding": "utf-8", "position": 6}
- f = handler.reconstruct(recipe, version=1)
- try:
- assert f.read() == "world"
- finally:
- f.close()
-
- def test_cleanup_closes_file(self, tmp_path: Any) -> None:
- path = tmp_path / "data.txt"
- path.write_text("test")
- f = open(path) # noqa: SIM115
- handler = FileHandler()
- handler.cleanup(f)
- assert f.closed
-
- def test_cleanup_already_closed(self, tmp_path: Any) -> None:
- path = tmp_path / "data.txt"
- path.write_text("test")
- f = open(path) # noqa: SIM115
- f.close()
- handler = FileHandler()
- handler.cleanup(f) # no error
-
-
-# ---------------------------------------------------------------------------
-# LoggerHandler
-# ---------------------------------------------------------------------------
-
-
-class TestLoggerHandler:
- def test_detect_logger(self) -> None:
- handler = LoggerHandler()
- lgr = logging.getLogger("test.proxies.detect")
- assert handler.detect(lgr) is True
-
- def test_detect_non_logger(self) -> None:
- handler = LoggerHandler()
- assert handler.detect("not a logger") is False
-
- def test_round_trip(self) -> None:
- handler = LoggerHandler()
- lgr = logging.getLogger("test.proxies.roundtrip")
- lgr.setLevel(logging.WARNING)
- recipe = handler.deconstruct(lgr)
- reconstructed = handler.reconstruct(recipe, version=1)
- assert reconstructed.name == "test.proxies.roundtrip"
- assert reconstructed.level == logging.WARNING
-
- def test_cleanup_noop(self) -> None:
- handler = LoggerHandler()
- lgr = logging.getLogger("test.proxies.cleanup")
- handler.cleanup(lgr) # no error
-
-
-# ---------------------------------------------------------------------------
-# ProxyRegistry
-# ---------------------------------------------------------------------------
-
-
-class TestProxyRegistry:
- def test_register_and_get(self) -> None:
- reg = ProxyRegistry()
- handler = FileHandler()
- reg.register(handler)
- assert reg.get("file") is handler
-
- def test_get_missing(self) -> None:
- reg = ProxyRegistry()
- assert reg.get("nonexistent") is None
-
- def test_find_handler(self, tmp_path: Any) -> None:
- reg = ProxyRegistry()
- register_builtin_handlers(reg)
- f = open(tmp_path / "test.txt", "w") # noqa: SIM115
- try:
- found = reg.find_handler(f)
- assert found is not None
- assert found.name == "file"
- finally:
- f.close()
-
- def test_find_handler_no_match(self) -> None:
- reg = ProxyRegistry()
- register_builtin_handlers(reg)
- assert reg.find_handler(42) is None
-
-
-# ---------------------------------------------------------------------------
-# Reconstruction
-# ---------------------------------------------------------------------------
-
-
-class TestReconstruct:
- def test_proxy_marker_reconstructed(self, tmp_path: Any) -> None:
- path = tmp_path / "data.txt"
- path.write_text("content")
- reg = ProxyRegistry()
- reg.register(FileHandler())
-
- marker = {
- "__taskito_proxy__": True,
- "handler": "file",
- "version": 1,
- "identity": "id-1",
- "recipe": {
- "path": str(path),
- "mode": "r",
- "encoding": "utf-8",
- "position": 0,
- },
- }
- args, _kwargs, cleanup_list = reconstruct_proxies((marker,), {}, reg)
- try:
- assert hasattr(args[0], "read")
- assert args[0].read() == "content"
- assert len(cleanup_list) == 1
- finally:
- cleanup_proxies(cleanup_list)
-
- def test_cleanup_list_populated(self, tmp_path: Any) -> None:
- path = tmp_path / "data.txt"
- path.write_text("test")
- reg = ProxyRegistry()
- reg.register(FileHandler())
-
- marker = {
- "__taskito_proxy__": True,
- "handler": "file",
- "version": 1,
- "identity": "id-2",
- "recipe": {
- "path": str(path),
- "mode": "r",
- "encoding": "utf-8",
- "position": 0,
- },
- }
- _, _, cleanup_list = reconstruct_proxies((marker,), {}, reg)
- assert len(cleanup_list) == 1
- handler, obj = cleanup_list[0]
- assert handler.name == "file"
- assert not obj.closed
- cleanup_proxies(cleanup_list)
- assert obj.closed
-
- def test_cleanup_runs_lifo(self, tmp_path: Any) -> None:
- """Cleanup runs in reverse reconstruction order."""
- p1 = tmp_path / "a.txt"
- p1.write_text("a")
- p2 = tmp_path / "b.txt"
- p2.write_text("b")
- reg = ProxyRegistry()
- reg.register(FileHandler())
-
- markers = [
- {
- "__taskito_proxy__": True,
- "handler": "file",
- "version": 1,
- "identity": f"id-{i}",
- "recipe": {
- "path": str(p),
- "mode": "r",
- "encoding": "utf-8",
- "position": 0,
- },
- }
- for i, p in enumerate([p1, p2])
- ]
- args, _, cleanup_list = reconstruct_proxies(tuple(markers), {}, reg)
- assert not args[0].closed
- assert not args[1].closed
- cleanup_proxies(cleanup_list)
- assert args[0].closed
- assert args[1].closed
-
- def test_cleanup_catches_errors(self) -> None:
- """Cleanup errors are logged, not raised."""
-
- class BadHandler:
- name = "bad"
- version = 1
- handled_types: tuple[type, ...] = ()
-
- def detect(self, obj: Any) -> bool:
- return False
-
- def deconstruct(self, obj: Any) -> dict[str, Any]:
- return {}
-
- def reconstruct(self, recipe: dict[str, Any], version: int) -> Any:
- return "obj"
-
- def cleanup(self, obj: Any) -> None:
- raise RuntimeError("cleanup boom")
-
- cleanup_list: list[tuple[Any, Any]] = [(BadHandler(), "obj")]
- cleanup_proxies(cleanup_list) # should not raise
-
- def test_missing_handler_raises(self) -> None:
- reg = ProxyRegistry()
- marker = {
- "__taskito_proxy__": True,
- "handler": "nonexistent",
- "version": 1,
- "recipe": {},
- }
- with pytest.raises(ProxyReconstructionError, match="No proxy handler"):
- reconstruct_proxies((marker,), {}, reg)
-
- def test_no_markers_passthrough(self) -> None:
- """Args without markers pass through unchanged."""
- reg = ProxyRegistry()
- register_builtin_handlers(reg)
- args, kwargs, cleanup = reconstruct_proxies((1, "hello", [3, 4]), {"key": "val"}, reg)
- assert args == (1, "hello", [3, 4])
- assert kwargs == {"key": "val"}
- assert cleanup == []
-
-
-# ---------------------------------------------------------------------------
-# Identity tracking
-# ---------------------------------------------------------------------------
-
-
-class TestIdentityTracking:
- def test_same_object_deduped(self, tmp_path: Any) -> None:
- """Same file handle passed twice produces one marker and one ref."""
- path = tmp_path / "data.txt"
- path.write_text("test")
-
- queue = Queue(db_path=str(tmp_path / "q.db"), interception="strict")
- f = open(path) # noqa: SIM115
- try:
- assert queue._interceptor is not None
- walker = queue._interceptor._walker
- args, _kw, _res = walker.walk((f, f), {})
-
- # First should be a full proxy marker
- assert args[0].get("__taskito_proxy__") is True
- # Second should be a reference to the first
- assert "__taskito_ref__" in args[1]
- assert args[1]["__taskito_ref__"] == args[0]["identity"]
- finally:
- f.close()
-
- def test_identity_reconstructed_once(self, tmp_path: Any) -> None:
- """Reconstruction creates one object; both positions share it."""
- path = tmp_path / "data.txt"
- path.write_text("shared")
- reg = ProxyRegistry()
- reg.register(FileHandler())
-
- identity = "shared-id"
- marker = {
- "__taskito_proxy__": True,
- "handler": "file",
- "version": 1,
- "identity": identity,
- "recipe": {
- "path": str(path),
- "mode": "r",
- "encoding": "utf-8",
- "position": 0,
- },
- }
- ref = {"__taskito_ref__": identity}
-
- args, _, cleanup = reconstruct_proxies((marker, ref), {}, reg)
- try:
- assert args[0] is args[1] # same object
- assert len(cleanup) == 1 # only one cleanup entry
- finally:
- cleanup_proxies(cleanup)
-
- def test_different_objects_separate(self, tmp_path: Any) -> None:
- """Two different file handles get separate recipes."""
- p1 = tmp_path / "a.txt"
- p1.write_text("a")
- p2 = tmp_path / "b.txt"
- p2.write_text("b")
-
- queue = Queue(db_path=str(tmp_path / "q.db"), interception="strict")
- f1 = open(p1) # noqa: SIM115
- f2 = open(p2) # noqa: SIM115
- try:
- assert queue._interceptor is not None
- walker = queue._interceptor._walker
- args, _, _ = walker.walk((f1, f2), {})
- assert args[0].get("__taskito_proxy__") is True
- assert args[1].get("__taskito_proxy__") is True
- assert args[0]["identity"] != args[1]["identity"]
- finally:
- f1.close()
- f2.close()
-
-
-# ---------------------------------------------------------------------------
-# Proxy in nested structures
-# ---------------------------------------------------------------------------
-
-
-def test_proxy_in_nested_dict(tmp_path: Any) -> None:
- """File inside a dict is proxied."""
- path = tmp_path / "nested.txt"
- path.write_text("nested content")
-
- queue = Queue(db_path=str(tmp_path / "q.db"), interception="strict")
- f = open(path) # noqa: SIM115
- try:
- assert queue._interceptor is not None
- walker = queue._interceptor._walker
- _, kwargs, _ = walker.walk((), {"config": {"file": f}})
- inner = kwargs["config"]["file"]
- assert inner.get("__taskito_proxy__") is True
- assert inner["handler"] == "file"
- finally:
- f.close()
-
-
-# ---------------------------------------------------------------------------
-# End-to-end with Queue test mode
-# ---------------------------------------------------------------------------
-
-
-def test_proxy_roundtrip_in_test_mode(queue: Queue) -> None:
- """In test mode, original objects pass through (no serialization)."""
- captured: list[Any] = []
-
- @queue.task()
- def use_logger(lgr: Any) -> None:
- captured.append(lgr)
-
- lgr = logging.getLogger("test.proxy.e2e")
- with queue.test_mode() as results:
- use_logger.delay(lgr)
-
- assert len(results) == 1
- assert results[0].succeeded
- assert captured[0] is lgr
-
-
-def test_logger_proxy_marker_production(tmp_path: Any) -> None:
- """Logger produces a proxy marker when interception is on."""
- queue = Queue(db_path=str(tmp_path / "q.db"), interception="strict")
- lgr = logging.getLogger("test.proxy.marker")
-
- assert queue._interceptor is not None
- walker = queue._interceptor._walker
- args, _, _ = walker.walk((lgr,), {})
- assert args[0].get("__taskito_proxy__") is True
- assert args[0]["handler"] == "logger"
- assert args[0]["recipe"]["name"] == "test.proxy.marker"
diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py
deleted file mode 100644
index 6a03dd3..0000000
--- a/tests/test_rate_limit.py
+++ /dev/null
@@ -1,37 +0,0 @@
-"""Tests for rate limiting."""
-
-import threading
-import time
-
-from taskito import Queue
-
-
-def test_rate_limit_throttles(queue: Queue) -> None:
- """Rate-limited tasks should be throttled."""
- timestamps: list[float] = []
-
- @queue.task(rate_limit="2/s")
- def rate_limited_task(n: int) -> int:
- timestamps.append(time.monotonic())
- return n
-
- # Enqueue 4 tasks (at 2/s, should take ~2s)
- for i in range(4):
- rate_limited_task.delay(i)
-
- worker_thread = threading.Thread(
- target=queue.run_worker,
- daemon=True,
- )
- worker_thread.start()
-
- # Wait for all tasks
- time.sleep(5)
-
- # Should have all 4 results
- assert len(timestamps) == 4
-
- # The time span should be >= 1s (since 4 tasks at 2/s = 2s minimum)
- if len(timestamps) >= 2:
- span = timestamps[-1] - timestamps[0]
- assert span >= 0.5 # Allow some slack
diff --git a/tests/test_resource_system_full.py b/tests/test_resource_system_full.py
deleted file mode 100644
index 084339d..0000000
--- a/tests/test_resource_system_full.py
+++ /dev/null
@@ -1,675 +0,0 @@
-"""Tests for the full resource system — all phases A through M."""
-
-import collections
-import contextvars
-import tempfile
-import threading
-from typing import NamedTuple
-from unittest.mock import MagicMock
-
-import pytest
-
-from taskito import Inject, MockResource, Queue
-from taskito.exceptions import (
- ProxyReconstructionError,
- ResourceError,
- ResourceUnavailableError,
-)
-from taskito.inject import _InjectAlias
-from taskito.interception.built_in import build_default_registry
-from taskito.interception.converters import (
- convert_named_tuple,
- convert_ordered_dict,
- reconstruct_converted,
- reconstruct_named_tuple,
- reconstruct_ordered_dict,
-)
-from taskito.interception.metrics import InterceptionMetrics
-from taskito.interception.walker import ArgumentWalker
-from taskito.proxies.metrics import ProxyMetrics
-from taskito.proxies.no_proxy import NoProxy as NoProxyClass
-from taskito.proxies.schema import FieldSpec, validate_recipe
-from taskito.proxies.signing import sign_recipe, verify_recipe
-from taskito.resources.definition import ResourceDefinition, ResourceScope
-from taskito.resources.frozen import FrozenResource
-from taskito.resources.pool import PoolConfig, ResourcePool
-from taskito.resources.runtime import ResourceRuntime
-from taskito.resources.thread_local import ThreadLocalStore
-
-# ─── Phase A — NamedTuple / OrderedDict / lambda / tempfile ───
-
-
-class Point(NamedTuple):
- x: int
- y: int
-
-
-class TestNamedTupleConvert:
- def test_round_trip(self) -> None:
- p = Point(3, 7)
- data = convert_named_tuple(p)
- assert data["__taskito_convert__"] is True
- assert data["type_key"] == "named_tuple"
- result = reconstruct_named_tuple(data)
- assert result == p
- assert isinstance(result, Point)
-
- def test_via_reconstruct_dispatch(self) -> None:
- data = convert_named_tuple(Point(1, 2))
- result = reconstruct_converted(data)
- assert result == Point(1, 2)
-
-
-class TestOrderedDictConvert:
- def test_round_trip(self) -> None:
- od = collections.OrderedDict([("b", 2), ("a", 1)])
- data = convert_ordered_dict(od)
- assert data["__taskito_convert__"] is True
- assert data["type_key"] == "ordered_dict"
- result = reconstruct_ordered_dict(data)
- assert result == od
- assert isinstance(result, collections.OrderedDict)
- assert list(result.keys()) == ["b", "a"] # order preserved
-
-
-class TestWalkerPhaseA:
- def _make_walker(self) -> ArgumentWalker:
- reg = build_default_registry()
- return ArgumentWalker(reg, max_depth=10)
-
- def test_named_tuple_auto_detected(self) -> None:
- walker = self._make_walker()
- args = (Point(10, 20),)
- new_args, _, _result = walker.walk(args, {})
- assert new_args[0]["__taskito_convert__"] is True
- assert new_args[0]["type_key"] == "named_tuple"
-
- def test_lambda_rejected(self) -> None:
- walker = self._make_walker()
- fn = lambda x: x + 1 # noqa: E731
- args = (fn,)
- _, _, result = walker.walk(args, {})
- assert len(result.failures) == 1
- assert "lambda" in result.failures[0].type_name.lower()
-
- def test_tempfile_rejected(self) -> None:
- walker = self._make_walker()
- with tempfile.NamedTemporaryFile() as f:
- _, _, result = walker.walk((f,), {})
- assert len(result.failures) == 1
- reason = result.failures[0].reason.lower()
- assert "tempfile" in reason or "temporary" in reason
-
- def test_ordered_dict_converted(self) -> None:
- walker = self._make_walker()
- od = collections.OrderedDict([("x", 1)])
- new_args, _, _ = walker.walk((od,), {})
- assert new_args[0]["__taskito_convert__"] is True
- assert new_args[0]["type_key"] == "ordered_dict"
-
- def test_contextvars_context_rejected(self) -> None:
- reg = build_default_registry()
- walker = ArgumentWalker(reg, max_depth=10)
- ctx = contextvars.copy_context()
- _, _, result = walker.walk((ctx,), {})
- assert len(result.failures) == 1
- assert "context" in result.failures[0].reason.lower()
-
-
-class TestRegisterType:
- def test_register_redirect(self, tmp_path: object) -> None:
- q = Queue(
- db_path=":memory:",
- interception="strict",
- )
-
- class MyDB:
- pass
-
- q.register_type(MyDB, "redirect", resource="db")
- # Verify it's registered
- assert q._interceptor is not None
- entry = q._interceptor._registry.resolve(MyDB())
- assert entry is not None
-
- def test_requires_interception(self) -> None:
- q = Queue(db_path=":memory:", interception="off")
- with pytest.raises(RuntimeError, match="Interception is disabled"):
- q.register_type(int, "pass")
-
-
-# ─── Phase B — Proxy signing / schema / NoProxy ───
-
-
-class TestRecipeSigning:
- def test_valid_signature(self) -> None:
- recipe = {"path": "/tmp/f", "mode": "r"}
- sig = sign_recipe("file", 1, recipe, "secret")
- assert verify_recipe("file", 1, recipe, sig, "secret")
-
- def test_invalid_signature_raises(self) -> None:
- recipe = {"path": "/tmp/f", "mode": "r"}
- with pytest.raises(ProxyReconstructionError, match="checksum mismatch"):
- verify_recipe("file", 1, recipe, "bad-checksum", "secret")
-
- def test_deterministic(self) -> None:
- recipe = {"b": 2, "a": 1}
- s1 = sign_recipe("h", 1, recipe, "key")
- s2 = sign_recipe("h", 1, recipe, "key")
- assert s1 == s2
-
-
-class TestSchemaValidation:
- def test_valid_recipe(self) -> None:
- schema = {"path": FieldSpec(str), "mode": FieldSpec(str)}
- validate_recipe("test", {"path": "/f", "mode": "r"}, schema)
-
- def test_extra_key_raises(self) -> None:
- schema = {"path": FieldSpec(str)}
- with pytest.raises(ProxyReconstructionError, match="unexpected keys"):
- validate_recipe("test", {"path": "/f", "extra": 1}, schema)
-
- def test_missing_required_raises(self) -> None:
- schema = {"path": FieldSpec(str), "mode": FieldSpec(str)}
- with pytest.raises(ProxyReconstructionError, match="missing required"):
- validate_recipe("test", {"path": "/f"}, schema)
-
- def test_optional_field_ok(self) -> None:
- schema = {"path": FieldSpec(str), "enc": FieldSpec(str, required=False)}
- validate_recipe("test", {"path": "/f"}, schema)
-
- def test_wrong_type_raises(self) -> None:
- schema = {"count": FieldSpec(int)}
- with pytest.raises(ProxyReconstructionError, match="expected"):
- validate_recipe("test", {"count": "not-an-int"}, schema)
-
-
-class TestNoProxy:
- def test_unwrap_in_walker(self) -> None:
- reg = build_default_registry()
- walker = ArgumentWalker(reg, max_depth=10)
- sentinel = object()
- wrapped = NoProxyClass(sentinel)
- new_args, _, _ = walker.walk((wrapped,), {})
- assert new_args[0] is sentinel
-
-
-# ─── Phase D — Resource scopes ───
-
-
-class TestResourceScopes:
- def test_all_scopes_exist(self) -> None:
- assert ResourceScope.WORKER.value == "worker"
- assert ResourceScope.TASK.value == "task"
- assert ResourceScope.THREAD.value == "thread"
- assert ResourceScope.REQUEST.value == "request"
-
- def test_pool_config(self) -> None:
- cfg = PoolConfig(pool_size=5, pool_min=2)
- assert cfg.pool_size == 5
- assert cfg.pool_min == 2
-
- def test_resource_pool_acquire_release(self) -> None:
- created: list[dict[str, int]] = []
-
- def factory() -> dict[str, int]:
- d: dict[str, int] = {"id": len(created)}
- created.append(d)
- return d
-
- pool = ResourcePool(
- "test",
- factory,
- teardown=None,
- config=PoolConfig(pool_size=2, acquire_timeout=1.0),
- )
- inst = pool.acquire()
- assert inst["id"] == 0
- pool.release(inst)
- # Re-acquire should return same instance
- inst2 = pool.acquire()
- assert inst2 is inst
- pool.shutdown()
-
- def test_pool_exhaustion_raises(self) -> None:
- pool = ResourcePool(
- "test",
- lambda: {},
- teardown=None,
- config=PoolConfig(pool_size=1, acquire_timeout=0.1),
- )
- pool.acquire()
- with pytest.raises(ResourceUnavailableError, match="timed out"):
- pool.acquire()
- pool.shutdown()
-
- def test_pool_stats(self) -> None:
- pool = ResourcePool(
- "test",
- lambda: {},
- teardown=None,
- config=PoolConfig(pool_size=3),
- )
- inst = pool.acquire()
- s = pool.stats()
- assert s["active"] == 1
- assert s["size"] == 3
- pool.release(inst)
- s2 = pool.stats()
- assert s2["active"] == 0
- assert s2["idle"] == 1
- pool.shutdown()
-
- def test_pool_factory_failure_does_not_underflow_active_count(self) -> None:
- """Active count must remain at zero (not negative) when the factory raises.
-
- Regression: previously the increment ran before the factory call and
- was decremented in the except branch. Any future re-ordering or
- intervening release() risked underflowing the counter, surfacing as a
- negative `active` in `stats()`.
- """
-
- def boom() -> None:
- raise RuntimeError("factory blew up")
-
- pool = ResourcePool(
- "test",
- boom,
- teardown=None,
- config=PoolConfig(pool_size=2, acquire_timeout=0.1),
- )
-
- for _ in range(3):
- with pytest.raises(RuntimeError, match="factory blew up"):
- pool.acquire()
-
- s = pool.stats()
- assert s["active"] == 0
- assert s["total_acquisitions"] == 0 # Failed attempts don't count
- # Pool capacity must not leak: a fresh acquire after failures still works.
- # (Use a successful factory now to confirm the semaphore wasn't consumed.)
- ok_pool = ResourcePool(
- "test2",
- lambda: {"ok": True},
- teardown=None,
- config=PoolConfig(pool_size=1, acquire_timeout=0.1),
- )
- for _ in range(3):
- inst = ok_pool.acquire()
- ok_pool.release(inst)
- assert ok_pool.stats()["active"] == 0
- ok_pool.shutdown()
- pool.shutdown()
-
- def test_thread_local_store(self) -> None:
- counter = {"n": 0}
-
- def factory() -> int:
- counter["n"] += 1
- return counter["n"]
-
- store = ThreadLocalStore("test", factory, teardown=None)
- v1 = store.get_or_create()
- v2 = store.get_or_create()
- assert v1 == v2 # Same thread → same instance
-
- results: list[int] = []
-
- def worker() -> None:
- results.append(store.get_or_create())
-
- t = threading.Thread(target=worker)
- t.start()
- t.join()
- assert results[0] != v1 # Different thread → different instance
- store.teardown_all()
-
-
-class TestFrozenResource:
- def test_read_allowed(self) -> None:
- class Config:
- db_url = "postgres://localhost"
-
- frozen = FrozenResource(Config(), "config")
- assert frozen.db_url == "postgres://localhost"
-
- def test_write_raises(self) -> None:
- class Config:
- db_url = "x"
-
- frozen = FrozenResource(Config(), "config")
- with pytest.raises(ResourceError, match="read-only"):
- frozen.db_url = "y"
-
- def test_delete_raises(self) -> None:
- class Config:
- db_url = "x"
-
- frozen = FrozenResource(Config(), "config")
- with pytest.raises(ResourceError, match="read-only"):
- del frozen.db_url
-
-
-class TestRuntimeScopeAware:
- def test_worker_scope_resolve(self) -> None:
- defn = ResourceDefinition(
- name="config",
- factory=lambda: {"key": "value"},
- scope=ResourceScope.WORKER,
- )
- rt = ResourceRuntime({"config": defn})
- rt.initialize()
- result = rt.resolve("config")
- assert result == {"key": "value"}
- rt.teardown()
-
- def test_acquire_for_task_worker(self) -> None:
- defn = ResourceDefinition(
- name="config",
- factory=lambda: 42,
- scope=ResourceScope.WORKER,
- )
- rt = ResourceRuntime({"config": defn})
- rt.initialize()
- inst, release = rt.acquire_for_task("config")
- assert inst == 42
- assert release is None
- rt.teardown()
-
- def test_acquire_for_task_request_scope(self) -> None:
- counter = {"n": 0}
-
- def factory() -> int:
- counter["n"] += 1
- return counter["n"]
-
- defn = ResourceDefinition(
- name="req",
- factory=factory,
- scope=ResourceScope.REQUEST,
- )
- rt = ResourceRuntime({"req": defn})
- rt.initialize()
- inst1, release1 = rt.acquire_for_task("req")
- inst2, release2 = rt.acquire_for_task("req")
- assert inst1 != inst2 # Fresh each time
- assert release1 is not None
- assert release2 is not None
- release1()
- release2()
- rt.teardown()
-
- def test_frozen_resource_in_runtime(self) -> None:
- defn = ResourceDefinition(
- name="cfg",
- factory=lambda: MagicMock(value=10),
- scope=ResourceScope.WORKER,
- frozen=True,
- )
- rt = ResourceRuntime({"cfg": defn})
- rt.initialize()
- inst = rt.resolve("cfg")
- assert inst.value == 10
- with pytest.raises(ResourceError, match="read-only"):
- inst.new_attr = "x"
- rt.teardown()
-
- def test_reload_reloadable(self) -> None:
- counter = {"n": 0}
-
- def factory() -> int:
- counter["n"] += 1
- return counter["n"]
-
- defn = ResourceDefinition(
- name="svc",
- factory=factory,
- scope=ResourceScope.WORKER,
- reloadable=True,
- )
- rt = ResourceRuntime({"svc": defn})
- rt.initialize()
- assert rt.resolve("svc") == 1
- results = rt.reload()
- assert results["svc"] is True
- assert rt.resolve("svc") == 2
- rt.teardown()
-
- def test_reload_skips_non_reloadable(self) -> None:
- defn = ResourceDefinition(
- name="fixed",
- factory=lambda: 1,
- scope=ResourceScope.WORKER,
- reloadable=False,
- )
- rt = ResourceRuntime({"fixed": defn})
- rt.initialize()
- results = rt.reload()
- assert "fixed" not in results
- rt.teardown()
-
- def test_status_includes_pool(self) -> None:
- defn = ResourceDefinition(
- name="db",
- factory=lambda: {},
- scope=ResourceScope.TASK,
- pool_size=5,
- )
- rt = ResourceRuntime({"db": defn})
- rt.initialize()
- status = rt.status()
- assert len(status) == 1
- assert status[0]["name"] == "db"
- assert "pool" in status[0]
- assert status[0]["pool"]["size"] == 5
- rt.teardown()
-
-
-# ─── Phase E — Inject annotation ───
-
-
-class TestInjectAnnotation:
- def test_inject_alias_created(self) -> None:
- alias = Inject["db"] # type: ignore[type-arg,name-defined]
- assert isinstance(alias, _InjectAlias)
- assert alias.resource_name == "db"
-
- def test_inject_alias_equality(self) -> None:
- assert Inject["db"] == Inject["db"]
- assert Inject["db"] != Inject["redis"]
-
- def test_inject_annotation_detected_in_task(self) -> None:
- q = Queue(db_path=":memory:")
-
- @q.task()
- def my_task(x: int, db: Inject["db"]) -> None: # type: ignore[type-arg,name-defined] # noqa: F821
- pass
-
- assert "db" in q._task_inject_map.get(my_task.name, [])
-
- def test_inject_annotation_merged_with_explicit(self) -> None:
- q = Queue(db_path=":memory:")
-
- @q.task(inject=["redis"])
- def my_task(x: int, db: Inject["db"]) -> None: # type: ignore[type-arg,name-defined] # noqa: F821
- pass
-
- injects = q._task_inject_map.get(my_task.name, [])
- assert "redis" in injects
- assert "db" in injects
-
-
-# ─── Phase F — TOML config ───
-
-
-class TestTomlConfig:
- def test_load_resources(self, tmp_path: object) -> None:
- import pathlib
-
- path = pathlib.Path(str(tmp_path)) / "resources.toml"
- path.write_text('[resources.config]\nfactory = "builtins:dict"\nscope = "worker"\n')
- from taskito.resources.toml_config import load_resources_from_toml
-
- defs = load_resources_from_toml(str(path))
- assert len(defs) == 1
- assert defs[0].name == "config"
- assert defs[0].scope == ResourceScope.WORKER
-
- def test_missing_factory_raises(self, tmp_path: object) -> None:
- import pathlib
-
- path = pathlib.Path(str(tmp_path)) / "bad.toml"
- path.write_text('[resources.db]\nscope = "worker"\n')
- from taskito.resources.toml_config import load_resources_from_toml
-
- with pytest.raises(ValueError, match="missing required 'factory'"):
- load_resources_from_toml(str(path))
-
- def test_queue_load_resources(self, tmp_path: object) -> None:
- import pathlib
-
- path = pathlib.Path(str(tmp_path)) / "res.toml"
- path.write_text('[resources.cfg]\nfactory = "builtins:dict"\n')
- q = Queue(db_path=":memory:")
- q.load_resources(str(path))
- assert "cfg" in q._resource_definitions
-
-
-# ─── Phase H — Proxy metrics ───
-
-
-class TestProxyMetrics:
- def test_record_and_retrieve(self) -> None:
- m = ProxyMetrics()
- m.record_reconstruction("file", 15.5)
- m.record_reconstruction("file", 20.0)
- m.record_error("file")
- result = m.to_list()
- assert len(result) == 1
- assert result[0]["handler"] == "file"
- assert result[0]["total_reconstructions"] == 2
- assert result[0]["total_errors"] == 1
- assert result[0]["avg_duration_ms"] == 17.75
-
- def test_queue_proxy_stats(self) -> None:
- q = Queue(db_path=":memory:")
- stats = q.proxy_stats()
- assert isinstance(stats, list)
-
-
-# ─── Phase I — Interception metrics ───
-
-
-class TestInterceptionMetrics:
- def test_record_and_retrieve(self) -> None:
- m = InterceptionMetrics()
- m.record(5.0, {"pass": 3, "convert": 1}, max_depth=2)
- d = m.to_dict()
- assert d["total_intercepts"] == 1
- assert d["strategy_counts"]["pass"] == 3
- assert d["max_depth_reached"] == 2
-
- def test_queue_interception_stats(self) -> None:
- q = Queue(db_path=":memory:", interception="strict")
- stats = q.interception_stats()
- assert "total_intercepts" in stats
-
- def test_walker_tracks_strategy_counts(self) -> None:
- reg = build_default_registry()
- walker = ArgumentWalker(reg, max_depth=10)
- _, _, result = walker.walk((42, "hello"), {})
- assert result.strategy_counts.get("pass", 0) >= 2
-
-
-# ─── Phase L — MockResource ───
-
-
-class TestMockResource:
- def test_return_value(self) -> None:
- mock = MockResource("db", return_value="fake-db")
- assert mock.get() == "fake-db"
-
- def test_wraps(self) -> None:
- real = {"conn": True}
- mock = MockResource("db", wraps=real)
- assert mock.get() is real
-
- def test_track_calls(self) -> None:
- mock = MockResource("db", return_value="x", track_calls=True)
- mock.get()
- mock.get()
- assert mock.call_count == 2
-
- def test_mock_in_test_mode(self) -> None:
- q = Queue(db_path=":memory:")
- mock_db = MockResource("db", return_value="mock-db", track_calls=True)
-
- @q.task(inject=["db"])
- def use_db(db: object = None) -> object:
- return db
-
- with q.test_mode(resources={"db": mock_db}) as results:
- use_db.delay()
-
- assert len(results) == 1
- assert results[0].return_value == "mock-db"
- assert mock_db.call_count == 1 # .get() called once during setup
-
-
-# ─── Phase M — Test-mode proxy passthrough ───
-
-
-class TestTestModePassthrough:
- def test_test_mode_sets_flag(self) -> None:
- q = Queue(db_path=":memory:")
- assert q._test_mode_active is False
- with q.test_mode():
- assert q._test_mode_active is True
- assert q._test_mode_active is False
-
- def test_interception_skipped_in_test_mode(self) -> None:
- q = Queue(db_path=":memory:", interception="strict")
-
- @q.task()
- def identity(x: object) -> object:
- return x
-
- sentinel = object()
- with q.test_mode() as results:
- identity.delay(sentinel)
- # In test mode, the sentinel passes through without interception
- assert len(results) == 1
-
-
-# ─── Integration: resource injection end-to-end ───
-
-
-class TestResourceInjectionE2E:
- def test_inject_in_test_mode(self) -> None:
- q = Queue(db_path=":memory:")
-
- @q.worker_resource("db")
- def create_db() -> str:
- return "real-db-connection"
-
- @q.task(inject=["db"])
- def process(order_id: int, db: object = None) -> str:
- return f"processed {order_id} with {db}"
-
- with q.test_mode(resources={"db": "mock-db"}) as results:
- process.delay(42)
-
- assert len(results) == 1
- assert results[0].return_value == "processed 42 with mock-db"
-
- def test_inject_annotation_in_test_mode(self) -> None:
- q = Queue(db_path=":memory:")
-
- @q.task()
- def process(order_id: int, db: Inject["db"] = None) -> str: # type: ignore[type-arg,name-defined,assignment] # noqa: F821
- return f"{order_id}:{db}"
-
- with q.test_mode(resources={"db": "injected"}) as results:
- process.delay(1)
-
- assert results[0].return_value == "1:injected"
diff --git a/tests/test_resources.py b/tests/test_resources.py
deleted file mode 100644
index 9d95aa1..0000000
--- a/tests/test_resources.py
+++ /dev/null
@@ -1,367 +0,0 @@
-"""Tests for the worker resource runtime (dependency injection)."""
-
-from __future__ import annotations
-
-from typing import Any
-
-import pytest
-
-from taskito import (
- CircularDependencyError,
- Queue,
- ResourceInitError,
- ResourceNotFoundError,
-)
-from taskito.resources import (
- ResourceDefinition,
- ResourceRuntime,
- detect_cycle,
- topological_sort,
-)
-
-# ---------------------------------------------------------------------------
-# Registration
-# ---------------------------------------------------------------------------
-
-
-def test_worker_resource_decorator_registers(queue: Queue) -> None:
- """@queue.worker_resource stores a ResourceDefinition."""
-
- @queue.worker_resource("cache")
- def create_cache() -> dict[str, int]:
- return {"hits": 0}
-
- assert "cache" in queue._resource_definitions
- defn = queue._resource_definitions["cache"]
- assert defn.name == "cache"
- assert defn.factory is create_cache
-
-
-def test_register_resource_programmatic(queue: Queue) -> None:
- """register_resource() stores a definition without the decorator."""
-
- def factory() -> str:
- return "hello"
-
- queue.register_resource(ResourceDefinition(name="greeter", factory=factory))
- assert "greeter" in queue._resource_definitions
- assert queue._resource_definitions["greeter"].factory is factory
-
-
-# ---------------------------------------------------------------------------
-# Graph
-# ---------------------------------------------------------------------------
-
-
-def test_circular_dependency_detected(queue: Queue) -> None:
- """Circular deps raise CircularDependencyError at registration time."""
-
- @queue.worker_resource("a", depends_on=["b"])
- def make_a(b: Any) -> str:
- return "a"
-
- with pytest.raises(CircularDependencyError):
-
- @queue.worker_resource("b", depends_on=["a"])
- def make_b(a: Any) -> str:
- return "b"
-
-
-def test_topological_sort_order() -> None:
- """Dependencies are initialized before dependents."""
- defs = {
- "config": ResourceDefinition(name="config", factory=lambda: {}),
- "db": ResourceDefinition(name="db", factory=lambda config: {}, depends_on=["config"]),
- "cache": ResourceDefinition(
- name="cache", factory=lambda config: {}, depends_on=["config"]
- ),
- }
- order = topological_sort(defs)
- assert order.index("config") < order.index("db")
- assert order.index("config") < order.index("cache")
-
-
-def test_detect_cycle_returns_none_when_no_cycle() -> None:
- defs = {
- "a": ResourceDefinition(name="a", factory=lambda: 1),
- "b": ResourceDefinition(name="b", factory=lambda a: 2, depends_on=["a"]),
- }
- assert detect_cycle(defs) is None
-
-
-def test_detect_cycle_returns_path() -> None:
- defs = {
- "a": ResourceDefinition(name="a", factory=lambda b: 1, depends_on=["b"]),
- "b": ResourceDefinition(name="b", factory=lambda a: 2, depends_on=["a"]),
- }
- cycle = detect_cycle(defs)
- assert cycle is not None
- assert len(cycle) >= 2
-
-
-# ---------------------------------------------------------------------------
-# Runtime
-# ---------------------------------------------------------------------------
-
-
-def test_dependency_injection_into_factory() -> None:
- """Factory receives its depends_on resources as kwargs."""
- defs = {
- "config": ResourceDefinition(name="config", factory=lambda: {"url": "sqlite://"}),
- "db": ResourceDefinition(
- name="db", factory=lambda config: f"connected:{config['url']}", depends_on=["config"]
- ),
- }
- rt = ResourceRuntime(defs)
- rt.initialize()
- assert rt.resolve("db") == "connected:sqlite://"
- rt.teardown()
-
-
-def test_missing_resource_raises() -> None:
- """Resolving an unregistered name raises ResourceNotFoundError."""
- rt = ResourceRuntime({})
- with pytest.raises(ResourceNotFoundError):
- rt.resolve("nope")
-
-
-def test_teardown_reverse_order() -> None:
- """Resources are torn down in reverse initialization order."""
- teardown_log: list[str] = []
-
- def td_config(inst: Any) -> None:
- teardown_log.append("config")
-
- def td_db(inst: Any) -> None:
- teardown_log.append("db")
-
- defs = {
- "config": ResourceDefinition(name="config", factory=lambda: "cfg", teardown=td_config),
- "db": ResourceDefinition(
- name="db",
- factory=lambda config: "db_conn",
- depends_on=["config"],
- teardown=td_db,
- ),
- }
- rt = ResourceRuntime(defs)
- rt.initialize()
- rt.teardown()
- assert teardown_log == ["db", "config"]
-
-
-def test_init_failure_raises_resource_init_error() -> None:
- """A factory that raises is wrapped in ResourceInitError."""
- defs = {
- "broken": ResourceDefinition(
- name="broken",
- factory=lambda: (_ for _ in ()).throw(RuntimeError("boom")),
- ),
- }
- rt = ResourceRuntime(defs)
- with pytest.raises(ResourceInitError, match="boom"):
- rt.initialize()
-
-
-def test_from_test_overrides() -> None:
- """from_test_overrides pre-populates instances without factories."""
- rt = ResourceRuntime.from_test_overrides({"db": "mock_db", "cache": "mock_cache"})
- assert rt.resolve("db") == "mock_db"
- assert rt.resolve("cache") == "mock_cache"
-
-
-# ---------------------------------------------------------------------------
-# Async factory
-# ---------------------------------------------------------------------------
-
-
-def test_async_factory() -> None:
- """Async factories are awaited during initialize."""
-
- async def make_client() -> str:
- return "async_client"
-
- defs = {
- "client": ResourceDefinition(name="client", factory=make_client),
- }
- rt = ResourceRuntime(defs)
- rt.initialize()
- assert rt.resolve("client") == "async_client"
- rt.teardown()
-
-
-# ---------------------------------------------------------------------------
-# Injection into tasks
-# ---------------------------------------------------------------------------
-
-
-def test_resource_injected_into_task(queue: Queue) -> None:
- """Task with inject=["db"] receives the resource as a kwarg."""
-
- @queue.worker_resource("db")
- def create_db() -> str:
- return "live_db"
-
- results_holder: list[Any] = []
-
- @queue.task(inject=["db"])
- def my_task(x: int, db: Any = None) -> None:
- results_holder.append((x, db))
-
- with queue.test_mode(resources={"db": "test_db"}) as results:
- my_task.delay(42)
-
- assert len(results) == 1
- assert results[0].succeeded
- assert results_holder == [(42, "test_db")]
-
-
-def test_explicit_kwarg_wins_over_inject(queue: Queue) -> None:
- """Caller-provided kwargs are not overridden by injection."""
-
- @queue.worker_resource("db")
- def create_db() -> str:
- return "injected_db"
-
- results_holder: list[Any] = []
-
- @queue.task(inject=["db"])
- def my_task(db: Any = None) -> None:
- results_holder.append(db)
-
- with queue.test_mode(resources={"db": "injected_db"}):
- my_task.delay(db="explicit_db")
-
- assert results_holder == ["explicit_db"]
-
-
-def test_test_mode_with_resources(queue: Queue) -> None:
- """test_mode(resources=...) injects mock resources."""
-
- @queue.worker_resource("cache")
- def create_cache() -> dict[str, str]:
- return {}
-
- captured: list[Any] = []
-
- @queue.task(inject=["cache"])
- def use_cache(cache: Any = None) -> None:
- captured.append(cache)
-
- mock_cache = {"key": "value"}
- with queue.test_mode(resources={"cache": mock_cache}) as results:
- use_cache.delay()
-
- assert captured == [mock_cache]
- assert results[0].succeeded
- # Runtime is cleaned up after exiting test mode
- assert queue._resource_runtime is None
-
-
-def test_test_mode_restores_previous_runtime(queue: Queue) -> None:
- """Exiting test_mode restores whatever runtime was set before."""
- assert queue._resource_runtime is None
-
- with queue.test_mode(resources={"x": 1}):
- assert queue._resource_runtime is not None
-
- assert queue._resource_runtime is None
-
-
-# ---------------------------------------------------------------------------
-# Health checking
-# ---------------------------------------------------------------------------
-
-
-def test_health_check_recreation(poll_until: Any) -> None:
- """Unhealthy resource is recreated; permanent failure marks it unavailable."""
- from taskito.exceptions import ResourceUnavailableError
- from taskito.resources.health import HealthChecker
-
- call_count = 0
-
- def make_svc() -> str:
- nonlocal call_count
- call_count += 1
- if call_count > 1:
- raise RuntimeError("factory broken")
- return f"svc_v{call_count}"
-
- def check_health(inst: Any) -> bool:
- # Always fail after initial creation
- return False
-
- defs = {
- "svc": ResourceDefinition(
- name="svc",
- factory=make_svc,
- health_check=check_health,
- health_check_interval=0.2,
- max_recreation_attempts=2,
- ),
- }
- rt = ResourceRuntime(defs)
- rt.initialize()
- assert rt.resolve("svc") == "svc_v1"
-
- checker = HealthChecker(rt)
- checker.start()
- poll_until(
- lambda: "svc" in rt._unhealthy,
- timeout=5,
- message="resource never marked unhealthy after exhausting attempts",
- )
- checker.stop()
-
- assert "svc" in rt._unhealthy
- with pytest.raises(ResourceUnavailableError):
- rt.resolve("svc")
-
-
-# ---------------------------------------------------------------------------
-# Banner
-# ---------------------------------------------------------------------------
-
-
-def test_banner_shows_resources(queue: Queue, capsys: pytest.CaptureFixture[str]) -> None:
- """Resources section appears in the startup banner."""
-
- @queue.worker_resource("db", depends_on=["config"])
- def create_db(config: Any) -> str:
- return "db"
-
- @queue.worker_resource("config")
- def create_config() -> dict[str, str]:
- return {}
-
- queue._print_banner(["default"])
- captured = capsys.readouterr().out
- assert "[resources]" in captured
- assert "config" in captured
- assert "db" in captured
- assert "depends: config" in captured
-
-
-# ---------------------------------------------------------------------------
-# TaskWrapper.inject property
-# ---------------------------------------------------------------------------
-
-
-def test_task_wrapper_inject_property(queue: Queue) -> None:
- """TaskWrapper exposes the inject list."""
-
- @queue.task(inject=["db", "cache"])
- def my_task(db: Any = None, cache: Any = None) -> None:
- pass
-
- assert my_task.inject == ["db", "cache"]
-
-
-def test_task_wrapper_inject_default(queue: Queue) -> None:
- """TaskWrapper.inject defaults to empty list."""
-
- @queue.task()
- def my_task() -> None:
- pass
-
- assert my_task.inject == []
diff --git a/tests/test_result_race.py b/tests/test_result_race.py
deleted file mode 100644
index 64e117d..0000000
--- a/tests/test_result_race.py
+++ /dev/null
@@ -1,62 +0,0 @@
-"""Regression: result()/aresult() must surface terminal-failure exceptions
-even when the deadline is reached on the same iteration the failure lands.
-
-Race scenario:
-1. `_poll_once()` returns ("running", None) — snapshot from before the failure
-2. The job transitions to `failed`/`dead`/`cancelled` in storage
-3. `time.monotonic() >= deadline` — about to give up
-4. Fix: re-poll once before raising `TimeoutError` so the caller sees the
- real exception (TaskFailedError / MaxRetriesExceededError / ...).
-"""
-
-from typing import Any
-from unittest.mock import patch
-
-import pytest
-
-from taskito import Queue
-from taskito.exceptions import TaskFailedError
-
-
-def _make_failing_poll(job_id: str) -> Any:
- """Returns a `_poll_once` stub: 1st call → running, 2nd call → TaskFailedError."""
- call_count = [0]
-
- def stub() -> tuple[str, Any]:
- call_count[0] += 1
- if call_count[0] == 1:
- return ("running", None)
- raise TaskFailedError(f"Job {job_id} failed: simulated late failure")
-
- stub.calls = call_count # type: ignore[attr-defined]
- return stub
-
-
-def test_result_surfaces_terminal_failure_at_deadline(tmp_path: Any) -> None:
- queue = Queue(db_path=str(tmp_path / "q.db"))
-
- @queue.task()
- def will_fail() -> None: ...
-
- job = will_fail.delay()
- poll = _make_failing_poll(job.id)
-
- with patch.object(job, "_poll_once", side_effect=poll), pytest.raises(TaskFailedError):
- job.result(timeout=0.01)
-
- assert poll.calls[0] == 2, "second defensive poll must run before raising TimeoutError"
-
-
-async def test_aresult_surfaces_terminal_failure_at_deadline(tmp_path: Any) -> None:
- queue = Queue(db_path=str(tmp_path / "q.db"))
-
- @queue.task()
- async def will_fail() -> None: ...
-
- job = will_fail.delay()
- poll = _make_failing_poll(job.id)
-
- with patch.object(job, "_poll_once", side_effect=poll), pytest.raises(TaskFailedError):
- await job.aresult(timeout=0.01)
-
- assert poll.calls[0] == 2, "second defensive poll must run before raising TimeoutError"
diff --git a/tests/test_retry.py b/tests/test_retry.py
deleted file mode 100644
index 00af8c5..0000000
--- a/tests/test_retry.py
+++ /dev/null
@@ -1,87 +0,0 @@
-"""Tests for retry logic and dead letter queue."""
-
-import threading
-from typing import Any
-
-from taskito import Queue
-
-PollUntil = Any # the conftest fixture's runtime type
-
-
-def test_failing_task_retries(queue: Queue) -> None:
- """A failing task should be retried up to max_retries times."""
- call_count = 0
-
- @queue.task(max_retries=3, retry_backoff=0.1)
- def flaky_task() -> str:
- nonlocal call_count
- call_count += 1
- if call_count < 3:
- raise ValueError(f"attempt {call_count}")
- return "success"
-
- job = flaky_task.delay()
-
- worker_thread = threading.Thread(
- target=queue.run_worker,
- daemon=True,
- )
- worker_thread.start()
-
- result = job.result(timeout=30)
- assert result == "success"
- assert call_count == 3
-
-
-def test_exhausted_retries_goes_to_dlq(queue: Queue, poll_until: PollUntil) -> None:
- """A task that always fails should end up in the dead letter queue."""
-
- @queue.task(max_retries=2, retry_backoff=0.1)
- def always_fails() -> None:
- raise RuntimeError("permanent failure")
-
- always_fails.delay()
-
- worker_thread = threading.Thread(
- target=queue.run_worker,
- daemon=True,
- )
- worker_thread.start()
-
- poll_until(
- lambda: len(queue.dead_letters()) >= 1,
- timeout=15,
- message="job did not reach DLQ after exhausting retries",
- )
-
- dead = queue.dead_letters()
- assert len(dead) >= 1
- assert dead[0]["task_name"].endswith("always_fails")
-
-
-def test_retry_dead_letter(queue: Queue, poll_until: PollUntil) -> None:
- """A dead letter job can be re-enqueued."""
-
- @queue.task(max_retries=1, retry_backoff=0.1)
- def fail_once() -> None:
- raise RuntimeError("fail")
-
- fail_once.delay()
-
- worker_thread = threading.Thread(
- target=queue.run_worker,
- daemon=True,
- )
- worker_thread.start()
-
- poll_until(
- lambda: len(queue.dead_letters()) >= 1,
- timeout=10,
- message="job did not reach DLQ",
- )
-
- dead = queue.dead_letters()
- if dead:
- new_id = queue.retry_dead(dead[0]["id"])
- assert new_id is not None
- assert len(new_id) > 0
diff --git a/tests/test_retry_history.py b/tests/test_retry_history.py
deleted file mode 100644
index 5ab7fed..0000000
--- a/tests/test_retry_history.py
+++ /dev/null
@@ -1,58 +0,0 @@
-"""Tests for retry history (job_errors tracking)."""
-
-import threading
-from pathlib import Path
-
-import pytest
-
-from taskito import Queue
-
-
-@pytest.fixture
-def queue(tmp_path: Path) -> Queue:
- db_path = str(tmp_path / "test_retry_history.db")
- return Queue(db_path=db_path, workers=1)
-
-
-def test_retry_errors_recorded(queue: Queue) -> None:
- """Failed attempts are recorded in job.errors."""
- call_count = {"n": 0}
-
- @queue.task(max_retries=3, retry_backoff=0.01)
- def flaky() -> str:
- call_count["n"] += 1
- if call_count["n"] <= 3:
- raise ValueError(f"attempt {call_count['n']}")
- return "ok"
-
- job = flaky.delay()
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- result = job.result(timeout=15)
- assert result == "ok"
-
- errors = job.errors
- assert len(errors) == 3
- assert errors[0]["attempt"] == 0
- assert "attempt 1" in errors[0]["error"]
- assert errors[1]["attempt"] == 1
- assert errors[2]["attempt"] == 2
-
-
-def test_errors_empty_on_success(queue: Queue) -> None:
- """Successful jobs have an empty errors list."""
-
- @queue.task()
- def ok_task() -> int:
- return 42
-
- job = ok_task.delay()
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- result = job.result(timeout=10)
- assert result == 42
- assert job.errors == []
diff --git a/tests/test_run_maybe_async.py b/tests/test_run_maybe_async.py
deleted file mode 100644
index 89f99bd..0000000
--- a/tests/test_run_maybe_async.py
+++ /dev/null
@@ -1,59 +0,0 @@
-"""Tests for `run_maybe_async` — detection of running event loop."""
-
-from __future__ import annotations
-
-import asyncio
-from typing import Any
-
-import pytest
-
-from taskito.async_support.helpers import run_maybe_async
-
-
-def test_run_maybe_async_passes_through_non_coroutine() -> None:
- assert run_maybe_async(42) == 42
- assert run_maybe_async("hello") == "hello"
- assert run_maybe_async(None) is None
-
-
-def test_run_maybe_async_runs_coroutine_in_sync_context() -> None:
- async def make_value() -> int:
- return 7
-
- assert run_maybe_async(make_value()) == 7
-
-
-async def test_run_maybe_async_raises_clear_error_under_running_loop() -> None:
- """Pytest-asyncio puts us in a running loop — must surface the taskito error."""
-
- async def make_value() -> int:
- return 1
-
- coro: Any = make_value()
- with pytest.raises(RuntimeError, match="async API"):
- run_maybe_async(coro)
- # Drain to silence the "coroutine was never awaited" warning.
- coro.close()
-
-
-async def test_run_maybe_async_async_message_mentions_a_methods() -> None:
- async def make_value() -> int:
- return 1
-
- coro: Any = make_value()
- try:
- run_maybe_async(coro)
- except RuntimeError as exc:
- assert "aresult" in str(exc) or "aenqueue" in str(exc) or "await" in str(exc)
- finally:
- coro.close()
-
-
-def test_run_maybe_async_no_loop_uses_asyncio_run() -> None:
- """Sanity: a fresh thread with no loop should run a coroutine to completion."""
-
- async def slow() -> str:
- await asyncio.sleep(0.001)
- return "ok"
-
- assert run_maybe_async(slow()) == "ok"
diff --git a/tests/test_serializers.py b/tests/test_serializers.py
deleted file mode 100644
index 977a718..0000000
--- a/tests/test_serializers.py
+++ /dev/null
@@ -1,129 +0,0 @@
-"""Tests for pluggable serializers."""
-
-import json
-import pickle
-
-import pytest
-
-from taskito.serializers import CloudpickleSerializer, JsonSerializer, Serializer
-
-
-class TestJsonSerializer:
- def test_roundtrip_dict(self) -> None:
- s = JsonSerializer()
- data = {"key": "value", "num": 42, "nested": [1, 2, 3]}
- assert s.loads(s.dumps(data)) == data
-
- def test_roundtrip_list(self) -> None:
- s = JsonSerializer()
- data = [1, "two", None, True]
- assert s.loads(s.dumps(data)) == data
-
- def test_roundtrip_primitives(self) -> None:
- s = JsonSerializer()
- for val in [42, 3.14, "hello", True, None]:
- assert s.loads(s.dumps(val)) == val
-
- def test_dumps_returns_bytes(self) -> None:
- s = JsonSerializer()
- result = s.dumps({"a": 1})
- assert isinstance(result, bytes)
-
- def test_non_serializable_raises(self) -> None:
- s = JsonSerializer()
- with pytest.raises(TypeError):
- s.dumps(object())
-
- def test_invalid_bytes_raises(self) -> None:
- s = JsonSerializer()
- with pytest.raises((json.JSONDecodeError, UnicodeDecodeError, ValueError)):
- s.loads(b"\xff\xfe")
-
-
-class TestCloudpickleSerializer:
- def test_roundtrip_dict(self) -> None:
- s = CloudpickleSerializer()
- data = {"key": "value", "num": 42}
- assert s.loads(s.dumps(data)) == data
-
- def test_roundtrip_lambda(self) -> None:
- s = CloudpickleSerializer()
- fn = lambda x: x * 2 # noqa: E731
- restored = s.loads(s.dumps(fn))
- assert restored(5) == 10
-
- def test_dumps_returns_bytes(self) -> None:
- s = CloudpickleSerializer()
- assert isinstance(s.dumps(42), bytes)
-
- def test_invalid_bytes_raises(self) -> None:
- s = CloudpickleSerializer()
- with pytest.raises((pickle.UnpicklingError, EOFError)):
- s.loads(b"not-valid-pickle")
-
-
-class TestSerializerProtocol:
- def test_json_is_serializer(self) -> None:
- assert isinstance(JsonSerializer(), Serializer)
-
- def test_cloudpickle_is_serializer(self) -> None:
- assert isinstance(CloudpickleSerializer(), Serializer)
-
-
-class TestMsgPackSerializer:
- def test_roundtrip(self) -> None:
- pytest.importorskip("msgpack")
- from taskito.serializers import MsgPackSerializer
-
- s = MsgPackSerializer()
- data = {"key": "value", "num": 42}
- assert s.loads(s.dumps(data)) == data
-
- def test_dumps_returns_bytes(self) -> None:
- pytest.importorskip("msgpack")
- from taskito.serializers import MsgPackSerializer
-
- s = MsgPackSerializer()
- assert isinstance(s.dumps([1, 2, 3]), bytes)
-
-
-class TestEncryptedSerializer:
- def test_roundtrip(self) -> None:
- pytest.importorskip("cryptography")
- import os
-
- from taskito.serializers import EncryptedSerializer
-
- key = os.urandom(32)
- s = EncryptedSerializer(JsonSerializer(), key)
- data = {"secret": "payload"}
- assert s.loads(s.dumps(data)) == data
-
- def test_wrong_key_fails(self) -> None:
- pytest.importorskip("cryptography")
- import os
-
- from taskito.serializers import EncryptedSerializer
-
- s1 = EncryptedSerializer(JsonSerializer(), os.urandom(32))
- s2 = EncryptedSerializer(JsonSerializer(), os.urandom(32))
- from cryptography.exceptions import InvalidTag
-
- encrypted = s1.dumps({"data": 1})
- with pytest.raises(InvalidTag):
- s2.loads(encrypted)
-
- def test_tampered_ciphertext_fails(self) -> None:
- pytest.importorskip("cryptography")
- import os
-
- from taskito.serializers import EncryptedSerializer
-
- key = os.urandom(32)
- s = EncryptedSerializer(JsonSerializer(), key)
- encrypted = s.dumps("hello")
- from cryptography.exceptions import InvalidTag
-
- tampered = encrypted[:-1] + bytes([encrypted[-1] ^ 0xFF])
- with pytest.raises(InvalidTag):
- s.loads(tampered)
diff --git a/tests/test_shutdown.py b/tests/test_shutdown.py
deleted file mode 100644
index 276ed5e..0000000
--- a/tests/test_shutdown.py
+++ /dev/null
@@ -1,62 +0,0 @@
-"""Tests for graceful shutdown."""
-
-import threading
-import time
-from typing import Any
-
-from taskito import Queue
-
-PollUntil = Any # the conftest fixture's runtime type
-
-
-def test_graceful_shutdown_completes_inflight(queue: Queue, poll_until: PollUntil) -> None:
- """Graceful shutdown waits for in-flight tasks to complete."""
- completed = threading.Event()
-
- @queue.task()
- def slow_task() -> str:
- # Intentional pacing — the test asserts the worker waits for this
- # to finish before shutting down.
- time.sleep(1)
- completed.set()
- return "done"
-
- job = slow_task.delay()
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- # Wait for the task to actually start running before triggering shutdown.
- poll_until(
- lambda: (j := queue.get_job(job.id)) is not None and j.status == "running",
- message="slow_task never reached running state",
- )
-
- # Request graceful shutdown
- queue._inner.request_shutdown()
-
- # Worker should finish the in-flight task
- worker_thread.join(timeout=10)
-
- assert completed.is_set()
- fetched = queue.get_job(job.id)
- assert fetched is not None
- assert fetched.status == "complete"
-
-
-def test_shutdown_stops_worker(queue: Queue) -> None:
- """request_shutdown causes run_worker to return."""
-
- @queue.task()
- def noop() -> None:
- pass
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- # Tiny grace window so the worker reaches its poll loop before shutdown.
- time.sleep(0.1)
- queue._inner.request_shutdown()
-
- worker_thread.join(timeout=10)
- assert not worker_thread.is_alive()
diff --git a/tests/test_streaming.py b/tests/test_streaming.py
deleted file mode 100644
index affeee3..0000000
--- a/tests/test_streaming.py
+++ /dev/null
@@ -1,154 +0,0 @@
-"""Tests for partial result streaming via current_job.publish() and job.stream()."""
-
-from __future__ import annotations
-
-import threading
-import time
-from pathlib import Path
-
-from taskito import Queue
-
-
-def test_publish_writes_result_log(tmp_path: Path) -> None:
- """publish() stores data as a task log with level='result'."""
- queue = Queue(db_path=str(tmp_path / "test.db"))
-
- @queue.task()
- def emit_data() -> str:
- from taskito.context import current_job
-
- current_job.publish({"step": 1, "value": "hello"})
- current_job.publish({"step": 2, "value": "world"})
- return "done"
-
- job = emit_data.delay()
-
- worker = threading.Thread(target=queue.run_worker, daemon=True)
- worker.start()
-
- result = job.result(timeout=10)
- assert result == "done"
-
- # Check that partial results are stored as task logs
- logs = queue.task_logs(job.id)
- result_logs = [lg for lg in logs if lg["level"] == "result"]
- assert len(result_logs) == 2
- assert '"step": 1' in result_logs[0]["extra"]
- assert '"step": 2' in result_logs[1]["extra"]
-
- queue._inner.request_shutdown()
-
-
-def test_stream_yields_partial_results(tmp_path: Path) -> None:
- """job.stream() yields published partial results."""
- queue = Queue(db_path=str(tmp_path / "test.db"))
-
- @queue.task()
- def batch_process() -> str:
- from taskito.context import current_job
-
- for i in range(3):
- current_job.publish({"item": i, "status": "processed"})
- time.sleep(0.1)
- return "all done"
-
- job = batch_process.delay()
-
- worker = threading.Thread(target=queue.run_worker, daemon=True)
- worker.start()
-
- # Collect streamed results
- results = list(job.stream(timeout=15, poll_interval=0.3))
-
- assert len(results) == 3
- assert results[0]["item"] == 0
- assert results[1]["item"] == 1
- assert results[2]["item"] == 2
- assert all(r["status"] == "processed" for r in results)
-
- queue._inner.request_shutdown()
-
-
-def test_stream_stops_on_completion(tmp_path: Path) -> None:
- """stream() stops iterating when the job completes."""
- queue = Queue(db_path=str(tmp_path / "test.db"))
-
- @queue.task()
- def quick_task() -> int:
- from taskito.context import current_job
-
- current_job.publish({"msg": "started"})
- return 42
-
- job = quick_task.delay()
-
- worker = threading.Thread(target=queue.run_worker, daemon=True)
- worker.start()
-
- results = list(job.stream(timeout=10, poll_interval=0.2))
- assert len(results) >= 1
- assert results[0]["msg"] == "started"
-
- queue._inner.request_shutdown()
-
-
-async def test_astream_async(tmp_path: Path) -> None:
- """astream() works as an async iterator."""
- queue = Queue(db_path=str(tmp_path / "test.db"))
-
- @queue.task()
- def async_batch() -> str:
- from taskito.context import current_job
-
- current_job.publish({"phase": "init"})
- current_job.publish({"phase": "done"})
- return "ok"
-
- job = async_batch.delay()
-
- worker = threading.Thread(target=queue.run_worker, daemon=True)
- worker.start()
-
- results: list[dict] = []
- async for partial in job.astream(timeout=10, poll_interval=0.3):
- results.append(partial)
-
- assert len(results) >= 1
- assert results[0]["phase"] == "init"
-
- queue._inner.request_shutdown()
-
-
-def test_publish_non_dict_data(tmp_path: Path) -> None:
- """publish() handles non-dict data (strings, lists, numbers).
-
- Since stream() polls, we verify via task_logs directly to avoid
- timing-dependent polling issues.
- """
- queue = Queue(db_path=str(tmp_path / "test.db"))
-
- @queue.task()
- def varied_output() -> str:
- from taskito.context import current_job
-
- current_job.publish("plain string")
- current_job.publish([1, 2, 3])
- current_job.publish(42)
- return "done"
-
- job = varied_output.delay()
-
- worker = threading.Thread(target=queue.run_worker, daemon=True)
- worker.start()
-
- job.result(timeout=10)
-
- # Verify all published data via task logs
- logs = queue.task_logs(job.id)
- result_logs = [lg for lg in logs if lg["level"] == "result"]
- extras = [lg["extra"] for lg in result_logs]
- assert '"plain string"' in extras
- assert "[1, 2, 3]" in extras
- assert "42" in extras
-
- queue._inner.request_shutdown()
diff --git a/tests/test_unique.py b/tests/test_unique.py
deleted file mode 100644
index c428bfc..0000000
--- a/tests/test_unique.py
+++ /dev/null
@@ -1,65 +0,0 @@
-"""Tests for unique task deduplication."""
-
-from __future__ import annotations
-
-import threading
-
-from taskito import Queue
-
-
-def test_unique_key_dedup(queue: Queue) -> None:
- """Two jobs with the same unique_key should return the same job ID."""
-
- @queue.task()
- def process(data: str) -> str:
- return data
-
- job1 = process.apply_async(args=("a",), unique_key="dedup-1")
- job2 = process.apply_async(args=("b",), unique_key="dedup-1")
-
- assert job1.id == job2.id
-
-
-def test_different_unique_keys(queue: Queue) -> None:
- """Different unique keys should create separate jobs."""
-
- @queue.task()
- def process(data: str) -> str:
- return data
-
- job1 = process.apply_async(args=("a",), unique_key="key-a")
- job2 = process.apply_async(args=("b",), unique_key="key-b")
-
- assert job1.id != job2.id
-
-
-def test_unique_key_allows_after_complete(queue: Queue) -> None:
- """After a unique job completes, a new one with the same key can be created."""
-
- @queue.task()
- def fast_task() -> str:
- return "done"
-
- job1 = fast_task.apply_async(unique_key="once")
-
- worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
- worker_thread.start()
-
- job1.result(timeout=10)
-
- # Now enqueue again with the same key — should create a new job
- job2 = fast_task.apply_async(unique_key="once")
- assert job2.id != job1.id
-
-
-def test_no_unique_key_allows_duplicates(queue: Queue) -> None:
- """Without unique_key, duplicate jobs are allowed."""
-
- @queue.task()
- def process(data: str) -> str:
- return data
-
- job1 = process.delay("a")
- job2 = process.delay("a")
-
- assert job1.id != job2.id
diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py
deleted file mode 100644
index ac41b42..0000000
--- a/tests/test_webhooks.py
+++ /dev/null
@@ -1,138 +0,0 @@
-"""Tests for webhook delivery system."""
-
-import hashlib
-import hmac
-import json
-import threading
-from collections.abc import Generator
-from http.server import BaseHTTPRequestHandler, HTTPServer
-from typing import Any
-
-import pytest
-
-from taskito.events import EventType
-from taskito.webhooks import WebhookManager
-
-PollUntil = Any # the conftest fixture's runtime type
-
-
-@pytest.fixture
-def webhook_server() -> Generator[tuple[str, list[dict[str, Any]]]]:
- """Start a local HTTP server that records webhook deliveries."""
- received: list[dict[str, Any]] = []
-
- class Handler(BaseHTTPRequestHandler):
- def do_POST(self) -> None:
- length = int(self.headers.get("Content-Length", 0))
- body = self.rfile.read(length)
- received.append(
- {
- "body": json.loads(body),
- "headers": dict(self.headers),
- }
- )
- self.send_response(200)
- self.end_headers()
-
- def log_message(self, *args: Any) -> None:
- pass
-
- server = HTTPServer(("127.0.0.1", 0), Handler)
- port = server.server_address[1]
- thread = threading.Thread(target=server.serve_forever, daemon=True)
- thread.start()
-
- yield f"http://127.0.0.1:{port}", received
-
- server.shutdown()
-
-
-def test_webhook_delivery(
- webhook_server: tuple[str, list[dict[str, Any]]], poll_until: PollUntil
-) -> None:
- """Webhooks are delivered to registered URLs."""
- url, received = webhook_server
- mgr = WebhookManager()
- mgr.add_webhook(url)
-
- mgr.notify(EventType.JOB_COMPLETED, {"job_id": "abc"})
- poll_until(lambda: len(received) >= 1, message="webhook not delivered")
-
- assert len(received) == 1
- assert received[0]["body"]["event"] == "job.completed"
- assert received[0]["body"]["job_id"] == "abc"
-
-
-def test_webhook_event_filtering(
- webhook_server: tuple[str, list[dict[str, Any]]], poll_until: PollUntil
-) -> None:
- """Webhooks with event filters only receive matching events."""
- url, received = webhook_server
- mgr = WebhookManager()
- mgr.add_webhook(url, events=[EventType.JOB_FAILED])
-
- mgr.notify(EventType.JOB_COMPLETED, {"job_id": "1"})
- mgr.notify(EventType.JOB_FAILED, {"job_id": "2", "error": "boom"})
- poll_until(lambda: len(received) >= 1, message="filtered webhook not delivered")
-
- assert len(received) == 1
- assert received[0]["body"]["event"] == "job.failed"
-
-
-def test_webhook_hmac_signing(
- webhook_server: tuple[str, list[dict[str, Any]]], poll_until: PollUntil
-) -> None:
- """Webhooks with a secret include a valid HMAC signature."""
- url, received = webhook_server
- secret = "my-secret-key"
- mgr = WebhookManager()
- mgr.add_webhook(url, secret=secret)
-
- mgr.notify(EventType.JOB_ENQUEUED, {"job_id": "xyz"})
- poll_until(lambda: len(received) >= 1, message="signed webhook not delivered")
-
- assert len(received) == 1
- sig_header = received[0]["headers"].get("X-Taskito-Signature")
- assert sig_header is not None
- assert sig_header.startswith("sha256=")
-
- # Verify the signature
- body_bytes = json.dumps(received[0]["body"], default=str).encode("utf-8")
- expected_sig = hmac.new(secret.encode(), body_bytes, hashlib.sha256).hexdigest()
- assert sig_header == f"sha256={expected_sig}"
-
-
-def test_webhook_url_validation() -> None:
- """Only http:// and https:// URLs are accepted."""
- mgr = WebhookManager()
-
- with pytest.raises(ValueError, match="http:// or https://"):
- mgr.add_webhook("ftp://example.com/hook")
-
- with pytest.raises(ValueError, match="http:// or https://"):
- mgr.add_webhook("javascript:alert(1)")
-
- # Valid URLs should not raise
- mgr.add_webhook("http://localhost:8080/hook")
- mgr.add_webhook("https://example.com/hook")
-
-
-def test_webhook_custom_headers(
- webhook_server: tuple[str, list[dict[str, Any]]], poll_until: PollUntil
-) -> None:
- """Custom headers are included in webhook requests."""
- url, received = webhook_server
- mgr = WebhookManager()
- mgr.add_webhook(url, headers={"X-Custom": "test-value"})
-
- mgr.notify(EventType.JOB_COMPLETED, {"job_id": "1"})
- poll_until(lambda: len(received) >= 1, message="custom-headers webhook not delivered")
-
- assert len(received) == 1
- assert received[0]["headers"].get("X-Custom") == "test-value"
-
-
-def test_webhook_no_subscribers() -> None:
- """Notifying with no matching webhooks doesn't raise."""
- mgr = WebhookManager()
- mgr.notify(EventType.JOB_COMPLETED, {"job_id": "1"})
diff --git a/tests/test_worker.py b/tests/test_worker.py
deleted file mode 100644
index f2ae258..0000000
--- a/tests/test_worker.py
+++ /dev/null
@@ -1,114 +0,0 @@
-"""Tests for worker behavior."""
-
-import threading
-import time
-from pathlib import Path
-
-import pytest
-
-from taskito import Queue
-
-
-def test_multiple_tasks(queue: Queue) -> None:
- """Worker handles multiple different task types."""
-
- @queue.task()
- def task_a(x: int) -> int:
- return x * 2
-
- @queue.task()
- def task_b(x: int) -> int:
- return x + 10
-
- job_a = task_a.delay(5)
- job_b = task_b.delay(5)
-
- worker_thread = threading.Thread(
- target=queue.run_worker,
- daemon=True,
- )
- worker_thread.start()
-
- assert job_a.result(timeout=10) == 10
- assert job_b.result(timeout=10) == 15
-
-
-def test_get_job(queue: Queue) -> None:
- """Can retrieve a job by ID."""
-
- @queue.task()
- def simple() -> int:
- return 42
-
- job = simple.delay()
- fetched = queue.get_job(job.id)
- assert fetched is not None
- assert fetched.id == job.id
-
-
-def test_job_status_progression(queue: Queue) -> None:
- """Job status progresses from pending through complete."""
-
- @queue.task()
- def slow() -> str:
- time.sleep(0.5)
- return "done"
-
- job = slow.delay()
-
- # Initially pending
- fetched = queue.get_job(job.id)
- assert fetched is not None
- assert fetched.status == "pending"
-
- worker_thread = threading.Thread(
- target=queue.run_worker,
- daemon=True,
- )
- worker_thread.start()
-
- result = job.result(timeout=10)
- assert result == "done"
-
- # After completion
- fetched = queue.get_job(job.id)
- assert fetched is not None
- assert fetched.status == "complete"
-
-
-@pytest.mark.asyncio
-async def test_async_result(tmp_path: Path) -> None:
- """Async result retrieval works."""
- db_path = str(tmp_path / "test_async.db")
- queue = Queue(db_path=db_path, workers=2)
-
- @queue.task()
- def add(a: int, b: int) -> int:
- return a + b
-
- job = add.delay(10, 20)
-
- worker_thread = threading.Thread(
- target=queue.run_worker,
- daemon=True,
- )
- worker_thread.start()
-
- result = await job.aresult(timeout=10)
- assert result == 30
-
-
-@pytest.mark.asyncio
-async def test_async_stats(tmp_path: Path) -> None:
- """Async stats work."""
- db_path = str(tmp_path / "test_async_stats.db")
- queue = Queue(db_path=db_path, workers=2)
-
- @queue.task()
- def noop() -> None:
- pass
-
- noop.delay()
-
- stats = await queue.astats()
- assert stats["pending"] == 1
diff --git a/tests/test_worker_resources.py b/tests/test_worker_resources.py
deleted file mode 100644
index c6ebe84..0000000
--- a/tests/test_worker_resources.py
+++ /dev/null
@@ -1,181 +0,0 @@
-"""Tests for worker resource advertisement (Phase 4 — storage extension)."""
-
-from __future__ import annotations
-
-import json
-import threading
-from typing import Any
-
-from taskito import Queue
-
-# ---------------------------------------------------------------------------
-# Direct storage-level tests (no worker needed)
-# ---------------------------------------------------------------------------
-
-
-class TestWorkerAdvertisement:
- def test_no_resources_returns_none(self, tmp_path: Any) -> None:
- """_build_resource_health_json returns None when no resources."""
- queue = Queue(db_path=str(tmp_path / "q.db"))
- assert queue._build_resource_health_json() is None
-
- def test_build_resource_health_json_with_resources(self, tmp_path: Any) -> None:
- """_build_resource_health_json returns correct JSON."""
- queue = Queue(db_path=str(tmp_path / "q.db"))
-
- @queue.worker_resource("db")
- def create_db() -> str:
- return "db_instance"
-
- @queue.worker_resource("cache")
- def create_cache() -> str:
- return "cache_instance"
-
- health_json = queue._build_resource_health_json()
- assert health_json is not None
- health = json.loads(health_json)
- assert health == {"db": "healthy", "cache": "healthy"}
-
- def test_build_resource_health_reflects_unhealthy(self, tmp_path: Any) -> None:
- """_build_resource_health_json marks unhealthy resources."""
- from taskito.resources.runtime import ResourceRuntime
-
- queue = Queue(db_path=str(tmp_path / "q.db"))
-
- @queue.worker_resource("db")
- def create_db() -> str:
- return "db_instance"
-
- # Simulate an initialized runtime with an unhealthy resource
- runtime = ResourceRuntime(queue._resource_definitions)
- runtime._instances = {"db": "db_instance"}
- runtime._init_order = ["db"]
- runtime._unhealthy = {"db"}
- queue._resource_runtime = runtime
-
- health_json = queue._build_resource_health_json()
- assert health_json is not None
- health = json.loads(health_json)
- assert health["db"] == "unhealthy"
-
- def test_worker_heartbeat_method(self, tmp_path: Any) -> None:
- """worker_heartbeat can be called without error."""
- queue = Queue(db_path=str(tmp_path / "q.db"))
- # Heartbeat for a non-existent worker is a no-op (updates 0 rows)
- queue._inner.worker_heartbeat("nonexistent-worker")
-
- def test_worker_heartbeat_with_health(self, tmp_path: Any) -> None:
- """worker_heartbeat accepts resource_health JSON."""
- queue = Queue(db_path=str(tmp_path / "q.db"))
- health = json.dumps({"db": "healthy"})
- queue._inner.worker_heartbeat("w-test", health)
-
- def test_list_workers_empty(self, tmp_path: Any) -> None:
- """list_workers returns empty list when no workers registered."""
- queue = Queue(db_path=str(tmp_path / "q.db"))
- workers = queue.workers()
- assert workers == []
-
-
-# ---------------------------------------------------------------------------
-# Integration tests with actual worker
-# ---------------------------------------------------------------------------
-
-
-class TestWorkerResourceIntegration:
- def test_worker_advertises_resources_and_threads(self, tmp_path: Any, poll_until: Any) -> None:
- """A running worker stores resources and threads in storage."""
- queue = Queue(db_path=str(tmp_path / "q.db"), workers=2)
-
- @queue.worker_resource("db")
- def create_db() -> str:
- return "db_instance"
-
- @queue.task()
- def noop() -> None:
- pass
-
- # Start worker in thread, wait for it to register
- thread = threading.Thread(target=queue.run_worker, daemon=True)
- thread.start()
-
- try:
- # Poll until a worker appears with resource_health populated
- # (initial registration has None; first heartbeat sets it).
- poll_until(
- lambda: bool((ws := queue.workers()) and ws[0].get("resource_health") is not None),
- timeout=15,
- message="worker did not publish resource_health",
- )
- workers: list[dict[str, Any]] = queue.workers()
- assert len(workers) >= 1
- w = workers[0]
- assert "resources" in w
- assert "resource_health" in w
- assert "threads" in w
- assert w["threads"] == 2
-
- # Parse resources JSON
- resources = json.loads(w["resources"])
- assert "db" in resources
-
- # Parse health JSON
- health = json.loads(w["resource_health"])
- assert health.get("db") == "healthy"
- finally:
- queue._inner.request_shutdown()
- thread.join(timeout=10)
-
- def test_worker_no_resources(self, tmp_path: Any, poll_until: Any) -> None:
- """Worker without resources stores None for resource fields."""
- queue = Queue(db_path=str(tmp_path / "q.db"), workers=1)
-
- @queue.task()
- def noop() -> None:
- pass
-
- thread = threading.Thread(target=queue.run_worker, daemon=True)
- thread.start()
-
- try:
- poll_until(
- lambda: bool(queue.workers()), timeout=10, message="worker did not register"
- )
- workers: list[dict[str, Any]] = queue.workers()
- assert len(workers) >= 1
- w = workers[0]
- assert w["resources"] is None
- assert w["resource_health"] is None
- finally:
- queue._inner.request_shutdown()
- thread.join(timeout=10)
-
- def test_heartbeat_updates_health(self, tmp_path: Any, poll_until: Any) -> None:
- """Heartbeat thread updates resource_health in storage."""
- queue = Queue(db_path=str(tmp_path / "q.db"), workers=1)
-
- @queue.worker_resource("db")
- def create_db() -> str:
- return "db_instance"
-
- @queue.task()
- def noop() -> None:
- pass
-
- thread = threading.Thread(target=queue.run_worker, daemon=True)
- thread.start()
-
- try:
- # Wait for worker + first heartbeat
- poll_until(
- lambda: bool((ws := queue.workers()) and ws[0].get("resource_health")),
- timeout=10,
- message="heartbeat never published resource_health",
- )
- workers: list[dict[str, Any]] = queue.workers()
- assert len(workers) >= 1
- health = json.loads(workers[0]["resource_health"])
- assert "db" in health
- finally:
- queue._inner.request_shutdown()
- thread.join(timeout=10)
diff --git a/tests/test_workflows_analysis.py b/tests/test_workflows_analysis.py
deleted file mode 100644
index dc01185..0000000
--- a/tests/test_workflows_analysis.py
+++ /dev/null
@@ -1,137 +0,0 @@
-"""Tests for Phase 6 workflow graph analysis."""
-
-from __future__ import annotations
-
-import pytest
-
-from taskito.workflows import Workflow
-
-
-class _FakeTask:
- _task_name = "fake"
-
-
-def _linear() -> Workflow:
- """a → b → c"""
- wf = Workflow(name="linear")
- wf.step("a", _FakeTask())
- wf.step("b", _FakeTask(), after="a")
- wf.step("c", _FakeTask(), after="b")
- return wf
-
-
-def _diamond() -> Workflow:
- """a → {b, c} → d"""
- wf = Workflow(name="diamond")
- wf.step("a", _FakeTask())
- wf.step("b", _FakeTask(), after="a")
- wf.step("c", _FakeTask(), after="a")
- wf.step("d", _FakeTask(), after=["b", "c"])
- return wf
-
-
-# ── Ancestors / Descendants ──────────────────────────────────────
-
-
-def test_ancestors_linear() -> None:
- wf = _linear()
- assert wf.ancestors("c") == ["a", "b"]
- assert wf.ancestors("b") == ["a"]
- assert wf.ancestors("a") == []
-
-
-def test_descendants_linear() -> None:
- wf = _linear()
- assert wf.descendants("a") == ["b", "c"]
- assert wf.descendants("b") == ["c"]
- assert wf.descendants("c") == []
-
-
-def test_ancestors_diamond() -> None:
- wf = _diamond()
- assert wf.ancestors("d") == ["a", "b", "c"]
- assert set(wf.ancestors("b")) == {"a"}
-
-
-def test_ancestors_unknown_node() -> None:
- wf = _linear()
- with pytest.raises(KeyError, match="nonexistent"):
- wf.ancestors("nonexistent")
-
-
-# ── Topological Levels ───────────────────────────────────────────
-
-
-def test_topological_levels_linear() -> None:
- wf = _linear()
- assert wf.topological_levels() == [["a"], ["b"], ["c"]]
-
-
-def test_topological_levels_diamond() -> None:
- wf = _diamond()
- levels = wf.topological_levels()
- assert levels[0] == ["a"]
- assert sorted(levels[1]) == ["b", "c"]
- assert levels[2] == ["d"]
-
-
-# ── Stats ────────────────────────────────────────────────────────
-
-
-def test_stats() -> None:
- wf = _diamond()
- s = wf.stats()
- assert s["nodes"] == 4
- assert s["edges"] == 4 # a→b, a→c, b→d, c→d
- assert s["depth"] == 3
- assert s["width"] == 2 # level 1 has b and c
- assert 0 < s["density"] <= 1.0
-
-
-# ── Critical Path ────────────────────────────────────────────────
-
-
-def test_critical_path_linear() -> None:
- wf = _linear()
- path, cost = wf.critical_path({"a": 1.0, "b": 2.0, "c": 3.0})
- assert path == ["a", "b", "c"]
- assert cost == 6.0
-
-
-def test_critical_path_diamond() -> None:
- wf = _diamond()
- # b branch is heavier: a(1) + b(5) + d(1) = 7
- # c branch: a(1) + c(2) + d(1) = 4
- path, cost = wf.critical_path({"a": 1.0, "b": 5.0, "c": 2.0, "d": 1.0})
- assert path == ["a", "b", "d"]
- assert cost == 7.0
-
-
-# ── Execution Plan ───────────────────────────────────────────────
-
-
-def test_execution_plan_parallelism() -> None:
- wf = _diamond()
- plan = wf.execution_plan(max_workers=2)
- # Level 0: [a], Level 1: [b, c] fits in 1 batch, Level 2: [d]
- assert plan == [["a"], ["b", "c"], ["d"]]
-
-
-def test_execution_plan_worker_limit() -> None:
- wf = _diamond()
- plan = wf.execution_plan(max_workers=1)
- # Level 1 splits into two batches
- assert plan == [["a"], ["b"], ["c"], ["d"]]
-
-
-# ── Bottleneck Analysis ──────────────────────────────────────────
-
-
-def test_bottleneck_analysis() -> None:
- wf = _diamond()
- result = wf.bottleneck_analysis({"a": 1.0, "b": 5.0, "c": 2.0, "d": 1.0})
- assert result["node"] == "b"
- assert result["cost"] == 5.0
- assert result["percentage"] > 50
- assert "b" in result["suggestion"]
- assert result["critical_path"] == ["a", "b", "d"]
diff --git a/tests/test_workflows_caching.py b/tests/test_workflows_caching.py
deleted file mode 100644
index cfba4c8..0000000
--- a/tests/test_workflows_caching.py
+++ /dev/null
@@ -1,215 +0,0 @@
-"""Tests for Phase 7 incremental execution and caching."""
-
-from __future__ import annotations
-
-import threading
-from collections.abc import Callable
-from contextlib import AbstractContextManager
-
-from taskito import Queue
-from taskito.workflows import NodeStatus, Workflow, WorkflowState
-
-WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]
-
-
-def test_result_hash_stored(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Completed nodes have a non-None result_hash."""
-
- @queue.task()
- def ok_task() -> str:
- return "hello"
-
- wf = Workflow(name="hash_stored")
- wf.step("a", ok_task)
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- run.wait(timeout=15)
-
- # Check the node data via base_run_node_data
- nodes = queue._inner.get_base_run_node_data(run.id)
- assert len(nodes) == 1
- name, status, _result_hash = nodes[0]
- assert name == "a"
- assert status == "completed"
- # Hash may be None if result wasn't stored before event fired (best-effort).
- # In practice it's usually populated.
-
-
-def test_incremental_skips_completed(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Incremental run marks base-completed nodes as CACHE_HIT."""
-
- executed: list[str] = []
-
- @queue.task()
- def step_a() -> str:
- executed.append("a")
- return "a"
-
- @queue.task()
- def step_b() -> str:
- executed.append("b")
- return "b"
-
- wf = Workflow(name="incr_skip")
- wf.step("a", step_a)
- wf.step("b", step_b, after="a")
-
- # First run: everything executes.
- with workflow_worker():
- run1 = queue.submit_workflow(wf)
- run1.wait(timeout=15)
-
- assert run1.status().state == WorkflowState.COMPLETED
- executed.clear()
-
- # Second run: incremental.
- with workflow_worker():
- run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id)
- run2.wait(timeout=15)
-
- final = run2.status()
- assert final.state == WorkflowState.COMPLETED
-
- # If hashes were stored, both nodes are CACHE_HIT and nothing re-ran.
- cache_hits = [n for n in final.nodes.values() if n.status == NodeStatus.CACHE_HIT]
- if cache_hits:
- assert len(cache_hits) == 2
- assert executed == [] # nothing re-executed
-
-
-def test_incremental_reruns_failed(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Failed nodes in the base run get re-executed."""
-
- call_count = {"n": 0}
-
- @queue.task(max_retries=0)
- def flaky() -> str:
- call_count["n"] += 1
- if call_count["n"] == 1:
- raise RuntimeError("first call fails")
- return "ok"
-
- wf = Workflow(name="incr_rerun")
- wf.step("a", flaky)
-
- # First run: fails.
- with workflow_worker():
- run1 = queue.submit_workflow(wf)
- run1.wait(timeout=15)
-
- assert run1.status().state == WorkflowState.FAILED
-
- # Second run incremental: failed node re-executes.
- with workflow_worker():
- run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id)
- run2.wait(timeout=15)
-
- assert run2.status().state == WorkflowState.COMPLETED
- assert call_count["n"] == 2
-
-
-def test_dirty_propagation(queue: Queue) -> None:
- """If a root node is dirty, all downstream re-execute even if they were cached."""
- from taskito.workflows.incremental import compute_dirty_set
-
- successors = {"a": ["b"], "b": ["c"], "c": []}
- predecessors = {"a": [], "b": ["a"], "c": ["b"]}
-
- # Simulate "a" being dirty (not in base).
- base_nodes_missing_a: list[tuple[str, str, str | None]] = [
- ("b", "completed", "hash_b"),
- ("c", "completed", "hash_c"),
- ]
-
- dirty, cached = compute_dirty_set(
- base_nodes=base_nodes_missing_a,
- new_node_names=["a", "b", "c"],
- successors=successors,
- predecessors=predecessors,
- )
-
- assert "a" in dirty
- assert "b" in dirty # propagated from a
- assert "c" in dirty # propagated from b
- assert not cached
-
-
-def test_cache_hit_is_terminal(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """CACHE_HIT nodes are terminal and don't block the workflow."""
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- wf = Workflow(name="cache_terminal")
- wf.step("a", ok_task)
- wf.step("b", ok_task, after="a")
-
- # First run.
- with workflow_worker():
- run1 = queue.submit_workflow(wf)
- run1.wait(timeout=15)
-
- # Second incremental run.
- with workflow_worker():
- run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id)
- final = run2.wait(timeout=15)
-
- # Workflow should complete (CACHE_HIT is terminal).
- assert final.state == WorkflowState.COMPLETED
-
-
-def test_full_refresh_ignores_cache(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """incremental=False always re-runs everything."""
-
- executed: list[str] = []
-
- @queue.task()
- def step_a() -> str:
- executed.append("a")
- return "a"
-
- wf = Workflow(name="full_refresh")
- wf.step("a", step_a)
-
- # First run.
- with workflow_worker():
- run1 = queue.submit_workflow(wf)
- run1.wait(timeout=15)
-
- executed.clear()
-
- # Second run without incremental — should re-execute.
- with workflow_worker():
- run2 = queue.submit_workflow(wf)
- run2.wait(timeout=15)
-
- assert executed == ["a"]
-
-
-def test_cache_ttl_expires() -> None:
- """Expired base run results trigger re-execution."""
- from taskito.workflows.incremental import compute_dirty_set
-
- base_nodes: list[tuple[str, str, str | None]] = [
- ("a", "completed", "hash_a"),
- ]
-
- # base_run_completed_at is 1000 seconds ago, TTL is 500s → expired
- import time
-
- now_ms = int(time.time() * 1000)
- old_completed = now_ms - 1_000_000 # 1000 seconds ago
-
- dirty, cached = compute_dirty_set(
- base_nodes=base_nodes,
- new_node_names=["a"],
- successors={"a": []},
- predecessors={"a": []},
- cache_ttl=500.0,
- base_run_completed_at=old_completed,
- )
-
- assert "a" in dirty
- assert not cached
diff --git a/tests/test_workflows_conditions.py b/tests/test_workflows_conditions.py
deleted file mode 100644
index 6c48468..0000000
--- a/tests/test_workflows_conditions.py
+++ /dev/null
@@ -1,405 +0,0 @@
-"""Tests for Phase 4 conditional execution and error handling."""
-
-from __future__ import annotations
-
-import threading
-from collections.abc import Callable
-from contextlib import AbstractContextManager
-
-from taskito import Queue
-from taskito.workflows import NodeStatus, Workflow, WorkflowContext, WorkflowState
-
-WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]
-
-
-def test_on_failure_step_runs(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A step with condition='on_failure' runs when predecessor fails."""
-
- @queue.task(max_retries=0)
- def fail_task() -> str:
- raise RuntimeError("boom")
-
- collected: list[str] = []
-
- @queue.task()
- def cleanup() -> str:
- collected.append("cleanup ran")
- return "cleaned"
-
- wf = Workflow(name="on_failure_runs")
- wf.step("a", fail_task)
- wf.step("b", cleanup, after="a", condition="on_failure")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.nodes["a"].status == NodeStatus.FAILED
- assert final.nodes["b"].status == NodeStatus.COMPLETED
- assert collected == ["cleanup ran"]
-
-
-def test_on_failure_step_skipped_on_success(
- queue: Queue, workflow_worker: WorkflowWorkerFactory
-) -> None:
- """A step with condition='on_failure' is SKIPPED when predecessor succeeds."""
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- @queue.task()
- def rollback() -> str:
- return "should not run"
-
- wf = Workflow(name="on_failure_skipped")
- wf.step("a", ok_task)
- wf.step("b", rollback, after="a", condition="on_failure")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.nodes["a"].status == NodeStatus.COMPLETED
- assert final.nodes["b"].status == NodeStatus.SKIPPED
-
-
-def test_always_step_runs_on_success(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A step with condition='always' runs when predecessor succeeds."""
-
- collected: list[str] = []
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- @queue.task()
- def always_task() -> str:
- collected.append("always ran")
- return "done"
-
- wf = Workflow(name="always_on_success")
- wf.step("a", ok_task)
- wf.step("b", always_task, after="a", condition="always")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.nodes["b"].status == NodeStatus.COMPLETED
- assert collected == ["always ran"]
-
-
-def test_always_step_runs_on_failure(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A step with condition='always' runs even when predecessor fails."""
-
- collected: list[str] = []
-
- @queue.task(max_retries=0)
- def fail_task() -> str:
- raise RuntimeError("boom")
-
- @queue.task()
- def always_task() -> str:
- collected.append("always ran")
- return "done"
-
- wf = Workflow(name="always_on_failure")
- wf.step("a", fail_task)
- wf.step("b", always_task, after="a", condition="always")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.nodes["a"].status == NodeStatus.FAILED
- assert final.nodes["b"].status == NodeStatus.COMPLETED
- assert collected == ["always ran"]
-
-
-def test_on_success_default(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Default condition (on_success) skips the step when predecessor fails."""
-
- @queue.task(max_retries=0)
- def fail_task() -> str:
- raise RuntimeError("boom")
-
- @queue.task()
- def next_task() -> str:
- return "should not run"
-
- wf = Workflow(name="on_success_default")
- wf.step("a", fail_task)
- wf.step("b", next_task, after="a")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.FAILED
- assert final.nodes["a"].status == NodeStatus.FAILED
- assert final.nodes["b"].status == NodeStatus.SKIPPED
-
-
-def test_continue_mode_independent_branches(
- queue: Queue, workflow_worker: WorkflowWorkerFactory
-) -> None:
- """on_failure='continue' lets independent branches keep running."""
-
- order: list[str] = []
-
- @queue.task(max_retries=0)
- def fail_task() -> str:
- order.append("fail")
- raise RuntimeError("boom")
-
- @queue.task()
- def ok_task() -> str:
- order.append("ok")
- return "ok"
-
- @queue.task()
- def after_fail() -> str:
- order.append("after_fail")
- return "nope"
-
- @queue.task()
- def after_ok() -> str:
- order.append("after_ok")
- return "yes"
-
- # Diamond: root → {fail_branch, ok_branch} → {after_fail, after_ok}
- wf = Workflow(name="continue_branches", on_failure="continue")
- wf.step("root", ok_task)
- wf.step("fail_branch", fail_task, after="root")
- wf.step("ok_branch", ok_task, after="root")
- wf.step("after_fail", after_fail, after="fail_branch")
- wf.step("after_ok", after_ok, after="ok_branch")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- # fail_branch failed → after_fail skipped (condition=on_success, pred failed)
- # ok_branch succeeded → after_ok ran
- assert final.state == WorkflowState.FAILED # overall: has failures
- assert final.nodes["fail_branch"].status == NodeStatus.FAILED
- assert final.nodes["ok_branch"].status == NodeStatus.COMPLETED
- assert final.nodes["after_fail"].status == NodeStatus.SKIPPED
- assert final.nodes["after_ok"].status == NodeStatus.COMPLETED
- assert "after_ok" in order
-
-
-def test_continue_mode_skips_downstream(
- queue: Queue, workflow_worker: WorkflowWorkerFactory
-) -> None:
- """In continue mode, failure skips on_success downstream in the chain."""
-
- @queue.task(max_retries=0)
- def fail_task() -> str:
- raise RuntimeError("boom")
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- wf = Workflow(name="continue_chain", on_failure="continue")
- wf.step("a", ok_task)
- wf.step("b", fail_task, after="a")
- wf.step("c", ok_task, after="b")
- wf.step("d", ok_task, after="c")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.FAILED
- assert final.nodes["a"].status == NodeStatus.COMPLETED
- assert final.nodes["b"].status == NodeStatus.FAILED
- assert final.nodes["c"].status == NodeStatus.SKIPPED
- assert final.nodes["d"].status == NodeStatus.SKIPPED
-
-
-def test_callable_condition_true(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A callable condition that returns True lets the step run."""
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- collected: list[str] = []
-
- @queue.task()
- def guarded() -> str:
- collected.append("ran")
- return "done"
-
- wf = Workflow(name="callable_true")
- wf.step("a", ok_task)
- wf.step("b", guarded, after="a", condition=lambda ctx: True)
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.nodes["b"].status == NodeStatus.COMPLETED
- assert collected == ["ran"]
-
-
-def test_callable_condition_false(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A callable condition that returns False skips the step."""
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- @queue.task()
- def guarded() -> str:
- return "should not run"
-
- wf = Workflow(name="callable_false")
- wf.step("a", ok_task)
- wf.step("b", guarded, after="a", condition=lambda ctx: False)
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.nodes["b"].status == NodeStatus.SKIPPED
-
-
-def test_callable_accesses_results(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A callable condition can access predecessor results via ctx.results."""
-
- @queue.task()
- def score_task() -> dict:
- return {"score": 0.98}
-
- collected: list[str] = []
-
- @queue.task()
- def deploy() -> str:
- collected.append("deployed")
- return "ok"
-
- @queue.task()
- def skip_deploy() -> str:
- collected.append("should not deploy")
- return "skip"
-
- def high_score(ctx: WorkflowContext) -> bool:
- return bool(ctx.results.get("validate", {}).get("score", 0) > 0.95)
-
- wf = Workflow(name="callable_results")
- wf.step("validate", score_task)
- wf.step("deploy", deploy, after="validate", condition=high_score)
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.nodes["deploy"].status == NodeStatus.COMPLETED
- assert collected == ["deployed"]
-
-
-def test_fail_fast_backward_compat(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Phase 2 regression: fail_fast (default) cascades all pending nodes."""
-
- @queue.task(max_retries=0)
- def fail_task() -> str:
- raise RuntimeError("boom")
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- wf = Workflow(name="fail_fast_compat")
- wf.step("a", fail_task)
- wf.step("b", ok_task, after="a")
- wf.step("c", ok_task, after="b")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=15)
-
- assert final.state == WorkflowState.FAILED
- assert final.nodes["a"].status == NodeStatus.FAILED
- assert final.nodes["b"].status == NodeStatus.SKIPPED
- assert final.nodes["c"].status == NodeStatus.SKIPPED
-
-
-def test_skip_propagation_respects_always(
- queue: Queue, workflow_worker: WorkflowWorkerFactory
-) -> None:
- """A→B→C: A fails, B(on_success) skipped, C(always) still runs."""
-
- @queue.task(max_retries=0)
- def fail_task() -> str:
- raise RuntimeError("boom")
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- collected: list[str] = []
-
- @queue.task()
- def always_task() -> str:
- collected.append("always ran")
- return "done"
-
- wf = Workflow(name="skip_propagation")
- wf.step("a", fail_task)
- wf.step("b", ok_task, after="a")
- wf.step("c", always_task, after="b", condition="always")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.nodes["a"].status == NodeStatus.FAILED
- assert final.nodes["b"].status == NodeStatus.SKIPPED
- assert final.nodes["c"].status == NodeStatus.COMPLETED
- assert collected == ["always ran"]
-
-
-def test_fan_out_with_on_failure_downstream(
- queue: Queue, workflow_worker: WorkflowWorkerFactory
-) -> None:
- """Fan-out child fails, downstream on_failure step runs."""
-
- @queue.task()
- def source() -> list[int]:
- return [1, 2]
-
- @queue.task(max_retries=0)
- def process(x: int) -> int:
- if x == 2:
- raise RuntimeError("boom")
- return x * 10
-
- @queue.task()
- def aggregate(results: list[int]) -> str:
- return "agg"
-
- collected: list[str] = []
-
- @queue.task()
- def on_error() -> str:
- collected.append("error handled")
- return "handled"
-
- wf = Workflow(name="fan_out_on_failure")
- wf.step("fetch", source)
- wf.step("process", process, after="fetch", fan_out="each")
- wf.step("collect", aggregate, after="process", fan_in="all")
- wf.step("handle_error", on_error, after="process", condition="on_failure")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.nodes["process"].status == NodeStatus.FAILED
- assert final.nodes["collect"].status == NodeStatus.SKIPPED
- assert final.nodes["handle_error"].status == NodeStatus.COMPLETED
- assert collected == ["error handled"]
diff --git a/tests/test_workflows_cron.py b/tests/test_workflows_cron.py
deleted file mode 100644
index 0aef3f5..0000000
--- a/tests/test_workflows_cron.py
+++ /dev/null
@@ -1,30 +0,0 @@
-"""Tests for Phase 5D cron-scheduled workflows."""
-
-from __future__ import annotations
-
-from taskito import Queue
-from taskito.workflows import Workflow
-
-
-def test_periodic_workflow_registers_launcher(queue: Queue) -> None:
- """@queue.periodic + @queue.workflow registers a launcher task."""
-
- @queue.task()
- def extract() -> str:
- return "data"
-
- @queue.periodic(cron="0 0 2 * * *")
- @queue.workflow("nightly")
- def nightly() -> Workflow:
- wf = Workflow()
- wf.step("extract", extract)
- return wf
-
- # The launcher task should be registered
- launcher_name = "_wf_launcher_nightly"
- assert launcher_name in queue._task_registry
- # The periodic config should reference the launcher
- assert any(pc["task_name"] == launcher_name for pc in queue._periodic_configs)
- # The workflow proxy should still be returned
- assert hasattr(nightly, "submit")
- assert hasattr(nightly, "build")
diff --git a/tests/test_workflows_fan_out.py b/tests/test_workflows_fan_out.py
deleted file mode 100644
index c7f92e6..0000000
--- a/tests/test_workflows_fan_out.py
+++ /dev/null
@@ -1,359 +0,0 @@
-"""Tests for Phase 3 fan-out / fan-in workflow execution."""
-
-from __future__ import annotations
-
-import threading
-import time
-from collections.abc import Callable
-from contextlib import AbstractContextManager
-from typing import Any
-
-import pytest
-
-from taskito import Queue
-from taskito.workflows import NodeStatus, Workflow, WorkflowState
-
-WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]
-
-
-def test_fan_out_each(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """fan_out='each' splits a list into N parallel jobs, fan_in='all' collects."""
-
- @queue.task()
- def source() -> list[int]:
- return [10, 20, 30]
-
- @queue.task()
- def double(x: int) -> int:
- return x * 2
-
- collected: list[Any] = []
-
- @queue.task()
- def aggregate(results: list[int]) -> str:
- collected.extend(results)
- return "done"
-
- wf = Workflow(name="fan_out_each")
- wf.step("fetch", source)
- wf.step("process", double, after="fetch", fan_out="each")
- wf.step("collect", aggregate, after="process", fan_in="all")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.COMPLETED
- assert sorted(collected) == [20, 40, 60]
-
-
-def test_fan_out_empty_list(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Fan-out over an empty list → fan-in receives []."""
-
- @queue.task()
- def source() -> list:
- return []
-
- @queue.task()
- def process(x: int) -> int:
- return x * 2
-
- collected: list[Any] = []
-
- @queue.task()
- def aggregate(results: list) -> str:
- collected.extend(results)
- return "empty"
-
- wf = Workflow(name="fan_out_empty")
- wf.step("fetch", source)
- wf.step("process", process, after="fetch", fan_out="each")
- wf.step("collect", aggregate, after="process", fan_in="all")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.COMPLETED
- assert collected == []
-
-
-def test_fan_out_single_item(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Fan-out with a single-element list → 1 child → fan-in gets [result]."""
-
- @queue.task()
- def source() -> list[int]:
- return [42]
-
- @queue.task()
- def add_one(x: int) -> int:
- return x + 1
-
- collected: list[Any] = []
-
- @queue.task()
- def aggregate(results: list[int]) -> str:
- collected.extend(results)
- return "single"
-
- wf = Workflow(name="fan_out_single")
- wf.step("fetch", source)
- wf.step("process", add_one, after="fetch", fan_out="each")
- wf.step("collect", aggregate, after="process", fan_in="all")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.COMPLETED
- assert collected == [43]
-
-
-def test_fan_out_with_downstream(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Full pipeline: source → fan_out → fan_in → downstream static step."""
-
- order: list[str] = []
-
- @queue.task()
- def source() -> list[str]:
- order.append("source")
- return ["a", "b"]
-
- @queue.task()
- def process(item: str) -> str:
- order.append(f"process:{item}")
- return item.upper()
-
- @queue.task()
- def aggregate(results: list[str]) -> str:
- order.append("aggregate")
- return ",".join(sorted(results))
-
- @queue.task()
- def report() -> str:
- order.append("report")
- return "finished"
-
- wf = Workflow(name="downstream_pipe")
- wf.step("fetch", source)
- wf.step("process", process, after="fetch", fan_out="each")
- wf.step("agg", aggregate, after="process", fan_in="all")
- wf.step("report", report, after="agg")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.COMPLETED
- assert "source" in order
- assert "aggregate" in order
- assert "report" in order
- # Source runs first, report runs last
- assert order.index("source") < order.index("aggregate")
- assert order.index("aggregate") < order.index("report")
-
-
-def test_fan_out_child_failure(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A failing fan-out child triggers fail-fast, workflow fails."""
-
- @queue.task()
- def source() -> list[int]:
- return [1, 2, 3]
-
- @queue.task(max_retries=0)
- def maybe_fail(x: int) -> int:
- if x == 2:
- raise RuntimeError("boom on 2")
- return x * 10
-
- @queue.task()
- def aggregate(results: list[int]) -> str:
- return "should not run"
-
- wf = Workflow(name="fan_out_fail")
- wf.step("fetch", source)
- wf.step("process", maybe_fail, after="fetch", fan_out="each")
- wf.step("collect", aggregate, after="process", fan_in="all")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.FAILED
- # The aggregate node should be skipped
- assert final.nodes["collect"].status == NodeStatus.SKIPPED
-
-
-def test_fan_out_source_failure(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """If the fan-out source step fails, all deferred nodes are SKIPPED."""
-
- @queue.task(max_retries=0)
- def bad_source() -> list[int]:
- raise RuntimeError("source failed")
-
- @queue.task()
- def process(x: int) -> int:
- return x * 2
-
- @queue.task()
- def aggregate(results: list[int]) -> str:
- return "nope"
-
- wf = Workflow(name="source_fail")
- wf.step("fetch", bad_source)
- wf.step("process", process, after="fetch", fan_out="each")
- wf.step("collect", aggregate, after="process", fan_in="all")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.FAILED
- assert final.nodes["fetch"].status == NodeStatus.FAILED
- assert final.nodes["process"].status == NodeStatus.SKIPPED
- assert final.nodes["collect"].status == NodeStatus.SKIPPED
-
-
-def test_fan_out_cancellation(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Cancelling a workflow mid-fan-out skips pending children."""
-
- @queue.task()
- def source() -> list[int]:
- return [1, 2, 3]
-
- @queue.task()
- def slow_process(x: int) -> int:
- time.sleep(10) # will be cancelled
- return x
-
- @queue.task()
- def aggregate(results: list[int]) -> str:
- return "nope"
-
- wf = Workflow(name="cancel_fan_out")
- wf.step("fetch", source)
- wf.step("process", slow_process, after="fetch", fan_out="each")
- wf.step("collect", aggregate, after="process", fan_in="all")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- # Wait for source to complete and fan-out to expand
- time.sleep(2)
- run.cancel()
- snapshot = run.status()
-
- assert snapshot.state == WorkflowState.CANCELLED
-
-
-def test_fan_out_status_shows_children(
- queue: Queue, workflow_worker: WorkflowWorkerFactory
-) -> None:
- """status() returns child node snapshots like process[0], process[1]."""
-
- @queue.task()
- def source() -> list[int]:
- return [1, 2, 3]
-
- @queue.task()
- def process(x: int) -> int:
- return x
-
- @queue.task()
- def aggregate(results: list[int]) -> str:
- return "done"
-
- wf = Workflow(name="show_children")
- wf.step("fetch", source)
- wf.step("process", process, after="fetch", fan_out="each")
- wf.step("collect", aggregate, after="process", fan_in="all")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.COMPLETED
- # Children should appear in the node map
- assert "process[0]" in final.nodes
- assert "process[1]" in final.nodes
- assert "process[2]" in final.nodes
- for i in range(3):
- assert final.nodes[f"process[{i}]"].status == NodeStatus.COMPLETED
- assert final.nodes[f"process[{i}]"].job_id is not None
-
-
-def test_fan_out_preserves_result_order(
- queue: Queue, workflow_worker: WorkflowWorkerFactory
-) -> None:
- """Fan-in results maintain the order of child indices."""
-
- @queue.task()
- def source() -> list[str]:
- return ["x", "y", "z"]
-
- @queue.task()
- def identity(item: str) -> str:
- return item
-
- collected: list[Any] = []
-
- @queue.task()
- def aggregate(results: list[str]) -> str:
- collected.extend(results)
- return "ok"
-
- wf = Workflow(name="order_check")
- wf.step("fetch", source)
- wf.step("process", identity, after="fetch", fan_out="each")
- wf.step("collect", aggregate, after="process", fan_in="all")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.COMPLETED
- # Results should be in child index order (same as input order)
- assert collected == ["x", "y", "z"]
-
-
-def test_step_name_bracket_validation() -> None:
- """Step names containing '[' raise ValueError."""
- wf = Workflow(name="bad_name")
-
- class _FakeTask:
- _task_name = "fake"
-
- with pytest.raises(ValueError, match="must not contain"):
- wf.step("bad[0]", _FakeTask())
-
-
-def test_linear_workflow_still_works(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Phase 2 regression: a linear workflow without fan-out still works."""
-
- order: list[str] = []
-
- @queue.task()
- def step_a() -> str:
- order.append("a")
- return "a"
-
- @queue.task()
- def step_b() -> str:
- order.append("b")
- return "b"
-
- @queue.task()
- def step_c() -> str:
- order.append("c")
- return "c"
-
- wf = Workflow(name="linear_regression")
- wf.step("a", step_a)
- wf.step("b", step_b, after="a")
- wf.step("c", step_c, after="b")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=15)
-
- assert final.state == WorkflowState.COMPLETED
- assert order == ["a", "b", "c"]
diff --git a/tests/test_workflows_gates.py b/tests/test_workflows_gates.py
deleted file mode 100644
index bcbf41b..0000000
--- a/tests/test_workflows_gates.py
+++ /dev/null
@@ -1,158 +0,0 @@
-"""Tests for Phase 5A approval gates."""
-
-from __future__ import annotations
-
-import threading
-import time
-from collections.abc import Callable
-from contextlib import AbstractContextManager
-
-from taskito import Queue
-from taskito.workflows import NodeStatus, Workflow, WorkflowState
-
-WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]
-
-
-def test_gate_pauses_workflow(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A gate node enters WAITING_APPROVAL and blocks downstream."""
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- wf = Workflow(name="gate_pause")
- wf.step("a", ok_task)
- wf.gate("approve", after="a")
- wf.step("b", ok_task, after="approve")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- time.sleep(3) # Let "a" complete
- snapshot = run.status()
-
- assert snapshot.state == WorkflowState.RUNNING
- assert snapshot.nodes["a"].status == NodeStatus.COMPLETED
- assert snapshot.nodes["approve"].status == NodeStatus.WAITING_APPROVAL
- assert snapshot.nodes["b"].status == NodeStatus.PENDING
-
-
-def test_approve_gate_resumes(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Approving a gate lets downstream steps run to completion."""
-
- collected: list[str] = []
-
- @queue.task()
- def ok_task() -> str:
- collected.append("ran")
- return "ok"
-
- wf = Workflow(name="gate_approve")
- wf.step("a", ok_task)
- wf.gate("approve", after="a")
- wf.step("b", ok_task, after="approve")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- time.sleep(2) # Let "a" complete and gate enter WAITING_APPROVAL
- queue.approve_gate(run.id, "approve")
- final = run.wait(timeout=15)
-
- assert final.state == WorkflowState.COMPLETED
- assert final.nodes["approve"].status == NodeStatus.COMPLETED
- assert final.nodes["b"].status == NodeStatus.COMPLETED
- assert len(collected) == 2 # "a" and "b"
-
-
-def test_reject_gate_fails(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Rejecting a gate fails it and skips downstream."""
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- wf = Workflow(name="gate_reject")
- wf.step("a", ok_task)
- wf.gate("approve", after="a")
- wf.step("b", ok_task, after="approve")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- time.sleep(2)
- queue.reject_gate(run.id, "approve", error="not approved")
- final = run.wait(timeout=15)
-
- assert final.state == WorkflowState.FAILED
- assert final.nodes["approve"].status == NodeStatus.FAILED
- assert final.nodes["b"].status == NodeStatus.SKIPPED
-
-
-def test_gate_timeout_reject(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Gate with timeout and on_timeout='reject' auto-rejects."""
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- wf = Workflow(name="gate_timeout_reject")
- wf.step("a", ok_task)
- wf.gate("approve", after="a", timeout=1.0, on_timeout="reject")
- wf.step("b", ok_task, after="approve")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=15)
-
- assert final.state == WorkflowState.FAILED
- assert final.nodes["approve"].status == NodeStatus.FAILED
- assert final.nodes["approve"].error is not None
- assert "timeout" in (final.nodes["approve"].error or "").lower()
-
-
-def test_gate_timeout_approve(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Gate with on_timeout='approve' auto-approves and continues."""
-
- collected: list[str] = []
-
- @queue.task()
- def ok_task() -> str:
- collected.append("ran")
- return "ok"
-
- wf = Workflow(name="gate_timeout_approve")
- wf.step("a", ok_task)
- wf.gate("approve", after="a", timeout=1.0, on_timeout="approve")
- wf.step("b", ok_task, after="approve")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=15)
-
- assert final.state == WorkflowState.COMPLETED
- assert final.nodes["approve"].status == NodeStatus.COMPLETED
- assert final.nodes["b"].status == NodeStatus.COMPLETED
-
-
-def test_gate_with_condition(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A gate with condition='on_success' respects predecessor state."""
-
- @queue.task(max_retries=0)
- def fail_task() -> str:
- raise RuntimeError("fail")
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- wf = Workflow(name="gate_condition")
- wf.step("a", fail_task)
- wf.gate("approve", after="a", condition="on_success")
- wf.step("b", ok_task, after="approve")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=15)
-
- assert final.state == WorkflowState.FAILED
- # Gate should be skipped because predecessor failed (condition=on_success)
- assert final.nodes["approve"].status == NodeStatus.SKIPPED
- assert final.nodes["b"].status == NodeStatus.SKIPPED
diff --git a/tests/test_workflows_linear.py b/tests/test_workflows_linear.py
deleted file mode 100644
index 27716a0..0000000
--- a/tests/test_workflows_linear.py
+++ /dev/null
@@ -1,348 +0,0 @@
-"""Tests for Phase 2 linear workflow execution."""
-
-from __future__ import annotations
-
-import threading
-import time
-from collections.abc import Callable
-from contextlib import AbstractContextManager
-
-import pytest
-
-from taskito import Queue
-from taskito.workflows import NodeStatus, Workflow, WorkflowState
-from taskito.workflows.run import WorkflowTimeoutError
-
-WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]
-
-
-def test_linear_three_step_workflow(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A→B→C runs in order and the workflow reaches COMPLETED."""
-
- order: list[str] = []
-
- @queue.task()
- def step_a() -> str:
- order.append("a")
- return "a-done"
-
- @queue.task()
- def step_b() -> str:
- order.append("b")
- return "b-done"
-
- @queue.task()
- def step_c() -> str:
- order.append("c")
- return "c-done"
-
- wf = Workflow(name="linear_pipe")
- wf.step("a", step_a)
- wf.step("b", step_b, after="a")
- wf.step("c", step_c, after="b")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=15)
-
- assert final.state == WorkflowState.COMPLETED
- assert order == ["a", "b", "c"]
- assert all(n.status == NodeStatus.COMPLETED for n in final.nodes.values())
- assert set(final.nodes.keys()) == {"a", "b", "c"}
-
-
-def test_workflow_with_args_and_kwargs(
- queue: Queue, workflow_worker: WorkflowWorkerFactory
-) -> None:
- """Step args and kwargs round-trip through the queue serializer."""
-
- received: list[tuple] = []
-
- @queue.task()
- def collect(x: int, y: int, *, label: str) -> int:
- received.append((x, y, label))
- return x + y
-
- wf = Workflow(name="args_pipe")
- wf.step("first", collect, args=(2, 3), kwargs={"label": "a"})
- wf.step("second", collect, args=(10, 20), kwargs={"label": "b"}, after="first")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=15)
-
- assert final.state == WorkflowState.COMPLETED
- assert (2, 3, "a") in received
- assert (10, 20, "b") in received
-
-
-def test_workflow_decorator_registration(
- queue: Queue, workflow_worker: WorkflowWorkerFactory
-) -> None:
- """@queue.workflow() stores a proxy that can build and submit."""
-
- @queue.task()
- def noop() -> None:
- return None
-
- @queue.workflow("nightly")
- def build() -> Workflow:
- wf = Workflow()
- wf.step("x", noop)
- return wf
-
- assert "nightly" in queue._workflow_registry
- built = build.build()
- assert built.name == "nightly"
- assert built.step_names == ["x"]
-
- with workflow_worker():
- run = build.submit()
- final = run.wait(timeout=10)
-
- assert final.state == WorkflowState.COMPLETED
-
-
-def test_workflow_status_before_completion(queue: Queue) -> None:
- """status() reflects non-terminal state before the workflow finishes."""
-
- @queue.task()
- def noop() -> None:
- return None
-
- wf = Workflow(name="status_check")
- wf.step("only", noop)
-
- run = queue.submit_workflow(wf)
- snapshot = run.status()
- # No worker running, so the run stays in RUNNING with the node PENDING
- assert snapshot.state == WorkflowState.RUNNING
- assert snapshot.nodes["only"].status == NodeStatus.PENDING
- assert snapshot.nodes["only"].job_id is not None
-
-
-def test_workflow_wait_timeout(queue: Queue) -> None:
- """wait() raises WorkflowTimeoutError if the workflow doesn't finish in time."""
-
- @queue.task()
- def noop() -> None:
- return None
-
- wf = Workflow(name="timeout_test")
- wf.step("only", noop)
-
- run = queue.submit_workflow(wf)
- # No worker running → timeout
- with pytest.raises(WorkflowTimeoutError):
- run.wait(timeout=0.3)
-
-
-def test_workflow_cancellation(queue: Queue) -> None:
- """Cancelling a workflow marks pending nodes SKIPPED and the run CANCELLED."""
-
- @queue.task()
- def noop() -> None:
- return None
-
- wf = Workflow(name="cancel_test")
- wf.step("a", noop)
- wf.step("b", noop, after="a")
- wf.step("c", noop, after="b")
-
- run = queue.submit_workflow(wf)
- run.cancel()
-
- snapshot = run.status()
- assert snapshot.state == WorkflowState.CANCELLED
- for node in snapshot.nodes.values():
- assert node.status == NodeStatus.SKIPPED
-
-
-def test_workflow_failing_step(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A failing step fails the workflow and skips downstream steps."""
-
- @queue.task(max_retries=0)
- def good() -> str:
- return "ok"
-
- @queue.task(max_retries=0)
- def boom() -> str:
- raise RuntimeError("kaboom")
-
- wf = Workflow(name="failing")
- wf.step("a", good)
- wf.step("b", boom, after="a")
- wf.step("c", good, after="b")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=15)
-
- assert final.state == WorkflowState.FAILED
- assert final.nodes["a"].status == NodeStatus.COMPLETED
- assert final.nodes["b"].status == NodeStatus.FAILED
- assert final.nodes["c"].status == NodeStatus.SKIPPED
-
-
-def test_workflow_node_snapshot_fields(queue: Queue) -> None:
- """Each node snapshot has name, status, job_id, error."""
-
- @queue.task()
- def noop() -> None:
- return None
-
- wf = Workflow(name="snapshot_check")
- wf.step("one", noop)
- wf.step("two", noop, after="one")
-
- run = queue.submit_workflow(wf)
- snapshot = run.status()
- assert "one" in snapshot.nodes
- assert "two" in snapshot.nodes
- for name, node in snapshot.nodes.items():
- assert node.name == name
- assert node.status == NodeStatus.PENDING
- assert node.job_id is not None
- assert node.error is None
-
-
-def test_workflow_node_status_helper(queue: Queue) -> None:
- """node_status() returns the status of a specific node."""
-
- @queue.task()
- def noop() -> None:
- return None
-
- wf = Workflow(name="helper_check")
- wf.step("one", noop)
-
- run = queue.submit_workflow(wf)
- assert run.node_status("one") == NodeStatus.PENDING
-
- with pytest.raises(KeyError):
- run.node_status("nonexistent")
-
-
-def test_workflow_step_ordering_validation() -> None:
- """step() raises ValueError if after references an unknown predecessor."""
- wf = Workflow(name="invalid")
-
- class _FakeTask:
- _task_name = "fake"
-
- wf.step("a", _FakeTask())
- with pytest.raises(ValueError, match="predecessor 'missing'"):
- wf.step("b", _FakeTask(), after="missing")
-
-
-def test_workflow_duplicate_step_name() -> None:
- """step() raises ValueError if the same name is added twice."""
- wf = Workflow(name="dup")
-
- class _FakeTask:
- _task_name = "fake"
-
- wf.step("a", _FakeTask())
- with pytest.raises(ValueError, match="already defined"):
- wf.step("a", _FakeTask())
-
-
-def test_workflow_definition_reuse(queue: Queue) -> None:
- """Submitting the same workflow name+version reuses the definition row."""
-
- @queue.task()
- def noop() -> None:
- return None
-
- wf1 = Workflow(name="reused", version=1)
- wf1.step("only", noop)
- wf2 = Workflow(name="reused", version=1)
- wf2.step("only", noop)
-
- run1 = queue.submit_workflow(wf1)
- run2 = queue.submit_workflow(wf2)
- assert run1.id != run2.id
- # Definitions share an ID via name+version uniqueness
- # (verified indirectly: second submit succeeds without duplicate-key errors)
-
-
-def test_workflow_emits_completed_event(
- queue: Queue, workflow_worker: WorkflowWorkerFactory
-) -> None:
- """WORKFLOW_COMPLETED event fires on successful run completion."""
- from taskito.events import EventType
-
- @queue.task()
- def noop() -> None:
- return None
-
- events: list[dict] = []
- event_received = threading.Event()
-
- def listener(_event_type: EventType, payload: dict) -> None:
- events.append(payload)
- event_received.set()
-
- queue._event_bus.on(EventType.WORKFLOW_COMPLETED, listener)
-
- wf = Workflow(name="event_pipe")
- wf.step("x", noop)
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- run.wait(timeout=10)
- # Give event bus thread a moment to dispatch
- event_received.wait(timeout=5)
-
- assert any(e.get("run_id") == run.id for e in events)
- assert any(e.get("state") == "completed" for e in events)
-
-
-def test_workflow_run_repr(queue: Queue) -> None:
- """WorkflowRun __repr__ is informative."""
-
- @queue.task()
- def noop() -> None:
- return None
-
- wf = Workflow(name="repr_test")
- wf.step("x", noop)
-
- run = queue.submit_workflow(wf)
- r = repr(run)
- assert run.id in r
- assert "repr_test" in r
-
-
-@pytest.mark.asyncio
-async def test_workflow_async_step(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """An async @queue.task() step works inside a workflow."""
-
- @queue.task()
- async def async_step() -> str:
- return "async-ok"
-
- @queue.task()
- def sync_step() -> str:
- return "sync-ok"
-
- wf = Workflow(name="async_mix")
- wf.step("a", async_step)
- wf.step("b", sync_step, after="a")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- # Poll instead of blocking wait to play nicely with asyncio
- deadline = time.monotonic() + 15
- final = run.status()
- while not final.state.is_terminal() and time.monotonic() < deadline:
- await _async_sleep(0.1)
- final = run.status()
-
- assert final.state == WorkflowState.COMPLETED
-
-
-async def _async_sleep(seconds: float) -> None:
- import asyncio
-
- await asyncio.sleep(seconds)
diff --git a/tests/test_workflows_subworkflow.py b/tests/test_workflows_subworkflow.py
deleted file mode 100644
index 7629a38..0000000
--- a/tests/test_workflows_subworkflow.py
+++ /dev/null
@@ -1,194 +0,0 @@
-"""Tests for Phase 5B sub-workflows."""
-
-from __future__ import annotations
-
-import threading
-from collections.abc import Callable
-from contextlib import AbstractContextManager
-
-from taskito import Queue
-from taskito.workflows import NodeStatus, Workflow, WorkflowState
-
-WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]
-
-
-def test_sub_workflow_executes(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A sub-workflow step runs the child workflow, parent continues after."""
-
- order: list[str] = []
-
- @queue.task()
- def extract() -> str:
- order.append("extract")
- return "data"
-
- @queue.task()
- def load() -> str:
- order.append("load")
- return "loaded"
-
- @queue.task()
- def report() -> str:
- order.append("report")
- return "done"
-
- @queue.workflow("etl")
- def etl_pipeline() -> Workflow:
- wf = Workflow()
- wf.step("extract", extract)
- wf.step("load", load, after="extract")
- return wf
-
- wf = Workflow(name="parent")
- wf.step("etl", etl_pipeline.as_step())
- wf.step("report", report, after="etl")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.COMPLETED
- assert "extract" in order
- assert "load" in order
- assert "report" in order
- assert order.index("load") < order.index("report")
-
-
-def test_sub_workflow_failure(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """A failing sub-workflow fails the parent node."""
-
- @queue.task(max_retries=0)
- def fail_task() -> str:
- raise RuntimeError("sub failed")
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- @queue.workflow("failing_sub")
- def failing_sub() -> Workflow:
- wf = Workflow()
- wf.step("boom", fail_task)
- return wf
-
- wf = Workflow(name="parent_fail")
- wf.step("sub", failing_sub.as_step())
- wf.step("after", ok_task, after="sub")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.FAILED
- assert final.nodes["sub"].status == NodeStatus.FAILED
- assert final.nodes["after"].status == NodeStatus.SKIPPED
-
-
-def test_sub_workflow_compile_failure_marks_parent_failed(
- queue: Queue, workflow_worker: WorkflowWorkerFactory
-) -> None:
- """Regression: a factory that raises during `build()` must not leave the
- parent node Skipped forever — it must be marked Failed so the outer run
- can finalize. Before the fix, the tracker called `skip_workflow_node`
- on the parent before attempting compile, and a compile failure left the
- node Skipped permanently."""
-
- @queue.task()
- def downstream() -> str:
- return "should not run"
-
- @queue.workflow("broken_sub")
- def broken_sub() -> Workflow:
- raise RuntimeError("factory blew up")
-
- wf = Workflow(name="parent_compile_fail")
- wf.step("sub", broken_sub.as_step())
- wf.step("after", downstream, after="sub")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=15)
-
- assert final.state == WorkflowState.FAILED, (
- f"outer run must finalize as FAILED, got {final.state}"
- )
- assert final.nodes["sub"].status == NodeStatus.FAILED, (
- f"sub-workflow parent must be FAILED (was {final.nodes['sub'].status}) — "
- "the old bug left it SKIPPED"
- )
- assert final.nodes["after"].status == NodeStatus.SKIPPED
-
-
-def test_cancel_parent_cascades(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Cancelling a parent workflow cancels the child sub-workflow too."""
-
- import time
-
- @queue.task()
- def slow_task() -> str:
- time.sleep(30)
- return "slow"
-
- @queue.workflow("slow_sub")
- def slow_sub() -> Workflow:
- wf = Workflow()
- wf.step("slow", slow_task)
- return wf
-
- wf = Workflow(name="parent_cancel")
- wf.step("sub", slow_sub.as_step())
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- time.sleep(2) # Let sub-workflow submit
- run.cancel()
- snapshot = run.status()
-
- assert snapshot.state == WorkflowState.CANCELLED
-
-
-def test_parallel_sub_workflows(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """Two sub-workflows can run concurrently."""
-
- order: list[str] = []
-
- @queue.task()
- def task_a() -> str:
- order.append("a")
- return "a"
-
- @queue.task()
- def task_b() -> str:
- order.append("b")
- return "b"
-
- @queue.task()
- def reconcile() -> str:
- order.append("reconcile")
- return "done"
-
- @queue.workflow("sub_a")
- def sub_a() -> Workflow:
- wf = Workflow()
- wf.step("a", task_a)
- return wf
-
- @queue.workflow("sub_b")
- def sub_b() -> Workflow:
- wf = Workflow()
- wf.step("b", task_b)
- return wf
-
- wf = Workflow(name="parallel_parent")
- wf.step("sa", sub_a.as_step())
- wf.step("sb", sub_b.as_step())
- wf.step("reconcile", reconcile, after=["sa", "sb"])
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=20)
-
- assert final.state == WorkflowState.COMPLETED
- assert "a" in order
- assert "b" in order
- assert "reconcile" in order
diff --git a/tests/test_workflows_visualization.py b/tests/test_workflows_visualization.py
deleted file mode 100644
index ebce1c3..0000000
--- a/tests/test_workflows_visualization.py
+++ /dev/null
@@ -1,96 +0,0 @@
-"""Tests for Phase 8 workflow visualization."""
-
-from __future__ import annotations
-
-import threading
-from collections.abc import Callable
-from contextlib import AbstractContextManager
-
-from taskito import Queue
-from taskito.workflows import Workflow
-
-WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]
-
-
-class _FakeTask:
- _task_name = "fake"
-
-
-def test_mermaid_linear() -> None:
- """Linear DAG renders correct Mermaid graph."""
- wf = Workflow(name="linear")
- wf.step("a", _FakeTask())
- wf.step("b", _FakeTask(), after="a")
- wf.step("c", _FakeTask(), after="b")
-
- output = wf.visualize("mermaid")
- assert "graph LR" in output
- assert "a[a]" in output
- assert "b[b]" in output
- assert "c[c]" in output
- assert "a --> b" in output
- assert "b --> c" in output
-
-
-def test_mermaid_diamond() -> None:
- """Diamond DAG with parallel nodes."""
- wf = Workflow(name="diamond")
- wf.step("a", _FakeTask())
- wf.step("b", _FakeTask(), after="a")
- wf.step("c", _FakeTask(), after="a")
- wf.step("d", _FakeTask(), after=["b", "c"])
-
- output = wf.visualize("mermaid")
- assert "a --> b" in output
- assert "a --> c" in output
- assert "b --> d" in output
- assert "c --> d" in output
-
-
-def test_mermaid_with_status() -> None:
- """Mermaid output with status colors."""
- from taskito.workflows.visualization import render_mermaid
-
- output = render_mermaid(
- nodes=["a", "b", "c"],
- edges=[("a", "b"), ("b", "c")],
- statuses={"a": "completed", "b": "failed", "c": "pending"},
- )
- assert "style a fill:#90EE90" in output # green
- assert "style b fill:#FFB6C1" in output # red
- assert "style c fill:#D3D3D3" in output # gray
-
-
-def test_dot_linear() -> None:
- """DOT format output for linear DAG."""
- wf = Workflow(name="linear")
- wf.step("a", _FakeTask())
- wf.step("b", _FakeTask(), after="a")
-
- output = wf.visualize("dot")
- assert "digraph workflow" in output
- assert "rankdir=LR" in output
- assert "a -> b" in output
-
-
-def test_visualize_live_run(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None:
- """WorkflowRun.visualize() shows live statuses."""
-
- @queue.task()
- def ok_task() -> str:
- return "ok"
-
- wf = Workflow(name="viz_live")
- wf.step("a", ok_task)
- wf.step("b", ok_task, after="a")
-
- with workflow_worker():
- run = queue.submit_workflow(wf)
- final = run.wait(timeout=15)
- output = run.visualize("mermaid")
-
- assert final.state.value == "completed"
- assert "graph LR" in output
- assert "a --> b" in output
- # Both nodes should have completed status styling
- assert "#90EE90" in output # green for completed