Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid port collisions when defining port ranges #6054

Merged
merged 7 commits into from
Apr 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 99 additions & 16 deletions distributed/cli/dask_worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import asyncio
import atexit
import gc
Expand All @@ -6,7 +8,9 @@
import signal
import sys
import warnings
from collections.abc import Iterator
from contextlib import suppress
from typing import Any

import click
from tlz import valmap
Expand All @@ -24,7 +28,7 @@
enable_proctitle_on_children,
enable_proctitle_on_current,
)
from distributed.utils import import_term
from distributed.utils import import_term, parse_ports

logger = logging.getLogger("distributed.dask_worker")

Expand Down Expand Up @@ -254,10 +258,10 @@
def main(
scheduler,
host,
worker_port,
worker_port: str | None,
listen_address,
contact_address,
nanny_port,
nanny_port: str | None,
nthreads,
nprocs,
n_workers,
Expand Down Expand Up @@ -364,7 +368,8 @@ def main(

try:
if listen_address:
(host, worker_port) = get_address_host_port(listen_address, strict=True)
host, _ = get_address_host_port(listen_address, strict=True)
worker_port = str(_)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy hack

if ":" in host:
# IPv6 -- bracket to pass as user args
host = f"[{host}]"
Expand All @@ -379,11 +384,6 @@ def main(
logger.error("Failed to launch worker. " + str(e))
sys.exit(1)

if nanny:
port = nanny_port
else:
port = worker_port

if not nthreads:
nthreads = CPU_COUNT // n_workers

Expand All @@ -407,16 +407,16 @@ def del_pid_file():
loop = IOLoop.current()

worker_class = import_term(worker_class)

port_kwargs = _apportion_ports(worker_port, nanny_port, n_workers, nanny)
assert len(port_kwargs) == n_workers

if nanny:
kwargs["worker_class"] = worker_class
kwargs["preload_nanny"] = preload_nanny

if nanny:
kwargs.update({"worker_port": worker_port, "listen_address": listen_address})
kwargs["listen_address"] = listen_address
t = Nanny
else:
if nanny_port:
kwargs["service_ports"] = {"nanny": nanny_port}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like legacy cruft. service_ports is a property and can't be passed to Worker.__init__.

t = worker_class

if (
Expand All @@ -442,15 +442,15 @@ def del_pid_file():
security=sec,
contact_address=contact_address,
host=host,
port=port,
dashboard=dashboard,
dashboard_address=dashboard_address,
name=name
if n_workers == 1 or name is None or name == ""
else str(name) + "-" + str(i),
**kwargs,
**port_kwargs_i,
)
for i in range(n_workers)
for i, port_kwargs_i in enumerate(port_kwargs)
]

async def close_all():
Expand Down Expand Up @@ -486,5 +486,88 @@ async def run():
logger.info("End worker")


def _apportion_ports(
worker_port: str | None, nanny_port: str | None, n_workers: int, nanny: bool
) -> list[dict[str, Any]]:
"""Spread out evenly --worker-port and/or --nanny-port ranges to the workers and
nannies, avoiding overlap.

Returns
=======
List of kwargs to pass to the Worker or Nanny construtors
"""
seen = set()

def parse_unique(s: str | None) -> Iterator[int | None]:
ports = parse_ports(s)
if ports in ([0], [None]):
for _ in range(n_workers):
yield ports[0]
else:
for port in ports:
if port not in seen:
seen.add(port)
yield port

worker_ports_iter = parse_unique(worker_port)
nanny_ports_iter = parse_unique(nanny_port)

# [(worker ports, nanny ports), ...]
ports: list[tuple[set[int | None], set[int | None]]] = [
(set(), set()) for _ in range(n_workers)
]

ports_iter = iter(ports)
more_wps = True
more_nps = True
while more_wps or more_nps:
try:
worker_ports_i, nanny_ports_i = next(ports_iter)
except StopIteration:
# Start again in round-robin
ports_iter = iter(ports)
continue

try:
worker_ports_i.add(next(worker_ports_iter))
except StopIteration:
more_wps = False
try:
nanny_ports_i.add(next(nanny_ports_iter))
except StopIteration:
more_nps = False

kwargs = []
for worker_ports_i, nanny_ports_i in ports:
if not worker_ports_i or not nanny_ports_i:
if nanny:
raise ValueError(
f"Not enough ports in range --worker_port {worker_port} "
f"--nanny_port {nanny_port} for {n_workers} workers"
)
else:
raise ValueError(
f"Not enough ports in range --worker_port {worker_port} "
f"for {n_workers} workers"
)

# None and int can't be sorted together,
# but None and 0 are guaranteed to be alone
wp: Any = sorted(worker_ports_i)
if len(wp) == 1:
wp = wp[0]
if nanny:
np: Any = sorted(nanny_ports_i)
if len(np) == 1:
np = np[0]
kwargs_i = {"port": np, "worker_port": wp}
else:
kwargs_i = {"port": wp}

kwargs.append(kwargs_i)

return kwargs


if __name__ == "__main__":
main() # pragma: no cover
Loading