diff --git a/contributing/LOCKING.md b/contributing/LOCKING.md index e23fb41f9..c89ad526a 100644 --- a/contributing/LOCKING.md +++ b/contributing/LOCKING.md @@ -1,7 +1,6 @@ # Locking -The `dstack` server supports SQLite and Postgres databases -with two implementations of resource locking to handle concurrent access: +The `dstack` server supports SQLite and Postgres databases with two implementations of resource locking to handle concurrent access: * In-memory locking for SQLite. * DB-level locking for Postgres. @@ -34,11 +33,11 @@ There are few places that rely on advisory locks as when generating unique resou ## Working with locks -Concurrency is hard. Below you'll find common patterns and gotchas when working with locks to make it a bit more manageable. +Concurrency is hard. Concurrency with locking is especially hard. Below you'll find common patterns and gotchas when working with locks to make it a bit more manageable. **A task should acquire locks on resources it modifies** -This is a common sense approach. An alternative could be the inverse: job processing cannot run in parallel with run processing, so job processing takes run lock. This indirection complicates things and is discouraged. In this example, run processing should take job lock instead. +This is common sense. An alternative could be the inverse: job processing cannot run in parallel with run processing, so job processing takes run lock. This indirection complicates things and is discouraged. In this example, run processing should take job lock instead. **Start new transaction after acquiring a lock to see other transactions changes in SQLite.** @@ -75,15 +74,19 @@ unlock resources If a transaction releases a lock before committing changes, the changes may not be visible to another transaction that acquired the lock and relies upon seeing all committed changes. -**Don't use `joinedload` when selecting `.with_for_update()`** +**Using `joinedload` when selecting `.with_for_update()`** -In fact, using `joinedload` and `.with_for_update()` will trigger an error because `joinedload` produces OUTER LEFT JOIN that cannot be used with SELECT FOR UPDATE. A regular `.join()` can be used to lock related resources but it may lead to no rows if there is no row to join. Usually, you'd select with `selectinload` or first select with `.with_for_update()` without loading related attributes and then re-selecting with `joinedload` without `.with_for_update()`. +Using `joinedload` and `.with_for_update()` triggers an error in case of no related rows because `joinedload` produces OUTER LEFT JOIN and SELECT FOR UPDATE cannot be applied to the nullable side of an OUTER JOIN. Here's the options: + +* Use `.with_for_update(of=MainModel)`. +* Select with `selectinload` +* First select with `.with_for_update()` without loading related attributes and then re-select with `joinedload` without `.with_for_update()`. +* Use regular `.join()` to lock related resources, but you may get 0 rows if there is no related row to join. **Always use `.with_for_update(key_share=True)` unless you plan to delete rows or update a primary key column** If you `SELECT FOR UPDATE` from a table that is referenced in a child table via a foreign key, it can lead to deadlocks if the child table is updated because Postgres will issue a `FOR KEY SHARE` lock on the parent table rows to ensure valid foreign keys. For this reason, you should always do `SELECT FOR NO KEY UPDATE` (.`with_for_update(key_share=True)`) if primary key columns are not modified. `SELECT FOR NO KEY UPDATE` is not blocked by a `FOR KEY SHARE` lock, so no deadlock. - **Lock unique names** The following pattern can be used to lock a unique name of some resource type: diff --git a/contributing/PIPELINES.md b/contributing/PIPELINES.md new file mode 100644 index 000000000..89037bb6b --- /dev/null +++ b/contributing/PIPELINES.md @@ -0,0 +1,74 @@ +# Pipelines + +This document describes how the `dstack` server implements background processing via so-called "pipelines". + +*Historical context: `dstack` used to do all background processing via scheduled tasks. A scheduled task would process a specific resource type like volumes or runs by keeping DB transaction open for the entire processing duration and keeping the resource lock with SELECT FOR UPDATE (or in-memory lock on SQLite). This approach didn't scale well because the number of DB connections was a huge bottleneck. Pipelines replaced scheduled tasks: the do all the heavy processing outside of DB transactions and write locks to DB columns.* + +## Overview + +* Resources are continuously processed in the background by pipelines. A pipeline consists of a fetcher, workers, and a heartbeater. +* A fetcher selects rows to be processed from the DB, marks them as locked in the DB, and puts them into an in-memory queue. +* Workers consume rows from the in-memory queue, process the rows, and unlock them. +* The locking (unlocking) is done by setting (unsetting) `lock_expires_at`, `lock_token`, `lock_owner`. +* If the replica/pipeline dies, the rows stay locked in the db. Another replica picks up the rows after `lock_expires_at`. +* `lock_token` prevents stale replica/pipeline to update the rows already picked up by the new replica. +* `lock_owner` stores the pipeline that's locked the row so that only that pipeline can recover if it's stale. +* A heartbeater tracks all rows in the pipeline (in the queue or in processing), and updates the lock expiration. This allows setting small `lock_expires_at` and picking up stale rows quickly +* A fetcher performs the fetch when the queue size goes under a configured lower limit. It has exponential retry delays between empty fetches, thus reducing load on the DB. +* There is a fetch hint mechanism that services can use to notify the pipelines within the replica – in that case the fetcher stops sleeping and fetches immediately. +* Each pipeline locks one main resource but may lock related resources as well. It's not necessary to heartbeat related resources if the pipeline ensures no one else can re-lock them. This is typically done via setting and respecting `lock_owner`. + +Related notes: + +* All write APIs must respect DB-level locks. The endpoints can either try to acquire the lock with a timeout and error or provide an async API by storing the request in the DB. + +## Implementation checklist + +Brief checklist for implementing a new pipeline: + +1. Fetcher locks only rows that are ready for processing: +`status/time` filters, `lock_expires_at` is empty or expired, and `lock_owner` is empty or equal to the pipeline name. Keep the fetch order stable with `last_processed_at`. +2. Fetcher takes row locks with `skip_locked` and updates `lock_expires_at`, `lock_token`, `lock_owner` before enqueueing items. +3. Worker keeps heavy work outside DB sessions. DB sessions should be short and used only for refetch/locking and final apply. +4. Apply stage updates rows using update maps/update rows, not by relying on mutating detached ORM models. +5. Main apply update is guarded by `id + lock_token`. If the update affects `0` rows, the item is stale and processing results must not be applied. +6. Successful apply updates `last_processed_at` and unlocks resources that were locked by this item. +7. If related lock is unavailable, reset main lock for retry: keep `lock_owner`, clear `lock_token` and `lock_expires_at`, and set `last_processed_at` to now. +8. Register the pipeline in `PipelineManager` and hint fetch from services after commit via `pipeline_hinter.hint_fetch(Model.__name__)`. +9. Add minimum tests: fetch eligibility/order, successful unlock path, stale lock token path, and related lock contention retry path. + +## Implementation patterns + +**Guarded apply by lock token** + +When writing processing results, update the main row with a filter by both `id` and `lock_token`. This guarantees that only the worker that still owns the lock can apply its results. If the update affects no rows, treat the item as stale and skip applying other changes (status changes, related updates, events). A stale item means another worker or replica already continued processing. + +**Locking many related resources** + +A pipeline may need to lock a potentially big set of related resource, e.g. fleet pipeline locking all fleet's instances. For this, do one SELECT FOR UPDATE of non-locked instances and one SELECT to see how many instances there are, and check if you managed to lock all of them. If fail to lock, release the main lock and try processing on another fetch iteration. You may keep `lock_owner` on the main resource or set `lock_owner` on locked related resource and make other pipelines respect that to guarantee the eventual locking of all related resources and avoid lock starvation. + +**Locking a shared related resource** + +Multiple main resources may need to lock the same related resource, e.g. multiple jobs may need to change the shared instance. In this case it's not sufficient to set `lock_owner` on the related resource to the pipeline name because workers processing different main resources can still race with each other. To avoid heartbeating the related resource, you may include main resource id in `lock_owner`, e.g. set `lock_owner = f"{Pipeline.__name__}:{item.id}"`. + +**Reset-and-retry when related lock is unavailable** + +If a worker cannot lock a required related resource, it should release only the main lock state needed for fast retry: unset `lock_token` and `lock_expires_at`, keep `lock_owner`, and set `last_processed_at` to now. This avoids long waiting and lets the same pipeline retry quickly on the next fetch iteration while other pipelines can still respect ownership intent. + +**Dealing with side effects** + +If processing has side effects and the apply phase fails due to a lock mismatch, there are several options: a) revert side effects b) make processing idempotent, i.e. next processing iteration detects side effects does not perform duplicating actions c) log side effects as errors and warn user about possible issues such as orphaned instances – as a temporary solution. + +**Bulk apply with one consistent current time** + +When apply needs to update multiple rows (main + related resources), build update maps/update rows first and resolve current-time placeholders once in the apply transaction using `NOW_PLACEHOLDER` + `resolve_now_placeholders()`. This keeps timestamps consistent across all rows and avoids subtle ordering bugs when the same processing pass writes several `*_at` fields. + +## Performance analysis + +* Pipeline throughput = workers_num / worker_processing_time. So quick tasks easily give high-throughput pipelines, e.g. 1s task with 20 workers is 1200 tasks/min. +A slow 30s task gives only 40 tasks/min with the same number of workers. We can increase the number of workers but the peak memory usage will grow proportionally. +In general, workers should be optimized to be as quick as possible to improve throughput. +* Processing latency (wait) is close to 0 due to fetch hints if the pipeline is not saturated. In general, latency = queue_size / throughput. +* In-memory queue maxsize provides a cap on memory usage and recovery time after crashes (number of locked items to retry). +* Fetcher's DB load is proportional to the number of pipelines and is expected to be negligible. Workers can put a considerable read/write DB load as it's proportional to the number of workers. This can be optimized by batching workers' writes. Workers do processing outside of transactions so DB connections won't be a bottleneck. +* There is a risk of lock starvation if a worker needs to lock all related resources. This is to be mitigated by 1) related pipelines checking `lock_owner` and skip locking to let the parent pipeline acquire all the locks eventually and 2) do the related resource locking only on paths that require it. diff --git a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py index 556e13daa..7b2d79047 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/__init__.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/__init__.py @@ -5,6 +5,9 @@ from dstack._internal.server.background.pipeline_tasks.fleets import FleetPipeline from dstack._internal.server.background.pipeline_tasks.gateways import GatewayPipeline from dstack._internal.server.background.pipeline_tasks.instances import InstancePipeline +from dstack._internal.server.background.pipeline_tasks.jobs_terminating import ( + JobTerminatingPipeline, +) from dstack._internal.server.background.pipeline_tasks.placement_groups import ( PlacementGroupPipeline, ) @@ -20,6 +23,7 @@ def __init__(self) -> None: ComputeGroupPipeline(), FleetPipeline(), GatewayPipeline(), + JobTerminatingPipeline(), InstancePipeline(), PlacementGroupPipeline(), VolumePipeline(), diff --git a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py index 2a63e21bd..3065c1e09 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/fleets.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/fleets.py @@ -152,7 +152,7 @@ async def fetch(self, limit: int) -> list[PipelineItem]: ) .order_by(FleetModel.last_processed_at.asc()) .limit(limit) - .with_for_update(skip_locked=True, key_share=True) + .with_for_update(skip_locked=True, key_share=True, of=FleetModel) .options( load_only( FleetModel.id, @@ -352,7 +352,7 @@ async def _lock_fleet_instances_for_consolidation( InstanceModel.lock_owner == FleetPipeline.__name__, ), ) - .with_for_update(skip_locked=True, key_share=True) + .with_for_update(skip_locked=True, key_share=True, of=InstanceModel) ) locked_instance_models = list(res.scalars().all()) locked_instance_ids = {instance_model.id for instance_model in locked_instance_models} @@ -369,7 +369,7 @@ async def _lock_fleet_instances_for_consolidation( "Failed to lock fleet %s instances. The fleet will be processed later.", item.id, ) - # Keep `lock_owner` so that `InstancePipeline` sees that the fleet is being locked + # Keep `lock_owner` so that `InstancePipeline` can check that the fleet is being locked # but unset `lock_expires_at` to process the item again ASAP (after `min_processing_interval`). # Unset `lock_token` so that heartbeater can no longer update the item. res = await session.execute( diff --git a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py index 81ba2ae70..a051c5a96 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/gateways.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/gateways.py @@ -152,7 +152,7 @@ async def fetch(self, limit: int) -> list[GatewayPipelineItem]: ) .order_by(GatewayModel.last_processed_at.asc()) .limit(limit) - .with_for_update(skip_locked=True, key_share=True) + .with_for_update(skip_locked=True, key_share=True, of=GatewayModel) .options( load_only( GatewayModel.id, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py new file mode 100644 index 000000000..61c8adaee --- /dev/null +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_terminating.py @@ -0,0 +1,923 @@ +import asyncio +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Optional, Sequence, TypedDict + +import httpx +from sqlalchemy import delete, or_, select, update +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload, load_only + +from dstack._internal.core.backends.base.backend import Backend +from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport +from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.errors import BackendError, GatewayError, SSHError +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.core.models.runs import ( + JobProvisioningData, + JobRuntimeData, + JobSpec, + JobStatus, + JobTerminationReason, + RunTerminationReason, +) +from dstack._internal.server import settings +from dstack._internal.server.background.pipeline_tasks.base import ( + NOW_PLACEHOLDER, + Fetcher, + Heartbeater, + ItemUpdateMap, + Pipeline, + PipelineItem, + UpdateMapDateTime, + Worker, + log_lock_token_changed_after_processing, + log_lock_token_changed_on_reset, + log_lock_token_mismatch, + resolve_now_placeholders, + set_processed_update_map_fields, + set_unlock_update_map_fields, +) +from dstack._internal.server.db import get_db, get_session_ctx +from dstack._internal.server.models import ( + InstanceModel, + JobModel, + ProjectModel, + RunModel, + VolumeAttachmentModel, + VolumeModel, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services import events +from dstack._internal.server.services.gateways import get_or_add_gateway_connection +from dstack._internal.server.services.instances import ( + emit_instance_status_change_event, + get_instance_ssh_private_keys, +) +from dstack._internal.server.services.jobs import ( + emit_job_status_change_event, + get_job_provisioning_data, + get_job_runtime_data, +) +from dstack._internal.server.services.locking import get_locker +from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.runner import client +from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel +from dstack._internal.server.services.volumes import ( + volume_model_to_volume, +) +from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils import common +from dstack._internal.utils.common import get_current_datetime, get_or_error +from dstack._internal.utils.logging import get_logger + +logger = get_logger(__name__) + + +@dataclass +class JobTerminatingPipelineItem(PipelineItem): + volumes_detached_at: Optional[datetime] + + +class JobTerminatingPipeline(Pipeline[JobTerminatingPipelineItem]): + def __init__( + self, + workers_num: int = 10, + queue_lower_limit_factor: float = 0.5, + queue_upper_limit_factor: float = 2.0, + min_processing_interval: timedelta = timedelta(seconds=5), + lock_timeout: timedelta = timedelta(seconds=30), + heartbeat_trigger: timedelta = timedelta(seconds=15), + ) -> None: + super().__init__( + workers_num=workers_num, + queue_lower_limit_factor=queue_lower_limit_factor, + queue_upper_limit_factor=queue_upper_limit_factor, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeat_trigger=heartbeat_trigger, + ) + self.__heartbeater = Heartbeater[JobTerminatingPipelineItem]( + model_type=JobModel, + lock_timeout=self._lock_timeout, + heartbeat_trigger=self._heartbeat_trigger, + ) + self.__fetcher = JobTerminatingFetcher( + queue=self._queue, + queue_desired_minsize=self._queue_desired_minsize, + min_processing_interval=self._min_processing_interval, + lock_timeout=self._lock_timeout, + heartbeater=self._heartbeater, + ) + self.__workers = [ + JobTerminatingWorker(queue=self._queue, heartbeater=self._heartbeater) + for _ in range(self._workers_num) + ] + + @property + def hint_fetch_model_name(self) -> str: + return JobModel.__name__ + + @property + def _heartbeater(self) -> Heartbeater[JobTerminatingPipelineItem]: + return self.__heartbeater + + @property + def _fetcher(self) -> Fetcher[JobTerminatingPipelineItem]: + return self.__fetcher + + @property + def _workers(self) -> Sequence["JobTerminatingWorker"]: + return self.__workers + + +class JobTerminatingFetcher(Fetcher[JobTerminatingPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[JobTerminatingPipelineItem], + queue_desired_minsize: int, + min_processing_interval: timedelta, + lock_timeout: timedelta, + heartbeater: Heartbeater[JobTerminatingPipelineItem], + queue_check_delay: float = 1.0, + ) -> None: + super().__init__( + queue=queue, + queue_desired_minsize=queue_desired_minsize, + min_processing_interval=min_processing_interval, + lock_timeout=lock_timeout, + heartbeater=heartbeater, + queue_check_delay=queue_check_delay, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.JobTerminatingFetcher.fetch") + async def fetch(self, limit: int) -> list[JobTerminatingPipelineItem]: + job_lock, _ = get_locker(get_db().dialect_name).get_lockset(JobModel.__tablename__) + async with job_lock: + async with get_session_ctx() as session: + now = get_current_datetime() + res = await session.execute( + select(JobModel) + .where( + JobModel.status == JobStatus.TERMINATING, + or_( + JobModel.remove_at.is_(None), + JobModel.remove_at < now, + ), + JobModel.last_processed_at <= now - self._min_processing_interval, + or_( + JobModel.lock_expires_at.is_(None), + JobModel.lock_expires_at < now, + ), + or_( + JobModel.lock_owner.is_(None), + JobModel.lock_owner == JobTerminatingPipeline.__name__, + ), + ) + .order_by(JobModel.last_processed_at.asc()) + .limit(limit) + .with_for_update(skip_locked=True, key_share=True, of=JobModel) + .options( + load_only( + JobModel.id, + JobModel.lock_token, + JobModel.lock_expires_at, + JobModel.volumes_detached_at, + ) + ) + ) + job_models = list(res.scalars().all()) + lock_expires_at = get_current_datetime() + self._lock_timeout + lock_token = uuid.uuid4() + items = [] + for job_model in job_models: + prev_lock_expired = job_model.lock_expires_at is not None + job_model.lock_expires_at = lock_expires_at + job_model.lock_token = lock_token + job_model.lock_owner = JobTerminatingPipeline.__name__ + items.append( + JobTerminatingPipelineItem( + __tablename__=JobModel.__tablename__, + id=job_model.id, + lock_expires_at=lock_expires_at, + lock_token=lock_token, + prev_lock_expired=prev_lock_expired, + volumes_detached_at=job_model.volumes_detached_at, + ) + ) + await session.commit() + return items + + +class JobTerminatingWorker(Worker[JobTerminatingPipelineItem]): + def __init__( + self, + queue: asyncio.Queue[JobTerminatingPipelineItem], + heartbeater: Heartbeater[JobTerminatingPipelineItem], + ) -> None: + super().__init__( + queue=queue, + heartbeater=heartbeater, + ) + + @sentry_utils.instrument_named_task("pipeline_tasks.JobTerminatingWorker.process") + async def process(self, item: JobTerminatingPipelineItem): + async with get_session_ctx() as session: + job_model = await _refetch_locked_job(session=session, item=item) + if job_model is None: + log_lock_token_mismatch(logger, item) + return + + instance_model: Optional[InstanceModel] = None + if job_model.used_instance_id is not None: + instance_model = await _lock_related_instance( + session=session, + item=item, + instance_id=job_model.used_instance_id, + ) + if instance_model is None: + await _reset_job_lock_for_retry(session=session, item=item) + return + + if job_model.volumes_detached_at is None: + result = await _process_terminating_job( + job_model=job_model, + instance_model=instance_model, + ) + else: + result = await _process_job_volumes_detaching( + job_model=job_model, + instance_model=get_or_error(instance_model), + ) + + set_processed_update_map_fields(result.job_update_map) + set_unlock_update_map_fields(result.job_update_map) + if instance_model is not None: + if result.instance_update_map is None: + result.instance_update_map = _InstanceUpdateMap() + instance_update_map = result.instance_update_map + set_processed_update_map_fields(instance_update_map) + set_unlock_update_map_fields(instance_update_map) + await _apply_process_result( + item=item, + job_model=job_model, + instance_model=instance_model, + result=result, + ) + + +class _JobUpdateMap(ItemUpdateMap, total=False): + status: JobStatus + termination_reason: Optional[JobTerminationReason] + termination_reason_message: Optional[str] + instance_id: Optional[uuid.UUID] + volumes_detached_at: UpdateMapDateTime + registered: bool + + +class _InstanceUpdateMap(ItemUpdateMap, total=False): + status: InstanceStatus + termination_reason: Optional[InstanceTerminationReason] + termination_reason_message: Optional[str] + busy_blocks: int + last_job_processed_at: UpdateMapDateTime + + +class _VolumeUpdateRow(TypedDict): + id: uuid.UUID + last_job_processed_at: UpdateMapDateTime + + +@dataclass +class _ProcessResult: + job_update_map: _JobUpdateMap = field(default_factory=_JobUpdateMap) + instance_update_map: Optional[_InstanceUpdateMap] = None + volume_update_rows: list[_VolumeUpdateRow] = field(default_factory=list) + detached_volume_ids: set[uuid.UUID] = field(default_factory=set) + unassign_event_message: Optional[str] = None + emit_unregister_replica_event: bool = False + unregister_gateway_target: Optional[events.Target] = None + + +@dataclass +class _VolumeDetachResult: + all_detached: bool + detached_volume_ids: set[uuid.UUID] = field(default_factory=set) + set_volumes_detached_at: bool = False + + +async def _refetch_locked_job( + session: AsyncSession, item: JobTerminatingPipelineItem +) -> Optional[JobModel]: + res = await session.execute( + select(JobModel) + .where( + JobModel.id == item.id, + JobModel.lock_token == item.lock_token, + ) + .options( + joinedload(JobModel.run).load_only( + RunModel.id, + RunModel.project_id, + RunModel.run_name, + RunModel.gateway_id, + RunModel.termination_reason, + ), + joinedload(JobModel.run) + .joinedload(RunModel.project) + .load_only(ProjectModel.id, ProjectModel.name), + ) + .execution_options(populate_existing=True) + ) + return res.unique().scalar_one_or_none() + + +async def _lock_related_instance( + session: AsyncSession, + item: JobTerminatingPipelineItem, + instance_id: uuid.UUID, +) -> Optional[InstanceModel]: + lock_owner = _get_related_instance_lock_owner(item.id) + instance_lock, _ = get_locker(get_db().dialect_name).get_lockset(InstanceModel.__tablename__) + async with instance_lock: + res = await session.execute( + select(InstanceModel) + .where( + InstanceModel.id == instance_id, + or_( + InstanceModel.lock_expires_at.is_(None), + InstanceModel.lock_expires_at < get_current_datetime(), + ), + or_( + InstanceModel.lock_owner.is_(None), + InstanceModel.lock_owner == lock_owner, + ), + ) + .options(joinedload(InstanceModel.project).joinedload(ProjectModel.backends)) + .options( + joinedload(InstanceModel.volume_attachments).joinedload( + VolumeAttachmentModel.volume + ) + ) + .options(joinedload(InstanceModel.jobs).load_only(JobModel.id)) + .with_for_update(skip_locked=True, key_share=True, of=InstanceModel) + ) + instance_model = res.unique().scalar_one_or_none() + if instance_model is None: + return None + instance_model.lock_expires_at = item.lock_expires_at + instance_model.lock_token = item.lock_token + instance_model.lock_owner = lock_owner + return instance_model + + +async def _load_job_volume_models( + job_model: JobModel, + instance_model: Optional[InstanceModel], +) -> list[VolumeModel]: + if instance_model is None: + return [] + jrd = get_job_runtime_data(job_model) + volume_names = ( + jrd.volume_names + if jrd and jrd.volume_names + else [va.volume.name for va in instance_model.volume_attachments] + ) + if len(volume_names) == 0: + return [] + async with get_session_ctx() as session: + res = await session.execute( + select(VolumeModel) + .where( + VolumeModel.project_id == instance_model.project.id, + VolumeModel.name.in_(volume_names), + VolumeModel.deleted == False, + ) + .options(joinedload(VolumeModel.project)) + .options(joinedload(VolumeModel.user)) + .options( + joinedload(VolumeModel.attachments) + .joinedload(VolumeAttachmentModel.instance) + .joinedload(InstanceModel.fleet) + ) + ) + return list(res.unique().scalars().all()) + + +async def _reset_job_lock_for_retry(session: AsyncSession, item: JobTerminatingPipelineItem): + res = await session.execute( + update(JobModel) + .where( + JobModel.id == item.id, + JobModel.lock_token == item.lock_token, + ) + # Keep `lock_owner` so that `InstancePipeline` can check that the job is being locked + # but unset `lock_expires_at` to process the item again ASAP (after `min_processing_interval`). + # Unset `lock_token` so that heartbeater can no longer update the item. + .values( + lock_expires_at=None, + lock_token=None, + last_processed_at=get_current_datetime(), + ) + .returning(JobModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_on_reset(logger) + + +async def _apply_process_result( + item: JobTerminatingPipelineItem, + job_model: JobModel, + instance_model: Optional[InstanceModel], + result: _ProcessResult, +) -> None: + async with get_session_ctx() as session: + now = get_current_datetime() + related_instance_lock_owner = _get_related_instance_lock_owner(item.id) + instance_update_map = result.instance_update_map + if instance_model is None: + instance_update_map = None + resolve_now_placeholders(result.job_update_map, now=now) + if instance_update_map is not None: + resolve_now_placeholders(instance_update_map, now=now) + if result.volume_update_rows: + resolve_now_placeholders(result.volume_update_rows, now=now) + + res = await session.execute( + update(JobModel) + .where( + JobModel.id == item.id, + JobModel.lock_token == item.lock_token, + ) + .values(**result.job_update_map) + .returning(JobModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + log_lock_token_changed_after_processing(logger, item) + if instance_model is not None: + await _unlock_related_instance( + session=session, + item=item, + instance_id=instance_model.id, + ) + return + + if instance_model is not None and instance_update_map is not None: + res = await session.execute( + update(InstanceModel) + .where( + InstanceModel.id == instance_model.id, + InstanceModel.lock_token == item.lock_token, + InstanceModel.lock_owner == related_instance_lock_owner, + ) + .values(**instance_update_map) + .returning(InstanceModel.id) + ) + updated_ids = list(res.scalars().all()) + if len(updated_ids) == 0: + logger.error( + "Failed to update related instance %s for terminating job %s.", + instance_model.id, + item.id, + ) + + if result.volume_update_rows: + await session.execute(update(VolumeModel), result.volume_update_rows) + + if result.detached_volume_ids and instance_model is not None: + await session.execute( + delete(VolumeAttachmentModel).where( + VolumeAttachmentModel.instance_id == instance_model.id, + VolumeAttachmentModel.volume_id.in_(result.detached_volume_ids), + ) + ) + + emit_job_status_change_event( + session=session, + job_model=job_model, + old_status=job_model.status, + new_status=result.job_update_map.get("status", job_model.status), + termination_reason=result.job_update_map.get( + "termination_reason", job_model.termination_reason + ), + termination_reason_message=result.job_update_map.get( + "termination_reason_message", + job_model.termination_reason_message, + ), + ) + + if instance_model is not None and instance_update_map is not None: + emit_instance_status_change_event( + session=session, + instance_model=instance_model, + old_status=instance_model.status, + new_status=instance_update_map.get("status", instance_model.status), + termination_reason=instance_update_map.get( + "termination_reason", + instance_model.termination_reason, + ), + termination_reason_message=instance_update_map.get( + "termination_reason_message", + instance_model.termination_reason_message, + ), + ) + + if result.unassign_event_message is not None and instance_model is not None: + events.emit( + session, + result.unassign_event_message, + actor=events.SystemActor(), + targets=[ + events.Target.from_model(job_model), + events.Target.from_model(instance_model), + ], + ) + + if result.emit_unregister_replica_event: + targets = [events.Target.from_model(job_model)] + if result.unregister_gateway_target is not None: + targets.append(result.unregister_gateway_target) + events.emit( + session, + "Service replica unregistered from receiving requests", + actor=events.SystemActor(), + targets=targets, + ) + + +async def _unlock_related_instance( + session: AsyncSession, + item: JobTerminatingPipelineItem, + instance_id: uuid.UUID, +) -> None: + await session.execute( + update(InstanceModel) + .where( + InstanceModel.id == instance_id, + InstanceModel.lock_token == item.lock_token, + InstanceModel.lock_owner == _get_related_instance_lock_owner(item.id), + ) + .values( + lock_expires_at=None, + lock_token=None, + lock_owner=None, + ) + ) + + +async def _process_terminating_job( + job_model: JobModel, + instance_model: Optional[InstanceModel], +) -> _ProcessResult: + """ + Stops the job: tells shim to stop the container, detaches the job from the instance, + and detaches volumes from the instance. + Graceful stop should already be done by `process_terminating_run`. + """ + instance_update_map = None if instance_model is None else _InstanceUpdateMap() + result = _ProcessResult(instance_update_map=instance_update_map) + + if instance_model is None: + await _unregister_replica_and_update_result(result=result, job_model=job_model) + result.job_update_map["status"] = _get_job_termination_status(job_model) + return result + + jrd = get_job_runtime_data(job_model) + jpd = get_job_provisioning_data(job_model) + if jpd is not None: + logger.debug("%s: stopping container", fmt(job_model)) + ssh_private_keys = get_instance_ssh_private_keys(instance_model) + if not await _stop_container(job_model, jpd, ssh_private_keys): + logger.warning( + ( + "%s: could not stop container, possibly due to a communication error." + " See debug logs for details." + " Ignoring, can attempt to remove the container later" + ), + fmt(job_model), + ) + + ( + result.volume_update_rows, + detach_result, + ) = await _detach_job_volumes( + job_model=job_model, + instance_model=instance_model, + job_provisioning_data=jpd, + ) + result.detached_volume_ids = detach_result.detached_volume_ids + if detach_result.set_volumes_detached_at: + result.job_update_map["volumes_detached_at"] = NOW_PLACEHOLDER + + instance_update_map = get_or_error(result.instance_update_map) + busy_blocks = instance_model.busy_blocks - _get_job_occupied_blocks(jrd) + instance_update_map["busy_blocks"] = busy_blocks + if instance_model.status != InstanceStatus.BUSY or jpd is None or not jpd.dockerized: + if instance_model.status not in InstanceStatus.finished_statuses(): + instance_update_map["termination_reason"] = InstanceTerminationReason.JOB_FINISHED + instance_update_map["status"] = InstanceStatus.TERMINATING + elif not [j for j in instance_model.jobs if j.id != job_model.id]: + instance_update_map["status"] = InstanceStatus.IDLE + + result.job_update_map["instance_id"] = None + instance_update_map["last_job_processed_at"] = NOW_PLACEHOLDER + result.unassign_event_message = ( + "Job unassigned from instance." + f" Instance blocks: {busy_blocks}/{instance_model.total_blocks} busy" + ) + + await _unregister_replica_and_update_result(result=result, job_model=job_model) + if detach_result.all_detached: + result.job_update_map["status"] = _get_job_termination_status(job_model) + return result + + +async def _process_job_volumes_detaching( + job_model: JobModel, + instance_model: InstanceModel, +) -> _ProcessResult: + """ + Called after job's volumes have been soft detached to check if they are detached. + Terminates the job when all the volumes are detached. + If the volumes fail to detach, force detaches them. + """ + result = _ProcessResult(instance_update_map=_InstanceUpdateMap()) + jpd = get_or_error(get_job_provisioning_data(job_model)) + ( + result.volume_update_rows, + detach_result, + ) = await _detach_job_volumes( + job_model=job_model, + instance_model=instance_model, + job_provisioning_data=jpd, + ) + result.detached_volume_ids = detach_result.detached_volume_ids + if detach_result.all_detached: + result.job_update_map["status"] = _get_job_termination_status(job_model) + return result + + +async def _detach_job_volumes( + job_model: JobModel, + instance_model: InstanceModel, + job_provisioning_data: Optional[JobProvisioningData], +) -> tuple[list[_VolumeUpdateRow], _VolumeDetachResult]: + volume_models = await _load_job_volume_models( + job_model=job_model, instance_model=instance_model + ) + volume_update_rows = _get_volume_update_rows(volume_models) + if len(volume_models) == 0: + return volume_update_rows, _VolumeDetachResult(all_detached=True) + + if job_provisioning_data is None: + return volume_update_rows, _VolumeDetachResult(all_detached=True) + + logger.info("Detaching volumes: %s", [v.name for v in volume_models]) + detach_result = await _detach_volumes_from_job_instance( + job_model=job_model, + instance_model=instance_model, + volume_models=volume_models, + jpd=job_provisioning_data, + run_termination_reason=job_model.run.termination_reason, + ) + return volume_update_rows, detach_result + + +async def _unregister_replica_and_update_result( + result: _ProcessResult, job_model: JobModel +) -> None: + result.unregister_gateway_target = await _unregister_replica(job_model=job_model) + if job_model.registered: + result.job_update_map["registered"] = False + result.emit_unregister_replica_event = True + + +async def _unregister_replica( + job_model: JobModel, +) -> Optional[events.Target]: + if not job_model.registered: + return None + gateway_target = None + run_model = job_model.run + if run_model.gateway_id is not None: + async with get_session_ctx() as session: + gateway, conn = await get_or_add_gateway_connection(session, run_model.gateway_id) + gateway_target = events.Target.from_model(gateway) + try: + logger.debug( + "%s: unregistering replica from service %s", fmt(job_model), job_model.run_id.hex + ) + async with conn.client() as client: + await client.unregister_replica( + project=run_model.project.name, + run_name=run_model.run_name, + job_id=job_model.id, + ) + except GatewayError as e: + logger.warning("%s: unregistering replica from service: %s", fmt(job_model), e) + except (httpx.RequestError, SSHError) as e: + logger.debug("Gateway request failed", exc_info=True) + # FIXME: Unhandled exception raised. + # Handle and retry unregister with timeout. + raise GatewayError(repr(e)) + return gateway_target + + +def _get_job_termination_status(job_model: JobModel) -> JobStatus: + if job_model.termination_reason is not None: + return job_model.termination_reason.to_status() + return JobStatus.FAILED + + +def _get_volume_update_rows(volume_models: list[VolumeModel]) -> list[_VolumeUpdateRow]: + return [ + { + "id": volume_model.id, + "last_job_processed_at": NOW_PLACEHOLDER, + } + for volume_model in volume_models + ] + + +def _get_job_occupied_blocks(jrd: Optional[JobRuntimeData]) -> int: + if jrd is not None and jrd.offer is not None: + return jrd.offer.blocks + return 1 + + +async def _stop_container( + job_model: JobModel, + job_provisioning_data: JobProvisioningData, + ssh_private_keys: tuple[str, Optional[str]], +) -> bool: + if job_provisioning_data.dockerized: + return await common.run_async( + _shim_submit_stop, + ssh_private_keys, + job_provisioning_data, + None, + job_model, + ) + return True + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) +def _shim_submit_stop(ports: dict[int, int], job_model: JobModel) -> bool: + shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + + resp = shim_client.healthcheck() + if resp is None: + logger.debug("%s: can't stop container, shim is not available yet", fmt(job_model)) + return False + + if shim_client.is_api_v2_supported(): + reason = ( + None if job_model.termination_reason is None else job_model.termination_reason.value + ) + shim_client.terminate_task( + task_id=job_model.id, + reason=reason, + message=job_model.termination_reason_message, + timeout=0, + ) + if not settings.SERVER_KEEP_SHIM_TASKS: + shim_client.remove_task(task_id=job_model.id) + else: + shim_client.stop(force=True) + return True + + +async def _detach_volumes_from_job_instance( + job_model: JobModel, + instance_model: InstanceModel, + volume_models: list[VolumeModel], + jpd: JobProvisioningData, + run_termination_reason: Optional[RunTerminationReason], +) -> _VolumeDetachResult: + job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + backend = await backends_services.get_project_backend_by_type( + project=instance_model.project, + backend_type=jpd.backend, + ) + if backend is None: + logger.error( + "Failed to detach volumes from %s. Backend not available.", instance_model.name + ) + return _VolumeDetachResult(all_detached=False) + + detached_volume_ids = set() + all_detached = True + for volume_model in volume_models: + detached = await _detach_volume_from_job_instance( + backend=backend, + job_model=job_model, + jpd=jpd, + job_spec=job_spec, + instance_model=instance_model, + volume_model=volume_model, + run_termination_reason=run_termination_reason, + ) + if detached: + detached_volume_ids.add(volume_model.id) + else: + all_detached = False + + return _VolumeDetachResult( + all_detached=all_detached, + detached_volume_ids=detached_volume_ids, + set_volumes_detached_at=job_model.volumes_detached_at is None, + ) + + +async def _detach_volume_from_job_instance( + backend: Backend, + job_model: JobModel, + jpd: JobProvisioningData, + job_spec: JobSpec, + instance_model: InstanceModel, + volume_model: VolumeModel, + run_termination_reason: Optional[RunTerminationReason], +) -> bool: + detached = True + volume = volume_model_to_volume(volume_model) + if volume.provisioning_data is None or not volume.provisioning_data.detachable: + return detached + compute = backend.compute() + assert isinstance(compute, ComputeWithVolumeSupport) + try: + if job_model.volumes_detached_at is None: + await common.run_async( + compute.detach_volume, + volume=volume, + provisioning_data=jpd, + force=False, + ) + detached = await common.run_async( + compute.is_volume_detached, + volume=volume, + provisioning_data=jpd, + ) + else: + detached = await common.run_async( + compute.is_volume_detached, + volume=volume, + provisioning_data=jpd, + ) + if not detached and _should_force_detach_volume( + job_model=job_model, + run_termination_reason=run_termination_reason, + stop_duration=job_spec.stop_duration, + ): + logger.info( + "Force detaching volume %s from %s", + volume_model.name, + instance_model.name, + ) + await common.run_async( + compute.detach_volume, + volume=volume, + provisioning_data=jpd, + force=True, + ) + except BackendError as e: + logger.error( + "Failed to detach volume %s from %s: %s", + volume_model.name, + instance_model.name, + repr(e), + ) + except Exception: + logger.exception( + "Got exception when detaching volume %s from instance %s", + volume_model.name, + instance_model.name, + ) + return detached + + +_MIN_FORCE_DETACH_WAIT_PERIOD = timedelta(seconds=60) + + +def _should_force_detach_volume( + job_model: JobModel, + run_termination_reason: Optional[RunTerminationReason], + stop_duration: Optional[int], +) -> bool: + now = get_current_datetime() + return ( + job_model.volumes_detached_at is not None + and now > job_model.volumes_detached_at + _MIN_FORCE_DETACH_WAIT_PERIOD + and ( + job_model.termination_reason == JobTerminationReason.ABORTED_BY_USER + or run_termination_reason == RunTerminationReason.ABORTED_BY_USER + or stop_duration is not None + and now > job_model.volumes_detached_at + timedelta(seconds=stop_duration) + ) + ) + + +def _get_related_instance_lock_owner(job_id: uuid.UUID) -> str: + return f"{JobTerminatingPipeline.__name__}:{job_id}" diff --git a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py index 552ae00dc..8fb5bfd4b 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/placement_groups.py @@ -136,7 +136,7 @@ async def fetch(self, limit: int) -> list[PipelineItem]: ) .order_by(PlacementGroupModel.last_processed_at.asc()) .limit(limit) - .with_for_update(skip_locked=True, key_share=True) + .with_for_update(skip_locked=True, key_share=True, of=PlacementGroupModel) .options( load_only( PlacementGroupModel.id, diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py index 81d94c361..89eea7d92 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -154,7 +154,7 @@ async def fetch(self, limit: int) -> list[VolumePipelineItem]: ) .order_by(VolumeModel.last_processed_at.asc()) .limit(limit) - .with_for_update(skip_locked=True, key_share=True) + .with_for_update(skip_locked=True, key_share=True, of=VolumeModel) .options( load_only( VolumeModel.id, diff --git a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py index 2994fca37..a0b97e7d5 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/__init__.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/__init__.py @@ -134,12 +134,6 @@ def start_scheduled_tasks() -> AsyncIOScheduler: kwargs={"batch_size": 5}, max_instances=2 if replica == 0 else 1, ) - _scheduler.add_job( - process_terminating_jobs, - IntervalTrigger(seconds=4, jitter=2), - kwargs={"batch_size": 5}, - max_instances=2 if replica == 0 else 1, - ) _scheduler.add_job( process_runs, IntervalTrigger(seconds=2, jitter=1), @@ -159,5 +153,11 @@ def start_scheduled_tasks() -> AsyncIOScheduler: kwargs={"batch_size": 1}, max_instances=2 if replica == 0 else 1, ) + _scheduler.add_job( + process_terminating_jobs, + IntervalTrigger(seconds=4, jitter=2), + kwargs={"batch_size": 5}, + max_instances=2 if replica == 0 else 1, + ) _scheduler.start() return _scheduler diff --git a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py index 27163b53d..4cf63f2b7 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/terminating_jobs.py @@ -1,24 +1,56 @@ import asyncio +from datetime import timedelta +from typing import Optional from sqlalchemy import or_, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload -from dstack._internal.core.models.runs import JobStatus +from dstack._internal.core.backends.base.backend import Backend +from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport +from dstack._internal.core.consts import DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.errors import BackendError +from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason +from dstack._internal.core.models.runs import ( + JobProvisioningData, + JobRuntimeData, + JobSpec, + JobStatus, + JobTerminationReason, + RunTerminationReason, +) +from dstack._internal.server import settings from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( InstanceModel, JobModel, ProjectModel, + RunModel, VolumeAttachmentModel, + VolumeModel, +) +from dstack._internal.server.services import backends as backends_services +from dstack._internal.server.services import events, services +from dstack._internal.server.services.instances import ( + format_instance_blocks_for_event, + get_instance_ssh_private_keys, + switch_instance_status, ) from dstack._internal.server.services.jobs import ( - process_terminating_job, - process_volumes_detaching, + get_job_provisioning_data, + get_job_runtime_data, + switch_job_status, ) from dstack._internal.server.services.locking import get_locker from dstack._internal.server.services.logging import fmt +from dstack._internal.server.services.runner import client +from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel +from dstack._internal.server.services.volumes import ( + list_project_volume_models, + volume_model_to_volume, +) from dstack._internal.server.utils import sentry_utils +from dstack._internal.utils import common from dstack._internal.utils.common import ( get_current_datetime, get_or_error, @@ -28,6 +60,8 @@ logger = get_logger(__name__) +# NOTE: This scheduled task is going to be deprecated in favor of `JobTerminatingPipeline`. +# If this logic changes before removal, keep `pipeline_tasks/jobs_terminating.py` in sync. async def process_terminating_jobs(batch_size: int = 1): tasks = [] for _ in range(batch_size): @@ -101,9 +135,361 @@ async def _process_job(session: AsyncSession, job_model: JobModel): ) instance_model = res.unique().scalar() if job_model.volumes_detached_at is None: - await process_terminating_job(session, job_model, instance_model) + await _process_terminating_job(session, job_model, instance_model) else: instance_model = get_or_error(instance_model) - await process_volumes_detaching(session, job_model, instance_model) + await _process_volumes_detaching(session, job_model, instance_model) job_model.last_processed_at = get_current_datetime() await session.commit() + + +async def _process_terminating_job( + session: AsyncSession, + job_model: JobModel, + instance_model: Optional[InstanceModel], +): + """ + Stops the job: tells shim to stop the container, detaches the job from the instance, + and detaches volumes from the instance. + Graceful stop should already be done by `process_terminating_run`. + Caller must acquire the locks on the job and the job's instance. + """ + if instance_model is None: + # Possible if the job hasn't been assigned an instance yet + await services.unregister_replica(session, job_model) + _set_job_termination_status(session, job_model) + return + + all_volumes_detached: bool = True + jrd = get_job_runtime_data(job_model) + jpd = get_job_provisioning_data(job_model) + volume_models = await _get_job_volume_models( + session=session, + job_model=job_model, + instance_model=instance_model, + jrd=jrd, + ) + if jpd is not None: + logger.debug("%s: stopping container", fmt(job_model)) + ssh_private_keys = get_instance_ssh_private_keys(instance_model) + if not await _stop_container(job_model, jpd, ssh_private_keys): + # The dangling container can be removed later during instance processing + logger.warning( + ( + "%s: could not stop container, possibly due to a communication error." + " See debug logs for details." + " Ignoring, can attempt to remove the container later" + ), + fmt(job_model), + ) + if len(volume_models) > 0: + logger.info("Detaching volumes: %s", [v.name for v in volume_models]) + all_volumes_detached = await _detach_volumes_from_job_instance( + session=session, + project=instance_model.project, + job_model=job_model, + jpd=jpd, + instance_model=instance_model, + volume_models=volume_models, + ) + + instance_model.busy_blocks -= _get_job_occupied_blocks(jrd) + if instance_model.status != InstanceStatus.BUSY or jpd is None or not jpd.dockerized: + # Terminate instances that: + # - have not finished provisioning yet + # - belong to container-based backends, and hence cannot be reused + if instance_model.status not in InstanceStatus.finished_statuses(): + instance_model.termination_reason = InstanceTerminationReason.JOB_FINISHED + switch_instance_status(session, instance_model, InstanceStatus.TERMINATING) + elif not [j for j in instance_model.jobs if j.id != job_model.id]: + # no other jobs besides this one + switch_instance_status(session, instance_model, InstanceStatus.IDLE) + + # The instance should be released even if detach fails + # so that stuck volumes don't prevent the instance from terminating. + job_model.instance_id = None + instance_model.last_job_processed_at = common.get_current_datetime() + + events.emit( + session, + ( + "Job unassigned from instance." + f" Instance blocks: {format_instance_blocks_for_event(instance_model)}" + ), + actor=events.SystemActor(), + targets=[ + events.Target.from_model(job_model), + events.Target.from_model(instance_model), + ], + ) + + # Volumes are not locked because no other place can update attached active volumes. + for volume_model in volume_models: + volume_model.last_job_processed_at = common.get_current_datetime() + + await services.unregister_replica(session, job_model) + if all_volumes_detached: + # Do not terminate while some volumes are not detached. + _set_job_termination_status(session, job_model) + + +async def _get_job_volume_models( + session: AsyncSession, + job_model: JobModel, + instance_model: InstanceModel, + jrd: Optional[JobRuntimeData], +) -> list[VolumeModel]: + volume_names = ( + jrd.volume_names + if jrd and jrd.volume_names + else [va.volume.name for va in instance_model.volume_attachments] + ) + if len(volume_names) == 0: + return [] + return await list_project_volume_models( + session=session, project=instance_model.project, names=volume_names + ) + + +def _get_job_occupied_blocks(jrd: Optional[JobRuntimeData]) -> int: + if jrd is not None and jrd.offer is not None: + return jrd.offer.blocks + # Old job submitted before jrd or blocks were introduced + return 1 + + +async def _process_volumes_detaching( + session: AsyncSession, + job_model: JobModel, + instance_model: InstanceModel, +): + """ + Called after job's volumes have been soft detached to check if they are detached. + Terminates the job when all the volumes are detached. + If the volumes fail to detach, force detaches them. + """ + jpd = get_or_error(get_job_provisioning_data(job_model)) + jrd = get_job_runtime_data(job_model) + volume_models = await _get_job_volume_models( + session=session, + job_model=job_model, + instance_model=instance_model, + jrd=jrd, + ) + logger.info("Detaching volumes: %s", [v.name for v in volume_models]) + all_volumes_detached = await _detach_volumes_from_job_instance( + session=session, + project=instance_model.project, + job_model=job_model, + jpd=jpd, + instance_model=instance_model, + volume_models=volume_models, + ) + if all_volumes_detached: + # Do not terminate the job while some volumes are not detached. + # If force detach never succeeds, the job will be stuck terminating. + # The job releases the instance when soft detaching, so the instance won't be stuck. + _set_job_termination_status(session, job_model) + + +def _set_job_termination_status(session: AsyncSession, job_model: JobModel): + if job_model.termination_reason is not None: + status = job_model.termination_reason.to_status() + else: + status = JobStatus.FAILED + switch_job_status(session, job_model, status) + + +async def _stop_container( + job_model: JobModel, + job_provisioning_data: JobProvisioningData, + ssh_private_keys: tuple[str, Optional[str]], +) -> bool: + if job_provisioning_data.dockerized: + # send a request to the shim to terminate the docker container + # SSHError and RequestException are caught in the `runner_ssh_tunner` decorator + return await common.run_async( + _shim_submit_stop, + ssh_private_keys, + job_provisioning_data, + None, + job_model, + ) + return True + + +@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) +def _shim_submit_stop(ports: dict[int, int], job_model: JobModel) -> bool: + shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) + + resp = shim_client.healthcheck() + if resp is None: + logger.debug("%s: can't stop container, shim is not available yet", fmt(job_model)) + return False # shim is not available yet + + # we force-kill container because the runner had time to gracefully stop the job + if shim_client.is_api_v2_supported(): + if job_model.termination_reason is None: + reason = None + else: + reason = job_model.termination_reason.value + shim_client.terminate_task( + task_id=job_model.id, + reason=reason, + message=job_model.termination_reason_message, + timeout=0, + ) + # maybe somehow postpone removing old tasks to allow inspecting failed jobs without + # the following setting? + if not settings.SERVER_KEEP_SHIM_TASKS: + shim_client.remove_task(task_id=job_model.id) + else: + shim_client.stop(force=True) + return True + + +async def _detach_volumes_from_job_instance( + session: AsyncSession, + project: ProjectModel, + job_model: JobModel, + jpd: JobProvisioningData, + instance_model: InstanceModel, + volume_models: list[VolumeModel], +) -> bool: + job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) + backend = await backends_services.get_project_backend_by_type( + project=project, + backend_type=jpd.backend, + ) + if backend is None: + logger.error( + "Failed to detach volumes from %s. Backend not available.", instance_model.name + ) + return False + + all_detached = True + detached_volumes = [] + run_termination_reason = await _get_run_termination_reason(session, job_model) + for volume_model in volume_models: + detached = await _detach_volume_from_job_instance( + backend=backend, + job_model=job_model, + jpd=jpd, + job_spec=job_spec, + instance_model=instance_model, + volume_model=volume_model, + run_termination_reason=run_termination_reason, + ) + if detached: + detached_volumes.append(volume_model) + else: + all_detached = False + + if job_model.volumes_detached_at is None: + job_model.volumes_detached_at = common.get_current_datetime() + detached_volumes_ids = {v.id for v in detached_volumes} + instance_model.volume_attachments = [ + va for va in instance_model.volume_attachments if va.volume_id not in detached_volumes_ids + ] + return all_detached + + +async def _detach_volume_from_job_instance( + backend: Backend, + job_model: JobModel, + jpd: JobProvisioningData, + job_spec: JobSpec, + instance_model: InstanceModel, + volume_model: VolumeModel, + run_termination_reason: Optional[RunTerminationReason], +) -> bool: + detached = True + volume = volume_model_to_volume(volume_model) + if volume.provisioning_data is None or not volume.provisioning_data.detachable: + # Backends without `detach_volume` detach volumes automatically + return detached + compute = backend.compute() + assert isinstance(compute, ComputeWithVolumeSupport) + try: + if job_model.volumes_detached_at is None: + # We haven't tried detaching volumes yet, try soft detach first + await common.run_async( + compute.detach_volume, + volume=volume, + provisioning_data=jpd, + force=False, + ) + # For some backends, the volume may be detached immediately + detached = await common.run_async( + compute.is_volume_detached, + volume=volume, + provisioning_data=jpd, + ) + else: + detached = await common.run_async( + compute.is_volume_detached, + volume=volume, + provisioning_data=jpd, + ) + if not detached and _should_force_detach_volume( + job_model, + run_termination_reason=run_termination_reason, + stop_duration=job_spec.stop_duration, + ): + logger.info( + "Force detaching volume %s from %s", + volume_model.name, + instance_model.name, + ) + await common.run_async( + compute.detach_volume, + volume=volume, + provisioning_data=jpd, + force=True, + ) + # Let the next iteration check if force detach worked + except BackendError as e: + logger.error( + "Failed to detach volume %s from %s: %s", + volume_model.name, + instance_model.name, + repr(e), + ) + except Exception: + logger.exception( + "Got exception when detaching volume %s from instance %s", + volume_model.name, + instance_model.name, + ) + return detached + + +async def _get_run_termination_reason( + session: AsyncSession, job_model: JobModel +) -> Optional[RunTerminationReason]: + res = await session.execute( + select(RunModel.termination_reason).where(RunModel.id == job_model.run_id) + ) + return res.scalar_one_or_none() + + +_MIN_FORCE_DETACH_WAIT_PERIOD = timedelta(seconds=60) + + +def _should_force_detach_volume( + job_model: JobModel, + run_termination_reason: Optional[RunTerminationReason], + stop_duration: Optional[int], +) -> bool: + return ( + job_model.volumes_detached_at is not None + and common.get_current_datetime() + > job_model.volumes_detached_at + _MIN_FORCE_DETACH_WAIT_PERIOD + and ( + job_model.termination_reason == JobTerminationReason.ABORTED_BY_USER + or run_termination_reason == RunTerminationReason.ABORTED_BY_USER + or stop_duration is not None + and common.get_current_datetime() + > job_model.volumes_detached_at + timedelta(seconds=stop_duration) + ) + ) diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_09_0928_6026b29d78c7_add_jobmodel_pipeline_columns.py b/src/dstack/_internal/server/migrations/versions/2026/03_09_0928_6026b29d78c7_add_jobmodel_pipeline_columns.py new file mode 100644 index 000000000..84126aa1c --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_09_0928_6026b29d78c7_add_jobmodel_pipeline_columns.py @@ -0,0 +1,47 @@ +"""Add JobModel pipeline columns + +Revision ID: 6026b29d78c7 +Revises: a13f5b55af01 +Create Date: 2026-03-09 09:28:17.993416+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "6026b29d78c7" +down_revision = "a13f5b55af01" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.add_column( + sa.Column( + "lock_expires_at", dstack._internal.server.models.NaiveDateTime(), nullable=True + ) + ) + batch_op.add_column( + sa.Column( + "lock_token", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=True + ) + ) + batch_op.add_column(sa.Column("lock_owner", sa.String(length=100), nullable=True)) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("jobs", schema=None) as batch_op: + batch_op.drop_column("lock_owner") + batch_op.drop_column("lock_token") + batch_op.drop_column("lock_expires_at") + + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_10_1130_8b6d5d8c1b9a_add_ix_jobs_pipeline_fetch_q_index.py b/src/dstack/_internal/server/migrations/versions/2026/03_10_1130_8b6d5d8c1b9a_add_ix_jobs_pipeline_fetch_q_index.py new file mode 100644 index 000000000..b6a1ca924 --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_10_1130_8b6d5d8c1b9a_add_ix_jobs_pipeline_fetch_q_index.py @@ -0,0 +1,51 @@ +"""Add ix_jobs_pipeline_fetch_q index + +Revision ID: 8b6d5d8c1b9a +Revises: 6026b29d78c7 +Create Date: 2026-03-10 11:30:00.000000+00:00 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "8b6d5d8c1b9a" +down_revision = "6026b29d78c7" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_jobs_pipeline_fetch_q", + table_name="jobs", + if_exists=True, + postgresql_concurrently=True, + ) + op.create_index( + "ix_jobs_pipeline_fetch_q", + "jobs", + [sa.literal_column("last_processed_at ASC")], + unique=False, + sqlite_where=sa.text("(status NOT IN ('TERMINATED', 'ABORTED', 'FAILED', 'DONE'))"), + postgresql_where=sa.text( + "(status NOT IN ('TERMINATED', 'ABORTED', 'FAILED', 'DONE'))" + ), + postgresql_concurrently=True, + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.get_context().autocommit_block(): + op.drop_index( + "ix_jobs_pipeline_fetch_q", + table_name="jobs", + if_exists=True, + postgresql_concurrently=True, + ) + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 2f82b5e59..3b4986b48 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -443,7 +443,7 @@ class RunModel(BaseModel): __table_args__ = (Index("ix_submitted_at_id", submitted_at.desc(), id),) -class JobModel(BaseModel): +class JobModel(PipelineModelMixin, BaseModel): __tablename__ = "jobs" id: Mapped[uuid.UUID] = mapped_column( @@ -515,6 +515,15 @@ class JobModel(BaseModel): should be processed only one-by-one. """ + __table_args__ = ( + Index( + "ix_jobs_pipeline_fetch_q", + last_processed_at.asc(), + postgresql_where=status.not_in(JobStatus.finished_statuses()), + sqlite_where=status.not_in(JobStatus.finished_statuses()), + ), + ) + class GatewayModel(PipelineModelMixin, BaseModel): __tablename__ = "gateways" @@ -810,6 +819,9 @@ class VolumeModel(PipelineModelMixin, BaseModel): NaiveDateTime, default=get_current_datetime ) last_job_processed_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) + """`last_job_processed_at` records the last time the volume was used by a job. + Updated when a job terminates and used to delete volumes on `auto_cleanup_duration`. + """ deleted: Mapped[bool] = mapped_column(Boolean, default=False) deleted_at: Mapped[Optional[datetime]] = mapped_column(NaiveDateTime) to_be_deleted: Mapped[bool] = mapped_column(Boolean, server_default=false()) diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 4c094c676..f718a80ce 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -9,19 +9,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload, load_only -import dstack._internal.server.services.backends as backends_services -from dstack._internal.core.backends.base.backend import Backend -from dstack._internal.core.backends.base.compute import ComputeWithVolumeSupport -from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT, DSTACK_SHIM_HTTP_PORT +from dstack._internal.core.consts import DSTACK_RUNNER_HTTP_PORT from dstack._internal.core.errors import ( - BackendError, ResourceNotExistsError, ServerClientError, SSHError, ) from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.configurations import RunConfigurationType -from dstack._internal.core.models.instances import InstanceStatus, InstanceTerminationReason from dstack._internal.core.models.runs import ( Job, JobProvisioningData, @@ -31,10 +26,8 @@ JobSubmission, JobTerminationReason, RunSpec, - RunTerminationReason, ) from dstack._internal.core.models.volumes import Volume, VolumeMountPoint, VolumeStatus -from dstack._internal.server import settings from dstack._internal.server.models import ( InstanceModel, JobModel, @@ -42,12 +35,10 @@ RunModel, VolumeModel, ) -from dstack._internal.server.services import events, services +from dstack._internal.server.services import events from dstack._internal.server.services import volumes as volumes_services from dstack._internal.server.services.instances import ( - format_instance_blocks_for_event, get_instance_ssh_private_keys, - switch_instance_status, ) from dstack._internal.server.services.jobs.configurators.base import ( JobConfigurator, @@ -60,12 +51,8 @@ from dstack._internal.server.services.probes import probe_model_to_probe from dstack._internal.server.services.runner import client from dstack._internal.server.services.runner.ssh import runner_ssh_tunnel -from dstack._internal.server.services.volumes import ( - list_project_volume_models, - volume_model_to_volume, -) from dstack._internal.utils import common -from dstack._internal.utils.common import get_or_error, run_async +from dstack._internal.utils.common import run_async from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) @@ -88,15 +75,55 @@ def switch_job_status( return job_model.status = new_status + emit_job_status_change_event( + session=session, + job_model=job_model, + old_status=old_status, + new_status=new_status, + termination_reason=job_model.termination_reason, + termination_reason_message=job_model.termination_reason_message, + actor=actor, + ) + +def get_job_status_change_message( + old_status: JobStatus, + new_status: JobStatus, + termination_reason: Optional[JobTerminationReason], + termination_reason_message: Optional[str], +) -> str: msg = f"Job status changed {old_status.upper()} -> {new_status.upper()}" if new_status == JobStatus.TERMINATING: - if job_model.termination_reason is None: + if termination_reason is None: raise ValueError("termination_reason must be set when switching to TERMINATING status") - msg += f". Termination reason: {job_model.termination_reason.upper()}" - if job_model.termination_reason_message: - msg += f" ({job_model.termination_reason_message})" - events.emit(session, msg, actor=actor, targets=[events.Target.from_model(job_model)]) + msg += f". Termination reason: {termination_reason.upper()}" + if termination_reason_message: + msg += f" ({termination_reason_message})" + return msg + + +def emit_job_status_change_event( + session: AsyncSession, + job_model: JobModel, + old_status: JobStatus, + new_status: JobStatus, + termination_reason: Optional[JobTerminationReason], + termination_reason_message: Optional[str], + actor: events.AnyActor = events.SystemActor(), +) -> None: + if old_status == new_status: + return + events.emit( + session, + get_job_status_change_message( + old_status=old_status, + new_status=new_status, + termination_reason=termination_reason, + termination_reason_message=termination_reason_message, + ), + actor=actor, + targets=[events.Target.from_model(job_model)], + ) async def get_jobs_from_run_spec( @@ -302,208 +329,6 @@ def _stop_runner( logger.exception("%s: failed to stop runner gracefully", fmt(job_model)) -async def process_terminating_job( - session: AsyncSession, - job_model: JobModel, - instance_model: Optional[InstanceModel], -): - """ - Stops the job: tells shim to stop the container, detaches the job from the instance, - and detaches volumes from the instance. - Graceful stop should already be done by `process_terminating_run`. - Caller must acquire the locks on the job and the job's instance. - """ - if job_model.remove_at is not None and job_model.remove_at > common.get_current_datetime(): - # it's too early to terminate the instance - return - - if instance_model is None: - # Possible if the job hasn't been assigned an instance yet - await services.unregister_replica(session, job_model) - _set_job_termination_status(session, job_model) - return - - all_volumes_detached: bool = True - jrd = get_job_runtime_data(job_model) - jpd = get_job_provisioning_data(job_model) - if jpd is not None: - logger.debug("%s: stopping container", fmt(job_model)) - ssh_private_keys = get_instance_ssh_private_keys(instance_model) - if not await stop_container(job_model, jpd, ssh_private_keys): - # The dangling container can be removed later during instance processing - logger.warning( - ( - "%s: could not stop container, possibly due to a communication error." - " See debug logs for details." - " Ignoring, can attempt to remove the container later" - ), - fmt(job_model), - ) - if jrd is not None and jrd.volume_names is not None: - volume_names = jrd.volume_names - else: - # Legacy jobs before job_runtime_data/blocks were introduced - volume_names = [va.volume.name for va in instance_model.volume_attachments] - volume_models = await list_project_volume_models( - session=session, project=instance_model.project, names=volume_names - ) - if len(volume_models) > 0: - logger.info("Detaching volumes: %s", [v.name for v in volume_models]) - all_volumes_detached = await _detach_volumes_from_job_instance( - session=session, - project=instance_model.project, - job_model=job_model, - jpd=jpd, - instance_model=instance_model, - volume_models=volume_models, - ) - - if jrd is not None and jrd.offer is not None: - blocks = jrd.offer.blocks - else: - # Old job submitted before jrd or blocks were introduced - blocks = 1 - instance_model.busy_blocks -= blocks - - if instance_model.status != InstanceStatus.BUSY or jpd is None or not jpd.dockerized: - # Terminate instances that: - # - have not finished provisioning yet - # - belong to container-based backends, and hence cannot be reused - if instance_model.status not in InstanceStatus.finished_statuses(): - instance_model.termination_reason = InstanceTerminationReason.JOB_FINISHED - switch_instance_status(session, instance_model, InstanceStatus.TERMINATING) - elif not [j for j in instance_model.jobs if j.id != job_model.id]: - # no other jobs besides this one - switch_instance_status(session, instance_model, InstanceStatus.IDLE) - - # The instance should be released even if detach fails - # so that stuck volumes don't prevent the instance from terminating. - job_model.instance_id = None - instance_model.last_job_processed_at = common.get_current_datetime() - - events.emit( - session, - ( - "Job unassigned from instance." - f" Instance blocks: {format_instance_blocks_for_event(instance_model)}" - ), - actor=events.SystemActor(), - targets=[ - events.Target.from_model(job_model), - events.Target.from_model(instance_model), - ], - ) - - volume_names = ( - jrd.volume_names - if jrd and jrd.volume_names - else [va.volume.name for va in instance_model.volume_attachments] - ) - if volume_names: - volumes = await list_project_volume_models( - session=session, project=instance_model.project, names=volume_names - ) - for volume in volumes: - volume.last_job_processed_at = common.get_current_datetime() - - await services.unregister_replica(session, job_model) - if all_volumes_detached: - # Do not terminate while some volumes are not detached. - _set_job_termination_status(session, job_model) - - -async def process_volumes_detaching( - session: AsyncSession, - job_model: JobModel, - instance_model: InstanceModel, -): - """ - Called after job's volumes have been soft detached to check if they are detached. - Terminates the job when all the volumes are detached. - If the volumes fail to detach, force detaches them. - """ - jpd = get_or_error(get_job_provisioning_data(job_model)) - jrd = get_job_runtime_data(job_model) - if jrd is not None and jrd.volume_names is not None: - volume_names = jrd.volume_names - else: - # Legacy jobs before job_runtime_data/blocks were introduced - volume_names = [va.volume.name for va in instance_model.volume_attachments] - volume_models = await list_project_volume_models( - session=session, project=instance_model.project, names=volume_names - ) - logger.info("Detaching volumes: %s", [v.name for v in volume_models]) - all_volumes_detached = await _detach_volumes_from_job_instance( - session=session, - project=instance_model.project, - job_model=job_model, - jpd=jpd, - instance_model=instance_model, - volume_models=volume_models, - ) - if all_volumes_detached: - # Do not terminate the job while some volumes are not detached. - # If force detach never succeeds, the job will be stuck terminating. - # The job releases the instance when soft detaching, so the instance won't be stuck. - _set_job_termination_status(session, job_model) - - -def _set_job_termination_status(session: AsyncSession, job_model: JobModel): - if job_model.termination_reason is not None: - status = job_model.termination_reason.to_status() - else: - status = JobStatus.FAILED - switch_job_status(session, job_model, status) - - -async def stop_container( - job_model: JobModel, - job_provisioning_data: JobProvisioningData, - ssh_private_keys: tuple[str, Optional[str]], -) -> bool: - if job_provisioning_data.dockerized: - # send a request to the shim to terminate the docker container - # SSHError and RequestException are caught in the `runner_ssh_tunner` decorator - return await run_async( - _shim_submit_stop, - ssh_private_keys, - job_provisioning_data, - None, - job_model, - ) - return True - - -@runner_ssh_tunnel(ports=[DSTACK_SHIM_HTTP_PORT]) -def _shim_submit_stop(ports: Dict[int, int], job_model: JobModel) -> bool: - shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) - - resp = shim_client.healthcheck() - if resp is None: - logger.debug("%s: can't stop container, shim is not available yet", fmt(job_model)) - return False # shim is not available yet - - # we force-kill container because the runner had time to gracefully stop the job - if shim_client.is_api_v2_supported(): - if job_model.termination_reason is None: - reason = None - else: - reason = job_model.termination_reason.value - shim_client.terminate_task( - task_id=job_model.id, - reason=reason, - message=job_model.termination_reason_message, - timeout=0, - ) - # maybe somehow postpone removing old tasks to allow inspecting failed jobs without - # the following setting? - if not settings.SERVER_KEEP_SHIM_TASKS: - shim_client.remove_task(task_id=job_model.id) - else: - shim_client.stop(force=True) - return True - - def group_jobs_by_replica_latest(jobs: List[JobModel]) -> Iterable[Tuple[int, List[JobModel]]]: """ Args: @@ -525,153 +350,6 @@ def group_jobs_by_replica_latest(jobs: List[JobModel]) -> Iterable[Tuple[int, Li yield replica_num, replica_jobs -async def _detach_volumes_from_job_instance( - session: AsyncSession, - project: ProjectModel, - job_model: JobModel, - jpd: JobProvisioningData, - instance_model: InstanceModel, - volume_models: list[VolumeModel], -) -> bool: - job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) - backend = await backends_services.get_project_backend_by_type( - project=project, - backend_type=jpd.backend, - ) - if backend is None: - logger.error( - "Failed to detach volumes from %s. Backend not available.", instance_model.name - ) - return False - - all_detached = True - detached_volumes = [] - run_termination_reason = await _get_run_termination_reason(session, job_model) - for volume_model in volume_models: - detached = await _detach_volume_from_job_instance( - backend=backend, - job_model=job_model, - jpd=jpd, - job_spec=job_spec, - instance_model=instance_model, - volume_model=volume_model, - run_termination_reason=run_termination_reason, - ) - if detached: - detached_volumes.append(volume_model) - else: - all_detached = False - - if job_model.volumes_detached_at is None: - job_model.volumes_detached_at = common.get_current_datetime() - detached_volumes_ids = {v.id for v in detached_volumes} - instance_model.volume_attachments = [ - va for va in instance_model.volume_attachments if va.volume_id not in detached_volumes_ids - ] - return all_detached - - -async def _detach_volume_from_job_instance( - backend: Backend, - job_model: JobModel, - jpd: JobProvisioningData, - job_spec: JobSpec, - instance_model: InstanceModel, - volume_model: VolumeModel, - run_termination_reason: Optional[RunTerminationReason], -) -> bool: - detached = True - volume = volume_model_to_volume(volume_model) - if volume.provisioning_data is None or not volume.provisioning_data.detachable: - # Backends without `detach_volume` detach volumes automatically - return detached - compute = backend.compute() - assert isinstance(compute, ComputeWithVolumeSupport) - try: - if job_model.volumes_detached_at is None: - # We haven't tried detaching volumes yet, try soft detach first - await run_async( - compute.detach_volume, - volume=volume, - provisioning_data=jpd, - force=False, - ) - # For some backends, the volume may be detached immediately - detached = await run_async( - compute.is_volume_detached, - volume=volume, - provisioning_data=jpd, - ) - else: - detached = await run_async( - compute.is_volume_detached, - volume=volume, - provisioning_data=jpd, - ) - if not detached and _should_force_detach_volume( - job_model, - run_termination_reason=run_termination_reason, - stop_duration=job_spec.stop_duration, - ): - logger.info( - "Force detaching volume %s from %s", - volume_model.name, - instance_model.name, - ) - await run_async( - compute.detach_volume, - volume=volume, - provisioning_data=jpd, - force=True, - ) - # Let the next iteration check if force detach worked - except BackendError as e: - logger.error( - "Failed to detach volume %s from %s: %s", - volume_model.name, - instance_model.name, - repr(e), - ) - except Exception: - logger.exception( - "Got exception when detaching volume %s from instance %s", - volume_model.name, - instance_model.name, - ) - return detached - - -MIN_FORCE_DETACH_WAIT_PERIOD = timedelta(seconds=60) - - -async def _get_run_termination_reason( - session: AsyncSession, job_model: JobModel -) -> Optional[RunTerminationReason]: - res = await session.execute( - select(RunModel.termination_reason).where(RunModel.id == job_model.run_id) - ) - return res.scalar_one_or_none() - - -def _should_force_detach_volume( - job_model: JobModel, - run_termination_reason: Optional[RunTerminationReason], - stop_duration: Optional[int], -) -> bool: - return ( - job_model.volumes_detached_at is not None - and common.get_current_datetime() - > job_model.volumes_detached_at + MIN_FORCE_DETACH_WAIT_PERIOD - and ( - job_model.termination_reason == JobTerminationReason.ABORTED_BY_USER - or run_termination_reason == RunTerminationReason.ABORTED_BY_USER - or stop_duration is not None - and common.get_current_datetime() - > job_model.volumes_detached_at + timedelta(seconds=stop_duration) - ) - ) - - async def get_instances_ids_with_detaching_volumes(session: AsyncSession) -> List[UUID]: res = await session.execute( select(JobModel) diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py new file mode 100644 index 000000000..822d6798a --- /dev/null +++ b/src/tests/_internal/server/background/pipeline_tasks/test_terminating_jobs.py @@ -0,0 +1,746 @@ +import asyncio +import uuid +from datetime import datetime, timedelta, timezone +from unittest.mock import Mock, patch + +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.instances import InstanceStatus +from dstack._internal.core.models.runs import JobStatus, JobTerminationReason +from dstack._internal.core.models.volumes import VolumeStatus +from dstack._internal.server.background.pipeline_tasks.jobs_terminating import ( + JobTerminatingFetcher, + JobTerminatingPipeline, + JobTerminatingPipelineItem, + JobTerminatingWorker, + _get_related_instance_lock_owner, +) +from dstack._internal.server.models import InstanceModel, JobModel, VolumeAttachmentModel +from dstack._internal.server.testing.common import ( + ComputeMockSpec, + create_instance, + create_job, + create_project, + create_repo, + create_run, + create_user, + create_volume, + get_instance_offer_with_availability, + get_job_provisioning_data, + get_job_runtime_data, + get_volume_configuration, + get_volume_provisioning_data, + list_events, +) +from dstack._internal.utils.common import get_current_datetime + + +@pytest.fixture +def worker() -> JobTerminatingWorker: + return JobTerminatingWorker(queue=Mock(), heartbeater=Mock()) + + +@pytest.fixture +def fetcher() -> JobTerminatingFetcher: + return JobTerminatingFetcher( + queue=asyncio.Queue(), + queue_desired_minsize=1, + min_processing_interval=timedelta(seconds=15), + lock_timeout=timedelta(seconds=30), + heartbeater=Mock(), + ) + + +def _job_to_pipeline_item(job_model: JobModel) -> JobTerminatingPipelineItem: + assert job_model.lock_token is not None + assert job_model.lock_expires_at is not None + return JobTerminatingPipelineItem( + __tablename__=job_model.__tablename__, + id=job_model.id, + lock_token=job_model.lock_token, + lock_expires_at=job_model.lock_expires_at, + prev_lock_expired=False, + volumes_detached_at=job_model.volumes_detached_at, + ) + + +def _lock_job(job_model: JobModel): + job_model.lock_token = uuid.uuid4() + job_model.lock_expires_at = get_current_datetime() + timedelta(seconds=30) + job_model.lock_owner = JobTerminatingPipeline.__name__ + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("image_config_mock") +class TestJobTerminatingFetcher: + async def test_fetch_selects_eligible_jobs_and_sets_lock_fields( + self, test_db, session: AsyncSession, fetcher: JobTerminatingFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + now = get_current_datetime() + stale = now - timedelta(minutes=1) + + terminating = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=stale - timedelta(seconds=2), + ) + past_remove_at = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=stale - timedelta(seconds=1), + ) + past_remove_at.remove_at = stale + past_remove_at.volumes_detached_at = stale - timedelta(seconds=30) + + future_remove_at = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=stale, + ) + future_remove_at.remove_at = now + timedelta(minutes=1) + + non_terminating = await create_job( + session=session, + run=run, + status=JobStatus.RUNNING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=stale, + ) + + recent = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=now, + ) + + locked = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=stale + timedelta(seconds=1), + ) + locked.lock_expires_at = now + timedelta(minutes=1) + locked.lock_token = uuid.uuid4() + locked.lock_owner = "OtherPipeline" + + expired_same_owner = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=stale - timedelta(minutes=2), + last_processed_at=stale + timedelta(seconds=2), + ) + expired_same_owner.lock_expires_at = stale + expired_same_owner.lock_token = uuid.uuid4() + expired_same_owner.lock_owner = JobTerminatingPipeline.__name__ + await session.commit() + + items = await fetcher.fetch(limit=10) + + assert [item.id for item in items] == [ + terminating.id, + past_remove_at.id, + expired_same_owner.id, + ] + assert {(item.id, item.volumes_detached_at) for item in items} == { + (terminating.id, None), + (past_remove_at.id, past_remove_at.volumes_detached_at), + (expired_same_owner.id, None), + } + + for job in [ + terminating, + past_remove_at, + future_remove_at, + non_terminating, + recent, + locked, + expired_same_owner, + ]: + await session.refresh(job) + + fetched_jobs = [terminating, past_remove_at, expired_same_owner] + assert all(job.lock_owner == JobTerminatingPipeline.__name__ for job in fetched_jobs) + assert all(job.lock_expires_at is not None for job in fetched_jobs) + assert all(job.lock_token is not None for job in fetched_jobs) + assert len({job.lock_token for job in fetched_jobs}) == 1 + + assert future_remove_at.lock_owner is None + assert non_terminating.lock_owner is None + assert recent.lock_owner is None + assert locked.lock_owner == "OtherPipeline" + + async def test_fetch_returns_oldest_jobs_first_up_to_limit( + self, test_db, session: AsyncSession, fetcher: JobTerminatingFetcher + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + now = get_current_datetime() + + oldest = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=now - timedelta(minutes=5), + last_processed_at=now - timedelta(minutes=3), + ) + middle = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=now - timedelta(minutes=4), + last_processed_at=now - timedelta(minutes=2), + ) + newest = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + submitted_at=now - timedelta(minutes=3), + last_processed_at=now - timedelta(minutes=1), + ) + + items = await fetcher.fetch(limit=2) + + assert [item.id for item in items] == [oldest.id, middle.id] + + await session.refresh(oldest) + await session.refresh(middle) + await session.refresh(newest) + + assert oldest.lock_owner == JobTerminatingPipeline.__name__ + assert middle.lock_owner == JobTerminatingPipeline.__name__ + assert newest.lock_owner is None + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("image_config_mock") +class TestJobTerminatingWorker: + async def test_terminates_job( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job_provisioning_data = get_job_provisioning_data(dockerized=True) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), + job_provisioning_data=job_provisioning_data, + instance=instance, + ) + _lock_job(job) + await session.commit() + + with ( + patch("dstack._internal.server.services.runner.ssh.SSHTunnel") as SSHTunnelMock, + patch("dstack._internal.server.services.runner.client.ShimClient") as ShimClientMock, + ): + shim_client_mock = ShimClientMock.return_value + await worker.process(_job_to_pipeline_item(job)) + SSHTunnelMock.assert_called_once() + shim_client_mock.healthcheck.assert_called_once() + + await session.refresh(job) + await session.refresh(instance) + assert job.status == JobStatus.TERMINATED + assert job.lock_token is None + assert job.lock_expires_at is None + assert instance.lock_token is None + assert instance.lock_owner is None + + events = await list_events(session) + assert any( + event.message == "Job status changed TERMINATING -> TERMINATED" for event in events + ) + + async def test_detaches_job_volumes( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + volume_provisioning_data=get_volume_provisioning_data(), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + volumes=[volume], + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job_provisioning_data = get_job_provisioning_data(dockerized=False) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), + job_provisioning_data=job_provisioning_data, + instance=instance, + ) + _lock_job(job) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.jobs_terminating.backends_services.get_project_backend_by_type" + ) as m: + backend_mock = Mock() + m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.is_volume_detached.return_value = True + await worker.process(_job_to_pipeline_item(job)) + m.assert_awaited_once() + backend_mock.compute.return_value.detach_volume.assert_called_once() + backend_mock.compute.return_value.is_volume_detached.assert_called_once() + + await session.refresh(job) + assert job.status == JobStatus.TERMINATED + await session.refresh(volume) + assert volume.last_job_processed_at is not None + + async def test_force_detaches_job_volumes( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + volume_provisioning_data=get_volume_provisioning_data(), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + volumes=[volume], + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job_provisioning_data = get_job_provisioning_data(dockerized=False) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), + job_provisioning_data=job_provisioning_data, + instance=instance, + ) + _lock_job(job) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.jobs_terminating.backends_services.get_project_backend_by_type" + ) as m: + backend_mock = Mock() + m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.is_volume_detached.return_value = False + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + res = await session.execute(select(JobModel).options(joinedload(JobModel.instance))) + job = res.unique().scalar_one() + assert job.status == JobStatus.TERMINATING + assert job.instance is None + assert job.volumes_detached_at is not None + + _lock_job(job) + await session.commit() + with ( + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_terminating.backends_services.get_project_backend_by_type" + ) as m, + patch( + "dstack._internal.server.background.pipeline_tasks.jobs_terminating.get_current_datetime" + ) as datetime_mock, + ): + datetime_mock.return_value = job.volumes_detached_at.replace( + tzinfo=timezone.utc + ) + timedelta(minutes=30) + backend_mock = Mock() + m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.is_volume_detached.return_value = False + await worker.process(_job_to_pipeline_item(job)) + backend_mock.compute.return_value.detach_volume.assert_called_once() + detach_kwargs = backend_mock.compute.return_value.detach_volume.call_args.kwargs + assert detach_kwargs["force"] is True + assert detach_kwargs["volume"].id == volume.id + assert ( + detach_kwargs["provisioning_data"].instance_id == job_provisioning_data.instance_id + ) + backend_mock.compute.return_value.is_volume_detached.assert_called_once() + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + + _lock_job(job) + await session.commit() + with patch( + "dstack._internal.server.background.pipeline_tasks.jobs_terminating.backends_services.get_project_backend_by_type" + ) as m: + backend_mock = Mock() + m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.is_volume_detached.return_value = True + await worker.process(_job_to_pipeline_item(job)) + backend_mock.compute.return_value.is_volume_detached.assert_called_once() + + await session.refresh(job) + await session.refresh(instance, ["volume_attachments"]) + res = await session.execute( + select(InstanceModel) + .where(InstanceModel.id == instance.id) + .options(joinedload(InstanceModel.volume_attachments)) + .execution_options(populate_existing=True) + ) + instance = res.unique().scalar_one() + assert job.status == JobStatus.TERMINATED + assert len(instance.volume_attachments) == 0 + + async def test_terminates_job_on_shared_instance( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session) + user = await create_user(session) + repo = await create_repo(session=session, project_id=project.id) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + total_blocks=4, + busy_blocks=3, + ) + run = await create_run(session=session, project=project, repo=repo, user=user) + shared_offer = get_instance_offer_with_availability(blocks=2, total_blocks=4) + jrd = get_job_runtime_data(offer=shared_offer) + job = await create_job( + session=session, + run=run, + instance_assigned=True, + instance=instance, + job_runtime_data=jrd, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + ) + _lock_job(job) + await session.commit() + + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + await session.refresh(instance) + res = await session.execute(select(JobModel).options(joinedload(JobModel.instance))) + job = res.unique().scalar_one() + assert job.status == JobStatus.TERMINATED + assert job.instance_assigned + assert job.instance is None + assert instance.busy_blocks == 1 + + async def test_detaches_job_volumes_on_shared_instance( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + volume_conf_1 = get_volume_configuration(name="vol-1") + volume_1 = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=volume_conf_1, + volume_provisioning_data=get_volume_provisioning_data(), + ) + volume_conf_2 = get_volume_configuration(name="vol-2") + volume_2 = await create_volume( + session=session, + project=project, + user=user, + status=VolumeStatus.ACTIVE, + backend=BackendType.AWS, + configuration=volume_conf_2, + volume_provisioning_data=get_volume_provisioning_data(), + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + volumes=[volume_1, volume_2], + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job_provisioning_data = get_job_provisioning_data(dockerized=False) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + submitted_at=datetime(2023, 1, 2, 5, 12, 30, 5, tzinfo=timezone.utc), + job_provisioning_data=job_provisioning_data, + job_runtime_data=get_job_runtime_data(volume_names=["vol-1"]), + instance=instance, + ) + _lock_job(job) + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.jobs_terminating.backends_services.get_project_backend_by_type" + ) as m: + backend_mock = Mock() + m.return_value = backend_mock + backend_mock.compute.return_value = Mock(spec=ComputeMockSpec) + backend_mock.compute.return_value.is_volume_detached.return_value = True + + await worker.process(_job_to_pipeline_item(job)) + + backend_mock.compute.return_value.detach_volume.assert_called_once() + backend_mock.compute.return_value.is_volume_detached.assert_called_once() + + await session.refresh(job) + await session.refresh(instance) + res = await session.execute( + select(InstanceModel).options( + joinedload(InstanceModel.volume_attachments).joinedload( + VolumeAttachmentModel.volume + ) + ) + ) + instance = res.unique().scalar_one() + assert job.status == JobStatus.TERMINATED + assert len(instance.volume_attachments) == 1 + assert instance.volume_attachments[0].volume == volume_2 + + async def test_resets_job_for_retry_if_related_instance_is_locked( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + instance.lock_owner = "OtherPipeline" + instance.lock_token = uuid.uuid4() + instance.lock_expires_at = get_current_datetime() + timedelta(minutes=1) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + instance=instance, + ) + _lock_job(job) + last_processed_before = job.last_processed_at + await session.commit() + + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner == JobTerminatingPipeline.__name__ + assert job.last_processed_at > last_processed_before + + async def test_resets_job_for_retry_if_related_instance_is_locked_by_another_job( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + other_job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + instance=instance, + ) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + instance=instance, + ) + instance.lock_owner = _get_related_instance_lock_owner(other_job.id) + instance.lock_token = uuid.uuid4() + instance.lock_expires_at = get_current_datetime() - timedelta(minutes=1) + _lock_job(job) + last_processed_before = job.last_processed_at + await session.commit() + + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + await session.refresh(instance) + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner == JobTerminatingPipeline.__name__ + assert job.last_processed_at > last_processed_before + assert instance.lock_owner == _get_related_instance_lock_owner(other_job.id) + + async def test_finishes_job_when_used_instance_is_not_set( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + ) + _lock_job(job) + await session.commit() + + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + assert job.status == JobStatus.TERMINATED + assert job.lock_token is None + assert job.lock_expires_at is None + + async def test_retries_detaching_when_used_instance_is_missing( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + ) + job.instance_id = None + job.used_instance_id = uuid.uuid4() + job.volumes_detached_at = get_current_datetime() + _lock_job(job) + last_processed_before = job.last_processed_at + await session.commit() + + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner == JobTerminatingPipeline.__name__ + assert job.last_processed_at > last_processed_before + + async def test_retries_terminating_when_used_instance_is_missing( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + ) + job.used_instance_id = uuid.uuid4() + _lock_job(job) + last_processed_before = job.last_processed_at + await session.commit() + + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + assert job.status == JobStatus.TERMINATING + assert job.lock_token is None + assert job.lock_expires_at is None + assert job.lock_owner == JobTerminatingPipeline.__name__ + assert job.last_processed_at > last_processed_before + + async def test_keeps_related_instance_locked_on_processing_exception( + self, test_db, session: AsyncSession, worker: JobTerminatingWorker + ): + project = await create_project(session=session) + user = await create_user(session=session) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, repo=repo, user=user) + job = await create_job( + session=session, + run=run, + status=JobStatus.TERMINATING, + termination_reason=JobTerminationReason.TERMINATED_BY_USER, + job_provisioning_data=get_job_provisioning_data(dockerized=True), + instance=instance, + ) + _lock_job(job) + job_lock_token = job.lock_token + await session.commit() + + with patch( + "dstack._internal.server.background.pipeline_tasks.jobs_terminating._process_terminating_job", + side_effect=RuntimeError("boom"), + ): + with pytest.raises(RuntimeError, match="boom"): + await worker.process(_job_to_pipeline_item(job)) + + await session.refresh(job) + await session.refresh(instance) + assert job.lock_token == job_lock_token + assert job.lock_owner == JobTerminatingPipeline.__name__ + assert instance.lock_token == job_lock_token + assert instance.lock_owner == _get_related_instance_lock_owner(job.id)