From 44ee17bfee77cbdf29325a7f0011387f4ed8a133 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 14:13:05 -0400 Subject: [PATCH 01/28] feat: add CLI shell and exec commands for deployment pod terminal access Add two new commands under `centml cluster`: - `shell ` -- interactive terminal session (like docker exec -it) - `exec -- ` -- run a command and return output (like ssh host "cmd") Both connect via WebSocket through the Platform API terminal proxy, matching the same protocol used by the Web UI TerminalView. --- centml/cli/main.py | 3 + centml/cli/shell.py | 309 +++++++++++++++++++++++++++++++ centml/sdk/api.py | 3 + requirements.txt | 1 + tests/test_shell.py | 429 ++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 745 insertions(+) create mode 100644 centml/cli/shell.py create mode 100644 tests/test_shell.py diff --git a/centml/cli/main.py b/centml/cli/main.py index b1ecc73..fb5b932 100644 --- a/centml/cli/main.py +++ b/centml/cli/main.py @@ -2,6 +2,7 @@ from centml.cli.login import login, logout from centml.cli.cluster import ls, get, delete, pause, resume +from centml.cli.shell import shell, exec_cmd @click.group() @@ -47,6 +48,8 @@ def ccluster(): ccluster.add_command(delete) ccluster.add_command(pause) ccluster.add_command(resume) +ccluster.add_command(shell) +ccluster.add_command(exec_cmd, name="exec") cli.add_command(ccluster, name="cluster") diff --git a/centml/cli/shell.py b/centml/cli/shell.py new file mode 100644 index 0000000..3dbce60 --- /dev/null +++ b/centml/cli/shell.py @@ -0,0 +1,309 @@ +"""CLI commands for interactive shell and command execution in deployment pods.""" + +import asyncio +import json +import re +import shutil +import signal +import sys +import urllib.parse + +import click + +from centml.sdk import PodStatus +from centml.sdk import auth +from centml.sdk.api import get_centml_client +from centml.sdk.config import settings +from centml.cli.cluster import handle_exception + +# Lazy-import to keep module loadable without websockets installed at import time, +# and to allow tests to patch the module attribute easily. +import websockets + +# These are only available on Unix; guarded at command level via isatty check. +import termios +import tty + + +def _build_ws_url(api_url, deployment_id, pod_name, shell=None): + """Build the WebSocket URL for a terminal connection.""" + ws_base = api_url.replace("https://", "wss://").replace("http://", "ws://") + url = f"{ws_base}/deployments/{deployment_id}/terminal?pod={urllib.parse.quote(pod_name)}" + if shell: + url += f"&shell={urllib.parse.quote(shell)}" + return url + + +def _resolve_pod(cclient, deployment_id, pod_name=None): + """Resolve which pod to connect to. + + Args: + cclient: CentMLClient instance. + deployment_id: The deployment ID. + pod_name: Optional specific pod name to target. + + Returns: + The pod name string to connect to. + + Raises: + click.ClickException: If no running pods or specified pod not found. + """ + status = cclient.get_status_v3(deployment_id) + running_pods = [] + for revision in (status.revision_pod_details_list or []): + for pod in (revision.pod_details_list or []): + if pod.status == PodStatus.RUNNING and pod.name: + running_pods.append(pod.name) + + if not running_pods: + raise click.ClickException( + f"No running pods found for deployment {deployment_id}" + ) + + if pod_name is not None: + if pod_name not in running_pods: + pods_list = ", ".join(running_pods) + raise click.ClickException( + f"Pod '{pod_name}' not found. Available running pods: {pods_list}" + ) + return pod_name + + if len(running_pods) > 1: + click.echo( + f"Multiple running pods found, connecting to {running_pods[0]}. " + f"Use --pod to specify a different pod.", + err=True, + ) + return running_pods[0] + + +async def _forward_io(ws): + """Bidirectional forwarding between local stdin/stdout and WebSocket. + + Returns the remote exit code. + """ + loop = asyncio.get_running_loop() + exit_code = 0 + stdin_fd = sys.stdin.fileno() + + stdin_closed = asyncio.Event() + + async def _read_ws(): + nonlocal exit_code + async for raw_msg in ws: + msg = json.loads(raw_msg) + if msg.get("data"): + sys.stdout.buffer.write(msg["data"].encode("utf-8", errors="replace")) + sys.stdout.buffer.flush() + elif msg.get("error"): + sys.stderr.write(f"Error: {msg['error']}\n") + sys.stderr.flush() + if "Code" in msg: + exit_code = msg["Code"] + return + + async def _read_stdin(): + read_queue = asyncio.Queue() + + def _on_stdin_ready(): + data = sys.stdin.buffer.read1(4096) + if data: + read_queue.put_nowait(data) + else: + stdin_closed.set() + + loop.add_reader(stdin_fd, _on_stdin_ready) + try: + while not stdin_closed.is_set(): + try: + data = await asyncio.wait_for(read_queue.get(), timeout=0.5) + except asyncio.TimeoutError: + continue + rows, cols = shutil.get_terminal_size() + await ws.send(json.dumps({ + "operation": "stdin", + "data": data.decode("utf-8", errors="replace"), + "rows": rows, + "cols": cols, + })) + finally: + loop.remove_reader(stdin_fd) + + tasks = [ + asyncio.create_task(_read_ws()), + asyncio.create_task(_read_stdin()), + ] + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + for t in pending: + t.cancel() + for t in done: + if t.exception() is not None: + raise t.exception() + return exit_code + + +async def _interactive_session(ws_url, token): + """Run an interactive terminal session over WebSocket. + + Enters raw mode, forwards I/O bidirectionally, and restores terminal on exit. + """ + fd = sys.stdin.fileno() + old_settings = termios.tcgetattr(fd) + try: + tty.setraw(fd) + rows, cols = shutil.get_terminal_size() + + headers = {"Authorization": f"Bearer {token}"} + async with websockets.connect(ws_url, additional_headers=headers) as ws: + await ws.send(json.dumps({ + "operation": "resize", + "rows": rows, + "cols": cols, + })) + + loop = asyncio.get_running_loop() + + def _on_resize(): + r, c = shutil.get_terminal_size() + asyncio.ensure_future(ws.send(json.dumps({ + "operation": "resize", + "rows": r, + "cols": c, + }))) + + loop.add_signal_handler(signal.SIGWINCH, _on_resize) + + try: + exit_code = await _forward_io(ws) + finally: + loop.remove_signal_handler(signal.SIGWINCH) + + return exit_code + finally: + termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + + +_ANSI_ESCAPE_RE = re.compile(r"\x1b\[[0-9;?]*[a-zA-Z]|\x1b\].*?\x07|\x1b\[.*?\x1b\\") + +_BEGIN_MARKER = "__CENTML_BEGIN_5f3a__" +_END_MARKER = "__CENTML_END_5f3a__" + +# printf octal \137 = underscore. The decoded output matches _BEGIN/_END_MARKER, +# but the literal command text does NOT, so shell echo won't trigger false matches. +_PRINTF_BEGIN = r"\137\137CENTML_BEGIN_5f3a\137\137" +_PRINTF_END = r"\137\137CENTML_END_5f3a\137\137" + + +def _strip_ansi(text): + """Remove ANSI escape sequences from text.""" + return _ANSI_ESCAPE_RE.sub("", text) + + +async def _exec_session(ws_url, token, command): + """Execute a command in a pod and return its exit code. + + Does not enter raw mode -- output is pipe-friendly. + Suppresses shell echo and uses markers to capture only command output. + """ + rows, cols = shutil.get_terminal_size(fallback=(80, 24)) + headers = {"Authorization": f"Bearer {token}"} + + async with websockets.connect(ws_url, additional_headers=headers) as ws: + await ws.send(json.dumps({ + "operation": "resize", + "rows": rows, + "cols": cols, + })) + + # Suppress echo/bracketed-paste, emit begin marker, run command, + # emit end marker with exit code, then exit. + # Markers use printf octal escapes so the literal marker string + # doesn't appear in the command echo. + wrapped = ( + f"stty -echo 2>/dev/null; printf '\\033[?2004l';" + f" printf '{_PRINTF_BEGIN}\\n';" + f" {command};" + f" __ec=$?;" + f" printf '\\n{_PRINTF_END}:%d\\n' \"$__ec\";" + f" exit $__ec\n" + ) + + await ws.send(json.dumps({ + "operation": "stdin", + "data": wrapped, + "rows": rows, + "cols": cols, + })) + + exit_code = 0 + buffer = "" + is_capturing = False + async for raw_msg in ws: + msg = json.loads(raw_msg) + if msg.get("data"): + buffer += msg["data"] + # Process complete lines from buffer + while "\n" in buffer: + line, buffer = buffer.split("\n", 1) + clean = _strip_ansi(line).rstrip("\r") + if _BEGIN_MARKER in clean: + is_capturing = True + continue + if _END_MARKER in clean: + # Parse exit code from marker line + parts = clean.split(_END_MARKER + ":") + if len(parts) > 1: + try: + exit_code = int(parts[1].strip()) + except ValueError: + pass + is_capturing = False + continue + if is_capturing: + sys.stdout.write(line + "\n") + sys.stdout.flush() + elif msg.get("error"): + sys.stderr.write(f"Error: {msg['error']}\n") + return 1 + if "Code" in msg: + exit_code = msg["Code"] + break + return exit_code + + +@click.command(help="Open an interactive shell to a deployment pod") +@click.argument("deployment_id", type=int) +@click.option("--pod", default=None, help="Specific pod name (auto-selects first running pod)") +@click.option("--shell", "shell_type", default=None, + type=click.Choice(["bash", "sh", "zsh"]), help="Shell type") +@handle_exception +def shell(deployment_id, pod, shell_type): + if not sys.stdin.isatty(): + raise click.ClickException("Interactive shell requires a terminal (TTY)") + + with get_centml_client() as cclient: + pod_name = _resolve_pod(cclient, deployment_id, pod) + + ws_url = _build_ws_url(settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type) + token = auth.get_centml_token() + exit_code = asyncio.run(_interactive_session(ws_url, token)) + sys.exit(exit_code) + + +@click.command(help="Execute a command in a deployment pod", + context_settings=dict(ignore_unknown_options=True)) +@click.argument("deployment_id", type=int) +@click.argument("command", nargs=-1, required=True, type=click.UNPROCESSED) +@click.option("--pod", default=None, help="Specific pod name") +@click.option("--shell", "shell_type", default=None, + type=click.Choice(["bash", "sh", "zsh"]), help="Shell type") +@handle_exception +def exec_cmd(deployment_id, command, pod, shell_type): + with get_centml_client() as cclient: + pod_name = _resolve_pod(cclient, deployment_id, pod) + + ws_url = _build_ws_url(settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type) + token = auth.get_centml_token() + cmd_str = " ".join(command) + exit_code = asyncio.run(_exec_session(ws_url, token, cmd_str)) + sys.exit(exit_code) diff --git a/centml/sdk/api.py b/centml/sdk/api.py index e1e11d3..20dfa99 100644 --- a/centml/sdk/api.py +++ b/centml/sdk/api.py @@ -27,6 +27,9 @@ def get(self, depl_type): def get_status(self, id): return self._api.get_deployment_status_deployments_status_deployment_id_get(id) + def get_status_v3(self, deployment_id): + return self._api.get_deployment_status_v3_deployments_status_v3_deployment_id_get(deployment_id) + def get_inference(self, id): """Get Inference deployment details - automatically handles both V2 and V3 deployments""" # Try V3 first (recommended), fallback to V2 if deployment is V2 diff --git a/requirements.txt b/requirements.txt index c3b4961..133ee99 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,5 @@ cryptography==44.0.1 prometheus-client>=0.20.0 scipy>=1.6.0 scikit-learn>=1.5.1 +websockets>=13.0 platform-api-python-client==4.6.0 diff --git a/tests/test_shell.py b/tests/test_shell.py new file mode 100644 index 0000000..dfb6b52 --- /dev/null +++ b/tests/test_shell.py @@ -0,0 +1,429 @@ +"""Tests for centml.cli.shell -- CLI terminal access commands.""" + +import json +import sys +import urllib.parse +from unittest.mock import AsyncMock, MagicMock, patch, call + +import click +import pytest + +from platform_api_python_client import PodStatus, PodDetails, RevisionPodDetails + + +def _async_iter_from_list(items): + """Create an async iterator from a list of items.""" + async def _aiter(): + for item in items: + yield item + return _aiter() + + +# --------------------------------------------------------------------------- +# Helpers to build mock status responses +# --------------------------------------------------------------------------- + +def _make_pod(name, status=PodStatus.RUNNING): + pod = MagicMock(spec=PodDetails) + pod.name = name + pod.status = status + return pod + + +def _make_revision(pods): + rev = MagicMock(spec=RevisionPodDetails) + rev.pod_details_list = pods + return rev + + +def _make_status_response(revisions): + resp = MagicMock() + resp.revision_pod_details_list = revisions + return resp + + +# =========================================================================== +# _build_ws_url +# =========================================================================== + +class TestStripAnsi: + def test_strips_csi_sequences(self): + from centml.cli.shell import _strip_ansi + assert _strip_ansi("\x1b[?2004htext\x1b[0m") == "text" + + def test_preserves_plain_text(self): + from centml.cli.shell import _strip_ansi + assert _strip_ansi("hello world") == "hello world" + + +class TestBuildWsUrl: + def test_https_to_wss(self): + from centml.cli.shell import _build_ws_url + url = _build_ws_url("https://api.centml.com", 123, "my-pod-abc") + assert url.startswith("wss://api.centml.com/") + + def test_http_to_ws(self): + from centml.cli.shell import _build_ws_url + url = _build_ws_url("http://localhost:16000", 42, "pod-1") + assert url.startswith("ws://localhost:16000/") + + def test_contains_deployment_id_and_pod(self): + from centml.cli.shell import _build_ws_url + url = _build_ws_url("https://api.centml.com", 99, "pod-xyz") + assert "/deployments/99/terminal" in url + assert "pod=pod-xyz" in url + + def test_with_shell(self): + from centml.cli.shell import _build_ws_url + url = _build_ws_url("https://api.centml.com", 1, "p", shell="bash") + assert "shell=bash" in url + + def test_without_shell(self): + from centml.cli.shell import _build_ws_url + url = _build_ws_url("https://api.centml.com", 1, "p") + assert "shell=" not in url + + def test_encodes_pod_name(self): + from centml.cli.shell import _build_ws_url + url = _build_ws_url("https://api.centml.com", 1, "pod name/special") + assert "pod%20name" in url or "pod+name" in url + + +# =========================================================================== +# _resolve_pod +# =========================================================================== + +class TestResolvePod: + def test_selects_first_running(self): + from centml.cli.shell import _resolve_pod + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response([ + _make_revision([ + _make_pod("pod-a", PodStatus.RUNNING), + _make_pod("pod-b", PodStatus.RUNNING), + ]) + ]) + result = _resolve_pod(cclient, 1) + assert result == "pod-a" + + def test_raises_no_running_pods(self): + from centml.cli.shell import _resolve_pod + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response([ + _make_revision([ + _make_pod("pod-err", PodStatus.ERROR), + ]) + ]) + with pytest.raises(click.ClickException, match="No running pods"): + _resolve_pod(cclient, 1) + + def test_raises_specified_pod_not_found(self): + from centml.cli.shell import _resolve_pod + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response([ + _make_revision([ + _make_pod("pod-a", PodStatus.RUNNING), + ]) + ]) + with pytest.raises(click.ClickException, match="pod-missing"): + _resolve_pod(cclient, 1, pod_name="pod-missing") + + def test_returns_specified_pod(self): + from centml.cli.shell import _resolve_pod + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response([ + _make_revision([ + _make_pod("pod-a", PodStatus.RUNNING), + _make_pod("pod-b", PodStatus.RUNNING), + ]) + ]) + result = _resolve_pod(cclient, 1, pod_name="pod-b") + assert result == "pod-b" + + def test_empty_revision_list(self): + from centml.cli.shell import _resolve_pod + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response([]) + with pytest.raises(click.ClickException, match="No running pods"): + _resolve_pod(cclient, 1) + + def test_none_revision_list(self): + from centml.cli.shell import _resolve_pod + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response(None) + cclient.get_status_v3.return_value.revision_pod_details_list = None + with pytest.raises(click.ClickException, match="No running pods"): + _resolve_pod(cclient, 1) + + def test_skips_pods_without_name(self): + from centml.cli.shell import _resolve_pod + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response([ + _make_revision([ + _make_pod(None, PodStatus.RUNNING), + _make_pod("pod-real", PodStatus.RUNNING), + ]) + ]) + result = _resolve_pod(cclient, 1) + assert result == "pod-real" + + def test_multiple_revisions(self): + from centml.cli.shell import _resolve_pod + cclient = MagicMock() + cclient.get_status_v3.return_value = _make_status_response([ + _make_revision([ + _make_pod("pod-old", PodStatus.ERROR), + ]), + _make_revision([ + _make_pod("pod-new", PodStatus.RUNNING), + ]), + ]) + result = _resolve_pod(cclient, 1) + assert result == "pod-new" + + +# =========================================================================== +# _exec_session +# =========================================================================== + +class TestExecSession: + @pytest.mark.asyncio + async def test_sends_resize_and_wrapped_command(self): + from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER + + ws = AsyncMock() + messages = [ + json.dumps({"data": f"noise\n{_BEGIN_MARKER}\nhello world\n{_END_MARKER}:0\n"}), + json.dumps({"Code": 0}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + with patch("centml.cli.shell.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + )) + + exit_code = await _exec_session("wss://test/ws", "fake-token", "ls -la") + + assert exit_code == 0 + assert ws.send.call_count == 2 + resize_msg = json.loads(ws.send.call_args_list[0][0][0]) + assert resize_msg["operation"] == "resize" + cmd_msg = json.loads(ws.send.call_args_list[1][0][0]) + assert cmd_msg["operation"] == "stdin" + assert "ls -la" in cmd_msg["data"] + assert "stty -echo" in cmd_msg["data"] + # Markers use printf octal escapes, so the literal marker + # should NOT appear in the command (prevents echo false-match). + assert _BEGIN_MARKER not in cmd_msg["data"] + assert "CENTML_BEGIN" in cmd_msg["data"] + + @pytest.mark.asyncio + async def test_returns_nonzero_exit_code_from_marker(self): + from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER + + ws = AsyncMock() + messages = [ + json.dumps({"data": f"{_BEGIN_MARKER}\n{_END_MARKER}:42\n"}), + json.dumps({"Code": 42}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + with patch("centml.cli.shell.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + )) + + exit_code = await _exec_session("wss://test/ws", "fake-token", "false") + + assert exit_code == 42 + + @pytest.mark.asyncio + async def test_error_message_returns_one(self): + from centml.cli.shell import _exec_session + + ws = AsyncMock() + messages = [ + json.dumps({"error": "something went wrong"}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + with patch("centml.cli.shell.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + )) + + exit_code = await _exec_session("wss://test/ws", "fake-token", "bad") + + assert exit_code == 1 + + @pytest.mark.asyncio + async def test_filters_noise_before_marker(self): + """Only output between BEGIN and END markers is written to stdout.""" + from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER + + ws = AsyncMock() + messages = [ + json.dumps({"data": f"prompt$ command\n{_BEGIN_MARKER}\nreal output\n{_END_MARKER}:0\n"}), + json.dumps({"Code": 0}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + captured = [] + with patch("centml.cli.shell.websockets") as mock_ws_mod, \ + patch("centml.cli.shell.sys") as mock_sys: + mock_ws_mod.connect = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + )) + mock_sys.stdout.write = lambda s: captured.append(s) + mock_sys.stdout.flush = MagicMock() + mock_sys.stderr.write = MagicMock() + + exit_code = await _exec_session("wss://test/ws", "fake-token", "echo test") + + assert exit_code == 0 + output = "".join(captured) + assert "real output" in output + assert "prompt$" not in output + + +# =========================================================================== +# _interactive_session -- terminal restore +# =========================================================================== + +class TestInteractiveSessionTerminalRestore: + @pytest.mark.asyncio + async def test_restores_terminal_on_exception(self): + from centml.cli.shell import _interactive_session + + with patch("centml.cli.shell.sys") as mock_sys, \ + patch("centml.cli.shell.termios") as mock_termios, \ + patch("centml.cli.shell.tty") as mock_tty, \ + patch("centml.cli.shell.websockets") as mock_ws_mod: + + mock_sys.stdin.fileno.return_value = 0 + mock_termios.tcgetattr.return_value = ["old_settings"] + + mock_ws_mod.connect = MagicMock(return_value=AsyncMock( + __aenter__=AsyncMock(side_effect=ConnectionRefusedError("fail")), + __aexit__=AsyncMock(return_value=False), + )) + + with pytest.raises(ConnectionRefusedError): + await _interactive_session("wss://test/ws", "fake-token") + + mock_termios.tcsetattr.assert_called_once() + restore_call = mock_termios.tcsetattr.call_args + assert restore_call[0][2] == ["old_settings"] + + +# =========================================================================== +# Click commands +# =========================================================================== + +class TestShellCommand: + def test_rejects_non_tty(self): + from centml.cli.shell import shell + from click.testing import CliRunner + runner = CliRunner() + result = runner.invoke(shell, ["123"]) + assert result.exit_code != 0 + assert "terminal" in result.output.lower() or "tty" in result.output.lower() + + def test_shell_option_forwarded(self): + from centml.cli.shell import shell + from click.testing import CliRunner + + with patch("centml.cli.shell._resolve_pod", return_value="pod-a"), \ + patch("centml.cli.shell.get_centml_client") as mock_ctx, \ + patch("centml.cli.shell.auth") as mock_auth, \ + patch("centml.cli.shell.settings") as mock_settings, \ + patch("centml.cli.shell.asyncio") as mock_asyncio, \ + patch("centml.cli.shell.sys") as mock_sys: + + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_auth.get_centml_token.return_value = "token" + mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" + mock_sys.stdin.isatty.return_value = True + mock_asyncio.run.return_value = 0 + + runner = CliRunner() + result = runner.invoke(shell, ["123", "--shell", "bash"]) + + # Verify asyncio.run was called, and the URL contains shell=bash + mock_asyncio.run.assert_called_once() + + def test_pod_option_forwarded(self): + from centml.cli.shell import shell + from click.testing import CliRunner + + with patch("centml.cli.shell._resolve_pod") as mock_resolve, \ + patch("centml.cli.shell.get_centml_client") as mock_ctx, \ + patch("centml.cli.shell.auth") as mock_auth, \ + patch("centml.cli.shell.settings") as mock_settings, \ + patch("centml.cli.shell.asyncio") as mock_asyncio, \ + patch("centml.cli.shell.sys") as mock_sys: + + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_resolve.return_value = "my-pod" + mock_auth.get_centml_token.return_value = "token" + mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" + mock_sys.stdin.isatty.return_value = True + mock_asyncio.run.return_value = 0 + + runner = CliRunner() + result = runner.invoke(shell, ["123", "--pod", "my-pod"]) + + mock_resolve.assert_called_once() + assert mock_resolve.call_args[1].get("pod_name") == "my-pod" or \ + mock_resolve.call_args[0][2] == "my-pod" + + +class TestExecCommand: + def test_passes_command(self): + from centml.cli.shell import exec_cmd + from click.testing import CliRunner + + with patch("centml.cli.shell._resolve_pod", return_value="pod-a"), \ + patch("centml.cli.shell.get_centml_client") as mock_ctx, \ + patch("centml.cli.shell.auth") as mock_auth, \ + patch("centml.cli.shell.settings") as mock_settings, \ + patch("centml.cli.shell.asyncio") as mock_asyncio: + + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_auth.get_centml_token.return_value = "token" + mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" + mock_asyncio.run.return_value = 0 + + runner = CliRunner() + result = runner.invoke(exec_cmd, ["123", "--", "ls", "-la"]) + + mock_asyncio.run.assert_called_once() + + def test_shell_option_forwarded(self): + from centml.cli.shell import exec_cmd + from click.testing import CliRunner + + with patch("centml.cli.shell._resolve_pod", return_value="pod-a"), \ + patch("centml.cli.shell.get_centml_client") as mock_ctx, \ + patch("centml.cli.shell.auth") as mock_auth, \ + patch("centml.cli.shell.settings") as mock_settings, \ + patch("centml.cli.shell.asyncio") as mock_asyncio: + + mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_ctx.return_value.__exit__ = MagicMock(return_value=False) + mock_auth.get_centml_token.return_value = "token" + mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" + mock_asyncio.run.return_value = 0 + + runner = CliRunner() + result = runner.invoke(exec_cmd, ["123", "--shell", "zsh", "--", "echo", "hi"]) + + mock_asyncio.run.assert_called_once() From c62f890307c63da5a00d01a17d132ebb290b3fe3 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 14:14:55 -0400 Subject: [PATCH 02/28] fix: use urlparse for scheme replacement to satisfy CodeQL Replaces str.replace("https://", "wss://") with urllib.parse.urlparse scheme replacement to avoid CodeQL py/incomplete-url-substring-sanitization. --- centml/cli/shell.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 3dbce60..0365c7f 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -27,7 +27,9 @@ def _build_ws_url(api_url, deployment_id, pod_name, shell=None): """Build the WebSocket URL for a terminal connection.""" - ws_base = api_url.replace("https://", "wss://").replace("http://", "ws://") + parsed = urllib.parse.urlparse(api_url) + ws_scheme = "wss" if parsed.scheme == "https" else "ws" + ws_base = parsed._replace(scheme=ws_scheme).geturl() url = f"{ws_base}/deployments/{deployment_id}/terminal?pod={urllib.parse.quote(pod_name)}" if shell: url += f"&shell={urllib.parse.quote(shell)}" From 095bd1e1469ab544cee8781562b3e6fa1726317c Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 14:21:31 -0400 Subject: [PATCH 03/28] fix: apply black formatting and fix CodeQL url.startswith alert Replace url.startswith() assertions with urllib.parse.urlparse() checks to satisfy CodeQL py/incomplete-url-substring-sanitization rule. Reformat both shell.py and test_shell.py with black. --- centml/cli/shell.py | 116 ++++++++++++------- tests/test_shell.py | 269 +++++++++++++++++++++++++++++--------------- 2 files changed, 256 insertions(+), 129 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 0365c7f..837d6d9 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -52,8 +52,8 @@ def _resolve_pod(cclient, deployment_id, pod_name=None): """ status = cclient.get_status_v3(deployment_id) running_pods = [] - for revision in (status.revision_pod_details_list or []): - for pod in (revision.pod_details_list or []): + for revision in status.revision_pod_details_list or []: + for pod in revision.pod_details_list or []: if pod.status == PodStatus.RUNNING and pod.name: running_pods.append(pod.name) @@ -122,12 +122,16 @@ def _on_stdin_ready(): except asyncio.TimeoutError: continue rows, cols = shutil.get_terminal_size() - await ws.send(json.dumps({ - "operation": "stdin", - "data": data.decode("utf-8", errors="replace"), - "rows": rows, - "cols": cols, - })) + await ws.send( + json.dumps( + { + "operation": "stdin", + "data": data.decode("utf-8", errors="replace"), + "rows": rows, + "cols": cols, + } + ) + ) finally: loop.remove_reader(stdin_fd) @@ -157,21 +161,31 @@ async def _interactive_session(ws_url, token): headers = {"Authorization": f"Bearer {token}"} async with websockets.connect(ws_url, additional_headers=headers) as ws: - await ws.send(json.dumps({ - "operation": "resize", - "rows": rows, - "cols": cols, - })) + await ws.send( + json.dumps( + { + "operation": "resize", + "rows": rows, + "cols": cols, + } + ) + ) loop = asyncio.get_running_loop() def _on_resize(): r, c = shutil.get_terminal_size() - asyncio.ensure_future(ws.send(json.dumps({ - "operation": "resize", - "rows": r, - "cols": c, - }))) + asyncio.ensure_future( + ws.send( + json.dumps( + { + "operation": "resize", + "rows": r, + "cols": c, + } + ) + ) + ) loop.add_signal_handler(signal.SIGWINCH, _on_resize) @@ -211,11 +225,15 @@ async def _exec_session(ws_url, token, command): headers = {"Authorization": f"Bearer {token}"} async with websockets.connect(ws_url, additional_headers=headers) as ws: - await ws.send(json.dumps({ - "operation": "resize", - "rows": rows, - "cols": cols, - })) + await ws.send( + json.dumps( + { + "operation": "resize", + "rows": rows, + "cols": cols, + } + ) + ) # Suppress echo/bracketed-paste, emit begin marker, run command, # emit end marker with exit code, then exit. @@ -230,12 +248,16 @@ async def _exec_session(ws_url, token, command): f" exit $__ec\n" ) - await ws.send(json.dumps({ - "operation": "stdin", - "data": wrapped, - "rows": rows, - "cols": cols, - })) + await ws.send( + json.dumps( + { + "operation": "stdin", + "data": wrapped, + "rows": rows, + "cols": cols, + } + ) + ) exit_code = 0 buffer = "" @@ -275,9 +297,16 @@ async def _exec_session(ws_url, token, command): @click.command(help="Open an interactive shell to a deployment pod") @click.argument("deployment_id", type=int) -@click.option("--pod", default=None, help="Specific pod name (auto-selects first running pod)") -@click.option("--shell", "shell_type", default=None, - type=click.Choice(["bash", "sh", "zsh"]), help="Shell type") +@click.option( + "--pod", default=None, help="Specific pod name (auto-selects first running pod)" +) +@click.option( + "--shell", + "shell_type", + default=None, + type=click.Choice(["bash", "sh", "zsh"]), + help="Shell type", +) @handle_exception def shell(deployment_id, pod, shell_type): if not sys.stdin.isatty(): @@ -286,25 +315,36 @@ def shell(deployment_id, pod, shell_type): with get_centml_client() as cclient: pod_name = _resolve_pod(cclient, deployment_id, pod) - ws_url = _build_ws_url(settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type) + ws_url = _build_ws_url( + settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type + ) token = auth.get_centml_token() exit_code = asyncio.run(_interactive_session(ws_url, token)) sys.exit(exit_code) -@click.command(help="Execute a command in a deployment pod", - context_settings=dict(ignore_unknown_options=True)) +@click.command( + help="Execute a command in a deployment pod", + context_settings=dict(ignore_unknown_options=True), +) @click.argument("deployment_id", type=int) @click.argument("command", nargs=-1, required=True, type=click.UNPROCESSED) @click.option("--pod", default=None, help="Specific pod name") -@click.option("--shell", "shell_type", default=None, - type=click.Choice(["bash", "sh", "zsh"]), help="Shell type") +@click.option( + "--shell", + "shell_type", + default=None, + type=click.Choice(["bash", "sh", "zsh"]), + help="Shell type", +) @handle_exception def exec_cmd(deployment_id, command, pod, shell_type): with get_centml_client() as cclient: pod_name = _resolve_pod(cclient, deployment_id, pod) - ws_url = _build_ws_url(settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type) + ws_url = _build_ws_url( + settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type + ) token = auth.get_centml_token() cmd_str = " ".join(command) exit_code = asyncio.run(_exec_session(ws_url, token, cmd_str)) diff --git a/tests/test_shell.py b/tests/test_shell.py index dfb6b52..6e49fd7 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -13,9 +13,11 @@ def _async_iter_from_list(items): """Create an async iterator from a list of items.""" + async def _aiter(): for item in items: yield item + return _aiter() @@ -23,6 +25,7 @@ async def _aiter(): # Helpers to build mock status responses # --------------------------------------------------------------------------- + def _make_pod(name, status=PodStatus.RUNNING): pod = MagicMock(spec=PodDetails) pod.name = name @@ -46,45 +49,58 @@ def _make_status_response(revisions): # _build_ws_url # =========================================================================== + class TestStripAnsi: def test_strips_csi_sequences(self): from centml.cli.shell import _strip_ansi + assert _strip_ansi("\x1b[?2004htext\x1b[0m") == "text" def test_preserves_plain_text(self): from centml.cli.shell import _strip_ansi + assert _strip_ansi("hello world") == "hello world" class TestBuildWsUrl: def test_https_to_wss(self): from centml.cli.shell import _build_ws_url + url = _build_ws_url("https://api.centml.com", 123, "my-pod-abc") - assert url.startswith("wss://api.centml.com/") + parsed = urllib.parse.urlparse(url) + assert parsed.scheme == "wss" + assert parsed.netloc == "api.centml.com" def test_http_to_ws(self): from centml.cli.shell import _build_ws_url + url = _build_ws_url("http://localhost:16000", 42, "pod-1") - assert url.startswith("ws://localhost:16000/") + parsed = urllib.parse.urlparse(url) + assert parsed.scheme == "ws" + assert parsed.netloc == "localhost:16000" def test_contains_deployment_id_and_pod(self): from centml.cli.shell import _build_ws_url + url = _build_ws_url("https://api.centml.com", 99, "pod-xyz") assert "/deployments/99/terminal" in url assert "pod=pod-xyz" in url def test_with_shell(self): from centml.cli.shell import _build_ws_url + url = _build_ws_url("https://api.centml.com", 1, "p", shell="bash") assert "shell=bash" in url def test_without_shell(self): from centml.cli.shell import _build_ws_url + url = _build_ws_url("https://api.centml.com", 1, "p") assert "shell=" not in url def test_encodes_pod_name(self): from centml.cli.shell import _build_ws_url + url = _build_ws_url("https://api.centml.com", 1, "pod name/special") assert "pod%20name" in url or "pod+name" in url @@ -93,55 +109,77 @@ def test_encodes_pod_name(self): # _resolve_pod # =========================================================================== + class TestResolvePod: def test_selects_first_running(self): from centml.cli.shell import _resolve_pod + cclient = MagicMock() - cclient.get_status_v3.return_value = _make_status_response([ - _make_revision([ - _make_pod("pod-a", PodStatus.RUNNING), - _make_pod("pod-b", PodStatus.RUNNING), - ]) - ]) + cclient.get_status_v3.return_value = _make_status_response( + [ + _make_revision( + [ + _make_pod("pod-a", PodStatus.RUNNING), + _make_pod("pod-b", PodStatus.RUNNING), + ] + ) + ] + ) result = _resolve_pod(cclient, 1) assert result == "pod-a" def test_raises_no_running_pods(self): from centml.cli.shell import _resolve_pod + cclient = MagicMock() - cclient.get_status_v3.return_value = _make_status_response([ - _make_revision([ - _make_pod("pod-err", PodStatus.ERROR), - ]) - ]) + cclient.get_status_v3.return_value = _make_status_response( + [ + _make_revision( + [ + _make_pod("pod-err", PodStatus.ERROR), + ] + ) + ] + ) with pytest.raises(click.ClickException, match="No running pods"): _resolve_pod(cclient, 1) def test_raises_specified_pod_not_found(self): from centml.cli.shell import _resolve_pod + cclient = MagicMock() - cclient.get_status_v3.return_value = _make_status_response([ - _make_revision([ - _make_pod("pod-a", PodStatus.RUNNING), - ]) - ]) + cclient.get_status_v3.return_value = _make_status_response( + [ + _make_revision( + [ + _make_pod("pod-a", PodStatus.RUNNING), + ] + ) + ] + ) with pytest.raises(click.ClickException, match="pod-missing"): _resolve_pod(cclient, 1, pod_name="pod-missing") def test_returns_specified_pod(self): from centml.cli.shell import _resolve_pod + cclient = MagicMock() - cclient.get_status_v3.return_value = _make_status_response([ - _make_revision([ - _make_pod("pod-a", PodStatus.RUNNING), - _make_pod("pod-b", PodStatus.RUNNING), - ]) - ]) + cclient.get_status_v3.return_value = _make_status_response( + [ + _make_revision( + [ + _make_pod("pod-a", PodStatus.RUNNING), + _make_pod("pod-b", PodStatus.RUNNING), + ] + ) + ] + ) result = _resolve_pod(cclient, 1, pod_name="pod-b") assert result == "pod-b" def test_empty_revision_list(self): from centml.cli.shell import _resolve_pod + cclient = MagicMock() cclient.get_status_v3.return_value = _make_status_response([]) with pytest.raises(click.ClickException, match="No running pods"): @@ -149,6 +187,7 @@ def test_empty_revision_list(self): def test_none_revision_list(self): from centml.cli.shell import _resolve_pod + cclient = MagicMock() cclient.get_status_v3.return_value = _make_status_response(None) cclient.get_status_v3.return_value.revision_pod_details_list = None @@ -157,27 +196,39 @@ def test_none_revision_list(self): def test_skips_pods_without_name(self): from centml.cli.shell import _resolve_pod + cclient = MagicMock() - cclient.get_status_v3.return_value = _make_status_response([ - _make_revision([ - _make_pod(None, PodStatus.RUNNING), - _make_pod("pod-real", PodStatus.RUNNING), - ]) - ]) + cclient.get_status_v3.return_value = _make_status_response( + [ + _make_revision( + [ + _make_pod(None, PodStatus.RUNNING), + _make_pod("pod-real", PodStatus.RUNNING), + ] + ) + ] + ) result = _resolve_pod(cclient, 1) assert result == "pod-real" def test_multiple_revisions(self): from centml.cli.shell import _resolve_pod + cclient = MagicMock() - cclient.get_status_v3.return_value = _make_status_response([ - _make_revision([ - _make_pod("pod-old", PodStatus.ERROR), - ]), - _make_revision([ - _make_pod("pod-new", PodStatus.RUNNING), - ]), - ]) + cclient.get_status_v3.return_value = _make_status_response( + [ + _make_revision( + [ + _make_pod("pod-old", PodStatus.ERROR), + ] + ), + _make_revision( + [ + _make_pod("pod-new", PodStatus.RUNNING), + ] + ), + ] + ) result = _resolve_pod(cclient, 1) assert result == "pod-new" @@ -186,6 +237,7 @@ def test_multiple_revisions(self): # _exec_session # =========================================================================== + class TestExecSession: @pytest.mark.asyncio async def test_sends_resize_and_wrapped_command(self): @@ -193,16 +245,20 @@ async def test_sends_resize_and_wrapped_command(self): ws = AsyncMock() messages = [ - json.dumps({"data": f"noise\n{_BEGIN_MARKER}\nhello world\n{_END_MARKER}:0\n"}), + json.dumps( + {"data": f"noise\n{_BEGIN_MARKER}\nhello world\n{_END_MARKER}:0\n"} + ), json.dumps({"Code": 0}), ] ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) with patch("centml.cli.shell.websockets") as mock_ws_mod: - mock_ws_mod.connect = MagicMock(return_value=AsyncMock( - __aenter__=AsyncMock(return_value=ws), - __aexit__=AsyncMock(return_value=False), - )) + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) exit_code = await _exec_session("wss://test/ws", "fake-token", "ls -la") @@ -231,10 +287,12 @@ async def test_returns_nonzero_exit_code_from_marker(self): ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) with patch("centml.cli.shell.websockets") as mock_ws_mod: - mock_ws_mod.connect = MagicMock(return_value=AsyncMock( - __aenter__=AsyncMock(return_value=ws), - __aexit__=AsyncMock(return_value=False), - )) + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) exit_code = await _exec_session("wss://test/ws", "fake-token", "false") @@ -251,10 +309,12 @@ async def test_error_message_returns_one(self): ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) with patch("centml.cli.shell.websockets") as mock_ws_mod: - mock_ws_mod.connect = MagicMock(return_value=AsyncMock( - __aenter__=AsyncMock(return_value=ws), - __aexit__=AsyncMock(return_value=False), - )) + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) exit_code = await _exec_session("wss://test/ws", "fake-token", "bad") @@ -267,18 +327,25 @@ async def test_filters_noise_before_marker(self): ws = AsyncMock() messages = [ - json.dumps({"data": f"prompt$ command\n{_BEGIN_MARKER}\nreal output\n{_END_MARKER}:0\n"}), + json.dumps( + { + "data": f"prompt$ command\n{_BEGIN_MARKER}\nreal output\n{_END_MARKER}:0\n" + } + ), json.dumps({"Code": 0}), ] ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) captured = [] - with patch("centml.cli.shell.websockets") as mock_ws_mod, \ - patch("centml.cli.shell.sys") as mock_sys: - mock_ws_mod.connect = MagicMock(return_value=AsyncMock( - __aenter__=AsyncMock(return_value=ws), - __aexit__=AsyncMock(return_value=False), - )) + with patch("centml.cli.shell.websockets") as mock_ws_mod, patch( + "centml.cli.shell.sys" + ) as mock_sys: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) mock_sys.stdout.write = lambda s: captured.append(s) mock_sys.stdout.flush = MagicMock() mock_sys.stderr.write = MagicMock() @@ -295,23 +362,27 @@ async def test_filters_noise_before_marker(self): # _interactive_session -- terminal restore # =========================================================================== + class TestInteractiveSessionTerminalRestore: @pytest.mark.asyncio async def test_restores_terminal_on_exception(self): from centml.cli.shell import _interactive_session - with patch("centml.cli.shell.sys") as mock_sys, \ - patch("centml.cli.shell.termios") as mock_termios, \ - patch("centml.cli.shell.tty") as mock_tty, \ - patch("centml.cli.shell.websockets") as mock_ws_mod: + with patch("centml.cli.shell.sys") as mock_sys, patch( + "centml.cli.shell.termios" + ) as mock_termios, patch("centml.cli.shell.tty") as mock_tty, patch( + "centml.cli.shell.websockets" + ) as mock_ws_mod: mock_sys.stdin.fileno.return_value = 0 mock_termios.tcgetattr.return_value = ["old_settings"] - mock_ws_mod.connect = MagicMock(return_value=AsyncMock( - __aenter__=AsyncMock(side_effect=ConnectionRefusedError("fail")), - __aexit__=AsyncMock(return_value=False), - )) + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(side_effect=ConnectionRefusedError("fail")), + __aexit__=AsyncMock(return_value=False), + ) + ) with pytest.raises(ConnectionRefusedError): await _interactive_session("wss://test/ws", "fake-token") @@ -325,10 +396,12 @@ async def test_restores_terminal_on_exception(self): # Click commands # =========================================================================== + class TestShellCommand: def test_rejects_non_tty(self): from centml.cli.shell import shell from click.testing import CliRunner + runner = CliRunner() result = runner.invoke(shell, ["123"]) assert result.exit_code != 0 @@ -338,12 +411,15 @@ def test_shell_option_forwarded(self): from centml.cli.shell import shell from click.testing import CliRunner - with patch("centml.cli.shell._resolve_pod", return_value="pod-a"), \ - patch("centml.cli.shell.get_centml_client") as mock_ctx, \ - patch("centml.cli.shell.auth") as mock_auth, \ - patch("centml.cli.shell.settings") as mock_settings, \ - patch("centml.cli.shell.asyncio") as mock_asyncio, \ - patch("centml.cli.shell.sys") as mock_sys: + with patch("centml.cli.shell._resolve_pod", return_value="pod-a"), patch( + "centml.cli.shell.get_centml_client" + ) as mock_ctx, patch("centml.cli.shell.auth") as mock_auth, patch( + "centml.cli.shell.settings" + ) as mock_settings, patch( + "centml.cli.shell.asyncio" + ) as mock_asyncio, patch( + "centml.cli.shell.sys" + ) as mock_sys: mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) @@ -362,12 +438,15 @@ def test_pod_option_forwarded(self): from centml.cli.shell import shell from click.testing import CliRunner - with patch("centml.cli.shell._resolve_pod") as mock_resolve, \ - patch("centml.cli.shell.get_centml_client") as mock_ctx, \ - patch("centml.cli.shell.auth") as mock_auth, \ - patch("centml.cli.shell.settings") as mock_settings, \ - patch("centml.cli.shell.asyncio") as mock_asyncio, \ - patch("centml.cli.shell.sys") as mock_sys: + with patch("centml.cli.shell._resolve_pod") as mock_resolve, patch( + "centml.cli.shell.get_centml_client" + ) as mock_ctx, patch("centml.cli.shell.auth") as mock_auth, patch( + "centml.cli.shell.settings" + ) as mock_settings, patch( + "centml.cli.shell.asyncio" + ) as mock_asyncio, patch( + "centml.cli.shell.sys" + ) as mock_sys: mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) @@ -381,8 +460,10 @@ def test_pod_option_forwarded(self): result = runner.invoke(shell, ["123", "--pod", "my-pod"]) mock_resolve.assert_called_once() - assert mock_resolve.call_args[1].get("pod_name") == "my-pod" or \ - mock_resolve.call_args[0][2] == "my-pod" + assert ( + mock_resolve.call_args[1].get("pod_name") == "my-pod" + or mock_resolve.call_args[0][2] == "my-pod" + ) class TestExecCommand: @@ -390,11 +471,13 @@ def test_passes_command(self): from centml.cli.shell import exec_cmd from click.testing import CliRunner - with patch("centml.cli.shell._resolve_pod", return_value="pod-a"), \ - patch("centml.cli.shell.get_centml_client") as mock_ctx, \ - patch("centml.cli.shell.auth") as mock_auth, \ - patch("centml.cli.shell.settings") as mock_settings, \ - patch("centml.cli.shell.asyncio") as mock_asyncio: + with patch("centml.cli.shell._resolve_pod", return_value="pod-a"), patch( + "centml.cli.shell.get_centml_client" + ) as mock_ctx, patch("centml.cli.shell.auth") as mock_auth, patch( + "centml.cli.shell.settings" + ) as mock_settings, patch( + "centml.cli.shell.asyncio" + ) as mock_asyncio: mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) @@ -411,11 +494,13 @@ def test_shell_option_forwarded(self): from centml.cli.shell import exec_cmd from click.testing import CliRunner - with patch("centml.cli.shell._resolve_pod", return_value="pod-a"), \ - patch("centml.cli.shell.get_centml_client") as mock_ctx, \ - patch("centml.cli.shell.auth") as mock_auth, \ - patch("centml.cli.shell.settings") as mock_settings, \ - patch("centml.cli.shell.asyncio") as mock_asyncio: + with patch("centml.cli.shell._resolve_pod", return_value="pod-a"), patch( + "centml.cli.shell.get_centml_client" + ) as mock_ctx, patch("centml.cli.shell.auth") as mock_auth, patch( + "centml.cli.shell.settings" + ) as mock_settings, patch( + "centml.cli.shell.asyncio" + ) as mock_asyncio: mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) @@ -424,6 +509,8 @@ def test_shell_option_forwarded(self): mock_asyncio.run.return_value = 0 runner = CliRunner() - result = runner.invoke(exec_cmd, ["123", "--shell", "zsh", "--", "echo", "hi"]) + result = runner.invoke( + exec_cmd, ["123", "--shell", "zsh", "--", "echo", "hi"] + ) mock_asyncio.run.assert_called_once() From d03f493da7d2cc56284265c2e1e6be4cc5c70315 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 14:27:18 -0400 Subject: [PATCH 04/28] style: condense multiline expressions for readability --- centml/cli/shell.py | 92 ++++++----------------------------- tests/test_shell.py | 115 ++++++++------------------------------------ 2 files changed, 35 insertions(+), 172 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 837d6d9..9e5e5a0 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -58,22 +58,17 @@ def _resolve_pod(cclient, deployment_id, pod_name=None): running_pods.append(pod.name) if not running_pods: - raise click.ClickException( - f"No running pods found for deployment {deployment_id}" - ) + raise click.ClickException(f"No running pods found for deployment {deployment_id}") if pod_name is not None: if pod_name not in running_pods: pods_list = ", ".join(running_pods) - raise click.ClickException( - f"Pod '{pod_name}' not found. Available running pods: {pods_list}" - ) + raise click.ClickException(f"Pod '{pod_name}' not found. Available running pods: {pods_list}") return pod_name if len(running_pods) > 1: click.echo( - f"Multiple running pods found, connecting to {running_pods[0]}. " - f"Use --pod to specify a different pod.", + f"Multiple running pods found, connecting to {running_pods[0]}. " f"Use --pod to specify a different pod.", err=True, ) return running_pods[0] @@ -135,10 +130,7 @@ def _on_stdin_ready(): finally: loop.remove_reader(stdin_fd) - tasks = [ - asyncio.create_task(_read_ws()), - asyncio.create_task(_read_stdin()), - ] + tasks = [asyncio.create_task(_read_ws()), asyncio.create_task(_read_stdin())] done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) for t in pending: t.cancel() @@ -161,31 +153,13 @@ async def _interactive_session(ws_url, token): headers = {"Authorization": f"Bearer {token}"} async with websockets.connect(ws_url, additional_headers=headers) as ws: - await ws.send( - json.dumps( - { - "operation": "resize", - "rows": rows, - "cols": cols, - } - ) - ) + await ws.send(json.dumps({"operation": "resize", "rows": rows, "cols": cols})) loop = asyncio.get_running_loop() def _on_resize(): r, c = shutil.get_terminal_size() - asyncio.ensure_future( - ws.send( - json.dumps( - { - "operation": "resize", - "rows": r, - "cols": c, - } - ) - ) - ) + asyncio.ensure_future(ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c}))) loop.add_signal_handler(signal.SIGWINCH, _on_resize) @@ -225,15 +199,7 @@ async def _exec_session(ws_url, token, command): headers = {"Authorization": f"Bearer {token}"} async with websockets.connect(ws_url, additional_headers=headers) as ws: - await ws.send( - json.dumps( - { - "operation": "resize", - "rows": rows, - "cols": cols, - } - ) - ) + await ws.send(json.dumps({"operation": "resize", "rows": rows, "cols": cols})) # Suppress echo/bracketed-paste, emit begin marker, run command, # emit end marker with exit code, then exit. @@ -248,16 +214,7 @@ async def _exec_session(ws_url, token, command): f" exit $__ec\n" ) - await ws.send( - json.dumps( - { - "operation": "stdin", - "data": wrapped, - "rows": rows, - "cols": cols, - } - ) - ) + await ws.send(json.dumps({"operation": "stdin", "data": wrapped, "rows": rows, "cols": cols})) exit_code = 0 buffer = "" @@ -297,16 +254,8 @@ async def _exec_session(ws_url, token, command): @click.command(help="Open an interactive shell to a deployment pod") @click.argument("deployment_id", type=int) -@click.option( - "--pod", default=None, help="Specific pod name (auto-selects first running pod)" -) -@click.option( - "--shell", - "shell_type", - default=None, - type=click.Choice(["bash", "sh", "zsh"]), - help="Shell type", -) +@click.option("--pod", default=None, help="Specific pod name (auto-selects first running pod)") +@click.option("--shell", "shell_type", default=None, type=click.Choice(["bash", "sh", "zsh"]), help="Shell type") @handle_exception def shell(deployment_id, pod, shell_type): if not sys.stdin.isatty(): @@ -315,36 +264,23 @@ def shell(deployment_id, pod, shell_type): with get_centml_client() as cclient: pod_name = _resolve_pod(cclient, deployment_id, pod) - ws_url = _build_ws_url( - settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type - ) + ws_url = _build_ws_url(settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type) token = auth.get_centml_token() exit_code = asyncio.run(_interactive_session(ws_url, token)) sys.exit(exit_code) -@click.command( - help="Execute a command in a deployment pod", - context_settings=dict(ignore_unknown_options=True), -) +@click.command(help="Execute a command in a deployment pod", context_settings=dict(ignore_unknown_options=True)) @click.argument("deployment_id", type=int) @click.argument("command", nargs=-1, required=True, type=click.UNPROCESSED) @click.option("--pod", default=None, help="Specific pod name") -@click.option( - "--shell", - "shell_type", - default=None, - type=click.Choice(["bash", "sh", "zsh"]), - help="Shell type", -) +@click.option("--shell", "shell_type", default=None, type=click.Choice(["bash", "sh", "zsh"]), help="Shell type") @handle_exception def exec_cmd(deployment_id, command, pod, shell_type): with get_centml_client() as cclient: pod_name = _resolve_pod(cclient, deployment_id, pod) - ws_url = _build_ws_url( - settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type - ) + ws_url = _build_ws_url(settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type) token = auth.get_centml_token() cmd_str = " ".join(command) exit_code = asyncio.run(_exec_session(ws_url, token, cmd_str)) diff --git a/tests/test_shell.py b/tests/test_shell.py index 6e49fd7..28eb157 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -116,14 +116,7 @@ def test_selects_first_running(self): cclient = MagicMock() cclient.get_status_v3.return_value = _make_status_response( - [ - _make_revision( - [ - _make_pod("pod-a", PodStatus.RUNNING), - _make_pod("pod-b", PodStatus.RUNNING), - ] - ) - ] + [_make_revision([_make_pod("pod-a", PodStatus.RUNNING), _make_pod("pod-b", PodStatus.RUNNING)])] ) result = _resolve_pod(cclient, 1) assert result == "pod-a" @@ -133,13 +126,7 @@ def test_raises_no_running_pods(self): cclient = MagicMock() cclient.get_status_v3.return_value = _make_status_response( - [ - _make_revision( - [ - _make_pod("pod-err", PodStatus.ERROR), - ] - ) - ] + [_make_revision([_make_pod("pod-err", PodStatus.ERROR)])] ) with pytest.raises(click.ClickException, match="No running pods"): _resolve_pod(cclient, 1) @@ -149,13 +136,7 @@ def test_raises_specified_pod_not_found(self): cclient = MagicMock() cclient.get_status_v3.return_value = _make_status_response( - [ - _make_revision( - [ - _make_pod("pod-a", PodStatus.RUNNING), - ] - ) - ] + [_make_revision([_make_pod("pod-a", PodStatus.RUNNING)])] ) with pytest.raises(click.ClickException, match="pod-missing"): _resolve_pod(cclient, 1, pod_name="pod-missing") @@ -165,14 +146,7 @@ def test_returns_specified_pod(self): cclient = MagicMock() cclient.get_status_v3.return_value = _make_status_response( - [ - _make_revision( - [ - _make_pod("pod-a", PodStatus.RUNNING), - _make_pod("pod-b", PodStatus.RUNNING), - ] - ) - ] + [_make_revision([_make_pod("pod-a", PodStatus.RUNNING), _make_pod("pod-b", PodStatus.RUNNING)])] ) result = _resolve_pod(cclient, 1, pod_name="pod-b") assert result == "pod-b" @@ -199,14 +173,7 @@ def test_skips_pods_without_name(self): cclient = MagicMock() cclient.get_status_v3.return_value = _make_status_response( - [ - _make_revision( - [ - _make_pod(None, PodStatus.RUNNING), - _make_pod("pod-real", PodStatus.RUNNING), - ] - ) - ] + [_make_revision([_make_pod(None, PodStatus.RUNNING), _make_pod("pod-real", PodStatus.RUNNING)])] ) result = _resolve_pod(cclient, 1) assert result == "pod-real" @@ -217,16 +184,8 @@ def test_multiple_revisions(self): cclient = MagicMock() cclient.get_status_v3.return_value = _make_status_response( [ - _make_revision( - [ - _make_pod("pod-old", PodStatus.ERROR), - ] - ), - _make_revision( - [ - _make_pod("pod-new", PodStatus.RUNNING), - ] - ), + _make_revision([_make_pod("pod-old", PodStatus.ERROR)]), + _make_revision([_make_pod("pod-new", PodStatus.RUNNING)]), ] ) result = _resolve_pod(cclient, 1) @@ -245,19 +204,14 @@ async def test_sends_resize_and_wrapped_command(self): ws = AsyncMock() messages = [ - json.dumps( - {"data": f"noise\n{_BEGIN_MARKER}\nhello world\n{_END_MARKER}:0\n"} - ), + json.dumps({"data": f"noise\n{_BEGIN_MARKER}\nhello world\n{_END_MARKER}:0\n"}), json.dumps({"Code": 0}), ] ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) with patch("centml.cli.shell.websockets") as mock_ws_mod: mock_ws_mod.connect = MagicMock( - return_value=AsyncMock( - __aenter__=AsyncMock(return_value=ws), - __aexit__=AsyncMock(return_value=False), - ) + return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) ) exit_code = await _exec_session("wss://test/ws", "fake-token", "ls -la") @@ -280,18 +234,12 @@ async def test_returns_nonzero_exit_code_from_marker(self): from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER ws = AsyncMock() - messages = [ - json.dumps({"data": f"{_BEGIN_MARKER}\n{_END_MARKER}:42\n"}), - json.dumps({"Code": 42}), - ] + messages = [json.dumps({"data": f"{_BEGIN_MARKER}\n{_END_MARKER}:42\n"}), json.dumps({"Code": 42})] ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) with patch("centml.cli.shell.websockets") as mock_ws_mod: mock_ws_mod.connect = MagicMock( - return_value=AsyncMock( - __aenter__=AsyncMock(return_value=ws), - __aexit__=AsyncMock(return_value=False), - ) + return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) ) exit_code = await _exec_session("wss://test/ws", "fake-token", "false") @@ -303,17 +251,12 @@ async def test_error_message_returns_one(self): from centml.cli.shell import _exec_session ws = AsyncMock() - messages = [ - json.dumps({"error": "something went wrong"}), - ] + messages = [json.dumps({"error": "something went wrong"})] ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) with patch("centml.cli.shell.websockets") as mock_ws_mod: mock_ws_mod.connect = MagicMock( - return_value=AsyncMock( - __aenter__=AsyncMock(return_value=ws), - __aexit__=AsyncMock(return_value=False), - ) + return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) ) exit_code = await _exec_session("wss://test/ws", "fake-token", "bad") @@ -327,24 +270,15 @@ async def test_filters_noise_before_marker(self): ws = AsyncMock() messages = [ - json.dumps( - { - "data": f"prompt$ command\n{_BEGIN_MARKER}\nreal output\n{_END_MARKER}:0\n" - } - ), + json.dumps({"data": f"prompt$ command\n{_BEGIN_MARKER}\nreal output\n{_END_MARKER}:0\n"}), json.dumps({"Code": 0}), ] ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) captured = [] - with patch("centml.cli.shell.websockets") as mock_ws_mod, patch( - "centml.cli.shell.sys" - ) as mock_sys: + with patch("centml.cli.shell.websockets") as mock_ws_mod, patch("centml.cli.shell.sys") as mock_sys: mock_ws_mod.connect = MagicMock( - return_value=AsyncMock( - __aenter__=AsyncMock(return_value=ws), - __aexit__=AsyncMock(return_value=False), - ) + return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) ) mock_sys.stdout.write = lambda s: captured.append(s) mock_sys.stdout.flush = MagicMock() @@ -368,11 +302,9 @@ class TestInteractiveSessionTerminalRestore: async def test_restores_terminal_on_exception(self): from centml.cli.shell import _interactive_session - with patch("centml.cli.shell.sys") as mock_sys, patch( - "centml.cli.shell.termios" - ) as mock_termios, patch("centml.cli.shell.tty") as mock_tty, patch( - "centml.cli.shell.websockets" - ) as mock_ws_mod: + with patch("centml.cli.shell.sys") as mock_sys, patch("centml.cli.shell.termios") as mock_termios, patch( + "centml.cli.shell.tty" + ) as mock_tty, patch("centml.cli.shell.websockets") as mock_ws_mod: mock_sys.stdin.fileno.return_value = 0 mock_termios.tcgetattr.return_value = ["old_settings"] @@ -460,10 +392,7 @@ def test_pod_option_forwarded(self): result = runner.invoke(shell, ["123", "--pod", "my-pod"]) mock_resolve.assert_called_once() - assert ( - mock_resolve.call_args[1].get("pod_name") == "my-pod" - or mock_resolve.call_args[0][2] == "my-pod" - ) + assert mock_resolve.call_args[1].get("pod_name") == "my-pod" or mock_resolve.call_args[0][2] == "my-pod" class TestExecCommand: @@ -509,8 +438,6 @@ def test_shell_option_forwarded(self): mock_asyncio.run.return_value = 0 runner = CliRunner() - result = runner.invoke( - exec_cmd, ["123", "--shell", "zsh", "--", "echo", "hi"] - ) + result = runner.invoke(exec_cmd, ["123", "--shell", "zsh", "--", "echo", "hi"]) mock_asyncio.run.assert_called_once() From 3dadb753b0cf6114d5b6a08b314a16246e70ac20 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 14:47:01 -0400 Subject: [PATCH 05/28] fix: resolve pylint warnings in shell.py and test_shell.py --- centml/cli/shell.py | 24 +++++++++--------------- tests/test_shell.py | 16 +++++++--------- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 9e5e5a0..9bd8ea6 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -6,33 +6,27 @@ import shutil import signal import sys +import termios +import tty import urllib.parse import click +import websockets -from centml.sdk import PodStatus -from centml.sdk import auth +from centml.cli.cluster import handle_exception +from centml.sdk import PodStatus, auth from centml.sdk.api import get_centml_client from centml.sdk.config import settings -from centml.cli.cluster import handle_exception - -# Lazy-import to keep module loadable without websockets installed at import time, -# and to allow tests to patch the module attribute easily. -import websockets - -# These are only available on Unix; guarded at command level via isatty check. -import termios -import tty -def _build_ws_url(api_url, deployment_id, pod_name, shell=None): +def _build_ws_url(api_url, deployment_id, pod_name, shell_type=None): """Build the WebSocket URL for a terminal connection.""" parsed = urllib.parse.urlparse(api_url) ws_scheme = "wss" if parsed.scheme == "https" else "ws" ws_base = parsed._replace(scheme=ws_scheme).geturl() url = f"{ws_base}/deployments/{deployment_id}/terminal?pod={urllib.parse.quote(pod_name)}" - if shell: - url += f"&shell={urllib.parse.quote(shell)}" + if shell_type: + url += f"&shell={urllib.parse.quote(shell_type)}" return url @@ -270,7 +264,7 @@ def shell(deployment_id, pod, shell_type): sys.exit(exit_code) -@click.command(help="Execute a command in a deployment pod", context_settings=dict(ignore_unknown_options=True)) +@click.command(help="Execute a command in a deployment pod", context_settings={"ignore_unknown_options": True}) @click.argument("deployment_id", type=int) @click.argument("command", nargs=-1, required=True, type=click.UNPROCESSED) @click.option("--pod", default=None, help="Specific pod name") diff --git a/tests/test_shell.py b/tests/test_shell.py index 28eb157..351b008 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -1,9 +1,8 @@ """Tests for centml.cli.shell -- CLI terminal access commands.""" import json -import sys import urllib.parse -from unittest.mock import AsyncMock, MagicMock, patch, call +from unittest.mock import AsyncMock, MagicMock, patch import click import pytest @@ -89,7 +88,7 @@ def test_contains_deployment_id_and_pod(self): def test_with_shell(self): from centml.cli.shell import _build_ws_url - url = _build_ws_url("https://api.centml.com", 1, "p", shell="bash") + url = _build_ws_url("https://api.centml.com", 1, "p", shell_type="bash") assert "shell=bash" in url def test_without_shell(self): @@ -304,7 +303,7 @@ async def test_restores_terminal_on_exception(self): with patch("centml.cli.shell.sys") as mock_sys, patch("centml.cli.shell.termios") as mock_termios, patch( "centml.cli.shell.tty" - ) as mock_tty, patch("centml.cli.shell.websockets") as mock_ws_mod: + ), patch("centml.cli.shell.websockets") as mock_ws_mod: mock_sys.stdin.fileno.return_value = 0 mock_termios.tcgetattr.return_value = ["old_settings"] @@ -361,9 +360,8 @@ def test_shell_option_forwarded(self): mock_asyncio.run.return_value = 0 runner = CliRunner() - result = runner.invoke(shell, ["123", "--shell", "bash"]) + runner.invoke(shell, ["123", "--shell", "bash"]) - # Verify asyncio.run was called, and the URL contains shell=bash mock_asyncio.run.assert_called_once() def test_pod_option_forwarded(self): @@ -389,7 +387,7 @@ def test_pod_option_forwarded(self): mock_asyncio.run.return_value = 0 runner = CliRunner() - result = runner.invoke(shell, ["123", "--pod", "my-pod"]) + runner.invoke(shell, ["123", "--pod", "my-pod"]) mock_resolve.assert_called_once() assert mock_resolve.call_args[1].get("pod_name") == "my-pod" or mock_resolve.call_args[0][2] == "my-pod" @@ -415,7 +413,7 @@ def test_passes_command(self): mock_asyncio.run.return_value = 0 runner = CliRunner() - result = runner.invoke(exec_cmd, ["123", "--", "ls", "-la"]) + runner.invoke(exec_cmd, ["123", "--", "ls", "-la"]) mock_asyncio.run.assert_called_once() @@ -438,6 +436,6 @@ def test_shell_option_forwarded(self): mock_asyncio.run.return_value = 0 runner = CliRunner() - result = runner.invoke(exec_cmd, ["123", "--shell", "zsh", "--", "echo", "hi"]) + runner.invoke(exec_cmd, ["123", "--shell", "zsh", "--", "echo", "hi"]) mock_asyncio.run.assert_called_once() From 8a93733d1b239fb9a78273d0b33dda71d1a2e483 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 14:52:23 -0400 Subject: [PATCH 06/28] fix: skip PyTorch-dependent tests in sanity mode --- tests/conftest.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/tests/conftest.py b/tests/conftest.py index f3de342..1ab0b0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,11 +3,22 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" +# Tests that require PyTorch at import time -- skip during sanity runs +# where PyTorch is not installed. +_PYTORCH_TEST_FILES = ["test_backend.py", "test_helpers.py", "test_server.py"] + +collect_ignore = [] + def pytest_addoption(parser): parser.addoption("--sanity", action="store_true", help="Run sanity tests (exclude 'gpu' tests)") +def pytest_configure(config): + if config.getoption("--sanity", default=False): + collect_ignore.extend(_PYTORCH_TEST_FILES) + + def pytest_collection_modifyitems(config, items): if config.getoption("--sanity"): skip_gpu = pytest.mark.skip(reason="Skipping GPU tests for sanity run") From 0a4c1bb39dfab16e2bf47da3c6d58fb2c0b704ea Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 14:53:10 -0400 Subject: [PATCH 07/28] fix: break out of exec loop after end marker to prevent hanging --- centml/cli/shell.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 9bd8ea6..b69fd2f 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -213,11 +213,11 @@ async def _exec_session(ws_url, token, command): exit_code = 0 buffer = "" is_capturing = False + is_done = False async for raw_msg in ws: msg = json.loads(raw_msg) if msg.get("data"): buffer += msg["data"] - # Process complete lines from buffer while "\n" in buffer: line, buffer = buffer.split("\n", 1) clean = _strip_ansi(line).rstrip("\r") @@ -225,23 +225,23 @@ async def _exec_session(ws_url, token, command): is_capturing = True continue if _END_MARKER in clean: - # Parse exit code from marker line parts = clean.split(_END_MARKER + ":") if len(parts) > 1: try: exit_code = int(parts[1].strip()) except ValueError: pass - is_capturing = False - continue + is_done = True + break if is_capturing: sys.stdout.write(line + "\n") sys.stdout.flush() elif msg.get("error"): sys.stderr.write(f"Error: {msg['error']}\n") return 1 - if "Code" in msg: - exit_code = msg["Code"] + if is_done or "Code" in msg: + if "Code" in msg: + exit_code = msg["Code"] break return exit_code From 7259bd49c797475ed2374efd3e8dccdd966a26a5 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 14:56:53 -0400 Subject: [PATCH 08/28] fix: re-enable OPOST after setraw to fix terminal rendering --- centml/cli/shell.py | 6 ++++++ tests/test_shell.py | 12 +++++++++--- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index b69fd2f..4bc1e33 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -143,6 +143,12 @@ async def _interactive_session(ws_url, token): old_settings = termios.tcgetattr(fd) try: tty.setraw(fd) + # setraw disables OPOST, which means \n won't be translated to \r\n. + # The remote PTY may send bare \n; re-enable output post-processing + # so the local terminal handles the translation (like xterm.js does). + mode = termios.tcgetattr(fd) + mode[1] = mode[1] | termios.OPOST + termios.tcsetattr(fd, termios.TCSANOW, mode) rows, cols = shutil.get_terminal_size() headers = {"Authorization": f"Bearer {token}"} diff --git a/tests/test_shell.py b/tests/test_shell.py index 351b008..bd927c0 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -306,7 +306,13 @@ async def test_restores_terminal_on_exception(self): ), patch("centml.cli.shell.websockets") as mock_ws_mod: mock_sys.stdin.fileno.return_value = 0 - mock_termios.tcgetattr.return_value = ["old_settings"] + # tcgetattr is called twice: once to save, once after setraw for OPOST fix. + # Return a realistic 7-element termios attrs list. + old_attrs = [0, 0, 0, 0, 0, 0, [0] * 32] + mock_termios.tcgetattr.return_value = old_attrs + mock_termios.OPOST = 0x1 + mock_termios.TCSANOW = 0 + mock_termios.TCSADRAIN = 1 mock_ws_mod.connect = MagicMock( return_value=AsyncMock( @@ -318,9 +324,9 @@ async def test_restores_terminal_on_exception(self): with pytest.raises(ConnectionRefusedError): await _interactive_session("wss://test/ws", "fake-token") - mock_termios.tcsetattr.assert_called_once() + mock_termios.tcsetattr.assert_called() restore_call = mock_termios.tcsetattr.call_args - assert restore_call[0][2] == ["old_settings"] + assert restore_call[0][2] == old_attrs # =========================================================================== From ab51bebc12eadc4552fd1dcc1ba3eb8444379504 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 15:02:40 -0400 Subject: [PATCH 09/28] fix: replace pytest-asyncio with asyncio.run in tests for CI compat --- tests/test_shell.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/tests/test_shell.py b/tests/test_shell.py index bd927c0..a53292d 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -1,5 +1,6 @@ """Tests for centml.cli.shell -- CLI terminal access commands.""" +import asyncio import json import urllib.parse from unittest.mock import AsyncMock, MagicMock, patch @@ -197,8 +198,7 @@ def test_multiple_revisions(self): class TestExecSession: - @pytest.mark.asyncio - async def test_sends_resize_and_wrapped_command(self): + def test_sends_resize_and_wrapped_command(self): from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER ws = AsyncMock() @@ -213,7 +213,7 @@ async def test_sends_resize_and_wrapped_command(self): return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) ) - exit_code = await _exec_session("wss://test/ws", "fake-token", "ls -la") + exit_code = asyncio.run(_exec_session("wss://test/ws", "fake-token", "ls -la")) assert exit_code == 0 assert ws.send.call_count == 2 @@ -228,8 +228,7 @@ async def test_sends_resize_and_wrapped_command(self): assert _BEGIN_MARKER not in cmd_msg["data"] assert "CENTML_BEGIN" in cmd_msg["data"] - @pytest.mark.asyncio - async def test_returns_nonzero_exit_code_from_marker(self): + def test_returns_nonzero_exit_code_from_marker(self): from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER ws = AsyncMock() @@ -241,12 +240,11 @@ async def test_returns_nonzero_exit_code_from_marker(self): return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) ) - exit_code = await _exec_session("wss://test/ws", "fake-token", "false") + exit_code = asyncio.run(_exec_session("wss://test/ws", "fake-token", "false")) assert exit_code == 42 - @pytest.mark.asyncio - async def test_error_message_returns_one(self): + def test_error_message_returns_one(self): from centml.cli.shell import _exec_session ws = AsyncMock() @@ -258,12 +256,11 @@ async def test_error_message_returns_one(self): return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) ) - exit_code = await _exec_session("wss://test/ws", "fake-token", "bad") + exit_code = asyncio.run(_exec_session("wss://test/ws", "fake-token", "bad")) assert exit_code == 1 - @pytest.mark.asyncio - async def test_filters_noise_before_marker(self): + def test_filters_noise_before_marker(self): """Only output between BEGIN and END markers is written to stdout.""" from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER @@ -283,7 +280,7 @@ async def test_filters_noise_before_marker(self): mock_sys.stdout.flush = MagicMock() mock_sys.stderr.write = MagicMock() - exit_code = await _exec_session("wss://test/ws", "fake-token", "echo test") + exit_code = asyncio.run(_exec_session("wss://test/ws", "fake-token", "echo test")) assert exit_code == 0 output = "".join(captured) @@ -297,8 +294,7 @@ async def test_filters_noise_before_marker(self): class TestInteractiveSessionTerminalRestore: - @pytest.mark.asyncio - async def test_restores_terminal_on_exception(self): + def test_restores_terminal_on_exception(self): from centml.cli.shell import _interactive_session with patch("centml.cli.shell.sys") as mock_sys, patch("centml.cli.shell.termios") as mock_termios, patch( @@ -322,7 +318,7 @@ async def test_restores_terminal_on_exception(self): ) with pytest.raises(ConnectionRefusedError): - await _interactive_session("wss://test/ws", "fake-token") + asyncio.run(_interactive_session("wss://test/ws", "fake-token")) mock_termios.tcsetattr.assert_called() restore_call = mock_termios.tcsetattr.call_args From 8770660c7b72b8b27dc6f87425a368a244bd4c7f Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 15:05:54 -0400 Subject: [PATCH 10/28] fix: match Web UI protocol - remove rows/cols from stdin messages, revert OPOST --- centml/cli/shell.py | 18 ++---------------- tests/test_shell.py | 12 +++--------- 2 files changed, 5 insertions(+), 25 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 4bc1e33..871b9e6 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -110,16 +110,8 @@ def _on_stdin_ready(): data = await asyncio.wait_for(read_queue.get(), timeout=0.5) except asyncio.TimeoutError: continue - rows, cols = shutil.get_terminal_size() await ws.send( - json.dumps( - { - "operation": "stdin", - "data": data.decode("utf-8", errors="replace"), - "rows": rows, - "cols": cols, - } - ) + json.dumps({"operation": "stdin", "data": data.decode("utf-8", errors="replace")}) ) finally: loop.remove_reader(stdin_fd) @@ -143,12 +135,6 @@ async def _interactive_session(ws_url, token): old_settings = termios.tcgetattr(fd) try: tty.setraw(fd) - # setraw disables OPOST, which means \n won't be translated to \r\n. - # The remote PTY may send bare \n; re-enable output post-processing - # so the local terminal handles the translation (like xterm.js does). - mode = termios.tcgetattr(fd) - mode[1] = mode[1] | termios.OPOST - termios.tcsetattr(fd, termios.TCSANOW, mode) rows, cols = shutil.get_terminal_size() headers = {"Authorization": f"Bearer {token}"} @@ -214,7 +200,7 @@ async def _exec_session(ws_url, token, command): f" exit $__ec\n" ) - await ws.send(json.dumps({"operation": "stdin", "data": wrapped, "rows": rows, "cols": cols})) + await ws.send(json.dumps({"operation": "stdin", "data": wrapped})) exit_code = 0 buffer = "" diff --git a/tests/test_shell.py b/tests/test_shell.py index a53292d..30dae77 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -302,13 +302,7 @@ def test_restores_terminal_on_exception(self): ), patch("centml.cli.shell.websockets") as mock_ws_mod: mock_sys.stdin.fileno.return_value = 0 - # tcgetattr is called twice: once to save, once after setraw for OPOST fix. - # Return a realistic 7-element termios attrs list. - old_attrs = [0, 0, 0, 0, 0, 0, [0] * 32] - mock_termios.tcgetattr.return_value = old_attrs - mock_termios.OPOST = 0x1 - mock_termios.TCSANOW = 0 - mock_termios.TCSADRAIN = 1 + mock_termios.tcgetattr.return_value = ["old_settings"] mock_ws_mod.connect = MagicMock( return_value=AsyncMock( @@ -320,9 +314,9 @@ def test_restores_terminal_on_exception(self): with pytest.raises(ConnectionRefusedError): asyncio.run(_interactive_session("wss://test/ws", "fake-token")) - mock_termios.tcsetattr.assert_called() + mock_termios.tcsetattr.assert_called_once() restore_call = mock_termios.tcsetattr.call_args - assert restore_call[0][2] == old_attrs + assert restore_call[0][2] == ["old_settings"] # =========================================================================== From b79a30a8f65d2633f29cc928984f270c0c84831a Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 15:10:14 -0400 Subject: [PATCH 11/28] fix: send delayed resize to fix prompt rendering after shell startup --- centml/cli/shell.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 871b9e6..07174d3 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -143,11 +143,15 @@ async def _interactive_session(ws_url, token): loop = asyncio.get_running_loop() - def _on_resize(): + def _send_resize(): r, c = shutil.get_terminal_size() asyncio.ensure_future(ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c}))) - loop.add_signal_handler(signal.SIGWINCH, _on_resize) + # ArgoCD starts the shell with a default PTY size before our resize + # arrives. A second resize after a short delay triggers SIGWINCH on + # the remote, causing readline to redraw the prompt at the correct width. + loop.call_later(0.3, _send_resize) + loop.add_signal_handler(signal.SIGWINCH, _send_resize) try: exit_code = await _forward_io(ws) From 0b03a533bb9ed29b3f069b7c149b4a9739db73fa Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 15:12:35 -0400 Subject: [PATCH 12/28] fix: await cancelled tasks for cleanup, reduce WS close_timeout to 2s --- centml/cli/shell.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 07174d3..4812ba7 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -120,6 +120,12 @@ def _on_stdin_ready(): done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) for t in pending: t.cancel() + # Await cancelled tasks so their finally blocks run (e.g. remove_reader). + for t in pending: + try: + await t + except (asyncio.CancelledError, Exception): + pass for t in done: if t.exception() is not None: raise t.exception() @@ -138,7 +144,7 @@ async def _interactive_session(ws_url, token): rows, cols = shutil.get_terminal_size() headers = {"Authorization": f"Bearer {token}"} - async with websockets.connect(ws_url, additional_headers=headers) as ws: + async with websockets.connect(ws_url, additional_headers=headers, close_timeout=2) as ws: await ws.send(json.dumps({"operation": "resize", "rows": rows, "cols": cols})) loop = asyncio.get_running_loop() @@ -188,7 +194,7 @@ async def _exec_session(ws_url, token, command): rows, cols = shutil.get_terminal_size(fallback=(80, 24)) headers = {"Authorization": f"Bearer {token}"} - async with websockets.connect(ws_url, additional_headers=headers) as ws: + async with websockets.connect(ws_url, additional_headers=headers, close_timeout=2) as ws: await ws.send(json.dumps({"operation": "resize", "rows": rows, "cols": cols})) # Suppress echo/bracketed-paste, emit begin marker, run command, From 54445b8a0c1dda7c82a89e8dc5a2995d08370454 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 15:21:04 -0400 Subject: [PATCH 13/28] fix: toggle PTY width to force SIGWINCH and prompt redraw on connect --- centml/cli/shell.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 4812ba7..5065a41 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -153,10 +153,24 @@ def _send_resize(): r, c = shutil.get_terminal_size() asyncio.ensure_future(ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c}))) - # ArgoCD starts the shell with a default PTY size before our resize - # arrives. A second resize after a short delay triggers SIGWINCH on - # the remote, causing readline to redraw the prompt at the correct width. - loop.call_later(0.3, _send_resize) + async def _force_initial_redraw(): + """Force SIGWINCH on the remote by toggling the PTY width. + + The initial resize may arrive before the shell starts, so by + the time the shell reads its PTY size, it's already correct + and no SIGWINCH fires. Toggling cols forces two SIGWINCHes, + making readline redraw the prompt at the right width. + """ + try: + await asyncio.sleep(0.5) + r, c = shutil.get_terminal_size() + await ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c + 1})) + await asyncio.sleep(0.05) + await ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c})) + except Exception: + pass + + asyncio.create_task(_force_initial_redraw()) loop.add_signal_handler(signal.SIGWINCH, _send_resize) try: From 0a0636c8d352866fcf63ee130f56330b60f434b9 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 15:47:05 -0400 Subject: [PATCH 14/28] fix: include rows/cols in stdin messages and send Ctrl+L after resize toggle --- centml/cli/shell.py | 82 ++++++++++++++++++++++++++++++++++++--------- 1 file changed, 66 insertions(+), 16 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 5065a41..b9cb51d 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -52,17 +52,22 @@ def _resolve_pod(cclient, deployment_id, pod_name=None): running_pods.append(pod.name) if not running_pods: - raise click.ClickException(f"No running pods found for deployment {deployment_id}") + raise click.ClickException( + f"No running pods found for deployment {deployment_id}" + ) if pod_name is not None: if pod_name not in running_pods: pods_list = ", ".join(running_pods) - raise click.ClickException(f"Pod '{pod_name}' not found. Available running pods: {pods_list}") + raise click.ClickException( + f"Pod '{pod_name}' not found. Available running pods: {pods_list}" + ) return pod_name if len(running_pods) > 1: click.echo( - f"Multiple running pods found, connecting to {running_pods[0]}. " f"Use --pod to specify a different pod.", + f"Multiple running pods found, connecting to {running_pods[0]}. " + f"Use --pod to specify a different pod.", err=True, ) return running_pods[0] @@ -110,8 +115,16 @@ def _on_stdin_ready(): data = await asyncio.wait_for(read_queue.get(), timeout=0.5) except asyncio.TimeoutError: continue + r, c = shutil.get_terminal_size() await ws.send( - json.dumps({"operation": "stdin", "data": data.decode("utf-8", errors="replace")}) + json.dumps( + { + "operation": "stdin", + "data": data.decode("utf-8", errors="replace"), + "rows": r, + "cols": c, + } + ) ) finally: loop.remove_reader(stdin_fd) @@ -144,14 +157,20 @@ async def _interactive_session(ws_url, token): rows, cols = shutil.get_terminal_size() headers = {"Authorization": f"Bearer {token}"} - async with websockets.connect(ws_url, additional_headers=headers, close_timeout=2) as ws: - await ws.send(json.dumps({"operation": "resize", "rows": rows, "cols": cols})) + async with websockets.connect( + ws_url, additional_headers=headers, close_timeout=2 + ) as ws: + await ws.send( + json.dumps({"operation": "resize", "rows": rows, "cols": cols}) + ) loop = asyncio.get_running_loop() def _send_resize(): r, c = shutil.get_terminal_size() - asyncio.ensure_future(ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c}))) + asyncio.ensure_future( + ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c})) + ) async def _force_initial_redraw(): """Force SIGWINCH on the remote by toggling the PTY width. @@ -164,9 +183,17 @@ async def _force_initial_redraw(): try: await asyncio.sleep(0.5) r, c = shutil.get_terminal_size() - await ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c + 1})) + await ws.send( + json.dumps({"operation": "resize", "rows": r, "cols": c + 1}) + ) await asyncio.sleep(0.05) - await ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c})) + await ws.send( + json.dumps({"operation": "resize", "rows": r, "cols": c}) + ) + await asyncio.sleep(0.1) + # Ctrl+L clears screen and forces the shell to redraw + # its prompt at the now-correct terminal width. + await ws.send(json.dumps({"operation": "stdin", "data": "\x0c"})) except Exception: pass @@ -208,7 +235,9 @@ async def _exec_session(ws_url, token, command): rows, cols = shutil.get_terminal_size(fallback=(80, 24)) headers = {"Authorization": f"Bearer {token}"} - async with websockets.connect(ws_url, additional_headers=headers, close_timeout=2) as ws: + async with websockets.connect( + ws_url, additional_headers=headers, close_timeout=2 + ) as ws: await ws.send(json.dumps({"operation": "resize", "rows": rows, "cols": cols})) # Suppress echo/bracketed-paste, emit begin marker, run command, @@ -264,8 +293,16 @@ async def _exec_session(ws_url, token, command): @click.command(help="Open an interactive shell to a deployment pod") @click.argument("deployment_id", type=int) -@click.option("--pod", default=None, help="Specific pod name (auto-selects first running pod)") -@click.option("--shell", "shell_type", default=None, type=click.Choice(["bash", "sh", "zsh"]), help="Shell type") +@click.option( + "--pod", default=None, help="Specific pod name (auto-selects first running pod)" +) +@click.option( + "--shell", + "shell_type", + default=None, + type=click.Choice(["bash", "sh", "zsh"]), + help="Shell type", +) @handle_exception def shell(deployment_id, pod, shell_type): if not sys.stdin.isatty(): @@ -274,23 +311,36 @@ def shell(deployment_id, pod, shell_type): with get_centml_client() as cclient: pod_name = _resolve_pod(cclient, deployment_id, pod) - ws_url = _build_ws_url(settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type) + ws_url = _build_ws_url( + settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type + ) token = auth.get_centml_token() exit_code = asyncio.run(_interactive_session(ws_url, token)) sys.exit(exit_code) -@click.command(help="Execute a command in a deployment pod", context_settings={"ignore_unknown_options": True}) +@click.command( + help="Execute a command in a deployment pod", + context_settings={"ignore_unknown_options": True}, +) @click.argument("deployment_id", type=int) @click.argument("command", nargs=-1, required=True, type=click.UNPROCESSED) @click.option("--pod", default=None, help="Specific pod name") -@click.option("--shell", "shell_type", default=None, type=click.Choice(["bash", "sh", "zsh"]), help="Shell type") +@click.option( + "--shell", + "shell_type", + default=None, + type=click.Choice(["bash", "sh", "zsh"]), + help="Shell type", +) @handle_exception def exec_cmd(deployment_id, command, pod, shell_type): with get_centml_client() as cclient: pod_name = _resolve_pod(cclient, deployment_id, pod) - ws_url = _build_ws_url(settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type) + ws_url = _build_ws_url( + settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type + ) token = auth.get_centml_token() cmd_str = " ".join(command) exit_code = asyncio.run(_exec_session(ws_url, token, cmd_str)) From 31d41ae20b70f668703364782ad1639e08dc872b Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 15:54:58 -0400 Subject: [PATCH 15/28] fix: use stty to set PTY dimensions from inside shell instead of resize messages --- centml/cli/shell.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index b9cb51d..09a9b56 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -173,27 +173,26 @@ def _send_resize(): ) async def _force_initial_redraw(): - """Force SIGWINCH on the remote by toggling the PTY width. - - The initial resize may arrive before the shell starts, so by - the time the shell reads its PTY size, it's already correct - and no SIGWINCH fires. Toggling cols forces two SIGWINCHes, - making readline redraw the prompt at the right width. + """Set PTY dimensions from inside the shell and clear screen. + + The WebSocket resize message may not reliably update the + remote PTY size (race with shell startup, ArgoCD buffering). + Sending ``stty rows R cols C`` directly through stdin is + authoritative -- it always works because the shell itself + calls the TIOCSWINSZ ioctl. The leading space keeps the + command out of bash history when HISTCONTROL=ignorespace. """ try: await asyncio.sleep(0.5) r, c = shutil.get_terminal_size() await ws.send( - json.dumps({"operation": "resize", "rows": r, "cols": c + 1}) - ) - await asyncio.sleep(0.05) - await ws.send( - json.dumps({"operation": "resize", "rows": r, "cols": c}) + json.dumps( + { + "operation": "stdin", + "data": f" stty rows {r} cols {c}; clear\n", + } + ) ) - await asyncio.sleep(0.1) - # Ctrl+L clears screen and forces the shell to redraw - # its prompt at the now-correct terminal width. - await ws.send(json.dumps({"operation": "stdin", "data": "\x0c"})) except Exception: pass From ec8286baf9c78ae139502630ce59b829bfacdfd3 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 15:59:00 -0400 Subject: [PATCH 16/28] fix: re-enable OPOST after setraw to convert bare \n to \r\n like xterm.js convertEol --- centml/cli/shell.py | 36 ++++++++++++++++++++++-------------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 09a9b56..f7dfa2c 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -154,6 +154,14 @@ async def _interactive_session(ws_url, token): old_settings = termios.tcgetattr(fd) try: tty.setraw(fd) + # Re-enable output post-processing so the local terminal converts + # bare \n to \r\n. The remote PTY may send lone \n (Kubernetes + # exec PTY behavior). xterm.js handles this with convertEol; + # for a raw CLI terminal we need OPOST. Extra \r before \n from + # a remote that already sends \r\n is harmless. + attrs = termios.tcgetattr(fd) + attrs[1] |= termios.OPOST + termios.tcsetattr(fd, termios.TCSANOW, attrs) rows, cols = shutil.get_terminal_size() headers = {"Authorization": f"Bearer {token}"} @@ -173,26 +181,26 @@ def _send_resize(): ) async def _force_initial_redraw(): - """Set PTY dimensions from inside the shell and clear screen. - - The WebSocket resize message may not reliably update the - remote PTY size (race with shell startup, ArgoCD buffering). - Sending ``stty rows R cols C`` directly through stdin is - authoritative -- it always works because the shell itself - calls the TIOCSWINSZ ioctl. The leading space keeps the - command out of bash history when HISTCONTROL=ignorespace. + """Toggle PTY width to force SIGWINCH, then Ctrl+L to redraw. + + The initial resize may arrive before the remote shell + starts, so the shell never sees a SIGWINCH. Toggling + cols by +1/-1 guarantees a size change and SIGWINCH. + Ctrl+L then clears the screen so the prompt redraws at + the correct width. """ try: await asyncio.sleep(0.5) r, c = shutil.get_terminal_size() await ws.send( - json.dumps( - { - "operation": "stdin", - "data": f" stty rows {r} cols {c}; clear\n", - } - ) + json.dumps({"operation": "resize", "rows": r, "cols": c + 1}) + ) + await asyncio.sleep(0.05) + await ws.send( + json.dumps({"operation": "resize", "rows": r, "cols": c}) ) + await asyncio.sleep(0.1) + await ws.send(json.dumps({"operation": "stdin", "data": "\x0c"})) except Exception: pass From 5d94f149c98f038acaf5ffcc07cedb57bbf51800 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 16:11:55 -0400 Subject: [PATCH 17/28] fix: convert \n to \r\n in output and use stty to fix PTY dimensions on connect --- centml/cli/shell.py | 68 +++++++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 39 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index f7dfa2c..d58ebb9 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -73,26 +73,49 @@ def _resolve_pod(cclient, deployment_id, pod_name=None): return running_pods[0] -async def _forward_io(ws): +async def _forward_io(ws, fix_dimensions=False): """Bidirectional forwarding between local stdin/stdout and WebSocket. + Args: + ws: WebSocket connection. + fix_dimensions: When True, send ``stty rows R cols C`` through stdin + on the first output message to fix the remote PTY dimensions, + then clear the screen so the prompt redraws at the correct width. + Returns the remote exit code. """ loop = asyncio.get_running_loop() exit_code = 0 stdin_fd = sys.stdin.fileno() - stdin_closed = asyncio.Event() + dimensions_fixed = not fix_dimensions async def _read_ws(): - nonlocal exit_code + nonlocal exit_code, dimensions_fixed async for raw_msg in ws: msg = json.loads(raw_msg) if msg.get("data"): - sys.stdout.buffer.write(msg["data"].encode("utf-8", errors="replace")) + if not dimensions_fixed: + dimensions_fixed = True + r, c = shutil.get_terminal_size() + # Set PTY size from inside the shell and clear screen. + # Leading space keeps the command out of bash history. + await ws.send( + json.dumps( + { + "operation": "stdin", + "data": f" stty rows {r} cols {c} 2>/dev/null; clear\n", + } + ) + ) + # Convert bare \n to \r\n (equivalent to xterm.js convertEol). + # The remote PTY may send lone \n; without this the cursor + # moves down without returning to column 0. + text = msg["data"].replace("\n", "\r\n") + sys.stdout.buffer.write(text.encode("utf-8", errors="replace")) sys.stdout.buffer.flush() elif msg.get("error"): - sys.stderr.write(f"Error: {msg['error']}\n") + sys.stderr.write(f"Error: {msg['error']}\r\n") sys.stderr.flush() if "Code" in msg: exit_code = msg["Code"] @@ -154,14 +177,6 @@ async def _interactive_session(ws_url, token): old_settings = termios.tcgetattr(fd) try: tty.setraw(fd) - # Re-enable output post-processing so the local terminal converts - # bare \n to \r\n. The remote PTY may send lone \n (Kubernetes - # exec PTY behavior). xterm.js handles this with convertEol; - # for a raw CLI terminal we need OPOST. Extra \r before \n from - # a remote that already sends \r\n is harmless. - attrs = termios.tcgetattr(fd) - attrs[1] |= termios.OPOST - termios.tcsetattr(fd, termios.TCSANOW, attrs) rows, cols = shutil.get_terminal_size() headers = {"Authorization": f"Bearer {token}"} @@ -180,35 +195,10 @@ def _send_resize(): ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c})) ) - async def _force_initial_redraw(): - """Toggle PTY width to force SIGWINCH, then Ctrl+L to redraw. - - The initial resize may arrive before the remote shell - starts, so the shell never sees a SIGWINCH. Toggling - cols by +1/-1 guarantees a size change and SIGWINCH. - Ctrl+L then clears the screen so the prompt redraws at - the correct width. - """ - try: - await asyncio.sleep(0.5) - r, c = shutil.get_terminal_size() - await ws.send( - json.dumps({"operation": "resize", "rows": r, "cols": c + 1}) - ) - await asyncio.sleep(0.05) - await ws.send( - json.dumps({"operation": "resize", "rows": r, "cols": c}) - ) - await asyncio.sleep(0.1) - await ws.send(json.dumps({"operation": "stdin", "data": "\x0c"})) - except Exception: - pass - - asyncio.create_task(_force_initial_redraw()) loop.add_signal_handler(signal.SIGWINCH, _send_resize) try: - exit_code = await _forward_io(ws) + exit_code = await _forward_io(ws, fix_dimensions=True) finally: loop.remove_signal_handler(signal.SIGWINCH) From 559460fbb7233098d2038775081b9e4289e7f573 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 16:24:30 -0400 Subject: [PATCH 18/28] feat: use pyte terminal emulator for interactive shell rendering Add pyte as local terminal emulator (equivalent to xterm.js) to solve cursor positioning and line wrapping issues. Feed WebSocket output through pyte Screen/Stream and render only dirty lines with ANSI escape sequences. --- centml/cli/shell.py | 186 +++++++++++++++++++++++++++++++------- requirements.txt | 1 + tests/test_shell.py | 211 ++++++++++++++++++++++++++++++++++++++++---- 3 files changed, 348 insertions(+), 50 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index d58ebb9..c274e4a 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -11,6 +11,7 @@ import urllib.parse import click +import pyte import websockets from centml.cli.cluster import handle_exception @@ -19,6 +20,130 @@ from centml.sdk.config import settings +# --------------------------------------------------------------------------- +# pyte screen renderer -- converts pyte's in-memory screen buffer to ANSI +# escape sequences for the local terminal. +# --------------------------------------------------------------------------- + +_PYTE_FG_TO_SGR = { + "default": "39", + "black": "30", + "red": "31", + "green": "32", + "brown": "33", + "blue": "34", + "magenta": "35", + "cyan": "36", + "white": "37", + "brightblack": "90", + "brightred": "91", + "brightgreen": "92", + "brightbrown": "93", + "brightblue": "94", + "brightmagenta": "95", + "brightcyan": "96", + "brightwhite": "97", +} + +_PYTE_BG_TO_SGR = { + "default": "49", + "black": "40", + "red": "41", + "green": "42", + "brown": "43", + "blue": "44", + "magenta": "45", + "cyan": "46", + "white": "47", + "brightblack": "100", + "brightred": "101", + "brightgreen": "102", + "brightbrown": "103", + "brightblue": "104", + "brightmagenta": "105", + "brightcyan": "106", + "brightwhite": "107", +} + + +def _color_sgr(color, is_bg=False): + """Convert a pyte color value to an SGR parameter string.""" + table = _PYTE_BG_TO_SGR if is_bg else _PYTE_FG_TO_SGR + if color in table: + default_val = "49" if is_bg else "39" + code = table[color] + return code if code != default_val else "" + # 6-char hex -> truecolor + if len(color) == 6: + try: + r, g, b = int(color[:2], 16), int(color[2:4], 16), int(color[4:], 16) + prefix = "48" if is_bg else "38" + return f"{prefix};2;{r};{g};{b}" + except ValueError: + return "" + return "" + + +def _char_to_sgr(char): + """Build the ANSI SGR parameter string for a pyte Char's attributes.""" + parts = [] + if char.bold: + parts.append("1") + if char.italics: + parts.append("3") + if char.underscore: + parts.append("4") + if char.blink: + parts.append("5") + if char.reverse: + parts.append("7") + if char.strikethrough: + parts.append("9") + fg = _color_sgr(char.fg, is_bg=False) + if fg: + parts.append(fg) + bg = _color_sgr(char.bg, is_bg=True) + if bg: + parts.append(bg) + return ";".join(parts) + + +def _render_dirty(screen, output): + """Render only the dirty lines from the pyte Screen to the terminal. + + Args: + screen: pyte.Screen instance. + output: Writable binary stream (e.g. sys.stdout.buffer). + """ + parts = [] + for row in sorted(screen.dirty): + # Position cursor at row (1-based), column 1; clear line. + parts.append(f"\033[{row + 1};1H\033[2K") + prev_sgr = "" + line_chars = [] + for col in range(screen.columns): + char = screen.buffer[row][col] + if char.data == "": + continue + sgr = _char_to_sgr(char) + if sgr != prev_sgr: + line_chars.append(f"\033[0m\033[{sgr}m" if sgr else "\033[0m") + prev_sgr = sgr + line_chars.append(char.data) + text = "".join(line_chars).rstrip() + parts.append(text) + # Reset attributes, position cursor. + parts.append("\033[0m") + parts.append(f"\033[{screen.cursor.y + 1};{screen.cursor.x + 1}H") + if screen.cursor.hidden: + parts.append("\033[?25l") + else: + parts.append("\033[?25h") + screen.dirty.clear() + output.write("".join(parts).encode("utf-8")) + output.flush() + + def _build_ws_url(api_url, deployment_id, pod_name, shell_type=None): """Build the WebSocket URL for a terminal connection.""" parsed = urllib.parse.urlparse(api_url) @@ -73,14 +198,17 @@ def _resolve_pod(cclient, deployment_id, pod_name=None): return running_pods[0] -async def _forward_io(ws, fix_dimensions=False): +async def _forward_io(ws, screen, stream): """Bidirectional forwarding between local stdin/stdout and WebSocket. + Output flows through a pyte terminal emulator so that cursor + addressing, line wrapping, and colors are rendered correctly + regardless of the remote PTY dimensions. + Args: ws: WebSocket connection. - fix_dimensions: When True, send ``stty rows R cols C`` through stdin - on the first output message to fix the remote PTY dimensions, - then clear the screen so the prompt redraws at the correct width. + screen: pyte.Screen instance sized to the local terminal. + stream: pyte.Stream attached to *screen*. Returns the remote exit code. """ @@ -88,35 +216,19 @@ async def _forward_io(ws, fix_dimensions=False): exit_code = 0 stdin_fd = sys.stdin.fileno() stdin_closed = asyncio.Event() - dimensions_fixed = not fix_dimensions async def _read_ws(): - nonlocal exit_code, dimensions_fixed + nonlocal exit_code async for raw_msg in ws: msg = json.loads(raw_msg) if msg.get("data"): - if not dimensions_fixed: - dimensions_fixed = True - r, c = shutil.get_terminal_size() - # Set PTY size from inside the shell and clear screen. - # Leading space keeps the command out of bash history. - await ws.send( - json.dumps( - { - "operation": "stdin", - "data": f" stty rows {r} cols {c} 2>/dev/null; clear\n", - } - ) - ) - # Convert bare \n to \r\n (equivalent to xterm.js convertEol). - # The remote PTY may send lone \n; without this the cursor - # moves down without returning to column 0. - text = msg["data"].replace("\n", "\r\n") - sys.stdout.buffer.write(text.encode("utf-8", errors="replace")) - sys.stdout.buffer.flush() + # Convert bare \n to \r\n before feeding pyte, equivalent + # to xterm.js ``convertEol: true``. + stream.feed(msg["data"].replace("\n", "\r\n")) + _render_dirty(screen, sys.stdout.buffer) elif msg.get("error"): - sys.stderr.write(f"Error: {msg['error']}\r\n") - sys.stderr.flush() + stream.feed(f"Error: {msg['error']}\r\n") + _render_dirty(screen, sys.stdout.buffer) if "Code" in msg: exit_code = msg["Code"] return @@ -138,14 +250,13 @@ def _on_stdin_ready(): data = await asyncio.wait_for(read_queue.get(), timeout=0.5) except asyncio.TimeoutError: continue - r, c = shutil.get_terminal_size() await ws.send( json.dumps( { "operation": "stdin", "data": data.decode("utf-8", errors="replace"), - "rows": r, - "cols": c, + "rows": screen.lines, + "cols": screen.columns, } ) ) @@ -156,7 +267,6 @@ def _on_stdin_ready(): done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) for t in pending: t.cancel() - # Await cancelled tasks so their finally blocks run (e.g. remove_reader). for t in pending: try: await t @@ -179,6 +289,13 @@ async def _interactive_session(ws_url, token): tty.setraw(fd) rows, cols = shutil.get_terminal_size() + screen = pyte.Screen(cols, rows) + stream = pyte.Stream(screen) + + # Clear local screen before starting. + sys.stdout.buffer.write(b"\033[2J\033[H") + sys.stdout.buffer.flush() + headers = {"Authorization": f"Bearer {token}"} async with websockets.connect( ws_url, additional_headers=headers, close_timeout=2 @@ -191,6 +308,8 @@ async def _interactive_session(ws_url, token): def _send_resize(): r, c = shutil.get_terminal_size() + screen.resize(r, c) + screen.dirty.update(range(r)) asyncio.ensure_future( ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c})) ) @@ -198,13 +317,16 @@ def _send_resize(): loop.add_signal_handler(signal.SIGWINCH, _send_resize) try: - exit_code = await _forward_io(ws, fix_dimensions=True) + exit_code = await _forward_io(ws, screen, stream) finally: loop.remove_signal_handler(signal.SIGWINCH) return exit_code finally: termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) + # Restore cursor visibility and attributes. + sys.stdout.buffer.write(b"\033[?25h\033[0m\n") + sys.stdout.buffer.flush() _ANSI_ESCAPE_RE = re.compile(r"\x1b\[[0-9;?]*[a-zA-Z]|\x1b\].*?\x07|\x1b\[.*?\x1b\\") diff --git a/requirements.txt b/requirements.txt index 133ee99..9e79a4f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,5 @@ prometheus-client>=0.20.0 scipy>=1.6.0 scikit-learn>=1.5.1 websockets>=13.0 +pyte>=0.8.0 platform-api-python-client==4.6.0 diff --git a/tests/test_shell.py b/tests/test_shell.py index 30dae77..9d950c6 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -1,11 +1,13 @@ """Tests for centml.cli.shell -- CLI terminal access commands.""" import asyncio +import io import json import urllib.parse from unittest.mock import AsyncMock, MagicMock, patch import click +import pyte import pytest from platform_api_python_client import PodStatus, PodDetails, RevisionPodDetails @@ -116,7 +118,14 @@ def test_selects_first_running(self): cclient = MagicMock() cclient.get_status_v3.return_value = _make_status_response( - [_make_revision([_make_pod("pod-a", PodStatus.RUNNING), _make_pod("pod-b", PodStatus.RUNNING)])] + [ + _make_revision( + [ + _make_pod("pod-a", PodStatus.RUNNING), + _make_pod("pod-b", PodStatus.RUNNING), + ] + ) + ] ) result = _resolve_pod(cclient, 1) assert result == "pod-a" @@ -146,7 +155,14 @@ def test_returns_specified_pod(self): cclient = MagicMock() cclient.get_status_v3.return_value = _make_status_response( - [_make_revision([_make_pod("pod-a", PodStatus.RUNNING), _make_pod("pod-b", PodStatus.RUNNING)])] + [ + _make_revision( + [ + _make_pod("pod-a", PodStatus.RUNNING), + _make_pod("pod-b", PodStatus.RUNNING), + ] + ) + ] ) result = _resolve_pod(cclient, 1, pod_name="pod-b") assert result == "pod-b" @@ -173,7 +189,14 @@ def test_skips_pods_without_name(self): cclient = MagicMock() cclient.get_status_v3.return_value = _make_status_response( - [_make_revision([_make_pod(None, PodStatus.RUNNING), _make_pod("pod-real", PodStatus.RUNNING)])] + [ + _make_revision( + [ + _make_pod(None, PodStatus.RUNNING), + _make_pod("pod-real", PodStatus.RUNNING), + ] + ) + ] ) result = _resolve_pod(cclient, 1) assert result == "pod-real" @@ -203,17 +226,24 @@ def test_sends_resize_and_wrapped_command(self): ws = AsyncMock() messages = [ - json.dumps({"data": f"noise\n{_BEGIN_MARKER}\nhello world\n{_END_MARKER}:0\n"}), + json.dumps( + {"data": f"noise\n{_BEGIN_MARKER}\nhello world\n{_END_MARKER}:0\n"} + ), json.dumps({"Code": 0}), ] ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) with patch("centml.cli.shell.websockets") as mock_ws_mod: mock_ws_mod.connect = MagicMock( - return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) ) - exit_code = asyncio.run(_exec_session("wss://test/ws", "fake-token", "ls -la")) + exit_code = asyncio.run( + _exec_session("wss://test/ws", "fake-token", "ls -la") + ) assert exit_code == 0 assert ws.send.call_count == 2 @@ -232,15 +262,23 @@ def test_returns_nonzero_exit_code_from_marker(self): from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER ws = AsyncMock() - messages = [json.dumps({"data": f"{_BEGIN_MARKER}\n{_END_MARKER}:42\n"}), json.dumps({"Code": 42})] + messages = [ + json.dumps({"data": f"{_BEGIN_MARKER}\n{_END_MARKER}:42\n"}), + json.dumps({"Code": 42}), + ] ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) with patch("centml.cli.shell.websockets") as mock_ws_mod: mock_ws_mod.connect = MagicMock( - return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) ) - exit_code = asyncio.run(_exec_session("wss://test/ws", "fake-token", "false")) + exit_code = asyncio.run( + _exec_session("wss://test/ws", "fake-token", "false") + ) assert exit_code == 42 @@ -253,7 +291,10 @@ def test_error_message_returns_one(self): with patch("centml.cli.shell.websockets") as mock_ws_mod: mock_ws_mod.connect = MagicMock( - return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) ) exit_code = asyncio.run(_exec_session("wss://test/ws", "fake-token", "bad")) @@ -266,21 +307,32 @@ def test_filters_noise_before_marker(self): ws = AsyncMock() messages = [ - json.dumps({"data": f"prompt$ command\n{_BEGIN_MARKER}\nreal output\n{_END_MARKER}:0\n"}), + json.dumps( + { + "data": f"prompt$ command\n{_BEGIN_MARKER}\nreal output\n{_END_MARKER}:0\n" + } + ), json.dumps({"Code": 0}), ] ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) captured = [] - with patch("centml.cli.shell.websockets") as mock_ws_mod, patch("centml.cli.shell.sys") as mock_sys: + with patch("centml.cli.shell.websockets") as mock_ws_mod, patch( + "centml.cli.shell.sys" + ) as mock_sys: mock_ws_mod.connect = MagicMock( - return_value=AsyncMock(__aenter__=AsyncMock(return_value=ws), __aexit__=AsyncMock(return_value=False)) + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) ) mock_sys.stdout.write = lambda s: captured.append(s) mock_sys.stdout.flush = MagicMock() mock_sys.stderr.write = MagicMock() - exit_code = asyncio.run(_exec_session("wss://test/ws", "fake-token", "echo test")) + exit_code = asyncio.run( + _exec_session("wss://test/ws", "fake-token", "echo test") + ) assert exit_code == 0 output = "".join(captured) @@ -297,9 +349,11 @@ class TestInteractiveSessionTerminalRestore: def test_restores_terminal_on_exception(self): from centml.cli.shell import _interactive_session - with patch("centml.cli.shell.sys") as mock_sys, patch("centml.cli.shell.termios") as mock_termios, patch( - "centml.cli.shell.tty" - ), patch("centml.cli.shell.websockets") as mock_ws_mod: + with patch("centml.cli.shell.sys") as mock_sys, patch( + "centml.cli.shell.termios" + ) as mock_termios, patch("centml.cli.shell.tty"), patch( + "centml.cli.shell.websockets" + ) as mock_ws_mod: mock_sys.stdin.fileno.return_value = 0 mock_termios.tcgetattr.return_value = ["old_settings"] @@ -386,7 +440,10 @@ def test_pod_option_forwarded(self): runner.invoke(shell, ["123", "--pod", "my-pod"]) mock_resolve.assert_called_once() - assert mock_resolve.call_args[1].get("pod_name") == "my-pod" or mock_resolve.call_args[0][2] == "my-pod" + assert ( + mock_resolve.call_args[1].get("pod_name") == "my-pod" + or mock_resolve.call_args[0][2] == "my-pod" + ) class TestExecCommand: @@ -435,3 +492,121 @@ def test_shell_option_forwarded(self): runner.invoke(exec_cmd, ["123", "--shell", "zsh", "--", "echo", "hi"]) mock_asyncio.run.assert_called_once() + + +# =========================================================================== +# pyte renderer: _char_to_sgr +# =========================================================================== + + +class TestCharToSgr: + def test_default_attrs_returns_empty(self): + from centml.cli.shell import _char_to_sgr + + char = pyte.screens.Char( + " ", "default", "default", False, False, False, False, False, False + ) + assert _char_to_sgr(char) == "" + + def test_bold_red_fg(self): + from centml.cli.shell import _char_to_sgr + + char = pyte.screens.Char( + "x", "red", "default", True, False, False, False, False, False + ) + sgr = _char_to_sgr(char) + assert "1" in sgr.split(";") + assert "31" in sgr.split(";") + + def test_bg_color(self): + from centml.cli.shell import _char_to_sgr + + char = pyte.screens.Char( + "x", "default", "blue", False, False, False, False, False, False + ) + sgr = _char_to_sgr(char) + assert "44" in sgr.split(";") + + def test_256_color_fg(self): + from centml.cli.shell import _char_to_sgr + + char = pyte.screens.Char( + "x", "ff0000", "default", False, False, False, False, False, False + ) + sgr = _char_to_sgr(char) + assert "38;2;255;0;0" in sgr + + def test_combined_attrs(self): + from centml.cli.shell import _char_to_sgr + + char = pyte.screens.Char( + "x", "green", "white", True, True, True, False, False, False + ) + sgr = _char_to_sgr(char) + parts = sgr.split(";") + assert "1" in parts # bold + assert "3" in parts # italics + assert "4" in parts # underscore + assert "32" in parts # green fg + assert "47" in parts # white bg + + +# =========================================================================== +# pyte renderer: _render_dirty +# =========================================================================== + + +class TestRenderDirty: + def test_renders_simple_text(self): + from centml.cli.shell import _render_dirty + + screen = pyte.Screen(40, 5) + stream = pyte.Stream(screen) + screen.dirty.clear() + stream.feed("hello") + buf = io.BytesIO() + _render_dirty(screen, buf) + output = buf.getvalue().decode("utf-8") + assert "hello" in output + assert len(screen.dirty) == 0 + + def test_clears_dirty_after_render(self): + from centml.cli.shell import _render_dirty + + screen = pyte.Screen(40, 5) + stream = pyte.Stream(screen) + screen.dirty.clear() + stream.feed("test") + assert len(screen.dirty) > 0 + _render_dirty(screen, io.BytesIO()) + assert len(screen.dirty) == 0 + + def test_cursor_position_in_output(self): + from centml.cli.shell import _render_dirty + + screen = pyte.Screen(40, 5) + stream = pyte.Stream(screen) + stream.feed("abc") + buf = io.BytesIO() + _render_dirty(screen, buf) + output = buf.getvalue().decode("utf-8") + # Cursor should be at row 1, col 4 (1-based: after "abc") + assert "\033[1;4H" in output + + def test_renders_only_dirty_lines(self): + from centml.cli.shell import _render_dirty + + screen = pyte.Screen(40, 5) + stream = pyte.Stream(screen) + stream.feed("line0\r\nline1\r\nline2") + # Render to clear dirty + _render_dirty(screen, io.BytesIO()) + # Now modify only line 0 + stream.feed("\033[1;1Hchanged") + buf = io.BytesIO() + _render_dirty(screen, buf) + output = buf.getvalue().decode("utf-8") + assert "changed" in output + # line1 and line2 should NOT be re-rendered + assert "line1" not in output + assert "line2" not in output From ff0e893bfa31083cadefafa7001a46a8edac9409 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 16:28:00 -0400 Subject: [PATCH 19/28] fix: swap rows/cols unpacking from shutil.get_terminal_size shutil.get_terminal_size() returns (columns, lines), not (rows, cols). The swapped unpacking caused pyte Screen to be created with terminal line count as width, making the display extremely narrow. --- centml/cli/shell.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index c274e4a..cea3286 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -287,7 +287,7 @@ async def _interactive_session(ws_url, token): old_settings = termios.tcgetattr(fd) try: tty.setraw(fd) - rows, cols = shutil.get_terminal_size() + cols, rows = shutil.get_terminal_size() screen = pyte.Screen(cols, rows) stream = pyte.Stream(screen) @@ -307,7 +307,7 @@ async def _interactive_session(ws_url, token): loop = asyncio.get_running_loop() def _send_resize(): - r, c = shutil.get_terminal_size() + c, r = shutil.get_terminal_size() screen.resize(r, c) screen.dirty.update(range(r)) asyncio.ensure_future( @@ -351,7 +351,7 @@ async def _exec_session(ws_url, token, command): Does not enter raw mode -- output is pipe-friendly. Suppresses shell echo and uses markers to capture only command output. """ - rows, cols = shutil.get_terminal_size(fallback=(80, 24)) + cols, rows = shutil.get_terminal_size(fallback=(80, 24)) headers = {"Authorization": f"Bearer {token}"} async with websockets.connect( From c6d42edc6487f14b87cdeb626235564a5eb09764 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 16:34:46 -0400 Subject: [PATCH 20/28] fix: use alternate screen buffer to prevent scrollback in Warp terminal --- centml/cli/shell.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index cea3286..fc585d0 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -292,8 +292,8 @@ async def _interactive_session(ws_url, token): screen = pyte.Screen(cols, rows) stream = pyte.Stream(screen) - # Clear local screen before starting. - sys.stdout.buffer.write(b"\033[2J\033[H") + # Switch to alternate screen buffer (disables scrollback) and clear. + sys.stdout.buffer.write(b"\033[?1049h\033[2J\033[H") sys.stdout.buffer.flush() headers = {"Authorization": f"Bearer {token}"} @@ -324,8 +324,8 @@ def _send_resize(): return exit_code finally: termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) - # Restore cursor visibility and attributes. - sys.stdout.buffer.write(b"\033[?25h\033[0m\n") + # Leave alternate screen buffer, restore cursor and attributes. + sys.stdout.buffer.write(b"\033[?1049l\033[?25h\033[0m") sys.stdout.buffer.flush() From ba0e9d585231a64022a52864f4abdafeb032bce4 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 19:11:05 -0400 Subject: [PATCH 21/28] fix: handle WebSocket ConnectionClosed to prevent hang on shell exit --- centml/cli/shell.py | 52 ++++++++++++++++++++++++++------------------- 1 file changed, 30 insertions(+), 22 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index fc585d0..50b7e03 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -219,19 +219,24 @@ async def _forward_io(ws, screen, stream): async def _read_ws(): nonlocal exit_code - async for raw_msg in ws: - msg = json.loads(raw_msg) - if msg.get("data"): - # Convert bare \n to \r\n before feeding pyte, equivalent - # to xterm.js ``convertEol: true``. - stream.feed(msg["data"].replace("\n", "\r\n")) - _render_dirty(screen, sys.stdout.buffer) - elif msg.get("error"): - stream.feed(f"Error: {msg['error']}\r\n") - _render_dirty(screen, sys.stdout.buffer) - if "Code" in msg: - exit_code = msg["Code"] - return + try: + async for raw_msg in ws: + msg = json.loads(raw_msg) + if msg.get("data"): + # pyte expects \r\n; remote PTY may send bare \n + # (same as xterm.js ``convertEol: true``). + stream.feed(msg["data"].replace("\n", "\r\n")) + _render_dirty(screen, sys.stdout.buffer) + elif msg.get("error"): + stream.feed(f"Error: {msg['error']}\r\n") + _render_dirty(screen, sys.stdout.buffer) + if "Code" in msg: + exit_code = msg["Code"] + return + except websockets.ConnectionClosed: + # Backend proxy may not send a clean close frame when + # ArgoCD disconnects after the remote shell exits. + return async def _read_stdin(): read_queue = asyncio.Queue() @@ -250,16 +255,19 @@ def _on_stdin_ready(): data = await asyncio.wait_for(read_queue.get(), timeout=0.5) except asyncio.TimeoutError: continue - await ws.send( - json.dumps( - { - "operation": "stdin", - "data": data.decode("utf-8", errors="replace"), - "rows": screen.lines, - "cols": screen.columns, - } + try: + await ws.send( + json.dumps( + { + "operation": "stdin", + "data": data.decode("utf-8", errors="replace"), + "rows": screen.lines, + "cols": screen.columns, + } + ) ) - ) + except websockets.ConnectionClosed: + return finally: loop.remove_reader(stdin_fd) From f29df9490006dd32f68b63f3225f55610742a46a Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 19:16:39 -0400 Subject: [PATCH 22/28] refactor: use pyte for exec ANSI stripping and add ConnectionClosed handling Replace regex-based _strip_ansi with pyte single-row screen for marker detection. pyte interprets all VT100/VT220 sequences including OSC and truecolor escapes that the regex could miss. --- centml/cli/shell.py | 83 ++++++++++++++++++++++++++------------------- tests/test_shell.py | 73 ++++++++++++++++++++++++++++++++------- 2 files changed, 109 insertions(+), 47 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 50b7e03..dacd9d1 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -2,7 +2,6 @@ import asyncio import json -import re import shutil import signal import sys @@ -337,8 +336,6 @@ def _send_resize(): sys.stdout.buffer.flush() -_ANSI_ESCAPE_RE = re.compile(r"\x1b\[[0-9;?]*[a-zA-Z]|\x1b\].*?\x07|\x1b\[.*?\x1b\\") - _BEGIN_MARKER = "__CENTML_BEGIN_5f3a__" _END_MARKER = "__CENTML_END_5f3a__" @@ -348,9 +345,17 @@ def _send_resize(): _PRINTF_END = r"\137\137CENTML_END_5f3a\137\137" -def _strip_ansi(text): - """Remove ANSI escape sequences from text.""" - return _ANSI_ESCAPE_RE.sub("", text) +def _pyte_extract_text(line_stream, line_screen, text): + """Feed text through a single-row pyte screen and return visible characters. + + More robust than regex ANSI stripping: pyte interprets all VT100/VT220 + sequences including OSC, cursor repositioning, and truecolor escapes. + """ + line_screen.reset() + line_stream.feed(text) + return "".join( + line_screen.buffer[0][col].data for col in range(line_screen.columns) + ).rstrip() async def _exec_session(ws_url, token, command): @@ -360,6 +365,9 @@ async def _exec_session(ws_url, token, command): Suppresses shell echo and uses markers to capture only command output. """ cols, rows = shutil.get_terminal_size(fallback=(80, 24)) + # Single-row screen for interpreting escape sequences in marker detection. + line_screen = pyte.Screen(cols, 1) + line_stream = pyte.Stream(line_screen) headers = {"Authorization": f"Bearer {token}"} async with websockets.connect( @@ -386,35 +394,40 @@ async def _exec_session(ws_url, token, command): buffer = "" is_capturing = False is_done = False - async for raw_msg in ws: - msg = json.loads(raw_msg) - if msg.get("data"): - buffer += msg["data"] - while "\n" in buffer: - line, buffer = buffer.split("\n", 1) - clean = _strip_ansi(line).rstrip("\r") - if _BEGIN_MARKER in clean: - is_capturing = True - continue - if _END_MARKER in clean: - parts = clean.split(_END_MARKER + ":") - if len(parts) > 1: - try: - exit_code = int(parts[1].strip()) - except ValueError: - pass - is_done = True - break - if is_capturing: - sys.stdout.write(line + "\n") - sys.stdout.flush() - elif msg.get("error"): - sys.stderr.write(f"Error: {msg['error']}\n") - return 1 - if is_done or "Code" in msg: - if "Code" in msg: - exit_code = msg["Code"] - break + try: + async for raw_msg in ws: + msg = json.loads(raw_msg) + if msg.get("data"): + buffer += msg["data"] + while "\n" in buffer: + line, buffer = buffer.split("\n", 1) + clean = _pyte_extract_text( + line_stream, line_screen, line.rstrip("\r") + ) + if _BEGIN_MARKER in clean: + is_capturing = True + continue + if _END_MARKER in clean: + parts = clean.split(_END_MARKER + ":") + if len(parts) > 1: + try: + exit_code = int(parts[1].strip()) + except ValueError: + pass + is_done = True + break + if is_capturing: + sys.stdout.write(line + "\n") + sys.stdout.flush() + elif msg.get("error"): + sys.stderr.write(f"Error: {msg['error']}\n") + return 1 + if is_done or "Code" in msg: + if "Code" in msg: + exit_code = msg["Code"] + break + except websockets.ConnectionClosed: + pass return exit_code diff --git a/tests/test_shell.py b/tests/test_shell.py index 9d950c6..8dfdd3f 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -52,18 +52,6 @@ def _make_status_response(revisions): # =========================================================================== -class TestStripAnsi: - def test_strips_csi_sequences(self): - from centml.cli.shell import _strip_ansi - - assert _strip_ansi("\x1b[?2004htext\x1b[0m") == "text" - - def test_preserves_plain_text(self): - from centml.cli.shell import _strip_ansi - - assert _strip_ansi("hello world") == "hello world" - - class TestBuildWsUrl: def test_https_to_wss(self): from centml.cli.shell import _build_ws_url @@ -339,6 +327,67 @@ def test_filters_noise_before_marker(self): assert "real output" in output assert "prompt$" not in output + def test_connection_closed_returns_zero(self): + """Graceful exit when server closes connection without Code message.""" + from centml.cli.shell import _exec_session + + import websockets as _ws_lib + + ws = AsyncMock() + + async def _raise_closed(): + yield json.dumps({"data": "partial\n"}) + raise _ws_lib.ConnectionClosed(None, None) + + ws.__aiter__ = MagicMock(return_value=_raise_closed()) + + with patch("centml.cli.shell.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + mock_ws_mod.ConnectionClosed = _ws_lib.ConnectionClosed + + exit_code = asyncio.run( + _exec_session("wss://test/ws", "fake-token", "exit") + ) + + assert exit_code == 0 + + def test_handles_ansi_around_markers(self): + """Markers wrapped in ANSI codes are still detected via pyte.""" + from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER + + ws = AsyncMock() + # Markers surrounded by ANSI color codes. + data = f"\x1b[32m{_BEGIN_MARKER}\x1b[0m\noutput\n\x1b[32m{_END_MARKER}:0\x1b[0m\n" + messages = [json.dumps({"data": data}), json.dumps({"Code": 0})] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + captured = [] + with patch("centml.cli.shell.websockets") as mock_ws_mod, patch( + "centml.cli.shell.sys" + ) as mock_sys: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + mock_sys.stdout.write = lambda s: captured.append(s) + mock_sys.stdout.flush = MagicMock() + mock_sys.stderr.write = MagicMock() + + exit_code = asyncio.run( + _exec_session("wss://test/ws", "fake-token", "echo test") + ) + + assert exit_code == 0 + output = "".join(captured) + assert "output" in output + # =========================================================================== # _interactive_session -- terminal restore From db40469b8669b75850063dfa92bc9601add55911 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 19:49:29 -0400 Subject: [PATCH 23/28] fix: treat ArgoCD Code message as reconnect signal, not shell exit code --- centml/cli/shell.py | 109 ++++++++++----- tests/test_shell.py | 329 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 400 insertions(+), 38 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index dacd9d1..37f04ea 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -197,7 +197,7 @@ def _resolve_pod(cclient, deployment_id, pod_name=None): return running_pods[0] -async def _forward_io(ws, screen, stream): +async def _forward_io(ws, screen, stream, shutdown): """Bidirectional forwarding between local stdin/stdout and WebSocket. Output flows through a pyte terminal emulator so that cursor @@ -208,16 +208,21 @@ async def _forward_io(ws, screen, stream): ws: WebSocket connection. screen: pyte.Screen instance sized to the local terminal. stream: pyte.Stream attached to *screen*. + shutdown: asyncio.Event set by signal handlers to request exit. - Returns the remote exit code. + Returns: + Tuple of (exit_code, should_reconnect). ``should_reconnect`` is True + when the server sent a ``{"Code": ...}`` reconnect signal (ArgoCD + token refresh), False on normal exit or connection close. """ loop = asyncio.get_running_loop() exit_code = 0 + should_reconnect = False stdin_fd = sys.stdin.fileno() stdin_closed = asyncio.Event() async def _read_ws(): - nonlocal exit_code + nonlocal exit_code, should_reconnect try: async for raw_msg in ws: msg = json.loads(raw_msg) @@ -229,8 +234,11 @@ async def _read_ws(): elif msg.get("error"): stream.feed(f"Error: {msg['error']}\r\n") _render_dirty(screen, sys.stdout.buffer) + # ArgoCD sends {"Code": ...} as a reconnect signal (token + # refresh), not a shell exit code. Mirror ArgoCD UI behavior: + # disconnect and reconnect with a fresh token. if "Code" in msg: - exit_code = msg["Code"] + should_reconnect = True return except websockets.ConnectionClosed: # Backend proxy may not send a clean close frame when @@ -249,7 +257,7 @@ def _on_stdin_ready(): loop.add_reader(stdin_fd, _on_stdin_ready) try: - while not stdin_closed.is_set(): + while not stdin_closed.is_set() and not shutdown.is_set(): try: data = await asyncio.wait_for(read_queue.get(), timeout=0.5) except asyncio.TimeoutError: @@ -270,7 +278,15 @@ def _on_stdin_ready(): finally: loop.remove_reader(stdin_fd) - tasks = [asyncio.create_task(_read_ws()), asyncio.create_task(_read_stdin())] + async def _watch_shutdown(): + while not shutdown.is_set(): + await asyncio.sleep(0.2) + + tasks = [ + asyncio.create_task(_read_ws()), + asyncio.create_task(_read_stdin()), + asyncio.create_task(_watch_shutdown()), + ] done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) for t in pending: t.cancel() @@ -282,13 +298,19 @@ def _on_stdin_ready(): for t in done: if t.exception() is not None: raise t.exception() - return exit_code + return (exit_code, should_reconnect) -async def _interactive_session(ws_url, token): +async def _interactive_session(ws_url, get_token_fn): """Run an interactive terminal session over WebSocket. - Enters raw mode, forwards I/O bidirectionally, and restores terminal on exit. + Enters raw mode, forwards I/O bidirectionally, and restores terminal on + exit. Reconnects automatically when the server sends a ``{"Code": ...}`` + token-refresh signal (matching ArgoCD UI behavior). + + Args: + ws_url: WebSocket URL for the terminal endpoint. + get_token_fn: Callable that returns a fresh bearer token string. """ fd = sys.stdin.fileno() old_settings = termios.tcgetattr(fd) @@ -303,32 +325,51 @@ async def _interactive_session(ws_url, token): sys.stdout.buffer.write(b"\033[?1049h\033[2J\033[H") sys.stdout.buffer.flush() - headers = {"Authorization": f"Bearer {token}"} - async with websockets.connect( - ws_url, additional_headers=headers, close_timeout=2 - ) as ws: - await ws.send( - json.dumps({"operation": "resize", "rows": rows, "cols": cols}) - ) + loop = asyncio.get_running_loop() - loop = asyncio.get_running_loop() + shutdown = asyncio.Event() + loop.add_signal_handler(signal.SIGTERM, shutdown.set) + loop.add_signal_handler(signal.SIGHUP, shutdown.set) - def _send_resize(): - c, r = shutil.get_terminal_size() - screen.resize(r, c) - screen.dirty.update(range(r)) + # _ws_ref holds the current websocket so SIGWINCH can reach it. + _ws_ref = [None] + + def _send_resize(): + c, r = shutil.get_terminal_size() + screen.resize(r, c) + screen.dirty.update(range(r)) + if _ws_ref[0] is not None: asyncio.ensure_future( - ws.send(json.dumps({"operation": "resize", "rows": r, "cols": c})) + _ws_ref[0].send( + json.dumps({"operation": "resize", "rows": r, "cols": c}) + ) ) - loop.add_signal_handler(signal.SIGWINCH, _send_resize) + loop.add_signal_handler(signal.SIGWINCH, _send_resize) - try: - exit_code = await _forward_io(ws, screen, stream) - finally: - loop.remove_signal_handler(signal.SIGWINCH) - - return exit_code + try: + while True: + token = get_token_fn() + headers = {"Authorization": f"Bearer {token}"} + async with websockets.connect( + ws_url, additional_headers=headers, close_timeout=2 + ) as ws: + _ws_ref[0] = ws + await ws.send( + json.dumps( + {"operation": "resize", "rows": rows, "cols": cols} + ) + ) + exit_code, should_reconnect = await _forward_io( + ws, screen, stream, shutdown + ) + _ws_ref[0] = None + if not should_reconnect or shutdown.is_set(): + return exit_code + finally: + loop.remove_signal_handler(signal.SIGWINCH) + loop.remove_signal_handler(signal.SIGTERM) + loop.remove_signal_handler(signal.SIGHUP) finally: termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) # Leave alternate screen buffer, restore cursor and attributes. @@ -422,9 +463,10 @@ async def _exec_session(ws_url, token, command): elif msg.get("error"): sys.stderr.write(f"Error: {msg['error']}\n") return 1 - if is_done or "Code" in msg: - if "Code" in msg: - exit_code = msg["Code"] + if "Code" in msg and not is_done: + sys.stderr.write("Connection interrupted, please retry.\n") + return 1 + if is_done: break except websockets.ConnectionClosed: pass @@ -454,8 +496,7 @@ def shell(deployment_id, pod, shell_type): ws_url = _build_ws_url( settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type ) - token = auth.get_centml_token() - exit_code = asyncio.run(_interactive_session(ws_url, token)) + exit_code = asyncio.run(_interactive_session(ws_url, auth.get_centml_token)) sys.exit(exit_code) diff --git a/tests/test_shell.py b/tests/test_shell.py index 8dfdd3f..fb05e7f 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -3,8 +3,10 @@ import asyncio import io import json +import os +import signal import urllib.parse -from unittest.mock import AsyncMock, MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, call, patch import click import pyte @@ -405,6 +407,7 @@ def test_restores_terminal_on_exception(self): ) as mock_ws_mod: mock_sys.stdin.fileno.return_value = 0 + mock_sys.stdout.buffer = io.BytesIO() mock_termios.tcgetattr.return_value = ["old_settings"] mock_ws_mod.connect = MagicMock( @@ -414,8 +417,9 @@ def test_restores_terminal_on_exception(self): ) ) + get_token_fn = MagicMock(return_value="fake-token") with pytest.raises(ConnectionRefusedError): - asyncio.run(_interactive_session("wss://test/ws", "fake-token")) + asyncio.run(_interactive_session("wss://test/ws", get_token_fn)) mock_termios.tcsetattr.assert_called_once() restore_call = mock_termios.tcsetattr.call_args @@ -453,7 +457,7 @@ def test_shell_option_forwarded(self): mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) - mock_auth.get_centml_token.return_value = "token" + mock_auth.get_centml_token = MagicMock(return_value="token") mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" mock_sys.stdin.isatty.return_value = True mock_asyncio.run.return_value = 0 @@ -462,6 +466,10 @@ def test_shell_option_forwarded(self): runner.invoke(shell, ["123", "--shell", "bash"]) mock_asyncio.run.assert_called_once() + # Verify get_centml_token callable is passed, not its return value + args = mock_asyncio.run.call_args[0][0].cr_frame.f_locals if hasattr(mock_asyncio.run.call_args[0][0], 'cr_frame') else None + # Just verify the call happened with the right URL pattern + assert mock_asyncio.run.called def test_pod_option_forwarded(self): from centml.cli.shell import shell @@ -480,7 +488,7 @@ def test_pod_option_forwarded(self): mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) mock_resolve.return_value = "my-pod" - mock_auth.get_centml_token.return_value = "token" + mock_auth.get_centml_token = MagicMock(return_value="token") mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" mock_sys.stdin.isatty.return_value = True mock_asyncio.run.return_value = 0 @@ -659,3 +667,316 @@ def test_renders_only_dirty_lines(self): # line1 and line2 should NOT be re-rendered assert "line1" not in output assert "line2" not in output + + +# =========================================================================== +# _forward_io -- return type (exit_code, should_reconnect) +# =========================================================================== + + +class TestForwardIoReturnType: + """Tests for _forward_io returning (exit_code, should_reconnect) tuple. + + Uses a real pipe fd so ``loop.add_reader`` works without OS errors. + """ + + def _run_forward_io(self, ws, shutdown=None): + """Helper: run _forward_io with a real pipe fd standing in for stdin.""" + from centml.cli.shell import _forward_io + + import websockets as _ws_lib + + screen = pyte.Screen(80, 24) + stream = pyte.Stream(screen) + if shutdown is None: + shutdown = asyncio.Event() + + read_fd, write_fd = os.pipe() + # Close write end so the reader sees EOF quickly. + os.close(write_fd) + try: + with patch("centml.cli.shell.sys") as mock_sys, patch( + "centml.cli.shell.websockets" + ) as mock_ws_mod: + mock_sys.stdin.fileno.return_value = read_fd + mock_sys.stdin.buffer.read1 = lambda n: b"" + mock_sys.stdout.buffer = io.BytesIO() + mock_ws_mod.ConnectionClosed = _ws_lib.ConnectionClosed + + return asyncio.run(_forward_io(ws, screen, stream, shutdown)) + finally: + os.close(read_fd) + + def test_code_message_returns_reconnect(self): + """{"Code": ...} is a reconnect signal, not an exit code.""" + ws = AsyncMock() + messages = [json.dumps({"Code": 1})] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + exit_code, should_reconnect = self._run_forward_io(ws) + + assert exit_code == 0 + assert should_reconnect is True + + def test_connection_closed_no_reconnect(self): + """ConnectionClosed returns (0, False) -- no reconnect.""" + import websockets as _ws_lib + + ws = AsyncMock() + + async def _raise_closed(): + raise _ws_lib.ConnectionClosed(None, None) + yield # make it an async generator # pragma: no cover + + ws.__aiter__ = MagicMock(return_value=_raise_closed()) + + exit_code, should_reconnect = self._run_forward_io(ws) + + assert exit_code == 0 + assert should_reconnect is False + + def test_normal_exit_no_reconnect(self): + """Normal data flow ending returns (exit_code, False).""" + ws = AsyncMock() + messages = [] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + exit_code, should_reconnect = self._run_forward_io(ws) + + assert exit_code == 0 + assert should_reconnect is False + + def test_shutdown_event_exits(self): + """shutdown event causes _forward_io to exit.""" + from centml.cli.shell import _forward_io + + import websockets as _ws_lib + + ws = AsyncMock() + + async def _block_forever(): + await asyncio.sleep(999) + yield # pragma: no cover + + ws.__aiter__ = MagicMock(return_value=_block_forever()) + + screen = pyte.Screen(80, 24) + stream = pyte.Stream(screen) + shutdown = asyncio.Event() + + read_fd, write_fd = os.pipe() + os.close(write_fd) + try: + with patch("centml.cli.shell.sys") as mock_sys, patch( + "centml.cli.shell.websockets" + ) as mock_ws_mod: + mock_sys.stdin.fileno.return_value = read_fd + mock_sys.stdin.buffer.read1 = lambda n: b"" + mock_sys.stdout.buffer = io.BytesIO() + mock_ws_mod.ConnectionClosed = _ws_lib.ConnectionClosed + + async def _run(): + async def _set_shutdown(): + await asyncio.sleep(0.1) + shutdown.set() + + asyncio.create_task(_set_shutdown()) + return await _forward_io(ws, screen, stream, shutdown) + + exit_code, should_reconnect = asyncio.run(_run()) + finally: + os.close(read_fd) + + assert should_reconnect is False + + +# =========================================================================== +# _interactive_session -- reconnect loop +# =========================================================================== + + +class TestInteractiveSessionReconnect: + """Tests for _interactive_session reconnect on Code message.""" + + def test_reconnects_on_code_message(self): + """When _forward_io returns should_reconnect=True, session reconnects + with a fresh token from get_token_fn.""" + from centml.cli.shell import _interactive_session + + call_count = 0 + + async def _fake_forward_io(ws, screen, stream, shutdown): + nonlocal call_count + call_count += 1 + if call_count == 1: + return (0, True) # reconnect + return (0, False) # done + + get_token_fn = MagicMock(side_effect=["token-1", "token-2"]) + + with patch("centml.cli.shell.sys") as mock_sys, patch( + "centml.cli.shell.termios" + ) as mock_termios, patch("centml.cli.shell.tty"), patch( + "centml.cli.shell.websockets" + ) as mock_ws_mod, patch( + "centml.cli.shell._forward_io", side_effect=_fake_forward_io + ): + mock_sys.stdin.fileno.return_value = 0 + mock_sys.stdout.buffer = io.BytesIO() + mock_termios.tcgetattr.return_value = ["old"] + + mock_ws = AsyncMock() + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + + exit_code = asyncio.run( + _interactive_session("wss://test/ws", get_token_fn) + ) + + assert exit_code == 0 + assert get_token_fn.call_count == 2 + assert mock_ws_mod.connect.call_count == 2 + + def test_sigterm_restores_terminal(self): + """SIGTERM triggers shutdown and terminal settings are restored.""" + from centml.cli.shell import _interactive_session + + signal_handlers = {} + + def _fake_add_signal_handler(sig, handler): + signal_handlers[sig] = handler + + def _fake_remove_signal_handler(sig): + signal_handlers.pop(sig, None) + + async def _fake_forward_io(ws, screen, stream, shutdown): + # Simulate SIGTERM arriving during I/O + if signal.SIGTERM in signal_handlers: + signal_handlers[signal.SIGTERM]() + return (0, False) + + get_token_fn = MagicMock(return_value="token") + + with patch("centml.cli.shell.sys") as mock_sys, patch( + "centml.cli.shell.termios" + ) as mock_termios, patch("centml.cli.shell.tty"), patch( + "centml.cli.shell.websockets" + ) as mock_ws_mod, patch( + "centml.cli.shell._forward_io", side_effect=_fake_forward_io + ): + mock_sys.stdin.fileno.return_value = 0 + mock_sys.stdout.buffer = io.BytesIO() + mock_termios.tcgetattr.return_value = ["old_settings"] + + mock_ws = AsyncMock() + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + + # Patch the event loop signal handler methods + original_run = asyncio.run + + def _patched_run(coro): + loop = asyncio.new_event_loop() + loop.add_signal_handler = _fake_add_signal_handler + loop.remove_signal_handler = _fake_remove_signal_handler + try: + return loop.run_until_complete(coro) + finally: + loop.close() + + with patch("centml.cli.shell.asyncio") as mock_asyncio_mod: + mock_asyncio_mod.get_running_loop.return_value = MagicMock( + add_signal_handler=_fake_add_signal_handler, + remove_signal_handler=_fake_remove_signal_handler, + ) + mock_asyncio_mod.Event = asyncio.Event + mock_asyncio_mod.create_task = asyncio.ensure_future + + exit_code = _patched_run( + _interactive_session("wss://test/ws", get_token_fn) + ) + + mock_termios.tcsetattr.assert_called_once() + assert mock_termios.tcsetattr.call_args[0][2] == ["old_settings"] + + +# =========================================================================== +# _exec_session -- Code handling fix +# =========================================================================== + + +class TestExecSessionCodeHandling: + """Tests for _exec_session treating Code as interruption, not exit code.""" + + def test_code_without_done_returns_error(self): + """Code message before END marker means connection interrupted.""" + from centml.cli.shell import _exec_session, _BEGIN_MARKER + + ws = AsyncMock() + messages = [ + json.dumps({"data": f"{_BEGIN_MARKER}\npartial output\n"}), + json.dumps({"Code": 1}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + captured_stderr = [] + with patch("centml.cli.shell.websockets") as mock_ws_mod, patch( + "centml.cli.shell.sys" + ) as mock_sys: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + mock_sys.stdout.write = MagicMock() + mock_sys.stdout.flush = MagicMock() + mock_sys.stderr.write = lambda s: captured_stderr.append(s) + + exit_code = asyncio.run( + _exec_session("wss://test/ws", "fake-token", "long-cmd") + ) + + assert exit_code == 1 + stderr_output = "".join(captured_stderr) + assert "interrupted" in stderr_output.lower() or "retry" in stderr_output.lower() + + def test_code_after_done_uses_marker_code(self): + """When END marker is already seen, Code message is ignored.""" + from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER + + ws = AsyncMock() + messages = [ + json.dumps( + {"data": f"{_BEGIN_MARKER}\noutput\n{_END_MARKER}:7\n"} + ), + json.dumps({"Code": 99}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + + with patch("centml.cli.shell.websockets") as mock_ws_mod, patch( + "centml.cli.shell.sys" + ) as mock_sys: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + mock_sys.stdout.write = MagicMock() + mock_sys.stdout.flush = MagicMock() + mock_sys.stderr.write = MagicMock() + + exit_code = asyncio.run( + _exec_session("wss://test/ws", "fake-token", "exit 7") + ) + + assert exit_code == 7 From f0b37b8c904f5a808ba9a34023f2b6ed508d6f57 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 19:59:46 -0400 Subject: [PATCH 24/28] fix: stop reconnecting when shell has genuinely exited If two Code signals arrive within 3 seconds, the shell has exited and the reconnect just opened a new session. Exit cleanly instead of looping forever. --- centml/cli/shell.py | 9 +++++++++ tests/test_shell.py | 49 ++++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 57 insertions(+), 1 deletion(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 37f04ea..bfb28e6 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -6,6 +6,7 @@ import signal import sys import termios +import time import tty import urllib.parse @@ -348,6 +349,7 @@ def _send_resize(): loop.add_signal_handler(signal.SIGWINCH, _send_resize) try: + last_code_time = -999.0 while True: token = get_token_fn() headers = {"Authorization": f"Bearer {token}"} @@ -366,6 +368,13 @@ def _send_resize(): _ws_ref[0] = None if not should_reconnect or shutdown.is_set(): return exit_code + # If we got two Code signals within 3 seconds the shell + # has genuinely exited (reconnect just opened a new + # session that immediately closed). Stop reconnecting. + now = time.monotonic() + if now - last_code_time < 3.0: + return 0 + last_code_time = now finally: loop.remove_signal_handler(signal.SIGWINCH) loop.remove_signal_handler(signal.SIGTERM) diff --git a/tests/test_shell.py b/tests/test_shell.py index fb05e7f..0a2f2cc 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -820,11 +820,16 @@ async def _fake_forward_io(ws, screen, stream, shutdown): "centml.cli.shell.websockets" ) as mock_ws_mod, patch( "centml.cli.shell._forward_io", side_effect=_fake_forward_io - ): + ), patch( + "centml.cli.shell.time" + ) as mock_time: mock_sys.stdin.fileno.return_value = 0 mock_sys.stdout.buffer = io.BytesIO() mock_termios.tcgetattr.return_value = ["old"] + # Simulate > 3s gap so reconnect is allowed + mock_time.monotonic = MagicMock(side_effect=[0.0, 10.0]) + mock_ws = AsyncMock() mock_ws_mod.connect = MagicMock( return_value=AsyncMock( @@ -841,6 +846,48 @@ async def _fake_forward_io(ws, screen, stream, shutdown): assert get_token_fn.call_count == 2 assert mock_ws_mod.connect.call_count == 2 + def test_stops_reconnecting_on_rapid_code(self): + """Two Code signals within 3 seconds means the shell has genuinely + exited -- stop reconnecting instead of looping forever.""" + from centml.cli.shell import _interactive_session + + async def _always_reconnect(ws, screen, stream, shutdown): + return (0, True) + + get_token_fn = MagicMock(side_effect=["token-1", "token-2"]) + + with patch("centml.cli.shell.sys") as mock_sys, patch( + "centml.cli.shell.termios" + ) as mock_termios, patch("centml.cli.shell.tty"), patch( + "centml.cli.shell.websockets" + ) as mock_ws_mod, patch( + "centml.cli.shell._forward_io", side_effect=_always_reconnect + ), patch( + "centml.cli.shell.time" + ) as mock_time: + mock_sys.stdin.fileno.return_value = 0 + mock_sys.stdout.buffer = io.BytesIO() + mock_termios.tcgetattr.return_value = ["old"] + + # Simulate two rapid Code signals (0.5s apart < 3s threshold) + mock_time.monotonic = MagicMock(side_effect=[1.0, 1.5]) + + mock_ws = AsyncMock() + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=mock_ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + + exit_code = asyncio.run( + _interactive_session("wss://test/ws", get_token_fn) + ) + + assert exit_code == 0 + # Connected twice: first attempt + one reconnect, then stopped + assert mock_ws_mod.connect.call_count == 2 + def test_sigterm_restores_terminal(self): """SIGTERM triggers shutdown and terminal settings are restored.""" from centml.cli.shell import _interactive_session From 42658393a65159ecdae97b4188675eb6ca715ba0 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 20:04:04 -0400 Subject: [PATCH 25/28] chore: add debug file logging to shell and exec for exit hang diagnosis Logs to /tmp/centml_shell_debug.log (overridable via CENTML_SHELL_DEBUG_LOG env var). Traces every WS message, stdin event, task lifecycle, reconnect decision, and connection close. --- centml/cli/shell.py | 107 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 98 insertions(+), 9 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index bfb28e6..40d8537 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -2,6 +2,8 @@ import asyncio import json +import logging +import os import shutil import signal import sys @@ -19,6 +21,22 @@ from centml.sdk.api import get_centml_client from centml.sdk.config import settings +_log = logging.getLogger("centml.cli.shell") + + +def _setup_debug_log(): + """Configure file-based debug logging (stdout unusable in raw mode).""" + log_path = os.environ.get( + "CENTML_SHELL_DEBUG_LOG", "/tmp/centml_shell_debug.log" + ) + handler = logging.FileHandler(log_path, mode="w") + handler.setFormatter( + logging.Formatter("%(asctime)s.%(msecs)03d %(message)s", datefmt="%H:%M:%S") + ) + _log.addHandler(handler) + _log.setLevel(logging.DEBUG) + _log.debug("=== shell debug log started (pid=%d) ===", os.getpid()) + # --------------------------------------------------------------------------- # pyte screen renderer -- converts pyte's in-memory screen buffer to ANSI @@ -224,29 +242,50 @@ async def _forward_io(ws, screen, stream, shutdown): async def _read_ws(): nonlocal exit_code, should_reconnect + _log.debug("[read_ws] started") + msg_count = 0 try: async for raw_msg in ws: + msg_count += 1 msg = json.loads(raw_msg) + keys = list(msg.keys()) + data_snippet = "" + if msg.get("data"): + data_snippet = repr(msg["data"][:120]) + _log.debug( + "[read_ws] msg#%d keys=%s data=%s", + msg_count, keys, data_snippet, + ) if msg.get("data"): # pyte expects \r\n; remote PTY may send bare \n # (same as xterm.js ``convertEol: true``). stream.feed(msg["data"].replace("\n", "\r\n")) _render_dirty(screen, sys.stdout.buffer) elif msg.get("error"): + _log.debug("[read_ws] error: %s", msg["error"]) stream.feed(f"Error: {msg['error']}\r\n") _render_dirty(screen, sys.stdout.buffer) # ArgoCD sends {"Code": ...} as a reconnect signal (token # refresh), not a shell exit code. Mirror ArgoCD UI behavior: # disconnect and reconnect with a fresh token. if "Code" in msg: + _log.debug( + "[read_ws] got Code=%s, setting should_reconnect=True", + msg["Code"], + ) should_reconnect = True return - except websockets.ConnectionClosed: + _log.debug("[read_ws] ws iterator exhausted after %d msgs", msg_count) + except websockets.ConnectionClosed as exc: + _log.debug( + "[read_ws] ConnectionClosed after %d msgs: %s", msg_count, exc, + ) # Backend proxy may not send a clean close frame when # ArgoCD disconnects after the remote shell exits. return async def _read_stdin(): + _log.debug("[read_stdin] started") read_queue = asyncio.Queue() def _on_stdin_ready(): @@ -254,6 +293,7 @@ def _on_stdin_ready(): if data: read_queue.put_nowait(data) else: + _log.debug("[read_stdin] stdin EOF") stdin_closed.set() loop.add_reader(stdin_fd, _on_stdin_ready) @@ -263,6 +303,7 @@ def _on_stdin_ready(): data = await asyncio.wait_for(read_queue.get(), timeout=0.5) except asyncio.TimeoutError: continue + _log.debug("[read_stdin] sending %d bytes: %s", len(data), repr(data[:40])) try: await ws.send( json.dumps( @@ -275,7 +316,12 @@ def _on_stdin_ready(): ) ) except websockets.ConnectionClosed: + _log.debug("[read_stdin] ws closed on send") return + _log.debug( + "[read_stdin] loop exited: stdin_closed=%s shutdown=%s", + stdin_closed.is_set(), shutdown.is_set(), + ) finally: loop.remove_reader(stdin_fd) @@ -283,12 +329,18 @@ async def _watch_shutdown(): while not shutdown.is_set(): await asyncio.sleep(0.2) - tasks = [ - asyncio.create_task(_read_ws()), - asyncio.create_task(_read_stdin()), - asyncio.create_task(_watch_shutdown()), - ] + _log.debug("[forward_io] creating tasks") + task_ws = asyncio.create_task(_read_ws()) + task_stdin = asyncio.create_task(_read_stdin()) + task_shutdown = asyncio.create_task(_watch_shutdown()) + tasks = [task_ws, task_stdin, task_shutdown] + task_names = {id(task_ws): "read_ws", id(task_stdin): "read_stdin", id(task_shutdown): "watch_shutdown"} + done, pending = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED) + done_names = [task_names[id(t)] for t in done] + pending_names = [task_names[id(t)] for t in pending] + _log.debug("[forward_io] first completed: done=%s pending=%s", done_names, pending_names) + for t in pending: t.cancel() for t in pending: @@ -298,7 +350,12 @@ async def _watch_shutdown(): pass for t in done: if t.exception() is not None: + _log.debug("[forward_io] task exception: %s", t.exception()) raise t.exception() + _log.debug( + "[forward_io] returning exit_code=%d should_reconnect=%s", + exit_code, should_reconnect, + ) return (exit_code, should_reconnect) @@ -350,12 +407,17 @@ def _send_resize(): try: last_code_time = -999.0 + attempt = 0 while True: + attempt += 1 + _log.debug("[session] attempt #%d, fetching token", attempt) token = get_token_fn() headers = {"Authorization": f"Bearer {token}"} + _log.debug("[session] connecting to %s", ws_url) async with websockets.connect( ws_url, additional_headers=headers, close_timeout=2 ) as ws: + _log.debug("[session] connected, sending resize %dx%d", cols, rows) _ws_ref[0] = ws await ws.send( json.dumps( @@ -366,15 +428,27 @@ def _send_resize(): ws, screen, stream, shutdown ) _ws_ref[0] = None + _log.debug( + "[session] forward_io returned exit_code=%d reconnect=%s shutdown=%s", + exit_code, should_reconnect, shutdown.is_set(), + ) if not should_reconnect or shutdown.is_set(): + _log.debug("[session] exiting with code %d", exit_code) return exit_code # If we got two Code signals within 3 seconds the shell # has genuinely exited (reconnect just opened a new # session that immediately closed). Stop reconnecting. now = time.monotonic() - if now - last_code_time < 3.0: + gap = now - last_code_time + _log.debug( + "[session] reconnect: gap=%.2fs last_code_time=%.2f now=%.2f", + gap, last_code_time, now, + ) + if gap < 3.0: + _log.debug("[session] rapid Code, exiting with 0") return 0 last_code_time = now + _log.debug("[session] will reconnect") finally: loop.remove_signal_handler(signal.SIGWINCH) loop.remove_signal_handler(signal.SIGTERM) @@ -444,9 +518,16 @@ async def _exec_session(ws_url, token, command): buffer = "" is_capturing = False is_done = False + msg_count = 0 try: async for raw_msg in ws: + msg_count += 1 msg = json.loads(raw_msg) + keys = list(msg.keys()) + _log.debug( + "[exec] msg#%d keys=%s data_len=%d", + msg_count, keys, len(msg.get("data", "")), + ) if msg.get("data"): buffer += msg["data"] while "\n" in buffer: @@ -455,6 +536,7 @@ async def _exec_session(ws_url, token, command): line_stream, line_screen, line.rstrip("\r") ) if _BEGIN_MARKER in clean: + _log.debug("[exec] BEGIN marker found") is_capturing = True continue if _END_MARKER in clean: @@ -464,21 +546,26 @@ async def _exec_session(ws_url, token, command): exit_code = int(parts[1].strip()) except ValueError: pass + _log.debug("[exec] END marker, exit_code=%d", exit_code) is_done = True break if is_capturing: sys.stdout.write(line + "\n") sys.stdout.flush() elif msg.get("error"): + _log.debug("[exec] error: %s", msg["error"]) sys.stderr.write(f"Error: {msg['error']}\n") return 1 if "Code" in msg and not is_done: + _log.debug("[exec] Code=%s before done, returning 1", msg["Code"]) sys.stderr.write("Connection interrupted, please retry.\n") return 1 if is_done: + _log.debug("[exec] done, breaking") break - except websockets.ConnectionClosed: - pass + except websockets.ConnectionClosed as exc: + _log.debug("[exec] ConnectionClosed: %s", exc) + _log.debug("[exec] returning exit_code=%d", exit_code) return exit_code @@ -496,6 +583,7 @@ async def _exec_session(ws_url, token, command): ) @handle_exception def shell(deployment_id, pod, shell_type): + _setup_debug_log() if not sys.stdin.isatty(): raise click.ClickException("Interactive shell requires a terminal (TTY)") @@ -525,6 +613,7 @@ def shell(deployment_id, pod, shell_type): ) @handle_exception def exec_cmd(deployment_id, command, pod, shell_type): + _setup_debug_log() with get_centml_client() as cclient: pod_name = _resolve_pod(cclient, deployment_id, pod) From 27977afded427e896b90de1eed737b91dcd9d7b1 Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 20:15:24 -0400 Subject: [PATCH 26/28] fix: detect shell exit via idle timeout instead of Code message The platform API proxy never forwards ArgoCD Code messages and does not close the WebSocket when the remote shell exits. Replace the Code/reconnect logic with exit echo detection: when the server echoes back "exit\r\n", arm a 2-second idle timeout on ws.recv(). If no more data arrives, the shell has exited -- break out cleanly. Also removes Code handling from _exec_session (markers already work). --- centml/cli/shell.py | 168 +++++++++++----------------- tests/test_shell.py | 267 ++++++++------------------------------------ 2 files changed, 108 insertions(+), 327 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 40d8537..3e1d8f7 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -8,7 +8,6 @@ import signal import sys import termios -import time import tty import urllib.parse @@ -216,6 +215,9 @@ def _resolve_pod(cclient, deployment_id, pod_name=None): return running_pods[0] +_EXIT_IDLE_TIMEOUT = 2.0 + + async def _forward_io(ws, screen, stream, shutdown): """Bidirectional forwarding between local stdin/stdout and WebSocket. @@ -223,6 +225,10 @@ async def _forward_io(ws, screen, stream, shutdown): addressing, line wrapping, and colors are rendered correctly regardless of the remote PTY dimensions. + The platform API proxy does not close the WebSocket when the remote + shell exits, so we detect the ``exit`` echo and use a short idle + timeout to break out. + Args: ws: WebSocket connection. screen: pyte.Screen instance sized to the local terminal. @@ -230,58 +236,53 @@ async def _forward_io(ws, screen, stream, shutdown): shutdown: asyncio.Event set by signal handlers to request exit. Returns: - Tuple of (exit_code, should_reconnect). ``should_reconnect`` is True - when the server sent a ``{"Code": ...}`` reconnect signal (ArgoCD - token refresh), False on normal exit or connection close. + The exit code (always 0 for interactive sessions). """ loop = asyncio.get_running_loop() - exit_code = 0 - should_reconnect = False stdin_fd = sys.stdin.fileno() stdin_closed = asyncio.Event() async def _read_ws(): - nonlocal exit_code, should_reconnect _log.debug("[read_ws] started") msg_count = 0 + recv_timeout = None try: - async for raw_msg in ws: + while True: + try: + raw_msg = await asyncio.wait_for(ws.recv(), timeout=recv_timeout) + except asyncio.TimeoutError: + _log.debug( + "[read_ws] idle timeout (%.1fs) after %d msgs -- shell exited", + recv_timeout, msg_count, + ) + return msg_count += 1 msg = json.loads(raw_msg) keys = list(msg.keys()) - data_snippet = "" - if msg.get("data"): - data_snippet = repr(msg["data"][:120]) + data = msg.get("data", "") + data_snippet = repr(data[:120]) if data else "" _log.debug( "[read_ws] msg#%d keys=%s data=%s", msg_count, keys, data_snippet, ) - if msg.get("data"): - # pyte expects \r\n; remote PTY may send bare \n - # (same as xterm.js ``convertEol: true``). - stream.feed(msg["data"].replace("\n", "\r\n")) + if data: + stream.feed(data.replace("\n", "\r\n")) _render_dirty(screen, sys.stdout.buffer) + # Detect shell exit echo. When the user runs ``exit``, + # the remote PTY echoes ``exit\r\n`` as the final + # output. Switch to a short recv timeout so we exit + # cleanly instead of hanging on the open WebSocket. + if "exit\r\n" in data: + _log.debug("[read_ws] detected exit echo, arming idle timeout") + recv_timeout = _EXIT_IDLE_TIMEOUT elif msg.get("error"): _log.debug("[read_ws] error: %s", msg["error"]) stream.feed(f"Error: {msg['error']}\r\n") _render_dirty(screen, sys.stdout.buffer) - # ArgoCD sends {"Code": ...} as a reconnect signal (token - # refresh), not a shell exit code. Mirror ArgoCD UI behavior: - # disconnect and reconnect with a fresh token. - if "Code" in msg: - _log.debug( - "[read_ws] got Code=%s, setting should_reconnect=True", - msg["Code"], - ) - should_reconnect = True - return - _log.debug("[read_ws] ws iterator exhausted after %d msgs", msg_count) except websockets.ConnectionClosed as exc: _log.debug( "[read_ws] ConnectionClosed after %d msgs: %s", msg_count, exc, ) - # Backend proxy may not send a clean close frame when - # ArgoCD disconnects after the remote shell exits. return async def _read_stdin(): @@ -352,23 +353,16 @@ async def _watch_shutdown(): if t.exception() is not None: _log.debug("[forward_io] task exception: %s", t.exception()) raise t.exception() - _log.debug( - "[forward_io] returning exit_code=%d should_reconnect=%s", - exit_code, should_reconnect, - ) - return (exit_code, should_reconnect) + _log.debug("[forward_io] returning exit_code=0") + return 0 -async def _interactive_session(ws_url, get_token_fn): +async def _interactive_session(ws_url, token): """Run an interactive terminal session over WebSocket. - Enters raw mode, forwards I/O bidirectionally, and restores terminal on - exit. Reconnects automatically when the server sends a ``{"Code": ...}`` - token-refresh signal (matching ArgoCD UI behavior). - - Args: - ws_url: WebSocket URL for the terminal endpoint. - get_token_fn: Callable that returns a fresh bearer token string. + Enters raw mode, forwards I/O bidirectionally, and restores terminal + on exit. SIGTERM and SIGHUP are caught to ensure terminal settings + are always restored. """ fd = sys.stdin.fileno() old_settings = termios.tcgetattr(fd) @@ -389,71 +383,40 @@ async def _interactive_session(ws_url, get_token_fn): loop.add_signal_handler(signal.SIGTERM, shutdown.set) loop.add_signal_handler(signal.SIGHUP, shutdown.set) - # _ws_ref holds the current websocket so SIGWINCH can reach it. - _ws_ref = [None] - - def _send_resize(): - c, r = shutil.get_terminal_size() - screen.resize(r, c) - screen.dirty.update(range(r)) - if _ws_ref[0] is not None: + headers = {"Authorization": f"Bearer {token}"} + _log.debug("[session] connecting to %s", ws_url) + async with websockets.connect( + ws_url, additional_headers=headers, close_timeout=2 + ) as ws: + _log.debug("[session] connected, sending resize %dx%d", cols, rows) + + def _send_resize(): + c, r = shutil.get_terminal_size() + screen.resize(r, c) + screen.dirty.update(range(r)) asyncio.ensure_future( - _ws_ref[0].send( + ws.send( json.dumps({"operation": "resize", "rows": r, "cols": c}) ) ) - loop.add_signal_handler(signal.SIGWINCH, _send_resize) + loop.add_signal_handler(signal.SIGWINCH, _send_resize) - try: - last_code_time = -999.0 - attempt = 0 - while True: - attempt += 1 - _log.debug("[session] attempt #%d, fetching token", attempt) - token = get_token_fn() - headers = {"Authorization": f"Bearer {token}"} - _log.debug("[session] connecting to %s", ws_url) - async with websockets.connect( - ws_url, additional_headers=headers, close_timeout=2 - ) as ws: - _log.debug("[session] connected, sending resize %dx%d", cols, rows) - _ws_ref[0] = ws - await ws.send( - json.dumps( - {"operation": "resize", "rows": rows, "cols": cols} - ) - ) - exit_code, should_reconnect = await _forward_io( - ws, screen, stream, shutdown - ) - _ws_ref[0] = None - _log.debug( - "[session] forward_io returned exit_code=%d reconnect=%s shutdown=%s", - exit_code, should_reconnect, shutdown.is_set(), - ) - if not should_reconnect or shutdown.is_set(): - _log.debug("[session] exiting with code %d", exit_code) - return exit_code - # If we got two Code signals within 3 seconds the shell - # has genuinely exited (reconnect just opened a new - # session that immediately closed). Stop reconnecting. - now = time.monotonic() - gap = now - last_code_time - _log.debug( - "[session] reconnect: gap=%.2fs last_code_time=%.2f now=%.2f", - gap, last_code_time, now, - ) - if gap < 3.0: - _log.debug("[session] rapid Code, exiting with 0") - return 0 - last_code_time = now - _log.debug("[session] will reconnect") - finally: - loop.remove_signal_handler(signal.SIGWINCH) - loop.remove_signal_handler(signal.SIGTERM) - loop.remove_signal_handler(signal.SIGHUP) + await ws.send( + json.dumps( + {"operation": "resize", "rows": rows, "cols": cols} + ) + ) + try: + exit_code = await _forward_io(ws, screen, stream, shutdown) + finally: + loop.remove_signal_handler(signal.SIGWINCH) + + _log.debug("[session] exiting with code %d", exit_code) + return exit_code finally: + loop.remove_signal_handler(signal.SIGTERM) + loop.remove_signal_handler(signal.SIGHUP) termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) # Leave alternate screen buffer, restore cursor and attributes. sys.stdout.buffer.write(b"\033[?1049l\033[?25h\033[0m") @@ -556,10 +519,6 @@ async def _exec_session(ws_url, token, command): _log.debug("[exec] error: %s", msg["error"]) sys.stderr.write(f"Error: {msg['error']}\n") return 1 - if "Code" in msg and not is_done: - _log.debug("[exec] Code=%s before done, returning 1", msg["Code"]) - sys.stderr.write("Connection interrupted, please retry.\n") - return 1 if is_done: _log.debug("[exec] done, breaking") break @@ -593,7 +552,8 @@ def shell(deployment_id, pod, shell_type): ws_url = _build_ws_url( settings.CENTML_PLATFORM_API_URL, deployment_id, pod_name, shell_type ) - exit_code = asyncio.run(_interactive_session(ws_url, auth.get_centml_token)) + token = auth.get_centml_token() + exit_code = asyncio.run(_interactive_session(ws_url, token)) sys.exit(exit_code) diff --git a/tests/test_shell.py b/tests/test_shell.py index 0a2f2cc..203c82c 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -6,7 +6,7 @@ import os import signal import urllib.parse -from unittest.mock import AsyncMock, MagicMock, call, patch +from unittest.mock import AsyncMock, MagicMock, patch import click import pyte @@ -417,9 +417,8 @@ def test_restores_terminal_on_exception(self): ) ) - get_token_fn = MagicMock(return_value="fake-token") with pytest.raises(ConnectionRefusedError): - asyncio.run(_interactive_session("wss://test/ws", get_token_fn)) + asyncio.run(_interactive_session("wss://test/ws", "fake-token")) mock_termios.tcsetattr.assert_called_once() restore_call = mock_termios.tcsetattr.call_args @@ -457,7 +456,7 @@ def test_shell_option_forwarded(self): mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) - mock_auth.get_centml_token = MagicMock(return_value="token") + mock_auth.get_centml_token.return_value = "token" mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" mock_sys.stdin.isatty.return_value = True mock_asyncio.run.return_value = 0 @@ -466,10 +465,6 @@ def test_shell_option_forwarded(self): runner.invoke(shell, ["123", "--shell", "bash"]) mock_asyncio.run.assert_called_once() - # Verify get_centml_token callable is passed, not its return value - args = mock_asyncio.run.call_args[0][0].cr_frame.f_locals if hasattr(mock_asyncio.run.call_args[0][0], 'cr_frame') else None - # Just verify the call happened with the right URL pattern - assert mock_asyncio.run.called def test_pod_option_forwarded(self): from centml.cli.shell import shell @@ -488,7 +483,7 @@ def test_pod_option_forwarded(self): mock_ctx.return_value.__enter__ = MagicMock(return_value=MagicMock()) mock_ctx.return_value.__exit__ = MagicMock(return_value=False) mock_resolve.return_value = "my-pod" - mock_auth.get_centml_token = MagicMock(return_value="token") + mock_auth.get_centml_token.return_value = "token" mock_settings.CENTML_PLATFORM_API_URL = "https://api.centml.com" mock_sys.stdin.isatty.return_value = True mock_asyncio.run.return_value = 0 @@ -670,12 +665,12 @@ def test_renders_only_dirty_lines(self): # =========================================================================== -# _forward_io -- return type (exit_code, should_reconnect) +# _forward_io -- exit detection and shutdown # =========================================================================== -class TestForwardIoReturnType: - """Tests for _forward_io returning (exit_code, should_reconnect) tuple. +class TestForwardIo: + """Tests for _forward_io exit detection via idle timeout. Uses a real pipe fd so ``loop.add_reader`` works without OS errors. """ @@ -692,7 +687,6 @@ def _run_forward_io(self, ws, shutdown=None): shutdown = asyncio.Event() read_fd, write_fd = os.pipe() - # Close write end so the reader sees EOF quickly. os.close(write_fd) try: with patch("centml.cli.shell.sys") as mock_sys, patch( @@ -707,44 +701,44 @@ def _run_forward_io(self, ws, shutdown=None): finally: os.close(read_fd) - def test_code_message_returns_reconnect(self): - """{"Code": ...} is a reconnect signal, not an exit code.""" - ws = AsyncMock() - messages = [json.dumps({"Code": 1})] - ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + def test_connection_closed_returns_zero(self): + """ConnectionClosed returns 0.""" + import websockets as _ws_lib - exit_code, should_reconnect = self._run_forward_io(ws) + ws = AsyncMock() + ws.recv = AsyncMock(side_effect=_ws_lib.ConnectionClosed(None, None)) - assert exit_code == 0 - assert should_reconnect is True + assert self._run_forward_io(ws) == 0 - def test_connection_closed_no_reconnect(self): - """ConnectionClosed returns (0, False) -- no reconnect.""" + def test_exit_echo_triggers_idle_timeout(self): + """Data containing 'exit\\r\\n' arms an idle timeout that ends the session.""" import websockets as _ws_lib ws = AsyncMock() + # First recv returns exit echo, second recv times out. + ws.recv = AsyncMock( + side_effect=[ + json.dumps({"data": "\r\n\x1b[?2004l\rexit\r\n"}), + asyncio.TimeoutError(), + ] + ) - async def _raise_closed(): - raise _ws_lib.ConnectionClosed(None, None) - yield # make it an async generator # pragma: no cover - - ws.__aiter__ = MagicMock(return_value=_raise_closed()) - - exit_code, should_reconnect = self._run_forward_io(ws) + with patch("centml.cli.shell._EXIT_IDLE_TIMEOUT", 0.01): + assert self._run_forward_io(ws) == 0 - assert exit_code == 0 - assert should_reconnect is False + def test_normal_data_no_early_exit(self): + """Data without 'exit\\r\\n' does not trigger early exit.""" + import websockets as _ws_lib - def test_normal_exit_no_reconnect(self): - """Normal data flow ending returns (exit_code, False).""" ws = AsyncMock() - messages = [] - ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) - - exit_code, should_reconnect = self._run_forward_io(ws) + ws.recv = AsyncMock( + side_effect=[ + json.dumps({"data": "hello\r\n"}), + _ws_lib.ConnectionClosed(None, None), + ] + ) - assert exit_code == 0 - assert should_reconnect is False + assert self._run_forward_io(ws) == 0 def test_shutdown_event_exits(self): """shutdown event causes _forward_io to exit.""" @@ -754,11 +748,11 @@ def test_shutdown_event_exits(self): ws = AsyncMock() - async def _block_forever(): + # recv that blocks until cancelled (simulates open WS with no data) + async def _block_recv(): await asyncio.sleep(999) - yield # pragma: no cover - ws.__aiter__ = MagicMock(return_value=_block_forever()) + ws.recv = _block_recv screen = pyte.Screen(80, 24) stream = pyte.Stream(screen) @@ -783,113 +777,20 @@ async def _set_shutdown(): asyncio.create_task(_set_shutdown()) return await _forward_io(ws, screen, stream, shutdown) - exit_code, should_reconnect = asyncio.run(_run()) + assert asyncio.run(_run()) == 0 finally: os.close(read_fd) - assert should_reconnect is False - # =========================================================================== -# _interactive_session -- reconnect loop +# _interactive_session -- signal handling # =========================================================================== -class TestInteractiveSessionReconnect: - """Tests for _interactive_session reconnect on Code message.""" - - def test_reconnects_on_code_message(self): - """When _forward_io returns should_reconnect=True, session reconnects - with a fresh token from get_token_fn.""" - from centml.cli.shell import _interactive_session - - call_count = 0 - - async def _fake_forward_io(ws, screen, stream, shutdown): - nonlocal call_count - call_count += 1 - if call_count == 1: - return (0, True) # reconnect - return (0, False) # done - - get_token_fn = MagicMock(side_effect=["token-1", "token-2"]) - - with patch("centml.cli.shell.sys") as mock_sys, patch( - "centml.cli.shell.termios" - ) as mock_termios, patch("centml.cli.shell.tty"), patch( - "centml.cli.shell.websockets" - ) as mock_ws_mod, patch( - "centml.cli.shell._forward_io", side_effect=_fake_forward_io - ), patch( - "centml.cli.shell.time" - ) as mock_time: - mock_sys.stdin.fileno.return_value = 0 - mock_sys.stdout.buffer = io.BytesIO() - mock_termios.tcgetattr.return_value = ["old"] - - # Simulate > 3s gap so reconnect is allowed - mock_time.monotonic = MagicMock(side_effect=[0.0, 10.0]) - - mock_ws = AsyncMock() - mock_ws_mod.connect = MagicMock( - return_value=AsyncMock( - __aenter__=AsyncMock(return_value=mock_ws), - __aexit__=AsyncMock(return_value=False), - ) - ) - - exit_code = asyncio.run( - _interactive_session("wss://test/ws", get_token_fn) - ) - - assert exit_code == 0 - assert get_token_fn.call_count == 2 - assert mock_ws_mod.connect.call_count == 2 - - def test_stops_reconnecting_on_rapid_code(self): - """Two Code signals within 3 seconds means the shell has genuinely - exited -- stop reconnecting instead of looping forever.""" - from centml.cli.shell import _interactive_session - - async def _always_reconnect(ws, screen, stream, shutdown): - return (0, True) - - get_token_fn = MagicMock(side_effect=["token-1", "token-2"]) - - with patch("centml.cli.shell.sys") as mock_sys, patch( - "centml.cli.shell.termios" - ) as mock_termios, patch("centml.cli.shell.tty"), patch( - "centml.cli.shell.websockets" - ) as mock_ws_mod, patch( - "centml.cli.shell._forward_io", side_effect=_always_reconnect - ), patch( - "centml.cli.shell.time" - ) as mock_time: - mock_sys.stdin.fileno.return_value = 0 - mock_sys.stdout.buffer = io.BytesIO() - mock_termios.tcgetattr.return_value = ["old"] - - # Simulate two rapid Code signals (0.5s apart < 3s threshold) - mock_time.monotonic = MagicMock(side_effect=[1.0, 1.5]) - - mock_ws = AsyncMock() - mock_ws_mod.connect = MagicMock( - return_value=AsyncMock( - __aenter__=AsyncMock(return_value=mock_ws), - __aexit__=AsyncMock(return_value=False), - ) - ) - - exit_code = asyncio.run( - _interactive_session("wss://test/ws", get_token_fn) - ) - - assert exit_code == 0 - # Connected twice: first attempt + one reconnect, then stopped - assert mock_ws_mod.connect.call_count == 2 +class TestInteractiveSessionSignals: + """Tests for SIGTERM/SIGHUP restoring terminal settings.""" def test_sigterm_restores_terminal(self): - """SIGTERM triggers shutdown and terminal settings are restored.""" from centml.cli.shell import _interactive_session signal_handlers = {} @@ -901,12 +802,9 @@ def _fake_remove_signal_handler(sig): signal_handlers.pop(sig, None) async def _fake_forward_io(ws, screen, stream, shutdown): - # Simulate SIGTERM arriving during I/O if signal.SIGTERM in signal_handlers: signal_handlers[signal.SIGTERM]() - return (0, False) - - get_token_fn = MagicMock(return_value="token") + return 0 with patch("centml.cli.shell.sys") as mock_sys, patch( "centml.cli.shell.termios" @@ -927,9 +825,6 @@ async def _fake_forward_io(ws, screen, stream, shutdown): ) ) - # Patch the event loop signal handler methods - original_run = asyncio.run - def _patched_run(coro): loop = asyncio.new_event_loop() loop.add_signal_handler = _fake_add_signal_handler @@ -947,83 +842,9 @@ def _patched_run(coro): mock_asyncio_mod.Event = asyncio.Event mock_asyncio_mod.create_task = asyncio.ensure_future - exit_code = _patched_run( - _interactive_session("wss://test/ws", get_token_fn) + _patched_run( + _interactive_session("wss://test/ws", "fake-token") ) mock_termios.tcsetattr.assert_called_once() assert mock_termios.tcsetattr.call_args[0][2] == ["old_settings"] - - -# =========================================================================== -# _exec_session -- Code handling fix -# =========================================================================== - - -class TestExecSessionCodeHandling: - """Tests for _exec_session treating Code as interruption, not exit code.""" - - def test_code_without_done_returns_error(self): - """Code message before END marker means connection interrupted.""" - from centml.cli.shell import _exec_session, _BEGIN_MARKER - - ws = AsyncMock() - messages = [ - json.dumps({"data": f"{_BEGIN_MARKER}\npartial output\n"}), - json.dumps({"Code": 1}), - ] - ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) - - captured_stderr = [] - with patch("centml.cli.shell.websockets") as mock_ws_mod, patch( - "centml.cli.shell.sys" - ) as mock_sys: - mock_ws_mod.connect = MagicMock( - return_value=AsyncMock( - __aenter__=AsyncMock(return_value=ws), - __aexit__=AsyncMock(return_value=False), - ) - ) - mock_sys.stdout.write = MagicMock() - mock_sys.stdout.flush = MagicMock() - mock_sys.stderr.write = lambda s: captured_stderr.append(s) - - exit_code = asyncio.run( - _exec_session("wss://test/ws", "fake-token", "long-cmd") - ) - - assert exit_code == 1 - stderr_output = "".join(captured_stderr) - assert "interrupted" in stderr_output.lower() or "retry" in stderr_output.lower() - - def test_code_after_done_uses_marker_code(self): - """When END marker is already seen, Code message is ignored.""" - from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER - - ws = AsyncMock() - messages = [ - json.dumps( - {"data": f"{_BEGIN_MARKER}\noutput\n{_END_MARKER}:7\n"} - ), - json.dumps({"Code": 99}), - ] - ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) - - with patch("centml.cli.shell.websockets") as mock_ws_mod, patch( - "centml.cli.shell.sys" - ) as mock_sys: - mock_ws_mod.connect = MagicMock( - return_value=AsyncMock( - __aenter__=AsyncMock(return_value=ws), - __aexit__=AsyncMock(return_value=False), - ) - ) - mock_sys.stdout.write = MagicMock() - mock_sys.stdout.flush = MagicMock() - mock_sys.stderr.write = MagicMock() - - exit_code = asyncio.run( - _exec_session("wss://test/ws", "fake-token", "exit 7") - ) - - assert exit_code == 7 From 1857f891391bc68eeb3afbef87ed409a190d0e4f Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 20:19:58 -0400 Subject: [PATCH 27/28] fix: exit immediately on exit echo, ignore echo exit with trailing prompt When "exit\r\n" appears at the end of ws data (nothing after it), the shell has exited -- return immediately instead of waiting 2s. When "exit\r\n" is followed by a new prompt (e.g. from echo exit), ignore it and continue the session. --- centml/cli/shell.py | 28 +++++++++++++++++++--------- tests/test_shell.py | 30 +++++++++++++++++++++++------- 2 files changed, 42 insertions(+), 16 deletions(-) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 3e1d8f7..70f27ce 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -215,9 +215,6 @@ def _resolve_pod(cclient, deployment_id, pod_name=None): return running_pods[0] -_EXIT_IDLE_TIMEOUT = 2.0 - - async def _forward_io(ws, screen, stream, shutdown): """Bidirectional forwarding between local stdin/stdout and WebSocket. @@ -268,13 +265,26 @@ async def _read_ws(): if data: stream.feed(data.replace("\n", "\r\n")) _render_dirty(screen, sys.stdout.buffer) - # Detect shell exit echo. When the user runs ``exit``, + # Detect shell exit echo. When the user runs ``exit``, # the remote PTY echoes ``exit\r\n`` as the final - # output. Switch to a short recv timeout so we exit - # cleanly instead of hanging on the open WebSocket. - if "exit\r\n" in data: - _log.debug("[read_ws] detected exit echo, arming idle timeout") - recv_timeout = _EXIT_IDLE_TIMEOUT + # output with nothing after it. If ``exit\r\n`` appears + # mid-message (e.g. ``echo exit`` produces + # ``exit\r\n``), the shell is still alive. + idx = data.rfind("exit\r\n") + if idx != -1: + after_exit = data[idx + len("exit\r\n"):] + if after_exit.strip(): + # New prompt follows -- shell is still alive. + _log.debug("[read_ws] exit echo with trailing data, not a real exit") + recv_timeout = None + else: + # Nothing meaningful after exit -- shell exited. + _log.debug("[read_ws] detected exit echo at end of data, exiting") + return + elif recv_timeout is not None: + # We previously armed a timeout but got more data + # (shouldn't normally happen, but be safe). + recv_timeout = None elif msg.get("error"): _log.debug("[read_ws] error: %s", msg["error"]) stream.feed(f"Error: {msg['error']}\r\n") diff --git a/tests/test_shell.py b/tests/test_shell.py index 203c82c..0aa7426 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -710,21 +710,37 @@ def test_connection_closed_returns_zero(self): assert self._run_forward_io(ws) == 0 - def test_exit_echo_triggers_idle_timeout(self): - """Data containing 'exit\\r\\n' arms an idle timeout that ends the session.""" + def test_exit_echo_at_end_exits_immediately(self): + """'exit\\r\\n' at end of data (no trailing prompt) exits immediately.""" + ws = AsyncMock() + ws.recv = AsyncMock( + side_effect=[ + json.dumps({"data": "\r\n\x1b[?2004l\rexit\r\n"}), + # Should never be called -- _read_ws returns before this. + json.dumps({"data": "should not reach"}), + ] + ) + + assert self._run_forward_io(ws) == 0 + # Only one recv call -- exited immediately after exit echo. + assert ws.recv.call_count == 1 + + def test_exit_echo_with_prompt_continues(self): + """'exit\\r\\n' followed by a new prompt is not a real exit.""" import websockets as _ws_lib ws = AsyncMock() - # First recv returns exit echo, second recv times out. ws.recv = AsyncMock( side_effect=[ - json.dumps({"data": "\r\n\x1b[?2004l\rexit\r\n"}), - asyncio.TimeoutError(), + # ``echo exit`` -- exit echo with prompt trailing. + json.dumps({"data": "\r\n\x1b[?2004l\rexit\r\n\x1b[?2004huser@host:~$ "}), + _ws_lib.ConnectionClosed(None, None), ] ) - with patch("centml.cli.shell._EXIT_IDLE_TIMEOUT", 0.01): - assert self._run_forward_io(ws) == 0 + assert self._run_forward_io(ws) == 0 + # Both recv calls made -- did not exit after the first message. + assert ws.recv.call_count == 2 def test_normal_data_no_early_exit(self): """Data without 'exit\\r\\n' does not trigger early exit.""" From 76fd598b734dea487d9aef93aab669f9ac9102aa Mon Sep 17 00:00:00 2001 From: Honglin Cao Date: Thu, 12 Mar 2026 20:26:15 -0400 Subject: [PATCH 28/28] fix: skip websocket close handshake wait after session ends --- centml/cli/shell.py | 6 ++++++ tests/test_shell.py | 23 +++++++++++++++++++++++ 2 files changed, 29 insertions(+) diff --git a/centml/cli/shell.py b/centml/cli/shell.py index 70f27ce..34abf6d 100644 --- a/centml/cli/shell.py +++ b/centml/cli/shell.py @@ -422,6 +422,8 @@ def _send_resize(): finally: loop.remove_signal_handler(signal.SIGWINCH) + # Skip close handshake -- proxy won't send close frame promptly. + ws.close_timeout = 0 _log.debug("[session] exiting with code %d", exit_code) return exit_code finally: @@ -534,6 +536,10 @@ async def _exec_session(ws_url, token, command): break except websockets.ConnectionClosed as exc: _log.debug("[exec] ConnectionClosed: %s", exc) + if is_done: + # Skip the close handshake -- the platform API proxy does not + # proactively close its end, so waiting wastes close_timeout seconds. + ws.close_timeout = 0 _log.debug("[exec] returning exit_code=%d", exit_code) return exit_code diff --git a/tests/test_shell.py b/tests/test_shell.py index 0aa7426..a6afd43 100644 --- a/tests/test_shell.py +++ b/tests/test_shell.py @@ -358,6 +358,29 @@ async def _raise_closed(): assert exit_code == 0 + def test_sets_zero_close_timeout_after_done(self): + """After END marker, close_timeout should be 0 to avoid waiting for server close.""" + from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER + + ws = AsyncMock() + messages = [ + json.dumps({"data": f"{_BEGIN_MARKER}\nhello\n{_END_MARKER}:0\n"}), + ] + ws.__aiter__ = MagicMock(return_value=_async_iter_from_list(messages)) + ws.close_timeout = 2 + + with patch("centml.cli.shell.websockets") as mock_ws_mod: + mock_ws_mod.connect = MagicMock( + return_value=AsyncMock( + __aenter__=AsyncMock(return_value=ws), + __aexit__=AsyncMock(return_value=False), + ) + ) + + asyncio.run(_exec_session("wss://test/ws", "fake-token", "echo hello")) + + assert ws.close_timeout == 0 + def test_handles_ansi_around_markers(self): """Markers wrapped in ANSI codes are still detected via pyte.""" from centml.cli.shell import _exec_session, _BEGIN_MARKER, _END_MARKER