Skip to content

Commit

Permalink
Test and improve the synchronization code
Browse files Browse the repository at this point in the history
Added
-----

* Add tests to bring `Synchronization.py` to 100% coverage
* Add type annotations to the synchronization primitives
* Add `typing_extensions` as a Python 3.9 requirement;
  this is needed to annotate that the `synchronized()` decorator
  takes a function with a certain set of parameter and return types
  and returns a function with the same parameter and return types

Changed
-------

* Make the synchronized function's `self` parameter explicit;
  this allows the `self` parameter to be type-annotated
  so the dependency on the `self.mutex` attribute is explicit
* Use the `self.mutex` lock as a context manager
* Support keyword arguments to synchronized functions

Fixed
-----

* Wrap synchronized functions correctly;
  previous behavior was to lose the function name and docstring

Removed
-------

* Remove `print()` lines that are commented out
* Remove a `bytes` instance check; method names can only be strings
  • Loading branch information
kurtmckee committed Oct 13, 2024
1 parent f162aa2 commit d984eec
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 16 deletions.
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ def run(self):
swig_opts=["-outdir", "src/smartcard/scard"] + platform_swig_opts,
)
],
"install_requires": [
"typing_extensions; python_version=='3.9'",
],
"extras_require": {
"Gui": ["wxPython"],
},
Expand Down
44 changes: 28 additions & 16 deletions src/smartcard/Synchronization.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,39 +8,51 @@
keyword, from Peter Norvig.
"""

from threading import RLock
from __future__ import annotations

import functools
import sys
import threading
from collections.abc import Iterable
from typing import Any, Callable, Protocol, TypeVar

def synchronized(method):
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec

def f(*args):
self = args[0]
self.mutex.acquire()
# print(method.__name__, 'acquired')
try:
return method(*args)
finally:
self.mutex.release()
# print(method.__name__, 'released')

T = TypeVar("T")
P = ParamSpec("P")


def synchronized(method: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(method)
def f(self: _SynchronizationProtocol, *args: Any, **kwargs: Any) -> Any:
with self.mutex:
return method(self, *args, **kwargs)

return f


def synchronize(klass, names=None):
def synchronize(klass: type, names: str | Iterable[str] | None = None) -> None:
"""Synchronize methods in the given class.
Only synchronize the methods whose names are
given, or all methods if names=None."""

if isinstance(names, (str, bytes)):
if isinstance(names, str):
names = names.split()
for name, val in list(klass.__dict__.items()):
if callable(val) and name != "__init__" and (names is None or name in names):
# print("synchronizing", name)
setattr(klass, name, synchronized(val))


class Synchronization:
class _SynchronizationProtocol(Protocol):
mutex: threading.Lock | threading.RLock


class Synchronization(_SynchronizationProtocol):
# You can create your own self.mutex, or inherit from this class:

def __init__(self):
self.mutex = RLock()
self.mutex = threading.RLock()
93 changes: 93 additions & 0 deletions test/test_synchronization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
from __future__ import annotations

import pytest

import smartcard.Synchronization


@pytest.mark.parametrize(
"defined_methods, names, modified_methods",
(
# Nothing should be wrapped (str version)
pytest.param(set(), "", set(), id="wrap nothing (str)"),
pytest.param({"a"}, "", set(), id="wrap nothing (a, str)"),
pytest.param({"a", "b"}, "", set(), id="wrap nothing (a+b, str)"),
# Nothing should be wrapped (iterable version)
pytest.param(set(), [], set(), id="wrap nothing (list)"),
pytest.param({"a"}, [], set(), id="wrap nothing (a, list)"),
pytest.param({"a", "b"}, [], set(), id="wrap nothing (a+b, list)"),
# Everything should be wrapped
pytest.param(set(), None, set(), id="wrap all"),
pytest.param({"a"}, None, {"a"}, id="wrap all (a)"),
pytest.param({"a", "b"}, None, {"a", "b"}, id="wrap all (a+b)"),
# Only "a" should be wrapped (str version)
pytest.param({"a"}, "a", {"a"}, id="wrap a only (a, str)"),
pytest.param({"a", "b"}, "a", {"a"}, id="wrap a only (a+b, str)"),
# Only "a" should be wrapped (list version)
pytest.param({"a"}, ["a"], {"a"}, id="wrap a only (a, list)"),
pytest.param({"a", "b"}, ["a"], {"a"}, id="wrap a only (a+b, list)"),
),
)
def test_synchronize(
defined_methods: set[str],
names: None | str | list[str],
modified_methods: set[str],
):
"""Verify synchronize() wraps class methods as expected."""

method_map = {method: lambda self: None for method in defined_methods}
class_ = type("A", (object,), method_map)

smartcard.Synchronization.synchronize(class_, names)

for modified_method in modified_methods:
assert getattr(class_, modified_method) is not method_map[modified_method]
for unmodified_method in defined_methods - modified_methods:
assert getattr(class_, unmodified_method) is method_map[unmodified_method]


def test_synchronization_reentrant_lock():
"""Verify Synchronization mutex locks are re-entrant by default."""

class A(smartcard.Synchronization.Synchronization):
def level_1(self):
self.level_2()

def level_2(self):
return self

smartcard.Synchronization.synchronize(A)

instance = A()
# If the synchronization lock is NOT re-entrant by default,
# the test suite will hang when it reaches this line.
instance.level_1()


def test_synchronization_wrapping():
"""Verify synchronized functions have correct names and docstrings."""

class A(smartcard.Synchronization.Synchronization):
def apple(self):
"""KEEP ME"""

smartcard.Synchronization.synchronize(A)

assert A.apple.__name__ == "apple"
assert "KEEP ME" in A.apple.__doc__


def test_synchronization_kwargs():
"""Verify synchronized functions support arguments and keyword arguments."""

class A(smartcard.Synchronization.Synchronization):
def positional_only(self, positional, /):
return positional

def keyword_only(self, *, keyword):
return keyword

smartcard.Synchronization.synchronize(A)

A().positional_only(True)
A().keyword_only(keyword=True)

0 comments on commit d984eec

Please sign in to comment.