From d984eecf0e3481c9c80b3d6921ca19d43b9da2f4 Mon Sep 17 00:00:00 2001 From: Kurt McKee Date: Sun, 13 Oct 2024 17:15:19 -0500 Subject: [PATCH] Test and improve the synchronization code Added ----- * Add tests to bring `Synchronization.py` to 100% coverage * Add type annotations to the synchronization primitives * Add `typing_extensions` as a Python 3.9 requirement; this is needed to annotate that the `synchronized()` decorator takes a function with a certain set of parameter and return types and returns a function with the same parameter and return types Changed ------- * Make the synchronized function's `self` parameter explicit; this allows the `self` parameter to be type-annotated so the dependency on the `self.mutex` attribute is explicit * Use the `self.mutex` lock as a context manager * Support keyword arguments to synchronized functions Fixed ----- * Wrap synchronized functions correctly; previous behavior was to lose the function name and docstring Removed ------- * Remove `print()` lines that are commented out * Remove a `bytes` instance check; method names can only be strings --- setup.py | 3 ++ src/smartcard/Synchronization.py | 44 +++++++++------ test/test_synchronization.py | 93 ++++++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+), 16 deletions(-) create mode 100644 test/test_synchronization.py diff --git a/setup.py b/setup.py index 47fa1d9c..078e398a 100644 --- a/setup.py +++ b/setup.py @@ -128,6 +128,9 @@ def run(self): swig_opts=["-outdir", "src/smartcard/scard"] + platform_swig_opts, ) ], + "install_requires": [ + "typing_extensions; python_version=='3.9'", + ], "extras_require": { "Gui": ["wxPython"], }, diff --git a/src/smartcard/Synchronization.py b/src/smartcard/Synchronization.py index c30fa580..cf6b0e17 100644 --- a/src/smartcard/Synchronization.py +++ b/src/smartcard/Synchronization.py @@ -8,39 +8,51 @@ keyword, from Peter Norvig. """ -from threading import RLock +from __future__ import annotations +import functools +import sys +import threading +from collections.abc import Iterable +from typing import Any, Callable, Protocol, TypeVar -def synchronized(method): +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec - def f(*args): - self = args[0] - self.mutex.acquire() - # print(method.__name__, 'acquired') - try: - return method(*args) - finally: - self.mutex.release() - # print(method.__name__, 'released') + +T = TypeVar("T") +P = ParamSpec("P") + + +def synchronized(method: Callable[P, T]) -> Callable[P, T]: + @functools.wraps(method) + def f(self: _SynchronizationProtocol, *args: Any, **kwargs: Any) -> Any: + with self.mutex: + return method(self, *args, **kwargs) return f -def synchronize(klass, names=None): +def synchronize(klass: type, names: str | Iterable[str] | None = None) -> None: """Synchronize methods in the given class. Only synchronize the methods whose names are given, or all methods if names=None.""" - if isinstance(names, (str, bytes)): + if isinstance(names, str): names = names.split() for name, val in list(klass.__dict__.items()): if callable(val) and name != "__init__" and (names is None or name in names): - # print("synchronizing", name) setattr(klass, name, synchronized(val)) -class Synchronization: +class _SynchronizationProtocol(Protocol): + mutex: threading.Lock | threading.RLock + + +class Synchronization(_SynchronizationProtocol): # You can create your own self.mutex, or inherit from this class: def __init__(self): - self.mutex = RLock() + self.mutex = threading.RLock() diff --git a/test/test_synchronization.py b/test/test_synchronization.py new file mode 100644 index 00000000..996ba445 --- /dev/null +++ b/test/test_synchronization.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import pytest + +import smartcard.Synchronization + + +@pytest.mark.parametrize( + "defined_methods, names, modified_methods", + ( + # Nothing should be wrapped (str version) + pytest.param(set(), "", set(), id="wrap nothing (str)"), + pytest.param({"a"}, "", set(), id="wrap nothing (a, str)"), + pytest.param({"a", "b"}, "", set(), id="wrap nothing (a+b, str)"), + # Nothing should be wrapped (iterable version) + pytest.param(set(), [], set(), id="wrap nothing (list)"), + pytest.param({"a"}, [], set(), id="wrap nothing (a, list)"), + pytest.param({"a", "b"}, [], set(), id="wrap nothing (a+b, list)"), + # Everything should be wrapped + pytest.param(set(), None, set(), id="wrap all"), + pytest.param({"a"}, None, {"a"}, id="wrap all (a)"), + pytest.param({"a", "b"}, None, {"a", "b"}, id="wrap all (a+b)"), + # Only "a" should be wrapped (str version) + pytest.param({"a"}, "a", {"a"}, id="wrap a only (a, str)"), + pytest.param({"a", "b"}, "a", {"a"}, id="wrap a only (a+b, str)"), + # Only "a" should be wrapped (list version) + pytest.param({"a"}, ["a"], {"a"}, id="wrap a only (a, list)"), + pytest.param({"a", "b"}, ["a"], {"a"}, id="wrap a only (a+b, list)"), + ), +) +def test_synchronize( + defined_methods: set[str], + names: None | str | list[str], + modified_methods: set[str], +): + """Verify synchronize() wraps class methods as expected.""" + + method_map = {method: lambda self: None for method in defined_methods} + class_ = type("A", (object,), method_map) + + smartcard.Synchronization.synchronize(class_, names) + + for modified_method in modified_methods: + assert getattr(class_, modified_method) is not method_map[modified_method] + for unmodified_method in defined_methods - modified_methods: + assert getattr(class_, unmodified_method) is method_map[unmodified_method] + + +def test_synchronization_reentrant_lock(): + """Verify Synchronization mutex locks are re-entrant by default.""" + + class A(smartcard.Synchronization.Synchronization): + def level_1(self): + self.level_2() + + def level_2(self): + return self + + smartcard.Synchronization.synchronize(A) + + instance = A() + # If the synchronization lock is NOT re-entrant by default, + # the test suite will hang when it reaches this line. + instance.level_1() + + +def test_synchronization_wrapping(): + """Verify synchronized functions have correct names and docstrings.""" + + class A(smartcard.Synchronization.Synchronization): + def apple(self): + """KEEP ME""" + + smartcard.Synchronization.synchronize(A) + + assert A.apple.__name__ == "apple" + assert "KEEP ME" in A.apple.__doc__ + + +def test_synchronization_kwargs(): + """Verify synchronized functions support arguments and keyword arguments.""" + + class A(smartcard.Synchronization.Synchronization): + def positional_only(self, positional, /): + return positional + + def keyword_only(self, *, keyword): + return keyword + + smartcard.Synchronization.synchronize(A) + + A().positional_only(True) + A().keyword_only(keyword=True)