Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add distributed.comm.ucx.create-cuda-context config #5526

Merged
merged 2 commits into from
Nov 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions distributed/comm/tests/test_ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import pytest

import dask

pytestmark = pytest.mark.gpu

ucp = pytest.importorskip("ucp")
Expand All @@ -10,6 +12,7 @@
from distributed.comm import connect, listen, parse_address, ucx
from distributed.comm.registry import backends, get_backend
from distributed.deploy.local import LocalCluster
from distributed.diagnostics.nvml import has_cuda_context
from distributed.protocol import to_serialize
from distributed.utils_test import inc

Expand Down Expand Up @@ -326,6 +329,20 @@ async def test_simple():
assert await client.submit(lambda x: x + 1, 10) == 11


@pytest.mark.asyncio
async def test_cuda_context():
with dask.config.set({"distributed.comm.ucx.create_cuda_context": True}):
async with LocalCluster(
protocol="ucx", n_workers=1, asynchronous=True
) as cluster:
async with Client(cluster, asynchronous=True) as client:
assert cluster.scheduler_address.startswith("ucx://")
assert has_cuda_context() == 0
worker_cuda_context = await client.run(has_cuda_context)
assert len(worker_cuda_context) == 1
assert list(worker_cuda_context.values())[0] == 0


@pytest.mark.asyncio
async def test_transpose():
da = pytest.importorskip("dask.array")
Expand Down
13 changes: 13 additions & 0 deletions distributed/comm/tests/test_ucx_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,19 @@ async def test_ucx_config(cleanup):
assert ucx_config.get("TLS") == "rc,tcp,rdmacm,cuda_copy"
assert ucx_config.get("SOCKADDR_TLS_PRIORITY") == "rdmacm"

ucx = {
"nvlink": None,
"infiniband": None,
"rdmacm": None,
"net-devices": None,
"tcp": None,
"cuda_copy": None,
}

with dask.config.set({"distributed.comm.ucx": ucx}):
ucx_config = _scrub_ucx_config()
assert ucx_config == {}


@pytest.mark.flaky(
reruns=10, reruns_delay=5, condition=ucp.get_ucx_version() < (1, 11, 0)
Expand Down
4 changes: 3 additions & 1 deletion distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def init_once():
# We ensure the CUDA context is created before initializing UCX. This can't
# be safely handled externally because communications in Dask start before
# preload scripts run.
if "TLS" in ucx_config and "cuda_copy" in ucx_config["TLS"]:
if dask.config.get("distributed.comm.ucx.create_cuda_context") is True or (
"TLS" in ucx_config and "cuda_copy" in ucx_config["TLS"]
):
try:
import numba.cuda
except ImportError:
Expand Down
9 changes: 9 additions & 0 deletions distributed/distributed-schema.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -834,6 +834,15 @@ properties:
introduced to resolve an issue with CUDA IPC that has been fixed in UCX 1.10, but
can cause establishing endpoints to be very slow, this is particularly noticeable in
clusters of more than a few dozen workers.
create-cuda-context:
type: [boolean, 'null']
description: |
Creates a CUDA context before UCX is initialized. This is necessary to enable UCX to
properly identify connectivity of GPUs with specialized networking hardware, such as
InfiniBand. This permits UCX to choose transports automatically, without specifying
additional variables for each transport, while ensuring optimal connectivity. When
``True``, a CUDA context will be created on the first device listed in
``CUDA_VISIBLE_DEVICES``.

websockets:
type: object
Expand Down
11 changes: 6 additions & 5 deletions distributed/distributed.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -187,13 +187,14 @@ distributed:
socket-backlog: 2048
recent-messages-log-length: 0 # number of messages to keep for debugging
ucx:
cuda_copy: False # enable cuda-copy
tcp: False # enable tcp
nvlink: False # enable cuda_ipc
infiniband: False # enable Infiniband
rdmacm: False # enable RDMACM
cuda_copy: null # enable cuda-copy
tcp: null # enable tcp
nvlink: null # enable cuda_ipc
infiniband: null # enable Infiniband
rdmacm: null # enable RDMACM
net-devices: null # define what interface to use for UCX comm
reuse-endpoints: null # enable endpoint reuse
create-cuda-context: null # create CUDA context before UCX initialization

zstd:
level: 3 # Compression level, between 1 and 22.
Expand Down