diff --git a/distributed/cli/dask_worker.py b/distributed/cli/dask_worker.py index f13672a4d10..376f2a1c62b 100755 --- a/distributed/cli/dask_worker.py +++ b/distributed/cli/dask_worker.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import asyncio import atexit import gc @@ -6,7 +8,9 @@ import signal import sys import warnings +from collections.abc import Iterator from contextlib import suppress +from typing import Any import click from tlz import valmap @@ -24,7 +28,7 @@ enable_proctitle_on_children, enable_proctitle_on_current, ) -from distributed.utils import import_term +from distributed.utils import import_term, parse_ports logger = logging.getLogger("distributed.dask_worker") @@ -254,10 +258,10 @@ def main( scheduler, host, - worker_port, + worker_port: str | None, listen_address, contact_address, - nanny_port, + nanny_port: str | None, nthreads, nprocs, n_workers, @@ -364,7 +368,8 @@ def main( try: if listen_address: - (host, worker_port) = get_address_host_port(listen_address, strict=True) + host, _ = get_address_host_port(listen_address, strict=True) + worker_port = str(_) if ":" in host: # IPv6 -- bracket to pass as user args host = f"[{host}]" @@ -379,11 +384,6 @@ def main( logger.error("Failed to launch worker. " + str(e)) sys.exit(1) - if nanny: - port = nanny_port - else: - port = worker_port - if not nthreads: nthreads = CPU_COUNT // n_workers @@ -407,16 +407,16 @@ def del_pid_file(): loop = IOLoop.current() worker_class = import_term(worker_class) + + port_kwargs = _apportion_ports(worker_port, nanny_port, n_workers, nanny) + assert len(port_kwargs) == n_workers + if nanny: kwargs["worker_class"] = worker_class kwargs["preload_nanny"] = preload_nanny - - if nanny: - kwargs.update({"worker_port": worker_port, "listen_address": listen_address}) + kwargs["listen_address"] = listen_address t = Nanny else: - if nanny_port: - kwargs["service_ports"] = {"nanny": nanny_port} t = worker_class if ( @@ -442,15 +442,15 @@ def del_pid_file(): security=sec, contact_address=contact_address, host=host, - port=port, dashboard=dashboard, dashboard_address=dashboard_address, name=name if n_workers == 1 or name is None or name == "" else str(name) + "-" + str(i), **kwargs, + **port_kwargs_i, ) - for i in range(n_workers) + for i, port_kwargs_i in enumerate(port_kwargs) ] async def close_all(): @@ -486,5 +486,88 @@ async def run(): logger.info("End worker") +def _apportion_ports( + worker_port: str | None, nanny_port: str | None, n_workers: int, nanny: bool +) -> list[dict[str, Any]]: + """Spread out evenly --worker-port and/or --nanny-port ranges to the workers and + nannies, avoiding overlap. + + Returns + ======= + List of kwargs to pass to the Worker or Nanny construtors + """ + seen = set() + + def parse_unique(s: str | None) -> Iterator[int | None]: + ports = parse_ports(s) + if ports in ([0], [None]): + for _ in range(n_workers): + yield ports[0] + else: + for port in ports: + if port not in seen: + seen.add(port) + yield port + + worker_ports_iter = parse_unique(worker_port) + nanny_ports_iter = parse_unique(nanny_port) + + # [(worker ports, nanny ports), ...] + ports: list[tuple[set[int | None], set[int | None]]] = [ + (set(), set()) for _ in range(n_workers) + ] + + ports_iter = iter(ports) + more_wps = True + more_nps = True + while more_wps or more_nps: + try: + worker_ports_i, nanny_ports_i = next(ports_iter) + except StopIteration: + # Start again in round-robin + ports_iter = iter(ports) + continue + + try: + worker_ports_i.add(next(worker_ports_iter)) + except StopIteration: + more_wps = False + try: + nanny_ports_i.add(next(nanny_ports_iter)) + except StopIteration: + more_nps = False + + kwargs = [] + for worker_ports_i, nanny_ports_i in ports: + if not worker_ports_i or not nanny_ports_i: + if nanny: + raise ValueError( + f"Not enough ports in range --worker_port {worker_port} " + f"--nanny_port {nanny_port} for {n_workers} workers" + ) + else: + raise ValueError( + f"Not enough ports in range --worker_port {worker_port} " + f"for {n_workers} workers" + ) + + # None and int can't be sorted together, + # but None and 0 are guaranteed to be alone + wp: Any = sorted(worker_ports_i) + if len(wp) == 1: + wp = wp[0] + if nanny: + np: Any = sorted(nanny_ports_i) + if len(np) == 1: + np = np[0] + kwargs_i = {"port": np, "worker_port": wp} + else: + kwargs_i = {"port": wp} + + kwargs.append(kwargs_i) + + return kwargs + + if __name__ == "__main__": main() # pragma: no cover diff --git a/distributed/cli/tests/test_dask_worker.py b/distributed/cli/tests/test_dask_worker.py index 5be7d4b4d8a..e9355c82225 100644 --- a/distributed/cli/tests/test_dask_worker.py +++ b/distributed/cli/tests/test_dask_worker.py @@ -1,5 +1,8 @@ +from __future__ import annotations + import asyncio import os +import sys from multiprocessing import cpu_count from time import sleep @@ -8,14 +11,152 @@ from dask.utils import tmpfile -import distributed.cli.dask_worker from distributed import Client +from distributed.cli.dask_worker import _apportion_ports, main from distributed.compatibility import LINUX, to_thread from distributed.deploy.utils import nprocesses_nthreads from distributed.metrics import time from distributed.utils_test import gen_cluster, popen, requires_ipv6 +@pytest.mark.parametrize( + # args: (worker_port, nanny_port, n_workers, nanny) + # Passing *args tuple instead of single args is to improve readability with black + "args,expect", + [ + # Single worker + ( + (None, None, 1, False), + [{"port": None}], + ), + ( + (None, None, 1, True), + [{"port": None, "worker_port": None}], + ), + (("123", None, 1, False), [{"port": 123}]), + ( + ("123", None, 1, True), + [{"port": None, "worker_port": 123}], + ), + ( + (None, "456", 1, True), + [{"port": 456, "worker_port": None}], + ), + ( + ("123", "456", 1, True), + [{"port": 456, "worker_port": 123}], + ), + # port=None or 0 and multiple workers + ( + (None, None, 2, False), + [ + {"port": None}, + {"port": None}, + ], + ), + ( + (None, None, 2, True), + [ + {"port": None, "worker_port": None}, + {"port": None, "worker_port": None}, + ], + ), + ( + (0, "0", 2, True), + [ + {"port": 0, "worker_port": 0}, + {"port": 0, "worker_port": 0}, + ], + ), + ( + ("0", None, 2, True), + [ + {"port": None, "worker_port": 0}, + {"port": None, "worker_port": 0}, + ], + ), + # port ranges + ( + ("100:103", None, 1, False), + [{"port": [100, 101, 102, 103]}], + ), + ( + ("100:103", None, 2, False), + [ + {"port": [100, 102]}, # Round robin apportion + {"port": [101, 103]}, + ], + ), + # port range is not an exact multiple of n_workers + ( + ("100:107", None, 3, False), + [ + {"port": [100, 103, 106]}, + {"port": [101, 104, 107]}, + {"port": [102, 105]}, + ], + ), + ( + ("100:103", None, 2, True), + [ + {"port": None, "worker_port": [100, 102]}, + {"port": None, "worker_port": [101, 103]}, + ], + ), + ( + (None, "110:113", 2, True), + [ + {"port": [110, 112], "worker_port": None}, + {"port": [111, 113], "worker_port": None}, + ], + ), + # port ranges have different length between nannies and workers + ( + ("100:103", "110:114", 2, True), + [ + {"port": [110, 112, 114], "worker_port": [100, 102]}, + {"port": [111, 113], "worker_port": [101, 103]}, + ], + ), + # identical port ranges + ( + ("100:103", "100:103", 2, True), + [ + {"port": 101, "worker_port": 100}, + {"port": 103, "worker_port": 102}, + ], + ), + # overlapping port ranges + ( + ("100:105", "104:106", 2, True), + [ + {"port": [104, 106], "worker_port": [100, 102]}, + {"port": 105, "worker_port": [101, 103]}, + ], + ), + ], +) +def test_apportion_ports(args, expect): + assert _apportion_ports(*args) == expect + + +def test_apportion_ports_bad(): + with pytest.raises(ValueError, match="Not enough ports in range"): + _apportion_ports("100:102", None, 4, False) + with pytest.raises(ValueError, match="Not enough ports in range"): + _apportion_ports(None, "100:102", 4, False) + with pytest.raises(ValueError, match="Not enough ports in range"): + _apportion_ports("100:102", "100:102", 3, True) + with pytest.raises(ValueError, match="Not enough ports in range"): + _apportion_ports("100:102", "102:104", 3, True) + with pytest.raises(ValueError, match="port_stop must be greater than port_start"): + _apportion_ports("102:100", None, 4, False) + with pytest.raises(ValueError, match="invalid literal for int"): + _apportion_ports("foo", None, 1, False) + with pytest.raises(ValueError, match="too many values to unpack"): + _apportion_ports("100:101:102", None, 1, False) + + @pytest.mark.slow @gen_cluster(client=True, nthreads=[]) async def test_nanny_worker_ports(c, s): @@ -38,8 +179,32 @@ async def test_nanny_worker_ports(c, s): @pytest.mark.slow +@pytest.mark.flaky( + LINUX and sys.version_info == (3, 9), + reason="Race condition in Nanny.process.start? " + "See https://github.com/dask/distributed/issues/6045", +) @gen_cluster(client=True, nthreads=[]) async def test_nanny_worker_port_range(c, s): + async def assert_ports(min_: int, max_: int, nanny: bool) -> None: + port_ranges = await c.run( + lambda dask_worker: dask_worker._start_port, nanny=nanny + ) + + for a in port_ranges.values(): + assert isinstance(a, list) + assert len(a) in (333, 334) + assert all(min_ <= i <= max_ for i in a) + + # Test no overlap + for b in port_ranges.values(): + assert a is b or not set(a) & set(b) + + ports = await c.run(lambda dask_worker: dask_worker.port, nanny=nanny) + assert all(min_ <= p <= max_ for p in ports.values()) + for addr, range in port_ranges.items(): + assert ports[addr] in range + with popen( [ "dask-worker", @@ -56,10 +221,8 @@ async def test_nanny_worker_port_range(c, s): ] ): await c.wait_for_workers(3) - worker_ports = await c.run(lambda dask_worker: dask_worker.port) - assert all(10000 <= p <= 11000 for p in worker_ports.values()) - nanny_ports = await c.run(lambda dask_worker: dask_worker.port, nanny=True) - assert all(11000 <= p <= 12000 for p in nanny_ports.values()) + await assert_ports(10000, 10999, nanny=False) + await assert_ports(11000, 12000, nanny=True) @gen_cluster(nthreads=[]) @@ -80,7 +243,9 @@ async def test_nanny_worker_port_range_too_many_workers_raises(s): ], flush_output=False, ) as worker: - assert any(b"Could not start" in worker.stdout.readline() for _ in range(100)) + assert any( + b"Not enough ports in range" in worker.stdout.readline() for _ in range(100) + ) @pytest.mark.slow @@ -416,7 +581,7 @@ async def test_dashboard_non_standard_ports(c, s): def test_version_option(): runner = CliRunner() - result = runner.invoke(distributed.cli.dask_worker.main, ["--version"]) + result = runner.invoke(main, ["--version"]) assert result.exit_code == 0 @@ -427,7 +592,7 @@ def test_worker_timeout(no_nanny): args = ["192.168.1.100:7777", "--death-timeout=1"] if no_nanny: args.append("--no-nanny") - result = runner.invoke(distributed.cli.dask_worker.main, args) + result = runner.invoke(main, args) assert result.exit_code != 0 @@ -437,14 +602,14 @@ def test_bokeh_deprecation(): runner = CliRunner() with pytest.warns(UserWarning, match="dashboard"): try: - runner.invoke(distributed.cli.dask_worker.main, ["--bokeh"]) + runner.invoke(main, ["--bokeh"]) except ValueError: # didn't pass scheduler pass with pytest.warns(UserWarning, match="dashboard"): try: - runner.invoke(distributed.cli.dask_worker.main, ["--no-bokeh"]) + runner.invoke(main, ["--no-bokeh"]) except ValueError: # didn't pass scheduler pass diff --git a/distributed/nanny.py b/distributed/nanny.py index e946655f501..db2371523e8 100644 --- a/distributed/nanny.py +++ b/distributed/nanny.py @@ -9,6 +9,7 @@ import uuid import warnings import weakref +from collections.abc import Collection from contextlib import suppress from inspect import isawaitable from queue import Empty @@ -24,7 +25,7 @@ from dask.utils import parse_timedelta from distributed import preloading -from distributed.comm import get_address_host, unparse_host_port +from distributed.comm import get_address_host from distributed.comm.addressing import address_from_user_args from distributed.core import ( CommClosedError, @@ -96,12 +97,16 @@ class Nanny(ServerNode): status = Status.undefined memory_manager: NannyMemoryManager + # Inputs to parse_ports() + _given_worker_port: int | str | Collection[int] | None + _start_port: int | str | Collection[int] | None + def __init__( self, scheduler_ip=None, scheduler_port=None, scheduler_file=None, - worker_port=0, + worker_port: int | str | Collection[int] | None = 0, nthreads=None, loop=None, local_dir=None, @@ -126,7 +131,7 @@ def __init__( env=None, interface=None, host=None, - port=None, + port: int | str | Collection[int] | None = None, protocol=None, config=None, **worker_kwargs, @@ -373,14 +378,6 @@ async def instantiate(self) -> Status: Blocks until the process is up and the scheduler is properly informed """ - if self._listen_address: - start_arg = self._listen_address - else: - host = self.listener.bound_address[0] - start_arg = self.listener.prefix + unparse_host_port( - host, self._given_worker_port - ) - if self.process is None: worker_kwargs = dict( scheduler_ip=self.scheduler_addr, @@ -403,7 +400,6 @@ async def instantiate(self) -> Status: worker_kwargs.update(self.worker_kwargs) self.process = WorkerProcess( worker_kwargs=worker_kwargs, - worker_start_args=(start_arg,), silence_logs=self.silence_logs, on_exit=self._on_exit_sync, worker=self.Worker, @@ -617,7 +613,6 @@ class WorkerProcess: def __init__( self, worker_kwargs, - worker_start_args, silence_logs, on_exit, worker, @@ -627,7 +622,6 @@ def __init__( self.status = Status.init self.silence_logs = silence_logs self.worker_kwargs = worker_kwargs - self.worker_start_args = worker_start_args self.on_exit = on_exit self.process = None self.Worker = worker @@ -658,7 +652,6 @@ async def start(self) -> Status: name="Dask Worker process (from Nanny)", kwargs=dict( worker_kwargs=self.worker_kwargs, - worker_start_args=self.worker_start_args, silence_logs=self.silence_logs, init_result_q=self.init_result_q, child_stop_q=self.child_stop_q, @@ -808,7 +801,6 @@ async def _wait_until_connected(self, uid): def _run( cls, worker_kwargs, - worker_start_args, silence_logs, init_result_q, child_stop_q, diff --git a/distributed/tests/test_utils.py b/distributed/tests/test_utils.py index 931db6e090c..9163458a5af 100644 --- a/distributed/tests/test_utils.py +++ b/distributed/tests/test_utils.py @@ -528,9 +528,22 @@ def test_parse_ports(): assert parse_ports(23) == [23] assert parse_ports("45") == [45] assert parse_ports("100:103") == [100, 101, 102, 103] + assert parse_ports([100, 101, 102, 103]) == [100, 101, 102, 103] + + out = parse_ports((100, 101, 102, 103)) + assert out == [100, 101, 102, 103] + assert isinstance(out, list) with pytest.raises(ValueError, match="port_stop must be greater than port_start"): parse_ports("103:100") + with pytest.raises(TypeError): + parse_ports(100.5) + with pytest.raises(TypeError): + parse_ports([100, 100.5]) + with pytest.raises(ValueError): + parse_ports("foo") + with pytest.raises(ValueError): + parse_ports("100.5") def test_lru(): diff --git a/distributed/utils.py b/distributed/utils.py index 64203488d12..cd0d93d08dd 100644 --- a/distributed/utils.py +++ b/distributed/utils.py @@ -20,7 +20,7 @@ import xml.etree.ElementTree from asyncio import TimeoutError from collections import OrderedDict, UserDict, deque -from collections.abc import Container, KeysView, ValuesView +from collections.abc import Collection, Container, KeysView, ValuesView from concurrent.futures import CancelledError, ThreadPoolExecutor # noqa: F401 from contextlib import contextmanager, suppress from contextvars import ContextVar @@ -1148,15 +1148,15 @@ def format_dashboard_link(host, port): ) -def parse_ports(port): +def parse_ports(port: int | str | Collection[int] | None) -> list[int] | list[None]: """Parse input port information into list of ports Parameters ---------- - port : int, str, None + port : int, str, list[int], None Input port or ports. Can be an integer like 8787, a string for a - single port like "8787", a string for a sequential range of ports like - "8000:8200", or None. + single port like "8787", string for a sequential range of ports like + "8000:8200", a collection of ints, or None. Returns ------- @@ -1188,12 +1188,7 @@ def parse_ports(port): [None] """ - if isinstance(port, str) and ":" not in port: - port = int(port) - - if isinstance(port, (int, type(None))): - ports = [port] - else: + if isinstance(port, str) and ":" in port: port_start, port_stop = map(int, port.split(":")) if port_stop <= port_start: raise ValueError( @@ -1201,9 +1196,20 @@ def parse_ports(port): "port_stop must be greater than port_start, but got " f"{port_start=} and {port_stop=}" ) - ports = list(range(port_start, port_stop + 1)) + return list(range(port_start, port_stop + 1)) + + if isinstance(port, str): + return [int(port)] + + if isinstance(port, int) or port is None: + return [port] # type: ignore + + if isinstance(port, Collection): + if not all(isinstance(p, int) for p in port): + raise TypeError(port) + return list(port) # type: ignore - return ports + raise TypeError(port) is_coroutine_function = iscoroutinefunction diff --git a/distributed/worker.py b/distributed/worker.py index 33b43bb2c02..137e9169bc4 100644 --- a/distributed/worker.py +++ b/distributed/worker.py @@ -380,7 +380,7 @@ class Worker(ServerNode): bandwidth_types: defaultdict[type, tuple[float, int]] preloads: list[preloading.Preload] contact_address: str | None - _start_port: int | None + _start_port: int | str | Collection[int] | None = None _start_host: str | None _interface: str | None _protocol: str @@ -446,7 +446,7 @@ def __init__( ] = DEFAULT_STARTUP_INFORMATION, interface: str | None = None, host: str | None = None, - port: int | None = None, + port: int | str | Collection[int] | None = None, protocol: str | None = None, dashboard_address: str | None = None, dashboard: bool = False,