Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,3 @@ COPY uv.lock .
RUN make install

COPY . .
Comment thread
krishan711 marked this conversation as resolved.
RUN make type-check
88 changes: 86 additions & 2 deletions core/api/authorizer.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,32 @@
import functools
import typing
from collections.abc import AsyncIterator
from typing import ParamSpec

from mypy_extensions import Arg
from pydantic import BaseModel

from core import logging
from core.api.api_request import KibaApiRequest
from core.exceptions import ForbiddenException
from core.exceptions import UnauthorizedException
from core.http.basic_authentication import BasicAuthentication
from core.http.jwt import Jwt

_P = ParamSpec('_P')
_AnyReturn = typing.Awaitable[typing.Any] | AsyncIterator[typing.Any] # type: ignore[explicit-any]


class Authorizer:
async def validate_jwt(self, jwtString: str) -> Jwt:
raise NotImplementedError


class SignatureAuthorizer:
async def retrieve_signature_signer(self, signatureString: str) -> str:
raise NotImplementedError


async def _authorize_bearer_jwt[ApiRequest: BaseModel](request: KibaApiRequest[ApiRequest], authorizer: Authorizer) -> Jwt:
authorization = request.headers.get('Authorization')
if not authorization:
Expand All @@ -33,8 +43,8 @@ async def _authorize_bearer_jwt[ApiRequest: BaseModel](request: KibaApiRequest[A

def authorize_bearer_jwt[ApiRequest: BaseModel]( # type: ignore[explicit-any]
authorizer: Authorizer,
) -> typing.Callable[[typing.Callable[[Arg(KibaApiRequest[ApiRequest], 'request')], typing.Awaitable[typing.Any]]], typing.Callable[_P, typing.Any]]:
def decorator(func: typing.Callable[[Arg(KibaApiRequest[ApiRequest], 'request')], typing.Awaitable[typing.Any]]) -> typing.Callable[_P, typing.Any]: # type: ignore[explicit-any]
) -> typing.Callable[[typing.Callable[[Arg(KibaApiRequest[ApiRequest], 'request')], _AnyReturn]], typing.Callable[_P, typing.Any]]:
def decorator(func: typing.Callable[[Arg(KibaApiRequest[ApiRequest], 'request')], _AnyReturn]) -> typing.Callable[_P, typing.Any]: # type: ignore[explicit-any]
@functools.wraps(func)
async def async_wrapper(request: KibaApiRequest[ApiRequest]) -> typing.Any: # type: ignore[explicit-any, misc]
request.authJwt = await _authorize_bearer_jwt(request=request, authorizer=authorizer)
Expand All @@ -48,3 +58,77 @@ async def async_wrapper(request: KibaApiRequest[ApiRequest]) -> typing.Any: # t
return async_wrapper # type: ignore[return-value]

return decorator


async def get_basic_authentication_from_authorization_signature[ApiRequest: BaseModel](request: KibaApiRequest[ApiRequest], authorizer: SignatureAuthorizer) -> BasicAuthentication:
authorization = request.headers.get('Authorization')
if not authorization:
raise ForbiddenException(message='AUTH_NOT_PROVIDED')
if not authorization.startswith('Signature '):
raise ForbiddenException(message='AUTH_INVALID')
signatureString = authorization.replace('Signature ', '', 1)
try:
signerId = await authorizer.retrieve_signature_signer(signatureString=signatureString)
except UnauthorizedException:
raise
except BaseException as exception: # noqa: BLE001
logging.exception(exception) # type: ignore[arg-type]
raise ForbiddenException(message='AUTH_INVALID')
return BasicAuthentication(username=signerId, password=signatureString)


def authorize_signature[ApiRequest: BaseModel]( # type: ignore[explicit-any]
authorizer: SignatureAuthorizer,
) -> typing.Callable[[typing.Callable[[Arg(KibaApiRequest[ApiRequest], 'request')], _AnyReturn]], typing.Callable[_P, typing.Any]]:
def decorator(func: typing.Callable[[Arg(KibaApiRequest[ApiRequest], 'request')], _AnyReturn]) -> typing.Callable[_P, typing.Any]: # type: ignore[explicit-any]
@functools.wraps(func)
async def async_wrapper(request: KibaApiRequest[ApiRequest]) -> typing.Any: # type: ignore[explicit-any, misc]
request.authBasic = await get_basic_authentication_from_authorization_signature(request=request, authorizer=authorizer)
result = func(request=request)
# NOTE(krishan711): this is here to support streaming responses which return an async generator
if hasattr(result, '__aiter__'):
return result
return await result

# TODO(krishan711): figure out correct typing here
return async_wrapper # type: ignore[return-value]

return decorator


class TokenAuthorizer:
async def validate_token(self, token: str) -> None:
raise NotImplementedError


class StaticTokenAuthorizer(TokenAuthorizer):
def __init__(self, token: str) -> None:
self._token = token

async def validate_token(self, token: str) -> None:
if token != self._token:
raise ForbiddenException(message='AUTH_INVALID')


def authorize_token[ApiRequest: BaseModel]( # type: ignore[explicit-any]
authorizer: TokenAuthorizer,
) -> typing.Callable[[typing.Callable[[Arg(KibaApiRequest[ApiRequest], 'request')], _AnyReturn]], typing.Callable[_P, typing.Any]]:
def decorator(func: typing.Callable[[Arg(KibaApiRequest[ApiRequest], 'request')], _AnyReturn]) -> typing.Callable[_P, typing.Any]: # type: ignore[explicit-any]
@functools.wraps(func)
async def async_wrapper(request: KibaApiRequest[ApiRequest]) -> typing.Any: # type: ignore[explicit-any, misc]
authorization = request.headers.get('Authorization')
if not authorization:
raise ForbiddenException(message='AUTH_NOT_PROVIDED')
if not authorization.startswith('Token '):
raise ForbiddenException(message='AUTH_INVALID')
await authorizer.validate_token(authorization[6:])
result = func(request=request)
# NOTE(krishan711): this is here to support streaming responses which return an async generator
if hasattr(result, '__aiter__'):
return result
return await result

# TODO(krishan711): figure out correct typing here
return async_wrapper # type: ignore[return-value]

return decorator
25 changes: 11 additions & 14 deletions core/api/middleware/database_connection_middleware.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,21 @@
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.base import RequestResponseEndpoint
from starlette.requests import Request
from starlette.responses import Response
from starlette.types import ASGIApp
from starlette.types import Receive
from starlette.types import Scope
from starlette.types import Send

from core.store.database import Database


class DatabaseConnectionMiddleware(BaseHTTPMiddleware):
class DatabaseConnectionMiddleware:
def __init__(self, app: ASGIApp, database: Database) -> None:
super().__init__(app=app)
self.app = app
self.database = database

# NOTE(krishan711): see note in database.py about why this can cause problems with concurrent operations
async def dispatch(self, request: Request, call_next: RequestResponseEndpoint) -> Response:
# NOTE(krishan711): hack to prevent running this for streaming endpoints because streaming
# endpoints return a response with a generator inside it so this middleware wouldn't work
if request.scope['path'].endswith('-streamed'):
return await call_next(request)
# isReadonly = request.method in {'GET', 'OPTIONS', 'HEAD'}
# NOTE(krishan711): raw ASGI (not BaseHTTPMiddleware) so the DB connection stays open across streaming body
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope['type'] != 'http':
await self.app(scope, receive, send)
return
async with self.database.create_context_connection():
response = await call_next(request)
return response
await self.app(scope, receive, send)
Comment on lines +19 to +21
Copy link

Copilot AI Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

create_context_connection() uses engine.begin() (a transaction) and this middleware now keeps that transaction open until the streaming response fully completes. For long-lived streams this can hold locks/retain MVCC snapshots and tie up a pooled connection for the entire client session. Consider using a non-transactional connection context for request-scoped connections (e.g., engine.connect()), or splitting streaming endpoints so DB work completes (and transaction closes) before yielding to the client.

Copilot uses AI. Check for mistakes.
2 changes: 1 addition & 1 deletion core/api/streaming_json_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ async def async_wrapper(*args: typing.Any, **kwargs: typing.Any) -> StreamingRes
raise BadRequestException(f'Invalid request: {validationErrorMessage}')
kibaRequest: KibaApiRequest[ApiRequest] = KibaApiRequest(scope=receivedRequest.scope, receive=receivedRequest._receive, send=receivedRequest._send) # noqa: SLF001
kibaRequest.data = requestParams
responseGeneratorOrAwaitable = func(request=kibaRequest)
responseGeneratorOrAwaitable = func(kibaRequest)
responseGenerator = await responseGeneratorOrAwaitable if inspect.isawaitable(responseGeneratorOrAwaitable) else responseGeneratorOrAwaitable
wrappedGenerator = _convert_to_json_generator(typing.cast(AsyncIterator[BaseModel], responseGenerator), expectedType=typing.cast(typing.Type[BaseModel], responseType))
# NOTE(krishan711): we set content-encoding to identity to prevent gzip from trying to process it (cos it buffers all the content)
Expand Down
Loading
Loading