diff --git a/core/api/authorizer.py b/core/api/authorizer.py index 5ebe04f..1a6c74b 100644 --- a/core/api/authorizer.py +++ b/core/api/authorizer.py @@ -38,7 +38,11 @@ def decorator(func: typing.Callable[[Arg(KibaApiRequest[ApiRequest], 'request')] @functools.wraps(func) async def async_wrapper(request: KibaApiRequest[ApiRequest]) -> typing.Any: # type: ignore[explicit-any, misc] request.authJwt = await _authorize_bearer_jwt(request=request, authorizer=authorizer) - return await func(request=request) + result = func(request=request) + # NOTE(krishan711): this is here to support streaming responses which return an async generator + if hasattr(result, '__aiter__'): + return result + return await result # TODO(krishan711): figure out correct typing here return async_wrapper # type: ignore[return-value] diff --git a/core/api/streaming_json_route.py b/core/api/streaming_json_route.py index 56c9e7f..0bd7728 100644 --- a/core/api/streaming_json_route.py +++ b/core/api/streaming_json_route.py @@ -1,4 +1,5 @@ import functools +import inspect import typing from collections.abc import AsyncIterator from typing import ParamSpec @@ -29,11 +30,13 @@ def streaming_json_route[ApiRequest: BaseModel, ApiResponse: BaseModel]( ) -> typing.Callable[[typing.Callable[[KibaApiRequest[ApiRequest]], AsyncIterator[ApiResponse]]], typing.Callable[_P, StreamingResponse]]: def decorator(func: typing.Callable[[KibaApiRequest[ApiRequest]], AsyncIterator[ApiResponse]]) -> typing.Callable[_P, StreamingResponse]: @functools.wraps(func) - async def async_wrapper(*args: typing.Any) -> StreamingResponse: # type: ignore[explicit-any, misc] - receivedRequest = args[0] + async def async_wrapper(*args: typing.Any, **kwargs: typing.Any) -> StreamingResponse: # type: ignore[explicit-any, misc] + receivedRequest = kwargs.get('request', args[0] if args else None) + if receivedRequest is None: + raise BadRequestException('Missing request') pathParams = receivedRequest.path_params queryParams = receivedRequest.query_params - bodyBytes = await args[0].body() + bodyBytes = await receivedRequest.body() if len(bodyBytes) == 0: body: JsonObject = {} else: @@ -49,9 +52,11 @@ async def async_wrapper(*args: typing.Any) -> StreamingResponse: # type: ignore raise BadRequestException(f'Invalid request: {validationErrorMessage}') kibaRequest: KibaApiRequest[ApiRequest] = KibaApiRequest(scope=receivedRequest.scope, receive=receivedRequest._receive, send=receivedRequest._send) # noqa: SLF001 kibaRequest.data = requestParams - responseGenerator = func(kibaRequest) + responseGeneratorOrAwaitable = func(kibaRequest) + responseGenerator = await responseGeneratorOrAwaitable if inspect.isawaitable(responseGeneratorOrAwaitable) else responseGeneratorOrAwaitable wrappedGenerator = _convert_to_json_generator(typing.cast(AsyncIterator[BaseModel], responseGenerator), expectedType=typing.cast(typing.Type[BaseModel], responseType)) - return StreamingResponse(content=wrappedGenerator, media_type='application/x-ndjson') + # NOTE(krishan711): we set content-encoding to identity to prevent gzip from trying to process it (cos it buffers all the content) + return StreamingResponse(content=wrappedGenerator, media_type='application/x-ndjson', headers={'Content-Encoding': 'identity'}) # TODO(krishan711): figure out correct typing here return async_wrapper # type: ignore[return-value] diff --git a/tests/api/test_streaming_json_route.py b/tests/api/test_streaming_json_route.py index 02bce66..8056f5c 100644 --- a/tests/api/test_streaming_json_route.py +++ b/tests/api/test_streaming_json_route.py @@ -84,6 +84,7 @@ def test_streaming_json_route_with_valid_body(client): ) assert response.status_code == 200 assert response.headers["content-type"] == "application/x-ndjson" + assert response.headers["content-encoding"] == "identity" dataList = parse_ndjson_response(response) assert len(dataList) == 2 first_message = dataList[0]