diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index bfef8bc78..bd1307df7 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -95,6 +95,10 @@ class RunTerminationReason(str, Enum): SERVER_ERROR = "server_error" def to_job_termination_reason(self) -> "JobTerminationReason": + """ + Converts run termination reason to job termination reason. + Used to set job termination reason for non-terminated jobs on run termination. + """ mapping = { self.ALL_JOBS_DONE: JobTerminationReason.DONE_BY_RUNNER, self.JOB_FAILED: JobTerminationReason.TERMINATED_BY_SERVER, diff --git a/src/dstack/_internal/server/background/scheduled_tasks/runs.py b/src/dstack/_internal/server/background/scheduled_tasks/runs.py index e0c6793ce..56d9fea77 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/runs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/runs.py @@ -66,6 +66,10 @@ logger = get_logger(__name__) MIN_PROCESSING_INTERVAL = datetime.timedelta(seconds=5) + +# No need to lock finished or terminating jobs since run processing does not update such jobs. +JOB_STATUSES_EXCLUDED_FOR_LOCKING = JobStatus.finished_statuses() + [JobStatus.TERMINATING] + ROLLING_DEPLOYMENT_MAX_SURGE = 1 # at most one extra replica during rolling deployment @@ -121,10 +125,9 @@ async def _process_next_run(): ) .options( joinedload(RunModel.jobs).load_only(JobModel.id), - # No need to lock finished jobs with_loader_criteria( JobModel, - JobModel.status.not_in(JobStatus.finished_statuses()), + JobModel.status.not_in(JOB_STATUSES_EXCLUDED_FOR_LOCKING), include_aliases=True, ), ) @@ -146,7 +149,7 @@ async def _process_next_run(): load_only(JobModel.id), with_loader_criteria( JobModel, - JobModel.status.not_in(JobStatus.finished_statuses()), + JobModel.status.not_in(JOB_STATUSES_EXCLUDED_FOR_LOCKING), include_aliases=True, ), ) diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index bf0f65bb6..4c094c676 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -31,6 +31,7 @@ JobSubmission, JobTerminationReason, RunSpec, + RunTerminationReason, ) from dstack._internal.core.models.volumes import Volume, VolumeMountPoint, VolumeStatus from dstack._internal.server import settings @@ -349,6 +350,7 @@ async def process_terminating_job( if len(volume_models) > 0: logger.info("Detaching volumes: %s", [v.name for v in volume_models]) all_volumes_detached = await _detach_volumes_from_job_instance( + session=session, project=instance_model.project, job_model=job_model, jpd=jpd, @@ -432,6 +434,7 @@ async def process_volumes_detaching( ) logger.info("Detaching volumes: %s", [v.name for v in volume_models]) all_volumes_detached = await _detach_volumes_from_job_instance( + session=session, project=instance_model.project, job_model=job_model, jpd=jpd, @@ -523,6 +526,7 @@ def group_jobs_by_replica_latest(jobs: List[JobModel]) -> Iterable[Tuple[int, Li async def _detach_volumes_from_job_instance( + session: AsyncSession, project: ProjectModel, job_model: JobModel, jpd: JobProvisioningData, @@ -542,6 +546,7 @@ async def _detach_volumes_from_job_instance( all_detached = True detached_volumes = [] + run_termination_reason = await _get_run_termination_reason(session, job_model) for volume_model in volume_models: detached = await _detach_volume_from_job_instance( backend=backend, @@ -550,6 +555,7 @@ async def _detach_volumes_from_job_instance( job_spec=job_spec, instance_model=instance_model, volume_model=volume_model, + run_termination_reason=run_termination_reason, ) if detached: detached_volumes.append(volume_model) @@ -572,6 +578,7 @@ async def _detach_volume_from_job_instance( job_spec: JobSpec, instance_model: InstanceModel, volume_model: VolumeModel, + run_termination_reason: Optional[RunTerminationReason], ) -> bool: detached = True volume = volume_model_to_volume(volume_model) @@ -601,7 +608,11 @@ async def _detach_volume_from_job_instance( volume=volume, provisioning_data=jpd, ) - if not detached and _should_force_detach_volume(job_model, job_spec.stop_duration): + if not detached and _should_force_detach_volume( + job_model, + run_termination_reason=run_termination_reason, + stop_duration=job_spec.stop_duration, + ): logger.info( "Force detaching volume %s from %s", volume_model.name, @@ -633,13 +644,27 @@ async def _detach_volume_from_job_instance( MIN_FORCE_DETACH_WAIT_PERIOD = timedelta(seconds=60) -def _should_force_detach_volume(job_model: JobModel, stop_duration: Optional[int]) -> bool: +async def _get_run_termination_reason( + session: AsyncSession, job_model: JobModel +) -> Optional[RunTerminationReason]: + res = await session.execute( + select(RunModel.termination_reason).where(RunModel.id == job_model.run_id) + ) + return res.scalar_one_or_none() + + +def _should_force_detach_volume( + job_model: JobModel, + run_termination_reason: Optional[RunTerminationReason], + stop_duration: Optional[int], +) -> bool: return ( job_model.volumes_detached_at is not None and common.get_current_datetime() > job_model.volumes_detached_at + MIN_FORCE_DETACH_WAIT_PERIOD and ( job_model.termination_reason == JobTerminationReason.ABORTED_BY_USER + or run_termination_reason == RunTerminationReason.ABORTED_BY_USER or stop_duration is not None and common.get_current_datetime() > job_model.volumes_detached_at + timedelta(seconds=stop_duration) diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index 73966916b..f8aa3f288 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -1003,10 +1003,6 @@ async def process_terminating_run(session: AsyncSession, run_model: RunModel): continue unfinished_jobs_count += 1 if job_model.status == JobStatus.TERMINATING: - if job_termination_reason == JobTerminationReason.ABORTED_BY_USER: - # Override termination reason so that - # abort actions such as volume force detach are triggered - job_model.termination_reason = job_termination_reason continue if job_model.status == JobStatus.RUNNING and job_termination_reason not in { diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 2893418a0..fdf4251bf 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -305,7 +305,7 @@ async def create_run( repo: RepoModel, user: UserModel, fleet: Optional[FleetModel] = None, - run_name: str = "test-run", + run_name: Optional[str] = None, status: RunStatus = RunStatus.SUBMITTED, termination_reason: Optional[RunTerminationReason] = None, submitted_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), @@ -317,6 +317,8 @@ async def create_run( resubmission_attempt: int = 0, next_triggered_at: Optional[datetime] = None, ) -> RunModel: + if run_name is None: + run_name = "test-run" if run_spec is None: run_spec = get_run_spec( run_name=run_name,