diff --git a/tests/prefork_apps/__init__.py b/tests/prefork_apps/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/prefork_apps/cancel_app.py b/tests/prefork_apps/cancel_app.py deleted file mode 100644 index 20bf770..0000000 --- a/tests/prefork_apps/cancel_app.py +++ /dev/null @@ -1,33 +0,0 @@ -"""Module-level Queue + tasks used by the prefork cancel regression tests. - -The Queue inside this module must be importable both in the parent test -process and inside each prefork child interpreter. The DB path comes from -``TASKITO_CANCEL_TEST_DB`` so each test run can use its own tmp file while -still letting the parent and child build identical Queue instances from -the same module path. -""" - -from __future__ import annotations - -import os -import time - -from taskito import Queue -from taskito.context import current_job - -queue = Queue(db_path=os.environ.get("TASKITO_CANCEL_TEST_DB", "/tmp/taskito-cancel.db")) - - -@queue.task(timeout=30, max_retries=0) -def cooperative_loop(max_iters: int = 600) -> int: - """Loop calling ``check_cancelled()`` so cancel can stop the task quickly.""" - for _ in range(max_iters): - current_job.check_cancelled() - time.sleep(0.05) - return max_iters - - -@queue.task(max_retries=0) -def quick(x: int) -> int: - """Returns immediately — used to verify the child still serves jobs after a cancel.""" - return x * 2 diff --git a/tests/prefork_apps/timeout_app.py b/tests/prefork_apps/timeout_app.py deleted file mode 100644 index b546e34..0000000 --- a/tests/prefork_apps/timeout_app.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Module-level Queue + tasks used by the prefork timeout regression tests. - -Prefork children import the app module independently, so the task registry -must live at a module path resolvable inside the child interpreter — that's -why this module exists as a sibling of ``test_prefork.py`` rather than being -defined inline in the test. - -The DB path comes from ``TASKITO_TIMEOUT_TEST_DB`` so each test run can use a -unique tmp file while still letting the parent and child build identical -Queue instances from this same module. -""" - -from __future__ import annotations - -import os -import time - -from taskito import Queue - -queue = Queue(db_path=os.environ.get("TASKITO_TIMEOUT_TEST_DB", "/tmp/taskito-timeout.db")) - - -@queue.task(timeout=2, max_retries=0) -def hang() -> None: - """Spin forever — used to trigger the watchdog's SIGKILL path.""" - while True: - pass - - -@queue.task() -def quick(x: int) -> int: - """Returns immediately — used to verify timeout=0 (no timeout) is unaffected.""" - return x * 2 - - -@queue.task(timeout=2, max_retries=0) -def sleep_then_finish(seconds: float) -> str: - """Sleeps for `seconds`, then finishes — used to verify the watchdog only - fires when the deadline is actually exceeded.""" - time.sleep(seconds) - return "done" diff --git a/tests/test_basic.py b/tests/test_basic.py deleted file mode 100644 index 9a42053..0000000 --- a/tests/test_basic.py +++ /dev/null @@ -1,140 +0,0 @@ -"""Basic tests for taskito — enqueue, dequeue, result retrieval.""" - -import threading - -from taskito import Queue - - -def test_task_registration(queue: Queue) -> None: - """Tasks can be registered with the decorator.""" - - @queue.task() - def add(a: int, b: int) -> int: - return a + b - - assert add.name.endswith("add") - assert add.name in queue._task_registry - - -def test_enqueue_returns_job_result(queue: Queue) -> None: - """Enqueueing a task returns a JobResult handle.""" - - @queue.task() - def noop() -> None: - pass - - job = noop.delay() - assert job.id is not None - assert len(job.id) > 0 - - -def test_task_direct_call(queue: Queue) -> None: - """Decorated tasks can still be called directly.""" - - @queue.task() - def multiply(a: int, b: int) -> int: - return a * b - - assert multiply(3, 4) == 12 - - -def test_apply_async_with_delay(queue: Queue) -> None: - """apply_async accepts a delay parameter.""" - - @queue.task() - def slow_task() -> None: - pass - - job = slow_task.apply_async(delay=60) - assert job.id is not None - - -def test_apply_async_with_overrides(queue: Queue) -> None: - """apply_async can override default task settings.""" - - @queue.task(priority=1, queue="default") - def configurable_task(x: int) -> int: - return x - - job = configurable_task.apply_async( - args=(42,), - priority=10, - queue="urgent", - max_retries=5, - timeout=600, - ) - assert job.id is not None - - -def test_queue_stats(queue: Queue) -> None: - """stats() returns counts by status.""" - - @queue.task() - def sample_task() -> None: - pass - - sample_task.delay() - sample_task.delay() - - stats = queue.stats() - assert stats["pending"] == 2 - assert stats["running"] == 0 - - -def test_worker_executes_task(queue: Queue) -> None: - """Worker processes tasks and stores results.""" - - @queue.task() - def add(a: int, b: int) -> int: - return a + b - - job = add.delay(2, 3) - - # Run worker in a background thread - worker_thread = threading.Thread( - target=queue.run_worker, - daemon=True, - ) - worker_thread.start() - - # Wait for result - result = job.result(timeout=10) - assert result == 5 - - -def test_worker_handles_kwargs(queue: Queue) -> None: - """Worker correctly passes keyword arguments.""" - - @queue.task() - def greet(name: str, greeting: str = "Hello") -> str: - return f"{greeting}, {name}!" - - job = greet.delay("World", greeting="Hi") - - worker_thread = threading.Thread( - target=queue.run_worker, - daemon=True, - ) - worker_thread.start() - - result = job.result(timeout=10) - assert result == "Hi, World!" - - -def test_worker_none_result(queue: Queue) -> None: - """Tasks returning None work correctly.""" - - @queue.task() - def void_task() -> None: - pass - - job = void_task.delay() - - worker_thread = threading.Thread( - target=queue.run_worker, - daemon=True, - ) - worker_thread.start() - - result = job.result(timeout=10) - assert result is None diff --git a/tests/test_batch.py b/tests/test_batch.py deleted file mode 100644 index 70cec6c..0000000 --- a/tests/test_batch.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Tests for batch enqueue (enqueue_many / task.map).""" - -import threading -from typing import Any - -from taskito import Queue -from taskito.middleware import TaskMiddleware - - -def test_enqueue_many(queue: Queue) -> None: - """enqueue_many enqueues all items in a single batch.""" - - @queue.task() - def double(x: int) -> int: - return x * 2 - - jobs = queue.enqueue_many( - task_name=double.name, - args_list=[(i,) for i in range(10)], - ) - assert len(jobs) == 10 - - stats = queue.stats() - assert stats["pending"] == 10 - - -def test_task_map(queue: Queue) -> None: - """TaskWrapper.map() enqueues and returns results.""" - - @queue.task() - def add(a: int, b: int) -> int: - return a + b - - jobs = add.map([(1, 2), (3, 4), (5, 6)]) - assert len(jobs) == 3 - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - results = [j.result(timeout=10) for j in jobs] - assert sorted(results) == [3, 7, 11] - - -def test_batch_stats(queue: Queue) -> None: - """Batch enqueue of 50 items shows correct pending count.""" - - @queue.task() - def noop() -> None: - pass - - queue.enqueue_many( - task_name=noop.name, - args_list=[() for _ in range(50)], - ) - - stats = queue.stats() - assert stats["pending"] == 50 - - -def test_enqueue_many_invokes_on_enqueue_per_job(tmp_path: Any) -> None: - """`on_enqueue` middleware must receive each job's own args/kwargs. - - Regression: the previous implementation always passed `args_list[0]` - and a fresh empty options dict to every middleware call, so middleware - could not distinguish jobs in the batch. - """ - - class RecordingMiddleware(TaskMiddleware): - def __init__(self) -> None: - self.calls: list[tuple[tuple, dict]] = [] - - def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None: - self.calls.append((args, dict(kwargs))) - - mw = RecordingMiddleware() - q = Queue(db_path=str(tmp_path / "test.db"), middleware=[mw]) - - @q.task() - def add(a: int, b: int) -> int: - return a + b - - q.enqueue_many( - task_name=add.name, - args_list=[(1, 2), (3, 4), (5, 6)], - kwargs_list=[{"trace": "alpha"}, {"trace": "beta"}, {"trace": "gamma"}], - ) - - assert mw.calls == [ - ((1, 2), {"trace": "alpha"}), - ((3, 4), {"trace": "beta"}), - ((5, 6), {"trace": "gamma"}), - ] - - -def test_enqueue_many_applies_option_mutations(tmp_path: Any) -> None: - """Mutations to the options dict inside on_enqueue must propagate to the - enqueued jobs — matching the documented behaviour of single-enqueue. - Regression: the previous implementation discarded mutations because the - hook ran *after* `enqueue_batch` and against a fresh empty dict. - """ - - class PerJobBoosterMiddleware(TaskMiddleware): - def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None: - # Bump priority based on the first argument so each job sees a - # distinct mutation. - options["priority"] = int(args[0]) * 10 - - q = Queue(db_path=str(tmp_path / "test.db"), middleware=[PerJobBoosterMiddleware()]) - - @q.task() - def task_one(n: int) -> int: - return n - - results = q.enqueue_many(task_name=task_one.name, args_list=[(1,), (2,), (3,)]) - - priorities = [q.get_job(r.id).to_dict()["priority"] for r in results] # type: ignore[union-attr] - assert priorities == [10, 20, 30] - - -def test_enqueue_many_logs_middleware_exceptions(tmp_path: Any, caplog: Any) -> None: - """Middleware exceptions must be logged, not silently swallowed. - - Regression: the previous implementation used a bare `except: pass`, - making misbehaving middleware effectively invisible. - """ - import logging - - class ExplodingMiddleware(TaskMiddleware): - def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None: - raise RuntimeError("middleware boom") - - q = Queue(db_path=str(tmp_path / "test.db"), middleware=[ExplodingMiddleware()]) - - @q.task() - def my_task(n: int) -> int: - return n - - with caplog.at_level(logging.ERROR, logger="taskito.app"): - results = q.enqueue_many(task_name=my_task.name, args_list=[(1,), (2,)]) - - # Jobs are still enqueued — middleware errors must not block enqueue - assert len(results) == 2 - - # And the error was surfaced via the logger - assert any("middleware on_enqueue() error" in rec.message for rec in caplog.records) - assert any("middleware boom" in (rec.exc_text or "") for rec in caplog.records) diff --git a/tests/test_cancel.py b/tests/test_cancel.py deleted file mode 100644 index 42c09e3..0000000 --- a/tests/test_cancel.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Tests for job cancellation.""" - -from __future__ import annotations - -import threading - -from taskito import Queue - - -def test_cancel_pending_job(queue: Queue) -> None: - """A pending job can be cancelled.""" - - @queue.task() - def slow_task() -> str: - return "done" - - job = slow_task.delay() - assert queue.cancel_job(job.id) is True - - refreshed = queue.get_job(job.id) - assert refreshed is not None - assert refreshed.status == "cancelled" - - -def test_cancel_nonexistent_job(queue: Queue) -> None: - """Cancelling a nonexistent job returns False.""" - - @queue.task() - def dummy() -> None: - pass - - assert queue.cancel_job("nonexistent-id") is False - - -def test_cancel_completed_job(queue: Queue) -> None: - """Cancelling a completed job returns False (only pending can be cancelled).""" - - @queue.task() - def quick_task() -> int: - return 42 - - job = quick_task.delay() - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - job.result(timeout=10) - assert queue.cancel_job(job.id) is False - - -def test_cancelled_in_stats(queue: Queue) -> None: - """Cancelled jobs show up in stats.""" - - @queue.task() - def task_a() -> None: - pass - - job = task_a.delay() - queue.cancel_job(job.id) - - stats = queue.stats() - assert stats["cancelled"] == 1 diff --git a/tests/test_chain.py b/tests/test_chain.py deleted file mode 100644 index a99a862..0000000 --- a/tests/test_chain.py +++ /dev/null @@ -1,82 +0,0 @@ -"""Tests for task chaining (chain, group, chord).""" - -from __future__ import annotations - -import threading -from pathlib import Path - -import pytest - -from taskito import Queue, chain, chord, group - - -@pytest.fixture -def queue(tmp_path: Path) -> Queue: - db_path = str(tmp_path / "test_chain.db") - q = Queue(db_path=db_path, workers=4) - - # Start worker in background - t = threading.Thread(target=q.run_worker, daemon=True) - t.start() - - return q - - -def test_chain_executes_in_order(queue: Queue) -> None: - """chain pipes results through signatures in order.""" - - @queue.task() - def add(a: int, b: int) -> int: - return a + b - - @queue.task() - def double(x: int) -> int: - return x * 2 - - result = chain(add.s(2, 3), double.s()) - last_job = result.apply(queue) - assert last_job.result(timeout=30) == 10 # (2+3) * 2 = 10 - - -def test_chain_with_immutable(queue: Queue) -> None: - """si() signatures ignore previous results.""" - - @queue.task() - def add(a: int, b: int) -> int: - return a + b - - @queue.task() - def constant() -> int: - return 99 - - result = chain(add.s(1, 2), constant.si()) - last_job = result.apply(queue) - assert last_job.result(timeout=30) == 99 - - -def test_group_parallel(queue: Queue) -> None: - """group enqueues tasks in parallel.""" - - @queue.task() - def square(x: int) -> int: - return x * x - - jobs = group(square.s(2), square.s(3), square.s(4)).apply(queue) - results = [j.result(timeout=30) for j in jobs] - assert sorted(results) == [4, 9, 16] - - -def test_chord_callback(queue: Queue) -> None: - """chord runs group, then callback with collected results.""" - - @queue.task() - def add(a: int, b: int) -> int: - return a + b - - @queue.task() - def total(results: list[int]) -> int: - return sum(results) - - grp = group(add.s(1, 2), add.s(3, 4), add.s(5, 6)) - result_job = chord(grp, total.s()).apply(queue) - assert result_job.result(timeout=30) == 21 # 3 + 7 + 11 = 21 diff --git a/tests/test_cli.py b/tests/test_cli.py deleted file mode 100644 index 174e079..0000000 --- a/tests/test_cli.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Tests for CLI info command.""" - -from pathlib import Path - -import pytest - -from taskito.cli import _load_queue, _print_stats - - -def test_load_queue_invalid_format() -> None: - """_load_queue rejects paths without a colon.""" - with pytest.raises(SystemExit): - _load_queue("no_colon_here") - - -def test_load_queue_missing_module() -> None: - """_load_queue exits on missing module.""" - with pytest.raises(SystemExit): - _load_queue("nonexistent.module:queue") - - -def test_print_stats_format(capsys: pytest.CaptureFixture[str], tmp_path: Path) -> None: - """_print_stats prints a formatted stats table.""" - from taskito import Queue - - db_path = str(tmp_path / "test_cli_stats.db") - queue = Queue(db_path=db_path) - - @queue.task() - def noop() -> None: - pass - - noop.delay() - noop.delay() - - _print_stats(queue) - output = capsys.readouterr().out - - assert "taskito queue statistics" in output - assert "pending" in output - assert "total" in output diff --git a/tests/test_context.py b/tests/test_context.py deleted file mode 100644 index b76f284..0000000 --- a/tests/test_context.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Tests for job context — current_job inside running tasks.""" - -import threading -from pathlib import Path -from typing import Any - -import pytest - -from taskito import Queue -from taskito.context import current_job - - -@pytest.fixture -def queue(tmp_path: Path) -> Queue: - db_path = str(tmp_path / "test_context.db") - return Queue(db_path=db_path, workers=1) - - -def test_current_job_raises_outside_task() -> None: - """current_job properties raise RuntimeError outside a task.""" - with pytest.raises(RuntimeError, match="No active job context"): - _ = current_job.id - - -def test_current_job_id_available_in_task(queue: Queue) -> None: - """current_job.id is accessible inside a running task.""" - captured: dict[str, Any] = {} - - @queue.task() - def capture_context() -> str: - captured["id"] = current_job.id - captured["task_name"] = current_job.task_name - captured["retry_count"] = current_job.retry_count - captured["queue_name"] = current_job.queue_name - return "ok" - - job = capture_context.delay() - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - result = job.result(timeout=10) - assert result == "ok" - assert captured["id"] == job.id - assert captured["task_name"].endswith("capture_context") - assert captured["retry_count"] == 0 - assert captured["queue_name"] == "default" - - -def test_current_job_update_progress(queue: Queue) -> None: - """current_job.update_progress() works inside a running task.""" - - @queue.task() - def task_with_progress() -> str: - current_job.update_progress(50) - current_job.update_progress(100) - return "done" - - job = task_with_progress.delay() - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - result = job.result(timeout=10) - assert result == "done" diff --git a/tests/test_contrib.py b/tests/test_contrib.py deleted file mode 100644 index 5df8773..0000000 --- a/tests/test_contrib.py +++ /dev/null @@ -1,325 +0,0 @@ -"""Tests for contrib middleware modules using mocks (no hard dependencies).""" - -from __future__ import annotations - -import types -from typing import Any -from unittest.mock import MagicMock, patch - -# ── Helpers ────────────────────────────────────────────────────────── - - -def _make_ctx( - job_id: str = "job-1", - task_name: str = "my_task", - queue_name: str = "default", - retry_count: int = 0, -) -> MagicMock: - ctx = MagicMock() - ctx.id = job_id - ctx.task_name = task_name - ctx.queue_name = queue_name - ctx.retry_count = retry_count - return ctx - - -# ── OpenTelemetry ──────────────────────────────────────────────────── - - -class TestOpenTelemetryMiddleware: - def test_before_starts_span(self) -> None: - otel = _try_import_otel() - if otel is None: - return - - mock_tracer = MagicMock() - mock_span = MagicMock() - mock_tracer.start_span.return_value = mock_span - - mw = otel.OpenTelemetryMiddleware.__new__(otel.OpenTelemetryMiddleware) - import threading - - mw._tracer = mock_tracer - mw._span_name_fn = None - mw._attr_prefix = "taskito" - mw._extra_attributes_fn = None - mw._task_filter = None - mw._spans = {} - mw._lock = threading.Lock() - - ctx = _make_ctx() - mw.before(ctx) - - mock_tracer.start_span.assert_called_once() - assert "job-1" in mw._spans - - def test_after_ends_span_success(self) -> None: - otel = _try_import_otel() - if otel is None: - return - - mock_span = MagicMock() - mw = otel.OpenTelemetryMiddleware.__new__(otel.OpenTelemetryMiddleware) - import threading - - mw._tracer = MagicMock() - mw._span_name_fn = None - mw._attr_prefix = "taskito" - mw._extra_attributes_fn = None - mw._task_filter = None - mw._spans = {"job-1": mock_span} - mw._lock = threading.Lock() - - ctx = _make_ctx() - mw.after(ctx, result="ok", error=None) - - mock_span.set_status.assert_called_once() - mock_span.end.assert_called_once() - mock_span.record_exception.assert_not_called() - - def test_after_records_exception_on_error(self) -> None: - otel = _try_import_otel() - if otel is None: - return - - mock_span = MagicMock() - mw = otel.OpenTelemetryMiddleware.__new__(otel.OpenTelemetryMiddleware) - import threading - - mw._tracer = MagicMock() - mw._span_name_fn = None - mw._attr_prefix = "taskito" - mw._extra_attributes_fn = None - mw._task_filter = None - mw._spans = {"job-1": mock_span} - mw._lock = threading.Lock() - - ctx = _make_ctx() - exc = ValueError("boom") - mw.after(ctx, result=None, error=exc) - - mock_span.record_exception.assert_called_once_with(exc) - mock_span.end.assert_called_once() - - -def _try_import_otel() -> types.ModuleType | None: - """Import otel module with mocked opentelemetry if not installed.""" - try: - import sys - - # Provide mock opentelemetry modules if not installed - mock_trace = MagicMock() - mock_trace.StatusCode.OK = "OK" - mock_trace.StatusCode.ERROR = "ERROR" - - with patch.dict( - sys.modules, - { - "opentelemetry": MagicMock(), - "opentelemetry.trace": mock_trace, - }, - ): - if "taskito.contrib.otel" in sys.modules: - del sys.modules["taskito.contrib.otel"] - from taskito.contrib import otel - - # Patch module-level references - otel.trace = mock_trace - otel.StatusCode = mock_trace.StatusCode - return otel - except Exception: - return None - - -# ── Sentry ─────────────────────────────────────────────────────────── - - -class TestSentryMiddleware: - def test_before_pushes_scope(self) -> None: - sentry_mod = _try_import_sentry() - if sentry_mod is None: - return - - mock_sdk = sentry_mod.sentry_sdk - ctx = _make_ctx() - - mw = sentry_mod.SentryMiddleware.__new__(sentry_mod.SentryMiddleware) - mw._tag_prefix = "taskito" - mw._transaction_name_fn = None - mw._task_filter = None - mw._extra_tags_fn = None - mw.before(ctx) - - mock_sdk.push_scope.assert_called_once() - - def test_after_pops_scope(self) -> None: - sentry_mod = _try_import_sentry() - if sentry_mod is None: - return - - mock_sdk = sentry_mod.sentry_sdk - ctx = _make_ctx() - - mw = sentry_mod.SentryMiddleware.__new__(sentry_mod.SentryMiddleware) - mw._tag_prefix = "taskito" - mw._transaction_name_fn = None - mw._task_filter = None - mw._extra_tags_fn = None - mw.after(ctx, result="ok", error=None) - - mock_sdk.pop_scope_unsafe.assert_called_once() - - def test_after_captures_exception_on_error(self) -> None: - sentry_mod = _try_import_sentry() - if sentry_mod is None: - return - - mock_sdk = sentry_mod.sentry_sdk - ctx = _make_ctx() - exc = RuntimeError("oops") - - mw = sentry_mod.SentryMiddleware.__new__(sentry_mod.SentryMiddleware) - mw._tag_prefix = "taskito" - mw._transaction_name_fn = None - mw._task_filter = None - mw._extra_tags_fn = None - mw.after(ctx, result=None, error=exc) - - mock_sdk.capture_exception.assert_called_once_with(exc) - - -def _try_import_sentry() -> types.ModuleType | None: - try: - import sys - - mock_sdk = MagicMock() - with patch.dict(sys.modules, {"sentry_sdk": mock_sdk}): - if "taskito.contrib.sentry" in sys.modules: - del sys.modules["taskito.contrib.sentry"] - from taskito.contrib import sentry - - sentry.sentry_sdk = mock_sdk - return sentry - except Exception: - return None - - -# ── Prometheus ─────────────────────────────────────────────────────── - - -def _make_mock_metrics() -> dict[str, Any]: - """Create a mock metrics dict matching the instance-based store format.""" - return { - "jobs_total": MagicMock(), - "job_duration": MagicMock(), - "active_workers": MagicMock(), - "retries_total": MagicMock(), - "queue_depth": MagicMock(), - "dlq_size": MagicMock(), - "worker_utilization": MagicMock(), - "resource_health": MagicMock(), - "resource_recreations": MagicMock(), - "resource_init_duration": MagicMock(), - "proxy_reconstruct_duration": MagicMock(), - "proxy_reconstruct_total": MagicMock(), - "proxy_reconstruct_errors": MagicMock(), - "intercept_duration": MagicMock(), - "intercept_strategy_total": MagicMock(), - "pool_size": MagicMock(), - "pool_active": MagicMock(), - "pool_idle": MagicMock(), - "pool_timeouts": MagicMock(), - } - - -class TestPrometheusMiddleware: - def test_before_increments_active_workers(self) -> None: - prom = _try_import_prometheus() - if prom is None: - return - - import threading - - metrics = _make_mock_metrics() - mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware) - mw._metrics = metrics - mw._extra_labels_fn = None - mw._task_filter = None - mw._start_times = {} - mw._lock = threading.Lock() - - ctx = _make_ctx() - mw.before(ctx) - - metrics["active_workers"].inc.assert_called() - - def test_after_tracks_counter_and_histogram(self) -> None: - prom = _try_import_prometheus() - if prom is None: - return - - import threading - - metrics = _make_mock_metrics() - mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware) - mw._metrics = metrics - mw._extra_labels_fn = None - mw._task_filter = None - mw._start_times = {"job-1": 0.0} - mw._lock = threading.Lock() - - ctx = _make_ctx() - mw.after(ctx, result="ok", error=None) - - metrics["active_workers"].dec.assert_called() - metrics["jobs_total"].labels.assert_called_with(task="my_task", status="completed") - metrics["jobs_total"].labels().inc.assert_called() - - def test_after_tracks_failure(self) -> None: - prom = _try_import_prometheus() - if prom is None: - return - - import threading - - metrics = _make_mock_metrics() - mw = prom.PrometheusMiddleware.__new__(prom.PrometheusMiddleware) - mw._metrics = metrics - mw._extra_labels_fn = None - mw._task_filter = None - mw._start_times = {"job-1": 0.0} - mw._lock = threading.Lock() - - ctx = _make_ctx() - exc = ValueError("fail") - mw.after(ctx, result=None, error=exc) - - metrics["jobs_total"].labels.assert_called_with(task="my_task", status="failed") - - -def _try_import_prometheus() -> types.ModuleType | None: - try: - import sys - - mock_counter = MagicMock() - mock_gauge = MagicMock() - mock_histogram = MagicMock() - - with patch.dict( - sys.modules, - { - "prometheus_client": MagicMock( - Counter=mock_counter, - Gauge=mock_gauge, - Histogram=mock_histogram, - start_http_server=MagicMock(), - ), - }, - ): - if "taskito.contrib.prometheus" in sys.modules: - del sys.modules["taskito.contrib.prometheus"] - from taskito.contrib import prometheus - - return prometheus - except Exception: - return None diff --git a/tests/test_customizability.py b/tests/test_customizability.py deleted file mode 100644 index b33f3c0..0000000 --- a/tests/test_customizability.py +++ /dev/null @@ -1,297 +0,0 @@ -"""Tests for customizability configuration options.""" - -from __future__ import annotations - -from typing import Any -from unittest.mock import MagicMock - -from taskito.app import Queue -from taskito.events import EventType -from taskito.middleware import TaskMiddleware -from taskito.webhooks import WebhookManager - -# ── Middleware Hooks ────────────────────────────────────────────────── - - -class RecordingMiddleware(TaskMiddleware): - """Middleware that records all hook calls for testing.""" - - def __init__(self) -> None: - self.calls: list[tuple[str, Any]] = [] - - def before(self, ctx: Any) -> None: - self.calls.append(("before", ctx.task_name)) - - def after(self, ctx: Any, result: Any, error: Any) -> None: - self.calls.append(("after", ctx.task_name)) - - def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None: - self.calls.append(("on_enqueue", task_name)) - - def on_dead_letter(self, ctx: Any, error: Exception) -> None: - self.calls.append(("on_dead_letter", ctx.task_name)) - - def on_timeout(self, ctx: Any) -> None: - self.calls.append(("on_timeout", ctx.task_name)) - - def on_cancel(self, ctx: Any) -> None: - self.calls.append(("on_cancel", ctx.task_name)) - - -class TestMiddlewareHooks: - def test_on_enqueue_called(self, tmp_path: Any) -> None: - mw = RecordingMiddleware() - q = Queue(db_path=str(tmp_path / "test.db"), middleware=[mw]) - - @q.task() - def my_task() -> None: - pass - - my_task.delay() - assert ("on_enqueue", my_task.name) in mw.calls - - def test_on_enqueue_can_mutate_options(self, tmp_path: Any) -> None: - """on_enqueue can modify the options dict to change enqueue params.""" - - class PriorityBoostMiddleware(TaskMiddleware): - def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None: - options["priority"] = 99 - - mw = PriorityBoostMiddleware() - q = Queue(db_path=str(tmp_path / "test.db"), middleware=[mw]) - - @q.task() - def my_task() -> None: - pass - - result = my_task.delay() - job = q.get_job(result.id) - assert job is not None - assert job.to_dict()["priority"] == 99 - - def test_default_hooks_are_noop(self) -> None: - """Base TaskMiddleware hooks should not raise.""" - mw = TaskMiddleware() - mw.on_enqueue("test", (), {}, {}) - mw.on_dead_letter(MagicMock(), Exception("test")) - mw.on_timeout(MagicMock()) - mw.on_cancel(MagicMock()) - - -# ── Event System ───────────────────────────────────────────────────── - - -class TestEventSystem: - def test_new_event_types_exist(self) -> None: - assert EventType.WORKER_STARTED.value == "worker.started" - assert EventType.WORKER_STOPPED.value == "worker.stopped" - assert EventType.QUEUE_PAUSED.value == "queue.paused" - assert EventType.QUEUE_RESUMED.value == "queue.resumed" - - def test_event_workers_param(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db"), event_workers=2) - assert q._event_bus._executor._max_workers == 2 - - def test_on_event_public_api(self, tmp_path: Any, poll_until: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - received: list[Any] = [] - - def callback(event_type: EventType, payload: dict) -> None: - received.append((event_type, payload)) - - q.on_event(EventType.JOB_ENQUEUED, callback) - - @q.task() - def my_task() -> None: - pass - - my_task.delay() - poll_until(lambda: len(received) >= 1, message="JOB_ENQUEUED event not delivered") - assert len(received) == 1 - assert received[0][0] == EventType.JOB_ENQUEUED - - -# ── Webhook Configuration ──────────────────────────────────────────── - - -class TestWebhookConfig: - def test_add_webhook_with_retry_params(self) -> None: - mgr = WebhookManager() - mgr.add_webhook( - "https://example.com/hook", - max_retries=5, - timeout=30.0, - retry_backoff=3.0, - ) - wh = mgr._webhooks[0] - assert wh["max_retries"] == 5 - assert wh["timeout"] == 30.0 - assert wh["retry_backoff"] == 3.0 - - def test_add_webhook_defaults(self) -> None: - mgr = WebhookManager() - mgr.add_webhook("https://example.com/hook") - wh = mgr._webhooks[0] - assert wh["max_retries"] == 3 - assert wh["timeout"] == 10.0 - assert wh["retry_backoff"] == 2.0 - - def test_queue_add_webhook_passes_params(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - q.add_webhook( - "https://example.com/hook", - max_retries=1, - timeout=5.0, - retry_backoff=1.5, - ) - wh = q._webhook_manager._webhooks[0] - assert wh["max_retries"] == 1 - assert wh["timeout"] == 5.0 - - -# ── Queue Configuration ────────────────────────────────────────────── - - -class TestQueueConfig: - def test_scheduler_timing_params(self, tmp_path: Any) -> None: - q = Queue( - db_path=str(tmp_path / "test.db"), - scheduler_poll_interval_ms=100, - scheduler_reap_interval=50, - scheduler_cleanup_interval=600, - ) - # These are passed to the Rust side — just verify they don't error - assert q._inner is not None - - def test_scheduler_timing_defaults(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - # Defaults should work fine - assert q._inner is not None - - -# ── Per-Task Configuration ─────────────────────────────────────────── - - -class TestPerTaskConfig: - def test_max_retry_delay_param(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - - @q.task(max_retry_delay=60) - def my_task() -> None: - pass - - config = q._task_configs[-1] - assert config.max_retry_delay == 60 - - def test_max_retry_delay_default(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - - @q.task() - def my_task() -> None: - pass - - config = q._task_configs[-1] - assert config.max_retry_delay is None - - def test_max_concurrent_param(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - - @q.task(max_concurrent=5) - def my_task() -> None: - pass - - config = q._task_configs[-1] - assert config.max_concurrent == 5 - - def test_max_concurrent_default(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - - @q.task() - def my_task() -> None: - pass - - config = q._task_configs[-1] - assert config.max_concurrent is None - - -# ── Per-Task Serializer ────────────────────────────────────────────── - - -class TestPerTaskSerializer: - def test_task_level_serializer_used_for_enqueue(self, tmp_path: Any) -> None: - """Per-task serializer is used instead of queue-level serializer.""" - mock_serializer = MagicMock() - mock_serializer.dumps.return_value = b"\x80\x04\x95" - - q = Queue(db_path=str(tmp_path / "test.db")) - - @q.task(serializer=mock_serializer) - def my_task(x: int) -> None: - pass - - my_task.delay(42) - mock_serializer.dumps.assert_called_once() - - def test_queue_serializer_used_when_no_task_serializer(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - - @q.task() - def my_task() -> None: - pass - - # Should use the default CloudpickleSerializer without error - my_task.delay() - assert my_task.name not in q._task_serializers - - -# ── Flask CLI ───────────────────────────────────────────────────────── - - -# ── Queue-Level Limits ──────────────────────────────────────────────── - - -class TestQueueLevelLimits: - def test_set_queue_rate_limit(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - q.set_queue_rate_limit("default", "100/m") - assert q._queue_configs["default"]["rate_limit"] == "100/m" - - def test_set_queue_concurrency(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - q.set_queue_concurrency("default", 10) - assert q._queue_configs["default"]["max_concurrent"] == 10 - - def test_set_both_on_same_queue(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - q.set_queue_rate_limit("emails", "50/m") - q.set_queue_concurrency("emails", 5) - assert q._queue_configs["emails"]["rate_limit"] == "50/m" - assert q._queue_configs["emails"]["max_concurrent"] == 5 - - def test_queue_configs_serialized_to_json(self, tmp_path: Any) -> None: - import json - - q = Queue(db_path=str(tmp_path / "test.db")) - q.set_queue_rate_limit("default", "10/s") - q.set_queue_concurrency("default", 3) - serialized = json.dumps(q._queue_configs) - parsed = json.loads(serialized) - assert parsed["default"]["rate_limit"] == "10/s" - assert parsed["default"]["max_concurrent"] == 3 - - -# ── Flask CLI ───────────────────────────────────────────────────────── - - -class TestFlaskConfig: - def test_cli_group_param(self) -> None: - from taskito.contrib.flask import Taskito - - ext = Taskito(cli_group="jobs") - assert ext._cli_group == "jobs" - - def test_cli_group_default(self) -> None: - from taskito.contrib.flask import Taskito - - ext = Taskito() - assert ext._cli_group == "taskito" diff --git a/tests/test_dashboard.py b/tests/test_dashboard.py deleted file mode 100644 index dc281e1..0000000 --- a/tests/test_dashboard.py +++ /dev/null @@ -1,403 +0,0 @@ -"""Tests for dashboard API endpoints and list_jobs/to_dict.""" - -import json -import threading -import urllib.error -import urllib.request -from collections.abc import Generator -from pathlib import Path -from typing import Any - -import pytest - -from taskito import Queue - - -@pytest.fixture -def queue(tmp_path: Path) -> Queue: - """Create a fresh queue with some test data pre-registered.""" - db_path = str(tmp_path / "test_dashboard.db") - q = Queue(db_path=db_path, workers=2) - - @q.task(queue="default") - def task_a(x: int) -> int: - return x * 2 - - @q.task(queue="email") - def task_b(x: int) -> int: - return x + 1 - - return q - - -@pytest.fixture -def populated_queue(queue: Queue) -> tuple[Queue, list[Any]]: - """Queue with several jobs enqueued.""" - task_a_name: str = "" - task_b_name: str = "" - for name, _fn in queue._task_registry.items(): - if "task_a" in name: - task_a_name = name - elif "task_b" in name: - task_b_name = name - - jobs: list[Any] = [] - for i in range(5): - jobs.append(queue.enqueue(task_a_name, args=(i,))) - for i in range(3): - jobs.append(queue.enqueue(task_b_name, args=(i,), queue="email")) - return queue, jobs - - -# ── list_jobs tests ────────────────────────────────────── - - -def test_list_jobs_returns_all(populated_queue: tuple[Queue, list[Any]]) -> None: - """list_jobs() with no filters returns all jobs.""" - queue, _ = populated_queue - result = queue.list_jobs() - assert len(result) == 8 - - -def test_list_jobs_filter_by_queue(populated_queue: tuple[Queue, list[Any]]) -> None: - """list_jobs() can filter by queue name.""" - queue, _ = populated_queue - result = queue.list_jobs(queue="email") - assert len(result) == 3 - for j in result: - d = j.to_dict() - assert d["queue"] == "email" - - -def test_list_jobs_filter_by_status(populated_queue: tuple[Queue, list[Any]]) -> None: - """list_jobs() can filter by status.""" - queue, _ = populated_queue - result = queue.list_jobs(status="pending") - assert len(result) == 8 # all are pending - - result = queue.list_jobs(status="running") - assert len(result) == 0 - - -def test_list_jobs_filter_by_task_name(populated_queue: tuple[Queue, list[Any]]) -> None: - """list_jobs() can filter by task name.""" - queue, _ = populated_queue - # Find the task_a name - task_a_name = None - for name in queue._task_registry: - if "task_a" in name: - task_a_name = name - break - - result = queue.list_jobs(task_name=task_a_name) - assert len(result) == 5 - - -def test_list_jobs_pagination(populated_queue: tuple[Queue, list[Any]]) -> None: - """list_jobs() respects limit and offset.""" - queue, _ = populated_queue - page1 = queue.list_jobs(limit=3, offset=0) - page2 = queue.list_jobs(limit=3, offset=3) - - assert len(page1) == 3 - assert len(page2) == 3 - - ids1 = {j.id for j in page1} - ids2 = {j.id for j in page2} - assert ids1.isdisjoint(ids2) - - -def test_list_jobs_invalid_status(queue: Queue) -> None: - """list_jobs() raises on invalid status string.""" - with pytest.raises(ValueError): - queue.list_jobs(status="bogus") - - -# ── to_dict tests ──────────────────────────────────────── - - -def test_to_dict_fields(queue: Queue) -> None: - """to_dict() returns all expected fields.""" - - @queue.task() - def dummy() -> None: - pass - - job = dummy.delay() - d = job.to_dict() - - expected_keys = { - "id", - "queue", - "task_name", - "status", - "priority", - "progress", - "retry_count", - "max_retries", - "created_at", - "scheduled_at", - "started_at", - "completed_at", - "error", - "timeout_ms", - "unique_key", - "metadata", - "namespace", - } - assert set(d.keys()) == expected_keys - assert d["status"] == "pending" - assert d["id"] == job.id - - -def test_to_dict_is_json_serializable(queue: Queue) -> None: - """to_dict() output can be serialized to JSON.""" - - @queue.task() - def dummy() -> None: - pass - - job = dummy.delay() - d = job.to_dict() - serialized = json.dumps(d) - assert isinstance(serialized, str) - - -# ── Dashboard HTTP tests ───────────────────────────────── - - -def _start_dashboard(queue: Queue, *, static_assets: Any = None) -> tuple[str, Any]: - """Boot a dashboard server bound to a random port. - - Returns the (url, server) pair so callers can shut it down explicitly. - Optionally takes a ``StaticAssets`` instance — production code uses - the package-bundled default; tests inject their own. - """ - from http.server import ThreadingHTTPServer - - from taskito.dashboard import _make_handler - - handler = _make_handler(queue, static_assets=static_assets) - server = ThreadingHTTPServer(("127.0.0.1", 0), handler) - port = server.server_address[1] - thread = threading.Thread(target=server.serve_forever, daemon=True) - thread.start() - return f"http://127.0.0.1:{port}", server - - -@pytest.fixture -def dashboard_server( - populated_queue: tuple[Queue, list[Any]], -) -> Generator[tuple[str, Queue, list[Any]]]: - """Start a dashboard server on a random port.""" - queue, jobs = populated_queue - url, server = _start_dashboard(queue) - try: - yield url, queue, jobs - finally: - server.shutdown() - - -def _get(url: str) -> Any: - """GET request and parse JSON.""" - with urllib.request.urlopen(url) as resp: - return json.loads(resp.read()) - - -def _post(url: str) -> Any: - """POST request and parse JSON.""" - req = urllib.request.Request(url, method="POST", data=b"") - with urllib.request.urlopen(req) as resp: - return json.loads(resp.read()) - - -def test_api_stats(dashboard_server: tuple[str, Queue, list[Any]]) -> None: - """GET /api/stats returns valid stats dict.""" - base, _, __ = dashboard_server - data = _get(f"{base}/api/stats") - assert "pending" in data - assert data["pending"] == 8 - - -def test_api_jobs_list(dashboard_server: tuple[str, Queue, list[Any]]) -> None: - """GET /api/jobs returns job list.""" - base, _, __ = dashboard_server - data = _get(f"{base}/api/jobs") - assert isinstance(data, list) - assert len(data) == 8 - - -def test_api_jobs_filter_status(dashboard_server: tuple[str, Queue, list[Any]]) -> None: - """GET /api/jobs?status=pending filters correctly.""" - base, _, __ = dashboard_server - data = _get(f"{base}/api/jobs?status=pending") - assert len(data) == 8 - - data = _get(f"{base}/api/jobs?status=running") - assert len(data) == 0 - - -def test_api_jobs_filter_queue(dashboard_server: tuple[str, Queue, list[Any]]) -> None: - """GET /api/jobs?queue=email filters correctly.""" - base, _, __ = dashboard_server - data = _get(f"{base}/api/jobs?queue=email") - assert len(data) == 3 - - -def test_api_jobs_pagination(dashboard_server: tuple[str, Queue, list[Any]]) -> None: - """GET /api/jobs?limit=3&offset=0 paginates.""" - base, _, __ = dashboard_server - data = _get(f"{base}/api/jobs?limit=3&offset=0") - assert len(data) == 3 - - -def test_api_job_detail(dashboard_server: tuple[str, Queue, list[Any]]) -> None: - """GET /api/jobs/{id} returns job dict.""" - base, _, jobs = dashboard_server - job_id = jobs[0].id - data = _get(f"{base}/api/jobs/{job_id}") - assert data["id"] == job_id - assert "status" in data - - -def test_api_job_not_found(dashboard_server: tuple[str, Queue, list[Any]]) -> None: - """GET /api/jobs/nonexistent returns 404.""" - base, _, __ = dashboard_server - try: - _get(f"{base}/api/jobs/nonexistent-id") - raise AssertionError("Expected 404") - except urllib.error.HTTPError as e: - assert e.code == 404 - - -def test_api_cancel_job(dashboard_server: tuple[str, Queue, list[Any]]) -> None: - """POST /api/jobs/{id}/cancel cancels a pending job.""" - base, _, jobs = dashboard_server - job_id = jobs[0].id - data = _post(f"{base}/api/jobs/{job_id}/cancel") - assert data["cancelled"] is True - - -def test_api_dead_letters_empty(dashboard_server: tuple[str, Queue, list[Any]]) -> None: - """GET /api/dead-letters returns empty list initially.""" - base, _, __ = dashboard_server - data = _get(f"{base}/api/dead-letters") - assert data == [] - - -def test_spa_html_served( - populated_queue: tuple[Queue, list[Any]], - tmp_path: Path, -) -> None: - """GET / serves the SPA index.html when assets are bundled. - - Tests inject a ``StaticAssets`` rooted at a tmp directory so the test - is self-contained — no dependency on a prior frontend build. - """ - from taskito.dashboard import StaticAssets - - tmp_path.joinpath("index.html").write_text( - '
', - encoding="utf-8", - ) - queue, _ = populated_queue - url, server = _start_dashboard(queue, static_assets=StaticAssets(tmp_path)) - try: - with urllib.request.urlopen(url) as resp: - assert resp.status == 200 - assert resp.headers.get("Content-Type", "").startswith("text/html") - html = resp.read().decode() - assert "" in html.lower() - assert 'id="app"' in html - finally: - server.shutdown() - - -def test_spa_assets_resolved_under_root( - populated_queue: tuple[Queue, list[Any]], - tmp_path: Path, -) -> None: - """Hashed asset paths resolve under the bundle root and get long - immutable cache headers.""" - from taskito.dashboard import StaticAssets - - tmp_path.joinpath("index.html").write_text("", encoding="utf-8") - assets_dir = tmp_path / "assets" - assets_dir.mkdir() - assets_dir.joinpath("index-abc.js").write_text("export {};", encoding="utf-8") - - queue, _ = populated_queue - url, server = _start_dashboard(queue, static_assets=StaticAssets(tmp_path)) - try: - with urllib.request.urlopen(f"{url}/assets/index-abc.js") as resp: - assert resp.status == 200 - assert resp.headers.get("Content-Type", "").startswith("application/javascript") - assert "immutable" in resp.headers.get("Cache-Control", "") - assert resp.read().decode() == "export {};" - finally: - server.shutdown() - - -def test_spa_unknown_route_falls_back_to_index( - populated_queue: tuple[Queue, list[Any]], - tmp_path: Path, -) -> None: - """Client-side routes (anything that's not /assets/* or a real file) - resolve to index.html so deep links keep working.""" - from taskito.dashboard import StaticAssets - - tmp_path.joinpath("index.html").write_text( - '
', - encoding="utf-8", - ) - queue, _ = populated_queue - url, server = _start_dashboard(queue, static_assets=StaticAssets(tmp_path)) - try: - with urllib.request.urlopen(f"{url}/jobs/some-id") as resp: - assert resp.status == 200 - assert 'id="app"' in resp.read().decode() - finally: - server.shutdown() - - -def test_spa_missing_asset_under_assets_returns_404( - populated_queue: tuple[Queue, list[Any]], - tmp_path: Path, -) -> None: - """A miss inside ``/assets/`` returns 404 — never the index fallback, - so a stale page can't confuse the browser into running old chunks.""" - from taskito.dashboard import StaticAssets - - tmp_path.joinpath("index.html").write_text("", encoding="utf-8") - queue, _ = populated_queue - url, server = _start_dashboard(queue, static_assets=StaticAssets(tmp_path)) - try: - try: - urllib.request.urlopen(f"{url}/assets/missing.js") - pytest.fail("Expected 404") - except urllib.error.HTTPError as exc: - assert exc.code == 404 - finally: - server.shutdown() - - -def test_spa_missing_assets_returns_503( - populated_queue: tuple[Queue, list[Any]], -) -> None: - """When the frontend build wasn't run, the dashboard returns 503 with - actionable rebuild instructions rather than silently 404-ing.""" - from taskito.dashboard import StaticAssets - - queue, _ = populated_queue - url, server = _start_dashboard(queue, static_assets=StaticAssets(None)) - try: - try: - urllib.request.urlopen(url) - pytest.fail("Expected 503") - except urllib.error.HTTPError as exc: - assert exc.code == 503 - body = exc.read().decode() - assert "not bundled" in body.lower() - assert "pnpm" in body.lower() - finally: - server.shutdown() diff --git a/tests/test_dashboard_settings.py b/tests/test_dashboard_settings.py deleted file mode 100644 index ee9888d..0000000 --- a/tests/test_dashboard_settings.py +++ /dev/null @@ -1,199 +0,0 @@ -"""Tests for the dashboard settings key/value store. - -Covers: -- ``Queue.get_setting`` / ``set_setting`` / ``delete_setting`` / ``list_settings`` -- HTTP endpoints under ``/api/settings`` -""" - -from __future__ import annotations - -import json -import threading -import urllib.error -import urllib.request -from collections.abc import Generator -from pathlib import Path -from typing import Any - -import pytest - -from taskito import Queue - - -@pytest.fixture -def queue(tmp_path: Path) -> Queue: - return Queue(db_path=str(tmp_path / "settings.db")) - - -def _put(url: str, body: dict) -> Any: - req = urllib.request.Request( - url, - method="PUT", - data=json.dumps(body).encode(), - headers={"Content-Type": "application/json"}, - ) - with urllib.request.urlopen(req) as resp: - return json.loads(resp.read()) - - -def _delete(url: str) -> Any: - req = urllib.request.Request(url, method="DELETE") - with urllib.request.urlopen(req) as resp: - return json.loads(resp.read()) - - -def _get(url: str) -> Any: - with urllib.request.urlopen(url) as resp: - return json.loads(resp.read()) - - -# ── Python API ────────────────────────────────────────── - - -def test_get_setting_returns_none_when_unset(queue: Queue) -> None: - assert queue.get_setting("missing") is None - - -def test_set_and_get_setting(queue: Queue) -> None: - queue.set_setting("dashboard.title", "My Queue") - assert queue.get_setting("dashboard.title") == "My Queue" - - -def test_set_setting_overwrites(queue: Queue) -> None: - queue.set_setting("k", "v1") - queue.set_setting("k", "v2") - assert queue.get_setting("k") == "v2" - - -def test_delete_setting(queue: Queue) -> None: - queue.set_setting("k", "v") - assert queue.delete_setting("k") is True - assert queue.get_setting("k") is None - # Delete on missing key is a no-op returning False. - assert queue.delete_setting("k") is False - - -def test_list_settings_returns_all(queue: Queue) -> None: - queue.set_setting("a", "1") - queue.set_setting("b", "2") - snapshot = queue.list_settings() - assert snapshot == {"a": "1", "b": "2"} - - -def test_setting_preserves_unicode(queue: Queue) -> None: - queue.set_setting("greeting", "안녕하세요 🌏") - assert queue.get_setting("greeting") == "안녕하세요 🌏" - - -def test_setting_preserves_json(queue: Queue) -> None: - payload = json.dumps({"label": "Grafana", "url": "https://example/dash"}) - queue.set_setting("dashboard.links.0", payload) - assert json.loads(queue.get_setting("dashboard.links.0") or "") == { - "label": "Grafana", - "url": "https://example/dash", - } - - -# ── HTTP endpoints ────────────────────────────────────── - - -@pytest.fixture -def dashboard_server(queue: Queue) -> Generator[tuple[str, Queue]]: - from http.server import ThreadingHTTPServer - - from taskito.dashboard import _make_handler - - handler = _make_handler(queue) - server = ThreadingHTTPServer(("127.0.0.1", 0), handler) - port = server.server_address[1] - thread = threading.Thread(target=server.serve_forever, daemon=True) - thread.start() - try: - yield f"http://127.0.0.1:{port}", queue - finally: - server.shutdown() - - -def test_get_settings_returns_empty_dict(dashboard_server: tuple[str, Queue]) -> None: - base, _ = dashboard_server - assert _get(f"{base}/api/settings") == {} - - -def test_put_then_get_setting(dashboard_server: tuple[str, Queue]) -> None: - base, _ = dashboard_server - _put(f"{base}/api/settings/dashboard.title", {"value": "My Queue"}) - - data = _get(f"{base}/api/settings/dashboard.title") - assert data == {"key": "dashboard.title", "value": "My Queue"} - - snapshot = _get(f"{base}/api/settings") - assert snapshot == {"dashboard.title": "My Queue"} - - -def test_put_setting_with_json_value(dashboard_server: tuple[str, Queue]) -> None: - """Non-string ``value`` is JSON-encoded before persistence.""" - base, queue = dashboard_server - payload = [ - {"label": "Grafana", "url": "https://grafana.example/d/abc"}, - {"label": "Sentry", "url": "https://sentry.example/issues"}, - ] - _put(f"{base}/api/settings/dashboard.external_links", {"value": payload}) - - stored = queue.get_setting("dashboard.external_links") - assert stored is not None - assert json.loads(stored) == payload - - -def test_get_unknown_setting_returns_404(dashboard_server: tuple[str, Queue]) -> None: - base, _ = dashboard_server - with pytest.raises(urllib.error.HTTPError) as exc_info: - _get(f"{base}/api/settings/missing.key") - assert exc_info.value.code == 404 - - -def test_put_setting_with_missing_value_field_returns_400( - dashboard_server: tuple[str, Queue], -) -> None: - base, _ = dashboard_server - with pytest.raises(urllib.error.HTTPError) as exc_info: - _put(f"{base}/api/settings/k", {"not_value": 1}) - assert exc_info.value.code == 400 - - -def test_put_setting_rejects_invalid_json_body(dashboard_server: tuple[str, Queue]) -> None: - base, _ = dashboard_server - req = urllib.request.Request( - f"{base}/api/settings/k", - method="PUT", - data=b"{not json", - headers={"Content-Type": "application/json"}, - ) - with pytest.raises(urllib.error.HTTPError) as exc_info: - urllib.request.urlopen(req) - assert exc_info.value.code == 400 - - -def test_delete_setting_returns_true_when_exists( - dashboard_server: tuple[str, Queue], -) -> None: - base, queue = dashboard_server - queue.set_setting("k", "v") - assert _delete(f"{base}/api/settings/k") == {"deleted": True} - assert queue.get_setting("k") is None - - -def test_delete_missing_setting_returns_false( - dashboard_server: tuple[str, Queue], -) -> None: - base, _ = dashboard_server - assert _delete(f"{base}/api/settings/missing") == {"deleted": False} - - -def test_settings_persist_across_queue_instances(tmp_path: Path) -> None: - """A fresh Queue instance pointed at the same DB sees prior writes.""" - db = str(tmp_path / "persist.db") - q1 = Queue(db_path=db) - q1.set_setting("k", "v") - - q2 = Queue(db_path=db) - assert q2.get_setting("k") == "v" diff --git a/tests/test_dashboard_static.py b/tests/test_dashboard_static.py deleted file mode 100644 index 4fc61c2..0000000 --- a/tests/test_dashboard_static.py +++ /dev/null @@ -1,142 +0,0 @@ -"""Tests for dashboard static asset resolution and Content-Type mapping.""" - -from __future__ import annotations - -from pathlib import Path - -import pytest - -from taskito.dashboard import _content_type_for, _resolve_static_node - - -@pytest.fixture -def static_root(tmp_path: Path) -> Path: - """Layout mirroring a built Vite SPA tree.""" - (tmp_path / "index.html").write_text("") - assets = tmp_path / "assets" - assets.mkdir() - (assets / "index-abc.js").write_text("// js") - (assets / "index-abc.css").write_text("/* css */") - (assets / "nested").mkdir() - (assets / "nested" / "deep.png").write_bytes(b"\x89PNG") - return tmp_path - - -# ── _resolve_static_node ──────────────────────────────────────────── - - -def test_resolve_index_html(static_root: Path) -> None: - node = _resolve_static_node(static_root, "/index.html") - assert node is not None - assert node.read_text() == "" - - -def test_resolve_hashed_asset(static_root: Path) -> None: - node = _resolve_static_node(static_root, "/assets/index-abc.js") - assert node is not None - assert node.read_text() == "// js" - - -def test_resolve_nested_asset(static_root: Path) -> None: - node = _resolve_static_node(static_root, "/assets/nested/deep.png") - assert node is not None - assert node.read_bytes() == b"\x89PNG" - - -def test_resolve_missing_file_returns_none(static_root: Path) -> None: - assert _resolve_static_node(static_root, "/assets/missing.js") is None - - -def test_resolve_directory_returns_none(static_root: Path) -> None: - # A directory matches joinpath but is_file() is False - assert _resolve_static_node(static_root, "/assets") is None - - -def test_resolve_empty_path_returns_none(static_root: Path) -> None: - assert _resolve_static_node(static_root, "") is None - assert _resolve_static_node(static_root, "/") is None - - -def test_resolve_rejects_parent_traversal(static_root: Path) -> None: - assert _resolve_static_node(static_root, "/../secret") is None - assert _resolve_static_node(static_root, "/assets/../../secret") is None - - -def test_resolve_rejects_current_directory(static_root: Path) -> None: - assert _resolve_static_node(static_root, "/./index.html") is None - - -def test_resolve_rejects_null_byte(static_root: Path) -> None: - assert _resolve_static_node(static_root, "/index.html\x00.png") is None - - -def test_resolve_rejects_backslash(static_root: Path) -> None: - # Windows-style separators should be rejected to avoid ambiguity - assert _resolve_static_node(static_root, "/assets\\index-abc.js") is None - - -def test_resolve_rejects_double_slash(static_root: Path) -> None: - # Empty segments from double slashes are rejected - assert _resolve_static_node(static_root, "/assets//index-abc.js") is None - - -# ── _content_type_for ────────────────────────────────────────────── - - -@pytest.mark.parametrize( - ("path", "expected"), - [ - ("/index.html", "text/html; charset=utf-8"), - ("/assets/index-abc.js", "application/javascript; charset=utf-8"), - ("/assets/index-abc.mjs", "application/javascript; charset=utf-8"), - ("/assets/index-abc.css", "text/css; charset=utf-8"), - ("/icon.svg", "image/svg+xml"), - ("/icon.png", "image/png"), - ("/favicon.ico", "image/x-icon"), - ("/fonts/inter.woff2", "font/woff2"), - ("/fonts/inter.woff", "font/woff"), - ("/app.webmanifest", "application/manifest+json"), - ("/data.json", "application/json; charset=utf-8"), - ("/unknown.bin", "application/octet-stream"), - ("/no-extension", "application/octet-stream"), - ], -) -def test_content_type_for(path: str, expected: str) -> None: - assert _content_type_for(path) == expected - - -def test_content_type_case_insensitive() -> None: - # Uppercase extensions should still match - assert _content_type_for("/IMAGE.PNG") == "image/png" - assert _content_type_for("/script.JS") == "application/javascript; charset=utf-8" - - -# ── _safe_path (log injection guard) ───────────────────────────────── - - -def test_safe_path_strips_crlf() -> None: - from taskito.dashboard.server import _safe_path - - assert _safe_path("/api/jobs\r\nFAKE LOG ENTRY") == "/api/jobsFAKE LOG ENTRY" - - -def test_safe_path_strips_null_byte() -> None: - from taskito.dashboard.server import _safe_path - - assert _safe_path("/api/jobs\x00admin") == "/api/jobsadmin" - - -def test_safe_path_strips_all_control_chars_except_tab() -> None: - from taskito.dashboard.server import _safe_path - - raw = "/api\x01\x02\x1f\x7fpath\twith-tab" - assert _safe_path(raw) == "/apipath\twith-tab" - - -def test_safe_path_truncates_long_input() -> None: - from taskito.dashboard.server import _safe_path - - raw = "/api/" + "x" * 1000 - out = _safe_path(raw) - assert len(out) == 256 - assert out.startswith("/api/x") diff --git a/tests/test_dependencies.py b/tests/test_dependencies.py deleted file mode 100644 index d5fa75d..0000000 --- a/tests/test_dependencies.py +++ /dev/null @@ -1,117 +0,0 @@ -"""Tests for task dependency feature.""" - -import threading - -import pytest - -from taskito import Queue - - -def test_enqueue_with_depends_on(queue: Queue) -> None: - """Jobs can declare dependencies on other jobs.""" - - @queue.task() - def step(x: int) -> int: - return x - - job_a = step.delay(1) - job_b = step.apply_async(args=(2,), depends_on=job_a.id) - - assert job_b.dependencies == [job_a.id] - assert job_a.dependents == [job_b.id] - - -def test_enqueue_with_multiple_deps(queue: Queue) -> None: - """Jobs can depend on multiple other jobs.""" - - @queue.task() - def step(x: int) -> int: - return x - - job_a = step.delay(1) - job_b = step.delay(2) - job_c = step.apply_async(args=(3,), depends_on=[job_a.id, job_b.id]) - - deps = set(job_c.dependencies) - assert deps == {job_a.id, job_b.id} - - -def test_depends_on_string_coercion(queue: Queue) -> None: - """depends_on accepts a single string ID.""" - - @queue.task() - def step(x: int) -> int: - return x - - job_a = step.delay(1) - # Pass as single string, not list - job_b = queue.enqueue( - task_name=step.name, - args=(2,), - depends_on=job_a.id, - ) - - assert job_b.dependencies == [job_a.id] - - -def test_dependency_blocks_execution(queue: Queue) -> None: - """Dependent job waits until dependency completes.""" - - @queue.task() - def step(x: int) -> int: - return x * 10 - - job_a = step.delay(1) - job_b = step.apply_async(args=(2,), depends_on=job_a.id) - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - # Both should complete — job_a first, then job_b - result_a = job_a.result(timeout=10) - result_b = job_b.result(timeout=10) - - assert result_a == 10 - assert result_b == 20 - - -def test_cascade_cancel_on_job_cancel(queue: Queue) -> None: - """Cancelling a job cascades to its dependents.""" - - @queue.task() - def step(x: int) -> int: - return x - - job_a = step.delay(1) - job_b = step.apply_async(args=(2,), depends_on=job_a.id) - job_c = step.apply_async(args=(3,), depends_on=job_b.id) - - queue.cancel_job(job_a.id) - - job_b.refresh() - job_c.refresh() - assert job_b.status == "cancelled" - assert job_c.status == "cancelled" - - -def test_no_dependencies_property_when_none(queue: Queue) -> None: - """Jobs without dependencies return empty list.""" - - @queue.task() - def step(x: int) -> int: - return x - - job = step.delay(1) - assert job.dependencies == [] - assert job.dependents == [] - - -def test_enqueue_rejects_missing_dependency(queue: Queue) -> None: - """Enqueuing with a nonexistent dependency raises an error.""" - - @queue.task() - def step(x: int) -> int: - return x - - with pytest.raises(RuntimeError): - step.apply_async(args=(1,), depends_on="nonexistent-id") diff --git a/tests/test_dlq.py b/tests/test_dlq.py deleted file mode 100644 index 14b484d..0000000 --- a/tests/test_dlq.py +++ /dev/null @@ -1,48 +0,0 @@ -"""Tests for dead letter queue management.""" - -import threading -from typing import Any - -from taskito import Queue - -PollUntil = Any # the conftest fixture's runtime type - - -def test_dead_letters_empty(queue: Queue) -> None: - """Empty DLQ returns empty list.""" - dead = queue.dead_letters() - assert dead == [] - - -def test_purge_dead(queue: Queue, poll_until: PollUntil) -> None: - """Purging removes old dead letter entries.""" - - @queue.task(max_retries=0, retry_backoff=0.1) - def instant_fail() -> None: - raise RuntimeError("fail") - - instant_fail.delay() - - worker_thread = threading.Thread( - target=queue.run_worker, - daemon=True, - ) - worker_thread.start() - - poll_until( - lambda: len(queue.dead_letters()) >= 1, - timeout=10, - message="failed job did not reach DLQ", - ) - dead = queue.dead_letters() - assert len(dead) >= 1 - - # Purge entries older than 0 seconds (purge everything) - purged = queue.purge_dead(older_than=0) - # Note: purge uses "older_than" seconds ago, so older_than=0 means - # cutoff is now, which purges everything before now - # The entry was just created so it might not be purged with 0 - # Use a large value instead - queue.dead_letters() - # Just verify the API works without error - assert isinstance(purged, int) diff --git a/tests/test_events.py b/tests/test_events.py deleted file mode 100644 index 6d6e5c6..0000000 --- a/tests/test_events.py +++ /dev/null @@ -1,103 +0,0 @@ -"""Tests for EventBus event dispatch.""" - -import time -from typing import Any - -from taskito.events import EventBus, EventType - -PollUntil = Any # the conftest fixture's runtime type - - -def test_callback_receives_event(poll_until: PollUntil) -> None: - """Registered callbacks receive emitted events.""" - received: list[tuple[EventType, dict[str, Any]]] = [] - bus = EventBus() - bus.on(EventType.JOB_COMPLETED, lambda et, p: received.append((et, p))) - - bus.emit(EventType.JOB_COMPLETED, {"job_id": "123"}) - poll_until(lambda: len(received) >= 1, message="event was not delivered") - - assert len(received) == 1 - assert received[0][0] == EventType.JOB_COMPLETED - assert received[0][1]["job_id"] == "123" - - -def test_multiple_callbacks(poll_until: PollUntil) -> None: - """Multiple callbacks for the same event type all fire.""" - counts = {"a": 0, "b": 0} - bus = EventBus() - bus.on(EventType.JOB_FAILED, lambda et, p: counts.__setitem__("a", counts["a"] + 1)) - bus.on(EventType.JOB_FAILED, lambda et, p: counts.__setitem__("b", counts["b"] + 1)) - - bus.emit(EventType.JOB_FAILED, {"error": "boom"}) - poll_until( - lambda: counts["a"] == 1 and counts["b"] == 1, - message="not all callbacks fired", - ) - - assert counts["a"] == 1 - assert counts["b"] == 1 - - -def test_event_filtering() -> None: - """Callbacks only fire for their registered event type.""" - received: list[str] = [] - bus = EventBus() - bus.on(EventType.JOB_COMPLETED, lambda et, p: received.append("completed")) - - bus.emit(EventType.JOB_FAILED, {"error": "boom"}) - # Brief settle so a (would-be incorrect) cross-type dispatch could land. - time.sleep(0.2) - - assert received == [] - - -def test_exception_in_callback_does_not_crash(poll_until: PollUntil) -> None: - """A raising callback doesn't prevent other events from processing.""" - results: list[str] = [] - bus = EventBus() - - def bad_callback(et: EventType, p: dict[str, Any]) -> None: - raise RuntimeError("callback error") - - def good_callback(et: EventType, p: dict[str, Any]) -> None: - results.append("ok") - - bus.on(EventType.JOB_ENQUEUED, bad_callback) - bus.on(EventType.JOB_ENQUEUED, good_callback) - - bus.emit(EventType.JOB_ENQUEUED, {}) - poll_until(lambda: results == ["ok"], message="good_callback did not run") - - assert results == ["ok"] - - -def test_emit_with_no_listeners() -> None: - """Emitting an event with no listeners doesn't raise.""" - bus = EventBus() - bus.emit(EventType.JOB_DEAD, {"job_id": "456"}) - - -def test_all_event_types_exist() -> None: - """All expected event types are defined.""" - expected = { - "job.enqueued", - "job.completed", - "job.failed", - "job.retrying", - "job.dead", - "job.cancelled", - "worker.started", - "worker.stopped", - "worker.online", - "worker.offline", - "worker.unhealthy", - "queue.paused", - "queue.resumed", - "workflow.submitted", - "workflow.completed", - "workflow.failed", - "workflow.cancelled", - "workflow.gate_reached", - } - assert {e.value for e in EventType} == expected diff --git a/tests/test_fastapi.py b/tests/test_fastapi.py deleted file mode 100644 index fdbcc1f..0000000 --- a/tests/test_fastapi.py +++ /dev/null @@ -1,190 +0,0 @@ -"""Tests for FastAPI integration.""" - -import threading -from typing import Any - -import pytest - -# Skip entire module if fastapi is not installed -fastapi = pytest.importorskip("fastapi") -httpx = pytest.importorskip("httpx") - -from fastapi import FastAPI # noqa: E402 -from fastapi.testclient import TestClient # noqa: E402 - -from taskito import Queue # noqa: E402 -from taskito.contrib.fastapi import TaskitoRouter # noqa: E402 - - -@pytest.fixture -def app(queue: Queue) -> FastAPI: - """Create a FastAPI app with TaskitoRouter.""" - app = FastAPI() - app.include_router(TaskitoRouter(queue), prefix="/tasks") - return app - - -@pytest.fixture -def client(app: FastAPI) -> TestClient: - """Create a TestClient.""" - return TestClient(app) - - -@pytest.fixture -def populated(queue: Queue, client: TestClient) -> tuple[Queue, TestClient, list[Any], Any]: - """Queue with a task and some jobs.""" - - @queue.task() - def add(a: int, b: int) -> int: - return a + b - - jobs = [add.delay(i, i + 1) for i in range(5)] - return queue, client, jobs, add - - -# ── Stats ──────────────────────────────────────────────── - - -def test_stats(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: - _queue, client, _jobs, _add = populated - resp = client.get("/tasks/stats") - assert resp.status_code == 200 - data = resp.json() - assert data["pending"] == 5 - assert data["running"] == 0 - - -# ── Job detail ─────────────────────────────────────────── - - -def test_get_job(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: - _queue, client, jobs, _add = populated - job_id = jobs[0].id - resp = client.get(f"/tasks/jobs/{job_id}") - assert resp.status_code == 200 - data = resp.json() - assert data["id"] == job_id - assert data["status"] == "pending" - - -def test_get_job_not_found(client: TestClient) -> None: - resp = client.get("/tasks/jobs/nonexistent") - assert resp.status_code == 404 - - -# ── Job errors ─────────────────────────────────────────── - - -def test_get_job_errors_empty(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: - _queue, client, jobs, _add = populated - job_id = jobs[0].id - resp = client.get(f"/tasks/jobs/{job_id}/errors") - assert resp.status_code == 200 - assert resp.json() == [] - - -# ── Job result ─────────────────────────────────────────── - - -def test_get_job_result_pending(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: - _queue, client, jobs, _add = populated - job_id = jobs[0].id - resp = client.get(f"/tasks/jobs/{job_id}/result") - assert resp.status_code == 200 - data = resp.json() - assert data["status"] == "pending" - assert data["result"] is None - - -def test_get_job_result_completed(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: - queue, client, jobs, _add = populated - - worker = threading.Thread(target=queue.run_worker, daemon=True) - worker.start() - - # Wait for first job to complete - jobs[0].result(timeout=10) - - resp = client.get(f"/tasks/jobs/{jobs[0].id}/result") - assert resp.status_code == 200 - data = resp.json() - assert data["status"] == "complete" - assert data["result"] == 1 # add(0, 1) = 1 - - -# ── Cancel ─────────────────────────────────────────────── - - -def test_cancel_job(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: - _queue, client, jobs, _add = populated - job_id = jobs[0].id - resp = client.post(f"/tasks/jobs/{job_id}/cancel") - assert resp.status_code == 200 - assert resp.json()["cancelled"] is True - - # Cancel again — should be false - resp = client.post(f"/tasks/jobs/{job_id}/cancel") - assert resp.json()["cancelled"] is False - - -# ── Dead letters ───────────────────────────────────────── - - -def test_dead_letters_empty(client: TestClient) -> None: - resp = client.get("/tasks/dead-letters") - assert resp.status_code == 200 - assert resp.json() == [] - - -# ── Progress SSE ───────────────────────────────────────── - - -def test_progress_stream(populated: tuple[Queue, TestClient, list[Any], Any]) -> None: - queue, client, jobs, _add = populated - - # Start worker so the job completes - worker = threading.Thread(target=queue.run_worker, daemon=True) - worker.start() - - job_id = jobs[0].id - # Wait for job to finish first - jobs[0].result(timeout=10) - - with client.stream("GET", f"/tasks/jobs/{job_id}/progress") as resp: - assert resp.status_code == 200 - lines: list[str] = [] - for line in resp.iter_lines(): - if line.startswith("data:"): - lines.append(line) - break # Just check first event - - assert len(lines) >= 1 - # The first (and only) event should show complete status - import json - - data = json.loads(lines[0].replace("data: ", "")) - assert data["status"] == "complete" - - -def test_progress_stream_not_found(client: TestClient) -> None: - resp = client.get("/tasks/jobs/nonexistent/progress") - assert resp.status_code == 404 - - -# ── Router config ──────────────────────────────────────── - - -def test_router_custom_tags(queue: Queue) -> None: - """TaskitoRouter accepts standard APIRouter kwargs.""" - router = TaskitoRouter(queue, tags=["my-tasks"]) - assert "my-tasks" in router.tags - - -def test_router_custom_prefix(queue: Queue) -> None: - """Router can be mounted with a custom prefix.""" - app = FastAPI() - app.include_router(TaskitoRouter(queue), prefix="/api/v1/queue") - client = TestClient(app) - - resp = client.get("/api/v1/queue/stats") - assert resp.status_code == 200 diff --git a/tests/test_hooks.py b/tests/test_hooks.py deleted file mode 100644 index e96d4f5..0000000 --- a/tests/test_hooks.py +++ /dev/null @@ -1,83 +0,0 @@ -"""Tests for the hooks/middleware system.""" - -from __future__ import annotations - -import threading -from typing import Any - -from taskito import Queue - - -def test_before_and_after_hooks(queue: Queue) -> None: - """before_task and after_task hooks fire around task execution.""" - events: list[tuple[Any, ...]] = [] - - @queue.before_task - def on_before(task_name: str, args: tuple, kwargs: dict) -> None: - events.append(("before", task_name)) - - @queue.after_task - def on_after(task_name: str, args: tuple, kwargs: dict, result: Any, error: Any) -> None: - events.append(("after", task_name, result, error)) - - @queue.task() - def add(a: int, b: int) -> int: - return a + b - - job = add.delay(1, 2) - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - result = job.result(timeout=10) - assert result == 3 - - # Verify hooks fired - assert any(e[0] == "before" for e in events) - assert any(e[0] == "after" and len(e) > 2 and e[2] == 3 and e[3] is None for e in events) - - -def test_on_success_hook(queue: Queue) -> None: - """on_success hook fires when task succeeds.""" - success_results: list[Any] = [] - - @queue.on_success - def on_success(task_name: str, args: tuple, kwargs: dict, result: Any) -> None: - success_results.append(result) - - @queue.task() - def multiply(a: int, b: int) -> int: - return a * b - - job = multiply.delay(3, 4) - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - result = job.result(timeout=10) - assert result == 12 - assert 12 in success_results - - -def test_on_failure_hook(queue: Queue) -> None: - """on_failure hook fires when task raises.""" - failure_errors: list[str] = [] - - @queue.on_failure - def on_failure(task_name: str, args: tuple, kwargs: dict, error: Exception) -> None: - failure_errors.append(str(error)) - - @queue.task(max_retries=1, retry_backoff=0.1) - def always_fails() -> None: - raise ValueError("boom") - - always_fails.delay() - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - import time - - time.sleep(3) - - assert any("boom" in e for e in failure_errors) diff --git a/tests/test_idempotent.py b/tests/test_idempotent.py deleted file mode 100644 index fefe3ef..0000000 --- a/tests/test_idempotent.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Tests for auto-derived idempotency on @queue.task(idempotent=True).""" - -from __future__ import annotations - -import threading - -from taskito import Queue - - -def test_idempotent_task_dedupes_by_args(queue: Queue) -> None: - """Two enqueues with identical args under idempotent=True share a job.""" - - @queue.task(idempotent=True) - def charge(customer_id: int, amount: int) -> int: - return amount - - job1 = charge.delay(42, 1000) - job2 = charge.delay(42, 1000) - - assert job1.id == job2.id - - -def test_idempotent_task_distinct_args_distinct_jobs(queue: Queue) -> None: - """Different arguments produce different auto-keys → different jobs.""" - - @queue.task(idempotent=True) - def charge(customer_id: int, amount: int) -> int: - return amount - - job_a = charge.delay(1, 100) - job_b = charge.delay(2, 100) - job_c = charge.delay(1, 200) - - assert len({job_a.id, job_b.id, job_c.id}) == 3 - - -def test_idempotent_per_call_key_overrides_auto(queue: Queue) -> None: - """An explicit idempotency_key wins over the auto-derived key.""" - - @queue.task(idempotent=True) - def process(data: str) -> str: - return data - - auto_job = process.delay("payload") - explicit_job = process.apply_async(args=("payload",), idempotency_key="custom-key") - - # Auto-key and explicit-key collide on the same args only by coincidence — - # since the explicit key is "custom-key", they should be different jobs. - assert auto_job.id != explicit_job.id - - -def test_idempotent_per_call_disable_creates_new_job(queue: Queue) -> None: - """idempotent=False on apply_async overrides the per-task default.""" - - @queue.task(idempotent=True) - def process(data: str) -> str: - return data - - auto_job = process.delay("same-args") - forced_new = process.apply_async(args=("same-args",), idempotent=False) - - assert auto_job.id != forced_new.id - - -def test_non_idempotent_task_allows_duplicates(queue: Queue) -> None: - """Without idempotent=True, identical calls produce distinct jobs.""" - - @queue.task() - def process(data: str) -> str: - return data - - job1 = process.delay("payload") - job2 = process.delay("payload") - - assert job1.id != job2.id - - -def test_idempotent_clears_after_completion(queue: Queue) -> None: - """After an idempotent job finishes, the next call creates a new job.""" - - @queue.task(idempotent=True) - def fast() -> str: - return "done" - - job1 = fast.delay() - - worker = threading.Thread(target=queue.run_worker, daemon=True) - worker.start() - - job1.result(timeout=10) - - job2 = fast.delay() - assert job2.id != job1.id - - -def test_idempotent_via_enqueue_kwarg(queue: Queue) -> None: - """Per-call idempotent=True works without a registered task default.""" - - @queue.task() - def process(data: str) -> str: - return data - - job1 = process.apply_async(args=("x",), idempotent=True) - job2 = process.apply_async(args=("x",), idempotent=True) - job3 = process.apply_async(args=("y",), idempotent=True) - - assert job1.id == job2.id - assert job1.id != job3.id - - -def test_idempotent_unique_key_takes_precedence(queue: Queue) -> None: - """An explicit unique_key beats both auto-derivation and idempotency_key.""" - - @queue.task(idempotent=True) - def process(data: str) -> str: - return data - - explicit = process.apply_async(args=("payload",), unique_key="explicit-uk") - auto = process.delay("payload") - - # The auto key is "auto:" and the explicit one is "explicit-uk", - # so they must differ. - assert explicit.id != auto.id diff --git a/tests/test_interception.py b/tests/test_interception.py deleted file mode 100644 index 4f7db15..0000000 --- a/tests/test_interception.py +++ /dev/null @@ -1,446 +0,0 @@ -"""Tests for the argument interception system (Layer 1).""" - -from __future__ import annotations - -import datetime -import decimal -import enum -import pathlib -import re -import socket -import threading -import uuid -from dataclasses import dataclass -from typing import Any - -import pytest - -from taskito import Queue -from taskito.interception import ( - ArgumentInterceptor, - InterceptionError, - InterceptionReport, -) -from taskito.interception.built_in import build_default_registry -from taskito.interception.converters import reconstruct_converted -from taskito.interception.reconstruct import reconstruct_args -from taskito.interception.strategy import Strategy - -# -- Fixtures -- - - -@pytest.fixture -def registry() -> Any: - return build_default_registry() - - -@pytest.fixture -def strict(registry: Any) -> ArgumentInterceptor: - return ArgumentInterceptor(registry, mode="strict") - - -@pytest.fixture -def lenient(registry: Any) -> ArgumentInterceptor: - return ArgumentInterceptor(registry, mode="lenient") - - -# -- PASS strategy -- - - -class TestPassStrategy: - def test_int_passes_through(self, strict: ArgumentInterceptor) -> None: - args, _kw = strict.intercept((42,), {}) - assert args == (42,) - - def test_str_passes_through(self, strict: ArgumentInterceptor) -> None: - args, _kw = strict.intercept(("hello",), {}) - assert args == ("hello",) - - def test_float_passes_through(self, strict: ArgumentInterceptor) -> None: - args, _kw = strict.intercept((3.14,), {}) - assert args == (3.14,) - - def test_bool_passes_through(self, strict: ArgumentInterceptor) -> None: - args, _kw = strict.intercept((True, False), {}) - assert args == (True, False) - - def test_none_passes_through(self, strict: ArgumentInterceptor) -> None: - args, _kw = strict.intercept((None,), {}) - assert args == (None,) - - def test_bytes_passes_through(self, strict: ArgumentInterceptor) -> None: - args, _kw = strict.intercept((b"data",), {}) - assert args == (b"data",) - - def test_mixed_primitives(self, strict: ArgumentInterceptor) -> None: - args, kwargs = strict.intercept( - (1, "two", 3.0, True, None, b"six"), - {"key": "val"}, - ) - assert args == (1, "two", 3.0, True, None, b"six") - assert kwargs == {"key": "val"} - - -# -- CONVERT strategy -- - - -class TestConvertStrategy: - def test_uuid_round_trip(self, strict: ArgumentInterceptor) -> None: - original = uuid.UUID("12345678-1234-5678-1234-567812345678") - args, _kw = strict.intercept((original,), {}) - assert args[0]["__taskito_convert__"] is True - assert args[0]["type_key"] == "uuid" - # Reconstruct - restored = reconstruct_converted(args[0]) - assert restored == original - - def test_datetime_round_trip(self, strict: ArgumentInterceptor) -> None: - original = datetime.datetime(2025, 3, 10, 12, 0, 0) - args, _ = strict.intercept((original,), {}) - assert args[0]["type_key"] == "datetime" - restored = reconstruct_converted(args[0]) - assert restored == original - - def test_date_round_trip(self, strict: ArgumentInterceptor) -> None: - original = datetime.date(2025, 3, 10) - args, _ = strict.intercept((original,), {}) - assert args[0]["type_key"] == "date" - restored = reconstruct_converted(args[0]) - assert restored == original - - def test_time_round_trip(self, strict: ArgumentInterceptor) -> None: - original = datetime.time(14, 30, 0) - args, _ = strict.intercept((original,), {}) - assert args[0]["type_key"] == "time" - restored = reconstruct_converted(args[0]) - assert restored == original - - def test_timedelta_round_trip(self, strict: ArgumentInterceptor) -> None: - original = datetime.timedelta(hours=1, minutes=30) - args, _ = strict.intercept((original,), {}) - assert args[0]["type_key"] == "timedelta" - restored = reconstruct_converted(args[0]) - assert restored == original - - def test_decimal_round_trip(self, strict: ArgumentInterceptor) -> None: - original = decimal.Decimal("3.14159") - args, _ = strict.intercept((original,), {}) - assert args[0]["type_key"] == "decimal" - restored = reconstruct_converted(args[0]) - assert restored == original - - def test_path_round_trip(self, strict: ArgumentInterceptor) -> None: - original = pathlib.Path("/tmp/data.csv") - args, _ = strict.intercept((original,), {}) - assert args[0]["type_key"] == "path" - restored = reconstruct_converted(args[0]) - assert restored == original - - def test_pattern_round_trip(self, strict: ArgumentInterceptor) -> None: - original = re.compile(r"\d+", re.IGNORECASE) - args, _ = strict.intercept((original,), {}) - assert args[0]["type_key"] == "pattern" - restored = reconstruct_converted(args[0]) - assert restored.pattern == original.pattern - assert restored.flags == original.flags - - def test_enum_converts(self, strict: ArgumentInterceptor) -> None: - class Color(enum.Enum): - RED = "red" - GREEN = "green" - - args, _ = strict.intercept((Color.RED,), {}) - assert args[0]["type_key"] == "enum" - assert args[0]["value"] == "red" - - def test_dataclass_converts(self, strict: ArgumentInterceptor) -> None: - @dataclass - class Point: - x: int - y: int - - original = Point(x=1, y=2) - args, _ = strict.intercept((original,), {}) - assert args[0]["__taskito_convert__"] is True - assert args[0]["type_key"] == "dataclass" - assert args[0]["value"] == {"x": 1, "y": 2} - - def test_datetime_before_date(self, strict: ArgumentInterceptor) -> None: - """datetime is a subclass of date — datetime must match first.""" - dt = datetime.datetime(2025, 1, 1, 12, 0, 0) - args, _ = strict.intercept((dt,), {}) - assert args[0]["type_key"] == "datetime" - - -# -- REDIRECT strategy -- - - -class TestRedirectStrategy: - def test_redirect_produces_marker(self, strict: ArgumentInterceptor) -> None: - """Test that redirect types produce markers (if sqlalchemy is installed).""" - try: - from sqlalchemy.orm import Session # type: ignore[import-not-found] # noqa: F401 - - sqlalchemy_available = True - except ImportError: - sqlalchemy_available = False - - if not sqlalchemy_available: - pytest.skip("sqlalchemy not installed") - - -# -- REJECT strategy -- - - -class TestRejectStrategy: - def test_threading_lock_rejected(self, strict: ArgumentInterceptor) -> None: - lock = threading.Lock() - with pytest.raises(InterceptionError) as exc_info: - strict.intercept((lock,), {}) - assert len(exc_info.value.failures) == 1 - assert "args[0]" in exc_info.value.failures[0].path - assert "lock" in exc_info.value.failures[0].type_name.lower() - - def test_socket_rejected(self, strict: ArgumentInterceptor) -> None: - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - try: - with pytest.raises(InterceptionError): - strict.intercept((sock,), {}) - finally: - sock.close() - - def test_generator_rejected(self, strict: ArgumentInterceptor) -> None: - gen = (x for x in range(10)) - with pytest.raises(InterceptionError): - strict.intercept((gen,), {}) - - def test_reject_error_includes_path(self, strict: ArgumentInterceptor) -> None: - lock = threading.Lock() - with pytest.raises(InterceptionError) as exc_info: - strict.intercept((), {"session": lock}) - assert "kwargs.session" in exc_info.value.failures[0].path - - def test_reject_error_has_suggestions(self, strict: ArgumentInterceptor) -> None: - lock = threading.Lock() - with pytest.raises(InterceptionError) as exc_info: - strict.intercept((lock,), {}) - assert len(exc_info.value.failures[0].suggestions) > 0 - - def test_multiple_rejections_collected(self, strict: ArgumentInterceptor) -> None: - lock = threading.Lock() - event = threading.Event() - with pytest.raises(InterceptionError) as exc_info: - strict.intercept((lock, event), {}) - assert len(exc_info.value.failures) == 2 - - -# -- Lenient mode -- - - -class TestLenientMode: - def test_rejected_arg_dropped(self, lenient: ArgumentInterceptor) -> None: - lock = threading.Lock() - args, _kw = lenient.intercept((42, lock), {}) - assert args[0] == 42 - assert args[1] is None # dropped to None - - def test_rejected_kwarg_dropped(self, lenient: ArgumentInterceptor) -> None: - lock = threading.Lock() - _args, kwargs = lenient.intercept((), {"x": 1, "lock": lock}) - assert kwargs == {"x": 1} - - -# -- Off mode -- - - -class TestOffMode: - def test_passthrough_no_interception(self, registry: Any) -> None: - interceptor = ArgumentInterceptor(registry, mode="off") - lock = threading.Lock() - args, _kw = interceptor.intercept((lock,), {}) - assert args[0] is lock - - -# -- Recursive walking -- - - -class TestRecursiveWalking: - def test_nested_uuid_in_dict(self, strict: ArgumentInterceptor) -> None: - uid = uuid.uuid4() - _, kwargs = strict.intercept((), {"config": {"user_id": uid}}) - assert kwargs["config"]["user_id"]["__taskito_convert__"] is True - - def test_uuid_in_list(self, strict: ArgumentInterceptor) -> None: - uid = uuid.uuid4() - args, _ = strict.intercept(([uid],), {}) - assert args[0][0]["__taskito_convert__"] is True - - def test_depth_limit(self, registry: Any) -> None: - interceptor = ArgumentInterceptor(registry, mode="strict", max_depth=2) - uid = uuid.uuid4() - # Depth 3 — beyond limit, should pass through - deep = {"a": {"b": {"c": uid}}} - _, kwargs = interceptor.intercept((), {"x": deep}) - # uid at depth 3 should pass through as-is (beyond max_depth=2) - assert kwargs["x"]["a"]["b"]["c"] is uid - - def test_circular_reference_handled(self, strict: ArgumentInterceptor) -> None: - d: dict[str, Any] = {"value": 42} - d["self"] = d # circular! - args, _ = strict.intercept((d,), {}) - # Should not infinite loop — circular ref is detected and passed through - assert args[0]["value"] == 42 - - -# -- Full round trip (intercept + reconstruct) -- - - -class TestRoundTrip: - def test_uuid_full_round_trip(self, strict: ArgumentInterceptor) -> None: - uid = uuid.UUID("aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee") - intercepted_args, intercepted_kwargs = strict.intercept((uid,), {"id": uid}) - args, kwargs, _redirects = reconstruct_args(intercepted_args, intercepted_kwargs) - assert args[0] == uid - assert kwargs["id"] == uid - - def test_mixed_types_round_trip(self, strict: ArgumentInterceptor) -> None: - dt = datetime.datetime(2025, 6, 15, 10, 30) - path = pathlib.Path("/data/file.txt") - intercepted_args, intercepted_kwargs = strict.intercept((42, "hello", dt), {"path": path}) - args, kwargs, _redirects = reconstruct_args(intercepted_args, intercepted_kwargs) - assert args[0] == 42 - assert args[1] == "hello" - assert args[2] == dt - assert kwargs["path"] == path - - def test_nested_convert_round_trip(self, strict: ArgumentInterceptor) -> None: - uid = uuid.uuid4() - intercepted_args, _ = strict.intercept(({"ids": [uid]},), {}) - args, _, _ = reconstruct_args(intercepted_args, {}) - assert args[0]["ids"][0] == uid - - -# -- Analyze / Report -- - - -class TestAnalyze: - def test_analyze_returns_report(self, strict: ArgumentInterceptor) -> None: - report = strict.analyze((42, "hello"), {"uid": uuid.uuid4()}) - assert isinstance(report, InterceptionReport) - assert len(report.entries) == 3 - - def test_analyze_shows_strategies(self, strict: ArgumentInterceptor) -> None: - uid = uuid.uuid4() - report = strict.analyze((42, uid), {}) - strategies = [e.strategy for e in report.entries] - assert Strategy.PASS in strategies - assert Strategy.CONVERT in strategies - - def test_analyze_on_off_mode_returns_empty(self, registry: Any) -> None: - interceptor = ArgumentInterceptor(registry, mode="off") - report = interceptor.analyze((42,), {}) - assert len(report.entries) == 0 - - def test_report_str_format(self, strict: ArgumentInterceptor) -> None: - report = strict.analyze((42,), {}) - text = str(report) - assert "Argument Analysis:" in text - - -# -- Queue integration -- - - -class TestQueueIntegration: - def test_queue_default_interception_off(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - assert q._interceptor is None - - def test_queue_strict_mode(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db"), interception="strict") - assert q._interceptor is not None - assert q._interceptor.mode == "strict" - - def test_queue_enqueue_with_interception(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db"), interception="strict") - - @q.task() - def add(a: int, b: int) -> int: - return a + b - - # Simple args should work fine - result = add.delay(1, 2) - assert result.id is not None - - def test_queue_enqueue_rejects_lock(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db"), interception="strict") - - @q.task() - def bad_task(lock: Any) -> None: - pass - - with pytest.raises(InterceptionError): - bad_task.delay(threading.Lock()) - - def test_task_analyze(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db"), interception="strict") - - @q.task() - def my_task(user_id: int, created_at: datetime.datetime) -> None: - pass - - report = my_task.analyze(42, datetime.datetime.now()) - assert len(report.entries) == 2 - - def test_task_analyze_off_mode(self, tmp_path: Any) -> None: - q = Queue(db_path=str(tmp_path / "test.db")) - - @q.task() - def my_task(x: int) -> None: - pass - - report = my_task.analyze(42) - assert len(report.entries) == 0 - - -# -- Custom type registration -- - - -class TestCustomRegistration: - def test_register_custom_reject(self, registry: Any, strict: ArgumentInterceptor) -> None: - class MyLock: - pass - - registry.register( - MyLock, - Strategy.REJECT, - priority=40, - reject_reason="Use distributed locking instead.", - reject_suggestions=["Use queue.lock()"], - ) - with pytest.raises(InterceptionError) as exc_info: - strict.intercept((MyLock(),), {}) - assert "distributed locking" in str(exc_info.value) - - def test_register_custom_convert(self, registry: Any, strict: ArgumentInterceptor) -> None: - class Money: - def __init__(self, amount: int, currency: str) -> None: - self.amount = amount - self.currency = currency - - def convert_money(obj: Money) -> dict[str, Any]: - return { - "__taskito_convert__": True, - "type_key": "money", - "value": {"amount": obj.amount, "currency": obj.currency}, - } - - registry.register( - Money, - Strategy.CONVERT, - priority=15, - converter=convert_money, - type_key="money", - ) - args, _ = strict.intercept((Money(100, "USD"),), {}) - assert args[0]["__taskito_convert__"] is True - assert args[0]["value"]["amount"] == 100 diff --git a/tests/test_keda.py b/tests/test_keda.py deleted file mode 100644 index a298e34..0000000 --- a/tests/test_keda.py +++ /dev/null @@ -1,153 +0,0 @@ -"""Tests for KEDA-compatible /api/scaler endpoint contract.""" - -import json -import threading -import urllib.request -from collections.abc import Generator -from http.server import ThreadingHTTPServer -from pathlib import Path - -import pytest - -from taskito import Queue -from taskito.dashboard import build_scaler_response - - -@pytest.fixture() -def empty_queue(tmp_path: Path) -> Queue: - """Queue with no pending jobs.""" - return Queue(db_path=str(tmp_path / "keda.db"), workers=4) - - -@pytest.fixture() -def populated_queue(tmp_path: Path) -> Queue: - """Queue with pending jobs across two queues.""" - q = Queue(db_path=str(tmp_path / "keda.db"), workers=4) - - @q.task(queue="emails") - def send_email(to: str) -> None: - pass - - @q.task(queue="reports") - def generate_report(name: str) -> None: - pass - - for i in range(5): - send_email.delay(f"user{i}@example.com") - for i in range(3): - generate_report.delay(f"report_{i}") - - return q - - -@pytest.fixture() -def scaler_server(empty_queue: Queue) -> Generator[str]: - """Start a scaler HTTP server on a random port, yield the base URL.""" - from taskito.scaler import _make_scaler_handler - - handler = _make_scaler_handler(empty_queue, target_queue_depth=10) - server = ThreadingHTTPServer(("127.0.0.1", 0), handler) - port = server.server_address[1] - thread = threading.Thread(target=server.serve_forever, daemon=True) - thread.start() - yield f"http://127.0.0.1:{port}" - server.shutdown() - - -# ─── Unit tests: build_scaler_response() ─── - - -class TestScalerResponseShape: - def test_required_fields(self, empty_queue: Queue) -> None: - resp = build_scaler_response(empty_queue) - assert "metricName" in resp - assert "metricValue" in resp - assert "isActive" in resp - assert "liveWorkers" in resp - assert "totalCapacity" in resp - assert "targetQueueDepth" in resp - - def test_metric_value_is_pending_count(self, populated_queue: Queue) -> None: - resp = build_scaler_response(populated_queue) - assert resp["metricValue"] == 8 # 5 emails + 3 reports - - def test_is_active_false_when_empty(self, empty_queue: Queue) -> None: - resp = build_scaler_response(empty_queue) - assert resp["isActive"] is False - assert resp["metricValue"] == 0 - - def test_is_active_true_when_pending(self, populated_queue: Queue) -> None: - resp = build_scaler_response(populated_queue) - assert resp["isActive"] is True - - def test_per_queue_filter(self, populated_queue: Queue) -> None: - resp = build_scaler_response(populated_queue, queue_name="emails") - assert resp["metricValue"] == 5 - assert resp["isActive"] is True - - def test_metric_name_namespaced_for_queue(self, populated_queue: Queue) -> None: - resp = build_scaler_response(populated_queue, queue_name="reports") - assert resp["metricName"] == "taskito_queue_depth_reports" - - def test_default_metric_name(self, empty_queue: Queue) -> None: - resp = build_scaler_response(empty_queue) - assert resp["metricName"] == "taskito_queue_depth" - - def test_worker_utilization_present(self, empty_queue: Queue) -> None: - resp = build_scaler_response(empty_queue) - # workers=4, 0 running → utilization = 0.0 - assert "workerUtilization" in resp - assert resp["workerUtilization"] == 0.0 - - def test_worker_utilization_is_zero_when_idle(self, empty_queue: Queue) -> None: - resp = build_scaler_response(empty_queue) - assert resp["workerUtilization"] == 0.0 - - def test_per_queue_stats_present(self, populated_queue: Queue) -> None: - resp = build_scaler_response(populated_queue) - assert "perQueue" in resp - assert "emails" in resp["perQueue"] - assert "reports" in resp["perQueue"] - assert resp["perQueue"]["emails"]["pending"] == 5 - assert resp["perQueue"]["reports"]["pending"] == 3 - - def test_target_queue_depth_default(self, empty_queue: Queue) -> None: - resp = build_scaler_response(empty_queue) - assert resp["targetQueueDepth"] == 10 - - def test_target_queue_depth_custom(self, empty_queue: Queue) -> None: - resp = build_scaler_response(empty_queue, target_queue_depth=25) - assert resp["targetQueueDepth"] == 25 - - def test_live_workers_zero_before_start(self, empty_queue: Queue) -> None: - resp = build_scaler_response(empty_queue) - assert resp["liveWorkers"] == 0 - - def test_total_capacity_matches_config(self, empty_queue: Queue) -> None: - resp = build_scaler_response(empty_queue) - assert resp["totalCapacity"] == 4 - - -# ─── Integration tests: HTTP server ─── - - -class TestScalerHTTP: - def test_scaler_endpoint_returns_json(self, scaler_server: str) -> None: - resp = urllib.request.urlopen(f"{scaler_server}/api/scaler") - assert resp.status == 200 - data = json.loads(resp.read()) - assert "metricName" in data - assert "metricValue" in data - - def test_health_endpoint(self, scaler_server: str) -> None: - resp = urllib.request.urlopen(f"{scaler_server}/health") - assert resp.status == 200 - data = json.loads(resp.read()) - assert data["status"] == "ok" - - def test_unknown_path_returns_404(self, scaler_server: str) -> None: - try: - urllib.request.urlopen(f"{scaler_server}/unknown") - pytest.fail("Expected HTTPError") - except urllib.error.HTTPError as e: - assert e.code == 404 diff --git a/tests/test_namespace.py b/tests/test_namespace.py deleted file mode 100644 index 704b343..0000000 --- a/tests/test_namespace.py +++ /dev/null @@ -1,118 +0,0 @@ -"""Tests for namespace-based routing and isolation.""" - -import threading -from pathlib import Path - -from taskito import Queue - - -def test_namespace_enqueue_sets_namespace(tmp_path: Path) -> None: - """Jobs enqueued on a namespaced Queue carry the namespace.""" - queue = Queue(db_path=str(tmp_path / "test.db"), namespace="team-a") - - @queue.task() - def add(x: int, y: int) -> int: - return x + y - - job = add.delay(1, 2) - py_job = queue._inner.get_job(job.id) - assert py_job is not None - assert py_job.namespace == "team-a" - - -def test_no_namespace_jobs_have_none(tmp_path: Path) -> None: - """Jobs enqueued without a namespace have namespace=None.""" - queue = Queue(db_path=str(tmp_path / "test.db")) - - @queue.task() - def noop() -> None: - pass - - job = noop.delay() - py_job = queue._inner.get_job(job.id) - assert py_job is not None - assert py_job.namespace is None - - -def test_namespace_isolation_worker(tmp_path: Path) -> None: - """A namespaced worker only processes jobs from its namespace.""" - db = str(tmp_path / "test.db") - - # Create two queues sharing the same DB but different namespaces - q_a = Queue(db_path=db, namespace="team-a") - q_b = Queue(db_path=db, namespace="team-b") - - results: list[str] = [] - - @q_a.task() - def task_a() -> str: - results.append("a") - return "a" - - @q_b.task() - def task_b() -> str: - results.append("b") - return "b" - - # Enqueue one job on each namespace - job_a = task_a.delay() - job_b = task_b.delay() - - # Run worker for team-a only - worker = threading.Thread(target=q_a.run_worker, daemon=True) - worker.start() - - # Wait for team-a's job to complete - job_a.result(timeout=10) - - # team-b's job should still be pending - job_b.refresh() - assert job_b.status == "pending" - - # Shut down team-a worker (daemon thread exits on its own) - q_a._inner.request_shutdown() - - -def test_namespace_list_jobs_scoped(tmp_path: Path) -> None: - """list_jobs defaults to the queue's namespace.""" - db = str(tmp_path / "test.db") - q_a = Queue(db_path=db, namespace="ns-a") - q_b = Queue(db_path=db, namespace="ns-b") - - @q_a.task() - def task_x() -> None: - pass - - @q_b.task() - def task_y() -> None: - pass - - task_x.delay() - task_x.delay() - task_y.delay() - - # Each queue sees only its own jobs by default - assert len(q_a.list_jobs()) == 2 - assert len(q_b.list_jobs()) == 1 - - # Passing namespace=None shows all - assert len(q_a.list_jobs(namespace=None)) == 3 - - -def test_namespace_preserved_in_job_result(tmp_path: Path) -> None: - """JobResult.to_dict() includes the namespace.""" - queue = Queue(db_path=str(tmp_path / "test.db"), namespace="my-ns") - - @queue.task() - def greet(name: str) -> str: - return f"hi {name}" - - job = greet.delay("world") - - worker = threading.Thread(target=queue.run_worker, daemon=True) - worker.start() - - result = job.result(timeout=10) - assert result == "hi world" - - queue._inner.request_shutdown() diff --git a/tests/test_native_async.py b/tests/test_native_async.py deleted file mode 100644 index f6fa7d1..0000000 --- a/tests/test_native_async.py +++ /dev/null @@ -1,494 +0,0 @@ -"""Tests for native async task support.""" - -from __future__ import annotations - -import asyncio -import threading -from pathlib import Path -from typing import Any -from unittest.mock import MagicMock - -from taskito import Queue, TaskCancelledError, current_job -from taskito.async_support.context import ( - clear_async_context, - get_async_context, - set_async_context, -) -from taskito.middleware import TaskMiddleware - -PollUntil = Any # the conftest fixture's runtime type - -# ── Async detection ────────────────────────────────────────────── - - -def test_async_task_detected(tmp_path: Path) -> None: - """_taskito_is_async is True for async functions.""" - queue = Queue(db_path=str(tmp_path / "test.db")) - - @queue.task() - async def my_async_task() -> None: - pass - - assert my_async_task._taskito_is_async is True - assert hasattr(my_async_task, "_taskito_async_fn") - - -def test_sync_task_not_async(tmp_path: Path) -> None: - """_taskito_is_async is False for sync functions.""" - queue = Queue(db_path=str(tmp_path / "test.db")) - - @queue.task() - def my_sync_task() -> None: - pass - - assert my_sync_task._taskito_is_async is False - assert my_sync_task._taskito_async_fn is None - - -# ── Async context (contextvars) ────────────────────────────────── - - -def test_async_context_var() -> None: - """set/get/clear async context via contextvars.""" - token = set_async_context("job-1", "my_task", 0, "default") - ctx = get_async_context() - assert ctx is not None - assert ctx.job_id == "job-1" - assert ctx.task_name == "my_task" - assert ctx.retry_count == 0 - assert ctx.queue_name == "default" - clear_async_context(token) - assert get_async_context() is None - - -def test_async_context_isolated_between_tasks() -> None: - """Each async task gets its own contextvar context (no cross-contamination).""" - results: list[str | None] = [] - - async def coro(job_id: str) -> None: - token = set_async_context(job_id, "task", 0, "q") - await asyncio.sleep(0.01) - ctx = get_async_context() - results.append(ctx.job_id if ctx else None) - clear_async_context(token) - - async def run_both() -> None: - await asyncio.gather(coro("a"), coro("b")) - - asyncio.run(run_both()) - - assert sorted(r for r in results if r is not None) == ["a", "b"] - - -def test_sync_context_unchanged(tmp_path: Path) -> None: - """current_job still works via threading.local for sync tasks.""" - from taskito.context import _clear_context, _set_context - - _set_context("sync-job", "sync_task", 2, "high") - assert current_job.id == "sync-job" - assert current_job.task_name == "sync_task" - assert current_job.retry_count == 2 - assert current_job.queue_name == "high" - _clear_context() - - -def test_async_context_fallback_to_sync() -> None: - """_require_context falls back to threading.local when no async context.""" - from taskito.context import _clear_context, _set_context - - # No async context set - assert get_async_context() is None - # Set sync context - _set_context("sync-id", "t", 0, "q") - assert current_job.id == "sync-id" - _clear_context() - - -def test_async_context_preferred_over_sync() -> None: - """When both async and sync contexts exist, async wins.""" - from taskito.context import _clear_context, _set_context - - _set_context("sync-id", "t", 0, "q") - token = set_async_context("async-id", "t", 0, "q") - assert current_job.id == "async-id" - clear_async_context(token) - assert current_job.id == "sync-id" - _clear_context() - - -# ── AsyncTaskExecutor unit tests ───────────────────────────────── - - -def test_async_executor_lifecycle() -> None: - """Start/stop executor without errors.""" - from taskito.async_support.executor import AsyncTaskExecutor - - sender = MagicMock() - registry: dict[str, Any] = {} - queue_ref = MagicMock() - executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10) - executor.start() - assert executor._loop is not None - assert executor._loop.is_running() - executor.stop() - - -def test_async_executor_submit_and_execute(poll_until: PollUntil) -> None: - """Basic async task produces correct result via executor.""" - import cloudpickle - - from taskito.async_support.executor import AsyncTaskExecutor - - sender = MagicMock() - - async def my_task(x: int, y: int) -> int: - return x + y - - # Build a minimal wrapper that the executor expects - class FakeWrapper: - _taskito_async_fn = staticmethod(my_task) - - registry: dict[str, Any] = {"test_mod.my_task": FakeWrapper()} - - queue_ref = MagicMock() - queue_ref._interceptor = None - queue_ref._proxy_registry = None - queue_ref._test_mode_active = False - queue_ref._resource_runtime = None - queue_ref._task_inject_map = {} - queue_ref._task_retry_filters = {} - queue_ref._get_middleware_chain.return_value = [] - queue_ref._proxy_metrics = None - - executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10) - executor.start() - - payload = cloudpickle.dumps(((2, 3), {})) - executor.submit_job("job-1", "test_mod.my_task", payload, 0, 3, "default") - poll_until(lambda: sender.report_success.called, message="job-1 result not reported") - executor.stop() - - sender.report_success.assert_called_once() - call_args = sender.report_success.call_args - assert call_args[0][0] == "job-1" - assert call_args[0][1] == "test_mod.my_task" - result = cloudpickle.loads(call_args[0][2]) - assert result == 5 - - -def test_async_exception_reported(poll_until: PollUntil) -> None: - """Exception in async task → failure result with traceback.""" - import cloudpickle - - from taskito.async_support.executor import AsyncTaskExecutor - - sender = MagicMock() - - async def failing_task() -> None: - raise ValueError("boom") - - class FakeWrapper: - _taskito_async_fn = staticmethod(failing_task) - - registry: dict[str, Any] = {"mod.failing_task": FakeWrapper()} - - queue_ref = MagicMock() - queue_ref._interceptor = None - queue_ref._proxy_registry = None - queue_ref._test_mode_active = False - queue_ref._resource_runtime = None - queue_ref._task_inject_map = {} - queue_ref._task_retry_filters = {} - queue_ref._get_middleware_chain.return_value = [] - queue_ref._proxy_metrics = None - - executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10) - executor.start() - - payload = cloudpickle.dumps(((), {})) - executor.submit_job("job-2", "mod.failing_task", payload, 0, 3, "default") - poll_until(lambda: sender.report_failure.called, message="job-2 failure not reported") - executor.stop() - - sender.report_failure.assert_called_once() - call_args = sender.report_failure.call_args - assert call_args[0][0] == "job-2" - assert "boom" in call_args[0][2] - assert call_args[0][6] is True # should_retry - - -def test_async_cancellation(poll_until: PollUntil) -> None: - """TaskCancelledError → cancelled result.""" - import cloudpickle - - from taskito.async_support.executor import AsyncTaskExecutor - - sender = MagicMock() - - async def cancelling_task() -> None: - raise TaskCancelledError("cancelled") - - class FakeWrapper: - _taskito_async_fn = staticmethod(cancelling_task) - - registry: dict[str, Any] = {"mod.cancelling_task": FakeWrapper()} - - queue_ref = MagicMock() - queue_ref._interceptor = None - queue_ref._proxy_registry = None - queue_ref._test_mode_active = False - queue_ref._resource_runtime = None - queue_ref._task_inject_map = {} - queue_ref._task_retry_filters = {} - queue_ref._get_middleware_chain.return_value = [] - queue_ref._proxy_metrics = None - - executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10) - executor.start() - - payload = cloudpickle.dumps(((), {})) - executor.submit_job("job-3", "mod.cancelling_task", payload, 0, 3, "default") - poll_until(lambda: sender.report_cancelled.called, message="job-3 cancellation not reported") - executor.stop() - - sender.report_cancelled.assert_called_once() - assert sender.report_cancelled.call_args[0][0] == "job-3" - - -def test_async_retry_filter(poll_until: PollUntil) -> None: - """Failed async task respects retry_on filter.""" - import cloudpickle - - from taskito.async_support.executor import AsyncTaskExecutor - - sender = MagicMock() - - async def flaky_task() -> None: - raise TypeError("wrong type") - - class FakeWrapper: - _taskito_async_fn = staticmethod(flaky_task) - - registry: dict[str, Any] = {"mod.flaky_task": FakeWrapper()} - - queue_ref = MagicMock() - queue_ref._interceptor = None - queue_ref._proxy_registry = None - queue_ref._test_mode_active = False - queue_ref._resource_runtime = None - queue_ref._task_inject_map = {} - # Only retry on ValueError, not TypeError - queue_ref._task_retry_filters = { - "mod.flaky_task": {"retry_on": [ValueError], "dont_retry_on": []}, - } - queue_ref._get_middleware_chain.return_value = [] - queue_ref._proxy_metrics = None - - executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10) - executor.start() - - payload = cloudpickle.dumps(((), {})) - executor.submit_job("job-4", "mod.flaky_task", payload, 0, 3, "default") - poll_until(lambda: sender.report_failure.called, message="job-4 failure not reported") - executor.stop() - - sender.report_failure.assert_called_once() - assert sender.report_failure.call_args[0][6] is False # should_retry = False - - -def test_async_concurrency_limit(poll_until: PollUntil) -> None: - """Semaphore bounds concurrent async tasks.""" - import cloudpickle - - from taskito.async_support.executor import AsyncTaskExecutor - - sender = MagicMock() - max_concurrent = 0 - current = 0 - lock = threading.Lock() - - async def slow_task() -> None: - nonlocal max_concurrent, current - with lock: - current += 1 - max_concurrent = max(max_concurrent, current) - await asyncio.sleep(0.1) - with lock: - current -= 1 - - class FakeWrapper: - _taskito_async_fn = staticmethod(slow_task) - - registry: dict[str, Any] = {"mod.slow_task": FakeWrapper()} - - queue_ref = MagicMock() - queue_ref._interceptor = None - queue_ref._proxy_registry = None - queue_ref._test_mode_active = False - queue_ref._resource_runtime = None - queue_ref._task_inject_map = {} - queue_ref._task_retry_filters = {} - queue_ref._get_middleware_chain.return_value = [] - queue_ref._proxy_metrics = None - - # Set concurrency to 2 - executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=2) - executor.start() - - payload = cloudpickle.dumps(((), {})) - for i in range(5): - executor.submit_job(f"job-{i}", "mod.slow_task", payload, 0, 3, "default") - - poll_until( - lambda: sender.report_success.call_count >= 5, - timeout=10, - message="not all 5 slow_task jobs reported success", - ) - executor.stop() - - assert max_concurrent <= 2 - assert sender.report_success.call_count == 5 - - -def test_async_middleware_hooks(poll_until: PollUntil) -> None: - """Middleware before/after called for async tasks.""" - import cloudpickle - - from taskito.async_support.executor import AsyncTaskExecutor - - before_called: list[str] = [] - after_called: list[str] = [] - - class TestMiddleware(TaskMiddleware): - def before(self, job_context: Any) -> None: - before_called.append(job_context.id) - - def after(self, job_context: Any, result: Any, error: Any) -> None: - after_called.append(job_context.id) - - sender = MagicMock() - - async def simple_task() -> int: - return 42 - - class FakeWrapper: - _taskito_async_fn = staticmethod(simple_task) - - registry: dict[str, Any] = {"mod.simple_task": FakeWrapper()} - - queue_ref = MagicMock() - queue_ref._interceptor = None - queue_ref._proxy_registry = None - queue_ref._test_mode_active = False - queue_ref._resource_runtime = None - queue_ref._task_inject_map = {} - queue_ref._task_retry_filters = {} - queue_ref._get_middleware_chain.return_value = [TestMiddleware()] - queue_ref._proxy_metrics = None - - executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10) - executor.start() - - payload = cloudpickle.dumps(((), {})) - executor.submit_job("mw-job", "mod.simple_task", payload, 0, 3, "default") - poll_until(lambda: "mw-job" in after_called, message="middleware after hook not called") - executor.stop() - - assert "mw-job" in before_called - assert "mw-job" in after_called - - -def test_async_task_with_injection(poll_until: PollUntil) -> None: - """inject=["db"] works for async tasks via executor.""" - import cloudpickle - - from taskito.async_support.executor import AsyncTaskExecutor - - sender = MagicMock() - - async def db_task(db: Any = None) -> str: - return f"got-{db}" - - class FakeWrapper: - _taskito_async_fn = staticmethod(db_task) - - registry: dict[str, Any] = {"mod.db_task": FakeWrapper()} - - fake_db = "fake-conn" - - queue_ref = MagicMock() - queue_ref._interceptor = None - queue_ref._proxy_registry = None - queue_ref._test_mode_active = False - queue_ref._task_inject_map = {"mod.db_task": ["db"]} - queue_ref._task_retry_filters = {} - queue_ref._get_middleware_chain.return_value = [] - queue_ref._proxy_metrics = None - - # Mock resource runtime - runtime = MagicMock() - runtime.acquire_for_task.return_value = (fake_db, None) - queue_ref._resource_runtime = runtime - - executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10) - executor.start() - - payload = cloudpickle.dumps(((), {})) - executor.submit_job("inj-job", "mod.db_task", payload, 0, 3, "default") - poll_until(lambda: sender.report_success.called, message="inj-job result not reported") - executor.stop() - - sender.report_success.assert_called_once() - result = cloudpickle.loads(sender.report_success.call_args[0][2]) - assert result == "got-fake-conn" - - -def test_async_context_available_inside_task(poll_until: PollUntil) -> None: - """current_job.id works inside an async task via contextvars.""" - import cloudpickle - - from taskito.async_support.executor import AsyncTaskExecutor - - sender = MagicMock() - captured_id: list[str] = [] - - async def ctx_task() -> str: - captured_id.append(current_job.id) - return "ok" - - class FakeWrapper: - _taskito_async_fn = staticmethod(ctx_task) - - registry: dict[str, Any] = {"mod.ctx_task": FakeWrapper()} - - queue_ref = MagicMock() - queue_ref._interceptor = None - queue_ref._proxy_registry = None - queue_ref._test_mode_active = False - queue_ref._resource_runtime = None - queue_ref._task_inject_map = {} - queue_ref._task_retry_filters = {} - queue_ref._get_middleware_chain.return_value = [] - queue_ref._proxy_metrics = None - - executor = AsyncTaskExecutor(sender, registry, queue_ref, max_concurrency=10) - executor.start() - - payload = cloudpickle.dumps(((), {})) - executor.submit_job("ctx-job", "mod.ctx_task", payload, 0, 3, "default") - poll_until(lambda: captured_id == ["ctx-job"], message="ctx-job context not captured") - executor.stop() - - assert captured_id == ["ctx-job"] - - -def test_async_concurrency_parameter(tmp_path: Path) -> None: - """Queue accepts async_concurrency parameter.""" - queue = Queue(db_path=str(tmp_path / "test.db"), async_concurrency=50) - assert queue._async_concurrency == 50 - - -def test_async_concurrency_default(tmp_path: Path) -> None: - """Default async_concurrency is 100.""" - queue = Queue(db_path=str(tmp_path / "test.db")) - assert queue._async_concurrency == 100 diff --git a/tests/test_observability.py b/tests/test_observability.py deleted file mode 100644 index 7e0f17a..0000000 --- a/tests/test_observability.py +++ /dev/null @@ -1,309 +0,0 @@ -"""Tests for resource observability (Phase 6 — status API, CLI, health checks).""" - -from __future__ import annotations - -import time -from typing import Any - -import pytest - -from taskito import Queue -from taskito.resources import ResourceDefinition, ResourceRuntime - -# --------------------------------------------------------------------------- -# ResourceRuntime.status() -# --------------------------------------------------------------------------- - - -class TestResourceRuntimeStatus: - def test_status_returns_all_resources(self) -> None: - """status() returns an entry for every initialized resource.""" - defs = { - "config": ResourceDefinition(name="config", factory=lambda: {}), - "db": ResourceDefinition( - name="db", factory=lambda config: "conn", depends_on=["config"] - ), - } - rt = ResourceRuntime(defs) - rt.initialize() - - entries = rt.status() - assert len(entries) == 2 - names = [e["name"] for e in entries] - assert "config" in names - assert "db" in names - rt.teardown() - - def test_status_healthy_resource(self) -> None: - """A successfully initialized resource has health='healthy'.""" - defs = {"svc": ResourceDefinition(name="svc", factory=lambda: "ok")} - rt = ResourceRuntime(defs) - rt.initialize() - - entry = rt.status()[0] - assert entry["health"] == "healthy" - assert entry["name"] == "svc" - assert entry["scope"] == "worker" - assert entry["recreations"] == 0 - assert entry["init_duration_ms"] >= 0 - assert entry["depends_on"] == [] - rt.teardown() - - def test_status_unhealthy_resource(self) -> None: - """An unhealthy resource is reported as such.""" - defs = {"svc": ResourceDefinition(name="svc", factory=lambda: "ok")} - rt = ResourceRuntime(defs) - rt.initialize() - rt._unhealthy.add("svc") - - entry = rt.status()[0] - assert entry["health"] == "unhealthy" - rt.teardown() - - def test_status_tracks_init_duration(self) -> None: - """init_duration_ms is populated after initialize.""" - - def slow_factory() -> str: - time.sleep(0.05) - return "result" - - defs = {"slow": ResourceDefinition(name="slow", factory=slow_factory)} - rt = ResourceRuntime(defs) - rt.initialize() - - entry = rt.status()[0] - assert entry["init_duration_ms"] >= 40 # at least ~50ms minus timing jitter - rt.teardown() - - def test_status_tracks_recreations(self) -> None: - """recreation count is incremented on successful recreate.""" - call_count = 0 - - def make_svc() -> str: - nonlocal call_count - call_count += 1 - return f"v{call_count}" - - defs = { - "svc": ResourceDefinition( - name="svc", - factory=make_svc, - health_check=lambda inst: True, - health_check_interval=1.0, - ), - } - rt = ResourceRuntime(defs) - rt.initialize() - assert rt._recreation_count.get("svc", 0) == 0 - - # Manually trigger recreation - rt.recreate("svc") - assert rt._recreation_count["svc"] == 1 - - entry = rt.status()[0] - assert entry["recreations"] == 1 - rt.teardown() - - def test_status_includes_depends_on(self) -> None: - """depends_on list is included in status output.""" - defs = { - "config": ResourceDefinition(name="config", factory=lambda: {}), - "db": ResourceDefinition( - name="db", factory=lambda config: "conn", depends_on=["config"] - ), - } - rt = ResourceRuntime(defs) - rt.initialize() - - entries = {e["name"]: e for e in rt.status()} - assert entries["config"]["depends_on"] == [] - assert entries["db"]["depends_on"] == ["config"] - rt.teardown() - - def test_from_test_overrides_status(self) -> None: - """Test-override runtime reports healthy status.""" - rt = ResourceRuntime.from_test_overrides({"db": "mock"}) - entries = rt.status() - assert len(entries) == 1 - assert entries[0]["name"] == "db" - assert entries[0]["health"] == "healthy" - - -# --------------------------------------------------------------------------- -# Queue.resource_status() -# --------------------------------------------------------------------------- - - -class TestQueueResourceStatus: - def test_resource_status_with_runtime(self, tmp_path: Any) -> None: - """resource_status() delegates to runtime when initialized.""" - queue = Queue(db_path=str(tmp_path / "q.db")) - - @queue.worker_resource("db") - def create_db() -> str: - return "conn" - - # Manually initialize runtime - rt = ResourceRuntime(queue._resource_definitions) - rt.initialize() - queue._resource_runtime = rt - - status = queue.resource_status() - assert len(status) == 1 - assert status[0]["name"] == "db" - assert status[0]["health"] == "healthy" - rt.teardown() - - def test_resource_status_without_runtime(self, tmp_path: Any) -> None: - """resource_status() returns definitions with not_initialized health.""" - queue = Queue(db_path=str(tmp_path / "q.db")) - - @queue.worker_resource("db") - def create_db() -> str: - return "conn" - - status = queue.resource_status() - assert len(status) == 1 - assert status[0]["name"] == "db" - assert status[0]["health"] == "not_initialized" - - def test_resource_status_empty(self, tmp_path: Any) -> None: - """resource_status() returns empty list with no resources.""" - queue = Queue(db_path=str(tmp_path / "q.db")) - assert queue.resource_status() == [] - - -# --------------------------------------------------------------------------- -# Health check integration -# --------------------------------------------------------------------------- - - -class TestHealthCheckIntegration: - def test_readiness_reports_healthy_resources(self, tmp_path: Any) -> None: - """check_readiness includes resource status when all healthy.""" - from taskito.health import check_readiness - - queue = Queue(db_path=str(tmp_path / "q.db")) - - @queue.worker_resource("db") - def create_db() -> str: - return "conn" - - rt = ResourceRuntime(queue._resource_definitions) - rt.initialize() - queue._resource_runtime = rt - - result = check_readiness(queue) - assert "resources" in result["checks"] - res_check = result["checks"]["resources"] - assert res_check["status"] == "ok" - assert res_check["count"] == 1 - assert res_check["unhealthy"] == [] - rt.teardown() - - def test_readiness_reports_unhealthy_resources(self, tmp_path: Any) -> None: - """check_readiness marks status as degraded for unhealthy resources.""" - from taskito.health import check_readiness - - queue = Queue(db_path=str(tmp_path / "q.db")) - - @queue.worker_resource("db") - def create_db() -> str: - return "conn" - - rt = ResourceRuntime(queue._resource_definitions) - rt.initialize() - rt._unhealthy.add("db") - queue._resource_runtime = rt - - result = check_readiness(queue) - assert result["status"] == "degraded" - res_check = result["checks"]["resources"] - assert res_check["status"] == "degraded" - assert "db" in res_check["unhealthy"] - rt.teardown() - - def test_readiness_no_resources(self, tmp_path: Any) -> None: - """check_readiness works fine without any resources.""" - from taskito.health import check_readiness - - queue = Queue(db_path=str(tmp_path / "q.db")) - result = check_readiness(queue) - # No resources section when empty - assert ( - "resources" not in result["checks"] - or result["checks"].get("resources", {}).get("count", 0) == 0 - ) - - def test_health_check_always_ok(self) -> None: - """check_health is always ok regardless of resources.""" - from taskito.health import check_health - - assert check_health() == {"status": "ok"} - - -# --------------------------------------------------------------------------- -# CLI resources subcommand -# --------------------------------------------------------------------------- - - -class TestCLIResources: - def test_run_resources_no_resources(self, tmp_path: Any) -> None: - """resource_status returns empty list when no resources registered.""" - queue = Queue(db_path=str(tmp_path / "q.db")) - assert queue.resource_status() == [] - - def test_resource_status_table_format( - self, tmp_path: Any, capsys: pytest.CaptureFixture[str] - ) -> None: - """Verify table output format from CLI helper.""" - queue = Queue(db_path=str(tmp_path / "q.db")) - - @queue.worker_resource("config") - def create_config() -> dict[str, str]: - return {} - - @queue.worker_resource("db", depends_on=["config"]) - def create_db(config: Any) -> str: - return "conn" - - rt = ResourceRuntime(queue._resource_definitions) - rt.initialize() - queue._resource_runtime = rt - - # Simulate what run_resources prints - resources = queue.resource_status() - assert len(resources) == 2 - names = {r["name"] for r in resources} - assert names == {"config", "db"} - - # Verify structure matches CLI expectations - for r in resources: - assert "name" in r - assert "scope" in r - assert "health" in r - assert "init_duration_ms" in r - assert "recreations" in r - assert "depends_on" in r - - rt.teardown() - - -# --------------------------------------------------------------------------- -# Prometheus resource metrics -# --------------------------------------------------------------------------- - - -class TestPrometheusResourceMetrics: - def test_prometheus_middleware_has_resource_metrics(self) -> None: - """Verify resource metric singletons are initialized.""" - pytest.importorskip("prometheus_client") - from taskito.contrib.prometheus import _init_metrics # type: ignore[attr-defined] - - _init_metrics() - - from taskito.contrib import prometheus as pmod - - assert pmod._resource_health is not None # type: ignore[attr-defined] - assert pmod._resource_recreations is not None # type: ignore[attr-defined] - assert pmod._resource_init_duration is not None # type: ignore[attr-defined] diff --git a/tests/test_periodic.py b/tests/test_periodic.py deleted file mode 100644 index 3cabf77..0000000 --- a/tests/test_periodic.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Tests for periodic (cron-scheduled) tasks.""" - -import threading -from pathlib import Path -from typing import Any - -import pytest - -from taskito import Queue - - -@pytest.fixture -def queue(tmp_path: Path) -> Queue: - db_path = str(tmp_path / "test_periodic.db") - return Queue(db_path=db_path, workers=1) - - -def test_periodic_task_registration(queue: Queue) -> None: - """Periodic tasks are registered as both regular tasks and periodic configs.""" - - @queue.periodic(cron="0 * * * * *") - def every_minute() -> str: - return "tick" - - assert every_minute.name.endswith("every_minute") - assert every_minute.name in queue._task_registry - assert len(queue._periodic_configs) == 1 - assert queue._periodic_configs[0]["cron_expr"] == "0 * * * * *" - - -def test_periodic_task_direct_call(queue: Queue) -> None: - """Periodic tasks can still be called directly like regular tasks.""" - - @queue.periodic(cron="0 * * * * *") - def add(a: int, b: int) -> int: - return a + b - - assert add(3, 4) == 7 - - -def test_periodic_task_triggers(queue: Queue, poll_until: Any) -> None: - """Periodic task gets enqueued by the scheduler when due.""" - results: list[int] = [] - - @queue.periodic(cron="* * * * * *") # every second - def frequent_task() -> str: - results.append(1) - return "done" - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - poll_until( - lambda: queue.stats()["completed"] >= 1, - timeout=15, - message="periodic task never triggered", - ) - - assert queue.stats()["completed"] >= 1 diff --git a/tests/test_prefork.py b/tests/test_prefork.py deleted file mode 100644 index fa1a366..0000000 --- a/tests/test_prefork.py +++ /dev/null @@ -1,323 +0,0 @@ -"""Tests for the prefork (multi-process) worker pool.""" - -from __future__ import annotations - -import contextlib -import importlib -import os -import sys -import threading -import time -from collections.abc import Iterator -from pathlib import Path -from typing import Any - -import pytest - -from taskito import Queue -from taskito.context import JobContext -from taskito.middleware import TaskMiddleware - - -@pytest.mark.skipif( - sys.platform == "win32", - reason="prefork is rejected on Windows before app= is checked", -) -def test_prefork_requires_app_path(tmp_path: Path) -> None: - """pool='prefork' without app= raises ValueError.""" - queue = Queue(db_path=str(tmp_path / "test.db")) - - @queue.task() - def noop() -> None: - pass - - with pytest.raises(ValueError, match="app= is required"): - queue.run_worker(pool="prefork") - - -def test_prefork_rejected_on_windows(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None: - """pool='prefork' on Windows fails fast with NotImplementedError.""" - monkeypatch.setattr(sys, "platform", "win32") - queue = Queue(db_path=str(tmp_path / "test.db")) - with pytest.raises(NotImplementedError, match="not supported on Windows"): - queue.run_worker(pool="prefork", app="x:y") - - -def test_prefork_basic_execution(tmp_path: Path) -> None: - """A task enqueued and processed by a prefork worker returns the correct result. - - NOTE: Prefork children import the app module independently, so the task name - must resolve to the same module path in both parent and child. Tasks defined - inside test functions can't be imported by children — use module-level tasks. - This test is currently skipped; see test_prefork_module_level_queue for the - working version. - """ - pytest.skip("Tasks defined inside functions can't be imported by prefork children") - - -def test_prefork_thread_pool_unchanged(tmp_path: Path) -> None: - """pool='thread' (default) still works normally.""" - q = Queue(db_path=str(tmp_path / "test.db")) - - @q.task() - def multiply(x: int, y: int) -> int: - return x * y - - job = multiply.delay(3, 7) - - worker = threading.Thread(target=q.run_worker, daemon=True) - worker.start() - - result = job.result(timeout=10) - assert result == 21 - - q._inner.request_shutdown() - - -# --------------------------------------------------------------------------- -# Per-job timeout enforcement (issue #81) -# --------------------------------------------------------------------------- - -# The prefork pool is Unix-oriented: child processes communicate over anonymous -# stdio pipes, which on Windows have different blocking semantics that make -# parent-side reader threads hang after `TerminateProcess`. Per-job timeout -# behaviour itself is identical, but the surrounding pool plumbing isn't -# Windows-ready, so these end-to-end tests are skipped there. -prefork_unix_only = pytest.mark.skipif( - sys.platform == "win32", - reason="prefork pool is Unix-only — child stdio pipe semantics differ on Windows", -) - -PREFORK_APP_PATH = "prefork_apps.timeout_app:queue" -PREFORK_APP_DIR = str(Path(__file__).parent) - - -@pytest.fixture -def timeout_app(tmp_path: Path) -> Iterator[object]: - """Set up the module-level timeout-test app with a per-test DB path. - - The Queue inside ``prefork_apps.timeout_app`` is constructed at import time - from ``$TASKITO_TIMEOUT_TEST_DB``, and the prefork child re-imports the - same module fresh in its own interpreter — so the env var must be set in - the parent process before that import happens, and propagates to the - child via inherited env. - """ - db_path = str(tmp_path / "timeout.db") - prev_db = os.environ.get("TASKITO_TIMEOUT_TEST_DB") - prev_pythonpath = os.environ.get("PYTHONPATH") - - os.environ["TASKITO_TIMEOUT_TEST_DB"] = db_path - # Make `prefork_apps.timeout_app` importable in both parent and (inherited) - # child without depending on pytest's rootdir manipulation. - os.environ["PYTHONPATH"] = ( - f"{PREFORK_APP_DIR}{os.pathsep}{prev_pythonpath}" if prev_pythonpath else PREFORK_APP_DIR - ) - if PREFORK_APP_DIR not in sys.path: - sys.path.insert(0, PREFORK_APP_DIR) - - # Force a fresh module import so the Queue picks up the per-test DB path. - sys.modules.pop("prefork_apps.timeout_app", None) - sys.modules.pop("prefork_apps", None) - module = importlib.import_module("prefork_apps.timeout_app") - - try: - yield module - finally: - with contextlib.suppress(Exception): - module.queue._inner.request_shutdown() - if prev_db is None: - os.environ.pop("TASKITO_TIMEOUT_TEST_DB", None) - else: - os.environ["TASKITO_TIMEOUT_TEST_DB"] = prev_db - if prev_pythonpath is None: - os.environ.pop("PYTHONPATH", None) - else: - os.environ["PYTHONPATH"] = prev_pythonpath - - -def _start_prefork_worker(queue: Queue) -> threading.Thread: - """Start a prefork worker for ``queue`` in a daemon thread.""" - thread = threading.Thread( - target=queue.run_worker, - kwargs={"pool": "prefork", "app": PREFORK_APP_PATH}, - daemon=True, - ) - thread.start() - return thread - - -def _wait_for_terminal(job: Any, timeout: float) -> str: - """Poll a JobResult.refresh() until the status is terminal or `timeout` elapses.""" - deadline = time.monotonic() + timeout - while time.monotonic() < deadline: - job.refresh() - status: str = job.status - if status in {"complete", "failed", "dead", "cancelled"}: - return status - time.sleep(0.1) - job.refresh() - final_status: str = job.status - return final_status - - -@prefork_unix_only -def test_prefork_kills_hung_task(timeout_app: object) -> None: - """A task that hangs past its `timeout=` is SIGKILLed by the watchdog and - reported as a timeout failure within the timeout + watchdog tick budget.""" - timeouts_seen: list[str] = [] - - class TimeoutSpy(TaskMiddleware): - def on_timeout(self, ctx: JobContext) -> None: - timeouts_seen.append(ctx.id) - - queue: Queue = timeout_app.queue # type: ignore[attr-defined] - queue._global_middleware.append(TimeoutSpy()) - - started = time.monotonic() - job = timeout_app.hang.delay() # type: ignore[attr-defined] - _start_prefork_worker(queue) - - # timeout=2s, watchdog tick=250ms → kill within ~2.25s; allow generous - # headroom for child spawn and CI noise. - status = _wait_for_terminal(job, timeout=15) - elapsed = time.monotonic() - started - - assert status == "dead", f"expected 'dead', got {status!r} (error={job.error!r})" - assert "timed out" in (job.error or "").lower() - assert elapsed < 12, f"hung task took {elapsed:.1f}s to be killed (expected < 12s)" - assert job.id in timeouts_seen, "on_timeout middleware did not fire" - - -@prefork_unix_only -def test_prefork_no_timeout_unaffected(timeout_app: object) -> None: - """A task with no timeout (timeout=0) runs to completion — the watchdog - must not kill jobs that have no deadline configured.""" - queue: Queue = timeout_app.queue # type: ignore[attr-defined] - - job = timeout_app.quick.delay(21) # type: ignore[attr-defined] - _start_prefork_worker(queue) - - result = job.result(timeout=15) - assert result == 42 - - -@prefork_unix_only -def test_prefork_finishes_before_deadline(timeout_app: object) -> None: - """A task that completes well before its deadline returns normally — the - watchdog only fires when the deadline is actually crossed.""" - queue: Queue = timeout_app.queue # type: ignore[attr-defined] - - # timeout=2s, sleep 0.5s — should finish cleanly. - job = timeout_app.sleep_then_finish.delay(0.5) # type: ignore[attr-defined] - _start_prefork_worker(queue) - - result = job.result(timeout=15) - assert result == "done" - - -# --------------------------------------------------------------------------- -# Cooperative cancellation propagation (issue #82) -# --------------------------------------------------------------------------- - -CANCEL_APP_PATH = "prefork_apps.cancel_app:queue" - - -@pytest.fixture -def cancel_app(tmp_path: Path) -> Iterator[object]: - """Set up the module-level cancel-test app with a per-test DB path. - - Mirrors ``timeout_app`` — must set the env var before the import so the - parent's Queue construction and the child's re-import see the same DB. - """ - db_path = str(tmp_path / "cancel.db") - prev_db = os.environ.get("TASKITO_CANCEL_TEST_DB") - prev_pythonpath = os.environ.get("PYTHONPATH") - - os.environ["TASKITO_CANCEL_TEST_DB"] = db_path - os.environ["PYTHONPATH"] = ( - f"{PREFORK_APP_DIR}{os.pathsep}{prev_pythonpath}" if prev_pythonpath else PREFORK_APP_DIR - ) - if PREFORK_APP_DIR not in sys.path: - sys.path.insert(0, PREFORK_APP_DIR) - - sys.modules.pop("prefork_apps.cancel_app", None) - sys.modules.pop("prefork_apps", None) - module = importlib.import_module("prefork_apps.cancel_app") - - try: - yield module - finally: - with contextlib.suppress(Exception): - module.queue._inner.request_shutdown() - if prev_db is None: - os.environ.pop("TASKITO_CANCEL_TEST_DB", None) - else: - os.environ["TASKITO_CANCEL_TEST_DB"] = prev_db - if prev_pythonpath is None: - os.environ.pop("PYTHONPATH", None) - else: - os.environ["PYTHONPATH"] = prev_pythonpath - - -def _start_cancel_worker(queue: Queue) -> threading.Thread: - thread = threading.Thread( - target=queue.run_worker, - kwargs={"pool": "prefork", "app": CANCEL_APP_PATH}, - daemon=True, - ) - thread.start() - return thread - - -@prefork_unix_only -def test_prefork_cancel_running_job_stops_quickly(cancel_app: object, poll_until: Any) -> None: - """``cancel_running_job`` propagates to the prefork child and stops a - cooperative task within a small budget — the regression test for #82.""" - queue: Queue = cancel_app.queue # type: ignore[attr-defined] - - cancels_seen: list[str] = [] - - class CancelSpy(TaskMiddleware): - def on_cancel(self, ctx: JobContext) -> None: - cancels_seen.append(ctx.id) - - queue._global_middleware.append(CancelSpy()) - - job = cancel_app.cooperative_loop.delay(600) # type: ignore[attr-defined] - _start_cancel_worker(queue) - - # Wait until the job is actually running on a child before cancelling. - def _running() -> bool: - job.refresh() - return bool(job.status == "running") - - poll_until(_running, timeout=10, message="job never reached running state") - - assert queue.cancel_running_job(job.id) is True - - status = _wait_for_terminal(job, timeout=10) - assert status == "cancelled", f"expected 'cancelled', got {status!r} (error={job.error!r})" - assert job.id in cancels_seen, "on_cancel middleware did not fire" - - -@prefork_unix_only -def test_prefork_cancel_does_not_kill_child(cancel_app: object, poll_until: Any) -> None: - """A cancel must stop the running task without killing the child — the - next job dispatched to the same pool should still complete normally.""" - queue: Queue = cancel_app.queue # type: ignore[attr-defined] - - long_job = cancel_app.cooperative_loop.delay(600) # type: ignore[attr-defined] - _start_cancel_worker(queue) - - def _running() -> bool: - long_job.refresh() - return bool(long_job.status == "running") - - poll_until(_running, timeout=10, message="long_job never reached running state") - assert queue.cancel_running_job(long_job.id) is True - status = _wait_for_terminal(long_job, timeout=10) - assert status == "cancelled" - - follow_up = cancel_app.quick.delay(21) # type: ignore[attr-defined] - result = follow_up.result(timeout=15) - assert result == 42 diff --git a/tests/test_priority.py b/tests/test_priority.py deleted file mode 100644 index 13c8034..0000000 --- a/tests/test_priority.py +++ /dev/null @@ -1,50 +0,0 @@ -"""Tests for priority scheduling.""" - -import threading -from pathlib import Path -from typing import Any - -import pytest - -from taskito import Queue - - -@pytest.fixture -def queue(tmp_path: Path) -> Queue: - db_path = str(tmp_path / "test_priority.db") - return Queue(db_path=db_path, workers=1) # 1 worker for ordering - - -def test_priority_ordering(queue: Queue, poll_until: Any) -> None: - """Higher priority jobs should be processed first.""" - results: list[str] = [] - - @queue.task() - def record_task(label: str) -> str: - results.append(label) - return label - - # Enqueue low priority first, then high priority. All three jobs are - # synchronously visible after `apply_async` returns, so no pre-worker - # delay is needed. - record_task.apply_async(args=("low",), priority=1) - record_task.apply_async(args=("medium",), priority=5) - record_task.apply_async(args=("high",), priority=10) - - worker_thread = threading.Thread( - target=queue.run_worker, - daemon=True, - ) - worker_thread.start() - - poll_until( - lambda: len(results) >= 3, - timeout=10, - message="not all priority jobs completed", - ) - - # High priority should have been processed first - assert len(results) == 3 - assert results[0] == "high" - assert results[1] == "medium" - assert results[2] == "low" diff --git a/tests/test_progress.py b/tests/test_progress.py deleted file mode 100644 index 58cd7c6..0000000 --- a/tests/test_progress.py +++ /dev/null @@ -1,44 +0,0 @@ -"""Tests for progress tracking.""" - -from __future__ import annotations - -import time - -from taskito import Queue - - -def test_update_progress(queue: Queue) -> None: - """Progress can be updated and read back.""" - - @queue.task() - def slow_task() -> str: - - time.sleep(0.5) - return "done" - - job = slow_task.delay() - - # Update progress directly via the queue API - queue.update_progress(job.id, 50) - - refreshed = queue.get_job(job.id) - assert refreshed is not None - assert refreshed.progress == 50 - - queue.update_progress(job.id, 100) - refreshed = queue.get_job(job.id) - assert refreshed is not None - assert refreshed.progress == 100 - - -def test_progress_starts_none(queue: Queue) -> None: - """Progress is None by default.""" - - @queue.task() - def task_a() -> int: - return 1 - - job = task_a.delay() - refreshed = queue.get_job(job.id) - assert refreshed is not None - assert refreshed.progress is None diff --git a/tests/test_proxies.py b/tests/test_proxies.py deleted file mode 100644 index b6c960f..0000000 --- a/tests/test_proxies.py +++ /dev/null @@ -1,429 +0,0 @@ -"""Tests for the proxy system (Layer 3 — transparent reconstruction).""" - -from __future__ import annotations - -import logging -from typing import Any - -import pytest - -from taskito import ProxyReconstructionError, Queue -from taskito.proxies import ProxyRegistry, cleanup_proxies, reconstruct_proxies -from taskito.proxies.built_in import register_builtin_handlers -from taskito.proxies.handlers.file import FileHandler -from taskito.proxies.handlers.logger import LoggerHandler - -# --------------------------------------------------------------------------- -# FileHandler -# --------------------------------------------------------------------------- - - -class TestFileHandler: - def test_detect_open_file(self, tmp_path: Any) -> None: - f = open(tmp_path / "test.txt", "w") # noqa: SIM115 - try: - handler = FileHandler() - assert handler.detect(f) is True - finally: - f.close() - - def test_detect_closed_file(self, tmp_path: Any) -> None: - f = open(tmp_path / "test.txt", "w") # noqa: SIM115 - f.close() - handler = FileHandler() - assert handler.detect(f) is False - - def test_detect_stdin(self) -> None: - import sys - - handler = FileHandler() - assert handler.detect(sys.stdin) is False - - def test_deconstruct_text_file(self, tmp_path: Any) -> None: - path = tmp_path / "data.txt" - path.write_text("hello world") - with open(path) as f: - handler = FileHandler() - recipe = handler.deconstruct(f) - assert recipe["path"] == str(path) - assert recipe["mode"] == "r" - assert recipe["encoding"] is not None - assert recipe["position"] == 0 - - def test_deconstruct_binary_file(self, tmp_path: Any) -> None: - path = tmp_path / "data.bin" - path.write_bytes(b"\x00\x01\x02") - with open(path, "rb") as f: - handler = FileHandler() - recipe = handler.deconstruct(f) - assert recipe["mode"] == "rb" - assert recipe["encoding"] is None - - def test_reconstruct_text_file(self, tmp_path: Any) -> None: - path = tmp_path / "data.txt" - path.write_text("hello world") - handler = FileHandler() - recipe = {"path": str(path), "mode": "r", "encoding": "utf-8", "position": 0} - f = handler.reconstruct(recipe, version=1) - try: - assert f.read() == "hello world" - finally: - f.close() - - def test_reconstruct_at_position(self, tmp_path: Any) -> None: - path = tmp_path / "data.txt" - path.write_text("hello world") - handler = FileHandler() - recipe = {"path": str(path), "mode": "r", "encoding": "utf-8", "position": 6} - f = handler.reconstruct(recipe, version=1) - try: - assert f.read() == "world" - finally: - f.close() - - def test_cleanup_closes_file(self, tmp_path: Any) -> None: - path = tmp_path / "data.txt" - path.write_text("test") - f = open(path) # noqa: SIM115 - handler = FileHandler() - handler.cleanup(f) - assert f.closed - - def test_cleanup_already_closed(self, tmp_path: Any) -> None: - path = tmp_path / "data.txt" - path.write_text("test") - f = open(path) # noqa: SIM115 - f.close() - handler = FileHandler() - handler.cleanup(f) # no error - - -# --------------------------------------------------------------------------- -# LoggerHandler -# --------------------------------------------------------------------------- - - -class TestLoggerHandler: - def test_detect_logger(self) -> None: - handler = LoggerHandler() - lgr = logging.getLogger("test.proxies.detect") - assert handler.detect(lgr) is True - - def test_detect_non_logger(self) -> None: - handler = LoggerHandler() - assert handler.detect("not a logger") is False - - def test_round_trip(self) -> None: - handler = LoggerHandler() - lgr = logging.getLogger("test.proxies.roundtrip") - lgr.setLevel(logging.WARNING) - recipe = handler.deconstruct(lgr) - reconstructed = handler.reconstruct(recipe, version=1) - assert reconstructed.name == "test.proxies.roundtrip" - assert reconstructed.level == logging.WARNING - - def test_cleanup_noop(self) -> None: - handler = LoggerHandler() - lgr = logging.getLogger("test.proxies.cleanup") - handler.cleanup(lgr) # no error - - -# --------------------------------------------------------------------------- -# ProxyRegistry -# --------------------------------------------------------------------------- - - -class TestProxyRegistry: - def test_register_and_get(self) -> None: - reg = ProxyRegistry() - handler = FileHandler() - reg.register(handler) - assert reg.get("file") is handler - - def test_get_missing(self) -> None: - reg = ProxyRegistry() - assert reg.get("nonexistent") is None - - def test_find_handler(self, tmp_path: Any) -> None: - reg = ProxyRegistry() - register_builtin_handlers(reg) - f = open(tmp_path / "test.txt", "w") # noqa: SIM115 - try: - found = reg.find_handler(f) - assert found is not None - assert found.name == "file" - finally: - f.close() - - def test_find_handler_no_match(self) -> None: - reg = ProxyRegistry() - register_builtin_handlers(reg) - assert reg.find_handler(42) is None - - -# --------------------------------------------------------------------------- -# Reconstruction -# --------------------------------------------------------------------------- - - -class TestReconstruct: - def test_proxy_marker_reconstructed(self, tmp_path: Any) -> None: - path = tmp_path / "data.txt" - path.write_text("content") - reg = ProxyRegistry() - reg.register(FileHandler()) - - marker = { - "__taskito_proxy__": True, - "handler": "file", - "version": 1, - "identity": "id-1", - "recipe": { - "path": str(path), - "mode": "r", - "encoding": "utf-8", - "position": 0, - }, - } - args, _kwargs, cleanup_list = reconstruct_proxies((marker,), {}, reg) - try: - assert hasattr(args[0], "read") - assert args[0].read() == "content" - assert len(cleanup_list) == 1 - finally: - cleanup_proxies(cleanup_list) - - def test_cleanup_list_populated(self, tmp_path: Any) -> None: - path = tmp_path / "data.txt" - path.write_text("test") - reg = ProxyRegistry() - reg.register(FileHandler()) - - marker = { - "__taskito_proxy__": True, - "handler": "file", - "version": 1, - "identity": "id-2", - "recipe": { - "path": str(path), - "mode": "r", - "encoding": "utf-8", - "position": 0, - }, - } - _, _, cleanup_list = reconstruct_proxies((marker,), {}, reg) - assert len(cleanup_list) == 1 - handler, obj = cleanup_list[0] - assert handler.name == "file" - assert not obj.closed - cleanup_proxies(cleanup_list) - assert obj.closed - - def test_cleanup_runs_lifo(self, tmp_path: Any) -> None: - """Cleanup runs in reverse reconstruction order.""" - p1 = tmp_path / "a.txt" - p1.write_text("a") - p2 = tmp_path / "b.txt" - p2.write_text("b") - reg = ProxyRegistry() - reg.register(FileHandler()) - - markers = [ - { - "__taskito_proxy__": True, - "handler": "file", - "version": 1, - "identity": f"id-{i}", - "recipe": { - "path": str(p), - "mode": "r", - "encoding": "utf-8", - "position": 0, - }, - } - for i, p in enumerate([p1, p2]) - ] - args, _, cleanup_list = reconstruct_proxies(tuple(markers), {}, reg) - assert not args[0].closed - assert not args[1].closed - cleanup_proxies(cleanup_list) - assert args[0].closed - assert args[1].closed - - def test_cleanup_catches_errors(self) -> None: - """Cleanup errors are logged, not raised.""" - - class BadHandler: - name = "bad" - version = 1 - handled_types: tuple[type, ...] = () - - def detect(self, obj: Any) -> bool: - return False - - def deconstruct(self, obj: Any) -> dict[str, Any]: - return {} - - def reconstruct(self, recipe: dict[str, Any], version: int) -> Any: - return "obj" - - def cleanup(self, obj: Any) -> None: - raise RuntimeError("cleanup boom") - - cleanup_list: list[tuple[Any, Any]] = [(BadHandler(), "obj")] - cleanup_proxies(cleanup_list) # should not raise - - def test_missing_handler_raises(self) -> None: - reg = ProxyRegistry() - marker = { - "__taskito_proxy__": True, - "handler": "nonexistent", - "version": 1, - "recipe": {}, - } - with pytest.raises(ProxyReconstructionError, match="No proxy handler"): - reconstruct_proxies((marker,), {}, reg) - - def test_no_markers_passthrough(self) -> None: - """Args without markers pass through unchanged.""" - reg = ProxyRegistry() - register_builtin_handlers(reg) - args, kwargs, cleanup = reconstruct_proxies((1, "hello", [3, 4]), {"key": "val"}, reg) - assert args == (1, "hello", [3, 4]) - assert kwargs == {"key": "val"} - assert cleanup == [] - - -# --------------------------------------------------------------------------- -# Identity tracking -# --------------------------------------------------------------------------- - - -class TestIdentityTracking: - def test_same_object_deduped(self, tmp_path: Any) -> None: - """Same file handle passed twice produces one marker and one ref.""" - path = tmp_path / "data.txt" - path.write_text("test") - - queue = Queue(db_path=str(tmp_path / "q.db"), interception="strict") - f = open(path) # noqa: SIM115 - try: - assert queue._interceptor is not None - walker = queue._interceptor._walker - args, _kw, _res = walker.walk((f, f), {}) - - # First should be a full proxy marker - assert args[0].get("__taskito_proxy__") is True - # Second should be a reference to the first - assert "__taskito_ref__" in args[1] - assert args[1]["__taskito_ref__"] == args[0]["identity"] - finally: - f.close() - - def test_identity_reconstructed_once(self, tmp_path: Any) -> None: - """Reconstruction creates one object; both positions share it.""" - path = tmp_path / "data.txt" - path.write_text("shared") - reg = ProxyRegistry() - reg.register(FileHandler()) - - identity = "shared-id" - marker = { - "__taskito_proxy__": True, - "handler": "file", - "version": 1, - "identity": identity, - "recipe": { - "path": str(path), - "mode": "r", - "encoding": "utf-8", - "position": 0, - }, - } - ref = {"__taskito_ref__": identity} - - args, _, cleanup = reconstruct_proxies((marker, ref), {}, reg) - try: - assert args[0] is args[1] # same object - assert len(cleanup) == 1 # only one cleanup entry - finally: - cleanup_proxies(cleanup) - - def test_different_objects_separate(self, tmp_path: Any) -> None: - """Two different file handles get separate recipes.""" - p1 = tmp_path / "a.txt" - p1.write_text("a") - p2 = tmp_path / "b.txt" - p2.write_text("b") - - queue = Queue(db_path=str(tmp_path / "q.db"), interception="strict") - f1 = open(p1) # noqa: SIM115 - f2 = open(p2) # noqa: SIM115 - try: - assert queue._interceptor is not None - walker = queue._interceptor._walker - args, _, _ = walker.walk((f1, f2), {}) - assert args[0].get("__taskito_proxy__") is True - assert args[1].get("__taskito_proxy__") is True - assert args[0]["identity"] != args[1]["identity"] - finally: - f1.close() - f2.close() - - -# --------------------------------------------------------------------------- -# Proxy in nested structures -# --------------------------------------------------------------------------- - - -def test_proxy_in_nested_dict(tmp_path: Any) -> None: - """File inside a dict is proxied.""" - path = tmp_path / "nested.txt" - path.write_text("nested content") - - queue = Queue(db_path=str(tmp_path / "q.db"), interception="strict") - f = open(path) # noqa: SIM115 - try: - assert queue._interceptor is not None - walker = queue._interceptor._walker - _, kwargs, _ = walker.walk((), {"config": {"file": f}}) - inner = kwargs["config"]["file"] - assert inner.get("__taskito_proxy__") is True - assert inner["handler"] == "file" - finally: - f.close() - - -# --------------------------------------------------------------------------- -# End-to-end with Queue test mode -# --------------------------------------------------------------------------- - - -def test_proxy_roundtrip_in_test_mode(queue: Queue) -> None: - """In test mode, original objects pass through (no serialization).""" - captured: list[Any] = [] - - @queue.task() - def use_logger(lgr: Any) -> None: - captured.append(lgr) - - lgr = logging.getLogger("test.proxy.e2e") - with queue.test_mode() as results: - use_logger.delay(lgr) - - assert len(results) == 1 - assert results[0].succeeded - assert captured[0] is lgr - - -def test_logger_proxy_marker_production(tmp_path: Any) -> None: - """Logger produces a proxy marker when interception is on.""" - queue = Queue(db_path=str(tmp_path / "q.db"), interception="strict") - lgr = logging.getLogger("test.proxy.marker") - - assert queue._interceptor is not None - walker = queue._interceptor._walker - args, _, _ = walker.walk((lgr,), {}) - assert args[0].get("__taskito_proxy__") is True - assert args[0]["handler"] == "logger" - assert args[0]["recipe"]["name"] == "test.proxy.marker" diff --git a/tests/test_rate_limit.py b/tests/test_rate_limit.py deleted file mode 100644 index 6a03dd3..0000000 --- a/tests/test_rate_limit.py +++ /dev/null @@ -1,37 +0,0 @@ -"""Tests for rate limiting.""" - -import threading -import time - -from taskito import Queue - - -def test_rate_limit_throttles(queue: Queue) -> None: - """Rate-limited tasks should be throttled.""" - timestamps: list[float] = [] - - @queue.task(rate_limit="2/s") - def rate_limited_task(n: int) -> int: - timestamps.append(time.monotonic()) - return n - - # Enqueue 4 tasks (at 2/s, should take ~2s) - for i in range(4): - rate_limited_task.delay(i) - - worker_thread = threading.Thread( - target=queue.run_worker, - daemon=True, - ) - worker_thread.start() - - # Wait for all tasks - time.sleep(5) - - # Should have all 4 results - assert len(timestamps) == 4 - - # The time span should be >= 1s (since 4 tasks at 2/s = 2s minimum) - if len(timestamps) >= 2: - span = timestamps[-1] - timestamps[0] - assert span >= 0.5 # Allow some slack diff --git a/tests/test_resource_system_full.py b/tests/test_resource_system_full.py deleted file mode 100644 index 084339d..0000000 --- a/tests/test_resource_system_full.py +++ /dev/null @@ -1,675 +0,0 @@ -"""Tests for the full resource system — all phases A through M.""" - -import collections -import contextvars -import tempfile -import threading -from typing import NamedTuple -from unittest.mock import MagicMock - -import pytest - -from taskito import Inject, MockResource, Queue -from taskito.exceptions import ( - ProxyReconstructionError, - ResourceError, - ResourceUnavailableError, -) -from taskito.inject import _InjectAlias -from taskito.interception.built_in import build_default_registry -from taskito.interception.converters import ( - convert_named_tuple, - convert_ordered_dict, - reconstruct_converted, - reconstruct_named_tuple, - reconstruct_ordered_dict, -) -from taskito.interception.metrics import InterceptionMetrics -from taskito.interception.walker import ArgumentWalker -from taskito.proxies.metrics import ProxyMetrics -from taskito.proxies.no_proxy import NoProxy as NoProxyClass -from taskito.proxies.schema import FieldSpec, validate_recipe -from taskito.proxies.signing import sign_recipe, verify_recipe -from taskito.resources.definition import ResourceDefinition, ResourceScope -from taskito.resources.frozen import FrozenResource -from taskito.resources.pool import PoolConfig, ResourcePool -from taskito.resources.runtime import ResourceRuntime -from taskito.resources.thread_local import ThreadLocalStore - -# ─── Phase A — NamedTuple / OrderedDict / lambda / tempfile ─── - - -class Point(NamedTuple): - x: int - y: int - - -class TestNamedTupleConvert: - def test_round_trip(self) -> None: - p = Point(3, 7) - data = convert_named_tuple(p) - assert data["__taskito_convert__"] is True - assert data["type_key"] == "named_tuple" - result = reconstruct_named_tuple(data) - assert result == p - assert isinstance(result, Point) - - def test_via_reconstruct_dispatch(self) -> None: - data = convert_named_tuple(Point(1, 2)) - result = reconstruct_converted(data) - assert result == Point(1, 2) - - -class TestOrderedDictConvert: - def test_round_trip(self) -> None: - od = collections.OrderedDict([("b", 2), ("a", 1)]) - data = convert_ordered_dict(od) - assert data["__taskito_convert__"] is True - assert data["type_key"] == "ordered_dict" - result = reconstruct_ordered_dict(data) - assert result == od - assert isinstance(result, collections.OrderedDict) - assert list(result.keys()) == ["b", "a"] # order preserved - - -class TestWalkerPhaseA: - def _make_walker(self) -> ArgumentWalker: - reg = build_default_registry() - return ArgumentWalker(reg, max_depth=10) - - def test_named_tuple_auto_detected(self) -> None: - walker = self._make_walker() - args = (Point(10, 20),) - new_args, _, _result = walker.walk(args, {}) - assert new_args[0]["__taskito_convert__"] is True - assert new_args[0]["type_key"] == "named_tuple" - - def test_lambda_rejected(self) -> None: - walker = self._make_walker() - fn = lambda x: x + 1 # noqa: E731 - args = (fn,) - _, _, result = walker.walk(args, {}) - assert len(result.failures) == 1 - assert "lambda" in result.failures[0].type_name.lower() - - def test_tempfile_rejected(self) -> None: - walker = self._make_walker() - with tempfile.NamedTemporaryFile() as f: - _, _, result = walker.walk((f,), {}) - assert len(result.failures) == 1 - reason = result.failures[0].reason.lower() - assert "tempfile" in reason or "temporary" in reason - - def test_ordered_dict_converted(self) -> None: - walker = self._make_walker() - od = collections.OrderedDict([("x", 1)]) - new_args, _, _ = walker.walk((od,), {}) - assert new_args[0]["__taskito_convert__"] is True - assert new_args[0]["type_key"] == "ordered_dict" - - def test_contextvars_context_rejected(self) -> None: - reg = build_default_registry() - walker = ArgumentWalker(reg, max_depth=10) - ctx = contextvars.copy_context() - _, _, result = walker.walk((ctx,), {}) - assert len(result.failures) == 1 - assert "context" in result.failures[0].reason.lower() - - -class TestRegisterType: - def test_register_redirect(self, tmp_path: object) -> None: - q = Queue( - db_path=":memory:", - interception="strict", - ) - - class MyDB: - pass - - q.register_type(MyDB, "redirect", resource="db") - # Verify it's registered - assert q._interceptor is not None - entry = q._interceptor._registry.resolve(MyDB()) - assert entry is not None - - def test_requires_interception(self) -> None: - q = Queue(db_path=":memory:", interception="off") - with pytest.raises(RuntimeError, match="Interception is disabled"): - q.register_type(int, "pass") - - -# ─── Phase B — Proxy signing / schema / NoProxy ─── - - -class TestRecipeSigning: - def test_valid_signature(self) -> None: - recipe = {"path": "/tmp/f", "mode": "r"} - sig = sign_recipe("file", 1, recipe, "secret") - assert verify_recipe("file", 1, recipe, sig, "secret") - - def test_invalid_signature_raises(self) -> None: - recipe = {"path": "/tmp/f", "mode": "r"} - with pytest.raises(ProxyReconstructionError, match="checksum mismatch"): - verify_recipe("file", 1, recipe, "bad-checksum", "secret") - - def test_deterministic(self) -> None: - recipe = {"b": 2, "a": 1} - s1 = sign_recipe("h", 1, recipe, "key") - s2 = sign_recipe("h", 1, recipe, "key") - assert s1 == s2 - - -class TestSchemaValidation: - def test_valid_recipe(self) -> None: - schema = {"path": FieldSpec(str), "mode": FieldSpec(str)} - validate_recipe("test", {"path": "/f", "mode": "r"}, schema) - - def test_extra_key_raises(self) -> None: - schema = {"path": FieldSpec(str)} - with pytest.raises(ProxyReconstructionError, match="unexpected keys"): - validate_recipe("test", {"path": "/f", "extra": 1}, schema) - - def test_missing_required_raises(self) -> None: - schema = {"path": FieldSpec(str), "mode": FieldSpec(str)} - with pytest.raises(ProxyReconstructionError, match="missing required"): - validate_recipe("test", {"path": "/f"}, schema) - - def test_optional_field_ok(self) -> None: - schema = {"path": FieldSpec(str), "enc": FieldSpec(str, required=False)} - validate_recipe("test", {"path": "/f"}, schema) - - def test_wrong_type_raises(self) -> None: - schema = {"count": FieldSpec(int)} - with pytest.raises(ProxyReconstructionError, match="expected"): - validate_recipe("test", {"count": "not-an-int"}, schema) - - -class TestNoProxy: - def test_unwrap_in_walker(self) -> None: - reg = build_default_registry() - walker = ArgumentWalker(reg, max_depth=10) - sentinel = object() - wrapped = NoProxyClass(sentinel) - new_args, _, _ = walker.walk((wrapped,), {}) - assert new_args[0] is sentinel - - -# ─── Phase D — Resource scopes ─── - - -class TestResourceScopes: - def test_all_scopes_exist(self) -> None: - assert ResourceScope.WORKER.value == "worker" - assert ResourceScope.TASK.value == "task" - assert ResourceScope.THREAD.value == "thread" - assert ResourceScope.REQUEST.value == "request" - - def test_pool_config(self) -> None: - cfg = PoolConfig(pool_size=5, pool_min=2) - assert cfg.pool_size == 5 - assert cfg.pool_min == 2 - - def test_resource_pool_acquire_release(self) -> None: - created: list[dict[str, int]] = [] - - def factory() -> dict[str, int]: - d: dict[str, int] = {"id": len(created)} - created.append(d) - return d - - pool = ResourcePool( - "test", - factory, - teardown=None, - config=PoolConfig(pool_size=2, acquire_timeout=1.0), - ) - inst = pool.acquire() - assert inst["id"] == 0 - pool.release(inst) - # Re-acquire should return same instance - inst2 = pool.acquire() - assert inst2 is inst - pool.shutdown() - - def test_pool_exhaustion_raises(self) -> None: - pool = ResourcePool( - "test", - lambda: {}, - teardown=None, - config=PoolConfig(pool_size=1, acquire_timeout=0.1), - ) - pool.acquire() - with pytest.raises(ResourceUnavailableError, match="timed out"): - pool.acquire() - pool.shutdown() - - def test_pool_stats(self) -> None: - pool = ResourcePool( - "test", - lambda: {}, - teardown=None, - config=PoolConfig(pool_size=3), - ) - inst = pool.acquire() - s = pool.stats() - assert s["active"] == 1 - assert s["size"] == 3 - pool.release(inst) - s2 = pool.stats() - assert s2["active"] == 0 - assert s2["idle"] == 1 - pool.shutdown() - - def test_pool_factory_failure_does_not_underflow_active_count(self) -> None: - """Active count must remain at zero (not negative) when the factory raises. - - Regression: previously the increment ran before the factory call and - was decremented in the except branch. Any future re-ordering or - intervening release() risked underflowing the counter, surfacing as a - negative `active` in `stats()`. - """ - - def boom() -> None: - raise RuntimeError("factory blew up") - - pool = ResourcePool( - "test", - boom, - teardown=None, - config=PoolConfig(pool_size=2, acquire_timeout=0.1), - ) - - for _ in range(3): - with pytest.raises(RuntimeError, match="factory blew up"): - pool.acquire() - - s = pool.stats() - assert s["active"] == 0 - assert s["total_acquisitions"] == 0 # Failed attempts don't count - # Pool capacity must not leak: a fresh acquire after failures still works. - # (Use a successful factory now to confirm the semaphore wasn't consumed.) - ok_pool = ResourcePool( - "test2", - lambda: {"ok": True}, - teardown=None, - config=PoolConfig(pool_size=1, acquire_timeout=0.1), - ) - for _ in range(3): - inst = ok_pool.acquire() - ok_pool.release(inst) - assert ok_pool.stats()["active"] == 0 - ok_pool.shutdown() - pool.shutdown() - - def test_thread_local_store(self) -> None: - counter = {"n": 0} - - def factory() -> int: - counter["n"] += 1 - return counter["n"] - - store = ThreadLocalStore("test", factory, teardown=None) - v1 = store.get_or_create() - v2 = store.get_or_create() - assert v1 == v2 # Same thread → same instance - - results: list[int] = [] - - def worker() -> None: - results.append(store.get_or_create()) - - t = threading.Thread(target=worker) - t.start() - t.join() - assert results[0] != v1 # Different thread → different instance - store.teardown_all() - - -class TestFrozenResource: - def test_read_allowed(self) -> None: - class Config: - db_url = "postgres://localhost" - - frozen = FrozenResource(Config(), "config") - assert frozen.db_url == "postgres://localhost" - - def test_write_raises(self) -> None: - class Config: - db_url = "x" - - frozen = FrozenResource(Config(), "config") - with pytest.raises(ResourceError, match="read-only"): - frozen.db_url = "y" - - def test_delete_raises(self) -> None: - class Config: - db_url = "x" - - frozen = FrozenResource(Config(), "config") - with pytest.raises(ResourceError, match="read-only"): - del frozen.db_url - - -class TestRuntimeScopeAware: - def test_worker_scope_resolve(self) -> None: - defn = ResourceDefinition( - name="config", - factory=lambda: {"key": "value"}, - scope=ResourceScope.WORKER, - ) - rt = ResourceRuntime({"config": defn}) - rt.initialize() - result = rt.resolve("config") - assert result == {"key": "value"} - rt.teardown() - - def test_acquire_for_task_worker(self) -> None: - defn = ResourceDefinition( - name="config", - factory=lambda: 42, - scope=ResourceScope.WORKER, - ) - rt = ResourceRuntime({"config": defn}) - rt.initialize() - inst, release = rt.acquire_for_task("config") - assert inst == 42 - assert release is None - rt.teardown() - - def test_acquire_for_task_request_scope(self) -> None: - counter = {"n": 0} - - def factory() -> int: - counter["n"] += 1 - return counter["n"] - - defn = ResourceDefinition( - name="req", - factory=factory, - scope=ResourceScope.REQUEST, - ) - rt = ResourceRuntime({"req": defn}) - rt.initialize() - inst1, release1 = rt.acquire_for_task("req") - inst2, release2 = rt.acquire_for_task("req") - assert inst1 != inst2 # Fresh each time - assert release1 is not None - assert release2 is not None - release1() - release2() - rt.teardown() - - def test_frozen_resource_in_runtime(self) -> None: - defn = ResourceDefinition( - name="cfg", - factory=lambda: MagicMock(value=10), - scope=ResourceScope.WORKER, - frozen=True, - ) - rt = ResourceRuntime({"cfg": defn}) - rt.initialize() - inst = rt.resolve("cfg") - assert inst.value == 10 - with pytest.raises(ResourceError, match="read-only"): - inst.new_attr = "x" - rt.teardown() - - def test_reload_reloadable(self) -> None: - counter = {"n": 0} - - def factory() -> int: - counter["n"] += 1 - return counter["n"] - - defn = ResourceDefinition( - name="svc", - factory=factory, - scope=ResourceScope.WORKER, - reloadable=True, - ) - rt = ResourceRuntime({"svc": defn}) - rt.initialize() - assert rt.resolve("svc") == 1 - results = rt.reload() - assert results["svc"] is True - assert rt.resolve("svc") == 2 - rt.teardown() - - def test_reload_skips_non_reloadable(self) -> None: - defn = ResourceDefinition( - name="fixed", - factory=lambda: 1, - scope=ResourceScope.WORKER, - reloadable=False, - ) - rt = ResourceRuntime({"fixed": defn}) - rt.initialize() - results = rt.reload() - assert "fixed" not in results - rt.teardown() - - def test_status_includes_pool(self) -> None: - defn = ResourceDefinition( - name="db", - factory=lambda: {}, - scope=ResourceScope.TASK, - pool_size=5, - ) - rt = ResourceRuntime({"db": defn}) - rt.initialize() - status = rt.status() - assert len(status) == 1 - assert status[0]["name"] == "db" - assert "pool" in status[0] - assert status[0]["pool"]["size"] == 5 - rt.teardown() - - -# ─── Phase E — Inject annotation ─── - - -class TestInjectAnnotation: - def test_inject_alias_created(self) -> None: - alias = Inject["db"] # type: ignore[type-arg,name-defined] - assert isinstance(alias, _InjectAlias) - assert alias.resource_name == "db" - - def test_inject_alias_equality(self) -> None: - assert Inject["db"] == Inject["db"] - assert Inject["db"] != Inject["redis"] - - def test_inject_annotation_detected_in_task(self) -> None: - q = Queue(db_path=":memory:") - - @q.task() - def my_task(x: int, db: Inject["db"]) -> None: # type: ignore[type-arg,name-defined] # noqa: F821 - pass - - assert "db" in q._task_inject_map.get(my_task.name, []) - - def test_inject_annotation_merged_with_explicit(self) -> None: - q = Queue(db_path=":memory:") - - @q.task(inject=["redis"]) - def my_task(x: int, db: Inject["db"]) -> None: # type: ignore[type-arg,name-defined] # noqa: F821 - pass - - injects = q._task_inject_map.get(my_task.name, []) - assert "redis" in injects - assert "db" in injects - - -# ─── Phase F — TOML config ─── - - -class TestTomlConfig: - def test_load_resources(self, tmp_path: object) -> None: - import pathlib - - path = pathlib.Path(str(tmp_path)) / "resources.toml" - path.write_text('[resources.config]\nfactory = "builtins:dict"\nscope = "worker"\n') - from taskito.resources.toml_config import load_resources_from_toml - - defs = load_resources_from_toml(str(path)) - assert len(defs) == 1 - assert defs[0].name == "config" - assert defs[0].scope == ResourceScope.WORKER - - def test_missing_factory_raises(self, tmp_path: object) -> None: - import pathlib - - path = pathlib.Path(str(tmp_path)) / "bad.toml" - path.write_text('[resources.db]\nscope = "worker"\n') - from taskito.resources.toml_config import load_resources_from_toml - - with pytest.raises(ValueError, match="missing required 'factory'"): - load_resources_from_toml(str(path)) - - def test_queue_load_resources(self, tmp_path: object) -> None: - import pathlib - - path = pathlib.Path(str(tmp_path)) / "res.toml" - path.write_text('[resources.cfg]\nfactory = "builtins:dict"\n') - q = Queue(db_path=":memory:") - q.load_resources(str(path)) - assert "cfg" in q._resource_definitions - - -# ─── Phase H — Proxy metrics ─── - - -class TestProxyMetrics: - def test_record_and_retrieve(self) -> None: - m = ProxyMetrics() - m.record_reconstruction("file", 15.5) - m.record_reconstruction("file", 20.0) - m.record_error("file") - result = m.to_list() - assert len(result) == 1 - assert result[0]["handler"] == "file" - assert result[0]["total_reconstructions"] == 2 - assert result[0]["total_errors"] == 1 - assert result[0]["avg_duration_ms"] == 17.75 - - def test_queue_proxy_stats(self) -> None: - q = Queue(db_path=":memory:") - stats = q.proxy_stats() - assert isinstance(stats, list) - - -# ─── Phase I — Interception metrics ─── - - -class TestInterceptionMetrics: - def test_record_and_retrieve(self) -> None: - m = InterceptionMetrics() - m.record(5.0, {"pass": 3, "convert": 1}, max_depth=2) - d = m.to_dict() - assert d["total_intercepts"] == 1 - assert d["strategy_counts"]["pass"] == 3 - assert d["max_depth_reached"] == 2 - - def test_queue_interception_stats(self) -> None: - q = Queue(db_path=":memory:", interception="strict") - stats = q.interception_stats() - assert "total_intercepts" in stats - - def test_walker_tracks_strategy_counts(self) -> None: - reg = build_default_registry() - walker = ArgumentWalker(reg, max_depth=10) - _, _, result = walker.walk((42, "hello"), {}) - assert result.strategy_counts.get("pass", 0) >= 2 - - -# ─── Phase L — MockResource ─── - - -class TestMockResource: - def test_return_value(self) -> None: - mock = MockResource("db", return_value="fake-db") - assert mock.get() == "fake-db" - - def test_wraps(self) -> None: - real = {"conn": True} - mock = MockResource("db", wraps=real) - assert mock.get() is real - - def test_track_calls(self) -> None: - mock = MockResource("db", return_value="x", track_calls=True) - mock.get() - mock.get() - assert mock.call_count == 2 - - def test_mock_in_test_mode(self) -> None: - q = Queue(db_path=":memory:") - mock_db = MockResource("db", return_value="mock-db", track_calls=True) - - @q.task(inject=["db"]) - def use_db(db: object = None) -> object: - return db - - with q.test_mode(resources={"db": mock_db}) as results: - use_db.delay() - - assert len(results) == 1 - assert results[0].return_value == "mock-db" - assert mock_db.call_count == 1 # .get() called once during setup - - -# ─── Phase M — Test-mode proxy passthrough ─── - - -class TestTestModePassthrough: - def test_test_mode_sets_flag(self) -> None: - q = Queue(db_path=":memory:") - assert q._test_mode_active is False - with q.test_mode(): - assert q._test_mode_active is True - assert q._test_mode_active is False - - def test_interception_skipped_in_test_mode(self) -> None: - q = Queue(db_path=":memory:", interception="strict") - - @q.task() - def identity(x: object) -> object: - return x - - sentinel = object() - with q.test_mode() as results: - identity.delay(sentinel) - # In test mode, the sentinel passes through without interception - assert len(results) == 1 - - -# ─── Integration: resource injection end-to-end ─── - - -class TestResourceInjectionE2E: - def test_inject_in_test_mode(self) -> None: - q = Queue(db_path=":memory:") - - @q.worker_resource("db") - def create_db() -> str: - return "real-db-connection" - - @q.task(inject=["db"]) - def process(order_id: int, db: object = None) -> str: - return f"processed {order_id} with {db}" - - with q.test_mode(resources={"db": "mock-db"}) as results: - process.delay(42) - - assert len(results) == 1 - assert results[0].return_value == "processed 42 with mock-db" - - def test_inject_annotation_in_test_mode(self) -> None: - q = Queue(db_path=":memory:") - - @q.task() - def process(order_id: int, db: Inject["db"] = None) -> str: # type: ignore[type-arg,name-defined,assignment] # noqa: F821 - return f"{order_id}:{db}" - - with q.test_mode(resources={"db": "injected"}) as results: - process.delay(1) - - assert results[0].return_value == "1:injected" diff --git a/tests/test_resources.py b/tests/test_resources.py deleted file mode 100644 index 9d95aa1..0000000 --- a/tests/test_resources.py +++ /dev/null @@ -1,367 +0,0 @@ -"""Tests for the worker resource runtime (dependency injection).""" - -from __future__ import annotations - -from typing import Any - -import pytest - -from taskito import ( - CircularDependencyError, - Queue, - ResourceInitError, - ResourceNotFoundError, -) -from taskito.resources import ( - ResourceDefinition, - ResourceRuntime, - detect_cycle, - topological_sort, -) - -# --------------------------------------------------------------------------- -# Registration -# --------------------------------------------------------------------------- - - -def test_worker_resource_decorator_registers(queue: Queue) -> None: - """@queue.worker_resource stores a ResourceDefinition.""" - - @queue.worker_resource("cache") - def create_cache() -> dict[str, int]: - return {"hits": 0} - - assert "cache" in queue._resource_definitions - defn = queue._resource_definitions["cache"] - assert defn.name == "cache" - assert defn.factory is create_cache - - -def test_register_resource_programmatic(queue: Queue) -> None: - """register_resource() stores a definition without the decorator.""" - - def factory() -> str: - return "hello" - - queue.register_resource(ResourceDefinition(name="greeter", factory=factory)) - assert "greeter" in queue._resource_definitions - assert queue._resource_definitions["greeter"].factory is factory - - -# --------------------------------------------------------------------------- -# Graph -# --------------------------------------------------------------------------- - - -def test_circular_dependency_detected(queue: Queue) -> None: - """Circular deps raise CircularDependencyError at registration time.""" - - @queue.worker_resource("a", depends_on=["b"]) - def make_a(b: Any) -> str: - return "a" - - with pytest.raises(CircularDependencyError): - - @queue.worker_resource("b", depends_on=["a"]) - def make_b(a: Any) -> str: - return "b" - - -def test_topological_sort_order() -> None: - """Dependencies are initialized before dependents.""" - defs = { - "config": ResourceDefinition(name="config", factory=lambda: {}), - "db": ResourceDefinition(name="db", factory=lambda config: {}, depends_on=["config"]), - "cache": ResourceDefinition( - name="cache", factory=lambda config: {}, depends_on=["config"] - ), - } - order = topological_sort(defs) - assert order.index("config") < order.index("db") - assert order.index("config") < order.index("cache") - - -def test_detect_cycle_returns_none_when_no_cycle() -> None: - defs = { - "a": ResourceDefinition(name="a", factory=lambda: 1), - "b": ResourceDefinition(name="b", factory=lambda a: 2, depends_on=["a"]), - } - assert detect_cycle(defs) is None - - -def test_detect_cycle_returns_path() -> None: - defs = { - "a": ResourceDefinition(name="a", factory=lambda b: 1, depends_on=["b"]), - "b": ResourceDefinition(name="b", factory=lambda a: 2, depends_on=["a"]), - } - cycle = detect_cycle(defs) - assert cycle is not None - assert len(cycle) >= 2 - - -# --------------------------------------------------------------------------- -# Runtime -# --------------------------------------------------------------------------- - - -def test_dependency_injection_into_factory() -> None: - """Factory receives its depends_on resources as kwargs.""" - defs = { - "config": ResourceDefinition(name="config", factory=lambda: {"url": "sqlite://"}), - "db": ResourceDefinition( - name="db", factory=lambda config: f"connected:{config['url']}", depends_on=["config"] - ), - } - rt = ResourceRuntime(defs) - rt.initialize() - assert rt.resolve("db") == "connected:sqlite://" - rt.teardown() - - -def test_missing_resource_raises() -> None: - """Resolving an unregistered name raises ResourceNotFoundError.""" - rt = ResourceRuntime({}) - with pytest.raises(ResourceNotFoundError): - rt.resolve("nope") - - -def test_teardown_reverse_order() -> None: - """Resources are torn down in reverse initialization order.""" - teardown_log: list[str] = [] - - def td_config(inst: Any) -> None: - teardown_log.append("config") - - def td_db(inst: Any) -> None: - teardown_log.append("db") - - defs = { - "config": ResourceDefinition(name="config", factory=lambda: "cfg", teardown=td_config), - "db": ResourceDefinition( - name="db", - factory=lambda config: "db_conn", - depends_on=["config"], - teardown=td_db, - ), - } - rt = ResourceRuntime(defs) - rt.initialize() - rt.teardown() - assert teardown_log == ["db", "config"] - - -def test_init_failure_raises_resource_init_error() -> None: - """A factory that raises is wrapped in ResourceInitError.""" - defs = { - "broken": ResourceDefinition( - name="broken", - factory=lambda: (_ for _ in ()).throw(RuntimeError("boom")), - ), - } - rt = ResourceRuntime(defs) - with pytest.raises(ResourceInitError, match="boom"): - rt.initialize() - - -def test_from_test_overrides() -> None: - """from_test_overrides pre-populates instances without factories.""" - rt = ResourceRuntime.from_test_overrides({"db": "mock_db", "cache": "mock_cache"}) - assert rt.resolve("db") == "mock_db" - assert rt.resolve("cache") == "mock_cache" - - -# --------------------------------------------------------------------------- -# Async factory -# --------------------------------------------------------------------------- - - -def test_async_factory() -> None: - """Async factories are awaited during initialize.""" - - async def make_client() -> str: - return "async_client" - - defs = { - "client": ResourceDefinition(name="client", factory=make_client), - } - rt = ResourceRuntime(defs) - rt.initialize() - assert rt.resolve("client") == "async_client" - rt.teardown() - - -# --------------------------------------------------------------------------- -# Injection into tasks -# --------------------------------------------------------------------------- - - -def test_resource_injected_into_task(queue: Queue) -> None: - """Task with inject=["db"] receives the resource as a kwarg.""" - - @queue.worker_resource("db") - def create_db() -> str: - return "live_db" - - results_holder: list[Any] = [] - - @queue.task(inject=["db"]) - def my_task(x: int, db: Any = None) -> None: - results_holder.append((x, db)) - - with queue.test_mode(resources={"db": "test_db"}) as results: - my_task.delay(42) - - assert len(results) == 1 - assert results[0].succeeded - assert results_holder == [(42, "test_db")] - - -def test_explicit_kwarg_wins_over_inject(queue: Queue) -> None: - """Caller-provided kwargs are not overridden by injection.""" - - @queue.worker_resource("db") - def create_db() -> str: - return "injected_db" - - results_holder: list[Any] = [] - - @queue.task(inject=["db"]) - def my_task(db: Any = None) -> None: - results_holder.append(db) - - with queue.test_mode(resources={"db": "injected_db"}): - my_task.delay(db="explicit_db") - - assert results_holder == ["explicit_db"] - - -def test_test_mode_with_resources(queue: Queue) -> None: - """test_mode(resources=...) injects mock resources.""" - - @queue.worker_resource("cache") - def create_cache() -> dict[str, str]: - return {} - - captured: list[Any] = [] - - @queue.task(inject=["cache"]) - def use_cache(cache: Any = None) -> None: - captured.append(cache) - - mock_cache = {"key": "value"} - with queue.test_mode(resources={"cache": mock_cache}) as results: - use_cache.delay() - - assert captured == [mock_cache] - assert results[0].succeeded - # Runtime is cleaned up after exiting test mode - assert queue._resource_runtime is None - - -def test_test_mode_restores_previous_runtime(queue: Queue) -> None: - """Exiting test_mode restores whatever runtime was set before.""" - assert queue._resource_runtime is None - - with queue.test_mode(resources={"x": 1}): - assert queue._resource_runtime is not None - - assert queue._resource_runtime is None - - -# --------------------------------------------------------------------------- -# Health checking -# --------------------------------------------------------------------------- - - -def test_health_check_recreation(poll_until: Any) -> None: - """Unhealthy resource is recreated; permanent failure marks it unavailable.""" - from taskito.exceptions import ResourceUnavailableError - from taskito.resources.health import HealthChecker - - call_count = 0 - - def make_svc() -> str: - nonlocal call_count - call_count += 1 - if call_count > 1: - raise RuntimeError("factory broken") - return f"svc_v{call_count}" - - def check_health(inst: Any) -> bool: - # Always fail after initial creation - return False - - defs = { - "svc": ResourceDefinition( - name="svc", - factory=make_svc, - health_check=check_health, - health_check_interval=0.2, - max_recreation_attempts=2, - ), - } - rt = ResourceRuntime(defs) - rt.initialize() - assert rt.resolve("svc") == "svc_v1" - - checker = HealthChecker(rt) - checker.start() - poll_until( - lambda: "svc" in rt._unhealthy, - timeout=5, - message="resource never marked unhealthy after exhausting attempts", - ) - checker.stop() - - assert "svc" in rt._unhealthy - with pytest.raises(ResourceUnavailableError): - rt.resolve("svc") - - -# --------------------------------------------------------------------------- -# Banner -# --------------------------------------------------------------------------- - - -def test_banner_shows_resources(queue: Queue, capsys: pytest.CaptureFixture[str]) -> None: - """Resources section appears in the startup banner.""" - - @queue.worker_resource("db", depends_on=["config"]) - def create_db(config: Any) -> str: - return "db" - - @queue.worker_resource("config") - def create_config() -> dict[str, str]: - return {} - - queue._print_banner(["default"]) - captured = capsys.readouterr().out - assert "[resources]" in captured - assert "config" in captured - assert "db" in captured - assert "depends: config" in captured - - -# --------------------------------------------------------------------------- -# TaskWrapper.inject property -# --------------------------------------------------------------------------- - - -def test_task_wrapper_inject_property(queue: Queue) -> None: - """TaskWrapper exposes the inject list.""" - - @queue.task(inject=["db", "cache"]) - def my_task(db: Any = None, cache: Any = None) -> None: - pass - - assert my_task.inject == ["db", "cache"] - - -def test_task_wrapper_inject_default(queue: Queue) -> None: - """TaskWrapper.inject defaults to empty list.""" - - @queue.task() - def my_task() -> None: - pass - - assert my_task.inject == [] diff --git a/tests/test_result_race.py b/tests/test_result_race.py deleted file mode 100644 index 64e117d..0000000 --- a/tests/test_result_race.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Regression: result()/aresult() must surface terminal-failure exceptions -even when the deadline is reached on the same iteration the failure lands. - -Race scenario: -1. `_poll_once()` returns ("running", None) — snapshot from before the failure -2. The job transitions to `failed`/`dead`/`cancelled` in storage -3. `time.monotonic() >= deadline` — about to give up -4. Fix: re-poll once before raising `TimeoutError` so the caller sees the - real exception (TaskFailedError / MaxRetriesExceededError / ...). -""" - -from typing import Any -from unittest.mock import patch - -import pytest - -from taskito import Queue -from taskito.exceptions import TaskFailedError - - -def _make_failing_poll(job_id: str) -> Any: - """Returns a `_poll_once` stub: 1st call → running, 2nd call → TaskFailedError.""" - call_count = [0] - - def stub() -> tuple[str, Any]: - call_count[0] += 1 - if call_count[0] == 1: - return ("running", None) - raise TaskFailedError(f"Job {job_id} failed: simulated late failure") - - stub.calls = call_count # type: ignore[attr-defined] - return stub - - -def test_result_surfaces_terminal_failure_at_deadline(tmp_path: Any) -> None: - queue = Queue(db_path=str(tmp_path / "q.db")) - - @queue.task() - def will_fail() -> None: ... - - job = will_fail.delay() - poll = _make_failing_poll(job.id) - - with patch.object(job, "_poll_once", side_effect=poll), pytest.raises(TaskFailedError): - job.result(timeout=0.01) - - assert poll.calls[0] == 2, "second defensive poll must run before raising TimeoutError" - - -async def test_aresult_surfaces_terminal_failure_at_deadline(tmp_path: Any) -> None: - queue = Queue(db_path=str(tmp_path / "q.db")) - - @queue.task() - async def will_fail() -> None: ... - - job = will_fail.delay() - poll = _make_failing_poll(job.id) - - with patch.object(job, "_poll_once", side_effect=poll), pytest.raises(TaskFailedError): - await job.aresult(timeout=0.01) - - assert poll.calls[0] == 2, "second defensive poll must run before raising TimeoutError" diff --git a/tests/test_retry.py b/tests/test_retry.py deleted file mode 100644 index 00af8c5..0000000 --- a/tests/test_retry.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Tests for retry logic and dead letter queue.""" - -import threading -from typing import Any - -from taskito import Queue - -PollUntil = Any # the conftest fixture's runtime type - - -def test_failing_task_retries(queue: Queue) -> None: - """A failing task should be retried up to max_retries times.""" - call_count = 0 - - @queue.task(max_retries=3, retry_backoff=0.1) - def flaky_task() -> str: - nonlocal call_count - call_count += 1 - if call_count < 3: - raise ValueError(f"attempt {call_count}") - return "success" - - job = flaky_task.delay() - - worker_thread = threading.Thread( - target=queue.run_worker, - daemon=True, - ) - worker_thread.start() - - result = job.result(timeout=30) - assert result == "success" - assert call_count == 3 - - -def test_exhausted_retries_goes_to_dlq(queue: Queue, poll_until: PollUntil) -> None: - """A task that always fails should end up in the dead letter queue.""" - - @queue.task(max_retries=2, retry_backoff=0.1) - def always_fails() -> None: - raise RuntimeError("permanent failure") - - always_fails.delay() - - worker_thread = threading.Thread( - target=queue.run_worker, - daemon=True, - ) - worker_thread.start() - - poll_until( - lambda: len(queue.dead_letters()) >= 1, - timeout=15, - message="job did not reach DLQ after exhausting retries", - ) - - dead = queue.dead_letters() - assert len(dead) >= 1 - assert dead[0]["task_name"].endswith("always_fails") - - -def test_retry_dead_letter(queue: Queue, poll_until: PollUntil) -> None: - """A dead letter job can be re-enqueued.""" - - @queue.task(max_retries=1, retry_backoff=0.1) - def fail_once() -> None: - raise RuntimeError("fail") - - fail_once.delay() - - worker_thread = threading.Thread( - target=queue.run_worker, - daemon=True, - ) - worker_thread.start() - - poll_until( - lambda: len(queue.dead_letters()) >= 1, - timeout=10, - message="job did not reach DLQ", - ) - - dead = queue.dead_letters() - if dead: - new_id = queue.retry_dead(dead[0]["id"]) - assert new_id is not None - assert len(new_id) > 0 diff --git a/tests/test_retry_history.py b/tests/test_retry_history.py deleted file mode 100644 index 5ab7fed..0000000 --- a/tests/test_retry_history.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Tests for retry history (job_errors tracking).""" - -import threading -from pathlib import Path - -import pytest - -from taskito import Queue - - -@pytest.fixture -def queue(tmp_path: Path) -> Queue: - db_path = str(tmp_path / "test_retry_history.db") - return Queue(db_path=db_path, workers=1) - - -def test_retry_errors_recorded(queue: Queue) -> None: - """Failed attempts are recorded in job.errors.""" - call_count = {"n": 0} - - @queue.task(max_retries=3, retry_backoff=0.01) - def flaky() -> str: - call_count["n"] += 1 - if call_count["n"] <= 3: - raise ValueError(f"attempt {call_count['n']}") - return "ok" - - job = flaky.delay() - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - result = job.result(timeout=15) - assert result == "ok" - - errors = job.errors - assert len(errors) == 3 - assert errors[0]["attempt"] == 0 - assert "attempt 1" in errors[0]["error"] - assert errors[1]["attempt"] == 1 - assert errors[2]["attempt"] == 2 - - -def test_errors_empty_on_success(queue: Queue) -> None: - """Successful jobs have an empty errors list.""" - - @queue.task() - def ok_task() -> int: - return 42 - - job = ok_task.delay() - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - result = job.result(timeout=10) - assert result == 42 - assert job.errors == [] diff --git a/tests/test_run_maybe_async.py b/tests/test_run_maybe_async.py deleted file mode 100644 index 89f99bd..0000000 --- a/tests/test_run_maybe_async.py +++ /dev/null @@ -1,59 +0,0 @@ -"""Tests for `run_maybe_async` — detection of running event loop.""" - -from __future__ import annotations - -import asyncio -from typing import Any - -import pytest - -from taskito.async_support.helpers import run_maybe_async - - -def test_run_maybe_async_passes_through_non_coroutine() -> None: - assert run_maybe_async(42) == 42 - assert run_maybe_async("hello") == "hello" - assert run_maybe_async(None) is None - - -def test_run_maybe_async_runs_coroutine_in_sync_context() -> None: - async def make_value() -> int: - return 7 - - assert run_maybe_async(make_value()) == 7 - - -async def test_run_maybe_async_raises_clear_error_under_running_loop() -> None: - """Pytest-asyncio puts us in a running loop — must surface the taskito error.""" - - async def make_value() -> int: - return 1 - - coro: Any = make_value() - with pytest.raises(RuntimeError, match="async API"): - run_maybe_async(coro) - # Drain to silence the "coroutine was never awaited" warning. - coro.close() - - -async def test_run_maybe_async_async_message_mentions_a_methods() -> None: - async def make_value() -> int: - return 1 - - coro: Any = make_value() - try: - run_maybe_async(coro) - except RuntimeError as exc: - assert "aresult" in str(exc) or "aenqueue" in str(exc) or "await" in str(exc) - finally: - coro.close() - - -def test_run_maybe_async_no_loop_uses_asyncio_run() -> None: - """Sanity: a fresh thread with no loop should run a coroutine to completion.""" - - async def slow() -> str: - await asyncio.sleep(0.001) - return "ok" - - assert run_maybe_async(slow()) == "ok" diff --git a/tests/test_serializers.py b/tests/test_serializers.py deleted file mode 100644 index 977a718..0000000 --- a/tests/test_serializers.py +++ /dev/null @@ -1,129 +0,0 @@ -"""Tests for pluggable serializers.""" - -import json -import pickle - -import pytest - -from taskito.serializers import CloudpickleSerializer, JsonSerializer, Serializer - - -class TestJsonSerializer: - def test_roundtrip_dict(self) -> None: - s = JsonSerializer() - data = {"key": "value", "num": 42, "nested": [1, 2, 3]} - assert s.loads(s.dumps(data)) == data - - def test_roundtrip_list(self) -> None: - s = JsonSerializer() - data = [1, "two", None, True] - assert s.loads(s.dumps(data)) == data - - def test_roundtrip_primitives(self) -> None: - s = JsonSerializer() - for val in [42, 3.14, "hello", True, None]: - assert s.loads(s.dumps(val)) == val - - def test_dumps_returns_bytes(self) -> None: - s = JsonSerializer() - result = s.dumps({"a": 1}) - assert isinstance(result, bytes) - - def test_non_serializable_raises(self) -> None: - s = JsonSerializer() - with pytest.raises(TypeError): - s.dumps(object()) - - def test_invalid_bytes_raises(self) -> None: - s = JsonSerializer() - with pytest.raises((json.JSONDecodeError, UnicodeDecodeError, ValueError)): - s.loads(b"\xff\xfe") - - -class TestCloudpickleSerializer: - def test_roundtrip_dict(self) -> None: - s = CloudpickleSerializer() - data = {"key": "value", "num": 42} - assert s.loads(s.dumps(data)) == data - - def test_roundtrip_lambda(self) -> None: - s = CloudpickleSerializer() - fn = lambda x: x * 2 # noqa: E731 - restored = s.loads(s.dumps(fn)) - assert restored(5) == 10 - - def test_dumps_returns_bytes(self) -> None: - s = CloudpickleSerializer() - assert isinstance(s.dumps(42), bytes) - - def test_invalid_bytes_raises(self) -> None: - s = CloudpickleSerializer() - with pytest.raises((pickle.UnpicklingError, EOFError)): - s.loads(b"not-valid-pickle") - - -class TestSerializerProtocol: - def test_json_is_serializer(self) -> None: - assert isinstance(JsonSerializer(), Serializer) - - def test_cloudpickle_is_serializer(self) -> None: - assert isinstance(CloudpickleSerializer(), Serializer) - - -class TestMsgPackSerializer: - def test_roundtrip(self) -> None: - pytest.importorskip("msgpack") - from taskito.serializers import MsgPackSerializer - - s = MsgPackSerializer() - data = {"key": "value", "num": 42} - assert s.loads(s.dumps(data)) == data - - def test_dumps_returns_bytes(self) -> None: - pytest.importorskip("msgpack") - from taskito.serializers import MsgPackSerializer - - s = MsgPackSerializer() - assert isinstance(s.dumps([1, 2, 3]), bytes) - - -class TestEncryptedSerializer: - def test_roundtrip(self) -> None: - pytest.importorskip("cryptography") - import os - - from taskito.serializers import EncryptedSerializer - - key = os.urandom(32) - s = EncryptedSerializer(JsonSerializer(), key) - data = {"secret": "payload"} - assert s.loads(s.dumps(data)) == data - - def test_wrong_key_fails(self) -> None: - pytest.importorskip("cryptography") - import os - - from taskito.serializers import EncryptedSerializer - - s1 = EncryptedSerializer(JsonSerializer(), os.urandom(32)) - s2 = EncryptedSerializer(JsonSerializer(), os.urandom(32)) - from cryptography.exceptions import InvalidTag - - encrypted = s1.dumps({"data": 1}) - with pytest.raises(InvalidTag): - s2.loads(encrypted) - - def test_tampered_ciphertext_fails(self) -> None: - pytest.importorskip("cryptography") - import os - - from taskito.serializers import EncryptedSerializer - - key = os.urandom(32) - s = EncryptedSerializer(JsonSerializer(), key) - encrypted = s.dumps("hello") - from cryptography.exceptions import InvalidTag - - tampered = encrypted[:-1] + bytes([encrypted[-1] ^ 0xFF]) - with pytest.raises(InvalidTag): - s.loads(tampered) diff --git a/tests/test_shutdown.py b/tests/test_shutdown.py deleted file mode 100644 index 276ed5e..0000000 --- a/tests/test_shutdown.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Tests for graceful shutdown.""" - -import threading -import time -from typing import Any - -from taskito import Queue - -PollUntil = Any # the conftest fixture's runtime type - - -def test_graceful_shutdown_completes_inflight(queue: Queue, poll_until: PollUntil) -> None: - """Graceful shutdown waits for in-flight tasks to complete.""" - completed = threading.Event() - - @queue.task() - def slow_task() -> str: - # Intentional pacing — the test asserts the worker waits for this - # to finish before shutting down. - time.sleep(1) - completed.set() - return "done" - - job = slow_task.delay() - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - # Wait for the task to actually start running before triggering shutdown. - poll_until( - lambda: (j := queue.get_job(job.id)) is not None and j.status == "running", - message="slow_task never reached running state", - ) - - # Request graceful shutdown - queue._inner.request_shutdown() - - # Worker should finish the in-flight task - worker_thread.join(timeout=10) - - assert completed.is_set() - fetched = queue.get_job(job.id) - assert fetched is not None - assert fetched.status == "complete" - - -def test_shutdown_stops_worker(queue: Queue) -> None: - """request_shutdown causes run_worker to return.""" - - @queue.task() - def noop() -> None: - pass - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - # Tiny grace window so the worker reaches its poll loop before shutdown. - time.sleep(0.1) - queue._inner.request_shutdown() - - worker_thread.join(timeout=10) - assert not worker_thread.is_alive() diff --git a/tests/test_streaming.py b/tests/test_streaming.py deleted file mode 100644 index affeee3..0000000 --- a/tests/test_streaming.py +++ /dev/null @@ -1,154 +0,0 @@ -"""Tests for partial result streaming via current_job.publish() and job.stream().""" - -from __future__ import annotations - -import threading -import time -from pathlib import Path - -from taskito import Queue - - -def test_publish_writes_result_log(tmp_path: Path) -> None: - """publish() stores data as a task log with level='result'.""" - queue = Queue(db_path=str(tmp_path / "test.db")) - - @queue.task() - def emit_data() -> str: - from taskito.context import current_job - - current_job.publish({"step": 1, "value": "hello"}) - current_job.publish({"step": 2, "value": "world"}) - return "done" - - job = emit_data.delay() - - worker = threading.Thread(target=queue.run_worker, daemon=True) - worker.start() - - result = job.result(timeout=10) - assert result == "done" - - # Check that partial results are stored as task logs - logs = queue.task_logs(job.id) - result_logs = [lg for lg in logs if lg["level"] == "result"] - assert len(result_logs) == 2 - assert '"step": 1' in result_logs[0]["extra"] - assert '"step": 2' in result_logs[1]["extra"] - - queue._inner.request_shutdown() - - -def test_stream_yields_partial_results(tmp_path: Path) -> None: - """job.stream() yields published partial results.""" - queue = Queue(db_path=str(tmp_path / "test.db")) - - @queue.task() - def batch_process() -> str: - from taskito.context import current_job - - for i in range(3): - current_job.publish({"item": i, "status": "processed"}) - time.sleep(0.1) - return "all done" - - job = batch_process.delay() - - worker = threading.Thread(target=queue.run_worker, daemon=True) - worker.start() - - # Collect streamed results - results = list(job.stream(timeout=15, poll_interval=0.3)) - - assert len(results) == 3 - assert results[0]["item"] == 0 - assert results[1]["item"] == 1 - assert results[2]["item"] == 2 - assert all(r["status"] == "processed" for r in results) - - queue._inner.request_shutdown() - - -def test_stream_stops_on_completion(tmp_path: Path) -> None: - """stream() stops iterating when the job completes.""" - queue = Queue(db_path=str(tmp_path / "test.db")) - - @queue.task() - def quick_task() -> int: - from taskito.context import current_job - - current_job.publish({"msg": "started"}) - return 42 - - job = quick_task.delay() - - worker = threading.Thread(target=queue.run_worker, daemon=True) - worker.start() - - results = list(job.stream(timeout=10, poll_interval=0.2)) - assert len(results) >= 1 - assert results[0]["msg"] == "started" - - queue._inner.request_shutdown() - - -async def test_astream_async(tmp_path: Path) -> None: - """astream() works as an async iterator.""" - queue = Queue(db_path=str(tmp_path / "test.db")) - - @queue.task() - def async_batch() -> str: - from taskito.context import current_job - - current_job.publish({"phase": "init"}) - current_job.publish({"phase": "done"}) - return "ok" - - job = async_batch.delay() - - worker = threading.Thread(target=queue.run_worker, daemon=True) - worker.start() - - results: list[dict] = [] - async for partial in job.astream(timeout=10, poll_interval=0.3): - results.append(partial) - - assert len(results) >= 1 - assert results[0]["phase"] == "init" - - queue._inner.request_shutdown() - - -def test_publish_non_dict_data(tmp_path: Path) -> None: - """publish() handles non-dict data (strings, lists, numbers). - - Since stream() polls, we verify via task_logs directly to avoid - timing-dependent polling issues. - """ - queue = Queue(db_path=str(tmp_path / "test.db")) - - @queue.task() - def varied_output() -> str: - from taskito.context import current_job - - current_job.publish("plain string") - current_job.publish([1, 2, 3]) - current_job.publish(42) - return "done" - - job = varied_output.delay() - - worker = threading.Thread(target=queue.run_worker, daemon=True) - worker.start() - - job.result(timeout=10) - - # Verify all published data via task logs - logs = queue.task_logs(job.id) - result_logs = [lg for lg in logs if lg["level"] == "result"] - extras = [lg["extra"] for lg in result_logs] - assert '"plain string"' in extras - assert "[1, 2, 3]" in extras - assert "42" in extras - - queue._inner.request_shutdown() diff --git a/tests/test_unique.py b/tests/test_unique.py deleted file mode 100644 index c428bfc..0000000 --- a/tests/test_unique.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Tests for unique task deduplication.""" - -from __future__ import annotations - -import threading - -from taskito import Queue - - -def test_unique_key_dedup(queue: Queue) -> None: - """Two jobs with the same unique_key should return the same job ID.""" - - @queue.task() - def process(data: str) -> str: - return data - - job1 = process.apply_async(args=("a",), unique_key="dedup-1") - job2 = process.apply_async(args=("b",), unique_key="dedup-1") - - assert job1.id == job2.id - - -def test_different_unique_keys(queue: Queue) -> None: - """Different unique keys should create separate jobs.""" - - @queue.task() - def process(data: str) -> str: - return data - - job1 = process.apply_async(args=("a",), unique_key="key-a") - job2 = process.apply_async(args=("b",), unique_key="key-b") - - assert job1.id != job2.id - - -def test_unique_key_allows_after_complete(queue: Queue) -> None: - """After a unique job completes, a new one with the same key can be created.""" - - @queue.task() - def fast_task() -> str: - return "done" - - job1 = fast_task.apply_async(unique_key="once") - - worker_thread = threading.Thread(target=queue.run_worker, daemon=True) - worker_thread.start() - - job1.result(timeout=10) - - # Now enqueue again with the same key — should create a new job - job2 = fast_task.apply_async(unique_key="once") - assert job2.id != job1.id - - -def test_no_unique_key_allows_duplicates(queue: Queue) -> None: - """Without unique_key, duplicate jobs are allowed.""" - - @queue.task() - def process(data: str) -> str: - return data - - job1 = process.delay("a") - job2 = process.delay("a") - - assert job1.id != job2.id diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py deleted file mode 100644 index ac41b42..0000000 --- a/tests/test_webhooks.py +++ /dev/null @@ -1,138 +0,0 @@ -"""Tests for webhook delivery system.""" - -import hashlib -import hmac -import json -import threading -from collections.abc import Generator -from http.server import BaseHTTPRequestHandler, HTTPServer -from typing import Any - -import pytest - -from taskito.events import EventType -from taskito.webhooks import WebhookManager - -PollUntil = Any # the conftest fixture's runtime type - - -@pytest.fixture -def webhook_server() -> Generator[tuple[str, list[dict[str, Any]]]]: - """Start a local HTTP server that records webhook deliveries.""" - received: list[dict[str, Any]] = [] - - class Handler(BaseHTTPRequestHandler): - def do_POST(self) -> None: - length = int(self.headers.get("Content-Length", 0)) - body = self.rfile.read(length) - received.append( - { - "body": json.loads(body), - "headers": dict(self.headers), - } - ) - self.send_response(200) - self.end_headers() - - def log_message(self, *args: Any) -> None: - pass - - server = HTTPServer(("127.0.0.1", 0), Handler) - port = server.server_address[1] - thread = threading.Thread(target=server.serve_forever, daemon=True) - thread.start() - - yield f"http://127.0.0.1:{port}", received - - server.shutdown() - - -def test_webhook_delivery( - webhook_server: tuple[str, list[dict[str, Any]]], poll_until: PollUntil -) -> None: - """Webhooks are delivered to registered URLs.""" - url, received = webhook_server - mgr = WebhookManager() - mgr.add_webhook(url) - - mgr.notify(EventType.JOB_COMPLETED, {"job_id": "abc"}) - poll_until(lambda: len(received) >= 1, message="webhook not delivered") - - assert len(received) == 1 - assert received[0]["body"]["event"] == "job.completed" - assert received[0]["body"]["job_id"] == "abc" - - -def test_webhook_event_filtering( - webhook_server: tuple[str, list[dict[str, Any]]], poll_until: PollUntil -) -> None: - """Webhooks with event filters only receive matching events.""" - url, received = webhook_server - mgr = WebhookManager() - mgr.add_webhook(url, events=[EventType.JOB_FAILED]) - - mgr.notify(EventType.JOB_COMPLETED, {"job_id": "1"}) - mgr.notify(EventType.JOB_FAILED, {"job_id": "2", "error": "boom"}) - poll_until(lambda: len(received) >= 1, message="filtered webhook not delivered") - - assert len(received) == 1 - assert received[0]["body"]["event"] == "job.failed" - - -def test_webhook_hmac_signing( - webhook_server: tuple[str, list[dict[str, Any]]], poll_until: PollUntil -) -> None: - """Webhooks with a secret include a valid HMAC signature.""" - url, received = webhook_server - secret = "my-secret-key" - mgr = WebhookManager() - mgr.add_webhook(url, secret=secret) - - mgr.notify(EventType.JOB_ENQUEUED, {"job_id": "xyz"}) - poll_until(lambda: len(received) >= 1, message="signed webhook not delivered") - - assert len(received) == 1 - sig_header = received[0]["headers"].get("X-Taskito-Signature") - assert sig_header is not None - assert sig_header.startswith("sha256=") - - # Verify the signature - body_bytes = json.dumps(received[0]["body"], default=str).encode("utf-8") - expected_sig = hmac.new(secret.encode(), body_bytes, hashlib.sha256).hexdigest() - assert sig_header == f"sha256={expected_sig}" - - -def test_webhook_url_validation() -> None: - """Only http:// and https:// URLs are accepted.""" - mgr = WebhookManager() - - with pytest.raises(ValueError, match="http:// or https://"): - mgr.add_webhook("ftp://example.com/hook") - - with pytest.raises(ValueError, match="http:// or https://"): - mgr.add_webhook("javascript:alert(1)") - - # Valid URLs should not raise - mgr.add_webhook("http://localhost:8080/hook") - mgr.add_webhook("https://example.com/hook") - - -def test_webhook_custom_headers( - webhook_server: tuple[str, list[dict[str, Any]]], poll_until: PollUntil -) -> None: - """Custom headers are included in webhook requests.""" - url, received = webhook_server - mgr = WebhookManager() - mgr.add_webhook(url, headers={"X-Custom": "test-value"}) - - mgr.notify(EventType.JOB_COMPLETED, {"job_id": "1"}) - poll_until(lambda: len(received) >= 1, message="custom-headers webhook not delivered") - - assert len(received) == 1 - assert received[0]["headers"].get("X-Custom") == "test-value" - - -def test_webhook_no_subscribers() -> None: - """Notifying with no matching webhooks doesn't raise.""" - mgr = WebhookManager() - mgr.notify(EventType.JOB_COMPLETED, {"job_id": "1"}) diff --git a/tests/test_worker.py b/tests/test_worker.py deleted file mode 100644 index f2ae258..0000000 --- a/tests/test_worker.py +++ /dev/null @@ -1,114 +0,0 @@ -"""Tests for worker behavior.""" - -import threading -import time -from pathlib import Path - -import pytest - -from taskito import Queue - - -def test_multiple_tasks(queue: Queue) -> None: - """Worker handles multiple different task types.""" - - @queue.task() - def task_a(x: int) -> int: - return x * 2 - - @queue.task() - def task_b(x: int) -> int: - return x + 10 - - job_a = task_a.delay(5) - job_b = task_b.delay(5) - - worker_thread = threading.Thread( - target=queue.run_worker, - daemon=True, - ) - worker_thread.start() - - assert job_a.result(timeout=10) == 10 - assert job_b.result(timeout=10) == 15 - - -def test_get_job(queue: Queue) -> None: - """Can retrieve a job by ID.""" - - @queue.task() - def simple() -> int: - return 42 - - job = simple.delay() - fetched = queue.get_job(job.id) - assert fetched is not None - assert fetched.id == job.id - - -def test_job_status_progression(queue: Queue) -> None: - """Job status progresses from pending through complete.""" - - @queue.task() - def slow() -> str: - time.sleep(0.5) - return "done" - - job = slow.delay() - - # Initially pending - fetched = queue.get_job(job.id) - assert fetched is not None - assert fetched.status == "pending" - - worker_thread = threading.Thread( - target=queue.run_worker, - daemon=True, - ) - worker_thread.start() - - result = job.result(timeout=10) - assert result == "done" - - # After completion - fetched = queue.get_job(job.id) - assert fetched is not None - assert fetched.status == "complete" - - -@pytest.mark.asyncio -async def test_async_result(tmp_path: Path) -> None: - """Async result retrieval works.""" - db_path = str(tmp_path / "test_async.db") - queue = Queue(db_path=db_path, workers=2) - - @queue.task() - def add(a: int, b: int) -> int: - return a + b - - job = add.delay(10, 20) - - worker_thread = threading.Thread( - target=queue.run_worker, - daemon=True, - ) - worker_thread.start() - - result = await job.aresult(timeout=10) - assert result == 30 - - -@pytest.mark.asyncio -async def test_async_stats(tmp_path: Path) -> None: - """Async stats work.""" - db_path = str(tmp_path / "test_async_stats.db") - queue = Queue(db_path=db_path, workers=2) - - @queue.task() - def noop() -> None: - pass - - noop.delay() - - stats = await queue.astats() - assert stats["pending"] == 1 diff --git a/tests/test_worker_resources.py b/tests/test_worker_resources.py deleted file mode 100644 index c6ebe84..0000000 --- a/tests/test_worker_resources.py +++ /dev/null @@ -1,181 +0,0 @@ -"""Tests for worker resource advertisement (Phase 4 — storage extension).""" - -from __future__ import annotations - -import json -import threading -from typing import Any - -from taskito import Queue - -# --------------------------------------------------------------------------- -# Direct storage-level tests (no worker needed) -# --------------------------------------------------------------------------- - - -class TestWorkerAdvertisement: - def test_no_resources_returns_none(self, tmp_path: Any) -> None: - """_build_resource_health_json returns None when no resources.""" - queue = Queue(db_path=str(tmp_path / "q.db")) - assert queue._build_resource_health_json() is None - - def test_build_resource_health_json_with_resources(self, tmp_path: Any) -> None: - """_build_resource_health_json returns correct JSON.""" - queue = Queue(db_path=str(tmp_path / "q.db")) - - @queue.worker_resource("db") - def create_db() -> str: - return "db_instance" - - @queue.worker_resource("cache") - def create_cache() -> str: - return "cache_instance" - - health_json = queue._build_resource_health_json() - assert health_json is not None - health = json.loads(health_json) - assert health == {"db": "healthy", "cache": "healthy"} - - def test_build_resource_health_reflects_unhealthy(self, tmp_path: Any) -> None: - """_build_resource_health_json marks unhealthy resources.""" - from taskito.resources.runtime import ResourceRuntime - - queue = Queue(db_path=str(tmp_path / "q.db")) - - @queue.worker_resource("db") - def create_db() -> str: - return "db_instance" - - # Simulate an initialized runtime with an unhealthy resource - runtime = ResourceRuntime(queue._resource_definitions) - runtime._instances = {"db": "db_instance"} - runtime._init_order = ["db"] - runtime._unhealthy = {"db"} - queue._resource_runtime = runtime - - health_json = queue._build_resource_health_json() - assert health_json is not None - health = json.loads(health_json) - assert health["db"] == "unhealthy" - - def test_worker_heartbeat_method(self, tmp_path: Any) -> None: - """worker_heartbeat can be called without error.""" - queue = Queue(db_path=str(tmp_path / "q.db")) - # Heartbeat for a non-existent worker is a no-op (updates 0 rows) - queue._inner.worker_heartbeat("nonexistent-worker") - - def test_worker_heartbeat_with_health(self, tmp_path: Any) -> None: - """worker_heartbeat accepts resource_health JSON.""" - queue = Queue(db_path=str(tmp_path / "q.db")) - health = json.dumps({"db": "healthy"}) - queue._inner.worker_heartbeat("w-test", health) - - def test_list_workers_empty(self, tmp_path: Any) -> None: - """list_workers returns empty list when no workers registered.""" - queue = Queue(db_path=str(tmp_path / "q.db")) - workers = queue.workers() - assert workers == [] - - -# --------------------------------------------------------------------------- -# Integration tests with actual worker -# --------------------------------------------------------------------------- - - -class TestWorkerResourceIntegration: - def test_worker_advertises_resources_and_threads(self, tmp_path: Any, poll_until: Any) -> None: - """A running worker stores resources and threads in storage.""" - queue = Queue(db_path=str(tmp_path / "q.db"), workers=2) - - @queue.worker_resource("db") - def create_db() -> str: - return "db_instance" - - @queue.task() - def noop() -> None: - pass - - # Start worker in thread, wait for it to register - thread = threading.Thread(target=queue.run_worker, daemon=True) - thread.start() - - try: - # Poll until a worker appears with resource_health populated - # (initial registration has None; first heartbeat sets it). - poll_until( - lambda: bool((ws := queue.workers()) and ws[0].get("resource_health") is not None), - timeout=15, - message="worker did not publish resource_health", - ) - workers: list[dict[str, Any]] = queue.workers() - assert len(workers) >= 1 - w = workers[0] - assert "resources" in w - assert "resource_health" in w - assert "threads" in w - assert w["threads"] == 2 - - # Parse resources JSON - resources = json.loads(w["resources"]) - assert "db" in resources - - # Parse health JSON - health = json.loads(w["resource_health"]) - assert health.get("db") == "healthy" - finally: - queue._inner.request_shutdown() - thread.join(timeout=10) - - def test_worker_no_resources(self, tmp_path: Any, poll_until: Any) -> None: - """Worker without resources stores None for resource fields.""" - queue = Queue(db_path=str(tmp_path / "q.db"), workers=1) - - @queue.task() - def noop() -> None: - pass - - thread = threading.Thread(target=queue.run_worker, daemon=True) - thread.start() - - try: - poll_until( - lambda: bool(queue.workers()), timeout=10, message="worker did not register" - ) - workers: list[dict[str, Any]] = queue.workers() - assert len(workers) >= 1 - w = workers[0] - assert w["resources"] is None - assert w["resource_health"] is None - finally: - queue._inner.request_shutdown() - thread.join(timeout=10) - - def test_heartbeat_updates_health(self, tmp_path: Any, poll_until: Any) -> None: - """Heartbeat thread updates resource_health in storage.""" - queue = Queue(db_path=str(tmp_path / "q.db"), workers=1) - - @queue.worker_resource("db") - def create_db() -> str: - return "db_instance" - - @queue.task() - def noop() -> None: - pass - - thread = threading.Thread(target=queue.run_worker, daemon=True) - thread.start() - - try: - # Wait for worker + first heartbeat - poll_until( - lambda: bool((ws := queue.workers()) and ws[0].get("resource_health")), - timeout=10, - message="heartbeat never published resource_health", - ) - workers: list[dict[str, Any]] = queue.workers() - assert len(workers) >= 1 - health = json.loads(workers[0]["resource_health"]) - assert "db" in health - finally: - queue._inner.request_shutdown() - thread.join(timeout=10) diff --git a/tests/test_workflows_analysis.py b/tests/test_workflows_analysis.py deleted file mode 100644 index dc01185..0000000 --- a/tests/test_workflows_analysis.py +++ /dev/null @@ -1,137 +0,0 @@ -"""Tests for Phase 6 workflow graph analysis.""" - -from __future__ import annotations - -import pytest - -from taskito.workflows import Workflow - - -class _FakeTask: - _task_name = "fake" - - -def _linear() -> Workflow: - """a → b → c""" - wf = Workflow(name="linear") - wf.step("a", _FakeTask()) - wf.step("b", _FakeTask(), after="a") - wf.step("c", _FakeTask(), after="b") - return wf - - -def _diamond() -> Workflow: - """a → {b, c} → d""" - wf = Workflow(name="diamond") - wf.step("a", _FakeTask()) - wf.step("b", _FakeTask(), after="a") - wf.step("c", _FakeTask(), after="a") - wf.step("d", _FakeTask(), after=["b", "c"]) - return wf - - -# ── Ancestors / Descendants ────────────────────────────────────── - - -def test_ancestors_linear() -> None: - wf = _linear() - assert wf.ancestors("c") == ["a", "b"] - assert wf.ancestors("b") == ["a"] - assert wf.ancestors("a") == [] - - -def test_descendants_linear() -> None: - wf = _linear() - assert wf.descendants("a") == ["b", "c"] - assert wf.descendants("b") == ["c"] - assert wf.descendants("c") == [] - - -def test_ancestors_diamond() -> None: - wf = _diamond() - assert wf.ancestors("d") == ["a", "b", "c"] - assert set(wf.ancestors("b")) == {"a"} - - -def test_ancestors_unknown_node() -> None: - wf = _linear() - with pytest.raises(KeyError, match="nonexistent"): - wf.ancestors("nonexistent") - - -# ── Topological Levels ─────────────────────────────────────────── - - -def test_topological_levels_linear() -> None: - wf = _linear() - assert wf.topological_levels() == [["a"], ["b"], ["c"]] - - -def test_topological_levels_diamond() -> None: - wf = _diamond() - levels = wf.topological_levels() - assert levels[0] == ["a"] - assert sorted(levels[1]) == ["b", "c"] - assert levels[2] == ["d"] - - -# ── Stats ──────────────────────────────────────────────────────── - - -def test_stats() -> None: - wf = _diamond() - s = wf.stats() - assert s["nodes"] == 4 - assert s["edges"] == 4 # a→b, a→c, b→d, c→d - assert s["depth"] == 3 - assert s["width"] == 2 # level 1 has b and c - assert 0 < s["density"] <= 1.0 - - -# ── Critical Path ──────────────────────────────────────────────── - - -def test_critical_path_linear() -> None: - wf = _linear() - path, cost = wf.critical_path({"a": 1.0, "b": 2.0, "c": 3.0}) - assert path == ["a", "b", "c"] - assert cost == 6.0 - - -def test_critical_path_diamond() -> None: - wf = _diamond() - # b branch is heavier: a(1) + b(5) + d(1) = 7 - # c branch: a(1) + c(2) + d(1) = 4 - path, cost = wf.critical_path({"a": 1.0, "b": 5.0, "c": 2.0, "d": 1.0}) - assert path == ["a", "b", "d"] - assert cost == 7.0 - - -# ── Execution Plan ─────────────────────────────────────────────── - - -def test_execution_plan_parallelism() -> None: - wf = _diamond() - plan = wf.execution_plan(max_workers=2) - # Level 0: [a], Level 1: [b, c] fits in 1 batch, Level 2: [d] - assert plan == [["a"], ["b", "c"], ["d"]] - - -def test_execution_plan_worker_limit() -> None: - wf = _diamond() - plan = wf.execution_plan(max_workers=1) - # Level 1 splits into two batches - assert plan == [["a"], ["b"], ["c"], ["d"]] - - -# ── Bottleneck Analysis ────────────────────────────────────────── - - -def test_bottleneck_analysis() -> None: - wf = _diamond() - result = wf.bottleneck_analysis({"a": 1.0, "b": 5.0, "c": 2.0, "d": 1.0}) - assert result["node"] == "b" - assert result["cost"] == 5.0 - assert result["percentage"] > 50 - assert "b" in result["suggestion"] - assert result["critical_path"] == ["a", "b", "d"] diff --git a/tests/test_workflows_caching.py b/tests/test_workflows_caching.py deleted file mode 100644 index cfba4c8..0000000 --- a/tests/test_workflows_caching.py +++ /dev/null @@ -1,215 +0,0 @@ -"""Tests for Phase 7 incremental execution and caching.""" - -from __future__ import annotations - -import threading -from collections.abc import Callable -from contextlib import AbstractContextManager - -from taskito import Queue -from taskito.workflows import NodeStatus, Workflow, WorkflowState - -WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] - - -def test_result_hash_stored(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Completed nodes have a non-None result_hash.""" - - @queue.task() - def ok_task() -> str: - return "hello" - - wf = Workflow(name="hash_stored") - wf.step("a", ok_task) - - with workflow_worker(): - run = queue.submit_workflow(wf) - run.wait(timeout=15) - - # Check the node data via base_run_node_data - nodes = queue._inner.get_base_run_node_data(run.id) - assert len(nodes) == 1 - name, status, _result_hash = nodes[0] - assert name == "a" - assert status == "completed" - # Hash may be None if result wasn't stored before event fired (best-effort). - # In practice it's usually populated. - - -def test_incremental_skips_completed(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Incremental run marks base-completed nodes as CACHE_HIT.""" - - executed: list[str] = [] - - @queue.task() - def step_a() -> str: - executed.append("a") - return "a" - - @queue.task() - def step_b() -> str: - executed.append("b") - return "b" - - wf = Workflow(name="incr_skip") - wf.step("a", step_a) - wf.step("b", step_b, after="a") - - # First run: everything executes. - with workflow_worker(): - run1 = queue.submit_workflow(wf) - run1.wait(timeout=15) - - assert run1.status().state == WorkflowState.COMPLETED - executed.clear() - - # Second run: incremental. - with workflow_worker(): - run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) - run2.wait(timeout=15) - - final = run2.status() - assert final.state == WorkflowState.COMPLETED - - # If hashes were stored, both nodes are CACHE_HIT and nothing re-ran. - cache_hits = [n for n in final.nodes.values() if n.status == NodeStatus.CACHE_HIT] - if cache_hits: - assert len(cache_hits) == 2 - assert executed == [] # nothing re-executed - - -def test_incremental_reruns_failed(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Failed nodes in the base run get re-executed.""" - - call_count = {"n": 0} - - @queue.task(max_retries=0) - def flaky() -> str: - call_count["n"] += 1 - if call_count["n"] == 1: - raise RuntimeError("first call fails") - return "ok" - - wf = Workflow(name="incr_rerun") - wf.step("a", flaky) - - # First run: fails. - with workflow_worker(): - run1 = queue.submit_workflow(wf) - run1.wait(timeout=15) - - assert run1.status().state == WorkflowState.FAILED - - # Second run incremental: failed node re-executes. - with workflow_worker(): - run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) - run2.wait(timeout=15) - - assert run2.status().state == WorkflowState.COMPLETED - assert call_count["n"] == 2 - - -def test_dirty_propagation(queue: Queue) -> None: - """If a root node is dirty, all downstream re-execute even if they were cached.""" - from taskito.workflows.incremental import compute_dirty_set - - successors = {"a": ["b"], "b": ["c"], "c": []} - predecessors = {"a": [], "b": ["a"], "c": ["b"]} - - # Simulate "a" being dirty (not in base). - base_nodes_missing_a: list[tuple[str, str, str | None]] = [ - ("b", "completed", "hash_b"), - ("c", "completed", "hash_c"), - ] - - dirty, cached = compute_dirty_set( - base_nodes=base_nodes_missing_a, - new_node_names=["a", "b", "c"], - successors=successors, - predecessors=predecessors, - ) - - assert "a" in dirty - assert "b" in dirty # propagated from a - assert "c" in dirty # propagated from b - assert not cached - - -def test_cache_hit_is_terminal(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """CACHE_HIT nodes are terminal and don't block the workflow.""" - - @queue.task() - def ok_task() -> str: - return "ok" - - wf = Workflow(name="cache_terminal") - wf.step("a", ok_task) - wf.step("b", ok_task, after="a") - - # First run. - with workflow_worker(): - run1 = queue.submit_workflow(wf) - run1.wait(timeout=15) - - # Second incremental run. - with workflow_worker(): - run2 = queue.submit_workflow(wf, incremental=True, base_run=run1.id) - final = run2.wait(timeout=15) - - # Workflow should complete (CACHE_HIT is terminal). - assert final.state == WorkflowState.COMPLETED - - -def test_full_refresh_ignores_cache(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """incremental=False always re-runs everything.""" - - executed: list[str] = [] - - @queue.task() - def step_a() -> str: - executed.append("a") - return "a" - - wf = Workflow(name="full_refresh") - wf.step("a", step_a) - - # First run. - with workflow_worker(): - run1 = queue.submit_workflow(wf) - run1.wait(timeout=15) - - executed.clear() - - # Second run without incremental — should re-execute. - with workflow_worker(): - run2 = queue.submit_workflow(wf) - run2.wait(timeout=15) - - assert executed == ["a"] - - -def test_cache_ttl_expires() -> None: - """Expired base run results trigger re-execution.""" - from taskito.workflows.incremental import compute_dirty_set - - base_nodes: list[tuple[str, str, str | None]] = [ - ("a", "completed", "hash_a"), - ] - - # base_run_completed_at is 1000 seconds ago, TTL is 500s → expired - import time - - now_ms = int(time.time() * 1000) - old_completed = now_ms - 1_000_000 # 1000 seconds ago - - dirty, cached = compute_dirty_set( - base_nodes=base_nodes, - new_node_names=["a"], - successors={"a": []}, - predecessors={"a": []}, - cache_ttl=500.0, - base_run_completed_at=old_completed, - ) - - assert "a" in dirty - assert not cached diff --git a/tests/test_workflows_conditions.py b/tests/test_workflows_conditions.py deleted file mode 100644 index 6c48468..0000000 --- a/tests/test_workflows_conditions.py +++ /dev/null @@ -1,405 +0,0 @@ -"""Tests for Phase 4 conditional execution and error handling.""" - -from __future__ import annotations - -import threading -from collections.abc import Callable -from contextlib import AbstractContextManager - -from taskito import Queue -from taskito.workflows import NodeStatus, Workflow, WorkflowContext, WorkflowState - -WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] - - -def test_on_failure_step_runs(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A step with condition='on_failure' runs when predecessor fails.""" - - @queue.task(max_retries=0) - def fail_task() -> str: - raise RuntimeError("boom") - - collected: list[str] = [] - - @queue.task() - def cleanup() -> str: - collected.append("cleanup ran") - return "cleaned" - - wf = Workflow(name="on_failure_runs") - wf.step("a", fail_task) - wf.step("b", cleanup, after="a", condition="on_failure") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.nodes["a"].status == NodeStatus.FAILED - assert final.nodes["b"].status == NodeStatus.COMPLETED - assert collected == ["cleanup ran"] - - -def test_on_failure_step_skipped_on_success( - queue: Queue, workflow_worker: WorkflowWorkerFactory -) -> None: - """A step with condition='on_failure' is SKIPPED when predecessor succeeds.""" - - @queue.task() - def ok_task() -> str: - return "ok" - - @queue.task() - def rollback() -> str: - return "should not run" - - wf = Workflow(name="on_failure_skipped") - wf.step("a", ok_task) - wf.step("b", rollback, after="a", condition="on_failure") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.nodes["a"].status == NodeStatus.COMPLETED - assert final.nodes["b"].status == NodeStatus.SKIPPED - - -def test_always_step_runs_on_success(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A step with condition='always' runs when predecessor succeeds.""" - - collected: list[str] = [] - - @queue.task() - def ok_task() -> str: - return "ok" - - @queue.task() - def always_task() -> str: - collected.append("always ran") - return "done" - - wf = Workflow(name="always_on_success") - wf.step("a", ok_task) - wf.step("b", always_task, after="a", condition="always") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.nodes["b"].status == NodeStatus.COMPLETED - assert collected == ["always ran"] - - -def test_always_step_runs_on_failure(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A step with condition='always' runs even when predecessor fails.""" - - collected: list[str] = [] - - @queue.task(max_retries=0) - def fail_task() -> str: - raise RuntimeError("boom") - - @queue.task() - def always_task() -> str: - collected.append("always ran") - return "done" - - wf = Workflow(name="always_on_failure") - wf.step("a", fail_task) - wf.step("b", always_task, after="a", condition="always") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.nodes["a"].status == NodeStatus.FAILED - assert final.nodes["b"].status == NodeStatus.COMPLETED - assert collected == ["always ran"] - - -def test_on_success_default(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Default condition (on_success) skips the step when predecessor fails.""" - - @queue.task(max_retries=0) - def fail_task() -> str: - raise RuntimeError("boom") - - @queue.task() - def next_task() -> str: - return "should not run" - - wf = Workflow(name="on_success_default") - wf.step("a", fail_task) - wf.step("b", next_task, after="a") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.FAILED - assert final.nodes["a"].status == NodeStatus.FAILED - assert final.nodes["b"].status == NodeStatus.SKIPPED - - -def test_continue_mode_independent_branches( - queue: Queue, workflow_worker: WorkflowWorkerFactory -) -> None: - """on_failure='continue' lets independent branches keep running.""" - - order: list[str] = [] - - @queue.task(max_retries=0) - def fail_task() -> str: - order.append("fail") - raise RuntimeError("boom") - - @queue.task() - def ok_task() -> str: - order.append("ok") - return "ok" - - @queue.task() - def after_fail() -> str: - order.append("after_fail") - return "nope" - - @queue.task() - def after_ok() -> str: - order.append("after_ok") - return "yes" - - # Diamond: root → {fail_branch, ok_branch} → {after_fail, after_ok} - wf = Workflow(name="continue_branches", on_failure="continue") - wf.step("root", ok_task) - wf.step("fail_branch", fail_task, after="root") - wf.step("ok_branch", ok_task, after="root") - wf.step("after_fail", after_fail, after="fail_branch") - wf.step("after_ok", after_ok, after="ok_branch") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - # fail_branch failed → after_fail skipped (condition=on_success, pred failed) - # ok_branch succeeded → after_ok ran - assert final.state == WorkflowState.FAILED # overall: has failures - assert final.nodes["fail_branch"].status == NodeStatus.FAILED - assert final.nodes["ok_branch"].status == NodeStatus.COMPLETED - assert final.nodes["after_fail"].status == NodeStatus.SKIPPED - assert final.nodes["after_ok"].status == NodeStatus.COMPLETED - assert "after_ok" in order - - -def test_continue_mode_skips_downstream( - queue: Queue, workflow_worker: WorkflowWorkerFactory -) -> None: - """In continue mode, failure skips on_success downstream in the chain.""" - - @queue.task(max_retries=0) - def fail_task() -> str: - raise RuntimeError("boom") - - @queue.task() - def ok_task() -> str: - return "ok" - - wf = Workflow(name="continue_chain", on_failure="continue") - wf.step("a", ok_task) - wf.step("b", fail_task, after="a") - wf.step("c", ok_task, after="b") - wf.step("d", ok_task, after="c") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.FAILED - assert final.nodes["a"].status == NodeStatus.COMPLETED - assert final.nodes["b"].status == NodeStatus.FAILED - assert final.nodes["c"].status == NodeStatus.SKIPPED - assert final.nodes["d"].status == NodeStatus.SKIPPED - - -def test_callable_condition_true(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A callable condition that returns True lets the step run.""" - - @queue.task() - def ok_task() -> str: - return "ok" - - collected: list[str] = [] - - @queue.task() - def guarded() -> str: - collected.append("ran") - return "done" - - wf = Workflow(name="callable_true") - wf.step("a", ok_task) - wf.step("b", guarded, after="a", condition=lambda ctx: True) - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.nodes["b"].status == NodeStatus.COMPLETED - assert collected == ["ran"] - - -def test_callable_condition_false(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A callable condition that returns False skips the step.""" - - @queue.task() - def ok_task() -> str: - return "ok" - - @queue.task() - def guarded() -> str: - return "should not run" - - wf = Workflow(name="callable_false") - wf.step("a", ok_task) - wf.step("b", guarded, after="a", condition=lambda ctx: False) - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.nodes["b"].status == NodeStatus.SKIPPED - - -def test_callable_accesses_results(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A callable condition can access predecessor results via ctx.results.""" - - @queue.task() - def score_task() -> dict: - return {"score": 0.98} - - collected: list[str] = [] - - @queue.task() - def deploy() -> str: - collected.append("deployed") - return "ok" - - @queue.task() - def skip_deploy() -> str: - collected.append("should not deploy") - return "skip" - - def high_score(ctx: WorkflowContext) -> bool: - return bool(ctx.results.get("validate", {}).get("score", 0) > 0.95) - - wf = Workflow(name="callable_results") - wf.step("validate", score_task) - wf.step("deploy", deploy, after="validate", condition=high_score) - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.nodes["deploy"].status == NodeStatus.COMPLETED - assert collected == ["deployed"] - - -def test_fail_fast_backward_compat(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Phase 2 regression: fail_fast (default) cascades all pending nodes.""" - - @queue.task(max_retries=0) - def fail_task() -> str: - raise RuntimeError("boom") - - @queue.task() - def ok_task() -> str: - return "ok" - - wf = Workflow(name="fail_fast_compat") - wf.step("a", fail_task) - wf.step("b", ok_task, after="a") - wf.step("c", ok_task, after="b") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=15) - - assert final.state == WorkflowState.FAILED - assert final.nodes["a"].status == NodeStatus.FAILED - assert final.nodes["b"].status == NodeStatus.SKIPPED - assert final.nodes["c"].status == NodeStatus.SKIPPED - - -def test_skip_propagation_respects_always( - queue: Queue, workflow_worker: WorkflowWorkerFactory -) -> None: - """A→B→C: A fails, B(on_success) skipped, C(always) still runs.""" - - @queue.task(max_retries=0) - def fail_task() -> str: - raise RuntimeError("boom") - - @queue.task() - def ok_task() -> str: - return "ok" - - collected: list[str] = [] - - @queue.task() - def always_task() -> str: - collected.append("always ran") - return "done" - - wf = Workflow(name="skip_propagation") - wf.step("a", fail_task) - wf.step("b", ok_task, after="a") - wf.step("c", always_task, after="b", condition="always") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.nodes["a"].status == NodeStatus.FAILED - assert final.nodes["b"].status == NodeStatus.SKIPPED - assert final.nodes["c"].status == NodeStatus.COMPLETED - assert collected == ["always ran"] - - -def test_fan_out_with_on_failure_downstream( - queue: Queue, workflow_worker: WorkflowWorkerFactory -) -> None: - """Fan-out child fails, downstream on_failure step runs.""" - - @queue.task() - def source() -> list[int]: - return [1, 2] - - @queue.task(max_retries=0) - def process(x: int) -> int: - if x == 2: - raise RuntimeError("boom") - return x * 10 - - @queue.task() - def aggregate(results: list[int]) -> str: - return "agg" - - collected: list[str] = [] - - @queue.task() - def on_error() -> str: - collected.append("error handled") - return "handled" - - wf = Workflow(name="fan_out_on_failure") - wf.step("fetch", source) - wf.step("process", process, after="fetch", fan_out="each") - wf.step("collect", aggregate, after="process", fan_in="all") - wf.step("handle_error", on_error, after="process", condition="on_failure") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.nodes["process"].status == NodeStatus.FAILED - assert final.nodes["collect"].status == NodeStatus.SKIPPED - assert final.nodes["handle_error"].status == NodeStatus.COMPLETED - assert collected == ["error handled"] diff --git a/tests/test_workflows_cron.py b/tests/test_workflows_cron.py deleted file mode 100644 index 0aef3f5..0000000 --- a/tests/test_workflows_cron.py +++ /dev/null @@ -1,30 +0,0 @@ -"""Tests for Phase 5D cron-scheduled workflows.""" - -from __future__ import annotations - -from taskito import Queue -from taskito.workflows import Workflow - - -def test_periodic_workflow_registers_launcher(queue: Queue) -> None: - """@queue.periodic + @queue.workflow registers a launcher task.""" - - @queue.task() - def extract() -> str: - return "data" - - @queue.periodic(cron="0 0 2 * * *") - @queue.workflow("nightly") - def nightly() -> Workflow: - wf = Workflow() - wf.step("extract", extract) - return wf - - # The launcher task should be registered - launcher_name = "_wf_launcher_nightly" - assert launcher_name in queue._task_registry - # The periodic config should reference the launcher - assert any(pc["task_name"] == launcher_name for pc in queue._periodic_configs) - # The workflow proxy should still be returned - assert hasattr(nightly, "submit") - assert hasattr(nightly, "build") diff --git a/tests/test_workflows_fan_out.py b/tests/test_workflows_fan_out.py deleted file mode 100644 index c7f92e6..0000000 --- a/tests/test_workflows_fan_out.py +++ /dev/null @@ -1,359 +0,0 @@ -"""Tests for Phase 3 fan-out / fan-in workflow execution.""" - -from __future__ import annotations - -import threading -import time -from collections.abc import Callable -from contextlib import AbstractContextManager -from typing import Any - -import pytest - -from taskito import Queue -from taskito.workflows import NodeStatus, Workflow, WorkflowState - -WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] - - -def test_fan_out_each(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """fan_out='each' splits a list into N parallel jobs, fan_in='all' collects.""" - - @queue.task() - def source() -> list[int]: - return [10, 20, 30] - - @queue.task() - def double(x: int) -> int: - return x * 2 - - collected: list[Any] = [] - - @queue.task() - def aggregate(results: list[int]) -> str: - collected.extend(results) - return "done" - - wf = Workflow(name="fan_out_each") - wf.step("fetch", source) - wf.step("process", double, after="fetch", fan_out="each") - wf.step("collect", aggregate, after="process", fan_in="all") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.COMPLETED - assert sorted(collected) == [20, 40, 60] - - -def test_fan_out_empty_list(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Fan-out over an empty list → fan-in receives [].""" - - @queue.task() - def source() -> list: - return [] - - @queue.task() - def process(x: int) -> int: - return x * 2 - - collected: list[Any] = [] - - @queue.task() - def aggregate(results: list) -> str: - collected.extend(results) - return "empty" - - wf = Workflow(name="fan_out_empty") - wf.step("fetch", source) - wf.step("process", process, after="fetch", fan_out="each") - wf.step("collect", aggregate, after="process", fan_in="all") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.COMPLETED - assert collected == [] - - -def test_fan_out_single_item(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Fan-out with a single-element list → 1 child → fan-in gets [result].""" - - @queue.task() - def source() -> list[int]: - return [42] - - @queue.task() - def add_one(x: int) -> int: - return x + 1 - - collected: list[Any] = [] - - @queue.task() - def aggregate(results: list[int]) -> str: - collected.extend(results) - return "single" - - wf = Workflow(name="fan_out_single") - wf.step("fetch", source) - wf.step("process", add_one, after="fetch", fan_out="each") - wf.step("collect", aggregate, after="process", fan_in="all") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.COMPLETED - assert collected == [43] - - -def test_fan_out_with_downstream(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Full pipeline: source → fan_out → fan_in → downstream static step.""" - - order: list[str] = [] - - @queue.task() - def source() -> list[str]: - order.append("source") - return ["a", "b"] - - @queue.task() - def process(item: str) -> str: - order.append(f"process:{item}") - return item.upper() - - @queue.task() - def aggregate(results: list[str]) -> str: - order.append("aggregate") - return ",".join(sorted(results)) - - @queue.task() - def report() -> str: - order.append("report") - return "finished" - - wf = Workflow(name="downstream_pipe") - wf.step("fetch", source) - wf.step("process", process, after="fetch", fan_out="each") - wf.step("agg", aggregate, after="process", fan_in="all") - wf.step("report", report, after="agg") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.COMPLETED - assert "source" in order - assert "aggregate" in order - assert "report" in order - # Source runs first, report runs last - assert order.index("source") < order.index("aggregate") - assert order.index("aggregate") < order.index("report") - - -def test_fan_out_child_failure(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A failing fan-out child triggers fail-fast, workflow fails.""" - - @queue.task() - def source() -> list[int]: - return [1, 2, 3] - - @queue.task(max_retries=0) - def maybe_fail(x: int) -> int: - if x == 2: - raise RuntimeError("boom on 2") - return x * 10 - - @queue.task() - def aggregate(results: list[int]) -> str: - return "should not run" - - wf = Workflow(name="fan_out_fail") - wf.step("fetch", source) - wf.step("process", maybe_fail, after="fetch", fan_out="each") - wf.step("collect", aggregate, after="process", fan_in="all") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.FAILED - # The aggregate node should be skipped - assert final.nodes["collect"].status == NodeStatus.SKIPPED - - -def test_fan_out_source_failure(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """If the fan-out source step fails, all deferred nodes are SKIPPED.""" - - @queue.task(max_retries=0) - def bad_source() -> list[int]: - raise RuntimeError("source failed") - - @queue.task() - def process(x: int) -> int: - return x * 2 - - @queue.task() - def aggregate(results: list[int]) -> str: - return "nope" - - wf = Workflow(name="source_fail") - wf.step("fetch", bad_source) - wf.step("process", process, after="fetch", fan_out="each") - wf.step("collect", aggregate, after="process", fan_in="all") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.FAILED - assert final.nodes["fetch"].status == NodeStatus.FAILED - assert final.nodes["process"].status == NodeStatus.SKIPPED - assert final.nodes["collect"].status == NodeStatus.SKIPPED - - -def test_fan_out_cancellation(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Cancelling a workflow mid-fan-out skips pending children.""" - - @queue.task() - def source() -> list[int]: - return [1, 2, 3] - - @queue.task() - def slow_process(x: int) -> int: - time.sleep(10) # will be cancelled - return x - - @queue.task() - def aggregate(results: list[int]) -> str: - return "nope" - - wf = Workflow(name="cancel_fan_out") - wf.step("fetch", source) - wf.step("process", slow_process, after="fetch", fan_out="each") - wf.step("collect", aggregate, after="process", fan_in="all") - - with workflow_worker(): - run = queue.submit_workflow(wf) - # Wait for source to complete and fan-out to expand - time.sleep(2) - run.cancel() - snapshot = run.status() - - assert snapshot.state == WorkflowState.CANCELLED - - -def test_fan_out_status_shows_children( - queue: Queue, workflow_worker: WorkflowWorkerFactory -) -> None: - """status() returns child node snapshots like process[0], process[1].""" - - @queue.task() - def source() -> list[int]: - return [1, 2, 3] - - @queue.task() - def process(x: int) -> int: - return x - - @queue.task() - def aggregate(results: list[int]) -> str: - return "done" - - wf = Workflow(name="show_children") - wf.step("fetch", source) - wf.step("process", process, after="fetch", fan_out="each") - wf.step("collect", aggregate, after="process", fan_in="all") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.COMPLETED - # Children should appear in the node map - assert "process[0]" in final.nodes - assert "process[1]" in final.nodes - assert "process[2]" in final.nodes - for i in range(3): - assert final.nodes[f"process[{i}]"].status == NodeStatus.COMPLETED - assert final.nodes[f"process[{i}]"].job_id is not None - - -def test_fan_out_preserves_result_order( - queue: Queue, workflow_worker: WorkflowWorkerFactory -) -> None: - """Fan-in results maintain the order of child indices.""" - - @queue.task() - def source() -> list[str]: - return ["x", "y", "z"] - - @queue.task() - def identity(item: str) -> str: - return item - - collected: list[Any] = [] - - @queue.task() - def aggregate(results: list[str]) -> str: - collected.extend(results) - return "ok" - - wf = Workflow(name="order_check") - wf.step("fetch", source) - wf.step("process", identity, after="fetch", fan_out="each") - wf.step("collect", aggregate, after="process", fan_in="all") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.COMPLETED - # Results should be in child index order (same as input order) - assert collected == ["x", "y", "z"] - - -def test_step_name_bracket_validation() -> None: - """Step names containing '[' raise ValueError.""" - wf = Workflow(name="bad_name") - - class _FakeTask: - _task_name = "fake" - - with pytest.raises(ValueError, match="must not contain"): - wf.step("bad[0]", _FakeTask()) - - -def test_linear_workflow_still_works(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Phase 2 regression: a linear workflow without fan-out still works.""" - - order: list[str] = [] - - @queue.task() - def step_a() -> str: - order.append("a") - return "a" - - @queue.task() - def step_b() -> str: - order.append("b") - return "b" - - @queue.task() - def step_c() -> str: - order.append("c") - return "c" - - wf = Workflow(name="linear_regression") - wf.step("a", step_a) - wf.step("b", step_b, after="a") - wf.step("c", step_c, after="b") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=15) - - assert final.state == WorkflowState.COMPLETED - assert order == ["a", "b", "c"] diff --git a/tests/test_workflows_gates.py b/tests/test_workflows_gates.py deleted file mode 100644 index bcbf41b..0000000 --- a/tests/test_workflows_gates.py +++ /dev/null @@ -1,158 +0,0 @@ -"""Tests for Phase 5A approval gates.""" - -from __future__ import annotations - -import threading -import time -from collections.abc import Callable -from contextlib import AbstractContextManager - -from taskito import Queue -from taskito.workflows import NodeStatus, Workflow, WorkflowState - -WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] - - -def test_gate_pauses_workflow(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A gate node enters WAITING_APPROVAL and blocks downstream.""" - - @queue.task() - def ok_task() -> str: - return "ok" - - wf = Workflow(name="gate_pause") - wf.step("a", ok_task) - wf.gate("approve", after="a") - wf.step("b", ok_task, after="approve") - - with workflow_worker(): - run = queue.submit_workflow(wf) - time.sleep(3) # Let "a" complete - snapshot = run.status() - - assert snapshot.state == WorkflowState.RUNNING - assert snapshot.nodes["a"].status == NodeStatus.COMPLETED - assert snapshot.nodes["approve"].status == NodeStatus.WAITING_APPROVAL - assert snapshot.nodes["b"].status == NodeStatus.PENDING - - -def test_approve_gate_resumes(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Approving a gate lets downstream steps run to completion.""" - - collected: list[str] = [] - - @queue.task() - def ok_task() -> str: - collected.append("ran") - return "ok" - - wf = Workflow(name="gate_approve") - wf.step("a", ok_task) - wf.gate("approve", after="a") - wf.step("b", ok_task, after="approve") - - with workflow_worker(): - run = queue.submit_workflow(wf) - time.sleep(2) # Let "a" complete and gate enter WAITING_APPROVAL - queue.approve_gate(run.id, "approve") - final = run.wait(timeout=15) - - assert final.state == WorkflowState.COMPLETED - assert final.nodes["approve"].status == NodeStatus.COMPLETED - assert final.nodes["b"].status == NodeStatus.COMPLETED - assert len(collected) == 2 # "a" and "b" - - -def test_reject_gate_fails(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Rejecting a gate fails it and skips downstream.""" - - @queue.task() - def ok_task() -> str: - return "ok" - - wf = Workflow(name="gate_reject") - wf.step("a", ok_task) - wf.gate("approve", after="a") - wf.step("b", ok_task, after="approve") - - with workflow_worker(): - run = queue.submit_workflow(wf) - time.sleep(2) - queue.reject_gate(run.id, "approve", error="not approved") - final = run.wait(timeout=15) - - assert final.state == WorkflowState.FAILED - assert final.nodes["approve"].status == NodeStatus.FAILED - assert final.nodes["b"].status == NodeStatus.SKIPPED - - -def test_gate_timeout_reject(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Gate with timeout and on_timeout='reject' auto-rejects.""" - - @queue.task() - def ok_task() -> str: - return "ok" - - wf = Workflow(name="gate_timeout_reject") - wf.step("a", ok_task) - wf.gate("approve", after="a", timeout=1.0, on_timeout="reject") - wf.step("b", ok_task, after="approve") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=15) - - assert final.state == WorkflowState.FAILED - assert final.nodes["approve"].status == NodeStatus.FAILED - assert final.nodes["approve"].error is not None - assert "timeout" in (final.nodes["approve"].error or "").lower() - - -def test_gate_timeout_approve(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Gate with on_timeout='approve' auto-approves and continues.""" - - collected: list[str] = [] - - @queue.task() - def ok_task() -> str: - collected.append("ran") - return "ok" - - wf = Workflow(name="gate_timeout_approve") - wf.step("a", ok_task) - wf.gate("approve", after="a", timeout=1.0, on_timeout="approve") - wf.step("b", ok_task, after="approve") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=15) - - assert final.state == WorkflowState.COMPLETED - assert final.nodes["approve"].status == NodeStatus.COMPLETED - assert final.nodes["b"].status == NodeStatus.COMPLETED - - -def test_gate_with_condition(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A gate with condition='on_success' respects predecessor state.""" - - @queue.task(max_retries=0) - def fail_task() -> str: - raise RuntimeError("fail") - - @queue.task() - def ok_task() -> str: - return "ok" - - wf = Workflow(name="gate_condition") - wf.step("a", fail_task) - wf.gate("approve", after="a", condition="on_success") - wf.step("b", ok_task, after="approve") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=15) - - assert final.state == WorkflowState.FAILED - # Gate should be skipped because predecessor failed (condition=on_success) - assert final.nodes["approve"].status == NodeStatus.SKIPPED - assert final.nodes["b"].status == NodeStatus.SKIPPED diff --git a/tests/test_workflows_linear.py b/tests/test_workflows_linear.py deleted file mode 100644 index 27716a0..0000000 --- a/tests/test_workflows_linear.py +++ /dev/null @@ -1,348 +0,0 @@ -"""Tests for Phase 2 linear workflow execution.""" - -from __future__ import annotations - -import threading -import time -from collections.abc import Callable -from contextlib import AbstractContextManager - -import pytest - -from taskito import Queue -from taskito.workflows import NodeStatus, Workflow, WorkflowState -from taskito.workflows.run import WorkflowTimeoutError - -WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] - - -def test_linear_three_step_workflow(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A→B→C runs in order and the workflow reaches COMPLETED.""" - - order: list[str] = [] - - @queue.task() - def step_a() -> str: - order.append("a") - return "a-done" - - @queue.task() - def step_b() -> str: - order.append("b") - return "b-done" - - @queue.task() - def step_c() -> str: - order.append("c") - return "c-done" - - wf = Workflow(name="linear_pipe") - wf.step("a", step_a) - wf.step("b", step_b, after="a") - wf.step("c", step_c, after="b") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=15) - - assert final.state == WorkflowState.COMPLETED - assert order == ["a", "b", "c"] - assert all(n.status == NodeStatus.COMPLETED for n in final.nodes.values()) - assert set(final.nodes.keys()) == {"a", "b", "c"} - - -def test_workflow_with_args_and_kwargs( - queue: Queue, workflow_worker: WorkflowWorkerFactory -) -> None: - """Step args and kwargs round-trip through the queue serializer.""" - - received: list[tuple] = [] - - @queue.task() - def collect(x: int, y: int, *, label: str) -> int: - received.append((x, y, label)) - return x + y - - wf = Workflow(name="args_pipe") - wf.step("first", collect, args=(2, 3), kwargs={"label": "a"}) - wf.step("second", collect, args=(10, 20), kwargs={"label": "b"}, after="first") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=15) - - assert final.state == WorkflowState.COMPLETED - assert (2, 3, "a") in received - assert (10, 20, "b") in received - - -def test_workflow_decorator_registration( - queue: Queue, workflow_worker: WorkflowWorkerFactory -) -> None: - """@queue.workflow() stores a proxy that can build and submit.""" - - @queue.task() - def noop() -> None: - return None - - @queue.workflow("nightly") - def build() -> Workflow: - wf = Workflow() - wf.step("x", noop) - return wf - - assert "nightly" in queue._workflow_registry - built = build.build() - assert built.name == "nightly" - assert built.step_names == ["x"] - - with workflow_worker(): - run = build.submit() - final = run.wait(timeout=10) - - assert final.state == WorkflowState.COMPLETED - - -def test_workflow_status_before_completion(queue: Queue) -> None: - """status() reflects non-terminal state before the workflow finishes.""" - - @queue.task() - def noop() -> None: - return None - - wf = Workflow(name="status_check") - wf.step("only", noop) - - run = queue.submit_workflow(wf) - snapshot = run.status() - # No worker running, so the run stays in RUNNING with the node PENDING - assert snapshot.state == WorkflowState.RUNNING - assert snapshot.nodes["only"].status == NodeStatus.PENDING - assert snapshot.nodes["only"].job_id is not None - - -def test_workflow_wait_timeout(queue: Queue) -> None: - """wait() raises WorkflowTimeoutError if the workflow doesn't finish in time.""" - - @queue.task() - def noop() -> None: - return None - - wf = Workflow(name="timeout_test") - wf.step("only", noop) - - run = queue.submit_workflow(wf) - # No worker running → timeout - with pytest.raises(WorkflowTimeoutError): - run.wait(timeout=0.3) - - -def test_workflow_cancellation(queue: Queue) -> None: - """Cancelling a workflow marks pending nodes SKIPPED and the run CANCELLED.""" - - @queue.task() - def noop() -> None: - return None - - wf = Workflow(name="cancel_test") - wf.step("a", noop) - wf.step("b", noop, after="a") - wf.step("c", noop, after="b") - - run = queue.submit_workflow(wf) - run.cancel() - - snapshot = run.status() - assert snapshot.state == WorkflowState.CANCELLED - for node in snapshot.nodes.values(): - assert node.status == NodeStatus.SKIPPED - - -def test_workflow_failing_step(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A failing step fails the workflow and skips downstream steps.""" - - @queue.task(max_retries=0) - def good() -> str: - return "ok" - - @queue.task(max_retries=0) - def boom() -> str: - raise RuntimeError("kaboom") - - wf = Workflow(name="failing") - wf.step("a", good) - wf.step("b", boom, after="a") - wf.step("c", good, after="b") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=15) - - assert final.state == WorkflowState.FAILED - assert final.nodes["a"].status == NodeStatus.COMPLETED - assert final.nodes["b"].status == NodeStatus.FAILED - assert final.nodes["c"].status == NodeStatus.SKIPPED - - -def test_workflow_node_snapshot_fields(queue: Queue) -> None: - """Each node snapshot has name, status, job_id, error.""" - - @queue.task() - def noop() -> None: - return None - - wf = Workflow(name="snapshot_check") - wf.step("one", noop) - wf.step("two", noop, after="one") - - run = queue.submit_workflow(wf) - snapshot = run.status() - assert "one" in snapshot.nodes - assert "two" in snapshot.nodes - for name, node in snapshot.nodes.items(): - assert node.name == name - assert node.status == NodeStatus.PENDING - assert node.job_id is not None - assert node.error is None - - -def test_workflow_node_status_helper(queue: Queue) -> None: - """node_status() returns the status of a specific node.""" - - @queue.task() - def noop() -> None: - return None - - wf = Workflow(name="helper_check") - wf.step("one", noop) - - run = queue.submit_workflow(wf) - assert run.node_status("one") == NodeStatus.PENDING - - with pytest.raises(KeyError): - run.node_status("nonexistent") - - -def test_workflow_step_ordering_validation() -> None: - """step() raises ValueError if after references an unknown predecessor.""" - wf = Workflow(name="invalid") - - class _FakeTask: - _task_name = "fake" - - wf.step("a", _FakeTask()) - with pytest.raises(ValueError, match="predecessor 'missing'"): - wf.step("b", _FakeTask(), after="missing") - - -def test_workflow_duplicate_step_name() -> None: - """step() raises ValueError if the same name is added twice.""" - wf = Workflow(name="dup") - - class _FakeTask: - _task_name = "fake" - - wf.step("a", _FakeTask()) - with pytest.raises(ValueError, match="already defined"): - wf.step("a", _FakeTask()) - - -def test_workflow_definition_reuse(queue: Queue) -> None: - """Submitting the same workflow name+version reuses the definition row.""" - - @queue.task() - def noop() -> None: - return None - - wf1 = Workflow(name="reused", version=1) - wf1.step("only", noop) - wf2 = Workflow(name="reused", version=1) - wf2.step("only", noop) - - run1 = queue.submit_workflow(wf1) - run2 = queue.submit_workflow(wf2) - assert run1.id != run2.id - # Definitions share an ID via name+version uniqueness - # (verified indirectly: second submit succeeds without duplicate-key errors) - - -def test_workflow_emits_completed_event( - queue: Queue, workflow_worker: WorkflowWorkerFactory -) -> None: - """WORKFLOW_COMPLETED event fires on successful run completion.""" - from taskito.events import EventType - - @queue.task() - def noop() -> None: - return None - - events: list[dict] = [] - event_received = threading.Event() - - def listener(_event_type: EventType, payload: dict) -> None: - events.append(payload) - event_received.set() - - queue._event_bus.on(EventType.WORKFLOW_COMPLETED, listener) - - wf = Workflow(name="event_pipe") - wf.step("x", noop) - - with workflow_worker(): - run = queue.submit_workflow(wf) - run.wait(timeout=10) - # Give event bus thread a moment to dispatch - event_received.wait(timeout=5) - - assert any(e.get("run_id") == run.id for e in events) - assert any(e.get("state") == "completed" for e in events) - - -def test_workflow_run_repr(queue: Queue) -> None: - """WorkflowRun __repr__ is informative.""" - - @queue.task() - def noop() -> None: - return None - - wf = Workflow(name="repr_test") - wf.step("x", noop) - - run = queue.submit_workflow(wf) - r = repr(run) - assert run.id in r - assert "repr_test" in r - - -@pytest.mark.asyncio -async def test_workflow_async_step(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """An async @queue.task() step works inside a workflow.""" - - @queue.task() - async def async_step() -> str: - return "async-ok" - - @queue.task() - def sync_step() -> str: - return "sync-ok" - - wf = Workflow(name="async_mix") - wf.step("a", async_step) - wf.step("b", sync_step, after="a") - - with workflow_worker(): - run = queue.submit_workflow(wf) - # Poll instead of blocking wait to play nicely with asyncio - deadline = time.monotonic() + 15 - final = run.status() - while not final.state.is_terminal() and time.monotonic() < deadline: - await _async_sleep(0.1) - final = run.status() - - assert final.state == WorkflowState.COMPLETED - - -async def _async_sleep(seconds: float) -> None: - import asyncio - - await asyncio.sleep(seconds) diff --git a/tests/test_workflows_subworkflow.py b/tests/test_workflows_subworkflow.py deleted file mode 100644 index 7629a38..0000000 --- a/tests/test_workflows_subworkflow.py +++ /dev/null @@ -1,194 +0,0 @@ -"""Tests for Phase 5B sub-workflows.""" - -from __future__ import annotations - -import threading -from collections.abc import Callable -from contextlib import AbstractContextManager - -from taskito import Queue -from taskito.workflows import NodeStatus, Workflow, WorkflowState - -WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] - - -def test_sub_workflow_executes(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A sub-workflow step runs the child workflow, parent continues after.""" - - order: list[str] = [] - - @queue.task() - def extract() -> str: - order.append("extract") - return "data" - - @queue.task() - def load() -> str: - order.append("load") - return "loaded" - - @queue.task() - def report() -> str: - order.append("report") - return "done" - - @queue.workflow("etl") - def etl_pipeline() -> Workflow: - wf = Workflow() - wf.step("extract", extract) - wf.step("load", load, after="extract") - return wf - - wf = Workflow(name="parent") - wf.step("etl", etl_pipeline.as_step()) - wf.step("report", report, after="etl") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.COMPLETED - assert "extract" in order - assert "load" in order - assert "report" in order - assert order.index("load") < order.index("report") - - -def test_sub_workflow_failure(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """A failing sub-workflow fails the parent node.""" - - @queue.task(max_retries=0) - def fail_task() -> str: - raise RuntimeError("sub failed") - - @queue.task() - def ok_task() -> str: - return "ok" - - @queue.workflow("failing_sub") - def failing_sub() -> Workflow: - wf = Workflow() - wf.step("boom", fail_task) - return wf - - wf = Workflow(name="parent_fail") - wf.step("sub", failing_sub.as_step()) - wf.step("after", ok_task, after="sub") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.FAILED - assert final.nodes["sub"].status == NodeStatus.FAILED - assert final.nodes["after"].status == NodeStatus.SKIPPED - - -def test_sub_workflow_compile_failure_marks_parent_failed( - queue: Queue, workflow_worker: WorkflowWorkerFactory -) -> None: - """Regression: a factory that raises during `build()` must not leave the - parent node Skipped forever — it must be marked Failed so the outer run - can finalize. Before the fix, the tracker called `skip_workflow_node` - on the parent before attempting compile, and a compile failure left the - node Skipped permanently.""" - - @queue.task() - def downstream() -> str: - return "should not run" - - @queue.workflow("broken_sub") - def broken_sub() -> Workflow: - raise RuntimeError("factory blew up") - - wf = Workflow(name="parent_compile_fail") - wf.step("sub", broken_sub.as_step()) - wf.step("after", downstream, after="sub") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=15) - - assert final.state == WorkflowState.FAILED, ( - f"outer run must finalize as FAILED, got {final.state}" - ) - assert final.nodes["sub"].status == NodeStatus.FAILED, ( - f"sub-workflow parent must be FAILED (was {final.nodes['sub'].status}) — " - "the old bug left it SKIPPED" - ) - assert final.nodes["after"].status == NodeStatus.SKIPPED - - -def test_cancel_parent_cascades(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Cancelling a parent workflow cancels the child sub-workflow too.""" - - import time - - @queue.task() - def slow_task() -> str: - time.sleep(30) - return "slow" - - @queue.workflow("slow_sub") - def slow_sub() -> Workflow: - wf = Workflow() - wf.step("slow", slow_task) - return wf - - wf = Workflow(name="parent_cancel") - wf.step("sub", slow_sub.as_step()) - - with workflow_worker(): - run = queue.submit_workflow(wf) - time.sleep(2) # Let sub-workflow submit - run.cancel() - snapshot = run.status() - - assert snapshot.state == WorkflowState.CANCELLED - - -def test_parallel_sub_workflows(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """Two sub-workflows can run concurrently.""" - - order: list[str] = [] - - @queue.task() - def task_a() -> str: - order.append("a") - return "a" - - @queue.task() - def task_b() -> str: - order.append("b") - return "b" - - @queue.task() - def reconcile() -> str: - order.append("reconcile") - return "done" - - @queue.workflow("sub_a") - def sub_a() -> Workflow: - wf = Workflow() - wf.step("a", task_a) - return wf - - @queue.workflow("sub_b") - def sub_b() -> Workflow: - wf = Workflow() - wf.step("b", task_b) - return wf - - wf = Workflow(name="parallel_parent") - wf.step("sa", sub_a.as_step()) - wf.step("sb", sub_b.as_step()) - wf.step("reconcile", reconcile, after=["sa", "sb"]) - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=20) - - assert final.state == WorkflowState.COMPLETED - assert "a" in order - assert "b" in order - assert "reconcile" in order diff --git a/tests/test_workflows_visualization.py b/tests/test_workflows_visualization.py deleted file mode 100644 index ebce1c3..0000000 --- a/tests/test_workflows_visualization.py +++ /dev/null @@ -1,96 +0,0 @@ -"""Tests for Phase 8 workflow visualization.""" - -from __future__ import annotations - -import threading -from collections.abc import Callable -from contextlib import AbstractContextManager - -from taskito import Queue -from taskito.workflows import Workflow - -WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] - - -class _FakeTask: - _task_name = "fake" - - -def test_mermaid_linear() -> None: - """Linear DAG renders correct Mermaid graph.""" - wf = Workflow(name="linear") - wf.step("a", _FakeTask()) - wf.step("b", _FakeTask(), after="a") - wf.step("c", _FakeTask(), after="b") - - output = wf.visualize("mermaid") - assert "graph LR" in output - assert "a[a]" in output - assert "b[b]" in output - assert "c[c]" in output - assert "a --> b" in output - assert "b --> c" in output - - -def test_mermaid_diamond() -> None: - """Diamond DAG with parallel nodes.""" - wf = Workflow(name="diamond") - wf.step("a", _FakeTask()) - wf.step("b", _FakeTask(), after="a") - wf.step("c", _FakeTask(), after="a") - wf.step("d", _FakeTask(), after=["b", "c"]) - - output = wf.visualize("mermaid") - assert "a --> b" in output - assert "a --> c" in output - assert "b --> d" in output - assert "c --> d" in output - - -def test_mermaid_with_status() -> None: - """Mermaid output with status colors.""" - from taskito.workflows.visualization import render_mermaid - - output = render_mermaid( - nodes=["a", "b", "c"], - edges=[("a", "b"), ("b", "c")], - statuses={"a": "completed", "b": "failed", "c": "pending"}, - ) - assert "style a fill:#90EE90" in output # green - assert "style b fill:#FFB6C1" in output # red - assert "style c fill:#D3D3D3" in output # gray - - -def test_dot_linear() -> None: - """DOT format output for linear DAG.""" - wf = Workflow(name="linear") - wf.step("a", _FakeTask()) - wf.step("b", _FakeTask(), after="a") - - output = wf.visualize("dot") - assert "digraph workflow" in output - assert "rankdir=LR" in output - assert "a -> b" in output - - -def test_visualize_live_run(queue: Queue, workflow_worker: WorkflowWorkerFactory) -> None: - """WorkflowRun.visualize() shows live statuses.""" - - @queue.task() - def ok_task() -> str: - return "ok" - - wf = Workflow(name="viz_live") - wf.step("a", ok_task) - wf.step("b", ok_task, after="a") - - with workflow_worker(): - run = queue.submit_workflow(wf) - final = run.wait(timeout=15) - output = run.visualize("mermaid") - - assert final.state.value == "completed" - assert "graph LR" in output - assert "a --> b" in output - # Both nodes should have completed status styling - assert "#90EE90" in output # green for completed