Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion core/api/authorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ def decorator(func: typing.Callable[[Arg(KibaApiRequest[ApiRequest], 'request')]
@functools.wraps(func)
async def async_wrapper(request: KibaApiRequest[ApiRequest]) -> typing.Any: # type: ignore[explicit-any, misc]
request.authJwt = await _authorize_bearer_jwt(request=request, authorizer=authorizer)
return await func(request=request)
result = func(request=request)
# NOTE(krishan711): this is here to support streaming responses which return an async generator
if hasattr(result, '__aiter__'):
return result
return await result

# TODO(krishan711): figure out correct typing here
return async_wrapper # type: ignore[return-value]
Expand Down
15 changes: 10 additions & 5 deletions core/api/streaming_json_route.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import functools
import inspect
import typing
from collections.abc import AsyncIterator
from typing import ParamSpec
Expand Down Expand Up @@ -29,11 +30,13 @@ def streaming_json_route[ApiRequest: BaseModel, ApiResponse: BaseModel](
) -> typing.Callable[[typing.Callable[[KibaApiRequest[ApiRequest]], AsyncIterator[ApiResponse]]], typing.Callable[_P, StreamingResponse]]:
def decorator(func: typing.Callable[[KibaApiRequest[ApiRequest]], AsyncIterator[ApiResponse]]) -> typing.Callable[_P, StreamingResponse]:
@functools.wraps(func)
async def async_wrapper(*args: typing.Any) -> StreamingResponse: # type: ignore[explicit-any, misc]
receivedRequest = args[0]
async def async_wrapper(*args: typing.Any, **kwargs: typing.Any) -> StreamingResponse: # type: ignore[explicit-any, misc]
receivedRequest = kwargs.get('request', args[0] if args else None)
if receivedRequest is None:
raise BadRequestException('Missing request')
pathParams = receivedRequest.path_params
queryParams = receivedRequest.query_params
bodyBytes = await args[0].body()
bodyBytes = await receivedRequest.body()
if len(bodyBytes) == 0:
body: JsonObject = {}
else:
Expand All @@ -49,9 +52,11 @@ async def async_wrapper(*args: typing.Any) -> StreamingResponse: # type: ignore
raise BadRequestException(f'Invalid request: {validationErrorMessage}')
kibaRequest: KibaApiRequest[ApiRequest] = KibaApiRequest(scope=receivedRequest.scope, receive=receivedRequest._receive, send=receivedRequest._send) # noqa: SLF001
kibaRequest.data = requestParams
responseGenerator = func(kibaRequest)
responseGeneratorOrAwaitable = func(kibaRequest)
responseGenerator = await responseGeneratorOrAwaitable if inspect.isawaitable(responseGeneratorOrAwaitable) else responseGeneratorOrAwaitable
wrappedGenerator = _convert_to_json_generator(typing.cast(AsyncIterator[BaseModel], responseGenerator), expectedType=typing.cast(typing.Type[BaseModel], responseType))
return StreamingResponse(content=wrappedGenerator, media_type='application/x-ndjson')
# NOTE(krishan711): we set content-encoding to identity to prevent gzip from trying to process it (cos it buffers all the content)
return StreamingResponse(content=wrappedGenerator, media_type='application/x-ndjson', headers={'Content-Encoding': 'identity'})

# TODO(krishan711): figure out correct typing here
return async_wrapper # type: ignore[return-value]
Expand Down
1 change: 1 addition & 0 deletions tests/api/test_streaming_json_route.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ def test_streaming_json_route_with_valid_body(client):
)
assert response.status_code == 200
assert response.headers["content-type"] == "application/x-ndjson"
assert response.headers["content-encoding"] == "identity"
Copy link

Copilot AI Mar 5, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This assertion bakes Content-Encoding: identity into the public contract of the route, even though the absence of the header already implies identity. This makes the test brittle and will fail if response compression is legitimately introduced later (or if the header is removed as redundant). Consider asserting that the response is not encoded (e.g., header missing or not equal to gzip/br), or omit this assertion unless callers truly require it.

Suggested change
assert response.headers["content-encoding"] == "identity"
content_encoding = response.headers.get("content-encoding", "").lower()
assert content_encoding not in ("gzip", "br")

Copilot uses AI. Check for mistakes.
dataList = parse_ndjson_response(response)
assert len(dataList) == 2
first_message = dataList[0]
Expand Down
Loading