2525import os
2626import sys
2727import warnings
28+ import weakref
2829from collections .abc import Callable , MutableMapping
2930from contextlib import suppress
30- from functools import partial
3131from typing import TYPE_CHECKING , Any , Container , Literal , cast
3232
3333import psutil
@@ -135,23 +135,32 @@ def __init__(
135135 )
136136 assert isinstance (self .memory_monitor_interval , (int , float ))
137137
138+ self ._worker = weakref .ref (worker )
139+
138140 if self .memory_limit and (
139141 self .memory_spill_fraction is not False
140142 or self .memory_pause_fraction is not False
141143 ):
142144 assert self .memory_monitor_interval is not None
143145 pc = PeriodicCallback (
144- # Don't store worker as self.worker to avoid creating a circular
145- # dependency. We could have alternatively used a weakref.
146146 # FIXME annotations: https://github.com/tornadoweb/tornado/issues/3117
147- partial ( self .memory_monitor , worker ) , # type: ignore
147+ self .memory_monitor , # type: ignore
148148 self .memory_monitor_interval * 1000 ,
149149 )
150150 worker .periodic_callbacks ["memory_monitor" ] = pc
151151
152152 self ._throttled_gc = ThrottledGC (logger = logger )
153153
154- async def memory_monitor (self , worker : Worker ) -> None :
154+ def get_process_memory (self ) -> int :
155+ """Get a measure for process memory.
156+ This can be a mock target.
157+ """
158+ worker = self ._worker ()
159+ if worker :
160+ return worker .monitor .get_process_memory ()
161+ return - 1
162+
163+ async def memory_monitor (self ) -> None :
155164 """Track this process's memory usage and act accordingly.
156165 If process memory rises above the spill threshold (70%), start dumping data to
157166 disk until it goes below the target threshold (60%).
@@ -166,16 +175,18 @@ async def memory_monitor(self, worker: Worker) -> None:
166175 # Don't use psutil directly; instead read from the same API that is used
167176 # to send info to the Scheduler (e.g. for the benefit of Active Memory
168177 # Manager) and which can be easily mocked in unit tests.
169- memory = worker . monitor .get_process_memory ()
170- self ._maybe_pause_or_unpause (worker , memory )
171- await self ._maybe_spill (worker , memory )
178+ memory = self .get_process_memory ()
179+ self ._maybe_pause_or_unpause (memory )
180+ await self ._maybe_spill (memory )
172181 finally :
173182 self ._memory_monitoring = False
174183
175- def _maybe_pause_or_unpause (self , worker : Worker , memory : int ) -> None :
184+ def _maybe_pause_or_unpause (self , memory : int ) -> None :
176185 if self .memory_pause_fraction is False :
177186 return
178-
187+ worker = self ._worker ()
188+ if not worker :
189+ return
179190 assert self .memory_limit
180191 frac = memory / self .memory_limit
181192 # Pause worker threads if above 80% memory use
@@ -205,7 +216,7 @@ def _maybe_pause_or_unpause(self, worker: Worker, memory: int) -> None:
205216 )
206217 worker .status = Status .running
207218
208- async def _maybe_spill (self , worker : Worker , memory : int ) -> None :
219+ async def _maybe_spill (self , memory : int ) -> None :
209220 if self .memory_spill_fraction is False :
210221 return
211222
@@ -257,15 +268,15 @@ async def _maybe_spill(self, worker: Worker, memory: int) -> None:
257268 count += 1
258269 await asyncio .sleep (0 )
259270
260- memory = worker . monitor .get_process_memory ()
271+ memory = self .get_process_memory ()
261272 if total_spilled > need and memory > target :
262273 # Issue a GC to ensure that the evicted data is actually
263274 # freed from memory and taken into account by the monitor
264275 # before trying to evict even more data.
265276 self ._throttled_gc .collect ()
266- memory = worker . monitor .get_process_memory ()
277+ memory = self .get_process_memory ()
267278
268- self ._maybe_pause_or_unpause (worker , memory )
279+ self ._maybe_pause_or_unpause (memory )
269280 if count :
270281 logger .debug (
271282 "Moved %d tasks worth %s to disk" ,
@@ -302,32 +313,46 @@ def __init__(
302313 dask .config .get ("distributed.worker.memory.monitor-interval" ),
303314 default = None ,
304315 )
316+ self .nanny = weakref .ref (nanny )
305317 assert isinstance (self .memory_monitor_interval , (int , float ))
306318 if self .memory_limit and self .memory_terminate_fraction is not False :
307319 pc = PeriodicCallback (
308- partial ( self .memory_monitor , nanny ) ,
320+ self .memory_monitor ,
309321 self .memory_monitor_interval * 1000 ,
310322 )
311323 nanny .periodic_callbacks ["memory_monitor" ] = pc
312324
313- def memory_monitor (self , nanny : Nanny ) -> None :
325+ def get_process_memory (self ) -> int :
326+ """Get a measure for process memory.
327+ This can be a mock target.
328+ """
329+ nanny = self .nanny ()
330+ if nanny :
331+ try :
332+ proc = nanny ._psutil_process
333+ return proc .memory_info ().rss
334+ except (ProcessLookupError , psutil .NoSuchProcess , psutil .AccessDenied ):
335+ return - 1 # pragma: nocover
336+ return - 1
337+
338+ def memory_monitor (self ) -> None :
314339 """Track worker's memory. Restart if it goes above terminate fraction."""
315- if nanny .status != Status .running :
316- return # pragma: nocover
317- if nanny .process is None or nanny .process .process is None :
318- return # pragma: nocover
319- process = nanny .process .process
320- try :
321- proc = nanny ._psutil_process
322- memory = proc .memory_info ().rss
323- except (ProcessLookupError , psutil .NoSuchProcess , psutil .AccessDenied ):
340+ nanny = self .nanny ()
341+ if (
342+ not nanny
343+ or nanny .status != Status .running
344+ or nanny .process is None
345+ or nanny .process .process is None
346+ or self .memory_limit is None
347+ ):
324348 return # pragma: nocover
325-
349+ memory = self . get_process_memory ()
326350 if memory / self .memory_limit > self .memory_terminate_fraction :
327351 logger .warning (
328352 "Worker exceeded %d%% memory budget. Restarting" ,
329353 100 * self .memory_terminate_fraction ,
330354 )
355+ process = nanny .process .process
331356 process .terminate ()
332357
333358
@@ -403,4 +428,4 @@ def __get__(self, instance: Nanny | Worker | None, owner):
403428 # This is triggered by Sphinx
404429 return None # pragma: nocover
405430 _warn_deprecated (instance , "memory_monitor" )
406- return partial ( instance .memory_manager .memory_monitor , instance )
431+ return instance .memory_manager .memory_monitor
0 commit comments