Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 7 additions & 121 deletions py_src/taskito/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
The Queue class itself (this file) handles only:
- Constructor and storage backend initialization
- enqueue() / enqueue_many() job submission
- _wrap_task() task body wrapping with hooks, middleware, proxies, resources
- Internal helpers (``_get_serializer``, ``_deserialize_payload``,
``_get_middleware_chain``)
- Internal helpers (``_get_serializer``, ``_deserialize_payload``)

Task-body wrapping (``_wrap_task``) and the per-task middleware chain
(``_get_middleware_chain``) live on ``QueueDecoratorMixin`` —
``_wrap_task`` is invoked from the decorator hot path, so co-locating it
with the decorator keeps the interactions in one file.
"""

from __future__ import annotations

import functools
import hashlib
import logging
import os
Expand All @@ -24,14 +26,11 @@
from typing import TYPE_CHECKING, Any

from taskito._taskito import PyQueue
from taskito.async_support.helpers import run_maybe_async
from taskito.async_support.mixins import AsyncQueueMixin
from taskito.context import _clear_context, current_job
from taskito.events import EventBus, EventType
from taskito.interception import ArgumentInterceptor
from taskito.interception.built_in import build_default_registry
from taskito.interception.metrics import InterceptionMetrics
from taskito.interception.reconstruct import reconstruct_args
from taskito.middleware import TaskMiddleware
from taskito.mixins import (
QueueDecoratorMixin,
Expand All @@ -43,7 +42,7 @@
QueueResourceMixin,
QueueSettingsMixin,
)
from taskito.proxies import ProxyRegistry, cleanup_proxies, reconstruct_proxies
from taskito.proxies import ProxyRegistry
from taskito.proxies.built_in import register_builtin_handlers
from taskito.proxies.metrics import ProxyMetrics
from taskito.result import JobResult
Expand Down Expand Up @@ -302,119 +301,6 @@ def _deserialize_payload(self, task_name: str, payload: bytes) -> tuple:
"""Deserialize a job payload using the per-task or queue-level serializer."""
return self._get_serializer(task_name).loads(payload) # type: ignore[no-any-return]

def _get_middleware_chain(self, task_name: str) -> list[TaskMiddleware]:
"""Get the combined global + per-task middleware list."""
per_task = self._task_middleware.get(task_name, [])
return self._global_middleware + per_task

def _wrap_task(
self, fn: Callable, task_name: str, soft_timeout: float | None = None
) -> Callable:
"""Wrap a task function with hooks, middleware, and job context."""
hooks = self._hooks
queue_ref = self

@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# Reconstruct intercepted arguments (CONVERT markers → original types)
redirects: dict[str, str] = {}
if queue_ref._interceptor is not None:
args, kwargs, redirects = reconstruct_args(args, kwargs)

# Reconstruct proxy markers (PROXY → live objects)
proxy_cleanup: list[Any] = []
if queue_ref._proxy_registry is not None and not queue_ref._test_mode_active:
args, kwargs, proxy_cleanup = reconstruct_proxies(
args,
kwargs,
queue_ref._proxy_registry,
signing_secret=queue_ref._recipe_signing_key,
max_timeout=queue_ref._max_reconstruction_timeout,
metrics=queue_ref._proxy_metrics,
)

# Inject resources from runtime
release_callbacks: list[Any] = []
runtime = queue_ref._resource_runtime
if runtime is not None:
# From explicit inject=["db"] on task decorator
for res_name in queue_ref._task_inject_map.get(task_name, []):
if res_name not in kwargs:
instance, release = runtime.acquire_for_task(res_name)
kwargs[res_name] = instance
if release is not None:
release_callbacks.append(release)
# From interception REDIRECT markers
for kwarg_name, resource_name in redirects.items():
instance, release = runtime.acquire_for_task(resource_name)
kwargs[kwarg_name] = instance
if release is not None:
release_callbacks.append(release)

middleware_chain = queue_ref._get_middleware_chain(task_name)

# Set soft timeout on context if configured
if soft_timeout is not None:
current_job._set_soft_timeout(soft_timeout)

# Run middleware before hooks
completed_mw: list[Any] = []
for mw in middleware_chain:
try:
mw.before(current_job)
completed_mw.append(mw)
except Exception:
logger.exception("middleware before() error")

for hook in hooks["before_task"]:
hook(task_name, args, kwargs)

error = None
result = None
try:
ret = run_maybe_async(fn(*args, **kwargs))
result = ret
except Exception as exc:
error = exc
for hook in hooks["on_failure"]:
hook(task_name, args, kwargs, exc)
raise
else:
for hook in hooks["on_success"]:
hook(task_name, args, kwargs, result)
return result
finally:
# Release task/request-scoped resources
for release_fn in release_callbacks:
try:
release_fn()
except Exception:
logger.exception("resource release error")
# Clean up reconstructed proxies (LIFO order)
cleanup_proxies(proxy_cleanup, metrics=queue_ref._proxy_metrics)
for hook in hooks["after_task"]:
hook(task_name, args, kwargs, result, error)
# Run middleware after hooks (only those whose before() succeeded)
for mw in completed_mw:
try:
mw.after(current_job, result, error)
except Exception:
logger.exception("middleware after() error")
# Emit job lifecycle events
event_payload = {
"task_name": task_name,
"job_id": current_job.id,
"queue": current_job.queue_name,
}
if error is not None:
event_payload["error"] = str(error)
queue_ref._emit_event(EventType.JOB_FAILED, event_payload)
else:
queue_ref._emit_event(EventType.JOB_COMPLETED, event_payload)
_clear_context()

return wrapper

def enqueue(
self,
task_name: str,
Expand Down
140 changes: 139 additions & 1 deletion py_src/taskito/mixins/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,35 @@
import contextlib
import functools
import inspect
import logging
import os
import sys
import typing
from collections.abc import Callable
from typing import TYPE_CHECKING, Any

from taskito._taskito import PyTaskConfig
from taskito.async_support.helpers import run_maybe_async
from taskito.context import _clear_context, current_job
from taskito.events import EventType
from taskito.inject import Inject, _InjectAlias
from taskito.interception.reconstruct import reconstruct_args
from taskito.interception.strategy import Strategy as S
from taskito.proxies import cleanup_proxies, reconstruct_proxies
from taskito.task import TaskWrapper

if TYPE_CHECKING:
from taskito.interception import ArgumentInterceptor
from taskito.middleware import TaskMiddleware
from taskito.proxies import ProxyRegistry
from taskito.proxies.metrics import ProxyMetrics
from taskito.resources.runtime import ResourceRuntime
from taskito.serializers import Serializer


logger = logging.getLogger("taskito")


def _resolve_module_name(module_name: str) -> str:
"""Resolve __main__ to the actual module name."""
if module_name != "__main__":
Expand Down Expand Up @@ -51,8 +63,134 @@ class QueueDecoratorMixin:
_task_retry_filters: dict[str, dict[str, list[type[Exception]]]]
_task_inject_map: dict[str, list[str]]
_interceptor: ArgumentInterceptor | None
_proxy_registry: ProxyRegistry | None
_proxy_metrics: ProxyMetrics
_resource_runtime: ResourceRuntime | None
_test_mode_active: bool
_recipe_signing_key: str | None
_max_reconstruction_timeout: int
_global_middleware: list[TaskMiddleware]
_queue_configs: dict[str, dict[str, Any]]

# ``_emit_event`` is provided by ``QueueEventsMixin`` on the composed
# Queue. Declaring it as a class-level callable attribute (not a method)
# lets mypy see it from this mixin without overriding the real
# implementation through MRO.
_emit_event: Callable[..., None]

def _get_middleware_chain(self, task_name: str) -> list[TaskMiddleware]:
"""Get the combined global + per-task middleware list."""
per_task = self._task_middleware.get(task_name, [])
return self._global_middleware + per_task

def _wrap_task(
self, fn: Callable, task_name: str, soft_timeout: float | None = None
) -> Callable:
"""Wrap a task function with hooks, middleware, and job context."""
hooks = self._hooks
queue_ref = self

@functools.wraps(fn)
def wrapper(*args: Any, **kwargs: Any) -> Any:
# Reconstruct intercepted arguments (CONVERT markers → original types)
redirects: dict[str, str] = {}
if queue_ref._interceptor is not None:
args, kwargs, redirects = reconstruct_args(args, kwargs)

# Reconstruct proxy markers (PROXY → live objects)
proxy_cleanup: list[Any] = []
if queue_ref._proxy_registry is not None and not queue_ref._test_mode_active:
args, kwargs, proxy_cleanup = reconstruct_proxies(
args,
kwargs,
queue_ref._proxy_registry,
signing_secret=queue_ref._recipe_signing_key,
max_timeout=queue_ref._max_reconstruction_timeout,
metrics=queue_ref._proxy_metrics,
)

# Inject resources from runtime
release_callbacks: list[Any] = []
runtime = queue_ref._resource_runtime
if runtime is not None:
# From explicit inject=["db"] on task decorator
for res_name in queue_ref._task_inject_map.get(task_name, []):
if res_name not in kwargs:
instance, release = runtime.acquire_for_task(res_name)
kwargs[res_name] = instance
if release is not None:
release_callbacks.append(release)
# From interception REDIRECT markers
for kwarg_name, resource_name in redirects.items():
instance, release = runtime.acquire_for_task(resource_name)
kwargs[kwarg_name] = instance
if release is not None:
release_callbacks.append(release)

middleware_chain = queue_ref._get_middleware_chain(task_name)

# Set soft timeout on context if configured
if soft_timeout is not None:
current_job._set_soft_timeout(soft_timeout)

# Run middleware before hooks
completed_mw: list[Any] = []
for mw in middleware_chain:
try:
mw.before(current_job)
completed_mw.append(mw)
except Exception:
logger.exception("middleware before() error")

for hook in hooks["before_task"]:
hook(task_name, args, kwargs)

error = None
result = None
try:
ret = run_maybe_async(fn(*args, **kwargs))
result = ret
except Exception as exc:
error = exc
for hook in hooks["on_failure"]:
hook(task_name, args, kwargs, exc)
raise
else:
for hook in hooks["on_success"]:
hook(task_name, args, kwargs, result)
return result
finally:
# Release task/request-scoped resources
for release_fn in release_callbacks:
try:
release_fn()
except Exception:
logger.exception("resource release error")
# Clean up reconstructed proxies (LIFO order)
cleanup_proxies(proxy_cleanup, metrics=queue_ref._proxy_metrics)
for hook in hooks["after_task"]:
hook(task_name, args, kwargs, result, error)
# Run middleware after hooks (only those whose before() succeeded)
for mw in completed_mw:
try:
mw.after(current_job, result, error)
except Exception:
logger.exception("middleware after() error")
# Emit job lifecycle events
event_payload: dict[str, Any] = {
"task_name": task_name,
"job_id": current_job.id,
"queue": current_job.queue_name,
}
if error is not None:
event_payload["error"] = str(error)
queue_ref._emit_event(EventType.JOB_FAILED, event_payload)
else:
queue_ref._emit_event(EventType.JOB_COMPLETED, event_payload)
_clear_context()

return wrapper

def task(
self,
name: str | None = None,
Expand Down Expand Up @@ -160,7 +298,7 @@ def decorator(fn: Callable) -> TaskWrapper:
self._task_inject_map[task_name] = final_inject

# Wrap the function with hooks, middleware, and context
wrapped = self._wrap_task(fn, task_name, soft_timeout) # type: ignore[attr-defined]
wrapped = self._wrap_task(fn, task_name, soft_timeout)
self._task_registry[task_name] = wrapped

cb_threshold = None
Expand Down
Loading