From 8d7a870fe59df4fdaa51e876a54a7c48257e5705 Mon Sep 17 00:00:00 2001 From: Dmitry Meyer Date: Tue, 10 Mar 2026 16:42:06 +0000 Subject: [PATCH] Add API for SSH proxy Part-of: https://github.com/dstackai/dstack/issues/3644 --- pyproject.toml | 2 + src/dstack/_internal/server/app.py | 2 + .../_internal/server/routers/sshproxy.py | 39 +++ .../_internal/server/schemas/sshproxy.py | 27 ++ .../_internal/server/security/permissions.py | 27 +- .../server/services/jobs/__init__.py | 4 + .../server/services/runs/__init__.py | 6 +- src/dstack/_internal/server/services/ssh.py | 106 +++--- .../_internal/server/services/sshproxy.py | 86 +++++ src/dstack/_internal/server/testing/common.py | 14 +- .../server/routers/test_prometheus.py | 4 +- .../_internal/server/routers/test_sshproxy.py | 189 +++++++++++ .../_internal/server/services/test_ssh.py | 310 ++++++++++++++++++ 13 files changed, 766 insertions(+), 50 deletions(-) create mode 100644 src/dstack/_internal/server/routers/sshproxy.py create mode 100644 src/dstack/_internal/server/schemas/sshproxy.py create mode 100644 src/dstack/_internal/server/services/sshproxy.py create mode 100644 src/tests/_internal/server/routers/test_sshproxy.py create mode 100644 src/tests/_internal/server/services/test_ssh.py diff --git a/pyproject.toml b/pyproject.toml index 8c53b2e166..f3542a7440 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -126,6 +126,7 @@ markers = [ ] env = [ "DSTACK_CLI_RICH_FORCE_TERMINAL=0", + "DSTACK_SSHPROXY_API_TOKEN=test-token", ] filterwarnings = [ # testcontainers modules use deprecated decorators – nothing we can do: @@ -142,6 +143,7 @@ dev = [ "pytest-httpbin>=2.1.0", "pytest-socket>=0.7.0", "pytest-env>=1.1.0", + "pytest-unordered>=0.7.0", "httpbin>=0.10.2", # indirect to make compatible with Werkzeug 3 "requests-mock>=1.12.1", "openai>=1.68.2", diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 56ea7f0386..28135b0b4e 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -44,6 +44,7 @@ runs, secrets, server, + sshproxy, templates, users, volumes, @@ -253,6 +254,7 @@ def register_routes(app: FastAPI, ui: bool = True): app.include_router(files.router) app.include_router(events.root_router) app.include_router(templates.router) + app.include_router(sshproxy.router) @app.exception_handler(ForbiddenError) async def forbidden_error_handler(request: Request, exc: ForbiddenError): diff --git a/src/dstack/_internal/server/routers/sshproxy.py b/src/dstack/_internal/server/routers/sshproxy.py new file mode 100644 index 0000000000..3edc927e96 --- /dev/null +++ b/src/dstack/_internal/server/routers/sshproxy.py @@ -0,0 +1,39 @@ +import os +from typing import Annotated + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import ResourceNotExistsError +from dstack._internal.server.db import get_session +from dstack._internal.server.schemas.sshproxy import GetUpstreamRequest, GetUpstreamResponse +from dstack._internal.server.security.permissions import AlwaysForbidden, ServiceAccount +from dstack._internal.server.services.sshproxy import get_upstream_response +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) + +if _token := os.getenv("DSTACK_SSHPROXY_API_TOKEN"): + _auth = ServiceAccount(_token) +else: + _auth = AlwaysForbidden() + + +router = APIRouter( + prefix="/api/sshproxy", + tags=["sshproxy"], + responses=get_base_api_additional_responses(), + dependencies=[Depends(_auth)], +) + + +@router.post("/get_upstream", response_model=GetUpstreamResponse) +async def get_upstream( + body: GetUpstreamRequest, + session: Annotated[AsyncSession, Depends(get_session)], +): + response = await get_upstream_response(session=session, upstream_id=body.id) + if response is None: + raise ResourceNotExistsError() + return CustomORJSONResponse(response) diff --git a/src/dstack/_internal/server/schemas/sshproxy.py b/src/dstack/_internal/server/schemas/sshproxy.py new file mode 100644 index 0000000000..10c9297d88 --- /dev/null +++ b/src/dstack/_internal/server/schemas/sshproxy.py @@ -0,0 +1,27 @@ +from typing import Annotated + +from pydantic import Field + +from dstack._internal.core.models.common import CoreModel + + +class GetUpstreamRequest(CoreModel): + # The format of id is intentionally not limited to UUID to allow further extensions + id: str + + +class UpstreamHost(CoreModel): + host: Annotated[str, Field(description="The hostname or IP address")] + port: Annotated[int, Field(description="The SSH port")] + user: Annotated[str, Field(description="The user to log in")] + private_key: Annotated[str, Field(description="The private key in OpenSSH file format")] + + +class GetUpstreamResponse(CoreModel): + hosts: Annotated[ + list[UpstreamHost], + Field(description="The chain of SSH hosts, the jump host(s) first, the target host last"), + ] + authorized_keys: Annotated[ + list[str], Field(description="The list of authorized public keys in OpenSSH file format") + ] diff --git a/src/dstack/_internal/server/security/permissions.py b/src/dstack/_internal/server/security/permissions.py index 107e526d30..a343152e6e 100644 --- a/src/dstack/_internal/server/security/permissions.py +++ b/src/dstack/_internal/server/security/permissions.py @@ -1,3 +1,4 @@ +from secrets import compare_digest from typing import Annotated, Optional, Tuple from uuid import UUID @@ -219,9 +220,23 @@ async def __call__( raise error_forbidden() -class OptionalServiceAccount: +class ServiceAccount: + def __init__(self, token: str) -> None: + self._token = token.encode() + + async def __call__( + self, token: Annotated[HTTPAuthorizationCredentials, Security(HTTPBearer())] + ) -> None: + if not compare_digest(token.credentials.encode(), self._token): + raise error_invalid_token() + + +class OptionalServiceAccount(ServiceAccount): + _token: Optional[bytes] = None + def __init__(self, token: Optional[str]) -> None: - self._token = token + if token is not None: + super().__init__(token) async def __call__( self, @@ -233,8 +248,12 @@ async def __call__( return if token is None: raise error_forbidden() - if token.credentials != self._token: - raise error_invalid_token() + await super().__call__(token) + + +class AlwaysForbidden: + async def __call__(self) -> None: + raise error_forbidden() async def get_project_member( diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index f718a80ce6..62254d7659 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -267,6 +267,10 @@ def get_job_runtime_data(job_model: JobModel) -> Optional[JobRuntimeData]: return JobRuntimeData.__response__.parse_raw(job_model.job_runtime_data) +def get_job_spec(job_model: JobModel) -> JobSpec: + return JobSpec.__response__.parse_raw(job_model.job_spec_data) + + def delay_job_instance_termination(job_model: JobModel): job_model.remove_at = common.get_current_datetime() + timedelta(seconds=15) diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index f8aa3f288b..8bb4f2ae7c 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -112,6 +112,10 @@ def switch_run_status( events.emit(session, msg, actor=actor, targets=[events.Target.from_model(run_model)]) +def get_run_spec(run_model: RunModel) -> RunSpec: + return RunSpec.__response__.parse_raw(run_model.run_spec) + + async def list_user_runs( session: AsyncSession, user: UserModel, @@ -743,7 +747,7 @@ def run_model_to_run( include_sensitive=include_sensitive, ) - run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + run_spec = get_run_spec(run_model) latest_job_submission = None if len(jobs) > 0 and len(jobs[0].job_submissions) > 0: diff --git a/src/dstack/_internal/server/services/ssh.py b/src/dstack/_internal/server/services/ssh.py index 0fa7c189e2..2c0970f680 100644 --- a/src/dstack/_internal/server/services/ssh.py +++ b/src/dstack/_internal/server/services/ssh.py @@ -1,67 +1,93 @@ from collections.abc import Iterable -from typing import Optional -import dstack._internal.server.services.jobs as jobs_services from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.instances import SSHConnectionParams -from dstack._internal.core.models.runs import JobProvisioningData from dstack._internal.core.services.ssh.tunnel import SSH_DEFAULT_OPTIONS, SocketPair, SSHTunnel from dstack._internal.server.models import JobModel from dstack._internal.server.services.instances import get_instance_remote_connection_info +from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data from dstack._internal.utils.common import get_or_error from dstack._internal.utils.path import FileContent -def container_ssh_tunnel( - job: JobModel, - forwarded_sockets: Iterable[SocketPair] = (), - options: dict[str, str] = SSH_DEFAULT_OPTIONS, -) -> SSHTunnel: +def get_container_ssh_credentials(job: JobModel) -> list[tuple[SSHConnectionParams, FileContent]]: """ - Build SSHTunnel for connecting to the container running the specified job. + Returns the information needed to connect to the SSH server inside the job container. + + The user of the target host (container) is set to: + * VM-based backends and SSH instances: "root" + * container-based backends: `JobProvisioningData.username`, which is, as of 2026-03-10, + is always "root" on all supported backends (Runpod, Vast.ai, Kubernetes) + + Args: + job: `JobModel` with `instance` and `instance.project` fields loaded. + + Returns: + A list of hosts credentials as (host's `SSHConnectionParams`, private key's `FileContent`) + pairs ordered from the first proxy jump (if any) to the target host (container). """ - jpd: JobProvisioningData = JobProvisioningData.__response__.parse_raw( - job.job_provisioning_data - ) + hosts: list[tuple[SSHConnectionParams, FileContent]] = [] + + instance = get_or_error(job.instance) + project_key = FileContent(instance.project.ssh_private_key) + + rci = get_instance_remote_connection_info(instance) + if rci is not None and (head_proxy := rci.ssh_proxy) is not None: + head_key = FileContent(get_or_error(get_or_error(rci.ssh_proxy_keys)[0].private)) + hosts.append((head_proxy, head_key)) + + jpd = get_job_provisioning_data(job) + assert jpd is not None assert jpd.hostname is not None assert jpd.ssh_port is not None - if not jpd.dockerized: - ssh_destination = f"{jpd.username}@{jpd.hostname}" - ssh_port = jpd.ssh_port - ssh_proxy = jpd.ssh_proxy - else: - ssh_destination = "root@localhost" + + if jpd.dockerized: + if jpd.backend != BackendType.LOCAL: + instance_proxy = SSHConnectionParams( + hostname=jpd.hostname, + username=jpd.username, + port=jpd.ssh_port, + ) + hosts.append((instance_proxy, project_key)) ssh_port = DSTACK_RUNNER_SSH_PORT - job_submission = jobs_services.job_model_to_job_submission(job) - jrd = job_submission.job_runtime_data + jrd = get_job_runtime_data(job) if jrd is not None and jrd.ports is not None: ssh_port = jrd.ports.get(ssh_port, ssh_port) - ssh_proxy = SSHConnectionParams( + target_host = SSHConnectionParams( + hostname="localhost", + username="root", + port=ssh_port, + ) + hosts.append((target_host, project_key)) + else: + if jpd.ssh_proxy is not None: + hosts.append((jpd.ssh_proxy, project_key)) + target_host = SSHConnectionParams( hostname=jpd.hostname, username=jpd.username, port=jpd.ssh_port, ) - if jpd.backend == BackendType.LOCAL: - ssh_proxy = None - ssh_head_proxy: Optional[SSHConnectionParams] = None - ssh_head_proxy_private_key: Optional[str] = None - instance = get_or_error(job.instance) - rci = get_instance_remote_connection_info(instance) - if rci is not None and rci.ssh_proxy is not None: - ssh_head_proxy = rci.ssh_proxy - ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private - ssh_proxies = [] - if ssh_head_proxy is not None: - ssh_head_proxy_private_key = get_or_error(ssh_head_proxy_private_key) - ssh_proxies.append((ssh_head_proxy, FileContent(ssh_head_proxy_private_key))) - if ssh_proxy is not None: - ssh_proxies.append((ssh_proxy, None)) + hosts.append((target_host, project_key)) + + return hosts + + +def container_ssh_tunnel( + job: JobModel, + forwarded_sockets: Iterable[SocketPair] = (), + options: dict[str, str] = SSH_DEFAULT_OPTIONS, +) -> SSHTunnel: + """ + Build SSHTunnel for connecting to the container running the specified job. + """ + hosts = get_container_ssh_credentials(job) + target, identity = hosts[-1] return SSHTunnel( - destination=ssh_destination, - port=ssh_port, - ssh_proxies=ssh_proxies, - identity=FileContent(instance.project.ssh_private_key), + destination=f"{target.username}@{target.hostname}", + port=target.port, + ssh_proxies=hosts[:-1], + identity=identity, forwarded_sockets=forwarded_sockets, options=options, ) diff --git a/src/dstack/_internal/server/services/sshproxy.py b/src/dstack/_internal/server/services/sshproxy.py new file mode 100644 index 0000000000..724e68e912 --- /dev/null +++ b/src/dstack/_internal/server/services/sshproxy.py @@ -0,0 +1,86 @@ +from typing import Optional +from uuid import UUID + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import joinedload + +from dstack._internal.core.models.runs import JobStatus +from dstack._internal.server.models import ( + InstanceModel, + JobModel, + ProjectModel, + RunModel, + UserModel, +) +from dstack._internal.server.schemas.sshproxy import GetUpstreamResponse, UpstreamHost +from dstack._internal.server.services.jobs import get_job_runtime_data, get_job_spec +from dstack._internal.server.services.runs import get_run_spec +from dstack._internal.server.services.ssh import get_container_ssh_credentials + + +async def get_upstream_response( + session: AsyncSession, + upstream_id: str, +) -> Optional[GetUpstreamResponse]: + # The format of upstream_id is intentionally not limited to UUID in the API schema to allow + # further extensions. Currently, it's just a JobModel.id + try: + job_id = UUID(upstream_id) + except ValueError: + return None + + res = await session.execute( + select(JobModel) + .where( + JobModel.id == job_id, + JobModel.status == JobStatus.RUNNING, + ) + .options( + ( + joinedload(JobModel.instance, innerjoin=True) + .load_only(InstanceModel.remote_connection_info) + .joinedload(InstanceModel.project, innerjoin=True) + .load_only(ProjectModel.ssh_private_key) + ), + ( + joinedload(JobModel.run, innerjoin=True) + .load_only(RunModel.run_spec) + .joinedload(RunModel.user, innerjoin=True) + .load_only(UserModel.ssh_public_key) + ), + ) + ) + job = res.scalar_one_or_none() + if job is None: + return None + + hosts: list[UpstreamHost] = [] + for ssh_params, private_key in get_container_ssh_credentials(job): + hosts.append( + UpstreamHost( + host=ssh_params.hostname, + port=ssh_params.port, + user=ssh_params.username, + private_key=private_key.content, + ) + ) + + username: Optional[str] = None + if (jrd := get_job_runtime_data(job)) is not None: + username = jrd.username + if username is None and (job_spec_user := get_job_spec(job).user) is not None: + username = job_spec_user.username + if username is not None: + hosts[-1].user = username + + authorized_keys: set[str] = set() + if (run_spec_key := get_run_spec(job.run).ssh_key_pub) is not None: + authorized_keys.add(run_spec_key) + if (user_key := job.run.user.ssh_public_key) is not None: + authorized_keys.add(user_key) + + return GetUpstreamResponse( + hosts=hosts, + authorized_keys=list(authorized_keys), + ) diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 89072e1555..366c944948 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -55,6 +55,7 @@ InstanceType, RemoteConnectionInfo, Resources, + SSHConnectionParams, SSHKey, ) from dstack._internal.core.models.placement import ( @@ -424,6 +425,9 @@ def get_job_provisioning_data( internal_ip: Optional[str] = "127.0.0.4", price: float = 10.5, instance_type: Optional[InstanceType] = None, + username: str = "ubuntu", + ssh_port: int = 22, + ssh_proxy: Optional[SSHConnectionParams] = None, ) -> JobProvisioningData: gpus = [ Gpu( @@ -447,11 +451,11 @@ def get_job_provisioning_data( internal_ip=internal_ip, region=region, price=price, - username="ubuntu", - ssh_port=22, + username=username, + ssh_port=ssh_port, dockerized=dockerized, backend_data=None, - ssh_proxy=None, + ssh_proxy=ssh_proxy, ) @@ -865,6 +869,8 @@ def get_remote_connection_info( port: int = 22, ssh_user: str = "ubuntu", ssh_keys: Optional[list[SSHKey]] = None, + ssh_proxy: Optional[SSHConnectionParams] = None, + ssh_proxy_keys: Optional[list[SSHKey]] = None, env: Optional[Union[Env, dict]] = None, ): if ssh_keys is None: @@ -878,6 +884,8 @@ def get_remote_connection_info( port=port, ssh_user=ssh_user, ssh_keys=ssh_keys, + ssh_proxy=ssh_proxy, + ssh_proxy_keys=ssh_proxy_keys, env=env, ) diff --git a/src/tests/_internal/server/routers/test_prometheus.py b/src/tests/_internal/server/routers/test_prometheus.py index ab9549965d..f87f43a80f 100644 --- a/src/tests/_internal/server/routers/test_prometheus.py +++ b/src/tests/_internal/server/routers/test_prometheus.py @@ -369,7 +369,7 @@ async def test_returns_404_if_not_enabled( async def test_returns_403_if_not_authenticated( self, monkeypatch: pytest.MonkeyPatch, client: AsyncClient, token: Optional[str] ): - monkeypatch.setattr("dstack._internal.server.routers.prometheus._auth._token", "secret") + monkeypatch.setattr("dstack._internal.server.routers.prometheus._auth._token", b"secret") if token is not None: headers = get_auth_headers(token) else: @@ -380,7 +380,7 @@ async def test_returns_403_if_not_authenticated( async def test_returns_200_if_token_is_valid( self, monkeypatch: pytest.MonkeyPatch, client: AsyncClient ): - monkeypatch.setattr("dstack._internal.server.routers.prometheus._auth._token", "secret") + monkeypatch.setattr("dstack._internal.server.routers.prometheus._auth._token", b"secret") response = await client.get("/metrics", headers=get_auth_headers("secret")) assert response.status_code == 200 diff --git a/src/tests/_internal/server/routers/test_sshproxy.py b/src/tests/_internal/server/routers/test_sshproxy.py new file mode 100644 index 0000000000..2b546d7d66 --- /dev/null +++ b/src/tests/_internal/server/routers/test_sshproxy.py @@ -0,0 +1,189 @@ +import os +from typing import Optional + +import pytest +from httpx import AsyncClient +from pytest_unordered import unordered +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import ServerClientErrorCode +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.configurations import DevEnvironmentConfiguration +from dstack._internal.core.models.runs import ( + JobStatus, +) +from dstack._internal.server.testing.common import ( + create_instance, + create_job, + create_project, + create_repo, + create_run, + create_user, + get_auth_headers, + get_job_provisioning_data, + get_job_runtime_data, + get_run_spec, +) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("image_config_mock", "test_db") +class TestGetUpstream: + @pytest.fixture + def token(self) -> str: + token_var = "DSTACK_SSHPROXY_API_TOKEN" + token = os.getenv(token_var) + assert token is not None, f"{token_var} must be set via pytest-env" + return token + + async def test_returns_40x_if_no_api_token_provided(self, client: AsyncClient): + response = await client.post("/api/sshproxy/get_upstream") + + assert response.status_code in [401, 403] + + async def test_returns_40x_if_api_token_is_not_valid(self, client: AsyncClient): + response = await client.post( + "/api/sshproxy/get_upstream", headers=get_auth_headers("invalid-token") + ) + + assert response.status_code in [401, 403] + + async def test_returns_resource_not_exists_if_upstream_id_is_not_uuid( + self, client: AsyncClient, token: str + ): + response = await client.post( + "/api/sshproxy/get_upstream", + headers=get_auth_headers(token), + json={"id": "some-string"}, + ) + + assert response.json()["detail"][0]["code"] == ServerClientErrorCode.RESOURCE_NOT_EXISTS + + async def test_returns_resource_not_exists_if_job_is_not_running( + self, + session: AsyncSession, + client: AsyncClient, + token: str, + ): + project = await create_project(session=session) + instance = await create_instance(session=session, project=project) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run = await create_run(session=session, project=project, user=user, repo=repo) + job = await create_job( + session=session, + run=run, + instance=instance, + status=JobStatus.TERMINATING, + ) + + response = await client.post( + "/api/sshproxy/get_upstream", + headers=get_auth_headers(token), + json={"id": str(job.id)}, + ) + + assert response.json()["detail"][0]["code"] == ServerClientErrorCode.RESOURCE_NOT_EXISTS + + async def test_response( + self, + session: AsyncSession, + client: AsyncClient, + token: str, + ): + project = await create_project(session=session, ssh_private_key="project-key") + instance = await create_instance( + session=session, project=project, backend=BackendType.RUNPOD + ) + user = await create_user(session=session, ssh_public_key="user-key") + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec(repo_id=repo.name, ssh_key_pub="run-spec-key") + run = await create_run( + session=session, project=project, user=user, repo=repo, run_spec=run_spec + ) + jpd = get_job_provisioning_data( + dockerized=False, + backend=BackendType.RUNPOD, + hostname="100.100.100.100", + username="root", + ssh_port=32768, + ssh_proxy=None, + ) + jrd = get_job_runtime_data(username="test-user") + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + job_runtime_data=jrd, + status=JobStatus.RUNNING, + ) + + response = await client.post( + "/api/sshproxy/get_upstream", + headers=get_auth_headers(token), + json={"id": str(job.id)}, + ) + + assert response.json() == { + "hosts": [ + { + "host": "100.100.100.100", + "port": 32768, + "private_key": "project-key", + "user": "test-user", + }, + ], + "authorized_keys": unordered( + [ + "user-key", + "run-spec-key", + ] + ), + } + + @pytest.mark.parametrize( + ["jrd_user", "conf_user", "expected_user"], + [ + pytest.param("jrd", "conf", "jrd", id="from-runner"), + pytest.param(None, "conf", "conf", id="from-configuration"), + pytest.param(None, None, "root", id="default"), + ], + ) + async def test_username_fallbacks( + self, + session: AsyncSession, + client: AsyncClient, + token: str, + jrd_user: Optional[str], + conf_user: Optional[str], + expected_user: str, + ): + project = await create_project(session=session, ssh_private_key="project-key") + instance = await create_instance(session=session, project=project, backend=BackendType.AWS) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + configuration = DevEnvironmentConfiguration(ide="vscode", user=conf_user) + run_spec = get_run_spec(repo_id=repo.name, configuration=configuration) + run = await create_run( + session=session, project=project, user=user, repo=repo, run_spec=run_spec + ) + jpd = get_job_provisioning_data(dockerized=True, backend=BackendType.AWS, username="root") + jrd = get_job_runtime_data(username=jrd_user) + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + job_runtime_data=jrd, + status=JobStatus.RUNNING, + ) + + response = await client.post( + "/api/sshproxy/get_upstream", + headers=get_auth_headers(token), + json={"id": str(job.id)}, + ) + + assert response.json()["hosts"][-1]["user"] == expected_user diff --git a/src/tests/_internal/server/services/test_ssh.py b/src/tests/_internal/server/services/test_ssh.py new file mode 100644 index 0000000000..d05b4709f8 --- /dev/null +++ b/src/tests/_internal/server/services/test_ssh.py @@ -0,0 +1,310 @@ +from typing import Optional + +import pytest +import pytest_asyncio +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT +from dstack._internal.core.models.backends.base import BackendType +from dstack._internal.core.models.common import NetworkMode +from dstack._internal.core.models.instances import SSHConnectionParams, SSHKey +from dstack._internal.core.models.runs import ( + JobRuntimeData, +) +from dstack._internal.server.models import ProjectModel, RunModel +from dstack._internal.server.services.ssh import get_container_ssh_credentials +from dstack._internal.server.testing.common import ( + create_instance, + create_job, + create_project, + create_repo, + create_run, + create_user, + get_job_provisioning_data, + get_job_runtime_data, + get_remote_connection_info, +) +from dstack._internal.utils.path import FileContent + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("test_db", "image_config_mock") +class TestGetContainerSSHCredentials: + project_key = "project-key" + + @pytest_asyncio.fixture + async def project(self, session: AsyncSession) -> ProjectModel: + return await create_project(session=session, ssh_private_key=self.project_key) + + @pytest_asyncio.fixture + async def run(self, session: AsyncSession, project: ProjectModel) -> RunModel: + user = await create_user(session=session) + repo = await create_repo( + session=session, + project_id=project.id, + ) + return await create_run(session=session, project=project, user=user, repo=repo) + + @pytest.mark.parametrize( + ["jrd", "expected_port"], + [ + pytest.param(None, DSTACK_RUNNER_SSH_PORT, id="no-jrd"), + pytest.param( + get_job_runtime_data(network_mode=NetworkMode.HOST, ports={}), + DSTACK_RUNNER_SSH_PORT, + id="host", + ), + pytest.param( + get_job_runtime_data( + network_mode=NetworkMode.HOST, ports={DSTACK_RUNNER_SSH_PORT: 32772} + ), + 32772, + id="bridge", + ), + ], + ) + async def test_vm_based_backend( + self, + session: AsyncSession, + project: ProjectModel, + run: RunModel, + jrd: Optional[JobRuntimeData], + expected_port: int, + ): + instance = await create_instance(session=session, project=project, backend=BackendType.AWS) + jpd = get_job_provisioning_data( + backend=BackendType.AWS, + dockerized=True, + hostname="80.80.80.80", + username="ubuntu", + ssh_port=22, + ssh_proxy=None, + ) + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + job_runtime_data=jrd, + ) + + hosts = get_container_ssh_credentials(job) + + assert hosts == [ + ( + SSHConnectionParams( + hostname="80.80.80.80", + username="ubuntu", + port=22, + ), + FileContent(self.project_key), + ), + ( + SSHConnectionParams( + hostname="localhost", + username="root", + port=expected_port, + ), + FileContent(self.project_key), + ), + ] + + async def test_container_based_backend( + self, + session: AsyncSession, + project: ProjectModel, + run: RunModel, + ): + instance = await create_instance( + session=session, project=project, backend=BackendType.RUNPOD + ) + jpd = get_job_provisioning_data( + backend=BackendType.RUNPOD, + dockerized=False, + hostname="100.100.100.100", + username="root", + ssh_port=32768, + ssh_proxy=None, + ) + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + ) + + hosts = get_container_ssh_credentials(job) + + assert hosts == [ + ( + SSHConnectionParams( + hostname="100.100.100.100", + username="root", + port=32768, + ), + FileContent(self.project_key), + ), + ] + + async def test_container_based_backend_with_proxy( + self, + session: AsyncSession, + project: ProjectModel, + run: RunModel, + ): + instance = await create_instance( + session=session, project=project, backend=BackendType.KUBERNETES + ) + jpd = get_job_provisioning_data( + backend=BackendType.KUBERNETES, + dockerized=False, + hostname="10.105.30.22", + username="root", + ssh_port=DSTACK_RUNNER_SSH_PORT, + ssh_proxy=SSHConnectionParams( + hostname="120.120.120.120", + username="root", + port=30022, + ), + ) + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + ) + + hosts = get_container_ssh_credentials(job) + + assert hosts == [ + ( + SSHConnectionParams( + hostname="120.120.120.120", + username="root", + port=30022, + ), + FileContent(self.project_key), + ), + ( + SSHConnectionParams( + hostname="10.105.30.22", + username="root", + port=DSTACK_RUNNER_SSH_PORT, + ), + FileContent(self.project_key), + ), + ] + + async def test_ssh_instance_with_head_proxy( + self, + session: AsyncSession, + project: ProjectModel, + run: RunModel, + ): + rci = get_remote_connection_info( + host="192.168.100.50", + port=22222, + ssh_user="ubuntu", + # User-provided key is only used for instance provisioning, then we always use + # the project key, which is added during provisioning + ssh_keys=[SSHKey(public="public", private="instance-key")], + ssh_proxy=SSHConnectionParams( + hostname="140.140.140.140", + username="bastion", + port=22, + ), + ssh_proxy_keys=[SSHKey(public="public", private="head-key")], + ) + instance = await create_instance( + session=session, + project=project, + backend=BackendType.REMOTE, + remote_connection_info=rci, + ) + jpd = get_job_provisioning_data( + backend=BackendType.REMOTE, + dockerized=True, + hostname="192.168.100.50", + username="ubuntu", + ssh_port=22222, + # Actually, JobModel.job_provisioning_data.ssh_proxy is set to + # InstanceModel.remote_connection_info.ssh_proxy but not used in the function we test + ssh_proxy=None, + ) + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + # jrd is tested in vm-based backend tests + job_runtime_data=None, + ) + + hosts = get_container_ssh_credentials(job) + + assert hosts == [ + ( + SSHConnectionParams( + hostname="140.140.140.140", + username="bastion", + port=22, + ), + FileContent("head-key"), + ), + ( + SSHConnectionParams( + hostname="192.168.100.50", + username="ubuntu", + port=22222, + ), + FileContent(self.project_key), + ), + ( + SSHConnectionParams( + hostname="localhost", + username="root", + port=DSTACK_RUNNER_SSH_PORT, + ), + FileContent(self.project_key), + ), + ] + + async def test_local_backend( + self, + session: AsyncSession, + project: ProjectModel, + run: RunModel, + ): + instance = await create_instance( + session=session, project=project, backend=BackendType.LOCAL + ) + jpd = get_job_provisioning_data( + backend=BackendType.LOCAL, + dockerized=True, + hostname="127.0.0.1", + username="root", + ssh_port=DSTACK_RUNNER_SSH_PORT, + ssh_proxy=None, + ) + job = await create_job( + session=session, + run=run, + instance=instance, + job_provisioning_data=jpd, + # jrd is tested in vm-based backend tests + job_runtime_data=None, + ) + + hosts = get_container_ssh_credentials(job) + + assert hosts == [ + ( + SSHConnectionParams( + hostname="localhost", + username="root", + port=DSTACK_RUNNER_SSH_PORT, + ), + FileContent(self.project_key), + ), + ]