Skip to content

Commit

Permalink
Merge branch 'main' into garbage-collect-nannies
Browse files Browse the repository at this point in the history
  • Loading branch information
hendrikmakait committed May 14, 2022
2 parents 8671450 + 6e0fe58 commit 5ca412c
Show file tree
Hide file tree
Showing 30 changed files with 509 additions and 417 deletions.
4 changes: 1 addition & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@ repos:
- id: isort
language_version: python3
- repo: https://github.com/asottile/pyupgrade
# Do not upgrade: there's a bug in Cython that causes sum(... for ...) to fail;
# it needs sum([... for ...])
rev: v2.13.0
rev: v2.32.0
hooks:
- id: pyupgrade
args:
Expand Down
13 changes: 7 additions & 6 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from tornado.ioloop import PeriodicCallback

from distributed import cluster_dump, preloading
from distributed import versions as version_module # type: ignore
from distributed import versions as version_module
from distributed.batched import BatchedSend
from distributed.cfexecutor import ClientExecutor
from distributed.core import (
Expand Down Expand Up @@ -4418,16 +4418,14 @@ async def _get_task_stream(
else:
return msgs

async def _register_scheduler_plugin(self, plugin, name, **kwargs):
if isinstance(plugin, type):
plugin = plugin(**kwargs)

async def _register_scheduler_plugin(self, plugin, name, idempotent=False):
return await self.scheduler.register_scheduler_plugin(
plugin=dumps(plugin, protocol=4),
name=name,
idempotent=idempotent,
)

def register_scheduler_plugin(self, plugin, name=None):
def register_scheduler_plugin(self, plugin, name=None, idempotent=False):
"""Register a scheduler plugin.
See https://distributed.readthedocs.io/en/latest/plugins.html#scheduler-plugins
Expand All @@ -4439,6 +4437,8 @@ def register_scheduler_plugin(self, plugin, name=None):
name : str
Name for the plugin; if None, a name is taken from the
plugin instance or automatically generated if not present.
idempotent : bool
Do not re-register if a plugin of the given name already exists.
"""
if name is None:
name = _get_plugin_name(plugin)
Expand All @@ -4447,6 +4447,7 @@ def register_scheduler_plugin(self, plugin, name=None):
self._register_scheduler_plugin,
plugin=plugin,
name=name,
idempotent=idempotent,
)

def register_worker_callbacks(self, setup=None):
Expand Down
28 changes: 16 additions & 12 deletions distributed/comm/asyncio_tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import socket
import struct
import sys
import weakref
from itertools import islice
from typing import Any
Expand All @@ -26,7 +27,7 @@
host_array,
to_frames,
)
from distributed.utils import ensure_ip, get_ip, get_ipv6
from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -379,7 +380,9 @@ async def write(self, frames: list[bytes]) -> int:
await drain_waiter

# Ensure all memoryviews are in single-byte format
frames = [f.cast("B") if isinstance(f, memoryview) else f for f in frames]
frames = [
ensure_memoryview(f) if isinstance(f, memoryview) else f for f in frames
]

nframes = len(frames)
frames_nbytes = [len(f) for f in frames]
Expand Down Expand Up @@ -776,10 +779,14 @@ class _ZeroCopyWriter:
# (which would be very large), and set a limit on the number of buffers to
# pass to sendmsg.
if hasattr(socket.socket, "sendmsg"):
try:
SENDMSG_MAX_COUNT = os.sysconf("SC_IOV_MAX") # type: ignore
except Exception:
SENDMSG_MAX_COUNT = 16 # Should be supported on all systems
# Note: can't use WINDOWS constant as it upsets mypy
if sys.platform == "win32":
SENDMSG_MAX_COUNT = 16 # No os.sysconf available
else:
try:
SENDMSG_MAX_COUNT = os.sysconf("SC_IOV_MAX")
except Exception:
SENDMSG_MAX_COUNT = 16 # Should be supported on all systems
else:
SENDMSG_MAX_COUNT = 1 # sendmsg not supported, use send instead

Expand Down Expand Up @@ -847,12 +854,9 @@ def _buffer_clear(self):

def _buffer_append(self, data: bytes) -> None:
"""Append new data to the send buffer"""
if not isinstance(data, memoryview):
data = memoryview(data)
if data.format != "B":
data = data.cast("B")
self._size += len(data)
self._buffers.append(data)
mv = ensure_memoryview(data)
self._size += len(mv)
self._buffers.append(mv)

def _buffer_peek(self) -> list[memoryview]:
"""Get one or more buffers to write to the socket"""
Expand Down
2 changes: 1 addition & 1 deletion distributed/comm/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def __call__(self, **kwargs: str) -> Iterable[importlib.metadata.EntryPoint]:
if sys.version_info >= (3, 10):
# py3.10 importlib.metadata type annotations are not in mypy yet
# https://github.com/python/typeshed/pull/7331
_entry_points: _EntryPoints = importlib.metadata.entry_points # type: ignore[assignment]
_entry_points: _EntryPoints = importlib.metadata.entry_points
else:

def _entry_points(
Expand Down
4 changes: 2 additions & 2 deletions distributed/comm/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
)
from distributed.protocol.utils import pack_frames_prelude, unpack_frames
from distributed.system import MEMORY_LIMIT
from distributed.utils import ensure_ip, get_ip, get_ipv6, nbytes
from distributed.utils import ensure_ip, ensure_memoryview, get_ip, get_ipv6, nbytes

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -305,7 +305,7 @@ async def write(self, msg, serializers=None, on_error="message"):
if isinstance(each_frame, memoryview):
# Make sure that `len(data) == data.nbytes`
# See <https://github.com/tornadoweb/tornado/pull/2996>
each_frame = memoryview(each_frame).cast("B")
each_frame = ensure_memoryview(each_frame)

stream._write_buffer.append(each_frame)
stream._total_write_index += each_frame_nbytes
Expand Down
2 changes: 1 addition & 1 deletion distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
except ImportError:
pass
else:
ucp = None # type: ignore
ucp = None

device_array = None
pre_existing_cuda_context = False
Expand Down
2 changes: 1 addition & 1 deletion distributed/compatibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

LINUX = sys.platform == "linux"
MACOS = sys.platform == "darwin"
WINDOWS = sys.platform.startswith("win")
WINDOWS = sys.platform == "win32"


if sys.version_info >= (3, 9):
Expand Down
4 changes: 2 additions & 2 deletions distributed/dashboard/components/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2794,7 +2794,7 @@ def _get_timeseries(self, restrict_to_existing=False):
back = None
# Remove any periods of zero compute at the front or back of the timeseries
if len(self.plugin.compute):
agg = sum([np.array(v[front:]) for v in self.plugin.compute.values()])
agg = sum(np.array(v[front:]) for v in self.plugin.compute.values())
front2 = len(agg) - len(np.trim_zeros(agg, trim="f"))
front += front2
back = len(np.trim_zeros(agg, trim="b")) - len(agg) or None
Expand Down Expand Up @@ -3192,7 +3192,7 @@ def update(self):
"names": ["Scheduler", "Workers"],
"values": [
s._tick_interval_observed,
sum([w.metrics["event_loop_interval"] for w in s.workers.values()])
sum(w.metrics["event_loop_interval"] for w in s.workers.values())
/ (len(s.workers) or 1),
],
}
Expand Down
2 changes: 1 addition & 1 deletion distributed/http/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def get_handlers(server, modules: list[str], prefix="/"):
_routes = []
for module_name in modules:
module = importlib.import_module(module_name)
_routes.extend(module.routes) # type: ignore
_routes.extend(module.routes)

routes = []

Expand Down
4 changes: 2 additions & 2 deletions distributed/protocol/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def pickle_loads(header, frames):
writeable = len(buffers) * (None,)

new = []
memoryviews = map(memoryview, buffers)
memoryviews = map(ensure_memoryview, buffers)
for w, mv in zip(writeable, memoryviews):
if w == mv.readonly:
if w:
Expand Down Expand Up @@ -785,7 +785,7 @@ def _serialize_memoryview(obj):
@dask_deserialize.register(memoryview)
def _deserialize_memoryview(header, frames):
if len(frames) == 1:
out = memoryview(frames[0]).cast("B")
out = ensure_memoryview(frames[0])
else:
out = memoryview(b"".join(frames))
out = out.cast(header["format"], header["shape"])
Expand Down
6 changes: 3 additions & 3 deletions distributed/pytest_resourceleaks.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def test1():
import psutil
import pytest

from distributed.compatibility import WINDOWS
from distributed.metrics import time


Expand Down Expand Up @@ -155,10 +154,11 @@ def format(self, before: int, after: int) -> str:

class FDChecker(ResourceChecker, name="fds"):
def measure(self) -> int:
if WINDOWS:
# Note: can't use WINDOWS constant as it upsets mypy
if sys.platform == "win32":
# Don't use num_handles(); you'll get tens of thousands of reported leaks
return 0
return psutil.Process().num_fds() # type: ignore
return psutil.Process().num_fds()

def has_leak(self, before: int, after: int) -> bool:
return after > before
Expand Down
Loading

0 comments on commit 5ca412c

Please sign in to comment.