From 3ecfc2d79f84a022ad340c4234875affea1b251d Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Sat, 2 May 2026 12:31:30 +0530 Subject: [PATCH] fix(app): correct enqueue_many middleware args and option propagation `enqueue_many` had three bugs in the on_enqueue dispatch: 1. Always passed `args_list[0]` and `kw_list[0]` to every middleware call, so middleware could not distinguish jobs in the batch. 2. Created a fresh empty options dict per call that was never read back, so any option mutations (priority, queue, delay, ...) were silently discarded. 3. Bare `except: pass` swallowed middleware exceptions with no logging. Restructure the dispatch to mirror single-enqueue: build a per-job options dict initialised from the call-site arguments, run middleware *before* `enqueue_batch` so mutations can take effect, then read the mutated options back into the per-job lists. Replace the silent except with `logger.exception(...)` so misbehaving middleware is observable. To support per-job priority/retries/timeout mutations the Rust `enqueue_batch` signature is widened from `Option>` to `Option>>` (matching the existing pattern for `delay_seconds_list`, `metadata_list`, etc.); type stub follows. Three regression tests cover the per-job args, the option mutation propagation, and the logged-exception path. --- crates/taskito-python/src/py_queue/mod.rs | 27 ++++--- py_src/taskito/_taskito.pyi | 6 +- py_src/taskito/app.py | 62 ++++++++++----- tests/python/test_batch.py | 91 +++++++++++++++++++++++ 4 files changed, 152 insertions(+), 34 deletions(-) diff --git a/crates/taskito-python/src/py_queue/mod.rs b/crates/taskito-python/src/py_queue/mod.rs index 2b9d09d..7190bda 100644 --- a/crates/taskito-python/src/py_queue/mod.rs +++ b/crates/taskito-python/src/py_queue/mod.rs @@ -252,9 +252,9 @@ impl PyQueue { task_names: Vec, payloads: Vec>, queues: Option>, - priorities: Option>, - max_retries_list: Option>, - timeouts: Option>, + priorities: Option>>, + max_retries_list: Option>>, + timeouts: Option>>, delay_seconds_list: Option>>, unique_keys: Option>>, metadata_list: Option>>, @@ -306,17 +306,20 @@ impl PyQueue { }), task_name: task_names[i].clone(), payload: payloads[i].clone(), - priority: priorities.as_ref().map_or(self.default_priority, |p| { - p.get(i).copied().unwrap_or(self.default_priority) - }), + priority: priorities + .as_ref() + .and_then(|p| p.get(i).copied().flatten()) + .unwrap_or(self.default_priority), scheduled_at, - max_retries: max_retries_list.as_ref().map_or(self.default_retry, |r| { - r.get(i).copied().unwrap_or(self.default_retry) - }), + max_retries: max_retries_list + .as_ref() + .and_then(|r| r.get(i).copied().flatten()) + .unwrap_or(self.default_retry), timeout_ms: { - let t = timeouts.as_ref().map_or(self.default_timeout, |t| { - t.get(i).copied().unwrap_or(self.default_timeout) - }); + let t = timeouts + .as_ref() + .and_then(|t| t.get(i).copied().flatten()) + .unwrap_or(self.default_timeout); t.checked_mul(1000).ok_or_else(|| { pyo3::exceptions::PyValueError::new_err("timeout too large, would overflow") })? diff --git a/py_src/taskito/_taskito.pyi b/py_src/taskito/_taskito.pyi index aea0751..cf445da 100644 --- a/py_src/taskito/_taskito.pyi +++ b/py_src/taskito/_taskito.pyi @@ -111,9 +111,9 @@ class PyQueue: task_names: list[str], payloads: list[bytes], queues: list[str] | None = None, - priorities: list[int] | None = None, - max_retries_list: list[int] | None = None, - timeouts: list[int] | None = None, + priorities: list[int | None] | None = None, + max_retries_list: list[int | None] | None = None, + timeouts: list[int | None] | None = None, delay_seconds_list: list[float | None] | None = None, unique_keys: list[str | None] | None = None, metadata_list: list[str | None] | None = None, diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index 1d54a72..522d1f4 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -536,6 +536,48 @@ def enqueue_many( f"args_list length ({len(args_list)})" ) kw_list = kwargs_list or [{}] * count + + # Build a per-job options dict so on_enqueue middleware can mutate + # priority/queue/delay/etc. on a per-job basis before the batch is + # committed. The dispatch must happen BEFORE enqueue_batch — running + # it after (as the previous implementation did) made mutations + # impossible to apply. + per_job_options: list[dict[str, Any]] = [ + { + "priority": priority, + "queue": queue, + "max_retries": max_retries, + "timeout": timeout, + "delay": (delay_list[i] if delay_list is not None else delay), + "unique_key": (unique_keys[i] if unique_keys is not None else None), + "metadata": (metadata_list[i] if metadata_list is not None else metadata), + "expires": (expires_list[i] if expires_list is not None else expires), + "result_ttl": (result_ttl_list[i] if result_ttl_list is not None else result_ttl), + } + for i in range(count) + ] + + chain = self._get_middleware_chain(task_name) + for i in range(count): + for mw in chain: + try: + mw.on_enqueue(task_name, args_list[i], kw_list[i], per_job_options[i]) + except Exception: + logger.exception("middleware on_enqueue() error") + + # Read mutated per-job options back into the per-job lists passed to + # the Rust batch enqueue. `None` entries are forwarded so the Rust + # side falls back to its defaults. + queues_list = [opt["queue"] or "default" for opt in per_job_options] + priorities_list = [opt["priority"] for opt in per_job_options] + retries_list = [opt["max_retries"] for opt in per_job_options] + timeouts_list = [opt["timeout"] for opt in per_job_options] + delays = [opt["delay"] for opt in per_job_options] + per_job_unique_keys = [opt["unique_key"] for opt in per_job_options] + metas = [opt["metadata"] for opt in per_job_options] + exp_list = [opt["expires"] for opt in per_job_options] + ttl_list = [opt["result_ttl"] for opt in per_job_options] + task_serializer = self._get_serializer(task_name) if self._interceptor is not None: pairs = [ @@ -549,17 +591,6 @@ def enqueue_many( ] task_names = [task_name] * count - queues_list = [queue or "default"] * count if queue else None - priorities_list = [priority] * count if priority is not None else None - retries_list = [max_retries] * count if max_retries is not None else None - timeouts_list = [timeout] * count if timeout is not None else None - - # Build per-job optional lists - delays = delay_list or ([delay] * count if delay is not None else None) - metas = metadata_list or ([metadata] * count if metadata is not None else None) - exp_list = expires_list or ([expires] * count if expires is not None else None) - ttl_list = result_ttl_list or ([result_ttl] * count if result_ttl is not None else None) - py_jobs = self._inner.enqueue_batch( task_names=task_names, payloads=payloads, @@ -568,7 +599,7 @@ def enqueue_many( max_retries_list=retries_list, timeouts=timeouts_list, delay_seconds_list=delays, - unique_keys=unique_keys, + unique_keys=per_job_unique_keys, metadata_list=metas, expires_list=exp_list, result_ttl_list=ttl_list, @@ -576,17 +607,10 @@ def enqueue_many( results = [JobResult(py_job=pj, queue=self) for pj in py_jobs] - # Emit events and dispatch on_enqueue middleware for job_result in results: self._emit_event( EventType.JOB_ENQUEUED, {"job_id": job_result.id, "task_name": task_name}, ) - for mw in self._get_middleware_chain(task_name): - try: - options: dict[str, Any] = {} - mw.on_enqueue(task_name, args_list[0], {}, options) - except Exception: - pass return results diff --git a/tests/python/test_batch.py b/tests/python/test_batch.py index d38b8fb..70cec6c 100644 --- a/tests/python/test_batch.py +++ b/tests/python/test_batch.py @@ -1,8 +1,10 @@ """Tests for batch enqueue (enqueue_many / task.map).""" import threading +from typing import Any from taskito import Queue +from taskito.middleware import TaskMiddleware def test_enqueue_many(queue: Queue) -> None: @@ -53,3 +55,92 @@ def noop() -> None: stats = queue.stats() assert stats["pending"] == 50 + + +def test_enqueue_many_invokes_on_enqueue_per_job(tmp_path: Any) -> None: + """`on_enqueue` middleware must receive each job's own args/kwargs. + + Regression: the previous implementation always passed `args_list[0]` + and a fresh empty options dict to every middleware call, so middleware + could not distinguish jobs in the batch. + """ + + class RecordingMiddleware(TaskMiddleware): + def __init__(self) -> None: + self.calls: list[tuple[tuple, dict]] = [] + + def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None: + self.calls.append((args, dict(kwargs))) + + mw = RecordingMiddleware() + q = Queue(db_path=str(tmp_path / "test.db"), middleware=[mw]) + + @q.task() + def add(a: int, b: int) -> int: + return a + b + + q.enqueue_many( + task_name=add.name, + args_list=[(1, 2), (3, 4), (5, 6)], + kwargs_list=[{"trace": "alpha"}, {"trace": "beta"}, {"trace": "gamma"}], + ) + + assert mw.calls == [ + ((1, 2), {"trace": "alpha"}), + ((3, 4), {"trace": "beta"}), + ((5, 6), {"trace": "gamma"}), + ] + + +def test_enqueue_many_applies_option_mutations(tmp_path: Any) -> None: + """Mutations to the options dict inside on_enqueue must propagate to the + enqueued jobs — matching the documented behaviour of single-enqueue. + Regression: the previous implementation discarded mutations because the + hook ran *after* `enqueue_batch` and against a fresh empty dict. + """ + + class PerJobBoosterMiddleware(TaskMiddleware): + def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None: + # Bump priority based on the first argument so each job sees a + # distinct mutation. + options["priority"] = int(args[0]) * 10 + + q = Queue(db_path=str(tmp_path / "test.db"), middleware=[PerJobBoosterMiddleware()]) + + @q.task() + def task_one(n: int) -> int: + return n + + results = q.enqueue_many(task_name=task_one.name, args_list=[(1,), (2,), (3,)]) + + priorities = [q.get_job(r.id).to_dict()["priority"] for r in results] # type: ignore[union-attr] + assert priorities == [10, 20, 30] + + +def test_enqueue_many_logs_middleware_exceptions(tmp_path: Any, caplog: Any) -> None: + """Middleware exceptions must be logged, not silently swallowed. + + Regression: the previous implementation used a bare `except: pass`, + making misbehaving middleware effectively invisible. + """ + import logging + + class ExplodingMiddleware(TaskMiddleware): + def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None: + raise RuntimeError("middleware boom") + + q = Queue(db_path=str(tmp_path / "test.db"), middleware=[ExplodingMiddleware()]) + + @q.task() + def my_task(n: int) -> int: + return n + + with caplog.at_level(logging.ERROR, logger="taskito.app"): + results = q.enqueue_many(task_name=my_task.name, args_list=[(1,), (2,)]) + + # Jobs are still enqueued — middleware errors must not block enqueue + assert len(results) == 2 + + # And the error was surfaced via the logger + assert any("middleware on_enqueue() error" in rec.message for rec in caplog.records) + assert any("middleware boom" in (rec.exc_text or "") for rec in caplog.records)