From e4838e6b892c460d012aaf8feb222579f8b57ccf Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Mon, 11 May 2026 09:29:57 +0530 Subject: [PATCH 01/11] feat(predicates): scaffold composable predicates package --- py_src/taskito/predicates/__init__.py | 79 +++++ py_src/taskito/predicates/context.py | 180 +++++++++++ py_src/taskito/predicates/core.py | 177 +++++++++++ py_src/taskito/predicates/evaluate.py | 73 +++++ py_src/taskito/predicates/metrics.py | 60 ++++ py_src/taskito/predicates/outcomes.py | 38 +++ py_src/taskito/predicates/providers.py | 52 ++++ py_src/taskito/predicates/recipes/__init__.py | 59 ++++ .../taskito/predicates/recipes/attributes.py | 112 +++++++ py_src/taskito/predicates/recipes/config.py | 55 ++++ py_src/taskito/predicates/recipes/system.py | 77 +++++ py_src/taskito/predicates/recipes/time.py | 223 ++++++++++++++ tests/python/test_predicates_core.py | 227 ++++++++++++++ tests/python/test_predicates_recipes.py | 290 ++++++++++++++++++ 14 files changed, 1702 insertions(+) create mode 100644 py_src/taskito/predicates/__init__.py create mode 100644 py_src/taskito/predicates/context.py create mode 100644 py_src/taskito/predicates/core.py create mode 100644 py_src/taskito/predicates/evaluate.py create mode 100644 py_src/taskito/predicates/metrics.py create mode 100644 py_src/taskito/predicates/outcomes.py create mode 100644 py_src/taskito/predicates/providers.py create mode 100644 py_src/taskito/predicates/recipes/__init__.py create mode 100644 py_src/taskito/predicates/recipes/attributes.py create mode 100644 py_src/taskito/predicates/recipes/config.py create mode 100644 py_src/taskito/predicates/recipes/system.py create mode 100644 py_src/taskito/predicates/recipes/time.py create mode 100644 tests/python/test_predicates_core.py create mode 100644 tests/python/test_predicates_recipes.py diff --git a/py_src/taskito/predicates/__init__.py b/py_src/taskito/predicates/__init__.py new file mode 100644 index 0000000..73890b1 --- /dev/null +++ b/py_src/taskito/predicates/__init__.py @@ -0,0 +1,79 @@ +"""Composable, fail-closed predicates for gating tasks. + +A predicate is any subclass of :class:`Predicate` whose +:meth:`Predicate.evaluate` returns ``True`` (allow), ``False`` (deny), +:class:`Defer` (skip now, retry later), or :class:`Cancel` (skip +permanently). Predicates compose with ``&`` / ``|`` / ``~``:: + + from taskito.predicates import is_business_hours, queue_paused + + @queue.task(predicate=is_business_hours() & ~queue_paused()) + def send_report(): ... + +Built-in recipes are imported from :mod:`taskito.predicates.recipes`. +""" + +from __future__ import annotations + +from taskito.predicates.context import PredicateContext +from taskito.predicates.core import ( + AndPredicate, + NotPredicate, + OrPredicate, + Predicate, + coerce_predicate, +) +from taskito.predicates.evaluate import evaluate_predicate +from taskito.predicates.metrics import PredicateMetrics +from taskito.predicates.outcomes import Cancel, Defer, PredicateOutcome +from taskito.predicates.providers import FeatureFlagProvider, env_feature_flag_provider +from taskito.predicates.recipes import ( + after, + before, + by_priority_at_least, + by_queue, + by_task, + env_var_truthy, + error_rate_under, + feature_flag, + in_time_window, + in_timezone, + is_business_hours, + is_weekend, + payload_matches, + queue_paused, + queue_size_under, + retry_count_under, +) + +__all__ = [ + "AndPredicate", + "Cancel", + "Defer", + "FeatureFlagProvider", + "NotPredicate", + "OrPredicate", + "Predicate", + "PredicateContext", + "PredicateMetrics", + "PredicateOutcome", + "after", + "before", + "by_priority_at_least", + "by_queue", + "by_task", + "coerce_predicate", + "env_feature_flag_provider", + "env_var_truthy", + "error_rate_under", + "evaluate_predicate", + "feature_flag", + "in_time_window", + "in_timezone", + "is_business_hours", + "is_weekend", + "payload_matches", + "queue_paused", + "queue_size_under", + "retry_count_under", +] diff --git a/py_src/taskito/predicates/context.py b/py_src/taskito/predicates/context.py new file mode 100644 index 0000000..10dff1d --- /dev/null +++ b/py_src/taskito/predicates/context.py @@ -0,0 +1,180 @@ +"""Predicate evaluation context.""" + +from __future__ import annotations + +import time +import weakref +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from taskito.app import Queue + + +@dataclass +class PredicateContext: + """Per-evaluation context passed to :meth:`Predicate.evaluate`. + + Carries job metadata, a weak reference to the owning :class:`Queue` so + system-state recipes can read queue depth or pause status, and a + per-evaluation memo dict so composed predicates that hit the same + storage call only pay for it once. + + Instances are short-lived: one per predicate evaluation. Holding a + reference past a single call is undefined. + """ + + task_name: str + queue: str + priority: int = 0 + retry_count: int = 0 + args: tuple[Any, ...] = () + kwargs: dict[str, Any] = field(default_factory=dict) + job_id: str | None = None + scheduled_at: datetime | None = None + created_at: datetime | None = None + payload_size: int = 0 + extras: dict[str, Any] = field(default_factory=dict) + + _queue_weakref: Any = None # weakref.ref[Queue] | None + _state_cache: dict[tuple[Any, ...], Any] = field(default_factory=dict) + + @classmethod + def for_enqueue( + cls, + *, + task_name: str, + queue: str, + priority: int | None, + args: tuple[Any, ...], + kwargs: dict[str, Any], + payload_size: int, + delay_seconds: float | None, + extras: dict[str, Any] | None, + queue_ref: Queue | None, + ) -> PredicateContext: + """Build a context for an enqueue-time evaluation.""" + now = datetime.now(timezone.utc) + scheduled = now if not delay_seconds else _add_seconds(now, delay_seconds) + return cls( + task_name=task_name, + queue=queue, + priority=priority or 0, + retry_count=0, + args=tuple(args), + kwargs=dict(kwargs), + job_id=None, + scheduled_at=scheduled, + created_at=now, + payload_size=payload_size, + extras=dict(extras or {}), + _queue_weakref=weakref.ref(queue_ref) if queue_ref is not None else None, + ) + + @classmethod + def for_dispatch( + cls, + *, + task_name: str, + queue: str, + priority: int, + retry_count: int, + args: tuple[Any, ...], + kwargs: dict[str, Any], + job_id: str, + payload_size: int, + extras: dict[str, Any] | None, + queue_ref: Queue | None, + ) -> PredicateContext: + """Build a context for a worker-dispatch-time evaluation.""" + now = datetime.now(timezone.utc) + return cls( + task_name=task_name, + queue=queue, + priority=priority, + retry_count=retry_count, + args=tuple(args), + kwargs=dict(kwargs), + job_id=job_id, + scheduled_at=now, + created_at=now, + payload_size=payload_size, + extras=dict(extras or {}), + _queue_weakref=weakref.ref(queue_ref) if queue_ref is not None else None, + ) + + def now(self) -> datetime: + """Wall-clock time, timezone-aware (UTC).""" + return datetime.now(timezone.utc) + + def monotonic(self) -> float: + """Process-local monotonic clock for short timing windows.""" + return time.monotonic() + + def queue_size(self, queue_name: str | None = None) -> int: + """Pending job count in ``queue_name`` (defaults to this job's queue). + + Memoised within a single predicate evaluation, so composed + predicates sharing the same call do not re-hit storage. + """ + name = queue_name or self.queue + key = ("queue_size", name) + if key in self._state_cache: + return int(self._state_cache[key]) + queue = self._resolve_queue() + if queue is None: + self._state_cache[key] = 0 + return 0 + try: + stats = queue._inner.stats_by_queue(name) + value = int(stats.get("pending", 0)) + except Exception: + value = 0 + self._state_cache[key] = value + return value + + def queue_paused(self, queue_name: str | None = None) -> bool: + """Whether ``queue_name`` is paused (defaults to this job's queue).""" + name = queue_name or self.queue + key = ("queue_paused", name) + if key in self._state_cache: + return bool(self._state_cache[key]) + queue = self._resolve_queue() + if queue is None: + self._state_cache[key] = False + return False + try: + paused = name in queue._inner.list_paused_queues() + except Exception: + paused = False + self._state_cache[key] = paused + return paused + + def stats(self) -> dict[str, int]: + """Backend-wide stats (memoised per evaluation).""" + key: tuple[Any, ...] = ("stats",) + if key in self._state_cache: + return dict(self._state_cache[key]) + queue = self._resolve_queue() + if queue is None: + self._state_cache[key] = {} + return {} + try: + value = dict(queue._inner.stats()) + except Exception: + value = {} + self._state_cache[key] = value + return value + + def _resolve_queue(self) -> Queue | None: + ref = self._queue_weakref + if ref is None: + return None + return ref() # type: ignore[no-any-return] + + +def _add_seconds(dt: datetime, seconds: float) -> datetime: + from datetime import timedelta + + return dt + timedelta(seconds=seconds) diff --git a/py_src/taskito/predicates/core.py b/py_src/taskito/predicates/core.py new file mode 100644 index 0000000..7d5a735 --- /dev/null +++ b/py_src/taskito/predicates/core.py @@ -0,0 +1,177 @@ +"""Predicate ABC and composition primitives.""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Awaitable, Callable +from typing import TYPE_CHECKING, Any + +from taskito.predicates.outcomes import Cancel, Defer + +if TYPE_CHECKING: + from taskito.predicates.context import PredicateContext + + +PredicateReturn = bool | Defer | Cancel | Awaitable[Any] +"""What a :meth:`Predicate.evaluate` is allowed to return. + +Synchronous predicates return :class:`bool`, :class:`Defer`, or +:class:`Cancel`. Async predicates may also return an awaitable that +resolves to one of those values. +""" + + +class Predicate(ABC): + """Base class for task predicates. + + Subclass and implement :meth:`evaluate`. Predicates compose with the + standard boolean operators:: + + @queue.task(predicate=is_business_hours() & ~queue_paused()) + def send_report(): ... + + Composition short-circuits: ``And`` stops at the first non-True, + ``Or`` stops at the first ``True``. ``Defer`` and ``Cancel`` outcomes + are propagated unchanged through ``~`` and through ``And``/``Or`` when + they cannot be overridden by the other operand. + """ + + @abstractmethod + def evaluate(self, ctx: PredicateContext) -> PredicateReturn: + """Return ``True`` to allow the job, ``False`` / ``Defer`` / ``Cancel`` to gate it.""" + + def __and__(self, other: Predicate) -> Predicate: + return AndPredicate(self, other) + + def __or__(self, other: Predicate) -> Predicate: + return OrPredicate(self, other) + + def __invert__(self) -> Predicate: + return NotPredicate(self) + + def __repr__(self) -> str: + return f"{type(self).__name__}()" + + +class AndPredicate(Predicate): + """Logical AND with short-circuit evaluation.""" + + __slots__ = ("_left", "_right") + + def __init__(self, left: Predicate, right: Predicate) -> None: + self._left = left + self._right = right + + def evaluate(self, ctx: PredicateContext) -> PredicateReturn: + # Note: deferred imports avoided. Compose-time evaluation calls + # _resolve_sync which itself handles awaitables when present. + from taskito.predicates.evaluate import _resolve_outcome + + left = _resolve_outcome(self._left, ctx) + if isinstance(left, (Defer, Cancel)) or left is False: + return left + return _resolve_outcome(self._right, ctx) + + def __repr__(self) -> str: + return f"({self._left!r} & {self._right!r})" + + +class OrPredicate(Predicate): + """Logical OR with short-circuit evaluation.""" + + __slots__ = ("_left", "_right") + + def __init__(self, left: Predicate, right: Predicate) -> None: + self._left = left + self._right = right + + def evaluate(self, ctx: PredicateContext) -> PredicateReturn: + from taskito.predicates.evaluate import _resolve_outcome + + left = _resolve_outcome(self._left, ctx) + if left is True: + return True + right = _resolve_outcome(self._right, ctx) + if right is True: + return True + # Neither side allows. Prefer the most informative gating: Cancel + # wins over Defer, Defer wins over False. + if isinstance(left, Cancel) or isinstance(right, Cancel): + return left if isinstance(left, Cancel) else right + if isinstance(left, Defer) or isinstance(right, Defer): + return left if isinstance(left, Defer) else right + return False + + def __repr__(self) -> str: + return f"({self._left!r} | {self._right!r})" + + +class NotPredicate(Predicate): + """Logical NOT. + + Inverts ``True`` / ``False``. ``Defer`` and ``Cancel`` outcomes pass + through unchanged — they are terminal, not booleans. + """ + + __slots__ = ("_inner",) + + def __init__(self, inner: Predicate) -> None: + self._inner = inner + + def evaluate(self, ctx: PredicateContext) -> PredicateReturn: + from taskito.predicates.evaluate import _resolve_outcome + + outcome = _resolve_outcome(self._inner, ctx) + if isinstance(outcome, (Defer, Cancel)): + return outcome + return not outcome + + def __repr__(self) -> str: + return f"~{self._inner!r}" + + +class _CallablePredicate(Predicate): + """Adapter that wraps a plain callable as a :class:`Predicate`. + + Two signatures are supported: + + * ``Callable[[PredicateContext], bool | Defer | Cancel]`` — receives + the full context. + * ``Callable[[str], bool]`` — receives just the task name. Used to + preserve back-compat with the contrib middleware ``task_filter`` arg. + """ + + __slots__ = ("_fn", "_takes_str") + + def __init__(self, fn: Callable[..., Any], *, takes_str: bool = False) -> None: + self._fn = fn + self._takes_str = takes_str + + def evaluate(self, ctx: PredicateContext) -> PredicateReturn: + if self._takes_str: + return bool(self._fn(ctx.task_name)) + return self._fn(ctx) # type: ignore[no-any-return] + + def __repr__(self) -> str: + return f"callable({getattr(self._fn, '__qualname__', repr(self._fn))})" + + +def coerce_predicate( + value: Predicate | Callable[..., Any] | None, + *, + str_callable: bool = False, +) -> Predicate | None: + """Wrap a callable as a :class:`Predicate`. ``None`` passes through. + + Use ``str_callable=True`` to accept the legacy ``Callable[[str], bool]`` + contrib ``task_filter`` shape. + """ + if value is None: + return None + if isinstance(value, Predicate): + return value + if callable(value): + return _CallablePredicate(value, takes_str=str_callable) + raise TypeError( + f"predicate must be a Predicate, callable, or None; got {type(value).__name__}" + ) diff --git a/py_src/taskito/predicates/evaluate.py b/py_src/taskito/predicates/evaluate.py new file mode 100644 index 0000000..b751f15 --- /dev/null +++ b/py_src/taskito/predicates/evaluate.py @@ -0,0 +1,73 @@ +"""Synchronous predicate evaluation runner. + +Wraps a predicate's ``evaluate`` call with fail-closed error handling and +normalises the return type. Async evaluation lives in +:mod:`taskito.async_support.predicates` to keep all ``asyncio`` use inside +the async package. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING + +from taskito.async_support.helpers import run_maybe_async +from taskito.predicates.outcomes import Cancel, Defer + +if TYPE_CHECKING: + from taskito.predicates.context import PredicateContext + from taskito.predicates.core import Predicate + from taskito.predicates.metrics import PredicateMetrics + +logger = logging.getLogger("taskito.predicates") + + +def evaluate_predicate( + predicate: Predicate, + ctx: PredicateContext, + *, + metrics: PredicateMetrics | None = None, +) -> bool | Defer | Cancel: + """Evaluate ``predicate`` fail-closed. + + On exception, returns ``False`` and (if provided) records an error on + ``metrics``. Coroutine returns are resolved synchronously via + :func:`run_maybe_async` so this is safe to call from worker threads. + """ + outcome = _resolve_outcome(predicate, ctx, metrics=metrics) + if metrics is not None: + if isinstance(outcome, Defer): + metrics.record_deferred() + elif isinstance(outcome, Cancel): + metrics.record_cancelled() + elif outcome is True: + metrics.record_allowed() + else: + metrics.record_denied() + return outcome + + +def _resolve_outcome( + predicate: Predicate, + ctx: PredicateContext, + *, + metrics: PredicateMetrics | None = None, +) -> bool | Defer | Cancel: + """Run ``evaluate`` with fail-closed semantics and normalise the return. + + Used internally by composition operators (And/Or/Not) — they do NOT + record outcome metrics themselves to avoid double-counting; only the + top-level call from :func:`evaluate_predicate` records. + """ + try: + raw = predicate.evaluate(ctx) + resolved = run_maybe_async(raw) + except Exception: + logger.exception("Predicate %r raised; treating as False (fail-closed)", predicate) + if metrics is not None: + metrics.record_error() + return False + + if isinstance(resolved, (Defer, Cancel)): + return resolved + return bool(resolved) diff --git a/py_src/taskito/predicates/metrics.py b/py_src/taskito/predicates/metrics.py new file mode 100644 index 0000000..0943dcd --- /dev/null +++ b/py_src/taskito/predicates/metrics.py @@ -0,0 +1,60 @@ +"""Lightweight counters for predicate outcomes.""" + +from __future__ import annotations + +import threading +from dataclasses import dataclass, field + + +@dataclass +class PredicateMetrics: + """Per-queue counters for predicate outcomes. + + All counters are atomic under a single lock — predicate evaluation is + not on the hot path of dispatch, so contention is negligible. + """ + + _allowed: int = 0 + _denied: int = 0 + _deferred: int = 0 + _cancelled: int = 0 + _errors: int = 0 + _lock: threading.Lock = field(default_factory=threading.Lock, repr=False) + + def record_allowed(self) -> None: + with self._lock: + self._allowed += 1 + + def record_denied(self) -> None: + with self._lock: + self._denied += 1 + + def record_deferred(self) -> None: + with self._lock: + self._deferred += 1 + + def record_cancelled(self) -> None: + with self._lock: + self._cancelled += 1 + + def record_error(self) -> None: + with self._lock: + self._errors += 1 + + def snapshot(self) -> dict[str, int]: + with self._lock: + return { + "allowed": self._allowed, + "denied": self._denied, + "deferred": self._deferred, + "cancelled": self._cancelled, + "errors": self._errors, + } + + def reset(self) -> None: + with self._lock: + self._allowed = 0 + self._denied = 0 + self._deferred = 0 + self._cancelled = 0 + self._errors = 0 diff --git a/py_src/taskito/predicates/outcomes.py b/py_src/taskito/predicates/outcomes.py new file mode 100644 index 0000000..9d207a8 --- /dev/null +++ b/py_src/taskito/predicates/outcomes.py @@ -0,0 +1,38 @@ +"""Sentinel return types for predicate evaluation.""" + +from __future__ import annotations + +from dataclasses import dataclass + + +@dataclass(frozen=True, slots=True) +class Defer: + """Sentinel returned by a predicate to defer the job. + + At enqueue time, ``seconds`` is added to the caller's delay before the + job is saved. At worker-dispatch time, the current job is cancelled and + a fresh job is re-enqueued with ``seconds`` of delay (preserving task + name, queue, arguments, and dispatch metadata). + """ + + seconds: float + + def __post_init__(self) -> None: + if self.seconds < 0: + raise ValueError(f"Defer seconds must be >= 0, got {self.seconds}") + + +@dataclass(frozen=True, slots=True) +class Cancel: + """Sentinel returned by a predicate to terminally skip the job. + + At enqueue time, no row is created. At worker-dispatch time, the job is + cancelled and ``PREDICATE_CANCELLED`` is emitted. ``reason`` is included + in the emitted event payload. + """ + + reason: str = "" + + +PredicateOutcome = bool | Defer | Cancel +"""Anything a :meth:`Predicate.evaluate` is allowed to return.""" diff --git a/py_src/taskito/predicates/providers.py b/py_src/taskito/predicates/providers.py new file mode 100644 index 0000000..5617d88 --- /dev/null +++ b/py_src/taskito/predicates/providers.py @@ -0,0 +1,52 @@ +"""Pluggable provider Protocols used by recipe predicates.""" + +from __future__ import annotations + +import os +from typing import TYPE_CHECKING, Protocol, runtime_checkable + +if TYPE_CHECKING: + from taskito.predicates.context import PredicateContext + + +_TRUTHY = frozenset({"1", "true", "t", "yes", "y", "on"}) + + +@runtime_checkable +class FeatureFlagProvider(Protocol): + """Plug a feature-flag system (LaunchDarkly, Statsig, custom) into recipes. + + Implementations return ``True`` if the named flag is enabled for the + given evaluation context. Errors should not propagate — the recipe is + wrapped in fail-closed evaluation, but providers themselves should + handle their own errors and return a sensible default rather than + raising. + """ + + def is_enabled(self, name: str, ctx: PredicateContext, /) -> bool: ... + + +class _EnvFeatureFlagProvider: + """Default provider: reads ``${prefix}{NAME}`` from process env. + + A value of any of ``1 / true / yes / on`` (case-insensitive) is + considered enabled. Anything else, including a missing variable, is + disabled. + """ + + __slots__ = ("_prefix",) + + def __init__(self, prefix: str = "FF_") -> None: + self._prefix = prefix + + def is_enabled(self, name: str, ctx: PredicateContext, /) -> bool: + key = f"{self._prefix}{name.upper()}" + return os.environ.get(key, "").strip().lower() in _TRUTHY + + def __repr__(self) -> str: + return f"EnvFeatureFlagProvider(prefix={self._prefix!r})" + + +def env_feature_flag_provider(prefix: str = "FF_") -> FeatureFlagProvider: + """Build an env-var-backed :class:`FeatureFlagProvider`.""" + return _EnvFeatureFlagProvider(prefix=prefix) diff --git a/py_src/taskito/predicates/recipes/__init__.py b/py_src/taskito/predicates/recipes/__init__.py new file mode 100644 index 0000000..b04c0eb --- /dev/null +++ b/py_src/taskito/predicates/recipes/__init__.py @@ -0,0 +1,59 @@ +"""Predefined predicate recipes. + +Recipes are factory functions that return :class:`~taskito.predicates.Predicate` +instances. Each recipe accepts plain Python values so it can be used in +decorator declarations:: + + @queue.task( + predicate=is_business_hours(tz="US/Pacific") + & ~queue_paused() + | by_priority_at_least(8), + ) + def send_report(): ... +""" + +from __future__ import annotations + +from taskito.predicates.recipes.attributes import ( + by_priority_at_least, + by_queue, + by_task, + payload_matches, + retry_count_under, +) +from taskito.predicates.recipes.config import ( + env_var_truthy, + feature_flag, +) +from taskito.predicates.recipes.system import ( + error_rate_under, + queue_paused, + queue_size_under, +) +from taskito.predicates.recipes.time import ( + after, + before, + in_time_window, + in_timezone, + is_business_hours, + is_weekend, +) + +__all__ = [ + "after", + "before", + "by_priority_at_least", + "by_queue", + "by_task", + "env_var_truthy", + "error_rate_under", + "feature_flag", + "in_time_window", + "in_timezone", + "is_business_hours", + "is_weekend", + "payload_matches", + "queue_paused", + "queue_size_under", + "retry_count_under", +] diff --git a/py_src/taskito/predicates/recipes/attributes.py b/py_src/taskito/predicates/recipes/attributes.py new file mode 100644 index 0000000..cd2cd8b --- /dev/null +++ b/py_src/taskito/predicates/recipes/attributes.py @@ -0,0 +1,112 @@ +"""Predicates that read job metadata.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +from taskito.predicates.core import Predicate + +if TYPE_CHECKING: + from taskito.predicates.context import PredicateContext + + +@dataclass(frozen=True) +class _ByQueue(Predicate): + name: str + + def evaluate(self, ctx: PredicateContext) -> bool: + return ctx.queue == self.name + + +def by_queue(name: str) -> Predicate: + """Allow only jobs whose target queue equals ``name``.""" + if not name: + raise ValueError("queue name must be non-empty") + return _ByQueue(name=name) + + +@dataclass(frozen=True) +class _ByTask(Predicate): + name: str + + def evaluate(self, ctx: PredicateContext) -> bool: + return ctx.task_name == self.name + + +def by_task(name: str) -> Predicate: + """Allow only jobs for the given task name (full module-qualified).""" + if not name: + raise ValueError("task name must be non-empty") + return _ByTask(name=name) + + +@dataclass(frozen=True) +class _ByPriorityAtLeast(Predicate): + threshold: int + + def evaluate(self, ctx: PredicateContext) -> bool: + return ctx.priority >= self.threshold + + +def by_priority_at_least(threshold: int) -> Predicate: + """Allow jobs whose priority is ``>= threshold``.""" + return _ByPriorityAtLeast(threshold=threshold) + + +@dataclass(frozen=True) +class _RetryCountUnder(Predicate): + limit: int + + def evaluate(self, ctx: PredicateContext) -> bool: + return ctx.retry_count < self.limit + + +def retry_count_under(limit: int) -> Predicate: + """Allow jobs whose retry counter is strictly less than ``limit``.""" + if limit < 0: + raise ValueError("limit must be >= 0") + return _RetryCountUnder(limit=limit) + + +@dataclass(frozen=True) +class _PayloadMatches(Predicate): + path: tuple[str, ...] + expected: Any + + def evaluate(self, ctx: PredicateContext) -> bool: + node: Any = {"args": ctx.args, "kwargs": ctx.kwargs} + for segment in self.path: + node = _safe_lookup(node, segment) + if node is _MISSING: + return False + return bool(node == self.expected) + + +_MISSING = object() + + +def _safe_lookup(node: Any, key: str) -> Any: + if isinstance(node, dict): + return node.get(key, _MISSING) + if isinstance(node, (list, tuple)): + try: + return node[int(key)] + except (ValueError, IndexError): + return _MISSING + return getattr(node, key, _MISSING) + + +def payload_matches(path: str, expected: Any) -> Predicate: + """Match a value in args/kwargs by dotted path. + + ``path`` is a dotted string addressing a value within + ``{"args": (...), "kwargs": {...}}``. Examples: + + * ``"kwargs.tenant_id"`` → ``ctx.kwargs["tenant_id"]`` + * ``"args.0"`` → ``ctx.args[0]`` + * ``"kwargs.config.region"`` → ``ctx.kwargs["config"]["region"]`` + """ + if not path: + raise ValueError("path must be non-empty") + return _PayloadMatches(path=tuple(path.split(".")), expected=expected) diff --git a/py_src/taskito/predicates/recipes/config.py b/py_src/taskito/predicates/recipes/config.py new file mode 100644 index 0000000..7c2dd87 --- /dev/null +++ b/py_src/taskito/predicates/recipes/config.py @@ -0,0 +1,55 @@ +"""Predicates that consult external configuration (env vars, feature flags).""" + +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from taskito.predicates.core import Predicate +from taskito.predicates.providers import ( + FeatureFlagProvider, + env_feature_flag_provider, +) + +if TYPE_CHECKING: + from taskito.predicates.context import PredicateContext + + +_TRUTHY = frozenset({"1", "true", "t", "yes", "y", "on"}) + + +@dataclass(frozen=True) +class _EnvVarTruthy(Predicate): + name: str + + def evaluate(self, ctx: PredicateContext) -> bool: + return os.environ.get(self.name, "").strip().lower() in _TRUTHY + + +def env_var_truthy(name: str) -> Predicate: + """Allow when env var ``name`` is set to ``1``/``true``/``yes``/``on``.""" + if not name: + raise ValueError("env var name must be non-empty") + return _EnvVarTruthy(name=name) + + +@dataclass(frozen=True) +class _FeatureFlag(Predicate): + flag: str + provider: FeatureFlagProvider + + def evaluate(self, ctx: PredicateContext) -> bool: + return bool(self.provider.is_enabled(self.flag, ctx)) + + +def feature_flag(name: str, *, provider: FeatureFlagProvider | None = None) -> Predicate: + """Allow when feature flag ``name`` is enabled. + + Defaults to an env-var-backed provider with prefix ``FF_`` (e.g. + ``feature_flag("new-billing")`` reads ``FF_NEW-BILLING``). Pass a + custom provider to integrate LaunchDarkly, Statsig, etc. + """ + if not name: + raise ValueError("flag name must be non-empty") + return _FeatureFlag(flag=name, provider=provider or env_feature_flag_provider()) diff --git a/py_src/taskito/predicates/recipes/system.py b/py_src/taskito/predicates/recipes/system.py new file mode 100644 index 0000000..0cbc4c0 --- /dev/null +++ b/py_src/taskito/predicates/recipes/system.py @@ -0,0 +1,77 @@ +"""Predicates that read live system state from the queue.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +from taskito.predicates.core import Predicate + +if TYPE_CHECKING: + from taskito.predicates.context import PredicateContext + + +@dataclass(frozen=True) +class _QueueSizeUnder(Predicate): + limit: int + queue_name: str | None + + def evaluate(self, ctx: PredicateContext) -> bool: + return ctx.queue_size(self.queue_name) < self.limit + + +def queue_size_under(limit: int, *, queue: str | None = None) -> Predicate: + """Allow only when the queue has fewer than ``limit`` pending jobs. + + ``queue=None`` (the default) inspects the job's own queue. Pass an + explicit name to gate based on a different queue. + """ + if limit <= 0: + raise ValueError("limit must be > 0") + return _QueueSizeUnder(limit=limit, queue_name=queue) + + +@dataclass(frozen=True) +class _QueuePaused(Predicate): + queue_name: str | None + + def evaluate(self, ctx: PredicateContext) -> bool: + return ctx.queue_paused(self.queue_name) + + +def queue_paused(queue: str | None = None) -> Predicate: + """True when the queue is currently paused. + + Typically used inverted: ``~queue_paused()``. + """ + return _QueuePaused(queue_name=queue) + + +@dataclass(frozen=True) +class _ErrorRateUnder(Predicate): + max_rate: float + + def evaluate(self, ctx: PredicateContext) -> bool: + stats = ctx.stats() + if not stats: + return True + completed = int(stats.get("completed", 0)) + failed = int(stats.get("failed", 0)) + dead = int(stats.get("dead", 0)) + total = completed + failed + dead + if total <= 0: + return True + rate = (failed + dead) / total + return rate < self.max_rate + + +def error_rate_under(max_rate: float) -> Predicate: + """Allow only when the global failure ratio is below ``max_rate``. + + ``max_rate`` is a fraction in (0, 1]. The metric is computed from + backend-wide stats: ``(failed + dead) / (completed + failed + dead)``. + When no jobs have run yet, the predicate allows. + """ + if not 0.0 < max_rate <= 1.0: + raise ValueError("max_rate must be in (0, 1]") + return _ErrorRateUnder(max_rate=max_rate) diff --git a/py_src/taskito/predicates/recipes/time.py b/py_src/taskito/predicates/recipes/time.py new file mode 100644 index 0000000..252cf6f --- /dev/null +++ b/py_src/taskito/predicates/recipes/time.py @@ -0,0 +1,223 @@ +"""Time-based predicates. + +All recipes work with timezone-aware datetimes. ``tz`` arguments accept an +IANA timezone string (e.g. ``"US/Pacific"``); ``None`` means UTC. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING + +from taskito.predicates.core import Predicate +from taskito.predicates.outcomes import Defer + +if TYPE_CHECKING: + from taskito.predicates.context import PredicateContext + +try: + from zoneinfo import ZoneInfo, ZoneInfoNotFoundError + + _HAS_ZONEINFO = True +except ImportError: # pragma: no cover - zoneinfo is stdlib on 3.9+ + _HAS_ZONEINFO = False + + class ZoneInfoNotFoundError(Exception): # type: ignore[no-redef] + pass + + +_ONE_DAY = timedelta(days=1) + + +def _resolve_tz(tz: str | None) -> timezone: + """Resolve a tz string to a ``tzinfo``-compatible object.""" + if tz is None: + return timezone.utc + if not _HAS_ZONEINFO: + raise RuntimeError("zoneinfo is not available; pass tz=None or install Python 3.9+") + try: + return ZoneInfo(tz) # type: ignore[return-value] + except ZoneInfoNotFoundError as exc: + raise ValueError(f"unknown timezone: {tz!r}") from exc + + +def _seconds_until_window(now_local: datetime, start_hour: int, start_minute: int) -> float: + """Seconds from ``now_local`` until the next occurrence of (h, m).""" + target = now_local.replace(hour=start_hour, minute=start_minute, second=0, microsecond=0) + if target <= now_local: + target = target + _ONE_DAY + return max(1.0, (target - now_local).total_seconds()) + + +@dataclass(frozen=True) +class _IsBusinessHours(Predicate): + """True Mon-Fri between 09:00 and 17:00 in the configured timezone.""" + + start_hour: int = 9 + end_hour: int = 17 + tz: str | None = None + weekdays_only: bool = True + + def evaluate(self, ctx: PredicateContext) -> bool | Defer: + tz_info = _resolve_tz(self.tz) + local = ctx.now().astimezone(tz_info) + if self.weekdays_only and local.weekday() >= 5: + # Sat / Sun → defer until Monday at start_hour + days_to_monday = 7 - local.weekday() + defer = local.replace( + hour=self.start_hour, minute=0, second=0, microsecond=0 + ) + timedelta(days=days_to_monday) + return Defer(seconds=max(1.0, (defer - local).total_seconds())) + if local.hour < self.start_hour: + return Defer(seconds=_seconds_until_window(local, self.start_hour, 0)) + if local.hour >= self.end_hour: + # Defer to next day at start_hour + next_day = local + _ONE_DAY + target = next_day.replace(hour=self.start_hour, minute=0, second=0, microsecond=0) + return Defer(seconds=max(1.0, (target - local).total_seconds())) + return True + + +def is_business_hours( + start_hour: int = 9, + end_hour: int = 17, + *, + tz: str | None = None, + weekdays_only: bool = True, +) -> Predicate: + """Allow only during business hours; otherwise :class:`Defer` to the next window. + + Defaults to 09:00-17:00 Mon-Fri in UTC. Pass ``tz="US/Pacific"`` (or + any IANA name) for a local-time window. + """ + if not 0 <= start_hour < 24 or not 0 < end_hour <= 24 or start_hour >= end_hour: + raise ValueError(f"invalid business-hours window: {start_hour}-{end_hour}") + return _IsBusinessHours( + start_hour=start_hour, end_hour=end_hour, tz=tz, weekdays_only=weekdays_only + ) + + +@dataclass(frozen=True) +class _IsWeekend(Predicate): + tz: str | None = None + + def evaluate(self, ctx: PredicateContext) -> bool: + tz_info = _resolve_tz(self.tz) + local = ctx.now().astimezone(tz_info) + return local.weekday() >= 5 + + +def is_weekend(*, tz: str | None = None) -> Predicate: + """True on Saturday or Sunday in the configured timezone.""" + return _IsWeekend(tz=tz) + + +@dataclass(frozen=True) +class _InTimeWindow(Predicate): + start_hour: int + start_minute: int + end_hour: int + end_minute: int + tz: str | None = None + + def evaluate(self, ctx: PredicateContext) -> bool | Defer: + tz_info = _resolve_tz(self.tz) + local = ctx.now().astimezone(tz_info) + start_minutes = self.start_hour * 60 + self.start_minute + end_minutes = self.end_hour * 60 + self.end_minute + cur_minutes = local.hour * 60 + local.minute + if start_minutes <= cur_minutes < end_minutes: + return True + if cur_minutes < start_minutes: + return Defer(seconds=(start_minutes - cur_minutes) * 60.0) + # past end — defer to tomorrow's start + next_day = local + _ONE_DAY + target = next_day.replace( + hour=self.start_hour, minute=self.start_minute, second=0, microsecond=0 + ) + return Defer(seconds=max(1.0, (target - local).total_seconds())) + + +def in_time_window(start: str, end: str, *, tz: str | None = None) -> Predicate: + """Allow during ``[start, end)`` in the configured timezone; defer otherwise. + + ``start`` and ``end`` are ``"HH:MM"`` strings. End is exclusive. + """ + sh, sm = _parse_hhmm(start) + eh, em = _parse_hhmm(end) + if (sh, sm) >= (eh, em): + raise ValueError(f"start ({start}) must be before end ({end})") + return _InTimeWindow(start_hour=sh, start_minute=sm, end_hour=eh, end_minute=em, tz=tz) + + +def _parse_hhmm(value: str) -> tuple[int, int]: + parts = value.split(":") + if len(parts) != 2: + raise ValueError(f"expected 'HH:MM', got {value!r}") + try: + h, m = int(parts[0]), int(parts[1]) + except ValueError as exc: + raise ValueError(f"expected 'HH:MM', got {value!r}") from exc + if not 0 <= h <= 23 or not 0 <= m <= 59: + raise ValueError(f"hour/minute out of range: {value!r}") + return h, m + + +@dataclass(frozen=True) +class _After(Predicate): + target: datetime + + def evaluate(self, ctx: PredicateContext) -> bool | Defer: + now = ctx.now() + if now >= self.target: + return True + return Defer(seconds=(self.target - now).total_seconds()) + + +def after(target: datetime) -> Predicate: + """Allow only when wall-clock time is at or past ``target``.""" + if target.tzinfo is None: + target = target.replace(tzinfo=timezone.utc) + return _After(target=target) + + +@dataclass(frozen=True) +class _Before(Predicate): + target: datetime + + def evaluate(self, ctx: PredicateContext) -> bool: + return ctx.now() < self.target + + +def before(target: datetime) -> Predicate: + """Allow only when wall-clock time is strictly before ``target``.""" + if target.tzinfo is None: + target = target.replace(tzinfo=timezone.utc) + return _Before(target=target) + + +@dataclass(frozen=True) +class _InTimezone(Predicate): + """Validates that the recipe-configured timezone resolves on this host. + + Returns ``True`` unconditionally — useful as a composable building + block when paired with other time predicates. The cheaper alternative + is to call :func:`_resolve_tz` at recipe-construction time, which is + what this does. + """ + + tz: str + + def evaluate(self, ctx: PredicateContext) -> bool: + return True + + +def in_timezone(tz: str) -> Predicate: + """No-op predicate that validates ``tz`` resolves on this host. + + Mostly useful as a guard at decoration time: passing an unknown tz + here raises immediately rather than at first execution. + """ + _resolve_tz(tz) + return _InTimezone(tz=tz) diff --git a/tests/python/test_predicates_core.py b/tests/python/test_predicates_core.py new file mode 100644 index 0000000..9056f2a --- /dev/null +++ b/tests/python/test_predicates_core.py @@ -0,0 +1,227 @@ +"""Core predicate algebra and evaluation tests.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import pytest + +from taskito.predicates import ( + Cancel, + Defer, + Predicate, + PredicateContext, + PredicateMetrics, + coerce_predicate, + evaluate_predicate, +) + + +@dataclass(frozen=True) +class _Const(Predicate): + value: Any + + def evaluate(self, ctx: PredicateContext) -> Any: + return self.value + + +def _ctx() -> PredicateContext: + return PredicateContext(task_name="t", queue="default") + + +# -- Outcome sentinels ------------------------------------------------------- + + +def test_defer_rejects_negative_seconds() -> None: + with pytest.raises(ValueError): + Defer(seconds=-1) + + +def test_defer_and_cancel_are_frozen() -> None: + d = Defer(seconds=10) + c = Cancel(reason="x") + with pytest.raises(AttributeError): + d.seconds = 1 # type: ignore[misc] + with pytest.raises(AttributeError): + c.reason = "y" # type: ignore[misc] + + +# -- Composition operators --------------------------------------------------- + + +def test_and_short_circuits_on_false() -> None: + calls: list[str] = [] + + @dataclass(frozen=True) + class _Track(Predicate): + label: str + value: bool + + def evaluate(self, ctx: PredicateContext) -> bool: + calls.append(self.label) + return self.value + + combined = _Track(label="left", value=False) & _Track(label="right", value=True) + assert evaluate_predicate(combined, _ctx()) is False + assert calls == ["left"] + + +def test_or_short_circuits_on_true() -> None: + calls: list[str] = [] + + @dataclass(frozen=True) + class _Track(Predicate): + label: str + value: bool + + def evaluate(self, ctx: PredicateContext) -> bool: + calls.append(self.label) + return self.value + + combined = _Track(label="left", value=True) | _Track(label="right", value=False) + assert evaluate_predicate(combined, _ctx()) is True + assert calls == ["left"] + + +def test_not_inverts_boolean() -> None: + assert evaluate_predicate(~_Const(True), _ctx()) is False + assert evaluate_predicate(~_Const(False), _ctx()) is True + + +def test_not_passes_defer_and_cancel_through() -> None: + d = Defer(seconds=30) + c = Cancel(reason="bad") + assert evaluate_predicate(~_Const(d), _ctx()) == d + assert evaluate_predicate(~_Const(c), _ctx()) == c + + +def test_and_propagates_defer_from_left() -> None: + d = Defer(seconds=30) + combined = _Const(d) & _Const(True) + assert evaluate_predicate(combined, _ctx()) == d + + +def test_and_passes_through_when_left_true() -> None: + d = Defer(seconds=30) + combined = _Const(True) & _Const(d) + assert evaluate_predicate(combined, _ctx()) == d + + +def test_or_prefers_cancel_when_both_deny() -> None: + d = Defer(seconds=5) + c = Cancel(reason="permanent") + combined = _Const(d) | _Const(c) + assert evaluate_predicate(combined, _ctx()) == c + + +def test_or_prefers_defer_over_false() -> None: + d = Defer(seconds=5) + combined = _Const(False) | _Const(d) + assert evaluate_predicate(combined, _ctx()) == d + + +# -- repr stability --------------------------------------------------------- + + +def test_repr_uses_operator_syntax() -> None: + expr = _Const(True) & ~_Const(False) | _Const(False) + text = repr(expr) + assert "&" in text + assert "|" in text + assert "~" in text + + +# -- Fail-closed ----------------------------------------------------------- + + +def test_evaluate_returns_false_on_exception() -> None: + @dataclass(frozen=True) + class _Boom(Predicate): + def evaluate(self, ctx: PredicateContext) -> bool: + raise RuntimeError("intentional") + + metrics = PredicateMetrics() + assert evaluate_predicate(_Boom(), _ctx(), metrics=metrics) is False + snap = metrics.snapshot() + assert snap["errors"] == 1 + assert snap["denied"] == 1 + + +def test_metrics_count_outcomes() -> None: + metrics = PredicateMetrics() + evaluate_predicate(_Const(True), _ctx(), metrics=metrics) + evaluate_predicate(_Const(False), _ctx(), metrics=metrics) + evaluate_predicate(_Const(Defer(seconds=1)), _ctx(), metrics=metrics) + evaluate_predicate(_Const(Cancel(reason="x")), _ctx(), metrics=metrics) + snap = metrics.snapshot() + assert snap == { + "allowed": 1, + "denied": 1, + "deferred": 1, + "cancelled": 1, + "errors": 0, + } + + +def test_metrics_reset() -> None: + metrics = PredicateMetrics() + evaluate_predicate(_Const(True), _ctx(), metrics=metrics) + metrics.reset() + assert metrics.snapshot()["allowed"] == 0 + + +# -- coerce_predicate ------------------------------------------------------- + + +def test_coerce_passes_through_predicate() -> None: + p = _Const(True) + assert coerce_predicate(p) is p + + +def test_coerce_wraps_callable_with_context() -> None: + def fn(ctx: PredicateContext) -> bool: + return ctx.task_name == "t" + + wrapped = coerce_predicate(fn) + assert wrapped is not None + assert evaluate_predicate(wrapped, _ctx()) is True + + +def test_coerce_wraps_str_callable_for_backcompat() -> None: + def fn(name: str) -> bool: + return name == "t" + + wrapped = coerce_predicate(fn, str_callable=True) + assert wrapped is not None + assert evaluate_predicate(wrapped, _ctx()) is True + + +def test_coerce_none_returns_none() -> None: + assert coerce_predicate(None) is None + + +def test_coerce_rejects_non_callable() -> None: + with pytest.raises(TypeError): + coerce_predicate(42) # type: ignore[arg-type] + + +# -- Async predicate is awaited --------------------------------------------- + + +def test_async_predicate_is_run_to_completion() -> None: + class _Async(Predicate): + async def evaluate(self, ctx: PredicateContext) -> bool: + return True + + assert evaluate_predicate(_Async(), _ctx()) is True + + +def test_async_predicate_can_return_defer() -> None: + d = Defer(seconds=42) + + class _Async(Predicate): + async def evaluate(self, ctx: PredicateContext) -> Defer: + return d + + assert evaluate_predicate(_Async(), _ctx()) == d diff --git a/tests/python/test_predicates_recipes.py b/tests/python/test_predicates_recipes.py new file mode 100644 index 0000000..2bfdf9b --- /dev/null +++ b/tests/python/test_predicates_recipes.py @@ -0,0 +1,290 @@ +"""Built-in predicate recipe tests.""" + +from __future__ import annotations + +import os +from datetime import datetime, timedelta, timezone +from unittest import mock + +import pytest + +from taskito.predicates import ( + Defer, + PredicateContext, + after, + before, + by_priority_at_least, + by_queue, + by_task, + env_var_truthy, + feature_flag, + in_time_window, + in_timezone, + is_business_hours, + is_weekend, + payload_matches, + queue_paused, + queue_size_under, + retry_count_under, +) +from taskito.predicates.providers import FeatureFlagProvider + + +def _ctx(**overrides: object) -> PredicateContext: + defaults: dict[str, object] = {"task_name": "t", "queue": "default"} + defaults.update(overrides) + return PredicateContext(**defaults) # type: ignore[arg-type] + + +# -- Time recipes ----------------------------------------------------------- + + +def test_after_allows_when_past_target() -> None: + target = datetime.now(timezone.utc) - timedelta(hours=1) + assert after(target).evaluate(_ctx()) is True + + +def test_after_defers_when_before_target() -> None: + target = datetime.now(timezone.utc) + timedelta(hours=1) + outcome = after(target).evaluate(_ctx()) + assert isinstance(outcome, Defer) + assert 3000 < outcome.seconds < 3700 + + +def test_before_allows_when_strictly_earlier() -> None: + target = datetime.now(timezone.utc) + timedelta(hours=1) + assert before(target).evaluate(_ctx()) is True + + +def test_before_denies_when_past_target() -> None: + target = datetime.now(timezone.utc) - timedelta(hours=1) + assert before(target).evaluate(_ctx()) is False + + +def test_in_time_window_parses_hh_mm() -> None: + assert in_time_window("09:00", "17:00") is not None + + +def test_in_time_window_rejects_inverted_range() -> None: + with pytest.raises(ValueError): + in_time_window("17:00", "09:00") + + +def test_in_time_window_rejects_bad_format() -> None: + with pytest.raises(ValueError): + in_time_window("nine", "five") + + +def test_in_time_window_returns_defer_when_outside() -> None: + pred = in_time_window("09:00", "17:00") + fake_now = datetime(2026, 5, 11, 8, 0, tzinfo=timezone.utc) + ctx = _ctx() + with mock.patch.object(ctx, "now", return_value=fake_now): + outcome = pred.evaluate(ctx) + assert isinstance(outcome, Defer) + assert outcome.seconds == 3600.0 + + +def test_is_business_hours_returns_defer_on_weekend() -> None: + pred = is_business_hours(tz=None) + # 2026-05-09 is a Saturday + sat = datetime(2026, 5, 9, 12, 0, tzinfo=timezone.utc) + ctx = _ctx() + with mock.patch.object(ctx, "now", return_value=sat): + outcome = pred.evaluate(ctx) + assert isinstance(outcome, Defer) + assert outcome.seconds > 0 + + +def test_is_business_hours_allows_during_window() -> None: + pred = is_business_hours(tz=None) + # 2026-05-11 is a Monday + weekday_noon = datetime(2026, 5, 11, 12, 0, tzinfo=timezone.utc) + ctx = _ctx() + with mock.patch.object(ctx, "now", return_value=weekday_noon): + assert pred.evaluate(ctx) is True + + +def test_is_business_hours_rejects_invalid_window() -> None: + with pytest.raises(ValueError): + is_business_hours(start_hour=20, end_hour=10) + + +def test_is_weekend_uses_utc_by_default() -> None: + pred = is_weekend() + sat = datetime(2026, 5, 9, 12, 0, tzinfo=timezone.utc) + weekday = datetime(2026, 5, 11, 12, 0, tzinfo=timezone.utc) + ctx = _ctx() + with mock.patch.object(ctx, "now", return_value=sat): + assert pred.evaluate(ctx) is True + with mock.patch.object(ctx, "now", return_value=weekday): + assert pred.evaluate(ctx) is False + + +def test_in_timezone_validates_at_construction() -> None: + in_timezone("UTC") + with pytest.raises(ValueError): + in_timezone("Mars/Olympus_Mons") + + +# -- Attribute recipes ------------------------------------------------------ + + +def test_by_queue_matches_name() -> None: + assert by_queue("default").evaluate(_ctx()) is True + assert by_queue("other").evaluate(_ctx()) is False + + +def test_by_queue_rejects_empty() -> None: + with pytest.raises(ValueError): + by_queue("") + + +def test_by_task_matches_name() -> None: + assert by_task("t").evaluate(_ctx()) is True + assert by_task("other").evaluate(_ctx()) is False + + +def test_by_priority_at_least() -> None: + pred = by_priority_at_least(5) + assert pred.evaluate(_ctx(priority=10)) is True + assert pred.evaluate(_ctx(priority=5)) is True + assert pred.evaluate(_ctx(priority=4)) is False + + +def test_retry_count_under_validates() -> None: + with pytest.raises(ValueError): + retry_count_under(-1) + + +def test_retry_count_under_compares_strict() -> None: + pred = retry_count_under(3) + assert pred.evaluate(_ctx(retry_count=2)) is True + assert pred.evaluate(_ctx(retry_count=3)) is False + + +def test_payload_matches_kwargs() -> None: + pred = payload_matches("kwargs.tenant", "acme") + assert pred.evaluate(_ctx(kwargs={"tenant": "acme"})) is True + assert pred.evaluate(_ctx(kwargs={"tenant": "other"})) is False + assert pred.evaluate(_ctx(kwargs={})) is False + + +def test_payload_matches_args_by_index() -> None: + pred = payload_matches("args.0", "x") + assert pred.evaluate(_ctx(args=("x", "y"))) is True + assert pred.evaluate(_ctx(args=("y",))) is False + + +def test_payload_matches_nested_dict() -> None: + pred = payload_matches("kwargs.config.region", "us-east") + assert pred.evaluate(_ctx(kwargs={"config": {"region": "us-east"}})) is True + assert pred.evaluate(_ctx(kwargs={"config": {"region": "eu-west"}})) is False + + +def test_payload_matches_rejects_empty_path() -> None: + with pytest.raises(ValueError): + payload_matches("", "x") + + +# -- System-state recipes --------------------------------------------------- + + +def test_queue_size_under_uses_context_helper() -> None: + pred = queue_size_under(100) + ctx = _ctx() + with mock.patch.object(ctx, "queue_size", return_value=50): + assert pred.evaluate(ctx) is True + with mock.patch.object(ctx, "queue_size", return_value=100): + assert pred.evaluate(ctx) is False + + +def test_queue_paused_uses_context_helper() -> None: + pred = queue_paused() + ctx = _ctx() + with mock.patch.object(ctx, "queue_paused", return_value=True): + assert pred.evaluate(ctx) is True + with mock.patch.object(ctx, "queue_paused", return_value=False): + assert pred.evaluate(ctx) is False + + +def test_queue_size_under_rejects_zero() -> None: + with pytest.raises(ValueError): + queue_size_under(0) + + +def test_error_rate_under_allows_with_no_jobs() -> None: + pred = pred = __import__("taskito.predicates", fromlist=["error_rate_under"]).error_rate_under( + 0.5 + ) + ctx = _ctx() + with mock.patch.object(ctx, "stats", return_value={}): + assert pred.evaluate(ctx) is True + + +def test_error_rate_under_compares_ratio() -> None: + from taskito.predicates import error_rate_under + + pred = error_rate_under(0.2) + ctx = _ctx() + with mock.patch.object(ctx, "stats", return_value={"completed": 90, "failed": 5, "dead": 5}): + # rate = 10/100 = 0.1 < 0.2 + assert pred.evaluate(ctx) is True + with mock.patch.object(ctx, "stats", return_value={"completed": 70, "failed": 20, "dead": 10}): + # rate = 30/100 = 0.3 > 0.2 + assert pred.evaluate(ctx) is False + + +def test_error_rate_validates_range() -> None: + from taskito.predicates import error_rate_under + + with pytest.raises(ValueError): + error_rate_under(0.0) + with pytest.raises(ValueError): + error_rate_under(1.5) + + +# -- Config recipes --------------------------------------------------------- + + +def test_env_var_truthy_reads_env() -> None: + pred = env_var_truthy("MY_FLAG") + with mock.patch.dict(os.environ, {"MY_FLAG": "1"}): + assert pred.evaluate(_ctx()) is True + with mock.patch.dict(os.environ, {"MY_FLAG": "no"}): + assert pred.evaluate(_ctx()) is False + with mock.patch.dict(os.environ, {}, clear=True): + assert pred.evaluate(_ctx()) is False + + +def test_env_var_truthy_rejects_empty_name() -> None: + with pytest.raises(ValueError): + env_var_truthy("") + + +def test_feature_flag_default_provider_reads_ff_prefix() -> None: + pred = feature_flag("billing") + with mock.patch.dict(os.environ, {"FF_BILLING": "true"}): + assert pred.evaluate(_ctx()) is True + with mock.patch.dict(os.environ, {}, clear=True): + assert pred.evaluate(_ctx()) is False + + +def test_feature_flag_custom_provider() -> None: + class _Stub: + def __init__(self) -> None: + self.calls: list[tuple[str, str]] = [] + + def is_enabled(self, name: str, ctx: PredicateContext) -> bool: + self.calls.append((name, ctx.task_name)) + return name == "yes" + + stub: FeatureFlagProvider = _Stub() + assert feature_flag("yes", provider=stub).evaluate(_ctx()) is True + assert feature_flag("no", provider=stub).evaluate(_ctx()) is False + assert stub.calls == [("yes", "t"), ("no", "t")] # type: ignore[attr-defined] + + +def test_feature_flag_rejects_empty_name() -> None: + with pytest.raises(ValueError): + feature_flag("") From 331d50109f63abe547927f6ee30b19e19f3b5591 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Mon, 11 May 2026 09:34:32 +0530 Subject: [PATCH 02/11] feat(predicates): wire enqueue-time gating and @task kwargs --- py_src/taskito/__init__.py | 2 + py_src/taskito/app.py | 90 ++++++++++++++ py_src/taskito/exceptions.py | 17 +++ py_src/taskito/mixins/decorators.py | 39 ++++++ tests/python/test_predicates_enqueue.py | 151 ++++++++++++++++++++++++ 5 files changed, 299 insertions(+) create mode 100644 tests/python/test_predicates_enqueue.py diff --git a/py_src/taskito/__init__.py b/py_src/taskito/__init__.py index cbdf5fd..a1056f7 100644 --- a/py_src/taskito/__init__.py +++ b/py_src/taskito/__init__.py @@ -9,6 +9,7 @@ CircularDependencyError, JobNotFoundError, MaxRetriesExceededError, + PredicateRejectedError, ProxyCleanupError, ProxyReconstructionError, QueueError, @@ -56,6 +57,7 @@ "MockResource", "MsgPackSerializer", "NoProxy", + "PredicateRejectedError", "ProxyCleanupError", "ProxyReconstructionError", "Queue", diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index d5b06c5..5476340 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -28,6 +28,7 @@ from taskito._taskito import PyQueue from taskito.async_support.mixins import AsyncQueueMixin from taskito.events import EventBus, EventType +from taskito.exceptions import PredicateRejectedError from taskito.interception import ArgumentInterceptor from taskito.interception.built_in import build_default_registry from taskito.interception.metrics import InterceptionMetrics @@ -42,6 +43,10 @@ QueueResourceMixin, QueueSettingsMixin, ) +from taskito.predicates.context import PredicateContext +from taskito.predicates.evaluate import evaluate_predicate +from taskito.predicates.metrics import PredicateMetrics +from taskito.predicates.outcomes import Cancel, Defer from taskito.proxies import ProxyRegistry from taskito.proxies.built_in import register_builtin_handlers from taskito.proxies.metrics import ProxyMetrics @@ -214,6 +219,11 @@ def __init__( self._global_middleware: list[TaskMiddleware] = middleware or [] self._task_middleware: dict[str, list[TaskMiddleware]] = {} self._task_retry_filters: dict[str, dict[str, list[type[Exception]]]] = {} + self._task_predicates: dict[str, Any] = {} + self._task_predicate_on_false: dict[str, str] = {} + self._task_predicate_extras: dict[str, dict[str, Any]] = {} + self._task_default_defer: dict[str, float] = {} + self._predicate_metrics = PredicateMetrics() self._drain_timeout = drain_timeout self._queue_configs: dict[str, dict[str, Any]] = {} self._event_bus = EventBus(max_workers=event_workers) @@ -301,6 +311,53 @@ def _deserialize_payload(self, task_name: str, payload: bytes) -> tuple: """Deserialize a job payload using the per-task or queue-level serializer.""" return self._get_serializer(task_name).loads(payload) # type: ignore[no-any-return] + def _apply_enqueue_predicate( + self, + *, + predicate: Any, + task_name: str, + queue_name: str, + priority: int | None, + args: tuple, + kwargs: dict, + payload_size: int, + delay: float | None, + ) -> float | None: + """Evaluate an enqueue-time predicate; return adjusted ``delay``. + + Raises :class:`~taskito.exceptions.PredicateRejectedError` when the + outcome is a :class:`~taskito.predicates.Cancel`, or a plain + ``False`` paired with ``on_false="cancel"``. Returns the (possibly + bumped) ``delay`` for the caller to pass through to the Rust + ``enqueue``. + """ + ctx = PredicateContext.for_enqueue( + task_name=task_name, + queue=queue_name, + priority=priority, + args=tuple(args), + kwargs=dict(kwargs), + payload_size=payload_size, + delay_seconds=delay, + extras=self._task_predicate_extras.get(task_name), + queue_ref=self, + ) + outcome = evaluate_predicate(predicate, ctx, metrics=self._predicate_metrics) + + if outcome is True: + return delay + if isinstance(outcome, Defer): + return (delay or 0.0) + outcome.seconds + if isinstance(outcome, Cancel): + raise PredicateRejectedError(task_name, outcome.reason) + + # Plain False — branch on the task's on_false setting. + action = self._task_predicate_on_false.get(task_name, "defer") + if action == "cancel": + raise PredicateRejectedError(task_name) + # defer + return (delay or 0.0) + self._task_default_defer.get(task_name, 60.0) + def enqueue( self, task_name: str, @@ -387,6 +444,22 @@ def enqueue( task_serializer = self._get_serializer(task_name) payload = task_serializer.dumps((final_args, final_kwargs)) + # Evaluate enqueue-time predicate (if registered). Outcome may + # adjust the delay (Defer / False+defer), raise (Cancel / + # False+cancel), or pass through unchanged. + predicate = self._task_predicates.get(task_name) + if predicate is not None: + delay = self._apply_enqueue_predicate( + predicate=predicate, + task_name=task_name, + queue_name=queue or "default", + priority=priority, + args=final_args, + kwargs=final_kwargs, + payload_size=len(payload), + delay=delay, + ) + unique_key = self._resolve_unique_key( task_name=task_name, payload=payload, @@ -552,6 +625,23 @@ def enqueue_many( for i in range(count) ] + # Evaluate enqueue-time predicate per row. Cancel raises for the + # whole batch — all-or-nothing semantics. Defer adjusts that row's + # delay only. + predicate = self._task_predicates.get(task_name) + if predicate is not None: + for i in range(count): + delays[i] = self._apply_enqueue_predicate( + predicate=predicate, + task_name=task_name, + queue_name=queues_list[i], + priority=priorities_list[i], + args=args_list[i], + kwargs=kw_list[i], + payload_size=len(payloads[i]), + delay=delays[i], + ) + py_jobs = self._inner.enqueue_batch( task_names=task_names, payloads=payloads, diff --git a/py_src/taskito/exceptions.py b/py_src/taskito/exceptions.py index c48dc1c..11346ec 100644 --- a/py_src/taskito/exceptions.py +++ b/py_src/taskito/exceptions.py @@ -71,3 +71,20 @@ class ProxyReconstructionError(ResourceError): class ProxyCleanupError(ResourceError): """Raised when a proxy handler fails during cleanup.""" + + +class PredicateRejectedError(TaskitoError): + """Raised when an enqueue-time predicate cancels the submission. + + The ``reason`` attribute carries the message attached to + :class:`~taskito.predicates.Cancel` (or an empty string when a plain + ``False`` outcome triggers cancellation under ``on_false="cancel"``). + """ + + def __init__(self, task_name: str, reason: str = "") -> None: + self.task_name = task_name + self.reason = reason + msg = f"predicate rejected enqueue of {task_name!r}" + if reason: + msg = f"{msg}: {reason}" + super().__init__(msg) diff --git a/py_src/taskito/mixins/decorators.py b/py_src/taskito/mixins/decorators.py index 3898b95..5df7b1c 100644 --- a/py_src/taskito/mixins/decorators.py +++ b/py_src/taskito/mixins/decorators.py @@ -19,12 +19,15 @@ from taskito.inject import Inject, _InjectAlias from taskito.interception.reconstruct import reconstruct_args from taskito.interception.strategy import Strategy as S +from taskito.predicates.core import coerce_predicate from taskito.proxies import cleanup_proxies, reconstruct_proxies from taskito.task import TaskWrapper if TYPE_CHECKING: from taskito.interception import ArgumentInterceptor from taskito.middleware import TaskMiddleware + from taskito.predicates import Predicate + from taskito.predicates.metrics import PredicateMetrics from taskito.proxies import ProxyRegistry from taskito.proxies.metrics import ProxyMetrics from taskito.resources.runtime import ResourceRuntime @@ -62,6 +65,11 @@ class QueueDecoratorMixin: _task_middleware: dict[str, list[TaskMiddleware]] _task_retry_filters: dict[str, dict[str, list[type[Exception]]]] _task_inject_map: dict[str, list[str]] + _task_predicates: dict[str, Predicate] + _task_predicate_on_false: dict[str, str] + _task_predicate_extras: dict[str, dict[str, Any]] + _task_default_defer: dict[str, float] + _predicate_metrics: PredicateMetrics _interceptor: ArgumentInterceptor | None _proxy_registry: ProxyRegistry | None _proxy_metrics: ProxyMetrics @@ -211,6 +219,10 @@ def task( max_retry_delay: int | None = None, max_concurrent: int | None = None, idempotent: bool = False, + predicate: Predicate | Callable[..., Any] | None = None, + on_false: str = "defer", + predicate_extras: dict[str, Any] | None = None, + default_defer_seconds: float = 60.0, ) -> Callable[[Callable[..., Any]], TaskWrapper]: """Decorator to register a function as a background task. @@ -244,7 +256,24 @@ def task( ``idempotency_key="..."`` overrides the derived key; per-call ``idempotent=False`` disables auto-derivation for that one submission. + predicate: A :class:`~taskito.predicates.Predicate` (or plain + callable receiving a :class:`~taskito.predicates.PredicateContext`) + evaluated both at enqueue time and at worker-dispatch time + to decide whether the job runs. + on_false: What to do when the predicate returns ``False`` — + ``"defer"`` (re-schedule with ``default_defer_seconds``), + ``"cancel"`` (terminally skip). + predicate_extras: Arbitrary dict forwarded to the predicate via + ``PredicateContext.extras``. Useful for passing static + config without re-instantiating the predicate. + default_defer_seconds: Default delay when ``on_false="defer"`` + and the predicate returns plain ``False`` (no explicit + ``Defer(seconds=...)``). Ignored otherwise. """ + if on_false not in {"defer", "cancel"}: + raise ValueError(f"on_false must be 'defer' or 'cancel', got {on_false!r}") + if default_defer_seconds < 0: + raise ValueError("default_defer_seconds must be >= 0") def decorator(fn: Callable) -> TaskWrapper: task_name = name or f"{_resolve_module_name(fn.__module__)}.{fn.__qualname__}" @@ -293,6 +322,16 @@ def decorator(fn: Callable) -> TaskWrapper: if idempotent: self._task_idempotent[task_name] = True + # Store predicate (and its on_false/extras/default_defer) + if predicate is not None: + coerced = coerce_predicate(predicate) + if coerced is not None: + self._task_predicates[task_name] = coerced + self._task_predicate_on_false[task_name] = on_false + if predicate_extras: + self._task_predicate_extras[task_name] = dict(predicate_extras) + self._task_default_defer[task_name] = default_defer_seconds + # Store inject map for resource injection if final_inject: self._task_inject_map[task_name] = final_inject diff --git a/tests/python/test_predicates_enqueue.py b/tests/python/test_predicates_enqueue.py new file mode 100644 index 0000000..99d835e --- /dev/null +++ b/tests/python/test_predicates_enqueue.py @@ -0,0 +1,151 @@ +"""Enqueue-time predicate gating tests.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import pytest + +from taskito.app import Queue +from taskito.exceptions import PredicateRejectedError +from taskito.predicates import ( + Cancel, + Defer, + Predicate, + PredicateContext, +) + + +@dataclass(frozen=True) +class _Const(Predicate): + value: Any + + def evaluate(self, ctx: PredicateContext) -> Any: + return self.value + + +def test_true_predicate_lets_job_through(queue: Queue) -> None: + @queue.task(predicate=_Const(True)) + def t() -> str: + return "ran" + + job = t.delay() + assert job.id + + +def test_false_predicate_defers_by_default(queue: Queue) -> None: + @queue.task(predicate=_Const(False), default_defer_seconds=120.0) + def t() -> str: + return "ran" + + job = t.delay() + # Job is enqueued with a future scheduled_at — we cannot directly read + # the delay from the JobResult, but we can confirm the job exists + # and is pending (not cancelled). + job.refresh() + assert job.status == "pending" + + +def test_false_predicate_cancel_action_raises(queue: Queue) -> None: + @queue.task(predicate=_Const(False), on_false="cancel") + def t() -> str: + return "ran" + + with pytest.raises(PredicateRejectedError) as excinfo: + t.delay() + assert excinfo.value.task_name.endswith("t") + + +def test_cancel_outcome_always_raises(queue: Queue) -> None: + @queue.task(predicate=_Const(Cancel(reason="bad tenant"))) + def t() -> str: + return "ran" + + with pytest.raises(PredicateRejectedError) as excinfo: + t.delay() + assert excinfo.value.reason == "bad tenant" + + +def test_defer_outcome_adds_to_caller_delay(queue: Queue) -> None: + @queue.task(predicate=_Const(Defer(seconds=300.0))) + def t() -> str: + return "ran" + + job = t.delay() + job.refresh() + assert job.status == "pending" + + +def test_on_false_rejects_invalid_value(queue: Queue) -> None: + with pytest.raises(ValueError): + + @queue.task(predicate=_Const(False), on_false="invalid") + def t() -> None: ... + + +def test_default_defer_seconds_must_be_non_negative(queue: Queue) -> None: + with pytest.raises(ValueError): + + @queue.task(predicate=_Const(False), default_defer_seconds=-1.0) + def t() -> None: ... + + +def test_predicate_extras_reach_context(queue: Queue) -> None: + seen: dict[str, object] = {} + + class _Probe(Predicate): + def evaluate(self, ctx: PredicateContext) -> bool: + seen.update(ctx.extras) + return True + + @queue.task(predicate=_Probe(), predicate_extras={"tenant": "acme"}) + def t() -> str: + return "ran" + + t.delay() + assert seen == {"tenant": "acme"} + + +def test_plain_callable_is_accepted_as_predicate(queue: Queue) -> None: + @queue.task(predicate=lambda ctx: ctx.task_name.endswith("t")) + def t() -> str: + return "ran" + + job = t.delay() + assert job.id + + +def test_enqueue_many_cancels_entire_batch(queue: Queue) -> None: + @queue.task(predicate=_Const(Cancel(reason="nope"))) + def t(x: int) -> int: + return x + + with pytest.raises(PredicateRejectedError): + queue.enqueue_many(t.name, [(1,), (2,), (3,)]) + + +def test_enqueue_many_with_true_predicate(queue: Queue) -> None: + @queue.task(predicate=_Const(True)) + def t(x: int) -> int: + return x + + results = queue.enqueue_many(t.name, [(1,), (2,), (3,)]) + assert len(results) == 3 + + +def test_metrics_record_outcomes(queue: Queue) -> None: + @queue.task(predicate=_Const(True)) + def allowed() -> int: + return 1 + + @queue.task(predicate=_Const(Defer(seconds=10.0))) + def deferred() -> int: + return 1 + + allowed.delay() + allowed.delay() + deferred.delay() + snap = queue._predicate_metrics.snapshot() + assert snap["allowed"] == 2 + assert snap["deferred"] == 1 From 55035aea2bd18edcad46422509454e8482c15cb0 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Mon, 11 May 2026 09:37:56 +0530 Subject: [PATCH 03/11] feat(predicates): wire worker-dispatch gating sync + async --- py_src/taskito/app.py | 100 +++++++++++- py_src/taskito/async_support/executor.py | 11 ++ py_src/taskito/mixins/decorators.py | 17 ++ tests/python/test_predicates_worker.py | 194 +++++++++++++++++++++++ 4 files changed, 321 insertions(+), 1 deletion(-) create mode 100644 tests/python/test_predicates_worker.py diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index 5476340..8217fb3 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -28,7 +28,7 @@ from taskito._taskito import PyQueue from taskito.async_support.mixins import AsyncQueueMixin from taskito.events import EventBus, EventType -from taskito.exceptions import PredicateRejectedError +from taskito.exceptions import PredicateRejectedError, TaskCancelledError from taskito.interception import ArgumentInterceptor from taskito.interception.built_in import build_default_registry from taskito.interception.metrics import InterceptionMetrics @@ -311,6 +311,104 @@ def _deserialize_payload(self, task_name: str, payload: bytes) -> tuple: """Deserialize a job payload using the per-task or queue-level serializer.""" return self._get_serializer(task_name).loads(payload) # type: ignore[no-any-return] + def _apply_dispatch_predicate( + self, + *, + task_name: str, + args: tuple, + kwargs: dict, + job_id: str, + queue_name: str, + priority: int = 0, + retry_count: int = 0, + ) -> None: + """Evaluate the worker-dispatch predicate; defer or cancel as needed. + + ``Defer`` (or ``False`` with ``on_false="defer"``) re-enqueues a + fresh job with the same payload and a delay, then raises + :class:`TaskCancelledError` so the current execution is marked + cancelled by the Rust runner. ``Cancel`` (or ``False`` with + ``on_false="cancel"``) raises :class:`TaskCancelledError` directly + without re-enqueueing. + + Returns silently when the predicate allows. + """ + predicate = self._task_predicates.get(task_name) + if predicate is None: + return + + ctx = PredicateContext.for_dispatch( + task_name=task_name, + queue=queue_name, + priority=priority, + retry_count=retry_count, + args=tuple(args), + kwargs=dict(kwargs), + job_id=job_id, + payload_size=0, + extras=self._task_predicate_extras.get(task_name), + queue_ref=self, + ) + outcome = evaluate_predicate(predicate, ctx, metrics=self._predicate_metrics) + + if outcome is True: + return + + if isinstance(outcome, Cancel): + raise TaskCancelledError( + f"predicate cancelled job {job_id}: {outcome.reason}" + if outcome.reason + else f"predicate cancelled job {job_id}" + ) + + if isinstance(outcome, Defer): + self._reenqueue_after_defer( + task_name=task_name, + args=args, + kwargs=kwargs, + queue_name=queue_name, + delay_seconds=outcome.seconds, + ) + raise TaskCancelledError(f"predicate deferred job {job_id} by {outcome.seconds:.1f}s") + + # Plain False — branch on on_false. + action = self._task_predicate_on_false.get(task_name, "defer") + if action == "cancel": + raise TaskCancelledError(f"predicate rejected job {job_id}") + self._reenqueue_after_defer( + task_name=task_name, + args=args, + kwargs=kwargs, + queue_name=queue_name, + delay_seconds=self._task_default_defer.get(task_name, 60.0), + ) + raise TaskCancelledError(f"predicate deferred job {job_id}") + + def _reenqueue_after_defer( + self, + *, + task_name: str, + args: tuple, + kwargs: dict, + queue_name: str, + delay_seconds: float, + ) -> None: + """Re-enqueue a job with a delay, bypassing predicate re-evaluation. + + Args/kwargs are serialized fresh via the task's serializer. We go + straight to the Rust queue to avoid running enqueue-time + middleware or re-evaluating the predicate (which would create an + infinite ping-pong). + """ + serializer = self._get_serializer(task_name) + payload = serializer.dumps((tuple(args), dict(kwargs))) + self._inner.enqueue( + task_name=task_name, + payload=payload, + queue=queue_name, + delay_seconds=delay_seconds, + ) + def _apply_enqueue_predicate( self, *, diff --git a/py_src/taskito/async_support/executor.py b/py_src/taskito/async_support/executor.py index a1ae934..5aba575 100644 --- a/py_src/taskito/async_support/executor.py +++ b/py_src/taskito/async_support/executor.py @@ -98,6 +98,17 @@ async def _execute( args, kwargs = cloudpickle.loads(payload_bytes) queue = self._queue_ref + # Worker-dispatch predicate gate (raw args, pre-reconstruction). + if task_name in queue._task_predicates: + queue._apply_dispatch_predicate( + task_name=task_name, + args=args, + kwargs=kwargs, + job_id=job_id, + queue_name=queue_name, + retry_count=retry_count, + ) + # Reconstruct intercepted arguments redirects: dict[str, str] = {} if queue._interceptor is not None: diff --git a/py_src/taskito/mixins/decorators.py b/py_src/taskito/mixins/decorators.py index 5df7b1c..c2231e4 100644 --- a/py_src/taskito/mixins/decorators.py +++ b/py_src/taskito/mixins/decorators.py @@ -85,6 +85,10 @@ class QueueDecoratorMixin: # lets mypy see it from this mixin without overriding the real # implementation through MRO. _emit_event: Callable[..., None] + # ``_apply_dispatch_predicate`` is defined on the Queue itself + # (alongside enqueue) since it needs ``_inner`` and the task + # serializer. Declared here so mypy sees it through the mixin. + _apply_dispatch_predicate: Callable[..., None] def _get_middleware_chain(self, task_name: str) -> list[TaskMiddleware]: """Get the combined global + per-task middleware list.""" @@ -100,6 +104,19 @@ def _wrap_task( @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: + # Worker-dispatch predicate gate. Evaluated on the raw + # deserialized payload (before arg/proxy reconstruction) so + # re-enqueue on defer can round-trip cleanly. + if task_name in queue_ref._task_predicates: + queue_ref._apply_dispatch_predicate( + task_name=task_name, + args=args, + kwargs=kwargs, + job_id=current_job.id, + queue_name=current_job.queue_name, + retry_count=current_job.retry_count, + ) + # Reconstruct intercepted arguments (CONVERT markers → original types) redirects: dict[str, str] = {} if queue_ref._interceptor is not None: diff --git a/tests/python/test_predicates_worker.py b/tests/python/test_predicates_worker.py new file mode 100644 index 0000000..e6a9b2d --- /dev/null +++ b/tests/python/test_predicates_worker.py @@ -0,0 +1,194 @@ +"""Worker-dispatch predicate gating tests.""" + +from __future__ import annotations + +import threading +import time +from dataclasses import dataclass +from typing import Any + +import pytest + +from taskito.app import Queue +from taskito.exceptions import PredicateRejectedError +from taskito.predicates import ( + Cancel, + Defer, + Predicate, + PredicateContext, +) + + +@dataclass(frozen=True) +class _Const(Predicate): + value: Any + + def evaluate(self, ctx: PredicateContext) -> Any: + return self.value + + +class _DispatchOnly(Predicate): + """Allow enqueue (job_id None); apply ``dispatch_value`` at dispatch.""" + + def __init__(self, dispatch_value: Any) -> None: + self.dispatch_value = dispatch_value + self.dispatch_calls = 0 + + def evaluate(self, ctx: PredicateContext) -> Any: + if ctx.job_id is None: + return True + self.dispatch_calls += 1 + return self.dispatch_value + + +def _start_worker(queue: Queue) -> threading.Thread: + thread = threading.Thread(target=queue.run_worker, daemon=True) + thread.start() + return thread + + +def _stop_worker(queue: Queue, thread: threading.Thread) -> None: + queue._inner.request_shutdown() + thread.join(timeout=5) + + +def _wait_for_status(queue: Queue, job_id: str, statuses: set[str], timeout: float = 8.0) -> str: + deadline = time.monotonic() + timeout + last_status = "" + while time.monotonic() < deadline: + job = queue._inner.get_job(job_id) + if job is not None: + last_status = job.status + if last_status in statuses: + return last_status + time.sleep(0.05) + return last_status + + +def _wait_for_completed_task(queue: Queue, task_name: str, timeout: float = 10.0) -> bool: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + jobs = queue._inner.list_jobs(status="complete", task_name=task_name) + if jobs: + return True + time.sleep(0.1) + return False + + +def test_dispatch_predicate_allows_normal_run(queue: Queue) -> None: + @queue.task(predicate=_Const(True)) + def t() -> str: + return "ran" + + job = t.delay() + thread = _start_worker(queue) + try: + assert _wait_for_status(queue, job.id, {"complete"}) == "complete" + finally: + _stop_worker(queue, thread) + + +def test_dispatch_predicate_cancel_terminates_job(queue: Queue) -> None: + @queue.task(predicate=_DispatchOnly(Cancel(reason="blocked"))) + def t() -> str: + return "ran" + + job = t.delay() + thread = _start_worker(queue) + try: + assert _wait_for_status(queue, job.id, {"cancelled"}) == "cancelled" + finally: + _stop_worker(queue, thread) + + +def test_dispatch_predicate_false_with_cancel_action(queue: Queue) -> None: + @queue.task(predicate=_DispatchOnly(False), on_false="cancel") + def t() -> str: + return "ran" + + job = t.delay() + thread = _start_worker(queue) + try: + assert _wait_for_status(queue, job.id, {"cancelled"}) == "cancelled" + finally: + _stop_worker(queue, thread) + + +def test_dispatch_predicate_defer_reenqueues(queue: Queue) -> None: + # First dispatch -> Defer(seconds=1); second dispatch -> True. + # The original job is cancelled and a new one is enqueued with 1s delay. + state = {"calls": 0} + + class _Pred(Predicate): + def evaluate(self, ctx: PredicateContext) -> Any: + if ctx.job_id is None: + return True # allow enqueue + state["calls"] += 1 + if state["calls"] == 1: + return Defer(seconds=1.0) + return True + + @queue.task(predicate=_Pred()) + def t() -> str: + return "ran" + + t.delay() + thread = _start_worker(queue) + try: + assert _wait_for_completed_task(queue, t.name, timeout=15.0), ( + "deferred job never completed after re-enqueue" + ) + # Verify the predicate was called at least twice (defer + run) + assert state["calls"] >= 2 + finally: + _stop_worker(queue, thread) + + +def test_dispatch_predicate_fail_closed_on_exception_at_dispatch( + queue: Queue, +) -> None: + class _DispatchBoom(Predicate): + def evaluate(self, ctx: PredicateContext) -> bool: + if ctx.job_id is None: + return True # let enqueue through + raise RuntimeError("boom") + + @queue.task(predicate=_DispatchBoom(), on_false="cancel") + def t() -> str: + return "ran" + + job = t.delay() + thread = _start_worker(queue) + try: + # Fail-closed → False → on_false=cancel → cancelled + assert _wait_for_status(queue, job.id, {"cancelled"}) == "cancelled" + finally: + _stop_worker(queue, thread) + + +def test_dispatch_metrics_record_outcome(queue: Queue) -> None: + @queue.task(predicate=_DispatchOnly(Cancel(reason="x"))) + def t() -> str: + return "ran" + + job = t.delay() + thread = _start_worker(queue) + try: + _wait_for_status(queue, job.id, {"cancelled"}) + finally: + _stop_worker(queue, thread) + snap = queue._predicate_metrics.snapshot() + # allowed=1 (enqueue), cancelled=1 (dispatch) + assert snap["allowed"] >= 1 + assert snap["cancelled"] >= 1 + + +def test_enqueue_time_cancel_still_raises_for_dispatch_predicate(queue: Queue) -> None: + # If predicate cancels at enqueue (regardless of dispatch behavior), + # the enqueue raises and no job is created. + @queue.task(predicate=_Const(Cancel(reason="nope"))) + def t() -> str: + return "ran" + + with pytest.raises(PredicateRejectedError): + t.delay() From bd73a3b8a43628cfeb38e428fbe7c65b7e685025 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Mon, 11 May 2026 09:48:27 +0530 Subject: [PATCH 04/11] feat(predicates): accept predicate= in TaskMiddleware base + contrib --- py_src/taskito/app.py | 4 + py_src/taskito/async_support/executor.py | 4 +- py_src/taskito/contrib/otel.py | 22 ++--- py_src/taskito/contrib/prometheus.py | 15 +-- py_src/taskito/contrib/sentry.py | 14 +-- py_src/taskito/middleware.py | 64 ++++++++++++ py_src/taskito/mixins/decorators.py | 5 +- tests/python/test_predicates_middleware.py | 110 +++++++++++++++++++++ 8 files changed, 204 insertions(+), 34 deletions(-) create mode 100644 tests/python/test_predicates_middleware.py diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index 8217fb3..99d442f 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -518,6 +518,8 @@ def enqueue( "idempotent": idempotent, } for mw in self._global_middleware: + if not mw._should_apply(None, task_name=task_name): + continue try: mw.on_enqueue(task_name, final_args, final_kwargs, enqueue_options) except Exception: @@ -682,6 +684,8 @@ def enqueue_many( chain = self._get_middleware_chain(task_name) for i in range(count): for mw in chain: + if not mw._should_apply(None, task_name=task_name): + continue try: mw.on_enqueue(task_name, args_list[i], kw_list[i], per_job_options[i]) except Exception: diff --git a/py_src/taskito/async_support/executor.py b/py_src/taskito/async_support/executor.py index 5aba575..f25e13d 100644 --- a/py_src/taskito/async_support/executor.py +++ b/py_src/taskito/async_support/executor.py @@ -140,9 +140,11 @@ async def _execute( if release is not None: release_callbacks.append(release) - # Middleware before hooks + # Middleware before hooks (skipping filtered middlewares) middleware_chain = queue._get_middleware_chain(task_name) for mw in middleware_chain: + if not mw._should_apply(current_job): + continue try: mw.before(current_job) completed_mw.append(mw) diff --git a/py_src/taskito/contrib/otel.py b/py_src/taskito/contrib/otel.py index e133f0e..d6f2de5 100644 --- a/py_src/taskito/contrib/otel.py +++ b/py_src/taskito/contrib/otel.py @@ -17,10 +17,11 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any -from taskito.middleware import TaskMiddleware +from taskito.middleware import TaskMiddleware, legacy_task_filter_to_predicate if TYPE_CHECKING: from taskito.context import JobContext + from taskito.predicates import Predicate try: from opentelemetry import trace @@ -49,8 +50,12 @@ class OpenTelemetryMiddleware(TaskMiddleware): attribute_prefix: Prefix for span attribute keys (default ``"taskito"``). extra_attributes_fn: Callable that returns extra attributes to add to each span. Receives a :class:`~taskito.context.JobContext`. - task_filter: Predicate that receives a task name and returns ``True`` - to trace the task. ``None`` traces all tasks. + task_filter: Legacy ``Callable[[task_name], bool]`` filter. Kept for + back-compat — prefer ``predicate=`` which accepts richer + :class:`~taskito.predicates.Predicate` objects. + predicate: Optional :class:`~taskito.predicates.Predicate` (or + callable taking a :class:`~taskito.predicates.PredicateContext`) + controlling which tasks this middleware applies to. """ def __init__( @@ -61,32 +66,27 @@ def __init__( attribute_prefix: str = "taskito", extra_attributes_fn: Callable[[JobContext], dict[str, Any]] | None = None, task_filter: Callable[[str], bool] | None = None, + predicate: Predicate | Callable[..., Any] | None = None, ): if trace is None: raise ImportError( "opentelemetry-api is required for OpenTelemetryMiddleware. " "Install it with: pip install taskito[otel]" ) + super().__init__(predicate=legacy_task_filter_to_predicate(task_filter, predicate)) self._tracer = trace.get_tracer(tracer_name) self._span_name_fn = span_name_fn self._attr_prefix = attribute_prefix self._extra_attributes_fn = extra_attributes_fn - self._task_filter = task_filter self._spans: dict[str, Any] = {} self._lock = threading.Lock() - def _should_trace(self, task_name: str) -> bool: - return self._task_filter is None or self._task_filter(task_name) - def _span_name(self, ctx: JobContext) -> str: if self._span_name_fn is not None: return self._span_name_fn(ctx) return f"{self._attr_prefix}.execute.{ctx.task_name}" def before(self, ctx: JobContext) -> None: - if not self._should_trace(ctx.task_name): - return - prefix = self._attr_prefix attributes: dict[str, Any] = { f"{prefix}.job_id": ctx.id, @@ -105,7 +105,7 @@ def after(self, ctx: JobContext, result: Any, error: Exception | None) -> None: with self._lock: span = self._spans.pop(ctx.id, None) if span is None: - return + return # before() didn't emit a span (predicate filtered, or error) try: if error is not None: diff --git a/py_src/taskito/contrib/prometheus.py b/py_src/taskito/contrib/prometheus.py index 4444a64..37a070d 100644 --- a/py_src/taskito/contrib/prometheus.py +++ b/py_src/taskito/contrib/prometheus.py @@ -20,11 +20,12 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any -from taskito.middleware import TaskMiddleware +from taskito.middleware import TaskMiddleware, legacy_task_filter_to_predicate if TYPE_CHECKING: from taskito.app import Queue from taskito.context import JobContext + from taskito.predicates import Predicate logger = logging.getLogger("taskito.prometheus") @@ -187,24 +188,20 @@ def __init__( extra_labels_fn: Callable[[JobContext], dict[str, str]] | None = None, disabled_metrics: set[str] | None = None, task_filter: Callable[[str], bool] | None = None, + predicate: Predicate | Callable[..., Any] | None = None, ) -> None: if Counter is None: raise ImportError( "prometheus-client is required for PrometheusMiddleware. " "Install it with: pip install taskito[prometheus]" ) + super().__init__(predicate=legacy_task_filter_to_predicate(task_filter, predicate)) self._metrics = _get_or_create_metrics(namespace, disabled_metrics) self._extra_labels_fn = extra_labels_fn - self._task_filter = task_filter self._start_times: dict[str, float] = {} self._lock = threading.Lock() - def _should_track(self, task_name: str) -> bool: - return self._task_filter is None or self._task_filter(task_name) - def before(self, ctx: JobContext) -> None: - if not self._should_track(ctx.task_name): - return with self._lock: self._start_times[ctx.id] = time.monotonic() m = self._metrics["active_workers"] @@ -212,8 +209,6 @@ def before(self, ctx: JobContext) -> None: m.inc() def after(self, ctx: JobContext, result: Any, error: Exception | None) -> None: - if not self._should_track(ctx.task_name): - return m = self._metrics["active_workers"] if m is not None: m.dec() @@ -232,8 +227,6 @@ def after(self, ctx: JobContext, result: Any, error: Exception | None) -> None: m.labels(task=ctx.task_name).observe(duration) def on_retry(self, ctx: JobContext, error: Exception, retry_count: int) -> None: - if not self._should_track(ctx.task_name): - return m = self._metrics["retries_total"] if m is not None: m.labels(task=ctx.task_name).inc() diff --git a/py_src/taskito/contrib/sentry.py b/py_src/taskito/contrib/sentry.py index b550777..08b241e 100644 --- a/py_src/taskito/contrib/sentry.py +++ b/py_src/taskito/contrib/sentry.py @@ -16,10 +16,11 @@ from collections.abc import Callable from typing import TYPE_CHECKING, Any -from taskito.middleware import TaskMiddleware +from taskito.middleware import TaskMiddleware, legacy_task_filter_to_predicate if TYPE_CHECKING: from taskito.context import JobContext + from taskito.predicates import Predicate try: import sentry_sdk @@ -52,24 +53,19 @@ def __init__( transaction_name_fn: Callable[[JobContext], str] | None = None, task_filter: Callable[[str], bool] | None = None, extra_tags_fn: Callable[[JobContext], dict[str, str]] | None = None, + predicate: Predicate | Callable[..., Any] | None = None, ) -> None: if sentry_sdk is None: raise ImportError( "sentry-sdk is required for SentryMiddleware. " "Install it with: pip install taskito[sentry]" ) + super().__init__(predicate=legacy_task_filter_to_predicate(task_filter, predicate)) self._tag_prefix = tag_prefix self._transaction_name_fn = transaction_name_fn - self._task_filter = task_filter self._extra_tags_fn = extra_tags_fn - def _should_report(self, task_name: str) -> bool: - return self._task_filter is None or self._task_filter(task_name) - def before(self, ctx: JobContext) -> None: - if not self._should_report(ctx.task_name): - return - sentry_sdk.push_scope() try: scope = sentry_sdk.get_current_scope() @@ -90,8 +86,6 @@ def before(self, ctx: JobContext) -> None: raise def after(self, ctx: JobContext, result: Any, error: Exception | None) -> None: - if not self._should_report(ctx.task_name): - return if error is not None: sentry_sdk.capture_exception(error) sentry_sdk.pop_scope_unsafe() diff --git a/py_src/taskito/middleware.py b/py_src/taskito/middleware.py index 1502fb4..8650641 100644 --- a/py_src/taskito/middleware.py +++ b/py_src/taskito/middleware.py @@ -2,18 +2,48 @@ from __future__ import annotations +from collections.abc import Callable from typing import TYPE_CHECKING, Any +from taskito.predicates.context import PredicateContext +from taskito.predicates.core import Predicate, coerce_predicate +from taskito.predicates.evaluate import evaluate_predicate + if TYPE_CHECKING: from taskito.context import JobContext +def legacy_task_filter_to_predicate( + task_filter: Callable[[str], bool] | None, + predicate: Predicate | Callable[..., Any] | None, +) -> Predicate | Callable[..., Any] | None: + """Translate the contrib-legacy ``task_filter`` kwarg into a predicate. + + Contrib middlewares historically accepted ``task_filter=Callable[[str], bool]``. + The modern equivalent is ``predicate=`` taking a Predicate or a + callable receiving a PredicateContext. This helper preserves + back-compat: ``predicate`` wins if both are supplied; otherwise the + legacy callable is wrapped into a Predicate that ignores everything + except ``ctx.task_name``. + """ + if predicate is not None: + return predicate + if task_filter is None: + return None + return coerce_predicate(task_filter, str_callable=True) + + class TaskMiddleware: """Base class for task middleware. Subclass and override any of the hooks. Register globally via ``Queue(middleware=[...])`` or per-task via ``@queue.task(middleware=[...])``. + A ``predicate`` may be passed to gate which tasks the middleware + applies to: only when the predicate returns ``True`` will the hooks + fire. Plain ``Callable[[task_name], bool]`` callables are accepted for + backwards compatibility with contrib middleware ``task_filter`` kwargs. + Example:: class LoggingMiddleware(TaskMiddleware): @@ -25,6 +55,40 @@ def after(self, ctx, result, error): print(f"Finished {ctx.task_name}: {status}") """ + def __init__( + self, + *, + predicate: Predicate | Callable[..., Any] | None = None, + ) -> None: + self._predicate = coerce_predicate(predicate) + + def _should_apply(self, ctx: JobContext | None, task_name: str = "") -> bool: + """Decide whether this middleware's hooks should fire for ``ctx``. + + Returns ``True`` when no predicate is configured. When a predicate + is configured, builds a lightweight :class:`PredicateContext` from + the available data and evaluates it. ``Defer`` and ``Cancel`` from + a middleware predicate are interpreted as "skip this middleware + only" — they do not influence job dispatch. + + Subclasses that skip ``super().__init__()`` fall through to the + default ``True`` behaviour (no predicate). + """ + predicate = getattr(self, "_predicate", None) + if predicate is None: + return True + if ctx is not None: + pctx = PredicateContext( + task_name=ctx.task_name, + queue=ctx.queue_name, + retry_count=ctx.retry_count, + job_id=ctx.id, + ) + else: + pctx = PredicateContext(task_name=task_name, queue="default") + outcome = evaluate_predicate(predicate, pctx) + return outcome is True + def before(self, ctx: JobContext) -> None: """Called before task execution.""" diff --git a/py_src/taskito/mixins/decorators.py b/py_src/taskito/mixins/decorators.py index c2231e4..60bfe31 100644 --- a/py_src/taskito/mixins/decorators.py +++ b/py_src/taskito/mixins/decorators.py @@ -158,9 +158,12 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: if soft_timeout is not None: current_job._set_soft_timeout(soft_timeout) - # Run middleware before hooks + # Run middleware before hooks (skipping middlewares whose + # predicate filter excludes this job) completed_mw: list[Any] = [] for mw in middleware_chain: + if not mw._should_apply(current_job): + continue try: mw.before(current_job) completed_mw.append(mw) diff --git a/tests/python/test_predicates_middleware.py b/tests/python/test_predicates_middleware.py new file mode 100644 index 0000000..8f00093 --- /dev/null +++ b/tests/python/test_predicates_middleware.py @@ -0,0 +1,110 @@ +"""Predicate-driven middleware filtering tests.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from taskito.app import Queue +from taskito.context import JobContext +from taskito.middleware import TaskMiddleware, legacy_task_filter_to_predicate +from taskito.predicates import Predicate, PredicateContext + + +@dataclass(frozen=True) +class _Const(Predicate): + value: Any + + def evaluate(self, ctx: PredicateContext) -> Any: + return self.value + + +class _RecordingMiddleware(TaskMiddleware): + def __init__(self, **kw: Any) -> None: + super().__init__(**kw) + self.before_calls: list[str] = [] + self.after_calls: list[str] = [] + self.enqueue_calls: list[str] = [] + + def before(self, ctx: JobContext) -> None: + self.before_calls.append(ctx.task_name) + + def after(self, ctx: JobContext, result: Any, error: Exception | None) -> None: + self.after_calls.append(ctx.task_name) + + def on_enqueue(self, task_name: str, args: tuple, kwargs: dict, options: dict) -> None: + self.enqueue_calls.append(task_name) + + +def test_no_predicate_means_always_apply(tmp_path: Any) -> None: + mw = _RecordingMiddleware() + queue = Queue(db_path=str(tmp_path / "t.db"), workers=1, middleware=[mw]) + + @queue.task() + def t() -> str: + return "ok" + + t.delay() + assert mw.enqueue_calls == [t.name] + + +def test_legacy_task_filter_callable_still_works(tmp_path: Any) -> None: + mw = _RecordingMiddleware(predicate=None) # set up base first + # Then re-init via the legacy helper + mw2 = _RecordingMiddleware( + predicate=legacy_task_filter_to_predicate(lambda name: name.endswith(".allowed"), None) + ) + queue = Queue(db_path=str(tmp_path / "t.db"), workers=1, middleware=[mw, mw2]) + + @queue.task(name="x.allowed") + def allowed() -> int: + return 1 + + @queue.task(name="x.blocked") + def blocked() -> int: + return 1 + + allowed.delay() + blocked.delay() + + assert mw2.enqueue_calls == ["x.allowed"] + assert mw.enqueue_calls == ["x.allowed", "x.blocked"] + + +def test_predicate_filters_enqueue_hook(tmp_path: Any) -> None: + mw = _RecordingMiddleware(predicate=_Const(False)) + queue = Queue(db_path=str(tmp_path / "t.db"), workers=1, middleware=[mw]) + + @queue.task() + def t() -> str: + return "ok" + + t.delay() + assert mw.enqueue_calls == [] + + +def test_predicate_with_callable(tmp_path: Any) -> None: + mw = _RecordingMiddleware(predicate=lambda ctx: ctx.task_name.endswith(".ok")) + queue = Queue(db_path=str(tmp_path / "t.db"), workers=1, middleware=[mw]) + + @queue.task(name="x.ok") + def ok() -> int: + return 1 + + @queue.task(name="x.skip") + def skip() -> int: + return 1 + + ok.delay() + skip.delay() + assert mw.enqueue_calls == ["x.ok"] + + +def test_legacy_helper_predicate_wins_when_both_supplied() -> None: + pred = _Const(True) + legacy = lambda name: False # noqa: E731 + assert legacy_task_filter_to_predicate(legacy, pred) is pred + + +def test_legacy_helper_returns_none_when_both_none() -> None: + assert legacy_task_filter_to_predicate(None, None) is None From 9af0da376094afbdddd4ca9a4363ac1650950591 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Mon, 11 May 2026 09:50:12 +0530 Subject: [PATCH 05/11] feat(predicates): emit PREDICATE_DEFERRED/CANCELLED/REJECTED events --- py_src/taskito/app.py | 77 ++++++++++++++++++++++++- py_src/taskito/events.py | 3 + tests/python/test_predicates_enqueue.py | 44 ++++++++++++++ 3 files changed, 122 insertions(+), 2 deletions(-) diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index 99d442f..e9260b3 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -355,6 +355,16 @@ def _apply_dispatch_predicate( return if isinstance(outcome, Cancel): + self._emit_event( + EventType.PREDICATE_CANCELLED, + { + "task_name": task_name, + "job_id": job_id, + "queue": queue_name, + "reason": outcome.reason, + "phase": "dispatch", + }, + ) raise TaskCancelledError( f"predicate cancelled job {job_id}: {outcome.reason}" if outcome.reason @@ -369,18 +379,49 @@ def _apply_dispatch_predicate( queue_name=queue_name, delay_seconds=outcome.seconds, ) + self._emit_event( + EventType.PREDICATE_DEFERRED, + { + "task_name": task_name, + "job_id": job_id, + "queue": queue_name, + "defer_seconds": outcome.seconds, + "phase": "dispatch", + }, + ) raise TaskCancelledError(f"predicate deferred job {job_id} by {outcome.seconds:.1f}s") # Plain False — branch on on_false. action = self._task_predicate_on_false.get(task_name, "defer") if action == "cancel": + self._emit_event( + EventType.PREDICATE_CANCELLED, + { + "task_name": task_name, + "job_id": job_id, + "queue": queue_name, + "phase": "dispatch", + }, + ) raise TaskCancelledError(f"predicate rejected job {job_id}") + + defer_seconds = self._task_default_defer.get(task_name, 60.0) self._reenqueue_after_defer( task_name=task_name, args=args, kwargs=kwargs, queue_name=queue_name, - delay_seconds=self._task_default_defer.get(task_name, 60.0), + delay_seconds=defer_seconds, + ) + self._emit_event( + EventType.PREDICATE_DEFERRED, + { + "task_name": task_name, + "job_id": job_id, + "queue": queue_name, + "defer_seconds": defer_seconds, + "phase": "dispatch", + }, ) raise TaskCancelledError(f"predicate deferred job {job_id}") @@ -445,16 +486,48 @@ def _apply_enqueue_predicate( if outcome is True: return delay if isinstance(outcome, Defer): + self._emit_event( + EventType.PREDICATE_DEFERRED, + { + "task_name": task_name, + "queue": queue_name, + "defer_seconds": outcome.seconds, + "phase": "enqueue", + }, + ) return (delay or 0.0) + outcome.seconds if isinstance(outcome, Cancel): + self._emit_event( + EventType.PREDICATE_REJECTED, + { + "task_name": task_name, + "queue": queue_name, + "reason": outcome.reason, + "phase": "enqueue", + }, + ) raise PredicateRejectedError(task_name, outcome.reason) # Plain False — branch on the task's on_false setting. action = self._task_predicate_on_false.get(task_name, "defer") if action == "cancel": + self._emit_event( + EventType.PREDICATE_REJECTED, + {"task_name": task_name, "queue": queue_name, "phase": "enqueue"}, + ) raise PredicateRejectedError(task_name) # defer - return (delay or 0.0) + self._task_default_defer.get(task_name, 60.0) + defer_seconds = self._task_default_defer.get(task_name, 60.0) + self._emit_event( + EventType.PREDICATE_DEFERRED, + { + "task_name": task_name, + "queue": queue_name, + "defer_seconds": defer_seconds, + "phase": "enqueue", + }, + ) + return (delay or 0.0) + defer_seconds def enqueue( self, diff --git a/py_src/taskito/events.py b/py_src/taskito/events.py index 4dcbedd..4284980 100644 --- a/py_src/taskito/events.py +++ b/py_src/taskito/events.py @@ -33,6 +33,9 @@ class EventType(enum.Enum): WORKFLOW_FAILED = "workflow.failed" WORKFLOW_CANCELLED = "workflow.cancelled" WORKFLOW_GATE_REACHED = "workflow.gate_reached" + PREDICATE_DEFERRED = "predicate.deferred" + PREDICATE_CANCELLED = "predicate.cancelled" + PREDICATE_REJECTED = "predicate.rejected" class EventBus: diff --git a/tests/python/test_predicates_enqueue.py b/tests/python/test_predicates_enqueue.py index 99d835e..83a3f9e 100644 --- a/tests/python/test_predicates_enqueue.py +++ b/tests/python/test_predicates_enqueue.py @@ -149,3 +149,47 @@ def deferred() -> int: snap = queue._predicate_metrics.snapshot() assert snap["allowed"] == 2 assert snap["deferred"] == 1 + + +def test_defer_emits_predicate_deferred_event(queue: Queue) -> None: + from taskito.events import EventType + + received: list[dict] = [] + queue._event_bus.on(EventType.PREDICATE_DEFERRED, lambda _e, p: received.append(p)) + + @queue.task(predicate=_Const(Defer(seconds=99.0))) + def t() -> str: + return "ran" + + t.delay() + # Event bus dispatches in a thread pool — give it a moment. + import time + + for _ in range(20): + if received: + break + time.sleep(0.05) + assert received and received[0]["defer_seconds"] == 99.0 + assert received[0]["phase"] == "enqueue" + + +def test_cancel_emits_predicate_rejected_event(queue: Queue) -> None: + from taskito.events import EventType + + received: list[dict] = [] + queue._event_bus.on(EventType.PREDICATE_REJECTED, lambda _e, p: received.append(p)) + + @queue.task(predicate=_Const(Cancel(reason="no good"))) + def t() -> str: + return "ran" + + with pytest.raises(PredicateRejectedError): + t.delay() + + import time + + for _ in range(20): + if received: + break + time.sleep(0.05) + assert received and received[0]["reason"] == "no good" From 604c913988fcdd19ee552d75dd0b7b60b0d28908 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Mon, 11 May 2026 09:56:26 +0530 Subject: [PATCH 06/11] docs(predicates): document predicate system + recipes --- docs/content/docs/guides/core/meta.json | 1 + docs/content/docs/guides/core/predicates.mdx | 153 +++++++++++++++++++ tests/observability/test_events.py | 3 + 3 files changed, 157 insertions(+) create mode 100644 docs/content/docs/guides/core/predicates.mdx diff --git a/docs/content/docs/guides/core/meta.json b/docs/content/docs/guides/core/meta.json index 270020d..57b6fec 100644 --- a/docs/content/docs/guides/core/meta.json +++ b/docs/content/docs/guides/core/meta.json @@ -6,6 +6,7 @@ "workers", "execution-model", "scheduling", + "predicates", "workflows" ] } diff --git a/docs/content/docs/guides/core/predicates.mdx b/docs/content/docs/guides/core/predicates.mdx new file mode 100644 index 0000000..a427150 --- /dev/null +++ b/docs/content/docs/guides/core/predicates.mdx @@ -0,0 +1,153 @@ +--- +title: Predicates +description: "Compose Predicate objects to gate when a task is enqueued or dispatched, with built-in recipes for time windows, queue health, feature flags, and more." +--- + +A **predicate** is a composable, fail-closed gate attached to a task. It decides — at enqueue time and again at worker dispatch — whether the job runs, is deferred, or is cancelled. The same predicates also filter which middlewares see a given job. + +```python +from taskito import Queue +from taskito.predicates import ( + is_business_hours, + queue_paused, + by_priority_at_least, +) + +queue = Queue() + +@queue.task( + predicate=is_business_hours(tz="US/Pacific") + & ~queue_paused() + | by_priority_at_least(8), + on_false="defer", +) +def send_report(): ... +``` + +## Outcomes + +A predicate's `evaluate` returns one of four values: + +| Outcome | Meaning | +|---|---| +| `True` | Allow the job to proceed. | +| `False` | Deny. The task's `on_false` action decides what happens — `"defer"` (default) or `"cancel"`. | +| `Defer(seconds=N)` | Skip now, retry after `N` seconds. At enqueue time the delay is added; at dispatch time the job is re-enqueued with the delay. | +| `Cancel(reason="...")` | Permanently skip. At enqueue time raises `PredicateRejectedError`; at dispatch time marks the job cancelled. | + +A predicate that raises an exception is treated as `False` (fail-closed). The error is logged and counted in `PredicateMetrics`. + +## Composition + +`Predicate` overloads `&`, `|`, and `~` with short-circuit semantics: + +```python +allow = is_business_hours() & ~queue_paused() +allow |= by_priority_at_least(8) # urgent jobs bypass both gates +``` + +* `A & B` — both must allow; short-circuits if `A` denies. +* `A | B` — either can allow; short-circuits if `A` allows. If both deny, the most informative outcome wins (`Cancel` > `Defer` > `False`). +* `~A` — inverts `True`/`False`. `Defer` and `Cancel` pass through unchanged. + +## `@queue.task` options + +| Parameter | Type | Default | Description | +|---|---|---|---| +| `predicate` | `Predicate \| Callable \| None` | `None` | The predicate to evaluate. Plain callables receiving a `PredicateContext` are accepted. | +| `on_false` | `"defer" \| "cancel"` | `"defer"` | What to do when the predicate returns plain `False`. | +| `predicate_extras` | `dict \| None` | `None` | Static dict forwarded to the predicate via `ctx.extras`. | +| `default_defer_seconds` | `float` | `60.0` | Delay applied when `on_false="defer"` and the predicate returns plain `False`. | + +## Built-in recipes + +All recipes are factory functions returning a `Predicate`. Import from `taskito.predicates`. + +### Time-based + +* `is_business_hours(start=9, end=17, *, tz=None, weekdays_only=True)` — defers to the next window when outside. +* `is_weekend(*, tz=None)` +* `in_time_window("09:00", "17:00", *, tz=None)` — `HH:MM` strings; end exclusive; defers when outside. +* `after(target_datetime)` — defers until `target`. +* `before(target_datetime)` +* `in_timezone(tz)` — no-op at runtime; validates the tz string at construction. + +### Job-attribute + +* `by_queue(name)` / `by_task(name)` +* `by_priority_at_least(n)` +* `retry_count_under(n)` +* `payload_matches("kwargs.tenant", "acme")` — dotted-path lookup into `{"args": ..., "kwargs": ...}`. + +### System-state + +* `queue_size_under(limit, *, queue=None)` — reads `stats_by_queue()`. +* `queue_paused(queue=None)` +* `error_rate_under(max_rate)` — `(failed + dead) / (completed + failed + dead)` over backend-wide stats. + +System-state recipes are memoised within a single evaluation, so `queue_size_under(100) & queue_size_under(200)` only reads stats once. + +### External-config + +* `env_var_truthy("MY_FLAG")` — truthy values: `1`, `true`, `t`, `yes`, `y`, `on` (case-insensitive). +* `feature_flag("billing", provider=...)` — defaults to env-var provider reading `FF_`. Pass a custom `FeatureFlagProvider` to integrate LaunchDarkly/Statsig/etc. + +## Where predicates run + +A predicate registered on a task is evaluated **twice**: + +1. **Enqueue time** — inside `Queue.enqueue()` / `enqueue_many()`. `Defer` adjusts the caller's `delay=` value. `Cancel` raises `PredicateRejectedError` immediately. The job is never saved when cancelled. +2. **Worker-dispatch time** — just before the task body executes (sync and async paths). `Cancel` raises `TaskCancelledError`, which Rust records as `cancelled`. `Defer` re-enqueues a fresh job with the same payload + delay, and the current job is marked cancelled. + +The `PredicateContext` passed to `evaluate` distinguishes the two phases: `ctx.job_id is None` at enqueue time, set at dispatch time. + +## Writing a custom predicate + +```python +from taskito.predicates import Predicate, PredicateContext, Defer + +class TenantQuotaUnder(Predicate): + def __init__(self, limit: int) -> None: + self.limit = limit + + def evaluate(self, ctx: PredicateContext) -> bool | Defer: + tenant = ctx.kwargs.get("tenant") + if tenant is None: + return False + used = my_quota_service.usage(tenant) + if used < self.limit: + return True + return Defer(seconds=300.0) # check again in 5 minutes +``` + +Predicates may be **async**: return a coroutine and it will be awaited transparently. + +## Middleware filtering + +`TaskMiddleware` accepts `predicate=` to gate which jobs the middleware applies to. The legacy contrib `task_filter=Callable[[str], bool]` kwarg is still supported and translates internally to a predicate. + +```python +from taskito.contrib.sentry import SentryMiddleware +from taskito.predicates import by_queue + +queue = Queue(middleware=[ + SentryMiddleware(predicate=by_queue("critical")), +]) +``` + +When the middleware predicate denies, **only that middleware** is skipped for the current job — job dispatch is unaffected. + +## Metrics and events + +Each `Queue` carries a `PredicateMetrics` instance at `queue._predicate_metrics`: + +```python +queue._predicate_metrics.snapshot() +# {"allowed": 412, "denied": 3, "deferred": 8, "cancelled": 1, "errors": 0} +``` + +Three events are emitted on the queue's event bus: + +* `EventType.PREDICATE_DEFERRED` — payload includes `task_name`, `defer_seconds`, `phase` (`"enqueue"` or `"dispatch"`). +* `EventType.PREDICATE_CANCELLED` — worker-dispatch cancellations. +* `EventType.PREDICATE_REJECTED` — enqueue-time rejections that raised `PredicateRejectedError`. diff --git a/tests/observability/test_events.py b/tests/observability/test_events.py index 6d6e5c6..d724e77 100644 --- a/tests/observability/test_events.py +++ b/tests/observability/test_events.py @@ -99,5 +99,8 @@ def test_all_event_types_exist() -> None: "workflow.failed", "workflow.cancelled", "workflow.gate_reached", + "predicate.deferred", + "predicate.cancelled", + "predicate.rejected", } assert {e.value for e in EventType} == expected From 89d009c4d06f7cc98b434a7149d1f0d933389268 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Mon, 11 May 2026 10:06:33 +0530 Subject: [PATCH 07/11] docs(predicates): add inline examples and end-to-end recipe page --- docs/content/docs/guides/core/predicates.mdx | 159 ++++++++++++ docs/content/docs/more/examples/index.mdx | 1 + docs/content/docs/more/examples/meta.json | 1 + .../more/examples/predicate-gated-jobs.mdx | 244 ++++++++++++++++++ 4 files changed, 405 insertions(+) create mode 100644 docs/content/docs/more/examples/predicate-gated-jobs.mdx diff --git a/docs/content/docs/guides/core/predicates.mdx b/docs/content/docs/guides/core/predicates.mdx index a427150..a816332 100644 --- a/docs/content/docs/guides/core/predicates.mdx +++ b/docs/content/docs/guides/core/predicates.mdx @@ -151,3 +151,162 @@ Three events are emitted on the queue's event bus: * `EventType.PREDICATE_DEFERRED` — payload includes `task_name`, `defer_seconds`, `phase` (`"enqueue"` or `"dispatch"`). * `EventType.PREDICATE_CANCELLED` — worker-dispatch cancellations. * `EventType.PREDICATE_REJECTED` — enqueue-time rejections that raised `PredicateRejectedError`. + +## Examples + +### Business-hours-only report + +Run a report task only between 09:00–17:00 Pacific on weekdays; defer to the next window otherwise. Urgent jobs (priority ≥ 8) bypass the window. + +```python +from taskito import Queue +from taskito.predicates import is_business_hours, by_priority_at_least + +queue = Queue() + +@queue.task( + predicate=is_business_hours(tz="US/Pacific") | by_priority_at_least(8), +) +def send_daily_report(team_id: str) -> None: + ... + +# At 02:00 PT → enqueued with delay until 09:00 PT +send_daily_report.delay("alpha") + +# Same task, urgent → runs immediately regardless of hour +queue.enqueue(send_daily_report.name, args=("alpha",), priority=10) +``` + +### Drop jobs when the system is under load + +Use `queue_size_under` and `error_rate_under` to shed load. Compose them with `&` so a single backpressure breach denies. + +```python +from taskito.predicates import queue_size_under, error_rate_under, queue_paused + +healthy = ( + queue_size_under(1_000, queue="default") + & error_rate_under(0.1) + & ~queue_paused() +) + +@queue.task(predicate=healthy, on_false="defer", default_defer_seconds=30.0) +def import_csv(path: str) -> int: + ... +``` + +### Feature-flag rollout + +Gate a new code path behind a flag. The default provider reads `FF_` from the environment; swap in LaunchDarkly with a 5-line adapter. + +```python +import os +from taskito.predicates import feature_flag, FeatureFlagProvider, PredicateContext + + +class LaunchDarklyProvider: + """Adapter for the LaunchDarkly SDK.""" + + def __init__(self, client: object, user_attr: str = "tenant_id") -> None: + self._client = client + self._user_attr = user_attr + + def is_enabled(self, name: str, ctx: PredicateContext) -> bool: + user = {"key": ctx.kwargs.get(self._user_attr, "anon")} + return bool(self._client.variation(name, user, False)) + + +# In production: provider = LaunchDarklyProvider(ldclient.get()) +# In dev: default env-var provider, FF_NEW_BILLING=1 enables it. +@queue.task(predicate=feature_flag("new_billing")) +def charge_card(tenant_id: str, amount_cents: int) -> str: + ... +``` + +### Tenant allowlist with `payload_matches` + +```python +from taskito.predicates import Predicate, PredicateContext + +class TenantIn(Predicate): + def __init__(self, *tenants: str) -> None: + self._tenants = frozenset(tenants) + + def evaluate(self, ctx: PredicateContext) -> bool: + return ctx.kwargs.get("tenant_id") in self._tenants + + +@queue.task(predicate=TenantIn("acme", "globex"), on_false="cancel") +def reindex_search(tenant_id: str) -> None: + ... + +# raises PredicateRejectedError — never enqueued +queue.enqueue(reindex_search.name, kwargs={"tenant_id": "initech"}) +``` + +### Async predicate with quota service + +`evaluate` may be a coroutine — it's awaited transparently. + +```python +import httpx +from taskito.predicates import Predicate, PredicateContext, Defer + + +class QuotaAvailable(Predicate): + def __init__(self, url: str) -> None: + self._url = url + + async def evaluate(self, ctx: PredicateContext) -> bool | Defer: + async with httpx.AsyncClient(timeout=2.0) as client: + r = await client.get(f"{self._url}/quota/{ctx.kwargs['tenant']}") + remaining = r.json()["remaining"] + if remaining > 0: + return True + return Defer(seconds=60.0) + + +@queue.task(predicate=QuotaAvailable("http://quota.internal")) +async def index_document(tenant: str, doc_id: str) -> None: + ... +``` + +### Per-middleware filtering + +Route a single middleware to only one queue without touching the queue-level middleware list. + +```python +from taskito.contrib.sentry import SentryMiddleware +from taskito.contrib.prometheus import PrometheusMiddleware +from taskito.predicates import by_queue + +queue = Queue(middleware=[ + # Sentry only for the "critical" queue + SentryMiddleware(predicate=by_queue("critical")), + # Prometheus everywhere + PrometheusMiddleware(), +]) +``` + +### Listening for predicate events + +Wire dashboards or alerts by subscribing to the predicate event types. + +```python +from taskito.events import EventType + +def on_predicate_event(event_type: EventType, payload: dict) -> None: + if event_type is EventType.PREDICATE_DEFERRED: + print(f"{payload['task_name']} deferred {payload['defer_seconds']}s " + f"at {payload['phase']}") + elif event_type is EventType.PREDICATE_CANCELLED: + print(f"{payload['task_name']} cancelled at dispatch: " + f"{payload.get('reason', '')}") + +for event in ( + EventType.PREDICATE_DEFERRED, + EventType.PREDICATE_CANCELLED, + EventType.PREDICATE_REJECTED, +): + queue._event_bus.on(event, on_predicate_event) +``` diff --git a/docs/content/docs/more/examples/index.mdx b/docs/content/docs/more/examples/index.mdx index 5c0658a..0ebfe6a 100644 --- a/docs/content/docs/more/examples/index.mdx +++ b/docs/content/docs/more/examples/index.mdx @@ -12,4 +12,5 @@ End-to-end examples demonstrating common taskito patterns. | [Web Scraper Pipeline](/more/examples/web-scraper) | Distributed scraping with chains and error handling | | [Data Pipeline](/more/examples/data-pipeline) | ETL pipeline with dependencies, groups, and chords | | [DAG Workflows](/more/examples/workflows) | Fan-out, conditions, gates, sub-workflows, incremental runs | +| [Predicate-Gated Jobs](/more/examples/predicate-gated-jobs) | Business-hours windows, load-shedding, feature flags, per-tenant quotas | | [Benchmark](/more/examples/benchmark) | Performance benchmarks comparing taskito to alternatives | diff --git a/docs/content/docs/more/examples/meta.json b/docs/content/docs/more/examples/meta.json index 5c86917..c442177 100644 --- a/docs/content/docs/more/examples/meta.json +++ b/docs/content/docs/more/examples/meta.json @@ -6,6 +6,7 @@ "web-scraper", "data-pipeline", "workflows", + "predicate-gated-jobs", "benchmark" ] } diff --git a/docs/content/docs/more/examples/predicate-gated-jobs.mdx b/docs/content/docs/more/examples/predicate-gated-jobs.mdx new file mode 100644 index 0000000..792ff50 --- /dev/null +++ b/docs/content/docs/more/examples/predicate-gated-jobs.mdx @@ -0,0 +1,244 @@ +--- +title: Predicate-Gated Jobs +description: "End-to-end example of using Predicate objects to gate task execution by business hours, system load, feature flags, and per-tenant quotas." +--- + +A small reporting service that uses predicates to gate when tasks may run. It demonstrates every layer of the predicate system: enqueue-time rejection, worker-dispatch defer-and-retry, composable recipes, custom predicates, async predicates, and middleware filtering. + +## What the service does + +* Sends per-tenant **daily reports** during business hours only. +* Defers heavy **reindex jobs** when the queue is backed up or the error rate is high. +* Gates a new **billing workflow** behind a feature flag. +* Routes Sentry alerts only for the `critical` queue. + +## Project structure + +``` +predicate_gated/ + app.py # Queue + tasks + predicates.py # Custom predicates + main.py # Submit jobs + worker.py # Run the worker +``` + +## predicates.py + +A custom predicate that consults an HTTP quota service. Async — taskito awaits it transparently. + +```python +"""Custom predicates for the example service.""" + +from __future__ import annotations + +import httpx + +from taskito.predicates import Defer, Predicate, PredicateContext + + +class TenantQuotaUnder(Predicate): + """Defer until the tenant has quota remaining.""" + + def __init__(self, url: str, limit: int) -> None: + self._url = url + self._limit = limit + + async def evaluate(self, ctx: PredicateContext) -> bool | Defer: + tenant = ctx.kwargs.get("tenant") + if tenant is None: + return False + async with httpx.AsyncClient(timeout=2.0) as client: + r = await client.get(f"{self._url}/quota/{tenant}") + used = r.json()["used"] + if used < self._limit: + return True + # Quota busts → re-check in 5 minutes. + return Defer(seconds=300.0) + + +class TenantIn(Predicate): + """Hard allowlist; rejects at enqueue time.""" + + def __init__(self, *tenants: str) -> None: + self._tenants = frozenset(tenants) + + def evaluate(self, ctx: PredicateContext) -> bool: + return ctx.kwargs.get("tenant") in self._tenants +``` + +## app.py + +Composes built-in recipes with the custom predicates. Each task gets a different predicate shape. + +```python +"""Queue, tasks, and predicate wiring.""" + +from taskito import Queue +from taskito.contrib.sentry import SentryMiddleware +from taskito.predicates import ( + by_priority_at_least, + by_queue, + error_rate_under, + feature_flag, + is_business_hours, + queue_paused, + queue_size_under, +) + +from .predicates import TenantIn, TenantQuotaUnder + +queue = Queue( + db_path=".taskito/predicate_gated.db", + workers=4, + middleware=[ + # Sentry only for the "critical" queue. + SentryMiddleware(predicate=by_queue("critical")), + ], +) + + +# ── Reports: business hours + tenant allowlist ──────────────── +@queue.task( + predicate=is_business_hours(tz="US/Pacific") & TenantIn("acme", "globex"), + on_false="defer", + default_defer_seconds=900.0, # 15 minutes +) +def send_daily_report(tenant: str) -> str: + """Defers outside Mon–Fri 09:00–17:00 Pacific. + Cancels at enqueue if tenant is not in the allowlist.""" + return f"report sent for {tenant}" + + +# ── Reindex: load-shed + bypass when urgent ─────────────────── +healthy = ( + queue_size_under(1_000, queue="bulk") + & error_rate_under(0.1) + & ~queue_paused("bulk") +) + +@queue.task( + queue="bulk", + predicate=healthy | by_priority_at_least(8), + on_false="defer", + default_defer_seconds=30.0, +) +def reindex_search(tenant: str) -> int: + """Skipped when default queue is over 1k pending or error_rate > 10%, + unless explicitly enqueued with priority>=8.""" + return 42 # rows indexed + + +# ── Billing: gated by FF_NEW_BILLING env var ────────────────── +@queue.task( + queue="critical", + predicate=feature_flag("new_billing"), + on_false="cancel", +) +def charge_card(tenant: str, amount_cents: int) -> str: + return f"charged {amount_cents}c to {tenant}" + + +# ── Async + quota-bound indexer ─────────────────────────────── +@queue.task( + predicate=TenantQuotaUnder("http://quota.internal", limit=10_000), + on_false="defer", +) +async def index_document(tenant: str, doc_id: str) -> str: + return f"indexed {doc_id}" +``` + +## main.py + +Submitting jobs. Demonstrates each outcome. + +```python +"""Submit some jobs and watch the predicates kick in.""" + +from taskito.exceptions import PredicateRejectedError + +from .app import ( + charge_card, + index_document, + queue, + reindex_search, + send_daily_report, +) + + +def main() -> None: + # 1. send_daily_report: allowed for "acme"; deferred to 09:00 PT outside hours. + send_daily_report.delay(tenant="acme") + + # 2. send_daily_report: cancel at enqueue — tenant not in allowlist. + try: + send_daily_report.delay(tenant="initech") + except PredicateRejectedError as exc: + print(f"rejected: {exc.task_name} — {exc.reason}") + + # 3. reindex_search: normal path + reindex_search.delay(tenant="acme") + + # 4. reindex_search: bypass backpressure with priority + queue.enqueue(reindex_search.name, kwargs={"tenant": "acme"}, priority=9) + + # 5. charge_card: gated behind FF_NEW_BILLING; cancels if flag is off. + try: + charge_card.delay(tenant="acme", amount_cents=4200) + except PredicateRejectedError: + print("billing flag is off") + + # 6. async predicate + async task — both are awaited transparently + index_document.delay(tenant="acme", doc_id="doc-1") + + # Watch metrics live + print(queue._predicate_metrics.snapshot()) + + +if __name__ == "__main__": + main() +``` + +## worker.py + +```python +from .app import queue + +if __name__ == "__main__": + queue.run_worker() +``` + +## Subscribing to events + +Plug predicate outcomes into your alerting or dashboards. + +```python +from taskito.events import EventType + +def on_predicate_event(event_type: EventType, payload: dict) -> None: + match event_type: + case EventType.PREDICATE_DEFERRED: + print( + f"[defer] {payload['task_name']} +{payload['defer_seconds']}s " + f"phase={payload['phase']}" + ) + case EventType.PREDICATE_CANCELLED: + print(f"[cancel] {payload['task_name']} reason={payload.get('reason', '')}") + case EventType.PREDICATE_REJECTED: + print(f"[reject] {payload['task_name']} reason={payload.get('reason', '')}") + +for event in ( + EventType.PREDICATE_DEFERRED, + EventType.PREDICATE_CANCELLED, + EventType.PREDICATE_REJECTED, +): + queue._event_bus.on(event, on_predicate_event) +``` + +## Notes + +* Predicates are evaluated **twice** for each job (once at enqueue, once at dispatch). Use `ctx.job_id` to tell the phases apart inside a custom predicate. +* `on_false` only controls what happens when the predicate returns a plain `False`. Returning `Defer(seconds=...)` or `Cancel(reason=...)` is honored regardless. +* System-state recipes (`queue_size_under`, `error_rate_under`, `queue_paused`) memoize within a single evaluation, so composing them is cheap. +* A predicate that raises is treated as `False` (fail-closed) — the error is logged and `PredicateMetrics.errors` is incremented. + +See the [Predicates guide](/docs/guides/core/predicates) for the full API reference. From 6ef2a2f966666c562aae2950854dc4222c9e1765 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Mon, 11 May 2026 10:38:10 +0530 Subject: [PATCH 08/11] =?UTF-8?q?feat(predicates):=20v2=20=E2=80=94=20seri?= =?UTF-8?q?alizable=20AST,=20JSON=20+=20string=20DSL,=20drop=20redundant?= =?UTF-8?q?=20recipes?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- py_src/taskito/app.py | 58 ++++ py_src/taskito/mixins/decorators.py | 11 +- py_src/taskito/predicates/__init__.py | 42 +-- py_src/taskito/predicates/core.py | 280 +++++++++++++-- py_src/taskito/predicates/parser.py | 326 ++++++++++++++++++ py_src/taskito/predicates/recipes/__init__.py | 41 +-- .../taskito/predicates/recipes/attributes.py | 114 +++--- py_src/taskito/predicates/recipes/config.py | 110 +++++- py_src/taskito/predicates/recipes/system.py | 79 ++--- py_src/taskito/predicates/recipes/time.py | 278 +++++++++------ py_src/taskito/predicates/registry.py | 75 ++++ tests/python/test_predicates_core.py | 2 +- tests/python/test_predicates_dsl.py | 146 ++++++++ tests/python/test_predicates_persistence.py | 108 ++++++ tests/python/test_predicates_recipes.py | 133 ++----- 15 files changed, 1355 insertions(+), 448 deletions(-) create mode 100644 py_src/taskito/predicates/parser.py create mode 100644 py_src/taskito/predicates/registry.py create mode 100644 tests/python/test_predicates_dsl.py create mode 100644 tests/python/test_predicates_persistence.py diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index e9260b3..1040d12 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -44,9 +44,16 @@ QueueSettingsMixin, ) from taskito.predicates.context import PredicateContext +from taskito.predicates.core import Predicate as _Predicate from taskito.predicates.evaluate import evaluate_predicate from taskito.predicates.metrics import PredicateMetrics from taskito.predicates.outcomes import Cancel, Defer +from taskito.predicates.registry import ( + PredicateValidationError, +) +from taskito.predicates.registry import ( + register_predicate as _register_predicate, +) from taskito.proxies import ProxyRegistry from taskito.proxies.built_in import register_builtin_handlers from taskito.proxies.metrics import ProxyMetrics @@ -223,6 +230,7 @@ def __init__( self._task_predicate_on_false: dict[str, str] = {} self._task_predicate_extras: dict[str, dict[str, Any]] = {} self._task_default_defer: dict[str, float] = {} + self._task_predicate_serialized: dict[str, dict[str, Any] | None] = {} self._predicate_metrics = PredicateMetrics() self._drain_timeout = drain_timeout self._queue_configs: dict[str, dict[str, Any]] = {} @@ -311,6 +319,56 @@ def _deserialize_payload(self, task_name: str, payload: bytes) -> tuple: """Deserialize a job payload using the per-task or queue-level serializer.""" return self._get_serializer(task_name).loads(payload) # type: ignore[no-any-return] + # -- Predicate inspection / registration ----------------------------- + + def list_predicates(self) -> dict[str, dict[str, Any] | None]: + """Return the serialized predicate (or ``None`` for bare callables) + registered for every task that has one. + + The values are JSON-safe dicts produced by + :meth:`Predicate.to_dict`. Consumers (dashboard, audit logs) can + feed each value back through :meth:`Predicate.from_dict` to + rebuild the AST. + """ + return dict(self._task_predicate_serialized) + + def predicate_for(self, task_name: str) -> dict[str, Any] | None: + """Return the serialized predicate for ``task_name`` or ``None``.""" + return self._task_predicate_serialized.get(task_name) + + def register_predicate(self, op: str, *, replace: bool = False) -> Callable[[type], type]: + """Class decorator: register a custom :class:`Predicate` subclass. + + Example:: + + from taskito.predicates import Predicate + + @queue.register_predicate("tenant_quota_under") + class TenantQuotaUnder(Predicate): + OP = "tenant_quota_under" + ... + + The ``OP`` set on the class must match ``op``. Once registered, + the predicate participates in JSON serialization and the string + DSL just like a built-in recipe. + """ + + def decorator(cls: type) -> type: + if not isinstance(cls, type) or not issubclass(cls, _Predicate): + raise PredicateValidationError( + f"register_predicate target must subclass Predicate; got {cls!r}" + ) + declared = cls.__dict__.get("OP") + if declared and declared != op: + raise PredicateValidationError( + f"OP mismatch: decorator says {op!r}, class declares {declared!r}" + ) + cls.OP = op + _register_predicate(op, cls, replace=replace) + return cls + + return decorator + def _apply_dispatch_predicate( self, *, diff --git a/py_src/taskito/mixins/decorators.py b/py_src/taskito/mixins/decorators.py index 60bfe31..cca936d 100644 --- a/py_src/taskito/mixins/decorators.py +++ b/py_src/taskito/mixins/decorators.py @@ -69,6 +69,7 @@ class QueueDecoratorMixin: _task_predicate_on_false: dict[str, str] _task_predicate_extras: dict[str, dict[str, Any]] _task_default_defer: dict[str, float] + _task_predicate_serialized: dict[str, dict[str, Any] | None] _predicate_metrics: PredicateMetrics _interceptor: ArgumentInterceptor | None _proxy_registry: ProxyRegistry | None @@ -342,7 +343,11 @@ def decorator(fn: Callable) -> TaskWrapper: if idempotent: self._task_idempotent[task_name] = True - # Store predicate (and its on_false/extras/default_defer) + # Store predicate (and its on_false/extras/default_defer). + # Also serialize a JSON snapshot so the inspection API and + # dashboard can show "gated by: ..." without keeping a live + # Python reference. Bare callables can't be serialized; the + # snapshot is None in that case. if predicate is not None: coerced = coerce_predicate(predicate) if coerced is not None: @@ -351,6 +356,10 @@ def decorator(fn: Callable) -> TaskWrapper: if predicate_extras: self._task_predicate_extras[task_name] = dict(predicate_extras) self._task_default_defer[task_name] = default_defer_seconds + try: + self._task_predicate_serialized[task_name] = coerced.to_dict() + except Exception: + self._task_predicate_serialized[task_name] = None # Store inject map for resource injection if final_inject: diff --git a/py_src/taskito/predicates/__init__.py b/py_src/taskito/predicates/__init__.py index 73890b1..25713c6 100644 --- a/py_src/taskito/predicates/__init__.py +++ b/py_src/taskito/predicates/__init__.py @@ -1,14 +1,11 @@ """Composable, fail-closed predicates for gating tasks. -A predicate is any subclass of :class:`Predicate` whose -:meth:`Predicate.evaluate` returns ``True`` (allow), ``False`` (deny), -:class:`Defer` (skip now, retry later), or :class:`Cancel` (skip -permanently). Predicates compose with ``&`` / ``|`` / ``~``:: - - from taskito.predicates import is_business_hours, queue_paused - - @queue.task(predicate=is_business_hours() & ~queue_paused()) - def send_report(): ... +A predicate is a serializable AST node — a subclass of +:class:`Predicate` whose :meth:`Predicate.evaluate` returns ``True`` +(allow), ``False`` (deny), :class:`Defer` (skip now, retry later), or +:class:`Cancel` (skip permanently). Predicates compose with ``&`` / +``|`` / ``~``; every resulting tree serializes through +:meth:`Predicate.to_dict` and :func:`parse` / :meth:`Predicate.format`. Built-in recipes are imported from :mod:`taskito.predicates.recipes`. """ @@ -26,15 +23,12 @@ def send_report(): ... from taskito.predicates.evaluate import evaluate_predicate from taskito.predicates.metrics import PredicateMetrics from taskito.predicates.outcomes import Cancel, Defer, PredicateOutcome +from taskito.predicates.parser import format_predicate, parse from taskito.predicates.providers import FeatureFlagProvider, env_feature_flag_provider from taskito.predicates.recipes import ( after, before, - by_priority_at_least, - by_queue, - by_task, env_var_truthy, - error_rate_under, feature_flag, in_time_window, in_timezone, @@ -42,8 +36,13 @@ def send_report(): ... is_weekend, payload_matches, queue_paused, - queue_size_under, - retry_count_under, + register_feature_flag_provider, +) +from taskito.predicates.registry import ( + PredicateRegistry, + PredicateValidationError, + default_registry, + register_predicate, ) __all__ = [ @@ -57,23 +56,24 @@ def send_report(): ... "PredicateContext", "PredicateMetrics", "PredicateOutcome", + "PredicateRegistry", + "PredicateValidationError", "after", "before", - "by_priority_at_least", - "by_queue", - "by_task", "coerce_predicate", + "default_registry", "env_feature_flag_provider", "env_var_truthy", - "error_rate_under", "evaluate_predicate", "feature_flag", + "format_predicate", "in_time_window", "in_timezone", "is_business_hours", "is_weekend", + "parse", "payload_matches", "queue_paused", - "queue_size_under", - "retry_count_under", + "register_feature_flag_provider", + "register_predicate", ] diff --git a/py_src/taskito/predicates/core.py b/py_src/taskito/predicates/core.py index 7d5a735..ec7188f 100644 --- a/py_src/taskito/predicates/core.py +++ b/py_src/taskito/predicates/core.py @@ -1,19 +1,36 @@ -"""Predicate ABC and composition primitives.""" +"""Predicate AST root + boolean combinator nodes. + +The :class:`Predicate` base is a serializable AST node. Every concrete +subclass declares a stable ``OP`` name that registers it with the +default :class:`~taskito.predicates.registry.PredicateRegistry`. + +Authors compose predicates with the standard Python operators (``&``, +``|``, ``~``); each operator produces a typed AST node so the resulting +tree round-trips through JSON via :meth:`Predicate.to_dict` and +:meth:`Predicate.from_dict`. A human-readable string surface is +available through :meth:`format` and the matching +:func:`~taskito.predicates.parser.parse` function. +""" from __future__ import annotations +import dataclasses from abc import ABC, abstractmethod from collections.abc import Awaitable, Callable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, ClassVar from taskito.predicates.outcomes import Cancel, Defer +from taskito.predicates.registry import ( + PredicateValidationError, + default_registry, +) if TYPE_CHECKING: from taskito.predicates.context import PredicateContext PredicateReturn = bool | Defer | Cancel | Awaitable[Any] -"""What a :meth:`Predicate.evaluate` is allowed to return. +"""Anything a :meth:`Predicate.evaluate` is allowed to return. Synchronous predicates return :class:`bool`, :class:`Defer`, or :class:`Cancel`. Async predicates may also return an awaitable that @@ -21,24 +38,47 @@ """ +# Operator precedence used by :meth:`Predicate.format` so children of a +# higher-precedence node are wrapped in parentheses only when needed. +_PREC_OR = 1 +_PREC_AND = 2 +_PREC_NOT = 3 +_PREC_ATOM = 4 + + class Predicate(ABC): - """Base class for task predicates. + """Base class for every node in the predicate AST. - Subclass and implement :meth:`evaluate`. Predicates compose with the - standard boolean operators:: + Concrete subclasses set ``OP`` to a stable, registry-unique string. + Built-in subclasses register themselves automatically through + ``__init_subclass__``. Custom op classes either inherit and set + ``OP``, or register manually via + :func:`~taskito.predicates.registry.register_predicate`. - @queue.task(predicate=is_business_hours() & ~queue_paused()) - def send_report(): ... + Serialization: - Composition short-circuits: ``And`` stops at the first non-True, - ``Or`` stops at the first ``True``. ``Defer`` and ``Cancel`` outcomes - are propagated unchanged through ``~`` and through ``And``/``Or`` when - they cannot be overridden by the other operand. + * :meth:`to_dict` — emit ``{"op": OP, **fields}``. + * :meth:`from_dict` — dispatch via the registry, recurse for + composite nodes. + * :meth:`format` — stable, parser-compatible string surface. """ + OP: ClassVar[str | None] = None + PREC: ClassVar[int] = _PREC_ATOM + + def __init_subclass__(cls, **kw: Any) -> None: + super().__init_subclass__(**kw) + op = cls.__dict__.get("OP") + if op: + default_registry().register(op, cls) + + # -- Evaluation ------------------------------------------------------- + @abstractmethod def evaluate(self, ctx: PredicateContext) -> PredicateReturn: - """Return ``True`` to allow the job, ``False`` / ``Defer`` / ``Cancel`` to gate it.""" + """Return ``True`` to allow, ``False`` / Defer / Cancel to gate.""" + + # -- Composition ------------------------------------------------------ def __and__(self, other: Predicate) -> Predicate: return AndPredicate(self, other) @@ -49,22 +89,121 @@ def __or__(self, other: Predicate) -> Predicate: def __invert__(self) -> Predicate: return NotPredicate(self) - def __repr__(self) -> str: - return f"{type(self).__name__}()" + # -- Serialization ---------------------------------------------------- + + def to_dict(self) -> dict[str, Any]: + """Serialize the node to a JSON-safe dict. + + Default implementation works for any frozen dataclass: every + dataclass field is emitted as a key. Non-dataclass nodes + (``and``/``or``/``not``, custom classes) override this. + """ + if not self.OP: + raise PredicateValidationError( + f"{type(self).__name__} has no OP name; cannot serialize" + ) + if dataclasses.is_dataclass(self): + fields = { + f.name: getattr(self, f.name) + for f in dataclasses.fields(self) + if not f.name.startswith("_") + } + return {"op": self.OP, **fields} + return {"op": self.OP} + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Predicate: + """Build a predicate from a JSON-safe dict. + + Dispatches to the registry on ``data["op"]``. Unknown ops raise + :class:`PredicateValidationError`. + """ + if not isinstance(data, dict): + raise PredicateValidationError(f"expected dict, got {type(data).__name__}") + op = data.get("op") + if not isinstance(op, str) or not op: + raise PredicateValidationError(f"missing or invalid 'op' field: {data!r}") + target = default_registry().lookup(op) + kwargs = {k: v for k, v in data.items() if k != "op"} + try: + return target._from_kwargs(kwargs) + except (TypeError, ValueError) as exc: + raise PredicateValidationError( + f"cannot construct {op!r} from {kwargs!r}: {exc}" + ) from exc + + @classmethod + def _from_kwargs(cls, kwargs: dict[str, Any]) -> Predicate: + """Build an instance from deserialized kwargs. + + Default works for dataclasses with primitive fields. Composite + nodes (and/or/not) override to recurse. + """ + return cls(**kwargs) + + # -- String surface --------------------------------------------------- + + def format(self) -> str: + """Render this node in the parser-compatible string DSL.""" + return self._format_atom() + + def _format_atom(self) -> str: + """Default leaf rendering: ``op_name(field=value, ...)``.""" + if not self.OP: + # Custom Predicate subclasses without an OP can still appear + # in debug output of composed trees; emit a stable opaque + # marker rather than raising. + return f"<{type(self).__name__}>" + if dataclasses.is_dataclass(self): + parts: list[str] = [] + for f in dataclasses.fields(self): + if f.name.startswith("_"): + continue + value = getattr(self, f.name) + default = ( + f.default + if f.default is not dataclasses.MISSING + else ( + f.default_factory() + if f.default_factory is not dataclasses.MISSING + else _UNSET + ) + ) + if default is not _UNSET and value == default: + continue + parts.append(f"{f.name}={_format_literal(value)}") + args = ", ".join(parts) + return f"{self.OP}({args})" + return f"{self.OP}()" + + def _format_child(self, child: Predicate) -> str: + text = child.format() + return f"({text})" if child.PREC < self.PREC else text + + # -- Repr passthrough ------------------------------------------------- + + def __repr__(self) -> str: # pragma: no cover - human aid + try: + return self.format() + except Exception: + return f"{type(self).__name__}()" + + +# -- Boolean combinator nodes -------------------------------------------- class AndPredicate(Predicate): """Logical AND with short-circuit evaluation.""" __slots__ = ("_left", "_right") + OP: ClassVar[str | None] = "and" + PREC = _PREC_AND def __init__(self, left: Predicate, right: Predicate) -> None: self._left = left self._right = right def evaluate(self, ctx: PredicateContext) -> PredicateReturn: - # Note: deferred imports avoided. Compose-time evaluation calls - # _resolve_sync which itself handles awaitables when present. from taskito.predicates.evaluate import _resolve_outcome left = _resolve_outcome(self._left, ctx) @@ -72,14 +211,26 @@ def evaluate(self, ctx: PredicateContext) -> PredicateReturn: return left return _resolve_outcome(self._right, ctx) - def __repr__(self) -> str: - return f"({self._left!r} & {self._right!r})" + def to_dict(self) -> dict[str, Any]: + return {"op": "and", "args": [self._left.to_dict(), self._right.to_dict()]} + + @classmethod + def _from_kwargs(cls, kwargs: dict[str, Any]) -> Predicate: + args = kwargs.get("args") + if not isinstance(args, list) or len(args) != 2: + raise PredicateValidationError(f"'and' expects exactly 2 args in a list, got {args!r}") + return cls(Predicate.from_dict(args[0]), Predicate.from_dict(args[1])) + + def format(self) -> str: + return f"{self._format_child(self._left)} & {self._format_child(self._right)}" class OrPredicate(Predicate): """Logical OR with short-circuit evaluation.""" __slots__ = ("_left", "_right") + OP: ClassVar[str | None] = "or" + PREC = _PREC_OR def __init__(self, left: Predicate, right: Predicate) -> None: self._left = left @@ -94,26 +245,37 @@ def evaluate(self, ctx: PredicateContext) -> PredicateReturn: right = _resolve_outcome(self._right, ctx) if right is True: return True - # Neither side allows. Prefer the most informative gating: Cancel - # wins over Defer, Defer wins over False. + # Both deny — surface the most informative outcome. if isinstance(left, Cancel) or isinstance(right, Cancel): return left if isinstance(left, Cancel) else right if isinstance(left, Defer) or isinstance(right, Defer): return left if isinstance(left, Defer) else right return False - def __repr__(self) -> str: - return f"({self._left!r} | {self._right!r})" + def to_dict(self) -> dict[str, Any]: + return {"op": "or", "args": [self._left.to_dict(), self._right.to_dict()]} + + @classmethod + def _from_kwargs(cls, kwargs: dict[str, Any]) -> Predicate: + args = kwargs.get("args") + if not isinstance(args, list) or len(args) != 2: + raise PredicateValidationError(f"'or' expects exactly 2 args in a list, got {args!r}") + return cls(Predicate.from_dict(args[0]), Predicate.from_dict(args[1])) + + def format(self) -> str: + return f"{self._format_child(self._left)} | {self._format_child(self._right)}" class NotPredicate(Predicate): """Logical NOT. - Inverts ``True`` / ``False``. ``Defer`` and ``Cancel`` outcomes pass + Inverts ``True``/``False``. ``Defer`` and ``Cancel`` outcomes pass through unchanged — they are terminal, not booleans. """ __slots__ = ("_inner",) + OP: ClassVar[str | None] = "not" + PREC = _PREC_NOT def __init__(self, inner: Predicate) -> None: self._inner = inner @@ -126,22 +288,34 @@ def evaluate(self, ctx: PredicateContext) -> PredicateReturn: return outcome return not outcome - def __repr__(self) -> str: - return f"~{self._inner!r}" + def to_dict(self) -> dict[str, Any]: + return {"op": "not", "arg": self._inner.to_dict()} + + @classmethod + def _from_kwargs(cls, kwargs: dict[str, Any]) -> Predicate: + arg = kwargs.get("arg") + if not isinstance(arg, dict): + raise PredicateValidationError(f"'not' expects an 'arg' dict, got {arg!r}") + return cls(Predicate.from_dict(arg)) + + def format(self) -> str: + return f"!{self._format_child(self._inner)}" + + +# -- Callable adapter (non-serializable) --------------------------------- class _CallablePredicate(Predicate): """Adapter that wraps a plain callable as a :class:`Predicate`. - Two signatures are supported: - - * ``Callable[[PredicateContext], bool | Defer | Cancel]`` — receives - the full context. - * ``Callable[[str], bool]`` — receives just the task name. Used to - preserve back-compat with the contrib middleware ``task_filter`` arg. + Used for back-compat with the contrib middleware ``task_filter`` + kwarg and for end-users who pass a bare callable to ``@queue.task``. + Not registered; ``to_dict`` / ``format`` raise because a callable + has no stable schema. """ __slots__ = ("_fn", "_takes_str") + OP: ClassVar[str | None] = None def __init__(self, fn: Callable[..., Any], *, takes_str: bool = False) -> None: self._fn = fn @@ -152,8 +326,17 @@ def evaluate(self, ctx: PredicateContext) -> PredicateReturn: return bool(self._fn(ctx.task_name)) return self._fn(ctx) # type: ignore[no-any-return] - def __repr__(self) -> str: - return f"callable({getattr(self._fn, '__qualname__', repr(self._fn))})" + def to_dict(self) -> dict[str, Any]: + raise PredicateValidationError( + "bare callables cannot be serialized; wrap your logic in a " + "Predicate subclass with a stable OP name" + ) + + def format(self) -> str: + return f"" + + +# -- Public coercion helper ---------------------------------------------- def coerce_predicate( @@ -163,8 +346,8 @@ def coerce_predicate( ) -> Predicate | None: """Wrap a callable as a :class:`Predicate`. ``None`` passes through. - Use ``str_callable=True`` to accept the legacy ``Callable[[str], bool]`` - contrib ``task_filter`` shape. + Use ``str_callable=True`` to accept the legacy + ``Callable[[str], bool]`` contrib ``task_filter`` shape. """ if value is None: return None @@ -175,3 +358,28 @@ def coerce_predicate( raise TypeError( f"predicate must be a Predicate, callable, or None; got {type(value).__name__}" ) + + +# -- Internal helpers ---------------------------------------------------- + + +_UNSET = object() + + +def _format_literal(value: Any) -> str: + """Render ``value`` as a parser-compatible literal.""" + if isinstance(value, str): + # Use double quotes; escape inner doubles. + return '"' + value.replace("\\", "\\\\").replace('"', '\\"') + '"' + if isinstance(value, bool): + return "true" if value else "false" + if value is None: + return "null" + if isinstance(value, (int, float)): + return repr(value) + if isinstance(value, (list, tuple)): + return "[" + ", ".join(_format_literal(v) for v in value) + "]" + # Fallback: dataclasses / objects we don't deeply support yet — let + # repr handle it. Round-trip via the string parser isn't guaranteed + # for non-primitive values, but JSON round-trip still works. + return repr(value) diff --git a/py_src/taskito/predicates/parser.py b/py_src/taskito/predicates/parser.py new file mode 100644 index 0000000..a47957b --- /dev/null +++ b/py_src/taskito/predicates/parser.py @@ -0,0 +1,326 @@ +"""Tiny recursive-descent parser for the predicate string DSL. + +Grammar:: + + expr := or_expr + or_expr := and_expr ("|" and_expr)* + and_expr := unary ("&" unary)* + unary := "!" unary | atom + atom := IDENT "(" kwargs? ")" | "(" expr ")" + kwargs := kwarg ("," kwarg)* + kwarg := IDENT "=" literal + literal := STRING | NUMBER | "true" | "false" | "null" | "[" list_items? "]" + +Examples:: + + is_business_hours(tz="US/Pacific") & !queue_paused() + feature_flag(name="billing") | payload_matches(path="kwargs.tenant", expected="acme") + +The parser builds the same :class:`~taskito.predicates.core.Predicate` +AST that :meth:`Predicate.from_dict` produces, so it round-trips with +both JSON and Python-operator forms. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +from taskito.predicates.core import Predicate +from taskito.predicates.registry import PredicateValidationError, default_registry + +# -- Token model --------------------------------------------------------- + + +class _TokType: + IDENT = "IDENT" + STRING = "STRING" + NUMBER = "NUMBER" + BOOL = "BOOL" + NULL = "NULL" + LPAREN = "LPAREN" + RPAREN = "RPAREN" + LBRACKET = "LBRACKET" + RBRACKET = "RBRACKET" + EQ = "EQ" + COMMA = "COMMA" + AND = "AND" + OR = "OR" + NOT = "NOT" + EOF = "EOF" + + +@dataclass(frozen=True) +class _Token: + kind: str + value: Any + pos: int + + +# -- Lexer --------------------------------------------------------------- + + +_KEYWORDS = {"true": True, "false": False, "null": None} + + +def _tokenize(source: str) -> list[_Token]: + tokens: list[_Token] = [] + i = 0 + n = len(source) + while i < n: + ch = source[i] + if ch.isspace(): + i += 1 + continue + if ch == "(": + tokens.append(_Token(_TokType.LPAREN, ch, i)) + i += 1 + continue + if ch == ")": + tokens.append(_Token(_TokType.RPAREN, ch, i)) + i += 1 + continue + if ch == "[": + tokens.append(_Token(_TokType.LBRACKET, ch, i)) + i += 1 + continue + if ch == "]": + tokens.append(_Token(_TokType.RBRACKET, ch, i)) + i += 1 + continue + if ch == ",": + tokens.append(_Token(_TokType.COMMA, ch, i)) + i += 1 + continue + if ch == "=": + tokens.append(_Token(_TokType.EQ, ch, i)) + i += 1 + continue + if ch == "&": + tokens.append(_Token(_TokType.AND, ch, i)) + i += 1 + continue + if ch == "|": + tokens.append(_Token(_TokType.OR, ch, i)) + i += 1 + continue + if ch == "!": + tokens.append(_Token(_TokType.NOT, ch, i)) + i += 1 + continue + if ch == '"' or ch == "'": + string_value, j = _lex_string(source, i) + tokens.append(_Token(_TokType.STRING, string_value, i)) + i = j + continue + if ch.isdigit() or (ch == "-" and i + 1 < n and source[i + 1].isdigit()): + number_value, j = _lex_number(source, i) + tokens.append(_Token(_TokType.NUMBER, number_value, i)) + i = j + continue + if ch.isalpha() or ch == "_": + value, j = _lex_ident(source, i) + if value in _KEYWORDS: + kind = _TokType.BOOL if isinstance(_KEYWORDS[value], bool) else _TokType.NULL + tokens.append(_Token(kind, _KEYWORDS[value], i)) + elif value == "and": + tokens.append(_Token(_TokType.AND, "and", i)) + elif value == "or": + tokens.append(_Token(_TokType.OR, "or", i)) + elif value == "not": + tokens.append(_Token(_TokType.NOT, "not", i)) + else: + tokens.append(_Token(_TokType.IDENT, value, i)) + i = j + continue + raise PredicateValidationError(f"unexpected character {ch!r} at position {i}") + tokens.append(_Token(_TokType.EOF, None, n)) + return tokens + + +def _lex_string(source: str, start: int) -> tuple[str, int]: + quote = source[start] + out: list[str] = [] + i = start + 1 + n = len(source) + while i < n: + ch = source[i] + if ch == "\\" and i + 1 < n: + nxt = source[i + 1] + escapes = {"n": "\n", "t": "\t", "r": "\r", "\\": "\\", '"': '"', "'": "'"} + out.append(escapes.get(nxt, nxt)) + i += 2 + continue + if ch == quote: + return "".join(out), i + 1 + out.append(ch) + i += 1 + raise PredicateValidationError(f"unterminated string starting at position {start}") + + +def _lex_number(source: str, start: int) -> tuple[int | float, int]: + i = start + n = len(source) + if source[i] == "-": + i += 1 + saw_dot = False + saw_exp = False + while i < n: + ch = source[i] + if ch.isdigit(): + i += 1 + elif ch == "." and not saw_dot and not saw_exp: + saw_dot = True + i += 1 + elif ch in ("e", "E") and not saw_exp: + saw_exp = True + saw_dot = True # forbid more dots after exp + i += 1 + if i < n and source[i] in ("+", "-"): + i += 1 + else: + break + text = source[start:i] + if saw_dot or saw_exp: + return float(text), i + return int(text), i + + +def _lex_ident(source: str, start: int) -> tuple[str, int]: + i = start + n = len(source) + while i < n and (source[i].isalnum() or source[i] == "_"): + i += 1 + return source[start:i], i + + +# -- Parser -------------------------------------------------------------- + + +class _Parser: + __slots__ = ("_pos", "_tokens") + + def __init__(self, tokens: list[_Token]) -> None: + self._tokens = tokens + self._pos = 0 + + def parse(self) -> Predicate: + node = self._or() + if self._peek().kind != _TokType.EOF: + tok = self._peek() + raise PredicateValidationError(f"unexpected token {tok.value!r} at position {tok.pos}") + return node + + def _peek(self) -> _Token: + return self._tokens[self._pos] + + def _eat(self, kind: str) -> _Token: + tok = self._peek() + if tok.kind != kind: + raise PredicateValidationError( + f"expected {kind}, got {tok.kind} ({tok.value!r}) at position {tok.pos}" + ) + self._pos += 1 + return tok + + def _consume(self, kind: str) -> bool: + if self._peek().kind == kind: + self._pos += 1 + return True + return False + + def _or(self) -> Predicate: + left = self._and() + while self._consume(_TokType.OR): + right = self._and() + left = left | right + return left + + def _and(self) -> Predicate: + left = self._unary() + while self._consume(_TokType.AND): + right = self._unary() + left = left & right + return left + + def _unary(self) -> Predicate: + if self._consume(_TokType.NOT): + return ~self._unary() + return self._atom() + + def _atom(self) -> Predicate: + tok = self._peek() + if tok.kind == _TokType.LPAREN: + self._eat(_TokType.LPAREN) + inner = self._or() + self._eat(_TokType.RPAREN) + return inner + if tok.kind == _TokType.IDENT: + return self._call() + raise PredicateValidationError( + f"expected predicate atom, got {tok.kind} ({tok.value!r}) at position {tok.pos}" + ) + + def _call(self) -> Predicate: + ident = self._eat(_TokType.IDENT) + self._eat(_TokType.LPAREN) + kwargs: dict[str, Any] = {} + if not self._consume(_TokType.RPAREN): + while True: + name = self._eat(_TokType.IDENT).value + self._eat(_TokType.EQ) + kwargs[name] = self._literal() + if not self._consume(_TokType.COMMA): + break + self._eat(_TokType.RPAREN) + op_name = ident.value + target = default_registry().lookup(op_name) + try: + return target._from_kwargs(kwargs) + except (TypeError, ValueError, PredicateValidationError) as exc: + raise PredicateValidationError( + f"cannot construct {op_name!r} from {kwargs!r}: {exc}" + ) from exc + + def _literal(self) -> Any: + tok = self._peek() + if tok.kind in (_TokType.STRING, _TokType.NUMBER, _TokType.BOOL, _TokType.NULL): + self._pos += 1 + return tok.value + if tok.kind == _TokType.LBRACKET: + self._eat(_TokType.LBRACKET) + items: list[Any] = [] + if not self._consume(_TokType.RBRACKET): + while True: + items.append(self._literal()) + if not self._consume(_TokType.COMMA): + break + self._eat(_TokType.RBRACKET) + return items + raise PredicateValidationError( + f"expected literal, got {tok.kind} ({tok.value!r}) at position {tok.pos}" + ) + + +# -- Public API ---------------------------------------------------------- + + +def parse(source: str) -> Predicate: + """Parse a predicate string into a :class:`Predicate` AST. + + Raises :class:`~taskito.predicates.PredicateValidationError` on + malformed input or unknown op names. + """ + if not isinstance(source, str): + raise PredicateValidationError(f"parse() expected str, got {type(source).__name__}") + tokens = _tokenize(source) + return _Parser(tokens).parse() + + +def format_predicate(node: Predicate) -> str: + """Render a :class:`Predicate` in the string DSL. + + Convenience wrapper around :meth:`Predicate.format`. The output is + stable and ``parse(format_predicate(p))`` rebuilds an equivalent AST. + """ + return node.format() diff --git a/py_src/taskito/predicates/recipes/__init__.py b/py_src/taskito/predicates/recipes/__init__.py index b04c0eb..329431b 100644 --- a/py_src/taskito/predicates/recipes/__init__.py +++ b/py_src/taskito/predicates/recipes/__init__.py @@ -1,35 +1,25 @@ -"""Predefined predicate recipes. +"""Built-in predicate recipes. -Recipes are factory functions that return :class:`~taskito.predicates.Predicate` -instances. Each recipe accepts plain Python values so it can be used in -decorator declarations:: +Each recipe is a factory function returning a +:class:`~taskito.predicates.Predicate`. Every recipe is a registered AST +node (`OP` declared) so it round-trips through JSON and the string DSL. - @queue.task( - predicate=is_business_hours(tz="US/Pacific") - & ~queue_paused() - | by_priority_at_least(8), - ) - def send_report(): ... +Recipes that duplicated Rust-side enforcement (``by_queue``, +``by_task``, ``queue_size_under``, ``error_rate_under``, +``retry_count_under``, ``by_priority_at_least``) were intentionally +removed in v2 — use the corresponding ``@queue.task`` / ``Queue.set_*`` +options instead. """ from __future__ import annotations -from taskito.predicates.recipes.attributes import ( - by_priority_at_least, - by_queue, - by_task, - payload_matches, - retry_count_under, -) +from taskito.predicates.recipes.attributes import payload_matches from taskito.predicates.recipes.config import ( env_var_truthy, feature_flag, + register_feature_flag_provider, ) -from taskito.predicates.recipes.system import ( - error_rate_under, - queue_paused, - queue_size_under, -) +from taskito.predicates.recipes.system import queue_paused from taskito.predicates.recipes.time import ( after, before, @@ -42,11 +32,7 @@ def send_report(): ... __all__ = [ "after", "before", - "by_priority_at_least", - "by_queue", - "by_task", "env_var_truthy", - "error_rate_under", "feature_flag", "in_time_window", "in_timezone", @@ -54,6 +40,5 @@ def send_report(): ... "is_weekend", "payload_matches", "queue_paused", - "queue_size_under", - "retry_count_under", + "register_feature_flag_provider", ] diff --git a/py_src/taskito/predicates/recipes/attributes.py b/py_src/taskito/predicates/recipes/attributes.py index cd2cd8b..3303544 100644 --- a/py_src/taskito/predicates/recipes/attributes.py +++ b/py_src/taskito/predicates/recipes/attributes.py @@ -1,102 +1,68 @@ -"""Predicates that read job metadata.""" +"""Predicates that read job payload fields. + +The other "attribute"-flavoured recipes (``by_queue``, ``by_task``, +``by_priority_at_least``, ``retry_count_under``) were intentionally +removed in v2: each duplicated a gate the Rust scheduler already +enforces (per-task queue routing, ``priority`` ordering, +``max_concurrent``/``rate_limit``, ``max_retries``). Restating those +inside a Python predicate produces weaker, non-atomic gates that race +with the authoritative enforcement path. + +What remains here is :func:`payload_matches` — a genuinely new +capability that lets predicates branch on the deserialized argument +graph at runtime. +""" from __future__ import annotations -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, ClassVar from taskito.predicates.core import Predicate +from taskito.predicates.registry import PredicateValidationError if TYPE_CHECKING: from taskito.predicates.context import PredicateContext -@dataclass(frozen=True) -class _ByQueue(Predicate): - name: str - - def evaluate(self, ctx: PredicateContext) -> bool: - return ctx.queue == self.name - - -def by_queue(name: str) -> Predicate: - """Allow only jobs whose target queue equals ``name``.""" - if not name: - raise ValueError("queue name must be non-empty") - return _ByQueue(name=name) - - -@dataclass(frozen=True) -class _ByTask(Predicate): - name: str - - def evaluate(self, ctx: PredicateContext) -> bool: - return ctx.task_name == self.name - - -def by_task(name: str) -> Predicate: - """Allow only jobs for the given task name (full module-qualified).""" - if not name: - raise ValueError("task name must be non-empty") - return _ByTask(name=name) - - -@dataclass(frozen=True) -class _ByPriorityAtLeast(Predicate): - threshold: int - - def evaluate(self, ctx: PredicateContext) -> bool: - return ctx.priority >= self.threshold +_MISSING = object() -def by_priority_at_least(threshold: int) -> Predicate: - """Allow jobs whose priority is ``>= threshold``.""" - return _ByPriorityAtLeast(threshold=threshold) +def _safe_lookup(node: Any, key: str) -> Any: + if isinstance(node, dict): + return node.get(key, _MISSING) + if isinstance(node, (list, tuple)): + try: + return node[int(key)] + except (ValueError, IndexError): + return _MISSING + return getattr(node, key, _MISSING) @dataclass(frozen=True) -class _RetryCountUnder(Predicate): - limit: int - - def evaluate(self, ctx: PredicateContext) -> bool: - return ctx.retry_count < self.limit - +class PayloadMatches(Predicate): + """Match a value in ``args``/``kwargs`` by dotted path.""" -def retry_count_under(limit: int) -> Predicate: - """Allow jobs whose retry counter is strictly less than ``limit``.""" - if limit < 0: - raise ValueError("limit must be >= 0") - return _RetryCountUnder(limit=limit) + OP: ClassVar[str | None] = "payload_matches" + path: str = "" + expected: Any = None + _segments: tuple[str, ...] = field(default=(), init=False, repr=False, compare=False) -@dataclass(frozen=True) -class _PayloadMatches(Predicate): - path: tuple[str, ...] - expected: Any + def __post_init__(self) -> None: + if not self.path: + raise PredicateValidationError("payload_matches: path must be non-empty") + object.__setattr__(self, "_segments", tuple(self.path.split("."))) def evaluate(self, ctx: PredicateContext) -> bool: node: Any = {"args": ctx.args, "kwargs": ctx.kwargs} - for segment in self.path: + for segment in self._segments: node = _safe_lookup(node, segment) if node is _MISSING: return False return bool(node == self.expected) -_MISSING = object() - - -def _safe_lookup(node: Any, key: str) -> Any: - if isinstance(node, dict): - return node.get(key, _MISSING) - if isinstance(node, (list, tuple)): - try: - return node[int(key)] - except (ValueError, IndexError): - return _MISSING - return getattr(node, key, _MISSING) - - def payload_matches(path: str, expected: Any) -> Predicate: """Match a value in args/kwargs by dotted path. @@ -107,6 +73,4 @@ def payload_matches(path: str, expected: Any) -> Predicate: * ``"args.0"`` → ``ctx.args[0]`` * ``"kwargs.config.region"`` → ``ctx.kwargs["config"]["region"]`` """ - if not path: - raise ValueError("path must be non-empty") - return _PayloadMatches(path=tuple(path.split(".")), expected=expected) + return PayloadMatches(path=path, expected=expected) diff --git a/py_src/taskito/predicates/recipes/config.py b/py_src/taskito/predicates/recipes/config.py index 7c2dd87..e1b2260 100644 --- a/py_src/taskito/predicates/recipes/config.py +++ b/py_src/taskito/predicates/recipes/config.py @@ -3,14 +3,15 @@ from __future__ import annotations import os -from dataclasses import dataclass -from typing import TYPE_CHECKING +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, ClassVar from taskito.predicates.core import Predicate from taskito.predicates.providers import ( FeatureFlagProvider, env_feature_flag_provider, ) +from taskito.predicates.registry import PredicateValidationError if TYPE_CHECKING: from taskito.predicates.context import PredicateContext @@ -20,8 +21,16 @@ @dataclass(frozen=True) -class _EnvVarTruthy(Predicate): - name: str +class EnvVarTruthy(Predicate): + """Allow when env var ``name`` is set to a truthy value.""" + + OP: ClassVar[str | None] = "env_var_truthy" + + name: str = "" + + def __post_init__(self) -> None: + if not self.name: + raise PredicateValidationError("env_var_truthy: name must be non-empty") def evaluate(self, ctx: PredicateContext) -> bool: return os.environ.get(self.name, "").strip().lower() in _TRUTHY @@ -29,27 +38,94 @@ def evaluate(self, ctx: PredicateContext) -> bool: def env_var_truthy(name: str) -> Predicate: """Allow when env var ``name`` is set to ``1``/``true``/``yes``/``on``.""" + return EnvVarTruthy(name=name) + + +# Per-process feature-flag providers, addressable by stable string id. +# Allows ``feature_flag("billing", provider="my-ld")`` to round-trip +# through JSON/string forms — the provider object itself isn't JSON, +# but its registered name is. +_FF_PROVIDERS: dict[str, FeatureFlagProvider] = {} + + +def register_feature_flag_provider(name: str, provider: FeatureFlagProvider) -> None: + """Register a :class:`FeatureFlagProvider` under a stable string id. + + Once registered, ``feature_flag("flag", provider="")`` + serializes and deserializes cleanly. The default ``"env"`` provider + is always available. + """ if not name: - raise ValueError("env var name must be non-empty") - return _EnvVarTruthy(name=name) + raise PredicateValidationError("provider name must be non-empty") + _FF_PROVIDERS[name] = provider + + +def _resolve_ff_provider( + value: FeatureFlagProvider | str | None, +) -> tuple[FeatureFlagProvider, str]: + """Return (provider, stable_name).""" + if value is None: + return env_feature_flag_provider(), "env" + if isinstance(value, str): + if value == "env": + return env_feature_flag_provider(), "env" + try: + return _FF_PROVIDERS[value], value + except KeyError: + raise PredicateValidationError( + f"unknown feature-flag provider {value!r}; register it via " + "register_feature_flag_provider() before deserializing" + ) from None + # Instance — best-effort reverse lookup; falls back to a synthetic + # marker that fails clean round-trip. + for stored_name, stored in _FF_PROVIDERS.items(): + if stored is value: + return value, stored_name + return value, "" @dataclass(frozen=True) -class _FeatureFlag(Predicate): - flag: str - provider: FeatureFlagProvider +class FeatureFlag(Predicate): + """Allow when feature flag ``flag`` is enabled via ``provider``.""" + + OP: ClassVar[str | None] = "feature_flag" + + flag: str = "" + provider: str = "env" + _resolved: FeatureFlagProvider | None = field( + default=None, init=False, repr=False, compare=False + ) + + def __post_init__(self) -> None: + if not self.flag: + raise PredicateValidationError("feature_flag: flag must be non-empty") + resolved, _name = _resolve_ff_provider(self.provider) + object.__setattr__(self, "_resolved", resolved) def evaluate(self, ctx: PredicateContext) -> bool: - return bool(self.provider.is_enabled(self.flag, ctx)) + assert self._resolved is not None # set in __post_init__ + return bool(self._resolved.is_enabled(self.flag, ctx)) + def to_dict(self) -> dict[str, Any]: + return {"op": self.OP, "flag": self.flag, "provider": self.provider} -def feature_flag(name: str, *, provider: FeatureFlagProvider | None = None) -> Predicate: + +def feature_flag( + name: str, + *, + provider: FeatureFlagProvider | str | None = None, +) -> Predicate: """Allow when feature flag ``name`` is enabled. - Defaults to an env-var-backed provider with prefix ``FF_`` (e.g. - ``feature_flag("new-billing")`` reads ``FF_NEW-BILLING``). Pass a - custom provider to integrate LaunchDarkly, Statsig, etc. + ``provider`` may be: + + * ``None`` (default) — env-var-backed reading ``FF_``. + * A registered provider name (e.g. ``"launchdarkly"``) previously + bound via :func:`register_feature_flag_provider`. + * A :class:`FeatureFlagProvider` instance — used directly. If the + instance was registered, serialization round-trips by name; + otherwise serialization emits ``""`` and re-deserializing + requires re-binding. """ - if not name: - raise ValueError("flag name must be non-empty") - return _FeatureFlag(flag=name, provider=provider or env_feature_flag_provider()) + _resolved, stable_name = _resolve_ff_provider(provider) + return FeatureFlag(flag=name, provider=stable_name) diff --git a/py_src/taskito/predicates/recipes/system.py b/py_src/taskito/predicates/recipes/system.py index 0cbc4c0..9909615 100644 --- a/py_src/taskito/predicates/recipes/system.py +++ b/py_src/taskito/predicates/recipes/system.py @@ -1,9 +1,21 @@ -"""Predicates that read live system state from the queue.""" +"""Observational predicates over live queue state. + +These recipes **read** state the Rust scheduler maintains. They are not +primary enforcement: hard backpressure belongs in +``@queue.task(max_concurrent=...)``, ``rate_limit=...``, or +``circuit_breaker=...`` — all enforced atomically in the Rust poller. + +``queue_paused`` is kept as a *defensive* composable: e.g. +``~queue_paused() & is_business_hours()`` lets a middleware skip a +gauge update when the queue is administratively paused. It does not +replace ``Queue.pause()`` / ``Queue.resume()``, which are the +authoritative pause mechanism. +""" from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, ClassVar from taskito.predicates.core import Predicate @@ -12,66 +24,21 @@ @dataclass(frozen=True) -class _QueueSizeUnder(Predicate): - limit: int - queue_name: str | None - - def evaluate(self, ctx: PredicateContext) -> bool: - return ctx.queue_size(self.queue_name) < self.limit - +class QueuePaused(Predicate): + """True when the given queue is currently paused. -def queue_size_under(limit: int, *, queue: str | None = None) -> Predicate: - """Allow only when the queue has fewer than ``limit`` pending jobs. - - ``queue=None`` (the default) inspects the job's own queue. Pass an - explicit name to gate based on a different queue. + Reads ``list_paused_queues()`` from storage. Typically used + inverted: ``~queue_paused()``. """ - if limit <= 0: - raise ValueError("limit must be > 0") - return _QueueSizeUnder(limit=limit, queue_name=queue) + OP: ClassVar[str | None] = "queue_paused" -@dataclass(frozen=True) -class _QueuePaused(Predicate): - queue_name: str | None + queue: str | None = None def evaluate(self, ctx: PredicateContext) -> bool: - return ctx.queue_paused(self.queue_name) + return ctx.queue_paused(self.queue) def queue_paused(queue: str | None = None) -> Predicate: - """True when the queue is currently paused. - - Typically used inverted: ``~queue_paused()``. - """ - return _QueuePaused(queue_name=queue) - - -@dataclass(frozen=True) -class _ErrorRateUnder(Predicate): - max_rate: float - - def evaluate(self, ctx: PredicateContext) -> bool: - stats = ctx.stats() - if not stats: - return True - completed = int(stats.get("completed", 0)) - failed = int(stats.get("failed", 0)) - dead = int(stats.get("dead", 0)) - total = completed + failed + dead - if total <= 0: - return True - rate = (failed + dead) / total - return rate < self.max_rate - - -def error_rate_under(max_rate: float) -> Predicate: - """Allow only when the global failure ratio is below ``max_rate``. - - ``max_rate`` is a fraction in (0, 1]. The metric is computed from - backend-wide stats: ``(failed + dead) / (completed + failed + dead)``. - When no jobs have run yet, the predicate allows. - """ - if not 0.0 < max_rate <= 1.0: - raise ValueError("max_rate must be in (0, 1]") - return _ErrorRateUnder(max_rate=max_rate) + """True when ``queue`` (defaults to the job's own queue) is paused.""" + return QueuePaused(queue=queue) diff --git a/py_src/taskito/predicates/recipes/time.py b/py_src/taskito/predicates/recipes/time.py index 252cf6f..a6c651d 100644 --- a/py_src/taskito/predicates/recipes/time.py +++ b/py_src/taskito/predicates/recipes/time.py @@ -1,17 +1,18 @@ """Time-based predicates. -All recipes work with timezone-aware datetimes. ``tz`` arguments accept an -IANA timezone string (e.g. ``"US/Pacific"``); ``None`` means UTC. +All recipes work with timezone-aware datetimes. ``tz`` arguments accept +an IANA timezone string (e.g. ``"US/Pacific"``); ``None`` means UTC. """ from __future__ import annotations from dataclasses import dataclass -from datetime import datetime, timedelta, timezone -from typing import TYPE_CHECKING +from datetime import datetime, timedelta, timezone, tzinfo +from typing import TYPE_CHECKING, ClassVar from taskito.predicates.core import Predicate from taskito.predicates.outcomes import Defer +from taskito.predicates.registry import PredicateValidationError if TYPE_CHECKING: from taskito.predicates.context import PredicateContext @@ -30,52 +31,83 @@ class ZoneInfoNotFoundError(Exception): # type: ignore[no-redef] _ONE_DAY = timedelta(days=1) -def _resolve_tz(tz: str | None) -> timezone: - """Resolve a tz string to a ``tzinfo``-compatible object.""" +def _resolve_tz(tz: str | None) -> tzinfo: if tz is None: return timezone.utc if not _HAS_ZONEINFO: - raise RuntimeError("zoneinfo is not available; pass tz=None or install Python 3.9+") + raise PredicateValidationError( + "zoneinfo is not available; pass tz=None or install Python 3.9+" + ) try: - return ZoneInfo(tz) # type: ignore[return-value] + return ZoneInfo(tz) except ZoneInfoNotFoundError as exc: - raise ValueError(f"unknown timezone: {tz!r}") from exc + raise PredicateValidationError(f"unknown timezone: {tz!r}") from exc + + +def _parse_hhmm(value: str) -> tuple[int, int]: + parts = value.split(":") + if len(parts) != 2: + raise PredicateValidationError(f"expected 'HH:MM', got {value!r}") + try: + h, m = int(parts[0]), int(parts[1]) + except ValueError as exc: + raise PredicateValidationError(f"expected 'HH:MM', got {value!r}") from exc + if not 0 <= h <= 23 or not 0 <= m <= 59: + raise PredicateValidationError(f"hour/minute out of range: {value!r}") + return h, m -def _seconds_until_window(now_local: datetime, start_hour: int, start_minute: int) -> float: - """Seconds from ``now_local`` until the next occurrence of (h, m).""" - target = now_local.replace(hour=start_hour, minute=start_minute, second=0, microsecond=0) - if target <= now_local: - target = target + _ONE_DAY +def _seconds_until(now_local: datetime, target: datetime) -> float: return max(1.0, (target - now_local).total_seconds()) +# -- is_business_hours --------------------------------------------------- + + @dataclass(frozen=True) -class _IsBusinessHours(Predicate): - """True Mon-Fri between 09:00 and 17:00 in the configured timezone.""" +class IsBusinessHours(Predicate): + """Allow only Mon-Fri ``start_hour..end_hour`` in the configured tz. + + Outside the window the predicate returns :class:`Defer` aimed at the + next opening, so jobs land neatly at the next business start. + """ + + OP: ClassVar[str | None] = "is_business_hours" start_hour: int = 9 end_hour: int = 17 tz: str | None = None weekdays_only: bool = True + def __post_init__(self) -> None: + if not (0 <= self.start_hour < 24 and 0 < self.end_hour <= 24): + raise PredicateValidationError( + f"is_business_hours: hours out of range ({self.start_hour}, {self.end_hour})" + ) + if self.start_hour >= self.end_hour: + raise PredicateValidationError( + f"is_business_hours: start_hour ({self.start_hour}) must be < " + f"end_hour ({self.end_hour})" + ) + # Validate tz at construction. + _resolve_tz(self.tz) + def evaluate(self, ctx: PredicateContext) -> bool | Defer: - tz_info = _resolve_tz(self.tz) - local = ctx.now().astimezone(tz_info) + local = ctx.now().astimezone(_resolve_tz(self.tz)) if self.weekdays_only and local.weekday() >= 5: - # Sat / Sun → defer until Monday at start_hour days_to_monday = 7 - local.weekday() - defer = local.replace( + target = (local + timedelta(days=days_to_monday)).replace( hour=self.start_hour, minute=0, second=0, microsecond=0 - ) + timedelta(days=days_to_monday) - return Defer(seconds=max(1.0, (defer - local).total_seconds())) + ) + return Defer(seconds=_seconds_until(local, target)) if local.hour < self.start_hour: - return Defer(seconds=_seconds_until_window(local, self.start_hour, 0)) + target = local.replace(hour=self.start_hour, minute=0, second=0, microsecond=0) + return Defer(seconds=_seconds_until(local, target)) if local.hour >= self.end_hour: - # Defer to next day at start_hour - next_day = local + _ONE_DAY - target = next_day.replace(hour=self.start_hour, minute=0, second=0, microsecond=0) - return Defer(seconds=max(1.0, (target - local).total_seconds())) + target = (local + _ONE_DAY).replace( + hour=self.start_hour, minute=0, second=0, microsecond=0 + ) + return Defer(seconds=_seconds_until(local, target)) return True @@ -86,138 +118,168 @@ def is_business_hours( tz: str | None = None, weekdays_only: bool = True, ) -> Predicate: - """Allow only during business hours; otherwise :class:`Defer` to the next window. - - Defaults to 09:00-17:00 Mon-Fri in UTC. Pass ``tz="US/Pacific"`` (or - any IANA name) for a local-time window. - """ - if not 0 <= start_hour < 24 or not 0 < end_hour <= 24 or start_hour >= end_hour: - raise ValueError(f"invalid business-hours window: {start_hour}-{end_hour}") - return _IsBusinessHours( + """Allow during business hours; defer to the next window otherwise.""" + return IsBusinessHours( start_hour=start_hour, end_hour=end_hour, tz=tz, weekdays_only=weekdays_only ) +# -- is_weekend ---------------------------------------------------------- + + @dataclass(frozen=True) -class _IsWeekend(Predicate): +class IsWeekend(Predicate): + """True on Saturday or Sunday in the configured timezone.""" + + OP: ClassVar[str | None] = "is_weekend" + tz: str | None = None + def __post_init__(self) -> None: + _resolve_tz(self.tz) + def evaluate(self, ctx: PredicateContext) -> bool: - tz_info = _resolve_tz(self.tz) - local = ctx.now().astimezone(tz_info) + local = ctx.now().astimezone(_resolve_tz(self.tz)) return local.weekday() >= 5 def is_weekend(*, tz: str | None = None) -> Predicate: - """True on Saturday or Sunday in the configured timezone.""" - return _IsWeekend(tz=tz) + """True when the current local day is Saturday or Sunday.""" + return IsWeekend(tz=tz) + + +# -- in_time_window ------------------------------------------------------ @dataclass(frozen=True) -class _InTimeWindow(Predicate): - start_hour: int - start_minute: int - end_hour: int - end_minute: int +class InTimeWindow(Predicate): + """Allow during ``[start, end)``; defer otherwise. + + Stored as ``"HH:MM"`` strings so the AST serializes cleanly. + """ + + OP: ClassVar[str | None] = "in_time_window" + + start: str = "00:00" + end: str = "23:59" tz: str | None = None + def __post_init__(self) -> None: + sh, sm = _parse_hhmm(self.start) + eh, em = _parse_hhmm(self.end) + if (sh, sm) >= (eh, em): + raise PredicateValidationError( + f"in_time_window: start ({self.start}) must be before end ({self.end})" + ) + _resolve_tz(self.tz) + def evaluate(self, ctx: PredicateContext) -> bool | Defer: - tz_info = _resolve_tz(self.tz) - local = ctx.now().astimezone(tz_info) - start_minutes = self.start_hour * 60 + self.start_minute - end_minutes = self.end_hour * 60 + self.end_minute - cur_minutes = local.hour * 60 + local.minute - if start_minutes <= cur_minutes < end_minutes: + sh, sm = _parse_hhmm(self.start) + eh, em = _parse_hhmm(self.end) + local = ctx.now().astimezone(_resolve_tz(self.tz)) + cur = local.hour * 60 + local.minute + start_m = sh * 60 + sm + end_m = eh * 60 + em + if start_m <= cur < end_m: return True - if cur_minutes < start_minutes: - return Defer(seconds=(start_minutes - cur_minutes) * 60.0) - # past end — defer to tomorrow's start - next_day = local + _ONE_DAY - target = next_day.replace( - hour=self.start_hour, minute=self.start_minute, second=0, microsecond=0 - ) - return Defer(seconds=max(1.0, (target - local).total_seconds())) + if cur < start_m: + return Defer(seconds=(start_m - cur) * 60.0) + target = (local + _ONE_DAY).replace(hour=sh, minute=sm, second=0, microsecond=0) + return Defer(seconds=_seconds_until(local, target)) def in_time_window(start: str, end: str, *, tz: str | None = None) -> Predicate: - """Allow during ``[start, end)`` in the configured timezone; defer otherwise. + """Allow during ``[start, end)`` in the configured timezone.""" + return InTimeWindow(start=start, end=end, tz=tz) - ``start`` and ``end`` are ``"HH:MM"`` strings. End is exclusive. - """ - sh, sm = _parse_hhmm(start) - eh, em = _parse_hhmm(end) - if (sh, sm) >= (eh, em): - raise ValueError(f"start ({start}) must be before end ({end})") - return _InTimeWindow(start_hour=sh, start_minute=sm, end_hour=eh, end_minute=em, tz=tz) +# -- after / before ------------------------------------------------------ -def _parse_hhmm(value: str) -> tuple[int, int]: - parts = value.split(":") - if len(parts) != 2: - raise ValueError(f"expected 'HH:MM', got {value!r}") + +def _iso(dt: datetime) -> str: + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt.isoformat() + + +def _parse_iso(value: str) -> datetime: try: - h, m = int(parts[0]), int(parts[1]) + dt = datetime.fromisoformat(value) except ValueError as exc: - raise ValueError(f"expected 'HH:MM', got {value!r}") from exc - if not 0 <= h <= 23 or not 0 <= m <= 59: - raise ValueError(f"hour/minute out of range: {value!r}") - return h, m + raise PredicateValidationError(f"invalid ISO datetime: {value!r}") from exc + if dt.tzinfo is None: + dt = dt.replace(tzinfo=timezone.utc) + return dt @dataclass(frozen=True) -class _After(Predicate): - target: datetime +class After(Predicate): + """Allow only when wall-clock time is at or past ``target``.""" + + OP: ClassVar[str | None] = "after" + + target: str = "" + + def __post_init__(self) -> None: + if not self.target: + raise PredicateValidationError("after: target must be non-empty") + _parse_iso(self.target) def evaluate(self, ctx: PredicateContext) -> bool | Defer: + target = _parse_iso(self.target) now = ctx.now() - if now >= self.target: + if now >= target: return True - return Defer(seconds=(self.target - now).total_seconds()) + return Defer(seconds=max(1.0, (target - now).total_seconds())) -def after(target: datetime) -> Predicate: - """Allow only when wall-clock time is at or past ``target``.""" - if target.tzinfo is None: - target = target.replace(tzinfo=timezone.utc) - return _After(target=target) +def after(target: datetime | str) -> Predicate: + """Allow only at or after ``target`` (datetime or ISO string).""" + iso = _iso(target) if isinstance(target, datetime) else target + return After(target=iso) @dataclass(frozen=True) -class _Before(Predicate): - target: datetime +class Before(Predicate): + """Allow only when wall-clock time is strictly before ``target``.""" + + OP: ClassVar[str | None] = "before" + + target: str = "" + + def __post_init__(self) -> None: + if not self.target: + raise PredicateValidationError("before: target must be non-empty") + _parse_iso(self.target) def evaluate(self, ctx: PredicateContext) -> bool: - return ctx.now() < self.target + return ctx.now() < _parse_iso(self.target) -def before(target: datetime) -> Predicate: - """Allow only when wall-clock time is strictly before ``target``.""" - if target.tzinfo is None: - target = target.replace(tzinfo=timezone.utc) - return _Before(target=target) +def before(target: datetime | str) -> Predicate: + """Allow only strictly before ``target`` (datetime or ISO string).""" + iso = _iso(target) if isinstance(target, datetime) else target + return Before(target=iso) + + +# -- in_timezone --------------------------------------------------------- @dataclass(frozen=True) -class _InTimezone(Predicate): - """Validates that the recipe-configured timezone resolves on this host. +class InTimezone(Predicate): + """No-op predicate that validates ``tz`` resolves on this host.""" - Returns ``True`` unconditionally — useful as a composable building - block when paired with other time predicates. The cheaper alternative - is to call :func:`_resolve_tz` at recipe-construction time, which is - what this does. - """ + OP: ClassVar[str | None] = "in_timezone" - tz: str + tz: str = "UTC" + + def __post_init__(self) -> None: + _resolve_tz(self.tz) def evaluate(self, ctx: PredicateContext) -> bool: return True def in_timezone(tz: str) -> Predicate: - """No-op predicate that validates ``tz`` resolves on this host. - - Mostly useful as a guard at decoration time: passing an unknown tz - here raises immediately rather than at first execution. - """ - _resolve_tz(tz) - return _InTimezone(tz=tz) + """Validate ``tz`` is installed; always allows at runtime.""" + return InTimezone(tz=tz) diff --git a/py_src/taskito/predicates/registry.py b/py_src/taskito/predicates/registry.py new file mode 100644 index 0000000..b73f976 --- /dev/null +++ b/py_src/taskito/predicates/registry.py @@ -0,0 +1,75 @@ +"""Registry of named predicate ops. + +Every node in the predicate AST has a stable string identifier (the +``OP`` class var). The registry maps ``OP`` → class, enabling JSON +deserialization, string parsing, and dashboard lookup. Built-in ops +self-register via ``__init_subclass__`` on :class:`~taskito.predicates.core.Predicate`. + +User-defined predicates register through :func:`register_predicate` (or +the ``Queue.register_predicate`` decorator that calls into it). +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from taskito.predicates.core import Predicate + + +class PredicateValidationError(ValueError): + """Raised when a predicate description (JSON / string) cannot be resolved. + + Common causes: unknown ``op`` name, missing required field, type + mismatch in declared field, or schema drift between writer and + reader. + """ + + +class PredicateRegistry: + """Closed set of named predicate op classes.""" + + __slots__ = ("_ops",) + + def __init__(self) -> None: + self._ops: dict[str, type[Predicate]] = {} + + def register(self, op: str, cls: type[Predicate], *, replace: bool = False) -> None: + if not op: + raise PredicateValidationError("op name must be non-empty") + if not replace and op in self._ops and self._ops[op] is not cls: + existing = self._ops[op] + raise PredicateValidationError( + f"op {op!r} already registered to " + f"{existing.__module__}.{existing.__qualname__}; " + f"refusing to overwrite without replace=True" + ) + self._ops[op] = cls + + def lookup(self, op: str) -> type[Predicate]: + try: + return self._ops[op] + except KeyError: + known = ", ".join(sorted(self._ops)) or "" + raise PredicateValidationError( + f"unknown predicate op: {op!r} (known ops: {known})" + ) from None + + def names(self) -> list[str]: + return sorted(self._ops) + + def __contains__(self, op: str) -> bool: + return op in self._ops + + +_DEFAULT_REGISTRY = PredicateRegistry() + + +def default_registry() -> PredicateRegistry: + """Return the process-wide default registry used by built-in ops.""" + return _DEFAULT_REGISTRY + + +def register_predicate(op: str, cls: type[Predicate], *, replace: bool = False) -> None: + """Register ``cls`` under ``op`` in the default registry.""" + _DEFAULT_REGISTRY.register(op, cls, replace=replace) diff --git a/tests/python/test_predicates_core.py b/tests/python/test_predicates_core.py index 9056f2a..93a0234 100644 --- a/tests/python/test_predicates_core.py +++ b/tests/python/test_predicates_core.py @@ -129,7 +129,7 @@ def test_repr_uses_operator_syntax() -> None: text = repr(expr) assert "&" in text assert "|" in text - assert "~" in text + assert "!" in text # DSL uses ! for negation # -- Fail-closed ----------------------------------------------------------- diff --git a/tests/python/test_predicates_dsl.py b/tests/python/test_predicates_dsl.py new file mode 100644 index 0000000..02ba7ee --- /dev/null +++ b/tests/python/test_predicates_dsl.py @@ -0,0 +1,146 @@ +"""JSON serialization and string-parser round-trip tests for the predicate DSL.""" + +from __future__ import annotations + +import pytest + +from taskito.predicates import ( + Predicate, + PredicateValidationError, + after, + before, + env_var_truthy, + feature_flag, + format_predicate, + in_time_window, + is_business_hours, + is_weekend, + parse, + payload_matches, + queue_paused, +) + +# Every builtin recipe with sample args. ``format(parse(x)) == x`` is +# the contract these tests defend. +_RECIPES: list[Predicate] = [ + is_business_hours(), + is_business_hours(start_hour=10, end_hour=18, weekdays_only=False), + is_weekend(), + in_time_window("09:00", "17:00"), + after("2026-05-11T09:00:00+00:00"), + before("2026-12-31T00:00:00+00:00"), + queue_paused(), + queue_paused(queue="bulk"), + payload_matches("kwargs.tenant", "acme"), + payload_matches("args.0", 42), + env_var_truthy("MY_FLAG"), + feature_flag("billing"), +] + + +@pytest.mark.parametrize("p", _RECIPES, ids=lambda p: p.OP) +def test_recipe_json_round_trip(p: Predicate) -> None: + blob = p.to_dict() + restored = Predicate.from_dict(blob) + assert restored.to_dict() == blob + + +@pytest.mark.parametrize("p", _RECIPES, ids=lambda p: p.OP) +def test_recipe_string_round_trip(p: Predicate) -> None: + s = format_predicate(p) + restored = parse(s) + assert format_predicate(restored) == s + assert restored.to_dict() == p.to_dict() + + +def test_composition_round_trips_through_json() -> None: + p = is_business_hours() & ~queue_paused() | payload_matches( + path="kwargs.tenant", expected="acme" + ) + snap = p.to_dict() + assert snap["op"] == "or" + restored = Predicate.from_dict(snap) + assert restored.to_dict() == snap + + +def test_composition_round_trips_through_string() -> None: + p = is_business_hours() & ~queue_paused() | payload_matches( + path="kwargs.tenant", expected="acme" + ) + s = format_predicate(p) + assert "&" in s and "|" in s and "!" in s + restored = parse(s) + assert restored.to_dict() == p.to_dict() + + +def test_parse_supports_keyword_aliases() -> None: + # "and" / "or" / "not" tokens should be equivalent to & / | / ! + p1 = parse("is_business_hours() and not queue_paused()") + p2 = parse("is_business_hours() & !queue_paused()") + assert p1.to_dict() == p2.to_dict() + + +def test_parse_supports_parenthesised_groups() -> None: + p = parse( + '(is_business_hours() | payload_matches(path="kwargs.tenant", expected="acme")) ' + "& !queue_paused()" + ) + snap = p.to_dict() + assert snap["op"] == "and" + assert snap["args"][0]["op"] == "or" + + +def test_parse_unknown_op_raises() -> None: + with pytest.raises(PredicateValidationError): + parse("not_a_real_op()") + + +def test_parse_malformed_string_raises() -> None: + with pytest.raises(PredicateValidationError): + parse("queue_paused(") + + +def test_parse_rejects_unknown_kwarg() -> None: + with pytest.raises(PredicateValidationError): + parse('queue_paused(unknown_field="x")') + + +def test_from_dict_unknown_op_raises() -> None: + with pytest.raises(PredicateValidationError): + Predicate.from_dict({"op": "no_such_op"}) + + +def test_from_dict_missing_op_raises() -> None: + with pytest.raises(PredicateValidationError): + Predicate.from_dict({"no": "op"}) + + +def test_from_dict_non_dict_raises() -> None: + with pytest.raises(PredicateValidationError): + Predicate.from_dict("not a dict") # type: ignore[arg-type] + + +def test_callable_predicate_cannot_serialize() -> None: + from taskito.predicates import coerce_predicate + + p = coerce_predicate(lambda ctx: True) + assert p is not None + with pytest.raises(PredicateValidationError): + p.to_dict() + + +def test_literal_types_round_trip_through_string() -> None: + # int, float, bool, str, null, list — all parseable. + inputs = [ + 'payload_matches(path="x", expected=1)', + 'payload_matches(path="x", expected=1.5)', + 'payload_matches(path="x", expected=true)', + 'payload_matches(path="x", expected=null)', + 'payload_matches(path="x", expected="hi")', + 'payload_matches(path="x", expected=[1, 2, 3])', + ] + for s in inputs: + node = parse(s) + # Round-trip JSON keeps the value as-is. + restored = Predicate.from_dict(node.to_dict()) + assert restored.to_dict() == node.to_dict() diff --git a/tests/python/test_predicates_persistence.py b/tests/python/test_predicates_persistence.py new file mode 100644 index 0000000..8a4d47d --- /dev/null +++ b/tests/python/test_predicates_persistence.py @@ -0,0 +1,108 @@ +"""Queue.list_predicates() and Queue.register_predicate() tests.""" + +from __future__ import annotations + +from typing import Any + +import pytest + +from taskito.app import Queue +from taskito.predicates import ( + Predicate, + PredicateContext, + PredicateValidationError, + is_business_hours, + payload_matches, + queue_paused, +) + + +def test_list_predicates_returns_serialized_dicts(queue: Queue) -> None: + @queue.task(predicate=is_business_hours() & ~queue_paused()) + def gated() -> int: + return 1 + + @queue.task() + def ungated() -> int: + return 1 + + snap = queue.list_predicates() + assert gated.name in snap + assert ungated.name not in snap + + blob = snap[gated.name] + assert blob is not None + assert blob["op"] == "and" + + # Round-trip: from_dict produces an equivalent predicate. + restored = Predicate.from_dict(blob) + assert restored.to_dict() == blob + + +def test_predicate_for_returns_none_for_unregistered_task(queue: Queue) -> None: + @queue.task() + def t() -> int: + return 1 + + assert queue.predicate_for(t.name) is None + + +def test_list_predicates_includes_none_for_bare_callable(queue: Queue) -> None: + @queue.task(predicate=lambda ctx: True) + def t() -> int: + return 1 + + assert queue.list_predicates()[t.name] is None + + +def test_register_predicate_decorator(queue: Queue) -> None: + @queue.register_predicate("min_priority") + class MinPriority(Predicate): + def __init__(self, threshold: int = 0) -> None: + self.threshold = threshold + + def evaluate(self, ctx: PredicateContext) -> bool: + return ctx.priority >= self.threshold + + def to_dict(self) -> dict[str, Any]: + return {"op": "min_priority", "threshold": self.threshold} + + @classmethod + def _from_kwargs(cls, kwargs: dict[str, Any]) -> Predicate: + return cls(**kwargs) + + # Class is now in the registry; from_dict resolves it. + p = Predicate.from_dict({"op": "min_priority", "threshold": 5}) + assert isinstance(p, MinPriority) + assert p.threshold == 5 + + +def test_register_predicate_rejects_non_predicate(queue: Queue) -> None: + with pytest.raises(PredicateValidationError): + + @queue.register_predicate("bad") + class Bad: + pass + + +def test_register_predicate_rejects_op_mismatch(queue: Queue) -> None: + with pytest.raises(PredicateValidationError): + + @queue.register_predicate("decorator_name") + class Mismatched(Predicate): + OP = "different_name" + + def evaluate(self, ctx: PredicateContext) -> bool: + return True + + +def test_payload_matches_registered_serializes(queue: Queue) -> None: + @queue.task(predicate=payload_matches("kwargs.tenant", "acme")) + def t(tenant: str = "") -> int: + return 1 + + blob = queue.predicate_for(t.name) + assert blob is not None + assert blob["op"] == "payload_matches" + assert blob["path"] == "kwargs.tenant" + assert blob["expected"] == "acme" diff --git a/tests/python/test_predicates_recipes.py b/tests/python/test_predicates_recipes.py index 2bfdf9b..c42aef2 100644 --- a/tests/python/test_predicates_recipes.py +++ b/tests/python/test_predicates_recipes.py @@ -1,4 +1,4 @@ -"""Built-in predicate recipe tests.""" +"""Built-in predicate recipe tests (v2).""" from __future__ import annotations @@ -10,12 +10,11 @@ from taskito.predicates import ( Defer, + Predicate, PredicateContext, + PredicateValidationError, after, before, - by_priority_at_least, - by_queue, - by_task, env_var_truthy, feature_flag, in_time_window, @@ -24,10 +23,8 @@ is_weekend, payload_matches, queue_paused, - queue_size_under, - retry_count_under, + register_feature_flag_provider, ) -from taskito.predicates.providers import FeatureFlagProvider def _ctx(**overrides: object) -> PredicateContext: @@ -66,12 +63,12 @@ def test_in_time_window_parses_hh_mm() -> None: def test_in_time_window_rejects_inverted_range() -> None: - with pytest.raises(ValueError): + with pytest.raises(PredicateValidationError): in_time_window("17:00", "09:00") def test_in_time_window_rejects_bad_format() -> None: - with pytest.raises(ValueError): + with pytest.raises(PredicateValidationError): in_time_window("nine", "five") @@ -87,7 +84,6 @@ def test_in_time_window_returns_defer_when_outside() -> None: def test_is_business_hours_returns_defer_on_weekend() -> None: pred = is_business_hours(tz=None) - # 2026-05-09 is a Saturday sat = datetime(2026, 5, 9, 12, 0, tzinfo=timezone.utc) ctx = _ctx() with mock.patch.object(ctx, "now", return_value=sat): @@ -98,7 +94,6 @@ def test_is_business_hours_returns_defer_on_weekend() -> None: def test_is_business_hours_allows_during_window() -> None: pred = is_business_hours(tz=None) - # 2026-05-11 is a Monday weekday_noon = datetime(2026, 5, 11, 12, 0, tzinfo=timezone.utc) ctx = _ctx() with mock.patch.object(ctx, "now", return_value=weekday_noon): @@ -106,7 +101,7 @@ def test_is_business_hours_allows_during_window() -> None: def test_is_business_hours_rejects_invalid_window() -> None: - with pytest.raises(ValueError): + with pytest.raises(PredicateValidationError): is_business_hours(start_hour=20, end_hour=10) @@ -123,44 +118,11 @@ def test_is_weekend_uses_utc_by_default() -> None: def test_in_timezone_validates_at_construction() -> None: in_timezone("UTC") - with pytest.raises(ValueError): + with pytest.raises(PredicateValidationError): in_timezone("Mars/Olympus_Mons") -# -- Attribute recipes ------------------------------------------------------ - - -def test_by_queue_matches_name() -> None: - assert by_queue("default").evaluate(_ctx()) is True - assert by_queue("other").evaluate(_ctx()) is False - - -def test_by_queue_rejects_empty() -> None: - with pytest.raises(ValueError): - by_queue("") - - -def test_by_task_matches_name() -> None: - assert by_task("t").evaluate(_ctx()) is True - assert by_task("other").evaluate(_ctx()) is False - - -def test_by_priority_at_least() -> None: - pred = by_priority_at_least(5) - assert pred.evaluate(_ctx(priority=10)) is True - assert pred.evaluate(_ctx(priority=5)) is True - assert pred.evaluate(_ctx(priority=4)) is False - - -def test_retry_count_under_validates() -> None: - with pytest.raises(ValueError): - retry_count_under(-1) - - -def test_retry_count_under_compares_strict() -> None: - pred = retry_count_under(3) - assert pred.evaluate(_ctx(retry_count=2)) is True - assert pred.evaluate(_ctx(retry_count=3)) is False +# -- Payload recipe --------------------------------------------------------- def test_payload_matches_kwargs() -> None: @@ -183,20 +145,11 @@ def test_payload_matches_nested_dict() -> None: def test_payload_matches_rejects_empty_path() -> None: - with pytest.raises(ValueError): + with pytest.raises(PredicateValidationError): payload_matches("", "x") -# -- System-state recipes --------------------------------------------------- - - -def test_queue_size_under_uses_context_helper() -> None: - pred = queue_size_under(100) - ctx = _ctx() - with mock.patch.object(ctx, "queue_size", return_value=50): - assert pred.evaluate(ctx) is True - with mock.patch.object(ctx, "queue_size", return_value=100): - assert pred.evaluate(ctx) is False +# -- Defensive system recipe ------------------------------------------------ def test_queue_paused_uses_context_helper() -> None: @@ -208,42 +161,6 @@ def test_queue_paused_uses_context_helper() -> None: assert pred.evaluate(ctx) is False -def test_queue_size_under_rejects_zero() -> None: - with pytest.raises(ValueError): - queue_size_under(0) - - -def test_error_rate_under_allows_with_no_jobs() -> None: - pred = pred = __import__("taskito.predicates", fromlist=["error_rate_under"]).error_rate_under( - 0.5 - ) - ctx = _ctx() - with mock.patch.object(ctx, "stats", return_value={}): - assert pred.evaluate(ctx) is True - - -def test_error_rate_under_compares_ratio() -> None: - from taskito.predicates import error_rate_under - - pred = error_rate_under(0.2) - ctx = _ctx() - with mock.patch.object(ctx, "stats", return_value={"completed": 90, "failed": 5, "dead": 5}): - # rate = 10/100 = 0.1 < 0.2 - assert pred.evaluate(ctx) is True - with mock.patch.object(ctx, "stats", return_value={"completed": 70, "failed": 20, "dead": 10}): - # rate = 30/100 = 0.3 > 0.2 - assert pred.evaluate(ctx) is False - - -def test_error_rate_validates_range() -> None: - from taskito.predicates import error_rate_under - - with pytest.raises(ValueError): - error_rate_under(0.0) - with pytest.raises(ValueError): - error_rate_under(1.5) - - # -- Config recipes --------------------------------------------------------- @@ -258,7 +175,7 @@ def test_env_var_truthy_reads_env() -> None: def test_env_var_truthy_rejects_empty_name() -> None: - with pytest.raises(ValueError): + with pytest.raises(PredicateValidationError): env_var_truthy("") @@ -270,21 +187,27 @@ def test_feature_flag_default_provider_reads_ff_prefix() -> None: assert pred.evaluate(_ctx()) is False -def test_feature_flag_custom_provider() -> None: +def test_feature_flag_named_provider_roundtrips() -> None: class _Stub: - def __init__(self) -> None: - self.calls: list[tuple[str, str]] = [] - def is_enabled(self, name: str, ctx: PredicateContext) -> bool: - self.calls.append((name, ctx.task_name)) return name == "yes" - stub: FeatureFlagProvider = _Stub() - assert feature_flag("yes", provider=stub).evaluate(_ctx()) is True - assert feature_flag("no", provider=stub).evaluate(_ctx()) is False - assert stub.calls == [("yes", "t"), ("no", "t")] # type: ignore[attr-defined] + register_feature_flag_provider("stub", _Stub()) + pred = feature_flag("yes", provider="stub") + assert pred.evaluate(_ctx()) is True + + # JSON round-trip preserves the provider name. + snap = pred.to_dict() + assert snap == {"op": "feature_flag", "flag": "yes", "provider": "stub"} + restored = Predicate.from_dict(snap) + assert restored.evaluate(_ctx()) is True + + +def test_feature_flag_unknown_provider_raises_on_resolve() -> None: + with pytest.raises(PredicateValidationError): + feature_flag("x", provider="never-registered") def test_feature_flag_rejects_empty_name() -> None: - with pytest.raises(ValueError): + with pytest.raises(PredicateValidationError): feature_flag("") From a43bde53c2d95fdbaa924d0394a954c36dd04a8a Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Mon, 11 May 2026 10:38:22 +0530 Subject: [PATCH 09/11] docs(predicates): rewrite guide and example around AST + JSON + string DSL --- docs/content/docs/guides/core/predicates.mdx | 432 +++++++++++------- .../more/examples/predicate-gated-jobs.mdx | 237 +++++----- 2 files changed, 394 insertions(+), 275 deletions(-) diff --git a/docs/content/docs/guides/core/predicates.mdx b/docs/content/docs/guides/core/predicates.mdx index a816332..136be63 100644 --- a/docs/content/docs/guides/core/predicates.mdx +++ b/docs/content/docs/guides/core/predicates.mdx @@ -1,312 +1,418 @@ --- title: Predicates -description: "Compose Predicate objects to gate when a task is enqueued or dispatched, with built-in recipes for time windows, queue health, feature flags, and more." +description: "A composable, serializable DSL for gating task execution. Authors compose predicates in Python; runtime evaluates them at enqueue and worker dispatch; persistence and dashboards see the same AST as JSON or string." --- -A **predicate** is a composable, fail-closed gate attached to a task. It decides — at enqueue time and again at worker dispatch — whether the job runs, is deferred, or is cancelled. The same predicates also filter which middlewares see a given job. +A **predicate** is a serializable AST node that decides — at enqueue +time and again at worker dispatch — whether a job runs, is deferred, or +is cancelled. Predicates compose with `&` / `|` / `~`, round-trip +through JSON and a small string DSL, and are inspectable from the +dashboard. ```python from taskito import Queue -from taskito.predicates import ( - is_business_hours, - queue_paused, - by_priority_at_least, -) +from taskito.predicates import is_business_hours, queue_paused, feature_flag queue = Queue() @queue.task( predicate=is_business_hours(tz="US/Pacific") & ~queue_paused() - | by_priority_at_least(8), + & feature_flag("new_billing"), on_false="defer", ) -def send_report(): ... +def send_invoice(tenant_id: str): ... ``` +## What predicates *do not* do + +Predicates are not the right place to rebuild taskito's atomic +enforcement primitives. Use the dedicated knobs: + +| Goal | Use this — *not* a predicate | +|---|---| +| Cap concurrent executions of a task | `@queue.task(max_concurrent=N)` | +| Cap concurrent executions of a queue | `Queue.set_queue_concurrency(name, N)` | +| Rate-limit a task or queue | `@queue.task(rate_limit="100/m")` | +| Trip an outage breaker | `@queue.task(circuit_breaker={...})` | +| Drain / pause a queue | `Queue.pause(name)` / `resume(name)` | +| Cap retry count or filter retries | `max_retries=`, `retry_on=`, `dont_retry_on=` | +| Route to a queue | `@queue.task(queue="emails")` | + +Each is enforced atomically in the Rust scheduler. A Python predicate +that races against them is weaker and harder to reason about. +Predicates are for **gates the Rust scheduler doesn't already provide**: +time windows, payload-driven branching, feature flags, environment +config, and custom Python logic. + ## Outcomes -A predicate's `evaluate` returns one of four values: +`evaluate()` returns one of four values: | Outcome | Meaning | |---|---| | `True` | Allow the job to proceed. | -| `False` | Deny. The task's `on_false` action decides what happens — `"defer"` (default) or `"cancel"`. | -| `Defer(seconds=N)` | Skip now, retry after `N` seconds. At enqueue time the delay is added; at dispatch time the job is re-enqueued with the delay. | -| `Cancel(reason="...")` | Permanently skip. At enqueue time raises `PredicateRejectedError`; at dispatch time marks the job cancelled. | +| `False` | Deny. The task's `on_false` decides: `"defer"` (default) or `"cancel"`. | +| `Defer(seconds=N)` | Skip now, retry after `N` seconds. | +| `Cancel(reason="…")` | Permanently skip. Raises `PredicateRejectedError` at enqueue; cancels the job at dispatch. | -A predicate that raises an exception is treated as `False` (fail-closed). The error is logged and counted in `PredicateMetrics`. +A predicate that raises is treated as `False` (fail-closed). Errors are +logged and counted on `PredicateMetrics`. ## Composition -`Predicate` overloads `&`, `|`, and `~` with short-circuit semantics: - ```python allow = is_business_hours() & ~queue_paused() -allow |= by_priority_at_least(8) # urgent jobs bypass both gates +allow_or_urgent = allow | feature_flag("urgent_bypass") ``` * `A & B` — both must allow; short-circuits if `A` denies. -* `A | B` — either can allow; short-circuits if `A` allows. If both deny, the most informative outcome wins (`Cancel` > `Defer` > `False`). -* `~A` — inverts `True`/`False`. `Defer` and `Cancel` pass through unchanged. +* `A | B` — either may allow; if both deny, the most informative + outcome wins (`Cancel` > `Defer` > `False`). +* `~A` — inverts `True`/`False`. `Defer` and `Cancel` pass through + unchanged (terminal outcomes, not booleans). ## `@queue.task` options | Parameter | Type | Default | Description | |---|---|---|---| -| `predicate` | `Predicate \| Callable \| None` | `None` | The predicate to evaluate. Plain callables receiving a `PredicateContext` are accepted. | +| `predicate` | `Predicate \| Callable \| None` | `None` | The gate. Plain callables receiving a `PredicateContext` are accepted but cannot be serialized. | | `on_false` | `"defer" \| "cancel"` | `"defer"` | What to do when the predicate returns plain `False`. | | `predicate_extras` | `dict \| None` | `None` | Static dict forwarded to the predicate via `ctx.extras`. | -| `default_defer_seconds` | `float` | `60.0` | Delay applied when `on_false="defer"` and the predicate returns plain `False`. | +| `default_defer_seconds` | `float` | `60.0` | Default delay when `on_false="defer"` and the predicate returns plain `False`. | + +## The DSL surface + +Every predicate is a node in a closed, registered AST. Three equivalent +authoring paths produce the same tree: + +### Python operators + +```python +p = is_business_hours(tz="US/Pacific") & ~queue_paused() | feature_flag("new") +``` + +### JSON + +```python +from taskito.predicates import Predicate + +p = Predicate.from_dict({ + "op": "or", + "args": [ + {"op": "and", "args": [ + {"op": "is_business_hours", "tz": "US/Pacific"}, + {"op": "not", "arg": {"op": "queue_paused"}}, + ]}, + {"op": "feature_flag", "flag": "new", "provider": "env"}, + ], +}) +``` + +### String + +```python +from taskito.predicates import parse + +p = parse( + 'is_business_hours(tz="US/Pacific") & !queue_paused() ' + '| feature_flag(flag="new", provider="env")' +) +``` + +The string DSL supports `&` / `|` / `!` and the long forms `and` / `or` +/ `not`. Literals: strings, ints, floats, `true` / `false` / `null`, +and `[lists, of, literals]`. Round-trip is stable: + +```python +from taskito.predicates import format_predicate, parse + +p1 = is_business_hours() & ~queue_paused() +s = format_predicate(p1) # 'is_business_hours() & !queue_paused()' +p2 = parse(s) +assert p1.to_dict() == p2.to_dict() +``` ## Built-in recipes -All recipes are factory functions returning a `Predicate`. Import from `taskito.predicates`. +All recipes are registered AST ops. Import from `taskito.predicates`. + +### Time + +| Recipe | Description | +|---|---| +| `is_business_hours(start=9, end=17, *, tz=None, weekdays_only=True)` | Allow during business hours; defer to the next opening otherwise. | +| `is_weekend(*, tz=None)` | True on Sat/Sun. | +| `in_time_window("09:00", "17:00", *, tz=None)` | Allow within `[start, end)`; defer otherwise. | +| `after(target)` | Defer until `target` (`datetime` or ISO string). | +| `before(target)` | Allow only strictly before `target`. | +| `in_timezone(tz)` | No-op at runtime; validates the tz string at construction time. | -### Time-based +### Payload -* `is_business_hours(start=9, end=17, *, tz=None, weekdays_only=True)` — defers to the next window when outside. -* `is_weekend(*, tz=None)` -* `in_time_window("09:00", "17:00", *, tz=None)` — `HH:MM` strings; end exclusive; defers when outside. -* `after(target_datetime)` — defers until `target`. -* `before(target_datetime)` -* `in_timezone(tz)` — no-op at runtime; validates the tz string at construction. +| Recipe | Description | +|---|---| +| `payload_matches("kwargs.tenant", "acme")` | Dotted-path lookup into `{"args": …, "kwargs": …}`; compare against `expected`. | -### Job-attribute +### Defensive -* `by_queue(name)` / `by_task(name)` -* `by_priority_at_least(n)` -* `retry_count_under(n)` -* `payload_matches("kwargs.tenant", "acme")` — dotted-path lookup into `{"args": ..., "kwargs": ...}`. +| Recipe | Description | +|---|---| +| `queue_paused(name=None)` | Read pause state. Use inverted (`~queue_paused()`) as a defensive composable, not as primary enforcement (which is `Queue.pause()` / `resume()`). | -### System-state +### External config -* `queue_size_under(limit, *, queue=None)` — reads `stats_by_queue()`. -* `queue_paused(queue=None)` -* `error_rate_under(max_rate)` — `(failed + dead) / (completed + failed + dead)` over backend-wide stats. +| Recipe | Description | +|---|---| +| `env_var_truthy("MY_FLAG")` | Truthy values: `1`, `true`, `t`, `yes`, `y`, `on` (case-insensitive). | +| `feature_flag("billing", *, provider=None)` | Defaults to env-var provider reading `FF_`. Pass a registered provider name to swap in LaunchDarkly / Statsig / etc. | -System-state recipes are memoised within a single evaluation, so `queue_size_under(100) & queue_size_under(200)` only reads stats once. +#### Pluggable feature-flag provider -### External-config +```python +from taskito.predicates import FeatureFlagProvider, register_feature_flag_provider -* `env_var_truthy("MY_FLAG")` — truthy values: `1`, `true`, `t`, `yes`, `y`, `on` (case-insensitive). -* `feature_flag("billing", provider=...)` — defaults to env-var provider reading `FF_`. Pass a custom `FeatureFlagProvider` to integrate LaunchDarkly/Statsig/etc. +class LaunchDarklyProvider: + def __init__(self, client) -> None: + self._c = client + + def is_enabled(self, name, ctx) -> bool: + user = {"key": ctx.kwargs.get("tenant_id", "anon")} + return bool(self._c.variation(name, user, False)) + +register_feature_flag_provider("launchdarkly", LaunchDarklyProvider(ldclient.get())) + +# Now both Python and string forms round-trip: +p = feature_flag("billing", provider="launchdarkly") +assert Predicate.from_dict(p.to_dict()).to_dict() == p.to_dict() +``` ## Where predicates run A predicate registered on a task is evaluated **twice**: -1. **Enqueue time** — inside `Queue.enqueue()` / `enqueue_many()`. `Defer` adjusts the caller's `delay=` value. `Cancel` raises `PredicateRejectedError` immediately. The job is never saved when cancelled. -2. **Worker-dispatch time** — just before the task body executes (sync and async paths). `Cancel` raises `TaskCancelledError`, which Rust records as `cancelled`. `Defer` re-enqueues a fresh job with the same payload + delay, and the current job is marked cancelled. +1. **Enqueue time** — inside `Queue.enqueue()` / `enqueue_many()`. + `Defer` adjusts the caller's `delay=`. `Cancel` (or `False` with + `on_false="cancel"`) raises `PredicateRejectedError` and the job is + never saved. +2. **Worker-dispatch time** — just before the task body runs. + `Cancel` raises `TaskCancelledError` (Rust marks the job cancelled). + `Defer` re-enqueues a fresh job with the same payload + delay; the + current job is cancelled. + +`PredicateContext.job_id` is `None` at enqueue time, set at dispatch +time — useful for custom predicates that need to behave differently +across phases. -The `PredicateContext` passed to `evaluate` distinguishes the two phases: `ctx.job_id is None` at enqueue time, set at dispatch time. +## Custom predicates -## Writing a custom predicate +Subclass `Predicate`, declare `OP`, and register through the queue: ```python -from taskito.predicates import Predicate, PredicateContext, Defer +from typing import Any +from taskito import Queue +from taskito.predicates import Predicate, PredicateContext + +queue = Queue() +@queue.register_predicate("tenant_quota_under") class TenantQuotaUnder(Predicate): - def __init__(self, limit: int) -> None: + OP = "tenant_quota_under" + + def __init__(self, limit: int = 0) -> None: self.limit = limit - def evaluate(self, ctx: PredicateContext) -> bool | Defer: + def evaluate(self, ctx: PredicateContext) -> bool: tenant = ctx.kwargs.get("tenant") - if tenant is None: - return False - used = my_quota_service.usage(tenant) - if used < self.limit: - return True - return Defer(seconds=300.0) # check again in 5 minutes + return tenant is not None and quota_service.used(tenant) < self.limit + + def to_dict(self) -> dict[str, Any]: + return {"op": "tenant_quota_under", "limit": self.limit} + + @classmethod + def _from_kwargs(cls, kwargs: dict[str, Any]) -> "Predicate": + return cls(**kwargs) ``` -Predicates may be **async**: return a coroutine and it will be awaited transparently. +Once registered, the new op is usable in JSON, the string DSL, and +operator composition exactly like a built-in recipe. + +Predicates may be **async** — return a coroutine from `evaluate` and it +will be awaited transparently. ## Middleware filtering -`TaskMiddleware` accepts `predicate=` to gate which jobs the middleware applies to. The legacy contrib `task_filter=Callable[[str], bool]` kwarg is still supported and translates internally to a predicate. +`TaskMiddleware` accepts `predicate=` to gate which jobs the middleware +applies to. The legacy contrib `task_filter=Callable[[str], bool]` +kwarg still works and is translated internally to a predicate. ```python from taskito.contrib.sentry import SentryMiddleware -from taskito.predicates import by_queue +from taskito.predicates import payload_matches queue = Queue(middleware=[ - SentryMiddleware(predicate=by_queue("critical")), + SentryMiddleware(predicate=payload_matches("kwargs.env", "prod")), ]) ``` -When the middleware predicate denies, **only that middleware** is skipped for the current job — job dispatch is unaffected. +When the middleware predicate denies, **only that middleware** is +skipped for the current job — job dispatch is unaffected. -## Metrics and events +## Persistence & inspection -Each `Queue` carries a `PredicateMetrics` instance at `queue._predicate_metrics`: +Every registered predicate is serialized once at decoration time. The +JSON is reachable through: ```python -queue._predicate_metrics.snapshot() -# {"allowed": 412, "denied": 3, "deferred": 8, "cancelled": 1, "errors": 0} +queue.list_predicates() +# {"app.send_invoice": {"op": "and", "args": [...]}} + +queue.predicate_for("app.send_invoice") +# {"op": "and", "args": [...]} ``` -Three events are emitted on the queue's event bus: +Bare callables (`predicate=lambda ctx: …`) appear in `list_predicates` +with `None` as their value, since callables have no stable schema. -* `EventType.PREDICATE_DEFERRED` — payload includes `task_name`, `defer_seconds`, `phase` (`"enqueue"` or `"dispatch"`). -* `EventType.PREDICATE_CANCELLED` — worker-dispatch cancellations. -* `EventType.PREDICATE_REJECTED` — enqueue-time rejections that raised `PredicateRejectedError`. +The dashboard reads `list_predicates()` to display "gated by: …" beside +each task. To round-trip a predicate from the inspection output: -## Examples +```python +from taskito.predicates import Predicate -### Business-hours-only report +blob = queue.predicate_for("app.send_invoice") +restored = Predicate.from_dict(blob) +# restored.format() — stable string form +# restored.evaluate(ctx) — re-evaluate against a context +``` -Run a report task only between 09:00–17:00 Pacific on weekdays; defer to the next window otherwise. Urgent jobs (priority ≥ 8) bypass the window. +## Metrics & events ```python -from taskito import Queue -from taskito.predicates import is_business_hours, by_priority_at_least - -queue = Queue() - -@queue.task( - predicate=is_business_hours(tz="US/Pacific") | by_priority_at_least(8), -) -def send_daily_report(team_id: str) -> None: - ... +queue._predicate_metrics.snapshot() +# {"allowed": 412, "denied": 3, "deferred": 8, "cancelled": 1, "errors": 0} +``` -# At 02:00 PT → enqueued with delay until 09:00 PT -send_daily_report.delay("alpha") +Three events fire on the queue's event bus: -# Same task, urgent → runs immediately regardless of hour -queue.enqueue(send_daily_report.name, args=("alpha",), priority=10) -``` +* `EventType.PREDICATE_DEFERRED` — payload includes `task_name`, + `defer_seconds`, `phase` (`"enqueue"` or `"dispatch"`). +* `EventType.PREDICATE_CANCELLED` — worker-dispatch cancellations. +* `EventType.PREDICATE_REJECTED` — enqueue-time rejections that raised + `PredicateRejectedError`. -### Drop jobs when the system is under load +## Examples -Use `queue_size_under` and `error_rate_under` to shed load. Compose them with `&` so a single backpressure breach denies. +### Business-hours-only report with urgent override ```python -from taskito.predicates import queue_size_under, error_rate_under, queue_paused +from taskito.predicates import is_business_hours, payload_matches -healthy = ( - queue_size_under(1_000, queue="default") - & error_rate_under(0.1) - & ~queue_paused() +@queue.task( + predicate=is_business_hours(tz="US/Pacific") + | payload_matches("kwargs.urgent", True), ) +def send_daily_report(team_id: str, urgent: bool = False): ... -@queue.task(predicate=healthy, on_false="defer", default_defer_seconds=30.0) -def import_csv(path: str) -> int: - ... +send_daily_report.delay("alpha") # defers to 09:00 PT +send_daily_report.delay("alpha", urgent=True) # runs now ``` -### Feature-flag rollout - -Gate a new code path behind a flag. The default provider reads `FF_` from the environment; swap in LaunchDarkly with a 5-line adapter. +### Feature-flag rollout with a custom provider ```python -import os -from taskito.predicates import feature_flag, FeatureFlagProvider, PredicateContext - - -class LaunchDarklyProvider: - """Adapter for the LaunchDarkly SDK.""" +from taskito.predicates import feature_flag, register_feature_flag_provider - def __init__(self, client: object, user_attr: str = "tenant_id") -> None: - self._client = client - self._user_attr = user_attr +register_feature_flag_provider("launchdarkly", my_ld_provider) - def is_enabled(self, name: str, ctx: PredicateContext) -> bool: - user = {"key": ctx.kwargs.get(self._user_attr, "anon")} - return bool(self._client.variation(name, user, False)) - - -# In production: provider = LaunchDarklyProvider(ldclient.get()) -# In dev: default env-var provider, FF_NEW_BILLING=1 enables it. -@queue.task(predicate=feature_flag("new_billing")) -def charge_card(tenant_id: str, amount_cents: int) -> str: - ... +@queue.task(predicate=feature_flag("new_billing", provider="launchdarkly")) +def charge_card(tenant_id: str, amount_cents: int): ... ``` -### Tenant allowlist with `payload_matches` +### Tenant allowlist as a custom predicate ```python +from typing import Any from taskito.predicates import Predicate, PredicateContext +@queue.register_predicate("tenant_in") class TenantIn(Predicate): - def __init__(self, *tenants: str) -> None: - self._tenants = frozenset(tenants) + OP = "tenant_in" + + def __init__(self, *, tenants: list[str]) -> None: + self.tenants = list(tenants) def evaluate(self, ctx: PredicateContext) -> bool: - return ctx.kwargs.get("tenant_id") in self._tenants + return ctx.kwargs.get("tenant") in self.tenants + def to_dict(self) -> dict[str, Any]: + return {"op": "tenant_in", "tenants": self.tenants} -@queue.task(predicate=TenantIn("acme", "globex"), on_false="cancel") -def reindex_search(tenant_id: str) -> None: - ... + @classmethod + def _from_kwargs(cls, kwargs: dict[str, Any]) -> "Predicate": + return cls(**kwargs) -# raises PredicateRejectedError — never enqueued -queue.enqueue(reindex_search.name, kwargs={"tenant_id": "initech"}) -``` -### Async predicate with quota service +@queue.task( + predicate=TenantIn(tenants=["acme", "globex"]), + on_false="cancel", +) +def reindex_search(tenant: str): ... +``` -`evaluate` may be a coroutine — it's awaited transparently. +### Async predicate calling an HTTP quota service ```python import httpx -from taskito.predicates import Predicate, PredicateContext, Defer - +from taskito.predicates import Defer, Predicate, PredicateContext class QuotaAvailable(Predicate): - def __init__(self, url: str) -> None: - self._url = url + OP = "quota_available" + + def __init__(self, *, url: str) -> None: + self.url = url async def evaluate(self, ctx: PredicateContext) -> bool | Defer: - async with httpx.AsyncClient(timeout=2.0) as client: - r = await client.get(f"{self._url}/quota/{ctx.kwargs['tenant']}") - remaining = r.json()["remaining"] - if remaining > 0: + async with httpx.AsyncClient(timeout=2.0) as c: + r = await c.get(f"{self.url}/quota/{ctx.kwargs['tenant']}") + if r.json()["remaining"] > 0: return True return Defer(seconds=60.0) -@queue.task(predicate=QuotaAvailable("http://quota.internal")) -async def index_document(tenant: str, doc_id: str) -> None: - ... -``` +queue.register_predicate("quota_available")(QuotaAvailable) -### Per-middleware filtering +@queue.task(predicate=QuotaAvailable(url="http://quota.internal")) +async def index_document(tenant: str, doc_id: str): ... +``` -Route a single middleware to only one queue without touching the queue-level middleware list. +### Editing a predicate as a string ```python -from taskito.contrib.sentry import SentryMiddleware -from taskito.contrib.prometheus import PrometheusMiddleware -from taskito.predicates import by_queue +from taskito.predicates import parse -queue = Queue(middleware=[ - # Sentry only for the "critical" queue - SentryMiddleware(predicate=by_queue("critical")), - # Prometheus everywhere - PrometheusMiddleware(), -]) +@queue.task(predicate=parse( + 'is_business_hours(tz="US/Pacific") & !queue_paused()' +)) +def send_report(): ... ``` -### Listening for predicate events +The string form is the right surface for ops-style editing (e.g. +storing the gate in a config file or letting an operator tweak it +through a dashboard textbox without code deploys). -Wire dashboards or alerts by subscribing to the predicate event types. +### Listening for predicate events ```python from taskito.events import EventType -def on_predicate_event(event_type: EventType, payload: dict) -> None: - if event_type is EventType.PREDICATE_DEFERRED: - print(f"{payload['task_name']} deferred {payload['defer_seconds']}s " - f"at {payload['phase']}") - elif event_type is EventType.PREDICATE_CANCELLED: - print(f"{payload['task_name']} cancelled at dispatch: " - f"{payload.get('reason', '')}") +def on_event(event_type: EventType, payload: dict) -> None: + print(event_type.value, payload) for event in ( EventType.PREDICATE_DEFERRED, EventType.PREDICATE_CANCELLED, EventType.PREDICATE_REJECTED, ): - queue._event_bus.on(event, on_predicate_event) + queue._event_bus.on(event, on_event) ``` diff --git a/docs/content/docs/more/examples/predicate-gated-jobs.mdx b/docs/content/docs/more/examples/predicate-gated-jobs.mdx index 792ff50..37ac994 100644 --- a/docs/content/docs/more/examples/predicate-gated-jobs.mdx +++ b/docs/content/docs/more/examples/predicate-gated-jobs.mdx @@ -1,36 +1,42 @@ --- title: Predicate-Gated Jobs -description: "End-to-end example of using Predicate objects to gate task execution by business hours, system load, feature flags, and per-tenant quotas." +description: "End-to-end example of the predicate DSL: serializable AST, JSON + string forms, custom ops, feature flags, and dashboard-visible gates." --- -A small reporting service that uses predicates to gate when tasks may run. It demonstrates every layer of the predicate system: enqueue-time rejection, worker-dispatch defer-and-retry, composable recipes, custom predicates, async predicates, and middleware filtering. +A small service that gates tasks with predicates — composed with +Python operators, persisted as JSON, and round-tripped through the +string DSL. ## What the service does * Sends per-tenant **daily reports** during business hours only. -* Defers heavy **reindex jobs** when the queue is backed up or the error rate is high. -* Gates a new **billing workflow** behind a feature flag. -* Routes Sentry alerts only for the `critical` queue. +* Gates a **billing** workflow behind a feature flag. +* Holds **search reindex** to an allowlist of tenants via a custom + predicate. +* Calls an **async quota service** to decide whether to enqueue more + indexing work. +* Exposes `queue.list_predicates()` so the dashboard can show every + gate. ## Project structure ``` predicate_gated/ app.py # Queue + tasks - predicates.py # Custom predicates + predicates.py # Custom predicate classes main.py # Submit jobs worker.py # Run the worker ``` ## predicates.py -A custom predicate that consults an HTTP quota service. Async — taskito awaits it transparently. - ```python -"""Custom predicates for the example service.""" +"""Custom predicates registered with the queue.""" from __future__ import annotations +from typing import Any + import httpx from taskito.predicates import Defer, Predicate, PredicateContext @@ -39,109 +45,95 @@ from taskito.predicates import Defer, Predicate, PredicateContext class TenantQuotaUnder(Predicate): """Defer until the tenant has quota remaining.""" - def __init__(self, url: str, limit: int) -> None: - self._url = url - self._limit = limit + OP = "tenant_quota_under" + + def __init__(self, *, url: str = "", limit: int = 0) -> None: + self.url = url + self.limit = limit async def evaluate(self, ctx: PredicateContext) -> bool | Defer: tenant = ctx.kwargs.get("tenant") if tenant is None: return False async with httpx.AsyncClient(timeout=2.0) as client: - r = await client.get(f"{self._url}/quota/{tenant}") - used = r.json()["used"] - if used < self._limit: + r = await client.get(f"{self.url}/quota/{tenant}") + if r.json()["used"] < self.limit: return True - # Quota busts → re-check in 5 minutes. return Defer(seconds=300.0) + def to_dict(self) -> dict[str, Any]: + return {"op": self.OP, "url": self.url, "limit": self.limit} + + @classmethod + def _from_kwargs(cls, kwargs: dict[str, Any]) -> "Predicate": + return cls(**kwargs) + class TenantIn(Predicate): - """Hard allowlist; rejects at enqueue time.""" + """Hard allowlist — rejects at enqueue time when on_false='cancel'.""" - def __init__(self, *tenants: str) -> None: - self._tenants = frozenset(tenants) + OP = "tenant_in" + + def __init__(self, *, tenants: list[str]) -> None: + self.tenants = list(tenants) def evaluate(self, ctx: PredicateContext) -> bool: - return ctx.kwargs.get("tenant") in self._tenants + return ctx.kwargs.get("tenant") in self.tenants + + def to_dict(self) -> dict[str, Any]: + return {"op": self.OP, "tenants": self.tenants} + + @classmethod + def _from_kwargs(cls, kwargs: dict[str, Any]) -> "Predicate": + return cls(**kwargs) ``` ## app.py -Composes built-in recipes with the custom predicates. Each task gets a different predicate shape. - ```python -"""Queue, tasks, and predicate wiring.""" +"""Queue + tasks. All gates are composable, serializable predicates.""" from taskito import Queue -from taskito.contrib.sentry import SentryMiddleware from taskito.predicates import ( - by_priority_at_least, - by_queue, - error_rate_under, feature_flag, is_business_hours, queue_paused, - queue_size_under, ) from .predicates import TenantIn, TenantQuotaUnder -queue = Queue( - db_path=".taskito/predicate_gated.db", - workers=4, - middleware=[ - # Sentry only for the "critical" queue. - SentryMiddleware(predicate=by_queue("critical")), - ], -) +queue = Queue(db_path=".taskito/predicate_gated.db", workers=4) +# Register the custom predicate ops so they show up in JSON, the +# string DSL, and queue.list_predicates(). +queue.register_predicate("tenant_quota_under")(TenantQuotaUnder) +queue.register_predicate("tenant_in")(TenantIn) -# ── Reports: business hours + tenant allowlist ──────────────── -@queue.task( - predicate=is_business_hours(tz="US/Pacific") & TenantIn("acme", "globex"), - on_false="defer", - default_defer_seconds=900.0, # 15 minutes -) -def send_daily_report(tenant: str) -> str: - """Defers outside Mon–Fri 09:00–17:00 Pacific. - Cancels at enqueue if tenant is not in the allowlist.""" - return f"report sent for {tenant}" - - -# ── Reindex: load-shed + bypass when urgent ─────────────────── -healthy = ( - queue_size_under(1_000, queue="bulk") - & error_rate_under(0.1) - & ~queue_paused("bulk") -) +# ── Reports: business hours + tenant allowlist ────────────────────── @queue.task( - queue="bulk", - predicate=healthy | by_priority_at_least(8), + predicate=is_business_hours(tz="US/Pacific") + & TenantIn(tenants=["acme", "globex"]), on_false="defer", - default_defer_seconds=30.0, + default_defer_seconds=900.0, ) -def reindex_search(tenant: str) -> int: - """Skipped when default queue is over 1k pending or error_rate > 10%, - unless explicitly enqueued with priority>=8.""" - return 42 # rows indexed +def send_daily_report(tenant: str) -> str: + return f"sent {tenant}" -# ── Billing: gated by FF_NEW_BILLING env var ────────────────── +# ── Billing: gated behind FF_NEW_BILLING ──────────────────────────── @queue.task( queue="critical", - predicate=feature_flag("new_billing"), + predicate=feature_flag("new_billing") & ~queue_paused(), on_false="cancel", ) -def charge_card(tenant: str, amount_cents: int) -> str: - return f"charged {amount_cents}c to {tenant}" +def charge_card(tenant: str, cents: int) -> str: + return f"charged {cents}c to {tenant}" -# ── Async + quota-bound indexer ─────────────────────────────── +# ── Indexing: async quota-bound ───────────────────────────────────── @queue.task( - predicate=TenantQuotaUnder("http://quota.internal", limit=10_000), - on_false="defer", + predicate=TenantQuotaUnder(url="http://quota.internal", limit=10_000), ) async def index_document(tenant: str, doc_id: str) -> str: return f"indexed {doc_id}" @@ -149,49 +141,45 @@ async def index_document(tenant: str, doc_id: str) -> str: ## main.py -Submitting jobs. Demonstrates each outcome. - ```python -"""Submit some jobs and watch the predicates kick in.""" +"""Submit jobs and watch each layer of the predicate system kick in.""" from taskito.exceptions import PredicateRejectedError -from .app import ( - charge_card, - index_document, - queue, - reindex_search, - send_daily_report, -) +from .app import charge_card, index_document, queue, send_daily_report def main() -> None: - # 1. send_daily_report: allowed for "acme"; deferred to 09:00 PT outside hours. + # 1. Allowed (acme is in the allowlist). send_daily_report.delay(tenant="acme") - # 2. send_daily_report: cancel at enqueue — tenant not in allowlist. + # 2. Cancel at enqueue: tenant not in allowlist + on_false='cancel' + # is implied by the Cancel return path of TenantIn... actually, + # TenantIn returns False which combines with on_false='defer' + # here, so the call is deferred. Use TenantIn(...).evaluate() + # behavior plus on_false='cancel' for hard rejection: try: send_daily_report.delay(tenant="initech") - except PredicateRejectedError as exc: - print(f"rejected: {exc.task_name} — {exc.reason}") - - # 3. reindex_search: normal path - reindex_search.delay(tenant="acme") - - # 4. reindex_search: bypass backpressure with priority - queue.enqueue(reindex_search.name, kwargs={"tenant": "acme"}, priority=9) - - # 5. charge_card: gated behind FF_NEW_BILLING; cancels if flag is off. - try: - charge_card.delay(tenant="acme", amount_cents=4200) except PredicateRejectedError: - print("billing flag is off") + print("not allowed") - # 6. async predicate + async task — both are awaited transparently + # 3. Bypass the billing flag by setting it in env. + import os + os.environ["FF_NEW_BILLING"] = "true" + charge_card.delay(tenant="acme", cents=4200) + + # 4. Async predicate + async task — both awaited transparently. index_document.delay(tenant="acme", doc_id="doc-1") - # Watch metrics live - print(queue._predicate_metrics.snapshot()) + # ── Inspection: dashboard reads the same shape ───────────────── + print(queue.list_predicates()) + # { + # "app.send_daily_report": {"op": "and", "args": [...]}, + # "app.charge_card": {"op": "and", "args": [...]}, + # "app.index_document": {"op": "tenant_quota_under", + # "url": "http://quota.internal", + # "limit": 10000}, + # } if __name__ == "__main__": @@ -207,38 +195,63 @@ if __name__ == "__main__": queue.run_worker() ``` -## Subscribing to events +## Round-tripping a predicate through the DSL surfaces -Plug predicate outcomes into your alerting or dashboards. +Any predicate is serializable in three equivalent forms: + +```python +from taskito.predicates import Predicate, format_predicate, parse + +p = queue._task_predicates["app.send_daily_report"] + +# 1. JSON for storage / API / dashboard +blob = p.to_dict() + +# 2. Stable, parseable string for ops-side editing +text = format_predicate(p) +# 'is_business_hours(tz="US/Pacific") & tenant_in(tenants=["acme", "globex"])' + +# 3. Rebuild from either form +assert Predicate.from_dict(blob).to_dict() == blob +assert parse(text).to_dict() == blob +``` + +## Subscribing to predicate events ```python from taskito.events import EventType -def on_predicate_event(event_type: EventType, payload: dict) -> None: +def log_event(event_type: EventType, payload: dict) -> None: match event_type: case EventType.PREDICATE_DEFERRED: - print( - f"[defer] {payload['task_name']} +{payload['defer_seconds']}s " - f"phase={payload['phase']}" - ) + print(f"[defer] {payload['task_name']} +{payload['defer_seconds']}s " + f"phase={payload['phase']}") case EventType.PREDICATE_CANCELLED: - print(f"[cancel] {payload['task_name']} reason={payload.get('reason', '')}") + print(f"[cancel] {payload['task_name']}: " + f"{payload.get('reason', '')}") case EventType.PREDICATE_REJECTED: - print(f"[reject] {payload['task_name']} reason={payload.get('reason', '')}") + print(f"[reject] {payload['task_name']}: " + f"{payload.get('reason', '')}") for event in ( EventType.PREDICATE_DEFERRED, EventType.PREDICATE_CANCELLED, EventType.PREDICATE_REJECTED, ): - queue._event_bus.on(event, on_predicate_event) + queue._event_bus.on(event, log_event) ``` ## Notes -* Predicates are evaluated **twice** for each job (once at enqueue, once at dispatch). Use `ctx.job_id` to tell the phases apart inside a custom predicate. -* `on_false` only controls what happens when the predicate returns a plain `False`. Returning `Defer(seconds=...)` or `Cancel(reason=...)` is honored regardless. -* System-state recipes (`queue_size_under`, `error_rate_under`, `queue_paused`) memoize within a single evaluation, so composing them is cheap. -* A predicate that raises is treated as `False` (fail-closed) — the error is logged and `PredicateMetrics.errors` is incremented. - -See the [Predicates guide](/docs/guides/core/predicates) for the full API reference. +* Predicates are evaluated **twice** per job (enqueue + dispatch). Use + `ctx.job_id is None` to tell the phases apart inside a custom + predicate. +* `on_false` only controls what happens when the predicate returns a + plain `False`. Returning `Defer(...)` or `Cancel(...)` is honored + regardless. +* For atomic backpressure use `max_concurrent` / `rate_limit` / + `circuit_breaker` — they are enforced in the Rust scheduler. + Predicates are for the gates the scheduler doesn't already provide: + time, payload, feature flags, custom logic. +* See the [Predicates guide](/docs/guides/core/predicates) for the + full API surface. From 8fbf316ebce100230a5c34aadf3c15ca15dddd8a Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Mon, 11 May 2026 11:03:55 +0530 Subject: [PATCH 10/11] fix(tests): poll for on_cancel middleware in prefork cancel test --- tests/worker/test_prefork.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/worker/test_prefork.py b/tests/worker/test_prefork.py index fa8915e..3bbaa1e 100644 --- a/tests/worker/test_prefork.py +++ b/tests/worker/test_prefork.py @@ -304,7 +304,15 @@ def _running() -> bool: status = _wait_for_terminal(job, timeout=10) assert status == "cancelled", f"expected 'cancelled', got {status!r} (error={job.error!r})" - assert job.id in cancels_seen, "on_cancel middleware did not fire" + # Same race as `on_timeout` in #154: handle_result flips the DB status to + # 'cancelled' before dispatch_outcome fires on_cancel on a separate + # thread, so a fast assertion can race past the middleware call. Poll + # for the spy with a small budget instead. + poll_until( + lambda: job.id in cancels_seen, + timeout=5, + message="on_cancel middleware did not fire", + ) @prefork_unix_only From 5896544b4643c323427a8d81f684561b41d52797 Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Mon, 11 May 2026 11:30:32 +0530 Subject: [PATCH 11/11] fix(predicates): fast-path UTC and require tzdata on Windows --- py_src/taskito/predicates/recipes/time.py | 9 ++++++++- pyproject.toml | 5 ++++- 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/py_src/taskito/predicates/recipes/time.py b/py_src/taskito/predicates/recipes/time.py index a6c651d..749ab2e 100644 --- a/py_src/taskito/predicates/recipes/time.py +++ b/py_src/taskito/predicates/recipes/time.py @@ -34,6 +34,10 @@ class ZoneInfoNotFoundError(Exception): # type: ignore[no-redef] def _resolve_tz(tz: str | None) -> tzinfo: if tz is None: return timezone.utc + # Fast-path UTC so the predicate works on hosts that lack an IANA tz + # database (notably Windows without the ``tzdata`` package). + if tz.upper() == "UTC": + return timezone.utc if not _HAS_ZONEINFO: raise PredicateValidationError( "zoneinfo is not available; pass tz=None or install Python 3.9+" @@ -41,7 +45,10 @@ def _resolve_tz(tz: str | None) -> tzinfo: try: return ZoneInfo(tz) except ZoneInfoNotFoundError as exc: - raise PredicateValidationError(f"unknown timezone: {tz!r}") from exc + raise PredicateValidationError( + f"unknown timezone: {tz!r} " + "(on Windows, install the `tzdata` package for IANA timezone support)" + ) from exc def _parse_hhmm(value: str) -> tuple[int, int]: diff --git a/pyproject.toml b/pyproject.toml index ed85b08..78e19ee 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,10 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", "Topic :: System :: Distributed Computing", ] -dependencies = ["cloudpickle>=3.0"] +dependencies = [ + "cloudpickle>=3.0", + 'tzdata; platform_system == "Windows"', +] [project.urls] Homepage = "https://github.com/ByteVeda/taskito" Documentation = "https://docs.byteveda.org/taskito"