From 39d1d6b25770d1a6adfb63ee2ce4670007d34cfd Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 10 Mar 2026 01:41:53 +0100 Subject: [PATCH 1/3] Exports API --- src/dstack/_internal/core/models/exports.py | 19 + src/dstack/_internal/server/app.py | 2 + src/dstack/_internal/server/models.py | 18 +- .../_internal/server/routers/exports.py | 84 ++ .../_internal/server/schemas/exports.py | 19 + .../_internal/server/services/exports.py | 311 ++++++ .../_internal/server/services/fleets.py | 13 +- .../_internal/server/services/projects.py | 15 +- src/dstack/_internal/server/testing/common.py | 4 +- .../_internal/server/routers/test_exports.py | 948 ++++++++++++++++++ 10 files changed, 1415 insertions(+), 18 deletions(-) create mode 100644 src/dstack/_internal/core/models/exports.py create mode 100644 src/dstack/_internal/server/routers/exports.py create mode 100644 src/dstack/_internal/server/schemas/exports.py create mode 100644 src/dstack/_internal/server/services/exports.py create mode 100644 src/tests/_internal/server/routers/test_exports.py diff --git a/src/dstack/_internal/core/models/exports.py b/src/dstack/_internal/core/models/exports.py new file mode 100644 index 000000000..52cb4a65a --- /dev/null +++ b/src/dstack/_internal/core/models/exports.py @@ -0,0 +1,19 @@ +import uuid + +from dstack._internal.core.models.common import CoreModel + + +class ExportImport(CoreModel): + project_name: str + + +class ExportedFleet(CoreModel): + id: uuid.UUID + name: str + + +class Export(CoreModel): + id: uuid.UUID + name: str + imports: list[ExportImport] + exported_fleets: list[ExportedFleet] diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 56ea7f038..a0e22b847 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -31,6 +31,7 @@ auth, backends, events, + exports, files, fleets, gateways, @@ -253,6 +254,7 @@ def register_routes(app: FastAPI, ui: bool = True): app.include_router(files.router) app.include_router(events.root_router) app.include_router(templates.router) + app.include_router(exports.project_router) @app.exception_handler(ForbiddenError) async def forbidden_error_handler(request: Request, exc: ForbiddenError): diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 2f82b5e59..9e3ff608d 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -49,6 +49,9 @@ from dstack._internal.utils.logging import get_logger logger = get_logger(__name__) +# Default options (save-update, merge) + delete-orphan + delete (required by delete-orphan) +# delete-orphan allows to automatically delete entities removed from the relationship +CASCADE_DEFAULT_WITH_DELETE_ORPHAN = "save-update, merge, delete-orphan, delete" class NaiveDateTime(TypeDecorator): @@ -760,10 +763,7 @@ class InstanceModel(PipelineModelMixin, BaseModel): volume_attachments: Mapped[List["VolumeAttachmentModel"]] = relationship( back_populates="instance", - # Add delete-orphan option so that removing entries from volume_attachments - # automatically marks them for deletion. - # SQLAlchemy requires delete when using delete-orphan. - cascade="save-update, merge, delete-orphan, delete", + cascade=CASCADE_DEFAULT_WITH_DELETE_ORPHAN, ) __table_args__ = ( @@ -1043,8 +1043,14 @@ class ExportModel(BaseModel): ) project: Mapped["ProjectModel"] = relationship() created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) - imports: Mapped[List["ImportModel"]] = relationship(back_populates="export") - exported_fleets: Mapped[List["ExportedFleetModel"]] = relationship(back_populates="export") + imports: Mapped[List["ImportModel"]] = relationship( + back_populates="export", + cascade=CASCADE_DEFAULT_WITH_DELETE_ORPHAN, + ) + exported_fleets: Mapped[List["ExportedFleetModel"]] = relationship( + back_populates="export", + cascade=CASCADE_DEFAULT_WITH_DELETE_ORPHAN, + ) class ImportModel(BaseModel): diff --git a/src/dstack/_internal/server/routers/exports.py b/src/dstack/_internal/server/routers/exports.py new file mode 100644 index 000000000..bc30ad822 --- /dev/null +++ b/src/dstack/_internal/server/routers/exports.py @@ -0,0 +1,84 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.exports import Export +from dstack._internal.server.db import get_session +from dstack._internal.server.models import ProjectModel, UserModel +from dstack._internal.server.schemas.exports import ( + CreateExportRequest, + DeleteExportRequest, + UpdateExportRequest, +) +from dstack._internal.server.security.permissions import ProjectAdmin, ProjectMember +from dstack._internal.server.services import exports as exports_services +from dstack._internal.server.utils.routers import get_base_api_additional_responses + +project_router = APIRouter( + prefix="/api/project/{project_name}/exports", + tags=["exports"], + responses=get_base_api_additional_responses(), +) + + +@project_router.post("/create", response_model=Export) +async def create_export( + body: CreateExportRequest, + session: Annotated[AsyncSession, Depends(get_session)], + user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectAdmin())], +): + user, project = user_project + return await exports_services.create_export( + session=session, + project=project, + user=user, + name=body.name, + importer_project_names=body.importer_projects, + exported_fleet_names=body.exported_fleets, + ) + + +@project_router.post("/update", response_model=Export) +async def update_export( + body: UpdateExportRequest, + session: Annotated[AsyncSession, Depends(get_session)], + user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectAdmin())], +): + user, project = user_project + return await exports_services.update_export( + session=session, + project=project, + user=user, + name=body.name, + add_importer_project_names=body.add_importer_projects, + remove_importer_project_names=body.remove_importer_projects, + add_exported_fleet_names=body.add_exported_fleets, + remove_exported_fleet_names=body.remove_exported_fleets, + ) + + +@project_router.post("/delete") +async def delete_export( + body: DeleteExportRequest, + session: Annotated[AsyncSession, Depends(get_session)], + user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectAdmin())], +): + _, project = user_project + await exports_services.delete_export( + session=session, + project=project, + name=body.name, + ) + + +@project_router.post("/list", response_model=list[Export]) +async def list_exports( + session: Annotated[AsyncSession, Depends(get_session)], + user_project: Annotated[tuple[UserModel, ProjectModel], Depends(ProjectMember())], +): + _, project = user_project + return await exports_services.list_exports( + session=session, + project=project, + ) diff --git a/src/dstack/_internal/server/schemas/exports.py b/src/dstack/_internal/server/schemas/exports.py new file mode 100644 index 000000000..240b6364a --- /dev/null +++ b/src/dstack/_internal/server/schemas/exports.py @@ -0,0 +1,19 @@ +from dstack._internal.core.models.common import CoreModel + + +class CreateExportRequest(CoreModel): + name: str + importer_projects: list[str] = [] + exported_fleets: list[str] = [] + + +class UpdateExportRequest(CoreModel): + name: str + add_importer_projects: list[str] = [] + remove_importer_projects: list[str] = [] + add_exported_fleets: list[str] = [] + remove_exported_fleets: list[str] = [] + + +class DeleteExportRequest(CoreModel): + name: str diff --git a/src/dstack/_internal/server/services/exports.py b/src/dstack/_internal/server/services/exports.py new file mode 100644 index 000000000..4c47187dd --- /dev/null +++ b/src/dstack/_internal/server/services/exports.py @@ -0,0 +1,311 @@ +from collections.abc import AsyncGenerator +from contextlib import asynccontextmanager +from typing import Optional + +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from dstack._internal.core.errors import ( + ResourceExistsError, + ResourceNotExistsError, + ServerClientError, +) +from dstack._internal.core.models.exports import Export, ExportedFleet, ExportImport +from dstack._internal.core.models.users import GlobalRole +from dstack._internal.core.services import validate_dstack_resource_name +from dstack._internal.server.db import get_db, is_db_postgres, is_db_sqlite +from dstack._internal.server.models import ( + ExportedFleetModel, + ExportModel, + FleetModel, + ImportModel, + ProjectModel, + ProjectRole, + UserModel, +) +from dstack._internal.server.services.fleets import get_fleet_spec, list_project_fleet_models +from dstack._internal.server.services.locking import get_locker, string_to_lock_id +from dstack._internal.server.services.projects import ( + get_user_project_role, + list_user_project_models, +) + + +@asynccontextmanager +async def get_export_model_by_name_for_update( + session: AsyncSession, project: ProjectModel, name: str +) -> AsyncGenerator[Optional[ExportModel], None]: + """ + Fetch export from the database and lock it for update. + + **NOTE**: commit changes to the database before exiting from this context manager, + so that in-memory locks are only released after commit. + """ + filters = [ + ExportModel.project_id == project.id, + ExportModel.name == name, + ] + res = await session.execute(select(ExportModel.id).where(*filters)) + export_id = res.scalars().one_or_none() + if not export_id: + yield None + else: + async with get_locker(get_db().dialect_name).lock_ctx( + ExportModel.__tablename__, [export_id] + ): + # Refetch after lock + res = await session.execute( + select(ExportModel) + .where(ExportModel.id == export_id, *filters) + .options( + selectinload( + ExportModel.imports.and_( + ImportModel.project.has(ProjectModel.deleted == False) + ) + ) + .joinedload(ImportModel.project) + .load_only(ProjectModel.name), + selectinload( + ExportModel.exported_fleets.and_( + ExportedFleetModel.fleet.has(FleetModel.deleted == False) + ) + ) + .joinedload(ExportedFleetModel.fleet) + .load_only(FleetModel.name), + ) + .with_for_update(key_share=True) + ) + yield res.scalars().one_or_none() + + +async def export_exists(session: AsyncSession, project: ProjectModel, name: str) -> bool: + res = await session.execute( + select(func.count()) + .select_from(ExportModel) + .where(ExportModel.project_id == project.id, ExportModel.name == name) + ) + return res.scalar_one() > 0 + + +async def create_export( + session: AsyncSession, + project: ProjectModel, + user: UserModel, + name: str, + importer_project_names: list[str], + exported_fleet_names: list[str], +) -> Export: + validate_dstack_resource_name(name) + + lock_namespace = f"export_names_{project.name}" + if is_db_sqlite(): + # Start new transaction to see committed changes after lock + await session.commit() + elif is_db_postgres(): + await session.execute( + select(func.pg_advisory_xact_lock(string_to_lock_id(lock_namespace))) + ) + lock, _ = get_locker(get_db().dialect_name).get_lockset(lock_namespace) + + async with lock: + if await export_exists(session, project, name): + raise ResourceExistsError( + f"Export {name!r} already exists in project {project.name!r}" + ) + export = ExportModel( + name=name, + project=project, + imports=[], + exported_fleets=[], + ) + await add_importer_projects(session, user, export, importer_project_names) + await add_exported_fleets(session, export, exported_fleet_names) + session.add(export) + await session.commit() + return export_model_to_export(export) + + +async def update_export( + session: AsyncSession, + project: ProjectModel, + user: UserModel, + name: str, + add_importer_project_names: list[str], + remove_importer_project_names: list[str], + add_exported_fleet_names: list[str], + remove_exported_fleet_names: list[str], +) -> Export: + async with get_export_model_by_name_for_update(session, project, name) as export: + if export is None: + raise ResourceNotExistsError(f"Export {name!r} not found in project {project.name!r}") + + if ( + not add_importer_project_names + and not remove_importer_project_names + and not add_exported_fleet_names + and not remove_exported_fleet_names + ): + raise ServerClientError("No changes specified") + + add_importer_project_names = list(map(str.lower, add_importer_project_names)) + remove_importer_project_names = list(map(str.lower, remove_importer_project_names)) + + add_remove_conflict_projects = set(add_importer_project_names) & set( + remove_importer_project_names + ) + if add_remove_conflict_projects: + raise ServerClientError( + f"Projects {add_remove_conflict_projects} are listed for both addition and removal." + " Cannot add and remove at the same time" + ) + add_remove_conflict_fleets = set(add_exported_fleet_names) & set( + remove_exported_fleet_names + ) + if add_remove_conflict_fleets: + raise ServerClientError( + f"Fleets {add_remove_conflict_fleets} are listed for both addition and removal." + " Cannot add and remove at the same time" + ) + + await add_importer_projects(session, user, export, add_importer_project_names) + await add_exported_fleets(session, export, add_exported_fleet_names) + await remove_importer_projects(export, remove_importer_project_names) + await remove_exported_fleets(export, remove_exported_fleet_names) + + await session.commit() + return export_model_to_export(export) + + +async def add_importer_projects( + session: AsyncSession, user: UserModel, export: ExportModel, names: list[str] +) -> None: + if not names: + return + names = list(map(str.lower, names)) + if len(names) != len(set(names)): + raise ServerClientError("Some importer projects are listed for addition more than once") + already_importing = {imp.project.name.lower() for imp in export.imports} & set(names) + if already_importing: + raise ServerClientError( + f"Projects {already_importing} are already importing export {export.name!r}" + ) + if export.project.name.lower() in names: + raise ServerClientError(f"Project {export.project.name!r} cannot import from itself") + projects = await list_user_project_models(session, user, only_names=True, include_members=True) + projects = [p for p in projects if p.name.lower() in names] + if user.global_role != GlobalRole.ADMIN: + projects = [p for p in projects if get_user_project_role(user, p) == ProjectRole.ADMIN] + if missing := set(names) - {p.name.lower() for p in projects}: + raise ServerClientError( + f"Projects {missing} not found or you are not allowed to add them as importers." + " Only project admins can add a project as importer" + ) + for project in projects: + export.imports.append(ImportModel(project=project)) + + +async def add_exported_fleets( + session: AsyncSession, export: ExportModel, names: list[str] +) -> None: + if not names: + return + if len(names) != len(set(names)): + raise ServerClientError("Some fleets are listed for addition more than once") + already_exported = {ef.fleet.name for ef in export.exported_fleets} & set(names) + if already_exported: + raise ServerClientError( + f"Fleets {already_exported} are already exported by export {export.name!r}" + ) + fleets = await list_project_fleet_models( + session=session, + project=export.project, + names=names, + include_imported=False, + include_deleted=False, + include_instances=False, + ) + if missing := set(names) - {f.name for f in fleets}: + raise ResourceNotExistsError( + f"Fleets {missing} not found in project {export.project.name!r}" + ) + cloud_fleet_names = [ + f.name for f in fleets if get_fleet_spec(f).configuration.ssh_config is None + ] + if cloud_fleet_names: + raise ServerClientError( + f"Fleets {cloud_fleet_names} are cloud fleets. Can only export SSH fleets" + ) + for fleet in fleets: + export.exported_fleets.append(ExportedFleetModel(fleet=fleet)) + + +async def remove_importer_projects(export: ExportModel, names: list[str]) -> None: + names = list(map(str.lower, names)) + if len(names) != len(set(names)): + raise ServerClientError("Some importer projects are listed for removal more than once") + existing = {imp.project.name.lower() for imp in export.imports} + if missing := set(names) - existing: + raise ServerClientError(f"Projects {missing} are not importing export {export.name!r}") + export.imports = [imp for imp in export.imports if imp.project.name.lower() not in names] + + +async def remove_exported_fleets(export: ExportModel, names: list[str]) -> None: + if len(names) != len(set(names)): + raise ServerClientError("Some fleets are listed for removal more than once") + existing = {ef.fleet.name for ef in export.exported_fleets} + if missing := set(names) - existing: + raise ServerClientError(f"Fleets {missing} are not exported by export {export.name!r}") + export.exported_fleets = [ef for ef in export.exported_fleets if ef.fleet.name not in names] + + +async def delete_export(session: AsyncSession, project: ProjectModel, name: str) -> None: + async with get_export_model_by_name_for_update(session, project, name) as export: + if export is None: + raise ResourceNotExistsError(f"Export {name!r} not found in project {project.name!r}") + await session.delete(export) + await session.commit() + + +async def list_exports(session: AsyncSession, project: ProjectModel) -> list[Export]: + res = await session.execute( + select(ExportModel) + .where(ExportModel.project == project) + .options( + selectinload( + ExportModel.imports.and_(ImportModel.project.has(ProjectModel.deleted == False)) + ) + .joinedload(ImportModel.project) + .load_only(ProjectModel.name), + selectinload( + ExportModel.exported_fleets.and_( + ExportedFleetModel.fleet.has(FleetModel.deleted == False) + ) + ) + .joinedload(ExportedFleetModel.fleet) + .load_only(FleetModel.name), + ) + .order_by(ExportModel.created_at.desc()) + ) + exports = res.scalars().all() + return [export_model_to_export(export) for export in exports] + + +def export_model_to_export(export_model: ExportModel) -> Export: + return Export( + id=export_model.id, + name=export_model.name, + imports=[ + ExportImport( + project_name=import_model.project.name, + ) + for import_model in export_model.imports + ], + exported_fleets=[ + ExportedFleet( + id=exported_fleet_model.fleet.id, + name=exported_fleet_model.fleet.name, + ) + for exported_fleet_model in export_model.exported_fleets + ], + ) diff --git a/src/dstack/_internal/server/services/fleets.py b/src/dstack/_internal/server/services/fleets.py index 183e81b20..969ef1ace 100644 --- a/src/dstack/_internal/server/services/fleets.py +++ b/src/dstack/_internal/server/services/fleets.py @@ -310,6 +310,7 @@ async def list_project_fleet_models( names: Optional[List[str]] = None, include_imported: bool = False, include_deleted: bool = False, + include_instances: bool = True, ) -> List[FleetModel]: filters = [] is_fleet_imported_subquery = exists().where( @@ -327,14 +328,10 @@ async def list_project_fleet_models( filters.append(FleetModel.name.in_(names)) if not include_deleted: filters.append(FleetModel.deleted == False) - res = await session.execute( - select(FleetModel) - .where(*filters) - .options( - joinedload(FleetModel.project).load_only(ProjectModel.name), - selectinload(FleetModel.instances.and_(InstanceModel.deleted == False)), - ) - ) + options = [joinedload(FleetModel.project).load_only(ProjectModel.name)] + if include_instances: + options.append(selectinload(FleetModel.instances.and_(InstanceModel.deleted == False))) + res = await session.execute(select(FleetModel).where(*filters).options(*options)) return list(res.unique().scalars().all()) diff --git a/src/dstack/_internal/server/services/projects.py b/src/dstack/_internal/server/services/projects.py index d5aa6fb14..90a19f0f0 100644 --- a/src/dstack/_internal/server/services/projects.py +++ b/src/dstack/_internal/server/services/projects.py @@ -457,14 +457,20 @@ async def list_user_project_models( session: AsyncSession, user: UserModel, only_names: bool = False, + include_members: bool = False, ) -> List[ProjectModel]: load_only_attrs = [] if only_names: load_only_attrs += [ProjectModel.id, ProjectModel.name] if user.global_role == GlobalRole.ADMIN: - return await list_project_models(session=session, load_only_attrs=load_only_attrs) + return await list_project_models( + session=session, load_only_attrs=load_only_attrs, include_members=include_members + ) return await list_member_project_models( - session=session, user=user, load_only_attrs=load_only_attrs + session=session, + user=user, + load_only_attrs=load_only_attrs, + include_members=include_members, ) @@ -529,14 +535,17 @@ async def list_user_owned_project_models( async def list_project_models( session: AsyncSession, load_only_attrs: Optional[List[QueryableAttribute]] = None, + include_members: bool = False, ) -> List[ProjectModel]: options = [] + if include_members: + options.append(joinedload(ProjectModel.members)) if load_only_attrs: options.append(load_only(*load_only_attrs)) res = await session.execute( select(ProjectModel).where(ProjectModel.deleted == False).options(*options) ) - return list(res.scalars().all()) + return list(res.scalars().unique().all()) # TODO: Do not load ProjectModel.backends and ProjectModel.members by default when getting project diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 89072e155..2b4c34712 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -171,6 +171,7 @@ async def create_project( ssh_public_key: str = "", is_public: bool = False, templates_repo: Optional[str] = None, + deleted: bool = False, ) -> ProjectModel: if owner is None: owner = await create_user(session=session, name="test_owner") @@ -182,6 +183,7 @@ async def create_project( ssh_public_key=ssh_public_key, is_public=is_public, templates_repo=templates_repo, + deleted=deleted, ) session.add(project) await session.commit() @@ -526,7 +528,7 @@ async def create_export( exporter_project: ProjectModel, importer_projects: list[ProjectModel], exported_fleets: list[FleetModel], - name: str = "test_export", + name: str = "test-export", ) -> ExportModel: export = ExportModel( name=name, diff --git a/src/tests/_internal/server/routers/test_exports.py b/src/tests/_internal/server/routers/test_exports.py new file mode 100644 index 000000000..a6bed34c4 --- /dev/null +++ b/src/tests/_internal/server/routers/test_exports.py @@ -0,0 +1,948 @@ +from typing import Optional + +import pytest +from httpx import AsyncClient +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.users import GlobalRole, ProjectRole +from dstack._internal.server.models import ExportModel +from dstack._internal.server.services.projects import add_project_member +from dstack._internal.server.testing.common import ( + create_export, + create_fleet, + create_project, + create_user, + get_auth_headers, + get_fleet_spec, + get_ssh_fleet_configuration, +) + +pytestmark = [ + pytest.mark.asyncio, + pytest.mark.usefixtures("test_db"), + pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True), +] + + +class TestCreateExport: + async def test_returns_403_if_not_authenticated(self, client: AsyncClient): + response = await client.post( + "/api/project/TestProject/exports/create", + json={ + "name": "test-export", + "importer_projects": ["OtherProject"], + "exported_fleets": ["fleet1"], + }, + ) + assert response.status_code in [401, 403] + + async def test_returns_403_if_not_admin(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + response = await client.post( + f"/api/project/{project.name}/exports/create", + headers=get_auth_headers(user.token), + json={ + "name": "test-export", + "importer_projects": ["OtherProject"], + "exported_fleets": ["fleet1"], + }, + ) + assert response.status_code == 403 + + @pytest.mark.parametrize( + ("global_role", "importer_project_role"), + [(GlobalRole.ADMIN, None), (GlobalRole.USER, ProjectRole.ADMIN)], + ) + async def test_creates_export( + self, + session: AsyncSession, + client: AsyncClient, + global_role: GlobalRole, + importer_project_role: Optional[ProjectRole], + ): + user = await create_user(session=session, global_role=global_role) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + importer_project = await create_project( + session=session, name="ImporterProject", owner=user + ) + if importer_project_role is not None: + await add_project_member( + session=session, + project=importer_project, + user=user, + project_role=importer_project_role, + ) + await create_fleet( + session=session, + project=project, + name="fleet1", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + + response = await client.post( + f"/api/project/{project.name}/exports/create", + headers=get_auth_headers(user.token), + json={ + "name": "test-export", + "importer_projects": ["ImporterProject"], + "exported_fleets": ["fleet1"], + }, + ) + assert response.status_code == 200 + export_response = response.json() + assert export_response["name"] == "test-export" + assert len(export_response["imports"]) == 1 + assert export_response["imports"][0]["project_name"] == "ImporterProject" + assert len(export_response["exported_fleets"]) == 1 + assert export_response["exported_fleets"][0]["name"] == "fleet1" + + res = await session.execute(select(ExportModel).where(ExportModel.name == "test-export")) + assert res.scalar() is not None + + async def test_creates_empty_export(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + response = await client.post( + f"/api/project/{project.name}/exports/create", + headers=get_auth_headers(user.token), + json={ + "name": "empty-export", + }, + ) + assert response.status_code == 200 + export_response = response.json() + assert export_response["name"] == "empty-export" + assert len(export_response["imports"]) == 0 + assert len(export_response["exported_fleets"]) == 0 + + res = await session.execute(select(ExportModel).where(ExportModel.name == "empty-export")) + assert res.scalar() is not None + + @pytest.mark.parametrize( + "body,error", + [ + pytest.param( + { + "name": "test-export", + "importer_projects": ["nonexistent"], + }, + "Projects {'nonexistent'} not found or you are not allowed to add them as importers", + id="nonexistent-project", + ), + pytest.param( + { + "name": "test-export", + "importer_projects": ["NotPermittedProject"], + }, + "Projects {'notpermittedproject'} not found or you are not allowed to add them as importers", + id="not-permitted-project", + ), + pytest.param( + { + "name": "test-export", + "exported_fleets": ["nonexistent-fleet"], + }, + "Fleets {'nonexistent-fleet'} not found in project 'ExporterProject'", + id="nonexistent-fleet", + ), + pytest.param( + { + "name": "test-export", + "importer_projects": [ + "ImporterProject", + "iMpOrTeRpRoJeCt", + ], # case-insensitive + }, + "Some importer projects are listed for addition more than once", + id="duplicate-project", + ), + pytest.param( + { + "name": "test-export", + "exported_fleets": ["exported-fleet", "exported-fleet"], + }, + "Some fleets are listed for addition more than once", + id="duplicate-fleet", + ), + pytest.param( + { + "name": "test-export", + "exported_fleets": ["cloud-fleet"], + }, + "Fleets ['cloud-fleet'] are cloud fleets. Can only export SSH fleets", + id="cloud-fleet", + ), + pytest.param( + { + "name": "test-export", + "importer_projects": ["eXpOrTeRpRoJeCt"], # case-insensitive + }, + "Project 'ExporterProject' cannot import from itself", + id="self-import", + ), + pytest.param( + { + "name": "", + }, + "Resource name should match regex '^[a-z][a-z0-9-]{1,40}$'", + id="empty-name", + ), + pytest.param( + { + "name": "a" * 256, + }, + "Resource name should match regex '^[a-z][a-z0-9-]{1,40}$'", + id="long-name", + ), + pytest.param( + { + "name": "!@#$", + }, + "Resource name should match regex '^[a-z][a-z0-9-]{1,40}$'", + id="invalid-name", + ), + ], + ) + async def test_rejects_invalid_export( + self, session: AsyncSession, client: AsyncClient, body: dict, error: str + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, name="ExporterProject", owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + importer_project = await create_project( + session=session, name="ImporterProject", owner=user + ) + await add_project_member( + session=session, project=importer_project, user=user, project_role=ProjectRole.ADMIN + ) + await create_fleet( + session=session, + project=project, + name="exported-fleet", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + await create_fleet(session=session, project=project, name="cloud-fleet") + not_permitted_project = await create_project( + session=session, name="NotPermittedProject", owner=user + ) + await add_project_member( + session=session, + project=not_permitted_project, + user=user, + project_role=ProjectRole.USER, + ) + + response = await client.post( + f"/api/project/{project.name}/exports/create", + headers=get_auth_headers(user.token), + json=body, + ) + assert response.status_code == 400 + assert error in response.json()["detail"][0]["msg"] + res = await session.execute(select(func.count()).select_from(ExportModel)) + assert res.scalar_one() == 0 + + async def test_rejects_export_on_name_conflict( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, name="Project") + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + await create_export( + session=session, + exporter_project=project, + importer_projects=[], + exported_fleets=[], + name="test-export", + ) + + response = await client.post( + f"/api/project/{project.name}/exports/create", + headers=get_auth_headers(user.token), + json={"name": "test-export"}, + ) + assert response.status_code == 400 + assert response.json()["detail"][0]["code"] == "resource_exists" + assert ( + response.json()["detail"][0]["msg"] + == "Export 'test-export' already exists in project 'Project'" + ) + res = await session.execute(select(func.count()).select_from(ExportModel)) + assert res.scalar_one() == 1 + + +class TestUpdateExport: + async def test_returns_403_if_not_authenticated(self, client: AsyncClient): + response = await client.post( + "/api/project/TestProject/exports/update", + json={ + "name": "test-export", + }, + ) + assert response.status_code in [401, 403] + + async def test_returns_403_if_not_admin(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + response = await client.post( + f"/api/project/{project.name}/exports/update", + headers=get_auth_headers(user.token), + json={ + "name": "test-export", + }, + ) + assert response.status_code == 403 + + @pytest.mark.parametrize( + ("global_role", "importer_project_role"), + [(GlobalRole.ADMIN, None), (GlobalRole.USER, ProjectRole.ADMIN)], + ) + async def test_updates_export( + self, + session: AsyncSession, + client: AsyncClient, + global_role: GlobalRole, + importer_project_role: Optional[ProjectRole], + ): + user = await create_user(session=session, global_role=global_role) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + other_project = await create_project(session=session, name="OtherProject", owner=user) + another_project = await create_project(session=session, name="AnotherProject", owner=user) + fleet1 = await create_fleet( + session=session, + project=project, + name="fleet1", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + fleet2 = await create_fleet( + session=session, + project=project, + name="fleet2", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + export = await create_export( + session=session, + exporter_project=project, + importer_projects=[other_project, another_project], + exported_fleets=[fleet1, fleet2], + name="test-export", + ) + + new_project1 = await create_project(session=session, name="NewProject1", owner=user) + new_project2 = await create_project(session=session, name="NewProject2", owner=user) + await create_fleet( + session=session, + project=project, + name="fleet3", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + await create_fleet( + session=session, + project=project, + name="fleet4", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + if importer_project_role is not None: + await add_project_member( + session=session, project=new_project1, user=user, project_role=ProjectRole.ADMIN + ) + await add_project_member( + session=session, project=new_project2, user=user, project_role=ProjectRole.ADMIN + ) + + response = await client.post( + f"/api/project/{project.name}/exports/update", + headers=get_auth_headers(user.token), + json={ + "name": "test-export", + "add_importer_projects": ["NewProject1", "NewProject2"], + "remove_importer_projects": ["AnotherProject"], + "add_exported_fleets": ["fleet3", "fleet4"], + "remove_exported_fleets": ["fleet2"], + }, + ) + assert response.status_code == 200 + export_response = response.json() + + assert export_response["name"] == "test-export" + assert len(export_response["imports"]) == 3 + assert {imp["project_name"] for imp in export_response["imports"]} == { + "OtherProject", + "NewProject1", + "NewProject2", + } + assert len(export_response["exported_fleets"]) == 3 + assert {fleet["name"] for fleet in export_response["exported_fleets"]} == { + "fleet1", + "fleet3", + "fleet4", + } + + await session.refresh(export, ["imports", "exported_fleets"]) + assert len(export.imports) == 3 + assert len(export.exported_fleets) == 3 + + response = await client.post( + f"/api/project/{project.name}/exports/list", headers=get_auth_headers(user.token) + ) + assert response.status_code == 200 + assert len(response.json()) == 1 + assert response.json()[0] == export_response + + async def test_can_add_same_entities_as_existing_deleted_ones( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + deleted_importer_project = await create_project( + session=session, name="_deleted_ImporterProject", owner=user, deleted=True + ) + importer_project = await create_project( + session=session, name="ImporterProject", owner=user + ) + await add_project_member( + session=session, project=importer_project, user=user, project_role=ProjectRole.ADMIN + ) + deleted_fleet = await create_fleet( + session=session, + project=project, + name="fleet", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + deleted=True, + ) + fleet = await create_fleet( + session=session, + project=project, + name=deleted_fleet.name, + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + export = await create_export( + session=session, + exporter_project=project, + importer_projects=[deleted_importer_project], + exported_fleets=[deleted_fleet], + name="test-export", + ) + + response = await client.post( + f"/api/project/{project.name}/exports/update", + headers=get_auth_headers(user.token), + json={ + "name": "test-export", + "add_importer_projects": ["ImporterProject"], + "add_exported_fleets": ["fleet"], + }, + ) + assert response.status_code == 200 + export_response = response.json() + + assert export_response["name"] == "test-export" + assert len(export_response["imports"]) == 1 + assert export_response["imports"][0]["project_name"] == "ImporterProject" + assert len(export_response["exported_fleets"]) == 1 + assert export_response["exported_fleets"][0]["name"] == "fleet" + assert export_response["exported_fleets"][0]["id"] == str(fleet.id) + + await session.refresh(export, ["imports", "exported_fleets"]) + # deleted imports and fleets are still in the database, just not returned in the response + assert len(export.imports) == 2 + assert len(export.exported_fleets) == 2 + + response = await client.post( + f"/api/project/{project.name}/exports/list", headers=get_auth_headers(user.token) + ) + assert response.status_code == 200 + assert len(response.json()) == 1 + assert response.json()[0] == export_response + + @pytest.mark.parametrize( + "body,error", + [ + pytest.param( + { + "name": "nonexistent-export", + "add_importer_projects": ["NotImporterProject"], + }, + "Export 'nonexistent-export' not found in project 'ExporterProject'", + id="nonexistent-export", + ), + pytest.param( + { + "name": "test-export", + }, + "No changes specified", + id="no-changes", + ), + pytest.param( + { + "name": "test-export", + "add_importer_projects": ["nonexistent"], + }, + "Projects {'nonexistent'} not found or you are not allowed to add them as importers", + id="add-nonexistent-project", + ), + pytest.param( + { + "name": "test-export", + "add_importer_projects": ["NotPermittedProject"], + }, + "Projects {'notpermittedproject'} not found or you are not allowed to add them as importers", + id="add-not-permitted-project", + ), + pytest.param( + { + "name": "test-export", + "add_exported_fleets": ["nonexistent-fleet"], + }, + "Fleets {'nonexistent-fleet'} not found in project 'ExporterProject'", + id="add-nonexistent-fleet", + ), + pytest.param( + { + "name": "test-export", + "add_importer_projects": ["iMpOrTeRpRoJeCt"], # case-insensitive + }, + "Projects {'importerproject'} are already importing export 'test-export'", + id="add-already-added-project", + ), + pytest.param( + { + "name": "test-export", + "add_importer_projects": [ + "ImporterProject", + "iMpOrTeRpRoJeCt", + ], # case-insensitive + }, + "Some importer projects are listed for addition more than once", + id="add-duplicate-project", + ), + pytest.param( + { + "name": "test-export", + "add_exported_fleets": ["exported-fleet"], + }, + "Fleets {'exported-fleet'} are already exported by export 'test-export'", + id="add-already-added-fleet", + ), + pytest.param( + { + "name": "test-export", + "add_exported_fleets": ["exported-fleet", "exported-fleet"], + }, + "Some fleets are listed for addition more than once", + id="add-duplicate-fleet", + ), + pytest.param( + { + "name": "test-export", + "add_exported_fleets": ["cloud-fleet"], + }, + "Fleets ['cloud-fleet'] are cloud fleets. Can only export SSH fleets", + id="add-cloud-fleet", + ), + pytest.param( + { + "name": "test-export", + "add_importer_projects": ["eXpOrTeRpRoJeCt"], # case-insensitive + }, + "Project 'ExporterProject' cannot import from itself", + id="add-self-import", + ), + pytest.param( + { + "name": "test-export", + "remove_importer_projects": ["NotImporterProject"], + }, + "Projects {'notimporterproject'} are not importing export 'test-export'", + id="remove-not-added-project", + ), + pytest.param( + { + "name": "test-export", + "remove_importer_projects": ["nonexistent"], + }, + "Projects {'nonexistent'} are not importing export 'test-export'", + id="remove-nonexistent-project", + ), + pytest.param( + { + "name": "test-export", + "remove_exported_fleets": ["not-exported-fleet"], + }, + "Fleets {'not-exported-fleet'} are not exported by export 'test-export'", + id="remove-not-exported-fleet", + ), + pytest.param( + { + "name": "test-export", + "remove_exported_fleets": ["nonexistent-fleet"], + }, + "Fleets {'nonexistent-fleet'} are not exported by export 'test-export'", + id="remove-nonexistent-fleet", + ), + pytest.param( + { + "name": "test-export", + "remove_importer_projects": [ + "ImporterProject", + "iMpOrTeRpRoJeCt", + ], # case-insensitive + }, + "Some importer projects are listed for removal more than once", + id="remove-duplicate-project", + ), + pytest.param( + { + "name": "test-export", + "remove_exported_fleets": ["exported-fleet", "exported-fleet"], + }, + "Some fleets are listed for removal more than once", + id="remove-duplicate-fleet", + ), + pytest.param( + { + "name": "test-export", + "add_importer_projects": ["NotImporterProject"], + "remove_importer_projects": ["NoTiMpOrTeRpRoJeCt"], # case-insensitive + }, + "Projects {'notimporterproject'} are listed for both addition and removal. Cannot add and remove at the same time", + id="add-remove-same-project", + ), + pytest.param( + { + "name": "test-export", + "add_exported_fleets": ["not-exported-fleet"], + "remove_exported_fleets": ["not-exported-fleet"], + }, + "Fleets {'not-exported-fleet'} are listed for both addition and removal. Cannot add and remove at the same time", + id="add-remove-same-fleet", + ), + ], + ) + async def test_rejects_invalid_update( + self, session: AsyncSession, client: AsyncClient, body: dict, error: str + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, name="ExporterProject", owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + importer_project = await create_project( + session=session, name="ImporterProject", owner=user + ) + await add_project_member( + session=session, project=importer_project, user=user, project_role=ProjectRole.ADMIN + ) + exported_fleet = await create_fleet( + session=session, + project=project, + name="exported-fleet", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + await create_export( + session=session, + exporter_project=project, + importer_projects=[importer_project], + exported_fleets=[exported_fleet], + name="test-export", + ) + await create_fleet(session=session, project=project, name="cloud-fleet") + await create_fleet( + session=session, + project=project, + name="not-exported-fleet", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + not_importer_project = await create_project( + session=session, name="NotImporterProject", owner=user + ) + await add_project_member( + session=session, + project=not_importer_project, + user=user, + project_role=ProjectRole.ADMIN, + ) + not_permitted_project = await create_project( + session=session, name="NotPermittedProject", owner=user + ) + await add_project_member( + session=session, + project=not_permitted_project, + user=user, + project_role=ProjectRole.USER, + ) + + response = await client.post( + f"/api/project/{project.name}/exports/list", headers=get_auth_headers(user.token) + ) + assert response.status_code == 200 + canonical_exports = response.json() + + response = await client.post( + f"/api/project/{project.name}/exports/update", + headers=get_auth_headers(user.token), + json=body, + ) + assert response.status_code == 400 + assert error in response.json()["detail"][0]["msg"] + + response = await client.post( + f"/api/project/{project.name}/exports/list", + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + assert response.json() == canonical_exports + + +class TestDeleteExport: + async def test_returns_403_if_not_authenticated(self, client: AsyncClient): + response = await client.post( + "/api/project/TestProject/exports/delete", + json={"name": "test-export"}, + ) + assert response.status_code in [401, 403] + + async def test_returns_403_if_not_admin(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.USER + ) + response = await client.post( + f"/api/project/{project.name}/exports/delete", + headers=get_auth_headers(user.token), + json={"name": "test-export"}, + ) + assert response.status_code == 403 + + async def test_deletes_export(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + other_project = await create_project(session=session, name="OtherProject", owner=user) + fleet = await create_fleet( + session=session, + project=project, + name="fleet1", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + await create_export( + session=session, + exporter_project=project, + importer_projects=[other_project], + exported_fleets=[fleet], + name="test-export", + ) + + response = await client.post( + f"/api/project/{project.name}/exports/delete", + headers=get_auth_headers(user.token), + json={"name": "test-export"}, + ) + assert response.status_code == 200 + + res = await session.execute(select(ExportModel).where(ExportModel.name == "test-export")) + assert res.scalar() is None + + async def test_returns_400_for_nonexistent_export( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + response = await client.post( + f"/api/project/{project.name}/exports/delete", + headers=get_auth_headers(user.token), + json={"name": "nonexistent-export"}, + ) + assert response.status_code == 400 + assert response.json()["detail"][0]["code"] == "resource_not_exists" + + +class TestListExports: + async def test_returns_403_if_not_authenticated(self, client: AsyncClient): + response = await client.post( + "/api/project/TestProject/exports/list", + ) + assert response.status_code in [401, 403] + + async def test_returns_403_if_not_member(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + response = await client.post( + f"/api/project/{project.name}/exports/list", + headers=get_auth_headers(user.token), + ) + assert response.status_code == 403 + + @pytest.mark.parametrize( + "global_role, project_role", + [ + (GlobalRole.ADMIN, None), + (GlobalRole.USER, ProjectRole.USER), + ], + ) + async def test_lists_exports( + self, + session: AsyncSession, + client: AsyncClient, + global_role: GlobalRole, + project_role: Optional[ProjectRole], + ): + user = await create_user(session=session, global_role=global_role) + project = await create_project(session=session, owner=user) + if project_role: + await add_project_member( + session=session, project=project, user=user, project_role=project_role + ) + + other_project = await create_project(session=session, name="OtherProject", owner=user) + fleet1 = await create_fleet( + session=session, + project=project, + name="fleet1", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + fleet2 = await create_fleet( + session=session, + project=project, + name="fleet2", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + for name, fleet in (("export1", fleet1), ("export2", fleet2)): + await create_export( + session=session, + exporter_project=project, + importer_projects=[other_project], + exported_fleets=[fleet], + name=name, + ) + + response = await client.post( + f"/api/project/{project.name}/exports/list", + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + exports = response.json() + assert len(exports) == 2 + exports.sort(key=lambda e: e["name"]) + + assert exports[0]["name"] == "export1" + assert len(exports[0]["imports"]) == 1 + assert exports[0]["imports"][0]["project_name"] == "OtherProject" + assert len(exports[0]["exported_fleets"]) == 1 + assert exports[0]["exported_fleets"][0]["name"] == "fleet1" + + assert exports[1]["name"] == "export2" + assert len(exports[1]["imports"]) == 1 + assert exports[1]["imports"][0]["project_name"] == "OtherProject" + assert len(exports[1]["exported_fleets"]) == 1 + assert exports[1]["exported_fleets"][0]["name"] == "fleet2" + + @pytest.mark.parametrize( + "global_role, project_role", + [ + (GlobalRole.ADMIN, None), + (GlobalRole.USER, ProjectRole.USER), + ], + ) + async def test_returns_empty_list_when_no_exports( + self, + session: AsyncSession, + client: AsyncClient, + global_role: GlobalRole, + project_role: Optional[ProjectRole], + ): + user = await create_user(session=session, global_role=global_role) + project = await create_project(session=session, owner=user) + if project_role: + await add_project_member( + session=session, project=project, user=user, project_role=project_role + ) + + response = await client.post( + f"/api/project/{project.name}/exports/list", + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + assert response.json() == [] + + async def test_not_includes_deleted_entities(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session, global_role=GlobalRole.USER) + project = await create_project(session=session, owner=user) + await add_project_member( + session=session, project=project, user=user, project_role=ProjectRole.ADMIN + ) + + importer_project = await create_project( + session=session, name="ImporterProject", owner=user + ) + deleted_importer_project = await create_project( + session=session, name="DeletedImporterProject", owner=user, deleted=True + ) + fleet = await create_fleet( + session=session, + project=project, + name="fleet", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + ) + deleted_fleet = await create_fleet( + session=session, + project=project, + name="deleted-fleet", + spec=get_fleet_spec(get_ssh_fleet_configuration()), + deleted=True, + ) + await create_export( + session=session, + exporter_project=project, + importer_projects=[importer_project, deleted_importer_project], + exported_fleets=[fleet, deleted_fleet], + name="test-export", + ) + + response = await client.post( + f"/api/project/{project.name}/exports/list", + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + exports = response.json() + assert len(exports) == 1 + assert exports[0]["name"] == "test-export" + assert len(exports[0]["imports"]) == 1 + assert exports[0]["imports"][0]["project_name"] == "ImporterProject" + assert len(exports[0]["exported_fleets"]) == 1 + assert exports[0]["exported_fleets"][0]["name"] == "fleet" From 94979946b3e26b9a42bc74bc28dd0ef5cb74297f Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 10 Mar 2026 14:07:22 +0100 Subject: [PATCH 2/3] Exports CLI --- src/dstack/_internal/cli/commands/export.py | 159 ++++++++++++++++++ src/dstack/_internal/cli/main.py | 2 + .../_internal/cli/services/completion.py | 5 + src/dstack/api/server/__init__.py | 6 + src/dstack/api/server/_exports.py | 57 +++++++ 5 files changed, 229 insertions(+) create mode 100644 src/dstack/_internal/cli/commands/export.py create mode 100644 src/dstack/api/server/_exports.py diff --git a/src/dstack/_internal/cli/commands/export.py b/src/dstack/_internal/cli/commands/export.py new file mode 100644 index 000000000..1d3a4566f --- /dev/null +++ b/src/dstack/_internal/cli/commands/export.py @@ -0,0 +1,159 @@ +import argparse +from typing import Any, Union + +from rich.table import Table + +from dstack._internal.cli.commands import APIBaseCommand +from dstack._internal.cli.services.completion import ExportNameCompleter +from dstack._internal.cli.utils.common import add_row_from_dict, confirm_ask, console +from dstack._internal.core.models.exports import Export + + +class ExportCommand(APIBaseCommand): + NAME = "export" + DESCRIPTION = "Manage exports" + + def _register(self): + super()._register() + self._parser.set_defaults(subfunc=self._list) + subparsers = self._parser.add_subparsers(dest="action") + + list_parser = subparsers.add_parser( + "list", help="List exports", formatter_class=self._parser.formatter_class + ) + list_parser.set_defaults(subfunc=self._list) + + create_parser = subparsers.add_parser( + "create", help="Create an export", formatter_class=self._parser.formatter_class + ) + create_parser.add_argument( + "name", + help="The name of the export", + ) + create_parser.add_argument( + "--importer", + action="append", + dest="importers", + help="Importer project name (can be specified multiple times)", + default=[], + ) + create_parser.add_argument( + "--fleet", + action="append", + dest="fleets", + help="Fleet name to export (can be specified multiple times)", + default=[], + ) + create_parser.set_defaults(subfunc=self._create) + + update_parser = subparsers.add_parser( + "update", help="Update an export", formatter_class=self._parser.formatter_class + ) + update_parser.add_argument( + "name", + help="The name of the export", + ).completer = ExportNameCompleter() # type: ignore[attr-defined] + update_parser.add_argument( + "--add-importer", + action="append", + dest="add_importers", + help="Importer project name to add (can be specified multiple times)", + default=[], + ) + update_parser.add_argument( + "--remove-importer", + action="append", + dest="remove_importers", + help="Importer project name to remove (can be specified multiple times)", + default=[], + ) + update_parser.add_argument( + "--add-fleet", + action="append", + dest="add_fleets", + help="Fleet name to add (can be specified multiple times)", + default=[], + ) + update_parser.add_argument( + "--remove-fleet", + action="append", + dest="remove_fleets", + help="Fleet name to remove (can be specified multiple times)", + default=[], + ) + update_parser.set_defaults(subfunc=self._update) + + delete_parser = subparsers.add_parser( + "delete", help="Delete an export", formatter_class=self._parser.formatter_class + ) + delete_parser.add_argument( + "name", + help="The name of the export", + ).completer = ExportNameCompleter() # type: ignore[attr-defined] + delete_parser.add_argument( + "-y", "--yes", help="Don't ask for confirmation", action="store_true" + ) + delete_parser.set_defaults(subfunc=self._delete) + + def _command(self, args: argparse.Namespace): + super()._command(args) + args.subfunc(args) + + def _list(self, args: argparse.Namespace): + exports = self.api.client.exports.list(self.api.project) + print_exports_table(exports) + + def _create(self, args: argparse.Namespace): + with console.status("Creating export..."): + export = self.api.client.exports.create( + project_name=self.api.project, + name=args.name, + importer_projects=args.importers, + exported_fleets=args.fleets, + ) + print_exports_table([export]) + + def _update(self, args: argparse.Namespace): + with console.status("Updating export..."): + export = self.api.client.exports.update( + project_name=self.api.project, + name=args.name, + add_importer_projects=args.add_importers, + remove_importer_projects=args.remove_importers, + add_exported_fleets=args.add_fleets, + remove_exported_fleets=args.remove_fleets, + ) + print_exports_table([export]) + + def _delete(self, args: argparse.Namespace): + if not args.yes and not confirm_ask(f"Delete the export [code]{args.name}[/]?"): + console.print("\nExiting...") + return + + with console.status("Deleting export..."): + self.api.client.exports.delete(project_name=self.api.project, name=args.name) + + console.print(f"Export [code]{args.name}[/] deleted") + + +def print_exports_table(exports: list[Export]): + table = Table(box=None) + table.add_column("NAME", no_wrap=True) + table.add_column("FLEETS") + table.add_column("IMPORTERS") + + for export in exports: + fleets = ( + ", ".join([f.name for f in export.exported_fleets]) if export.exported_fleets else "-" + ) + importers = ", ".join([i.project_name for i in export.imports]) if export.imports else "-" + + row: dict[Union[str, int], Any] = { + "NAME": export.name, + "FLEETS": fleets, + "IMPORTERS": importers, + } + add_row_from_dict(table, row) + + console.print(table) + console.print() diff --git a/src/dstack/_internal/cli/main.py b/src/dstack/_internal/cli/main.py index a5f678a98..be1f2605d 100644 --- a/src/dstack/_internal/cli/main.py +++ b/src/dstack/_internal/cli/main.py @@ -9,6 +9,7 @@ from dstack._internal.cli.commands.completion import CompletionCommand from dstack._internal.cli.commands.delete import DeleteCommand from dstack._internal.cli.commands.event import EventCommand +from dstack._internal.cli.commands.export import ExportCommand from dstack._internal.cli.commands.fleet import FleetCommand from dstack._internal.cli.commands.gateway import GatewayCommand from dstack._internal.cli.commands.init import InitCommand @@ -66,6 +67,7 @@ def main(): AttachCommand.register(subparsers) DeleteCommand.register(subparsers) EventCommand.register(subparsers) + ExportCommand.register(subparsers) FleetCommand.register(subparsers) GatewayCommand.register(subparsers) InitCommand.register(subparsers) diff --git a/src/dstack/_internal/cli/services/completion.py b/src/dstack/_internal/cli/services/completion.py index 39ddbc1f3..aee036835 100644 --- a/src/dstack/_internal/cli/services/completion.py +++ b/src/dstack/_internal/cli/services/completion.py @@ -80,6 +80,11 @@ def fetch_resource_names(self, api: Client) -> Iterable[str]: return [r.name for r in api.client.secrets.list(api.project)] +class ExportNameCompleter(BaseAPINameCompleter): + def fetch_resource_names(self, api: Client) -> Iterable[str]: + return [r.name for r in api.client.exports.list(api.project)] + + class ProjectNameCompleter(BaseCompleter): """ Completer for local project names. diff --git a/src/dstack/api/server/__init__.py b/src/dstack/api/server/__init__.py index 5d6ea0860..caa1e3419 100644 --- a/src/dstack/api/server/__init__.py +++ b/src/dstack/api/server/__init__.py @@ -17,6 +17,7 @@ from dstack.api.server._auth import AuthAPIClient from dstack.api.server._backends import BackendsAPIClient from dstack.api.server._events import EventsAPIClient +from dstack.api.server._exports import ExportsAPIClient from dstack.api.server._files import FilesAPIClient from dstack.api.server._fleets import FleetsAPIClient from dstack.api.server._gateways import GatewaysAPIClient @@ -50,6 +51,7 @@ class APIClient: logs: operations with logs gateways: operations with gateways volumes: operations with volumes + exports: operations with exports files: operations with files """ @@ -126,6 +128,10 @@ def gateways(self) -> GatewaysAPIClient: def volumes(self) -> VolumesAPIClient: return VolumesAPIClient(self._request, self._logger) + @property + def exports(self) -> ExportsAPIClient: + return ExportsAPIClient(self._request, self._logger) + @property def files(self) -> FilesAPIClient: return FilesAPIClient(self._request, self._logger) diff --git a/src/dstack/api/server/_exports.py b/src/dstack/api/server/_exports.py new file mode 100644 index 000000000..419a4179b --- /dev/null +++ b/src/dstack/api/server/_exports.py @@ -0,0 +1,57 @@ +from typing import List + +from pydantic import parse_obj_as + +from dstack._internal.core.models.exports import Export +from dstack._internal.server.schemas.exports import ( + CreateExportRequest, + DeleteExportRequest, + UpdateExportRequest, +) +from dstack.api.server._group import APIClientGroup + + +class ExportsAPIClient(APIClientGroup): + def list(self, project_name: str) -> List[Export]: + resp = self._request(f"/api/project/{project_name}/exports/list") + return parse_obj_as(List[Export.__response__], resp.json()) + + def create( + self, + project_name: str, + name: str, + *, + importer_projects: List[str] = [], + exported_fleets: List[str] = [], + ) -> Export: + body = CreateExportRequest( + name=name, + importer_projects=importer_projects, + exported_fleets=exported_fleets, + ) + resp = self._request(f"/api/project/{project_name}/exports/create", body=body.json()) + return parse_obj_as(Export.__response__, resp.json()) + + def update( + self, + project_name: str, + name: str, + *, + add_importer_projects: List[str] = [], + remove_importer_projects: List[str] = [], + add_exported_fleets: List[str] = [], + remove_exported_fleets: List[str] = [], + ) -> Export: + body = UpdateExportRequest( + name=name, + add_importer_projects=add_importer_projects, + remove_importer_projects=remove_importer_projects, + add_exported_fleets=add_exported_fleets, + remove_exported_fleets=remove_exported_fleets, + ) + resp = self._request(f"/api/project/{project_name}/exports/update", body=body.json()) + return parse_obj_as(Export.__response__, resp.json()) + + def delete(self, project_name: str, name: str) -> None: + body = DeleteExportRequest(name=name) + self._request(f"/api/project/{project_name}/exports/delete", body=body.json()) From 738d01c21f1622010942d6d60c9023382468be34 Mon Sep 17 00:00:00 2001 From: Jvst Me Date: Tue, 10 Mar 2026 20:10:18 +0100 Subject: [PATCH 3/3] Fix flaky test --- src/tests/_internal/server/routers/test_exports.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/tests/_internal/server/routers/test_exports.py b/src/tests/_internal/server/routers/test_exports.py index a6bed34c4..dd41419fd 100644 --- a/src/tests/_internal/server/routers/test_exports.py +++ b/src/tests/_internal/server/routers/test_exports.py @@ -410,8 +410,13 @@ async def test_updates_export( f"/api/project/{project.name}/exports/list", headers=get_auth_headers(user.token) ) assert response.status_code == 200 - assert len(response.json()) == 1 - assert response.json()[0] == export_response + export_list = response.json() + assert len(export_list) == 1 + export_response["imports"].sort(key=lambda i: i["project_name"]) + export_list[0]["imports"].sort(key=lambda i: i["project_name"]) + export_response["exported_fleets"].sort(key=lambda f: f["name"]) + export_list[0]["exported_fleets"].sort(key=lambda f: f["name"]) + assert export_list[0] == export_response async def test_can_add_same_entities_as_existing_deleted_ones( self, session: AsyncSession, client: AsyncClient