Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH: vendor SerializableLock from dask and use as default backend lock, adapt tests #8571

Merged
merged 7 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/whats-new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ Bug fixes

- Reverse index output of bottleneck's rolling move_argmax/move_argmin functions (:issue:`8541`, :pull:`8552`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.
- Vendor `SerializableLock` from dask and use as default lock for netcdf4 backends (:issue:`8442`, :pull:`8571`).
By `Kai Mühlbauer <https://github.com/kmuehlbauer>`_.


Documentation
Expand Down
84 changes: 76 additions & 8 deletions xarray/backends/locks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,83 @@

import multiprocessing
import threading
import uuid
import weakref
from collections.abc import MutableMapping
from typing import Any

try:
from dask.utils import SerializableLock
except ImportError:
# no need to worry about serializing the lock
SerializableLock = threading.Lock # type: ignore
from collections.abc import Hashable, MutableMapping
from typing import Any, ClassVar
from weakref import WeakValueDictionary


# SerializableLock is adapted from Dask:
# https://github.com/dask/dask/blob/74e898f0ec712e8317ba86cc3b9d18b6b9922be0/dask/utils.py#L1160-L1224
# Used under the terms of Dask's license, see licenses/DASK_LICENSE.
class SerializableLock:
"""A Serializable per-process Lock

This wraps a normal ``threading.Lock`` object and satisfies the same
interface. However, this lock can also be serialized and sent to different
processes. It will not block concurrent operations between processes (for
this you should look at ``dask.multiprocessing.Lock`` or ``locket.lock_file``
but will consistently deserialize into the same lock.

So if we make a lock in one process::

lock = SerializableLock()

And then send it over to another process multiple times::

bytes = pickle.dumps(lock)
a = pickle.loads(bytes)
b = pickle.loads(bytes)

Then the deserialized objects will operate as though they were the same
lock, and collide as appropriate.

This is useful for consistently protecting resources on a per-process
level.

The creation of locks is itself not threadsafe.
"""

_locks: ClassVar[
WeakValueDictionary[Hashable, threading.Lock]
] = WeakValueDictionary()
token: Hashable
lock: threading.Lock

def __init__(self, token: Hashable | None = None):
self.token = token or str(uuid.uuid4())
if self.token in SerializableLock._locks:
self.lock = SerializableLock._locks[self.token]
else:
self.lock = threading.Lock()
SerializableLock._locks[self.token] = self.lock

def acquire(self, *args, **kwargs):
return self.lock.acquire(*args, **kwargs)

def release(self, *args, **kwargs):
return self.lock.release(*args, **kwargs)

def __enter__(self):
self.lock.__enter__()

def __exit__(self, *args):
self.lock.__exit__(*args)

def locked(self):
return self.lock.locked()

def __getstate__(self):
return self.token

def __setstate__(self, token):
self.__init__(token)

def __str__(self):
return f"<{self.__class__.__name__}: {self.token}>"

__repr__ = __str__


# Locks used by multiple backends.
Expand Down
2 changes: 0 additions & 2 deletions xarray/tests/test_backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,8 +432,6 @@ def test_dataset_compute(self) -> None:
assert_identical(expected, computed)

def test_pickle(self) -> None:
if not has_dask:
pytest.xfail("pickling requires dask for SerializableLock")
expected = Dataset({"foo": ("x", [42])})
with self.roundtrip(expected, allow_cleanup_failure=ON_WINDOWS) as roundtripped:
with roundtripped:
Expand Down
4 changes: 2 additions & 2 deletions xarray/tests/test_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)

import xarray as xr
from xarray.backends.locks import HDF5_LOCK, CombinedLock
from xarray.backends.locks import HDF5_LOCK, CombinedLock, SerializableLock
from xarray.tests import (
assert_allclose,
assert_identical,
Expand Down Expand Up @@ -273,7 +273,7 @@ async def test_async(c, s, a, b) -> None:


def test_hdf5_lock() -> None:
assert isinstance(HDF5_LOCK, dask.utils.SerializableLock)
assert isinstance(HDF5_LOCK, SerializableLock)


@gen_cluster(client=True)
Expand Down
Loading