diff --git a/distributed/tests/test_profile.py b/distributed/tests/test_profile.py index e0563b7c7f1..1d417cb19c5 100644 --- a/distributed/tests/test_profile.py +++ b/distributed/tests/test_profile.py @@ -185,53 +185,56 @@ def test_identifier(): def test_watch(): + stop_called = threading.Event() + watch_thread = None start = time() def stop(): + if not stop_called.is_set(): # Run setup code + nonlocal watch_thread + nonlocal start + watch_thread = threading.current_thread() + start = time() + stop_called.set() return time() > start + 0.500 - start_threads = threading.active_count() - log = watch(interval="10ms", cycle="50ms", stop=stop) - start = time() # wait until thread starts up - while threading.active_count() <= start_threads: - assert time() < start + 2 - sleep(0.01) - + stop_called.wait(2) sleep(0.5) assert 1 < len(log) < 10 - - start = time() - while threading.active_count() > start_threads: - assert time() < start + 2 - sleep(0.01) + watch_thread.join(2) def test_watch_requires_lock_to_run(): start = time() - def stop_lock(): - return time() > start + 0.600 + stop_profiling_called = threading.Event() + profiling_thread = None - def stop_profile(): + def stop_profiling(): + if not stop_profiling_called.is_set(): # Run setup code + nonlocal profiling_thread + nonlocal start + profiling_thread = threading.current_thread() + start = time() + stop_profiling_called.set() return time() > start + 0.500 - def hold_lock(stop): + release_lock = threading.Event() + + def block_lock(): with lock: - while not stop(): - sleep(0.1) + release_lock.wait() start_threads = threading.active_count() - # Hog the lock over the entire duration of watch - thread = threading.Thread( - target=hold_lock, name="Hold Lock", kwargs={"stop": stop_lock} - ) - thread.daemon = True - thread.start() + # Block the lock over the entire duration of watch + blocking_thread = threading.Thread(target=block_lock, name="Block Lock") + blocking_thread.daemon = True + blocking_thread.start() - log = watch(interval="10ms", cycle="50ms", stop=stop_profile) + log = watch(interval="10ms", cycle="50ms", stop=stop_profiling) start = time() # wait until thread starts up while threading.active_count() < start_threads + 2: @@ -240,11 +243,10 @@ def hold_lock(stop): sleep(0.5) assert len(log) == 0 + release_lock.set() - start = time() - while threading.active_count() > start_threads: - assert time() < start + 2 - sleep(0.01) + profiling_thread.join(2) + blocking_thread.join(2) @dataclasses.dataclass(frozen=True)