Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 0 additions & 29 deletions backend/app/celery/celery_app.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import importlib
import logging

from celery import Celery
Expand All @@ -9,34 +8,6 @@

logger = logging.getLogger(__name__)

# All modules referenced as function_path in execute_high/low_priority_task calls.
# Pre-importing these at worker process startup eliminates the 2-5s cold-import
# penalty on the first task execution.
_JOB_MODULES = [
"app.services.llm.jobs",
"app.services.response.jobs",
"app.services.doctransform.job",
"app.services.collections.create_collection",
"app.services.collections.delete_collection",
"app.services.stt_evaluations.batch_job",
"app.services.tts_evaluations.batch_job",
"app.services.tts_evaluations.batch_result_processing",
"app.services.stt_evaluations.metric_job",
]


@worker_process_init.connect
def warmup_job_modules(sender, **kwargs: object) -> None:
"""Pre-import all job modules so the first task execution is not delayed by cold imports."""
for module_path in _JOB_MODULES:
try:
importlib.import_module(module_path)
logger.debug(f"[warmup_job_modules] Pre-imported {module_path}")
except Exception as exc:
logger.warning(
f"[warmup_job_modules] Failed to pre-import {module_path}: {exc}"
)


# Create Celery instance
celery_app = Celery(
Expand Down
37 changes: 32 additions & 5 deletions backend/app/celery/tasks/job_execution.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,38 @@
import logging
import importlib
from collections.abc import Callable
from celery import current_task
from asgi_correlation_id import correlation_id

from app.celery.celery_app import celery_app
import app.services.llm.jobs as _llm_jobs
import app.services.response.jobs as _response_jobs
import app.services.doctransform.job as _doctransform_job
import app.services.collections.create_collection as _create_collection
import app.services.collections.delete_collection as _delete_collection
import app.services.stt_evaluations.batch_job as _stt_batch_job
import app.services.stt_evaluations.metric_job as _stt_metric_job
import app.services.tts_evaluations.batch_job as _tts_batch_job
import app.services.tts_evaluations.batch_result_processing as _tts_result_processing

logger = logging.getLogger(__name__)

# Hardcoded dispatch table — avoids dynamic importlib at task execution time.
# Imports above happen once in the main Celery process before worker forks,
# so all child workers inherit them via copy-on-write instead of each loading
# them independently (which was causing OOM with warmup_job_modules).
_FUNCTION_REGISTRY: dict[str, Callable] = {
"app.services.llm.jobs.execute_job": _llm_jobs.execute_job,
"app.services.llm.jobs.execute_chain_job": _llm_jobs.execute_chain_job,
"app.services.response.jobs.execute_job": _response_jobs.execute_job,
"app.services.doctransform.job.execute_job": _doctransform_job.execute_job,
"app.services.collections.create_collection.execute_job": _create_collection.execute_job,
"app.services.collections.delete_collection.execute_job": _delete_collection.execute_job,
"app.services.stt_evaluations.batch_job.execute_batch_submission": _stt_batch_job.execute_batch_submission,
"app.services.stt_evaluations.metric_job.execute_metric_computation": _stt_metric_job.execute_metric_computation,
"app.services.tts_evaluations.batch_job.execute_batch_submission": _tts_batch_job.execute_batch_submission,
"app.services.tts_evaluations.batch_result_processing.execute_tts_result_processing": _tts_result_processing.execute_tts_result_processing,
}


@celery_app.task(bind=True, queue="high_priority")
def execute_high_priority_task(
Expand Down Expand Up @@ -85,10 +111,11 @@ def _execute_job_internal(
logger.info(f"Set correlation ID context: {trace_id} for job {job_id}")

try:
# Dynamically import and resolve the function
module_path, function_name = function_path.rsplit(".", 1)
module = importlib.import_module(module_path)
execute_function = getattr(module, function_name)
execute_function = _FUNCTION_REGISTRY.get(function_path)
if execute_function is None:
raise ValueError(
f"[_execute_job_internal] Unknown function path: {function_path}"
)

logger.info(
f"Executing {priority} job {job_id} (task {task_id}) using function {function_path}"
Expand Down
8 changes: 4 additions & 4 deletions backend/app/celery/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
from celery.result import AsyncResult

from app.celery.celery_app import celery_app
from app.celery.tasks.job_execution import (
execute_high_priority_task,
execute_low_priority_task,
)

logger = logging.getLogger(__name__)

Expand All @@ -31,6 +27,8 @@ def start_high_priority_job(
Returns:
Celery task ID (different from job_id)
"""
from app.celery.tasks.job_execution import execute_high_priority_task

task = execute_high_priority_task.delay(
function_path=function_path,
project_id=project_id,
Expand Down Expand Up @@ -59,6 +57,8 @@ def start_low_priority_job(
Returns:
Celery task ID (different from job_id)
"""
from app.celery.tasks.job_execution import execute_low_priority_task

task = execute_low_priority_task.delay(
function_path=function_path,
project_id=project_id,
Expand Down
Loading