diff --git a/ci/test_python.sh b/ci/test_python.sh index 3a47d7e1490..447b0a3e786 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -3,6 +3,10 @@ set -euo pipefail +# TODO: Enable dask query planning (by default) once some bugs are fixed. +# xref: https://github.com/rapidsai/cudf/issues/15027 +export DASK_DATAFRAME__QUERY_PLANNING=False + # Support invoking test_python.sh outside the script directory cd "$(dirname "$(realpath "${BASH_SOURCE[0]}")")"/../ diff --git a/ci/test_wheel.sh b/ci/test_wheel.sh index 158704e08d1..cda40d92c74 100755 --- a/ci/test_wheel.sh +++ b/ci/test_wheel.sh @@ -3,6 +3,10 @@ set -eoxu pipefail +# TODO: Enable dask query planning (by default) once some bugs are fixed. +# xref: https://github.com/rapidsai/cudf/issues/15027 +export DASK_DATAFRAME__QUERY_PLANNING=False + package_name=$1 package_dir=$2 diff --git a/python/cugraph/cugraph/dask/common/part_utils.py b/python/cugraph/cugraph/dask/common/part_utils.py index 25311902b29..d362502f239 100644 --- a/python/cugraph/cugraph/dask/common/part_utils.py +++ b/python/cugraph/cugraph/dask/common/part_utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# Copyright (c) 2019-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -115,7 +115,10 @@ def persist_dask_df_equal_parts_per_worker( raise ValueError("return_type must be either 'dask_cudf.DataFrame' or 'dict'") ddf_keys = dask_df.to_delayed() - workers = client.scheduler_info()["workers"].keys() + rank_to_worker = Comms.rank_to_worker(client) + # rank-worker mappings are in ascending order + workers = dict(sorted(rank_to_worker.items())).values() + ddf_keys_ls = _chunk_lst(ddf_keys, len(workers)) persisted_keys_d = {} for w, ddf_k in zip(workers, ddf_keys_ls): diff --git a/python/cugraph/cugraph/dask/comms/comms.py b/python/cugraph/cugraph/dask/comms/comms.py index d623f20a038..5499b13af03 100644 --- a/python/cugraph/cugraph/dask/comms/comms.py +++ b/python/cugraph/cugraph/dask/comms/comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2018-2023, NVIDIA CORPORATION. +# Copyright (c) 2018-2024, NVIDIA CORPORATION. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at @@ -265,3 +265,16 @@ def get_n_workers(sID=None, dask_worker=None): dask_worker = get_worker() sessionstate = get_raft_comm_state(sID, dask_worker) return sessionstate["nworkers"] + + +def rank_to_worker(client): + """ + Return a mapping of ranks to dask workers. + """ + workers = client.scheduler_info()["workers"].keys() + worker_info = __instance.worker_info(workers) + rank_to_worker = {} + for w in worker_info: + rank_to_worker[worker_info[w]["rank"]] = w + + return rank_to_worker