diff --git a/mocket/socket.py b/mocket/socket.py index 03868bf..bd79528 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -225,20 +225,26 @@ def gettimeout(self) -> float | None: """ return self._timeout - def setsockopt(self, family: int, type: int, proto: int) -> None: - """Set socket options. + def setsockopt( + self, + level: int, + optname: int, + value: int | bytes | None, + optlen: int | None = None, + ) -> None: + """Set socket option. Args: - family: Address family - type: Socket type - proto: Protocol number + level: Socket option level (e.g., socket.SOL_SOCKET) + optname: Socket option name (e.g., socket.SO_REUSEADDR) + value: Option value as an integer or bytes, or None when optlen is provided + optlen: Option length (used when value is None) """ - self._family = family - self._type = type - self._proto = proto - if self._true_socket: - self._true_socket.setsockopt(family, type, proto) + if optlen is not None: + self._true_socket.setsockopt(level, optname, value, optlen) + else: + self._true_socket.setsockopt(level, optname, value) def settimeout(self, timeout: float | None) -> None: """Set the socket timeout. diff --git a/tests/test_socket.py b/tests/test_socket.py index dad62a3..68e71ae 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -1,4 +1,6 @@ import socket +import struct +from unittest.mock import MagicMock import pytest @@ -126,3 +128,22 @@ def test_recvfrom_into(): assert nbytes == len(test_data) assert buf[:nbytes] == test_data assert addr == sock._address + + +def test_setsockopt_without_optlen(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + sock._true_socket = MagicMock() + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock._true_socket.setsockopt.assert_called_once_with( + socket.SOL_SOCKET, socket.SO_REUSEADDR, 1 + ) + + +def test_setsockopt_with_optlen(): + sock = MocketSocket(socket.AF_INET, socket.SOCK_STREAM) + sock._true_socket = MagicMock() + linger_value = struct.pack("ii", 1, 5) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_LINGER, linger_value, len(linger_value)) + sock._true_socket.setsockopt.assert_called_once_with( + socket.SOL_SOCKET, socket.SO_LINGER, linger_value, len(linger_value) + )