Skip to content

Commit 4f7c70f

Browse files
committed
Use dependency injection for proc memory mocks
1 parent 6dd928b commit 4f7c70f

File tree

3 files changed

+124
-36
lines changed

3 files changed

+124
-36
lines changed

distributed/tests/test_worker_memory.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,67 @@
1414
from distributed.core import Status
1515
from distributed.spill import has_zict_210
1616
from distributed.utils_test import captured_logger, gen_cluster, inc
17-
from distributed.worker_memory import parse_memory_limit
17+
from distributed.worker_memory import WorkerMemoryManager, parse_memory_limit
1818

1919
requires_zict_210 = pytest.mark.skipif(
2020
not has_zict_210,
2121
reason="requires zict version >= 2.1.0",
2222
)
2323

2424

25+
def get_fake_wmm_fast_static(value: float) -> type[WorkerMemoryManager]:
26+
"""Fake factory for WorkerMemoryManager for convenience
27+
28+
This will set the observed process memory to be constant to ``value`` if
29+
there is data in `data.fast`.
30+
"""
31+
32+
class FakeWMMFastStatic(WorkerMemoryManager):
33+
def get_process_memory(self):
34+
worker = self.worker()
35+
if worker and worker.data.fast:
36+
return value
37+
else:
38+
return 0
39+
40+
return FakeWMMFastStatic
41+
42+
43+
def get_fake_wmm_fast_dynamic(value: float) -> type[WorkerMemoryManager]:
44+
"""Fake factory for WorkerMemoryManager for convenience
45+
46+
This will set the observed process memory to be ``value`` times the number of elements in `data.fast`.
47+
"""
48+
49+
class FakeWMMFastDyn(WorkerMemoryManager):
50+
def get_process_memory(self):
51+
worker = self.worker()
52+
if worker and worker.data.fast:
53+
return value * len(worker.data.fast)
54+
else:
55+
return 0
56+
57+
return FakeWMMFastDyn
58+
59+
60+
def get_fake_wmm_all_static(value: float) -> type[WorkerMemoryManager]:
61+
"""Fake factory for WorkerMemoryManager for convenience
62+
63+
This will set the observed process memory to be ``value`` as long as there
64+
is any data in the buffer
65+
"""
66+
67+
class FakeWMMAll(WorkerMemoryManager):
68+
def get_process_memory(self):
69+
worker = self.worker()
70+
if worker and worker.data:
71+
return value
72+
else:
73+
return 0
74+
75+
return FakeWMMAll
76+
77+
2578
def memory_monitor_running(dask_worker: Worker | Nanny) -> bool:
2679
return "memory_monitor" in dask_worker.periodic_callbacks
2780

@@ -109,7 +162,9 @@ async def test_fail_to_pickle_target_1(c, s, a, b):
109162
@gen_cluster(
110163
client=True,
111164
nthreads=[("", 1)],
112-
worker_kwargs={"memory_limit": "1 kiB"},
165+
worker_kwargs={
166+
"memory_limit": "1 kiB",
167+
},
113168
config={
114169
"distributed.worker.memory.target": 0.5,
115170
"distributed.worker.memory.spill": False,
@@ -142,7 +197,10 @@ async def test_fail_to_pickle_target_2(c, s, a):
142197
@gen_cluster(
143198
client=True,
144199
nthreads=[("", 1)],
145-
worker_kwargs={"memory_limit": "1 kB"},
200+
worker_kwargs={
201+
"memory_limit": "1 kB",
202+
"memory_manager_cls": get_fake_wmm_fast_static(701),
203+
},
146204
config={
147205
"distributed.worker.memory.target": False,
148206
"distributed.worker.memory.spill": 0.7,
@@ -151,7 +209,6 @@ async def test_fail_to_pickle_target_2(c, s, a):
151209
)
152210
async def test_fail_to_pickle_spill(c, s, a):
153211
"""Test failure to evict a key, triggered by the spill threshold"""
154-
a.monitor.get_process_memory = lambda: 701 if a.data.fast else 0
155212

156213
with captured_logger(logging.getLogger("distributed.spill")) as logs:
157214
bad = c.submit(FailToPickle, key="bad")
@@ -268,7 +325,10 @@ async def test_spill_constrained(c, s, w):
268325
@gen_cluster(
269326
nthreads=[("", 1)],
270327
client=True,
271-
worker_kwargs={"memory_limit": "1000 MB"},
328+
worker_kwargs={
329+
"memory_limit": "1000 MB",
330+
"memory_manager_cls": get_fake_wmm_fast_static(800_000_000),
331+
},
272332
config={
273333
"distributed.worker.memory.target": False,
274334
"distributed.worker.memory.spill": 0.7,
@@ -282,7 +342,6 @@ async def test_spill_spill_threshold(c, s, a):
282342
reported by sizeof(), which may be inaccurate.
283343
"""
284344
assert memory_monitor_running(a)
285-
a.monitor.get_process_memory = lambda: 800_000_000 if a.data.fast else 0
286345
x = c.submit(inc, 0, key="x")
287346
while not a.data.disk:
288347
await asyncio.sleep(0.01)
@@ -326,8 +385,11 @@ def __sizeof__(self):
326385
return managed
327386

328387
with dask.config.set({"distributed.worker.memory.target": target}):
329-
async with Worker(s.address, memory_limit="1000 MB") as a:
330-
a.monitor.get_process_memory = lambda: 50_000_000 * len(a.data.fast)
388+
async with Worker(
389+
s.address,
390+
memory_limit="1000 MB",
391+
memory_manager_cls=get_fake_wmm_fast_dynamic(50_000_000),
392+
) as a:
331393

332394
# Add 500MB (reported) process memory. Spilling must not happen.
333395
futures = [c.submit(C, pure=False) for _ in range(10)]

distributed/worker.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,7 @@ def __init__(
455455
lifetime_restart: bool | None = None,
456456
###################################
457457
# Parameters to WorkerMemoryManager
458+
memory_manager_cls: type[WorkerMemoryManager] = WorkerMemoryManager,
458459
memory_limit: str | float = "auto",
459460
# Allow overriding the dict-like that stores the task outputs.
460461
# This is meant for power users only. See WorkerMemoryManager for details.
@@ -786,7 +787,7 @@ def __init__(
786787
for ext in extensions:
787788
ext(self)
788789

789-
self.memory_manager = WorkerMemoryManager(
790+
self.memory_manager = memory_manager_cls(
790791
self,
791792
data=data,
792793
memory_limit=memory_limit,

distributed/worker_memory.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,9 @@
2525
import os
2626
import sys
2727
import warnings
28+
import weakref
2829
from collections.abc import Callable, MutableMapping
2930
from contextlib import suppress
30-
from functools import partial
3131
from typing import TYPE_CHECKING, Any, Container, Literal, cast
3232

3333
import psutil
@@ -135,23 +135,32 @@ def __init__(
135135
)
136136
assert isinstance(self.memory_monitor_interval, (int, float))
137137

138+
self._worker = weakref.ref(worker)
139+
138140
if self.memory_limit and (
139141
self.memory_spill_fraction is not False
140142
or self.memory_pause_fraction is not False
141143
):
142144
assert self.memory_monitor_interval is not None
143145
pc = PeriodicCallback(
144-
# Don't store worker as self.worker to avoid creating a circular
145-
# dependency. We could have alternatively used a weakref.
146146
# FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117
147-
partial(self.memory_monitor, worker), # type: ignore
147+
self.memory_monitor, # type: ignore
148148
self.memory_monitor_interval * 1000,
149149
)
150150
worker.periodic_callbacks["memory_monitor"] = pc
151151

152152
self._throttled_gc = ThrottledGC(logger=logger)
153153

154-
async def memory_monitor(self, worker: Worker) -> None:
154+
def get_process_memory(self) -> int:
155+
"""Get a measure for process memory.
156+
This can be a mock target.
157+
"""
158+
worker = self._worker()
159+
if worker:
160+
return worker.monitor.get_process_memory()
161+
return -1
162+
163+
async def memory_monitor(self) -> None:
155164
"""Track this process's memory usage and act accordingly.
156165
If process memory rises above the spill threshold (70%), start dumping data to
157166
disk until it goes below the target threshold (60%).
@@ -166,16 +175,18 @@ async def memory_monitor(self, worker: Worker) -> None:
166175
# Don't use psutil directly; instead read from the same API that is used
167176
# to send info to the Scheduler (e.g. for the benefit of Active Memory
168177
# Manager) and which can be easily mocked in unit tests.
169-
memory = worker.monitor.get_process_memory()
170-
self._maybe_pause_or_unpause(worker, memory)
171-
await self._maybe_spill(worker, memory)
178+
memory = self.get_process_memory()
179+
self._maybe_pause_or_unpause(memory)
180+
await self._maybe_spill(memory)
172181
finally:
173182
self._memory_monitoring = False
174183

175-
def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None:
184+
def _maybe_pause_or_unpause(self, memory: int) -> None:
176185
if self.memory_pause_fraction is False:
177186
return
178-
187+
worker = self._worker()
188+
if not worker:
189+
return
179190
assert self.memory_limit
180191
frac = memory / self.memory_limit
181192
# Pause worker threads if above 80% memory use
@@ -205,7 +216,7 @@ def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None:
205216
)
206217
worker.status = Status.running
207218

208-
async def _maybe_spill(self, worker: Worker, memory: int) -> None:
219+
async def _maybe_spill(self, memory: int) -> None:
209220
if self.memory_spill_fraction is False:
210221
return
211222

@@ -257,15 +268,15 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None:
257268
count += 1
258269
await asyncio.sleep(0)
259270

260-
memory = worker.monitor.get_process_memory()
271+
memory = self.get_process_memory()
261272
if total_spilled > need and memory > target:
262273
# Issue a GC to ensure that the evicted data is actually
263274
# freed from memory and taken into account by the monitor
264275
# before trying to evict even more data.
265276
self._throttled_gc.collect()
266-
memory = worker.monitor.get_process_memory()
277+
memory = self.get_process_memory()
267278

268-
self._maybe_pause_or_unpause(worker, memory)
279+
self._maybe_pause_or_unpause(memory)
269280
if count:
270281
logger.debug(
271282
"Moved %d tasks worth %s to disk",
@@ -302,32 +313,46 @@ def __init__(
302313
dask.config.get("distributed.worker.memory.monitor-interval"),
303314
default=None,
304315
)
316+
self.nanny = weakref.ref(nanny)
305317
assert isinstance(self.memory_monitor_interval, (int, float))
306318
if self.memory_limit and self.memory_terminate_fraction is not False:
307319
pc = PeriodicCallback(
308-
partial(self.memory_monitor, nanny),
320+
self.memory_monitor,
309321
self.memory_monitor_interval * 1000,
310322
)
311323
nanny.periodic_callbacks["memory_monitor"] = pc
312324

313-
def memory_monitor(self, nanny: Nanny) -> None:
325+
def get_process_memory(self) -> int:
326+
"""Get a measure for process memory.
327+
This can be a mock target.
328+
"""
329+
nanny = self.nanny()
330+
if nanny:
331+
try:
332+
proc = nanny._psutil_process
333+
return proc.memory_info().rss
334+
except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied):
335+
return -1 # pragma: nocover
336+
return -1
337+
338+
def memory_monitor(self) -> None:
314339
"""Track worker's memory. Restart if it goes above terminate fraction."""
315-
if nanny.status != Status.running:
316-
return # pragma: nocover
317-
if nanny.process is None or nanny.process.process is None:
318-
return # pragma: nocover
319-
process = nanny.process.process
320-
try:
321-
proc = nanny._psutil_process
322-
memory = proc.memory_info().rss
323-
except (ProcessLookupError, psutil.NoSuchProcess, psutil.AccessDenied):
340+
nanny = self.nanny()
341+
if (
342+
not nanny
343+
or nanny.status != Status.running
344+
or nanny.process is None
345+
or nanny.process.process is None
346+
or self.memory_limit is None
347+
):
324348
return # pragma: nocover
325-
349+
memory = self.get_process_memory()
326350
if memory / self.memory_limit > self.memory_terminate_fraction:
327351
logger.warning(
328352
"Worker exceeded %d%% memory budget. Restarting",
329353
100 * self.memory_terminate_fraction,
330354
)
355+
process = nanny.process.process
331356
process.terminate()
332357

333358

@@ -403,4 +428,4 @@ def __get__(self, instance: Nanny | Worker | None, owner):
403428
# This is triggered by Sphinx
404429
return None # pragma: nocover
405430
_warn_deprecated(instance, "memory_monitor")
406-
return partial(instance.memory_manager.memory_monitor, instance)
431+
return instance.memory_manager.memory_monitor

0 commit comments

Comments
 (0)