Skip to content

Commit edfb1e9

Browse files
Fix server streaming handler not cancelled on client disconnect (#175)
Signed-off-by: Stefan VanBuren <svanburen@buf.build>
1 parent 9ae3455 commit edfb1e9

2 files changed

Lines changed: 122 additions & 13 deletions

File tree

src/connectrpc/_server_async.py

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from __future__ import annotations
22

33
import base64
4+
import contextlib
45
import functools
56
import inspect
67
from abc import ABC, abstractmethod
7-
from asyncio import CancelledError, sleep
8+
from asyncio import CancelledError, Event, create_task, sleep
89
from dataclasses import replace
910
from http import HTTPStatus
1011
from typing import TYPE_CHECKING, Generic, TypeVar, cast
@@ -387,6 +388,9 @@ async def _handle_stream(
387388
self._read_max_bytes,
388389
)
389390

391+
disconnect_detected: Event | None = None
392+
monitor_task = None
393+
390394
match endpoint:
391395
case EndpointUnary():
392396
request = await _consume_single_request(request_stream)
@@ -398,22 +402,50 @@ async def _handle_stream(
398402
case EndpointServerStream():
399403
request = await _consume_single_request(request_stream)
400404
response_stream = endpoint.function(request, ctx)
405+
406+
# The request has been fully consumed; monitor receive() for a
407+
# client disconnect so we can stop streaming promptly.
408+
disconnect_detected = Event()
409+
410+
async def _watch_for_disconnect() -> None:
411+
while True:
412+
msg = await receive()
413+
if msg["type"] == "http.disconnect":
414+
disconnect_detected.set()
415+
return
416+
417+
monitor_task = create_task(_watch_for_disconnect())
401418
case EndpointBidiStream():
402419
response_stream = endpoint.function(request_stream, ctx)
403420

404-
async for message in response_stream:
405-
# Don't send headers until the first message to allow logic a chance to add
406-
# response headers.
407-
if not sent_headers:
408-
await _send_stream_response_headers(
409-
send, protocol, codec, resp_compression.name(), ctx
421+
try:
422+
async for message in response_stream:
423+
if disconnect_detected is not None and disconnect_detected.is_set():
424+
raise ConnectError(Code.CANCELED, "Client disconnected")
425+
# Don't send headers until the first message to allow logic a chance to add
426+
# response headers.
427+
if not sent_headers:
428+
await _send_stream_response_headers(
429+
send, protocol, codec, resp_compression.name(), ctx
430+
)
431+
sent_headers = True
432+
433+
body = writer.write(message)
434+
await send(
435+
{"type": "http.response.body", "body": body, "more_body": True}
410436
)
411-
sent_headers = True
412-
413-
body = writer.write(message)
414-
await send(
415-
{"type": "http.response.body", "body": body, "more_body": True}
416-
)
437+
finally:
438+
# Cancel the monitor first so a throwing generator finally-block
439+
# doesn't leak the task.
440+
if monitor_task is not None:
441+
monitor_task.cancel()
442+
with contextlib.suppress(CancelledError):
443+
await monitor_task
444+
# Explicitly close the stream so that any generator finally-blocks
445+
# run promptly (Python defers async-generator cleanup to GC otherwise).
446+
aclose = getattr(response_stream, "aclose", None)
447+
if aclose is not None:
448+
await aclose()
417449
except CancelledError as e:
418450
raise ConnectError(Code.CANCELED, "Request was cancelled") from e
419451
except Exception as e:

test/test_roundtrip.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
import asyncio
4+
import struct
35
from typing import TYPE_CHECKING
46

57
import pytest
@@ -23,6 +25,8 @@
2325
if TYPE_CHECKING:
2426
from collections.abc import AsyncIterator, Iterator
2527

28+
from asgiref.typing import HTTPDisconnectEvent, HTTPRequestEvent, HTTPScope
29+
2630

2731
@pytest.mark.parametrize("proto_json", [False, True])
2832
@pytest.mark.parametrize("compression_name", ["gzip", "br", "zstd", "identity"])
@@ -280,3 +284,76 @@ async def request_stream():
280284
else:
281285
assert len(requests) == 2
282286
assert len(responses) == 1
287+
288+
289+
@pytest.mark.asyncio
290+
async def test_server_stream_client_disconnect() -> None:
291+
"""Server streaming generator should be closed when the client disconnects.
292+
293+
Regression test for https://github.com/connectrpc/connect-python/issues/174.
294+
"""
295+
generator_closed = asyncio.Event()
296+
297+
class InfiniteHaberdasher(Haberdasher):
298+
async def make_similar_hats(self, request, ctx):
299+
try:
300+
while True:
301+
yield Hat(size=request.inches, color="green")
302+
await asyncio.sleep(0) # yield control to event loop
303+
finally:
304+
generator_closed.set()
305+
306+
app = HaberdasherASGIApplication(InfiniteHaberdasher())
307+
308+
# Encode a Connect protocol (application/connect+proto) request for Size(inches=10).
309+
request_bytes = Size(inches=10).SerializeToString()
310+
request_body = struct.pack(">BI", 0, len(request_bytes)) + request_bytes
311+
312+
# We invoke the ASGI app directly rather than using a real client with a
313+
# short timeout because a real client could trigger the disconnect before the
314+
# request body has been fully read, which would be a different code path.
315+
disconnect_trigger = asyncio.Event()
316+
response_count = 0
317+
call_count = 0
318+
319+
async def receive() -> HTTPRequestEvent | HTTPDisconnectEvent:
320+
nonlocal call_count
321+
call_count += 1
322+
if call_count == 1:
323+
return {"type": "http.request", "body": request_body, "more_body": False}
324+
# Block until the test is ready to simulate a disconnect.
325+
await disconnect_trigger.wait()
326+
return {"type": "http.disconnect"}
327+
328+
async def send(message):
329+
nonlocal response_count
330+
if message.get("type") == "http.response.body" and message.get(
331+
"more_body", False
332+
):
333+
response_count += 1
334+
if response_count >= 3:
335+
disconnect_trigger.set()
336+
337+
scope: HTTPScope = {
338+
"type": "http",
339+
"asgi": {"spec_version": "2.0", "version": "3.0"},
340+
"http_version": "1.1",
341+
"method": "POST",
342+
"scheme": "http",
343+
"path": "/connectrpc.example.Haberdasher/MakeSimilarHats",
344+
"raw_path": b"/connectrpc.example.Haberdasher/MakeSimilarHats",
345+
"query_string": b"",
346+
"root_path": "",
347+
"headers": [(b"content-type", b"application/connect+proto")],
348+
"client": None,
349+
"server": None,
350+
"extensions": None,
351+
}
352+
353+
# Without the fix the app hangs forever (generator never stopped), causing a
354+
# TimeoutError here. With the fix it terminates promptly after the disconnect.
355+
await asyncio.wait_for(app(scope, receive, send), timeout=5.0)
356+
357+
assert generator_closed.is_set(), (
358+
"generator should be closed after client disconnect"
359+
)

0 commit comments

Comments
 (0)