From 28326b4ce065d25912ffc0c7b8c7fb8edd749a0c Mon Sep 17 00:00:00 2001 From: Stefan de Vogelaere Date: Wed, 13 May 2026 19:40:27 +0200 Subject: [PATCH 1/5] feat: add reset_timeout_on_progress and max_total_timeout to send_request MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit When an MCP client calls a long-running tool that sends periodic progress notifications, the request currently times out on a fixed wall-clock budget even though the server is actively sending progress updates. This breaks blocking UI-card tools in agent frameworks (e.g., DevFlow's devflow_user_choice, devflow_suggest_create_pr) that intentionally wait for human interaction and rely on progress notifications as heartbeats. Add two new opt-in parameters to BaseSession.send_request(): - reset_timeout_on_progress (bool, default False): when enabled, each matching notifications/progress resets the timeout window to current_time + timeout. - max_total_timeout (float | None, default None): optional absolute ceiling measured from the original request start time. If exceeded, the request fails even if progress keeps arriving. Both parameters are threaded through ClientSession.call_tool() and are backward-compatible — existing code is unaffected. Implementation uses anyio.move_on_after() in a loop, checking a per-request _ProgressTimeoutInfo.was_reset flag set by the receive loop when a matching progress notification arrives. This avoids cross-task CancelScope manipulation and works with both asyncio and trio backends. Semantics mirror the TypeScript SDK's resetTimeoutOnProgress / maxTotalTimeout in RequestOptions. Tests cover: no progress → timeout, progress resets timeout, max total timeout ceiling, progress stops → timeout fires, multiple progress notifications, default-off behavior, and call_tool passthrough. DEV-3332 --- src/mcp/client/session.py | 18 +- src/mcp/shared/session.py | 132 ++++++- tests/shared/test_progress_reset_timeout.py | 376 ++++++++++++++++++++ 3 files changed, 513 insertions(+), 13 deletions(-) create mode 100644 tests/shared/test_progress_reset_timeout.py diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0cea454a7..ad947a64f 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -304,8 +304,22 @@ async def call_tool( progress_callback: ProgressFnT | None = None, *, meta: RequestParamsMeta | None = None, + reset_timeout_on_progress: bool = False, + max_total_timeout: float | None = None, ) -> types.CallToolResult: - """Send a tools/call request with optional progress callback support.""" + """Send a tools/call request with optional progress callback support. + + Args: + name: The tool name. + arguments: Optional arguments for the tool. + read_timeout_seconds: Per-request timeout override. + progress_callback: Optional callback for progress notifications. + meta: Optional request metadata. + reset_timeout_on_progress: When True, each progress notification + resets the timeout window. + max_total_timeout: Optional absolute ceiling (seconds) measured + from request start. + """ result = await self.send_request( types.CallToolRequest( @@ -314,6 +328,8 @@ async def call_tool( types.CallToolResult, request_read_timeout_seconds=read_timeout_seconds, progress_callback=progress_callback, + reset_timeout_on_progress=reset_timeout_on_progress, + max_total_timeout=max_total_timeout, ) if not result.is_error: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 243eef5ae..ba1293141 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -2,8 +2,10 @@ import contextvars import logging +import time from collections.abc import Callable from contextlib import AsyncExitStack +from dataclasses import dataclass, field from types import TracebackType from typing import Any, Generic, Protocol, TypeVar @@ -57,6 +59,30 @@ async def __call__( ) -> None: ... # pragma: no branch +@dataclass +class _ProgressTimeoutInfo: + """Tracks timeout state for a request with progress-reset support. + + When ``reset_timeout_on_progress`` is enabled, each matching + ``notifications/progress`` sets ``was_reset`` so the send_request loop + can restart its timeout window. ``max_total_timeout`` provides an + optional absolute ceiling measured from ``start_time``. + """ + + timeout: float + start_time: float + max_total_timeout: float | None + was_reset: bool = False + max_exceeded: bool = False + + +async def _noop_progress_callback( + progress: float, total: float | None, message: str | None +) -> None: + """No-op progress callback used when reset_timeout_on_progress is set + without an explicit progress_callback.""" + + class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -183,6 +209,7 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] + _progress_timeout_infos: dict[RequestId, _ProgressTimeoutInfo] _response_routers: list[ResponseRouter] def __init__( @@ -199,6 +226,7 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} + self._progress_timeout_infos = {} self._response_routers = [] self._exit_stack = AsyncExitStack() @@ -243,6 +271,8 @@ async def send_request( request_read_timeout_seconds: float | None = None, metadata: MessageMetadata = None, progress_callback: ProgressFnT | None = None, + reset_timeout_on_progress: bool = False, + max_total_timeout: float | None = None, ) -> ReceiveResultT: """Sends a request and waits for a response. @@ -250,6 +280,21 @@ async def send_request( precedence over the session read timeout. Do not use this method to emit notifications! Use send_notification() instead. + + Args: + request: The request to send. + result_type: The expected result type. + request_read_timeout_seconds: Per-request timeout override. + metadata: Optional metadata for the message. + progress_callback: Optional callback invoked on progress notifications. + reset_timeout_on_progress: When True, each ``notifications/progress`` + for this request resets the timeout window. The new deadline is + ``current_time + timeout``. Requires a non-None timeout. + max_total_timeout: Optional absolute ceiling (seconds) measured from + the original request start time. Only meaningful when + ``reset_timeout_on_progress=True``. If the elapsed time since + ``start_time`` exceeds this value, the request fails immediately + even if progress keeps arriving. """ request_id = self._request_id self._request_id = request_id + 1 @@ -257,17 +302,22 @@ async def send_request( response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) self._response_streams[request_id] = response_stream - # Set up progress token if progress callback is provided + # Set up progress token if a progress callback is provided or if + # reset_timeout_on_progress is enabled (which needs the token to + # match incoming progress notifications). request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) - if progress_callback is not None: - # Use request_id as progress token + needs_progress_token = progress_callback is not None or reset_timeout_on_progress + if needs_progress_token: if "params" not in request_data: # pragma: lax no cover request_data["params"] = {} if "_meta" not in request_data["params"]: # pragma: lax no cover request_data["params"]["_meta"] = {} request_data["params"]["_meta"]["progressToken"] = request_id - # Store the callback for this request - self._progress_callbacks[request_id] = progress_callback + # Register the callback (no-op if only reset_timeout_on_progress is set) + if progress_callback is not None: + self._progress_callbacks[request_id] = progress_callback + else: + self._progress_callbacks[request_id] = _noop_progress_callback try: target = request_data.get("params", {}).get("name") @@ -288,13 +338,56 @@ async def send_request( # request read timeout takes precedence over session read timeout timeout = request_read_timeout_seconds or self._session_read_timeout_seconds - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - class_name = request.__class__.__name__ - message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." - raise MCPError(code=REQUEST_TIMEOUT, message=message) + # --- Progress-reset timeout path --- + if reset_timeout_on_progress and timeout is not None: + timeout_info = _ProgressTimeoutInfo( + timeout=timeout, + start_time=time.monotonic(), + max_total_timeout=max_total_timeout, + ) + self._progress_timeout_infos[request_id] = timeout_info + + try: + while True: + with anyio.move_on_after(timeout_info.timeout): + response_or_error = await response_stream_reader.receive() + break # received response + + # Timeout window expired — check for progress reset + if timeout_info.was_reset: + timeout_info.was_reset = False + # Check absolute ceiling + if timeout_info.max_total_timeout is not None: + elapsed = time.monotonic() - timeout_info.start_time + if elapsed >= timeout_info.max_total_timeout: + class_name = request.__class__.__name__ + raise MCPError( + code=REQUEST_TIMEOUT, + message=( + f"Maximum total timeout exceeded for {class_name}. " + f"Elapsed {elapsed:.1f}s, limit {timeout_info.max_total_timeout:.1f}s." + ), + ) + continue # restart timeout window + + # No progress reset — genuine timeout + class_name = request.__class__.__name__ + raise MCPError( + code=REQUEST_TIMEOUT, + message=f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds.", + ) + finally: + self._progress_timeout_infos.pop(request_id, None) + + # --- Standard timeout path --- + else: + try: + with anyio.fail_after(timeout): + response_or_error = await response_stream_reader.receive() + except TimeoutError: + class_name = request.__class__.__name__ + message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." + raise MCPError(code=REQUEST_TIMEOUT, message=message) if isinstance(response_or_error, JSONRPCError): raise MCPError.from_jsonrpc_error(response_or_error) @@ -405,6 +498,21 @@ async def _handle_session_message(message: SessionMessage) -> None: # If there is a progress callback for this token, # call it with the progress information if progress_token in self._progress_callbacks: + # Signal timeout reset before invoking callback + timeout_info = self._progress_timeout_infos.get(progress_token) + if timeout_info is not None: + # Check max total timeout first + if timeout_info.max_total_timeout is not None: + elapsed = time.monotonic() - timeout_info.start_time + if elapsed >= timeout_info.max_total_timeout: + timeout_info.max_exceeded = True + timeout_info.was_reset = True + # Skip progress callback (matches TypeScript SDK) + await self._received_notification(notification) + await self._handle_incoming(notification) + return + timeout_info.was_reset = True + callback = self._progress_callbacks[progress_token] try: await callback( diff --git a/tests/shared/test_progress_reset_timeout.py b/tests/shared/test_progress_reset_timeout.py new file mode 100644 index 000000000..81ecf78a2 --- /dev/null +++ b/tests/shared/test_progress_reset_timeout.py @@ -0,0 +1,376 @@ +"""Tests for reset_timeout_on_progress and max_total_timeout semantics. + +Mirrors the TypeScript SDK test coverage in test/shared/protocol.test.ts. +""" + +import anyio +import pytest + +from mcp import types +from mcp.client.session import ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.exceptions import MCPError +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + JSONRPCRequest, + JSONRPCResponse, +) + + +def _make_progress_notification( + progress_token: int | str, + progress: float, + total: float | None = None, +) -> SessionMessage: + """Build a raw progress notification to inject into the client's read stream.""" + from mcp.types import JSONRPCNotification + + notification = JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params={ + "progressToken": progress_token, + "progress": progress, + "total": total, + }, + ) + return SessionMessage(message=notification) + + +@pytest.mark.anyio +async def test_no_progress_no_reset_timeout_fires(): + """Without progress notifications, the timeout fires normally even when + reset_timeout_on_progress is True.""" + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Receive the request but never respond.""" + await server_read.receive() + # Never respond — let the client timeout + await anyio.sleep(5) + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + + with pytest.raises(MCPError) as exc_info: + await client_session.send_request( + types.PingRequest(params=types.RequestParams()), + types.EmptyResult, + request_read_timeout_seconds=0.3, + reset_timeout_on_progress=True, + ) + assert "Timed out" in str(exc_info.value) + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_progress_resets_timeout(): + """A progress notification received before the timeout window expires + resets the deadline, keeping the request alive.""" + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Receive the request, send progress at ~50% of the timeout window, + then respond after the original timeout would have expired.""" + msg = await server_read.receive() + assert isinstance(msg.message, JSONRPCRequest) + request_id = msg.message.id + + # Wait half the timeout, then send progress + await anyio.sleep(0.15) + await server_write.send( + _make_progress_notification(progress_token=request_id, progress=0.5, total=1.0) + ) + + # Wait past the original timeout (0.3s total) so the request + # would have timed out without the reset + await anyio.sleep(0.25) + + # Now respond + await server_write.send( + SessionMessage( + message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}) + ) + ) + + result_holder: list[types.EmptyResult] = [] + + async def make_request(client_session: ClientSession): + result = await client_session.send_request( + types.PingRequest(params=types.RequestParams()), + types.EmptyResult, + request_read_timeout_seconds=0.3, + reset_timeout_on_progress=True, + ) + result_holder.append(result) + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(5): + while not result_holder: + await anyio.sleep(0.05) + + assert len(result_holder) == 1 + assert isinstance(result_holder[0], types.EmptyResult) + + +@pytest.mark.anyio +async def test_max_total_timeout_exceeded(): + """When max_total_timeout is set and the elapsed time exceeds it, the + request fails even though progress keeps resetting the per-window timeout.""" + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Send multiple progress notifications that keep resetting the + per-window timeout but eventually exceed max_total_timeout.""" + msg = await server_read.receive() + assert isinstance(msg.message, JSONRPCRequest) + request_id = msg.message.id + + # Send progress notifications every 80ms (well within the 0.3s window) + for i in range(5): + await anyio.sleep(0.08) + await server_write.send( + _make_progress_notification( + progress_token=request_id, + progress=float(i + 1) / 5, + total=1.0, + ) + ) + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + + with pytest.raises(MCPError) as exc_info: + await client_session.send_request( + types.PingRequest(params=types.RequestParams()), + types.EmptyResult, + request_read_timeout_seconds=0.3, + reset_timeout_on_progress=True, + max_total_timeout=0.35, + ) + assert "Maximum total timeout exceeded" in str(exc_info.value) + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_progress_stops_timeout_fires(): + """When progress notifications stop arriving, the per-window timeout + eventually fires even though reset_timeout_on_progress is enabled.""" + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Send a few progress notifications, then stop.""" + msg = await server_read.receive() + assert isinstance(msg.message, JSONRPCRequest) + request_id = msg.message.id + + # Send progress at 80ms and 160ms (within the 0.3s window) + await anyio.sleep(0.08) + await server_write.send( + _make_progress_notification(progress_token=request_id, progress=0.3, total=1.0) + ) + await anyio.sleep(0.08) + await server_write.send( + _make_progress_notification(progress_token=request_id, progress=0.6, total=1.0) + ) + # Stop sending progress — let the client timeout after the + # last reset window expires + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + + with pytest.raises(MCPError) as exc_info: + await client_session.send_request( + types.PingRequest(params=types.RequestParams()), + types.EmptyResult, + request_read_timeout_seconds=0.3, + reset_timeout_on_progress=True, + ) + assert "Timed out" in str(exc_info.value) + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_multiple_progress_notifications(): + """Multiple progress notifications each reset the timeout, keeping the + request alive for well beyond the original timeout.""" + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Send progress every 80ms with a 0.3s timeout (3 progress + notifications), then respond.""" + msg = await server_read.receive() + assert isinstance(msg.message, JSONRPCRequest) + request_id = msg.message.id + + for i in range(3): + await anyio.sleep(0.08) + await server_write.send( + _make_progress_notification( + progress_token=request_id, + progress=float(i + 1) / 3, + total=1.0, + ) + ) + + # Respond after the 3rd progress + await anyio.sleep(0.05) + await server_write.send( + SessionMessage( + message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}) + ) + ) + + result_holder: list[types.EmptyResult] = [] + + async def make_request(client_session: ClientSession): + result = await client_session.send_request( + types.PingRequest(params=types.RequestParams()), + types.EmptyResult, + request_read_timeout_seconds=0.3, + reset_timeout_on_progress=True, + ) + result_holder.append(result) + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(5): + while not result_holder: + await anyio.sleep(0.05) + + assert len(result_holder) == 1 + + +@pytest.mark.anyio +async def test_reset_timeout_false_by_default(): + """When reset_timeout_on_progress is False (default), progress notifications + do NOT reset the timeout.""" + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Send progress before the timeout, then wait.""" + msg = await server_read.receive() + assert isinstance(msg.message, JSONRPCRequest) + request_id = msg.message.id + + # Send progress at 80ms (before the 0.3s timeout) + await anyio.sleep(0.08) + await server_write.send( + _make_progress_notification(progress_token=request_id, progress=0.5, total=1.0) + ) + + # Wait past the original timeout + await anyio.sleep(0.5) + + progress_received: list[float] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + progress_received.append(progress) + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + + with pytest.raises(MCPError) as exc_info: + await client_session.send_request( + types.PingRequest(params=types.RequestParams()), + types.EmptyResult, + request_read_timeout_seconds=0.3, + progress_callback=on_progress, + reset_timeout_on_progress=False, + ) + assert "Timed out" in str(exc_info.value) + + tg.cancel_scope.cancel() + + # Progress callback was still invoked (just didn't reset the timeout) + assert len(progress_received) == 1 + + +@pytest.mark.anyio +async def test_call_tool_threads_reset_timeout(): + """Verify that ClientSession.call_tool passes reset_timeout_on_progress + through to send_request, keeping a slow tool alive via progress.""" + + async def handle_call_tool( + ctx: ServerRequestContext, params: types.CallToolRequestParams + ) -> types.CallToolResult: + assert ctx.request_id is not None + # Send progress to keep the request alive + for i in range(3): + await anyio.sleep(0.08) + await ctx.session.send_progress_notification( + progress_token=ctx.request_id, + progress=float(i + 1) / 3, + total=1.0, + ) + return types.CallToolResult(content=[types.TextContent(type="text", text="done")]) + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="slow_tool", input_schema={})]) + + server = Server( + name="TestServer", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) + + from mcp import Client + + async with Client(server) as client: + result = await client.session.call_tool( + "slow_tool", + arguments={}, + read_timeout_seconds=0.3, + reset_timeout_on_progress=True, + ) + assert len(result.content) == 1 + text = result.content[0] + assert isinstance(text, types.TextContent) + assert text.text == "done" From d5df222b8bc1e7f6549e439e75082fc3f08ab56f Mon Sep 17 00:00:00 2001 From: Stefan de Vogelaere Date: Wed, 13 May 2026 20:11:34 +0200 Subject: [PATCH 2/5] style: apply ruff formatting to test file --- tests/shared/test_progress_reset_timeout.py | 32 +++++---------------- 1 file changed, 7 insertions(+), 25 deletions(-) diff --git a/tests/shared/test_progress_reset_timeout.py b/tests/shared/test_progress_reset_timeout.py index 81ecf78a2..82b5ae711 100644 --- a/tests/shared/test_progress_reset_timeout.py +++ b/tests/shared/test_progress_reset_timeout.py @@ -89,20 +89,14 @@ async def mock_server(): # Wait half the timeout, then send progress await anyio.sleep(0.15) - await server_write.send( - _make_progress_notification(progress_token=request_id, progress=0.5, total=1.0) - ) + await server_write.send(_make_progress_notification(progress_token=request_id, progress=0.5, total=1.0)) # Wait past the original timeout (0.3s total) so the request # would have timed out without the reset await anyio.sleep(0.25) # Now respond - await server_write.send( - SessionMessage( - message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}) - ) - ) + await server_write.send(SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) result_holder: list[types.EmptyResult] = [] @@ -193,13 +187,9 @@ async def mock_server(): # Send progress at 80ms and 160ms (within the 0.3s window) await anyio.sleep(0.08) - await server_write.send( - _make_progress_notification(progress_token=request_id, progress=0.3, total=1.0) - ) + await server_write.send(_make_progress_notification(progress_token=request_id, progress=0.3, total=1.0)) await anyio.sleep(0.08) - await server_write.send( - _make_progress_notification(progress_token=request_id, progress=0.6, total=1.0) - ) + await server_write.send(_make_progress_notification(progress_token=request_id, progress=0.6, total=1.0)) # Stop sending progress — let the client timeout after the # last reset window expires @@ -249,11 +239,7 @@ async def mock_server(): # Respond after the 3rd progress await anyio.sleep(0.05) - await server_write.send( - SessionMessage( - message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}) - ) - ) + await server_write.send(SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) result_holder: list[types.EmptyResult] = [] @@ -297,9 +283,7 @@ async def mock_server(): # Send progress at 80ms (before the 0.3s timeout) await anyio.sleep(0.08) - await server_write.send( - _make_progress_notification(progress_token=request_id, progress=0.5, total=1.0) - ) + await server_write.send(_make_progress_notification(progress_token=request_id, progress=0.5, total=1.0)) # Wait past the original timeout await anyio.sleep(0.5) @@ -336,9 +320,7 @@ async def test_call_tool_threads_reset_timeout(): """Verify that ClientSession.call_tool passes reset_timeout_on_progress through to send_request, keeping a slow tool alive via progress.""" - async def handle_call_tool( - ctx: ServerRequestContext, params: types.CallToolRequestParams - ) -> types.CallToolResult: + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: assert ctx.request_id is not None # Send progress to keep the request alive for i in range(3): From fac6effc0d52a0e3f1bd2dbe3d906adc727be36f Mon Sep 17 00:00:00 2001 From: Stefan de Vogelaere Date: Wed, 13 May 2026 20:19:38 +0200 Subject: [PATCH 3/5] style: apply ruff formatting to session.py --- src/mcp/shared/session.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index ba1293141..bd0a6e170 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -76,9 +76,7 @@ class _ProgressTimeoutInfo: max_exceeded: bool = False -async def _noop_progress_callback( - progress: float, total: float | None, message: str | None -) -> None: +async def _noop_progress_callback(progress: float, total: float | None, message: str | None) -> None: """No-op progress callback used when reset_timeout_on_progress is set without an explicit progress_callback.""" From 580564e8dd6ae8479a25f1fd3e821899ff1c7d89 Mon Sep 17 00:00:00 2001 From: Stefan de Vogelaere Date: Wed, 13 May 2026 20:31:51 +0200 Subject: [PATCH 4/5] style: fix ruff lint issues in session.py Remove unused `field` import and break long line (E501). --- src/mcp/shared/session.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index bd0a6e170..cd11e6e27 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -5,7 +5,7 @@ import time from collections.abc import Callable from contextlib import AsyncExitStack -from dataclasses import dataclass, field +from dataclasses import dataclass from types import TracebackType from typing import Any, Generic, Protocol, TypeVar @@ -372,7 +372,9 @@ async def send_request( class_name = request.__class__.__name__ raise MCPError( code=REQUEST_TIMEOUT, - message=f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds.", + message=( + f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." + ), ) finally: self._progress_timeout_infos.pop(request_id, None) From bdc4795820eeecb494ecf0ca7024fbaab59bfd40 Mon Sep 17 00:00:00 2001 From: Stefan de Vogelaere Date: Wed, 13 May 2026 20:39:02 +0200 Subject: [PATCH 5/5] fix: resolve pyright type errors in test file Add isinstance(msg, SessionMessage) guard before accessing .message to satisfy pyright's type narrowing. Mark unused server_write with underscore prefix. --- tests/shared/test_progress_reset_timeout.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/shared/test_progress_reset_timeout.py b/tests/shared/test_progress_reset_timeout.py index 82b5ae711..dd9a41677 100644 --- a/tests/shared/test_progress_reset_timeout.py +++ b/tests/shared/test_progress_reset_timeout.py @@ -45,7 +45,7 @@ async def test_no_progress_no_reset_timeout_fires(): async with create_client_server_memory_streams() as (client_streams, server_streams): client_read, client_write = client_streams - server_read, server_write = server_streams + server_read, _server_write = server_streams async def mock_server(): """Receive the request but never respond.""" @@ -84,6 +84,7 @@ async def mock_server(): """Receive the request, send progress at ~50% of the timeout window, then respond after the original timeout would have expired.""" msg = await server_read.receive() + assert isinstance(msg, SessionMessage) assert isinstance(msg.message, JSONRPCRequest) request_id = msg.message.id @@ -137,6 +138,7 @@ async def mock_server(): """Send multiple progress notifications that keep resetting the per-window timeout but eventually exceed max_total_timeout.""" msg = await server_read.receive() + assert isinstance(msg, SessionMessage) assert isinstance(msg.message, JSONRPCRequest) request_id = msg.message.id @@ -182,6 +184,7 @@ async def test_progress_stops_timeout_fires(): async def mock_server(): """Send a few progress notifications, then stop.""" msg = await server_read.receive() + assert isinstance(msg, SessionMessage) assert isinstance(msg.message, JSONRPCRequest) request_id = msg.message.id @@ -224,6 +227,7 @@ async def mock_server(): """Send progress every 80ms with a 0.3s timeout (3 progress notifications), then respond.""" msg = await server_read.receive() + assert isinstance(msg, SessionMessage) assert isinstance(msg.message, JSONRPCRequest) request_id = msg.message.id @@ -278,6 +282,7 @@ async def test_reset_timeout_false_by_default(): async def mock_server(): """Send progress before the timeout, then wait.""" msg = await server_read.receive() + assert isinstance(msg, SessionMessage) assert isinstance(msg.message, JSONRPCRequest) request_id = msg.message.id