Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion src/mcp/client/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,22 @@ async def call_tool(
progress_callback: ProgressFnT | None = None,
*,
meta: RequestParamsMeta | None = None,
reset_timeout_on_progress: bool = False,
max_total_timeout: float | None = None,
) -> types.CallToolResult:
"""Send a tools/call request with optional progress callback support."""
"""Send a tools/call request with optional progress callback support.

Args:
name: The tool name.
arguments: Optional arguments for the tool.
read_timeout_seconds: Per-request timeout override.
progress_callback: Optional callback for progress notifications.
meta: Optional request metadata.
reset_timeout_on_progress: When True, each progress notification
resets the timeout window.
max_total_timeout: Optional absolute ceiling (seconds) measured
from request start.
"""

result = await self.send_request(
types.CallToolRequest(
Expand All @@ -314,6 +328,8 @@ async def call_tool(
types.CallToolResult,
request_read_timeout_seconds=read_timeout_seconds,
progress_callback=progress_callback,
reset_timeout_on_progress=reset_timeout_on_progress,
max_total_timeout=max_total_timeout,
)

if not result.is_error:
Expand Down
132 changes: 120 additions & 12 deletions src/mcp/shared/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

import contextvars
import logging
import time
from collections.abc import Callable
from contextlib import AsyncExitStack
from dataclasses import dataclass
from types import TracebackType
from typing import Any, Generic, Protocol, TypeVar

Expand Down Expand Up @@ -57,6 +59,28 @@ async def __call__(
) -> None: ... # pragma: no branch


@dataclass
class _ProgressTimeoutInfo:
"""Tracks timeout state for a request with progress-reset support.

When ``reset_timeout_on_progress`` is enabled, each matching
``notifications/progress`` sets ``was_reset`` so the send_request loop
can restart its timeout window. ``max_total_timeout`` provides an
optional absolute ceiling measured from ``start_time``.
"""

timeout: float
start_time: float
max_total_timeout: float | None
was_reset: bool = False
max_exceeded: bool = False


async def _noop_progress_callback(progress: float, total: float | None, message: str | None) -> None:
"""No-op progress callback used when reset_timeout_on_progress is set
without an explicit progress_callback."""


class RequestResponder(Generic[ReceiveRequestT, SendResultT]):
"""Handles responding to MCP requests and manages request lifecycle.

Expand Down Expand Up @@ -183,6 +207,7 @@ class BaseSession(
_request_id: int
_in_flight: dict[RequestId, RequestResponder[ReceiveRequestT, SendResultT]]
_progress_callbacks: dict[RequestId, ProgressFnT]
_progress_timeout_infos: dict[RequestId, _ProgressTimeoutInfo]
_response_routers: list[ResponseRouter]

def __init__(
Expand All @@ -199,6 +224,7 @@ def __init__(
self._session_read_timeout_seconds = read_timeout_seconds
self._in_flight = {}
self._progress_callbacks = {}
self._progress_timeout_infos = {}
self._response_routers = []
self._exit_stack = AsyncExitStack()

Expand Down Expand Up @@ -243,31 +269,53 @@ async def send_request(
request_read_timeout_seconds: float | None = None,
metadata: MessageMetadata = None,
progress_callback: ProgressFnT | None = None,
reset_timeout_on_progress: bool = False,
max_total_timeout: float | None = None,
) -> ReceiveResultT:
"""Sends a request and waits for a response.

Raises an MCPError if the response contains an error. If a request read timeout is provided, it will take
precedence over the session read timeout.

Do not use this method to emit notifications! Use send_notification() instead.

Args:
request: The request to send.
result_type: The expected result type.
request_read_timeout_seconds: Per-request timeout override.
metadata: Optional metadata for the message.
progress_callback: Optional callback invoked on progress notifications.
reset_timeout_on_progress: When True, each ``notifications/progress``
for this request resets the timeout window. The new deadline is
``current_time + timeout``. Requires a non-None timeout.
max_total_timeout: Optional absolute ceiling (seconds) measured from
the original request start time. Only meaningful when
``reset_timeout_on_progress=True``. If the elapsed time since
``start_time`` exceeds this value, the request fails immediately
even if progress keeps arriving.
"""
request_id = self._request_id
self._request_id = request_id + 1

response_stream, response_stream_reader = anyio.create_memory_object_stream[JSONRPCResponse | JSONRPCError](1)
self._response_streams[request_id] = response_stream

# Set up progress token if progress callback is provided
# Set up progress token if a progress callback is provided or if
# reset_timeout_on_progress is enabled (which needs the token to
# match incoming progress notifications).
request_data = request.model_dump(by_alias=True, mode="json", exclude_none=True)
if progress_callback is not None:
# Use request_id as progress token
needs_progress_token = progress_callback is not None or reset_timeout_on_progress
if needs_progress_token:
if "params" not in request_data: # pragma: lax no cover
request_data["params"] = {}
if "_meta" not in request_data["params"]: # pragma: lax no cover
request_data["params"]["_meta"] = {}
request_data["params"]["_meta"]["progressToken"] = request_id
# Store the callback for this request
self._progress_callbacks[request_id] = progress_callback
# Register the callback (no-op if only reset_timeout_on_progress is set)
if progress_callback is not None:
self._progress_callbacks[request_id] = progress_callback
else:
self._progress_callbacks[request_id] = _noop_progress_callback

try:
target = request_data.get("params", {}).get("name")
Expand All @@ -288,13 +336,58 @@ async def send_request(
# request read timeout takes precedence over session read timeout
timeout = request_read_timeout_seconds or self._session_read_timeout_seconds

try:
with anyio.fail_after(timeout):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
class_name = request.__class__.__name__
message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds."
raise MCPError(code=REQUEST_TIMEOUT, message=message)
# --- Progress-reset timeout path ---
if reset_timeout_on_progress and timeout is not None:
timeout_info = _ProgressTimeoutInfo(
timeout=timeout,
start_time=time.monotonic(),
max_total_timeout=max_total_timeout,
)
self._progress_timeout_infos[request_id] = timeout_info

try:
while True:
with anyio.move_on_after(timeout_info.timeout):
response_or_error = await response_stream_reader.receive()
break # received response

# Timeout window expired — check for progress reset
if timeout_info.was_reset:
timeout_info.was_reset = False
# Check absolute ceiling
if timeout_info.max_total_timeout is not None:
elapsed = time.monotonic() - timeout_info.start_time
if elapsed >= timeout_info.max_total_timeout:
class_name = request.__class__.__name__
raise MCPError(
code=REQUEST_TIMEOUT,
message=(
f"Maximum total timeout exceeded for {class_name}. "
f"Elapsed {elapsed:.1f}s, limit {timeout_info.max_total_timeout:.1f}s."
),
)
continue # restart timeout window

# No progress reset — genuine timeout
class_name = request.__class__.__name__
raise MCPError(
code=REQUEST_TIMEOUT,
message=(
f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds."
),
)
finally:
self._progress_timeout_infos.pop(request_id, None)

# --- Standard timeout path ---
else:
try:
with anyio.fail_after(timeout):
response_or_error = await response_stream_reader.receive()
except TimeoutError:
class_name = request.__class__.__name__
message = f"Timed out while waiting for response to {class_name}. Waited {timeout} seconds."
raise MCPError(code=REQUEST_TIMEOUT, message=message)

if isinstance(response_or_error, JSONRPCError):
raise MCPError.from_jsonrpc_error(response_or_error)
Expand Down Expand Up @@ -405,6 +498,21 @@ async def _handle_session_message(message: SessionMessage) -> None:
# If there is a progress callback for this token,
# call it with the progress information
if progress_token in self._progress_callbacks:
# Signal timeout reset before invoking callback
timeout_info = self._progress_timeout_infos.get(progress_token)
if timeout_info is not None:
# Check max total timeout first
if timeout_info.max_total_timeout is not None:
elapsed = time.monotonic() - timeout_info.start_time
if elapsed >= timeout_info.max_total_timeout:
timeout_info.max_exceeded = True
timeout_info.was_reset = True
# Skip progress callback (matches TypeScript SDK)
await self._received_notification(notification)
await self._handle_incoming(notification)
return
timeout_info.was_reset = True

callback = self._progress_callbacks[progress_token]
try:
await callback(
Expand Down
Loading
Loading