Skip to content

Commit

Permalink
ENH: vendor SerializableLock from dask and use as default backend loc…
Browse files Browse the repository at this point in the history
…k, adapt tests (#8571)

* vendor SerializableLock from dask, adapt tests
* Update doc/whats-new.rst

---------

Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com>
  • Loading branch information
kmuehlbauer and dcherian authored Jan 4, 2024
1 parent 92f79a0 commit 693f0b9
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 12 deletions.
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ Bug fixes

- Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Vendor `SerializableLock` from dask and use as default lock for netcdf4 backends (:issue:`8442`, :pull:`8571`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.


Documentation
Expand Down
84 changes: 76 additions & 8 deletions xarray/backends/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,83 @@

import multiprocessing
import threading
import uuid
import weakref
from collections.abc import MutableMapping
from typing import Any

try:
from dask.utils import SerializableLock
except ImportError:
# no need to worry about serializing the lock
SerializableLock = threading.Lock # type: ignore
from collections.abc import Hashable, MutableMapping
from typing import Any, ClassVar
from weakref import WeakValueDictionary


# SerializableLock is adapted from Dask:
# https://github.com/dask/dask/blob/74e898f0ec712e8317ba86cc3b9d18b6b9922be0/dask/utils.py#L1160-L1224
# Used under the terms of Dask's license, see licenses/DASK_LICENSE.
class SerializableLock:
"""A Serializable per-process Lock
This wraps a normal ``threading.Lock`` object and satisfies the same
interface. However, this lock can also be serialized and sent to different
processes. It will not block concurrent operations between processes (for
this you should look at ``dask.multiprocessing.Lock`` or ``locket.lock_file``
but will consistently deserialize into the same lock.
So if we make a lock in one process::
lock = SerializableLock()
And then send it over to another process multiple times::
bytes = pickle.dumps(lock)
a = pickle.loads(bytes)
b = pickle.loads(bytes)
Then the deserialized objects will operate as though they were the same
lock, and collide as appropriate.
This is useful for consistently protecting resources on a per-process
level.
The creation of locks is itself not threadsafe.
"""

_locks: ClassVar[
WeakValueDictionary[Hashable, threading.Lock]
] = WeakValueDictionary()
token: Hashable
lock: threading.Lock

def __init__(self, token: Hashable | None = None):
self.token = token or str(uuid.uuid4())
if self.token in SerializableLock._locks:
self.lock = SerializableLock._locks[self.token]
else:
self.lock = threading.Lock()
SerializableLock._locks[self.token] = self.lock

def acquire(self, *args, **kwargs):
return self.lock.acquire(*args, **kwargs)

def release(self, *args, **kwargs):
return self.lock.release(*args, **kwargs)

def __enter__(self):
self.lock.__enter__()

def __exit__(self, *args):
self.lock.__exit__(*args)

def locked(self):
return self.lock.locked()

def __getstate__(self):
return self.token

def __setstate__(self, token):
self.__init__(token)

def __str__(self):
return f"<{self.__class__.__name__}: {self.token}>"

__repr__ = __str__


# Locks used by multiple backends.
Expand Down
2 changes: 0 additions & 2 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,6 @@ def test_dataset_compute(self) -> None:
assert_identical(expected, computed)

def test_pickle(self) -> None:
if not has_dask:
pytest.xfail("pickling requires dask for SerializableLock")
expected = Dataset({"foo": ("x", [42])})
with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped:
with roundtripped:
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)

import xarray as xr
from xarray.backends.locks import HDF5_LOCK, CombinedLock
from xarray.backends.locks import HDF5_LOCK, CombinedLock, SerializableLock
from xarray.tests import (
assert_allclose,
assert_identical,
Expand Down Expand Up @@ -273,7 +273,7 @@ async def test_async(c, s, a, b) -> None:


def test_hdf5_lock() -> None:
assert isinstance(HDF5_LOCK, dask.utils.SerializableLock)
assert isinstance(HDF5_LOCK, SerializableLock)


@gen_cluster(client=True)
Expand Down

0 comments on commit 693f0b9

Please sign in to comment.