diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 0cea454a7..ad947a64f 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -304,8 +304,22 @@ async def call_tool( progress_callback: ProgressFnT | None = None, *, meta: RequestParamsMeta | None = None, + reset_timeout_on_progress: bool = False, + max_total_timeout: float | None = None, ) -> types.CallToolResult: - """Send a tools/call request with optional progress callback support.""" + """Send a tools/call request with optional progress callback support. + + Args: + name: The tool name. + arguments: Optional arguments for the tool. + read_timeout_seconds: Per-request timeout override. + progress_callback: Optional callback for progress notifications. + meta: Optional request metadata. + reset_timeout_on_progress: When True, each progress notification + resets the timeout window. + max_total_timeout: Optional absolute ceiling (seconds) measured + from request start. + """ result = await self.send_request( types.CallToolRequest( @@ -314,6 +328,8 @@ async def call_tool( types.CallToolResult, request_read_timeout_seconds=read_timeout_seconds, progress_callback=progress_callback, + reset_timeout_on_progress=reset_timeout_on_progress, + max_total_timeout=max_total_timeout, ) if not result.is_error: diff --git a/src/mcp/shared/session.py b/src/mcp/shared/session.py index 243eef5ae..cd11e6e27 100644 --- a/src/mcp/shared/session.py +++ b/src/mcp/shared/session.py @@ -2,8 +2,10 @@ import contextvars import logging +import time from collections.abc import Callable from contextlib import AsyncExitStack +from dataclasses import dataclass from types import TracebackType from typing import Any, Generic, Protocol, TypeVar @@ -57,6 +59,28 @@ async def __call__( ) -> None: ... # pragma: no branch +@dataclass +class _ProgressTimeoutInfo: + """Tracks timeout state for a request with progress-reset support. + + When ``reset_timeout_on_progress`` is enabled, each matching + ``notifications/progress`` sets ``was_reset`` so the send_request loop + can restart its timeout window. ``max_total_timeout`` provides an + optional absolute ceiling measured from ``start_time``. + """ + + timeout: float + start_time: float + max_total_timeout: float | None + was_reset: bool = False + max_exceeded: bool = False + + +async def _noop_progress_callback(progress: float, total: float | None, message: str | None) -> None: + """No-op progress callback used when reset_timeout_on_progress is set + without an explicit progress_callback.""" + + class RequestResponder(Generic[ReceiveRequestT, SendResultT]): """Handles responding to MCP requests and manages request lifecycle. @@ -183,6 +207,7 @@ class BaseSession( _request_id: int _in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]] _progress_callbacks: dict[RequestId, ProgressFnT] + _progress_timeout_infos: dict[RequestId, _ProgressTimeoutInfo] _response_routers: list[ResponseRouter] def __init__( @@ -199,6 +224,7 @@ def __init__( self._session_read_timeout_seconds = read_timeout_seconds self._in_flight = {} self._progress_callbacks = {} + self._progress_timeout_infos = {} self._response_routers = [] self._exit_stack = AsyncExitStack() @@ -243,6 +269,8 @@ async def send_request( request_read_timeout_seconds: float | None = None, metadata: MessageMetadata = None, progress_callback: ProgressFnT | None = None, + reset_timeout_on_progress: bool = False, + max_total_timeout: float | None = None, ) -> ReceiveResultT: """Sends a request and waits for a response. @@ -250,6 +278,21 @@ async def send_request( precedence over the session read timeout. Do not use this method to emit notifications! Use send_notification() instead. + + Args: + request: The request to send. + result_type: The expected result type. + request_read_timeout_seconds: Per-request timeout override. + metadata: Optional metadata for the message. + progress_callback: Optional callback invoked on progress notifications. + reset_timeout_on_progress: When True, each ``notifications/progress`` + for this request resets the timeout window. The new deadline is + ``current_time + timeout``. Requires a non-None timeout. + max_total_timeout: Optional absolute ceiling (seconds) measured from + the original request start time. Only meaningful when + ``reset_timeout_on_progress=True``. If the elapsed time since + ``start_time`` exceeds this value, the request fails immediately + even if progress keeps arriving. """ request_id = self._request_id self._request_id = request_id + 1 @@ -257,17 +300,22 @@ async def send_request( response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1) self._response_streams[request_id] = response_stream - # Set up progress token if progress callback is provided + # Set up progress token if a progress callback is provided or if + # reset_timeout_on_progress is enabled (which needs the token to + # match incoming progress notifications). request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True) - if progress_callback is not None: - # Use request_id as progress token + needs_progress_token = progress_callback is not None or reset_timeout_on_progress + if needs_progress_token: if "params" not in request_data: # pragma: lax no cover request_data["params"] = {} if "_meta" not in request_data["params"]: # pragma: lax no cover request_data["params"]["_meta"] = {} request_data["params"]["_meta"]["progressToken"] = request_id - # Store the callback for this request - self._progress_callbacks[request_id] = progress_callback + # Register the callback (no-op if only reset_timeout_on_progress is set) + if progress_callback is not None: + self._progress_callbacks[request_id] = progress_callback + else: + self._progress_callbacks[request_id] = _noop_progress_callback try: target = request_data.get("params", {}).get("name") @@ -288,13 +336,58 @@ async def send_request( # request read timeout takes precedence over session read timeout timeout = request_read_timeout_seconds or self._session_read_timeout_seconds - try: - with anyio.fail_after(timeout): - response_or_error = await response_stream_reader.receive() - except TimeoutError: - class_name = request.__class__.__name__ - message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." - raise MCPError(code=REQUEST_TIMEOUT, message=message) + # --- Progress-reset timeout path --- + if reset_timeout_on_progress and timeout is not None: + timeout_info = _ProgressTimeoutInfo( + timeout=timeout, + start_time=time.monotonic(), + max_total_timeout=max_total_timeout, + ) + self._progress_timeout_infos[request_id] = timeout_info + + try: + while True: + with anyio.move_on_after(timeout_info.timeout): + response_or_error = await response_stream_reader.receive() + break # received response + + # Timeout window expired — check for progress reset + if timeout_info.was_reset: + timeout_info.was_reset = False + # Check absolute ceiling + if timeout_info.max_total_timeout is not None: + elapsed = time.monotonic() - timeout_info.start_time + if elapsed >= timeout_info.max_total_timeout: + class_name = request.__class__.__name__ + raise MCPError( + code=REQUEST_TIMEOUT, + message=( + f"Maximum total timeout exceeded for {class_name}. " + f"Elapsed {elapsed:.1f}s, limit {timeout_info.max_total_timeout:.1f}s." + ), + ) + continue # restart timeout window + + # No progress reset — genuine timeout + class_name = request.__class__.__name__ + raise MCPError( + code=REQUEST_TIMEOUT, + message=( + f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." + ), + ) + finally: + self._progress_timeout_infos.pop(request_id, None) + + # --- Standard timeout path --- + else: + try: + with anyio.fail_after(timeout): + response_or_error = await response_stream_reader.receive() + except TimeoutError: + class_name = request.__class__.__name__ + message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds." + raise MCPError(code=REQUEST_TIMEOUT, message=message) if isinstance(response_or_error, JSONRPCError): raise MCPError.from_jsonrpc_error(response_or_error) @@ -405,6 +498,21 @@ async def _handle_session_message(message: SessionMessage) -> None: # If there is a progress callback for this token, # call it with the progress information if progress_token in self._progress_callbacks: + # Signal timeout reset before invoking callback + timeout_info = self._progress_timeout_infos.get(progress_token) + if timeout_info is not None: + # Check max total timeout first + if timeout_info.max_total_timeout is not None: + elapsed = time.monotonic() - timeout_info.start_time + if elapsed >= timeout_info.max_total_timeout: + timeout_info.max_exceeded = True + timeout_info.was_reset = True + # Skip progress callback (matches TypeScript SDK) + await self._received_notification(notification) + await self._handle_incoming(notification) + return + timeout_info.was_reset = True + callback = self._progress_callbacks[progress_token] try: await callback( diff --git a/tests/shared/test_progress_reset_timeout.py b/tests/shared/test_progress_reset_timeout.py new file mode 100644 index 000000000..dd9a41677 --- /dev/null +++ b/tests/shared/test_progress_reset_timeout.py @@ -0,0 +1,363 @@ +"""Tests for reset_timeout_on_progress and max_total_timeout semantics. + +Mirrors the TypeScript SDK test coverage in test/shared/protocol.test.ts. +""" + +import anyio +import pytest + +from mcp import types +from mcp.client.session import ClientSession +from mcp.server import Server, ServerRequestContext +from mcp.shared.exceptions import MCPError +from mcp.shared.memory import create_client_server_memory_streams +from mcp.shared.message import SessionMessage +from mcp.types import ( + JSONRPCRequest, + JSONRPCResponse, +) + + +def _make_progress_notification( + progress_token: int | str, + progress: float, + total: float | None = None, +) -> SessionMessage: + """Build a raw progress notification to inject into the client's read stream.""" + from mcp.types import JSONRPCNotification + + notification = JSONRPCNotification( + jsonrpc="2.0", + method="notifications/progress", + params={ + "progressToken": progress_token, + "progress": progress, + "total": total, + }, + ) + return SessionMessage(message=notification) + + +@pytest.mark.anyio +async def test_no_progress_no_reset_timeout_fires(): + """Without progress notifications, the timeout fires normally even when + reset_timeout_on_progress is True.""" + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, _server_write = server_streams + + async def mock_server(): + """Receive the request but never respond.""" + await server_read.receive() + # Never respond — let the client timeout + await anyio.sleep(5) + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + + with pytest.raises(MCPError) as exc_info: + await client_session.send_request( + types.PingRequest(params=types.RequestParams()), + types.EmptyResult, + request_read_timeout_seconds=0.3, + reset_timeout_on_progress=True, + ) + assert "Timed out" in str(exc_info.value) + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_progress_resets_timeout(): + """A progress notification received before the timeout window expires + resets the deadline, keeping the request alive.""" + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Receive the request, send progress at ~50% of the timeout window, + then respond after the original timeout would have expired.""" + msg = await server_read.receive() + assert isinstance(msg, SessionMessage) + assert isinstance(msg.message, JSONRPCRequest) + request_id = msg.message.id + + # Wait half the timeout, then send progress + await anyio.sleep(0.15) + await server_write.send(_make_progress_notification(progress_token=request_id, progress=0.5, total=1.0)) + + # Wait past the original timeout (0.3s total) so the request + # would have timed out without the reset + await anyio.sleep(0.25) + + # Now respond + await server_write.send(SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) + + result_holder: list[types.EmptyResult] = [] + + async def make_request(client_session: ClientSession): + result = await client_session.send_request( + types.PingRequest(params=types.RequestParams()), + types.EmptyResult, + request_read_timeout_seconds=0.3, + reset_timeout_on_progress=True, + ) + result_holder.append(result) + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(5): + while not result_holder: + await anyio.sleep(0.05) + + assert len(result_holder) == 1 + assert isinstance(result_holder[0], types.EmptyResult) + + +@pytest.mark.anyio +async def test_max_total_timeout_exceeded(): + """When max_total_timeout is set and the elapsed time exceeds it, the + request fails even though progress keeps resetting the per-window timeout.""" + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Send multiple progress notifications that keep resetting the + per-window timeout but eventually exceed max_total_timeout.""" + msg = await server_read.receive() + assert isinstance(msg, SessionMessage) + assert isinstance(msg.message, JSONRPCRequest) + request_id = msg.message.id + + # Send progress notifications every 80ms (well within the 0.3s window) + for i in range(5): + await anyio.sleep(0.08) + await server_write.send( + _make_progress_notification( + progress_token=request_id, + progress=float(i + 1) / 5, + total=1.0, + ) + ) + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + + with pytest.raises(MCPError) as exc_info: + await client_session.send_request( + types.PingRequest(params=types.RequestParams()), + types.EmptyResult, + request_read_timeout_seconds=0.3, + reset_timeout_on_progress=True, + max_total_timeout=0.35, + ) + assert "Maximum total timeout exceeded" in str(exc_info.value) + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_progress_stops_timeout_fires(): + """When progress notifications stop arriving, the per-window timeout + eventually fires even though reset_timeout_on_progress is enabled.""" + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Send a few progress notifications, then stop.""" + msg = await server_read.receive() + assert isinstance(msg, SessionMessage) + assert isinstance(msg.message, JSONRPCRequest) + request_id = msg.message.id + + # Send progress at 80ms and 160ms (within the 0.3s window) + await anyio.sleep(0.08) + await server_write.send(_make_progress_notification(progress_token=request_id, progress=0.3, total=1.0)) + await anyio.sleep(0.08) + await server_write.send(_make_progress_notification(progress_token=request_id, progress=0.6, total=1.0)) + # Stop sending progress — let the client timeout after the + # last reset window expires + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + + with pytest.raises(MCPError) as exc_info: + await client_session.send_request( + types.PingRequest(params=types.RequestParams()), + types.EmptyResult, + request_read_timeout_seconds=0.3, + reset_timeout_on_progress=True, + ) + assert "Timed out" in str(exc_info.value) + + tg.cancel_scope.cancel() + + +@pytest.mark.anyio +async def test_multiple_progress_notifications(): + """Multiple progress notifications each reset the timeout, keeping the + request alive for well beyond the original timeout.""" + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Send progress every 80ms with a 0.3s timeout (3 progress + notifications), then respond.""" + msg = await server_read.receive() + assert isinstance(msg, SessionMessage) + assert isinstance(msg.message, JSONRPCRequest) + request_id = msg.message.id + + for i in range(3): + await anyio.sleep(0.08) + await server_write.send( + _make_progress_notification( + progress_token=request_id, + progress=float(i + 1) / 3, + total=1.0, + ) + ) + + # Respond after the 3rd progress + await anyio.sleep(0.05) + await server_write.send(SessionMessage(message=JSONRPCResponse(jsonrpc="2.0", id=request_id, result={}))) + + result_holder: list[types.EmptyResult] = [] + + async def make_request(client_session: ClientSession): + result = await client_session.send_request( + types.PingRequest(params=types.RequestParams()), + types.EmptyResult, + request_read_timeout_seconds=0.3, + reset_timeout_on_progress=True, + ) + result_holder.append(result) + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + tg.start_soon(make_request, client_session) + + with anyio.fail_after(5): + while not result_holder: + await anyio.sleep(0.05) + + assert len(result_holder) == 1 + + +@pytest.mark.anyio +async def test_reset_timeout_false_by_default(): + """When reset_timeout_on_progress is False (default), progress notifications + do NOT reset the timeout.""" + + async with create_client_server_memory_streams() as (client_streams, server_streams): + client_read, client_write = client_streams + server_read, server_write = server_streams + + async def mock_server(): + """Send progress before the timeout, then wait.""" + msg = await server_read.receive() + assert isinstance(msg, SessionMessage) + assert isinstance(msg.message, JSONRPCRequest) + request_id = msg.message.id + + # Send progress at 80ms (before the 0.3s timeout) + await anyio.sleep(0.08) + await server_write.send(_make_progress_notification(progress_token=request_id, progress=0.5, total=1.0)) + + # Wait past the original timeout + await anyio.sleep(0.5) + + progress_received: list[float] = [] + + async def on_progress(progress: float, total: float | None, message: str | None) -> None: + progress_received.append(progress) + + async with ( + anyio.create_task_group() as tg, + ClientSession(read_stream=client_read, write_stream=client_write) as client_session, + ): + tg.start_soon(mock_server) + + with pytest.raises(MCPError) as exc_info: + await client_session.send_request( + types.PingRequest(params=types.RequestParams()), + types.EmptyResult, + request_read_timeout_seconds=0.3, + progress_callback=on_progress, + reset_timeout_on_progress=False, + ) + assert "Timed out" in str(exc_info.value) + + tg.cancel_scope.cancel() + + # Progress callback was still invoked (just didn't reset the timeout) + assert len(progress_received) == 1 + + +@pytest.mark.anyio +async def test_call_tool_threads_reset_timeout(): + """Verify that ClientSession.call_tool passes reset_timeout_on_progress + through to send_request, keeping a slow tool alive via progress.""" + + async def handle_call_tool(ctx: ServerRequestContext, params: types.CallToolRequestParams) -> types.CallToolResult: + assert ctx.request_id is not None + # Send progress to keep the request alive + for i in range(3): + await anyio.sleep(0.08) + await ctx.session.send_progress_notification( + progress_token=ctx.request_id, + progress=float(i + 1) / 3, + total=1.0, + ) + return types.CallToolResult(content=[types.TextContent(type="text", text="done")]) + + async def handle_list_tools( + ctx: ServerRequestContext, params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + return types.ListToolsResult(tools=[types.Tool(name="slow_tool", input_schema={})]) + + server = Server( + name="TestServer", + on_call_tool=handle_call_tool, + on_list_tools=handle_list_tools, + ) + + from mcp import Client + + async with Client(server) as client: + result = await client.session.call_tool( + "slow_tool", + arguments={}, + read_timeout_seconds=0.3, + reset_timeout_on_progress=True, + ) + assert len(result.content) == 1 + text = result.content[0] + assert isinstance(text, types.TextContent) + assert text.text == "done"