diff --git a/tests/python/conftest.py b/tests/python/conftest.py index 556e658..0de71a1 100644 --- a/tests/python/conftest.py +++ b/tests/python/conftest.py @@ -3,6 +3,7 @@ import os import sys import threading +import time from collections.abc import Callable, Generator from contextlib import AbstractContextManager, contextmanager from pathlib import Path @@ -17,6 +18,41 @@ # definition in one place. WorkflowWorkerFactory = Callable[[], AbstractContextManager[threading.Thread]] +PollUntil = Callable[..., None] + + +@pytest.fixture +def poll_until() -> PollUntil: + """Poll a predicate until it returns truthy, or fail on timeout. + + Replaces ``time.sleep(N)`` followed by an assertion in tests that wait + for a background event (event-bus dispatch, webhook delivery, async + executor completion). Polling shortens the typical wait while keeping + a hard timeout for slow CI runners. + + Usage:: + + poll_until(lambda: len(received) == 1) + poll_until(lambda: counts["a"] == 1, timeout=10, message="callback never fired") + """ + + def _poll_until( + predicate: Callable[[], bool], + *, + timeout: float = 5.0, + interval: float = 0.05, + message: str = "predicate did not become true", + ) -> None: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if predicate(): + return + time.sleep(interval) + if not predicate(): + raise AssertionError(f"{message} (timeout {timeout}s)") + + return _poll_until + @pytest.fixture def queue(tmp_path: Path) -> Queue: diff --git a/tests/python/test_customizability.py b/tests/python/test_customizability.py index 057e919..b33f3c0 100644 --- a/tests/python/test_customizability.py +++ b/tests/python/test_customizability.py @@ -2,7 +2,6 @@ from __future__ import annotations -import time from typing import Any from unittest.mock import MagicMock @@ -93,7 +92,7 @@ def test_event_workers_param(self, tmp_path: Any) -> None: q = Queue(db_path=str(tmp_path / "test.db"), event_workers=2) assert q._event_bus._executor._max_workers == 2 - def test_on_event_public_api(self, tmp_path: Any) -> None: + def test_on_event_public_api(self, tmp_path: Any, poll_until: Any) -> None: q = Queue(db_path=str(tmp_path / "test.db")) received: list[Any] = [] @@ -107,7 +106,7 @@ def my_task() -> None: pass my_task.delay() - time.sleep(0.2) + poll_until(lambda: len(received) >= 1, message="JOB_ENQUEUED event not delivered") assert len(received) == 1 assert received[0][0] == EventType.JOB_ENQUEUED diff --git a/tests/python/test_dlq.py b/tests/python/test_dlq.py index 4ce14c0..14b484d 100644 --- a/tests/python/test_dlq.py +++ b/tests/python/test_dlq.py @@ -1,10 +1,12 @@ """Tests for dead letter queue management.""" import threading -import time +from typing import Any from taskito import Queue +PollUntil = Any # the conftest fixture's runtime type + def test_dead_letters_empty(queue: Queue) -> None: """Empty DLQ returns empty list.""" @@ -12,7 +14,7 @@ def test_dead_letters_empty(queue: Queue) -> None: assert dead == [] -def test_purge_dead(queue: Queue) -> None: +def test_purge_dead(queue: Queue, poll_until: PollUntil) -> None: """Purging removes old dead letter entries.""" @queue.task(max_retries=0, retry_backoff=0.1) @@ -27,11 +29,12 @@ def instant_fail() -> None: ) worker_thread.start() - for _ in range(20): - time.sleep(0.5) - dead = queue.dead_letters() - if len(dead) >= 1: - break + poll_until( + lambda: len(queue.dead_letters()) >= 1, + timeout=10, + message="failed job did not reach DLQ", + ) + dead = queue.dead_letters() assert len(dead) >= 1 # Purge entries older than 0 seconds (purge everything) diff --git a/tests/python/test_events.py b/tests/python/test_events.py index b77f08a..6d6e5c6 100644 --- a/tests/python/test_events.py +++ b/tests/python/test_events.py @@ -5,22 +5,24 @@ from taskito.events import EventBus, EventType +PollUntil = Any # the conftest fixture's runtime type -def test_callback_receives_event() -> None: + +def test_callback_receives_event(poll_until: PollUntil) -> None: """Registered callbacks receive emitted events.""" received: list[tuple[EventType, dict[str, Any]]] = [] bus = EventBus() bus.on(EventType.JOB_COMPLETED, lambda et, p: received.append((et, p))) bus.emit(EventType.JOB_COMPLETED, {"job_id": "123"}) - time.sleep(0.5) + poll_until(lambda: len(received) >= 1, message="event was not delivered") assert len(received) == 1 assert received[0][0] == EventType.JOB_COMPLETED assert received[0][1]["job_id"] == "123" -def test_multiple_callbacks() -> None: +def test_multiple_callbacks(poll_until: PollUntil) -> None: """Multiple callbacks for the same event type all fire.""" counts = {"a": 0, "b": 0} bus = EventBus() @@ -28,7 +30,10 @@ def test_multiple_callbacks() -> None: bus.on(EventType.JOB_FAILED, lambda et, p: counts.__setitem__("b", counts["b"] + 1)) bus.emit(EventType.JOB_FAILED, {"error": "boom"}) - time.sleep(0.5) + poll_until( + lambda: counts["a"] == 1 and counts["b"] == 1, + message="not all callbacks fired", + ) assert counts["a"] == 1 assert counts["b"] == 1 @@ -41,12 +46,13 @@ def test_event_filtering() -> None: bus.on(EventType.JOB_COMPLETED, lambda et, p: received.append("completed")) bus.emit(EventType.JOB_FAILED, {"error": "boom"}) - time.sleep(0.5) + # Brief settle so a (would-be incorrect) cross-type dispatch could land. + time.sleep(0.2) assert received == [] -def test_exception_in_callback_does_not_crash() -> None: +def test_exception_in_callback_does_not_crash(poll_until: PollUntil) -> None: """A raising callback doesn't prevent other events from processing.""" results: list[str] = [] bus = EventBus() @@ -61,7 +67,7 @@ def good_callback(et: EventType, p: dict[str, Any]) -> None: bus.on(EventType.JOB_ENQUEUED, good_callback) bus.emit(EventType.JOB_ENQUEUED, {}) - time.sleep(0.5) + poll_until(lambda: results == ["ok"], message="good_callback did not run") assert results == ["ok"] diff --git a/tests/python/test_native_async.py b/tests/python/test_native_async.py index de047ac..f6fa7d1 100644 --- a/tests/python/test_native_async.py +++ b/tests/python/test_native_async.py @@ -4,7 +4,6 @@ import asyncio import threading -import time from pathlib import Path from typing import Any from unittest.mock import MagicMock @@ -17,6 +16,8 @@ ) from taskito.middleware import TaskMiddleware +PollUntil = Any # the conftest fixture's runtime type + # ── Async detection ────────────────────────────────────────────── @@ -132,7 +133,7 @@ def test_async_executor_lifecycle() -> None: executor.stop() -def test_async_executor_submit_and_execute() -> None: +def test_async_executor_submit_and_execute(poll_until: PollUntil) -> None: """Basic async task produces correct result via executor.""" import cloudpickle @@ -164,7 +165,7 @@ class FakeWrapper: payload = cloudpickle.dumps(((2, 3), {})) executor.submit_job("job-1", "test_mod.my_task", payload, 0, 3, "default") - time.sleep(0.5) + poll_until(lambda: sender.report_success.called, message="job-1 result not reported") executor.stop() sender.report_success.assert_called_once() @@ -175,7 +176,7 @@ class FakeWrapper: assert result == 5 -def test_async_exception_reported() -> None: +def test_async_exception_reported(poll_until: PollUntil) -> None: """Exception in async task → failure result with traceback.""" import cloudpickle @@ -206,7 +207,7 @@ class FakeWrapper: payload = cloudpickle.dumps(((), {})) executor.submit_job("job-2", "mod.failing_task", payload, 0, 3, "default") - time.sleep(0.5) + poll_until(lambda: sender.report_failure.called, message="job-2 failure not reported") executor.stop() sender.report_failure.assert_called_once() @@ -216,7 +217,7 @@ class FakeWrapper: assert call_args[0][6] is True # should_retry -def test_async_cancellation() -> None: +def test_async_cancellation(poll_until: PollUntil) -> None: """TaskCancelledError → cancelled result.""" import cloudpickle @@ -247,14 +248,14 @@ class FakeWrapper: payload = cloudpickle.dumps(((), {})) executor.submit_job("job-3", "mod.cancelling_task", payload, 0, 3, "default") - time.sleep(0.5) + poll_until(lambda: sender.report_cancelled.called, message="job-3 cancellation not reported") executor.stop() sender.report_cancelled.assert_called_once() assert sender.report_cancelled.call_args[0][0] == "job-3" -def test_async_retry_filter() -> None: +def test_async_retry_filter(poll_until: PollUntil) -> None: """Failed async task respects retry_on filter.""" import cloudpickle @@ -288,14 +289,14 @@ class FakeWrapper: payload = cloudpickle.dumps(((), {})) executor.submit_job("job-4", "mod.flaky_task", payload, 0, 3, "default") - time.sleep(0.5) + poll_until(lambda: sender.report_failure.called, message="job-4 failure not reported") executor.stop() sender.report_failure.assert_called_once() assert sender.report_failure.call_args[0][6] is False # should_retry = False -def test_async_concurrency_limit() -> None: +def test_async_concurrency_limit(poll_until: PollUntil) -> None: """Semaphore bounds concurrent async tasks.""" import cloudpickle @@ -338,14 +339,18 @@ class FakeWrapper: for i in range(5): executor.submit_job(f"job-{i}", "mod.slow_task", payload, 0, 3, "default") - time.sleep(1.0) + poll_until( + lambda: sender.report_success.call_count >= 5, + timeout=10, + message="not all 5 slow_task jobs reported success", + ) executor.stop() assert max_concurrent <= 2 assert sender.report_success.call_count == 5 -def test_async_middleware_hooks() -> None: +def test_async_middleware_hooks(poll_until: PollUntil) -> None: """Middleware before/after called for async tasks.""" import cloudpickle @@ -386,14 +391,14 @@ class FakeWrapper: payload = cloudpickle.dumps(((), {})) executor.submit_job("mw-job", "mod.simple_task", payload, 0, 3, "default") - time.sleep(0.5) + poll_until(lambda: "mw-job" in after_called, message="middleware after hook not called") executor.stop() assert "mw-job" in before_called assert "mw-job" in after_called -def test_async_task_with_injection() -> None: +def test_async_task_with_injection(poll_until: PollUntil) -> None: """inject=["db"] works for async tasks via executor.""" import cloudpickle @@ -430,7 +435,7 @@ class FakeWrapper: payload = cloudpickle.dumps(((), {})) executor.submit_job("inj-job", "mod.db_task", payload, 0, 3, "default") - time.sleep(0.5) + poll_until(lambda: sender.report_success.called, message="inj-job result not reported") executor.stop() sender.report_success.assert_called_once() @@ -438,7 +443,7 @@ class FakeWrapper: assert result == "got-fake-conn" -def test_async_context_available_inside_task() -> None: +def test_async_context_available_inside_task(poll_until: PollUntil) -> None: """current_job.id works inside an async task via contextvars.""" import cloudpickle @@ -471,7 +476,7 @@ class FakeWrapper: payload = cloudpickle.dumps(((), {})) executor.submit_job("ctx-job", "mod.ctx_task", payload, 0, 3, "default") - time.sleep(0.5) + poll_until(lambda: captured_id == ["ctx-job"], message="ctx-job context not captured") executor.stop() assert captured_id == ["ctx-job"] diff --git a/tests/python/test_periodic.py b/tests/python/test_periodic.py index c8be23d..3cabf77 100644 --- a/tests/python/test_periodic.py +++ b/tests/python/test_periodic.py @@ -1,8 +1,8 @@ """Tests for periodic (cron-scheduled) tasks.""" import threading -import time from pathlib import Path +from typing import Any import pytest @@ -38,7 +38,7 @@ def add(a: int, b: int) -> int: assert add(3, 4) == 7 -def test_periodic_task_triggers(queue: Queue) -> None: +def test_periodic_task_triggers(queue: Queue, poll_until: Any) -> None: """Periodic task gets enqueued by the scheduler when due.""" results: list[int] = [] @@ -50,12 +50,10 @@ def frequent_task() -> str: worker_thread = threading.Thread(target=queue.run_worker, daemon=True) worker_thread.start() - # Wait for the periodic task to trigger at least once - deadline = time.time() + 15 - while time.time() < deadline: - stats = queue.stats() - if stats["completed"] >= 1: - break - time.sleep(0.5) + poll_until( + lambda: queue.stats()["completed"] >= 1, + timeout=15, + message="periodic task never triggered", + ) assert queue.stats()["completed"] >= 1 diff --git a/tests/python/test_prefork.py b/tests/python/test_prefork.py index 8277db1..fa1a366 100644 --- a/tests/python/test_prefork.py +++ b/tests/python/test_prefork.py @@ -270,7 +270,7 @@ def _start_cancel_worker(queue: Queue) -> threading.Thread: @prefork_unix_only -def test_prefork_cancel_running_job_stops_quickly(cancel_app: object) -> None: +def test_prefork_cancel_running_job_stops_quickly(cancel_app: object, poll_until: Any) -> None: """``cancel_running_job`` propagates to the prefork child and stops a cooperative task within a small budget — the regression test for #82.""" queue: Queue = cancel_app.queue # type: ignore[attr-defined] @@ -286,9 +286,12 @@ def on_cancel(self, ctx: JobContext) -> None: job = cancel_app.cooperative_loop.delay(600) # type: ignore[attr-defined] _start_cancel_worker(queue) - # Give the worker time to dispatch the job to a child and let the loop - # start spinning before we cancel. - time.sleep(1.0) + # Wait until the job is actually running on a child before cancelling. + def _running() -> bool: + job.refresh() + return bool(job.status == "running") + + poll_until(_running, timeout=10, message="job never reached running state") assert queue.cancel_running_job(job.id) is True @@ -298,7 +301,7 @@ def on_cancel(self, ctx: JobContext) -> None: @prefork_unix_only -def test_prefork_cancel_does_not_kill_child(cancel_app: object) -> None: +def test_prefork_cancel_does_not_kill_child(cancel_app: object, poll_until: Any) -> None: """A cancel must stop the running task without killing the child — the next job dispatched to the same pool should still complete normally.""" queue: Queue = cancel_app.queue # type: ignore[attr-defined] @@ -306,7 +309,11 @@ def test_prefork_cancel_does_not_kill_child(cancel_app: object) -> None: long_job = cancel_app.cooperative_loop.delay(600) # type: ignore[attr-defined] _start_cancel_worker(queue) - time.sleep(1.0) + def _running() -> bool: + long_job.refresh() + return bool(long_job.status == "running") + + poll_until(_running, timeout=10, message="long_job never reached running state") assert queue.cancel_running_job(long_job.id) is True status = _wait_for_terminal(long_job, timeout=10) assert status == "cancelled" diff --git a/tests/python/test_priority.py b/tests/python/test_priority.py index 1b3d6b2..13c8034 100644 --- a/tests/python/test_priority.py +++ b/tests/python/test_priority.py @@ -1,8 +1,8 @@ """Tests for priority scheduling.""" import threading -import time from pathlib import Path +from typing import Any import pytest @@ -15,7 +15,7 @@ def queue(tmp_path: Path) -> Queue: return Queue(db_path=db_path, workers=1) # 1 worker for ordering -def test_priority_ordering(queue: Queue) -> None: +def test_priority_ordering(queue: Queue, poll_until: Any) -> None: """Higher priority jobs should be processed first.""" results: list[str] = [] @@ -24,22 +24,24 @@ def record_task(label: str) -> str: results.append(label) return label - # Enqueue low priority first, then high priority + # Enqueue low priority first, then high priority. All three jobs are + # synchronously visible after `apply_async` returns, so no pre-worker + # delay is needed. record_task.apply_async(args=("low",), priority=1) record_task.apply_async(args=("medium",), priority=5) record_task.apply_async(args=("high",), priority=10) - # Small delay to ensure all are enqueued before worker starts - time.sleep(0.1) - worker_thread = threading.Thread( target=queue.run_worker, daemon=True, ) worker_thread.start() - # Wait for all to complete - time.sleep(3) + poll_until( + lambda: len(results) >= 3, + timeout=10, + message="not all priority jobs completed", + ) # High priority should have been processed first assert len(results) == 3 diff --git a/tests/python/test_resources.py b/tests/python/test_resources.py index c94cc2b..9d95aa1 100644 --- a/tests/python/test_resources.py +++ b/tests/python/test_resources.py @@ -2,7 +2,6 @@ from __future__ import annotations -import time from typing import Any import pytest @@ -274,7 +273,7 @@ def test_test_mode_restores_previous_runtime(queue: Queue) -> None: # --------------------------------------------------------------------------- -def test_health_check_recreation() -> None: +def test_health_check_recreation(poll_until: Any) -> None: """Unhealthy resource is recreated; permanent failure marks it unavailable.""" from taskito.exceptions import ResourceUnavailableError from taskito.resources.health import HealthChecker @@ -307,11 +306,13 @@ def check_health(inst: Any) -> bool: checker = HealthChecker(rt) checker.start() - # Wait for health checks to run and exhaust recreation attempts - time.sleep(1.5) + poll_until( + lambda: "svc" in rt._unhealthy, + timeout=5, + message="resource never marked unhealthy after exhausting attempts", + ) checker.stop() - # After exhausting attempts, resource should be marked unhealthy assert "svc" in rt._unhealthy with pytest.raises(ResourceUnavailableError): rt.resolve("svc") diff --git a/tests/python/test_retry.py b/tests/python/test_retry.py index fe55909..00af8c5 100644 --- a/tests/python/test_retry.py +++ b/tests/python/test_retry.py @@ -1,10 +1,12 @@ """Tests for retry logic and dead letter queue.""" import threading -import time +from typing import Any from taskito import Queue +PollUntil = Any # the conftest fixture's runtime type + def test_failing_task_retries(queue: Queue) -> None: """A failing task should be retried up to max_retries times.""" @@ -31,7 +33,7 @@ def flaky_task() -> str: assert call_count == 3 -def test_exhausted_retries_goes_to_dlq(queue: Queue) -> None: +def test_exhausted_retries_goes_to_dlq(queue: Queue, poll_until: PollUntil) -> None: """A task that always fails should end up in the dead letter queue.""" @queue.task(max_retries=2, retry_backoff=0.1) @@ -46,16 +48,18 @@ def always_fails() -> None: ) worker_thread.start() - # Wait for the job to exhaust retries and move to DLQ - time.sleep(5) + poll_until( + lambda: len(queue.dead_letters()) >= 1, + timeout=15, + message="job did not reach DLQ after exhausting retries", + ) - # Check that it's in the DLQ dead = queue.dead_letters() assert len(dead) >= 1 assert dead[0]["task_name"].endswith("always_fails") -def test_retry_dead_letter(queue: Queue) -> None: +def test_retry_dead_letter(queue: Queue, poll_until: PollUntil) -> None: """A dead letter job can be re-enqueued.""" @queue.task(max_retries=1, retry_backoff=0.1) @@ -70,7 +74,11 @@ def fail_once() -> None: ) worker_thread.start() - time.sleep(3) + poll_until( + lambda: len(queue.dead_letters()) >= 1, + timeout=10, + message="job did not reach DLQ", + ) dead = queue.dead_letters() if dead: diff --git a/tests/python/test_shutdown.py b/tests/python/test_shutdown.py index 7c063b8..276ed5e 100644 --- a/tests/python/test_shutdown.py +++ b/tests/python/test_shutdown.py @@ -2,16 +2,21 @@ import threading import time +from typing import Any from taskito import Queue +PollUntil = Any # the conftest fixture's runtime type -def test_graceful_shutdown_completes_inflight(queue: Queue) -> None: + +def test_graceful_shutdown_completes_inflight(queue: Queue, poll_until: PollUntil) -> None: """Graceful shutdown waits for in-flight tasks to complete.""" completed = threading.Event() @queue.task() def slow_task() -> str: + # Intentional pacing — the test asserts the worker waits for this + # to finish before shutting down. time.sleep(1) completed.set() return "done" @@ -21,8 +26,11 @@ def slow_task() -> str: worker_thread = threading.Thread(target=queue.run_worker, daemon=True) worker_thread.start() - # Wait a bit for the task to start - time.sleep(0.3) + # Wait for the task to actually start running before triggering shutdown. + poll_until( + lambda: (j := queue.get_job(job.id)) is not None and j.status == "running", + message="slow_task never reached running state", + ) # Request graceful shutdown queue._inner.request_shutdown() @@ -46,7 +54,8 @@ def noop() -> None: worker_thread = threading.Thread(target=queue.run_worker, daemon=True) worker_thread.start() - time.sleep(0.3) + # Tiny grace window so the worker reaches its poll loop before shutdown. + time.sleep(0.1) queue._inner.request_shutdown() worker_thread.join(timeout=10) diff --git a/tests/python/test_webhooks.py b/tests/python/test_webhooks.py index 39a47f6..ac41b42 100644 --- a/tests/python/test_webhooks.py +++ b/tests/python/test_webhooks.py @@ -4,7 +4,6 @@ import hmac import json import threading -import time from collections.abc import Generator from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any @@ -14,6 +13,8 @@ from taskito.events import EventType from taskito.webhooks import WebhookManager +PollUntil = Any # the conftest fixture's runtime type + @pytest.fixture def webhook_server() -> Generator[tuple[str, list[dict[str, Any]]]]: @@ -46,21 +47,25 @@ def log_message(self, *args: Any) -> None: server.shutdown() -def test_webhook_delivery(webhook_server: tuple[str, list[dict[str, Any]]]) -> None: +def test_webhook_delivery( + webhook_server: tuple[str, list[dict[str, Any]]], poll_until: PollUntil +) -> None: """Webhooks are delivered to registered URLs.""" url, received = webhook_server mgr = WebhookManager() mgr.add_webhook(url) mgr.notify(EventType.JOB_COMPLETED, {"job_id": "abc"}) - time.sleep(1) + poll_until(lambda: len(received) >= 1, message="webhook not delivered") assert len(received) == 1 assert received[0]["body"]["event"] == "job.completed" assert received[0]["body"]["job_id"] == "abc" -def test_webhook_event_filtering(webhook_server: tuple[str, list[dict[str, Any]]]) -> None: +def test_webhook_event_filtering( + webhook_server: tuple[str, list[dict[str, Any]]], poll_until: PollUntil +) -> None: """Webhooks with event filters only receive matching events.""" url, received = webhook_server mgr = WebhookManager() @@ -68,13 +73,15 @@ def test_webhook_event_filtering(webhook_server: tuple[str, list[dict[str, Any]] mgr.notify(EventType.JOB_COMPLETED, {"job_id": "1"}) mgr.notify(EventType.JOB_FAILED, {"job_id": "2", "error": "boom"}) - time.sleep(1) + poll_until(lambda: len(received) >= 1, message="filtered webhook not delivered") assert len(received) == 1 assert received[0]["body"]["event"] == "job.failed" -def test_webhook_hmac_signing(webhook_server: tuple[str, list[dict[str, Any]]]) -> None: +def test_webhook_hmac_signing( + webhook_server: tuple[str, list[dict[str, Any]]], poll_until: PollUntil +) -> None: """Webhooks with a secret include a valid HMAC signature.""" url, received = webhook_server secret = "my-secret-key" @@ -82,7 +89,7 @@ def test_webhook_hmac_signing(webhook_server: tuple[str, list[dict[str, Any]]]) mgr.add_webhook(url, secret=secret) mgr.notify(EventType.JOB_ENQUEUED, {"job_id": "xyz"}) - time.sleep(1) + poll_until(lambda: len(received) >= 1, message="signed webhook not delivered") assert len(received) == 1 sig_header = received[0]["headers"].get("X-Taskito-Signature") @@ -110,14 +117,16 @@ def test_webhook_url_validation() -> None: mgr.add_webhook("https://example.com/hook") -def test_webhook_custom_headers(webhook_server: tuple[str, list[dict[str, Any]]]) -> None: +def test_webhook_custom_headers( + webhook_server: tuple[str, list[dict[str, Any]]], poll_until: PollUntil +) -> None: """Custom headers are included in webhook requests.""" url, received = webhook_server mgr = WebhookManager() mgr.add_webhook(url, headers={"X-Custom": "test-value"}) mgr.notify(EventType.JOB_COMPLETED, {"job_id": "1"}) - time.sleep(1) + poll_until(lambda: len(received) >= 1, message="custom-headers webhook not delivered") assert len(received) == 1 assert received[0]["headers"].get("X-Custom") == "test-value" diff --git a/tests/python/test_worker_resources.py b/tests/python/test_worker_resources.py index dac4301..c6ebe84 100644 --- a/tests/python/test_worker_resources.py +++ b/tests/python/test_worker_resources.py @@ -4,7 +4,6 @@ import json import threading -import time from typing import Any from taskito import Queue @@ -84,7 +83,7 @@ def test_list_workers_empty(self, tmp_path: Any) -> None: class TestWorkerResourceIntegration: - def test_worker_advertises_resources_and_threads(self, tmp_path: Any) -> None: + def test_worker_advertises_resources_and_threads(self, tmp_path: Any, poll_until: Any) -> None: """A running worker stores resources and threads in storage.""" queue = Queue(db_path=str(tmp_path / "q.db"), workers=2) @@ -102,15 +101,13 @@ def noop() -> None: try: # Poll until a worker appears with resource_health populated - # (initial registration has None; first heartbeat sets it) - deadline = time.monotonic() + 15 - workers: list[dict[str, Any]] = [] - while time.monotonic() < deadline: - workers = queue.workers() - if workers and workers[0].get("resource_health") is not None: - break - time.sleep(0.5) - + # (initial registration has None; first heartbeat sets it). + poll_until( + lambda: bool((ws := queue.workers()) and ws[0].get("resource_health") is not None), + timeout=15, + message="worker did not publish resource_health", + ) + workers: list[dict[str, Any]] = queue.workers() assert len(workers) >= 1 w = workers[0] assert "resources" in w @@ -129,7 +126,7 @@ def noop() -> None: queue._inner.request_shutdown() thread.join(timeout=10) - def test_worker_no_resources(self, tmp_path: Any) -> None: + def test_worker_no_resources(self, tmp_path: Any, poll_until: Any) -> None: """Worker without resources stores None for resource fields.""" queue = Queue(db_path=str(tmp_path / "q.db"), workers=1) @@ -141,14 +138,10 @@ def noop() -> None: thread.start() try: - deadline = time.monotonic() + 10 - workers: list[dict[str, Any]] = [] - while time.monotonic() < deadline: - workers = queue.workers() - if workers: - break - time.sleep(0.2) - + poll_until( + lambda: bool(queue.workers()), timeout=10, message="worker did not register" + ) + workers: list[dict[str, Any]] = queue.workers() assert len(workers) >= 1 w = workers[0] assert w["resources"] is None @@ -157,7 +150,7 @@ def noop() -> None: queue._inner.request_shutdown() thread.join(timeout=10) - def test_heartbeat_updates_health(self, tmp_path: Any) -> None: + def test_heartbeat_updates_health(self, tmp_path: Any, poll_until: Any) -> None: """Heartbeat thread updates resource_health in storage.""" queue = Queue(db_path=str(tmp_path / "q.db"), workers=1) @@ -174,14 +167,12 @@ def noop() -> None: try: # Wait for worker + first heartbeat - deadline = time.monotonic() + 10 - workers: list[dict[str, Any]] = [] - while time.monotonic() < deadline: - workers = queue.workers() - if workers and workers[0].get("resource_health"): - break - time.sleep(0.5) - + poll_until( + lambda: bool((ws := queue.workers()) and ws[0].get("resource_health")), + timeout=10, + message="heartbeat never published resource_health", + ) + workers: list[dict[str, Any]] = queue.workers() assert len(workers) >= 1 health = json.loads(workers[0]["resource_health"]) assert "db" in health