Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions core/api/middleware/database_connection_middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ def __init__(self, app: ASGIApp, database: Database) -> None:
super().__init__(app=app)
self.database = database

# NOTE(krishan711): see note in database.py about why this can cause problems with concurrent operations
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
# NOTE(krishan711): hack to prevent running this for streaming endpoints because streaming
# endpoints return a response with a generator inside it so this middleware wouldn't work
Expand Down
39 changes: 31 additions & 8 deletions core/store/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
from collections.abc import AsyncIterator
from typing import TypeVar

import sqlalchemy
from sqlalchemy.engine import Result
from sqlalchemy.ext.asyncio import AsyncConnection
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy.sql.selectable import TypedReturnsRows

from core import logging
from core.exceptions import InternalServerErrorException

DatabaseConnection = AsyncConnection
Expand All @@ -34,9 +36,16 @@ def __init__(self, connectionString: str) -> None:
self._engine: AsyncEngine | None = None
self._connectionContext = contextvars.ContextVar[DatabaseConnection | None]('_connectionContext')

async def connect(self) -> None:
async def connect(self, poolSize: int = 100) -> None:
if not self._engine:
self._engine = create_async_engine(self.connectionString, future=True)
self._engine = create_async_engine(
self.connectionString,
# echo_pool=True,
# hide_parameters=False,
pool_size=poolSize,
pool_recycle=3600,
pool_pre_ping=True,
)

async def disconnect(self) -> None:
if self._engine:
Expand All @@ -59,18 +68,32 @@ def _get_context_connection(self) -> DatabaseConnection | None:
pass
return None

# NOTE(krishan711): this is a little confusing. We creaete a connection for each erquest
# but if anything inside that request wants to do parallel queries, they should create
# their own transaction using `self.database.create_transaction()`, because asyncpg (and psql)
# do not support parallel queries on the same connection. This shows up badly if there is an
# uncaught exception raised whilst parallel queries are running.
# We have the forced reconnect at the bottom just to catch for this wierd case.
@contextlib.asynccontextmanager
async def create_context_connection(self) -> AsyncIterator[DatabaseConnection]:
if not self._engine:
raise InternalServerErrorException(message='Engine has not been established. Please called collect() first.')
if self._get_context_connection() is not None:
raise InternalServerErrorException(message='Connection has already been established in this context.')
async with self._engine.begin() as connection:
self._connectionContext.set(connection)
try:
yield connection
finally:
self._connectionContext.set(None)
connection = None
try:
async with self._engine.begin() as connection:
self._connectionContext.set(connection)
try:
yield connection
finally:
self._connectionContext.set(None)
except sqlalchemy.exc.InterfaceError as exception:
if 'cannot perform operation: another operation is in progress' not in str(exception):
raise
logging.error(f'Database connection error (likely concurrent operations): {exception}. Forcing reconnect. You MUST ensure that you are not running parallel queries on the same connection.')
await self.disconnect()
await self.connect()

async def execute(self, query: TypedReturnsRows[ResultType], connection: DatabaseConnection | None = None) -> Result[ResultType]:
if not self._engine:
Expand Down
Loading