|
4 | 4 | from collections.abc import AsyncIterable |
5 | 5 | from contextlib import asynccontextmanager |
6 | 6 | from dataclasses import dataclass |
7 | | -from datetime import datetime, timedelta |
| 7 | +from datetime import timedelta |
8 | 8 | from typing import ( |
9 | 9 | Any, |
10 | 10 | AsyncGenerator, |
@@ -169,6 +169,7 @@ class Session[HandshakeMetadata]: |
169 | 169 |
|
170 | 170 | # Terminating |
171 | 171 | _terminating_task: asyncio.Task[None] | None |
| 172 | + _closing_waiter: asyncio.Event | None |
172 | 173 |
|
173 | 174 | def __init__( |
174 | 175 | self, |
@@ -228,6 +229,7 @@ def __init__( |
228 | 229 |
|
229 | 230 | # Terminating |
230 | 231 | self._terminating_task = None |
| 232 | + self._closing_waiter = None |
231 | 233 |
|
232 | 234 | self._start_recv_from_ws() |
233 | 235 | self._start_buffered_message_sender() |
@@ -415,27 +417,25 @@ async def close( |
415 | 417 | _wait_for_closed: bool = True, |
416 | 418 | ) -> None: |
417 | 419 | """Close the session and all associated streams.""" |
418 | | - if (current_state or self._state) in TerminalStates: |
419 | | - start = datetime.now() |
420 | | - while ( |
421 | | - _wait_for_closed |
422 | | - and (current_state or self._state) != SessionState.CLOSED |
423 | | - ): |
424 | | - elapsed = (datetime.now() - start).total_seconds() |
425 | | - if elapsed >= SESSION_CLOSE_TIMEOUT_SEC: |
426 | | - logger.warning( |
427 | | - f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} " |
428 | | - "seconds to close, leaking", |
429 | | - ) |
430 | | - break |
| 420 | + if self._closing_waiter: |
| 421 | + # Break early for internal callers |
| 422 | + if not _wait_for_closed: |
| 423 | + return |
| 424 | + try: |
431 | 425 | logger.debug("Session already closing, waiting...") |
432 | | - await asyncio.sleep(0.2) |
433 | | - # already closing |
| 426 | + async with asyncio.timeout(SESSION_CLOSE_TIMEOUT_SEC): |
| 427 | + await self._closing_waiter.wait() |
| 428 | + except asyncio.TimeoutError: |
| 429 | + logger.warning( |
| 430 | + f"Session took longer than {SESSION_CLOSE_TIMEOUT_SEC} " |
| 431 | + "seconds to close, leaking", |
| 432 | + ) |
434 | 433 | return |
435 | 434 | logger.info( |
436 | 435 | f"{self.session_id} closing session to {self._server_id}, ws: {self._ws}" |
437 | 436 | ) |
438 | 437 | self._state = SessionState.CLOSING |
| 438 | + self._closing_waiter = asyncio.Event() |
439 | 439 |
|
440 | 440 | # We're closing, so we need to wake up... |
441 | 441 | # ... tasks waiting for connection to be established |
@@ -501,6 +501,10 @@ async def close( |
501 | 501 | # This will get us GC'd, so this should be the last thing. |
502 | 502 | self._close_session_callback(self) |
503 | 503 |
|
| 504 | + # Release waiters, then release the event |
| 505 | + self._closing_waiter.set() |
| 506 | + self._closing_waiter = None |
| 507 | + |
504 | 508 | def _start_buffered_message_sender( |
505 | 509 | self, |
506 | 510 | ) -> None: |
|
0 commit comments