Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AnyIO and type-hints; other modernization. #22

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion fair_async_rlock/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from fair_async_rlock.fair_async_rlock import *
from .fair_async_rlock import FairAsyncRLock

__all__ = ["FairAsyncRLock"]
52 changes: 27 additions & 25 deletions fair_async_rlock/fair_async_rlock.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,31 @@
import asyncio
from collections import deque
from typing import List, Union

__all__ = [
'FairAsyncRLock'
]
import anyio

__all__: List[str] = ["FairAsyncRLock"]


class FairAsyncRLock:
"""
A fair reentrant lock for async programming. Fair means that it respects the order of acquisition.
"""
"""A fair reentrant lock for async programming. Fair means that it respects the order of acquisition."""

def __init__(self):
self._owner: asyncio.Task | None = None
def __init__(self) -> None:
self._owner: Union[anyio.TaskInfo, None] = None
self._count = 0
self._owner_transfer = False
self._queue = deque()
self._queue: deque[anyio.Event] = deque()

def is_owner(self, task=None):
def is_owner(self, task: Union[anyio.TaskInfo, None] = None) -> bool:
if task is None:
task = asyncio.current_task()
task = anyio.get_current_task()
return self._owner == task

def locked(self) -> bool:
return self._owner is not None

async def acquire(self):
async def acquire(self) -> None:
"""Acquire the lock."""
me = asyncio.current_task()
me = anyio.get_current_task()

# If the lock is reentrant, acquire it immediately
if self.is_owner(task=me):
Expand All @@ -41,7 +39,7 @@ async def acquire(self):
return

# Create an event for this task, to notify when it's ready for acquire
event = asyncio.Event()
event = anyio.Event()
self._queue.append(event)

# Wait for the lock to be free, then acquire
Expand All @@ -50,17 +48,19 @@ async def acquire(self):
self._owner_transfer = False
self._owner = me
self._count = 1
except asyncio.CancelledError:
except anyio.get_cancelled_exc_class():
try: # if in queue, then cancelled before release
self._queue.remove(event)
except ValueError: # otherwise, release happened, this was next, and we simulate passing on
except (
ValueError
): # otherwise, release happened, this was next, and we simulate passing on
self._owner_transfer = False
self._owner = me
self._count = 1
self._current_task_release()
raise

def _current_task_release(self):
def _current_task_release(self) -> None:
self._count -= 1
if self._count == 0:
self._owner = None
Expand All @@ -71,21 +71,23 @@ def _current_task_release(self):
# Setting this here prevents another task getting lock until owner transfer.
self._owner_transfer = True

def release(self):
"""Release the lock"""
me = asyncio.current_task()
def release(self) -> None:
"""Release the lock."""
me = anyio.get_current_task()

if self._owner is None:
raise RuntimeError(f"Cannot release un-acquired lock. {me} tried to release.")
msg = f"Cannot release un-acquired lock. {me} tried to release."
raise RuntimeError(msg)

if not self.is_owner(task=me):
raise RuntimeError(f"Cannot release foreign lock. {me} tried to unlock {self._owner}.")
msg = f"Cannot release foreign lock. {me} tried to unlock {self._owner}."
raise RuntimeError(msg)

self._current_task_release()

async def __aenter__(self):
async def __aenter__(self) -> "FairAsyncRLock":
await self.acquire()
return self

async def __aexit__(self, exc_type, exc, tb):
async def __aexit__(self, *exc: object) -> None:
self.release()
Empty file added fair_async_rlock/py.typed
Empty file.
Empty file.
Loading
Loading