Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 15 additions & 12 deletions crates/taskito-python/src/py_queue/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,9 +252,9 @@ impl PyQueue {
task_names: Vec<String>,
payloads: Vec<Vec<u8>>,
queues: Option<Vec<String>>,
priorities: Option<Vec<i32>>,
max_retries_list: Option<Vec<i32>>,
timeouts: Option<Vec<i64>>,
priorities: Option<Vec<Option<i32>>>,
max_retries_list: Option<Vec<Option<i32>>>,
timeouts: Option<Vec<Option<i64>>>,
delay_seconds_list: Option<Vec<Option<f64>>>,
unique_keys: Option<Vec<Option<String>>>,
metadata_list: Option<Vec<Option<String>>>,
Expand Down Expand Up @@ -306,17 +306,20 @@ impl PyQueue {
}),
task_name: task_names[i].clone(),
payload: payloads[i].clone(),
priority: priorities.as_ref().map_or(self.default_priority, |p| {
p.get(i).copied().unwrap_or(self.default_priority)
}),
priority: priorities
.as_ref()
.and_then(|p| p.get(i).copied().flatten())
.unwrap_or(self.default_priority),
scheduled_at,
max_retries: max_retries_list.as_ref().map_or(self.default_retry, |r| {
r.get(i).copied().unwrap_or(self.default_retry)
}),
max_retries: max_retries_list
.as_ref()
.and_then(|r| r.get(i).copied().flatten())
.unwrap_or(self.default_retry),
timeout_ms: {
let t = timeouts.as_ref().map_or(self.default_timeout, |t| {
t.get(i).copied().unwrap_or(self.default_timeout)
});
let t = timeouts
.as_ref()
.and_then(|t| t.get(i).copied().flatten())
.unwrap_or(self.default_timeout);
t.checked_mul(1000).ok_or_else(|| {
pyo3::exceptions::PyValueError::new_err("timeout too large, would overflow")
})?
Expand Down
6 changes: 3 additions & 3 deletions py_src/taskito/_taskito.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,9 @@ class PyQueue:
task_names: list[str],
payloads: list[bytes],
queues: list[str] | None = None,
priorities: list[int] | None = None,
max_retries_list: list[int] | None = None,
timeouts: list[int] | None = None,
priorities: list[int | None] | None = None,
max_retries_list: list[int | None] | None = None,
timeouts: list[int | None] | None = None,
delay_seconds_list: list[float | None] | None = None,
unique_keys: list[str | None] | None = None,
metadata_list: list[str | None] | None = None,
Expand Down
62 changes: 43 additions & 19 deletions py_src/taskito/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,48 @@ def enqueue_many(
f"args_list length ({len(args_list)})"
)
kw_list = kwargs_list or [{}] * count

# Build a per-job options dict so on_enqueue middleware can mutate
# priority/queue/delay/etc. on a per-job basis before the batch is
# committed. The dispatch must happen BEFORE enqueue_batch — running
# it after (as the previous implementation did) made mutations
# impossible to apply.
per_job_options: list[dict[str, Any]] = [
{
"priority": priority,
"queue": queue,
"max_retries": max_retries,
"timeout": timeout,
"delay": (delay_list[i] if delay_list is not None else delay),
"unique_key": (unique_keys[i] if unique_keys is not None else None),
"metadata": (metadata_list[i] if metadata_list is not None else metadata),
"expires": (expires_list[i] if expires_list is not None else expires),
"result_ttl": (result_ttl_list[i] if result_ttl_list is not None else result_ttl),
}
for i in range(count)
]

chain = self._get_middleware_chain(task_name)
for i in range(count):
for mw in chain:
try:
mw.on_enqueue(task_name, args_list[i], kw_list[i], per_job_options[i])
except Exception:
logger.exception("middleware on_enqueue() error")

# Read mutated per-job options back into the per-job lists passed to
# the Rust batch enqueue. `None` entries are forwarded so the Rust
# side falls back to its defaults.
queues_list = [opt["queue"] or "default" for opt in per_job_options]
priorities_list = [opt["priority"] for opt in per_job_options]
retries_list = [opt["max_retries"] for opt in per_job_options]
timeouts_list = [opt["timeout"] for opt in per_job_options]
delays = [opt["delay"] for opt in per_job_options]
per_job_unique_keys = [opt["unique_key"] for opt in per_job_options]
metas = [opt["metadata"] for opt in per_job_options]
exp_list = [opt["expires"] for opt in per_job_options]
ttl_list = [opt["result_ttl"] for opt in per_job_options]

task_serializer = self._get_serializer(task_name)
if self._interceptor is not None:
pairs = [
Expand All @@ -549,17 +591,6 @@ def enqueue_many(
]
task_names = [task_name] * count

queues_list = [queue or "default"] * count if queue else None
priorities_list = [priority] * count if priority is not None else None
retries_list = [max_retries] * count if max_retries is not None else None
timeouts_list = [timeout] * count if timeout is not None else None

# Build per-job optional lists
delays = delay_list or ([delay] * count if delay is not None else None)
metas = metadata_list or ([metadata] * count if metadata is not None else None)
exp_list = expires_list or ([expires] * count if expires is not None else None)
ttl_list = result_ttl_list or ([result_ttl] * count if result_ttl is not None else None)

py_jobs = self._inner.enqueue_batch(
task_names=task_names,
payloads=payloads,
Expand All @@ -568,25 +599,18 @@ def enqueue_many(
max_retries_list=retries_list,
timeouts=timeouts_list,
delay_seconds_list=delays,
unique_keys=unique_keys,
unique_keys=per_job_unique_keys,
metadata_list=metas,
expires_list=exp_list,
result_ttl_list=ttl_list,
)

results = [JobResult(py_job=pj, queue=self) for pj in py_jobs]

# Emit events and dispatch on_enqueue middleware
for job_result in results:
self._emit_event(
EventType.JOB_ENQUEUED,
{"job_id": job_result.id, "task_name": task_name},
)
for mw in self._get_middleware_chain(task_name):
try:
options: dict[str, Any] = {}
mw.on_enqueue(task_name, args_list[0], {}, options)
except Exception:
pass

return results
91 changes: 91 additions & 0 deletions tests/python/test_batch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
"""Tests for batch enqueue (enqueue_many / task.map)."""

import threading
from typing import Any

from taskito import Queue
from taskito.middleware import TaskMiddleware


def test_enqueue_many(queue: Queue) -> None:
Expand Down Expand Up @@ -53,3 +55,92 @@ def noop() -> None:

stats = queue.stats()
assert stats["pending"] == 50


def test_enqueue_many_invokes_on_enqueue_per_job(tmp_path: Any) -> None:
"""`on_enqueue` middleware must receive each job's own args/kwargs.

Regression: the previous implementation always passed `args_list[0]`
and a fresh empty options dict to every middleware call, so middleware
could not distinguish jobs in the batch.
"""

class RecordingMiddleware(TaskMiddleware):
def __init__(self) -> None:
self.calls: list[tuple[tuple, dict]] = []

def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None:
self.calls.append((args, dict(kwargs)))

mw = RecordingMiddleware()
q = Queue(db_path=str(tmp_path / "test.db"), middleware=[mw])

@q.task()
def add(a: int, b: int) -> int:
return a + b

q.enqueue_many(
task_name=add.name,
args_list=[(1, 2), (3, 4), (5, 6)],
kwargs_list=[{"trace": "alpha"}, {"trace": "beta"}, {"trace": "gamma"}],
)

assert mw.calls == [
((1, 2), {"trace": "alpha"}),
((3, 4), {"trace": "beta"}),
((5, 6), {"trace": "gamma"}),
]


def test_enqueue_many_applies_option_mutations(tmp_path: Any) -> None:
"""Mutations to the options dict inside on_enqueue must propagate to the
enqueued jobs — matching the documented behaviour of single-enqueue.
Regression: the previous implementation discarded mutations because the
hook ran *after* `enqueue_batch` and against a fresh empty dict.
"""

class PerJobBoosterMiddleware(TaskMiddleware):
def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None:
# Bump priority based on the first argument so each job sees a
# distinct mutation.
options["priority"] = int(args[0]) * 10

q = Queue(db_path=str(tmp_path / "test.db"), middleware=[PerJobBoosterMiddleware()])

@q.task()
def task_one(n: int) -> int:
return n

results = q.enqueue_many(task_name=task_one.name, args_list=[(1,), (2,), (3,)])

priorities = [q.get_job(r.id).to_dict()["priority"] for r in results] # type: ignore[union-attr]
assert priorities == [10, 20, 30]


def test_enqueue_many_logs_middleware_exceptions(tmp_path: Any, caplog: Any) -> None:
"""Middleware exceptions must be logged, not silently swallowed.

Regression: the previous implementation used a bare `except: pass`,
making misbehaving middleware effectively invisible.
"""
import logging

class ExplodingMiddleware(TaskMiddleware):
def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None:
raise RuntimeError("middleware boom")

q = Queue(db_path=str(tmp_path / "test.db"), middleware=[ExplodingMiddleware()])

@q.task()
def my_task(n: int) -> int:
return n

with caplog.at_level(logging.ERROR, logger="taskito.app"):
results = q.enqueue_many(task_name=my_task.name, args_list=[(1,), (2,)])

# Jobs are still enqueued — middleware errors must not block enqueue
assert len(results) == 2

# And the error was surfaced via the logger
assert any("middleware on_enqueue() error" in rec.message for rec in caplog.records)
assert any("middleware boom" in (rec.exc_text or "") for rec in caplog.records)
Loading