Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/ahttpx/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,12 @@ async def __aexit__(


class NetworkBackend:
def __init__(self):
self._ssl_context = ssl.create_default_context(cafile=certifi.where())
def __init__(self, ssl_ctx: ssl.SSLContext | None = None):
self._ssl_ctx = self.create_default_context() if ssl_ctx is None else ssl_ctx

def create_default_context(self) -> ssl.SSLContext:
import certifi
return ssl.create_default_context(cafile=certifi.where())

async def connect(self, host: str, port: int) -> NetworkStream:
"""
Expand All @@ -98,7 +102,7 @@ async def connect_tls(self, host: str, port: int, hostname: str = '') -> Network
"""
address = f"{host}:{port}"
reader, writer = await asyncio.open_connection(host, port)
await writer.start_tls(self._ssl_context, server_hostname=hostname)
await writer.start_tls(self._ssl_ctx, server_hostname=hostname)
return NetworkStream(reader, writer, address=address)

async def serve(self, host: str, port: int, handler: typing.Callable[[NetworkStream], None]) -> NetworkServer:
Expand Down
12 changes: 7 additions & 5 deletions src/httpx/_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,6 @@
import types
import typing

import certifi

from ._streams import Stream


Expand Down Expand Up @@ -193,8 +191,12 @@ def _handler(self, stream):


class NetworkBackend:
def __init__(self):
self._ssl_context = ssl.create_default_context(cafile=certifi.where())
def __init__(self, ssl_ctx: ssl.SSLContext | None = None):
self._ssl_ctx = self.create_default_context() if ssl_ctx is None else ssl_ctx

def create_default_context(self) -> ssl.SSLContext:
import certifi
return ssl.create_default_context(cafile=certifi.where())

def connect(self, host: str, port: int) -> NetworkStream:
"""
Expand All @@ -213,7 +215,7 @@ def connect_tls(self, host: str, port: int, hostname: str = '') -> NetworkStream
hostname = hostname or host
timeout = get_current_timeout()
sock = socket.create_connection(address, timeout=timeout)
sock = self._ssl_context.wrap_socket(sock, server_hostname=hostname)
sock = self._ssl_ctx.wrap_socket(sock, server_hostname=hostname)
return NetworkStream(sock, address)

def listen(self, host: str, port: int) -> NetworkListener:
Expand Down