diff --git a/distributed/comm/tests/test_ucx.py b/distributed/comm/tests/test_ucx.py index b841eec99ed..a55dcf11f08 100644 --- a/distributed/comm/tests/test_ucx.py +++ b/distributed/comm/tests/test_ucx.py @@ -2,6 +2,8 @@ import pytest +import dask + pytestmark = pytest.mark.gpu ucp = pytest.importorskip("ucp") @@ -10,6 +12,7 @@ from distributed.comm import connect, listen, parse_address, ucx from distributed.comm.registry import backends, get_backend from distributed.deploy.local import LocalCluster +from distributed.diagnostics.nvml import has_cuda_context from distributed.protocol import to_serialize from distributed.utils_test import inc @@ -326,6 +329,20 @@ async def test_simple(): assert await client.submit(lambda x: x + 1, 10) == 11 +@pytest.mark.asyncio +async def test_cuda_context(): + with dask.config.set({"distributed.comm.ucx.create_cuda_context": True}): + async with LocalCluster( + protocol="ucx", n_workers=1, asynchronous=True + ) as cluster: + async with Client(cluster, asynchronous=True) as client: + assert cluster.scheduler_address.startswith("ucx://") + assert has_cuda_context() == 0 + worker_cuda_context = await client.run(has_cuda_context) + assert len(worker_cuda_context) == 1 + assert list(worker_cuda_context.values())[0] == 0 + + @pytest.mark.asyncio async def test_transpose(): da = pytest.importorskip("dask.array") diff --git a/distributed/comm/tests/test_ucx_config.py b/distributed/comm/tests/test_ucx_config.py index 8fa38290f0a..7bd6db66410 100644 --- a/distributed/comm/tests/test_ucx_config.py +++ b/distributed/comm/tests/test_ucx_config.py @@ -79,6 +79,19 @@ async def test_ucx_config(cleanup): assert ucx_config.get("TLS") == "rc,tcp,rdmacm,cuda_copy" assert ucx_config.get("SOCKADDR_TLS_PRIORITY") == "rdmacm" + ucx = { + "nvlink": None, + "infiniband": None, + "rdmacm": None, + "net-devices": None, + "tcp": None, + "cuda_copy": None, + } + + with dask.config.set({"distributed.comm.ucx": ucx}): + ucx_config = _scrub_ucx_config() + assert ucx_config == {} + @pytest.mark.flaky( reruns=10, reruns_delay=5, condition=ucp.get_ucx_version() < (1, 11, 0) diff --git a/distributed/comm/ucx.py b/distributed/comm/ucx.py index e47b61a4fa0..63a2556fde0 100644 --- a/distributed/comm/ucx.py +++ b/distributed/comm/ucx.py @@ -70,7 +70,9 @@ def init_once(): # We ensure the CUDA context is created before initializing UCX. This can't # be safely handled externally because communications in Dask start before # preload scripts run. - if "TLS" in ucx_config and "cuda_copy" in ucx_config["TLS"]: + if dask.config.get("distributed.comm.ucx.create_cuda_context") is True or ( + "TLS" in ucx_config and "cuda_copy" in ucx_config["TLS"] + ): try: import numba.cuda except ImportError: diff --git a/distributed/distributed-schema.yaml b/distributed/distributed-schema.yaml index a19a6fc6fcd..477187b9177 100644 --- a/distributed/distributed-schema.yaml +++ b/distributed/distributed-schema.yaml @@ -834,6 +834,15 @@ properties: introduced to resolve an issue with CUDA IPC that has been fixed in UCX 1.10, but can cause establishing endpoints to be very slow, this is particularly noticeable in clusters of more than a few dozen workers. + create-cuda-context: + type: [boolean, 'null'] + description: | + Creates a CUDA context before UCX is initialized. This is necessary to enable UCX to + properly identify connectivity of GPUs with specialized networking hardware, such as + InfiniBand. This permits UCX to choose transports automatically, without specifying + additional variables for each transport, while ensuring optimal connectivity. When + ``True``, a CUDA context will be created on the first device listed in + ``CUDA_VISIBLE_DEVICES``. websockets: type: object diff --git a/distributed/distributed.yaml b/distributed/distributed.yaml index c9677424aa9..97ac480a172 100644 --- a/distributed/distributed.yaml +++ b/distributed/distributed.yaml @@ -187,13 +187,14 @@ distributed: socket-backlog: 2048 recent-messages-log-length: 0 # number of messages to keep for debugging ucx: - cuda_copy: False # enable cuda-copy - tcp: False # enable tcp - nvlink: False # enable cuda_ipc - infiniband: False # enable Infiniband - rdmacm: False # enable RDMACM + cuda_copy: null # enable cuda-copy + tcp: null # enable tcp + nvlink: null # enable cuda_ipc + infiniband: null # enable Infiniband + rdmacm: null # enable RDMACM net-devices: null # define what interface to use for UCX comm reuse-endpoints: null # enable endpoint reuse + create-cuda-context: null # create CUDA context before UCX initialization zstd: level: 3 # Compression level, between 1 and 22.