From 40ea1ca25cc4b8cd7662933cb029575be2da3492 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 13 Jun 2023 12:36:34 -0700 Subject: [PATCH 1/9] Add scheduler_file argument to support MNMG setup --- python/raft-dask/raft_dask/test/conftest.py | 52 ++++++++++----------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/python/raft-dask/raft_dask/test/conftest.py b/python/raft-dask/raft_dask/test/conftest.py index 39ee21cbaa..ac4f6d74d6 100644 --- a/python/raft-dask/raft_dask/test/conftest.py +++ b/python/raft-dask/raft_dask/test/conftest.py @@ -1,54 +1,54 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. import os import pytest from dask.distributed import Client -from dask_cuda import LocalCUDACluster, initialize +from dask_cuda import LocalCUDACluster os.environ["UCX_LOG_LEVEL"] = "error" -enable_tcp_over_ucx = True -enable_nvlink = False -enable_infiniband = False - - @pytest.fixture(scope="session") def cluster(): - cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) - yield cluster - cluster.close() + scheduler_file = os.environ.get("SCHEDULER_FILE") + if scheduler_file: + return scheduler_file + else: + cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) + yield cluster + cluster.close() @pytest.fixture(scope="session") def ucx_cluster(): - initialize.initialize( - create_cuda_context=True, - enable_tcp_over_ucx=enable_tcp_over_ucx, - enable_nvlink=enable_nvlink, - enable_infiniband=enable_infiniband, - ) - cluster = LocalCUDACluster( - protocol="ucx", - enable_tcp_over_ucx=enable_tcp_over_ucx, - enable_nvlink=enable_nvlink, - enable_infiniband=enable_infiniband, - ) - yield cluster - cluster.close() + scheduler_file = os.environ.get("SCHEDULER_FILE") + if scheduler_file: + return scheduler_file + else: + cluster = LocalCUDACluster( + protocol="ucx", + ) + yield cluster + cluster.close() @pytest.fixture(scope="session") def client(cluster): - client = Client(cluster) + if isinstance(cluster, LocalCUDACluster): + client = Client(cluster) + else: + client = Client(scheduler_file=cluster) yield client client.close() @pytest.fixture() def ucx_client(ucx_cluster): - client = Client(cluster) + if isinstance(ucx_cluster, LocalCUDACluster): + client = Client(ucx_cluster) + else: + client = Client(scheduler_file=cluster) yield client client.close() From 1e594dbee3baacf97b711ae70aaedcfe6a00125b Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 13 Jun 2023 15:00:51 -0700 Subject: [PATCH 2/9] Verfied working with non local cuda cluster --- python/raft-dask/raft_dask/common/utils.py | 12 ++++++++++-- python/raft-dask/raft_dask/test/__init__.py | 16 ++++++++++++++++ python/raft-dask/raft_dask/test/conftest.py | 17 ++++++----------- python/raft-dask/raft_dask/test/test_comms.py | 11 +++++------ 4 files changed, 37 insertions(+), 19 deletions(-) create mode 100644 python/raft-dask/raft_dask/test/__init__.py diff --git a/python/raft-dask/raft_dask/common/utils.py b/python/raft-dask/raft_dask/common/utils.py index 78a899aa50..4c7eafb43c 100644 --- a/python/raft-dask/raft_dask/common/utils.py +++ b/python/raft-dask/raft_dask/common/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,7 +13,8 @@ # limitations under the License. # -from dask.distributed import default_client +from dask.distributed import Client, default_client +from dask_cuda import LocalCUDACluster def get_client(client=None): @@ -37,3 +38,10 @@ def parse_host_port(address): host, port = address.split(":") port = int(port) return host, port + + +def create_client(cluster): + if isinstance(cluster, LocalCUDACluster): + return Client(cluster) + else: + return Client(scheduler_file=cluster) diff --git a/python/raft-dask/raft_dask/test/__init__.py b/python/raft-dask/raft_dask/test/__init__.py new file mode 100644 index 0000000000..f294906058 --- /dev/null +++ b/python/raft-dask/raft_dask/test/__init__.py @@ -0,0 +1,16 @@ +# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +__version__ = "23.08.00" diff --git a/python/raft-dask/raft_dask/test/conftest.py b/python/raft-dask/raft_dask/test/conftest.py index ac4f6d74d6..64afecdb3f 100644 --- a/python/raft-dask/raft_dask/test/conftest.py +++ b/python/raft-dask/raft_dask/test/conftest.py @@ -4,9 +4,10 @@ import pytest -from dask.distributed import Client from dask_cuda import LocalCUDACluster +from raft_dask.common.utils import create_client + os.environ["UCX_LOG_LEVEL"] = "error" @@ -14,7 +15,7 @@ def cluster(): scheduler_file = os.environ.get("SCHEDULER_FILE") if scheduler_file: - return scheduler_file + yield scheduler_file else: cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) yield cluster @@ -25,7 +26,7 @@ def cluster(): def ucx_cluster(): scheduler_file = os.environ.get("SCHEDULER_FILE") if scheduler_file: - return scheduler_file + yield scheduler_file else: cluster = LocalCUDACluster( protocol="ucx", @@ -36,19 +37,13 @@ def ucx_cluster(): @pytest.fixture(scope="session") def client(cluster): - if isinstance(cluster, LocalCUDACluster): - client = Client(cluster) - else: - client = Client(scheduler_file=cluster) + client = create_client(cluster) yield client client.close() @pytest.fixture() def ucx_client(ucx_cluster): - if isinstance(ucx_cluster, LocalCUDACluster): - client = Client(ucx_cluster) - else: - client = Client(scheduler_file=cluster) + client = create_client(ucx_cluster) yield client client.close() diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 3a430f9270..bcf2edc976 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -17,7 +17,9 @@ import pytest -from dask.distributed import Client, get_worker, wait +from dask.distributed import get_worker, wait + +from raft_dask.common.utils import create_client try: from raft_dask.common import ( @@ -43,9 +45,7 @@ def test_comms_init_no_p2p(cluster): - - client = Client(cluster) - + client = create_client(cluster) try: cb = Comms(verbose=True) cb.init() @@ -121,8 +121,7 @@ def func_check_uid_on_worker(sessionId, uniqueId, dask_worker=None): def test_handles(cluster): - - client = Client(cluster) + client = create_client(cluster) def _has_handle(sessionId): return local_handle(sessionId, dask_worker=get_worker()) is not None From e943d2d904490b1988aaff5be679a56e623ddf2f Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 13 Jun 2023 15:10:17 -0700 Subject: [PATCH 3/9] Update __init__.py --- python/raft-dask/raft_dask/test/__init__.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/raft-dask/raft_dask/test/__init__.py b/python/raft-dask/raft_dask/test/__init__.py index f294906058..68884aab74 100644 --- a/python/raft-dask/raft_dask/test/__init__.py +++ b/python/raft-dask/raft_dask/test/__init__.py @@ -12,5 +12,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # - -__version__ = "23.08.00" From 1271dc7a5addb4fbdabb28cfcc6b680d10cfd214 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 13 Jun 2023 15:10:46 -0700 Subject: [PATCH 4/9] Update __init__.py --- python/raft-dask/raft_dask/test/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/raft-dask/raft_dask/test/__init__.py b/python/raft-dask/raft_dask/test/__init__.py index 68884aab74..764e0f32fd 100644 --- a/python/raft-dask/raft_dask/test/__init__.py +++ b/python/raft-dask/raft_dask/test/__init__.py @@ -11,4 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# From af8e2c0731b9ba4d4a312177d23e3d4e5c6ea51d Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Tue, 13 Jun 2023 15:30:41 -0700 Subject: [PATCH 5/9] Added docstring for create_client --- python/raft-dask/raft_dask/common/utils.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/python/raft-dask/raft_dask/common/utils.py b/python/raft-dask/raft_dask/common/utils.py index 4c7eafb43c..21c6ca72cc 100644 --- a/python/raft-dask/raft_dask/common/utils.py +++ b/python/raft-dask/raft_dask/common/utils.py @@ -41,6 +41,22 @@ def parse_host_port(address): def create_client(cluster): + """ + Create a Dask distributed client for a specified cluster. + + Parameters + ---------- + cluster : LocalCUDACluster instance or str + If a LocalCUDACluster instance is provided, a client will be created + for it directly. If a string is provided, it should specify the path to + a Dask scheduler file. A client will then be created for the cluster + referenced by this scheduler file. + + Returns + ------- + dask.distributed.Client + A client connected to the specified cluster. + """ if isinstance(cluster, LocalCUDACluster): return Client(cluster) else: From f4530c3a91a77d444d9facdd03df30e96d914eca Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 14 Jun 2023 21:03:04 -0700 Subject: [PATCH 6/9] Trying to fix import issues found on CI --- python/raft-dask/raft_dask/common/comms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 7a0b786ec4..e39721f6d9 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -25,13 +25,13 @@ from pylibraft.common.handle import Handle -from .comms_utils import ( +from raft_dask.common.comms_utils import ( inject_comms_on_handle, inject_comms_on_handle_coll_only, ) -from .nccl import nccl -from .ucx import UCX -from .utils import parse_host_port +from raft_dask.common.nccl import nccl +from raft_dask.common.ucx import UCX +from raft_dask.common.utils import parse_host_port logger = logging.getLogger(__name__) From 5281041f8e1436f7a327edeecfc6b365d7fda4e1 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Wed, 14 Jun 2023 23:08:38 -0700 Subject: [PATCH 7/9] fix import issue on CI --- python/raft-dask/raft_dask/common/__init__.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/raft-dask/raft_dask/common/__init__.py b/python/raft-dask/raft_dask/common/__init__.py index c8ce695def..aaaa373723 100644 --- a/python/raft-dask/raft_dask/common/__init__.py +++ b/python/raft-dask/raft_dask/common/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # -from .comms import Comms, local_handle -from .comms_utils import ( +from raft_dask.common.comms import Comms, local_handle +from raft_dask.common.comms_utils import ( inject_comms_on_handle, inject_comms_on_handle_coll_only, perform_test_comm_split, @@ -30,4 +30,4 @@ perform_test_comms_reducescatter, perform_test_comms_send_recv, ) -from .ucx import UCX +from raft_dask.common.ucx import UCX From 6629add78c7ab634ed54e5c4a5ee9a7225d44259 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Thu, 22 Jun 2023 10:50:42 -0700 Subject: [PATCH 8/9] Try to fix import issues --- python/raft-dask/raft_dask/common/utils.py | 26 +------------------ python/raft-dask/raft_dask/test/conftest.py | 26 +++++++++++++++++-- python/raft-dask/raft_dask/test/test_comms.py | 2 +- 3 files changed, 26 insertions(+), 28 deletions(-) diff --git a/python/raft-dask/raft_dask/common/utils.py b/python/raft-dask/raft_dask/common/utils.py index 21c6ca72cc..dcc53fda9a 100644 --- a/python/raft-dask/raft_dask/common/utils.py +++ b/python/raft-dask/raft_dask/common/utils.py @@ -13,8 +13,7 @@ # limitations under the License. # -from dask.distributed import Client, default_client -from dask_cuda import LocalCUDACluster +from dask.distributed import default_client def get_client(client=None): @@ -38,26 +37,3 @@ def parse_host_port(address): host, port = address.split(":") port = int(port) return host, port - - -def create_client(cluster): - """ - Create a Dask distributed client for a specified cluster. - - Parameters - ---------- - cluster : LocalCUDACluster instance or str - If a LocalCUDACluster instance is provided, a client will be created - for it directly. If a string is provided, it should specify the path to - a Dask scheduler file. A client will then be created for the cluster - referenced by this scheduler file. - - Returns - ------- - dask.distributed.Client - A client connected to the specified cluster. - """ - if isinstance(cluster, LocalCUDACluster): - return Client(cluster) - else: - return Client(scheduler_file=cluster) diff --git a/python/raft-dask/raft_dask/test/conftest.py b/python/raft-dask/raft_dask/test/conftest.py index 64afecdb3f..d1baa684d4 100644 --- a/python/raft-dask/raft_dask/test/conftest.py +++ b/python/raft-dask/raft_dask/test/conftest.py @@ -4,10 +4,9 @@ import pytest +from dask.distributed import Client from dask_cuda import LocalCUDACluster -from raft_dask.common.utils import create_client - os.environ["UCX_LOG_LEVEL"] = "error" @@ -47,3 +46,26 @@ def ucx_client(ucx_cluster): client = create_client(ucx_cluster) yield client client.close() + + +def create_client(cluster): + """ + Create a Dask distributed client for a specified cluster. + + Parameters + ---------- + cluster : LocalCUDACluster instance or str + If a LocalCUDACluster instance is provided, a client will be created + for it directly. If a string is provided, it should specify the path to + a Dask scheduler file. A client will then be created for the cluster + referenced by this scheduler file. + + Returns + ------- + dask.distributed.Client + A client connected to the specified cluster. + """ + if isinstance(cluster, LocalCUDACluster): + return Client(cluster) + else: + return Client(scheduler_file=cluster) diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index bcf2edc976..5c69a94fd8 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -19,7 +19,7 @@ from dask.distributed import get_worker, wait -from raft_dask.common.utils import create_client +from .conftest import create_client try: from raft_dask.common import ( From 15814662fc6bd9d364dd55aa9f3a6c6075aacb3d Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Thu, 22 Jun 2023 10:58:07 -0700 Subject: [PATCH 9/9] Try to fix import issues --- python/raft-dask/raft_dask/common/__init__.py | 8 ++++---- python/raft-dask/raft_dask/common/comms.py | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/python/raft-dask/raft_dask/common/__init__.py b/python/raft-dask/raft_dask/common/__init__.py index aaaa373723..c8ce695def 100644 --- a/python/raft-dask/raft_dask/common/__init__.py +++ b/python/raft-dask/raft_dask/common/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,8 +13,8 @@ # limitations under the License. # -from raft_dask.common.comms import Comms, local_handle -from raft_dask.common.comms_utils import ( +from .comms import Comms, local_handle +from .comms_utils import ( inject_comms_on_handle, inject_comms_on_handle_coll_only, perform_test_comm_split, @@ -30,4 +30,4 @@ perform_test_comms_reducescatter, perform_test_comms_send_recv, ) -from raft_dask.common.ucx import UCX +from .ucx import UCX diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index e39721f6d9..7a0b786ec4 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -25,13 +25,13 @@ from pylibraft.common.handle import Handle -from raft_dask.common.comms_utils import ( +from .comms_utils import ( inject_comms_on_handle, inject_comms_on_handle_coll_only, ) -from raft_dask.common.nccl import nccl -from raft_dask.common.ucx import UCX -from raft_dask.common.utils import parse_host_port +from .nccl import nccl +from .ucx import UCX +from .utils import parse_host_port logger = logging.getLogger(__name__)