Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix CUDA_VISIBLE_DEVICES tests #638

Merged
12 changes: 12 additions & 0 deletions dask_cuda/cli/dask_cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from distributed.cli.utils import check_python_3, install_signal_handlers
from distributed.preloading import validate_preload_argv
from distributed.security import Security
from distributed.utils import import_term

from ..cuda_worker import CUDAWorker

Expand Down Expand Up @@ -248,6 +249,12 @@
``proxy_object.ProxyObject`` and ``proxify_host_file.ProxifyHostFile`` for more
info.""",
)
@click.option(
"--worker-class",
default=None,
help="""Use a different class than Distributed's default (``distributed.Worker``)
to spawn ``distributed.Nanny``.""",
)
def main(
scheduler,
host,
Expand Down Expand Up @@ -277,6 +284,7 @@ def main(
enable_rdmacm,
net_devices,
enable_jit_unspill,
worker_class,
**kwargs,
):
if tls_ca_file and tls_cert and tls_key:
Expand All @@ -293,6 +301,9 @@ def main(
"unsupported one. Scheduler address: %s" % scheduler
)

if worker_class is not None:
worker_class = import_term(worker_class)

worker = CUDAWorker(
scheduler,
host,
Expand Down Expand Up @@ -320,6 +331,7 @@ def main(
enable_rdmacm,
net_devices,
enable_jit_unspill,
worker_class,
**kwargs,
)

Expand Down
2 changes: 2 additions & 0 deletions dask_cuda/cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def __init__(
enable_rdmacm=False,
net_devices=None,
jit_unspill=None,
worker_class=None,
**kwargs,
):
# Required by RAPIDS libraries (e.g., cuDF) to ensure no context
Expand Down Expand Up @@ -235,6 +236,7 @@ def del_pid_file():
)
},
data=data(i),
worker_class=worker_class,
**kwargs,
)
for i in range(nprocs)
Expand Down
11 changes: 10 additions & 1 deletion dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ def __init__(
rmm_log_directory=None,
jit_unspill=None,
log_spilling=False,
worker_class=None,
**kwargs,
):
# Required by RAPIDS libraries (e.g., cuDF) to ensure no context
Expand Down Expand Up @@ -306,6 +307,14 @@ def __init__(
cuda_device_index=0,
)

if worker_class is not None:
from functools import partial

worker_class = partial(
LoggedNanny if log_spilling is True else Nanny,
worker_class=worker_class,
)

super().__init__(
n_workers=0,
threads_per_worker=threads_per_worker,
Expand All @@ -314,7 +323,7 @@ def __init__(
data=data,
local_directory=local_directory,
protocol=protocol,
worker_class=LoggedNanny if log_spilling is True else Nanny,
worker_class=worker_class,
config={
"ucx": get_ucx_config(
enable_tcp_over_ucx=enable_tcp_over_ucx,
Expand Down
6 changes: 4 additions & 2 deletions dask_cuda/tests/test_dask_cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@


def test_cuda_visible_devices_and_memory_limit_and_nthreads(loop): # noqa: F811
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,7,8"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,3,7,8"
nthreads = 4
try:
with popen(["dask-scheduler", "--port", "9359", "--no-dashboard"]):
Expand All @@ -34,6 +34,8 @@ def test_cuda_visible_devices_and_memory_limit_and_nthreads(loop): # noqa: F811
"--nthreads",
str(nthreads),
"--no-dashboard",
"--worker-class",
"dask_cuda.utils.MockWorker",
]
):
with Client("127.0.0.1:9359", loop=loop) as client:
Expand All @@ -44,7 +46,7 @@ def get_visible_devices():

# verify 4 workers with the 4 expected CUDA_VISIBLE_DEVICES
result = client.run(get_visible_devices)
expected = {"2,3,7,8": 1, "3,7,8,2": 1, "7,8,2,3": 1, "8,2,3,7": 1}
expected = {"0,3,7,8": 1, "3,7,8,0": 1, "7,8,0,3": 1, "8,0,3,7": 1}
for v in result.values():
del expected[v]

Expand Down
16 changes: 10 additions & 6 deletions dask_cuda/tests/test_local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from dask_cuda import CUDAWorker, LocalCUDACluster, utils
from dask_cuda.initialize import initialize
from dask_cuda.utils import MockWorker

_driver_version = rmm._cuda.gpu.driverGetVersion()
_runtime_version = rmm._cuda.gpu.runtimeGetVersion()
Expand Down Expand Up @@ -51,10 +52,13 @@ def get_visible_devices():
# than 8 but as long as the test passes the errors can be ignored.
@gen_test(timeout=20)
async def test_with_subset_of_cuda_visible_devices():
os.environ["CUDA_VISIBLE_DEVICES"] = "2,3,6,7"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,3,6,8"
try:
async with LocalCUDACluster(
scheduler_port=0, asynchronous=True, device_memory_limit=1
scheduler_port=0,
asynchronous=True,
device_memory_limit=1,
worker_class=MockWorker,
) as cluster:
async with Client(cluster, asynchronous=True) as client:
assert len(cluster.workers) == 4
Expand All @@ -68,10 +72,10 @@ def get_visible_devices():
assert all(len(v.split(",")) == 4 for v in result.values())
for i in range(4):
assert {int(v.split(",")[i]) for v in result.values()} == {
2,
0,
3,
6,
7,
8,
}
finally:
del os.environ["CUDA_VISIBLE_DEVICES"]
Expand Down Expand Up @@ -106,7 +110,7 @@ async def test_ucx_protocol_type_error():
@gen_test(timeout=20)
async def test_n_workers():
async with LocalCUDACluster(
CUDA_VISIBLE_DEVICES="0,1", asynchronous=True
CUDA_VISIBLE_DEVICES="0,1", worker_class=MockWorker, asynchronous=True
) as cluster:
assert len(cluster.workers) == 2
assert len(cluster.worker_spec) == 2
Expand All @@ -121,7 +125,7 @@ async def test_threads_per_worker():
@gen_test(timeout=20)
async def test_all_to_all():
async with LocalCUDACluster(
CUDA_VISIBLE_DEVICES="0,1", asynchronous=True
CUDA_VISIBLE_DEVICES="0,1", worker_class=MockWorker, asynchronous=True
) as cluster:
async with Client(cluster, asynchronous=True) as client:
workers = list(client.scheduler_info()["workers"])
Expand Down
27 changes: 26 additions & 1 deletion dask_cuda/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import pynvml
import toolz

from distributed import wait
from distributed import Worker, wait
from distributed.utils import parse_bytes

try:
Expand Down Expand Up @@ -530,3 +530,28 @@ def parse_device_memory_limit(device_memory_limit, device_index=0):
return parse_bytes(device_memory_limit)
else:
return int(device_memory_limit)


class MockWorker(Worker):
"""Mock Worker class preventing NVML from getting used by SystemMonitor.
By preventing the Worker from initializing NVML in the SystemMonitor, we can
mock test multiple devices in `CUDA_VISIBLE_DEVICES` behavior with single-GPU
machines.
"""

def __init__(self, *args, **kwargs):
import distributed

distributed.diagnostics.nvml.device_get_count = MockWorker.device_get_count
self._device_get_count = distributed.diagnostics.nvml.device_get_count
super().__init__(*args, **kwargs)

def __del__(self):
import distributed

distributed.diagnostics.nvml.device_get_count = self._device_get_count

@staticmethod
def device_get_count():
return 0