diff --git a/redis/lock.py b/redis/lock.py index e6070c4fed..adf90e0458 100644 --- a/redis/lock.py +++ b/redis/lock.py @@ -1,3 +1,4 @@ +import threading import time as mod_time import uuid from redis.exceptions import LockError, WatchError @@ -44,7 +45,8 @@ def __init__(self, redis, name, timeout=None, sleep=0.1, self.sleep = sleep self.blocking = blocking self.blocking_timeout = blocking_timeout - self.token = None + self.local = threading.local() + self.local.token = None if self.timeout and self.sleep > self.timeout: raise LockError("'sleep' must be less than 'timeout'") @@ -79,7 +81,7 @@ def acquire(self, blocking=None, blocking_timeout=None): stop_trying_at = mod_time.time() + self.blocking_timeout while 1: if self.do_acquire(token): - self.token = token + self.local.token = token return True if not blocking: return False @@ -98,10 +100,10 @@ def do_acquire(self, token): def release(self): "Releases the already acquired lock" - if self.token is None: + expected_token = self.local.token + if expected_token is None: raise LockError("Cannot release an unlocked lock") - expected_token = self.token - self.token = None + self.local.token = None self.do_release(expected_token) def do_release(self, expected_token): @@ -122,7 +124,7 @@ def extend(self, additional_time): ``additional_time`` can be specified as an integer or a float, both representing the number of seconds to add. """ - if self.token is None: + if self.local.token is None: raise LockError("Cannot extend an unlocked lock") if self.timeout is None: raise LockError("Cannot extend a lock with no timeout") @@ -132,7 +134,7 @@ def do_extend(self, additional_time): pipe = self.redis.pipeline() pipe.watch(self.name) lock_value = pipe.get(self.name) - if lock_value != self.token: + if lock_value != self.local.token: raise LockError("Cannot extend a lock that's no longer owned") expiration = pipe.pttl(self.name) if expiration is None or expiration < 0: @@ -236,7 +238,7 @@ def do_release(self, expected_token): def do_extend(self, additional_time): additional_time = int(additional_time * 1000) if not bool(self.lua_extend(keys=[self.name], - args=[self.token, additional_time], + args=[self.local.token, additional_time], client=self.redis)): raise LockError("Cannot extend a lock that's no longer owned") return True diff --git a/tests/test_lock.py b/tests/test_lock.py index 028f9a6021..d732ae1f1e 100644 --- a/tests/test_lock.py +++ b/tests/test_lock.py @@ -16,7 +16,7 @@ def get_lock(self, redis, *args, **kwargs): def test_lock(self, sr): lock = self.get_lock(sr, 'foo') assert lock.acquire(blocking=False) - assert sr.get('foo') == lock.token + assert sr.get('foo') == lock.local.token assert sr.ttl('foo') == -1 lock.release() assert sr.get('foo') is None @@ -56,7 +56,7 @@ def test_context_manager(self, sr): # blocking_timeout prevents a deadlock if the lock can't be acquired # for some reason with self.get_lock(sr, 'foo', blocking_timeout=0.2) as lock: - assert sr.get('foo') == lock.token + assert sr.get('foo') == lock.local.token assert sr.get('foo') is None def test_high_sleep_raises_error(self, sr): @@ -77,7 +77,7 @@ def test_releasing_lock_no_longer_owned_raises_error(self, sr): with pytest.raises(LockError): lock.release() # even though we errored, the token is still cleared - assert lock.token is None + assert lock.local.token is None def test_extend_lock(self, sr): lock = self.get_lock(sr, 'foo', timeout=10)