diff --git a/CHANGES/11283.bugfix.rst b/CHANGES/11283.bugfix.rst new file mode 100644 index 00000000000..966b9afbd00 --- /dev/null +++ b/CHANGES/11283.bugfix.rst @@ -0,0 +1,3 @@ +Fixed access log timestamps ignoring daylight saving time (DST) changes. The +previous implementation used :py:data:`time.timezone` which is a constant and +does not reflect DST transitions -- by :user:`nightcityblade`. diff --git a/CHANGES/11989.feature.rst b/CHANGES/11989.feature.rst new file mode 100644 index 00000000000..ced05b5e100 --- /dev/null +++ b/CHANGES/11989.feature.rst @@ -0,0 +1,7 @@ +Added explicit APIs for bytes-returning JSON serializer: +``JSONBytesEncoder`` type, ``JsonBytesPayload``, +:func:`~aiohttp.web.json_bytes_response`, +:meth:`~aiohttp.web.WebSocketResponse.send_json_bytes` and +:meth:`~aiohttp.ClientWebSocketResponse.send_json_bytes` methods, and +``json_serialize_bytes`` parameter for :class:`~aiohttp.ClientSession` +-- by :user:`kevinpark1217`. diff --git a/aiohttp/client.py b/aiohttp/client.py index 555d5678d4d..2aa8dfb8acf 100644 --- a/aiohttp/client.py +++ b/aiohttp/client.py @@ -106,7 +106,13 @@ from .http import WS_KEY, HttpVersion, WebSocketReader, WebSocketWriter from .http_websocket import WSHandshakeError, ws_ext_gen, ws_ext_parse from .tracing import Trace, TraceConfig -from .typedefs import JSONEncoder, LooseCookies, LooseHeaders, StrOrURL +from .typedefs import ( + JSONBytesEncoder, + JSONEncoder, + LooseCookies, + LooseHeaders, + StrOrURL, +) __all__ = ( # client_exceptions @@ -278,6 +284,7 @@ class ClientSession: "_default_auth", "_version", "_json_serialize", + "_json_serialize_bytes", "_requote_redirect_url", "_timeout", "_raise_for_status", @@ -312,6 +319,7 @@ def __init__( skip_auto_headers: Iterable[str] | None = None, auth: BasicAuth | None = None, json_serialize: JSONEncoder = json.dumps, + json_serialize_bytes: JSONBytesEncoder | None = None, request_class: type[ClientRequest] = ClientRequest, response_class: type[ClientResponse] = ClientResponse, ws_response_class: type[ClientWebSocketResponse] = ClientWebSocketResponse, @@ -390,6 +398,7 @@ def __init__( self._default_auth = auth self._version = version self._json_serialize = json_serialize + self._json_serialize_bytes = json_serialize_bytes self._raise_for_status = raise_for_status self._auto_decompress = auto_decompress self._trust_env = trust_env @@ -518,7 +527,10 @@ async def _request( "data and json parameters can not be used at the same time" ) elif json is not None: - data = payload.JsonPayload(json, dumps=self._json_serialize) + if self._json_serialize_bytes is not None: + data = payload.JsonBytesPayload(json, dumps=self._json_serialize_bytes) + else: + data = payload.JsonPayload(json, dumps=self._json_serialize) redirects = 0 history: list[ClientResponse] = [] diff --git a/aiohttp/client_ws.py b/aiohttp/client_ws.py index b387f4bfc94..ef48b430162 100644 --- a/aiohttp/client_ws.py +++ b/aiohttp/client_ws.py @@ -24,6 +24,7 @@ from .typedefs import ( DEFAULT_JSON_DECODER, DEFAULT_JSON_ENCODER, + JSONBytesEncoder, JSONDecoder, JSONEncoder, ) @@ -302,6 +303,20 @@ async def send_json( ) -> None: await self.send_str(dumps(data), compress=compress) + async def send_json_bytes( + self, + data: Any, + compress: int | None = None, + *, + dumps: JSONBytesEncoder, + ) -> None: + """Send JSON data using a bytes-returning encoder as a binary frame. + + Use this when your JSON encoder (like orjson) returns bytes + instead of str, avoiding the encode/decode overhead. + """ + await self.send_bytes(dumps(data), compress=compress) + async def close(self, *, code: int = WSCloseCode.OK, message: bytes = b"") -> bool: # we need to break `receive()` cycle first, # `close()` may be called from different task diff --git a/aiohttp/payload.py b/aiohttp/payload.py index d5996a0e915..9a8dc2f3262 100644 --- a/aiohttp/payload.py +++ b/aiohttp/payload.py @@ -23,7 +23,7 @@ sentinel, ) from .streams import StreamReader -from .typedefs import JSONEncoder +from .typedefs import JSONBytesEncoder, JSONEncoder __all__ = ( "PAYLOAD_REGISTRY", @@ -38,6 +38,7 @@ "TextIOPayload", "StringIOPayload", "JsonPayload", + "JsonBytesPayload", "AsyncIterablePayload", ) @@ -939,6 +940,29 @@ def __init__( ) +class JsonBytesPayload(BytesPayload): + """JSON payload for encoders that return bytes directly. + + Use this when your JSON encoder (like orjson) returns bytes + instead of str, avoiding the encode/decode overhead. + """ + + def __init__( + self, + value: Any, + dumps: JSONBytesEncoder, + content_type: str = "application/json", + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__( + dumps(value), + content_type=content_type, + *args, + **kwargs, + ) + + class AsyncIterablePayload(Payload): _iter: AsyncIterator[bytes] | None = None _value: AsyncIterable[bytes] diff --git a/aiohttp/typedefs.py b/aiohttp/typedefs.py index cd016a4e3c4..a98d793c7a7 100644 --- a/aiohttp/typedefs.py +++ b/aiohttp/typedefs.py @@ -15,6 +15,7 @@ Byteish = bytes | bytearray | memoryview JSONEncoder = Callable[[Any], str] +JSONBytesEncoder = Callable[[Any], bytes] JSONDecoder = Callable[[str], Any] LooseHeaders = ( Mapping[str, str] diff --git a/aiohttp/web.py b/aiohttp/web.py index b116b5913d1..15cfcc99b98 100644 --- a/aiohttp/web.py +++ b/aiohttp/web.py @@ -81,7 +81,13 @@ from .web_middlewares import middleware, normalize_path_middleware from .web_protocol import PayloadAccessError, RequestHandler, RequestPayloadError from .web_request import BaseRequest, FileField, Request -from .web_response import ContentCoding, Response, StreamResponse, json_response +from .web_response import ( + ContentCoding, + Response, + StreamResponse, + json_bytes_response, + json_response, +) from .web_routedef import ( AbstractRouteDef, RouteDef, @@ -208,6 +214,7 @@ "ContentCoding", "Response", "StreamResponse", + "json_bytes_response", "json_response", "ResponseKey", # web_routedef diff --git a/aiohttp/web_log.py b/aiohttp/web_log.py index 95b34e1029a..aafbf237ca6 100644 --- a/aiohttp/web_log.py +++ b/aiohttp/web_log.py @@ -5,7 +5,8 @@ import re import time as time_mod from collections import namedtuple -from typing import Any, Callable, Dict, Iterable, List, Tuple # noqa +from collections.abc import Iterable +from typing import Callable, ClassVar from .abc import AbstractAccessLogger from .web_request import BaseRequest @@ -60,6 +61,9 @@ class AccessLogger(AbstractAccessLogger): CLEANUP_RE = re.compile(r"(%[^s])") _FORMAT_CACHE: dict[str, tuple[str, list[KeyMethod]]] = {} + _cached_tz: ClassVar[datetime.timezone | None] = None + _cached_tz_expires: ClassVar[float] = 0.0 + def __init__(self, logger: logging.Logger, log_format: str = LOG_FORMAT) -> None: """Initialise the logger. @@ -136,10 +140,24 @@ def _format_a(request: BaseRequest, response: StreamResponse, time: float) -> st ip = request.remote return ip if ip is not None else "-" + @classmethod + def _get_local_time(cls) -> datetime.datetime: + if cls._cached_tz is None or time_mod.time() >= cls._cached_tz_expires: + gmtoff = time_mod.localtime().tm_gmtoff + cls._cached_tz = tz = datetime.timezone(datetime.timedelta(seconds=gmtoff)) + + now = datetime.datetime.now(tz) + # Expire at every 30 mins, as any DST change should occur at 0/30 mins past. + d = now + datetime.timedelta(minutes=30) + d = d.replace(minute=30 if d.minute >= 30 else 0, second=0, microsecond=0) + cls._cached_tz_expires = d.timestamp() + return now + + return datetime.datetime.now(cls._cached_tz) + @staticmethod def _format_t(request: BaseRequest, response: StreamResponse, time: float) -> str: - tz = datetime.timezone(datetime.timedelta(seconds=-time_mod.timezone)) - now = datetime.datetime.now(tz) + now = AccessLogger._get_local_time() start_time = now - datetime.timedelta(seconds=time) return start_time.strftime("[%d/%b/%Y:%H:%M:%S %z]") diff --git a/aiohttp/web_response.py b/aiohttp/web_response.py index 7b506716f8b..c88911c4dcc 100644 --- a/aiohttp/web_response.py +++ b/aiohttp/web_response.py @@ -32,12 +32,18 @@ ) from .http import SERVER_SOFTWARE, HttpVersion10, HttpVersion11 from .payload import Payload -from .typedefs import JSONEncoder, LooseHeaders +from .typedefs import JSONBytesEncoder, JSONEncoder, LooseHeaders REASON_PHRASES = {http_status.value: http_status.phrase for http_status in HTTPStatus} LARGE_BODY_SIZE = 1024**2 -__all__ = ("ContentCoding", "StreamResponse", "Response", "json_response") +__all__ = ( + "ContentCoding", + "StreamResponse", + "Response", + "json_response", + "json_bytes_response", +) if TYPE_CHECKING: @@ -758,3 +764,32 @@ def json_response( headers=headers, content_type=content_type, ) + + +def json_bytes_response( + data: Any = sentinel, + *, + dumps: JSONBytesEncoder, + body: bytes | None = None, + status: int = 200, + reason: str | None = None, + headers: LooseHeaders | None = None, + content_type: str = "application/json", +) -> Response: + """Create a JSON response using a bytes-returning encoder. + + Use this when your JSON encoder (like orjson) returns bytes + instead of str, avoiding the encode/decode overhead. + """ + if data is not sentinel: + if body is not None: + raise ValueError("only one of data or body should be specified") + else: + body = dumps(data) + return Response( + body=body, + status=status, + reason=reason, + headers=headers, + content_type=content_type, + ) diff --git a/aiohttp/web_ws.py b/aiohttp/web_ws.py index dee7225d428..ca129bb0f30 100644 --- a/aiohttp/web_ws.py +++ b/aiohttp/web_ws.py @@ -37,7 +37,7 @@ from .http_websocket import _INTERNAL_RECEIVE_TYPES, WSMessageError from .log import ws_logger from .streams import EofStream -from .typedefs import JSONDecoder, JSONEncoder +from .typedefs import JSONBytesEncoder, JSONDecoder, JSONEncoder from .web_exceptions import HTTPBadRequest, HTTPException from .web_request import BaseRequest from .web_response import StreamResponse @@ -481,6 +481,20 @@ async def send_json( ) -> None: await self.send_str(dumps(data), compress=compress) + async def send_json_bytes( + self, + data: Any, + compress: int | None = None, + *, + dumps: JSONBytesEncoder, + ) -> None: + """Send JSON data using a bytes-returning encoder as a binary frame. + + Use this when your JSON encoder (like orjson) returns bytes + instead of str, avoiding the encode/decode overhead. + """ + await self.send_bytes(dumps(data), compress=compress) + async def write_eof(self) -> None: # type: ignore[override] if self._eof_sent: return diff --git a/docs/client_reference.rst b/docs/client_reference.rst index ddd561c7e60..033dfb53ca7 100644 --- a/docs/client_reference.rst +++ b/docs/client_reference.rst @@ -1842,6 +1842,28 @@ manually. The method is converted into :term:`coroutine`, *compress* parameter added. + .. method:: send_json_bytes(data, compress=None, *, dumps) + :async: + + Send *data* to peer as a JSON binary frame using a bytes-returning encoder. + + :param data: data to send. + + :param int compress: sets specific level of compression for + single message, + ``None`` for not overriding per-socket setting. + + :param collections.abc.Callable dumps: any :term:`callable` that accepts an object and + returns JSON as :class:`bytes` + (e.g. ``orjson.dumps``). + + :raise RuntimeError: if connection is not started or closing + + :raise ValueError: if data is not serializable object + + :raise TypeError: if value returned by ``dumps(data)`` is not + :class:`bytes` + .. method:: send_frame(message, opcode, compress=None) :async: diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 82e1e65487d..9890b228e01 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -229,6 +229,7 @@ nowait OAuth Online optimizations +orjson os outcoming Overridable diff --git a/docs/web_reference.rst b/docs/web_reference.rst index 63fb510562d..fe2488c63b8 100644 --- a/docs/web_reference.rst +++ b/docs/web_reference.rst @@ -1208,6 +1208,27 @@ and :ref:`aiohttp-web-signals` handlers:: The method is converted into :term:`coroutine`, *compress* parameter added. + .. method:: send_json_bytes(data, compress=None, *, dumps) + :async: + + Send *data* to peer as a JSON binary frame using a bytes-returning encoder. + + :param data: data to send. + + :param int compress: sets specific level of compression for + single message, + ``None`` for not overriding per-socket setting. + + :param collections.abc.Callable dumps: any :term:`callable` that accepts an object and + returns JSON as :class:`bytes` + (e.g. ``orjson.dumps``). + + :raise RuntimeError: if the connection is not started. + + :raise ValueError: if data is not serializable object + + :raise TypeError: if value returned by ``dumps`` param is not :class:`bytes` + .. method:: send_frame(message, opcode, compress=None) :async: @@ -1389,6 +1410,18 @@ content type and *data* encoded by ``dumps`` parameter (:func:`json.dumps` by default). +.. function:: json_bytes_response([data], *, dumps, body=None, \ + status=200, reason=None, headers=None, \ + content_type='application/json') + +Return :class:`Response` with predefined ``'application/json'`` +content type and *data* encoded by ``dumps`` parameter +which must return :class:`bytes` directly (e.g. ``orjson.dumps``). + +Use this when your JSON encoder returns :class:`bytes` instead of :class:`str`, +avoiding the :class:`str`-to-:class:`bytes` encoding overhead. + + .. class:: ResponseKey(name, t) :canonical: aiohttp.helpers.ResponseKey diff --git a/tests/test_client_functional.py b/tests/test_client_functional.py index c191ce10910..0c432602893 100644 --- a/tests/test_client_functional.py +++ b/tests/test_client_functional.py @@ -2193,6 +2193,27 @@ def dumps(obj: Any) -> str: await client.post("/", data="some data", json={"some": "data"}) +async def test_json_serialize_bytes(aiohttp_client: AiohttpClient) -> None: + """Test ClientSession.json_serialize_bytes with bytes-returning encoder.""" + + async def handler(request: web.Request) -> web.Response: + assert request.content_type == "application/json" + data = await request.json() + return web.Response(body=aiohttp.JsonPayload(data)) + + json_bytes_encoder = mock.Mock(side_effect=lambda x: json.dumps(x).encode("utf-8")) + + app = web.Application() + app.router.add_post("/", handler) + client = await aiohttp_client(app, json_serialize_bytes=json_bytes_encoder) + + async with client.post("/", json={"some": "data"}) as resp: + assert resp.status == 200 + assert json_bytes_encoder.called + content = await resp.json() + assert content == {"some": "data"} + + async def test_expect_continue(aiohttp_client: AiohttpClient) -> None: expect_called = False diff --git a/tests/test_client_ws_functional.py b/tests/test_client_ws_functional.py index ddd1404579f..c58613c6ca9 100644 --- a/tests/test_client_ws_functional.py +++ b/tests/test_client_ws_functional.py @@ -207,6 +207,64 @@ async def handler(request: web.Request) -> web.WebSocketResponse: await resp.close() +async def test_send_json_bytes_client(aiohttp_client: AiohttpClient) -> None: + """Test ClientWebSocketResponse.send_json_bytes sends binary frame.""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive() + assert msg.type is WSMsgType.BINARY + data = json.loads(msg.data) + await ws.send_json_bytes( + {"response": data["request"]}, + dumps=lambda x: json.dumps(x).encode("utf-8"), + ) + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") + test_payload = {"request": "test"} + await resp.send_json_bytes( + test_payload, dumps=lambda x: json.dumps(x).encode("utf-8") + ) + + msg = await resp.receive() + assert msg.type is WSMsgType.BINARY + data = json.loads(msg.data) + assert data["response"] == test_payload["request"] + await resp.close() + + +async def test_send_json_bytes_custom_encoder(aiohttp_client: AiohttpClient) -> None: + """Test send_json_bytes with custom bytes-returning encoder.""" + + async def handler(request: web.Request) -> web.WebSocketResponse: + ws = web.WebSocketResponse() + await ws.prepare(request) + + msg = await ws.receive() + assert msg.type is WSMsgType.BINARY + # Custom encoder uses compact separators + assert msg.data == b'{"test":"value"}' + await ws.close() + return ws + + app = web.Application() + app.router.add_route("GET", "/", handler) + client = await aiohttp_client(app) + resp = await client.ws_connect("/") + await resp.send_json_bytes( + {"test": "value"}, + dumps=lambda x: json.dumps(x, separators=(",", ":")).encode("utf-8"), + ) + await resp.close() + + async def test_send_recv_frame(aiohttp_client: AiohttpClient) -> None: async def handler(request: web.Request) -> web.WebSocketResponse: ws = web.WebSocketResponse() diff --git a/tests/test_payload.py b/tests/test_payload.py index d5c2a9a0246..dd25ccfc459 100644 --- a/tests/test_payload.py +++ b/tests/test_payload.py @@ -1221,6 +1221,40 @@ def test_json_payload_size() -> None: assert jp_custom.size == len(expected_custom.encode("utf-16")) +def test_json_bytes_payload() -> None: + """Test JsonBytesPayload with a bytes-returning encoder.""" + data = {"hello": "world"} + + # Test with standard library encoder + jp = payload.JsonBytesPayload(data, dumps=lambda x: json.dumps(x).encode("utf-8")) + expected = json.dumps(data).encode("utf-8") + assert jp.size == len(expected) + + # Test with custom bytes-returning encoder (compact separators) + jp_custom = payload.JsonBytesPayload( + data, dumps=lambda x: json.dumps(x, separators=(",", ":")).encode("utf-8") + ) + expected_custom = json.dumps(data, separators=(",", ":")).encode("utf-8") + assert jp_custom.size == len(expected_custom) + + +def test_json_bytes_payload_content_type() -> None: + """Test JsonBytesPayload content_type.""" + data = {"test": "data"} + + # Default content type + jp = payload.JsonBytesPayload(data, dumps=lambda x: json.dumps(x).encode("utf-8")) + assert jp.content_type == "application/json" + + # Custom content type + jp_custom = payload.JsonBytesPayload( + data, + dumps=lambda x: json.dumps(x).encode("utf-8"), + content_type="application/vnd.api+json", + ) + assert jp_custom.content_type == "application/vnd.api+json" + + async def test_text_io_payload_size_matches_file_encoding(tmp_path: Path) -> None: """Test TextIOPayload.size when file encoding matches payload encoding.""" # Create UTF-8 file diff --git a/tests/test_web_log.py b/tests/test_web_log.py index c78f360f05e..c77c60f0ae4 100644 --- a/tests/test_web_log.py +++ b/tests/test_web_log.py @@ -94,10 +94,28 @@ def test_access_logger_atoms( class PatchedDatetime(datetime.datetime): @classmethod def now(cls, tz: datetime.tzinfo | None = None) -> Self: - return cls(1843, 1, 1, 0, 30, tzinfo=tz) + assert tz is not None + # Simulate: real UTC time is 1842-12-31 16:30, convert to local tz + utc = datetime.datetime(1842, 12, 31, 16, 30, tzinfo=datetime.timezone.utc) + local = utc.astimezone(tz) + return cls( + local.year, + local.month, + local.day, + local.hour, + local.minute, + local.second, + tzinfo=tz, + ) monkeypatch.setattr("datetime.datetime", PatchedDatetime) - monkeypatch.setattr("time.timezone", -28800) + # Mock localtime to return CST (+0800 = 28800 seconds) + mock_localtime = mock.Mock() + mock_localtime.return_value.tm_gmtoff = 28800 + monkeypatch.setattr("aiohttp.web_log.time_mod.localtime", mock_localtime) + # Clear cached timezone so it gets rebuilt with our mock + AccessLogger._cached_tz = None + AccessLogger._cached_tz_expires = 0.0 monkeypatch.setattr("os.getpid", lambda: 42) mock_logger = mock.Mock() access_logger = AccessLogger(mock_logger, log_format) @@ -115,6 +133,84 @@ def now(cls, tz: datetime.tzinfo | None = None) -> Self: mock_logger.info.assert_called_with(expected, extra=extra) +@pytest.mark.skipif( + IS_PYPY, + reason="PyPy has issues with patching datetime.datetime", +) +def test_access_logger_dst_timezone(monkeypatch: pytest.MonkeyPatch) -> None: + """Test that _format_t uses the current local UTC offset, not a cached one. + + This ensures timestamps are correct during DST transitions. The old + implementation used time.timezone which is a constant and doesn't + reflect DST changes. + """ + # Simulate a timezone that observes DST (e.g., US Eastern) + # During EST: UTC-5 (-18000s), during EDT: UTC-4 (-14400s) + gmtoff_est = -18000 # UTC-5 + gmtoff_edt = -14400 # UTC-4 + + class PatchedDatetime(datetime.datetime): + @classmethod + def now(cls, tz: datetime.tzinfo | None = None) -> Self: + assert tz is not None + # Simulate: real UTC time is 07:00, convert to local tz + utc = datetime.datetime(2024, 3, 10, 7, 0, 0, tzinfo=datetime.timezone.utc) + local = utc.astimezone(tz) + return cls( + local.year, + local.month, + local.day, + local.hour, + local.minute, + local.second, + tzinfo=tz, + ) + + monkeypatch.setattr("datetime.datetime", PatchedDatetime) + mock_localtime = mock.Mock() + mock_localtime.return_value.tm_gmtoff = gmtoff_est + monkeypatch.setattr("aiohttp.web_log.time_mod.localtime", mock_localtime) + # Force cache refresh + AccessLogger._cached_tz = None + AccessLogger._cached_tz_expires = 0.0 + + mock_logger = mock.Mock() + access_logger = AccessLogger(mock_logger, "%t") + request = mock.Mock( + headers={}, method="GET", path_qs="/", version=(1, 1), remote="127.0.0.1" + ) + response = mock.Mock(headers={}, body_length=0, status=200) + + # During EST (UTC-5): time is 07:00-05:00 = 02:00 EST + access_logger.log(request, response, 0.0) + call1 = mock_logger.info.call_args[0][0] + assert "-0500" in call1, f"Expected EST offset in {call1}" + + mock_logger.reset_mock() + + # Switch to EDT (UTC-4): force cache invalidation + mock_localtime.return_value.tm_gmtoff = gmtoff_edt + AccessLogger._cached_tz = None + AccessLogger._cached_tz_expires = 0.0 + access_logger.log(request, response, 0.0) + call2 = mock_logger.info.call_args[0][0] + assert "-0400" in call2, f"Expected EDT offset in {call2}" + + # Verify the hour changed too (02:00 -> 03:00) + assert "02:00:00 -0500" in call1 + assert "03:00:00 -0400" in call2 + + # Verify cached tz works too + assert access_logger._cached_tz is not None + with mock.patch( + "aiohttp.web_log.time_mod.time", + return_value=access_logger._cached_tz_expires - 1, + ): + access_logger.log(request, response, 0.0) + call3 = mock_logger.info.call_args[0][0] + assert "-0400" in call3, f"Expected EDT offset in {call3}" + + def test_access_logger_dicts() -> None: log_format = "%{User-Agent}i %{Content-Length}o %{None}i" mock_logger = mock.Mock() diff --git a/tests/test_web_response.py b/tests/test_web_response.py index bc729506e0d..becbfedc965 100644 --- a/tests/test_web_response.py +++ b/tests/test_web_response.py @@ -1462,6 +1462,50 @@ def test_content_type_is_overrideable(self) -> None: assert "application/vnd.json+api" == resp.content_type +class TestJSONBytesResponse: + def test_content_type_is_application_json_by_default(self) -> None: + resp = web.json_bytes_response( + "", dumps=lambda x: json.dumps(x).encode("utf-8") + ) + assert "application/json" == resp.content_type + + def test_passing_body_only(self) -> None: + resp = web.json_bytes_response( + dumps=lambda x: json.dumps(x).encode("utf-8"), + body=b'"jaysawn"', + ) + assert resp.body == b'"jaysawn"' + + def test_data_and_body_raises_value_error(self) -> None: + with pytest.raises(ValueError) as excinfo: + web.json_bytes_response( + data="foo", dumps=lambda x: json.dumps(x).encode("utf-8"), body=b"bar" + ) + expected_message = "only one of data or body should be specified" + assert expected_message == excinfo.value.args[0] + + def test_body_is_json_encoded_bytes(self) -> None: + resp = web.json_bytes_response( + {"foo": 42}, dumps=lambda x: json.dumps(x).encode("utf-8") + ) + assert json.dumps({"foo": 42}).encode("utf-8") == resp.body + + def test_content_type_is_overrideable(self) -> None: + resp = web.json_bytes_response( + {"foo": 42}, + dumps=lambda x: json.dumps(x).encode("utf-8"), + content_type="application/vnd.json+api", + ) + assert "application/vnd.json+api" == resp.content_type + + def test_custom_dumps(self) -> None: + resp = web.json_bytes_response( + {"foo": 42}, + dumps=lambda x: json.dumps(x, separators=(",", ":")).encode("utf-8"), + ) + assert b'{"foo":42}' == resp.body + + @pytest.mark.dev_mode async def test_no_warn_small_cookie( buf: bytearray, writer: AbstractStreamWriter diff --git a/tests/test_web_websocket.py b/tests/test_web_websocket.py index d3ec524b345..12edb532e28 100644 --- a/tests/test_web_websocket.py +++ b/tests/test_web_websocket.py @@ -1,4 +1,5 @@ import asyncio +import json import time from typing import Protocol from unittest import mock @@ -190,6 +191,26 @@ async def test_send_json_nonjson(make_request: _RequestMaker) -> None: await ws.send_json(set()) +async def test_nonstarted_send_json_bytes() -> None: + ws = web.WebSocketResponse() + with pytest.raises(RuntimeError): + await ws.send_json_bytes( + {"type": "json"}, dumps=lambda x: json.dumps(x).encode("utf-8") + ) + + +async def test_send_json_bytes_nonjson(make_request: _RequestMaker) -> None: + req = make_request("GET", "/") + ws = web.WebSocketResponse() + await ws.prepare(req) + with pytest.raises(TypeError): + await ws.send_json_bytes(set(), dumps=lambda x: json.dumps(x).encode("utf-8")) + + assert ws._reader is not None + ws._reader.feed_data(WS_CLOSED_MESSAGE) + await ws.close() + + async def test_write_non_prepared() -> None: ws = web.WebSocketResponse() with pytest.raises(RuntimeError): @@ -383,6 +404,20 @@ async def test_send_json_closed(make_request: _RequestMaker) -> None: await ws.send_json({"type": "json"}) +async def test_send_json_bytes_closed(make_request: _RequestMaker) -> None: + req = make_request("GET", "/") + ws = web.WebSocketResponse() + await ws.prepare(req) + assert ws._reader is not None + ws._reader.feed_data(WS_CLOSED_MESSAGE) + await ws.close() + + with pytest.raises(ConnectionError): + await ws.send_json_bytes( + {"type": "json"}, dumps=lambda x: json.dumps(x).encode("utf-8") + ) + + async def test_send_frame_closed(make_request: _RequestMaker) -> None: req = make_request("GET", "/") ws = web.WebSocketResponse()