diff --git a/docs/api.md b/docs/api.md index 6ef4ac4b..4781381f 100644 --- a/docs/api.md +++ b/docs/api.md @@ -120,7 +120,7 @@ what gets sent over the wire.* >>> response = client.send(request) ``` -* `def __init__(method, url, [params], [headers], [cookies], [content], [data], [files], [json], [stream])` +* `def __init__(method, url, [params], [headers], [cookies], [content], [data], [files], [json], [json_serializer], [json_deserializer], [stream])` * `.method` - **str** * `.url` - **URL** * `.content` - **byte**, **byte iterator**, or **byte async iterator** diff --git a/docs/quickstart.md b/docs/quickstart.md index 7cb3929c..518e4249 100644 --- a/docs/quickstart.md +++ b/docs/quickstart.md @@ -249,6 +249,17 @@ For more complicated data structures you'll often want to use JSON encoding inst } ``` +The JSON serializer and deserializer can be customized per request or per client. +The serializer must accept the object passed as `json=...` and return either +`str` or `bytes`. This allows serializers such as `orjson.dumps` to be used +directly. + +```pycon +>>> import orjson +>>> with httpx.Client(json_serializer=orjson.dumps, json_deserializer=orjson.loads) as client: +... r = client.post("https://httpbin.org/post", json=data) +``` + ## Sending Binary Request Data For other encodings, you should use the `content=...` parameter, passing diff --git a/src/httpx2/httpx2/_api.py b/src/httpx2/httpx2/_api.py index 95810605..09b823e7 100644 --- a/src/httpx2/httpx2/_api.py +++ b/src/httpx2/httpx2/_api.py @@ -10,6 +10,8 @@ AuthTypes, CookieTypes, HeaderTypes, + JsonDeserializer, + JsonSerializer, ProxyTypes, QueryParamTypes, RequestContent, @@ -45,6 +47,8 @@ def request( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | None = None, + json_deserializer: JsonDeserializer | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | None = None, @@ -105,6 +109,7 @@ def request( verify=verify, timeout=timeout, trust_env=trust_env, + json_deserializer=json_deserializer, ) as client: return client.request( method=method, @@ -113,6 +118,8 @@ def request( data=data, files=files, json=json, + json_serializer=json_serializer, + json_deserializer=json_deserializer, params=params, headers=headers, auth=auth, @@ -130,6 +137,8 @@ def stream( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | None = None, + json_deserializer: JsonDeserializer | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, auth: AuthTypes | None = None, @@ -155,6 +164,7 @@ def stream( verify=verify, timeout=timeout, trust_env=trust_env, + json_deserializer=json_deserializer, ) as client: with client.stream( method=method, @@ -163,6 +173,8 @@ def stream( data=data, files=files, json=json, + json_serializer=json_serializer, + json_deserializer=json_deserializer, params=params, headers=headers, auth=auth, @@ -182,6 +194,7 @@ def get( follow_redirects: bool = False, verify: ssl.SSLContext | str | bool = True, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + json_deserializer: JsonDeserializer | None = None, trust_env: bool = True, ) -> Response: """ @@ -203,6 +216,7 @@ def get( follow_redirects=follow_redirects, verify=verify, timeout=timeout, + json_deserializer=json_deserializer, trust_env=trust_env, ) @@ -218,6 +232,7 @@ def options( follow_redirects: bool = False, verify: ssl.SSLContext | str | bool = True, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + json_deserializer: JsonDeserializer | None = None, trust_env: bool = True, ) -> Response: """ @@ -239,6 +254,7 @@ def options( follow_redirects=follow_redirects, verify=verify, timeout=timeout, + json_deserializer=json_deserializer, trust_env=trust_env, ) @@ -254,6 +270,7 @@ def head( follow_redirects: bool = False, verify: ssl.SSLContext | str | bool = True, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + json_deserializer: JsonDeserializer | None = None, trust_env: bool = True, ) -> Response: """ @@ -275,6 +292,7 @@ def head( follow_redirects=follow_redirects, verify=verify, timeout=timeout, + json_deserializer=json_deserializer, trust_env=trust_env, ) @@ -286,6 +304,7 @@ def post( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -294,6 +313,7 @@ def post( follow_redirects: bool = False, verify: ssl.SSLContext | str | bool = True, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + json_deserializer: JsonDeserializer | None = None, trust_env: bool = True, ) -> Response: """ @@ -308,6 +328,7 @@ def post( data=data, files=files, json=json, + json_serializer=json_serializer, params=params, headers=headers, cookies=cookies, @@ -316,6 +337,7 @@ def post( follow_redirects=follow_redirects, verify=verify, timeout=timeout, + json_deserializer=json_deserializer, trust_env=trust_env, ) @@ -327,6 +349,7 @@ def put( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -335,6 +358,7 @@ def put( follow_redirects: bool = False, verify: ssl.SSLContext | str | bool = True, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + json_deserializer: JsonDeserializer | None = None, trust_env: bool = True, ) -> Response: """ @@ -349,6 +373,7 @@ def put( data=data, files=files, json=json, + json_serializer=json_serializer, params=params, headers=headers, cookies=cookies, @@ -357,6 +382,7 @@ def put( follow_redirects=follow_redirects, verify=verify, timeout=timeout, + json_deserializer=json_deserializer, trust_env=trust_env, ) @@ -368,6 +394,7 @@ def patch( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | None = None, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -376,6 +403,7 @@ def patch( follow_redirects: bool = False, verify: ssl.SSLContext | str | bool = True, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + json_deserializer: JsonDeserializer | None = None, trust_env: bool = True, ) -> Response: """ @@ -390,6 +418,7 @@ def patch( data=data, files=files, json=json, + json_serializer=json_serializer, params=params, headers=headers, cookies=cookies, @@ -398,6 +427,7 @@ def patch( follow_redirects=follow_redirects, verify=verify, timeout=timeout, + json_deserializer=json_deserializer, trust_env=trust_env, ) @@ -412,6 +442,7 @@ def delete( proxy: ProxyTypes | None = None, follow_redirects: bool = False, timeout: TimeoutTypes = DEFAULT_TIMEOUT_CONFIG, + json_deserializer: JsonDeserializer | None = None, verify: ssl.SSLContext | str | bool = True, trust_env: bool = True, ) -> Response: @@ -434,5 +465,6 @@ def delete( follow_redirects=follow_redirects, verify=verify, timeout=timeout, + json_deserializer=json_deserializer, trust_env=trust_env, ) diff --git a/src/httpx2/httpx2/_client.py b/src/httpx2/httpx2/_client.py index 3c950679..5b0e28c7 100644 --- a/src/httpx2/httpx2/_client.py +++ b/src/httpx2/httpx2/_client.py @@ -36,6 +36,8 @@ CertTypes, CookieTypes, HeaderTypes, + JsonDeserializer, + JsonSerializer, ProxyTypes, QueryParamTypes, RequestContent, @@ -190,6 +192,8 @@ def __init__( base_url: URL | str = "", trust_env: bool = True, default_encoding: str | typing.Callable[[bytes], str] = "utf-8", + json_serializer: JsonSerializer | None = None, + json_deserializer: JsonDeserializer | None = None, ) -> None: event_hooks = {} if event_hooks is None else event_hooks @@ -208,6 +212,8 @@ def __init__( } self._trust_env = trust_env self._default_encoding = default_encoding + self._json_serializer = json_serializer + self._json_deserializer = json_deserializer self._state = ClientState.UNOPENED @property @@ -331,6 +337,8 @@ def build_request( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | UseClientDefault | None = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -353,6 +361,8 @@ def build_request( cookies = self._merge_cookies(cookies) params = self._merge_queryparams(params) extensions = {} if extensions is None else extensions + if not isinstance(json_deserializer, UseClientDefault): + extensions = dict(**extensions, json_deserializer=json_deserializer) if "timeout" not in extensions: timeout = self.timeout if isinstance(timeout, UseClientDefault) else Timeout(timeout) extensions = dict(**extensions, timeout=timeout.as_dict()) @@ -363,6 +373,7 @@ def build_request( data=data, files=files, json=json, + json_serializer=self._json_serializer if isinstance(json_serializer, UseClientDefault) else json_serializer, params=params, headers=headers, cookies=cookies, @@ -629,6 +640,8 @@ def __init__( base_url: URL | str = "", transport: BaseTransport | None = None, default_encoding: str | typing.Callable[[bytes], str] = "utf-8", + json_serializer: JsonSerializer | None = None, + json_deserializer: JsonDeserializer | None = None, ) -> None: super().__init__( auth=auth, @@ -642,6 +655,8 @@ def __init__( base_url=base_url, trust_env=trust_env, default_encoding=default_encoding, + json_serializer=json_serializer, + json_deserializer=json_deserializer, ) if http2: @@ -746,6 +761,8 @@ def request( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | UseClientDefault | None = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -785,6 +802,8 @@ def request( data=data, files=files, json=json, + json_serializer=json_serializer, + json_deserializer=json_deserializer, params=params, headers=headers, cookies=cookies, @@ -803,6 +822,8 @@ def stream( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | UseClientDefault | None = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -828,6 +849,8 @@ def stream( data=data, files=files, json=json, + json_serializer=json_serializer, + json_deserializer=json_deserializer, params=params, headers=headers, cookies=cookies, @@ -980,6 +1003,7 @@ def _send_single_request(self, request: Request) -> Response: response.stream = BoundSyncStream(response.stream, response=response, start=start) self.cookies.extract_cookies(response) response.default_encoding = self._default_encoding + response._json_deserializer = request.extensions.get("json_deserializer", self._json_deserializer) logger.info( 'HTTP Request: %s %s "%s %d %s"', @@ -1002,6 +1026,7 @@ def get( auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, extensions: RequestExtensions | None = None, ) -> Response: """ @@ -1018,6 +1043,7 @@ def get( auth=auth, follow_redirects=follow_redirects, timeout=timeout, + json_deserializer=json_deserializer, extensions=extensions, ) @@ -1031,6 +1057,7 @@ def options( auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, extensions: RequestExtensions | None = None, ) -> Response: """ @@ -1047,6 +1074,7 @@ def options( auth=auth, follow_redirects=follow_redirects, timeout=timeout, + json_deserializer=json_deserializer, extensions=extensions, ) @@ -1060,6 +1088,7 @@ def head( auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, extensions: RequestExtensions | None = None, ) -> Response: """ @@ -1076,6 +1105,7 @@ def head( auth=auth, follow_redirects=follow_redirects, timeout=timeout, + json_deserializer=json_deserializer, extensions=extensions, ) @@ -1087,6 +1117,8 @@ def post( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | UseClientDefault | None = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -1107,6 +1139,8 @@ def post( data=data, files=files, json=json, + json_serializer=json_serializer, + json_deserializer=json_deserializer, params=params, headers=headers, cookies=cookies, @@ -1124,6 +1158,8 @@ def put( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | UseClientDefault | None = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -1144,6 +1180,8 @@ def put( data=data, files=files, json=json, + json_serializer=json_serializer, + json_deserializer=json_deserializer, params=params, headers=headers, cookies=cookies, @@ -1161,6 +1199,8 @@ def patch( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | UseClientDefault | None = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -1181,6 +1221,8 @@ def patch( data=data, files=files, json=json, + json_serializer=json_serializer, + json_deserializer=json_deserializer, params=params, headers=headers, cookies=cookies, @@ -1200,6 +1242,7 @@ def delete( auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, extensions: RequestExtensions | None = None, ) -> Response: """ @@ -1216,6 +1259,7 @@ def delete( auth=auth, follow_redirects=follow_redirects, timeout=timeout, + json_deserializer=json_deserializer, extensions=extensions, ) @@ -1331,6 +1375,8 @@ def __init__( transport: AsyncBaseTransport | None = None, trust_env: bool = True, default_encoding: str | typing.Callable[[bytes], str] = "utf-8", + json_serializer: JsonSerializer | None = None, + json_deserializer: JsonDeserializer | None = None, ) -> None: super().__init__( auth=auth, @@ -1344,6 +1390,8 @@ def __init__( base_url=base_url, trust_env=trust_env, default_encoding=default_encoding, + json_serializer=json_serializer, + json_deserializer=json_deserializer, ) if http2: @@ -1448,6 +1496,8 @@ async def request( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | UseClientDefault | None = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -1488,6 +1538,8 @@ async def request( data=data, files=files, json=json, + json_serializer=json_serializer, + json_deserializer=json_deserializer, params=params, headers=headers, cookies=cookies, @@ -1506,6 +1558,8 @@ async def stream( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | UseClientDefault | None = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -1531,6 +1585,8 @@ async def stream( data=data, files=files, json=json, + json_serializer=json_serializer, + json_deserializer=json_deserializer, params=params, headers=headers, cookies=cookies, @@ -1683,6 +1739,7 @@ async def _send_single_request(self, request: Request) -> Response: response.stream = BoundAsyncStream(response.stream, response=response, start=start) self.cookies.extract_cookies(response) response.default_encoding = self._default_encoding + response._json_deserializer = request.extensions.get("json_deserializer", self._json_deserializer) logger.info( 'HTTP Request: %s %s "%s %d %s"', @@ -1705,6 +1762,7 @@ async def get( auth: AuthTypes | UseClientDefault | None = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, extensions: RequestExtensions | None = None, ) -> Response: """ @@ -1721,6 +1779,7 @@ async def get( auth=auth, follow_redirects=follow_redirects, timeout=timeout, + json_deserializer=json_deserializer, extensions=extensions, ) @@ -1734,6 +1793,7 @@ async def options( auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, extensions: RequestExtensions | None = None, ) -> Response: """ @@ -1750,6 +1810,7 @@ async def options( auth=auth, follow_redirects=follow_redirects, timeout=timeout, + json_deserializer=json_deserializer, extensions=extensions, ) @@ -1763,6 +1824,7 @@ async def head( auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, extensions: RequestExtensions | None = None, ) -> Response: """ @@ -1779,6 +1841,7 @@ async def head( auth=auth, follow_redirects=follow_redirects, timeout=timeout, + json_deserializer=json_deserializer, extensions=extensions, ) @@ -1790,6 +1853,8 @@ async def post( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | UseClientDefault | None = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -1810,6 +1875,8 @@ async def post( data=data, files=files, json=json, + json_serializer=json_serializer, + json_deserializer=json_deserializer, params=params, headers=headers, cookies=cookies, @@ -1827,6 +1894,8 @@ async def put( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | UseClientDefault | None = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -1847,6 +1916,8 @@ async def put( data=data, files=files, json=json, + json_serializer=json_serializer, + json_deserializer=json_deserializer, params=params, headers=headers, cookies=cookies, @@ -1864,6 +1935,8 @@ async def patch( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | UseClientDefault | None = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, params: QueryParamTypes | None = None, headers: HeaderTypes | None = None, cookies: CookieTypes | None = None, @@ -1884,6 +1957,8 @@ async def patch( data=data, files=files, json=json, + json_serializer=json_serializer, + json_deserializer=json_deserializer, params=params, headers=headers, cookies=cookies, @@ -1903,6 +1978,7 @@ async def delete( auth: AuthTypes | UseClientDefault = USE_CLIENT_DEFAULT, follow_redirects: bool | UseClientDefault = USE_CLIENT_DEFAULT, timeout: TimeoutTypes | UseClientDefault = USE_CLIENT_DEFAULT, + json_deserializer: JsonDeserializer | UseClientDefault | None = USE_CLIENT_DEFAULT, extensions: RequestExtensions | None = None, ) -> Response: """ @@ -1919,6 +1995,7 @@ async def delete( auth=auth, follow_redirects=follow_redirects, timeout=timeout, + json_deserializer=json_deserializer, extensions=extensions, ) diff --git a/src/httpx2/httpx2/_content.py b/src/httpx2/httpx2/_content.py index 71969ecc..9f2e06aa 100644 --- a/src/httpx2/httpx2/_content.py +++ b/src/httpx2/httpx2/_content.py @@ -17,6 +17,7 @@ from ._multipart import MultipartStream from ._types import ( AsyncByteStream, + JsonSerializer, RequestContent, RequestData, RequestFiles, @@ -28,6 +29,10 @@ __all__ = ["ByteStream"] +def default_json_serializer(json: Any) -> str: + return json_dumps(json, ensure_ascii=False, separators=(",", ":"), allow_nan=False) + + class ByteStream(AsyncByteStream, SyncByteStream): def __init__(self, stream: bytes) -> None: self._stream = stream @@ -173,8 +178,13 @@ def encode_html(html: str) -> tuple[dict[str, str], ByteStream]: return headers, ByteStream(body) -def encode_json(json: Any) -> tuple[dict[str, str], ByteStream]: - body = json_dumps(json, ensure_ascii=False, separators=(",", ":"), allow_nan=False).encode("utf-8") +def encode_json(json: Any, json_serializer: JsonSerializer | None = None) -> tuple[dict[str, str], ByteStream]: + dumps = default_json_serializer if json_serializer is None else json_serializer + body = dumps(json) + if isinstance(body, str): + body = body.encode("utf-8") + elif not isinstance(body, bytes): + raise TypeError(f"JSON serializer returned unsupported type {type(body)!r}; expected str or bytes") content_length = str(len(body)) content_type = "application/json" headers = {"Content-Length": content_length, "Content-Type": content_type} @@ -186,6 +196,7 @@ def encode_request( data: RequestData | None = None, files: RequestFiles | None = None, json: Any | None = None, + json_serializer: JsonSerializer | None = None, boundary: bytes | None = None, ) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]: """ @@ -211,7 +222,7 @@ def encode_request( elif data: return encode_urlencoded_data(data) elif json is not None: - return encode_json(json) + return encode_json(json, json_serializer=json_serializer) return {}, ByteStream(b"") @@ -221,6 +232,7 @@ def encode_response( text: str | None = None, html: str | None = None, json: Any | None = None, + json_serializer: JsonSerializer | None = None, ) -> tuple[dict[str, str], SyncByteStream | AsyncByteStream]: """ Handles encoding the given `content`, returning a two-tuple of @@ -233,6 +245,6 @@ def encode_response( elif html is not None: return encode_html(html) elif json is not None: - return encode_json(json) + return encode_json(json, json_serializer=json_serializer) return {}, ByteStream(b"") diff --git a/src/httpx2/httpx2/_models.py b/src/httpx2/httpx2/_models.py index 02baa3eb..a94993de 100644 --- a/src/httpx2/httpx2/_models.py +++ b/src/httpx2/httpx2/_models.py @@ -36,6 +36,8 @@ AsyncByteStream, CookieTypes, HeaderTypes, + JsonDeserializer, + JsonSerializer, QueryParamTypes, RequestContent, RequestData, @@ -381,6 +383,7 @@ def __init__( data: RequestData | None = None, files: RequestFiles | None = None, json: typing.Any | None = None, + json_serializer: JsonSerializer | None = None, stream: SyncByteStream | AsyncByteStream | None = None, extensions: RequestExtensions | None = None, ) -> None: @@ -399,6 +402,7 @@ def __init__( data=data, files=files, json=json, + json_serializer=json_serializer, boundary=get_multipart_boundary_from_content_type( content_type=content_type.encode(self.headers.encoding) if content_type else None ), @@ -503,6 +507,8 @@ def __init__( text: str | None = None, html: str | None = None, json: typing.Any = None, + json_serializer: JsonSerializer | None = None, + json_deserializer: JsonDeserializer | None = None, stream: SyncByteStream | AsyncByteStream | None = None, request: Request | None = None, extensions: ResponseExtensions | None = None, @@ -513,6 +519,7 @@ def __init__( self.headers = Headers(headers) self._request: Request | None = request + self._json_deserializer = json_deserializer # When follow_redirects=False and a redirect is received, # the client will set `response.next_request`. @@ -527,7 +534,7 @@ def __init__( self.default_encoding = default_encoding if stream is None: - headers, stream = encode_response(content, text, html, json) + headers, stream = encode_response(content, text, html, json, json_serializer=json_serializer) self._prepare(headers) self.stream = stream if isinstance(stream, ByteStream): @@ -802,6 +809,10 @@ def raise_for_status(self) -> Response: raise HTTPStatusError(message, request=request, response=self) def json(self, **kwargs: typing.Any) -> typing.Any: + if self._json_deserializer is not None: + if kwargs: + raise TypeError("Response.json() keyword arguments are only supported with the default JSON decoder") + return self._json_deserializer(self.content) return jsonlib.loads(self.content, **kwargs) @property diff --git a/src/httpx2/httpx2/_types.py b/src/httpx2/httpx2/_types.py index 99a91d3b..ff0e85fb 100644 --- a/src/httpx2/httpx2/_types.py +++ b/src/httpx2/httpx2/_types.py @@ -70,6 +70,8 @@ ResponseExtensions = Mapping[str, Any] RequestData = Mapping[str, Any] +JsonSerializer = Callable[[Any], Union[str, bytes]] +JsonDeserializer = Callable[[Union[bytes, str]], Any] FileContent = Union[IO[bytes], bytes, str] FileTypes = Union[ diff --git a/tests/httpx2/client/test_async_client.py b/tests/httpx2/client/test_async_client.py index ccc73813..18727648 100644 --- a/tests/httpx2/client/test_async_client.py +++ b/tests/httpx2/client/test_async_client.py @@ -67,6 +67,36 @@ async def test_post_json(server): assert response.status_code == 200 +@pytest.mark.anyio +async def test_async_client_json_serializer(): + def echo_body(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(200, content=request.content) + + def dumps(data: typing.Any) -> bytes: + assert data == {"text": "Hello, world!"} + return b'{"text":"custom!"}' + + async with httpx2.AsyncClient(transport=httpx2.MockTransport(echo_body), json_serializer=dumps) as client: + response = await client.post("https://example.org", json={"text": "Hello, world!"}) + + assert response.content == b'{"text":"custom!"}' + + +@pytest.mark.anyio +async def test_async_client_json_deserializer(): + def json_response(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(200, content=b'{"text":"Hello, world!"}') + + def loads(content: bytes | str) -> typing.Any: + assert content == b'{"text":"Hello, world!"}' + return {"text": "custom!"} + + async with httpx2.AsyncClient(transport=httpx2.MockTransport(json_response), json_deserializer=loads) as client: + response = await client.get("https://example.org") + + assert response.json() == {"text": "custom!"} + + @pytest.mark.anyio async def test_stream_response(server): async with httpx2.AsyncClient() as client: diff --git a/tests/httpx2/client/test_client.py b/tests/httpx2/client/test_client.py index e39e17d7..2ec11288 100644 --- a/tests/httpx2/client/test_client.py +++ b/tests/httpx2/client/test_client.py @@ -89,6 +89,92 @@ def test_post_json(server): assert response.reason_phrase == "OK" +def test_client_json_serializer(): + def echo_body(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(200, content=request.content) + + def dumps(data: typing.Any) -> bytes: + assert data == {"text": "Hello, world!"} + return b'{"text":"custom!"}' + + with httpx2.Client(transport=httpx2.MockTransport(echo_body), json_serializer=dumps) as client: + response = client.post("https://example.org", json={"text": "Hello, world!"}) + + assert response.content == b'{"text":"custom!"}' + + +def test_request_json_serializer_overrides_client_default(): + def echo_body(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(200, content=request.content) + + def client_dumps(data: typing.Any) -> bytes: + return b'{"source":"client"}' + + def request_dumps(data: typing.Any) -> bytes: + return b'{"source":"request"}' + + with httpx2.Client(transport=httpx2.MockTransport(echo_body), json_serializer=client_dumps) as client: + response = client.post("https://example.org", json={"text": "Hello, world!"}, json_serializer=request_dumps) + + assert response.content == b'{"source":"request"}' + + +def test_request_json_serializer_can_use_builtin_default(): + def echo_body(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(200, content=request.content) + + def client_dumps(data: typing.Any) -> bytes: + return b'{"source":"client"}' + + with httpx2.Client(transport=httpx2.MockTransport(echo_body), json_serializer=client_dumps) as client: + response = client.post("https://example.org", json={"text": "Hello, world!"}, json_serializer=None) + + assert response.content == b'{"text":"Hello, world!"}' + + +def test_client_json_deserializer(): + def json_response(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(200, content=b'{"text":"Hello, world!"}') + + def loads(content: bytes | str) -> typing.Any: + assert content == b'{"text":"Hello, world!"}' + return {"text": "custom!"} + + with httpx2.Client(transport=httpx2.MockTransport(json_response), json_deserializer=loads) as client: + response = client.get("https://example.org") + + assert response.json() == {"text": "custom!"} + + +def test_request_json_deserializer_overrides_client_default(): + def json_response(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(200, content=b'{"text":"Hello, world!"}') + + def client_loads(content: bytes | str) -> typing.Any: + return {"source": "client"} + + def request_loads(content: bytes | str) -> typing.Any: + return {"source": "request"} + + with httpx2.Client(transport=httpx2.MockTransport(json_response), json_deserializer=client_loads) as client: + response = client.get("https://example.org", json_deserializer=request_loads) + + assert response.json() == {"source": "request"} + + +def test_request_json_deserializer_can_use_builtin_default(): + def json_response(request: httpx2.Request) -> httpx2.Response: + return httpx2.Response(200, content=b'{"text":"Hello, world!"}') + + def client_loads(content: bytes | str) -> typing.Any: + return {"source": "client"} + + with httpx2.Client(transport=httpx2.MockTransport(json_response), json_deserializer=client_loads) as client: + response = client.get("https://example.org", json_deserializer=None) + + assert response.json() == {"text": "Hello, world!"} + + def test_stream_response(server): with httpx2.Client() as client: with client.stream("GET", server.url) as response: diff --git a/tests/httpx2/test_content.py b/tests/httpx2/test_content.py index 77a3b7f5..1b8483f4 100644 --- a/tests/httpx2/test_content.py +++ b/tests/httpx2/test_content.py @@ -180,6 +180,36 @@ async def test_json_content(): assert async_content == b'{"Hello":"world!"}' +@pytest.mark.anyio +async def test_json_content_with_custom_serializer(): + def dumps(data: typing.Any) -> bytes: + assert data == {"Hello": "world!"} + return b'{"Hello":"custom!"}' + + request = httpx2.Request(method, url, json={"Hello": "world!"}, json_serializer=dumps) + assert isinstance(request.stream, typing.Iterable) + assert isinstance(request.stream, typing.AsyncIterable) + + sync_content = b"".join(list(request.stream)) + async_content = b"".join([part async for part in request.stream]) + + assert request.headers == { + "Host": "www.example.com", + "Content-Length": "19", + "Content-Type": "application/json", + } + assert sync_content == b'{"Hello":"custom!"}' + assert async_content == b'{"Hello":"custom!"}' + + +def test_json_serializer_must_return_str_or_bytes(): + def dumps(data: typing.Any) -> object: + return data + + with pytest.raises(TypeError, match="JSON serializer returned unsupported type"): + httpx2.Request(method, url, json={"Hello": "world!"}, json_serializer=dumps) # type: ignore[arg-type] + + @pytest.mark.anyio async def test_urlencoded_content(): request = httpx2.Request(method, url, data={"Hello": "world!"}) @@ -502,6 +532,35 @@ def test_separators_for_compact_json(): assert response.headers["Content-Type"] == "application/json" +def test_response_json_with_custom_serializer(): + def dumps(data: typing.Any) -> str: + assert data == {"Hello": "world!"} + return '{"Hello":"custom!"}' + + response = httpx2.Response(200, json={"Hello": "world!"}, json_serializer=dumps) + + assert response.content == b'{"Hello":"custom!"}' + assert response.headers["Content-Length"] == "19" + assert response.headers["Content-Type"] == "application/json" + + +def test_response_json_with_custom_deserializer(): + def loads(content: bytes | str) -> typing.Any: + assert content == b'{"Hello":"world!"}' + return {"Hello": "custom!"} + + response = httpx2.Response(200, content=b'{"Hello":"world!"}', json_deserializer=loads) + + assert response.json() == {"Hello": "custom!"} + + +def test_custom_json_deserializer_rejects_loads_kwargs(): + response = httpx2.Response(200, content=b'{"Hello":"world!"}', json_deserializer=lambda content: {}) + + with pytest.raises(TypeError, match="keyword arguments are only supported"): + response.json(parse_float=float) + + def test_allow_nan_false(): data_with_nan = {"nombre": float("nan")} data_with_inf = {"nombre": float("inf")}