Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ markers = [
]
env = [
"DSTACK_CLI_RICH_FORCE_TERMINAL=0",
"DSTACK_SSHPROXY_API_TOKEN=test-token",
]
filterwarnings = [
# testcontainers modules use deprecated decorators – nothing we can do:
Expand All @@ -142,6 +143,7 @@ dev = [
"pytest-httpbin>=2.1.0",
"pytest-socket>=0.7.0",
"pytest-env>=1.1.0",
"pytest-unordered>=0.7.0",
"httpbin>=0.10.2", # indirect to make compatible with Werkzeug 3
"requests-mock>=1.12.1",
"openai>=1.68.2",
Expand Down
2 changes: 2 additions & 0 deletions src/dstack/_internal/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
runs,
secrets,
server,
sshproxy,
templates,
users,
volumes,
Expand Down Expand Up @@ -253,6 +254,7 @@ def register_routes(app: FastAPI, ui: bool = True):
app.include_router(files.router)
app.include_router(events.root_router)
app.include_router(templates.router)
app.include_router(sshproxy.router)

@app.exception_handler(ForbiddenError)
async def forbidden_error_handler(request: Request, exc: ForbiddenError):
Expand Down
39 changes: 39 additions & 0 deletions src/dstack/_internal/server/routers/sshproxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import os
from typing import Annotated

from fastapi import APIRouter, Depends
from sqlalchemy.ext.asyncio import AsyncSession

from dstack._internal.core.errors import ResourceNotExistsError
from dstack._internal.server.db import get_session
from dstack._internal.server.schemas.sshproxy import GetUpstreamRequest, GetUpstreamResponse
from dstack._internal.server.security.permissions import AlwaysForbidden, ServiceAccount
from dstack._internal.server.services.sshproxy import get_upstream_response
from dstack._internal.server.utils.routers import (
CustomORJSONResponse,
get_base_api_additional_responses,
)

if _token := os.getenv("DSTACK_SSHPROXY_API_TOKEN"):
_auth = ServiceAccount(_token)
else:
_auth = AlwaysForbidden()


router = APIRouter(
prefix="/api/sshproxy",
tags=["sshproxy"],
responses=get_base_api_additional_responses(),
dependencies=[Depends(_auth)],
)


@router.post("/get_upstream", response_model=GetUpstreamResponse)
async def get_upstream(
body: GetUpstreamRequest,
session: Annotated[AsyncSession, Depends(get_session)],
):
response = await get_upstream_response(session=session, upstream_id=body.id)
if response is None:
raise ResourceNotExistsError()
return CustomORJSONResponse(response)
27 changes: 27 additions & 0 deletions src/dstack/_internal/server/schemas/sshproxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from typing import Annotated

from pydantic import Field

from dstack._internal.core.models.common import CoreModel


class GetUpstreamRequest(CoreModel):
# The format of id is intentionally not limited to UUID to allow further extensions
id: str


class UpstreamHost(CoreModel):
host: Annotated[str, Field(description="The hostname or IP address")]
port: Annotated[int, Field(description="The SSH port")]
user: Annotated[str, Field(description="The user to log in")]
private_key: Annotated[str, Field(description="The private key in OpenSSH file format")]


class GetUpstreamResponse(CoreModel):
hosts: Annotated[
list[UpstreamHost],
Field(description="The chain of SSH hosts, the jump host(s) first, the target host last"),
]
authorized_keys: Annotated[
list[str], Field(description="The list of authorized public keys in OpenSSH file format")
]
27 changes: 23 additions & 4 deletions src/dstack/_internal/server/security/permissions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from secrets import compare_digest
from typing import Annotated, Optional, Tuple
from uuid import UUID

Expand Down Expand Up @@ -219,9 +220,23 @@ async def __call__(
raise error_forbidden()


class OptionalServiceAccount:
class ServiceAccount:
def __init__(self, token: str) -> None:
self._token = token.encode()

async def __call__(
self, token: Annotated[HTTPAuthorizationCredentials, Security(HTTPBearer())]
) -> None:
if not compare_digest(token.credentials.encode(), self._token):
raise error_invalid_token()


class OptionalServiceAccount(ServiceAccount):
_token: Optional[bytes] = None

def __init__(self, token: Optional[str]) -> None:
self._token = token
if token is not None:
super().__init__(token)

async def __call__(
self,
Expand All @@ -233,8 +248,12 @@ async def __call__(
return
if token is None:
raise error_forbidden()
if token.credentials != self._token:
raise error_invalid_token()
await super().__call__(token)


class AlwaysForbidden:
async def __call__(self) -> None:
raise error_forbidden()


async def get_project_member(
Expand Down
4 changes: 4 additions & 0 deletions src/dstack/_internal/server/services/jobs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,10 @@ def get_job_runtime_data(job_model: JobModel) -> Optional[JobRuntimeData]:
return JobRuntimeData.__response__.parse_raw(job_model.job_runtime_data)


def get_job_spec(job_model: JobModel) -> JobSpec:
return JobSpec.__response__.parse_raw(job_model.job_spec_data)


def delay_job_instance_termination(job_model: JobModel):
job_model.remove_at = common.get_current_datetime() + timedelta(seconds=15)

Expand Down
6 changes: 5 additions & 1 deletion src/dstack/_internal/server/services/runs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def switch_run_status(
events.emit(session, msg, actor=actor, targets=[events.Target.from_model(run_model)])


def get_run_spec(run_model: RunModel) -> RunSpec:
return RunSpec.__response__.parse_raw(run_model.run_spec)


async def list_user_runs(
session: AsyncSession,
user: UserModel,
Expand Down Expand Up @@ -743,7 +747,7 @@ def run_model_to_run(
include_sensitive=include_sensitive,
)

run_spec = RunSpec.__response__.parse_raw(run_model.run_spec)
run_spec = get_run_spec(run_model)

latest_job_submission = None
if len(jobs) > 0 and len(jobs[0].job_submissions) > 0:
Expand Down
106 changes: 66 additions & 40 deletions src/dstack/_internal/server/services/ssh.py
Original file line number Diff line number Diff line change
@@ -1,67 +1,93 @@
from collections.abc import Iterable
from typing import Optional

import dstack._internal.server.services.jobs as jobs_services
from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT
from dstack._internal.core.models.backends.base import BackendType
from dstack._internal.core.models.instances import SSHConnectionParams
from dstack._internal.core.models.runs import JobProvisioningData
from dstack._internal.core.services.ssh.tunnel import SSH_DEFAULT_OPTIONS, SocketPair, SSHTunnel
from dstack._internal.server.models import JobModel
from dstack._internal.server.services.instances import get_instance_remote_connection_info
from dstack._internal.server.services.jobs import get_job_provisioning_data, get_job_runtime_data
from dstack._internal.utils.common import get_or_error
from dstack._internal.utils.path import FileContent


def container_ssh_tunnel(
job: JobModel,
forwarded_sockets: Iterable[SocketPair] = (),
options: dict[str, str] = SSH_DEFAULT_OPTIONS,
) -> SSHTunnel:
def get_container_ssh_credentials(job: JobModel) -> list[tuple[SSHConnectionParams, FileContent]]:
"""
Build SSHTunnel for connecting to the container running the specified job.
Returns the information needed to connect to the SSH server inside the job container.

The user of the target host (container) is set to:
* VM-based backends and SSH instances: "root"
* container-based backends: `JobProvisioningData.username`, which is, as of 2026-03-10,
is always "root" on all supported backends (Runpod, Vast.ai, Kubernetes)

Args:
job: `JobModel` with `instance` and `instance.project` fields loaded.

Returns:
A list of hosts credentials as (host's `SSHConnectionParams`, private key's `FileContent`)
pairs ordered from the first proxy jump (if any) to the target host (container).
"""
jpd: JobProvisioningData = JobProvisioningData.__response__.parse_raw(
job.job_provisioning_data
)
hosts: list[tuple[SSHConnectionParams, FileContent]] = []

instance = get_or_error(job.instance)
project_key = FileContent(instance.project.ssh_private_key)

rci = get_instance_remote_connection_info(instance)
if rci is not None and (head_proxy := rci.ssh_proxy) is not None:
head_key = FileContent(get_or_error(get_or_error(rci.ssh_proxy_keys)[0].private))
hosts.append((head_proxy, head_key))

jpd = get_job_provisioning_data(job)
assert jpd is not None
assert jpd.hostname is not None
assert jpd.ssh_port is not None
if not jpd.dockerized:
ssh_destination = f"{jpd.username}@{jpd.hostname}"
ssh_port = jpd.ssh_port
ssh_proxy = jpd.ssh_proxy
else:
ssh_destination = "root@localhost"

if jpd.dockerized:
if jpd.backend != BackendType.LOCAL:
instance_proxy = SSHConnectionParams(
hostname=jpd.hostname,
username=jpd.username,
port=jpd.ssh_port,
)
hosts.append((instance_proxy, project_key))
ssh_port = DSTACK_RUNNER_SSH_PORT
job_submission = jobs_services.job_model_to_job_submission(job)
jrd = job_submission.job_runtime_data
jrd = get_job_runtime_data(job)
if jrd is not None and jrd.ports is not None:
ssh_port = jrd.ports.get(ssh_port, ssh_port)
ssh_proxy = SSHConnectionParams(
target_host = SSHConnectionParams(
hostname="localhost",
username="root",
port=ssh_port,
)
hosts.append((target_host, project_key))
else:
if jpd.ssh_proxy is not None:
hosts.append((jpd.ssh_proxy, project_key))
target_host = SSHConnectionParams(
hostname=jpd.hostname,
username=jpd.username,
port=jpd.ssh_port,
)
if jpd.backend == BackendType.LOCAL:
ssh_proxy = None
ssh_head_proxy: Optional[SSHConnectionParams] = None
ssh_head_proxy_private_key: Optional[str] = None
instance = get_or_error(job.instance)
rci = get_instance_remote_connection_info(instance)
if rci is not None and rci.ssh_proxy is not None:
ssh_head_proxy = rci.ssh_proxy
ssh_head_proxy_private_key = get_or_error(rci.ssh_proxy_keys)[0].private
ssh_proxies = []
if ssh_head_proxy is not None:
ssh_head_proxy_private_key = get_or_error(ssh_head_proxy_private_key)
ssh_proxies.append((ssh_head_proxy, FileContent(ssh_head_proxy_private_key)))
if ssh_proxy is not None:
ssh_proxies.append((ssh_proxy, None))
hosts.append((target_host, project_key))

return hosts


def container_ssh_tunnel(
job: JobModel,
forwarded_sockets: Iterable[SocketPair] = (),
options: dict[str, str] = SSH_DEFAULT_OPTIONS,
) -> SSHTunnel:
"""
Build SSHTunnel for connecting to the container running the specified job.
"""
hosts = get_container_ssh_credentials(job)
target, identity = hosts[-1]
return SSHTunnel(
destination=ssh_destination,
port=ssh_port,
ssh_proxies=ssh_proxies,
identity=FileContent(instance.project.ssh_private_key),
destination=f"{target.username}@{target.hostname}",
port=target.port,
ssh_proxies=hosts[:-1],
identity=identity,
forwarded_sockets=forwarded_sockets,
options=options,
)
86 changes: 86 additions & 0 deletions src/dstack/_internal/server/services/sshproxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from typing import Optional
from uuid import UUID

from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload

from dstack._internal.core.models.runs import JobStatus
from dstack._internal.server.models import (
InstanceModel,
JobModel,
ProjectModel,
RunModel,
UserModel,
)
from dstack._internal.server.schemas.sshproxy import GetUpstreamResponse, UpstreamHost
from dstack._internal.server.services.jobs import get_job_runtime_data, get_job_spec
from dstack._internal.server.services.runs import get_run_spec
from dstack._internal.server.services.ssh import get_container_ssh_credentials


async def get_upstream_response(
session: AsyncSession,
upstream_id: str,
) -> Optional[GetUpstreamResponse]:
# The format of upstream_id is intentionally not limited to UUID in the API schema to allow
# further extensions. Currently, it's just a JobModel.id
try:
job_id = UUID(upstream_id)
except ValueError:
return None

res = await session.execute(
select(JobModel)
.where(
JobModel.id == job_id,
JobModel.status == JobStatus.RUNNING,
)
.options(
(
joinedload(JobModel.instance, innerjoin=True)
.load_only(InstanceModel.remote_connection_info)
.joinedload(InstanceModel.project, innerjoin=True)
.load_only(ProjectModel.ssh_private_key)
),
(
joinedload(JobModel.run, innerjoin=True)
.load_only(RunModel.run_spec)
.joinedload(RunModel.user, innerjoin=True)
.load_only(UserModel.ssh_public_key)
),
)
)
job = res.scalar_one_or_none()
if job is None:
return None

hosts: list[UpstreamHost] = []
for ssh_params, private_key in get_container_ssh_credentials(job):
hosts.append(
UpstreamHost(
host=ssh_params.hostname,
port=ssh_params.port,
user=ssh_params.username,
private_key=private_key.content,
)
)

username: Optional[str] = None
if (jrd := get_job_runtime_data(job)) is not None:
username = jrd.username
if username is None and (job_spec_user := get_job_spec(job).user) is not None:
username = job_spec_user.username
if username is not None:
hosts[-1].user = username

authorized_keys: set[str] = set()
if (run_spec_key := get_run_spec(job.run).ssh_key_pub) is not None:
authorized_keys.add(run_spec_key)
if (user_key := job.run.user.ssh_public_key) is not None:
authorized_keys.add(user_key)

return GetUpstreamResponse(
hosts=hosts,
authorized_keys=list(authorized_keys),
)
Loading