Skip to content

Commit bf02c02

Browse files
Replacing "while" with an asyncio.Event
1 parent 810ed9f commit bf02c02

1 file changed

Lines changed: 20 additions & 16 deletions

File tree

src/replit_river/v2/session.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import AsyncIterable
55
from contextlib import asynccontextmanager
66
from dataclasses import dataclass
7-
from datetime import datetime, timedelta
7+
from datetime import timedelta
88
from typing import (
99
Any,
1010
AsyncGenerator,
@@ -169,6 +169,7 @@ class Session[HandshakeMetadata]:
169169

170170
# Terminating
171171
_terminating_task: asyncio.Task[None] | None
172+
_closing_waiter: asyncio.Event | None
172173

173174
def __init__(
174175
self,
@@ -228,6 +229,7 @@ def __init__(
228229

229230
# Terminating
230231
self._terminating_task = None
232+
self._closing_waiter = None
231233

232234
self._start_recv_from_ws()
233235
self._start_buffered_message_sender()
@@ -415,27 +417,25 @@ async def close(
415417
_wait_for_closed: bool = True,
416418
) -> None:
417419
"""Close the session and all associated streams."""
418-
if (current_state or self._state) in TerminalStates:
419-
start = datetime.now()
420-
while (
421-
_wait_for_closed
422-
and (current_state or self._state) != SessionState.CLOSED
423-
):
424-
elapsed = (datetime.now() - start).total_seconds()
425-
if elapsed >= SESSION_CLOSE_TIMEOUT_SEC:
426-
logger.warning(
427-
f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} "
428-
"seconds to close, leaking",
429-
)
430-
break
420+
if self._closing_waiter:
421+
# Break early for internal callers
422+
if not _wait_for_closed:
423+
return
424+
try:
431425
logger.debug("Session already closing, waiting...")
432-
await asyncio.sleep(0.2)
433-
# already closing
426+
async with asyncio.timeout(SESSION_CLOSE_TIMEOUT_SEC):
427+
await self._closing_waiter.wait()
428+
except asyncio.TimeoutError:
429+
logger.warning(
430+
f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} "
431+
"seconds to close, leaking",
432+
)
434433
return
435434
logger.info(
436435
f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}"
437436
)
438437
self._state = SessionState.CLOSING
438+
self._closing_waiter = asyncio.Event()
439439

440440
# We're closing, so we need to wake up...
441441
# ... tasks waiting for connection to be established
@@ -501,6 +501,10 @@ async def close(
501501
# This will get us GC'd, so this should be the last thing.
502502
self._close_session_callback(self)
503503

504+
# Release waiters, then release the event
505+
self._closing_waiter.set()
506+
self._closing_waiter = None
507+
504508
def _start_buffered_message_sender(
505509
self,
506510
) -> None:

0 commit comments

Comments
 (0)