Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions tests/python/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import sys
import threading
import time
from collections.abc import Callable, Generator
from contextlib import AbstractContextManager, contextmanager
from pathlib import Path
Expand All @@ -17,6 +18,41 @@
# definition in one place.
WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]]

PollUntil = Callable[..., None]


@pytest.fixture
def poll_until() -> PollUntil:
"""Poll a predicate until it returns truthy, or fail on timeout.

Replaces ``time.sleep(N)`` followed by an assertion in tests that wait
for a background event (event-bus dispatch, webhook delivery, async
executor completion). Polling shortens the typical wait while keeping
a hard timeout for slow CI runners.

Usage::

poll_until(lambda: len(received) == 1)
poll_until(lambda: counts["a"] == 1, timeout=10, message="callback never fired")
"""

def _poll_until(
predicate: Callable[[], bool],
*,
timeout: float = 5.0,
interval: float = 0.05,
message: str = "predicate did not become true",
) -> None:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if predicate():
return
time.sleep(interval)
if not predicate():
raise AssertionError(f"{message} (timeout {timeout}s)")

return _poll_until


@pytest.fixture
def queue(tmp_path: Path) -> Queue:
Expand Down
5 changes: 2 additions & 3 deletions tests/python/test_customizability.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

from __future__ import annotations

import time
from typing import Any
from unittest.mock import MagicMock

Expand Down Expand Up @@ -93,7 +92,7 @@ def test_event_workers_param(self, tmp_path: Any) -> None:
q = Queue(db_path=str(tmp_path / "test.db"), event_workers=2)
assert q._event_bus._executor._max_workers == 2

def test_on_event_public_api(self, tmp_path: Any) -> None:
def test_on_event_public_api(self, tmp_path: Any, poll_until: Any) -> None:
q = Queue(db_path=str(tmp_path / "test.db"))
received: list[Any] = []

Expand All @@ -107,7 +106,7 @@ def my_task() -> None:
pass

my_task.delay()
time.sleep(0.2)
poll_until(lambda: len(received) >= 1, message="JOB_ENQUEUED event not delivered")
assert len(received) == 1
assert received[0][0] == EventType.JOB_ENQUEUED

Expand Down
17 changes: 10 additions & 7 deletions tests/python/test_dlq.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
"""Tests for dead letter queue management."""

import threading
import time
from typing import Any

from taskito import Queue

PollUntil = Any # the conftest fixture's runtime type


def test_dead_letters_empty(queue: Queue) -> None:
"""Empty DLQ returns empty list."""
dead = queue.dead_letters()
assert dead == []


def test_purge_dead(queue: Queue) -> None:
def test_purge_dead(queue: Queue, poll_until: PollUntil) -> None:
"""Purging removes old dead letter entries."""

@queue.task(max_retries=0, retry_backoff=0.1)
Expand All @@ -27,11 +29,12 @@ def instant_fail() -> None:
)
worker_thread.start()

for _ in range(20):
time.sleep(0.5)
dead = queue.dead_letters()
if len(dead) >= 1:
break
poll_until(
lambda: len(queue.dead_letters()) >= 1,
timeout=10,
message="failed job did not reach DLQ",
)
dead = queue.dead_letters()
assert len(dead) >= 1

# Purge entries older than 0 seconds (purge everything)
Expand Down
20 changes: 13 additions & 7 deletions tests/python/test_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,35 @@

from taskito.events import EventBus, EventType

PollUntil = Any # the conftest fixture's runtime type

def test_callback_receives_event() -> None:

def test_callback_receives_event(poll_until: PollUntil) -> None:
"""Registered callbacks receive emitted events."""
received: list[tuple[EventType, dict[str, Any]]] = []
bus = EventBus()
bus.on(EventType.JOB_COMPLETED, lambda et, p: received.append((et, p)))

bus.emit(EventType.JOB_COMPLETED, {"job_id": "123"})
time.sleep(0.5)
poll_until(lambda: len(received) >= 1, message="event was not delivered")

assert len(received) == 1
assert received[0][0] == EventType.JOB_COMPLETED
assert received[0][1]["job_id"] == "123"


def test_multiple_callbacks() -> None:
def test_multiple_callbacks(poll_until: PollUntil) -> None:
"""Multiple callbacks for the same event type all fire."""
counts = {"a": 0, "b": 0}
bus = EventBus()
bus.on(EventType.JOB_FAILED, lambda et, p: counts.__setitem__("a", counts["a"] + 1))
bus.on(EventType.JOB_FAILED, lambda et, p: counts.__setitem__("b", counts["b"] + 1))

bus.emit(EventType.JOB_FAILED, {"error": "boom"})
time.sleep(0.5)
poll_until(
lambda: counts["a"] == 1 and counts["b"] == 1,
message="not all callbacks fired",
)

assert counts["a"] == 1
assert counts["b"] == 1
Expand All @@ -41,12 +46,13 @@ def test_event_filtering() -> None:
bus.on(EventType.JOB_COMPLETED, lambda et, p: received.append("completed"))

bus.emit(EventType.JOB_FAILED, {"error": "boom"})
time.sleep(0.5)
# Brief settle so a (would-be incorrect) cross-type dispatch could land.
time.sleep(0.2)

assert received == []


def test_exception_in_callback_does_not_crash() -> None:
def test_exception_in_callback_does_not_crash(poll_until: PollUntil) -> None:
"""A raising callback doesn't prevent other events from processing."""
results: list[str] = []
bus = EventBus()
Expand All @@ -61,7 +67,7 @@ def good_callback(et: EventType, p: dict[str, Any]) -> None:
bus.on(EventType.JOB_ENQUEUED, good_callback)

bus.emit(EventType.JOB_ENQUEUED, {})
time.sleep(0.5)
poll_until(lambda: results == ["ok"], message="good_callback did not run")

assert results == ["ok"]

Expand Down
39 changes: 22 additions & 17 deletions tests/python/test_native_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import asyncio
import threading
import time
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock
Expand All @@ -17,6 +16,8 @@
)
from taskito.middleware import TaskMiddleware

PollUntil = Any # the conftest fixture's runtime type

# ── Async detection ──────────────────────────────────────────────


Expand Down Expand Up @@ -132,7 +133,7 @@ def test_async_executor_lifecycle() -> None:
executor.stop()


def test_async_executor_submit_and_execute() -> None:
def test_async_executor_submit_and_execute(poll_until: PollUntil) -> None:
"""Basic async task produces correct result via executor."""
import cloudpickle

Expand Down Expand Up @@ -164,7 +165,7 @@ class FakeWrapper:

payload = cloudpickle.dumps(((2, 3), {}))
executor.submit_job("job-1", "test_mod.my_task", payload, 0, 3, "default")
time.sleep(0.5)
poll_until(lambda: sender.report_success.called, message="job-1 result not reported")
executor.stop()

sender.report_success.assert_called_once()
Expand All @@ -175,7 +176,7 @@ class FakeWrapper:
assert result == 5


def test_async_exception_reported() -> None:
def test_async_exception_reported(poll_until: PollUntil) -> None:
"""Exception in async task → failure result with traceback."""
import cloudpickle

Expand Down Expand Up @@ -206,7 +207,7 @@ class FakeWrapper:

payload = cloudpickle.dumps(((), {}))
executor.submit_job("job-2", "mod.failing_task", payload, 0, 3, "default")
time.sleep(0.5)
poll_until(lambda: sender.report_failure.called, message="job-2 failure not reported")
executor.stop()

sender.report_failure.assert_called_once()
Expand All @@ -216,7 +217,7 @@ class FakeWrapper:
assert call_args[0][6] is True # should_retry


def test_async_cancellation() -> None:
def test_async_cancellation(poll_until: PollUntil) -> None:
"""TaskCancelledError → cancelled result."""
import cloudpickle

Expand Down Expand Up @@ -247,14 +248,14 @@ class FakeWrapper:

payload = cloudpickle.dumps(((), {}))
executor.submit_job("job-3", "mod.cancelling_task", payload, 0, 3, "default")
time.sleep(0.5)
poll_until(lambda: sender.report_cancelled.called, message="job-3 cancellation not reported")
executor.stop()

sender.report_cancelled.assert_called_once()
assert sender.report_cancelled.call_args[0][0] == "job-3"


def test_async_retry_filter() -> None:
def test_async_retry_filter(poll_until: PollUntil) -> None:
"""Failed async task respects retry_on filter."""
import cloudpickle

Expand Down Expand Up @@ -288,14 +289,14 @@ class FakeWrapper:

payload = cloudpickle.dumps(((), {}))
executor.submit_job("job-4", "mod.flaky_task", payload, 0, 3, "default")
time.sleep(0.5)
poll_until(lambda: sender.report_failure.called, message="job-4 failure not reported")
executor.stop()

sender.report_failure.assert_called_once()
assert sender.report_failure.call_args[0][6] is False # should_retry = False


def test_async_concurrency_limit() -> None:
def test_async_concurrency_limit(poll_until: PollUntil) -> None:
"""Semaphore bounds concurrent async tasks."""
import cloudpickle

Expand Down Expand Up @@ -338,14 +339,18 @@ class FakeWrapper:
for i in range(5):
executor.submit_job(f"job-{i}", "mod.slow_task", payload, 0, 3, "default")

time.sleep(1.0)
poll_until(
lambda: sender.report_success.call_count >= 5,
timeout=10,
message="not all 5 slow_task jobs reported success",
)
executor.stop()

assert max_concurrent <= 2
assert sender.report_success.call_count == 5


def test_async_middleware_hooks() -> None:
def test_async_middleware_hooks(poll_until: PollUntil) -> None:
"""Middleware before/after called for async tasks."""
import cloudpickle

Expand Down Expand Up @@ -386,14 +391,14 @@ class FakeWrapper:

payload = cloudpickle.dumps(((), {}))
executor.submit_job("mw-job", "mod.simple_task", payload, 0, 3, "default")
time.sleep(0.5)
poll_until(lambda: "mw-job" in after_called, message="middleware after hook not called")
executor.stop()

assert "mw-job" in before_called
assert "mw-job" in after_called


def test_async_task_with_injection() -> None:
def test_async_task_with_injection(poll_until: PollUntil) -> None:
"""inject=["db"] works for async tasks via executor."""
import cloudpickle

Expand Down Expand Up @@ -430,15 +435,15 @@ class FakeWrapper:

payload = cloudpickle.dumps(((), {}))
executor.submit_job("inj-job", "mod.db_task", payload, 0, 3, "default")
time.sleep(0.5)
poll_until(lambda: sender.report_success.called, message="inj-job result not reported")
executor.stop()

sender.report_success.assert_called_once()
result = cloudpickle.loads(sender.report_success.call_args[0][2])
assert result == "got-fake-conn"


def test_async_context_available_inside_task() -> None:
def test_async_context_available_inside_task(poll_until: PollUntil) -> None:
"""current_job.id works inside an async task via contextvars."""
import cloudpickle

Expand Down Expand Up @@ -471,7 +476,7 @@ class FakeWrapper:

payload = cloudpickle.dumps(((), {}))
executor.submit_job("ctx-job", "mod.ctx_task", payload, 0, 3, "default")
time.sleep(0.5)
poll_until(lambda: captured_id == ["ctx-job"], message="ctx-job context not captured")
executor.stop()

assert captured_id == ["ctx-job"]
Expand Down
16 changes: 7 additions & 9 deletions tests/python/test_periodic.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""Tests for periodic (cron-scheduled) tasks."""

import threading
import time
from pathlib import Path
from typing import Any

import pytest

Expand Down Expand Up @@ -38,7 +38,7 @@ def add(a: int, b: int) -> int:
assert add(3, 4) == 7


def test_periodic_task_triggers(queue: Queue) -> None:
def test_periodic_task_triggers(queue: Queue, poll_until: Any) -> None:
"""Periodic task gets enqueued by the scheduler when due."""
results: list[int] = []

Expand All @@ -50,12 +50,10 @@ def frequent_task() -> str:
worker_thread = threading.Thread(target=queue.run_worker, daemon=True)
worker_thread.start()

# Wait for the periodic task to trigger at least once
deadline = time.time() + 15
while time.time() < deadline:
stats = queue.stats()
if stats["completed"] >= 1:
break
time.sleep(0.5)
poll_until(
lambda: queue.stats()["completed"] >= 1,
timeout=15,
message="periodic task never triggered",
)

assert queue.stats()["completed"] >= 1
Loading
Loading