diff --git a/core/api/middleware/database_connection_middleware.py b/core/api/middleware/database_connection_middleware.py index 8d179ce..88a00b0 100644 --- a/core/api/middleware/database_connection_middleware.py +++ b/core/api/middleware/database_connection_middleware.py @@ -12,6 +12,7 @@ def __init__(self, app: ASGIApp, database: Database) -> None: super().__init__(app=app) self.database = database + # NOTE(krishan711): see note in database.py about why this can cause problems with concurrent operations async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response: # NOTE(krishan711): hack to prevent running this for streaming endpoints because streaming # endpoints return a response with a generator inside it so this middleware wouldn't work diff --git a/core/store/database.py b/core/store/database.py index 42242b4..29f2365 100644 --- a/core/store/database.py +++ b/core/store/database.py @@ -4,12 +4,14 @@ from collections.abc import AsyncIterator from typing import TypeVar +import sqlalchemy from sqlalchemy.engine import Result from sqlalchemy.ext.asyncio import AsyncConnection from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.sql.selectable import TypedReturnsRows +from core import logging from core.exceptions import InternalServerErrorException DatabaseConnection = AsyncConnection @@ -34,9 +36,16 @@ def __init__(self, connectionString: str) -> None: self._engine: AsyncEngine | None = None self._connectionContext = contextvars.ContextVar[DatabaseConnection | None]('_connectionContext') - async def connect(self) -> None: + async def connect(self, poolSize: int = 100) -> None: if not self._engine: - self._engine = create_async_engine(self.connectionString, future=True) + self._engine = create_async_engine( + self.connectionString, + # echo_pool=True, + # hide_parameters=False, + pool_size=poolSize, + pool_recycle=3600, + pool_pre_ping=True, + ) async def disconnect(self) -> None: if self._engine: @@ -59,18 +68,32 @@ def _get_context_connection(self) -> DatabaseConnection | None: pass return None + # NOTE(krishan711): this is a little confusing. We creaete a connection for each erquest + # but if anything inside that request wants to do parallel queries, they should create + # their own transaction using `self.database.create_transaction()`, because asyncpg (and psql) + # do not support parallel queries on the same connection. This shows up badly if there is an + # uncaught exception raised whilst parallel queries are running. + # We have the forced reconnect at the bottom just to catch for this wierd case. @contextlib.asynccontextmanager async def create_context_connection(self) -> AsyncIterator[DatabaseConnection]: if not self._engine: raise InternalServerErrorException(message='Engine has not been established. Please called collect() first.') if self._get_context_connection() is not None: raise InternalServerErrorException(message='Connection has already been established in this context.') - async with self._engine.begin() as connection: - self._connectionContext.set(connection) - try: - yield connection - finally: - self._connectionContext.set(None) + connection = None + try: + async with self._engine.begin() as connection: + self._connectionContext.set(connection) + try: + yield connection + finally: + self._connectionContext.set(None) + except sqlalchemy.exc.InterfaceError as exception: + if 'cannot perform operation: another operation is in progress' not in str(exception): + raise + logging.error(f'Database connection error (likely concurrent operations): {exception}. Forcing reconnect. You MUST ensure that you are not running parallel queries on the same connection.') + await self.disconnect() + await self.connect() async def execute(self, query: TypedReturnsRows[ResultType], connection: DatabaseConnection | None = None) -> Result[ResultType]: if not self._engine: