From 2e1eb8234cc7b48b8b59afc61f7e7c7396515c81 Mon Sep 17 00:00:00 2001 From: Jason Etcovitch Date: Mon, 11 May 2026 11:33:10 -0400 Subject: [PATCH] Add Python cloud session SDK API Implement cloud session support in the Python SDK, matching the Node SDK behavior from PR #1256. This adds: - copilot.cloud sub-package with CloudSession, MissionControlClient, and all related types - CopilotClient.create_cloud_session() for creating sandbox-backed cloud sessions through Mission Control - CopilotClient.connect_cloud_session() for attaching to existing tasks - Event polling with deduplication and chronological sorting - Steering helpers: send, abort, respond_to_permission, respond_to_ask_user, respond_to_elicitation, respond_to_exit_plan_mode, switch_mode - MissionControlCommandType enum covering user_message, abort, permission_response, ask_user_response, plan_approval_response, elicitation_response, and mode_switch - 13 unit tests covering task creation, repo-less tasks, owner validation, steer API, event sorting/dedup, error handling, typed handlers, and context manager support Uses X-Copilot-Agent-Slug: copilot-developer-sandbox. Requires callers to pass explicit repository or owner (no Git inference). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- python/copilot/__init__.py | 48 +- python/copilot/client.py | 268 +++++++ python/copilot/cloud/__init__.py | 52 ++ python/copilot/cloud/cloud_session.py | 458 ++++++++++++ .../copilot/cloud/mission_control_client.py | 322 +++++++++ python/copilot/cloud/types.py | 309 ++++++++ python/test_cloud_session.py | 670 ++++++++++++++++++ 7 files changed, 2124 insertions(+), 3 deletions(-) create mode 100644 python/copilot/cloud/__init__.py create mode 100644 python/copilot/cloud/cloud_session.py create mode 100644 python/copilot/cloud/mission_control_client.py create mode 100644 python/copilot/cloud/types.py create mode 100644 python/test_cloud_session.py diff --git a/python/copilot/__init__.py b/python/copilot/__init__.py index 1963a2d41..c19e02cd7 100644 --- a/python/copilot/__init__.py +++ b/python/copilot/__init__.py @@ -13,6 +13,28 @@ ModelVisionLimitsOverride, SubprocessConfig, ) +from .cloud import ( + CloudAskUserResponsePayload, + CloudConnectOptions, + CloudElicitationResponsePayload, + CloudModeSwitchPayload, + CloudPermissionResponsePayload, + CloudPlanApprovalResponsePayload, + CloudProgressEvent, + CloudProgressPhase, + CloudRepository, + CloudSession, + CloudSessionError, + CloudSessionEvent, + CloudSessionEventHandler, + CloudSessionFailureReason, + CloudSessionMetadata, + CloudSessionOptions, + MissionControlClient, + MissionControlCommandType, + MissionControlTask, + MissionControlTaskSession, +) from .session import ( AutoModeSwitchHandler, AutoModeSwitchRequest, @@ -45,23 +67,43 @@ __version__ = "0.1.0" __all__ = [ - "CommandContext", "AutoModeSwitchHandler", "AutoModeSwitchRequest", "AutoModeSwitchResponse", + "CloudAskUserResponsePayload", + "CloudConnectOptions", + "CloudElicitationResponsePayload", + "CloudModeSwitchPayload", + "CloudPermissionResponsePayload", + "CloudPlanApprovalResponsePayload", + "CloudProgressEvent", + "CloudProgressPhase", + "CloudRepository", + "CloudSession", + "CloudSessionError", + "CloudSessionEvent", + "CloudSessionEventHandler", + "CloudSessionFailureReason", + "CloudSessionMetadata", + "CloudSessionOptions", + "CommandContext", "CommandDefinition", "CopilotClient", "CopilotSession", "CreateSessionFsHandler", + "ElicitationContext", "ElicitationHandler", "ElicitationParams", - "ElicitationContext", "ElicitationResult", "ExitPlanModeHandler", "ExitPlanModeRequest", "ExitPlanModeResult", "ExternalServerConfig", "InputOptions", + "MissionControlClient", + "MissionControlCommandType", + "MissionControlTask", + "MissionControlTaskSession", "ModelCapabilitiesOverride", "ModelLimitsOverride", "ModelSupportsOverride", @@ -71,10 +113,10 @@ "SessionFsConfig", "SessionFsFileInfo", "SessionFsProvider", - "create_session_fs_adapter", "SessionUiApi", "SessionUiCapabilities", "SubprocessConfig", "convert_mcp_call_tool_result", + "create_session_fs_adapter", "define_tool", ] diff --git a/python/copilot/client.py b/python/copilot/client.py index 848af4b92..fb343bd4c 100644 --- a/python/copilot/client.py +++ b/python/copilot/client.py @@ -27,6 +27,7 @@ import uuid from collections.abc import Awaitable, Callable from dataclasses import KW_ONLY, dataclass, field +from datetime import UTC from pathlib import Path from types import TracebackType from typing import Any, Literal, TypedDict, cast, overload @@ -35,6 +36,15 @@ from ._jsonrpc import JsonRpcClient, JsonRpcError, ProcessExitedError from ._sdk_protocol_version import get_sdk_protocol_version from ._telemetry import get_trace_context, trace_context +from .cloud.cloud_session import CloudSession +from .cloud.mission_control_client import MissionControlClient +from .cloud.types import ( + CloudConnectOptions, + CloudRepository, + CloudSessionMetadata, + CloudSessionOptions, + MissionControlTask, +) from .generated.rpc import ( ClientSessionApiHandlers, ConnectRequest, @@ -90,6 +100,17 @@ def _validate_session_fs_config(config: SessionFsConfig) -> None: raise ValueError("session_fs.conventions must be either 'posix' or 'windows'") +def _strip_trailing_slash(value: str) -> str: + return value.rstrip("/") + + +def _normalize_token(value: str | None) -> str | None: + if value is None: + return None + trimmed = value.strip() + return trimmed if trimmed else None + + class TelemetryConfig(TypedDict, total=False): """Configuration for OpenTelemetry integration with the Copilot CLI.""" @@ -2069,6 +2090,253 @@ async def get_auth_status(self) -> GetAuthStatusResponse: result = await self._client.request("auth.getStatus", {}) return GetAuthStatusResponse.from_dict(result) + # ------------------------------------------------------------------ + # Cloud sessions + # ------------------------------------------------------------------ + + def _create_mission_control_client( + self, + options: CloudSessionOptions | CloudConnectOptions, + ) -> MissionControlClient: + """Build a MissionControlClient from cloud session options and env.""" + cfg = self._config + if isinstance(cfg, SubprocessConfig) and cfg.env is not None: + env: dict[str, str] = cfg.env + else: + env = dict(os.environ) + + copilot_api_base_url = _strip_trailing_slash( + options.get("copilot_api_base_url") + or env.get("COPILOT_API_BASE_URL") + or env.get("COPILOT_API_URL") + or "https://api.githubcopilot.com" + ) + base_url = ( + options.get("mission_control_base_url") + or env.get("COPILOT_MC_BASE_URL") + or f"{copilot_api_base_url}/agents" + ) + + github_token = cfg.github_token if isinstance(cfg, SubprocessConfig) else None + auth_token = ( + _normalize_token(options.get("auth_token")) + or _normalize_token(env.get("COPILOT_MC_ACCESS_TOKEN")) + or _normalize_token(github_token) + ) + + frontend_base_url = ( + options.get("frontend_base_url") + or env.get("COPILOT_MC_FRONTEND_URL") + or "https://github.com" + ) + + return MissionControlClient( + base_url=base_url, + auth_token=auth_token, + integration_id=options.get("integration_id"), + frontend_base_url=frontend_base_url, + ) + + @staticmethod + def _create_cloud_session_metadata( + task: MissionControlTask, + mc_client: MissionControlClient, + repository: CloudRepository | None = None, + owner: str | None = None, + ) -> CloudSessionMetadata: + from datetime import datetime + + return CloudSessionMetadata( + task_id=task.id, + mission_control_session_id=(task.sessions[-1].id if task.sessions else None), + frontend_url=mc_client.get_frontend_url(task.id), + owner=owner, + repository=repository, + created_at=datetime.fromisoformat(task.created_at), + updated_at=datetime.fromisoformat(task.updated_at), + state=task.state, + status=task.status, + ) + + @staticmethod + def _create_fallback_cloud_session_metadata( + task_id: str, + mc_client: MissionControlClient, + repository: CloudRepository | None = None, + owner: str | None = None, + ) -> CloudSessionMetadata: + from datetime import datetime + + now = datetime.now(UTC) + return CloudSessionMetadata( + task_id=task_id, + frontend_url=mc_client.get_frontend_url(task_id), + owner=owner, + repository=repository, + created_at=now, + updated_at=now, + ) + + async def create_cloud_session( + self, + options: CloudSessionOptions | None = None, + ) -> CloudSession: + """Create a sandbox-backed cloud session through Mission Control. + + This does not create a local runtime session. The agent runs inside the + provisioned cloud sandbox; this SDK instance polls Mission Control for + events and sends user actions through the task steer API. + + Args: + options: Cloud session options. Either ``repository`` or ``owner`` + must be provided. If ``repository`` is omitted, ``owner`` is + required for billing/authorization. + + Returns: + A connected :class:`~copilot.cloud.CloudSession` instance. + + Raises: + ValueError: If neither ``repository`` nor ``owner`` is provided. + CloudSessionError: On Mission Control API errors. + """ + from .cloud.mission_control_client import _CreateCloudTaskParams, _CreateCloudTaskRepository + + if options is None: + options = CloudSessionOptions() + + started_at = time.monotonic() + mc_client = self._create_mission_control_client(options) + owner = _normalize_token(options.get("owner")) + repository = options.get("repository") + + if not repository and not owner: + raise ValueError("CloudSessionOptions.owner is required when repository is omitted") + + on_progress = options.get("on_progress") + if on_progress: + from .cloud.types import CloudProgressEvent + + on_progress(CloudProgressEvent(phase="creating_task", elapsed_ms=0)) + on_progress( + CloudProgressEvent( + phase="provisioning_sandbox", + elapsed_ms=(time.monotonic() - started_at) * 1000, + ) + ) + + task_repo = ( + _CreateCloudTaskRepository(owner=repository["owner"], name=repository["name"]) + if repository + else None + ) + task = await mc_client.create_cloud_task( + _CreateCloudTaskParams(owner=owner, repository=task_repo) + ) + + on_cloud_task_created = options.get("on_cloud_task_created") + if on_cloud_task_created: + on_cloud_task_created(task) + + if on_progress: + on_progress( + CloudProgressEvent( + phase="waiting_for_session", + elapsed_ms=(time.monotonic() - started_at) * 1000, + task_id=task.id, + ) + ) + + session = CloudSession( + client=mc_client, + metadata=self._create_cloud_session_metadata(task, mc_client, repository, owner), + poll_interval_ms=options.get("poll_interval_ms"), + initial_event_timeout_ms=options.get("initial_event_timeout_ms"), + initial_event_poll_interval_ms=options.get("initial_event_poll_interval_ms"), + on_event_poll_error=options.get("on_event_poll_error"), + ) + await session.connect() + + if on_progress: + on_progress( + CloudProgressEvent( + phase="connected", + elapsed_ms=(time.monotonic() - started_at) * 1000, + task_id=task.id, + ) + ) + + return session + + async def connect_cloud_session( + self, + task_or_session_id: str, + options: CloudConnectOptions | None = None, + ) -> CloudSession: + """Attach to an existing Mission Control cloud task. + + The identifier is treated as a task ID. If Mission Control can return + task metadata, it populates the session metadata; otherwise the SDK + still attaches by polling task events for the provided identifier. + + Args: + task_or_session_id: The Mission Control task ID to connect to. + options: Connection options. + + Returns: + A connected :class:`~copilot.cloud.CloudSession` instance. + """ + if options is None: + options = CloudConnectOptions() + + started_at = time.monotonic() + mc_client = self._create_mission_control_client(options) + + on_progress = options.get("on_progress") + if on_progress: + from .cloud.types import CloudProgressEvent + + on_progress( + CloudProgressEvent( + phase="waiting_for_session", + elapsed_ms=0, + task_id=task_or_session_id, + ) + ) + + task = await mc_client.get_task(task_or_session_id) + owner = _normalize_token(options.get("owner")) + repository = options.get("repository") + + if task: + metadata = self._create_cloud_session_metadata(task, mc_client, repository, owner) + else: + metadata = self._create_fallback_cloud_session_metadata( + task_or_session_id, mc_client, repository, owner + ) + + session = CloudSession( + client=mc_client, + metadata=metadata, + poll_interval_ms=options.get("poll_interval_ms"), + initial_event_timeout_ms=options.get("initial_event_timeout_ms"), + initial_event_poll_interval_ms=options.get("initial_event_poll_interval_ms"), + on_event_poll_error=options.get("on_event_poll_error"), + ) + await session.connect() + + if on_progress: + from .cloud.types import CloudProgressEvent + + on_progress( + CloudProgressEvent( + phase="connected", + elapsed_ms=(time.monotonic() - started_at) * 1000, + task_id=metadata.task_id, + ) + ) + + return session + async def list_models(self) -> list[ModelInfo]: """ List available models with their metadata. diff --git a/python/copilot/cloud/__init__.py b/python/copilot/cloud/__init__.py new file mode 100644 index 000000000..718707dae --- /dev/null +++ b/python/copilot/cloud/__init__.py @@ -0,0 +1,52 @@ +""" +Cloud session support for the Copilot SDK. + +This sub-package provides the :class:`CloudSession` class for creating and +controlling sandbox-backed cloud sessions through Mission Control, along with +the low-level :class:`MissionControlClient` HTTP client. +""" + +from .cloud_session import CloudSession +from .mission_control_client import CloudSessionError, MissionControlClient +from .types import ( + CloudAskUserResponsePayload, + CloudConnectOptions, + CloudElicitationResponsePayload, + CloudModeSwitchPayload, + CloudPermissionResponsePayload, + CloudPlanApprovalResponsePayload, + CloudProgressEvent, + CloudProgressPhase, + CloudRepository, + CloudSessionEvent, + CloudSessionEventHandler, + CloudSessionFailureReason, + CloudSessionMetadata, + CloudSessionOptions, + MissionControlCommandType, + MissionControlTask, + MissionControlTaskSession, +) + +__all__ = [ + "CloudAskUserResponsePayload", + "CloudConnectOptions", + "CloudElicitationResponsePayload", + "CloudModeSwitchPayload", + "CloudPermissionResponsePayload", + "CloudPlanApprovalResponsePayload", + "CloudProgressEvent", + "CloudProgressPhase", + "CloudRepository", + "CloudSession", + "CloudSessionError", + "CloudSessionEvent", + "CloudSessionEventHandler", + "CloudSessionFailureReason", + "CloudSessionMetadata", + "CloudSessionOptions", + "MissionControlClient", + "MissionControlCommandType", + "MissionControlTask", + "MissionControlTaskSession", +] diff --git a/python/copilot/cloud/cloud_session.py b/python/copilot/cloud/cloud_session.py new file mode 100644 index 000000000..89ad3219c --- /dev/null +++ b/python/copilot/cloud/cloud_session.py @@ -0,0 +1,458 @@ +""" +Cloud session — remote-control client for a sandbox-backed cloud session. + +The :class:`CloudSession` class polls Mission Control for task events and +exposes methods for sending user messages and steering commands. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from collections.abc import Callable +from typing import Any, overload + +from .mission_control_client import MissionControlClient +from .types import ( + CloudAskUserResponsePayload, + CloudElicitationResponsePayload, + CloudModeSwitchPayload, + CloudPermissionResponsePayload, + CloudPlanApprovalResponsePayload, + CloudSessionEvent, + CloudSessionEventHandler, + CloudSessionMetadata, + ElicitationResult, + ExitPlanModeResult, + MissionControlCommandType, +) + +logger = logging.getLogger(__name__) + +_DEFAULT_POLL_INTERVAL_S = 5.0 +_DEFAULT_INITIAL_EVENT_TIMEOUT_S = 10.0 +_DEFAULT_INITIAL_EVENT_POLL_INTERVAL_S = 0.5 + + +class CloudSession: + """Remote-control client for a cloud sandbox session. + + After construction, call :meth:`connect` to start receiving events. + + Args: + client: The Mission Control HTTP client. + metadata: Metadata describing the cloud session/task. + poll_interval_ms: Milliseconds between event polls (default: 5000). + initial_event_timeout_ms: Milliseconds to wait for the first event (default: 10000). + initial_event_poll_interval_ms: Milliseconds between polls while waiting for the + first event (default: 500). + on_event_poll_error: Callback invoked when an event poll fails. + """ + + def __init__( + self, + *, + client: MissionControlClient, + metadata: CloudSessionMetadata, + poll_interval_ms: int | None = None, + initial_event_timeout_ms: int | None = None, + initial_event_poll_interval_ms: int | None = None, + on_event_poll_error: Callable[[Exception], None] | None = None, + ) -> None: + self._client = client + self.metadata = metadata + self.session_id = metadata.mission_control_session_id or metadata.task_id + + self._poll_interval_s = ( + poll_interval_ms / 1000 if poll_interval_ms is not None else _DEFAULT_POLL_INTERVAL_S + ) + self._initial_event_timeout_s = ( + initial_event_timeout_ms / 1000 + if initial_event_timeout_ms is not None + else _DEFAULT_INITIAL_EVENT_TIMEOUT_S + ) + self._initial_event_poll_interval_s = ( + initial_event_poll_interval_ms / 1000 + if initial_event_poll_interval_ms is not None + else _DEFAULT_INITIAL_EVENT_POLL_INTERVAL_S + ) + self._on_event_poll_error = on_event_poll_error + + self._event_handlers: set[CloudSessionEventHandler] = set() + self._typed_event_handlers: dict[str, set[CloudSessionEventHandler]] = {} + self._events: list[CloudSessionEvent] = [] + self._seen_event_ids: set[str] = set() + self._seen_event_ids_at_last_timestamp: set[str] = set() + self._last_seen_timestamp: str | None = None + self._poller_task: asyncio.Task[None] | None = None + self._is_polling = False + self._is_disconnected = False + self._remote_steerable = True + + # ------------------------------------------------------------------ + # Connection lifecycle + # ------------------------------------------------------------------ + + async def connect(self) -> None: + """Connect to the cloud session and start event polling.""" + initial_events = await self._wait_for_initial_events() + self._record_events(initial_events) + self._start_event_polling() + + # ------------------------------------------------------------------ + # Event subscription + # ------------------------------------------------------------------ + + @overload + def on(self, handler: CloudSessionEventHandler, /) -> Callable[[], None]: ... + + @overload + def on(self, event_type: str, handler: CloudSessionEventHandler, /) -> Callable[[], None]: ... + + def on( + self, + event_type_or_handler: str | CloudSessionEventHandler, + handler: CloudSessionEventHandler | None = None, + /, + ) -> Callable[[], None]: + """Register an event handler. + + Can be called with a wildcard handler:: + + session.on(lambda event: print(event.type)) + + Or with a specific event type:: + + session.on("session.idle", lambda event: print("idle")) + + Returns a callable that removes the handler when invoked. + """ + if isinstance(event_type_or_handler, str) and handler is not None: + event_type = event_type_or_handler + if event_type not in self._typed_event_handlers: + self._typed_event_handlers[event_type] = set() + self._typed_event_handlers[event_type].add(handler) + + def _unsubscribe() -> None: + handlers = self._typed_event_handlers.get(event_type) + if handlers: + handlers.discard(handler) + + return _unsubscribe + + wildcard_handler = event_type_or_handler + assert callable(wildcard_handler) + self._event_handlers.add(wildcard_handler) + + def _unsubscribe() -> None: + self._event_handlers.discard(wildcard_handler) + + return _unsubscribe + + # ------------------------------------------------------------------ + # Sending messages & steering + # ------------------------------------------------------------------ + + async def send(self, *, prompt: str) -> None: + """Send a user message to the cloud session. + + Args: + prompt: The message text to send. + """ + self._assert_connected() + await self.submit_remote_command(MissionControlCommandType.USER_MESSAGE, prompt) + + async def send_and_wait( + self, + *, + prompt: str, + timeout: float | None = None, + ) -> CloudSessionEvent | None: + """Send a message and wait for the session to reach idle. + + Returns the last ``assistant.message`` event received before idle, + or ``None`` if the session became idle without an assistant message. + + Args: + prompt: The message text to send. + timeout: Maximum seconds to wait (default: 60). + """ + effective_timeout = timeout if timeout is not None else 60.0 + last_assistant_message: CloudSessionEvent | None = None + done = asyncio.Event() + error_holder: list[Exception] = [] + + def _handler(event: CloudSessionEvent) -> None: + nonlocal last_assistant_message + if event.type == "assistant.message": + last_assistant_message = event + elif event.type == "session.idle": + done.set() + elif event.type == "session.error": + msg = event.data.get("message", "Unknown error") if event.data else "Unknown error" + error_holder.append(Exception(msg)) + done.set() + + unsubscribe = self.on(_handler) + try: + await self.send(prompt=prompt) + await asyncio.wait_for(done.wait(), timeout=effective_timeout) + if error_holder: + raise error_holder[0] + return last_assistant_message + except TimeoutError as exc: + raise TimeoutError( + f"Timeout after {effective_timeout}s waiting for session.idle" + ) from exc + finally: + unsubscribe() + + async def abort(self) -> None: + """Abort the current cloud session operation.""" + self._assert_connected() + await self.submit_remote_command(MissionControlCommandType.ABORT) + + async def submit_remote_command( + self, + command_type: MissionControlCommandType, + content: str | None = None, + ) -> None: + """Send a raw steering command to Mission Control. + + Args: + command_type: The type of steering command. + content: Optional payload content. + """ + self._assert_connected() + if not self._remote_steerable: + raise RuntimeError("This session is read-only — remote steering is not enabled") + request: dict[str, Any] = {"type": command_type.value} + if content is not None: + request["content"] = content + await self._client.steer_task(self.metadata.task_id, request) + + # ------------------------------------------------------------------ + # Response helpers + # ------------------------------------------------------------------ + + async def respond_to_permission(self, payload: CloudPermissionResponsePayload) -> None: + """Respond to a permission request.""" + wire = _to_camel_case_dict(payload) + await self.submit_remote_command( + MissionControlCommandType.PERMISSION_RESPONSE, json.dumps(wire) + ) + + async def respond_to_ask_user(self, payload: CloudAskUserResponsePayload) -> None: + """Respond to an ask-user prompt.""" + wire = _to_camel_case_dict(payload) + await self.submit_remote_command( + MissionControlCommandType.ASK_USER_RESPONSE, json.dumps(wire) + ) + + async def respond_to_elicitation(self, payload: CloudElicitationResponsePayload) -> None: + """Respond to an elicitation prompt.""" + wire = _to_camel_case_dict(payload) + await self.submit_remote_command( + MissionControlCommandType.ELICITATION_RESPONSE, json.dumps(wire) + ) + + async def respond_to_exit_plan_mode(self, payload: CloudPlanApprovalResponsePayload) -> None: + """Respond to a plan approval prompt.""" + wire = _to_camel_case_dict(payload) + await self.submit_remote_command( + MissionControlCommandType.PLAN_APPROVAL_RESPONSE, json.dumps(wire) + ) + + async def switch_mode(self, payload: CloudModeSwitchPayload) -> None: + """Switch the cloud session mode.""" + await self.submit_remote_command( + MissionControlCommandType.MODE_SWITCH, json.dumps(dict(payload)) + ) + + async def respond_to_elicitation_result( + self, prompt_id: str, result: ElicitationResult + ) -> None: + """Convenience: respond to an elicitation with a prompt ID and result.""" + payload = CloudElicitationResponsePayload( + prompt_id=prompt_id, + **result, # type: ignore[typeddict-item] + ) + await self.respond_to_elicitation(payload) + + async def respond_to_plan_approval(self, prompt_id: str, result: ExitPlanModeResult) -> None: + """Convenience: respond to a plan approval with a prompt ID and result.""" + payload = CloudPlanApprovalResponsePayload( + prompt_id=prompt_id, + **result, # type: ignore[typeddict-item] + ) + await self.respond_to_exit_plan_mode(payload) + + # ------------------------------------------------------------------ + # Event access + # ------------------------------------------------------------------ + + def get_messages(self) -> list[CloudSessionEvent]: + """Return a copy of all events received so far.""" + return list(self._events) + + # ------------------------------------------------------------------ + # Disconnect + # ------------------------------------------------------------------ + + async def disconnect(self) -> None: + """Disconnect from the cloud session and stop event polling.""" + self._stop_event_polling() + self._event_handlers.clear() + self._typed_event_handlers.clear() + self._is_disconnected = True + + async def destroy(self) -> None: + """Alias for :meth:`disconnect`.""" + await self.disconnect() + + async def __aenter__(self) -> CloudSession: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: Any, + ) -> None: + await self.disconnect() + + # ------------------------------------------------------------------ + # Event polling internals + # ------------------------------------------------------------------ + + def _start_event_polling(self) -> None: + if self._poller_task is not None or self._is_disconnected: + return + + self._poller_task = asyncio.ensure_future(self._poll_loop()) + + def _stop_event_polling(self) -> None: + if self._poller_task is not None: + self._poller_task.cancel() + self._poller_task = None + + async def _poll_loop(self) -> None: + try: + while not self._is_disconnected: + await asyncio.sleep(self._poll_interval_s) + try: + await self._poll_events() + except Exception as exc: + self._report_poll_error(exc) + except asyncio.CancelledError: + pass + + async def _wait_for_initial_events(self) -> list[CloudSessionEvent]: + deadline = asyncio.get_event_loop().time() + self._initial_event_timeout_s + while True: + events = await self._client.list_task_events(self.metadata.task_id) + if events: + return _sort_events_chronologically(events) + if self._initial_event_timeout_s <= 0 or asyncio.get_event_loop().time() >= deadline: + return [] + await asyncio.sleep(self._initial_event_poll_interval_s) + + async def _poll_events(self) -> None: + if self._is_polling or self._is_disconnected: + return + self._is_polling = True + try: + events = await self._client.list_task_events(self.metadata.task_id) + new_events = self._collect_new_events(events) + self._record_events(new_events) + finally: + self._is_polling = False + + def _collect_new_events(self, events: list[CloudSessionEvent]) -> list[CloudSessionEvent]: + new: list[CloudSessionEvent] = [] + for event in events: + if event.id in self._seen_event_ids: + continue + if self._last_seen_timestamp is None: + new.append(event) + continue + order = _compare_strings(event.timestamp, self._last_seen_timestamp) + if order > 0: + new.append(event) + elif order == 0 and event.id not in self._seen_event_ids_at_last_timestamp: + new.append(event) + return _sort_events_chronologically(new) + + def _record_events(self, events: list[CloudSessionEvent]) -> None: + for event in _sort_events_chronologically(events): + if event.id in self._seen_event_ids: + continue + self._seen_event_ids.add(event.id) + self._events.append(event) + self._mark_event_as_seen_at_timestamp(event) + self._update_remote_steerable(event) + self._dispatch_event(event) + + def _mark_event_as_seen_at_timestamp(self, event: CloudSessionEvent) -> None: + if self._last_seen_timestamp != event.timestamp: + self._last_seen_timestamp = event.timestamp + self._seen_event_ids_at_last_timestamp = set() + self._seen_event_ids_at_last_timestamp.add(event.id) + + def _update_remote_steerable(self, event: CloudSessionEvent) -> None: + if event.type == "session.remote_steerable_changed" and event.data: + self._remote_steerable = bool(event.data.get("remoteSteerable", True)) + + def _dispatch_event(self, event: CloudSessionEvent) -> None: + typed_handlers = self._typed_event_handlers.get(event.type) + if typed_handlers: + for handler in list(typed_handlers): + try: + handler(event) + except Exception: + pass + + for handler in list(self._event_handlers): + try: + handler(event) + except Exception: + pass + + def _report_poll_error(self, error: Exception) -> None: + if self._on_event_poll_error: + self._on_event_poll_error(error) + + def _assert_connected(self) -> None: + if self._is_disconnected: + raise RuntimeError("Cloud session is disconnected") + + +# ------------------------------------------------------------------ +# Module-level helpers +# ------------------------------------------------------------------ + + +def _sort_events_chronologically( + events: list[CloudSessionEvent], +) -> list[CloudSessionEvent]: + return sorted(events, key=lambda e: (e.timestamp, e.id)) + + +def _compare_strings(a: str, b: str) -> int: + if a > b: + return 1 + if a < b: + return -1 + return 0 + + +def _snake_to_camel(name: str) -> str: + """Convert a snake_case name to camelCase.""" + parts = name.split("_") + return parts[0] + "".join(p.capitalize() for p in parts[1:]) + + +def _to_camel_case_dict(d: dict[str, Any]) -> dict[str, Any]: + """Convert a dict with snake_case keys to camelCase keys.""" + return {_snake_to_camel(k): v for k, v in d.items()} diff --git a/python/copilot/cloud/mission_control_client.py b/python/copilot/cloud/mission_control_client.py new file mode 100644 index 000000000..f17296335 --- /dev/null +++ b/python/copilot/cloud/mission_control_client.py @@ -0,0 +1,322 @@ +""" +HTTP client for the Mission Control API. + +Provides :class:`MissionControlClient` for creating cloud tasks, polling +task events, and steering tasks through the Mission Control REST API. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import urllib.error +import urllib.request +from typing import Any + +from .types import ( + CloudSessionEvent, + CloudSessionFailureReason, + MissionControlTask, +) + +logger = logging.getLogger(__name__) + +CLOUD_SANDBOX_AGENT_SLUG = "copilot-developer-sandbox" + +_DEFAULT_REQUEST_TIMEOUT_S = 10 +_DEFAULT_CREATE_CLOUD_TASK_TIMEOUT_S = 10 * 60 + + +class CloudSessionError(Exception): + """Error from a Mission Control API request. + + Attributes: + reason: Categorised failure reason. + status: HTTP status code, if available. + """ + + def __init__( + self, + message: str, + reason: CloudSessionFailureReason, + status: int | None = None, + ) -> None: + super().__init__(message) + self.reason = reason + self.status = status + + +class _CreateCloudTaskRepository: + """Repository reference sent in a create-task request body.""" + + __slots__ = ("owner", "name") + + def __init__(self, owner: str, name: str) -> None: + self.owner = owner + self.name = name + + +class _CreateCloudTaskParams: + """Parameters for creating a cloud task.""" + + __slots__ = ("owner", "repository") + + def __init__( + self, + owner: str | None = None, + repository: _CreateCloudTaskRepository | None = None, + ) -> None: + self.owner = owner + self.repository = repository + + +class MissionControlClient: + """HTTP client for the Mission Control task API. + + Args: + base_url: Base URL for the Mission Control API (e.g. ``https://api.githubcopilot.com/agents``). + auth_token: Bearer token for authentication. + integration_id: Copilot integration identifier. + frontend_base_url: Base URL for task frontend links. + request_timeout_s: Timeout for normal requests in seconds. + create_cloud_task_timeout_s: Timeout for task creation in seconds. + """ + + def __init__( + self, + *, + base_url: str, + auth_token: str | None = None, + integration_id: str | None = None, + frontend_base_url: str, + request_timeout_s: float | None = None, + create_cloud_task_timeout_s: float | None = None, + ) -> None: + self._base_url = base_url.rstrip("/") + self._auth_token = auth_token.strip() if auth_token and auth_token.strip() else None + self._integration_id = integration_id or "copilot-cli" + self._frontend_base_url = frontend_base_url.rstrip("/") + self._request_timeout_s = request_timeout_s or _DEFAULT_REQUEST_TIMEOUT_S + self._create_cloud_task_timeout_s = ( + create_cloud_task_timeout_s or _DEFAULT_CREATE_CLOUD_TASK_TIMEOUT_S + ) + + async def create_cloud_task( + self, params: _CreateCloudTaskParams | None = None + ) -> MissionControlTask: + """Create a new cloud sandbox task.""" + if params is None: + params = _CreateCloudTaskParams() + + body: dict[str, Any] = {} + if params.owner: + body["owner"] = params.owner + if params.repository: + body["repositories"] = [ + {"owner": params.repository.owner, "name": params.repository.name} + ] + + data = await self._request_json( + f"{self._base_url}/tasks", + method="POST", + headers=self._headers({"X-Copilot-Agent-Slug": CLOUD_SANDBOX_AGENT_SLUG}), + body=json.dumps(body), + timeout=self._create_cloud_task_timeout_s, + ) + return MissionControlTask.from_dict(data) + + async def list_task_events(self, task_id: str) -> list[CloudSessionEvent]: + """Poll task events from Mission Control.""" + encoded_id = urllib.request.quote(task_id, safe="") + data = await self._request_json( + f"{self._base_url}/tasks/{encoded_id}/events", + method="GET", + headers=self._headers(), + timeout=self._request_timeout_s, + ) + + events_raw = data.get("events") if isinstance(data, dict) else None + if not isinstance(events_raw, list): + raise CloudSessionError( + f"Unexpected Mission Control events response for task {task_id}", + "server", + ) + + return [CloudSessionEvent.from_dict(e) for e in events_raw if _is_cloud_session_event(e)] + + async def steer_task( + self, + task_id: str, + request: dict[str, Any], + ) -> None: + """Send a steering command to a running task.""" + encoded_id = urllib.request.quote(task_id, safe="") + await self._request_ok( + f"{self._base_url}/tasks/{encoded_id}/steer", + method="POST", + headers=self._headers(), + body=json.dumps(request), + timeout=self._request_timeout_s, + ) + + async def get_task(self, task_id: str) -> MissionControlTask | None: + """Get task metadata. Returns ``None`` if the task is not found.""" + encoded_id = urllib.request.quote(task_id, safe="") + try: + data = await self._request_json( + f"{self._base_url}/tasks/{encoded_id}", + method="GET", + headers=self._headers(), + timeout=self._request_timeout_s, + ) + return MissionControlTask.from_dict(data) + except CloudSessionError as exc: + if exc.status == 404: + return None + raise + + def get_frontend_url(self, task_id: str) -> str: + """Build the frontend URL for a task.""" + encoded_id = urllib.request.quote(task_id, safe="") + return f"{self._frontend_base_url}/copilot/tasks/{encoded_id}" + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _headers(self, extra: dict[str, str] | None = None) -> dict[str, str]: + headers: dict[str, str] = { + "Content-Type": "application/json", + "Copilot-Integration-Id": self._integration_id, + } + if self._auth_token: + headers["Authorization"] = f"Bearer {self._auth_token}" + if extra: + headers.update(extra) + return headers + + async def _request_json( + self, + url: str, + *, + method: str, + headers: dict[str, str], + body: str | None = None, + timeout: float, + ) -> Any: + response_body = await self._request_ok( + url, method=method, headers=headers, body=body, timeout=timeout + ) + if not response_body: + return None + try: + return json.loads(response_body) + except json.JSONDecodeError as exc: + raise CloudSessionError( + f"Mission Control returned invalid JSON: {exc}", + "server", + ) from exc + + async def _request_ok( + self, + url: str, + *, + method: str, + headers: dict[str, str], + body: str | None = None, + timeout: float, + ) -> str: + """Perform an HTTP request and return the response body text. + + Raises :class:`CloudSessionError` on non-2xx responses, timeouts, + and network errors. + """ + try: + return await asyncio.wait_for( + self._do_request(url, method=method, headers=headers, body=body), + timeout=timeout, + ) + except TimeoutError as exc: + raise CloudSessionError( + "Mission Control request timed out", + "timeout", + ) from exc + except CloudSessionError: + raise + except Exception as exc: + raise CloudSessionError( + f"Mission Control request failed: {exc}", + "network", + ) from exc + + @staticmethod + async def _do_request( + url: str, + *, + method: str, + headers: dict[str, str], + body: str | None = None, + ) -> str: + """Execute the HTTP request in a thread to avoid blocking the event loop.""" + loop = asyncio.get_running_loop() + + def _sync_request() -> str: + data = body.encode("utf-8") if body else None + req = urllib.request.Request(url, data=data, headers=headers, method=method) + try: + with urllib.request.urlopen(req) as resp: + return resp.read().decode("utf-8") + except urllib.error.HTTPError as http_err: + error_body = "" + try: + error_body = http_err.read().decode("utf-8") + except Exception: + pass + message = _extract_mission_control_message(error_body) or ( + f"Mission Control request failed with HTTP {http_err.code}" + ) + raise CloudSessionError( + message, + _reason_for_status(http_err.code), + http_err.code, + ) from http_err + + return await loop.run_in_executor(None, _sync_request) + + +# ------------------------------------------------------------------ +# Module-level helpers +# ------------------------------------------------------------------ + + +def _reason_for_status(status: int) -> CloudSessionFailureReason: + if status == 403: + return "policy_blocked" + if status in (400, 422): + return "validation" + return "server" + + +def _extract_mission_control_message(text: str) -> str | None: + if not text: + return None + try: + parsed = json.loads(text) + if isinstance(parsed, dict): + msg = parsed.get("message") + if isinstance(msg, str) and msg: + return msg + except json.JSONDecodeError: + pass + return text + + +def _is_cloud_session_event(value: Any) -> bool: + if not isinstance(value, dict): + return False + return ( + isinstance(value.get("id"), str) + and isinstance(value.get("timestamp"), str) + and isinstance(value.get("type"), str) + ) diff --git a/python/copilot/cloud/types.py b/python/copilot/cloud/types.py new file mode 100644 index 000000000..7edc1d466 --- /dev/null +++ b/python/copilot/cloud/types.py @@ -0,0 +1,309 @@ +""" +Type definitions for the cloud session API. + +These types mirror the Node SDK's cloud session types, adapted to Python +conventions (snake_case, dataclasses, TypedDict, enums). +""" + +from __future__ import annotations + +import enum +from collections.abc import Callable +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any, Literal, NotRequired, TypedDict + +# ============================================================================ +# Repository & Progress +# ============================================================================ + + +class CloudRepository(TypedDict): + """Repository context used when creating a cloud sandbox task.""" + + owner: str + name: str + branch: NotRequired[str] + + +CloudProgressPhase = Literal[ + "creating_task", + "provisioning_sandbox", + "waiting_for_session", + "connected", +] + + +@dataclass +class CloudProgressEvent: + """Progress phases emitted while creating or attaching to a cloud session.""" + + phase: CloudProgressPhase + elapsed_ms: float | None = None + task_id: str | None = None + + +CloudSessionFailureReason = Literal[ + "policy_blocked", + "validation", + "timeout", + "network", + "server", +] + + +# ============================================================================ +# Mission Control Task +# ============================================================================ + + +@dataclass +class MissionControlTaskSession: + """A session within a Mission Control task.""" + + id: str + task_id: str + state: str + created_at: str + updated_at: str + owner_id: int + agent_task_id: str | None = None + name: str | None = None + repo_id: int | None = None + + @staticmethod + def from_dict(data: dict[str, Any]) -> MissionControlTaskSession: + return MissionControlTaskSession( + id=data["id"], + task_id=data["task_id"], + state=data["state"], + created_at=data["created_at"], + updated_at=data["updated_at"], + owner_id=data["owner_id"], + agent_task_id=data.get("agent_task_id"), + name=data.get("name"), + repo_id=data.get("repo_id"), + ) + + +@dataclass +class MissionControlTask: + """Represents a Mission Control task.""" + + id: str + name: str + state: str + status: str + creator_id: int + owner_id: int + session_count: int + created_at: str + updated_at: str + repo_id: int | None = None + sessions: list[MissionControlTaskSession] = field(default_factory=list) + + @staticmethod + def from_dict(data: dict[str, Any]) -> MissionControlTask: + sessions_data = data.get("sessions") or [] + return MissionControlTask( + id=data["id"], + name=data["name"], + state=data["state"], + status=data["status"], + creator_id=data["creator_id"], + owner_id=data["owner_id"], + session_count=data["session_count"], + created_at=data["created_at"], + updated_at=data["updated_at"], + repo_id=data.get("repo_id"), + sessions=[MissionControlTaskSession.from_dict(s) for s in sessions_data], + ) + + +# ============================================================================ +# Cloud Session Metadata +# ============================================================================ + + +@dataclass +class CloudSessionMetadata: + """Metadata about a cloud session, populated from a Mission Control task.""" + + task_id: str + frontend_url: str + created_at: datetime + updated_at: datetime + mission_control_session_id: str | None = None + owner: str | None = None + repository: CloudRepository | None = None + state: str | None = None + status: str | None = None + + +# ============================================================================ +# Cloud Session Events +# ============================================================================ + + +@dataclass +class CloudSessionEvent: + """A single event from a cloud session's event stream. + + Cloud session events include standard session events plus cloud-specific + events like ``session.requested``. + """ + + id: str + timestamp: str + type: str + parent_id: str | None = None + data: dict[str, Any] | None = None + ephemeral: bool | None = None + + @staticmethod + def from_dict(data: dict[str, Any]) -> CloudSessionEvent: + return CloudSessionEvent( + id=data["id"], + timestamp=data["timestamp"], + type=data["type"], + parent_id=data.get("parentId"), + data=data.get("data"), + ephemeral=data.get("ephemeral"), + ) + + +CloudSessionEventHandler = Callable[[CloudSessionEvent], None] +"""Event handler callback type for cloud session events.""" + + +# ============================================================================ +# Mission Control Command Types +# ============================================================================ + + +class MissionControlCommandType(enum.StrEnum): + """Command types for steering a cloud session through Mission Control.""" + + USER_MESSAGE = "user_message" + ASK_USER_RESPONSE = "ask_user_response" + PLAN_APPROVAL_RESPONSE = "plan_approval_response" + PERMISSION_RESPONSE = "permission_response" + ELICITATION_RESPONSE = "elicitation_response" + ABORT = "abort" + MODE_SWITCH = "mode_switch" + + +# ============================================================================ +# Steering Payloads +# ============================================================================ + + +class CloudAskUserResponsePayload(TypedDict): + """Payload for responding to an ask_user prompt.""" + + prompt_id: str + answer: str + was_freeform: bool + dismissed: NotRequired[bool] + + +class CloudPlanApprovalResponsePayload(TypedDict): + """Payload for responding to a plan approval prompt.""" + + prompt_id: str + approved: bool + selected_action: NotRequired[str] + auto_approve_edits: NotRequired[bool] + feedback: NotRequired[str] + + +class CloudPermissionResponsePayload(TypedDict): + """Payload for responding to a permission prompt.""" + + prompt_id: str + approved: bool + scope: Literal["once", "session"] + + +ElicitationFieldValue = str | int | float | bool | list[str] +"""Primitive field value in an elicitation result.""" + + +class CloudElicitationResponsePayload(TypedDict): + """Payload for responding to an elicitation prompt.""" + + prompt_id: str + action: Literal["accept", "decline", "cancel"] + content: NotRequired[dict[str, ElicitationFieldValue]] + + +class CloudModeSwitchPayload(TypedDict): + """Payload for switching the session mode.""" + + mode: Literal["interactive", "plan", "autopilot"] + + +class ElicitationResult(TypedDict): + """Result returned from an elicitation request.""" + + action: Literal["accept", "decline", "cancel"] + content: NotRequired[dict[str, ElicitationFieldValue]] + + +class ExitPlanModeResult(TypedDict): + """Result returned from an exit-plan-mode request.""" + + approved: bool + selected_action: NotRequired[str] + feedback: NotRequired[str] + + +# ============================================================================ +# Options Types +# ============================================================================ + + +class CloudSessionOptions(TypedDict, total=False): + """Options for creating a new cloud session. + + Either ``repository`` or ``owner`` must be provided. If ``repository`` is + omitted, ``owner`` is required for billing/authorization. + """ + + owner: str + """Billing/authorization owner for repo-less cloud sandboxes.""" + + repository: CloudRepository + """Repository context for the cloud sandbox.""" + + mission_control_base_url: str + copilot_api_base_url: str + frontend_base_url: str + auth_token: str + integration_id: str + poll_interval_ms: int + initial_event_timeout_ms: int + initial_event_poll_interval_ms: int + on_progress: Callable[[CloudProgressEvent], None] + on_cloud_task_created: Callable[[MissionControlTask], None] + on_event_poll_error: Callable[[Exception], None] + + +class CloudConnectOptions(TypedDict, total=False): + """Options for connecting to an existing cloud session. + + Same as :class:`CloudSessionOptions` but ``repository`` is optional. + """ + + owner: str + repository: CloudRepository + mission_control_base_url: str + copilot_api_base_url: str + frontend_base_url: str + auth_token: str + integration_id: str + poll_interval_ms: int + initial_event_timeout_ms: int + initial_event_poll_interval_ms: int + on_progress: Callable[[CloudProgressEvent], None] + on_event_poll_error: Callable[[Exception], None] diff --git a/python/test_cloud_session.py b/python/test_cloud_session.py new file mode 100644 index 000000000..860680900 --- /dev/null +++ b/python/test_cloud_session.py @@ -0,0 +1,670 @@ +""" +Cloud Session Unit Tests + +Tests for the cloud session SDK API: Mission Control task creation, +event polling, steering, and error handling. +""" + +from __future__ import annotations + +import asyncio +import json +from datetime import UTC +from typing import Any +from unittest.mock import patch + +import pytest + +from copilot import ( + CloudSession, + CloudSessionError, + CloudSessionEvent, + CopilotClient, + MissionControlCommandType, + MissionControlTask, + MissionControlTaskSession, + SubprocessConfig, +) +from copilot.cloud.mission_control_client import MissionControlClient + +# --------------------------------------------------------------------------- +# Fixtures / helpers +# --------------------------------------------------------------------------- + +TASK = MissionControlTask( + id="task-1", + name="Cloud task", + state="running", + status="ready", + creator_id=1, + owner_id=2, + session_count=1, + created_at="2026-05-11T10:00:00.000Z", + updated_at="2026-05-11T10:01:00.000Z", + repo_id=3, + sessions=[ + MissionControlTaskSession( + id="mc-session-1", + task_id="task-1", + state="running", + created_at="2026-05-11T10:00:30.000Z", + updated_at="2026-05-11T10:00:30.000Z", + owner_id=2, + repo_id=3, + ) + ], +) + +REQUESTED_EVENT = CloudSessionEvent( + id="event-1", + timestamp="2026-05-11T10:00:00.000Z", + type="session.requested", + parent_id=None, +) + +IDLE_EVENT = CloudSessionEvent( + id="event-2", + timestamp="2026-05-11T10:00:01.000Z", + type="session.idle", + parent_id="event-1", + data={}, +) + + +def _task_dict() -> dict[str, Any]: + """Return a JSON-serialisable dict for the test task.""" + return { + "id": TASK.id, + "name": TASK.name, + "state": TASK.state, + "status": TASK.status, + "creator_id": TASK.creator_id, + "owner_id": TASK.owner_id, + "session_count": TASK.session_count, + "created_at": TASK.created_at, + "updated_at": TASK.updated_at, + "repo_id": TASK.repo_id, + "sessions": [ + { + "id": s.id, + "task_id": s.task_id, + "state": s.state, + "created_at": s.created_at, + "updated_at": s.updated_at, + "owner_id": s.owner_id, + "repo_id": s.repo_id, + } + for s in TASK.sessions + ], + } + + +def _event_dict(event: CloudSessionEvent) -> dict[str, Any]: + """Return the wire-format dict for an event.""" + d: dict[str, Any] = { + "id": event.id, + "timestamp": event.timestamp, + "type": event.type, + "parentId": event.parent_id, + } + if event.data is not None: + d["data"] = event.data + if event.ephemeral is not None: + d["ephemeral"] = event.ephemeral + return d + + +class _FakeHTTPResponse: + """Simulates urllib responses for mocking.""" + + def __init__(self, body: str, status: int = 200) -> None: + self._body = body.encode("utf-8") + self.status = status + self.code = status + + def read(self) -> bytes: + return self._body + + def __enter__(self) -> _FakeHTTPResponse: + return self + + def __exit__(self, *args: Any) -> None: + pass + + +def _make_url_responses(responses: list[tuple[str, int]]) -> Any: + """Create a side_effect for urllib.request.urlopen that returns successive responses.""" + import urllib.error + + call_idx = 0 + + def _side_effect(req: Any) -> _FakeHTTPResponse: + nonlocal call_idx + if call_idx >= len(responses): + raise RuntimeError("Unexpected HTTP request") + body, status = responses[call_idx] + call_idx += 1 + if status >= 400: + error = urllib.error.HTTPError( + req.full_url if hasattr(req, "full_url") else str(req), + status, + "Error", + {}, # type: ignore[arg-type] + None, + ) + # Patch the read method to return the body + error.read = lambda: body.encode("utf-8") # type: ignore[assignment] + raise error + return _FakeHTTPResponse(body, status) + + return _side_effect + + +def _make_url_responses_tracking( + responses: list[tuple[str, int]], +) -> tuple[Any, list[Any]]: + """Like _make_url_responses but also returns a list of captured requests.""" + import urllib.error + + captured: list[Any] = [] + call_idx = 0 + + def _side_effect(req: Any) -> _FakeHTTPResponse: + nonlocal call_idx + captured.append(req) + if call_idx >= len(responses): + raise RuntimeError("Unexpected HTTP request") + body, status = responses[call_idx] + call_idx += 1 + if status >= 400: + error = urllib.error.HTTPError( + req.full_url if hasattr(req, "full_url") else str(req), + status, + "Error", + {}, # type: ignore[arg-type] + None, + ) + error.read = lambda: body.encode("utf-8") # type: ignore[assignment] + raise error + return _FakeHTTPResponse(body, status) + + return _side_effect, captured + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestCloudSessions: + @pytest.mark.asyncio + async def test_creates_mission_control_cloud_task_and_attaches(self) -> None: + """Creates a Mission Control cloud task and attaches to task events.""" + responses = [ + (json.dumps(_task_dict()), 200), + (json.dumps({"events": [_event_dict(REQUESTED_EVENT)]}), 200), + ] + side_effect, captured = _make_url_responses_tracking(responses) + progress: list[str] = [] + + client = CopilotClient( + SubprocessConfig( + cli_path="/dev/null", + github_token="token-1", + env={ + "COPILOT_MC_BASE_URL": "https://mc.test/agents", + "COPILOT_MC_FRONTEND_URL": "https://github.test", + }, + ), + auto_start=False, + ) + + with patch("urllib.request.urlopen", side_effect=side_effect): + session = await client.create_cloud_session( + { + "repository": {"owner": "github", "name": "copilot-sdk", "branch": "main"}, + "initial_event_timeout_ms": 0, + "on_progress": lambda event: progress.append(event.phase), + } + ) + + assert session.metadata.task_id == "task-1" + assert session.metadata.mission_control_session_id == "mc-session-1" + assert session.metadata.frontend_url == "https://github.test/copilot/tasks/task-1" + assert session.metadata.repository == { + "owner": "github", + "name": "copilot-sdk", + "branch": "main", + } + assert session.metadata.state == "running" + assert session.metadata.status == "ready" + + messages = session.get_messages() + assert len(messages) == 1 + assert messages[0].id == "event-1" + assert messages[0].type == "session.requested" + + assert progress == [ + "creating_task", + "provisioning_sandbox", + "waiting_for_session", + "connected", + ] + + # Verify the create-task request + create_req = captured[0] + assert create_req.full_url == "https://mc.test/agents/tasks" + assert create_req.method == "POST" + assert create_req.get_header("Authorization") == "Bearer token-1" + assert create_req.get_header("X-copilot-agent-slug") == "copilot-developer-sandbox" + assert json.loads(create_req.data) == { + "repositories": [{"owner": "github", "name": "copilot-sdk"}], + } + + # Verify the events request + events_req = captured[1] + assert events_req.full_url == "https://mc.test/agents/tasks/task-1/events" + assert events_req.method == "GET" + + await session.disconnect() + + @pytest.mark.asyncio + async def test_creates_repo_less_cloud_task(self) -> None: + """Creates a repo-less cloud task when owner is provided.""" + responses = [ + (json.dumps(_task_dict()), 200), + (json.dumps({"events": []}), 200), + ] + side_effect, captured = _make_url_responses_tracking(responses) + + client = CopilotClient( + SubprocessConfig( + cli_path="/dev/null", env={"COPILOT_MC_BASE_URL": "https://mc.test/agents"} + ), + auto_start=False, + ) + + with patch("urllib.request.urlopen", side_effect=side_effect): + session = await client.create_cloud_session( + {"owner": "github", "initial_event_timeout_ms": 0} + ) + + assert session.metadata.owner == "github" + + # Verify the create-task request body + create_req = captured[0] + assert json.loads(create_req.data) == {"owner": "github"} + + await session.disconnect() + + @pytest.mark.asyncio + async def test_requires_owner_when_no_repository(self) -> None: + """Requires an owner when creating a repo-less cloud task.""" + client = CopilotClient( + SubprocessConfig( + cli_path="/dev/null", env={"COPILOT_MC_BASE_URL": "https://mc.test/agents"} + ), + auto_start=False, + ) + + with pytest.raises(ValueError, match="owner is required when repository is omitted"): + await client.create_cloud_session({"initial_event_timeout_ms": 0}) + + @pytest.mark.asyncio + async def test_sends_user_messages_through_steer_api(self) -> None: + """Sends cloud session user messages through the Mission Control steer API.""" + responses = [ + ("", 404), # get_task returns 404 + (json.dumps({"events": []}), 200), # list_task_events + ("", 200), # steer + ] + side_effect, captured = _make_url_responses_tracking(responses) + + client = CopilotClient( + SubprocessConfig( + cli_path="/dev/null", env={"COPILOT_MC_BASE_URL": "https://mc.test/agents"} + ), + auto_start=False, + ) + + with patch("urllib.request.urlopen", side_effect=side_effect): + session = await client.connect_cloud_session("task-1", {"initial_event_timeout_ms": 0}) + await session.send(prompt="hello cloud") + + steer_req = captured[2] + assert steer_req.full_url == "https://mc.test/agents/tasks/task-1/steer" + assert steer_req.method == "POST" + assert json.loads(steer_req.data) == { + "type": "user_message", + "content": "hello cloud", + } + + await session.disconnect() + + @pytest.mark.asyncio + async def test_sorts_and_deduplicates_events(self) -> None: + """Sorts replayed events and deduplicates events observed during polling.""" + polled_event = CloudSessionEvent( + id="event-3", + timestamp="2026-05-11T10:00:02.000Z", + type="session.idle", + parent_id="event-2", + data={}, + ) + # Initial connect: get_task returns task, list_task_events returns events out of order + responses_connect = [ + (json.dumps(_task_dict()), 200), + ( + json.dumps({"events": [_event_dict(IDLE_EVENT), _event_dict(REQUESTED_EVENT)]}), + 200, + ), + ] + + side_effect_connect, _ = _make_url_responses_tracking(responses_connect) + + client = CopilotClient( + SubprocessConfig( + cli_path="/dev/null", env={"COPILOT_MC_BASE_URL": "https://mc.test/agents"} + ), + auto_start=False, + ) + + with patch("urllib.request.urlopen", side_effect=side_effect_connect): + session = await client.connect_cloud_session( + "task-1", + {"initial_event_timeout_ms": 0, "poll_interval_ms": 50}, + ) + + # Events should be sorted chronologically + message_ids = [e.id for e in session.get_messages()] + assert message_ids == ["event-1", "event-2"] + + # Now set up poll responses (includes previously seen events + new one) + poll_responses = [ + ( + json.dumps( + { + "events": [ + _event_dict(IDLE_EVENT), + _event_dict(REQUESTED_EVENT), + _event_dict(polled_event), + ] + } + ), + 200, + ), + ] + poll_side_effect, _ = _make_url_responses_tracking(poll_responses) + + seen: list[str] = [] + session.on(lambda event: seen.append(event.id)) + + with patch("urllib.request.urlopen", side_effect=poll_side_effect): + # Wait for the poller to run + await asyncio.sleep(0.15) + + # Only the new event should have been dispatched to the handler + assert seen == ["event-3"] + # All events should be in order + assert [e.id for e in session.get_messages()] == ["event-1", "event-2", "event-3"] + + await session.disconnect() + + @pytest.mark.asyncio + async def test_surfaces_error_responses_as_cloud_session_errors(self) -> None: + """Surfaces Mission Control error responses as typed CloudSessionError.""" + responses = [ + (json.dumps({"message": "blocked"}), 403), + ] + side_effect, _ = _make_url_responses_tracking(responses) + + client = CopilotClient( + SubprocessConfig( + cli_path="/dev/null", env={"COPILOT_MC_BASE_URL": "https://mc.test/agents"} + ), + auto_start=False, + ) + + with patch("urllib.request.urlopen", side_effect=side_effect): + with pytest.raises(CloudSessionError) as exc_info: + await client.create_cloud_session( + { + "repository": {"owner": "github", "name": "copilot-sdk"}, + "initial_event_timeout_ms": 0, + } + ) + + err = exc_info.value + assert str(err) == "blocked" + assert err.reason == "policy_blocked" + assert err.status == 403 + + @pytest.mark.asyncio + async def test_connect_cloud_session_with_existing_task(self) -> None: + """connectCloudSession populates metadata from existing task.""" + responses = [ + (json.dumps(_task_dict()), 200), # get_task + (json.dumps({"events": [_event_dict(REQUESTED_EVENT)]}), 200), + ] + side_effect, _ = _make_url_responses_tracking(responses) + + client = CopilotClient( + SubprocessConfig( + cli_path="/dev/null", + env={ + "COPILOT_MC_BASE_URL": "https://mc.test/agents", + "COPILOT_MC_FRONTEND_URL": "https://github.test", + }, + ), + auto_start=False, + ) + + with patch("urllib.request.urlopen", side_effect=side_effect): + session = await client.connect_cloud_session("task-1", {"initial_event_timeout_ms": 0}) + + assert session.metadata.task_id == "task-1" + assert session.metadata.mission_control_session_id == "mc-session-1" + assert session.metadata.frontend_url == "https://github.test/copilot/tasks/task-1" + + await session.disconnect() + + @pytest.mark.asyncio + async def test_connect_cloud_session_with_missing_task(self) -> None: + """connectCloudSession uses fallback metadata when task not found.""" + responses = [ + ("", 404), # get_task returns 404 + (json.dumps({"events": []}), 200), + ] + side_effect, _ = _make_url_responses_tracking(responses) + + client = CopilotClient( + SubprocessConfig( + cli_path="/dev/null", + env={ + "COPILOT_MC_BASE_URL": "https://mc.test/agents", + "COPILOT_MC_FRONTEND_URL": "https://github.test", + }, + ), + auto_start=False, + ) + + with patch("urllib.request.urlopen", side_effect=side_effect): + session = await client.connect_cloud_session( + "task-missing", {"initial_event_timeout_ms": 0} + ) + + assert session.metadata.task_id == "task-missing" + assert session.metadata.mission_control_session_id is None + assert session.metadata.frontend_url == "https://github.test/copilot/tasks/task-missing" + + await session.disconnect() + + @pytest.mark.asyncio + async def test_disconnect_prevents_further_sends(self) -> None: + """Disconnected sessions reject send calls.""" + responses = [ + ("", 404), + (json.dumps({"events": []}), 200), + ] + side_effect, _ = _make_url_responses_tracking(responses) + + client = CopilotClient( + SubprocessConfig( + cli_path="/dev/null", env={"COPILOT_MC_BASE_URL": "https://mc.test/agents"} + ), + auto_start=False, + ) + + with patch("urllib.request.urlopen", side_effect=side_effect): + session = await client.connect_cloud_session("task-1", {"initial_event_timeout_ms": 0}) + + await session.disconnect() + + with pytest.raises(RuntimeError, match="disconnected"): + await session.send(prompt="should fail") + + @pytest.mark.asyncio + async def test_mission_control_command_types(self) -> None: + """Verify MissionControlCommandType enum values.""" + assert MissionControlCommandType.USER_MESSAGE.value == "user_message" + assert MissionControlCommandType.ASK_USER_RESPONSE.value == "ask_user_response" + assert MissionControlCommandType.PLAN_APPROVAL_RESPONSE.value == "plan_approval_response" + assert MissionControlCommandType.PERMISSION_RESPONSE.value == "permission_response" + assert MissionControlCommandType.ELICITATION_RESPONSE.value == "elicitation_response" + assert MissionControlCommandType.ABORT.value == "abort" + assert MissionControlCommandType.MODE_SWITCH.value == "mode_switch" + + @pytest.mark.asyncio + async def test_typed_event_handlers(self) -> None: + """Typed event handlers only receive matching event types.""" + responses = [ + (json.dumps(_task_dict()), 200), + ( + json.dumps({"events": [_event_dict(REQUESTED_EVENT), _event_dict(IDLE_EVENT)]}), + 200, + ), + ] + side_effect, _ = _make_url_responses_tracking(responses) + + client = CopilotClient( + SubprocessConfig( + cli_path="/dev/null", env={"COPILOT_MC_BASE_URL": "https://mc.test/agents"} + ), + auto_start=False, + ) + + idle_events: list[CloudSessionEvent] = [] + + with patch("urllib.request.urlopen", side_effect=side_effect): + session = await client.connect_cloud_session( + "task-1", {"initial_event_timeout_ms": 0, "poll_interval_ms": 50} + ) + + # Register typed handler — events from connect() are already recorded + session.on("session.idle", lambda e: idle_events.append(e)) + + # Send events again via polling to trigger typed handler + poll_new_idle = CloudSessionEvent( + id="event-4", + timestamp="2026-05-11T10:00:04.000Z", + type="session.idle", + data={}, + ) + poll_responses = [ + ( + json.dumps({"events": [_event_dict(poll_new_idle)]}), + 200, + ), + ] + poll_side_effect, _ = _make_url_responses_tracking(poll_responses) + + with patch("urllib.request.urlopen", side_effect=poll_side_effect): + await asyncio.sleep(0.15) + + assert len(idle_events) == 1 + assert idle_events[0].id == "event-4" + + await session.disconnect() + + @pytest.mark.asyncio + async def test_event_handler_unsubscribe(self) -> None: + """Unsubscribing removes the handler from future events.""" + mc_client = MissionControlClient( + base_url="https://mc.test/agents", + frontend_base_url="https://github.test", + ) + + from datetime import datetime + + from copilot.cloud.types import CloudSessionMetadata + + metadata = CloudSessionMetadata( + task_id="task-1", + frontend_url="https://github.test/copilot/tasks/task-1", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + session = CloudSession( + client=mc_client, + metadata=metadata, + initial_event_timeout_ms=0, + ) + + # Manually connect with empty events + with patch( + "urllib.request.urlopen", + side_effect=_make_url_responses([(json.dumps({"events": []}), 200)]), + ): + await session.connect() + + seen: list[str] = [] + unsubscribe = session.on(lambda e: seen.append(e.id)) + + # Dispatch an event manually + session._record_events( + [CloudSessionEvent(id="e1", timestamp="2026-01-01T00:00:00Z", type="test")] + ) + assert seen == ["e1"] + + # Unsubscribe and dispatch another + unsubscribe() + session._record_events( + [CloudSessionEvent(id="e2", timestamp="2026-01-01T00:00:01Z", type="test")] + ) + assert seen == ["e1"] # Still only e1 + + await session.disconnect() + + @pytest.mark.asyncio + async def test_context_manager(self) -> None: + """CloudSession works as an async context manager.""" + mc_client = MissionControlClient( + base_url="https://mc.test/agents", + frontend_base_url="https://github.test", + ) + + from datetime import datetime + + from copilot.cloud.types import CloudSessionMetadata + + metadata = CloudSessionMetadata( + task_id="task-1", + frontend_url="https://github.test/copilot/tasks/task-1", + created_at=datetime.now(UTC), + updated_at=datetime.now(UTC), + ) + + with patch( + "urllib.request.urlopen", + side_effect=_make_url_responses([(json.dumps({"events": []}), 200)]), + ): + async with CloudSession( + client=mc_client, + metadata=metadata, + initial_event_timeout_ms=0, + ) as session: + await session.connect() + assert not session._is_disconnected + + assert session._is_disconnected