From 2abe062c78e0d3ec8cece9d9830e924abd2ba7cf Mon Sep 17 00:00:00 2001 From: Pratyush Sharma <56130065+pratyush618@users.noreply.github.com> Date: Fri, 8 May 2026 07:30:14 +0530 Subject: [PATCH] refactor(decorators): move _wrap_task into mixin --- py_src/taskito/app.py | 128 ++----------------------- py_src/taskito/mixins/decorators.py | 140 +++++++++++++++++++++++++++- 2 files changed, 146 insertions(+), 122 deletions(-) diff --git a/py_src/taskito/app.py b/py_src/taskito/app.py index 6f84b10..d5b06c5 100644 --- a/py_src/taskito/app.py +++ b/py_src/taskito/app.py @@ -8,14 +8,16 @@ The Queue class itself (this file) handles only: - Constructor and storage backend initialization - enqueue() / enqueue_many() job submission -- _wrap_task() task body wrapping with hooks, middleware, proxies, resources -- Internal helpers (``_get_serializer``, ``_deserialize_payload``, - ``_get_middleware_chain``) +- Internal helpers (``_get_serializer``, ``_deserialize_payload``) + +Task-body wrapping (``_wrap_task``) and the per-task middleware chain +(``_get_middleware_chain``) live on ``QueueDecoratorMixin`` — +``_wrap_task`` is invoked from the decorator hot path, so co-locating it +with the decorator keeps the interactions in one file. """ from __future__ import annotations -import functools import hashlib import logging import os @@ -24,14 +26,11 @@ from typing import TYPE_CHECKING, Any from taskito._taskito import PyQueue -from taskito.async_support.helpers import run_maybe_async from taskito.async_support.mixins import AsyncQueueMixin -from taskito.context import _clear_context, current_job from taskito.events import EventBus, EventType from taskito.interception import ArgumentInterceptor from taskito.interception.built_in import build_default_registry from taskito.interception.metrics import InterceptionMetrics -from taskito.interception.reconstruct import reconstruct_args from taskito.middleware import TaskMiddleware from taskito.mixins import ( QueueDecoratorMixin, @@ -43,7 +42,7 @@ QueueResourceMixin, QueueSettingsMixin, ) -from taskito.proxies import ProxyRegistry, cleanup_proxies, reconstruct_proxies +from taskito.proxies import ProxyRegistry from taskito.proxies.built_in import register_builtin_handlers from taskito.proxies.metrics import ProxyMetrics from taskito.result import JobResult @@ -302,119 +301,6 @@ def _deserialize_payload(self, task_name: str, payload: bytes) -> tuple: """Deserialize a job payload using the per-task or queue-level serializer.""" return self._get_serializer(task_name).loads(payload) # type: ignore[no-any-return] - def _get_middleware_chain(self, task_name: str) -> list[TaskMiddleware]: - """Get the combined global + per-task middleware list.""" - per_task = self._task_middleware.get(task_name, []) - return self._global_middleware + per_task - - def _wrap_task( - self, fn: Callable, task_name: str, soft_timeout: float | None = None - ) -> Callable: - """Wrap a task function with hooks, middleware, and job context.""" - hooks = self._hooks - queue_ref = self - - @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> Any: - # Reconstruct intercepted arguments (CONVERT markers → original types) - redirects: dict[str, str] = {} - if queue_ref._interceptor is not None: - args, kwargs, redirects = reconstruct_args(args, kwargs) - - # Reconstruct proxy markers (PROXY → live objects) - proxy_cleanup: list[Any] = [] - if queue_ref._proxy_registry is not None and not queue_ref._test_mode_active: - args, kwargs, proxy_cleanup = reconstruct_proxies( - args, - kwargs, - queue_ref._proxy_registry, - signing_secret=queue_ref._recipe_signing_key, - max_timeout=queue_ref._max_reconstruction_timeout, - metrics=queue_ref._proxy_metrics, - ) - - # Inject resources from runtime - release_callbacks: list[Any] = [] - runtime = queue_ref._resource_runtime - if runtime is not None: - # From explicit inject=["db"] on task decorator - for res_name in queue_ref._task_inject_map.get(task_name, []): - if res_name not in kwargs: - instance, release = runtime.acquire_for_task(res_name) - kwargs[res_name] = instance - if release is not None: - release_callbacks.append(release) - # From interception REDIRECT markers - for kwarg_name, resource_name in redirects.items(): - instance, release = runtime.acquire_for_task(resource_name) - kwargs[kwarg_name] = instance - if release is not None: - release_callbacks.append(release) - - middleware_chain = queue_ref._get_middleware_chain(task_name) - - # Set soft timeout on context if configured - if soft_timeout is not None: - current_job._set_soft_timeout(soft_timeout) - - # Run middleware before hooks - completed_mw: list[Any] = [] - for mw in middleware_chain: - try: - mw.before(current_job) - completed_mw.append(mw) - except Exception: - logger.exception("middleware before() error") - - for hook in hooks["before_task"]: - hook(task_name, args, kwargs) - - error = None - result = None - try: - ret = run_maybe_async(fn(*args, **kwargs)) - result = ret - except Exception as exc: - error = exc - for hook in hooks["on_failure"]: - hook(task_name, args, kwargs, exc) - raise - else: - for hook in hooks["on_success"]: - hook(task_name, args, kwargs, result) - return result - finally: - # Release task/request-scoped resources - for release_fn in release_callbacks: - try: - release_fn() - except Exception: - logger.exception("resource release error") - # Clean up reconstructed proxies (LIFO order) - cleanup_proxies(proxy_cleanup, metrics=queue_ref._proxy_metrics) - for hook in hooks["after_task"]: - hook(task_name, args, kwargs, result, error) - # Run middleware after hooks (only those whose before() succeeded) - for mw in completed_mw: - try: - mw.after(current_job, result, error) - except Exception: - logger.exception("middleware after() error") - # Emit job lifecycle events - event_payload = { - "task_name": task_name, - "job_id": current_job.id, - "queue": current_job.queue_name, - } - if error is not None: - event_payload["error"] = str(error) - queue_ref._emit_event(EventType.JOB_FAILED, event_payload) - else: - queue_ref._emit_event(EventType.JOB_COMPLETED, event_payload) - _clear_context() - - return wrapper - def enqueue( self, task_name: str, diff --git a/py_src/taskito/mixins/decorators.py b/py_src/taskito/mixins/decorators.py index 15582a0..3898b95 100644 --- a/py_src/taskito/mixins/decorators.py +++ b/py_src/taskito/mixins/decorators.py @@ -5,6 +5,7 @@ import contextlib import functools import inspect +import logging import os import sys import typing @@ -12,16 +13,27 @@ from typing import TYPE_CHECKING, Any from taskito._taskito import PyTaskConfig +from taskito.async_support.helpers import run_maybe_async +from taskito.context import _clear_context, current_job +from taskito.events import EventType from taskito.inject import Inject, _InjectAlias +from taskito.interception.reconstruct import reconstruct_args from taskito.interception.strategy import Strategy as S +from taskito.proxies import cleanup_proxies, reconstruct_proxies from taskito.task import TaskWrapper if TYPE_CHECKING: from taskito.interception import ArgumentInterceptor from taskito.middleware import TaskMiddleware + from taskito.proxies import ProxyRegistry + from taskito.proxies.metrics import ProxyMetrics + from taskito.resources.runtime import ResourceRuntime from taskito.serializers import Serializer +logger = logging.getLogger("taskito") + + def _resolve_module_name(module_name: str) -> str: """Resolve __main__ to the actual module name.""" if module_name != "__main__": @@ -51,8 +63,134 @@ class QueueDecoratorMixin: _task_retry_filters: dict[str, dict[str, list[type[Exception]]]] _task_inject_map: dict[str, list[str]] _interceptor: ArgumentInterceptor | None + _proxy_registry: ProxyRegistry | None + _proxy_metrics: ProxyMetrics + _resource_runtime: ResourceRuntime | None + _test_mode_active: bool + _recipe_signing_key: str | None + _max_reconstruction_timeout: int + _global_middleware: list[TaskMiddleware] _queue_configs: dict[str, dict[str, Any]] + # ``_emit_event`` is provided by ``QueueEventsMixin`` on the composed + # Queue. Declaring it as a class-level callable attribute (not a method) + # lets mypy see it from this mixin without overriding the real + # implementation through MRO. + _emit_event: Callable[..., None] + + def _get_middleware_chain(self, task_name: str) -> list[TaskMiddleware]: + """Get the combined global + per-task middleware list.""" + per_task = self._task_middleware.get(task_name, []) + return self._global_middleware + per_task + + def _wrap_task( + self, fn: Callable, task_name: str, soft_timeout: float | None = None + ) -> Callable: + """Wrap a task function with hooks, middleware, and job context.""" + hooks = self._hooks + queue_ref = self + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + # Reconstruct intercepted arguments (CONVERT markers → original types) + redirects: dict[str, str] = {} + if queue_ref._interceptor is not None: + args, kwargs, redirects = reconstruct_args(args, kwargs) + + # Reconstruct proxy markers (PROXY → live objects) + proxy_cleanup: list[Any] = [] + if queue_ref._proxy_registry is not None and not queue_ref._test_mode_active: + args, kwargs, proxy_cleanup = reconstruct_proxies( + args, + kwargs, + queue_ref._proxy_registry, + signing_secret=queue_ref._recipe_signing_key, + max_timeout=queue_ref._max_reconstruction_timeout, + metrics=queue_ref._proxy_metrics, + ) + + # Inject resources from runtime + release_callbacks: list[Any] = [] + runtime = queue_ref._resource_runtime + if runtime is not None: + # From explicit inject=["db"] on task decorator + for res_name in queue_ref._task_inject_map.get(task_name, []): + if res_name not in kwargs: + instance, release = runtime.acquire_for_task(res_name) + kwargs[res_name] = instance + if release is not None: + release_callbacks.append(release) + # From interception REDIRECT markers + for kwarg_name, resource_name in redirects.items(): + instance, release = runtime.acquire_for_task(resource_name) + kwargs[kwarg_name] = instance + if release is not None: + release_callbacks.append(release) + + middleware_chain = queue_ref._get_middleware_chain(task_name) + + # Set soft timeout on context if configured + if soft_timeout is not None: + current_job._set_soft_timeout(soft_timeout) + + # Run middleware before hooks + completed_mw: list[Any] = [] + for mw in middleware_chain: + try: + mw.before(current_job) + completed_mw.append(mw) + except Exception: + logger.exception("middleware before() error") + + for hook in hooks["before_task"]: + hook(task_name, args, kwargs) + + error = None + result = None + try: + ret = run_maybe_async(fn(*args, **kwargs)) + result = ret + except Exception as exc: + error = exc + for hook in hooks["on_failure"]: + hook(task_name, args, kwargs, exc) + raise + else: + for hook in hooks["on_success"]: + hook(task_name, args, kwargs, result) + return result + finally: + # Release task/request-scoped resources + for release_fn in release_callbacks: + try: + release_fn() + except Exception: + logger.exception("resource release error") + # Clean up reconstructed proxies (LIFO order) + cleanup_proxies(proxy_cleanup, metrics=queue_ref._proxy_metrics) + for hook in hooks["after_task"]: + hook(task_name, args, kwargs, result, error) + # Run middleware after hooks (only those whose before() succeeded) + for mw in completed_mw: + try: + mw.after(current_job, result, error) + except Exception: + logger.exception("middleware after() error") + # Emit job lifecycle events + event_payload: dict[str, Any] = { + "task_name": task_name, + "job_id": current_job.id, + "queue": current_job.queue_name, + } + if error is not None: + event_payload["error"] = str(error) + queue_ref._emit_event(EventType.JOB_FAILED, event_payload) + else: + queue_ref._emit_event(EventType.JOB_COMPLETED, event_payload) + _clear_context() + + return wrapper + def task( self, name: str | None = None, @@ -160,7 +298,7 @@ def decorator(fn: Callable) -> TaskWrapper: self._task_inject_map[task_name] = final_inject # Wrap the function with hooks, middleware, and context - wrapped = self._wrap_task(fn, task_name, soft_timeout) # type: ignore[attr-defined] + wrapped = self._wrap_task(fn, task_name, soft_timeout) self._task_registry[task_name] = wrapped cb_threshold = None