From a9ad9f588da20d0cb3611ae7201109bb7510b501 Mon Sep 17 00:00:00 2001
From: Giorgio Salluzzo
Date: Sun, 22 Feb 2026 20:54:59 +0100
Subject: [PATCH 01/14] Adding type hints and docstrings.
---
mocket/__init__.py | 2 +
mocket/compat.py | 36 +++-
mocket/decorators/async_mocket.py | 33 ++-
mocket/decorators/mocketizer.py | 120 +++++++++--
mocket/entry.py | 60 +++++-
mocket/exceptions.py | 7 +
mocket/inject.py | 17 ++
mocket/io.py | 24 ++-
mocket/mocket.py | 86 +++++++-
mocket/mocks/mockhttp.py | 284 +++++++++++++++++---------
mocket/mocks/mockredis.py | 124 ++++++++++--
mocket/mode.py | 23 ++-
mocket/recording.py | 80 +++++++-
mocket/socket.py | 321 ++++++++++++++++++++++++++++--
mocket/ssl/context.py | 70 ++++++-
mocket/ssl/socket.py | 69 ++++++-
mocket/types.py | 2 +
mocket/urllib3.py | 20 ++
mocket/utils.py | 52 ++++-
tests/test_asyncio.py | 4 +-
tests/test_http.py | 2 +-
tests/test_http_httpx.py | 1 -
tests/test_httpretty.py | 3 +-
tests/test_https.py | 2 +-
tests/test_httpx.py | 4 +-
tests/test_mocket.py | 2 +-
tests/test_mode.py | 4 +-
tests/test_pook.py | 1 -
tests/test_redis.py | 2 +-
tests/test_socket.py | 2 +-
tests/test_utils.py | 1 -
31 files changed, 1252 insertions(+), 206 deletions(-)
diff --git a/mocket/__init__.py b/mocket/__init__.py
index 857ed2ed..2103f97e 100644
--- a/mocket/__init__.py
+++ b/mocket/__init__.py
@@ -1,3 +1,5 @@
+"""Mocket - socket mocking library for Python."""
+
import importlib
import sys
diff --git a/mocket/compat.py b/mocket/compat.py
index 1ac2fc89..a8e726f6 100644
--- a/mocket/compat.py
+++ b/mocket/compat.py
@@ -11,23 +11,57 @@
def encode_to_bytes(s: str | bytes, encoding: str = ENCODING) -> bytes:
+ """Encode a string or bytes to bytes.
+
+ Args:
+ s: String or bytes to encode
+ encoding: Encoding to use (default: utf-8 or MOCKET_ENCODING env var)
+
+ Returns:
+ Encoded bytes
+ """
if isinstance(s, str):
s = s.encode(encoding)
return bytes(s)
def decode_from_bytes(s: str | bytes, encoding: str = ENCODING) -> str:
+ """Decode bytes or string to string.
+
+ Args:
+ s: String or bytes to decode
+ encoding: Encoding to use (default: utf-8 or MOCKET_ENCODING env var)
+
+ Returns:
+ Decoded string
+ """
if isinstance(s, bytes):
s = codecs.decode(s, encoding, "ignore")
return str(s)
def shsplit(s: str | bytes) -> list[str]:
+ """Split a shell command string into arguments.
+
+ Args:
+ s: Shell command string or bytes
+
+ Returns:
+ List of shell command arguments
+ """
s = decode_from_bytes(s)
return shlex.split(s)
-def do_the_magic(body):
+def do_the_magic(body: bytes) -> str:
+ """Detect MIME type of binary data using puremagic.
+
+ Args:
+ body: Binary data to analyze
+
+ Returns:
+ MIME type string
+ """
try:
magic = puremagic.magic_string(body)
except puremagic.PureError:
diff --git a/mocket/decorators/async_mocket.py b/mocket/decorators/async_mocket.py
index 3839d5f1..53b966c0 100644
--- a/mocket/decorators/async_mocket.py
+++ b/mocket/decorators/async_mocket.py
@@ -1,15 +1,34 @@
+"""Async version of Mocket decorator."""
+
+from __future__ import annotations
+
+from typing import Any, Callable
+
from mocket.decorators.mocketizer import Mocketizer
from mocket.utils import get_mocketize
async def wrapper(
- test,
- truesocket_recording_dir=None,
- strict_mode=False,
- strict_mode_allowed=None,
- *args,
- **kwargs,
-):
+ test: Callable,
+ truesocket_recording_dir: str | None = None,
+ strict_mode: bool = False,
+ strict_mode_allowed: list | None = None,
+ *args: Any,
+ **kwargs: Any,
+) -> Any:
+ """Async wrapper function for @async_mocketize decorator.
+
+ Args:
+ test: Async test function to wrap
+ truesocket_recording_dir: Directory for recording true socket calls
+ strict_mode: Enable STRICT mode to forbid real socket calls
+ strict_mode_allowed: List of allowed hosts in STRICT mode
+ *args: Test arguments
+ **kwargs: Test keyword arguments
+
+ Returns:
+ Result of the test function
+ """
async with Mocketizer.factory(
test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
):
diff --git a/mocket/decorators/mocketizer.py b/mocket/decorators/mocketizer.py
index fb7c811b..b067ffdf 100644
--- a/mocket/decorators/mocketizer.py
+++ b/mocket/decorators/mocketizer.py
@@ -1,17 +1,34 @@
+"""Mocketizer decorator for managing Mocket lifecycle in tests."""
+
+from __future__ import annotations
+
+from typing import Any, Callable
+
from mocket.mocket import Mocket
from mocket.mode import MocketMode
from mocket.utils import get_mocketize
class Mocketizer:
+ """Context manager and decorator for managing Mocket lifecycle in tests."""
+
def __init__(
self,
- instance=None,
- namespace=None,
- truesocket_recording_dir=None,
- strict_mode=False,
- strict_mode_allowed=None,
- ):
+ instance: Any | None = None,
+ namespace: str | None = None,
+ truesocket_recording_dir: str | None = None,
+ strict_mode: bool = False,
+ strict_mode_allowed: list | None = None,
+ ) -> None:
+ """Initialize the Mocketizer.
+
+ Args:
+ instance: Test instance (optional)
+ namespace: Namespace for recordings
+ truesocket_recording_dir: Directory for recording true socket calls
+ strict_mode: Enable STRICT mode to forbid real socket calls
+ strict_mode_allowed: List of allowed hosts in STRICT mode
+ """
self.instance = instance
self.truesocket_recording_dir = truesocket_recording_dir
self.namespace = namespace or str(id(self))
@@ -23,7 +40,8 @@ def __init__(
"Allowed locations are only accepted when STRICT mode is active."
)
- def enter(self):
+ def enter(self) -> None:
+ """Enter the Mocketizer context (enable Mocket)."""
Mocket.enable(
namespace=self.namespace,
truesocket_recording_dir=self.truesocket_recording_dir,
@@ -31,33 +49,80 @@ def enter(self):
if self.instance:
self.check_and_call("mocketize_setup")
- def __enter__(self):
+ def __enter__(self) -> Mocketizer:
+ """Enter context manager.
+
+ Returns:
+ Self for use in `with` statements
+ """
self.enter()
return self
- def exit(self):
+ def exit(self) -> None:
+ """Exit the Mocketizer context (disable Mocket)."""
if self.instance:
self.check_and_call("mocketize_teardown")
Mocket.disable()
- def __exit__(self, type, value, tb):
+ def __exit__(self, type: Any, value: Any, tb: Any) -> None:
+ """Exit context manager.
+
+ Args:
+ type: Exception type
+ value: Exception value
+ tb: Traceback
+ """
self.exit()
- async def __aenter__(self, *args, **kwargs):
+ async def __aenter__(self, *args: Any, **kwargs: Any) -> Mocketizer:
+ """Enter async context manager.
+
+ Returns:
+ Self for use in `async with` statements
+ """
self.enter()
return self
- async def __aexit__(self, *args, **kwargs):
+ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
+ """Exit async context manager.
+
+ Args:
+ *args: Exception arguments
+ **kwargs: Exception keyword arguments
+ """
self.exit()
- def check_and_call(self, method_name):
+ def check_and_call(self, method_name: str) -> None:
+ """Check if instance has a method and call it.
+
+ Args:
+ method_name: Name of method to check and call
+ """
method = getattr(self.instance, method_name, None)
if callable(method):
method()
@staticmethod
- def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args):
+ def factory(
+ test: Callable,
+ truesocket_recording_dir: str | None,
+ strict_mode: bool,
+ strict_mode_allowed: list | None,
+ args: tuple,
+ ) -> Mocketizer:
+ """Create a Mocketizer instance for a test function.
+
+ Args:
+ test: Test function being decorated
+ truesocket_recording_dir: Recording directory
+ strict_mode: Enable STRICT mode
+ strict_mode_allowed: Allowed hosts in STRICT mode
+ args: Positional arguments to test
+
+ Returns:
+ Configured Mocketizer instance
+ """
instance = args[0] if args else None
namespace = None
if truesocket_recording_dir:
@@ -79,13 +144,26 @@ def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, ar
def wrapper(
- test,
- truesocket_recording_dir=None,
- strict_mode=False,
- strict_mode_allowed=None,
- *args,
- **kwargs,
-):
+ test: Callable,
+ truesocket_recording_dir: str | None = None,
+ strict_mode: bool = False,
+ strict_mode_allowed: list | None = None,
+ *args: Any,
+ **kwargs: Any,
+) -> Any:
+ """Wrapper function for @mocketize decorator.
+
+ Args:
+ test: Test function to wrap
+ truesocket_recording_dir: Recording directory
+ strict_mode: Enable STRICT mode
+ strict_mode_allowed: Allowed hosts in STRICT mode
+ *args: Test arguments
+ **kwargs: Test keyword arguments
+
+ Returns:
+ Result of the test function
+ """
with Mocketizer.factory(
test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args
):
diff --git a/mocket/entry.py b/mocket/entry.py
index 9dbbf442..2d618472 100644
--- a/mocket/entry.py
+++ b/mocket/entry.py
@@ -1,22 +1,38 @@
+"""Mocket entry base class for registering mock responses."""
+
+from __future__ import annotations
+
import collections.abc
+from typing import Any
from mocket.compat import encode_to_bytes
from mocket.mocket import Mocket
class MocketEntry:
+ """Base class for Mocket entries that match requests and return responses."""
+
class Response(bytes):
+ """Response wrapper class that extends bytes."""
+
@property
- def data(self):
+ def data(self) -> bytes:
+ """Get the response data."""
return self
- response_index = 0
- request_cls = bytes
- response_cls = Response
- responses = None
- _served = None
+ response_index: int = 0
+ request_cls: type = bytes
+ response_cls: type = Response
+ responses: list | None = None
+ _served: bool | None = None
+
+ def __init__(self, location: tuple, responses: Any) -> None:
+ """Initialize a Mocket entry.
- def __init__(self, location, responses):
+ Args:
+ location: Tuple of (host, port)
+ responses: Single response or list of responses to cycle through
+ """
self._served = False
self.location = location
@@ -34,18 +50,40 @@ def __init__(self, location, responses):
r = self.response_cls(r)
self.responses.append(r)
- def __repr__(self):
+ def __repr__(self) -> str:
+ """Return a string representation of the entry."""
return f"{self.__class__.__name__}(location={self.location})"
@staticmethod
- def can_handle(data):
+ def can_handle(data: bytes) -> bool:
+ """Check if this entry can handle the given request data.
+
+ Args:
+ data: Request data to check
+
+ Returns:
+ True if this entry can handle the request, False otherwise
+ """
return True
- def collect(self, data):
+ def collect(self, data: bytes) -> None:
+ """Collect the request data in the Mocket singleton.
+
+ Args:
+ data: Request data to collect
+ """
req = self.request_cls(data)
Mocket.collect(req)
- def get_response(self):
+ def get_response(self) -> bytes:
+ """Get the next response to send.
+
+ Returns:
+ Response bytes to send to the client
+
+ Raises:
+ BaseException: If a response is an exception, it will be raised
+ """
response = self.responses[self.response_index]
if self.response_index < len(self.responses) - 1:
self.response_index += 1
diff --git a/mocket/exceptions.py b/mocket/exceptions.py
index f5537568..db78dbf5 100644
--- a/mocket/exceptions.py
+++ b/mocket/exceptions.py
@@ -1,6 +1,13 @@
+"""Mocket exception classes."""
+
+
class MocketException(Exception):
+ """Base exception class for Mocket errors."""
+
pass
class StrictMocketException(MocketException):
+ """Exception raised when a socket operation is not allowed in STRICT mode."""
+
pass
diff --git a/mocket/inject.py b/mocket/inject.py
index 866ee563..e788a929 100644
--- a/mocket/inject.py
+++ b/mocket/inject.py
@@ -1,3 +1,5 @@
+"""Socket patching and restoration for Mocket injection."""
+
from __future__ import annotations
import contextlib
@@ -12,17 +14,31 @@
def _patch(module: ModuleType, name: str, patched_value: Any) -> None:
+ """Patch a module with a new value and store the original.
+
+ Args:
+ module: Module to patch
+ name: Attribute name to patch
+ patched_value: New value to set
+ """
with contextlib.suppress(KeyError):
original_value, module.__dict__[name] = module.__dict__[name], patched_value
_patches_restore[(module, name)] = original_value
def _restore(module: ModuleType, name: str) -> None:
+ """Restore a module's original attribute value.
+
+ Args:
+ module: Module to restore
+ name: Attribute name to restore
+ """
if original_value := _patches_restore.pop((module, name)):
module.__dict__[name] = original_value
def enable() -> None:
+ """Enable Mocket by patching socket, ssl, and urllib3 modules."""
from mocket.socket import (
MocketSocket,
mock_create_connection,
@@ -71,6 +87,7 @@ def enable() -> None:
def disable() -> None:
+ """Disable Mocket by restoring all patched modules."""
for module, name in list(_patches_restore.keys()):
_restore(module, name)
diff --git a/mocket/io.py b/mocket/io.py
index 0334410b..e815e0ec 100644
--- a/mocket/io.py
+++ b/mocket/io.py
@@ -1,3 +1,7 @@
+"""Mocket socket I/O implementation."""
+
+from __future__ import annotations
+
import io
import os
@@ -5,13 +9,29 @@
class MocketSocketIO(io.BytesIO):
- def __init__(self, address) -> None:
+ """A BytesIO wrapper that integrates with Mocket's pipe-based I/O."""
+
+ def __init__(self, address: tuple) -> None:
+ """Initialize the socket I/O with a socket address.
+
+ Args:
+ address: Tuple of (host, port)
+ """
self._address = address
super().__init__()
- def write(self, content):
+ def write(self, content: bytes) -> int:
+ """Write content to the buffer and the pipe if available.
+
+ Args:
+ content: Bytes to write
+
+ Returns:
+ Number of bytes written
+ """
super().write(content)
_, w_fd = Mocket.get_pair(self._address)
if w_fd:
os.write(w_fd, content)
+ return len(content)
diff --git a/mocket/mocket.py b/mocket/mocket.py
index a8dc7997..75ae6285 100644
--- a/mocket/mocket.py
+++ b/mocket/mocket.py
@@ -1,3 +1,5 @@
+"""Core Mocket singleton for socket mocking management."""
+
from __future__ import annotations
import collections
@@ -18,6 +20,8 @@
class Mocket:
+ """Singleton class managing all mock socket operations and entries."""
+
_socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {}
_address: ClassVar[Address | tuple[None, None]] = (None, None)
_entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list)
@@ -30,6 +34,12 @@ def enable(
namespace: str | None = None,
truesocket_recording_dir: str | None = None,
) -> None:
+ """Enable Mocket socket mocking.
+
+ Args:
+ namespace: Namespace for recording storage (defaults to id of _entries)
+ truesocket_recording_dir: Directory to store recorded requests/responses
+ """
if namespace is None:
namespace = str(id(cls._entries))
@@ -47,33 +57,61 @@ def enable(
@classmethod
def disable(cls) -> None:
+ """Disable Mocket socket mocking and clean up resources."""
cls.reset()
mocket.inject.disable()
@classmethod
def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]:
- """
+ """Get the file descriptor pair for a socket address.
+
Given the id() of the caller, return a pair of file descriptors
as a tuple of two integers: (, )
+
+ Args:
+ address: (host, port) tuple
+
+ Returns:
+ Tuple of (read_fd, write_fd) or (None, None) if not found
"""
return cls._socket_pairs.get(address, (None, None))
@classmethod
def set_pair(cls, address: Address, pair: tuple[int, int]) -> None:
- """
- Store a pair of file descriptors under the key `id_`
+ """Store a file descriptor pair for a socket address.
+
+ Store a pair of file descriptors under the key `address`
as a tuple of two integers: (, )
+
+ Args:
+ address: (host, port) tuple
+ pair: Tuple of (read_fd, write_fd)
"""
cls._socket_pairs[address] = pair
@classmethod
def register(cls, *entries: MocketEntry) -> None:
+ """Register mock entries with Mocket.
+
+ Args:
+ *entries: Variable number of MocketEntry instances to register
+ """
for entry in entries:
cls._entries[entry.location].append(entry)
@classmethod
- def get_entry(cls, host: str, port: int, data) -> MocketEntry | None:
+ def get_entry(cls, host: str, port: int, data: Any) -> MocketEntry | None:
+ """Get a matching entry for the given request data.
+
+ Args:
+ host: Hostname
+ port: Port number
+ data: Request data
+
+ Returns:
+ Matching MocketEntry or None
+ """
host = host or cls._address[0]
port = port or cls._address[1]
entries = cls._entries.get((host, port), [])
@@ -83,11 +121,17 @@ def get_entry(cls, host: str, port: int, data) -> MocketEntry | None:
return None
@classmethod
- def collect(cls, data) -> None:
+ def collect(cls, data: Any) -> None:
+ """Collect a request in the list of all requests.
+
+ Args:
+ data: Request data to collect
+ """
cls._requests.append(data)
@classmethod
def reset(cls) -> None:
+ """Reset all Mocket state and clean up file descriptors."""
for r_fd, w_fd in cls._socket_pairs.values():
os.close(r_fd)
os.close(w_fd)
@@ -98,32 +142,62 @@ def reset(cls) -> None:
@classmethod
def last_request(cls) -> Any:
+ """Get the last request made.
+
+ Returns:
+ Last request data or None if no requests
+ """
if cls.has_requests():
return cls._requests[-1]
@classmethod
def request_list(cls) -> list[Any]:
+ """Get the list of all requests.
+
+ Returns:
+ List of all collected requests
+ """
return cls._requests
@classmethod
def remove_last_request(cls) -> None:
+ """Remove the last request from the request list."""
if cls.has_requests():
del cls._requests[-1]
@classmethod
def has_requests(cls) -> bool:
+ """Check if any requests have been made.
+
+ Returns:
+ True if there are requests, False otherwise
+ """
return bool(cls.request_list())
@classmethod
def get_namespace(cls) -> str | None:
+ """Get the recording namespace.
+
+ Returns:
+ Namespace string or None if recording is not enabled
+ """
return cls._record_storage.namespace if cls._record_storage else None
@classmethod
def get_truesocket_recording_dir(cls) -> str | None:
+ """Get the true socket recording directory.
+
+ Returns:
+ Directory path as string or None if recording is not enabled
+ """
return str(cls._record_storage.directory) if cls._record_storage else None
@classmethod
def assert_fail_if_entries_not_served(cls) -> None:
- """Mocket checks that all entries have been served at least once."""
+ """Assert that all registered entries have been served at least once.
+
+ Raises:
+ AssertionError: If any entries have not been served
+ """
if not all(entry._served for entry in itertools.chain(*cls._entries.values())):
raise AssertionError("Some Mocket entries have not been served")
diff --git a/mocket/mocks/mockhttp.py b/mocket/mocks/mockhttp.py
index 50a6f952..e7e5a7b9 100644
--- a/mocket/mocks/mockhttp.py
+++ b/mocket/mocks/mockhttp.py
@@ -1,8 +1,12 @@
+"""HTTP mocking implementation for Mocket."""
+
+from __future__ import annotations
+
import re
import time
from functools import cached_property
from http.server import BaseHTTPRequestHandler
-from typing import Callable, Optional
+from typing import Any, Callable
from urllib.parse import parse_qs, unquote, urlsplit
from h11 import SERVER, Connection, Data
@@ -12,42 +16,79 @@
from mocket.entry import MocketEntry
from mocket.mocket import Mocket
-STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()}
-CRLF = "\r\n"
-ASCII = "ascii"
+STATUS: dict = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()}
+CRLF: str = "\r\n"
+ASCII: str = "ascii"
class Request:
- _parser = None
- _event = None
+ """HTTP request parser using h11."""
+
+ _parser: Connection | None = None
+ _event: Any | None = None
+
+ def __init__(self, data: bytes) -> None:
+ """Initialize the request parser.
- def __init__(self, data):
+ Args:
+ data: Raw HTTP request data
+ """
self._parser = Connection(SERVER)
self.add_data(data)
- def add_data(self, data):
+ def add_data(self, data: bytes) -> None:
+ """Add more data to the request.
+
+ Args:
+ data: Additional raw request data
+ """
self._parser.receive_data(data)
@property
- def event(self):
+ def event(self) -> Any:
+ """Get the parsed request event.
+
+ Returns:
+ The h11 request event
+ """
if not self._event:
self._event = self._parser.next_event()
return self._event
@cached_property
- def method(self):
+ def method(self) -> str:
+ """Get the HTTP method.
+
+ Returns:
+ HTTP method (GET, POST, etc.)
+ """
return self.event.method.decode(ASCII)
@cached_property
- def path(self):
+ def path(self) -> str:
+ """Get the request path.
+
+ Returns:
+ Request path with query string
+ """
return self.event.target.decode(ASCII)
@cached_property
- def headers(self):
+ def headers(self) -> dict:
+ """Get the request headers.
+
+ Returns:
+ Dictionary of header names to values
+ """
return {k.decode(ASCII): v.decode(ASCII) for k, v in self.event.headers}
@cached_property
- def querystring(self):
+ def querystring(self) -> dict:
+ """Get the parsed query string.
+
+ Returns:
+ Dictionary of query parameter names to lists of values
+ """
parts = self.path.split("?", 1)
return (
parse_qs(unquote(parts[1]), keep_blank_values=True)
@@ -56,7 +97,12 @@ def querystring(self):
)
@cached_property
- def body(self):
+ def body(self) -> str:
+ """Get the request body.
+
+ Returns:
+ Decoded request body string
+ """
while True:
event = self._parser.next_event()
if isinstance(event, H11Request):
@@ -64,15 +110,31 @@ def body(self):
elif isinstance(event, Data):
return event.data.decode(ENCODING)
- def __str__(self):
+ def __str__(self) -> str:
+ """Get string representation of request.
+
+ Returns:
+ Formatted request string
+ """
return f"{self.method} - {self.path} - {self.headers}"
class Response:
- headers = None
- is_file_object = False
+ """HTTP response builder."""
+
+ headers: dict | None = None
+ is_file_object: bool = False
- def __init__(self, body="", status=200, headers=None):
+ def __init__(
+ self, body: Any = "", status: int = 200, headers: dict | None = None
+ ) -> None:
+ """Initialize an HTTP response.
+
+ Args:
+ body: Response body (string, bytes, or file-like object)
+ status: HTTP status code
+ headers: Dictionary of response headers
+ """
headers = headers or {}
try:
# File Objects
@@ -88,6 +150,14 @@ def __init__(self, body="", status=200, headers=None):
self.data = self.get_protocol_data() + self.body
def get_protocol_data(self, str_format_fun_name: str = "capitalize") -> bytes:
+ """Get the HTTP protocol headers and status line.
+
+ Args:
+ str_format_fun_name: Name of string formatting method to use
+
+ Returns:
+ Bytes of protocol headers (status line and headers)
+ """
status_line = f"HTTP/1.1 {self.status} {STATUS[self.status]}"
header_lines = CRLF.join(
(
@@ -97,7 +167,8 @@ def get_protocol_data(self, str_format_fun_name: str = "capitalize") -> bytes:
)
return f"{status_line}\r\n{header_lines}\r\n\r\n".encode(ENCODING)
- def set_base_headers(self):
+ def set_base_headers(self) -> None:
+ """Set the base response headers."""
self.headers = {
"Status": str(self.status),
"Date": time.strftime("%a, %d %b %Y %H:%M:%S GMT", time.gmtime()),
@@ -110,22 +181,19 @@ def set_base_headers(self):
else:
self.headers["Content-Type"] = do_the_magic(self.body)
- def set_extra_headers(self, headers):
- r"""
- >>> r = Response(body="")
- >>> len(r.headers.keys())
- 6
- >>> r.set_extra_headers({"foo-bar": "Foobar"})
- >>> len(r.headers.keys())
- 7
- >>> encode_to_bytes(r.headers.get("Foo-Bar")) == encode_to_bytes("Foobar")
- True
+ def set_extra_headers(self, headers: dict) -> None:
+ """Add extra headers to the response.
+
+ Args:
+ headers: Dictionary of additional headers
"""
for k, v in headers.items():
self.headers["-".join(token.capitalize() for token in k.split("-"))] = v
class Entry(MocketEntry):
+ """HTTP entry for matching and responding to HTTP requests."""
+
CONNECT = "CONNECT"
DELETE = "DELETE"
GET = "GET"
@@ -136,22 +204,31 @@ class Entry(MocketEntry):
PUT = "PUT"
TRACE = "TRACE"
- METHODS = (CONNECT, DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT, TRACE)
+ METHODS: tuple = (CONNECT, DELETE, GET, HEAD, OPTIONS, PATCH, POST, PUT, TRACE)
- request_cls = Request
- response_cls = Response
+ request_cls: type = Request
+ response_cls: type = Response
- default_config = {"match_querystring": True, "can_handle_fun": None}
- _can_handle_fun: Optional[Callable] = None
+ default_config: dict = {"match_querystring": True, "can_handle_fun": None}
+ _can_handle_fun: Callable | None = None
def __init__(
self,
- uri,
- method,
- responses,
+ uri: str,
+ method: str,
+ responses: Any,
match_querystring: bool = True,
- can_handle_fun: Optional[Callable] = None,
- ):
+ can_handle_fun: Callable | None = None,
+ ) -> None:
+ """Initialize an HTTP entry.
+
+ Args:
+ uri: URI to match (http://host:port/path?query)
+ method: HTTP method (GET, POST, etc.)
+ responses: Response(s) to return
+ match_querystring: Whether to match query strings
+ can_handle_fun: Custom matching function
+ """
self._can_handle_fun = can_handle_fun if can_handle_fun else self._can_handle
uri = urlsplit(uri)
@@ -168,10 +245,23 @@ def __init__(
self._sent_data = b""
self._match_querystring = match_querystring
- def __repr__(self):
+ def __repr__(self) -> str:
+ """Get string representation of the entry.
+
+ Returns:
+ String representation
+ """
return f"{self.__class__.__name__}(method={self.method!r}, schema={self.schema!r}, location={self.location!r}, path={self.path!r}, query={self.query!r})"
- def collect(self, data):
+ def collect(self, data: bytes) -> bool:
+ """Collect the request data.
+
+ Args:
+ data: Request data
+
+ Returns:
+ Whether to consume the response
+ """
consume_response = True
decoded_data = decode_from_bytes(data)
@@ -187,9 +277,14 @@ def collect(self, data):
return consume_response
def _can_handle(self, path: str, qs_dict: dict) -> bool:
- """
- The default can_handle function, which checks if the path match,
- and if match_querystring is True, also checks if the querystring matches.
+ """Default can_handle function checking path and query string.
+
+ Args:
+ path: Request path
+ qs_dict: Parsed query string parameters
+
+ Returns:
+ True if this entry can handle the request
"""
can_handle = path == self.path
if self._match_querystring:
@@ -198,14 +293,14 @@ def _can_handle(self, path: str, qs_dict: dict) -> bool:
)
return can_handle
- def can_handle(self, data):
- r"""
- >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b''),))
- >>> e.can_handle(b'GET /?bar=foo HTTP/1.1\r\nHost: github.com\r\nAccept-Encoding: gzip, deflate\r\nConnection: keep-alive\r\nUser-Agent: python-requests/2.7.0 CPython/3.4.3 Linux/3.19.0-16-generic\r\nAccept: */*\r\n\r\n')
- False
- >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b''),))
- >>> e.can_handle(b'GET /?bar=foo&foobar HTTP/1.1\r\nHost: github.com\r\nAccept-Encoding: gzip, deflate\r\nConnection: keep-alive\r\nUser-Agent: python-requests/2.7.0 CPython/3.4.3 Linux/3.19.0-16-generic\r\nAccept: */*\r\n\r\n')
- True
+ def can_handle(self, data: bytes) -> bool:
+ """Check if this entry can handle the given request data.
+
+ Args:
+ data: Request data
+
+ Returns:
+ True if this entry can handle the request
"""
try:
requestline, _ = decode_from_bytes(data).split(CRLF, 1)
@@ -224,18 +319,17 @@ def can_handle(self, data):
return can_handle
@staticmethod
- def _parse_requestline(line):
- """
- http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5
-
- >>> Entry._parse_requestline('GET / HTTP/1.0') == ('GET', '/', '1.0')
- True
- >>> Entry._parse_requestline('post /testurl htTP/1.1') == ('POST', '/testurl', '1.1')
- True
- >>> Entry._parse_requestline('Im not a RequestLine')
- Traceback (most recent call last):
- ...
- ValueError: Not a Request-Line
+ def _parse_requestline(line: str) -> tuple:
+ """Parse an HTTP request line.
+
+ Args:
+ line: HTTP request line string
+
+ Returns:
+ Tuple of (method, path, version)
+
+ Raises:
+ ValueError: If line is not a valid request line
"""
m = re.match(
r"({})\s+(.*)\s+HTTP/(1.[0|1])".format("|".join(Entry.METHODS)), line, re.I
@@ -245,7 +339,19 @@ def _parse_requestline(line):
raise ValueError("Not a Request-Line")
@classmethod
- def register(cls, method, uri, *responses, **config):
+ def register(cls, method: str, uri: str, *responses: Any, **config: Any) -> None:
+ """Register an HTTP entry for multiple responses.
+
+ Args:
+ method: HTTP method (GET, POST, etc.)
+ uri: URI to match
+ *responses: Response(s) to cycle through
+ **config: Configuration options (match_querystring, can_handle_fun)
+
+ Raises:
+ AttributeError: If using body/status params (use single_register instead)
+ KeyError: If invalid config keys provided
+ """
if "body" in config or "status" in config:
raise AttributeError("Did you mean `Entry.single_register(...)`?")
@@ -262,33 +368,31 @@ def register(cls, method, uri, *responses, **config):
@classmethod
def single_register(
cls,
- method,
- uri,
- body="",
- status=200,
- headers=None,
- exception=None,
- match_querystring=True,
- can_handle_fun=None,
- **config,
- ):
- """
- A helper method to register a single Response for a given URI and method.
- Instead of passing a list of Response objects, you can just pass the response
- parameters directly.
+ method: str,
+ uri: str,
+ body: Any = "",
+ status: int = 200,
+ headers: dict | None = None,
+ exception: Exception | None = None,
+ match_querystring: bool = True,
+ can_handle_fun: Callable | None = None,
+ **config: Any,
+ ) -> None:
+ """Register a single HTTP response for a URI and method.
+
+ This is a convenience method that creates a single Response object
+ instead of requiring a list.
Args:
- method (str): The HTTP method (e.g., 'GET', 'POST').
- uri (str): The URI to register the response for.
- body (str, optional): The body of the response. Defaults to an empty string.
- status (int, optional): The HTTP status code. Defaults to 200.
- headers (dict, optional): A dictionary of headers to include in the response. Defaults to None.
- exception (Exception, optional): An exception to raise instead of returning a response. Defaults to None.
- match_querystring (bool, optional): Whether to match the querystring in the URI. Defaults to True.
- can_handle_fun (Callable, optional): A custom function to determine if the Entry can handle a request.
- Defaults to None. If None, the default matching logic is used. The function should accept two parameters:
- path (str), and querystring params (dict), and return a boolean. Method is matched before the function call.
- **config: Additional configuration options.
+ method: HTTP method (GET, POST, etc.)
+ uri: URI to match
+ body: Response body content
+ status: HTTP status code
+ headers: Dictionary of response headers
+ exception: Exception to raise instead of returning response
+ match_querystring: Whether to match query strings
+ can_handle_fun: Custom matching function
+ **config: Additional configuration options
"""
response = (
exception
diff --git a/mocket/mocks/mockredis.py b/mocket/mocks/mockredis.py
index fc386e2d..eee2d6c8 100644
--- a/mocket/mocks/mockredis.py
+++ b/mocket/mocks/mockredis.py
@@ -1,4 +1,9 @@
+"""Redis mocking implementation for Mocket."""
+
+from __future__ import annotations
+
from itertools import chain
+from typing import Any
from mocket.compat import (
decode_from_bytes,
@@ -7,29 +12,63 @@
)
from mocket.entry import MocketEntry
from mocket.mocket import Mocket
+from mocket.types import Address
class Request:
- def __init__(self, data):
+ """Redis request wrapper."""
+
+ def __init__(self, data: bytes) -> None:
+ """Initialize a Redis request.
+
+ Args:
+ data: Raw Redis command data
+ """
self.data = data
class Response:
- def __init__(self, data=None):
+ """Redis response wrapper."""
+
+ def __init__(self, data: Any = None) -> None:
+ """Initialize a Redis response.
+
+ Args:
+ data: Response data (will be "redisize"d)
+ """
self.data = Redisizer.redisize(data or OK)
class Redisizer(bytes):
+ """Convert Python types to Redis protocol format."""
+
@staticmethod
- def tokens(iterable):
+ def tokens(iterable: list[Any]) -> list[bytes]:
+ """Convert an iterable to Redis tokens.
+
+ Args:
+ iterable: List of items to convert
+
+ Returns:
+ List of Redis protocol bytes
+ """
iterable = [encode_to_bytes(x) for x in iterable]
return [f"*{len(iterable)}".encode()] + list(
chain(*zip([f"${len(x)}".encode() for x in iterable], iterable))
)
@staticmethod
- def redisize(data):
- def get_conversion(t):
+ def redisize(data: Any) -> Redisizer:
+ """Convert Python data to Redis protocol format.
+
+ Args:
+ data: Python data to convert
+
+ Returns:
+ Redisizer bytes
+ """
+
+ def get_conversion(t: type) -> Any:
return {
dict: lambda x: b"\r\n".join(
Redisizer.tokens(list(chain(*tuple(x.items()))))
@@ -48,11 +87,28 @@ def get_conversion(t):
return Redisizer(get_conversion(data.__class__)(data) + b"\r\n")
@staticmethod
- def command(description, _type="+"):
+ def command(description: str, _type: str = "+") -> Redisizer:
+ """Create a Redis command response.
+
+ Args:
+ description: Response description
+ _type: Response type prefix (+, -, :, $, *)
+
+ Returns:
+ Formatted Redis response
+ """
return Redisizer("{}{}{}".format(_type, description, "\r\n").encode("utf-8"))
@staticmethod
- def error(description):
+ def error(description: str) -> Redisizer:
+ """Create a Redis error response.
+
+ Args:
+ description: Error description
+
+ Returns:
+ Formatted Redis error response
+ """
return Redisizer.command(description, _type="-")
@@ -62,20 +118,46 @@ def error(description):
class Entry(MocketEntry):
+ """Redis entry for matching and responding to Redis commands."""
+
request_cls = Request
response_cls = Response
- def __init__(self, addr, command, responses):
+ def __init__(
+ self, addr: Address | None, command: str, responses: list[Any]
+ ) -> None:
+ """Initialize a Redis entry.
+
+ Args:
+ addr: (host, port) tuple or None for default
+ command: Redis command string to match
+ responses: List of responses to cycle through
+ """
super().__init__(addr or ("localhost", 6379), responses)
d = shsplit(command)
d[0] = d[0].upper()
self.command = Redisizer.tokens(d)
- def can_handle(self, data):
+ def can_handle(self, data: bytes) -> bool:
+ """Check if this entry can handle the given command.
+
+ Args:
+ data: Raw Redis command data
+
+ Returns:
+ True if this entry matches the command
+ """
return data.splitlines() == self.command
@classmethod
- def register(cls, addr, command, *responses):
+ def register(cls, addr: Address | None, command: str, *responses: Any) -> None:
+ """Register a Redis entry.
+
+ Args:
+ addr: (host, port) tuple or None for default
+ command: Redis command to match
+ *responses: Responses to cycle through
+ """
responses = [
r if isinstance(r, BaseException) else cls.response_cls(r)
for r in responses
@@ -83,9 +165,27 @@ def register(cls, addr, command, *responses):
Mocket.register(cls(addr, command, responses))
@classmethod
- def register_response(cls, command, response, addr=None):
+ def register_response(
+ cls, command: str, response: Any, addr: Address | None = None
+ ) -> None:
+ """Register a single response for a command.
+
+ Args:
+ command: Redis command to match
+ response: Response to return
+ addr: (host, port) tuple or None for default
+ """
cls.register(addr, command, response)
@classmethod
- def register_responses(cls, command, responses, addr=None):
+ def register_responses(
+ cls, command: str, responses: list[Any], addr: Address | None = None
+ ) -> None:
+ """Register multiple responses for a command.
+
+ Args:
+ command: Redis command to match
+ responses: List of responses to cycle through
+ addr: (host, port) tuple or None for default
+ """
cls.register(addr, command, *responses)
diff --git a/mocket/mode.py b/mocket/mode.py
index ac2ca16a..ffb23a44 100644
--- a/mocket/mode.py
+++ b/mocket/mode.py
@@ -1,3 +1,5 @@
+"""Mocket mode management for strict socket enforcement."""
+
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar
@@ -10,17 +12,27 @@
class _MocketMode:
+ """Singleton class for managing Mocket's strict mode enforcement."""
+
__shared_state: ClassVar[dict[str, Any]] = {}
STRICT: ClassVar = None
STRICT_ALLOWED: ClassVar = None
def __init__(self) -> None:
+ """Initialize the MocketMode singleton with shared state."""
self.__dict__ = self.__shared_state
def is_allowed(self, location: str | tuple[str, int]) -> bool:
- """
+ """Check if a location is allowed to perform real socket calls.
+
Checks if (`host`, `port`) or at least `host`
are allowed locations to perform real `socket` calls
+
+ Args:
+ location: Hostname string or (host, port) tuple
+
+ Returns:
+ True if the location is allowed, False if in STRICT mode and not allowed
"""
if not self.STRICT:
return True
@@ -35,6 +47,15 @@ def raise_not_allowed(
address: tuple[str, int] | None = None,
data: bytes | None = None,
) -> NoReturn:
+ """Raise an exception when a socket operation is not allowed in STRICT mode.
+
+ Args:
+ address: The (host, port) tuple that was attempted
+ data: The request data that was sent
+
+ Raises:
+ StrictMocketException: Always raised with detailed context
+ """
current_entries = [
(location, "\n ".join(map(str, entries)))
for location, entries in Mocket._entries.items()
diff --git a/mocket/recording.py b/mocket/recording.py
index 97d2adbe..95faf126 100644
--- a/mocket/recording.py
+++ b/mocket/recording.py
@@ -1,3 +1,5 @@
+"""Request/response recording for playback during tests."""
+
from __future__ import annotations
import contextlib
@@ -6,12 +8,13 @@
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
+from typing import Any
from mocket.compat import decode_from_bytes, encode_to_bytes
from mocket.types import Address
from mocket.utils import hexdump, hexload
-hash_function = hashlib.md5
+hash_function: Any = hashlib.md5
with contextlib.suppress(ImportError):
from xxhash_cffi import xxh32 as xxhash_cffi_xxh32
@@ -25,22 +28,48 @@
def _hash_prepare_request(data: bytes) -> bytes:
+ """Prepare request data for hashing by sorting headers.
+
+ Args:
+ data: Raw request data
+
+ Returns:
+ Prepared bytes for hashing
+ """
_data = decode_from_bytes(data)
return encode_to_bytes("".join(sorted(_data.split("\r\n"))))
def _hash_request(data: bytes) -> str:
+ """Hash a request using the best available hash function.
+
+ Args:
+ data: Raw request data
+
+ Returns:
+ Hex digest of the hash
+ """
_data = _hash_prepare_request(data)
return hash_function(_data).hexdigest()
def _hash_request_fallback(data: bytes) -> str:
+ """Hash a request using MD5 as fallback.
+
+ Args:
+ data: Raw request data
+
+ Returns:
+ Hex digest of the MD5 hash
+ """
_data = _hash_prepare_request(data)
return hashlib.md5(_data).hexdigest()
@dataclass
class MocketRecord:
+ """A record of a request and its corresponding response."""
+
host: str
port: int
request: bytes
@@ -48,7 +77,15 @@ class MocketRecord:
class MocketRecordStorage:
+ """Storage for recording and retrieving request/response pairs."""
+
def __init__(self, directory: Path, namespace: str) -> None:
+ """Initialize the record storage.
+
+ Args:
+ directory: Path to directory for storing recordings
+ namespace: Namespace for grouping records
+ """
self._directory = directory
self._namespace = namespace
self._records: defaultdict[Address, defaultdict[str, MocketRecord]] = (
@@ -59,17 +96,33 @@ def __init__(self, directory: Path, namespace: str) -> None:
@property
def directory(self) -> Path:
+ """Get the recording directory.
+
+ Returns:
+ Path to recording directory
+ """
return self._directory
@property
def namespace(self) -> str:
+ """Get the recording namespace.
+
+ Returns:
+ Namespace string
+ """
return self._namespace
@property
def file(self) -> Path:
+ """Get the path to the namespace's JSON file.
+
+ Returns:
+ Path to JSON recording file
+ """
return self._directory / f"{self._namespace}.json"
def _load(self) -> None:
+ """Load recordings from disk."""
if not self.file.exists():
return
@@ -92,6 +145,7 @@ def _load(self) -> None:
)
def _save(self) -> None:
+ """Save recordings to disk."""
data: dict[str, dict[str, dict[str, dict[str, str]]]] = defaultdict(
lambda: defaultdict(defaultdict)
)
@@ -108,9 +162,26 @@ def _save(self) -> None:
self.file.write_text(json_data)
def get_records(self, address: Address) -> list[MocketRecord]:
+ """Get all records for an address.
+
+ Args:
+ address: (host, port) tuple
+
+ Returns:
+ List of MocketRecord instances
+ """
return list(self._records[address].values())
def get_record(self, address: Address, request: bytes) -> MocketRecord | None:
+ """Get a specific record matching the request.
+
+ Args:
+ address: (host, port) tuple
+ request: Request bytes
+
+ Returns:
+ Matching MocketRecord or None
+ """
# NOTE for backward-compat
request_signature_fallback = _hash_request_fallback(request)
if request_signature_fallback in self._records[address]:
@@ -128,6 +199,13 @@ def put_record(
request: bytes,
response: bytes,
) -> None:
+ """Store a new record.
+
+ Args:
+ address: (host, port) tuple
+ request: Request bytes
+ response: Response bytes
+ """
host, port = address
record = MocketRecord(
host=host,
diff --git a/mocket/socket.py b/mocket/socket.py
index e06a1a8e..03868bff 100644
--- a/mocket/socket.py
+++ b/mocket/socket.py
@@ -1,3 +1,5 @@
+"""Mock socket implementation for Mocket."""
+
from __future__ import annotations
import contextlib
@@ -25,7 +27,21 @@
true_socket = socket.socket
-def mock_create_connection(address, timeout=None, source_address=None):
+def mock_create_connection(
+ address: Address,
+ timeout: float | None = None,
+ source_address: Address | None = None,
+) -> socket.socket:
+ """Create a mock socket connection.
+
+ Args:
+ address: (host, port) tuple
+ timeout: Connection timeout in seconds
+ source_address: Source address for binding (unused)
+
+ Returns:
+ MocketSocket instance
+ """
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP)
if timeout:
s.settimeout(timeout)
@@ -41,29 +57,80 @@ def mock_getaddrinfo(
proto: int = 0,
flags: int = 0,
) -> list[tuple[int, int, int, str, tuple[str, int]]]:
+ """Mock socket.getaddrinfo function.
+
+ Args:
+ host: Hostname
+ port: Port number
+ family: Address family (ignored)
+ type: Socket type (ignored)
+ proto: Protocol (ignored)
+ flags: Flags (ignored)
+
+ Returns:
+ List of address info tuples
+ """
return [(2, 1, 6, "", (host, port))]
def mock_gethostbyname(hostname: str) -> str:
+ """Mock socket.gethostbyname function.
+
+ Args:
+ hostname: Hostname to resolve (unused)
+
+ Returns:
+ Localhost IP address
+ """
return "127.0.0.1"
def mock_gethostname() -> str:
+ """Mock socket.gethostname function.
+
+ Returns:
+ Localhost hostname
+ """
return "localhost"
def mock_inet_pton(address_family: int, ip_string: str) -> bytes:
+ """Mock socket.inet_pton function.
+
+ Args:
+ address_family: Address family (unused)
+ ip_string: IP string (unused)
+
+ Returns:
+ Localhost as bytes
+ """
return bytes("\x7f\x00\x00\x01", "utf-8")
-def mock_socketpair(*args, **kwargs):
- """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services."""
+def mock_socketpair(
+ *args: Any,
+ **kwargs: Any,
+) -> tuple[socket.socket, socket.socket]:
+ """Mock socket.socketpair function.
+
+ Returns a real socketpair() used by asyncio loop for supporting
+ calls made by fastapi and similar services.
+
+ Args:
+ *args: Positional arguments
+ **kwargs: Keyword arguments
+
+ Returns:
+ Tuple of two connected sockets
+ """
import _socket
return _socket.socketpair(*args, **kwargs)
class MocketSocket:
+ """Mock socket implementation for Mocket."""
+
def __init__(
self,
family: socket.AddressFamily | int = socket.AF_INET,
@@ -72,6 +139,15 @@ def __init__(
fileno: int | None = None,
**kwargs: Any,
) -> None:
+ """Initialize a Mocket socket.
+
+ Args:
+ family: Address family
+ type: Socket type
+ proto: Protocol number
+ fileno: File descriptor (unused)
+ **kwargs: Additional keyword arguments
+ """
self._family = family
self._type = type
self._proto = proto
@@ -90,9 +166,11 @@ def __init__(
self._entry = None
def __str__(self) -> str:
+ """Return a string representation of the socket."""
return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})"
def __enter__(self) -> Self:
+ """Enter context manager."""
return self
def __exit__(
@@ -101,27 +179,37 @@ def __exit__(
value: BaseException | None,
traceback: TracebackType | None,
) -> None:
+ """Exit context manager and close socket."""
self.close()
@property
def family(self) -> int:
+ """Get the address family."""
return self._family
@property
def type(self) -> int:
+ """Get the socket type."""
return self._type
@property
def proto(self) -> int:
+ """Get the protocol number."""
return self._proto
@property
def io(self) -> MocketSocketIO:
+ """Get or create the socket I/O object."""
if self._io is None:
self._io = MocketSocketIO((self._host, self._port))
return self._io
def fileno(self) -> int:
+ """Get the file descriptor for reading.
+
+ Returns:
+ File descriptor number
+ """
address = (self._host, self._port)
r_fd, _ = Mocket.get_pair(address)
if not r_fd:
@@ -130,10 +218,21 @@ def fileno(self) -> int:
return r_fd
def gettimeout(self) -> float | None:
+ """Get the socket timeout.
+
+ Returns:
+ Timeout in seconds or None
+ """
return self._timeout
- # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None`
def setsockopt(self, family: int, type: int, proto: int) -> None:
+ """Set socket options.
+
+ Args:
+ family: Address family
+ type: Socket type
+ proto: Protocol number
+ """
self._family = family
self._type = type
self._proto = proto
@@ -142,40 +241,124 @@ def setsockopt(self, family: int, type: int, proto: int) -> None:
self._true_socket.setsockopt(family, type, proto)
def settimeout(self, timeout: float | None) -> None:
+ """Set the socket timeout.
+
+ Args:
+ timeout: Timeout in seconds or None
+ """
self._timeout = timeout
@staticmethod
def getsockopt(level: int, optname: int, buflen: int | None = None) -> int:
+ """Get socket option (mock implementation).
+
+ Args:
+ level: Socket option level
+ optname: Socket option name
+ buflen: Buffer length (unused)
+
+ Returns:
+ SOCK_STREAM constant
+ """
return socket.SOCK_STREAM
def getpeername(self) -> _RetAddress:
+ """Get the remote socket address.
+
+ Returns:
+ Address of the remote socket
+ """
return self._address
def setblocking(self, block: bool) -> None:
+ """Set the socket to blocking or non-blocking mode.
+
+ Args:
+ block: True for blocking, False for non-blocking
+ """
self.settimeout(None) if block else self.settimeout(0.0)
def getblocking(self) -> bool:
+ """Check if the socket is in blocking mode.
+
+ Returns:
+ True if blocking, False otherwise
+ """
return self.gettimeout() is None
def getsockname(self) -> _RetAddress:
+ """Get the local socket address.
+
+ Returns:
+ Local socket address
+ """
return socket.gethostbyname(self._address[0]), self._address[1]
def connect(self, address: Address) -> None:
+ """Connect the socket to a remote address.
+
+ Args:
+ address: (host, port) tuple
+ """
self._address = self._host, self._port = address
Mocket._address = address
def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO:
+ """Create a file object for the socket.
+
+ Args:
+ mode: Mode string (unused)
+ bufsize: Buffer size (unused)
+
+ Returns:
+ MocketSocketIO object
+ """
return self.io
def get_entry(self, data: bytes) -> MocketEntry | None:
+ """Get a matching entry for the given data.
+
+ Args:
+ data: Request data
+
+ Returns:
+ Matching MocketEntry or None
+ """
return Mocket.get_entry(self._host, self._port, data)
- def sendto(self, data: ReadableBuffer, address: Address | None = None) -> int:
+ def sendto(
+ self,
+ data: ReadableBuffer,
+ address: Address | None = None,
+ ) -> int:
+ """Send data to a specific address (UDP-like).
+
+ Args:
+ data: Data to send
+ address: Destination address
+
+ Returns:
+ Number of bytes sent
+ """
self.connect(address)
self.sendall(data)
return len(data)
- def sendall(self, data, entry=None, *args, **kwargs):
+ def sendall(
+ self,
+ data: ReadableBuffer,
+ entry: MocketEntry | None = None,
+ *args: Any,
+ **kwargs: Any,
+ ) -> None:
+ """Send all data through the socket.
+
+ Args:
+ data: Data to send
+ entry: Pre-matched entry (optional)
+ *args: Additional arguments
+ **kwargs: Additional keyword arguments
+ """
if entry is None:
entry = self.get_entry(data)
@@ -198,6 +381,17 @@ def sendmsg(
flags: int = 0,
address: Address | None = None,
) -> int:
+ """Send a message through multiple buffers.
+
+ Args:
+ buffers: List of buffers to send
+ ancdata: Ancillary data (unused)
+ flags: Flags (unused)
+ address: Destination address (unused)
+
+ Returns:
+ Number of bytes sent
+ """
if not buffers:
return 0
@@ -211,16 +405,23 @@ def recvmsg(
ancbufsize: int | None = None,
flags: int = 0,
) -> tuple[bytes, list[tuple[int, bytes]]]:
- """
- Receive a message from the socket.
+ """Receive a message from the socket.
+
This is a mock implementation that reads from the MocketSocketIO.
+
+ Args:
+ buffersize: Size of buffer to receive
+ ancbufsize: Ancillary buffer size (unused)
+ flags: Flags (unused)
+
+ Returns:
+ Tuple of (data, ancillary_data)
"""
try:
data = self.recv(buffersize)
except BlockingIOError:
return b"", []
- # Mocking the ancillary data and flags as empty
return data, []
def recvmsg_into(
@@ -229,10 +430,19 @@ def recvmsg_into(
ancbufsize: int | None = None,
flags: int = 0,
address: Address | None = None,
- ):
- """
- Receive a message into multiple buffers.
+ ) -> int:
+ """Receive a message into multiple buffers.
+
This is a mock implementation that reads from the MocketSocketIO.
+
+ Args:
+ buffers: List of buffers to receive into
+ ancbufsize: Ancillary buffer size (unused)
+ flags: Flags (unused)
+ address: Address (unused)
+
+ Returns:
+ Number of bytes received
"""
if not buffers:
return 0
@@ -254,10 +464,16 @@ def recvfrom_into(
buffer: WriteableBuffer,
buffersize: int | None = None,
flags: int | None = None,
- ):
- """
- Receive data into a buffer and return the number of bytes received.
- This is a mock implementation that reads from the MocketSocketIO.
+ ) -> tuple[int, _RetAddress]:
+ """Receive data into a buffer and return the source address.
+
+ Args:
+ buffer: Buffer to receive into
+ buffersize: Size to receive
+ flags: Flags (unused)
+
+ Returns:
+ Tuple of (bytes_received, source_address)
"""
return self.recv_into(buffer, buffersize, flags), self._address
@@ -267,10 +483,19 @@ def recv_into(
buffersize: int | None = None,
flags: int | None = None,
) -> int:
+ """Receive data into a buffer.
+
+ Args:
+ buffer: Buffer to receive into
+ buffersize: Number of bytes to receive
+ flags: Flags (unused)
+
+ Returns:
+ Number of bytes received
+ """
if hasattr(buffer, "write"):
return buffer.write(self.recv(buffersize))
- # buffer is a memoryview
if buffersize is None:
buffersize = len(buffer)
@@ -282,9 +507,30 @@ def recv_into(
def recvfrom(
self, buffersize: int, flags: int | None = None
) -> tuple[bytes, _RetAddress]:
+ """Receive data and the source address.
+
+ Args:
+ buffersize: Number of bytes to receive
+ flags: Flags (unused)
+
+ Returns:
+ Tuple of (data, source_address)
+ """
return self.recv(buffersize, flags), self._address
def recv(self, buffersize: int, flags: int | None = None) -> bytes:
+ """Receive data from the socket.
+
+ Args:
+ buffersize: Maximum number of bytes to receive
+ flags: Flags (unused)
+
+ Returns:
+ Received bytes
+
+ Raises:
+ BlockingIOError: If socket is non-blocking and no data available
+ """
r_fd, _ = Mocket.get_pair((self._host, self._port))
if r_fd:
return os.read(r_fd, buffersize)
@@ -298,6 +544,19 @@ def recv(self, buffersize: int, flags: int | None = None) -> bytes:
raise exc
def true_sendall(self, data: bytes, *args: Any, **kwargs: Any) -> bytes:
+ """Send data through the real socket and receive response.
+
+ Args:
+ data: Data to send
+ *args: Additional arguments
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ Response bytes from the real socket
+
+ Raises:
+ StrictMocketException: If operation not allowed in STRICT mode
+ """
if not MocketMode.is_allowed(self._address):
MocketMode.raise_not_allowed(self._address, data)
@@ -344,7 +603,17 @@ def send(
data: ReadableBuffer,
*args: Any,
**kwargs: Any,
- ) -> int: # pragma: no cover
+ ) -> int:
+ """Send data through the socket.
+
+ Args:
+ data: Data to send
+ *args: Additional arguments
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ Number of bytes sent
+ """
entry = self.get_entry(data)
if not entry or (entry and self._entry != entry):
kwargs["entry"] = entry
@@ -357,7 +626,11 @@ def send(
return len(data)
def accept(self) -> tuple[MocketSocket, _RetAddress]:
- """Accept a connection and return a new MocketSocket object."""
+ """Accept a connection and return a new MocketSocket object.
+
+ Returns:
+ Tuple of (new_socket, client_address)
+ """
new_socket = MocketSocket(
family=self._family,
type=self._type,
@@ -369,11 +642,19 @@ def accept(self) -> tuple[MocketSocket, _RetAddress]:
return new_socket, (self._host, self._port)
def close(self) -> None:
+ """Close the socket and underlying true socket."""
if self._true_socket and not self._true_socket._closed:
self._true_socket.close()
def __getattr__(self, name: str) -> Any:
- """Do nothing catchall function, for methods like shutdown()"""
+ """Do-nothing catchall function for methods like shutdown().
+
+ Args:
+ name: Method name
+
+ Returns:
+ A callable that does nothing
+ """
def do_nothing(*args: Any, **kwargs: Any) -> Any:
pass
diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py
index 6d5e7307..aeaab6b5 100644
--- a/mocket/ssl/context.py
+++ b/mocket/ssl/context.py
@@ -1,3 +1,5 @@
+"""Mocket SSL context implementation."""
+
from __future__ import annotations
from typing import Any
@@ -7,10 +9,13 @@
class _MocketSSLContext:
- """For Python 3.6 and newer."""
+ """Mock SSL context for Python 3.6 and newer."""
class FakeSetter(int):
+ """Descriptor that ignores assignment."""
+
def __set__(self, *args: Any) -> None:
+ """Ignore any assignment attempts."""
pass
minimum_version = FakeSetter()
@@ -20,29 +25,49 @@ def __set__(self, *args: Any) -> None:
class MocketSSLContext(_MocketSSLContext):
- DUMMY_METHODS = (
+ """Mock SSL context that wraps sockets in MocketSSLSocket."""
+
+ DUMMY_METHODS: tuple = (
"load_default_certs",
"load_verify_locations",
"set_alpn_protocols",
"set_ciphers",
"set_default_verify_paths",
)
- sock = None
- post_handshake_auth = None
- _check_hostname = False
+ sock: MocketSocket | None = None
+ post_handshake_auth: bool | None = None
+ _check_hostname: bool = False
@property
def check_hostname(self) -> bool:
+ """Get the check_hostname setting.
+
+ Returns:
+ Always False (mock implementation)
+ """
return self._check_hostname
@check_hostname.setter
def check_hostname(self, _: bool) -> None:
+ """Set the check_hostname setting (mocked).
+
+ Args:
+ _: Value (ignored, always set to False)
+ """
self._check_hostname = False
def __init__(self, *args: Any, **kwargs: Any) -> None:
+ """Initialize the SSL context.
+
+ Args:
+ *args: Positional arguments (ignored)
+ **kwargs: Keyword arguments (ignored)
+ """
self._set_dummy_methods()
def _set_dummy_methods(self) -> None:
+ """Set all dummy methods that do nothing."""
+
def dummy_method(*args: Any, **kwargs: Any) -> Any:
pass
@@ -55,15 +80,36 @@ def wrap_socket(
*args: Any,
**kwargs: Any,
) -> MocketSSLSocket:
+ """Wrap a socket in an SSL socket.
+
+ Args:
+ sock: Socket to wrap
+ *args: Additional arguments
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ MocketSSLSocket instance
+ """
return MocketSSLSocket._create(sock, *args, **kwargs)
def wrap_bio(
self,
- incoming: Any, # _ssl.MemoryBIO
- outgoing: Any, # _ssl.MemoryBIO
+ incoming: Any,
+ outgoing: Any,
server_side: bool = False,
server_hostname: str | bytes | None = None,
) -> MocketSSLSocket:
+ """Wrap BIO objects in an SSL socket (mock implementation).
+
+ Args:
+ incoming: Incoming BIO (_ssl.MemoryBIO)
+ outgoing: Outgoing BIO (_ssl.MemoryBIO)
+ server_side: Whether this is server side
+ server_hostname: Server hostname
+
+ Returns:
+ MocketSSLSocket instance
+ """
ssl_obj = MocketSSLSocket()
ssl_obj._host = server_hostname
return ssl_obj
@@ -74,5 +120,15 @@ def mock_wrap_socket(
*args: Any,
**kwargs: Any,
) -> MocketSSLSocket:
+ """Mock ssl.wrap_socket function.
+
+ Args:
+ sock: Socket to wrap
+ *args: Additional arguments
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ MocketSSLSocket instance
+ """
context = MocketSSLContext()
return context.wrap_socket(sock, *args, **kwargs)
diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py
index 6dcd7817..94984fce 100644
--- a/mocket/ssl/socket.py
+++ b/mocket/ssl/socket.py
@@ -1,3 +1,5 @@
+"""Mocket SSL socket implementation."""
+
from __future__ import annotations
import ssl
@@ -12,14 +14,33 @@
class MocketSSLSocket(MocketSocket):
+ """Mock SSL socket that extends MocketSocket with SSL-specific behavior."""
+
def __init__(self, *args: Any, **kwargs: Any) -> None:
+ """Initialize an SSL socket.
+
+ Args:
+ *args: Positional arguments
+ **kwargs: Keyword arguments
+ """
super().__init__(*args, **kwargs)
- self._did_handshake = False
- self._sent_non_empty_bytes = False
+ self._did_handshake: bool = False
+ self._sent_non_empty_bytes: bool = False
self._original_socket: MocketSocket = self
def read(self, buffersize: int | None = None) -> bytes:
+ """Read data from the SSL socket.
+
+ Args:
+ buffersize: Maximum bytes to read
+
+ Returns:
+ Bytes read from the socket
+
+ Raises:
+ ssl.SSLWantReadError: If handshake not completed and no data
+ """
rv = self.io.read(buffersize)
if rv:
self._sent_non_empty_bytes = True
@@ -28,12 +49,29 @@ def read(self, buffersize: int | None = None) -> bytes:
return rv
def write(self, data: bytes) -> int | None:
+ """Write data to the SSL socket.
+
+ Args:
+ data: Bytes to write
+
+ Returns:
+ Number of bytes written
+ """
return self.send(encode_to_bytes(data))
def do_handshake(self) -> None:
+ """Perform SSL handshake (mock implementation)."""
self._did_handshake = True
def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
+ """Get the peer certificate (mock implementation).
+
+ Args:
+ binary_form: Whether to return binary form (unused)
+
+ Returns:
+ Mock certificate dictionary
+ """
if not (self._host and self._port):
self._address = self._host, self._port = Mocket._address
@@ -54,12 +92,27 @@ def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType:
}
def ciper(self) -> tuple[str, str, str]:
+ """Get cipher information (mock implementation).
+
+ Returns:
+ Tuple of (cipher_name, protocol, key_exchange_algorithm)
+ """
return "ADH", "AES256", "SHA"
def compression(self) -> Options:
+ """Get compression options (mock implementation).
+
+ Returns:
+ SSL options constant
+ """
return ssl.OP_NO_COMPRESSION
def unwrap(self) -> MocketSocket:
+ """Unwrap the SSL socket and return the underlying socket.
+
+ Returns:
+ The original MocketSocket
+ """
return self._original_socket
@classmethod
@@ -71,6 +124,18 @@ def _create(
*args: Any,
**kwargs: Any,
) -> MocketSSLSocket:
+ """Create an SSL socket from a regular socket.
+
+ Args:
+ sock: Socket to wrap
+ ssl_context: SSL context (optional)
+ server_hostname: Server hostname
+ *args: Additional arguments
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ New MocketSSLSocket instance
+ """
ssl_socket = MocketSSLSocket()
ssl_socket._original_socket = sock
ssl_socket._true_socket = sock._true_socket
diff --git a/mocket/types.py b/mocket/types.py
index 562648c7..fedfd37f 100644
--- a/mocket/types.py
+++ b/mocket/types.py
@@ -1,3 +1,5 @@
+"""Type aliases and definitions for Mocket."""
+
from __future__ import annotations
from typing import Any, Dict, Tuple, Union
diff --git a/mocket/urllib3.py b/mocket/urllib3.py
index e89bc7b5..872efc5f 100644
--- a/mocket/urllib3.py
+++ b/mocket/urllib3.py
@@ -1,3 +1,5 @@
+"""Urllib3 specific socket mocking."""
+
from __future__ import annotations
from typing import Any
@@ -8,6 +10,14 @@
def mock_match_hostname(*args: Any) -> None:
+ """Mock urllib3's match_hostname function.
+
+ Args:
+ *args: Ignored arguments
+
+ Returns:
+ None
+ """
return None
@@ -16,5 +26,15 @@ def mock_ssl_wrap_socket(
*args: Any,
**kwargs: Any,
) -> MocketSSLSocket:
+ """Mock urllib3's ssl_wrap_socket function.
+
+ Args:
+ sock: The socket to wrap
+ *args: Additional arguments
+ **kwargs: Additional keyword arguments
+
+ Returns:
+ MocketSSLSocket instance
+ """
context = MocketSSLContext()
return context.wrap_socket(sock, *args, **kwargs)
diff --git a/mocket/utils.py b/mocket/utils.py
index 6180ae3f..749b2b70 100644
--- a/mocket/utils.py
+++ b/mocket/utils.py
@@ -1,3 +1,5 @@
+"""Utility functions for Mocket."""
+
from __future__ import annotations
import binascii
@@ -14,12 +16,13 @@
class MocketizeDecorator(Protocol):
- """
+ """Protocol for a flexible decorator that can be used in multiple ways.
+
This is a generic decorator signature, currently applicable to get_mocketize.
- Decorators can be used as:
+ Decorators implementing this protocol can be used as:
1. A function that transforms func (the parameter) into func1 (the returned object).
- 2. A function that takes keyword arguments and returns 1.
+ 2. A function that takes keyword arguments and returns a decorator.
"""
@overload
@@ -32,18 +35,37 @@ def __call__(
def hexdump(binary_string: bytes) -> str:
- r"""
- >>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))
- True
+ """Convert binary data to space-separated hex string.
+
+ Args:
+ binary_string: Binary data to convert
+
+ Returns:
+ Space-separated hexadecimal representation
+
+ Example:
+ >>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F"))
+ True
"""
bs = decode_from_bytes(binascii.hexlify(binary_string).upper())
return " ".join(a + b for a, b in zip(bs[::2], bs[1::2]))
def hexload(string: str) -> bytes:
- r"""
- >>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo")
- True
+ """Convert space-separated hex string to binary data.
+
+ Args:
+ string: Space-separated hexadecimal string
+
+ Returns:
+ Binary data
+
+ Raises:
+ ValueError: If the hex string is invalid
+
+ Example:
+ >>> hexload("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F") == encode_to_bytes("bar foobar foo")
+ True
"""
string_no_spaces = "".join(string.split())
try:
@@ -53,6 +75,18 @@ def hexload(string: str) -> bytes:
def get_mocketize(wrapper_: Callable) -> MocketizeDecorator:
+ """Get a mocketize decorator from a wrapper function.
+
+ Decorators can be used as:
+ 1. A function that transforms func (the parameter) into func1 (the returned object).
+ 2. A function that takes keyword arguments and returns 1.
+
+ Args:
+ wrapper_: The wrapper function to convert to a decorator
+
+ Returns:
+ A MocketizeDecorator instance that can be used as a flexible decorator
+ """
# trying to support different versions of `decorator`
with contextlib.suppress(TypeError):
return decorator.decorator(wrapper_, kwsyntax=True) # type: ignore[return-value, call-arg, unused-ignore]
diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py
index a1eae240..3ee91d3d 100644
--- a/tests/test_asyncio.py
+++ b/tests/test_asyncio.py
@@ -6,11 +6,11 @@
import aiohttp
import pytest
-
-from mocket import Mocketizer, async_mocketize
from mocket.mockhttp import Entry
from mocket.plugins.aiohttp_connector import MocketTCPConnector
+from mocket import Mocketizer, async_mocketize
+
def test_asyncio_record_replay():
async def test_asyncio_connection():
diff --git a/tests/test_http.py b/tests/test_http.py
index 3d3e5b8e..2bf72620 100644
--- a/tests/test_http.py
+++ b/tests/test_http.py
@@ -10,9 +10,9 @@
import pytest
import requests
+from mocket.mocks.mockhttp import Entry, Response
from mocket import Mocket, Mocketizer, mocketize
-from mocket.mocks.mockhttp import Entry, Response
class HttpTestCase(TestCase):
diff --git a/tests/test_http_httpx.py b/tests/test_http_httpx.py
index 6fb0fcab..3b088505 100644
--- a/tests/test_http_httpx.py
+++ b/tests/test_http_httpx.py
@@ -2,7 +2,6 @@
from unittest import IsolatedAsyncioTestCase
import httpx
-
from mocket.plugins.httpretty import HTTPretty, async_httprettified
diff --git a/tests/test_httpretty.py b/tests/test_httpretty.py
index 2b00a154..ae125d60 100644
--- a/tests/test_httpretty.py
+++ b/tests/test_httpretty.py
@@ -24,9 +24,8 @@
import requests
-from sure import expect
-
from mocket.plugins.httpretty import HTTPretty, httprettified, httpretty
+from sure import expect
@httprettified
diff --git a/tests/test_https.py b/tests/test_https.py
index 4685f4eb..8c6e0b1b 100644
--- a/tests/test_https.py
+++ b/tests/test_https.py
@@ -5,9 +5,9 @@
import pytest
import requests
+from mocket.mockhttp import Entry # noqa - test retrocompatibility
from mocket import Mocket, Mocketizer, mocketize
-from mocket.mockhttp import Entry # noqa - test retrocompatibility
@pytest.fixture
diff --git a/tests/test_httpx.py b/tests/test_httpx.py
index add53de8..a9026c8b 100644
--- a/tests/test_httpx.py
+++ b/tests/test_httpx.py
@@ -6,11 +6,11 @@
from asgiref.sync import async_to_sync
from fastapi import FastAPI
from fastapi.testclient import TestClient
-
-from mocket import Mocket, Mocketizer, async_mocketize, mocketize
from mocket.mockhttp import Entry
from mocket.plugins.httpretty import httprettified, httpretty
+from mocket import Mocket, Mocketizer, async_mocketize, mocketize
+
@mocketize
@pytest.mark.parametrize("url", ("http://httpbin.org/ip", "https://httpbin.org/ip"))
diff --git a/tests/test_mocket.py b/tests/test_mocket.py
index 8810a5b9..82a2c86a 100644
--- a/tests/test_mocket.py
+++ b/tests/test_mocket.py
@@ -7,9 +7,9 @@
import httpx
import psutil
import pytest
+from mocket.compat import encode_to_bytes
from mocket import Mocket, MocketEntry, Mocketizer, mocketize
-from mocket.compat import encode_to_bytes
class MocketTestCase(TestCase):
diff --git a/tests/test_mode.py b/tests/test_mode.py
index bfdb2a79..db197c59 100644
--- a/tests/test_mode.py
+++ b/tests/test_mode.py
@@ -1,11 +1,11 @@
import pytest
import requests
-
-from mocket import Mocketizer, mocketize
from mocket.exceptions import StrictMocketException
from mocket.mockhttp import Entry, Response
from mocket.mode import MocketMode
+from mocket import Mocketizer, mocketize
+
@mocketize(strict_mode=True)
def test_strict_mode_fails():
diff --git a/tests/test_pook.py b/tests/test_pook.py
index 56721b5f..012fcdfb 100644
--- a/tests/test_pook.py
+++ b/tests/test_pook.py
@@ -3,7 +3,6 @@
with contextlib.suppress(ModuleNotFoundError):
import pook
import requests
-
from mocket.plugins.pook_mock_engine import MocketEngine
pook.set_mock_engine(MocketEngine)
diff --git a/tests/test_redis.py b/tests/test_redis.py
index fb6ec355..ccd43ca2 100644
--- a/tests/test_redis.py
+++ b/tests/test_redis.py
@@ -3,9 +3,9 @@
import pytest
import redis
+from mocket.mockredis import ERROR, OK, Entry, Redisizer
from mocket import Mocket, mocketize
-from mocket.mockredis import ERROR, OK, Entry, Redisizer
class RedisizerTestCase(TestCase):
diff --git a/tests/test_socket.py b/tests/test_socket.py
index dad62a33..9933c864 100644
--- a/tests/test_socket.py
+++ b/tests/test_socket.py
@@ -1,9 +1,9 @@
import socket
import pytest
+from mocket.socket import MocketSocket
from mocket import Mocket, MocketEntry, mocketize
-from mocket.socket import MocketSocket
@pytest.mark.parametrize("blocking", (False, True))
diff --git a/tests/test_utils.py b/tests/test_utils.py
index d3b5eba7..e3325f80 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -3,7 +3,6 @@
from unittest.mock import NonCallableMock, patch
import decorator
-
from mocket.utils import get_mocketize
From e011f2747ebf443a197becd59ca47283fa881ca0 Mon Sep 17 00:00:00 2001
From: Copilot <198982749+Copilot@users.noreply.github.com>
Date: Sun, 22 Feb 2026 21:16:36 +0100
Subject: [PATCH 02/14] Restore doctests in mockhttp.py (#318)
* Initial plan
* Restore doctests in mockhttp.py for set_extra_headers, can_handle, and _parse_requestline
Co-authored-by: mindflayer <527325+mindflayer@users.noreply.github.com>
* Add coverage.xml to .gitignore and remove from tracking
Co-authored-by: mindflayer <527325+mindflayer@users.noreply.github.com>
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: mindflayer <527325+mindflayer@users.noreply.github.com>
---
.gitignore | 1 +
mocket/mocks/mockhttp.py | 31 +++++++++++++++++++++++++++++--
2 files changed, 30 insertions(+), 2 deletions(-)
diff --git a/.gitignore b/.gitignore
index 564b8ce6..9bacc469 100644
--- a/.gitignore
+++ b/.gitignore
@@ -28,3 +28,4 @@ shippable
.vscode/
Pipfile.lock
requirements.txt
+coverage.xml
diff --git a/mocket/mocks/mockhttp.py b/mocket/mocks/mockhttp.py
index e7e5a7b9..5ec14a62 100644
--- a/mocket/mocks/mockhttp.py
+++ b/mocket/mocks/mockhttp.py
@@ -182,10 +182,19 @@ def set_base_headers(self) -> None:
self.headers["Content-Type"] = do_the_magic(self.body)
def set_extra_headers(self, headers: dict) -> None:
- """Add extra headers to the response.
+ r"""Add extra headers to the response.
Args:
headers: Dictionary of additional headers
+
+ >>> r = Response(body="")
+ >>> len(r.headers.keys())
+ 6
+ >>> r.set_extra_headers({"foo-bar": "Foobar"})
+ >>> len(r.headers.keys())
+ 7
+ >>> encode_to_bytes(r.headers.get("Foo-Bar")) == encode_to_bytes("Foobar")
+ True
"""
for k, v in headers.items():
self.headers["-".join(token.capitalize() for token in k.split("-"))] = v
@@ -294,13 +303,20 @@ def _can_handle(self, path: str, qs_dict: dict) -> bool:
return can_handle
def can_handle(self, data: bytes) -> bool:
- """Check if this entry can handle the given request data.
+ r"""Check if this entry can handle the given request data.
Args:
data: Request data
Returns:
True if this entry can handle the request
+
+ >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b''),))
+ >>> e.can_handle(b'GET /?bar=foo HTTP/1.1\r\nHost: github.com\r\nAccept-Encoding: gzip, deflate\r\nConnection: keep-alive\r\nUser-Agent: python-requests/2.7.0 CPython/3.4.3 Linux/3.19.0-16-generic\r\nAccept: */*\r\n\r\n')
+ False
+ >>> e = Entry('http://www.github.com/?bar=foo&foobar', Entry.GET, (Response(b'
'),))
+ >>> e.can_handle(b'GET /?bar=foo&foobar HTTP/1.1\r\nHost: github.com\r\nAccept-Encoding: gzip, deflate\r\nConnection: keep-alive\r\nUser-Agent: python-requests/2.7.0 CPython/3.4.3 Linux/3.19.0-16-generic\r\nAccept: */*\r\n\r\n')
+ True
"""
try:
requestline, _ = decode_from_bytes(data).split(CRLF, 1)
@@ -322,6 +338,8 @@ def can_handle(self, data: bytes) -> bool:
def _parse_requestline(line: str) -> tuple:
"""Parse an HTTP request line.
+ http://www.w3.org/Protocols/rfc2616/rfc2616-sec5.html#sec5
+
Args:
line: HTTP request line string
@@ -330,6 +348,15 @@ def _parse_requestline(line: str) -> tuple:
Raises:
ValueError: If line is not a valid request line
+
+ >>> Entry._parse_requestline('GET / HTTP/1.0') == ('GET', '/', '1.0')
+ True
+ >>> Entry._parse_requestline('post /testurl htTP/1.1') == ('POST', '/testurl', '1.1')
+ True
+ >>> Entry._parse_requestline('Im not a RequestLine')
+ Traceback (most recent call last):
+ ...
+ ValueError: Not a Request-Line
"""
m = re.match(
r"({})\s+(.*)\s+HTTP/(1.[0|1])".format("|".join(Entry.METHODS)), line, re.I
From 5da06d3f934460abf8cdaa9897593a2a4c840bdb Mon Sep 17 00:00:00 2001
From: Giorgio Salluzzo
Date: Sun, 22 Feb 2026 21:27:06 +0100
Subject: [PATCH 03/14] Revert
---
tests/test_utils.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/tests/test_utils.py b/tests/test_utils.py
index e3325f80..d3b5eba7 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -3,6 +3,7 @@
from unittest.mock import NonCallableMock, patch
import decorator
+
from mocket.utils import get_mocketize
From 30668d522c7556f9da57843682f398659b5d2f6f Mon Sep 17 00:00:00 2001
From: Giorgio Salluzzo
Date: Sun, 22 Feb 2026 21:28:17 +0100
Subject: [PATCH 04/14] Revert
---
tests/test_redis.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_redis.py b/tests/test_redis.py
index ccd43ca2..fb6ec355 100644
--- a/tests/test_redis.py
+++ b/tests/test_redis.py
@@ -3,9 +3,9 @@
import pytest
import redis
-from mocket.mockredis import ERROR, OK, Entry, Redisizer
from mocket import Mocket, mocketize
+from mocket.mockredis import ERROR, OK, Entry, Redisizer
class RedisizerTestCase(TestCase):
From 1ac0ad1f4d7c9c2fd6f466ac1db82de47e31939a Mon Sep 17 00:00:00 2001
From: Giorgio Salluzzo
Date: Sun, 22 Feb 2026 21:29:18 +0100
Subject: [PATCH 05/14] Revert
---
tests/test_socket.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_socket.py b/tests/test_socket.py
index 9933c864..dad62a33 100644
--- a/tests/test_socket.py
+++ b/tests/test_socket.py
@@ -1,9 +1,9 @@
import socket
import pytest
-from mocket.socket import MocketSocket
from mocket import Mocket, MocketEntry, mocketize
+from mocket.socket import MocketSocket
@pytest.mark.parametrize("blocking", (False, True))
From 70dc4fab95ff236b0694f3dbf05674a13dd42183 Mon Sep 17 00:00:00 2001
From: Giorgio Salluzzo
Date: Sun, 22 Feb 2026 21:30:19 +0100
Subject: [PATCH 06/14] Revert
Removed duplicate import of Mocketizer and mocketize.
---
tests/test_mode.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/tests/test_mode.py b/tests/test_mode.py
index db197c59..bfdb2a79 100644
--- a/tests/test_mode.py
+++ b/tests/test_mode.py
@@ -1,11 +1,11 @@
import pytest
import requests
+
+from mocket import Mocketizer, mocketize
from mocket.exceptions import StrictMocketException
from mocket.mockhttp import Entry, Response
from mocket.mode import MocketMode
-from mocket import Mocketizer, mocketize
-
@mocketize(strict_mode=True)
def test_strict_mode_fails():
From 7ea685d2624196836690fdf9b24eaf13dd3a93a7 Mon Sep 17 00:00:00 2001
From: Giorgio Salluzzo
Date: Sun, 22 Feb 2026 21:31:02 +0100
Subject: [PATCH 07/14] Revert
---
tests/test_mocket.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_mocket.py b/tests/test_mocket.py
index 82a2c86a..8810a5b9 100644
--- a/tests/test_mocket.py
+++ b/tests/test_mocket.py
@@ -7,9 +7,9 @@
import httpx
import psutil
import pytest
-from mocket.compat import encode_to_bytes
from mocket import Mocket, MocketEntry, Mocketizer, mocketize
+from mocket.compat import encode_to_bytes
class MocketTestCase(TestCase):
From 948f4301189166bf592cc8ccec7090b423f6b56c Mon Sep 17 00:00:00 2001
From: Giorgio Salluzzo
Date: Sun, 22 Feb 2026 21:31:43 +0100
Subject: [PATCH 08/14] Revert
Removed duplicate import statements for Mocketizer and async_mocketize.
---
tests/test_asyncio.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/tests/test_asyncio.py b/tests/test_asyncio.py
index 3ee91d3d..a1eae240 100644
--- a/tests/test_asyncio.py
+++ b/tests/test_asyncio.py
@@ -6,10 +6,10 @@
import aiohttp
import pytest
-from mocket.mockhttp import Entry
-from mocket.plugins.aiohttp_connector import MocketTCPConnector
from mocket import Mocketizer, async_mocketize
+from mocket.mockhttp import Entry
+from mocket.plugins.aiohttp_connector import MocketTCPConnector
def test_asyncio_record_replay():
From 5b6a08f9deba5573b40046fdacf7293eaae30da7 Mon Sep 17 00:00:00 2001
From: Giorgio Salluzzo
Date: Sun, 22 Feb 2026 21:32:29 +0100
Subject: [PATCH 09/14] Revert
Removed duplicate import statements for Mocket.
---
tests/test_httpx.py | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/tests/test_httpx.py b/tests/test_httpx.py
index a9026c8b..add53de8 100644
--- a/tests/test_httpx.py
+++ b/tests/test_httpx.py
@@ -6,10 +6,10 @@
from asgiref.sync import async_to_sync
from fastapi import FastAPI
from fastapi.testclient import TestClient
-from mocket.mockhttp import Entry
-from mocket.plugins.httpretty import httprettified, httpretty
from mocket import Mocket, Mocketizer, async_mocketize, mocketize
+from mocket.mockhttp import Entry
+from mocket.plugins.httpretty import httprettified, httpretty
@mocketize
From 86e699a24a23ac9d0f11008e5830ae4e903c47a7 Mon Sep 17 00:00:00 2001
From: Giorgio Salluzzo
Date: Sun, 22 Feb 2026 21:33:55 +0100
Subject: [PATCH 10/14] Revert
---
tests/test_http.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_http.py b/tests/test_http.py
index 2bf72620..3d3e5b8e 100644
--- a/tests/test_http.py
+++ b/tests/test_http.py
@@ -10,9 +10,9 @@
import pytest
import requests
-from mocket.mocks.mockhttp import Entry, Response
from mocket import Mocket, Mocketizer, mocketize
+from mocket.mocks.mockhttp import Entry, Response
class HttpTestCase(TestCase):
From de300c30f3c29d1e2c73abdbbf86597be8caa3dc Mon Sep 17 00:00:00 2001
From: Giorgio Salluzzo
Date: Sun, 22 Feb 2026 21:34:33 +0100
Subject: [PATCH 11/14] Revert
---
tests/test_https.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/tests/test_https.py b/tests/test_https.py
index 8c6e0b1b..4685f4eb 100644
--- a/tests/test_https.py
+++ b/tests/test_https.py
@@ -5,9 +5,9 @@
import pytest
import requests
-from mocket.mockhttp import Entry # noqa - test retrocompatibility
from mocket import Mocket, Mocketizer, mocketize
+from mocket.mockhttp import Entry # noqa - test retrocompatibility
@pytest.fixture
From 0e2c017d5266c58cbdc630653f3fad84d3cc8d35 Mon Sep 17 00:00:00 2001
From: Giorgio Salluzzo
Date: Sun, 22 Feb 2026 21:35:30 +0100
Subject: [PATCH 12/14] Revert
---
tests/test_httpretty.py | 3 ++-
1 file changed, 2 insertions(+), 1 deletion(-)
diff --git a/tests/test_httpretty.py b/tests/test_httpretty.py
index ae125d60..2b00a154 100644
--- a/tests/test_httpretty.py
+++ b/tests/test_httpretty.py
@@ -24,9 +24,10 @@
import requests
-from mocket.plugins.httpretty import HTTPretty, httprettified, httpretty
from sure import expect
+from mocket.plugins.httpretty import HTTPretty, httprettified, httpretty
+
@httprettified
def test_httpretty_should_mock_a_simple_get_with_requests_read():
From b74910a9c15568cc50a374f692df9c37323a72b8 Mon Sep 17 00:00:00 2001
From: Giorgio Salluzzo
Date: Sun, 22 Feb 2026 21:36:00 +0100
Subject: [PATCH 13/14] Revert
---
tests/test_http_httpx.py | 1 +
1 file changed, 1 insertion(+)
diff --git a/tests/test_http_httpx.py b/tests/test_http_httpx.py
index 3b088505..6fb0fcab 100644
--- a/tests/test_http_httpx.py
+++ b/tests/test_http_httpx.py
@@ -2,6 +2,7 @@
from unittest import IsolatedAsyncioTestCase
import httpx
+
from mocket.plugins.httpretty import HTTPretty, async_httprettified
From 119d3b87567af7f7e54e9497197f162511c08c70 Mon Sep 17 00:00:00 2001
From: Copilot <198982749+Copilot@users.noreply.github.com>
Date: Sun, 22 Feb 2026 22:40:09 +0100
Subject: [PATCH 14/14] Fix `setsockopt` signature to match standard socket API
(#319)
* Initial plan
* Fix setsockopt signature to match standard socket API
Co-authored-by: mindflayer <527325+mindflayer@users.noreply.github.com>
* Add tests for setsockopt with and without optlen
Co-authored-by: mindflayer <527325+mindflayer@users.noreply.github.com>
---------
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: mindflayer <527325+mindflayer@users.noreply.github.com>
---
mocket/socket.py | 26 ++++++++++++++++----------
tests/test_socket.py | 21 +++++++++++++++++++++
2 files changed, 37 insertions(+), 10 deletions(-)
diff --git a/mocket/socket.py b/mocket/socket.py
index 03868bff..bd79528c 100644
--- a/mocket/socket.py
+++ b/mocket/socket.py
@@ -225,20 +225,26 @@ def gettimeout(self) -> float | None:
"""
return self._timeout
- def setsockopt(self, family: int, type: int, proto: int) -> None:
- """Set socket options.
+ def setsockopt(
+ self,
+ level: int,
+ optname: int,
+ value: int | bytes | None,
+ optlen: int | None = None,
+ ) -> None:
+ """Set socket option.
Args:
- family: Address family
- type: Socket type
- proto: Protocol number
+ level: Socket option level (e.g., socket.SOL_SOCKET)
+ optname: Socket option name (e.g., socket.SO_REUSEADDR)
+ value: Option value as an integer or bytes, or None when optlen is provided
+ optlen: Option length (used when value is None)
"""
- self._family = family
- self._type = type
- self._proto = proto
-
if self._true_socket:
- self._true_socket.setsockopt(family, type, proto)
+ if optlen is not None:
+ self._true_socket.setsockopt(level, optname, value, optlen)
+ else:
+ self._true_socket.setsockopt(level, optname, value)
def settimeout(self, timeout: float | None) -> None:
"""Set the socket timeout.
diff --git a/tests/test_socket.py b/tests/test_socket.py
index dad62a33..68e71aee 100644
--- a/tests/test_socket.py
+++ b/tests/test_socket.py
@@ -1,4 +1,6 @@
import socket
+import struct
+from unittest.mock import MagicMock
import pytest
@@ -126,3 +128,22 @@ def test_recvfrom_into():
assert nbytes == len(test_data)
assert buf[:nbytes] == test_data
assert addr == sock._address
+
+
+def test_setsockopt_without_optlen():
+ sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM)
+ sock._true_socket = MagicMock()
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
+ sock._true_socket.setsockopt.assert_called_once_with(
+ socket.SOL_SOCKET, socket.SO_REUSEADDR, 1
+ )
+
+
+def test_setsockopt_with_optlen():
+ sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM)
+ sock._true_socket = MagicMock()
+ linger_value = struct.pack("ii", 1, 5)
+ sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger_value, len(linger_value))
+ sock._true_socket.setsockopt.assert_called_once_with(
+ socket.SOL_SOCKET, socket.SO_LINGER, linger_value, len(linger_value)
+ )