Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/httpx2/httpx2/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
TimeoutTypes,
)
from ._urls import URL, QueryParams
from ._utils import URLPattern, get_environment_proxies
from ._utils import Pattern, build_url_pattern, get_environment_proxies

if typing.TYPE_CHECKING:
import ssl # pragma: no cover
Expand Down Expand Up @@ -665,8 +665,8 @@ def __init__(
limits=limits,
transport=transport,
)
self._mounts: dict[URLPattern, BaseTransport | None] = {
URLPattern(key): None
self._mounts: dict[Pattern, BaseTransport | None] = {
build_url_pattern(key): None
if proxy is None
else self._init_proxy_transport(
proxy,
Expand All @@ -680,7 +680,7 @@ def __init__(
for key, proxy in proxy_map.items()
}
if mounts is not None:
self._mounts.update({URLPattern(key): transport for key, transport in mounts.items()})
self._mounts.update({build_url_pattern(key): transport for key, transport in mounts.items()})

self._mounts = dict(sorted(self._mounts.items()))

Expand Down Expand Up @@ -1368,8 +1368,8 @@ def __init__(
transport=transport,
)

self._mounts: dict[URLPattern, AsyncBaseTransport | None] = {
URLPattern(key): None
self._mounts: dict[Pattern, AsyncBaseTransport | None] = {
build_url_pattern(key): None
if proxy is None
else self._init_proxy_transport(
proxy,
Expand All @@ -1383,7 +1383,7 @@ def __init__(
for key, proxy in proxy_map.items()
}
if mounts is not None:
self._mounts.update({URLPattern(key): transport for key, transport in mounts.items()})
self._mounts.update({build_url_pattern(key): transport for key, transport in mounts.items()})
self._mounts = dict(sorted(self._mounts.items()))

def _init_transport(
Expand Down
81 changes: 72 additions & 9 deletions src/httpx2/httpx2/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import re
import typing
from abc import abstractmethod
from urllib.request import getproxies

from ._types import PrimitiveData
Expand Down Expand Up @@ -115,24 +116,41 @@ def peek_filelike_length(stream: typing.Any) -> int | None:
return length


class URLPattern:
class Pattern(typing.Protocol):
@abstractmethod
def matches(self, other: URL) -> bool:
"""this method should never be accessed"""

@property
@abstractmethod
def priority(self) -> tuple[int, int, int]:
"""this property should never be accessed"""

def __lt__(self, other: Pattern) -> bool:
"""this method should never be accessed"""

def __eq__(self, other: typing.Any) -> bool:
"""this method should never be accessed"""


class WildcardURLPattern(Pattern):
"""
A utility class currently used for making lookups against proxy keys...

# Wildcard matching...
>>> pattern = URLPattern("all://")
>>> pattern = WildcardURLPattern("all://")
>>> pattern.matches(httpx2.URL("http://example.com"))
True

# Witch scheme matching...
>>> pattern = URLPattern("https://")
>>> pattern = WildcardURLPattern("https://")
>>> pattern.matches(httpx2.URL("https://example.com"))
True
>>> pattern.matches(httpx2.URL("http://example.com"))
False

# With domain matching...
>>> pattern = URLPattern("https://example.com")
>>> pattern = WildcardURLPattern("https://example.com")
>>> pattern.matches(httpx2.URL("https://example.com"))
True
>>> pattern.matches(httpx2.URL("http://example.com"))
Expand All @@ -141,7 +159,7 @@ class URLPattern:
False

# Wildcard scheme, with domain matching...
>>> pattern = URLPattern("all://example.com")
>>> pattern = WildcardURLPattern("all://example.com")
>>> pattern.matches(httpx2.URL("https://example.com"))
True
>>> pattern.matches(httpx2.URL("http://example.com"))
Expand All @@ -150,7 +168,7 @@ class URLPattern:
False

# With port matching...
>>> pattern = URLPattern("https://example.com:1234")
>>> pattern = WildcardURLPattern("https://example.com:1234")
>>> pattern.matches(httpx2.URL("https://example.com:1234"))
True
>>> pattern.matches(httpx2.URL("https://example.com"))
Expand Down Expand Up @@ -199,7 +217,7 @@ def matches(self, other: URL) -> bool:
@property
def priority(self) -> tuple[int, int, int]:
"""
The priority allows URLPattern instances to be sortable, so that
The priority allows WildcardURLPattern instances to be sortable, so that
we can match from most specific to least specific.
"""
# URLs with a port should take priority over URLs without a port.
Expand All @@ -213,11 +231,56 @@ def priority(self) -> tuple[int, int, int]:
def __hash__(self) -> int:
return hash(self.pattern)

def __lt__(self, other: URLPattern) -> bool:
def __lt__(self, other: Pattern) -> bool:
return self.priority < other.priority

def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, WildcardURLPattern) and self.pattern == other.pattern


class IPNetPattern(Pattern):
def __init__(self, ip_net: str) -> None:
try:
addr, range = ip_net.split("/", 1)
if addr[0] == "[" and addr[-1] == "]":
addr = addr[1:-1]
ip_net = f"{addr}/{range}"
except ValueError:
pass # not a range
self.net = ipaddress.ip_network(ip_net)

def matches(self, other: URL) -> bool:
try:
return ipaddress.ip_address(other.host) in self.net
except ValueError:
return False

@property
def priority(self) -> tuple[int, int, int]:
return -1, 0, 0 # higher priority than WildcardURLPatterns

def __hash__(self) -> int:
return hash(self.net)

def __lt__(self, other: Pattern) -> bool:
return self.priority < other.priority

def __eq__(self, other: typing.Any) -> bool:
return isinstance(other, URLPattern) and self.pattern == other.pattern
return isinstance(other, IPNetPattern) and self.net == other.net


# Backward-compatible alias so existing code using URLPattern("...") keeps working.
URLPattern = WildcardURLPattern


def build_url_pattern(pattern: str) -> Pattern:
try:
proto, rest = pattern.split("://", 1)
if proto == "all" and "/" in rest:
return IPNetPattern(rest)
except ValueError: # covers .split() and IPNetPattern
pass
return WildcardURLPattern(pattern)


def is_ipv4_hostname(hostname: str) -> bool:
Expand Down
59 changes: 49 additions & 10 deletions tests/httpx2/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import pytest

import httpx2
from httpx2._utils import URLPattern, get_environment_proxies
from httpx2._utils import (
IPNetPattern,
WildcardURLPattern,
build_url_pattern,
get_environment_proxies,
)


@pytest.mark.parametrize(
Expand Down Expand Up @@ -126,24 +131,58 @@ def test_get_environment_proxies(environment, proxies):
("http://", "https://example.com", False),
("all://", "https://example.com:123", True),
("", "https://example.com:123", True),
("all://192.168.0.0/24", "http://192.168.0.1", True),
("all://192.168.0.0/24", "https://192.168.1.1", False),
("all://[2001:db8:abcd:0012::]/64", "http://[2001:db8:abcd:12::1]", True),
("all://[2001:db8:abcd:0012::]/64", "http://[2001:db8:abcd:13::1]:8080", False),
],
)
def test_url_matches(pattern, url, expected):
pattern = URLPattern(pattern)
pattern = build_url_pattern(pattern)
assert pattern.matches(httpx2.URL(url)) == expected


@pytest.mark.parametrize(
["pattern", "url", "expected"],
[
("all://192.168.0.0/24", "http://192.168.0.1", True),
("all://192.168.0.1", "http://192.168.0.1", True),
("all://192.168.0.0/24", "foobar", False),
],
)
def test_IPNetPattern(pattern, url, expected):
proto, rest = pattern.split("://", 1)
pattern = IPNetPattern(rest)
assert pattern.matches(httpx2.URL(url)) == expected


def test_build_url_pattern():
pattern1 = build_url_pattern("all://192.168.0.0/16")
pattern2 = build_url_pattern("all://192.168.0.0/16")
pattern3 = build_url_pattern("all://192.168.0.1")
assert isinstance(pattern1, IPNetPattern)
assert isinstance(pattern2, IPNetPattern)
assert isinstance(pattern3, WildcardURLPattern)
assert pattern1 == pattern2
assert pattern2 != pattern3
assert pattern1 < pattern3
assert hash(pattern1) == hash(pattern2)
assert hash(pattern2) != hash(pattern3)


def test_pattern_priority():
matchers = [
URLPattern("all://"),
URLPattern("http://"),
URLPattern("http://example.com"),
URLPattern("http://example.com:123"),
build_url_pattern("all://"),
build_url_pattern("http://"),
build_url_pattern("http://example.com"),
build_url_pattern("http://example.com:123"),
build_url_pattern("all://192.168.0.0/16"),
]
random.shuffle(matchers)
assert sorted(matchers) == [
URLPattern("http://example.com:123"),
URLPattern("http://example.com"),
URLPattern("http://"),
URLPattern("all://"),
build_url_pattern("all://192.168.0.0/16"),
build_url_pattern("http://example.com:123"),
build_url_pattern("http://example.com"),
build_url_pattern("http://"),
build_url_pattern("all://"),
]