Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the assumption made on the client data's keys #3835

20 changes: 10 additions & 10 deletions python/cugraph/cugraph/dask/common/part_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def persist_distributed_data(dask_df, client):
_keys = dask_df.__dask_keys__()
worker_dict = {}
for i, key in enumerate(_keys):
worker_dict[str(key)] = tuple([worker_addresses[i]])
worker_dict[key] = tuple([worker_addresses[i]])
persisted = client.persist(dask_df, workers=worker_dict)
parts = futures_of(persisted)
return parts
Expand All @@ -89,7 +89,7 @@ def get_persisted_df_worker_map(dask_df, client):
ddf_keys = futures_of(dask_df)
output_map = {}
for w, w_keys in client.has_what().items():
output_map[w] = [ddf_k for ddf_k in ddf_keys if str(ddf_k.key) in w_keys]
output_map[w] = [ddf_k for ddf_k in ddf_keys if ddf_k.key in w_keys]
if len(output_map[w]) == 0:
output_map[w] = _create_empty_dask_df_future(dask_df._meta, client, w)
return output_map
Expand Down Expand Up @@ -157,7 +157,7 @@ async def _extract_partitions(
# NOTE: We colocate (X, y) here by zipping delayed
# n partitions of them as (X1, y1), (X2, y2)...
# and asking client to compute a single future for
# each tuple in the list
# each tuple in the list.
dela = [np.asarray(d.to_delayed()) for d in dask_obj]

# TODO: ravel() is causing strange behavior w/ delayed Arrays which are
Expand All @@ -167,7 +167,7 @@ async def _extract_partitions(
parts = client.compute([p for p in zip(*raveled)])

await wait(parts)
key_to_part = [(str(part.key), part) for part in parts]
key_to_part = [(part.key, part) for part in parts]
who_has = await client.who_has(parts)
return [(first(who_has[key]), part) for key, part in key_to_part]

Expand Down Expand Up @@ -229,7 +229,7 @@ def load_balance_func(ddf_, by, client=None):
wait(parts)

who_has = client.who_has(parts)
key_to_part = [(str(part.key), part) for part in parts]
key_to_part = [(part.key, part) for part in parts]
gpu_fututres = [
(first(who_has[key]), part.key[1], part) for key, part in key_to_part
]
Expand All @@ -245,7 +245,7 @@ def load_balance_func(ddf_, by, client=None):
for cumsum in cumsum_parts:
num_rows.append(cumsum.iloc[-1])

# Calculate current partition divisions
# Calculate current partition divisions.
divisions = [sum(num_rows[0:x:1]) for x in range(0, len(num_rows) + 1)]
divisions[-1] = divisions[-1] - 1
divisions = tuple(divisions)
Expand All @@ -271,25 +271,25 @@ def load_balance_func(ddf_, by, client=None):

def concat_dfs(df_list):
"""
Concat a list of cudf dataframes
Concat a list of cudf dataframes.
"""
return cudf.concat(df_list)


def get_delayed_dict(ddf):
"""
Returns a dicitionary with the dataframe tasks as keys and
the dataframe delayed objects as values
the dataframe delayed objects as values.
"""
df_delayed = {}
for delayed_obj in ddf.to_delayed():
df_delayed[str(delayed_obj.key)] = delayed_obj
df_delayed[delayed_obj.key] = delayed_obj
return df_delayed


def concat_within_workers(client, ddf):
"""
Concats all partitions within workers without transfers
Concats all partitions within workers without transfers.
"""
df_delayed = get_delayed_dict(ddf)

Expand Down