From 693f0b91b4381f5a672cb93ff8113abd1dc4957c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Kai=20M=C3=BChlbauer?= Date: Thu, 4 Jan 2024 07:41:35 +0100 Subject: [PATCH] ENH: vendor SerializableLock from dask and use as default backend lock, adapt tests (#8571) * vendor SerializableLock from dask, adapt tests * Update doc/whats-new.rst --------- Co-authored-by: Deepak Cherian --- doc/whats-new.rst | 2 + xarray/backends/locks.py | 84 +++++++++++++++++++++++++++++--- xarray/tests/test_backends.py | 2 - xarray/tests/test_distributed.py | 4 +- 4 files changed, 80 insertions(+), 12 deletions(-) diff --git a/doc/whats-new.rst b/doc/whats-new.rst index 3a6c15d1704..94b85ea224e 100644 --- a/doc/whats-new.rst +++ b/doc/whats-new.rst @@ -44,6 +44,8 @@ Bug fixes - Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`). By `Kai Mühlbauer `_. +- Vendor `SerializableLock` from dask and use as default lock for netcdf4 backends (:issue:`8442`, :pull:`8571`). + By `Kai Mühlbauer `_. Documentation diff --git a/xarray/backends/locks.py b/xarray/backends/locks.py index bba12a29609..045ee522fa8 100644 --- a/xarray/backends/locks.py +++ b/xarray/backends/locks.py @@ -2,15 +2,83 @@ import multiprocessing import threading +import uuid import weakref -from collections.abc import MutableMapping -from typing import Any - -try: - from dask.utils import SerializableLock -except ImportError: - # no need to worry about serializing the lock - SerializableLock = threading.Lock # type: ignore +from collections.abc import Hashable, MutableMapping +from typing import Any, ClassVar +from weakref import WeakValueDictionary + + +# SerializableLock is adapted from Dask: +# https://github.com/dask/dask/blob/74e898f0ec712e8317ba86cc3b9d18b6b9922be0/dask/utils.py#L1160-L1224 +# Used under the terms of Dask's license, see licenses/DASK_LICENSE. +class SerializableLock: + """A Serializable per-process Lock + + This wraps a normal ``threading.Lock`` object and satisfies the same + interface. However, this lock can also be serialized and sent to different + processes. It will not block concurrent operations between processes (for + this you should look at ``dask.multiprocessing.Lock`` or ``locket.lock_file`` + but will consistently deserialize into the same lock. + + So if we make a lock in one process:: + + lock = SerializableLock() + + And then send it over to another process multiple times:: + + bytes = pickle.dumps(lock) + a = pickle.loads(bytes) + b = pickle.loads(bytes) + + Then the deserialized objects will operate as though they were the same + lock, and collide as appropriate. + + This is useful for consistently protecting resources on a per-process + level. + + The creation of locks is itself not threadsafe. + """ + + _locks: ClassVar[ + WeakValueDictionary[Hashable, threading.Lock] + ] = WeakValueDictionary() + token: Hashable + lock: threading.Lock + + def __init__(self, token: Hashable | None = None): + self.token = token or str(uuid.uuid4()) + if self.token in SerializableLock._locks: + self.lock = SerializableLock._locks[self.token] + else: + self.lock = threading.Lock() + SerializableLock._locks[self.token] = self.lock + + def acquire(self, *args, **kwargs): + return self.lock.acquire(*args, **kwargs) + + def release(self, *args, **kwargs): + return self.lock.release(*args, **kwargs) + + def __enter__(self): + self.lock.__enter__() + + def __exit__(self, *args): + self.lock.__exit__(*args) + + def locked(self): + return self.lock.locked() + + def __getstate__(self): + return self.token + + def __setstate__(self, token): + self.__init__(token) + + def __str__(self): + return f"<{self.__class__.__name__}: {self.token}>" + + __repr__ = __str__ # Locks used by multiple backends. diff --git a/xarray/tests/test_backends.py b/xarray/tests/test_backends.py index 86986cedb06..dac241834d1 100644 --- a/xarray/tests/test_backends.py +++ b/xarray/tests/test_backends.py @@ -432,8 +432,6 @@ def test_dataset_compute(self) -> None: assert_identical(expected, computed) def test_pickle(self) -> None: - if not has_dask: - pytest.xfail("pickling requires dask for SerializableLock") expected = Dataset({"foo": ("x", [42])}) with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped: with roundtripped: diff --git a/xarray/tests/test_distributed.py b/xarray/tests/test_distributed.py index bfc37121597..aa53bcf329b 100644 --- a/xarray/tests/test_distributed.py +++ b/xarray/tests/test_distributed.py @@ -27,7 +27,7 @@ ) import xarray as xr -from xarray.backends.locks import HDF5_LOCK, CombinedLock +from xarray.backends.locks import HDF5_LOCK, CombinedLock, SerializableLock from xarray.tests import ( assert_allclose, assert_identical, @@ -273,7 +273,7 @@ async def test_async(c, s, a, b) -> None: def test_hdf5_lock() -> None: - assert isinstance(HDF5_LOCK, dask.utils.SerializableLock) + assert isinstance(HDF5_LOCK, SerializableLock) @gen_cluster(client=True)