Skip to content
Open
2 changes: 1 addition & 1 deletion redis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def int_or_str(value):
return value


__version__ = "7.1.0"
__version__ = "7.1.0+sb1"

VERSION = tuple(map(int_or_str, __version__.split(".")))

Expand Down
41 changes: 25 additions & 16 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
import threading
import time
from collections import defaultdict
from itertools import chain
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -39,6 +40,7 @@
SSLConnection,
UnixDomainSocketConnection,
)
from redis.crc import key_slot
from redis.credentials import CredentialProvider
from redis.event import (
AfterPooledConnectionsInstantiationEvent,
Expand Down Expand Up @@ -838,6 +840,7 @@ def __init__(
self.health_check_response = [b"pong", self.health_check_response_b]
if self.push_handler_func is None:
_set_info_logger()
self._connection_lock = threading.Lock()
self.reset()

def __enter__(self) -> "PubSub":
Expand Down Expand Up @@ -892,11 +895,14 @@ def on_connect(self, connection) -> None:
}
self.psubscribe(**patterns)
if self.shard_channels:
shard_channels = {
self.encoder.decode(k, force=True): v
for k, v in self.shard_channels.items()
}
self.ssubscribe(**shard_channels)
channels_by_slot = defaultdict(dict)
for k, v in self.shard_channels.items():
key = self.encoder.decode(k, force=True)
slot = key_slot(self.encoder.encode(key))
channels_by_slot[slot][key] = v

for slot, channels in channels_by_slot.items():
self.ssubscribe(**channels)

@property
def subscribed(self) -> bool:
Expand All @@ -911,17 +917,19 @@ def execute_command(self, *args):
# subscribed to one or more channels

if self.connection is None:
self.connection = self.connection_pool.get_connection()
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
if self.push_handler_func is not None:
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
self._event_dispatcher.dispatch(
AfterPubSubConnectionInstantiationEvent(
self.connection, self.connection_pool, ClientType.SYNC, self._lock
)
)
with self._connection_lock:
if self.connection is None:
self.connection = self.connection_pool.get_connection()
# register a callback that re-subscribes to any channels we
# were listening to when we were disconnected
self.connection.register_connect_callback(self.on_connect)
if self.push_handler_func is not None:
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)
self._event_dispatcher.dispatch(
AfterPubSubConnectionInstantiationEvent(
self.connection, self.connection_pool, ClientType.SYNC, self._lock
)
)
connection = self.connection
kwargs = {"check_health": not self.subscribed}
if not self.subscribed:
Expand Down Expand Up @@ -1127,6 +1135,7 @@ def ssubscribe(self, *args, target_node=None, **kwargs):
new_s_channels = dict.fromkeys(args)
new_s_channels.update(kwargs)
ret_val = self.execute_command("SSUBSCRIBE", *new_s_channels.keys())

# update the s_channels dict AFTER we send the command. we don't want to
# subscribe twice to these channels, once for the command and again
# for the reconnection.
Expand Down
57 changes: 48 additions & 9 deletions redis/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2051,6 +2051,7 @@ def __init__(
node=None,
host=None,
port=None,
replica=False,
push_handler_func=None,
event_dispatcher: Optional["EventDispatcher"] = None,
**kwargs,
Expand All @@ -2069,6 +2070,7 @@ def __init__(
:type port: int
"""
self.node = None
self.replica = replica
self.set_pubsub_node(redis_cluster, node, host, port)
connection_pool = (
None
Expand Down Expand Up @@ -2218,7 +2220,7 @@ def get_sharded_message(
if message["channel"] in self.pending_unsubscribe_shard_channels:
self.pending_unsubscribe_shard_channels.remove(message["channel"])
self.shard_channels.pop(message["channel"], None)
node = self.cluster.get_node_from_key(message["channel"])
node = self.cluster.get_node_from_key(message["channel"], self.replica)
if self.node_pubsub_mapping[node.name].subscribed is False:
self.node_pubsub_mapping.pop(node.name)
if not self.channels and not self.patterns and not self.shard_channels:
Expand All @@ -2235,7 +2237,7 @@ def ssubscribe(self, *args, **kwargs):
s_channels = dict.fromkeys(args)
s_channels.update(kwargs)
for s_channel, handler in s_channels.items():
node = self.cluster.get_node_from_key(s_channel)
node = self.cluster.get_node_from_key(s_channel, self.replica)
pubsub = self._get_node_pubsub(node)
if handler:
pubsub.ssubscribe(**{s_channel: handler})
Expand All @@ -2256,7 +2258,7 @@ def sunsubscribe(self, *args):
args = self.shard_channels

for s_channel in args:
node = self.cluster.get_node_from_key(s_channel)
node = self.cluster.get_node_from_key(s_channel, self.replica)
p = self._get_node_pubsub(node)
p.sunsubscribe(s_channel)
self.pending_unsubscribe_shard_channels.update(
Expand Down Expand Up @@ -2612,6 +2614,16 @@ def __init__(self, args, options=None, position=None):
self.asking = False
self.command_policies: Optional[CommandPolicies] = None

def __repr__(self):
return (
f"{self.__class__.__name__}<"
f"args={repr(self.args)},"
f"options={repr(self.options)},"
f"position={self.position},"
f"result={repr(self.result)}"
">"
)


class NodeCommands:
""" """
Expand All @@ -2623,6 +2635,14 @@ def __init__(self, parse_response, connection_pool, connection):
self.connection = connection
self.commands = []

def __repr__(self):
return (
f"{self.__class__.__name__}<"
f"connection={repr(self.connection)},"
f"commands={repr(self.commands)}"
">"
)

def append(self, c):
""" """
self.commands.append(c)
Expand Down Expand Up @@ -3045,14 +3065,21 @@ def _send_cluster_commands(
redis_node = self._pipe.get_redis_connection(node)
try:
connection = get_connection(redis_node)
except (ConnectionError, TimeoutError):
except BaseException as e:
for n in nodes.values():
n.connection_pool.release(n.connection)
# Connection retries are being handled in the node's
# Retry object. Reinitialize the node -> slot table.
self._nodes_manager.initialize()
if is_default_node:
self._pipe.replace_default_node()
n.connection = None
nodes = {}
if self._pipe.retry and self._pipe.retry.is_supported_error(e):
backoff = self._pipe.retry._backoff.compute(0)
if backoff > 0:
time.sleep(backoff)
if isinstance(e, (ConnectionError, TimeoutError)):
# Connection retries are being handled in the node's
# Retry object. Reinitialize the node -> slot table.
self._nodes_manager.initialize()
if is_default_node:
self._pipe.replace_default_node()
raise
nodes[node_name] = NodeCommands(
redis_node.parse_response,
Expand All @@ -3077,6 +3104,18 @@ def _send_cluster_commands(

for n in node_commands:
n.read()
except BaseException:
# if nodes is not empty, a problem must have occurred
# since we can't guarantee the state of the connections,
# disconnect before returning it to the connection pool
for n in nodes.values():
if n.connection:
n.connection.disconnect()
n.connection_pool.release(n.connection)
if len(nodes) > 0:
time.sleep(0.25)
nodes = {} # Clear to prevent double-release in finally
raise
finally:
# release all of the redis connections we allocated earlier
# back into the connection pool.
Expand Down
8 changes: 2 additions & 6 deletions redis/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5949,9 +5949,7 @@ def __call__(
except NoScriptError:
# Maybe the client is pointed to a different server than the client
# that created this instance?
# Overwrite the sha just in case there was a discrepancy.
self.sha = client.script_load(self.script)
return client.evalsha(self.sha, len(keys), *args)
return client.eval(self.script, len(keys), *args)

def get_encoder(self):
"""Get the encoder to encode string scripts into bytes."""
Expand Down Expand Up @@ -6020,9 +6018,7 @@ async def __call__(
except NoScriptError:
# Maybe the client is pointed to a different server than the client
# that created this instance?
# Overwrite the sha just in case there was a discrepancy.
self.sha = await client.script_load(self.script)
return await client.evalsha(self.sha, len(keys), *args)
return await client.eval(self.script, len(keys), *args)


class PubSubCommands(CommandsProtocol):
Expand Down
4 changes: 4 additions & 0 deletions redis/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,10 @@ def update_retries(self, value: int) -> None:
"""
self._retries = value

def is_supported_error(self, error: Exception) -> bool:
"""Check if the error is one of the supported error types."""
return isinstance(error, self._supported_errors)


class Retry(AbstractRetry[Exception]):
__hash__ = AbstractRetry.__hash__
Expand Down
45 changes: 45 additions & 0 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import pytest
import redis
from redis._parsers.encoders import Encoder
from redis.exceptions import ConnectionError

from .conftest import (
Expand Down Expand Up @@ -42,6 +43,31 @@ def wait_for_message(
return None


def test_on_connect_resubscribes_shard_channels_grouped_by_slot():
connection_pool = mock.Mock()
connection_pool.get_encoder.return_value = Encoder("utf-8", "strict", False)

pubsub = redis.client.PubSub(connection_pool)
handler_a = mock.Mock()
handler_b = mock.Mock()
handler_c = mock.Mock()
pubsub.shard_channels = {
b"{same-slot}:a": handler_a,
b"{same-slot}:b": handler_b,
b"{other-slot}:c": handler_c,
}
pubsub.ssubscribe = mock.Mock()

pubsub.on_connect(mock.Mock())

resubscribe_groups = [call.kwargs for call in pubsub.ssubscribe.call_args_list]
assert {
"{same-slot}:a": handler_a,
"{same-slot}:b": handler_b,
} in resubscribe_groups
assert {"{other-slot}:c": handler_c} in resubscribe_groups


def make_message(type, channel, data, pattern=None):
return {
"type": type,
Expand Down Expand Up @@ -1160,3 +1186,22 @@ def get_msg():

# the timeout on the read should not cause disconnect
assert is_connected()


@pytest.mark.onlynoncluster
class TestConnectionLeak:
def test_connection_leak(self, r: redis.Redis):
pubsub = r.pubsub()

def test():
tid = threading.get_ident()
pubsub.subscribe(f"foo{tid}")

threads = [threading.Thread(target=test) for _ in range(10)]
for thread in threads:
thread.start()

for thread in threads:
thread.join()

assert r.connection_pool._created_connections == 2