Skip to content

Commit

Permalink
Add a lock to distributed.profile for better concurrency control (#…
Browse files Browse the repository at this point in the history
…6421)

Adds a Lock to distributed.profile to enable better concurrency control. In particular, it allows running garbage collection without a profiling thread holding references to objects, which is necessary for #6250.
  • Loading branch information
hendrikmakait authored May 25, 2022
1 parent 0a77946 commit dea9ef2
Show file tree
Hide file tree
Showing 10 changed files with 94 additions and 66 deletions.
39 changes: 13 additions & 26 deletions distributed/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@
from distributed.metrics import time
from distributed.utils import color_of

#: This lock can be acquired to ensure that no instance of watch() is concurrently holding references to frames
lock = threading.Lock()


def identifier(frame: FrameType | None) -> str:
"""A string identifier from a frame
Expand Down Expand Up @@ -314,18 +317,6 @@ def traverse(state, start, stop, height):
}


_watch_running: set[int] = set()


def wait_profiler() -> None:
"""Wait until a moment when no instances of watch() are sampling the frames.
You must call this function whenever you would otherwise expect an object to be
immediately released after it's descoped.
"""
while _watch_running:
sleep(0.0001)


def _watch(
thread_id: int,
log: deque[tuple[float, dict[str, Any]]], # [(timestamp, output of create()), ...]
Expand All @@ -337,24 +328,20 @@ def _watch(

recent = create()
last = time()
watch_id = threading.get_ident()

while not stop():
_watch_running.add(watch_id)
try:
if time() > last + cycle:
if time() > last + cycle:
recent = create()
with lock:
log.append((time(), recent))
recent = create()
last = time()
try:
frame = sys._current_frames()[thread_id]
except KeyError:
return

process(frame, None, recent, omit=omit)
del frame
finally:
_watch_running.remove(watch_id)
try:
frame = sys._current_frames()[thread_id]
except KeyError:
return

process(frame, None, recent, omit=omit)
del frame
sleep(interval)


Expand Down
10 changes: 5 additions & 5 deletions distributed/protocol/tests/test_pickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from distributed.profile import wait_profiler
from distributed import profile
from distributed.protocol import deserialize, serialize
from distributed.protocol.pickle import HIGHEST_PROTOCOL, dumps, loads

Expand Down Expand Up @@ -181,7 +181,7 @@ def funcs():
assert func3(1) == func(1)

del func, func2, func3
wait_profiler()
assert wr() is None
assert wr2() is None
assert wr3() is None
with profile.lock:
assert wr() is None
assert wr2() is None
assert wr3() is None
18 changes: 8 additions & 10 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
from distributed.compatibility import LINUX, WINDOWS
from distributed.core import Server, Status
from distributed.metrics import time
from distributed.profile import wait_profiler
from distributed.scheduler import CollectTaskMetaDataPlugin, KilledWorker, Scheduler
from distributed.sizeof import sizeof
from distributed.utils import is_valid_xml, mp_context, sync, tmp_text
Expand Down Expand Up @@ -678,8 +677,8 @@ def test_no_future_references(c):
futures = c.map(inc, range(10))
ws.update(futures)
del futures
wait_profiler()
assert not list(ws)
with profile.lock:
assert not list(ws)


def test_get_sync_optimize_graph_passes_through(c):
Expand Down Expand Up @@ -811,9 +810,9 @@ async def test_recompute_released_key(c, s, a, b):
result1 = await x
xkey = x.key
del x
wait_profiler()
await asyncio.sleep(0)
assert c.refcount[xkey] == 0
with profile.lock:
await asyncio.sleep(0)
assert c.refcount[xkey] == 0

# 1 second batching needs a second action to trigger
while xkey in s.tasks and s.tasks[xkey].who_has or xkey in a.data or xkey in b.data:
Expand Down Expand Up @@ -3483,10 +3482,9 @@ async def test_Client_clears_references_after_restart(c, s, a, b):

key = x.key
del x
wait_profiler()
await asyncio.sleep(0)

assert key not in c.refcount
with profile.lock:
await asyncio.sleep(0)
assert key not in c.refcount


@gen_cluster(Worker=Nanny, client=True)
Expand Down
18 changes: 9 additions & 9 deletions distributed/tests/test_diskutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

import dask

from distributed import profile
from distributed.compatibility import WINDOWS
from distributed.diskutils import WorkSpace
from distributed.metrics import time
from distributed.profile import wait_profiler
from distributed.utils import mp_context
from distributed.utils_test import captured_logger

Expand Down Expand Up @@ -53,8 +53,8 @@ def test_workdir_simple(tmpdir):
a.release()
assert_contents(["bb", "bb.dirlock"])
del b
wait_profiler()
gc.collect()
with profile.lock:
gc.collect()
assert_contents([])

# Generated temporary name with a prefix
Expand Down Expand Up @@ -89,12 +89,12 @@ def test_two_workspaces_in_same_directory(tmpdir):

del ws
del b
wait_profiler()
gc.collect()
with profile.lock:
gc.collect()
assert_contents(["aa", "aa.dirlock"], trials=5)
del a
wait_profiler()
gc.collect()
with profile.lock:
gc.collect()
assert_contents([], trials=5)


Expand Down Expand Up @@ -188,8 +188,8 @@ def test_locking_disabled(tmpdir):
a.release()
assert_contents(["bb"])
del b
wait_profiler()
gc.collect()
with profile.lock:
gc.collect()
assert_contents([])

lock_file.assert_not_called()
Expand Down
8 changes: 5 additions & 3 deletions distributed/tests/test_failed_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@

from dask import delayed

from distributed import Client, Nanny, wait
from distributed import Client, Nanny, profile, wait
from distributed.comm import CommClosedError
from distributed.compatibility import MACOS
from distributed.metrics import time
from distributed.profile import wait_profiler
from distributed.utils import CancelledError, sync
from distributed.utils_test import (
captured_logger,
Expand Down Expand Up @@ -262,7 +261,10 @@ async def test_forgotten_futures_dont_clean_up_new_futures(c, s, a, b):
await c.restart()
y = c.submit(inc, 1)
del x
wait_profiler()

# Ensure that the profiler has stopped and released all references to x so that it can be garbage-collected
with profile.lock:
pass
await asyncio.sleep(0.1)
await y

Expand Down
7 changes: 3 additions & 4 deletions distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
import dask
from dask.utils import tmpfile

from distributed import Nanny, Scheduler, Worker, rpc, wait, worker
from distributed import Nanny, Scheduler, Worker, profile, rpc, wait, worker
from distributed.compatibility import LINUX, WINDOWS
from distributed.core import CommClosedError, Status
from distributed.diagnostics import SchedulerPlugin
from distributed.metrics import time
from distributed.profile import wait_profiler
from distributed.protocol.pickle import dumps
from distributed.utils import TimeoutError, parse_ports
from distributed.utils_test import (
Expand Down Expand Up @@ -170,8 +169,8 @@ async def test_num_fds(s):
# Warm up
async with Nanny(s.address):
pass
wait_profiler()
gc.collect()
with profile.lock:
gc.collect()

before = proc.num_fds()

Expand Down
40 changes: 40 additions & 0 deletions distributed/tests/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
info_frame,
ll_get_stack,
llprocess,
lock,
merge,
plot_data,
process,
Expand Down Expand Up @@ -207,6 +208,45 @@ def stop():
sleep(0.01)


def test_watch_requires_lock_to_run():
start = time()

def stop_lock():
return time() > start + 0.600

def stop_profile():
return time() > start + 0.500

def hold_lock(stop):
with lock:
while not stop():
sleep(0.1)

start_threads = threading.active_count()

# Hog the lock over the entire duration of watch
thread = threading.Thread(
target=hold_lock, name="Hold Lock", kwargs={"stop": stop_lock}
)
thread.daemon = True
thread.start()

log = watch(interval="10ms", cycle="50ms", stop=stop_profile)

start = time() # wait until thread starts up
while threading.active_count() < start_threads + 2:
assert time() < start + 2
sleep(0.01)

sleep(0.5)
assert len(log) == 0

start = time()
while threading.active_count() > start_threads:
assert time() < start + 2
sleep(0.01)


@dataclasses.dataclass(frozen=True)
class FakeCode:
co_filename: str
Expand Down
7 changes: 5 additions & 2 deletions distributed/tests/test_spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from dask.sizeof import sizeof

from distributed import profile
from distributed.compatibility import WINDOWS
from distributed.profile import wait_profiler
from distributed.protocol import serialize_bytelist
from distributed.spill import SpillBuffer, has_zict_210, has_zict_220
from distributed.utils_test import captured_logger
Expand Down Expand Up @@ -338,7 +338,10 @@ def test_weakref_cache(tmpdir, cls, expect_cached, size):
# the same id as a deleted one
id_x = x.id
del x
wait_profiler()

# Ensure that the profiler has stopped and released all references to x so that it can be garbage-collected
with profile.lock:
pass

if size < 100:
buf["y"]
Expand Down
7 changes: 3 additions & 4 deletions distributed/tests/test_steal.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@

import dask

from distributed import Event, Lock, Nanny, Worker, wait, worker_client
from distributed import Event, Lock, Nanny, Worker, profile, wait, worker_client
from distributed.compatibility import LINUX
from distributed.config import config
from distributed.core import Status
from distributed.metrics import time
from distributed.profile import wait_profiler
from distributed.scheduler import key_split
from distributed.system import MEMORY_LIMIT
from distributed.utils_test import (
Expand Down Expand Up @@ -948,8 +947,8 @@ class Foo:

assert not s.tasks

wait_profiler()
assert not list(ws)
with profile.lock:
assert not list(ws)


@gen_cluster(client=True, nthreads=[("127.0.0.1", 1)] * 2)
Expand Down
6 changes: 3 additions & 3 deletions distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
default_client,
get_client,
get_worker,
profile,
wait,
)
from distributed.comm.registry import backends
Expand All @@ -42,7 +43,6 @@
from distributed.diagnostics import nvml
from distributed.diagnostics.plugin import PipInstall
from distributed.metrics import time
from distributed.profile import wait_profiler
from distributed.protocol import pickle
from distributed.scheduler import Scheduler
from distributed.utils_test import (
Expand Down Expand Up @@ -1851,8 +1851,8 @@ class C:
del f
while "f" in a.data:
await asyncio.sleep(0.01)
wait_profiler()
assert ref() is None
with profile.lock:
assert ref() is None

story = a.stimulus_story("f", "f2")
assert {ev.key for ev in story} == {"f", "f2"}
Expand Down

0 comments on commit dea9ef2

Please sign in to comment.