Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 37 additions & 30 deletions httpx/_transports/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from __future__ import annotations

import contextlib
import typing
from functools import cache
from types import TracebackType

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -68,9 +68,8 @@

__all__ = ["AsyncHTTPTransport", "HTTPTransport"]

HTTPCORE_EXC_MAP: dict[type[Exception], type[httpx.HTTPError]] = {}


@cache
def _load_httpcore_exceptions() -> dict[type[Exception], type[httpx.HTTPError]]:
import httpcore

Expand All @@ -92,40 +91,38 @@ def _load_httpcore_exceptions() -> dict[type[Exception], type[httpx.HTTPError]]:
}


@contextlib.contextmanager
def map_httpcore_exceptions() -> typing.Iterator[None]:
global HTTPCORE_EXC_MAP
if len(HTTPCORE_EXC_MAP) == 0:
HTTPCORE_EXC_MAP = _load_httpcore_exceptions()
try:
yield
except Exception as exc:
mapped_exc = None
@cache
def _get_httpcore_exception_types() -> tuple[type[Exception], ...]:
return tuple(_load_httpcore_exceptions())


for from_exc, to_exc in HTTPCORE_EXC_MAP.items():
if not isinstance(exc, from_exc):
continue
# We want to map to the most specific exception we can find.
# Eg if `exc` is an `httpcore.ReadTimeout`, we want to map to
# `httpx.ReadTimeout`, not just `httpx.TimeoutException`.
if mapped_exc is None or issubclass(to_exc, mapped_exc):
mapped_exc = to_exc
def _map_httpcore_exception(exc: Exception) -> httpx.HTTPError:
mapped_exc = None
for from_exc, to_exc in _load_httpcore_exceptions().items():
if not isinstance(exc, from_exc):
continue
# We want to map to the most specific exception we can find.
# Eg if `exc` is an `httpcore.ReadTimeout`, we want to map to
# `httpx.ReadTimeout`, not just `httpx.TimeoutException`.
if mapped_exc is None or issubclass(to_exc, mapped_exc):
mapped_exc = to_exc

if mapped_exc is None: # pragma: no cover
raise
if mapped_exc is None: # pragma: no cover
raise

message = str(exc)
raise mapped_exc(message) from exc
return mapped_exc(str(exc))


class ResponseStream(SyncByteStream):
def __init__(self, httpcore_stream: typing.Iterable[bytes]) -> None:
self._httpcore_stream = httpcore_stream

def __iter__(self) -> typing.Iterator[bytes]:
with map_httpcore_exceptions():
try:
for part in self._httpcore_stream:
yield part
except _get_httpcore_exception_types() as exc:
raise _map_httpcore_exception(exc) from exc

def close(self) -> None:
if hasattr(self._httpcore_stream, "close"):
Expand Down Expand Up @@ -224,8 +221,10 @@ def __exit__(
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
with map_httpcore_exceptions():
try:
self._pool.__exit__(exc_type, exc_value, traceback)
except _get_httpcore_exception_types() as exc: # pragma: no cover
raise _map_httpcore_exception(exc) from exc

def handle_request(
self,
Expand All @@ -246,8 +245,10 @@ def handle_request(
content=request.stream,
extensions=request.extensions,
)
with map_httpcore_exceptions():
try:
resp = self._pool.handle_request(req)
except _get_httpcore_exception_types() as exc:
raise _map_httpcore_exception(exc) from exc

assert isinstance(resp.stream, typing.Iterable)

Expand All @@ -267,9 +268,11 @@ def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]) -> None:
self._httpcore_stream = httpcore_stream

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
with map_httpcore_exceptions():
try:
async for part in self._httpcore_stream:
yield part
except _get_httpcore_exception_types() as exc: # pragma: no cover
raise _map_httpcore_exception(exc) from exc

async def aclose(self) -> None:
if hasattr(self._httpcore_stream, "aclose"):
Expand Down Expand Up @@ -368,8 +371,10 @@ async def __aexit__(
exc_value: BaseException | None = None,
traceback: TracebackType | None = None,
) -> None:
with map_httpcore_exceptions():
try:
await self._pool.__aexit__(exc_type, exc_value, traceback)
except _get_httpcore_exception_types() as exc: # pragma: no cover
raise _map_httpcore_exception(exc) from exc

async def handle_async_request(
self,
Expand All @@ -390,8 +395,10 @@ async def handle_async_request(
content=request.stream,
extensions=request.extensions,
)
with map_httpcore_exceptions():
try:
resp = await self._pool.handle_async_request(req)
except _get_httpcore_exception_types() as exc:
raise _map_httpcore_exception(exc) from exc

assert isinstance(resp.stream, typing.AsyncIterable)

Expand Down