diff --git a/redis/__init__.py b/redis/__init__.py index 0d3cc21a08..238c3d58fc 100644 --- a/redis/__init__.py +++ b/redis/__init__.py @@ -46,7 +46,7 @@ def int_or_str(value): return value -__version__ = "7.1.0" +__version__ = "7.1.0+sb1" VERSION = tuple(map(int_or_str, __version__.split("."))) diff --git a/redis/client.py b/redis/client.py index d3ab3cfcfe..5b10f55d09 100755 --- a/redis/client.py +++ b/redis/client.py @@ -2,6 +2,7 @@ import re import threading import time +from collections import defaultdict from itertools import chain from typing import ( TYPE_CHECKING, @@ -39,6 +40,7 @@ SSLConnection, UnixDomainSocketConnection, ) +from redis.crc import key_slot from redis.credentials import CredentialProvider from redis.event import ( AfterPooledConnectionsInstantiationEvent, @@ -838,6 +840,7 @@ def __init__( self.health_check_response = [b"pong", self.health_check_response_b] if self.push_handler_func is None: _set_info_logger() + self._connection_lock = threading.Lock() self.reset() def __enter__(self) -> "PubSub": @@ -892,11 +895,14 @@ def on_connect(self, connection) -> None: } self.psubscribe(**patterns) if self.shard_channels: - shard_channels = { - self.encoder.decode(k, force=True): v - for k, v in self.shard_channels.items() - } - self.ssubscribe(**shard_channels) + channels_by_slot = defaultdict(dict) + for k, v in self.shard_channels.items(): + key = self.encoder.decode(k, force=True) + slot = key_slot(self.encoder.encode(key)) + channels_by_slot[slot][key] = v + + for slot, channels in channels_by_slot.items(): + self.ssubscribe(**channels) @property def subscribed(self) -> bool: @@ -911,17 +917,19 @@ def execute_command(self, *args): # subscribed to one or more channels if self.connection is None: - self.connection = self.connection_pool.get_connection() - # register a callback that re-subscribes to any channels we - # were listening to when we were disconnected - self.connection.register_connect_callback(self.on_connect) - if self.push_handler_func is not None: - self.connection._parser.set_pubsub_push_handler(self.push_handler_func) - self._event_dispatcher.dispatch( - AfterPubSubConnectionInstantiationEvent( - self.connection, self.connection_pool, ClientType.SYNC, self._lock - ) - ) + with self._connection_lock: + if self.connection is None: + self.connection = self.connection_pool.get_connection() + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection.register_connect_callback(self.on_connect) + if self.push_handler_func is not None: + self.connection._parser.set_pubsub_push_handler(self.push_handler_func) + self._event_dispatcher.dispatch( + AfterPubSubConnectionInstantiationEvent( + self.connection, self.connection_pool, ClientType.SYNC, self._lock + ) + ) connection = self.connection kwargs = {"check_health": not self.subscribed} if not self.subscribed: @@ -1127,6 +1135,7 @@ def ssubscribe(self, *args, target_node=None, **kwargs): new_s_channels = dict.fromkeys(args) new_s_channels.update(kwargs) ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys()) + # update the s_channels dict AFTER we send the command. we don't want to # subscribe twice to these channels, once for the command and again # for the reconnection. diff --git a/redis/cluster.py b/redis/cluster.py index 33b54b1bed..ea91909a25 100644 --- a/redis/cluster.py +++ b/redis/cluster.py @@ -2051,6 +2051,7 @@ def __init__( node=None, host=None, port=None, + replica=False, push_handler_func=None, event_dispatcher: Optional["EventDispatcher"] = None, **kwargs, @@ -2069,6 +2070,7 @@ def __init__( :type port: int """ self.node = None + self.replica = replica self.set_pubsub_node(redis_cluster, node, host, port) connection_pool = ( None @@ -2218,7 +2220,7 @@ def get_sharded_message( if message["channel"] in self.pending_unsubscribe_shard_channels: self.pending_unsubscribe_shard_channels.remove(message["channel"]) self.shard_channels.pop(message["channel"], None) - node = self.cluster.get_node_from_key(message["channel"]) + node = self.cluster.get_node_from_key(message["channel"], self.replica) if self.node_pubsub_mapping[node.name].subscribed is False: self.node_pubsub_mapping.pop(node.name) if not self.channels and not self.patterns and not self.shard_channels: @@ -2235,7 +2237,7 @@ def ssubscribe(self, *args, **kwargs): s_channels = dict.fromkeys(args) s_channels.update(kwargs) for s_channel, handler in s_channels.items(): - node = self.cluster.get_node_from_key(s_channel) + node = self.cluster.get_node_from_key(s_channel, self.replica) pubsub = self._get_node_pubsub(node) if handler: pubsub.ssubscribe(**{s_channel: handler}) @@ -2256,7 +2258,7 @@ def sunsubscribe(self, *args): args = self.shard_channels for s_channel in args: - node = self.cluster.get_node_from_key(s_channel) + node = self.cluster.get_node_from_key(s_channel, self.replica) p = self._get_node_pubsub(node) p.sunsubscribe(s_channel) self.pending_unsubscribe_shard_channels.update( @@ -2612,6 +2614,16 @@ def __init__(self, args, options=None, position=None): self.asking = False self.command_policies: Optional[CommandPolicies] = None + def __repr__(self): + return ( + f"{self.__class__.__name__}<" + f"args={repr(self.args)}," + f"options={repr(self.options)}," + f"position={self.position}," + f"result={repr(self.result)}" + ">" + ) + class NodeCommands: """ """ @@ -2623,6 +2635,14 @@ def __init__(self, parse_response, connection_pool, connection): self.connection = connection self.commands = [] + def __repr__(self): + return ( + f"{self.__class__.__name__}<" + f"connection={repr(self.connection)}," + f"commands={repr(self.commands)}" + ">" + ) + def append(self, c): """ """ self.commands.append(c) @@ -3045,14 +3065,21 @@ def _send_cluster_commands( redis_node = self._pipe.get_redis_connection(node) try: connection = get_connection(redis_node) - except (ConnectionError, TimeoutError): + except BaseException as e: for n in nodes.values(): n.connection_pool.release(n.connection) - # Connection retries are being handled in the node's - # Retry object. Reinitialize the node -> slot table. - self._nodes_manager.initialize() - if is_default_node: - self._pipe.replace_default_node() + n.connection = None + nodes = {} + if self._pipe.retry and self._pipe.retry.is_supported_error(e): + backoff = self._pipe.retry._backoff.compute(0) + if backoff > 0: + time.sleep(backoff) + if isinstance(e, (ConnectionError, TimeoutError)): + # Connection retries are being handled in the node's + # Retry object. Reinitialize the node -> slot table. + self._nodes_manager.initialize() + if is_default_node: + self._pipe.replace_default_node() raise nodes[node_name] = NodeCommands( redis_node.parse_response, @@ -3077,6 +3104,18 @@ def _send_cluster_commands( for n in node_commands: n.read() + except BaseException: + # if nodes is not empty, a problem must have occurred + # since we can't guarantee the state of the connections, + # disconnect before returning it to the connection pool + for n in nodes.values(): + if n.connection: + n.connection.disconnect() + n.connection_pool.release(n.connection) + if len(nodes) > 0: + time.sleep(0.25) + nodes = {} # Clear to prevent double-release in finally + raise finally: # release all of the redis connections we allocated earlier # back into the connection pool. diff --git a/redis/commands/core.py b/redis/commands/core.py index 908895a846..7bda3ddc2a 100644 --- a/redis/commands/core.py +++ b/redis/commands/core.py @@ -5949,9 +5949,7 @@ def __call__( except NoScriptError: # Maybe the client is pointed to a different server than the client # that created this instance? - # Overwrite the sha just in case there was a discrepancy. - self.sha = client.script_load(self.script) - return client.evalsha(self.sha, len(keys), *args) + return client.eval(self.script, len(keys), *args) def get_encoder(self): """Get the encoder to encode string scripts into bytes.""" @@ -6020,9 +6018,7 @@ async def __call__( except NoScriptError: # Maybe the client is pointed to a different server than the client # that created this instance? - # Overwrite the sha just in case there was a discrepancy. - self.sha = await client.script_load(self.script) - return await client.evalsha(self.sha, len(keys), *args) + return await client.eval(self.script, len(keys), *args) class PubSubCommands(CommandsProtocol): diff --git a/redis/retry.py b/redis/retry.py index 225e431eb2..fc54dc43b2 100644 --- a/redis/retry.py +++ b/redis/retry.py @@ -71,6 +71,10 @@ def update_retries(self, value: int) -> None: """ self._retries = value + def is_supported_error(self, error: Exception) -> bool: + """Check if the error is one of the supported error types.""" + return isinstance(error, self._supported_errors) + class Retry(AbstractRetry[Exception]): __hash__ = AbstractRetry.__hash__ diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py index db313e2437..731503d845 100644 --- a/tests/test_pubsub.py +++ b/tests/test_pubsub.py @@ -9,6 +9,7 @@ import pytest import redis +from redis._parsers.encoders import Encoder from redis.exceptions import ConnectionError from .conftest import ( @@ -42,6 +43,31 @@ def wait_for_message( return None +def test_on_connect_resubscribes_shard_channels_grouped_by_slot(): + connection_pool = mock.Mock() + connection_pool.get_encoder.return_value = Encoder("utf-8", "strict", False) + + pubsub = redis.client.PubSub(connection_pool) + handler_a = mock.Mock() + handler_b = mock.Mock() + handler_c = mock.Mock() + pubsub.shard_channels = { + b"{same-slot}:a": handler_a, + b"{same-slot}:b": handler_b, + b"{other-slot}:c": handler_c, + } + pubsub.ssubscribe = mock.Mock() + + pubsub.on_connect(mock.Mock()) + + resubscribe_groups = [call.kwargs for call in pubsub.ssubscribe.call_args_list] + assert { + "{same-slot}:a": handler_a, + "{same-slot}:b": handler_b, + } in resubscribe_groups + assert {"{other-slot}:c": handler_c} in resubscribe_groups + + def make_message(type, channel, data, pattern=None): return { "type": type, @@ -1160,3 +1186,22 @@ def get_msg(): # the timeout on the read should not cause disconnect assert is_connected() + + +@pytest.mark.onlynoncluster +class TestConnectionLeak: + def test_connection_leak(self, r: redis.Redis): + pubsub = r.pubsub() + + def test(): + tid = threading.get_ident() + pubsub.subscribe(f"foo{tid}") + + threads = [threading.Thread(target=test) for _ in range(10)] + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + assert r.connection_pool._created_connections == 2