Skip to content

Commit

Permalink
Remove force=True from release and check if lock is actually held. Ra…
Browse files Browse the repository at this point in the history
…ise error if expired or not acquired. Closes #25.
  • Loading branch information
ionelmc committed Dec 26, 2015
1 parent 15749c0 commit ecb405b
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 59 deletions.
16 changes: 6 additions & 10 deletions docs/usage.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,9 @@ The above example could be rewritten using context manager::
time.sleep(5)

In cases, where lock not necessarily in acquired state, and
user need to ensure, that it's released, ``force`` parameter could be used::

lock = Lock(conn, "foo")
try:
if lock.acquire(block=False):
print("Got the lock. Do crazy dance")
else:
print("Didn't get the lock. Do normal dance")
finally:
lock.release(force=True)
user need to ensure, that it has a matching ``id``, example::

lock1 = Lock(conn, "foo")
lock1.acquire()
lock2 = Lock(conn, "foo", id=lock1.id)
lock2.release()
50 changes: 30 additions & 20 deletions src/redis_lock/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@

logger = getLogger(__name__)

# Check if the id match. If not, return an error code.
UNLOCK_SCRIPT = b"""
if redis.call("get", KEYS[1]) == ARGV[1] then
if redis.call("get", KEYS[1]) ~= ARGV[1] then
return 1
else
redis.call("del", KEYS[2])
redis.call("lpush", KEYS[2], 1)
redis.call("expire", KEYS[2], 1)
return redis.call("del", KEYS[1])
else
redis.call("del", KEYS[1])
return 0
end
"""
Expand Down Expand Up @@ -128,6 +130,9 @@ def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False)
:param id:
The ID (redis value) the lock should have. A random value is
generated when left at the default.
Note that if you specify this then the lock is marked as "held". Acquires
won't be possible.
:param auto_renewal:
If set to True, Lock will automatically renew the lock so that it
doesn't expire for as long as the lock is held (acquire() called
Expand All @@ -145,7 +150,7 @@ def __init__(self, redis_client, name, expire=None, id=None, auto_renewal=False)
self._client = redis_client
self._expire = expire if expire is None else int(expire)
self._id = urandom(16) if id is None else id
self._held = False
self._held = id is not None
self._name = 'lock:'+name
self._signal = 'lock-signal:'+name
self._lock_renewal_interval = expire*2/3 if auto_renewal else None
Expand Down Expand Up @@ -280,26 +285,31 @@ def __enter__(self):
assert acquired, "Lock wasn't acquired, but blocking=True"
return self

def __exit__(self, exc_type=None, exc_value=None, traceback=None, force=False):
if not (self._held or force):
raise NotAcquired("This Lock instance didn't acquire the lock.")
if self._lock_renewal_thread is not None:
self._stop_lock_renewer()
logger.debug("Releasing %r.", self._name)
_eval_script(self._client, UNLOCK,
2, self._name, self._signal, self._id)
def __exit__(self, exc_type=None, exc_value=None, traceback=None):
self.release()

def release(self):
"""Releases the lock, that was acquired with the same object.
self._held = False
.. note::
def release(self, force=False):
"""Releases the lock, that was acquired in the same Python context.
If you want to release a lock that you acquired in a different place you have two choices:
:param force:
If ``False`` - fail with exception if this instance was not in
acquired state in the same Python context.
If ``True`` - fail silently.
* Use ``Lock("name", id=id_from_other_place).release()``
* Use ``Lock("name").reset()``
"""
return self.__exit__(force=force)
if not self._held:
raise NotAcquired("This Lock instance didn't acquire the lock.")
if self._lock_renewal_thread is not None:
self._stop_lock_renewer()
logger.debug("Releasing %r.", self._name)
error = _eval_script(self._client, UNLOCK, self._name, self._signal, args=(self._id,))
if error == 1:
raise NotAcquired("Lock %s is not acquired or it already expired." % self._name)
elif error:
raise RuntimeError("Unsupported error code %s from EXTEND script." % error)
else:
self._held = False


class InterruptableThread(threading.Thread):
Expand Down
2 changes: 1 addition & 1 deletion tests/helper.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import print_function
from __future__ import print_function, division

import logging
import os
Expand Down
66 changes: 38 additions & 28 deletions tests/test_redis_lock.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from __future__ import print_function
from __future__ import print_function, division

import os
import platform
Expand Down Expand Up @@ -105,9 +105,10 @@ def test_timeout(conn):


def test_timeout_expire(conn):
with Lock(conn, "foobar", expire=1):
lock = Lock(conn, "foobar")
assert lock.acquire(timeout=2)
lock1 = Lock(conn, "foobar", expire=1)
lock1.acquire()
lock2 = Lock(conn, "foobar")
assert lock2.acquire(timeout=2)


def test_timeout_expire_with_renewal(conn):
Expand Down Expand Up @@ -152,18 +153,19 @@ def test_invalid_timeout(conn):


def test_expire(conn):
with Lock(conn, "foobar", expire=TIMEOUT/4):
with TestProcess(sys.executable, HELPER, 'test_expire') as proc:
with dump_on_error(proc.read):
name = 'lock:foobar'
wait_for_strings(
proc.read, TIMEOUT,
'Getting %r ...' % name,
'Got lock for %r.' % name,
'Releasing %r.' % name,
'UNLOCK_SCRIPT not cached.',
'DIED.',
)
lock = Lock(conn, "foobar", expire=TIMEOUT/4)
lock.acquire()
with TestProcess(sys.executable, HELPER, 'test_expire') as proc:
with dump_on_error(proc.read):
name = 'lock:foobar'
wait_for_strings(
proc.read, TIMEOUT,
'Getting %r ...' % name,
'Got lock for %r.' % name,
'Releasing %r.' % name,
'UNLOCK_SCRIPT not cached.',
'DIED.',
)
lock = Lock(conn, "foobar")
try:
assert lock.acquire(blocking=False) == True
Expand Down Expand Up @@ -216,11 +218,11 @@ def test_extend_another_instance(conn):
"""
name = 'foobar'
key_name = 'lock:' + name
lock = Lock(conn, name, id='spam', expire=100)
lock = Lock(conn, name, expire=100)
lock.acquire()
assert 0 <= conn.ttl(key_name) <= 100

another_lock = Lock(conn, name, id='spam')
another_lock = Lock(conn, name, id=lock.id)
another_lock.extend(1000)

assert conn.ttl(key_name) > 100
Expand All @@ -232,15 +234,16 @@ def test_extend_another_instance_different_id_fail(conn):
"""
name = 'foobar'
key_name = 'lock:' + name
lock = Lock(conn, name, expire=100, id='spam')
lock = Lock(conn, name, expire=100)
lock.acquire()
assert 0 <= conn.ttl(key_name) <= 100

another_lock = Lock(conn, name, id='eggs')
another_lock = Lock(conn, name)
with pytest.raises(NotAcquired):
another_lock.extend(1000)

assert conn.ttl(key_name) <= 100
assert lock.id != another_lock.id


def test_double_acquire(conn):
Expand Down Expand Up @@ -353,11 +356,12 @@ def workerfn(go, count_lock, count):


def test_reset(conn):
with Lock(conn, "foobar") as lock:
lock.reset()
new_lock = Lock(conn, "foobar")
new_lock.acquire(blocking=False)
new_lock.release()
lock = Lock(conn, "foobar")
lock.reset()
new_lock = Lock(conn, "foobar")
new_lock.acquire(blocking=False)
new_lock.release()
pytest.raises(NotAcquired, lock.release)


def test_reset_all(conn):
Expand All @@ -379,8 +383,12 @@ def test_owner_id(conn):
lock = Lock(conn, "foobar-tok", expire=TIMEOUT/4, id=unique_identifier)
lock_id = lock.id
assert lock_id == unique_identifier
lock.acquire(blocking=False)
assert lock.get_owner_id() == unique_identifier


def test_get_owner_id(conn):
lock = Lock(conn, "foobar-tok")
lock.acquire()
assert lock.get_owner_id() == lock.id
lock.release()


Expand All @@ -395,7 +403,9 @@ def test_token(conn):
def test_bogus_release(conn):
lock = Lock(conn, "foobar-tok")
pytest.raises(NotAcquired, lock.release)
lock.release(force=True)
lock.acquire()
lock2 = Lock(conn, "foobar-tok", id=lock.id)
lock2.release()


def test_release_from_nonblocking_leaving_garbage(conn):
Expand Down

0 comments on commit ecb405b

Please sign in to comment.