Skip to content

Disconnect clients from pool safely #753

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from distutils.version import StrictVersion
from itertools import chain
from select import select
from copy import copy

import os
import socket
import sys
Expand Down Expand Up @@ -534,6 +536,26 @@ def disconnect(self):
pass
self._sock = None

def shutdown_socket(self):
"""
Shutdown the socket hold by the current connection, called from
the connection pool class u other manager to singal it that has to be
disconnected in a thread safe way. Later the connection instance
will get an error and will call `disconnect` by it self.
"""
try:
self._sock.shutdown(socket.SHUT_RDWR)
except AttributeError:
# either _sock attribute does not exist or
# connection thread removed it.
pass
except OSError as e:
if e.errno == 107:
# Transport endpoint is not connected
pass
else:
raise

def send_packed_command(self, command):
"Send an already packed command to the Redis server"
if not self._sock:
Expand Down Expand Up @@ -950,10 +972,10 @@ def release(self, connection):

def disconnect(self):
"Disconnects all connections in the pool"
all_conns = chain(self._available_connections,
self._in_use_connections)
all_conns = chain(copy(self._available_connections),
copy(self._in_use_connections))
for connection in all_conns:
connection.disconnect()
connection.shutdown_socket()


class BlockingConnectionPool(ConnectionPool):
Expand Down Expand Up @@ -1072,4 +1094,4 @@ def release(self, connection):
def disconnect(self):
"Disconnects all connections in the pool."
for connection in self._connections:
connection.disconnect()
connection.shutdown_socket()
25 changes: 25 additions & 0 deletions tests/test_connection_pool.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import with_statement

import os
import pytest
import redis
Expand Down Expand Up @@ -69,6 +70,30 @@ def test_repr_contains_db_info_unix(self):
expected = 'ConnectionPool<UnixDomainSocketConnection<path=/abc,db=1>>'
assert repr(pool) == expected

def test_disconnect_active_connections(self):

class MyConnection(redis.Connection):

connect_calls = 0

def __init__(self, *args, **kwargs):
super(MyConnection, self).__init__(*args, **kwargs)
self.register_connect_callback(self.count_connect)

def count_connect(self, connection):
MyConnection.connect_calls += 1

pool = self.get_pool(connection_class=MyConnection)
r = redis.StrictRedis(connection_pool=pool)
r.ping()
pool.disconnect()
r.ping()

# If the connection is not disconnected by the pool the
# callback belonging to Connection will be called just
# one time.
assert MyConnection.connect_calls == 2


class TestBlockingConnectionPool(object):
def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20):
Expand Down