diff --git a/python/raft-dask/raft_dask/common/utils.py b/python/raft-dask/raft_dask/common/utils.py index 78a899aa50..dcc53fda9a 100644 --- a/python/raft-dask/raft_dask/common/utils.py +++ b/python/raft-dask/raft_dask/common/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/raft-dask/raft_dask/test/__init__.py b/python/raft-dask/raft_dask/test/__init__.py new file mode 100644 index 0000000000..764e0f32fd --- /dev/null +++ b/python/raft-dask/raft_dask/test/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/raft-dask/raft_dask/test/conftest.py b/python/raft-dask/raft_dask/test/conftest.py index 39ee21cbaa..d1baa684d4 100644 --- a/python/raft-dask/raft_dask/test/conftest.py +++ b/python/raft-dask/raft_dask/test/conftest.py @@ -1,54 +1,71 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. import os import pytest from dask.distributed import Client -from dask_cuda import LocalCUDACluster, initialize +from dask_cuda import LocalCUDACluster os.environ["UCX_LOG_LEVEL"] = "error" -enable_tcp_over_ucx = True -enable_nvlink = False -enable_infiniband = False - - @pytest.fixture(scope="session") def cluster(): - cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) - yield cluster - cluster.close() + scheduler_file = os.environ.get("SCHEDULER_FILE") + if scheduler_file: + yield scheduler_file + else: + cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) + yield cluster + cluster.close() @pytest.fixture(scope="session") def ucx_cluster(): - initialize.initialize( - create_cuda_context=True, - enable_tcp_over_ucx=enable_tcp_over_ucx, - enable_nvlink=enable_nvlink, - enable_infiniband=enable_infiniband, - ) - cluster = LocalCUDACluster( - protocol="ucx", - enable_tcp_over_ucx=enable_tcp_over_ucx, - enable_nvlink=enable_nvlink, - enable_infiniband=enable_infiniband, - ) - yield cluster - cluster.close() + scheduler_file = os.environ.get("SCHEDULER_FILE") + if scheduler_file: + yield scheduler_file + else: + cluster = LocalCUDACluster( + protocol="ucx", + ) + yield cluster + cluster.close() @pytest.fixture(scope="session") def client(cluster): - client = Client(cluster) + client = create_client(cluster) yield client client.close() @pytest.fixture() def ucx_client(ucx_cluster): - client = Client(cluster) + client = create_client(ucx_cluster) yield client client.close() + + +def create_client(cluster): + """ + Create a Dask distributed client for a specified cluster. + + Parameters + ---------- + cluster : LocalCUDACluster instance or str + If a LocalCUDACluster instance is provided, a client will be created + for it directly. If a string is provided, it should specify the path to + a Dask scheduler file. A client will then be created for the cluster + referenced by this scheduler file. + + Returns + ------- + dask.distributed.Client + A client connected to the specified cluster. + """ + if isinstance(cluster, LocalCUDACluster): + return Client(cluster) + else: + return Client(scheduler_file=cluster) diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 3a430f9270..5c69a94fd8 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -17,7 +17,9 @@ import pytest -from dask.distributed import Client, get_worker, wait +from dask.distributed import get_worker, wait + +from .conftest import create_client try: from raft_dask.common import ( @@ -43,9 +45,7 @@ def test_comms_init_no_p2p(cluster): - - client = Client(cluster) - + client = create_client(cluster) try: cb = Comms(verbose=True) cb.init() @@ -121,8 +121,7 @@ def func_check_uid_on_worker(sessionId, uniqueId, dask_worker=None): def test_handles(cluster): - - client = Client(cluster) + client = create_client(cluster) def _has_handle(sessionId): return local_handle(sessionId, dask_worker=get_worker()) is not None