From 911084c93eaa56d2bd013e4e472dc78f68c085aa Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 9 Mar 2026 12:53:14 +0500 Subject: [PATCH 01/15] Move terminating jobs logic from services to background task --- .../scheduled_tasks/terminating_jobs.py | 390 +++++++++++++++++- .../server/services/jobs/__init__.py | 368 +---------------- 2 files changed, 388 insertions(+), 370 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py index 27163b53d..e7296956d 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py @@ -1,24 +1,55 @@ import asyncio +from datetime import timedelta +from typing import Optional from sqlalchemy import or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload -from dstack._internal.core.models.runs import JobStatus +from dstack._internal.core.backends.base.backend import Backend +from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport +from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.errors import BackendError +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.core.models.runs import ( + JobProvisioningData, + JobSpec, + JobStatus, + JobTerminationReason, + RunTerminationReason, +) +from dstack._internal.server import settings from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( InstanceModel, JobModel, ProjectModel, + RunModel, VolumeAttachmentModel, + VolumeModel, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services import events, services +from dstack._internal.server.services.instances import ( + format_instance_blocks_for_event, + get_instance_ssh_private_keys, + switch_instance_status, ) from dstack._internal.server.services.jobs import ( - process_terminating_job, - process_volumes_detaching, + get_job_provisioning_data, + get_job_runtime_data, + switch_job_status, ) from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.runner import client +from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel +from dstack._internal.server.services.volumes import ( + list_project_volume_models, + volume_model_to_volume, +) from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils import common from dstack._internal.utils.common import ( get_current_datetime, get_or_error, @@ -101,9 +132,358 @@ async def _process_job(session: AsyncSession, job_model: JobModel): ) instance_model = res.unique().scalar() if job_model.volumes_detached_at is None: - await process_terminating_job(session, job_model, instance_model) + await _process_terminating_job(session, job_model, instance_model) else: instance_model = get_or_error(instance_model) - await process_volumes_detaching(session, job_model, instance_model) + await _process_volumes_detaching(session, job_model, instance_model) job_model.last_processed_at = get_current_datetime() await session.commit() + + +async def _process_terminating_job( + session: AsyncSession, + job_model: JobModel, + instance_model: Optional[InstanceModel], +): + """ + Stops the job: tells shim to stop the container, detaches the job from the instance, + and detaches volumes from the instance. + Graceful stop should already be done by `process_terminating_run`. + Caller must acquire the locks on the job and the job's instance. + """ + if job_model.remove_at is not None and job_model.remove_at > common.get_current_datetime(): + # it's too early to terminate the instance + return + + if instance_model is None: + # Possible if the job hasn't been assigned an instance yet + await services.unregister_replica(session, job_model) + _set_job_termination_status(session, job_model) + return + + all_volumes_detached: bool = True + jrd = get_job_runtime_data(job_model) + jpd = get_job_provisioning_data(job_model) + if jpd is not None: + logger.debug("%s: stopping container", fmt(job_model)) + ssh_private_keys = get_instance_ssh_private_keys(instance_model) + if not await _stop_container(job_model, jpd, ssh_private_keys): + # The dangling container can be removed later during instance processing + logger.warning( + ( + "%s: could not stop container, possibly due to a communication error." + " See debug logs for details." + " Ignoring, can attempt to remove the container later" + ), + fmt(job_model), + ) + if jrd is not None and jrd.volume_names is not None: + volume_names = jrd.volume_names + else: + # Legacy jobs before job_runtime_data/blocks were introduced + volume_names = [va.volume.name for va in instance_model.volume_attachments] + volume_models = await list_project_volume_models( + session=session, project=instance_model.project, names=volume_names + ) + if len(volume_models) > 0: + logger.info("Detaching volumes: %s", [v.name for v in volume_models]) + all_volumes_detached = await _detach_volumes_from_job_instance( + session=session, + project=instance_model.project, + job_model=job_model, + jpd=jpd, + instance_model=instance_model, + volume_models=volume_models, + ) + + if jrd is not None and jrd.offer is not None: + blocks = jrd.offer.blocks + else: + # Old job submitted before jrd or blocks were introduced + blocks = 1 + instance_model.busy_blocks -= blocks + + if instance_model.status != InstanceStatus.BUSY or jpd is None or not jpd.dockerized: + # Terminate instances that: + # - have not finished provisioning yet + # - belong to container-based backends, and hence cannot be reused + if instance_model.status not in InstanceStatus.finished_statuses(): + instance_model.termination_reason = InstanceTerminationReason.JOB_FINISHED + switch_instance_status(session, instance_model, InstanceStatus.TERMINATING) + elif not [j for j in instance_model.jobs if j.id != job_model.id]: + # no other jobs besides this one + switch_instance_status(session, instance_model, InstanceStatus.IDLE) + + # The instance should be released even if detach fails + # so that stuck volumes don't prevent the instance from terminating. + job_model.instance_id = None + instance_model.last_job_processed_at = common.get_current_datetime() + + events.emit( + session, + ( + "Job unassigned from instance." + f" Instance blocks: {format_instance_blocks_for_event(instance_model)}" + ), + actor=events.SystemActor(), + targets=[ + events.Target.from_model(job_model), + events.Target.from_model(instance_model), + ], + ) + + volume_names = ( + jrd.volume_names + if jrd and jrd.volume_names + else [va.volume.name for va in instance_model.volume_attachments] + ) + if volume_names: + volumes = await list_project_volume_models( + session=session, project=instance_model.project, names=volume_names + ) + for volume in volumes: + volume.last_job_processed_at = common.get_current_datetime() + + await services.unregister_replica(session, job_model) + if all_volumes_detached: + # Do not terminate while some volumes are not detached. + _set_job_termination_status(session, job_model) + + +async def _process_volumes_detaching( + session: AsyncSession, + job_model: JobModel, + instance_model: InstanceModel, +): + """ + Called after job's volumes have been soft detached to check if they are detached. + Terminates the job when all the volumes are detached. + If the volumes fail to detach, force detaches them. + """ + jpd = get_or_error(get_job_provisioning_data(job_model)) + jrd = get_job_runtime_data(job_model) + if jrd is not None and jrd.volume_names is not None: + volume_names = jrd.volume_names + else: + # Legacy jobs before job_runtime_data/blocks were introduced + volume_names = [va.volume.name for va in instance_model.volume_attachments] + volume_models = await list_project_volume_models( + session=session, project=instance_model.project, names=volume_names + ) + logger.info("Detaching volumes: %s", [v.name for v in volume_models]) + all_volumes_detached = await _detach_volumes_from_job_instance( + session=session, + project=instance_model.project, + job_model=job_model, + jpd=jpd, + instance_model=instance_model, + volume_models=volume_models, + ) + if all_volumes_detached: + # Do not terminate the job while some volumes are not detached. + # If force detach never succeeds, the job will be stuck terminating. + # The job releases the instance when soft detaching, so the instance won't be stuck. + _set_job_termination_status(session, job_model) + + +def _set_job_termination_status(session: AsyncSession, job_model: JobModel): + if job_model.termination_reason is not None: + status = job_model.termination_reason.to_status() + else: + status = JobStatus.FAILED + switch_job_status(session, job_model, status) + + +async def _stop_container( + job_model: JobModel, + job_provisioning_data: JobProvisioningData, + ssh_private_keys: tuple[str, Optional[str]], +) -> bool: + if job_provisioning_data.dockerized: + # send a request to the shim to terminate the docker container + # SSHError and RequestException are caught in the `runner_ssh_tunner` decorator + return await common.run_async( + _shim_submit_stop, + ssh_private_keys, + job_provisioning_data, + None, + job_model, + ) + return True + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) +def _shim_submit_stop(ports: dict[int, int], job_model: JobModel) -> bool: + shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + + resp = shim_client.healthcheck() + if resp is None: + logger.debug("%s: can't stop container, shim is not available yet", fmt(job_model)) + return False # shim is not available yet + + # we force-kill container because the runner had time to gracefully stop the job + if shim_client.is_api_v2_supported(): + if job_model.termination_reason is None: + reason = None + else: + reason = job_model.termination_reason.value + shim_client.terminate_task( + task_id=job_model.id, + reason=reason, + message=job_model.termination_reason_message, + timeout=0, + ) + # maybe somehow postpone removing old tasks to allow inspecting failed jobs without + # the following setting? + if not settings.SERVER_KEEP_SHIM_TASKS: + shim_client.remove_task(task_id=job_model.id) + else: + shim_client.stop(force=True) + return True + + +async def _detach_volumes_from_job_instance( + session: AsyncSession, + project: ProjectModel, + job_model: JobModel, + jpd: JobProvisioningData, + instance_model: InstanceModel, + volume_models: list[VolumeModel], +) -> bool: + job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + backend = await backends_services.get_project_backend_by_type( + project=project, + backend_type=jpd.backend, + ) + if backend is None: + logger.error( + "Failed to detach volumes from %s. Backend not available.", instance_model.name + ) + return False + + all_detached = True + detached_volumes = [] + run_termination_reason = await _get_run_termination_reason(session, job_model) + for volume_model in volume_models: + detached = await _detach_volume_from_job_instance( + backend=backend, + job_model=job_model, + jpd=jpd, + job_spec=job_spec, + instance_model=instance_model, + volume_model=volume_model, + run_termination_reason=run_termination_reason, + ) + if detached: + detached_volumes.append(volume_model) + else: + all_detached = False + + if job_model.volumes_detached_at is None: + job_model.volumes_detached_at = common.get_current_datetime() + detached_volumes_ids = {v.id for v in detached_volumes} + instance_model.volume_attachments = [ + va for va in instance_model.volume_attachments if va.volume_id not in detached_volumes_ids + ] + return all_detached + + +async def _detach_volume_from_job_instance( + backend: Backend, + job_model: JobModel, + jpd: JobProvisioningData, + job_spec: JobSpec, + instance_model: InstanceModel, + volume_model: VolumeModel, + run_termination_reason: Optional[RunTerminationReason], +) -> bool: + detached = True + volume = volume_model_to_volume(volume_model) + if volume.provisioning_data is None or not volume.provisioning_data.detachable: + # Backends without `detach_volume` detach volumes automatically + return detached + compute = backend.compute() + assert isinstance(compute, ComputeWithVolumeSupport) + try: + if job_model.volumes_detached_at is None: + # We haven't tried detaching volumes yet, try soft detach first + await common.run_async( + compute.detach_volume, + volume=volume, + provisioning_data=jpd, + force=False, + ) + # For some backends, the volume may be detached immediately + detached = await common.run_async( + compute.is_volume_detached, + volume=volume, + provisioning_data=jpd, + ) + else: + detached = await common.run_async( + compute.is_volume_detached, + volume=volume, + provisioning_data=jpd, + ) + if not detached and _should_force_detach_volume( + job_model, + run_termination_reason=run_termination_reason, + stop_duration=job_spec.stop_duration, + ): + logger.info( + "Force detaching volume %s from %s", + volume_model.name, + instance_model.name, + ) + await common.run_async( + compute.detach_volume, + volume=volume, + provisioning_data=jpd, + force=True, + ) + # Let the next iteration check if force detach worked + except BackendError as e: + logger.error( + "Failed to detach volume %s from %s: %s", + volume_model.name, + instance_model.name, + repr(e), + ) + except Exception: + logger.exception( + "Got exception when detaching volume %s from instance %s", + volume_model.name, + instance_model.name, + ) + return detached + + +async def _get_run_termination_reason( + session: AsyncSession, job_model: JobModel +) -> Optional[RunTerminationReason]: + res = await session.execute( + select(RunModel.termination_reason).where(RunModel.id == job_model.run_id) + ) + return res.scalar_one_or_none() + + +_MIN_FORCE_DETACH_WAIT_PERIOD = timedelta(seconds=60) + + +def _should_force_detach_volume( + job_model: JobModel, + run_termination_reason: Optional[RunTerminationReason], + stop_duration: Optional[int], +) -> bool: + return ( + job_model.volumes_detached_at is not None + and common.get_current_datetime() + > job_model.volumes_detached_at + _MIN_FORCE_DETACH_WAIT_PERIOD + and ( + job_model.termination_reason == JobTerminationReason.ABORTED_BY_USER + or run_termination_reason == RunTerminationReason.ABORTED_BY_USER + or stop_duration is not None + and common.get_current_datetime() + > job_model.volumes_detached_at + timedelta(seconds=stop_duration) + ) + ) diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 4c094c676..66e7ebdb2 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -9,19 +9,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, load_only -import dstack._internal.server.services.backends as backends_services -from dstack._internal.core.backends.base.backend import Backend -from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport -from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT from dstack._internal.core.errors import ( - BackendError, ResourceNotExistsError, ServerClientError, SSHError, ) from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import RunConfigurationType -from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.runs import ( Job, JobProvisioningData, @@ -31,10 +26,8 @@ JobSubmission, JobTerminationReason, RunSpec, - RunTerminationReason, ) from dstack._internal.core.models.volumes import Volume, VolumeMountPoint, VolumeStatus -from dstack._internal.server import settings from dstack._internal.server.models import ( InstanceModel, JobModel, @@ -42,12 +35,10 @@ RunModel, VolumeModel, ) -from dstack._internal.server.services import events, services +from dstack._internal.server.services import events from dstack._internal.server.services import volumes as volumes_services from dstack._internal.server.services.instances import ( - format_instance_blocks_for_event, get_instance_ssh_private_keys, - switch_instance_status, ) from dstack._internal.server.services.jobs.configurators.base import ( JobConfigurator, @@ -60,12 +51,8 @@ from dstack._internal.server.services.probes import probe_model_to_probe from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel -from dstack._internal.server.services.volumes import ( - list_project_volume_models, - volume_model_to_volume, -) from dstack._internal.utils import common -from dstack._internal.utils.common import get_or_error, run_async +from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -302,208 +289,6 @@ def _stop_runner( logger.exception("%s: failed to stop runner gracefully", fmt(job_model)) -async def process_terminating_job( - session: AsyncSession, - job_model: JobModel, - instance_model: Optional[InstanceModel], -): - """ - Stops the job: tells shim to stop the container, detaches the job from the instance, - and detaches volumes from the instance. - Graceful stop should already be done by `process_terminating_run`. - Caller must acquire the locks on the job and the job's instance. - """ - if job_model.remove_at is not None and job_model.remove_at > common.get_current_datetime(): - # it's too early to terminate the instance - return - - if instance_model is None: - # Possible if the job hasn't been assigned an instance yet - await services.unregister_replica(session, job_model) - _set_job_termination_status(session, job_model) - return - - all_volumes_detached: bool = True - jrd = get_job_runtime_data(job_model) - jpd = get_job_provisioning_data(job_model) - if jpd is not None: - logger.debug("%s: stopping container", fmt(job_model)) - ssh_private_keys = get_instance_ssh_private_keys(instance_model) - if not await stop_container(job_model, jpd, ssh_private_keys): - # The dangling container can be removed later during instance processing - logger.warning( - ( - "%s: could not stop container, possibly due to a communication error." - " See debug logs for details." - " Ignoring, can attempt to remove the container later" - ), - fmt(job_model), - ) - if jrd is not None and jrd.volume_names is not None: - volume_names = jrd.volume_names - else: - # Legacy jobs before job_runtime_data/blocks were introduced - volume_names = [va.volume.name for va in instance_model.volume_attachments] - volume_models = await list_project_volume_models( - session=session, project=instance_model.project, names=volume_names - ) - if len(volume_models) > 0: - logger.info("Detaching volumes: %s", [v.name for v in volume_models]) - all_volumes_detached = await _detach_volumes_from_job_instance( - session=session, - project=instance_model.project, - job_model=job_model, - jpd=jpd, - instance_model=instance_model, - volume_models=volume_models, - ) - - if jrd is not None and jrd.offer is not None: - blocks = jrd.offer.blocks - else: - # Old job submitted before jrd or blocks were introduced - blocks = 1 - instance_model.busy_blocks -= blocks - - if instance_model.status != InstanceStatus.BUSY or jpd is None or not jpd.dockerized: - # Terminate instances that: - # - have not finished provisioning yet - # - belong to container-based backends, and hence cannot be reused - if instance_model.status not in InstanceStatus.finished_statuses(): - instance_model.termination_reason = InstanceTerminationReason.JOB_FINISHED - switch_instance_status(session, instance_model, InstanceStatus.TERMINATING) - elif not [j for j in instance_model.jobs if j.id != job_model.id]: - # no other jobs besides this one - switch_instance_status(session, instance_model, InstanceStatus.IDLE) - - # The instance should be released even if detach fails - # so that stuck volumes don't prevent the instance from terminating. - job_model.instance_id = None - instance_model.last_job_processed_at = common.get_current_datetime() - - events.emit( - session, - ( - "Job unassigned from instance." - f" Instance blocks: {format_instance_blocks_for_event(instance_model)}" - ), - actor=events.SystemActor(), - targets=[ - events.Target.from_model(job_model), - events.Target.from_model(instance_model), - ], - ) - - volume_names = ( - jrd.volume_names - if jrd and jrd.volume_names - else [va.volume.name for va in instance_model.volume_attachments] - ) - if volume_names: - volumes = await list_project_volume_models( - session=session, project=instance_model.project, names=volume_names - ) - for volume in volumes: - volume.last_job_processed_at = common.get_current_datetime() - - await services.unregister_replica(session, job_model) - if all_volumes_detached: - # Do not terminate while some volumes are not detached. - _set_job_termination_status(session, job_model) - - -async def process_volumes_detaching( - session: AsyncSession, - job_model: JobModel, - instance_model: InstanceModel, -): - """ - Called after job's volumes have been soft detached to check if they are detached. - Terminates the job when all the volumes are detached. - If the volumes fail to detach, force detaches them. - """ - jpd = get_or_error(get_job_provisioning_data(job_model)) - jrd = get_job_runtime_data(job_model) - if jrd is not None and jrd.volume_names is not None: - volume_names = jrd.volume_names - else: - # Legacy jobs before job_runtime_data/blocks were introduced - volume_names = [va.volume.name for va in instance_model.volume_attachments] - volume_models = await list_project_volume_models( - session=session, project=instance_model.project, names=volume_names - ) - logger.info("Detaching volumes: %s", [v.name for v in volume_models]) - all_volumes_detached = await _detach_volumes_from_job_instance( - session=session, - project=instance_model.project, - job_model=job_model, - jpd=jpd, - instance_model=instance_model, - volume_models=volume_models, - ) - if all_volumes_detached: - # Do not terminate the job while some volumes are not detached. - # If force detach never succeeds, the job will be stuck terminating. - # The job releases the instance when soft detaching, so the instance won't be stuck. - _set_job_termination_status(session, job_model) - - -def _set_job_termination_status(session: AsyncSession, job_model: JobModel): - if job_model.termination_reason is not None: - status = job_model.termination_reason.to_status() - else: - status = JobStatus.FAILED - switch_job_status(session, job_model, status) - - -async def stop_container( - job_model: JobModel, - job_provisioning_data: JobProvisioningData, - ssh_private_keys: tuple[str, Optional[str]], -) -> bool: - if job_provisioning_data.dockerized: - # send a request to the shim to terminate the docker container - # SSHError and RequestException are caught in the `runner_ssh_tunner` decorator - return await run_async( - _shim_submit_stop, - ssh_private_keys, - job_provisioning_data, - None, - job_model, - ) - return True - - -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) -def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel) -> bool: - shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) - - resp = shim_client.healthcheck() - if resp is None: - logger.debug("%s: can't stop container, shim is not available yet", fmt(job_model)) - return False # shim is not available yet - - # we force-kill container because the runner had time to gracefully stop the job - if shim_client.is_api_v2_supported(): - if job_model.termination_reason is None: - reason = None - else: - reason = job_model.termination_reason.value - shim_client.terminate_task( - task_id=job_model.id, - reason=reason, - message=job_model.termination_reason_message, - timeout=0, - ) - # maybe somehow postpone removing old tasks to allow inspecting failed jobs without - # the following setting? - if not settings.SERVER_KEEP_SHIM_TASKS: - shim_client.remove_task(task_id=job_model.id) - else: - shim_client.stop(force=True) - return True - - def group_jobs_by_replica_latest(jobs: List[JobModel]) -> Iterable[Tuple[int, List[JobModel]]]: """ Args: @@ -525,153 +310,6 @@ def group_jobs_by_replica_latest(jobs: List[JobModel]) -> Iterable[Tuple[int, Li yield replica_num, replica_jobs -async def _detach_volumes_from_job_instance( - session: AsyncSession, - project: ProjectModel, - job_model: JobModel, - jpd: JobProvisioningData, - instance_model: InstanceModel, - volume_models: list[VolumeModel], -) -> bool: - job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) - backend = await backends_services.get_project_backend_by_type( - project=project, - backend_type=jpd.backend, - ) - if backend is None: - logger.error( - "Failed to detach volumes from %s. Backend not available.", instance_model.name - ) - return False - - all_detached = True - detached_volumes = [] - run_termination_reason = await _get_run_termination_reason(session, job_model) - for volume_model in volume_models: - detached = await _detach_volume_from_job_instance( - backend=backend, - job_model=job_model, - jpd=jpd, - job_spec=job_spec, - instance_model=instance_model, - volume_model=volume_model, - run_termination_reason=run_termination_reason, - ) - if detached: - detached_volumes.append(volume_model) - else: - all_detached = False - - if job_model.volumes_detached_at is None: - job_model.volumes_detached_at = common.get_current_datetime() - detached_volumes_ids = {v.id for v in detached_volumes} - instance_model.volume_attachments = [ - va for va in instance_model.volume_attachments if va.volume_id not in detached_volumes_ids - ] - return all_detached - - -async def _detach_volume_from_job_instance( - backend: Backend, - job_model: JobModel, - jpd: JobProvisioningData, - job_spec: JobSpec, - instance_model: InstanceModel, - volume_model: VolumeModel, - run_termination_reason: Optional[RunTerminationReason], -) -> bool: - detached = True - volume = volume_model_to_volume(volume_model) - if volume.provisioning_data is None or not volume.provisioning_data.detachable: - # Backends without `detach_volume` detach volumes automatically - return detached - compute = backend.compute() - assert isinstance(compute, ComputeWithVolumeSupport) - try: - if job_model.volumes_detached_at is None: - # We haven't tried detaching volumes yet, try soft detach first - await run_async( - compute.detach_volume, - volume=volume, - provisioning_data=jpd, - force=False, - ) - # For some backends, the volume may be detached immediately - detached = await run_async( - compute.is_volume_detached, - volume=volume, - provisioning_data=jpd, - ) - else: - detached = await run_async( - compute.is_volume_detached, - volume=volume, - provisioning_data=jpd, - ) - if not detached and _should_force_detach_volume( - job_model, - run_termination_reason=run_termination_reason, - stop_duration=job_spec.stop_duration, - ): - logger.info( - "Force detaching volume %s from %s", - volume_model.name, - instance_model.name, - ) - await run_async( - compute.detach_volume, - volume=volume, - provisioning_data=jpd, - force=True, - ) - # Let the next iteration check if force detach worked - except BackendError as e: - logger.error( - "Failed to detach volume %s from %s: %s", - volume_model.name, - instance_model.name, - repr(e), - ) - except Exception: - logger.exception( - "Got exception when detaching volume %s from instance %s", - volume_model.name, - instance_model.name, - ) - return detached - - -MIN_FORCE_DETACH_WAIT_PERIOD = timedelta(seconds=60) - - -async def _get_run_termination_reason( - session: AsyncSession, job_model: JobModel -) -> Optional[RunTerminationReason]: - res = await session.execute( - select(RunModel.termination_reason).where(RunModel.id == job_model.run_id) - ) - return res.scalar_one_or_none() - - -def _should_force_detach_volume( - job_model: JobModel, - run_termination_reason: Optional[RunTerminationReason], - stop_duration: Optional[int], -) -> bool: - return ( - job_model.volumes_detached_at is not None - and common.get_current_datetime() - > job_model.volumes_detached_at + MIN_FORCE_DETACH_WAIT_PERIOD - and ( - job_model.termination_reason == JobTerminationReason.ABORTED_BY_USER - or run_termination_reason == RunTerminationReason.ABORTED_BY_USER - or stop_duration is not None - and common.get_current_datetime() - > job_model.volumes_detached_at + timedelta(seconds=stop_duration) - ) - ) - - async def get_instances_ids_with_detaching_volumes(session: AsyncSession) -> List[UUID]: res = await session.execute( select(JobModel) From 84fc86d8f3dbbfb52aaa6fda9f5080ab24a00b14 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 9 Mar 2026 14:20:25 +0500 Subject: [PATCH 02/15] Simplify _process_terminating_job() code --- .../scheduled_tasks/terminating_jobs.py | 75 ++++++++++--------- src/dstack/_internal/server/models.py | 3 + 2 files changed, 42 insertions(+), 36 deletions(-) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py index e7296956d..6fc391f01 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py @@ -13,6 +13,7 @@ from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.runs import ( JobProvisioningData, + JobRuntimeData, JobSpec, JobStatus, JobTerminationReason, @@ -151,10 +152,6 @@ async def _process_terminating_job( Graceful stop should already be done by `process_terminating_run`. Caller must acquire the locks on the job and the job's instance. """ - if job_model.remove_at is not None and job_model.remove_at > common.get_current_datetime(): - # it's too early to terminate the instance - return - if instance_model is None: # Possible if the job hasn't been assigned an instance yet await services.unregister_replica(session, job_model) @@ -164,6 +161,12 @@ async def _process_terminating_job( all_volumes_detached: bool = True jrd = get_job_runtime_data(job_model) jpd = get_job_provisioning_data(job_model) + volume_models = await _get_job_volume_models( + session=session, + job_model=job_model, + instance_model=instance_model, + jrd=jrd, + ) if jpd is not None: logger.debug("%s: stopping container", fmt(job_model)) ssh_private_keys = get_instance_ssh_private_keys(instance_model) @@ -177,14 +180,6 @@ async def _process_terminating_job( ), fmt(job_model), ) - if jrd is not None and jrd.volume_names is not None: - volume_names = jrd.volume_names - else: - # Legacy jobs before job_runtime_data/blocks were introduced - volume_names = [va.volume.name for va in instance_model.volume_attachments] - volume_models = await list_project_volume_models( - session=session, project=instance_model.project, names=volume_names - ) if len(volume_models) > 0: logger.info("Detaching volumes: %s", [v.name for v in volume_models]) all_volumes_detached = await _detach_volumes_from_job_instance( @@ -196,13 +191,7 @@ async def _process_terminating_job( volume_models=volume_models, ) - if jrd is not None and jrd.offer is not None: - blocks = jrd.offer.blocks - else: - # Old job submitted before jrd or blocks were introduced - blocks = 1 - instance_model.busy_blocks -= blocks - + instance_model.busy_blocks -= _get_job_occupied_blocks(jrd) if instance_model.status != InstanceStatus.BUSY or jpd is None or not jpd.dockerized: # Terminate instances that: # - have not finished provisioning yet @@ -232,22 +221,38 @@ async def _process_terminating_job( ], ) + for volume_model in volume_models: + volume_model.last_job_processed_at = common.get_current_datetime() + + await services.unregister_replica(session, job_model) + if all_volumes_detached: + # Do not terminate while some volumes are not detached. + _set_job_termination_status(session, job_model) + + +async def _get_job_volume_models( + session: AsyncSession, + job_model: JobModel, + instance_model: InstanceModel, + jrd: Optional[JobRuntimeData], +) -> list[VolumeModel]: volume_names = ( jrd.volume_names if jrd and jrd.volume_names else [va.volume.name for va in instance_model.volume_attachments] ) - if volume_names: - volumes = await list_project_volume_models( - session=session, project=instance_model.project, names=volume_names - ) - for volume in volumes: - volume.last_job_processed_at = common.get_current_datetime() + if len(volume_names) == 0: + return [] + return await list_project_volume_models( + session=session, project=instance_model.project, names=volume_names + ) - await services.unregister_replica(session, job_model) - if all_volumes_detached: - # Do not terminate while some volumes are not detached. - _set_job_termination_status(session, job_model) + +def _get_job_occupied_blocks(jrd: Optional[JobRuntimeData]) -> int: + if jrd is not None and jrd.offer is not None: + return jrd.offer.blocks + # Old job submitted before jrd or blocks were introduced + return 1 async def _process_volumes_detaching( @@ -262,13 +267,11 @@ async def _process_volumes_detaching( """ jpd = get_or_error(get_job_provisioning_data(job_model)) jrd = get_job_runtime_data(job_model) - if jrd is not None and jrd.volume_names is not None: - volume_names = jrd.volume_names - else: - # Legacy jobs before job_runtime_data/blocks were introduced - volume_names = [va.volume.name for va in instance_model.volume_attachments] - volume_models = await list_project_volume_models( - session=session, project=instance_model.project, names=volume_names + volume_models = await _get_job_volume_models( + session=session, + job_model=job_model, + instance_model=instance_model, + jrd=jrd, ) logger.info("Detaching volumes: %s", [v.name for v in volume_models]) all_volumes_detached = await _detach_volumes_from_job_instance( diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index da733054f..3550183bd 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -809,6 +809,9 @@ class VolumeModel(PipelineModelMixin, BaseModel): NaiveDateTime, default=get_current_datetime ) last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) + """`last_job_processed_at` records the last time the volume was used by a job. + Updated when a job terminates and used to delete volumes on `auto_cleanup_duration`. + """ deleted: Mapped[bool] = mapped_column(Boolean, default=False) deleted_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) to_be_deleted: Mapped[bool] = mapped_column(Boolean, server_default=false()) From 7724189286331d3a74ccea3253d6a0aba4c1da0b Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 9 Mar 2026 14:39:52 +0500 Subject: [PATCH 03/15] Implement JobTerminatingPipeline scaffolding --- .../pipeline_tasks/terminating_jobs.py | 173 +++++++++++++ ...6b29d78c7_add_jobmodel_pipeline_columns.py | 47 ++++ src/dstack/_internal/server/models.py | 2 +- .../pipeline_tasks/test_terminating_jobs.py | 227 ++++++++++++++++++ 4 files changed, 448 insertions(+), 1 deletion(-) create mode 100644 src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py create mode 100644 src/dstack/_internal/server/migrations/versions/2026/03_09_0928_6026b29d78c7_add_jobmodel_pipeline_columns.py create mode 100644 src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py diff --git a/src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py new file mode 100644 index 000000000..01fcf82a7 --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py @@ -0,0 +1,173 @@ +import asyncio +import uuid +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Optional, Sequence + +from sqlalchemy import or_, select +from sqlalchemy.orm import load_only + +from dstack._internal.core.models.runs import JobStatus +from dstack._internal.server.background.pipeline_tasks.base import ( + Fetcher, + Heartbeater, + Pipeline, + PipelineItem, + Worker, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import JobModel +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils.common import get_current_datetime + + +@dataclass +class JobTerminatingPipelineItem(PipelineItem): + volumes_detached_at: Optional[datetime] + + +class JobTerminatingPipeline(Pipeline[JobTerminatingPipelineItem]): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=15), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[JobTerminatingPipelineItem]( + model_type=JobModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = JobTerminatingFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + JobTerminatingWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return JobModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[JobTerminatingPipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[JobTerminatingPipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["JobTerminatingWorker"]: + return self.__workers + + +class JobTerminatingFetcher(Fetcher[JobTerminatingPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[JobTerminatingPipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[JobTerminatingPipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.JobTerminatingFetcher.fetch") + async def fetch(self, limit: int) -> list[JobTerminatingPipelineItem]: + job_lock, _ = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__) + async with job_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(JobModel) + .where( + JobModel.status == JobStatus.TERMINATING, + or_( + JobModel.remove_at.is_(None), + JobModel.remove_at < now, + ), + JobModel.last_processed_at <= now - self._min_processing_interval, + or_( + JobModel.lock_expires_at.is_(None), + JobModel.lock_expires_at < now, + ), + or_( + JobModel.lock_owner.is_(None), + JobModel.lock_owner == JobTerminatingPipeline.__name__, + ), + ) + .order_by(JobModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True) + .options( + load_only( + JobModel.id, + JobModel.lock_token, + JobModel.lock_expires_at, + JobModel.volumes_detached_at, + ) + ) + ) + job_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for job_model in job_models: + prev_lock_expired = job_model.lock_expires_at is not None + job_model.lock_expires_at = lock_expires_at + job_model.lock_token = lock_token + job_model.lock_owner = JobTerminatingPipeline.__name__ + items.append( + JobTerminatingPipelineItem( + __tablename__=JobModel.__tablename__, + id=job_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + volumes_detached_at=job_model.volumes_detached_at, + ) + ) + await session.commit() + return items + + +class JobTerminatingWorker(Worker[JobTerminatingPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[JobTerminatingPipelineItem], + heartbeater: Heartbeater[JobTerminatingPipelineItem], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.JobTerminatingWorker.process") + async def process(self, item: JobTerminatingPipelineItem): + raise NotImplementedError("JobTerminatingWorker.process is not implemented yet") diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_09_0928_6026b29d78c7_add_jobmodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/2026/03_09_0928_6026b29d78c7_add_jobmodel_pipeline_columns.py new file mode 100644 index 000000000..b2898e833 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_09_0928_6026b29d78c7_add_jobmodel_pipeline_columns.py @@ -0,0 +1,47 @@ +"""Add JobModel pipeline columns + +Revision ID: 6026b29d78c7 +Revises: c7b0a8e57294 +Create Date: 2026-03-09 09:28:17.993416+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "6026b29d78c7" +down_revision = "c7b0a8e57294" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 3550183bd..3861907d1 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -442,7 +442,7 @@ class RunModel(BaseModel): __table_args__ = (Index("ix_submitted_at_id", submitted_at.desc(), id),) -class JobModel(BaseModel): +class JobModel(PipelineModelMixin, BaseModel): __tablename__ = "jobs" id: Mapped[uuid.UUID] = mapped_column( diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py new file mode 100644 index 000000000..95e4457c6 --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py @@ -0,0 +1,227 @@ +import asyncio +import uuid +from datetime import datetime, timedelta, timezone +from unittest.mock import Mock + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.runs import JobStatus +from dstack._internal.server.background.pipeline_tasks.terminating_jobs import ( + JobTerminatingFetcher, + JobTerminatingPipeline, + JobTerminatingPipelineItem, + JobTerminatingWorker, +) +from dstack._internal.server.models import JobModel +from dstack._internal.server.testing.common import ( + create_job, + create_project, + create_repo, + create_run, + create_user, +) +from dstack._internal.utils.common import get_current_datetime + + +@pytest.fixture +def worker() -> JobTerminatingWorker: + return JobTerminatingWorker(queue=Mock(), heartbeater=Mock()) + + +@pytest.fixture +def fetcher() -> JobTerminatingFetcher: + return JobTerminatingFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=15), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + +def _job_to_pipeline_item(job_model: JobModel) -> JobTerminatingPipelineItem: + assert job_model.lock_token is not None + assert job_model.lock_expires_at is not None + return JobTerminatingPipelineItem( + __tablename__=job_model.__tablename__, + id=job_model.id, + lock_token=job_model.lock_token, + lock_expires_at=job_model.lock_expires_at, + prev_lock_expired=False, + volumes_detached_at=job_model.volumes_detached_at, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("image_config_mock") +class TestJobTerminatingFetcher: + async def test_fetch_selects_eligible_jobs_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: JobTerminatingFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + terminating = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=stale - timedelta(seconds=2), + ) + past_remove_at = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=stale - timedelta(seconds=1), + ) + past_remove_at.remove_at = stale + past_remove_at.volumes_detached_at = stale - timedelta(seconds=30) + + future_remove_at = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=stale, + ) + future_remove_at.remove_at = now + timedelta(minutes=1) + + non_terminating = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=stale, + ) + + recent = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=now, + ) + + locked = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=stale + timedelta(seconds=1), + ) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + + expired_same_owner = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=stale + timedelta(seconds=2), + ) + expired_same_owner.lock_expires_at = stale + expired_same_owner.lock_token = uuid.uuid4() + expired_same_owner.lock_owner = JobTerminatingPipeline.__name__ + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert [item.id for item in items] == [ + terminating.id, + past_remove_at.id, + expired_same_owner.id, + ] + assert {(item.id, item.volumes_detached_at) for item in items} == { + (terminating.id, None), + (past_remove_at.id, past_remove_at.volumes_detached_at), + (expired_same_owner.id, None), + } + + for job in [ + terminating, + past_remove_at, + future_remove_at, + non_terminating, + recent, + locked, + expired_same_owner, + ]: + await session.refresh(job) + + fetched_jobs = [terminating, past_remove_at, expired_same_owner] + assert all(job.lock_owner == JobTerminatingPipeline.__name__ for job in fetched_jobs) + assert all(job.lock_expires_at is not None for job in fetched_jobs) + assert all(job.lock_token is not None for job in fetched_jobs) + assert len({job.lock_token for job in fetched_jobs}) == 1 + + assert future_remove_at.lock_owner is None + assert non_terminating.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_returns_oldest_jobs_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: JobTerminatingFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + now = get_current_datetime() + + oldest = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=now - timedelta(minutes=5), + last_processed_at=now - timedelta(minutes=3), + ) + middle = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=now - timedelta(minutes=4), + last_processed_at=now - timedelta(minutes=2), + ) + newest = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=now - timedelta(minutes=3), + last_processed_at=now - timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == JobTerminatingPipeline.__name__ + assert middle.lock_owner == JobTerminatingPipeline.__name__ + assert newest.lock_owner is None + + +@pytest.mark.asyncio +class TestJobTerminatingWorker: + async def test_process_is_not_implemented(self, worker: JobTerminatingWorker): + item = JobTerminatingPipelineItem( + __tablename__=JobModel.__tablename__, + id=uuid.uuid4(), + lock_token=uuid.uuid4(), + lock_expires_at=datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc), + prev_lock_expired=False, + volumes_detached_at=None, + ) + + with pytest.raises(NotImplementedError, match="not implemented yet"): + await worker.process(item) From 7e02b1d7e5fe067d9a54c85c230712c48a2dfdb9 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 9 Mar 2026 16:53:53 +0500 Subject: [PATCH 04/15] Implement JobTerminatingWorker --- .../pipeline_tasks/terminating_jobs.py | 756 +++++++++++++++++- .../scheduled_tasks/terminating_jobs.py | 1 + .../server/services/jobs/__init__.py | 50 +- .../pipeline_tasks/test_terminating_jobs.py | 505 +++++++++++- 4 files changed, 1284 insertions(+), 28 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py index 01fcf82a7..da2a6b503 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py @@ -1,25 +1,78 @@ import asyncio import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timedelta -from typing import Optional, Sequence +from typing import Optional, Sequence, TypedDict -from sqlalchemy import or_, select -from sqlalchemy.orm import load_only +import httpx +from sqlalchemy import delete, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload, load_only -from dstack._internal.core.models.runs import JobStatus +from dstack._internal.core.backends.base.backend import Backend +from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport +from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.errors import BackendError, GatewayError, SSHError +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.core.models.runs import ( + JobProvisioningData, + JobRuntimeData, + JobSpec, + JobStatus, + JobTerminationReason, + RunTerminationReason, +) +from dstack._internal.server import settings from dstack._internal.server.background.pipeline_tasks.base import ( + NOW_PLACEHOLDER, Fetcher, Heartbeater, + ItemUpdateMap, Pipeline, PipelineItem, + UpdateMapDateTime, Worker, + log_lock_token_changed_after_processing, + log_lock_token_changed_on_reset, + log_lock_token_mismatch, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, ) from dstack._internal.server.db import get_db, get_session_ctx -from dstack._internal.server.models import JobModel +from dstack._internal.server.models import ( + InstanceModel, + JobModel, + ProjectModel, + RunModel, + VolumeAttachmentModel, + VolumeModel, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services import events +from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.instances import ( + emit_instance_status_change_event, + get_instance_ssh_private_keys, +) +from dstack._internal.server.services.jobs import ( + emit_job_status_change_event, + get_job_provisioning_data, + get_job_runtime_data, +) from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.runner import client +from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel +from dstack._internal.server.services.volumes import ( + volume_model_to_volume, +) from dstack._internal.server.utils import sentry_utils -from dstack._internal.utils.common import get_current_datetime +from dstack._internal.utils import common +from dstack._internal.utils.common import get_current_datetime, get_or_error +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) @dataclass @@ -170,4 +223,691 @@ def __init__( @sentry_utils.instrument_named_task("pipeline_tasks.JobTerminatingWorker.process") async def process(self, item: JobTerminatingPipelineItem): - raise NotImplementedError("JobTerminatingWorker.process is not implemented yet") + async with get_session_ctx() as session: + job_model = await _refetch_locked_job(session=session, item=item) + if job_model is None: + log_lock_token_mismatch(logger, item) + return + + instance_model: Optional[InstanceModel] = None + if job_model.used_instance_id is not None: + instance_model = await _lock_related_instance( + session=session, + item=item, + instance_id=job_model.used_instance_id, + ) + if instance_model is None: + await _reset_job_lock_for_retry(session=session, item=item) + return + if job_model.volumes_detached_at is not None and instance_model is None: + logger.error( + "%s: expected used_instance_id while detaching volumes", fmt(job_model) + ) + await _reset_job_lock_for_retry(session=session, item=item) + return + + if job_model.volumes_detached_at is None: + result = await _process_terminating_job( + job_model=job_model, + instance_model=instance_model, + ) + else: + result = await _process_job_volumes_detaching( + job_model=job_model, + instance_model=get_or_error(instance_model), + ) + + set_processed_update_map_fields(result.job_update_map) + set_unlock_update_map_fields(result.job_update_map) + if instance_model is not None: + if result.instance_update_map is None: + result.instance_update_map = _InstanceUpdateMap() + set_processed_update_map_fields(result.instance_update_map) + set_unlock_update_map_fields(result.instance_update_map) + await _apply_process_result( + item=item, + job_model=job_model, + instance_model=instance_model, + result=result, + ) + + +class _JobUpdateMap(ItemUpdateMap, total=False): + status: JobStatus + termination_reason: Optional[JobTerminationReason] + termination_reason_message: Optional[str] + instance_id: Optional[uuid.UUID] + volumes_detached_at: UpdateMapDateTime + registered: bool + + +class _InstanceUpdateMap(ItemUpdateMap, total=False): + status: InstanceStatus + termination_reason: Optional[InstanceTerminationReason] + termination_reason_message: Optional[str] + busy_blocks: int + last_job_processed_at: UpdateMapDateTime + lock_expires_at: Optional[datetime] + lock_token: Optional[uuid.UUID] + lock_owner: Optional[str] + + +class _VolumeUpdateRow(TypedDict): + id: uuid.UUID + last_job_processed_at: UpdateMapDateTime + + +@dataclass +class _ProcessResult: + job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) + instance_update_map: Optional[_InstanceUpdateMap] = None + volume_update_rows: list[_VolumeUpdateRow] = field(default_factory=list) + detached_volume_ids: set[uuid.UUID] = field(default_factory=set) + unassign_event_message: Optional[str] = None + emit_unregister_replica_event: bool = False + unregister_gateway_target: Optional[events.Target] = None + + +@dataclass +class _VolumeDetachResult: + all_detached: bool + detached_volume_ids: set[uuid.UUID] = field(default_factory=set) + set_volumes_detached_at: bool = False + + +async def _refetch_locked_job( + session: AsyncSession, item: JobTerminatingPipelineItem +) -> Optional[JobModel]: + res = await session.execute( + select(JobModel) + .where( + JobModel.id == item.id, + JobModel.lock_token == item.lock_token, + ) + .options( + joinedload(JobModel.run).load_only( + RunModel.id, + RunModel.project_id, + RunModel.run_name, + RunModel.gateway_id, + RunModel.termination_reason, + ), + joinedload(JobModel.run) + .joinedload(RunModel.project) + .load_only(ProjectModel.id, ProjectModel.name), + ) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one_or_none() + + +async def _lock_related_instance( + session: AsyncSession, + item: JobTerminatingPipelineItem, + instance_id: uuid.UUID, +) -> Optional[InstanceModel]: + instance_lock, _ = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__) + async with instance_lock: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == instance_id, + or_( + InstanceModel.lock_expires_at.is_(None), + InstanceModel.lock_expires_at < get_current_datetime(), + ), + or_( + InstanceModel.lock_owner.is_(None), + InstanceModel.lock_owner == JobTerminatingPipeline.__name__, + ), + ) + .with_for_update(skip_locked=True, key_share=True) + .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) + .options( + joinedload(InstanceModel.volume_attachments).joinedload( + VolumeAttachmentModel.volume + ) + ) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id)) + ) + instance_model = res.unique().scalar_one_or_none() + if instance_model is None: + return None + instance_model.lock_expires_at = item.lock_expires_at + instance_model.lock_token = item.lock_token + instance_model.lock_owner = JobTerminatingPipeline.__name__ + return instance_model + + +async def _load_job_volume_models( + job_model: JobModel, + instance_model: Optional[InstanceModel], +) -> list[VolumeModel]: + if instance_model is None: + return [] + jrd = get_job_runtime_data(job_model) + volume_names = ( + jrd.volume_names + if jrd and jrd.volume_names + else [va.volume.name for va in instance_model.volume_attachments] + ) + if len(volume_names) == 0: + return [] + async with get_session_ctx() as session: + res = await session.execute( + select(VolumeModel) + .where( + VolumeModel.project_id == instance_model.project.id, + VolumeModel.name.in_(volume_names), + VolumeModel.deleted == False, + ) + .options(joinedload(VolumeModel.project)) + .options(joinedload(VolumeModel.user)) + .options( + joinedload(VolumeModel.attachments) + .joinedload(VolumeAttachmentModel.instance) + .joinedload(InstanceModel.fleet) + ) + ) + return list(res.unique().scalars().all()) + + +async def _reset_job_lock_for_retry(session: AsyncSession, item: JobTerminatingPipelineItem): + res = await session.execute( + update(JobModel) + .where( + JobModel.id == item.id, + JobModel.lock_token == item.lock_token, + ) + .values( + lock_expires_at=None, + lock_token=None, + last_processed_at=get_current_datetime(), + ) + .returning(JobModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_on_reset(logger) + + +async def _apply_process_result( + item: JobTerminatingPipelineItem, + job_model: JobModel, + instance_model: Optional[InstanceModel], + result: _ProcessResult, +) -> None: + async with get_session_ctx() as session: + now = get_current_datetime() + locked_instance_model: Optional[InstanceModel] = instance_model + instance_update_map: Optional[_InstanceUpdateMap] = None + if locked_instance_model is not None: + instance_update_map = get_or_error(result.instance_update_map) + resolve_now_placeholders(result.job_update_map, now=now) + if instance_update_map is not None: + resolve_now_placeholders(instance_update_map, now=now) + if result.volume_update_rows: + resolve_now_placeholders(result.volume_update_rows, now=now) + + res = await session.execute( + update(JobModel) + .where( + JobModel.id == item.id, + JobModel.lock_token == item.lock_token, + ) + .values(**result.job_update_map) + .returning(JobModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_after_processing(logger, item) + if locked_instance_model is not None: + await _unlock_related_instance( + session=session, + item=item, + instance_id=locked_instance_model.id, + ) + return + + if instance_update_map is not None: + locked_instance_model = get_or_error(locked_instance_model) + res = await session.execute( + update(InstanceModel) + .where( + InstanceModel.id == locked_instance_model.id, + InstanceModel.lock_token == item.lock_token, + InstanceModel.lock_owner == JobTerminatingPipeline.__name__, + ) + .values(**instance_update_map) + .returning(InstanceModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.error( + "Failed to update related instance %s for terminating job %s.", + locked_instance_model.id, + item.id, + ) + + if result.volume_update_rows: + await session.execute(update(VolumeModel), result.volume_update_rows) + + if result.detached_volume_ids and locked_instance_model is not None: + await session.execute( + delete(VolumeAttachmentModel).where( + VolumeAttachmentModel.instance_id == locked_instance_model.id, + VolumeAttachmentModel.volume_id.in_(result.detached_volume_ids), + ) + ) + + emit_job_status_change_event( + session=session, + job_model=job_model, + old_status=job_model.status, + new_status=result.job_update_map.get("status", job_model.status), + termination_reason=result.job_update_map.get( + "termination_reason", job_model.termination_reason + ), + termination_reason_message=result.job_update_map.get( + "termination_reason_message", + job_model.termination_reason_message, + ), + ) + if instance_update_map is not None: + locked_instance_model = get_or_error(locked_instance_model) + emit_instance_status_change_event( + session=session, + instance_model=locked_instance_model, + old_status=locked_instance_model.status, + new_status=instance_update_map.get("status", locked_instance_model.status), + termination_reason=instance_update_map.get( + "termination_reason", locked_instance_model.termination_reason + ), + termination_reason_message=instance_update_map.get( + "termination_reason_message", + locked_instance_model.termination_reason_message, + ), + ) + if result.unassign_event_message is not None and locked_instance_model is not None: + events.emit( + session, + result.unassign_event_message, + actor=events.SystemActor(), + targets=[ + events.Target.from_model(job_model), + events.Target.from_model(locked_instance_model), + ], + ) + if result.emit_unregister_replica_event: + targets = [events.Target.from_model(job_model)] + if result.unregister_gateway_target is not None: + targets.append(result.unregister_gateway_target) + events.emit( + session, + "Service replica unregistered from receiving requests", + actor=events.SystemActor(), + targets=targets, + ) + + +async def _unlock_related_instance( + session: AsyncSession, + item: JobTerminatingPipelineItem, + instance_id: uuid.UUID, +) -> None: + await session.execute( + update(InstanceModel) + .where( + InstanceModel.id == instance_id, + InstanceModel.lock_token == item.lock_token, + InstanceModel.lock_owner == JobTerminatingPipeline.__name__, + ) + .values( + lock_expires_at=None, + lock_token=None, + lock_owner=None, + ) + ) + + +async def _process_terminating_job( + job_model: JobModel, + instance_model: Optional[InstanceModel], +) -> _ProcessResult: + """ + Stops the job: tells shim to stop the container, detaches the job from the instance, + and detaches volumes from the instance. + Graceful stop should already be done by `process_terminating_run`. + """ + instance_update_map = None if instance_model is None else _InstanceUpdateMap() + result = _ProcessResult(instance_update_map=instance_update_map) + + if instance_model is None: + result.unregister_gateway_target = await _unregister_replica(job_model=job_model) + if job_model.registered: + result.job_update_map["registered"] = False + result.emit_unregister_replica_event = True + result.job_update_map["status"] = _get_job_termination_status(job_model) + return result + + jrd = get_job_runtime_data(job_model) + jpd = get_job_provisioning_data(job_model) + if jpd is not None: + logger.debug("%s: stopping container", fmt(job_model)) + ssh_private_keys = get_instance_ssh_private_keys(instance_model) + if not await _stop_container(job_model, jpd, ssh_private_keys): + logger.warning( + ( + "%s: could not stop container, possibly due to a communication error." + " See debug logs for details." + " Ignoring, can attempt to remove the container later" + ), + fmt(job_model), + ) + + ( + result.volume_update_rows, + detach_result, + ) = await _detach_job_volumes( + job_model=job_model, + instance_model=instance_model, + job_provisioning_data=jpd, + ) + result.detached_volume_ids = detach_result.detached_volume_ids + if detach_result.set_volumes_detached_at: + result.job_update_map["volumes_detached_at"] = NOW_PLACEHOLDER + + instance_update_map = get_or_error(result.instance_update_map) + busy_blocks = instance_model.busy_blocks - _get_job_occupied_blocks(jrd) + instance_update_map["busy_blocks"] = busy_blocks + if instance_model.status != InstanceStatus.BUSY or jpd is None or not jpd.dockerized: + if instance_model.status not in InstanceStatus.finished_statuses(): + instance_update_map["termination_reason"] = InstanceTerminationReason.JOB_FINISHED + instance_update_map["status"] = InstanceStatus.TERMINATING + elif not [j for j in instance_model.jobs if j.id != job_model.id]: + instance_update_map["status"] = InstanceStatus.IDLE + + result.job_update_map["instance_id"] = None + instance_update_map["last_job_processed_at"] = NOW_PLACEHOLDER + result.unassign_event_message = ( + "Job unassigned from instance." + f" Instance blocks: {busy_blocks}/{instance_model.total_blocks} busy" + ) + + result.unregister_gateway_target = await _unregister_replica(job_model=job_model) + if job_model.registered: + result.job_update_map["registered"] = False + result.emit_unregister_replica_event = True + if detach_result.all_detached: + result.job_update_map["status"] = _get_job_termination_status(job_model) + return result + + +async def _process_job_volumes_detaching( + job_model: JobModel, + instance_model: InstanceModel, +) -> _ProcessResult: + """ + Called after job's volumes have been soft detached to check if they are detached. + Terminates the job when all the volumes are detached. + If the volumes fail to detach, force detaches them. + """ + result = _ProcessResult() + jpd = get_or_error(get_job_provisioning_data(job_model)) + ( + result.volume_update_rows, + detach_result, + ) = await _detach_job_volumes( + job_model=job_model, + instance_model=instance_model, + job_provisioning_data=jpd, + ) + result.detached_volume_ids = detach_result.detached_volume_ids + if detach_result.all_detached: + result.job_update_map["status"] = _get_job_termination_status(job_model) + return result + + +async def _detach_job_volumes( + job_model: JobModel, + instance_model: InstanceModel, + job_provisioning_data: Optional[JobProvisioningData], +) -> tuple[list[_VolumeUpdateRow], _VolumeDetachResult]: + volume_models = await _load_job_volume_models( + job_model=job_model, instance_model=instance_model + ) + volume_update_rows = _get_volume_update_rows(volume_models) + if len(volume_models) == 0: + return volume_update_rows, _VolumeDetachResult(all_detached=True) + + if job_provisioning_data is None: + return volume_update_rows, _VolumeDetachResult(all_detached=True) + + logger.info("Detaching volumes: %s", [v.name for v in volume_models]) + detach_result = await _detach_volumes_from_job_instance( + job_model=job_model, + instance_model=instance_model, + volume_models=volume_models, + jpd=job_provisioning_data, + run_termination_reason=job_model.run.termination_reason, + ) + return volume_update_rows, detach_result + + +async def _unregister_replica( + job_model: JobModel, +) -> Optional[events.Target]: + if not job_model.registered: + return None + gateway_target = None + run_model = job_model.run + if run_model.gateway_id is not None: + async with get_session_ctx() as session: + gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + gateway_target = events.Target.from_model(gateway) + try: + logger.debug( + "%s: unregistering replica from service %s", fmt(job_model), job_model.run_id.hex + ) + async with conn.client() as client: + await client.unregister_replica( + project=run_model.project.name, + run_name=run_model.run_name, + job_id=job_model.id, + ) + except GatewayError as e: + logger.warning("%s: unregistering replica from service: %s", fmt(job_model), e) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + raise GatewayError(repr(e)) + return gateway_target + + +def _get_job_termination_status(job_model: JobModel) -> JobStatus: + if job_model.termination_reason is not None: + return job_model.termination_reason.to_status() + return JobStatus.FAILED + + +def _get_volume_update_rows(volume_models: list[VolumeModel]) -> list[_VolumeUpdateRow]: + return [ + { + "id": volume_model.id, + "last_job_processed_at": NOW_PLACEHOLDER, + } + for volume_model in volume_models + ] + + +def _get_job_occupied_blocks(jrd: Optional[JobRuntimeData]) -> int: + if jrd is not None and jrd.offer is not None: + return jrd.offer.blocks + return 1 + + +async def _stop_container( + job_model: JobModel, + job_provisioning_data: JobProvisioningData, + ssh_private_keys: tuple[str, Optional[str]], +) -> bool: + if job_provisioning_data.dockerized: + return await common.run_async( + _shim_submit_stop, + ssh_private_keys, + job_provisioning_data, + None, + job_model, + ) + return True + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) +def _shim_submit_stop(ports: dict[int, int], job_model: JobModel) -> bool: + shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + + resp = shim_client.healthcheck() + if resp is None: + logger.debug("%s: can't stop container, shim is not available yet", fmt(job_model)) + return False + + if shim_client.is_api_v2_supported(): + reason = ( + None if job_model.termination_reason is None else job_model.termination_reason.value + ) + shim_client.terminate_task( + task_id=job_model.id, + reason=reason, + message=job_model.termination_reason_message, + timeout=0, + ) + if not settings.SERVER_KEEP_SHIM_TASKS: + shim_client.remove_task(task_id=job_model.id) + else: + shim_client.stop(force=True) + return True + + +async def _detach_volumes_from_job_instance( + job_model: JobModel, + instance_model: InstanceModel, + volume_models: list[VolumeModel], + jpd: JobProvisioningData, + run_termination_reason: Optional[RunTerminationReason], +) -> _VolumeDetachResult: + job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + backend = await backends_services.get_project_backend_by_type( + project=instance_model.project, + backend_type=jpd.backend, + ) + if backend is None: + logger.error( + "Failed to detach volumes from %s. Backend not available.", instance_model.name + ) + return _VolumeDetachResult(all_detached=False) + + detached_volume_ids = set() + all_detached = True + for volume_model in volume_models: + detached = await _detach_volume_from_job_instance( + backend=backend, + job_model=job_model, + jpd=jpd, + job_spec=job_spec, + instance_model=instance_model, + volume_model=volume_model, + run_termination_reason=run_termination_reason, + ) + if detached: + detached_volume_ids.add(volume_model.id) + else: + all_detached = False + + return _VolumeDetachResult( + all_detached=all_detached, + detached_volume_ids=detached_volume_ids, + set_volumes_detached_at=job_model.volumes_detached_at is None, + ) + + +async def _detach_volume_from_job_instance( + backend: Backend, + job_model: JobModel, + jpd: JobProvisioningData, + job_spec: JobSpec, + instance_model: InstanceModel, + volume_model: VolumeModel, + run_termination_reason: Optional[RunTerminationReason], +) -> bool: + detached = True + volume = volume_model_to_volume(volume_model) + if volume.provisioning_data is None or not volume.provisioning_data.detachable: + return detached + compute = backend.compute() + assert isinstance(compute, ComputeWithVolumeSupport) + try: + if job_model.volumes_detached_at is None: + await common.run_async( + compute.detach_volume, + volume=volume, + provisioning_data=jpd, + force=False, + ) + detached = await common.run_async( + compute.is_volume_detached, + volume=volume, + provisioning_data=jpd, + ) + else: + detached = await common.run_async( + compute.is_volume_detached, + volume=volume, + provisioning_data=jpd, + ) + if not detached and _should_force_detach_volume( + job_model=job_model, + run_termination_reason=run_termination_reason, + stop_duration=job_spec.stop_duration, + ): + logger.info( + "Force detaching volume %s from %s", + volume_model.name, + instance_model.name, + ) + await common.run_async( + compute.detach_volume, + volume=volume, + provisioning_data=jpd, + force=True, + ) + except BackendError as e: + logger.error( + "Failed to detach volume %s from %s: %s", + volume_model.name, + instance_model.name, + repr(e), + ) + except Exception: + logger.exception( + "Got exception when detaching volume %s from instance %s", + volume_model.name, + instance_model.name, + ) + return detached + + +def _should_force_detach_volume( + job_model: JobModel, + run_termination_reason: Optional[RunTerminationReason], + stop_duration: Optional[int], +) -> bool: + return ( + job_model.volumes_detached_at is not None + and get_current_datetime() > job_model.volumes_detached_at + timedelta(seconds=60) + and ( + job_model.termination_reason == JobTerminationReason.ABORTED_BY_USER + or run_termination_reason == RunTerminationReason.ABORTED_BY_USER + or stop_duration is not None + and get_current_datetime() + > job_model.volumes_detached_at + timedelta(seconds=stop_duration) + ) + ) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py index 6fc391f01..671a05b16 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py @@ -221,6 +221,7 @@ async def _process_terminating_job( ], ) + # Volumes are not locked because no other place can update attached active volumes. for volume_model in volume_models: volume_model.last_job_processed_at = common.get_current_datetime() diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 66e7ebdb2..f718a80ce 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -75,15 +75,55 @@ def switch_job_status( return job_model.status = new_status + emit_job_status_change_event( + session=session, + job_model=job_model, + old_status=old_status, + new_status=new_status, + termination_reason=job_model.termination_reason, + termination_reason_message=job_model.termination_reason_message, + actor=actor, + ) + +def get_job_status_change_message( + old_status: JobStatus, + new_status: JobStatus, + termination_reason: Optional[JobTerminationReason], + termination_reason_message: Optional[str], +) -> str: msg = f"Job status changed {old_status.upper()} -> {new_status.upper()}" if new_status == JobStatus.TERMINATING: - if job_model.termination_reason is None: + if termination_reason is None: raise ValueError("termination_reason must be set when switching to TERMINATING status") - msg += f". Termination reason: {job_model.termination_reason.upper()}" - if job_model.termination_reason_message: - msg += f" ({job_model.termination_reason_message})" - events.emit(session, msg, actor=actor, targets=[events.Target.from_model(job_model)]) + msg += f". Termination reason: {termination_reason.upper()}" + if termination_reason_message: + msg += f" ({termination_reason_message})" + return msg + + +def emit_job_status_change_event( + session: AsyncSession, + job_model: JobModel, + old_status: JobStatus, + new_status: JobStatus, + termination_reason: Optional[JobTerminationReason], + termination_reason_message: Optional[str], + actor: events.AnyActor = events.SystemActor(), +) -> None: + if old_status == new_status: + return + events.emit( + session, + get_job_status_change_message( + old_status=old_status, + new_status=new_status, + termination_reason=termination_reason, + termination_reason_message=termination_reason_message, + ), + actor=actor, + targets=[events.Target.from_model(job_model)], + ) async def get_jobs_from_run_spec( diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py index 95e4457c6..3e8ba188a 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py @@ -1,25 +1,39 @@ import asyncio import uuid from datetime import datetime, timedelta, timezone -from unittest.mock import Mock +from unittest.mock import Mock, patch import pytest +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload -from dstack._internal.core.models.runs import JobStatus +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.runs import JobStatus, JobTerminationReason +from dstack._internal.core.models.volumes import VolumeStatus from dstack._internal.server.background.pipeline_tasks.terminating_jobs import ( JobTerminatingFetcher, JobTerminatingPipeline, JobTerminatingPipelineItem, JobTerminatingWorker, ) -from dstack._internal.server.models import JobModel +from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel from dstack._internal.server.testing.common import ( + ComputeMockSpec, + create_instance, create_job, create_project, create_repo, create_run, create_user, + create_volume, + get_instance_offer_with_availability, + get_job_provisioning_data, + get_job_runtime_data, + get_volume_configuration, + get_volume_provisioning_data, + list_events, ) from dstack._internal.utils.common import get_current_datetime @@ -53,6 +67,12 @@ def _job_to_pipeline_item(job_model: JobModel) -> JobTerminatingPipelineItem: ) +def _lock_job(job_model: JobModel): + job_model.lock_token = uuid.uuid4() + job_model.lock_expires_at = get_current_datetime() + timedelta(seconds=30) + job_model.lock_owner = JobTerminatingPipeline.__name__ + + @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) @pytest.mark.usefixtures("image_config_mock") @@ -212,16 +232,471 @@ async def test_fetch_returns_oldest_jobs_first_up_to_limit( @pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("image_config_mock") class TestJobTerminatingWorker: - async def test_process_is_not_implemented(self, worker: JobTerminatingWorker): - item = JobTerminatingPipelineItem( - __tablename__=JobModel.__tablename__, - id=uuid.uuid4(), - lock_token=uuid.uuid4(), - lock_expires_at=datetime(2025, 1, 2, 3, 4, tzinfo=timezone.utc), - prev_lock_expired=False, - volumes_detached_at=None, - ) - - with pytest.raises(NotImplementedError, match="not implemented yet"): - await worker.process(item) + async def test_terminates_job( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job_provisioning_data = get_job_provisioning_data(dockerized=True) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), + job_provisioning_data=job_provisioning_data, + instance=instance, + ) + _lock_job(job) + await session.commit() + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, + patch("dstack._internal.server.services.runner.client.ShimClient") as ShimClientMock, + ): + shim_client_mock = ShimClientMock.return_value + await worker.process(_job_to_pipeline_item(job)) + SSHTunnelMock.assert_called_once() + shim_client_mock.healthcheck.assert_called_once() + + await session.refresh(job) + await session.refresh(instance) + assert job.status == JobStatus.TERMINATED + assert job.lock_token is None + assert job.lock_expires_at is None + assert instance.lock_token is None + assert instance.lock_owner is None + + events = await list_events(session) + assert any( + event.message == "Job status changed TERMINATING -> TERMINATED" for event in events + ) + + async def test_detaches_job_volumes( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + volume_provisioning_data=get_volume_provisioning_data(), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + volumes=[volume], + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job_provisioning_data = get_job_provisioning_data(dockerized=False) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), + job_provisioning_data=job_provisioning_data, + instance=instance, + ) + _lock_job(job) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.terminating_jobs.backends_services.get_project_backend_by_type" + ) as m: + backend_mock = Mock() + m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.is_volume_detached.return_value = True + await worker.process(_job_to_pipeline_item(job)) + m.assert_awaited_once() + backend_mock.compute.return_value.detach_volume.assert_called_once() + backend_mock.compute.return_value.is_volume_detached.assert_called_once() + + await session.refresh(job) + assert job.status == JobStatus.TERMINATED + await session.refresh(volume) + assert volume.last_job_processed_at is not None + + async def test_force_detaches_job_volumes( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + volume_provisioning_data=get_volume_provisioning_data(), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + volumes=[volume], + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job_provisioning_data = get_job_provisioning_data(dockerized=False) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), + job_provisioning_data=job_provisioning_data, + instance=instance, + ) + _lock_job(job) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.terminating_jobs.backends_services.get_project_backend_by_type" + ) as m: + backend_mock = Mock() + m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.is_volume_detached.return_value = False + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + res = await session.execute(select(JobModel).options(joinedload(JobModel.instance))) + job = res.unique().scalar_one() + assert job.status == JobStatus.TERMINATING + assert job.instance is None + assert job.volumes_detached_at is not None + + _lock_job(job) + await session.commit() + with ( + patch( + "dstack._internal.server.background.pipeline_tasks.terminating_jobs.backends_services.get_project_backend_by_type" + ) as m, + patch( + "dstack._internal.server.background.pipeline_tasks.terminating_jobs.get_current_datetime" + ) as datetime_mock, + ): + datetime_mock.return_value = job.volumes_detached_at.replace( + tzinfo=timezone.utc + ) + timedelta(minutes=30) + backend_mock = Mock() + m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.is_volume_detached.return_value = False + await worker.process(_job_to_pipeline_item(job)) + backend_mock.compute.return_value.detach_volume.assert_called_once() + detach_kwargs = backend_mock.compute.return_value.detach_volume.call_args.kwargs + assert detach_kwargs["force"] is True + assert detach_kwargs["volume"].id == volume.id + assert ( + detach_kwargs["provisioning_data"].instance_id == job_provisioning_data.instance_id + ) + backend_mock.compute.return_value.is_volume_detached.assert_called_once() + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + + _lock_job(job) + await session.commit() + with patch( + "dstack._internal.server.background.pipeline_tasks.terminating_jobs.backends_services.get_project_backend_by_type" + ) as m: + backend_mock = Mock() + m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.is_volume_detached.return_value = True + await worker.process(_job_to_pipeline_item(job)) + backend_mock.compute.return_value.is_volume_detached.assert_called_once() + + await session.refresh(job) + await session.refresh(instance, ["volume_attachments"]) + res = await session.execute( + select(InstanceModel) + .where(InstanceModel.id == instance.id) + .options(joinedload(InstanceModel.volume_attachments)) + .execution_options(populate_existing=True) + ) + instance = res.unique().scalar_one() + assert job.status == JobStatus.TERMINATED + assert len(instance.volume_attachments) == 0 + + async def test_terminates_job_on_shared_instance( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session) + user = await create_user(session) + repo = await create_repo(session=session, project_id=project.id) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + total_blocks=4, + busy_blocks=3, + ) + run = await create_run(session=session, project=project, repo=repo, user=user) + shared_offer = get_instance_offer_with_availability(blocks=2, total_blocks=4) + jrd = get_job_runtime_data(offer=shared_offer) + job = await create_job( + session=session, + run=run, + instance_assigned=True, + instance=instance, + job_runtime_data=jrd, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + ) + _lock_job(job) + await session.commit() + + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + await session.refresh(instance) + res = await session.execute(select(JobModel).options(joinedload(JobModel.instance))) + job = res.unique().scalar_one() + assert job.status == JobStatus.TERMINATED + assert job.instance_assigned + assert job.instance is None + assert instance.busy_blocks == 1 + + async def test_detaches_job_volumes_on_shared_instance( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume_conf_1 = get_volume_configuration(name="vol-1") + volume_1 = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=volume_conf_1, + volume_provisioning_data=get_volume_provisioning_data(), + ) + volume_conf_2 = get_volume_configuration(name="vol-2") + volume_2 = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=volume_conf_2, + volume_provisioning_data=get_volume_provisioning_data(), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + volumes=[volume_1, volume_2], + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job_provisioning_data = get_job_provisioning_data(dockerized=False) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), + job_provisioning_data=job_provisioning_data, + job_runtime_data=get_job_runtime_data(volume_names=["vol-1"]), + instance=instance, + ) + _lock_job(job) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.terminating_jobs.backends_services.get_project_backend_by_type" + ) as m: + backend_mock = Mock() + m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.is_volume_detached.return_value = True + + await worker.process(_job_to_pipeline_item(job)) + + backend_mock.compute.return_value.detach_volume.assert_called_once() + backend_mock.compute.return_value.is_volume_detached.assert_called_once() + + await session.refresh(job) + await session.refresh(instance) + res = await session.execute( + select(InstanceModel).options( + joinedload(InstanceModel.volume_attachments).joinedload( + VolumeAttachmentModel.volume + ) + ) + ) + instance = res.unique().scalar_one() + assert job.status == JobStatus.TERMINATED + assert len(instance.volume_attachments) == 1 + assert instance.volume_attachments[0].volume == volume_2 + + async def test_resets_job_for_retry_if_related_instance_is_locked( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + instance.lock_owner = "OtherPipeline" + instance.lock_token = uuid.uuid4() + instance.lock_expires_at = get_current_datetime() + timedelta(minutes=1) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + instance=instance, + ) + _lock_job(job) + last_processed_before = job.last_processed_at + await session.commit() + + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner == JobTerminatingPipeline.__name__ + assert job.last_processed_at > last_processed_before + + async def test_finishes_job_when_used_instance_is_not_set( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + ) + _lock_job(job) + await session.commit() + + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + assert job.status == JobStatus.TERMINATED + assert job.lock_token is None + assert job.lock_expires_at is None + + async def test_retries_detaching_when_used_instance_is_missing( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + ) + job.instance_id = None + job.used_instance_id = uuid.uuid4() + job.volumes_detached_at = get_current_datetime() + _lock_job(job) + last_processed_before = job.last_processed_at + await session.commit() + + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner == JobTerminatingPipeline.__name__ + assert job.last_processed_at > last_processed_before + + async def test_retries_terminating_when_used_instance_is_missing( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + ) + job.used_instance_id = uuid.uuid4() + _lock_job(job) + last_processed_before = job.last_processed_at + await session.commit() + + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner == JobTerminatingPipeline.__name__ + assert job.last_processed_at > last_processed_before + + async def test_keeps_related_instance_locked_on_processing_exception( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + ) + _lock_job(job) + job_lock_token = job.lock_token + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.terminating_jobs._process_terminating_job", + side_effect=RuntimeError("boom"), + ): + with pytest.raises(RuntimeError, match="boom"): + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + await session.refresh(instance) + assert job.lock_token == job_lock_token + assert job.lock_owner == JobTerminatingPipeline.__name__ + assert instance.lock_token == job_lock_token + assert instance.lock_owner == JobTerminatingPipeline.__name__ From fb6306aa917a4461a432160360266f0a181306bc Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Mon, 9 Mar 2026 17:25:36 +0500 Subject: [PATCH 05/15] Minor refactoring --- .../background/pipeline_tasks/fleets.py | 2 +- .../pipeline_tasks/terminating_jobs.py | 94 ++++++++++--------- 2 files changed, 50 insertions(+), 46 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 2a63e21bd..0a4dfda25 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -369,7 +369,7 @@ async def _lock_fleet_instances_for_consolidation( "Failed to lock fleet %s instances. The fleet will be processed later.", item.id, ) - # Keep `lock_owner` so that `InstancePipeline` sees that the fleet is being locked + # Keep `lock_owner` so that `InstancePipeline` can check that the fleet is being locked # but unset `lock_expires_at` to process the item again ASAP (after `min_processing_interval`). # Unset `lock_token` so that heartbeater can no longer update the item. res = await session.execute( diff --git a/src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py index da2a6b503..d3230b32f 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py @@ -239,12 +239,6 @@ async def process(self, item: JobTerminatingPipelineItem): if instance_model is None: await _reset_job_lock_for_retry(session=session, item=item) return - if job_model.volumes_detached_at is not None and instance_model is None: - logger.error( - "%s: expected used_instance_id while detaching volumes", fmt(job_model) - ) - await _reset_job_lock_for_retry(session=session, item=item) - return if job_model.volumes_detached_at is None: result = await _process_terminating_job( @@ -262,8 +256,9 @@ async def process(self, item: JobTerminatingPipelineItem): if instance_model is not None: if result.instance_update_map is None: result.instance_update_map = _InstanceUpdateMap() - set_processed_update_map_fields(result.instance_update_map) - set_unlock_update_map_fields(result.instance_update_map) + instance_update_map = result.instance_update_map + set_processed_update_map_fields(instance_update_map) + set_unlock_update_map_fields(instance_update_map) await _apply_process_result( item=item, job_model=job_model, @@ -287,9 +282,6 @@ class _InstanceUpdateMap(ItemUpdateMap, total=False): termination_reason_message: Optional[str] busy_blocks: int last_job_processed_at: UpdateMapDateTime - lock_expires_at: Optional[datetime] - lock_token: Optional[uuid.UUID] - lock_owner: Optional[str] class _VolumeUpdateRow(TypedDict): @@ -361,7 +353,6 @@ async def _lock_related_instance( InstanceModel.lock_owner == JobTerminatingPipeline.__name__, ), ) - .with_for_update(skip_locked=True, key_share=True) .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) .options( joinedload(InstanceModel.volume_attachments).joinedload( @@ -369,6 +360,7 @@ async def _lock_related_instance( ) ) .options(joinedload(InstanceModel.jobs).load_only(JobModel.id)) + .with_for_update(skip_locked=True, key_share=True) ) instance_model = res.unique().scalar_one_or_none() if instance_model is None: @@ -419,6 +411,9 @@ async def _reset_job_lock_for_retry(session: AsyncSession, item: JobTerminatingP JobModel.id == item.id, JobModel.lock_token == item.lock_token, ) + # Keep `lock_owner` so that `InstancePipeline` can check that the job is being locked + # but unset `lock_expires_at` to process the item again ASAP (after `min_processing_interval`). + # Unset `lock_token` so that heartbeater can no longer update the item. .values( lock_expires_at=None, lock_token=None, @@ -439,10 +434,9 @@ async def _apply_process_result( ) -> None: async with get_session_ctx() as session: now = get_current_datetime() - locked_instance_model: Optional[InstanceModel] = instance_model - instance_update_map: Optional[_InstanceUpdateMap] = None - if locked_instance_model is not None: - instance_update_map = get_or_error(result.instance_update_map) + instance_update_map = result.instance_update_map + if instance_model is None: + instance_update_map = None resolve_now_placeholders(result.job_update_map, now=now) if instance_update_map is not None: resolve_now_placeholders(instance_update_map, now=now) @@ -461,20 +455,19 @@ async def _apply_process_result( updated_ids = list(res.scalars().all()) if len(updated_ids) == 0: log_lock_token_changed_after_processing(logger, item) - if locked_instance_model is not None: + if instance_model is not None: await _unlock_related_instance( session=session, item=item, - instance_id=locked_instance_model.id, + instance_id=instance_model.id, ) return - if instance_update_map is not None: - locked_instance_model = get_or_error(locked_instance_model) + if instance_model is not None and instance_update_map is not None: res = await session.execute( update(InstanceModel) .where( - InstanceModel.id == locked_instance_model.id, + InstanceModel.id == instance_model.id, InstanceModel.lock_token == item.lock_token, InstanceModel.lock_owner == JobTerminatingPipeline.__name__, ) @@ -485,17 +478,17 @@ async def _apply_process_result( if len(updated_ids) == 0: logger.error( "Failed to update related instance %s for terminating job %s.", - locked_instance_model.id, + instance_model.id, item.id, ) if result.volume_update_rows: await session.execute(update(VolumeModel), result.volume_update_rows) - if result.detached_volume_ids and locked_instance_model is not None: + if result.detached_volume_ids and instance_model is not None: await session.execute( delete(VolumeAttachmentModel).where( - VolumeAttachmentModel.instance_id == locked_instance_model.id, + VolumeAttachmentModel.instance_id == instance_model.id, VolumeAttachmentModel.volume_id.in_(result.detached_volume_ids), ) ) @@ -513,31 +506,34 @@ async def _apply_process_result( job_model.termination_reason_message, ), ) - if instance_update_map is not None: - locked_instance_model = get_or_error(locked_instance_model) + + if instance_model is not None and instance_update_map is not None: emit_instance_status_change_event( session=session, - instance_model=locked_instance_model, - old_status=locked_instance_model.status, - new_status=instance_update_map.get("status", locked_instance_model.status), + instance_model=instance_model, + old_status=instance_model.status, + new_status=instance_update_map.get("status", instance_model.status), termination_reason=instance_update_map.get( - "termination_reason", locked_instance_model.termination_reason + "termination_reason", + instance_model.termination_reason, ), termination_reason_message=instance_update_map.get( "termination_reason_message", - locked_instance_model.termination_reason_message, + instance_model.termination_reason_message, ), ) - if result.unassign_event_message is not None and locked_instance_model is not None: + + if result.unassign_event_message is not None and instance_model is not None: events.emit( session, result.unassign_event_message, actor=events.SystemActor(), targets=[ events.Target.from_model(job_model), - events.Target.from_model(locked_instance_model), + events.Target.from_model(instance_model), ], ) + if result.emit_unregister_replica_event: targets = [events.Target.from_model(job_model)] if result.unregister_gateway_target is not None: @@ -583,10 +579,7 @@ async def _process_terminating_job( result = _ProcessResult(instance_update_map=instance_update_map) if instance_model is None: - result.unregister_gateway_target = await _unregister_replica(job_model=job_model) - if job_model.registered: - result.job_update_map["registered"] = False - result.emit_unregister_replica_event = True + await _unregister_replica_and_update_result(result=result, job_model=job_model) result.job_update_map["status"] = _get_job_termination_status(job_model) return result @@ -634,10 +627,7 @@ async def _process_terminating_job( f" Instance blocks: {busy_blocks}/{instance_model.total_blocks} busy" ) - result.unregister_gateway_target = await _unregister_replica(job_model=job_model) - if job_model.registered: - result.job_update_map["registered"] = False - result.emit_unregister_replica_event = True + await _unregister_replica_and_update_result(result=result, job_model=job_model) if detach_result.all_detached: result.job_update_map["status"] = _get_job_termination_status(job_model) return result @@ -652,7 +642,7 @@ async def _process_job_volumes_detaching( Terminates the job when all the volumes are detached. If the volumes fail to detach, force detaches them. """ - result = _ProcessResult() + result = _ProcessResult(instance_update_map=_InstanceUpdateMap()) jpd = get_or_error(get_job_provisioning_data(job_model)) ( result.volume_update_rows, @@ -694,6 +684,15 @@ async def _detach_job_volumes( return volume_update_rows, detach_result +async def _unregister_replica_and_update_result( + result: _ProcessResult, job_model: JobModel +) -> None: + result.unregister_gateway_target = await _unregister_replica(job_model=job_model) + if job_model.registered: + result.job_update_map["registered"] = False + result.emit_unregister_replica_event = True + + async def _unregister_replica( job_model: JobModel, ) -> Optional[events.Target]: @@ -719,6 +718,8 @@ async def _unregister_replica( logger.warning("%s: unregistering replica from service: %s", fmt(job_model), e) except (httpx.RequestError, SSHError) as e: logger.debug("Gateway request failed", exc_info=True) + # FIXME: Unhandled exception raised. + # Handle and retry unregister with timeout. raise GatewayError(repr(e)) return gateway_target @@ -895,19 +896,22 @@ async def _detach_volume_from_job_instance( return detached +_MIN_FORCE_DETACH_WAIT_PERIOD = timedelta(seconds=60) + + def _should_force_detach_volume( job_model: JobModel, run_termination_reason: Optional[RunTerminationReason], stop_duration: Optional[int], ) -> bool: + now = get_current_datetime() return ( job_model.volumes_detached_at is not None - and get_current_datetime() > job_model.volumes_detached_at + timedelta(seconds=60) + and now > job_model.volumes_detached_at + _MIN_FORCE_DETACH_WAIT_PERIOD and ( job_model.termination_reason == JobTerminationReason.ABORTED_BY_USER or run_termination_reason == RunTerminationReason.ABORTED_BY_USER or stop_duration is not None - and get_current_datetime() - > job_model.volumes_detached_at + timedelta(seconds=stop_duration) + and now > job_model.volumes_detached_at + timedelta(seconds=stop_duration) ) ) From d8ec8c59f2d73339661b5ca6ee004437193b494a Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 10:44:57 +0500 Subject: [PATCH 06/15] Rename pipeline_tasks terminating_jobs.py to jobs_terminating.py --- .../{terminating_jobs.py => jobs_terminating.py} | 0 .../pipeline_tasks/test_terminating_jobs.py | 16 ++++++++-------- 2 files changed, 8 insertions(+), 8 deletions(-) rename src/dstack/_internal/server/background/pipeline_tasks/{terminating_jobs.py => jobs_terminating.py} (100%) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py similarity index 100% rename from src/dstack/_internal/server/background/pipeline_tasks/terminating_jobs.py rename to src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py index 3e8ba188a..acefd3bf4 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py @@ -12,7 +12,7 @@ from dstack._internal.core.models.instances import InstanceStatus from dstack._internal.core.models.runs import JobStatus, JobTerminationReason from dstack._internal.core.models.volumes import VolumeStatus -from dstack._internal.server.background.pipeline_tasks.terminating_jobs import ( +from dstack._internal.server.background.pipeline_tasks.jobs_terminating import ( JobTerminatingFetcher, JobTerminatingPipeline, JobTerminatingPipelineItem, @@ -317,7 +317,7 @@ async def test_detaches_job_volumes( await session.commit() with patch( - "dstack._internal.server.background.pipeline_tasks.terminating_jobs.backends_services.get_project_backend_by_type" + "dstack._internal.server.background.pipeline_tasks.jobs_terminating.backends_services.get_project_backend_by_type" ) as m: backend_mock = Mock() m.return_value = backend_mock @@ -368,7 +368,7 @@ async def test_force_detaches_job_volumes( await session.commit() with patch( - "dstack._internal.server.background.pipeline_tasks.terminating_jobs.backends_services.get_project_backend_by_type" + "dstack._internal.server.background.pipeline_tasks.jobs_terminating.backends_services.get_project_backend_by_type" ) as m: backend_mock = Mock() m.return_value = backend_mock @@ -387,10 +387,10 @@ async def test_force_detaches_job_volumes( await session.commit() with ( patch( - "dstack._internal.server.background.pipeline_tasks.terminating_jobs.backends_services.get_project_backend_by_type" + "dstack._internal.server.background.pipeline_tasks.jobs_terminating.backends_services.get_project_backend_by_type" ) as m, patch( - "dstack._internal.server.background.pipeline_tasks.terminating_jobs.get_current_datetime" + "dstack._internal.server.background.pipeline_tasks.jobs_terminating.get_current_datetime" ) as datetime_mock, ): datetime_mock.return_value = job.volumes_detached_at.replace( @@ -416,7 +416,7 @@ async def test_force_detaches_job_volumes( _lock_job(job) await session.commit() with patch( - "dstack._internal.server.background.pipeline_tasks.terminating_jobs.backends_services.get_project_backend_by_type" + "dstack._internal.server.background.pipeline_tasks.jobs_terminating.backends_services.get_project_backend_by_type" ) as m: backend_mock = Mock() m.return_value = backend_mock @@ -524,7 +524,7 @@ async def test_detaches_job_volumes_on_shared_instance( await session.commit() with patch( - "dstack._internal.server.background.pipeline_tasks.terminating_jobs.backends_services.get_project_backend_by_type" + "dstack._internal.server.background.pipeline_tasks.jobs_terminating.backends_services.get_project_backend_by_type" ) as m: backend_mock = Mock() m.return_value = backend_mock @@ -688,7 +688,7 @@ async def test_keeps_related_instance_locked_on_processing_exception( await session.commit() with patch( - "dstack._internal.server.background.pipeline_tasks.terminating_jobs._process_terminating_job", + "dstack._internal.server.background.pipeline_tasks.jobs_terminating._process_terminating_job", side_effect=RuntimeError("boom"), ): with pytest.raises(RuntimeError, match="boom"): From c6ca7d8de661f9a1f4b0127035e6bf378bf4fd06 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 10:47:17 +0500 Subject: [PATCH 07/15] Wire pipeline --- .../server/background/pipeline_tasks/__init__.py | 4 ++++ .../server/background/scheduled_tasks/__init__.py | 12 ++++++------ 2 files changed, 10 insertions(+), 6 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index 556e13daa..7b2d79047 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -5,6 +5,9 @@ from dstack._internal.server.background.pipeline_tasks.fleets import FleetPipeline from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline from dstack._internal.server.background.pipeline_tasks.instances import InstancePipeline +from dstack._internal.server.background.pipeline_tasks.jobs_terminating import ( + JobTerminatingPipeline, +) from dstack._internal.server.background.pipeline_tasks.placement_groups import ( PlacementGroupPipeline, ) @@ -20,6 +23,7 @@ def __init__(self) -> None: ComputeGroupPipeline(), FleetPipeline(), GatewayPipeline(), + JobTerminatingPipeline(), InstancePipeline(), PlacementGroupPipeline(), VolumePipeline(), diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index 2994fca37..a0b97e7d5 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -134,12 +134,6 @@ def start_scheduled_tasks() -> AsyncIOScheduler: kwargs={"batch_size": 5}, max_instances=2 if replica == 0 else 1, ) - _scheduler.add_job( - process_terminating_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) _scheduler.add_job( process_runs, IntervalTrigger(seconds=2, jitter=1), @@ -159,5 +153,11 @@ def start_scheduled_tasks() -> AsyncIOScheduler: kwargs={"batch_size": 1}, max_instances=2 if replica == 0 else 1, ) + _scheduler.add_job( + process_terminating_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) _scheduler.start() return _scheduler From faa4552c8cbcff8cf3f21748657fa357b2b9db73 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 11:12:11 +0500 Subject: [PATCH 08/15] Use SELECT FOR UPDATE OF --- contributing/LOCKING.md | 17 ++++++++++------- .../server/background/pipeline_tasks/fleets.py | 4 ++-- .../background/pipeline_tasks/gateways.py | 2 +- .../pipeline_tasks/jobs_terminating.py | 4 ++-- .../pipeline_tasks/placement_groups.py | 2 +- .../server/background/pipeline_tasks/volumes.py | 2 +- 6 files changed, 17 insertions(+), 14 deletions(-) diff --git a/contributing/LOCKING.md b/contributing/LOCKING.md index e23fb41f9..c89ad526a 100644 --- a/contributing/LOCKING.md +++ b/contributing/LOCKING.md @@ -1,7 +1,6 @@ # Locking -The `dstack` server supports SQLite and Postgres databases -with two implementations of resource locking to handle concurrent access: +The `dstack` server supports SQLite and Postgres databases with two implementations of resource locking to handle concurrent access: * In-memory locking for SQLite. * DB-level locking for Postgres. @@ -34,11 +33,11 @@ There are few places that rely on advisory locks as when generating unique resou ## Working with locks -Concurrency is hard. Below you'll find common patterns and gotchas when working with locks to make it a bit more manageable. +Concurrency is hard. Concurrency with locking is especially hard. Below you'll find common patterns and gotchas when working with locks to make it a bit more manageable. **A task should acquire locks on resources it modifies** -This is a common sense approach. An alternative could be the inverse: job processing cannot run in parallel with run processing, so job processing takes run lock. This indirection complicates things and is discouraged. In this example, run processing should take job lock instead. +This is common sense. An alternative could be the inverse: job processing cannot run in parallel with run processing, so job processing takes run lock. This indirection complicates things and is discouraged. In this example, run processing should take job lock instead. **Start new transaction after acquiring a lock to see other transactions changes in SQLite.** @@ -75,15 +74,19 @@ unlock resources If a transaction releases a lock before committing changes, the changes may not be visible to another transaction that acquired the lock and relies upon seeing all committed changes. -**Don't use `joinedload` when selecting `.with_for_update()`** +**Using `joinedload` when selecting `.with_for_update()`** -In fact, using `joinedload` and `.with_for_update()` will trigger an error because `joinedload` produces OUTER LEFT JOIN that cannot be used with SELECT FOR UPDATE. A regular `.join()` can be used to lock related resources but it may lead to no rows if there is no row to join. Usually, you'd select with `selectinload` or first select with `.with_for_update()` without loading related attributes and then re-selecting with `joinedload` without `.with_for_update()`. +Using `joinedload` and `.with_for_update()` triggers an error in case of no related rows because `joinedload` produces OUTER LEFT JOIN and SELECT FOR UPDATE cannot be applied to the nullable side of an OUTER JOIN. Here's the options: + +* Use `.with_for_update(of=MainModel)`. +* Select with `selectinload` +* First select with `.with_for_update()` without loading related attributes and then re-select with `joinedload` without `.with_for_update()`. +* Use regular `.join()` to lock related resources, but you may get 0 rows if there is no related row to join. **Always use `.with_for_update(key_share=True)` unless you plan to delete rows or update a primary key column** If you `SELECT FOR UPDATE` from a table that is referenced in a child table via a foreign key, it can lead to deadlocks if the child table is updated because Postgres will issue a `FOR KEY SHARE` lock on the parent table rows to ensure valid foreign keys. For this reason, you should always do `SELECT FOR NO KEY UPDATE` (.`with_for_update(key_share=True)`) if primary key columns are not modified. `SELECT FOR NO KEY UPDATE` is not blocked by a `FOR KEY SHARE` lock, so no deadlock. - **Lock unique names** The following pattern can be used to lock a unique name of some resource type: diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 0a4dfda25..3065c1e09 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -152,7 +152,7 @@ async def fetch(self, limit: int) -> list[PipelineItem]: ) .order_by(FleetModel.last_processed_at.asc()) .limit(limit) - .with_for_update(skip_locked=True, key_share=True) + .with_for_update(skip_locked=True, key_share=True, of=FleetModel) .options( load_only( FleetModel.id, @@ -352,7 +352,7 @@ async def _lock_fleet_instances_for_consolidation( InstanceModel.lock_owner == FleetPipeline.__name__, ), ) - .with_for_update(skip_locked=True, key_share=True) + .with_for_update(skip_locked=True, key_share=True, of=InstanceModel) ) locked_instance_models = list(res.scalars().all()) locked_instance_ids = {instance_model.id for instance_model in locked_instance_models} diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index 81ba2ae70..a051c5a96 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -152,7 +152,7 @@ async def fetch(self, limit: int) -> list[GatewayPipelineItem]: ) .order_by(GatewayModel.last_processed_at.asc()) .limit(limit) - .with_for_update(skip_locked=True, key_share=True) + .with_for_update(skip_locked=True, key_share=True, of=GatewayModel) .options( load_only( GatewayModel.id, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index d3230b32f..4e059f237 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -177,7 +177,7 @@ async def fetch(self, limit: int) -> list[JobTerminatingPipelineItem]: ) .order_by(JobModel.last_processed_at.asc()) .limit(limit) - .with_for_update(skip_locked=True, key_share=True) + .with_for_update(skip_locked=True, key_share=True, of=JobModel) .options( load_only( JobModel.id, @@ -360,7 +360,7 @@ async def _lock_related_instance( ) ) .options(joinedload(InstanceModel.jobs).load_only(JobModel.id)) - .with_for_update(skip_locked=True, key_share=True) + .with_for_update(skip_locked=True, key_share=True, of=InstanceModel) ) instance_model = res.unique().scalar_one_or_none() if instance_model is None: diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index 552ae00dc..8fb5bfd4b 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -136,7 +136,7 @@ async def fetch(self, limit: int) -> list[PipelineItem]: ) .order_by(PlacementGroupModel.last_processed_at.asc()) .limit(limit) - .with_for_update(skip_locked=True, key_share=True) + .with_for_update(skip_locked=True, key_share=True, of=PlacementGroupModel) .options( load_only( PlacementGroupModel.id, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py index 81d94c361..89eea7d92 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -154,7 +154,7 @@ async def fetch(self, limit: int) -> list[VolumePipelineItem]: ) .order_by(VolumeModel.last_processed_at.asc()) .limit(limit) - .with_for_update(skip_locked=True, key_share=True) + .with_for_update(skip_locked=True, key_share=True, of=VolumeModel) .options( load_only( VolumeModel.id, From c44f1aec90d1e7c84695d3239cf05cf01cb2018c Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 12:09:06 +0500 Subject: [PATCH 09/15] Add contributing/PIPELINES.md --- contributing/PIPELINES.md | 43 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 contributing/PIPELINES.md diff --git a/contributing/PIPELINES.md b/contributing/PIPELINES.md new file mode 100644 index 000000000..eefe607e0 --- /dev/null +++ b/contributing/PIPELINES.md @@ -0,0 +1,43 @@ +# Pipelines + +This document describes how the `dstack` server implements background processing via so-called "pipelines". + +*Historical context: `dstack` used to do all background processing via scheduled tasks. A scheduled task would process a specific resource type like volumes or runs by keeping DB transaction open for the entire processing duration and keeping the resource lock with SELECT FOR UPDATE (or in-memory lock on SQLite). This approach didn't scale well because the number of DB connections was a huge bottleneck. Pipelines replaced scheduled tasks: the do all the heavy processing outside of DB transactions and write locks to DB columns.* + +## Overview + +* Resources are continuously processed in the background by pipelines. A pipeline consists of a fetcher, workers, and a heartbeater. +* A fetcher selects rows to be processed from the DB, marks them as locked in the DB, and puts them into an in-memory queue. +* Workers consume rows from the in-memory queue, process the rows, and unlock them. +* The locking (unlocking) is done by setting (unsetting) `lock_expires_at`, `lock_token`, `lock_owner`. +* If the replica/pipeline dies, the rows stay locked in the db. Another replica picks up the rows after `lock_expires_at`. +* `lock_token` prevents stale replica/pipeline to update the rows already picked up by the new replica. +* `lock_owner` stores the pipeline that's locked the row so that only that pipeline can recover if it's stale. +* A heartbeater tracks all rows in the pipeline (in the queue or in processing), and updates the lock expiration. This allows setting small `lock_expires_at` and picking up stale rows quickly +* A fetcher performs the fetch when the queue size goes under a configured lower limit. It has exponential retry delays between empty fetches, thus reducing load on the DB. +* There is a fetch hint mechanism that services can use to notify the pipelines within the replica – in that case the fetcher stops sleeping and fetches immediately. +* Each pipeline locks one main resource but may lock related resources as well. It's not necessary to heartbeat related resources if the pipeline ensures no one else can re-lock them. This is typically done via setting and respecting `lock_owner`. + +Related notes: + +* All write APIs must respect DB-level locks. The endpoints can either try to acquire the lock with a timeout and error or provide an async API by storing the request in the DB. + +## Implementation patterns + +**Locking many related resources** + +A pipeline may need to lock a potentially big set of related resource, e.g. fleet pipeline locking all fleet's instances. For this, do one SELECT FOR UPDATE of non-locked instances and one SELECT to see how many instances there are, and check if you managed to lock all of them. If fail to lock, release the main lock and try processing on another fetch iteration. You may keep `lock_owner` on the main resource or set `lock_owner` on locked related resource and make other pipelines respect that to guarantee the eventual locking of all related resources and avoid lock starvation. + +**Locking a shared related resource** + +Multiple main resources may need to lock the same related resource, e.g. multiple jobs may need to change the shared instance. In this case it's not sufficient to set `lock_owner` on the related resource to the pipeline name because workers processing different main resources can still race with each other. To avoid heartbeating the related resource, you may include main resource id in `lock_owner`, e.g. set `lock_owner = f"{Pipeline.__name__}:{item.id}"`. + +## Performance analysis + +* Pipeline throughput = workers_num / worker_processing_time. So quick tasks easily give high-throughput pipelines, e.g. 1s task with 20 workers is 1200 tasks/min. +A slow 30s task gives only 40 tasks/min with the same number of workers. We can increase the number of workers but the peak memory usage will grow proportionally. +In general, workers should be optimized to be as quick as possible to improve throughput. +* Processing latency (wait) is close to 0 due to fetch hints if the pipeline is not saturated. In general, latency = queue_size / throughput. +* In-memory queue maxsize provides a cap on memory usage and recovery time after crashes (number of locked items to retry). +* Fetcher's DB load is proportional to the number of pipelines and is expected to be negligible. Workers can put a considerable read/write DB load as it's proportional to the number of workers. This can be optimized by batching workers' writes. Workers do processing outside of transactions so DB connections won't be a bottleneck. +* There is a risk of lock starvation if a worker needs to lock all related resources. This is to be mitigated by 1) related pipelines checking `lock_owner` and skip locking to let the parent pipeline acquire all the locks eventually and 2) do the related resource locking only on paths that require it. From 6a7c63a5fdae6e6357863929625e7b75a5565334 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 12:09:31 +0500 Subject: [PATCH 10/15] Set job-specific lock_owner --- .../pipeline_tasks/jobs_terminating.py | 14 ++++-- .../pipeline_tasks/test_terminating_jobs.py | 46 ++++++++++++++++++- 2 files changed, 55 insertions(+), 5 deletions(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index 4e059f237..75ae08c2a 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -338,6 +338,7 @@ async def _lock_related_instance( item: JobTerminatingPipelineItem, instance_id: uuid.UUID, ) -> Optional[InstanceModel]: + lock_owner = _get_related_instance_lock_owner(item.id) instance_lock, _ = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__) async with instance_lock: res = await session.execute( @@ -350,7 +351,7 @@ async def _lock_related_instance( ), or_( InstanceModel.lock_owner.is_(None), - InstanceModel.lock_owner == JobTerminatingPipeline.__name__, + InstanceModel.lock_owner == lock_owner, ), ) .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) @@ -367,7 +368,7 @@ async def _lock_related_instance( return None instance_model.lock_expires_at = item.lock_expires_at instance_model.lock_token = item.lock_token - instance_model.lock_owner = JobTerminatingPipeline.__name__ + instance_model.lock_owner = lock_owner return instance_model @@ -434,6 +435,7 @@ async def _apply_process_result( ) -> None: async with get_session_ctx() as session: now = get_current_datetime() + related_instance_lock_owner = _get_related_instance_lock_owner(item.id) instance_update_map = result.instance_update_map if instance_model is None: instance_update_map = None @@ -469,7 +471,7 @@ async def _apply_process_result( .where( InstanceModel.id == instance_model.id, InstanceModel.lock_token == item.lock_token, - InstanceModel.lock_owner == JobTerminatingPipeline.__name__, + InstanceModel.lock_owner == related_instance_lock_owner, ) .values(**instance_update_map) .returning(InstanceModel.id) @@ -556,7 +558,7 @@ async def _unlock_related_instance( .where( InstanceModel.id == instance_id, InstanceModel.lock_token == item.lock_token, - InstanceModel.lock_owner == JobTerminatingPipeline.__name__, + InstanceModel.lock_owner == _get_related_instance_lock_owner(item.id), ) .values( lock_expires_at=None, @@ -915,3 +917,7 @@ def _should_force_detach_volume( and now > job_model.volumes_detached_at + timedelta(seconds=stop_duration) ) ) + + +def _get_related_instance_lock_owner(job_id: uuid.UUID) -> str: + return f"{JobTerminatingPipeline.__name__}:{job_id}" diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py index acefd3bf4..822d6798a 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py @@ -17,6 +17,7 @@ JobTerminatingPipeline, JobTerminatingPipelineItem, JobTerminatingWorker, + _get_related_instance_lock_owner, ) from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel from dstack._internal.server.testing.common import ( @@ -584,6 +585,49 @@ async def test_resets_job_for_retry_if_related_instance_is_locked( assert job.lock_owner == JobTerminatingPipeline.__name__ assert job.last_processed_at > last_processed_before + async def test_resets_job_for_retry_if_related_instance_is_locked_by_another_job( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + other_job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + instance=instance, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + instance=instance, + ) + instance.lock_owner = _get_related_instance_lock_owner(other_job.id) + instance.lock_token = uuid.uuid4() + instance.lock_expires_at = get_current_datetime() - timedelta(minutes=1) + _lock_job(job) + last_processed_before = job.last_processed_at + await session.commit() + + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + await session.refresh(instance) + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner == JobTerminatingPipeline.__name__ + assert job.last_processed_at > last_processed_before + assert instance.lock_owner == _get_related_instance_lock_owner(other_job.id) + async def test_finishes_job_when_used_instance_is_not_set( self, test_db, session: AsyncSession, worker: JobTerminatingWorker ): @@ -699,4 +743,4 @@ async def test_keeps_related_instance_locked_on_processing_exception( assert job.lock_token == job_lock_token assert job.lock_owner == JobTerminatingPipeline.__name__ assert instance.lock_token == job_lock_token - assert instance.lock_owner == JobTerminatingPipeline.__name__ + assert instance.lock_owner == _get_related_instance_lock_owner(job.id) From 58efc6ecef4f3bf4807937cbfbc6cfb71e347b20 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 12:14:08 +0500 Subject: [PATCH 11/15] Decrease min_processing_interval --- .../server/background/pipeline_tasks/jobs_terminating.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py index 75ae08c2a..61c8adaee 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -86,7 +86,7 @@ def __init__( workers_num: int = 10, queue_lower_limit_factor: float = 0.5, queue_upper_limit_factor: float = 2.0, - min_processing_interval: timedelta = timedelta(seconds=15), + min_processing_interval: timedelta = timedelta(seconds=5), lock_timeout: timedelta = timedelta(seconds=30), heartbeat_trigger: timedelta = timedelta(seconds=15), ) -> None: From 301f217730387bb9a05fea3edb42393f377ea253 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 12:21:46 +0500 Subject: [PATCH 12/15] Rebase migration --- .../03_09_0928_6026b29d78c7_add_jobmodel_pipeline_columns.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_09_0928_6026b29d78c7_add_jobmodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/2026/03_09_0928_6026b29d78c7_add_jobmodel_pipeline_columns.py index b2898e833..84126aa1c 100644 --- a/src/dstack/_internal/server/migrations/versions/2026/03_09_0928_6026b29d78c7_add_jobmodel_pipeline_columns.py +++ b/src/dstack/_internal/server/migrations/versions/2026/03_09_0928_6026b29d78c7_add_jobmodel_pipeline_columns.py @@ -1,7 +1,7 @@ """Add JobModel pipeline columns Revision ID: 6026b29d78c7 -Revises: c7b0a8e57294 +Revises: a13f5b55af01 Create Date: 2026-03-09 09:28:17.993416+00:00 """ @@ -14,7 +14,7 @@ # revision identifiers, used by Alembic. revision = "6026b29d78c7" -down_revision = "c7b0a8e57294" +down_revision = "a13f5b55af01" branch_labels = None depends_on = None From a04dd167fdd8d16fc5c3ce3f6ae3ece3e5721371 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 14:09:02 +0500 Subject: [PATCH 13/15] Add ix_jobs_pipeline_fetch_q index --- ...1b9a_add_ix_jobs_pipeline_fetch_q_index.py | 51 +++++++++++++++++++ src/dstack/_internal/server/models.py | 9 ++++ 2 files changed, 60 insertions(+) create mode 100644 src/dstack/_internal/server/migrations/versions/2026/03_10_1130_8b6d5d8c1b9a_add_ix_jobs_pipeline_fetch_q_index.py diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_10_1130_8b6d5d8c1b9a_add_ix_jobs_pipeline_fetch_q_index.py b/src/dstack/_internal/server/migrations/versions/2026/03_10_1130_8b6d5d8c1b9a_add_ix_jobs_pipeline_fetch_q_index.py new file mode 100644 index 000000000..b6a1ca924 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_10_1130_8b6d5d8c1b9a_add_ix_jobs_pipeline_fetch_q_index.py @@ -0,0 +1,51 @@ +"""Add ix_jobs_pipeline_fetch_q index + +Revision ID: 8b6d5d8c1b9a +Revises: 6026b29d78c7 +Create Date: 2026-03-10 11:30:00.000000+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8b6d5d8c1b9a" +down_revision = "6026b29d78c7" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_jobs_pipeline_fetch_q", + table_name="jobs", + if_exists=True, + postgresql_concurrently=True, + ) + op.create_index( + "ix_jobs_pipeline_fetch_q", + "jobs", + [sa.literal_column("last_processed_at ASC")], + unique=False, + sqlite_where=sa.text("(status NOT IN ('TERMINATED', 'ABORTED', 'FAILED', 'DONE'))"), + postgresql_where=sa.text( + "(status NOT IN ('TERMINATED', 'ABORTED', 'FAILED', 'DONE'))" + ), + postgresql_concurrently=True, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_jobs_pipeline_fetch_q", + table_name="jobs", + if_exists=True, + postgresql_concurrently=True, + ) + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 4272cef4d..3b4986b48 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -515,6 +515,15 @@ class JobModel(PipelineModelMixin, BaseModel): should be processed only one-by-one. """ + __table_args__ = ( + Index( + "ix_jobs_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=status.not_in(JobStatus.finished_statuses()), + sqlite_where=status.not_in(JobStatus.finished_statuses()), + ), + ) + class GatewayModel(PipelineModelMixin, BaseModel): __tablename__ = "gateways" From 73f343b4f1bb6e8b7d4dab1952ba6f3d33449a97 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 14:29:15 +0500 Subject: [PATCH 14/15] Add deprecated note --- .../server/background/scheduled_tasks/terminating_jobs.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py index 671a05b16..4cf63f2b7 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py @@ -60,6 +60,8 @@ logger = get_logger(__name__) +# NOTE: This scheduled task is going to be deprecated in favor of `JobTerminatingPipeline`. +# If this logic changes before removal, keep `pipeline_tasks/jobs_terminating.py` in sync. async def process_terminating_jobs(batch_size: int = 1): tasks = [] for _ in range(batch_size): From 87911bdfdb77569525088bcf8cfef6eab10a5744 Mon Sep 17 00:00:00 2001 From: Victor Skvortsov Date: Tue, 10 Mar 2026 15:08:31 +0500 Subject: [PATCH 15/15] Expand PIPELINES.md --- contributing/PIPELINES.md | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/contributing/PIPELINES.md b/contributing/PIPELINES.md index eefe607e0..89037bb6b 100644 --- a/contributing/PIPELINES.md +++ b/contributing/PIPELINES.md @@ -22,8 +22,27 @@ Related notes: * All write APIs must respect DB-level locks. The endpoints can either try to acquire the lock with a timeout and error or provide an async API by storing the request in the DB. +## Implementation checklist + +Brief checklist for implementing a new pipeline: + +1. Fetcher locks only rows that are ready for processing: +`status/time` filters, `lock_expires_at` is empty or expired, and `lock_owner` is empty or equal to the pipeline name. Keep the fetch order stable with `last_processed_at`. +2. Fetcher takes row locks with `skip_locked` and updates `lock_expires_at`, `lock_token`, `lock_owner` before enqueueing items. +3. Worker keeps heavy work outside DB sessions. DB sessions should be short and used only for refetch/locking and final apply. +4. Apply stage updates rows using update maps/update rows, not by relying on mutating detached ORM models. +5. Main apply update is guarded by `id + lock_token`. If the update affects `0` rows, the item is stale and processing results must not be applied. +6. Successful apply updates `last_processed_at` and unlocks resources that were locked by this item. +7. If related lock is unavailable, reset main lock for retry: keep `lock_owner`, clear `lock_token` and `lock_expires_at`, and set `last_processed_at` to now. +8. Register the pipeline in `PipelineManager` and hint fetch from services after commit via `pipeline_hinter.hint_fetch(Model.__name__)`. +9. Add minimum tests: fetch eligibility/order, successful unlock path, stale lock token path, and related lock contention retry path. + ## Implementation patterns +**Guarded apply by lock token** + +When writing processing results, update the main row with a filter by both `id` and `lock_token`. This guarantees that only the worker that still owns the lock can apply its results. If the update affects no rows, treat the item as stale and skip applying other changes (status changes, related updates, events). A stale item means another worker or replica already continued processing. + **Locking many related resources** A pipeline may need to lock a potentially big set of related resource, e.g. fleet pipeline locking all fleet's instances. For this, do one SELECT FOR UPDATE of non-locked instances and one SELECT to see how many instances there are, and check if you managed to lock all of them. If fail to lock, release the main lock and try processing on another fetch iteration. You may keep `lock_owner` on the main resource or set `lock_owner` on locked related resource and make other pipelines respect that to guarantee the eventual locking of all related resources and avoid lock starvation. @@ -32,6 +51,18 @@ A pipeline may need to lock a potentially big set of related resource, e.g. flee Multiple main resources may need to lock the same related resource, e.g. multiple jobs may need to change the shared instance. In this case it's not sufficient to set `lock_owner` on the related resource to the pipeline name because workers processing different main resources can still race with each other. To avoid heartbeating the related resource, you may include main resource id in `lock_owner`, e.g. set `lock_owner = f"{Pipeline.__name__}:{item.id}"`. +**Reset-and-retry when related lock is unavailable** + +If a worker cannot lock a required related resource, it should release only the main lock state needed for fast retry: unset `lock_token` and `lock_expires_at`, keep `lock_owner`, and set `last_processed_at` to now. This avoids long waiting and lets the same pipeline retry quickly on the next fetch iteration while other pipelines can still respect ownership intent. + +**Dealing with side effects** + +If processing has side effects and the apply phase fails due to a lock mismatch, there are several options: a) revert side effects b) make processing idempotent, i.e. next processing iteration detects side effects does not perform duplicating actions c) log side effects as errors and warn user about possible issues such as orphaned instances – as a temporary solution. + +**Bulk apply with one consistent current time** + +When apply needs to update multiple rows (main + related resources), build update maps/update rows first and resolve current-time placeholders once in the apply transaction using `NOW_PLACEHOLDER` + `resolve_now_placeholders()`. This keeps timestamps consistent across all rows and avoids subtle ordering bugs when the same processing pass writes several `*_at` fields. + ## Performance analysis * Pipeline throughput = workers_num / worker_processing_time. So quick tasks easily give high-throughput pipelines, e.g. 1s task with 20 workers is 1200 tasks/min.