diff --git a/.gitignore b/.gitignore index 564b8ce6..9bacc469 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ shippable .vscode/ Pipfile.lock requirements.txt +coverage.xml diff --git a/mocket/__init__.py b/mocket/__init__.py index 857ed2ed..2103f97e 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -1,3 +1,5 @@ +"""Mocket - socket mocking library for Python.""" + import importlib import sys diff --git a/mocket/compat.py b/mocket/compat.py index 1ac2fc89..a8e726f6 100644 --- a/mocket/compat.py +++ b/mocket/compat.py @@ -11,23 +11,57 @@ def encode_to_bytes(s: str | bytes, encoding: str = ENCODING) -> bytes: + """Encode a string or bytes to bytes. + + Args: + s: String or bytes to encode + encoding: Encoding to use (default: utf-8 or MOCKET_ENCODING env var) + + Returns: + Encoded bytes + """ if isinstance(s, str): s = s.encode(encoding) return bytes(s) def decode_from_bytes(s: str | bytes, encoding: str = ENCODING) -> str: + """Decode bytes or string to string. + + Args: + s: String or bytes to decode + encoding: Encoding to use (default: utf-8 or MOCKET_ENCODING env var) + + Returns: + Decoded string + """ if isinstance(s, bytes): s = codecs.decode(s, encoding, "ignore") return str(s) def shsplit(s: str | bytes) -> list[str]: + """Split a shell command string into arguments. + + Args: + s: Shell command string or bytes + + Returns: + List of shell command arguments + """ s = decode_from_bytes(s) return shlex.split(s) -def do_the_magic(body): +def do_the_magic(body: bytes) -> str: + """Detect MIME type of binary data using puremagic. + + Args: + body: Binary data to analyze + + Returns: + MIME type string + """ try: magic = puremagic.magic_string(body) except puremagic.PureError: diff --git a/mocket/decorators/async_mocket.py b/mocket/decorators/async_mocket.py index 3839d5f1..53b966c0 100644 --- a/mocket/decorators/async_mocket.py +++ b/mocket/decorators/async_mocket.py @@ -1,15 +1,34 @@ +"""Async version of Mocket decorator.""" + +from __future__ import annotations + +from typing import Any, Callable + from mocket.decorators.mocketizer import Mocketizer from mocket.utils import get_mocketize async def wrapper( - test, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - *args, - **kwargs, -): + test: Callable, + truesocket_recording_dir: str | None = None, + strict_mode: bool = False, + strict_mode_allowed: list | None = None, + *args: Any, + **kwargs: Any, +) -> Any: + """Async wrapper function for @async_mocketize decorator. + + Args: + test: Async test function to wrap + truesocket_recording_dir: Directory for recording true socket calls + strict_mode: Enable STRICT mode to forbid real socket calls + strict_mode_allowed: List of allowed hosts in STRICT mode + *args: Test arguments + **kwargs: Test keyword arguments + + Returns: + Result of the test function + """ async with Mocketizer.factory( test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args ): diff --git a/mocket/decorators/mocketizer.py b/mocket/decorators/mocketizer.py index fb7c811b..b067ffdf 100644 --- a/mocket/decorators/mocketizer.py +++ b/mocket/decorators/mocketizer.py @@ -1,17 +1,34 @@ +"""Mocketizer decorator for managing Mocket lifecycle in tests.""" + +from __future__ import annotations + +from typing import Any, Callable + from mocket.mocket import Mocket from mocket.mode import MocketMode from mocket.utils import get_mocketize class Mocketizer: + """Context manager and decorator for managing Mocket lifecycle in tests.""" + def __init__( self, - instance=None, - namespace=None, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - ): + instance: Any | None = None, + namespace: str | None = None, + truesocket_recording_dir: str | None = None, + strict_mode: bool = False, + strict_mode_allowed: list | None = None, + ) -> None: + """Initialize the Mocketizer. + + Args: + instance: Test instance (optional) + namespace: Namespace for recordings + truesocket_recording_dir: Directory for recording true socket calls + strict_mode: Enable STRICT mode to forbid real socket calls + strict_mode_allowed: List of allowed hosts in STRICT mode + """ self.instance = instance self.truesocket_recording_dir = truesocket_recording_dir self.namespace = namespace or str(id(self)) @@ -23,7 +40,8 @@ def __init__( "Allowed locations are only accepted when STRICT mode is active." ) - def enter(self): + def enter(self) -> None: + """Enter the Mocketizer context (enable Mocket).""" Mocket.enable( namespace=self.namespace, truesocket_recording_dir=self.truesocket_recording_dir, @@ -31,33 +49,80 @@ def enter(self): if self.instance: self.check_and_call("mocketize_setup") - def __enter__(self): + def __enter__(self) -> Mocketizer: + """Enter context manager. + + Returns: + Self for use in `with` statements + """ self.enter() return self - def exit(self): + def exit(self) -> None: + """Exit the Mocketizer context (disable Mocket).""" if self.instance: self.check_and_call("mocketize_teardown") Mocket.disable() - def __exit__(self, type, value, tb): + def __exit__(self, type: Any, value: Any, tb: Any) -> None: + """Exit context manager. + + Args: + type: Exception type + value: Exception value + tb: Traceback + """ self.exit() - async def __aenter__(self, *args, **kwargs): + async def __aenter__(self, *args: Any, **kwargs: Any) -> Mocketizer: + """Enter async context manager. + + Returns: + Self for use in `async with` statements + """ self.enter() return self - async def __aexit__(self, *args, **kwargs): + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + """Exit async context manager. + + Args: + *args: Exception arguments + **kwargs: Exception keyword arguments + """ self.exit() - def check_and_call(self, method_name): + def check_and_call(self, method_name: str) -> None: + """Check if instance has a method and call it. + + Args: + method_name: Name of method to check and call + """ method = getattr(self.instance, method_name, None) if callable(method): method() @staticmethod - def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): + def factory( + test: Callable, + truesocket_recording_dir: str | None, + strict_mode: bool, + strict_mode_allowed: list | None, + args: tuple, + ) -> Mocketizer: + """Create a Mocketizer instance for a test function. + + Args: + test: Test function being decorated + truesocket_recording_dir: Recording directory + strict_mode: Enable STRICT mode + strict_mode_allowed: Allowed hosts in STRICT mode + args: Positional arguments to test + + Returns: + Configured Mocketizer instance + """ instance = args[0] if args else None namespace = None if truesocket_recording_dir: @@ -79,13 +144,26 @@ def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, ar def wrapper( - test, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - *args, - **kwargs, -): + test: Callable, + truesocket_recording_dir: str | None = None, + strict_mode: bool = False, + strict_mode_allowed: list | None = None, + *args: Any, + **kwargs: Any, +) -> Any: + """Wrapper function for @mocketize decorator. + + Args: + test: Test function to wrap + truesocket_recording_dir: Recording directory + strict_mode: Enable STRICT mode + strict_mode_allowed: Allowed hosts in STRICT mode + *args: Test arguments + **kwargs: Test keyword arguments + + Returns: + Result of the test function + """ with Mocketizer.factory( test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args ): diff --git a/mocket/entry.py b/mocket/entry.py index 9dbbf442..2d618472 100644 --- a/mocket/entry.py +++ b/mocket/entry.py @@ -1,22 +1,38 @@ +"""Mocket entry base class for registering mock responses.""" + +from __future__ import annotations + import collections.abc +from typing import Any from mocket.compat import encode_to_bytes from mocket.mocket import Mocket class MocketEntry: + """Base class for Mocket entries that match requests and return responses.""" + class Response(bytes): + """Response wrapper class that extends bytes.""" + @property - def data(self): + def data(self) -> bytes: + """Get the response data.""" return self - response_index = 0 - request_cls = bytes - response_cls = Response - responses = None - _served = None + response_index: int = 0 + request_cls: type = bytes + response_cls: type = Response + responses: list | None = None + _served: bool | None = None + + def __init__(self, location: tuple, responses: Any) -> None: + """Initialize a Mocket entry. - def __init__(self, location, responses): + Args: + location: Tuple of (host, port) + responses: Single response or list of responses to cycle through + """ self._served = False self.location = location @@ -34,18 +50,40 @@ def __init__(self, location, responses): r = self.response_cls(r) self.responses.append(r) - def __repr__(self): + def __repr__(self) -> str: + """Return a string representation of the entry.""" return f"{self.__class__.__name__}(location={self.location})" @staticmethod - def can_handle(data): + def can_handle(data: bytes) -> bool: + """Check if this entry can handle the given request data. + + Args: + data: Request data to check + + Returns: + True if this entry can handle the request, False otherwise + """ return True - def collect(self, data): + def collect(self, data: bytes) -> None: + """Collect the request data in the Mocket singleton. + + Args: + data: Request data to collect + """ req = self.request_cls(data) Mocket.collect(req) - def get_response(self): + def get_response(self) -> bytes: + """Get the next response to send. + + Returns: + Response bytes to send to the client + + Raises: + BaseException: If a response is an exception, it will be raised + """ response = self.responses[self.response_index] if self.response_index < len(self.responses) - 1: self.response_index += 1 diff --git a/mocket/exceptions.py b/mocket/exceptions.py index f5537568..db78dbf5 100644 --- a/mocket/exceptions.py +++ b/mocket/exceptions.py @@ -1,6 +1,13 @@ +"""Mocket exception classes.""" + + class MocketException(Exception): + """Base exception class for Mocket errors.""" + pass class StrictMocketException(MocketException): + """Exception raised when a socket operation is not allowed in STRICT mode.""" + pass diff --git a/mocket/inject.py b/mocket/inject.py index 866ee563..e788a929 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -1,3 +1,5 @@ +"""Socket patching and restoration for Mocket injection.""" + from __future__ import annotations import contextlib @@ -12,17 +14,31 @@ def _patch(module: ModuleType, name: str, patched_value: Any) -> None: + """Patch a module with a new value and store the original. + + Args: + module: Module to patch + name: Attribute name to patch + patched_value: New value to set + """ with contextlib.suppress(KeyError): original_value, module.__dict__[name] = module.__dict__[name], patched_value _patches_restore[(module, name)] = original_value def _restore(module: ModuleType, name: str) -> None: + """Restore a module's original attribute value. + + Args: + module: Module to restore + name: Attribute name to restore + """ if original_value := _patches_restore.pop((module, name)): module.__dict__[name] = original_value def enable() -> None: + """Enable Mocket by patching socket, ssl, and urllib3 modules.""" from mocket.socket import ( MocketSocket, mock_create_connection, @@ -71,6 +87,7 @@ def enable() -> None: def disable() -> None: + """Disable Mocket by restoring all patched modules.""" for module, name in list(_patches_restore.keys()): _restore(module, name) diff --git a/mocket/io.py b/mocket/io.py index 0334410b..e815e0ec 100644 --- a/mocket/io.py +++ b/mocket/io.py @@ -1,3 +1,7 @@ +"""Mocket socket I/O implementation.""" + +from __future__ import annotations + import io import os @@ -5,13 +9,29 @@ class MocketSocketIO(io.BytesIO): - def __init__(self, address) -> None: + """A BytesIO wrapper that integrates with Mocket's pipe-based I/O.""" + + def __init__(self, address: tuple) -> None: + """Initialize the socket I/O with a socket address. + + Args: + address: Tuple of (host, port) + """ self._address = address super().__init__() - def write(self, content): + def write(self, content: bytes) -> int: + """Write content to the buffer and the pipe if available. + + Args: + content: Bytes to write + + Returns: + Number of bytes written + """ super().write(content) _, w_fd = Mocket.get_pair(self._address) if w_fd: os.write(w_fd, content) + return len(content) diff --git a/mocket/mocket.py b/mocket/mocket.py index a8dc7997..75ae6285 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -1,3 +1,5 @@ +"""Core Mocket singleton for socket mocking management.""" + from __future__ import annotations import collections @@ -18,6 +20,8 @@ class Mocket: + """Singleton class managing all mock socket operations and entries.""" + _socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {} _address: ClassVar[Address | tuple[None, None]] = (None, None) _entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list) @@ -30,6 +34,12 @@ def enable( namespace: str | None = None, truesocket_recording_dir: str | None = None, ) -> None: + """Enable Mocket socket mocking. + + Args: + namespace: Namespace for recording storage (defaults to id of _entries) + truesocket_recording_dir: Directory to store recorded requests/responses + """ if namespace is None: namespace = str(id(cls._entries)) @@ -47,33 +57,61 @@ def enable( @classmethod def disable(cls) -> None: + """Disable Mocket socket mocking and clean up resources.""" cls.reset() mocket.inject.disable() @classmethod def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]: - """ + """Get the file descriptor pair for a socket address. + Given the id() of the caller, return a pair of file descriptors as a tuple of two integers: (, ) + + Args: + address: (host, port) tuple + + Returns: + Tuple of (read_fd, write_fd) or (None, None) if not found """ return cls._socket_pairs.get(address, (None, None)) @classmethod def set_pair(cls, address: Address, pair: tuple[int, int]) -> None: - """ - Store a pair of file descriptors under the key `id_` + """Store a file descriptor pair for a socket address. + + Store a pair of file descriptors under the key `address` as a tuple of two integers: (, ) + + Args: + address: (host, port) tuple + pair: Tuple of (read_fd, write_fd) """ cls._socket_pairs[address] = pair @classmethod def register(cls, *entries: MocketEntry) -> None: + """Register mock entries with Mocket. + + Args: + *entries: Variable number of MocketEntry instances to register + """ for entry in entries: cls._entries[entry.location].append(entry) @classmethod - def get_entry(cls, host: str, port: int, data) -> MocketEntry | None: + def get_entry(cls, host: str, port: int, data: Any) -> MocketEntry | None: + """Get a matching entry for the given request data. + + Args: + host: Hostname + port: Port number + data: Request data + + Returns: + Matching MocketEntry or None + """ host = host or cls._address[0] port = port or cls._address[1] entries = cls._entries.get((host, port), []) @@ -83,11 +121,17 @@ def get_entry(cls, host: str, port: int, data) -> MocketEntry | None: return None @classmethod - def collect(cls, data) -> None: + def collect(cls, data: Any) -> None: + """Collect a request in the list of all requests. + + Args: + data: Request data to collect + """ cls._requests.append(data) @classmethod def reset(cls) -> None: + """Reset all Mocket state and clean up file descriptors.""" for r_fd, w_fd in cls._socket_pairs.values(): os.close(r_fd) os.close(w_fd) @@ -98,32 +142,62 @@ def reset(cls) -> None: @classmethod def last_request(cls) -> Any: + """Get the last request made. + + Returns: + Last request data or None if no requests + """ if cls.has_requests(): return cls._requests[-1] @classmethod def request_list(cls) -> list[Any]: + """Get the list of all requests. + + Returns: + List of all collected requests + """ return cls._requests @classmethod def remove_last_request(cls) -> None: + """Remove the last request from the request list.""" if cls.has_requests(): del cls._requests[-1] @classmethod def has_requests(cls) -> bool: + """Check if any requests have been made. + + Returns: + True if there are requests, False otherwise + """ return bool(cls.request_list()) @classmethod def get_namespace(cls) -> str | None: + """Get the recording namespace. + + Returns: + Namespace string or None if recording is not enabled + """ return cls._record_storage.namespace if cls._record_storage else None @classmethod def get_truesocket_recording_dir(cls) -> str | None: + """Get the true socket recording directory. + + Returns: + Directory path as string or None if recording is not enabled + """ return str(cls._record_storage.directory) if cls._record_storage else None @classmethod def assert_fail_if_entries_not_served(cls) -> None: - """Mocket checks that all entries have been served at least once.""" + """Assert that all registered entries have been served at least once. + + Raises: + AssertionError: If any entries have not been served + """ if not all(entry._served for entry in itertools.chain(*cls._entries.values())): raise AssertionError("Some Mocket entries have not been served") diff --git a/mocket/mocks/mockhttp.py b/mocket/mocks/mockhttp.py index 50a6f952..5ec14a62 100644 --- a/mocket/mocks/mockhttp.py +++ b/mocket/mocks/mockhttp.py @@ -1,8 +1,12 @@ +"""HTTP mocking implementation for Mocket.""" + +from __future__ import annotations + import re import time from functools import cached_property from http.server import BaseHTTPRequestHandler -from typing import Callable, Optional +from typing import Any, Callable from urllib.parse import parse_qs, unquote, urlsplit from h11 import SERVER, Connection, Data @@ -12,42 +16,79 @@ from mocket.entry import MocketEntry from mocket.mocket import Mocket -STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()} -CRLF = "\r\n" -ASCII = "ascii" +STATUS: dict = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()} +CRLF: str = "\r\n" +ASCII: str = "ascii" class Request: - _parser = None - _event = None + """HTTP request parser using h11.""" + + _parser: Connection | None = None + _event: Any | None = None - def __init__(self, data): + def __init__(self, data: bytes) -> None: + """Initialize the request parser. + + Args: + data: Raw HTTP request data + """ self._parser = Connection(SERVER) self.add_data(data) - def add_data(self, data): + def add_data(self, data: bytes) -> None: + """Add more data to the request. + + Args: + data: Additional raw request data + """ self._parser.receive_data(data) @property - def event(self): + def event(self) -> Any: + """Get the parsed request event. + + Returns: + The h11 request event + """ if not self._event: self._event = self._parser.next_event() return self._event @cached_property - def method(self): + def method(self) -> str: + """Get the HTTP method. + + Returns: + HTTP method (GET, POST, etc.) + """ return self.event.method.decode(ASCII) @cached_property - def path(self): + def path(self) -> str: + """Get the request path. + + Returns: + Request path with query string + """ return self.event.target.decode(ASCII) @cached_property - def headers(self): + def headers(self) -> dict: + """Get the request headers. + + Returns: + Dictionary of header names to values + """ return {k.decode(ASCII): v.decode(ASCII) for k, v in self.event.headers} @cached_property - def querystring(self): + def querystring(self) -> dict: + """Get the parsed query string. + + Returns: + Dictionary of query parameter names to lists of values + """ parts = self.path.split("?", 1) return ( parse_qs(unquote(parts[1]), keep_blank_values=True) @@ -56,7 +97,12 @@ def querystring(self): ) @cached_property - def body(self): + def body(self) -> str: + """Get the request body. + + Returns: + Decoded request body string + """ while True: event = self._parser.next_event() if isinstance(event, H11Request): @@ -64,15 +110,31 @@ def body(self): elif isinstance(event, Data): return event.data.decode(ENCODING) - def __str__(self): + def __str__(self) -> str: + """Get string representation of request. + + Returns: + Formatted request string + """ return f"{self.method} - {self.path} - {self.headers}" class Response: - headers = None - is_file_object = False + """HTTP response builder.""" + + headers: dict | None = None + is_file_object: bool = False + + def __init__( + self, body: Any = "", status: int = 200, headers: dict | None = None + ) -> None: + """Initialize an HTTP response. - def __init__(self, body="", status=200, headers=None): + Args: + body: Response body (string, bytes, or file-like object) + status: HTTP status code + headers: Dictionary of response headers + """ headers = headers or {} try: # File Objects @@ -88,6 +150,14 @@ def __init__(self, body="", status=200, headers=None): self.data = self.get_protocol_data() + self.body def get_protocol_data(self, str_format_fun_name: str = "capitalize") -> bytes: + """Get the HTTP protocol headers and status line. + + Args: + str_format_fun_name: Name of string formatting method to use + + Returns: + Bytes of protocol headers (status line and headers) + """ status_line = f"HTTP/1.1 {self.status} {STATUS[self.status]}" header_lines = CRLF.join( ( @@ -97,7 +167,8 @@ def get_protocol_data(self, str_format_fun_name: str = "capitalize") -> bytes: ) return f"{status_line}\r\n{header_lines}\r\n\r\n".encode(ENCODING) - def set_base_headers(self): + def set_base_headers(self) -> None: + """Set the base response headers.""" self.headers = { "Status": str(self.status), "Date": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()), @@ -110,8 +181,12 @@ def set_base_headers(self): else: self.headers["Content-Type"] = do_the_magic(self.body) - def set_extra_headers(self, headers): - r""" + def set_extra_headers(self, headers: dict) -> None: + r"""Add extra headers to the response. + + Args: + headers: Dictionary of additional headers + >>> r = Response(body="") >>> len(r.headers.keys()) 6 @@ -126,6 +201,8 @@ def set_extra_headers(self, headers): class Entry(MocketEntry): + """HTTP entry for matching and responding to HTTP requests.""" + CONNECT = "CONNECT" DELETE = "DELETE" GET = "GET" @@ -136,22 +213,31 @@ class Entry(MocketEntry): PUT = "PUT" TRACE = "TRACE" - METHODS = (CONNECT, DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT, TRACE) + METHODS: tuple = (CONNECT, DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT, TRACE) - request_cls = Request - response_cls = Response + request_cls: type = Request + response_cls: type = Response - default_config = {"match_querystring": True, "can_handle_fun": None} - _can_handle_fun: Optional[Callable] = None + default_config: dict = {"match_querystring": True, "can_handle_fun": None} + _can_handle_fun: Callable | None = None def __init__( self, - uri, - method, - responses, + uri: str, + method: str, + responses: Any, match_querystring: bool = True, - can_handle_fun: Optional[Callable] = None, - ): + can_handle_fun: Callable | None = None, + ) -> None: + """Initialize an HTTP entry. + + Args: + uri: URI to match (http://host:port/path?query) + method: HTTP method (GET, POST, etc.) + responses: Response(s) to return + match_querystring: Whether to match query strings + can_handle_fun: Custom matching function + """ self._can_handle_fun = can_handle_fun if can_handle_fun else self._can_handle uri = urlsplit(uri) @@ -168,10 +254,23 @@ def __init__( self._sent_data = b"" self._match_querystring = match_querystring - def __repr__(self): + def __repr__(self) -> str: + """Get string representation of the entry. + + Returns: + String representation + """ return f"{self.__class__.__name__}(method={self.method!r}, schema={self.schema!r}, location={self.location!r}, path={self.path!r}, query={self.query!r})" - def collect(self, data): + def collect(self, data: bytes) -> bool: + """Collect the request data. + + Args: + data: Request data + + Returns: + Whether to consume the response + """ consume_response = True decoded_data = decode_from_bytes(data) @@ -187,9 +286,14 @@ def collect(self, data): return consume_response def _can_handle(self, path: str, qs_dict: dict) -> bool: - """ - The default can_handle function, which checks if the path match, - and if match_querystring is True, also checks if the querystring matches. + """Default can_handle function checking path and query string. + + Args: + path: Request path + qs_dict: Parsed query string parameters + + Returns: + True if this entry can handle the request """ can_handle = path == self.path if self._match_querystring: @@ -198,8 +302,15 @@ def _can_handle(self, path: str, qs_dict: dict) -> bool: ) return can_handle - def can_handle(self, data): - r""" + def can_handle(self, data: bytes) -> bool: + r"""Check if this entry can handle the given request data. + + Args: + data: Request data + + Returns: + True if this entry can handle the request + >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b''),)) >>> e.can_handle(b'GET /?bar=foo HTTP/1.1\r\nHost: github.com\r\nAccept-Encoding: gzip, deflate\r\nConnection: keep-alive\r\nUser-Agent: python-requests/2.7.0 CPython/3.4.3 Linux/3.19.0-16-generic\r\nAccept: */*\r\n\r\n') False @@ -224,10 +335,20 @@ def can_handle(self, data): return can_handle @staticmethod - def _parse_requestline(line): - """ + def _parse_requestline(line: str) -> tuple: + """Parse an HTTP request line. + http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5 + Args: + line: HTTP request line string + + Returns: + Tuple of (method, path, version) + + Raises: + ValueError: If line is not a valid request line + >>> Entry._parse_requestline('GET / HTTP/1.0') == ('GET', '/', '1.0') True >>> Entry._parse_requestline('post /testurl htTP/1.1') == ('POST', '/testurl', '1.1') @@ -245,7 +366,19 @@ def _parse_requestline(line): raise ValueError("Not a Request-Line") @classmethod - def register(cls, method, uri, *responses, **config): + def register(cls, method: str, uri: str, *responses: Any, **config: Any) -> None: + """Register an HTTP entry for multiple responses. + + Args: + method: HTTP method (GET, POST, etc.) + uri: URI to match + *responses: Response(s) to cycle through + **config: Configuration options (match_querystring, can_handle_fun) + + Raises: + AttributeError: If using body/status params (use single_register instead) + KeyError: If invalid config keys provided + """ if "body" in config or "status" in config: raise AttributeError("Did you mean `Entry.single_register(...)`?") @@ -262,33 +395,31 @@ def register(cls, method, uri, *responses, **config): @classmethod def single_register( cls, - method, - uri, - body="", - status=200, - headers=None, - exception=None, - match_querystring=True, - can_handle_fun=None, - **config, - ): - """ - A helper method to register a single Response for a given URI and method. - Instead of passing a list of Response objects, you can just pass the response - parameters directly. + method: str, + uri: str, + body: Any = "", + status: int = 200, + headers: dict | None = None, + exception: Exception | None = None, + match_querystring: bool = True, + can_handle_fun: Callable | None = None, + **config: Any, + ) -> None: + """Register a single HTTP response for a URI and method. + + This is a convenience method that creates a single Response object + instead of requiring a list. Args: - method (str): The HTTP method (e.g., 'GET', 'POST'). - uri (str): The URI to register the response for. - body (str, optional): The body of the response. Defaults to an empty string. - status (int, optional): The HTTP status code. Defaults to 200. - headers (dict, optional): A dictionary of headers to include in the response. Defaults to None. - exception (Exception, optional): An exception to raise instead of returning a response. Defaults to None. - match_querystring (bool, optional): Whether to match the querystring in the URI. Defaults to True. - can_handle_fun (Callable, optional): A custom function to determine if the Entry can handle a request. - Defaults to None. If None, the default matching logic is used. The function should accept two parameters: - path (str), and querystring params (dict), and return a boolean. Method is matched before the function call. - **config: Additional configuration options. + method: HTTP method (GET, POST, etc.) + uri: URI to match + body: Response body content + status: HTTP status code + headers: Dictionary of response headers + exception: Exception to raise instead of returning response + match_querystring: Whether to match query strings + can_handle_fun: Custom matching function + **config: Additional configuration options """ response = ( exception diff --git a/mocket/mocks/mockredis.py b/mocket/mocks/mockredis.py index fc386e2d..eee2d6c8 100644 --- a/mocket/mocks/mockredis.py +++ b/mocket/mocks/mockredis.py @@ -1,4 +1,9 @@ +"""Redis mocking implementation for Mocket.""" + +from __future__ import annotations + from itertools import chain +from typing import Any from mocket.compat import ( decode_from_bytes, @@ -7,29 +12,63 @@ ) from mocket.entry import MocketEntry from mocket.mocket import Mocket +from mocket.types import Address class Request: - def __init__(self, data): + """Redis request wrapper.""" + + def __init__(self, data: bytes) -> None: + """Initialize a Redis request. + + Args: + data: Raw Redis command data + """ self.data = data class Response: - def __init__(self, data=None): + """Redis response wrapper.""" + + def __init__(self, data: Any = None) -> None: + """Initialize a Redis response. + + Args: + data: Response data (will be "redisize"d) + """ self.data = Redisizer.redisize(data or OK) class Redisizer(bytes): + """Convert Python types to Redis protocol format.""" + @staticmethod - def tokens(iterable): + def tokens(iterable: list[Any]) -> list[bytes]: + """Convert an iterable to Redis tokens. + + Args: + iterable: List of items to convert + + Returns: + List of Redis protocol bytes + """ iterable = [encode_to_bytes(x) for x in iterable] return [f"*{len(iterable)}".encode()] + list( chain(*zip([f"${len(x)}".encode() for x in iterable], iterable)) ) @staticmethod - def redisize(data): - def get_conversion(t): + def redisize(data: Any) -> Redisizer: + """Convert Python data to Redis protocol format. + + Args: + data: Python data to convert + + Returns: + Redisizer bytes + """ + + def get_conversion(t: type) -> Any: return { dict: lambda x: b"\r\n".join( Redisizer.tokens(list(chain(*tuple(x.items())))) @@ -48,11 +87,28 @@ def get_conversion(t): return Redisizer(get_conversion(data.__class__)(data) + b"\r\n") @staticmethod - def command(description, _type="+"): + def command(description: str, _type: str = "+") -> Redisizer: + """Create a Redis command response. + + Args: + description: Response description + _type: Response type prefix (+, -, :, $, *) + + Returns: + Formatted Redis response + """ return Redisizer("{}{}{}".format(_type, description, "\r\n").encode("utf-8")) @staticmethod - def error(description): + def error(description: str) -> Redisizer: + """Create a Redis error response. + + Args: + description: Error description + + Returns: + Formatted Redis error response + """ return Redisizer.command(description, _type="-") @@ -62,20 +118,46 @@ def error(description): class Entry(MocketEntry): + """Redis entry for matching and responding to Redis commands.""" + request_cls = Request response_cls = Response - def __init__(self, addr, command, responses): + def __init__( + self, addr: Address | None, command: str, responses: list[Any] + ) -> None: + """Initialize a Redis entry. + + Args: + addr: (host, port) tuple or None for default + command: Redis command string to match + responses: List of responses to cycle through + """ super().__init__(addr or ("localhost", 6379), responses) d = shsplit(command) d[0] = d[0].upper() self.command = Redisizer.tokens(d) - def can_handle(self, data): + def can_handle(self, data: bytes) -> bool: + """Check if this entry can handle the given command. + + Args: + data: Raw Redis command data + + Returns: + True if this entry matches the command + """ return data.splitlines() == self.command @classmethod - def register(cls, addr, command, *responses): + def register(cls, addr: Address | None, command: str, *responses: Any) -> None: + """Register a Redis entry. + + Args: + addr: (host, port) tuple or None for default + command: Redis command to match + *responses: Responses to cycle through + """ responses = [ r if isinstance(r, BaseException) else cls.response_cls(r) for r in responses @@ -83,9 +165,27 @@ def register(cls, addr, command, *responses): Mocket.register(cls(addr, command, responses)) @classmethod - def register_response(cls, command, response, addr=None): + def register_response( + cls, command: str, response: Any, addr: Address | None = None + ) -> None: + """Register a single response for a command. + + Args: + command: Redis command to match + response: Response to return + addr: (host, port) tuple or None for default + """ cls.register(addr, command, response) @classmethod - def register_responses(cls, command, responses, addr=None): + def register_responses( + cls, command: str, responses: list[Any], addr: Address | None = None + ) -> None: + """Register multiple responses for a command. + + Args: + command: Redis command to match + responses: List of responses to cycle through + addr: (host, port) tuple or None for default + """ cls.register(addr, command, *responses) diff --git a/mocket/mode.py b/mocket/mode.py index ac2ca16a..ffb23a44 100644 --- a/mocket/mode.py +++ b/mocket/mode.py @@ -1,3 +1,5 @@ +"""Mocket mode management for strict socket enforcement.""" + from __future__ import annotations from typing import TYPE_CHECKING, Any, ClassVar @@ -10,17 +12,27 @@ class _MocketMode: + """Singleton class for managing Mocket's strict mode enforcement.""" + __shared_state: ClassVar[dict[str, Any]] = {} STRICT: ClassVar = None STRICT_ALLOWED: ClassVar = None def __init__(self) -> None: + """Initialize the MocketMode singleton with shared state.""" self.__dict__ = self.__shared_state def is_allowed(self, location: str | tuple[str, int]) -> bool: - """ + """Check if a location is allowed to perform real socket calls. + Checks if (`host`, `port`) or at least `host` are allowed locations to perform real `socket` calls + + Args: + location: Hostname string or (host, port) tuple + + Returns: + True if the location is allowed, False if in STRICT mode and not allowed """ if not self.STRICT: return True @@ -35,6 +47,15 @@ def raise_not_allowed( address: tuple[str, int] | None = None, data: bytes | None = None, ) -> NoReturn: + """Raise an exception when a socket operation is not allowed in STRICT mode. + + Args: + address: The (host, port) tuple that was attempted + data: The request data that was sent + + Raises: + StrictMocketException: Always raised with detailed context + """ current_entries = [ (location, "\n ".join(map(str, entries))) for location, entries in Mocket._entries.items() diff --git a/mocket/recording.py b/mocket/recording.py index 97d2adbe..95faf126 100644 --- a/mocket/recording.py +++ b/mocket/recording.py @@ -1,3 +1,5 @@ +"""Request/response recording for playback during tests.""" + from __future__ import annotations import contextlib @@ -6,12 +8,13 @@ from collections import defaultdict from dataclasses import dataclass from pathlib import Path +from typing import Any from mocket.compat import decode_from_bytes, encode_to_bytes from mocket.types import Address from mocket.utils import hexdump, hexload -hash_function = hashlib.md5 +hash_function: Any = hashlib.md5 with contextlib.suppress(ImportError): from xxhash_cffi import xxh32 as xxhash_cffi_xxh32 @@ -25,22 +28,48 @@ def _hash_prepare_request(data: bytes) -> bytes: + """Prepare request data for hashing by sorting headers. + + Args: + data: Raw request data + + Returns: + Prepared bytes for hashing + """ _data = decode_from_bytes(data) return encode_to_bytes("".join(sorted(_data.split("\r\n")))) def _hash_request(data: bytes) -> str: + """Hash a request using the best available hash function. + + Args: + data: Raw request data + + Returns: + Hex digest of the hash + """ _data = _hash_prepare_request(data) return hash_function(_data).hexdigest() def _hash_request_fallback(data: bytes) -> str: + """Hash a request using MD5 as fallback. + + Args: + data: Raw request data + + Returns: + Hex digest of the MD5 hash + """ _data = _hash_prepare_request(data) return hashlib.md5(_data).hexdigest() @dataclass class MocketRecord: + """A record of a request and its corresponding response.""" + host: str port: int request: bytes @@ -48,7 +77,15 @@ class MocketRecord: class MocketRecordStorage: + """Storage for recording and retrieving request/response pairs.""" + def __init__(self, directory: Path, namespace: str) -> None: + """Initialize the record storage. + + Args: + directory: Path to directory for storing recordings + namespace: Namespace for grouping records + """ self._directory = directory self._namespace = namespace self._records: defaultdict[Address, defaultdict[str, MocketRecord]] = ( @@ -59,17 +96,33 @@ def __init__(self, directory: Path, namespace: str) -> None: @property def directory(self) -> Path: + """Get the recording directory. + + Returns: + Path to recording directory + """ return self._directory @property def namespace(self) -> str: + """Get the recording namespace. + + Returns: + Namespace string + """ return self._namespace @property def file(self) -> Path: + """Get the path to the namespace's JSON file. + + Returns: + Path to JSON recording file + """ return self._directory / f"{self._namespace}.json" def _load(self) -> None: + """Load recordings from disk.""" if not self.file.exists(): return @@ -92,6 +145,7 @@ def _load(self) -> None: ) def _save(self) -> None: + """Save recordings to disk.""" data: dict[str, dict[str, dict[str, dict[str, str]]]] = defaultdict( lambda: defaultdict(defaultdict) ) @@ -108,9 +162,26 @@ def _save(self) -> None: self.file.write_text(json_data) def get_records(self, address: Address) -> list[MocketRecord]: + """Get all records for an address. + + Args: + address: (host, port) tuple + + Returns: + List of MocketRecord instances + """ return list(self._records[address].values()) def get_record(self, address: Address, request: bytes) -> MocketRecord | None: + """Get a specific record matching the request. + + Args: + address: (host, port) tuple + request: Request bytes + + Returns: + Matching MocketRecord or None + """ # NOTE for backward-compat request_signature_fallback = _hash_request_fallback(request) if request_signature_fallback in self._records[address]: @@ -128,6 +199,13 @@ def put_record( request: bytes, response: bytes, ) -> None: + """Store a new record. + + Args: + address: (host, port) tuple + request: Request bytes + response: Response bytes + """ host, port = address record = MocketRecord( host=host, diff --git a/mocket/socket.py b/mocket/socket.py index e06a1a8e..bd79528c 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -1,3 +1,5 @@ +"""Mock socket implementation for Mocket.""" + from __future__ import annotations import contextlib @@ -25,7 +27,21 @@ true_socket = socket.socket -def mock_create_connection(address, timeout=None, source_address=None): +def mock_create_connection( + address: Address, + timeout: float | None = None, + source_address: Address | None = None, +) -> socket.socket: + """Create a mock socket connection. + + Args: + address: (host, port) tuple + timeout: Connection timeout in seconds + source_address: Source address for binding (unused) + + Returns: + MocketSocket instance + """ s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) if timeout: s.settimeout(timeout) @@ -41,29 +57,80 @@ def mock_getaddrinfo( proto: int = 0, flags: int = 0, ) -> list[tuple[int, int, int, str, tuple[str, int]]]: + """Mock socket.getaddrinfo function. + + Args: + host: Hostname + port: Port number + family: Address family (ignored) + type: Socket type (ignored) + proto: Protocol (ignored) + flags: Flags (ignored) + + Returns: + List of address info tuples + """ return [(2, 1, 6, "", (host, port))] def mock_gethostbyname(hostname: str) -> str: + """Mock socket.gethostbyname function. + + Args: + hostname: Hostname to resolve (unused) + + Returns: + Localhost IP address + """ return "127.0.0.1" def mock_gethostname() -> str: + """Mock socket.gethostname function. + + Returns: + Localhost hostname + """ return "localhost" def mock_inet_pton(address_family: int, ip_string: str) -> bytes: + """Mock socket.inet_pton function. + + Args: + address_family: Address family (unused) + ip_string: IP string (unused) + + Returns: + Localhost as bytes + """ return bytes("\x7f\x00\x00\x01", "utf-8") -def mock_socketpair(*args, **kwargs): - """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services.""" +def mock_socketpair( + *args: Any, + **kwargs: Any, +) -> tuple[socket.socket, socket.socket]: + """Mock socket.socketpair function. + + Returns a real socketpair() used by asyncio loop for supporting + calls made by fastapi and similar services. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + + Returns: + Tuple of two connected sockets + """ import _socket return _socket.socketpair(*args, **kwargs) class MocketSocket: + """Mock socket implementation for Mocket.""" + def __init__( self, family: socket.AddressFamily | int = socket.AF_INET, @@ -72,6 +139,15 @@ def __init__( fileno: int | None = None, **kwargs: Any, ) -> None: + """Initialize a Mocket socket. + + Args: + family: Address family + type: Socket type + proto: Protocol number + fileno: File descriptor (unused) + **kwargs: Additional keyword arguments + """ self._family = family self._type = type self._proto = proto @@ -90,9 +166,11 @@ def __init__( self._entry = None def __str__(self) -> str: + """Return a string representation of the socket.""" return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" def __enter__(self) -> Self: + """Enter context manager.""" return self def __exit__( @@ -101,27 +179,37 @@ def __exit__( value: BaseException | None, traceback: TracebackType | None, ) -> None: + """Exit context manager and close socket.""" self.close() @property def family(self) -> int: + """Get the address family.""" return self._family @property def type(self) -> int: + """Get the socket type.""" return self._type @property def proto(self) -> int: + """Get the protocol number.""" return self._proto @property def io(self) -> MocketSocketIO: + """Get or create the socket I/O object.""" if self._io is None: self._io = MocketSocketIO((self._host, self._port)) return self._io def fileno(self) -> int: + """Get the file descriptor for reading. + + Returns: + File descriptor number + """ address = (self._host, self._port) r_fd, _ = Mocket.get_pair(address) if not r_fd: @@ -130,52 +218,153 @@ def fileno(self) -> int: return r_fd def gettimeout(self) -> float | None: + """Get the socket timeout. + + Returns: + Timeout in seconds or None + """ return self._timeout - # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None` - def setsockopt(self, family: int, type: int, proto: int) -> None: - self._family = family - self._type = type - self._proto = proto + def setsockopt( + self, + level: int, + optname: int, + value: int | bytes | None, + optlen: int | None = None, + ) -> None: + """Set socket option. + Args: + level: Socket option level (e.g., socket.SOL_SOCKET) + optname: Socket option name (e.g., socket.SO_REUSEADDR) + value: Option value as an integer or bytes, or None when optlen is provided + optlen: Option length (used when value is None) + """ if self._true_socket: - self._true_socket.setsockopt(family, type, proto) + if optlen is not None: + self._true_socket.setsockopt(level, optname, value, optlen) + else: + self._true_socket.setsockopt(level, optname, value) def settimeout(self, timeout: float | None) -> None: + """Set the socket timeout. + + Args: + timeout: Timeout in seconds or None + """ self._timeout = timeout @staticmethod def getsockopt(level: int, optname: int, buflen: int | None = None) -> int: + """Get socket option (mock implementation). + + Args: + level: Socket option level + optname: Socket option name + buflen: Buffer length (unused) + + Returns: + SOCK_STREAM constant + """ return socket.SOCK_STREAM def getpeername(self) -> _RetAddress: + """Get the remote socket address. + + Returns: + Address of the remote socket + """ return self._address def setblocking(self, block: bool) -> None: + """Set the socket to blocking or non-blocking mode. + + Args: + block: True for blocking, False for non-blocking + """ self.settimeout(None) if block else self.settimeout(0.0) def getblocking(self) -> bool: + """Check if the socket is in blocking mode. + + Returns: + True if blocking, False otherwise + """ return self.gettimeout() is None def getsockname(self) -> _RetAddress: + """Get the local socket address. + + Returns: + Local socket address + """ return socket.gethostbyname(self._address[0]), self._address[1] def connect(self, address: Address) -> None: + """Connect the socket to a remote address. + + Args: + address: (host, port) tuple + """ self._address = self._host, self._port = address Mocket._address = address def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO: + """Create a file object for the socket. + + Args: + mode: Mode string (unused) + bufsize: Buffer size (unused) + + Returns: + MocketSocketIO object + """ return self.io def get_entry(self, data: bytes) -> MocketEntry | None: + """Get a matching entry for the given data. + + Args: + data: Request data + + Returns: + Matching MocketEntry or None + """ return Mocket.get_entry(self._host, self._port, data) - def sendto(self, data: ReadableBuffer, address: Address | None = None) -> int: + def sendto( + self, + data: ReadableBuffer, + address: Address | None = None, + ) -> int: + """Send data to a specific address (UDP-like). + + Args: + data: Data to send + address: Destination address + + Returns: + Number of bytes sent + """ self.connect(address) self.sendall(data) return len(data) - def sendall(self, data, entry=None, *args, **kwargs): + def sendall( + self, + data: ReadableBuffer, + entry: MocketEntry | None = None, + *args: Any, + **kwargs: Any, + ) -> None: + """Send all data through the socket. + + Args: + data: Data to send + entry: Pre-matched entry (optional) + *args: Additional arguments + **kwargs: Additional keyword arguments + """ if entry is None: entry = self.get_entry(data) @@ -198,6 +387,17 @@ def sendmsg( flags: int = 0, address: Address | None = None, ) -> int: + """Send a message through multiple buffers. + + Args: + buffers: List of buffers to send + ancdata: Ancillary data (unused) + flags: Flags (unused) + address: Destination address (unused) + + Returns: + Number of bytes sent + """ if not buffers: return 0 @@ -211,16 +411,23 @@ def recvmsg( ancbufsize: int | None = None, flags: int = 0, ) -> tuple[bytes, list[tuple[int, bytes]]]: - """ - Receive a message from the socket. + """Receive a message from the socket. + This is a mock implementation that reads from the MocketSocketIO. + + Args: + buffersize: Size of buffer to receive + ancbufsize: Ancillary buffer size (unused) + flags: Flags (unused) + + Returns: + Tuple of (data, ancillary_data) """ try: data = self.recv(buffersize) except BlockingIOError: return b"", [] - # Mocking the ancillary data and flags as empty return data, [] def recvmsg_into( @@ -229,10 +436,19 @@ def recvmsg_into( ancbufsize: int | None = None, flags: int = 0, address: Address | None = None, - ): - """ - Receive a message into multiple buffers. + ) -> int: + """Receive a message into multiple buffers. + This is a mock implementation that reads from the MocketSocketIO. + + Args: + buffers: List of buffers to receive into + ancbufsize: Ancillary buffer size (unused) + flags: Flags (unused) + address: Address (unused) + + Returns: + Number of bytes received """ if not buffers: return 0 @@ -254,10 +470,16 @@ def recvfrom_into( buffer: WriteableBuffer, buffersize: int | None = None, flags: int | None = None, - ): - """ - Receive data into a buffer and return the number of bytes received. - This is a mock implementation that reads from the MocketSocketIO. + ) -> tuple[int, _RetAddress]: + """Receive data into a buffer and return the source address. + + Args: + buffer: Buffer to receive into + buffersize: Size to receive + flags: Flags (unused) + + Returns: + Tuple of (bytes_received, source_address) """ return self.recv_into(buffer, buffersize, flags), self._address @@ -267,10 +489,19 @@ def recv_into( buffersize: int | None = None, flags: int | None = None, ) -> int: + """Receive data into a buffer. + + Args: + buffer: Buffer to receive into + buffersize: Number of bytes to receive + flags: Flags (unused) + + Returns: + Number of bytes received + """ if hasattr(buffer, "write"): return buffer.write(self.recv(buffersize)) - # buffer is a memoryview if buffersize is None: buffersize = len(buffer) @@ -282,9 +513,30 @@ def recv_into( def recvfrom( self, buffersize: int, flags: int | None = None ) -> tuple[bytes, _RetAddress]: + """Receive data and the source address. + + Args: + buffersize: Number of bytes to receive + flags: Flags (unused) + + Returns: + Tuple of (data, source_address) + """ return self.recv(buffersize, flags), self._address def recv(self, buffersize: int, flags: int | None = None) -> bytes: + """Receive data from the socket. + + Args: + buffersize: Maximum number of bytes to receive + flags: Flags (unused) + + Returns: + Received bytes + + Raises: + BlockingIOError: If socket is non-blocking and no data available + """ r_fd, _ = Mocket.get_pair((self._host, self._port)) if r_fd: return os.read(r_fd, buffersize) @@ -298,6 +550,19 @@ def recv(self, buffersize: int, flags: int | None = None) -> bytes: raise exc def true_sendall(self, data: bytes, *args: Any, **kwargs: Any) -> bytes: + """Send data through the real socket and receive response. + + Args: + data: Data to send + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + Response bytes from the real socket + + Raises: + StrictMocketException: If operation not allowed in STRICT mode + """ if not MocketMode.is_allowed(self._address): MocketMode.raise_not_allowed(self._address, data) @@ -344,7 +609,17 @@ def send( data: ReadableBuffer, *args: Any, **kwargs: Any, - ) -> int: # pragma: no cover + ) -> int: + """Send data through the socket. + + Args: + data: Data to send + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + Number of bytes sent + """ entry = self.get_entry(data) if not entry or (entry and self._entry != entry): kwargs["entry"] = entry @@ -357,7 +632,11 @@ def send( return len(data) def accept(self) -> tuple[MocketSocket, _RetAddress]: - """Accept a connection and return a new MocketSocket object.""" + """Accept a connection and return a new MocketSocket object. + + Returns: + Tuple of (new_socket, client_address) + """ new_socket = MocketSocket( family=self._family, type=self._type, @@ -369,11 +648,19 @@ def accept(self) -> tuple[MocketSocket, _RetAddress]: return new_socket, (self._host, self._port) def close(self) -> None: + """Close the socket and underlying true socket.""" if self._true_socket and not self._true_socket._closed: self._true_socket.close() def __getattr__(self, name: str) -> Any: - """Do nothing catchall function, for methods like shutdown()""" + """Do-nothing catchall function for methods like shutdown(). + + Args: + name: Method name + + Returns: + A callable that does nothing + """ def do_nothing(*args: Any, **kwargs: Any) -> Any: pass diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index 6d5e7307..aeaab6b5 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -1,3 +1,5 @@ +"""Mocket SSL context implementation.""" + from __future__ import annotations from typing import Any @@ -7,10 +9,13 @@ class _MocketSSLContext: - """For Python 3.6 and newer.""" + """Mock SSL context for Python 3.6 and newer.""" class FakeSetter(int): + """Descriptor that ignores assignment.""" + def __set__(self, *args: Any) -> None: + """Ignore any assignment attempts.""" pass minimum_version = FakeSetter() @@ -20,29 +25,49 @@ def __set__(self, *args: Any) -> None: class MocketSSLContext(_MocketSSLContext): - DUMMY_METHODS = ( + """Mock SSL context that wraps sockets in MocketSSLSocket.""" + + DUMMY_METHODS: tuple = ( "load_default_certs", "load_verify_locations", "set_alpn_protocols", "set_ciphers", "set_default_verify_paths", ) - sock = None - post_handshake_auth = None - _check_hostname = False + sock: MocketSocket | None = None + post_handshake_auth: bool | None = None + _check_hostname: bool = False @property def check_hostname(self) -> bool: + """Get the check_hostname setting. + + Returns: + Always False (mock implementation) + """ return self._check_hostname @check_hostname.setter def check_hostname(self, _: bool) -> None: + """Set the check_hostname setting (mocked). + + Args: + _: Value (ignored, always set to False) + """ self._check_hostname = False def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize the SSL context. + + Args: + *args: Positional arguments (ignored) + **kwargs: Keyword arguments (ignored) + """ self._set_dummy_methods() def _set_dummy_methods(self) -> None: + """Set all dummy methods that do nothing.""" + def dummy_method(*args: Any, **kwargs: Any) -> Any: pass @@ -55,15 +80,36 @@ def wrap_socket( *args: Any, **kwargs: Any, ) -> MocketSSLSocket: + """Wrap a socket in an SSL socket. + + Args: + sock: Socket to wrap + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + MocketSSLSocket instance + """ return MocketSSLSocket._create(sock, *args, **kwargs) def wrap_bio( self, - incoming: Any, # _ssl.MemoryBIO - outgoing: Any, # _ssl.MemoryBIO + incoming: Any, + outgoing: Any, server_side: bool = False, server_hostname: str | bytes | None = None, ) -> MocketSSLSocket: + """Wrap BIO objects in an SSL socket (mock implementation). + + Args: + incoming: Incoming BIO (_ssl.MemoryBIO) + outgoing: Outgoing BIO (_ssl.MemoryBIO) + server_side: Whether this is server side + server_hostname: Server hostname + + Returns: + MocketSSLSocket instance + """ ssl_obj = MocketSSLSocket() ssl_obj._host = server_hostname return ssl_obj @@ -74,5 +120,15 @@ def mock_wrap_socket( *args: Any, **kwargs: Any, ) -> MocketSSLSocket: + """Mock ssl.wrap_socket function. + + Args: + sock: Socket to wrap + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + MocketSSLSocket instance + """ context = MocketSSLContext() return context.wrap_socket(sock, *args, **kwargs) diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py index 6dcd7817..94984fce 100644 --- a/mocket/ssl/socket.py +++ b/mocket/ssl/socket.py @@ -1,3 +1,5 @@ +"""Mocket SSL socket implementation.""" + from __future__ import annotations import ssl @@ -12,14 +14,33 @@ class MocketSSLSocket(MocketSocket): + """Mock SSL socket that extends MocketSocket with SSL-specific behavior.""" + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize an SSL socket. + + Args: + *args: Positional arguments + **kwargs: Keyword arguments + """ super().__init__(*args, **kwargs) - self._did_handshake = False - self._sent_non_empty_bytes = False + self._did_handshake: bool = False + self._sent_non_empty_bytes: bool = False self._original_socket: MocketSocket = self def read(self, buffersize: int | None = None) -> bytes: + """Read data from the SSL socket. + + Args: + buffersize: Maximum bytes to read + + Returns: + Bytes read from the socket + + Raises: + ssl.SSLWantReadError: If handshake not completed and no data + """ rv = self.io.read(buffersize) if rv: self._sent_non_empty_bytes = True @@ -28,12 +49,29 @@ def read(self, buffersize: int | None = None) -> bytes: return rv def write(self, data: bytes) -> int | None: + """Write data to the SSL socket. + + Args: + data: Bytes to write + + Returns: + Number of bytes written + """ return self.send(encode_to_bytes(data)) def do_handshake(self) -> None: + """Perform SSL handshake (mock implementation).""" self._did_handshake = True def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: + """Get the peer certificate (mock implementation). + + Args: + binary_form: Whether to return binary form (unused) + + Returns: + Mock certificate dictionary + """ if not (self._host and self._port): self._address = self._host, self._port = Mocket._address @@ -54,12 +92,27 @@ def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: } def ciper(self) -> tuple[str, str, str]: + """Get cipher information (mock implementation). + + Returns: + Tuple of (cipher_name, protocol, key_exchange_algorithm) + """ return "ADH", "AES256", "SHA" def compression(self) -> Options: + """Get compression options (mock implementation). + + Returns: + SSL options constant + """ return ssl.OP_NO_COMPRESSION def unwrap(self) -> MocketSocket: + """Unwrap the SSL socket and return the underlying socket. + + Returns: + The original MocketSocket + """ return self._original_socket @classmethod @@ -71,6 +124,18 @@ def _create( *args: Any, **kwargs: Any, ) -> MocketSSLSocket: + """Create an SSL socket from a regular socket. + + Args: + sock: Socket to wrap + ssl_context: SSL context (optional) + server_hostname: Server hostname + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + New MocketSSLSocket instance + """ ssl_socket = MocketSSLSocket() ssl_socket._original_socket = sock ssl_socket._true_socket = sock._true_socket diff --git a/mocket/types.py b/mocket/types.py index 562648c7..fedfd37f 100644 --- a/mocket/types.py +++ b/mocket/types.py @@ -1,3 +1,5 @@ +"""Type aliases and definitions for Mocket.""" + from __future__ import annotations from typing import Any, Dict, Tuple, Union diff --git a/mocket/urllib3.py b/mocket/urllib3.py index e89bc7b5..872efc5f 100644 --- a/mocket/urllib3.py +++ b/mocket/urllib3.py @@ -1,3 +1,5 @@ +"""Urllib3 specific socket mocking.""" + from __future__ import annotations from typing import Any @@ -8,6 +10,14 @@ def mock_match_hostname(*args: Any) -> None: + """Mock urllib3's match_hostname function. + + Args: + *args: Ignored arguments + + Returns: + None + """ return None @@ -16,5 +26,15 @@ def mock_ssl_wrap_socket( *args: Any, **kwargs: Any, ) -> MocketSSLSocket: + """Mock urllib3's ssl_wrap_socket function. + + Args: + sock: The socket to wrap + *args: Additional arguments + **kwargs: Additional keyword arguments + + Returns: + MocketSSLSocket instance + """ context = MocketSSLContext() return context.wrap_socket(sock, *args, **kwargs) diff --git a/mocket/utils.py b/mocket/utils.py index 6180ae3f..749b2b70 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -1,3 +1,5 @@ +"""Utility functions for Mocket.""" + from __future__ import annotations import binascii @@ -14,12 +16,13 @@ class MocketizeDecorator(Protocol): - """ + """Protocol for a flexible decorator that can be used in multiple ways. + This is a generic decorator signature, currently applicable to get_mocketize. - Decorators can be used as: + Decorators implementing this protocol can be used as: 1. A function that transforms func (the parameter) into func1 (the returned object). - 2. A function that takes keyword arguments and returns 1. + 2. A function that takes keyword arguments and returns a decorator. """ @overload @@ -32,18 +35,37 @@ def __call__( def hexdump(binary_string: bytes) -> str: - r""" - >>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F")) - True + """Convert binary data to space-separated hex string. + + Args: + binary_string: Binary data to convert + + Returns: + Space-separated hexadecimal representation + + Example: + >>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F")) + True """ bs = decode_from_bytes(binascii.hexlify(binary_string).upper()) return " ".join(a + b for a, b in zip(bs[::2], bs[1::2])) def hexload(string: str) -> bytes: - r""" - >>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo") - True + """Convert space-separated hex string to binary data. + + Args: + string: Space-separated hexadecimal string + + Returns: + Binary data + + Raises: + ValueError: If the hex string is invalid + + Example: + >>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo") + True """ string_no_spaces = "".join(string.split()) try: @@ -53,6 +75,18 @@ def hexload(string: str) -> bytes: def get_mocketize(wrapper_: Callable) -> MocketizeDecorator: + """Get a mocketize decorator from a wrapper function. + + Decorators can be used as: + 1. A function that transforms func (the parameter) into func1 (the returned object). + 2. A function that takes keyword arguments and returns 1. + + Args: + wrapper_: The wrapper function to convert to a decorator + + Returns: + A MocketizeDecorator instance that can be used as a flexible decorator + """ # trying to support different versions of `decorator` with contextlib.suppress(TypeError): return decorator.decorator(wrapper_, kwsyntax=True) # type: ignore[return-value, call-arg, unused-ignore] diff --git a/tests/test_pook.py b/tests/test_pook.py index 56721b5f..012fcdfb 100644 --- a/tests/test_pook.py +++ b/tests/test_pook.py @@ -3,7 +3,6 @@ with contextlib.suppress(ModuleNotFoundError): import pook import requests - from mocket.plugins.pook_mock_engine import MocketEngine pook.set_mock_engine(MocketEngine) diff --git a/tests/test_socket.py b/tests/test_socket.py index dad62a33..68e71aee 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -1,4 +1,6 @@ import socket +import struct +from unittest.mock import MagicMock import pytest @@ -126,3 +128,22 @@ def test_recvfrom_into(): assert nbytes == len(test_data) assert buf[:nbytes] == test_data assert addr == sock._address + + +def test_setsockopt_without_optlen(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + sock._true_socket = MagicMock() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock._true_socket.setsockopt.assert_called_once_with( + socket.SOL_SOCKET, socket.SO_REUSEADDR, 1 + ) + + +def test_setsockopt_with_optlen(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + sock._true_socket = MagicMock() + linger_value = struct.pack("ii", 1, 5) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger_value, len(linger_value)) + sock._true_socket.setsockopt.assert_called_once_with( + socket.SOL_SOCKET, socket.SO_LINGER, linger_value, len(linger_value) + )