diff --git a/centml/cli/main.py b/centml/cli/main.py index b1ecc73..fb5b932 100644 --- a/centml/cli/main.py +++ b/centml/cli/main.py @@ -2,6 +2,7 @@ from centml.cli.login import login, logout from centml.cli.cluster import ls, get, delete, pause, resume +from centml.cli.shell import shell, exec_cmd @click.group() @@ -47,6 +48,8 @@ def ccluster(): ccluster.add_command(delete) ccluster.add_command(pause) ccluster.add_command(resume) +ccluster.add_command(shell) +ccluster.add_command(exec_cmd, name="exec") cli.add_command(ccluster, name="cluster") diff --git a/centml/cli/shell.py b/centml/cli/shell.py new file mode 100644 index 0000000..34abf6d --- /dev/null +++ b/centml/cli/shell.py @@ -0,0 +1,602 @@ +"""CLI commands for interactive shell and command execution in deployment pods.""" + +import asyncio +import json +import logging +import os +import shutil +import signal +import sys +import termios +import tty +import urllib.parse + +import click +import pyte +import websockets + +from centml.cli.cluster import handle_exception +from centml.sdk import PodStatus, auth +from centml.sdk.api import get_centml_client +from centml.sdk.config import settings + +_log = logging.getLogger("centml.cli.shell") + + +def _setup_debug_log(): + """Configure file-based debug logging (stdout unusable in raw mode).""" + log_path = os.environ.get( + "CENTML_SHELL_DEBUG_LOG", "/tmp/centml_shell_debug.log" + ) + handler = logging.FileHandler(log_path, mode="w") + handler.setFormatter( + logging.Formatter("%(asctime)s.%(msecs)03d %(message)s", datefmt="%H:%M:%S") + ) + _log.addHandler(handler) + _log.setLevel(logging.DEBUG) + _log.debug("=== shell debug log started (pid=%d) ===", os.getpid()) + + +# --------------------------------------------------------------------------- +# pyte screen renderer -- converts pyte's in-memory screen buffer to ANSI +# escape sequences for the local terminal. +# --------------------------------------------------------------------------- + +_PYTE_FG_TO_SGR = { + "default": "39", + "black": "30", + "red": "31", + "green": "32", + "brown": "33", + "blue": "34", + "magenta": "35", + "cyan": "36", + "white": "37", + "brightblack": "90", + "brightred": "91", + "brightgreen": "92", + "brightbrown": "93", + "brightblue": "94", + "brightmagenta": "95", + "brightcyan": "96", + "brightwhite": "97", +} + +_PYTE_BG_TO_SGR = { + "default": "49", + "black": "40", + "red": "41", + "green": "42", + "brown": "43", + "blue": "44", + "magenta": "45", + "cyan": "46", + "white": "47", + "brightblack": "100", + "brightred": "101", + "brightgreen": "102", + "brightbrown": "103", + "brightblue": "104", + "brightmagenta": "105", + "brightcyan": "106", + "brightwhite": "107", +} + + +def _color_sgr(color, is_bg=False): + """Convert a pyte color value to an SGR parameter string.""" + table = _PYTE_BG_TO_SGR if is_bg else _PYTE_FG_TO_SGR + if color in table: + default_val = "49" if is_bg else "39" + code = table[color] + return code if code != default_val else "" + # 6-char hex -> truecolor + if len(color) == 6: + try: + r, g, b = int(color[:2], 16), int(color[2:4], 16), int(color[4:], 16) + prefix = "48" if is_bg else "38" + return f"{prefix};2;{r};{g};{b}" + except ValueError: + return "" + return "" + + +def _char_to_sgr(char): + """Build the ANSI SGR parameter string for a pyte Char's attributes.""" + parts = [] + if char.bold: + parts.append("1") + if char.italics: + parts.append("3") + if char.underscore: + parts.append("4") + if char.blink: + parts.append("5") + if char.reverse: + parts.append("7") + if char.strikethrough: + parts.append("9") + fg = _color_sgr(char.fg, is_bg=False) + if fg: + parts.append(fg) + bg = _color_sgr(char.bg, is_bg=True) + if bg: + parts.append(bg) + return ";".join(parts) + + +def _render_dirty(screen, output): + """Render only the dirty lines from the pyte Screen to the terminal. + + Args: + screen: pyte.Screen instance. + output: Writable binary stream (e.g. sys.stdout.buffer). + """ + parts = [] + for row in sorted(screen.dirty): + # Position cursor at row (1-based), column 1; clear line. + parts.append(f"\033[{row + 1};1H\033[2K") + prev_sgr = "" + line_chars = [] + for col in range(screen.columns): + char = screen.buffer[row][col] + if char.data == "": + continue + sgr = _char_to_sgr(char) + if sgr != prev_sgr: + line_chars.append(f"\033[0m\033[{sgr}m" if sgr else "\033[0m") + prev_sgr = sgr + line_chars.append(char.data) + text = "".join(line_chars).rstrip() + parts.append(text) + # Reset attributes, position cursor. + parts.append("\033[0m") + parts.append(f"\033[{screen.cursor.y + 1};{screen.cursor.x + 1}H") + if screen.cursor.hidden: + parts.append("\033[?25l") + else: + parts.append("\033[?25h") + screen.dirty.clear() + output.write("".join(parts).encode("utf-8")) + output.flush() + + +def _build_ws_url(api_url, deployment_id, pod_name, shell_type=None): + """Build the WebSocket URL for a terminal connection.""" + parsed = urllib.parse.urlparse(api_url) + ws_scheme = "wss" if parsed.scheme == "https" else "ws" + ws_base = parsed._replace(scheme=ws_scheme).geturl() + url = f"{ws_base}/deployments/{deployment_id}/terminal?pod={urllib.parse.quote(pod_name)}" + if shell_type: + url += f"&shell={urllib.parse.quote(shell_type)}" + return url + + +def _resolve_pod(cclient, deployment_id, pod_name=None): + """Resolve which pod to connect to. + + Args: + cclient: CentMLClient instance. + deployment_id: The deployment ID. + pod_name: Optional specific pod name to target. + + Returns: + The pod name string to connect to. + + Raises: + click.ClickException: If no running pods or specified pod not found. + """ + status = cclient.get_status_v3(deployment_id) + running_pods = [] + for revision in status.revision_pod_details_list or []: + for pod in revision.pod_details_list or []: + if pod.status == PodStatus.RUNNING and pod.name: + running_pods.append(pod.name) + + if not running_pods: + raise click.ClickException( + f"No running pods found for deployment {deployment_id}" + ) + + if pod_name is not None: + if pod_name not in running_pods: + pods_list = ", ".join(running_pods) + raise click.ClickException( + f"Pod '{pod_name}' not found. Available running pods: {pods_list}" + ) + return pod_name + + if len(running_pods) > 1: + click.echo( + f"Multiple running pods found, connecting to {running_pods[0]}. " + f"Use --pod to specify a different pod.", + err=True, + ) + return running_pods[0] + + +async def _forward_io(ws, screen, stream, shutdown): + """Bidirectional forwarding between local stdin/stdout and WebSocket. + + Output flows through a pyte terminal emulator so that cursor + addressing, line wrapping, and colors are rendered correctly + regardless of the remote PTY dimensions. + + The platform API proxy does not close the WebSocket when the remote + shell exits, so we detect the ``exit`` echo and use a short idle + timeout to break out. + + Args: + ws: WebSocket connection. + screen: pyte.Screen instance sized to the local terminal. + stream: pyte.Stream attached to *screen*. + shutdown: asyncio.Event set by signal handlers to request exit. + + Returns: + The exit code (always 0 for interactive sessions). + """ + loop = asyncio.get_running_loop() + stdin_fd = sys.stdin.fileno() + stdin_closed = asyncio.Event() + + async def _read_ws(): + _log.debug("[read_ws] started") + msg_count = 0 + recv_timeout = None + try: + while True: + try: + raw_msg = await asyncio.wait_for(ws.recv(), timeout=recv_timeout) + except asyncio.TimeoutError: + _log.debug( + "[read_ws] idle timeout (%.1fs) after %d msgs -- shell exited", + recv_timeout, msg_count, + ) + return + msg_count += 1 + msg = json.loads(raw_msg) + keys = list(msg.keys()) + data = msg.get("data", "") + data_snippet = repr(data[:120]) if data else "" + _log.debug( + "[read_ws] msg#%d keys=%s data=%s", + msg_count, keys, data_snippet, + ) + if data: + stream.feed(data.replace("\n", "\r\n")) + _render_dirty(screen, sys.stdout.buffer) + # Detect shell exit echo. When the user runs ``exit``, + # the remote PTY echoes ``exit\r\n`` as the final + # output with nothing after it. If ``exit\r\n`` appears + # mid-message (e.g. ``echo exit`` produces + # ``exit\r\n``), the shell is still alive. + idx = data.rfind("exit\r\n") + if idx != -1: + after_exit = data[idx + len("exit\r\n"):] + if after_exit.strip(): + # New prompt follows -- shell is still alive. + _log.debug("[read_ws] exit echo with trailing data, not a real exit") + recv_timeout = None + else: + # Nothing meaningful after exit -- shell exited. + _log.debug("[read_ws] detected exit echo at end of data, exiting") + return + elif recv_timeout is not None: + # We previously armed a timeout but got more data + # (shouldn't normally happen, but be safe). + recv_timeout = None + elif msg.get("error"): + _log.debug("[read_ws] error: %s", msg["error"]) + stream.feed(f"Error: {msg['error']}\r\n") + _render_dirty(screen, sys.stdout.buffer) + except websockets.ConnectionClosed as exc: + _log.debug( + "[read_ws] ConnectionClosed after %d msgs: %s", msg_count, exc, + ) + return + + async def _read_stdin(): + _log.debug("[read_stdin] started") + read_queue = asyncio.Queue() + + def _on_stdin_ready(): + data = sys.stdin.buffer.read1(4096) + if data: + read_queue.put_nowait(data) + else: + _log.debug("[read_stdin] stdin EOF") + stdin_closed.set() + + loop.add_reader(stdin_fd, _on_stdin_ready) + try: + while not stdin_closed.is_set() and not shutdown.is_set(): + try: + data = await asyncio.wait_for(read_queue.get(), timeout=0.5) + except asyncio.TimeoutError: + continue + _log.debug("[read_stdin] sending %d bytes: %s", len(data), repr(data[:40])) + try: + await ws.send( + json.dumps( + { + "operation": "stdin", + "data": data.decode("utf-8", errors="replace"), + "rows": screen.lines, + "cols": screen.columns, + } + ) + ) + except websockets.ConnectionClosed: + _log.debug("[read_stdin] ws closed on send") + return + _log.debug( + "[read_stdin] loop exited: stdin_closed=%s shutdown=%s", + stdin_closed.is_set(), shutdown.is_set(), + ) + finally: + loop.remove_reader(stdin_fd) + + async def _watch_shutdown(): + while not shutdown.is_set(): + await asyncio.sleep(0.2) + + _log.debug("[forward_io] creating tasks") + task_ws = asyncio.create_task(_read_ws()) + task_stdin = asyncio.create_task(_read_stdin()) + task_shutdown = asyncio.create_task(_watch_shutdown()) + tasks = [task_ws, task_stdin, task_shutdown] + task_names = {id(task_ws): "read_ws", id(task_stdin): "read_stdin", id(task_shutdown): "watch_shutdown"} + + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + done_names = [task_names[id(t)] for t in done] + pending_names = [task_names[id(t)] for t in pending] + _log.debug("[forward_io] first completed: done=%s pending=%s", done_names, pending_names) + + for t in pending: + t.cancel() + for t in pending: + try: + await t + except (asyncio.CancelledError, Exception): + pass + for t in done: + if t.exception() is not None: + _log.debug("[forward_io] task exception: %s", t.exception()) + raise t.exception() + _log.debug("[forward_io] returning exit_code=0") + return 0 + + +async def _interactive_session(ws_url, token): + """Run an interactive terminal session over WebSocket. + + Enters raw mode, forwards I/O bidirectionally, and restores terminal + on exit. SIGTERM and SIGHUP are caught to ensure terminal settings + are always restored. + """ + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + cols, rows = shutil.get_terminal_size() + + screen = pyte.Screen(cols, rows) + stream = pyte.Stream(screen) + + # Switch to alternate screen buffer (disables scrollback) and clear. + sys.stdout.buffer.write(b"\033[?1049h\033[2J\033[H") + sys.stdout.buffer.flush() + + loop = asyncio.get_running_loop() + + shutdown = asyncio.Event() + loop.add_signal_handler(signal.SIGTERM, shutdown.set) + loop.add_signal_handler(signal.SIGHUP, shutdown.set) + + headers = {"Authorization": f"Bearer {token}"} + _log.debug("[session] connecting to %s", ws_url) + async with websockets.connect( + ws_url, additional_headers=headers, close_timeout=2 + ) as ws: + _log.debug("[session] connected, sending resize %dx%d", cols, rows) + + def _send_resize(): + c, r = shutil.get_terminal_size() + screen.resize(r, c) + screen.dirty.update(range(r)) + asyncio.ensure_future( + ws.send( + json.dumps({"operation": "resize", "rows": r, "cols": c}) + ) + ) + + loop.add_signal_handler(signal.SIGWINCH, _send_resize) + + await ws.send( + json.dumps( + {"operation": "resize", "rows": rows, "cols": cols} + ) + ) + try: + exit_code = await _forward_io(ws, screen, stream, shutdown) + finally: + loop.remove_signal_handler(signal.SIGWINCH) + + # Skip close handshake -- proxy won't send close frame promptly. + ws.close_timeout = 0 + _log.debug("[session] exiting with code %d", exit_code) + return exit_code + finally: + loop.remove_signal_handler(signal.SIGTERM) + loop.remove_signal_handler(signal.SIGHUP) + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + # Leave alternate screen buffer, restore cursor and attributes. + sys.stdout.buffer.write(b"\033[?1049l\033[?25h\033[0m") + sys.stdout.buffer.flush() + + +_BEGIN_MARKER = "__CENTML_BEGIN_5f3a__" +_END_MARKER = "__CENTML_END_5f3a__" + +# printf octal \137 = underscore. The decoded output matches _BEGIN/_END_MARKER, +# but the literal command text does NOT, so shell echo won't trigger false matches. +_PRINTF_BEGIN = r"\137\137CENTML_BEGIN_5f3a\137\137" +_PRINTF_END = r"\137\137CENTML_END_5f3a\137\137" + + +def _pyte_extract_text(line_stream, line_screen, text): + """Feed text through a single-row pyte screen and return visible characters. + + More robust than regex ANSI stripping: pyte interprets all VT100/VT220 + sequences including OSC, cursor repositioning, and truecolor escapes. + """ + line_screen.reset() + line_stream.feed(text) + return "".join( + line_screen.buffer[0][col].data for col in range(line_screen.columns) + ).rstrip() + + +async def _exec_session(ws_url, token, command): + """Execute a command in a pod and return its exit code. + + Does not enter raw mode -- output is pipe-friendly. + Suppresses shell echo and uses markers to capture only command output. + """ + cols, rows = shutil.get_terminal_size(fallback=(80, 24)) + # Single-row screen for interpreting escape sequences in marker detection. + line_screen = pyte.Screen(cols, 1) + line_stream = pyte.Stream(line_screen) + headers = {"Authorization": f"Bearer {token}"} + + async with websockets.connect( + ws_url, additional_headers=headers, close_timeout=2 + ) as ws: + await ws.send(json.dumps({"operation": "resize", "rows": rows, "cols": cols})) + + # Suppress echo/bracketed-paste, emit begin marker, run command, + # emit end marker with exit code, then exit. + # Markers use printf octal escapes so the literal marker string + # doesn't appear in the command echo. + wrapped = ( + f"stty -echo 2>/dev/null; printf '\\033[?2004l';" + f" printf '{_PRINTF_BEGIN}\\n';" + f" {command};" + f" __ec=$?;" + f" printf '\\n{_PRINTF_END}:%d\\n' \"$__ec\";" + f" exit $__ec\n" + ) + + await ws.send(json.dumps({"operation": "stdin", "data": wrapped})) + + exit_code = 0 + buffer = "" + is_capturing = False + is_done = False + msg_count = 0 + try: + async for raw_msg in ws: + msg_count += 1 + msg = json.loads(raw_msg) + keys = list(msg.keys()) + _log.debug( + "[exec] msg#%d keys=%s data_len=%d", + msg_count, keys, len(msg.get("data", "")), + ) + if msg.get("data"): + buffer += msg["data"] + while "\n" in buffer: + line, buffer = buffer.split("\n", 1) + clean = _pyte_extract_text( + line_stream, line_screen, line.rstrip("\r") + ) + if _BEGIN_MARKER in clean: + _log.debug("[exec] BEGIN marker found") + is_capturing = True + continue + if _END_MARKER in clean: + parts = clean.split(_END_MARKER + ":") + if len(parts) > 1: + try: + exit_code = int(parts[1].strip()) + except ValueError: + pass + _log.debug("[exec] END marker, exit_code=%d", exit_code) + is_done = True + break + if is_capturing: + sys.stdout.write(line + "\n") + sys.stdout.flush() + elif msg.get("error"): + _log.debug("[exec] error: %s", msg["error"]) + sys.stderr.write(f"Error: {msg['error']}\n") + return 1 + if is_done: + _log.debug("[exec] done, breaking") + break + except websockets.ConnectionClosed as exc: + _log.debug("[exec] ConnectionClosed: %s", exc) + if is_done: + # Skip the close handshake -- the platform API proxy does not + # proactively close its end, so waiting wastes close_timeout seconds. + ws.close_timeout = 0 + _log.debug("[exec] returning exit_code=%d", exit_code) + return exit_code + + +@click.command(help="Open an interactive shell to a deployment pod") +@click.argument("deployment_id", type=int) +@click.option( + "--pod", default=None, help="Specific pod name (auto-selects first running pod)" +) +@click.option( + "--shell", + "shell_type", + default=None, + type=click.Choice(["bash", "sh", "zsh"]), + help="Shell type", +) +@handle_exception +def shell(deployment_id, pod, shell_type): + _setup_debug_log() + if not sys.stdin.isatty(): + raise click.ClickException("Interactive shell requires a terminal (TTY)") + + with get_centml_client() as cclient: + pod_name = _resolve_pod(cclient, deployment_id, pod) + + ws_url = _build_ws_url( + settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type + ) + token = auth.get_centml_token() + exit_code = asyncio.run(_interactive_session(ws_url, token)) + sys.exit(exit_code) + + +@click.command( + help="Execute a command in a deployment pod", + context_settings={"ignore_unknown_options": True}, +) +@click.argument("deployment_id", type=int) +@click.argument("command", nargs=-1, required=True, type=click.UNPROCESSED) +@click.option("--pod", default=None, help="Specific pod name") +@click.option( + "--shell", + "shell_type", + default=None, + type=click.Choice(["bash", "sh", "zsh"]), + help="Shell type", +) +@handle_exception +def exec_cmd(deployment_id, command, pod, shell_type): + _setup_debug_log() + with get_centml_client() as cclient: + pod_name = _resolve_pod(cclient, deployment_id, pod) + + ws_url = _build_ws_url( + settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type + ) + token = auth.get_centml_token() + cmd_str = " ".join(command) + exit_code = asyncio.run(_exec_session(ws_url, token, cmd_str)) + sys.exit(exit_code) diff --git a/centml/sdk/api.py b/centml/sdk/api.py index e1e11d3..20dfa99 100644 --- a/centml/sdk/api.py +++ b/centml/sdk/api.py @@ -27,6 +27,9 @@ def get(self, depl_type): def get_status(self, id): return self._api.get_deployment_status_deployments_status_deployment_id_get(id) + def get_status_v3(self, deployment_id): + return self._api.get_deployment_status_v3_deployments_status_v3_deployment_id_get(deployment_id) + def get_inference(self, id): """Get Inference deployment details - automatically handles both V2 and V3 deployments""" # Try V3 first (recommended), fallback to V2 if deployment is V2 diff --git a/requirements.txt b/requirements.txt index c3b4961..9e79a4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ cryptography==44.0.1 prometheus-client>=0.20.0 scipy>=1.6.0 scikit-learn>=1.5.1 +websockets>=13.0 +pyte>=0.8.0 platform-api-python-client==4.6.0 diff --git a/tests/conftest.py b/tests/conftest.py index f3de342..1ab0b0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,11 +3,22 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" +# Tests that require PyTorch at import time -- skip during sanity runs +# where PyTorch is not installed. +_PYTORCH_TEST_FILES = ["test_backend.py", "test_helpers.py", "test_server.py"] + +collect_ignore = [] + def pytest_addoption(parser): parser.addoption("--sanity", action="store_true", help="Run sanity tests (exclude 'gpu' tests)") +def pytest_configure(config): + if config.getoption("--sanity", default=False): + collect_ignore.extend(_PYTORCH_TEST_FILES) + + def pytest_collection_modifyitems(config, items): if config.getoption("--sanity"): skip_gpu = pytest.mark.skip(reason="Skipping GPU tests for sanity run") diff --git a/tests/test_shell.py b/tests/test_shell.py new file mode 100644 index 0000000..a6afd43 --- /dev/null +++ b/tests/test_shell.py @@ -0,0 +1,889 @@ +"""Tests for centml.cli.shell -- CLI terminal access commands.""" + +import asyncio +import io +import json +import os +import signal +import urllib.parse +from unittest.mock import AsyncMock, MagicMock, patch + +import click +import pyte +import pytest + +from platform_api_python_client import PodStatus, PodDetails, RevisionPodDetails + + +def _async_iter_from_list(items): + """Create an async iterator from a list of items.""" + + async def _aiter(): + for item in items: + yield item + + return _aiter() + + +# --------------------------------------------------------------------------- +# Helpers to build mock status responses +# --------------------------------------------------------------------------- + + +def _make_pod(name, status=PodStatus.RUNNING): + pod = MagicMock(spec=PodDetails) + pod.name = name + pod.status = status + return pod + + +def _make_revision(pods): + rev = MagicMock(spec=RevisionPodDetails) + rev.pod_details_list = pods + return rev + + +def _make_status_response(revisions): + resp = MagicMock() + resp.revision_pod_details_list = revisions + return resp + + +# =========================================================================== +# _build_ws_url +# =========================================================================== + + +class TestBuildWsUrl: + def test_https_to_wss(self): + from centml.cli.shell import _build_ws_url + + url = _build_ws_url("https://api.centml.com", 123, "my-pod-abc") + parsed = urllib.parse.urlparse(url) + assert parsed.scheme == "wss" + assert parsed.netloc == "api.centml.com" + + def test_http_to_ws(self): + from centml.cli.shell import _build_ws_url + + url = _build_ws_url("http://localhost:16000", 42, "pod-1") + parsed = urllib.parse.urlparse(url) + assert parsed.scheme == "ws" + assert parsed.netloc == "localhost:16000" + + def test_contains_deployment_id_and_pod(self): + from centml.cli.shell import _build_ws_url + + url = _build_ws_url("https://api.centml.com", 99, "pod-xyz") + assert "/deployments/99/terminal" in url + assert "pod=pod-xyz" in url + + def test_with_shell(self): + from centml.cli.shell import _build_ws_url + + url = _build_ws_url("https://api.centml.com", 1, "p", shell_type="bash") + assert "shell=bash" in url + + def test_without_shell(self): + from centml.cli.shell import _build_ws_url + + url = _build_ws_url("https://api.centml.com", 1, "p") + assert "shell=" not in url + + def test_encodes_pod_name(self): + from centml.cli.shell import _build_ws_url + + url = _build_ws_url("https://api.centml.com", 1, "pod name/special") + assert "pod%20name" in url or "pod+name" in url + + +# =========================================================================== +# _resolve_pod +# =========================================================================== + + +class TestResolvePod: + def test_selects_first_running(self): + from centml.cli.shell import _resolve_pod + + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [ + _make_revision( + [ + _make_pod("pod-a", PodStatus.RUNNING), + _make_pod("pod-b", PodStatus.RUNNING), + ] + ) + ] + ) + result = _resolve_pod(cclient, 1) + assert result == "pod-a" + + def test_raises_no_running_pods(self): + from centml.cli.shell import _resolve_pod + + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [_make_revision([_make_pod("pod-err", PodStatus.ERROR)])] + ) + with pytest.raises(click.ClickException, match="No running pods"): + _resolve_pod(cclient, 1) + + def test_raises_specified_pod_not_found(self): + from centml.cli.shell import _resolve_pod + + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [_make_revision([_make_pod("pod-a", PodStatus.RUNNING)])] + ) + with pytest.raises(click.ClickException, match="pod-missing"): + _resolve_pod(cclient, 1, pod_name="pod-missing") + + def test_returns_specified_pod(self): + from centml.cli.shell import _resolve_pod + + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [ + _make_revision( + [ + _make_pod("pod-a", PodStatus.RUNNING), + _make_pod("pod-b", PodStatus.RUNNING), + ] + ) + ] + ) + result = _resolve_pod(cclient, 1, pod_name="pod-b") + assert result == "pod-b" + + def test_empty_revision_list(self): + from centml.cli.shell import _resolve_pod + + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response([]) + with pytest.raises(click.ClickException, match="No running pods"): + _resolve_pod(cclient, 1) + + def test_none_revision_list(self): + from centml.cli.shell import _resolve_pod + + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response(None) + cclient.get_status_v3.return_value.revision_pod_details_list = None + with pytest.raises(click.ClickException, match="No running pods"): + _resolve_pod(cclient, 1) + + def test_skips_pods_without_name(self): + from centml.cli.shell import _resolve_pod + + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [ + _make_revision( + [ + _make_pod(None, PodStatus.RUNNING), + _make_pod("pod-real", PodStatus.RUNNING), + ] + ) + ] + ) + result = _resolve_pod(cclient, 1) + assert result == "pod-real" + + def test_multiple_revisions(self): + from centml.cli.shell import _resolve_pod + + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response( + [ + _make_revision([_make_pod("pod-old", PodStatus.ERROR)]), + _make_revision([_make_pod("pod-new", PodStatus.RUNNING)]), + ] + ) + result = _resolve_pod(cclient, 1) + assert result == "pod-new" + + +# =========================================================================== +# _exec_session +# =========================================================================== + + +class TestExecSession: + def test_sends_resize_and_wrapped_command(self): + from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER + + ws = AsyncMock() + messages = [ + json.dumps( + {"data": f"noise\n{_BEGIN_MARKER}\nhello world\n{_END_MARKER}:0\n"} + ), + json.dumps({"Code": 0}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + with patch("centml.cli.shell.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + + exit_code = asyncio.run( + _exec_session("wss://test/ws", "fake-token", "ls -la") + ) + + assert exit_code == 0 + assert ws.send.call_count == 2 + resize_msg = json.loads(ws.send.call_args_list[0][0][0]) + assert resize_msg["operation"] == "resize" + cmd_msg = json.loads(ws.send.call_args_list[1][0][0]) + assert cmd_msg["operation"] == "stdin" + assert "ls -la" in cmd_msg["data"] + assert "stty -echo" in cmd_msg["data"] + # Markers use printf octal escapes, so the literal marker + # should NOT appear in the command (prevents echo false-match). + assert _BEGIN_MARKER not in cmd_msg["data"] + assert "CENTML_BEGIN" in cmd_msg["data"] + + def test_returns_nonzero_exit_code_from_marker(self): + from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER + + ws = AsyncMock() + messages = [ + json.dumps({"data": f"{_BEGIN_MARKER}\n{_END_MARKER}:42\n"}), + json.dumps({"Code": 42}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + with patch("centml.cli.shell.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + + exit_code = asyncio.run( + _exec_session("wss://test/ws", "fake-token", "false") + ) + + assert exit_code == 42 + + def test_error_message_returns_one(self): + from centml.cli.shell import _exec_session + + ws = AsyncMock() + messages = [json.dumps({"error": "something went wrong"})] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + with patch("centml.cli.shell.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + + exit_code = asyncio.run(_exec_session("wss://test/ws", "fake-token", "bad")) + + assert exit_code == 1 + + def test_filters_noise_before_marker(self): + """Only output between BEGIN and END markers is written to stdout.""" + from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER + + ws = AsyncMock() + messages = [ + json.dumps( + { + "data": f"prompt$ command\n{_BEGIN_MARKER}\nreal output\n{_END_MARKER}:0\n" + } + ), + json.dumps({"Code": 0}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + captured = [] + with patch("centml.cli.shell.websockets") as mock_ws_mod, patch( + "centml.cli.shell.sys" + ) as mock_sys: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + mock_sys.stdout.write = lambda s: captured.append(s) + mock_sys.stdout.flush = MagicMock() + mock_sys.stderr.write = MagicMock() + + exit_code = asyncio.run( + _exec_session("wss://test/ws", "fake-token", "echo test") + ) + + assert exit_code == 0 + output = "".join(captured) + assert "real output" in output + assert "prompt$" not in output + + def test_connection_closed_returns_zero(self): + """Graceful exit when server closes connection without Code message.""" + from centml.cli.shell import _exec_session + + import websockets as _ws_lib + + ws = AsyncMock() + + async def _raise_closed(): + yield json.dumps({"data": "partial\n"}) + raise _ws_lib.ConnectionClosed(None, None) + + ws.__aiter__ = MagicMock(return_value=_raise_closed()) + + with patch("centml.cli.shell.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + mock_ws_mod.ConnectionClosed = _ws_lib.ConnectionClosed + + exit_code = asyncio.run( + _exec_session("wss://test/ws", "fake-token", "exit") + ) + + assert exit_code == 0 + + def test_sets_zero_close_timeout_after_done(self): + """After END marker, close_timeout should be 0 to avoid waiting for server close.""" + from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER + + ws = AsyncMock() + messages = [ + json.dumps({"data": f"{_BEGIN_MARKER}\nhello\n{_END_MARKER}:0\n"}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + ws.close_timeout = 2 + + with patch("centml.cli.shell.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + + asyncio.run(_exec_session("wss://test/ws", "fake-token", "echo hello")) + + assert ws.close_timeout == 0 + + def test_handles_ansi_around_markers(self): + """Markers wrapped in ANSI codes are still detected via pyte.""" + from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER + + ws = AsyncMock() + # Markers surrounded by ANSI color codes. + data = f"\x1b[32m{_BEGIN_MARKER}\x1b[0m\noutput\n\x1b[32m{_END_MARKER}:0\x1b[0m\n" + messages = [json.dumps({"data": data}), json.dumps({"Code": 0})] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + captured = [] + with patch("centml.cli.shell.websockets") as mock_ws_mod, patch( + "centml.cli.shell.sys" + ) as mock_sys: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + mock_sys.stdout.write = lambda s: captured.append(s) + mock_sys.stdout.flush = MagicMock() + mock_sys.stderr.write = MagicMock() + + exit_code = asyncio.run( + _exec_session("wss://test/ws", "fake-token", "echo test") + ) + + assert exit_code == 0 + output = "".join(captured) + assert "output" in output + + +# =========================================================================== +# _interactive_session -- terminal restore +# =========================================================================== + + +class TestInteractiveSessionTerminalRestore: + def test_restores_terminal_on_exception(self): + from centml.cli.shell import _interactive_session + + with patch("centml.cli.shell.sys") as mock_sys, patch( + "centml.cli.shell.termios" + ) as mock_termios, patch("centml.cli.shell.tty"), patch( + "centml.cli.shell.websockets" + ) as mock_ws_mod: + + mock_sys.stdin.fileno.return_value = 0 + mock_sys.stdout.buffer = io.BytesIO() + mock_termios.tcgetattr.return_value = ["old_settings"] + + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(side_effect=ConnectionRefusedError("fail")), + __aexit__=AsyncMock(return_value=False), + ) + ) + + with pytest.raises(ConnectionRefusedError): + asyncio.run(_interactive_session("wss://test/ws", "fake-token")) + + mock_termios.tcsetattr.assert_called_once() + restore_call = mock_termios.tcsetattr.call_args + assert restore_call[0][2] == ["old_settings"] + + +# =========================================================================== +# Click commands +# =========================================================================== + + +class TestShellCommand: + def test_rejects_non_tty(self): + from centml.cli.shell import shell + from click.testing import CliRunner + + runner = CliRunner() + result = runner.invoke(shell, ["123"]) + assert result.exit_code != 0 + assert "terminal" in result.output.lower() or "tty" in result.output.lower() + + def test_shell_option_forwarded(self): + from centml.cli.shell import shell + from click.testing import CliRunner + + with patch("centml.cli.shell._resolve_pod", return_value="pod-a"), patch( + "centml.cli.shell.get_centml_client" + ) as mock_ctx, patch("centml.cli.shell.auth") as mock_auth, patch( + "centml.cli.shell.settings" + ) as mock_settings, patch( + "centml.cli.shell.asyncio" + ) as mock_asyncio, patch( + "centml.cli.shell.sys" + ) as mock_sys: + + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_auth.get_centml_token.return_value = "token" + mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" + mock_sys.stdin.isatty.return_value = True + mock_asyncio.run.return_value = 0 + + runner = CliRunner() + runner.invoke(shell, ["123", "--shell", "bash"]) + + mock_asyncio.run.assert_called_once() + + def test_pod_option_forwarded(self): + from centml.cli.shell import shell + from click.testing import CliRunner + + with patch("centml.cli.shell._resolve_pod") as mock_resolve, patch( + "centml.cli.shell.get_centml_client" + ) as mock_ctx, patch("centml.cli.shell.auth") as mock_auth, patch( + "centml.cli.shell.settings" + ) as mock_settings, patch( + "centml.cli.shell.asyncio" + ) as mock_asyncio, patch( + "centml.cli.shell.sys" + ) as mock_sys: + + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_resolve.return_value = "my-pod" + mock_auth.get_centml_token.return_value = "token" + mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" + mock_sys.stdin.isatty.return_value = True + mock_asyncio.run.return_value = 0 + + runner = CliRunner() + runner.invoke(shell, ["123", "--pod", "my-pod"]) + + mock_resolve.assert_called_once() + assert ( + mock_resolve.call_args[1].get("pod_name") == "my-pod" + or mock_resolve.call_args[0][2] == "my-pod" + ) + + +class TestExecCommand: + def test_passes_command(self): + from centml.cli.shell import exec_cmd + from click.testing import CliRunner + + with patch("centml.cli.shell._resolve_pod", return_value="pod-a"), patch( + "centml.cli.shell.get_centml_client" + ) as mock_ctx, patch("centml.cli.shell.auth") as mock_auth, patch( + "centml.cli.shell.settings" + ) as mock_settings, patch( + "centml.cli.shell.asyncio" + ) as mock_asyncio: + + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_auth.get_centml_token.return_value = "token" + mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" + mock_asyncio.run.return_value = 0 + + runner = CliRunner() + runner.invoke(exec_cmd, ["123", "--", "ls", "-la"]) + + mock_asyncio.run.assert_called_once() + + def test_shell_option_forwarded(self): + from centml.cli.shell import exec_cmd + from click.testing import CliRunner + + with patch("centml.cli.shell._resolve_pod", return_value="pod-a"), patch( + "centml.cli.shell.get_centml_client" + ) as mock_ctx, patch("centml.cli.shell.auth") as mock_auth, patch( + "centml.cli.shell.settings" + ) as mock_settings, patch( + "centml.cli.shell.asyncio" + ) as mock_asyncio: + + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_auth.get_centml_token.return_value = "token" + mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" + mock_asyncio.run.return_value = 0 + + runner = CliRunner() + runner.invoke(exec_cmd, ["123", "--shell", "zsh", "--", "echo", "hi"]) + + mock_asyncio.run.assert_called_once() + + +# =========================================================================== +# pyte renderer: _char_to_sgr +# =========================================================================== + + +class TestCharToSgr: + def test_default_attrs_returns_empty(self): + from centml.cli.shell import _char_to_sgr + + char = pyte.screens.Char( + " ", "default", "default", False, False, False, False, False, False + ) + assert _char_to_sgr(char) == "" + + def test_bold_red_fg(self): + from centml.cli.shell import _char_to_sgr + + char = pyte.screens.Char( + "x", "red", "default", True, False, False, False, False, False + ) + sgr = _char_to_sgr(char) + assert "1" in sgr.split(";") + assert "31" in sgr.split(";") + + def test_bg_color(self): + from centml.cli.shell import _char_to_sgr + + char = pyte.screens.Char( + "x", "default", "blue", False, False, False, False, False, False + ) + sgr = _char_to_sgr(char) + assert "44" in sgr.split(";") + + def test_256_color_fg(self): + from centml.cli.shell import _char_to_sgr + + char = pyte.screens.Char( + "x", "ff0000", "default", False, False, False, False, False, False + ) + sgr = _char_to_sgr(char) + assert "38;2;255;0;0" in sgr + + def test_combined_attrs(self): + from centml.cli.shell import _char_to_sgr + + char = pyte.screens.Char( + "x", "green", "white", True, True, True, False, False, False + ) + sgr = _char_to_sgr(char) + parts = sgr.split(";") + assert "1" in parts # bold + assert "3" in parts # italics + assert "4" in parts # underscore + assert "32" in parts # green fg + assert "47" in parts # white bg + + +# =========================================================================== +# pyte renderer: _render_dirty +# =========================================================================== + + +class TestRenderDirty: + def test_renders_simple_text(self): + from centml.cli.shell import _render_dirty + + screen = pyte.Screen(40, 5) + stream = pyte.Stream(screen) + screen.dirty.clear() + stream.feed("hello") + buf = io.BytesIO() + _render_dirty(screen, buf) + output = buf.getvalue().decode("utf-8") + assert "hello" in output + assert len(screen.dirty) == 0 + + def test_clears_dirty_after_render(self): + from centml.cli.shell import _render_dirty + + screen = pyte.Screen(40, 5) + stream = pyte.Stream(screen) + screen.dirty.clear() + stream.feed("test") + assert len(screen.dirty) > 0 + _render_dirty(screen, io.BytesIO()) + assert len(screen.dirty) == 0 + + def test_cursor_position_in_output(self): + from centml.cli.shell import _render_dirty + + screen = pyte.Screen(40, 5) + stream = pyte.Stream(screen) + stream.feed("abc") + buf = io.BytesIO() + _render_dirty(screen, buf) + output = buf.getvalue().decode("utf-8") + # Cursor should be at row 1, col 4 (1-based: after "abc") + assert "\033[1;4H" in output + + def test_renders_only_dirty_lines(self): + from centml.cli.shell import _render_dirty + + screen = pyte.Screen(40, 5) + stream = pyte.Stream(screen) + stream.feed("line0\r\nline1\r\nline2") + # Render to clear dirty + _render_dirty(screen, io.BytesIO()) + # Now modify only line 0 + stream.feed("\033[1;1Hchanged") + buf = io.BytesIO() + _render_dirty(screen, buf) + output = buf.getvalue().decode("utf-8") + assert "changed" in output + # line1 and line2 should NOT be re-rendered + assert "line1" not in output + assert "line2" not in output + + +# =========================================================================== +# _forward_io -- exit detection and shutdown +# =========================================================================== + + +class TestForwardIo: + """Tests for _forward_io exit detection via idle timeout. + + Uses a real pipe fd so ``loop.add_reader`` works without OS errors. + """ + + def _run_forward_io(self, ws, shutdown=None): + """Helper: run _forward_io with a real pipe fd standing in for stdin.""" + from centml.cli.shell import _forward_io + + import websockets as _ws_lib + + screen = pyte.Screen(80, 24) + stream = pyte.Stream(screen) + if shutdown is None: + shutdown = asyncio.Event() + + read_fd, write_fd = os.pipe() + os.close(write_fd) + try: + with patch("centml.cli.shell.sys") as mock_sys, patch( + "centml.cli.shell.websockets" + ) as mock_ws_mod: + mock_sys.stdin.fileno.return_value = read_fd + mock_sys.stdin.buffer.read1 = lambda n: b"" + mock_sys.stdout.buffer = io.BytesIO() + mock_ws_mod.ConnectionClosed = _ws_lib.ConnectionClosed + + return asyncio.run(_forward_io(ws, screen, stream, shutdown)) + finally: + os.close(read_fd) + + def test_connection_closed_returns_zero(self): + """ConnectionClosed returns 0.""" + import websockets as _ws_lib + + ws = AsyncMock() + ws.recv = AsyncMock(side_effect=_ws_lib.ConnectionClosed(None, None)) + + assert self._run_forward_io(ws) == 0 + + def test_exit_echo_at_end_exits_immediately(self): + """'exit\\r\\n' at end of data (no trailing prompt) exits immediately.""" + ws = AsyncMock() + ws.recv = AsyncMock( + side_effect=[ + json.dumps({"data": "\r\n\x1b[?2004l\rexit\r\n"}), + # Should never be called -- _read_ws returns before this. + json.dumps({"data": "should not reach"}), + ] + ) + + assert self._run_forward_io(ws) == 0 + # Only one recv call -- exited immediately after exit echo. + assert ws.recv.call_count == 1 + + def test_exit_echo_with_prompt_continues(self): + """'exit\\r\\n' followed by a new prompt is not a real exit.""" + import websockets as _ws_lib + + ws = AsyncMock() + ws.recv = AsyncMock( + side_effect=[ + # ``echo exit`` -- exit echo with prompt trailing. + json.dumps({"data": "\r\n\x1b[?2004l\rexit\r\n\x1b[?2004huser@host:~$ "}), + _ws_lib.ConnectionClosed(None, None), + ] + ) + + assert self._run_forward_io(ws) == 0 + # Both recv calls made -- did not exit after the first message. + assert ws.recv.call_count == 2 + + def test_normal_data_no_early_exit(self): + """Data without 'exit\\r\\n' does not trigger early exit.""" + import websockets as _ws_lib + + ws = AsyncMock() + ws.recv = AsyncMock( + side_effect=[ + json.dumps({"data": "hello\r\n"}), + _ws_lib.ConnectionClosed(None, None), + ] + ) + + assert self._run_forward_io(ws) == 0 + + def test_shutdown_event_exits(self): + """shutdown event causes _forward_io to exit.""" + from centml.cli.shell import _forward_io + + import websockets as _ws_lib + + ws = AsyncMock() + + # recv that blocks until cancelled (simulates open WS with no data) + async def _block_recv(): + await asyncio.sleep(999) + + ws.recv = _block_recv + + screen = pyte.Screen(80, 24) + stream = pyte.Stream(screen) + shutdown = asyncio.Event() + + read_fd, write_fd = os.pipe() + os.close(write_fd) + try: + with patch("centml.cli.shell.sys") as mock_sys, patch( + "centml.cli.shell.websockets" + ) as mock_ws_mod: + mock_sys.stdin.fileno.return_value = read_fd + mock_sys.stdin.buffer.read1 = lambda n: b"" + mock_sys.stdout.buffer = io.BytesIO() + mock_ws_mod.ConnectionClosed = _ws_lib.ConnectionClosed + + async def _run(): + async def _set_shutdown(): + await asyncio.sleep(0.1) + shutdown.set() + + asyncio.create_task(_set_shutdown()) + return await _forward_io(ws, screen, stream, shutdown) + + assert asyncio.run(_run()) == 0 + finally: + os.close(read_fd) + + +# =========================================================================== +# _interactive_session -- signal handling +# =========================================================================== + + +class TestInteractiveSessionSignals: + """Tests for SIGTERM/SIGHUP restoring terminal settings.""" + + def test_sigterm_restores_terminal(self): + from centml.cli.shell import _interactive_session + + signal_handlers = {} + + def _fake_add_signal_handler(sig, handler): + signal_handlers[sig] = handler + + def _fake_remove_signal_handler(sig): + signal_handlers.pop(sig, None) + + async def _fake_forward_io(ws, screen, stream, shutdown): + if signal.SIGTERM in signal_handlers: + signal_handlers[signal.SIGTERM]() + return 0 + + with patch("centml.cli.shell.sys") as mock_sys, patch( + "centml.cli.shell.termios" + ) as mock_termios, patch("centml.cli.shell.tty"), patch( + "centml.cli.shell.websockets" + ) as mock_ws_mod, patch( + "centml.cli.shell._forward_io", side_effect=_fake_forward_io + ): + mock_sys.stdin.fileno.return_value = 0 + mock_sys.stdout.buffer = io.BytesIO() + mock_termios.tcgetattr.return_value = ["old_settings"] + + mock_ws = AsyncMock() + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + + def _patched_run(coro): + loop = asyncio.new_event_loop() + loop.add_signal_handler = _fake_add_signal_handler + loop.remove_signal_handler = _fake_remove_signal_handler + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + with patch("centml.cli.shell.asyncio") as mock_asyncio_mod: + mock_asyncio_mod.get_running_loop.return_value = MagicMock( + add_signal_handler=_fake_add_signal_handler, + remove_signal_handler=_fake_remove_signal_handler, + ) + mock_asyncio_mod.Event = asyncio.Event + mock_asyncio_mod.create_task = asyncio.ensure_future + + _patched_run( + _interactive_session("wss://test/ws", "fake-token") + ) + + mock_termios.tcsetattr.assert_called_once() + assert mock_termios.tcsetattr.call_args[0][2] == ["old_settings"]