From fa12463a674d67e1205f78ac0f7dac88cb373ef8 Mon Sep 17 00:00:00 2001 From: Tom Christie Date: Tue, 16 Sep 2025 16:18:20 +0100 Subject: [PATCH] Configurable ssl_ctx on NetworkBackend --- src/ahttpx/_network.py | 10 +++++++--- src/httpx/_network.py | 12 +++++++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/src/ahttpx/_network.py b/src/ahttpx/_network.py index d895fb9..957e036 100644 --- a/src/ahttpx/_network.py +++ b/src/ahttpx/_network.py @@ -81,8 +81,12 @@ async def __aexit__( class NetworkBackend: - def __init__(self): - self._ssl_context = ssl.create_default_context(cafile=certifi.where()) + def __init__(self, ssl_ctx: ssl.SSLContext | None = None): + self._ssl_ctx = self.create_default_context() if ssl_ctx is None else ssl_ctx + + def create_default_context(self) -> ssl.SSLContext: + import certifi + return ssl.create_default_context(cafile=certifi.where()) async def connect(self, host: str, port: int) -> NetworkStream: """ @@ -98,7 +102,7 @@ async def connect_tls(self, host: str, port: int, hostname: str = '') -> Network """ address = f"{host}:{port}" reader, writer = await asyncio.open_connection(host, port) - await writer.start_tls(self._ssl_context, server_hostname=hostname) + await writer.start_tls(self._ssl_ctx, server_hostname=hostname) return NetworkStream(reader, writer, address=address) async def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer: diff --git a/src/httpx/_network.py b/src/httpx/_network.py index 551d824..5ea9bb5 100644 --- a/src/httpx/_network.py +++ b/src/httpx/_network.py @@ -9,8 +9,6 @@ import types import typing -import certifi - from ._streams import Stream @@ -193,8 +191,12 @@ def _handler(self, stream): class NetworkBackend: - def __init__(self): - self._ssl_context = ssl.create_default_context(cafile=certifi.where()) + def __init__(self, ssl_ctx: ssl.SSLContext | None = None): + self._ssl_ctx = self.create_default_context() if ssl_ctx is None else ssl_ctx + + def create_default_context(self) -> ssl.SSLContext: + import certifi + return ssl.create_default_context(cafile=certifi.where()) def connect(self, host: str, port: int) -> NetworkStream: """ @@ -213,7 +215,7 @@ def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream hostname = hostname or host timeout = get_current_timeout() sock = socket.create_connection(address, timeout=timeout) - sock = self._ssl_context.wrap_socket(sock, server_hostname=hostname) + sock = self._ssl_ctx.wrap_socket(sock, server_hostname=hostname) return NetworkStream(sock, address) def listen(self, host: str, port: int) -> NetworkListener: