Skip to content

Commit fc58dc5

Browse files
Adding a test for cancelling upload
1 parent 46f65ee commit fc58dc5

1 file changed

Lines changed: 116 additions & 2 deletions

File tree

tests/v2/test_v2_session_lifecycle.py

Lines changed: 116 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
import asyncio
22
import logging
3-
from typing import AsyncIterator, Awaitable, Callable, TypeAlias, TypedDict
3+
from typing import (
4+
Any,
5+
AsyncIterator,
6+
Awaitable,
7+
Callable,
8+
Literal,
9+
TypeAlias,
10+
TypedDict,
11+
)
412

513
import msgpack
614
import nanoid
@@ -20,7 +28,12 @@
2028
)
2129
from replit_river.transport_options import TransportOptions, UriAndMetadata
2230
from replit_river.v2.client import Client
23-
from replit_river.v2.session import STREAM_CLOSED_BIT, Session
31+
from replit_river.v2.session import STREAM_CANCEL_BIT, STREAM_CLOSED_BIT, Session
32+
33+
34+
class OuterPayload[A](TypedDict):
35+
ok: Literal[True]
36+
payload: A
2437

2538

2639
class _PermissiveRateLimiter(RateLimiter):
@@ -247,3 +260,104 @@ async def handle_server_messages() -> None:
247260
# Ensure we're listening to close messages as well
248261
stream_handler.cancel()
249262
await stream_handler
263+
264+
265+
async def test_upload_cancel(ws_server: WsServerFixture) -> None:
266+
(urimeta, recv, conn) = ws_server
267+
268+
client = Client(
269+
client_id="CLIENT1",
270+
server_id="SERVER",
271+
transport_options=TransportOptions(),
272+
uri_and_metadata_factory=urimeta,
273+
)
274+
275+
connecting = asyncio.create_task(client.ensure_connected())
276+
request_msg = parse_transport_msg(await recv.get())
277+
278+
assert not isinstance(request_msg, str)
279+
assert (serverconn := conn())
280+
handshake_request: ControlMessageHandshakeRequest[None] = (
281+
ControlMessageHandshakeRequest(**request_msg.payload)
282+
)
283+
284+
handshake_resp = ControlMessageHandshakeResponse(
285+
status=HandShakeStatus(
286+
ok=True,
287+
),
288+
)
289+
handshake_request.sessionId
290+
291+
msg = TransportMessage(
292+
from_=request_msg.from_,
293+
to=request_msg.to,
294+
streamId=request_msg.streamId,
295+
controlFlags=0,
296+
id=nanoid.generate(),
297+
seq=0,
298+
ack=0,
299+
payload=handshake_resp.model_dump(),
300+
)
301+
packed = msgpack.packb(
302+
msg.model_dump(by_alias=True, exclude_none=True), datetime=True
303+
)
304+
await serverconn.send(packed)
305+
306+
async def handle_server_messages() -> None:
307+
request_msg = parse_transport_msg(await recv.get())
308+
assert not isinstance(request_msg, str)
309+
310+
logging.debug("request_msg: %r", repr(request_msg))
311+
312+
msg = TransportMessage(**msgpack.unpackb(await recv.get()))
313+
while msg.payload.get("payload", {}).get("hello") == "world":
314+
logging.debug("Found a hello:world %r", repr(msg))
315+
msg = TransportMessage(**msgpack.unpackb(await recv.get()))
316+
317+
assert msg.controlFlags == STREAM_CANCEL_BIT
318+
319+
stream_handler = asyncio.create_task(handle_server_messages())
320+
321+
sent_waiter = asyncio.Event()
322+
323+
async def upload_chunks() -> AsyncIterator[OuterPayload[dict[Any, Any]]]:
324+
count = 0
325+
while True:
326+
yield {
327+
"ok": True,
328+
"payload": {
329+
"hello": "world",
330+
},
331+
}
332+
count += 1
333+
if count > 5:
334+
sent_waiter.set()
335+
await asyncio.sleep(0.1)
336+
337+
upload_task = asyncio.create_task(
338+
client.send_upload(
339+
"test",
340+
"bigstream",
341+
{},
342+
upload_chunks(),
343+
lambda x: x,
344+
lambda x: x,
345+
lambda x: x,
346+
lambda x: x,
347+
)
348+
)
349+
350+
await sent_waiter.wait()
351+
352+
upload_task.cancel()
353+
try:
354+
await upload_task
355+
except asyncio.CancelledError:
356+
pass
357+
358+
await client.close()
359+
await connecting
360+
361+
# Ensure we're listening to close messages as well
362+
stream_handler.cancel()
363+
await stream_handler

0 commit comments

Comments
 (0)