diff --git a/backend/app/celery/celery_app.py b/backend/app/celery/celery_app.py index 5335d19d5..1eeb98f1c 100644 --- a/backend/app/celery/celery_app.py +++ b/backend/app/celery/celery_app.py @@ -1,4 +1,3 @@ -import importlib import logging from celery import Celery @@ -9,34 +8,6 @@ logger = logging.getLogger(__name__) -# All modules referenced as function_path in execute_high/low_priority_task calls. -# Pre-importing these at worker process startup eliminates the 2-5s cold-import -# penalty on the first task execution. -_JOB_MODULES = [ - "app.services.llm.jobs", - "app.services.response.jobs", - "app.services.doctransform.job", - "app.services.collections.create_collection", - "app.services.collections.delete_collection", - "app.services.stt_evaluations.batch_job", - "app.services.tts_evaluations.batch_job", - "app.services.tts_evaluations.batch_result_processing", - "app.services.stt_evaluations.metric_job", -] - - -@worker_process_init.connect -def warmup_job_modules(sender, **kwargs: object) -> None: - """Pre-import all job modules so the first task execution is not delayed by cold imports.""" - for module_path in _JOB_MODULES: - try: - importlib.import_module(module_path) - logger.debug(f"[warmup_job_modules] Pre-imported {module_path}") - except Exception as exc: - logger.warning( - f"[warmup_job_modules] Failed to pre-import {module_path}: {exc}" - ) - # Create Celery instance celery_app = Celery( diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index 8274de960..58c961902 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -1,12 +1,38 @@ import logging -import importlib +from collections.abc import Callable from celery import current_task from asgi_correlation_id import correlation_id from app.celery.celery_app import celery_app +import app.services.llm.jobs as _llm_jobs +import app.services.response.jobs as _response_jobs +import app.services.doctransform.job as _doctransform_job +import app.services.collections.create_collection as _create_collection +import app.services.collections.delete_collection as _delete_collection +import app.services.stt_evaluations.batch_job as _stt_batch_job +import app.services.stt_evaluations.metric_job as _stt_metric_job +import app.services.tts_evaluations.batch_job as _tts_batch_job +import app.services.tts_evaluations.batch_result_processing as _tts_result_processing logger = logging.getLogger(__name__) +# Hardcoded dispatch table — avoids dynamic importlib at task execution time. +# Imports above happen once in the main Celery process before worker forks, +# so all child workers inherit them via copy-on-write instead of each loading +# them independently (which was causing OOM with warmup_job_modules). +_FUNCTION_REGISTRY: dict[str, Callable] = { + "app.services.llm.jobs.execute_job": _llm_jobs.execute_job, + "app.services.llm.jobs.execute_chain_job": _llm_jobs.execute_chain_job, + "app.services.response.jobs.execute_job": _response_jobs.execute_job, + "app.services.doctransform.job.execute_job": _doctransform_job.execute_job, + "app.services.collections.create_collection.execute_job": _create_collection.execute_job, + "app.services.collections.delete_collection.execute_job": _delete_collection.execute_job, + "app.services.stt_evaluations.batch_job.execute_batch_submission": _stt_batch_job.execute_batch_submission, + "app.services.stt_evaluations.metric_job.execute_metric_computation": _stt_metric_job.execute_metric_computation, + "app.services.tts_evaluations.batch_job.execute_batch_submission": _tts_batch_job.execute_batch_submission, + "app.services.tts_evaluations.batch_result_processing.execute_tts_result_processing": _tts_result_processing.execute_tts_result_processing, +} + @celery_app.task(bind=True, queue="high_priority") def execute_high_priority_task( @@ -85,10 +111,11 @@ def _execute_job_internal( logger.info(f"Set correlation ID context: {trace_id} for job {job_id}") try: - # Dynamically import and resolve the function - module_path, function_name = function_path.rsplit(".", 1) - module = importlib.import_module(module_path) - execute_function = getattr(module, function_name) + execute_function = _FUNCTION_REGISTRY.get(function_path) + if execute_function is None: + raise ValueError( + f"[_execute_job_internal] Unknown function path: {function_path}" + ) logger.info( f"Executing {priority} job {job_id} (task {task_id}) using function {function_path}" diff --git a/backend/app/celery/utils.py b/backend/app/celery/utils.py index 957c02d9a..e500a5d63 100644 --- a/backend/app/celery/utils.py +++ b/backend/app/celery/utils.py @@ -7,10 +7,6 @@ from celery.result import AsyncResult from app.celery.celery_app import celery_app -from app.celery.tasks.job_execution import ( - execute_high_priority_task, - execute_low_priority_task, -) logger = logging.getLogger(__name__) @@ -31,6 +27,8 @@ def start_high_priority_job( Returns: Celery task ID (different from job_id) """ + from app.celery.tasks.job_execution import execute_high_priority_task + task = execute_high_priority_task.delay( function_path=function_path, project_id=project_id, @@ -59,6 +57,8 @@ def start_low_priority_job( Returns: Celery task ID (different from job_id) """ + from app.celery.tasks.job_execution import execute_low_priority_task + task = execute_low_priority_task.delay( function_path=function_path, project_id=project_id,