diff --git a/distributed/shuffle/__init__.py b/distributed/shuffle/__init__.py index 2a073116781..64c7fe012b4 100644 --- a/distributed/shuffle/__init__.py +++ b/distributed/shuffle/__init__.py @@ -1,8 +1,14 @@ from __future__ import annotations -from distributed.shuffle.shuffle import rearrange_by_column_p2p -from distributed.shuffle.shuffle_extension import ( - ShuffleId, +from distributed.shuffle._shuffle import P2PShuffleLayer, rearrange_by_column_p2p +from distributed.shuffle._shuffle_extension import ( ShuffleSchedulerExtension, ShuffleWorkerExtension, ) + +__all__ = [ + "P2PShuffleLayer", + "rearrange_by_column_p2p", + "ShuffleSchedulerExtension", + "ShuffleWorkerExtension", +] diff --git a/distributed/shuffle/arrow.py b/distributed/shuffle/_arrow.py similarity index 100% rename from distributed/shuffle/arrow.py rename to distributed/shuffle/_arrow.py diff --git a/distributed/shuffle/multi_comm.py b/distributed/shuffle/_multi_comm.py similarity index 100% rename from distributed/shuffle/multi_comm.py rename to distributed/shuffle/_multi_comm.py diff --git a/distributed/shuffle/multi_file.py b/distributed/shuffle/_multi_file.py similarity index 100% rename from distributed/shuffle/multi_file.py rename to distributed/shuffle/_multi_file.py diff --git a/distributed/shuffle/shuffle.py b/distributed/shuffle/_shuffle.py similarity index 91% rename from distributed/shuffle/shuffle.py rename to distributed/shuffle/_shuffle.py index ca3fe0c3c0d..df5f0c6b712 100644 --- a/distributed/shuffle/shuffle.py +++ b/distributed/shuffle/_shuffle.py @@ -6,7 +6,7 @@ from dask.highlevelgraph import HighLevelGraph from dask.layers import SimpleShuffleLayer -from distributed.shuffle.shuffle_extension import ShuffleId, ShuffleWorkerExtension +from distributed.shuffle._shuffle_extension import ShuffleId, ShuffleWorkerExtension if TYPE_CHECKING: import pandas as pd @@ -14,7 +14,7 @@ from dask.dataframe import DataFrame -def get_ext() -> ShuffleWorkerExtension: +def _get_worker_extension() -> ShuffleWorkerExtension: from distributed import get_worker try: @@ -25,7 +25,7 @@ def get_ext() -> ShuffleWorkerExtension: "please confirm that you've created a distributed Client and are submitting this computation through it." ) from e extension: ShuffleWorkerExtension | None = worker.extensions.get("shuffle") - if not extension: + if extension is None: raise RuntimeError( f"The worker {worker.address} does not have a ShuffleExtension. " "Is pandas installed on the worker?" @@ -39,17 +39,19 @@ def shuffle_transfer( npartitions: int, column: str, ) -> None: - get_ext().add_partition(input, id, npartitions=npartitions, column=column) + _get_worker_extension().add_partition( + input, id, npartitions=npartitions, column=column + ) def shuffle_unpack( id: ShuffleId, output_partition: int, barrier: object ) -> pd.DataFrame: - return get_ext().get_output_partition(id, output_partition) + return _get_worker_extension().get_output_partition(id, output_partition) def shuffle_barrier(id: ShuffleId, transfers: list[None]) -> None: - get_ext().barrier(id) + return _get_worker_extension().barrier(id) def rearrange_by_column_p2p( diff --git a/distributed/shuffle/shuffle_extension.py b/distributed/shuffle/_shuffle_extension.py similarity index 98% rename from distributed/shuffle/shuffle_extension.py rename to distributed/shuffle/_shuffle_extension.py index bad94aeee95..ded1d84c985 100644 --- a/distributed/shuffle/shuffle_extension.py +++ b/distributed/shuffle/_shuffle_extension.py @@ -13,14 +13,14 @@ import toolz from distributed.protocol import to_serialize -from distributed.shuffle.arrow import ( +from distributed.shuffle._arrow import ( deserialize_schema, dump_batch, list_of_buffers_to_table, load_arrow, ) -from distributed.shuffle.multi_comm import MultiComm -from distributed.shuffle.multi_file import MultiFile +from distributed.shuffle._multi_comm import MultiComm +from distributed.shuffle._multi_file import MultiFile from distributed.utils import log_errors, sync if TYPE_CHECKING: @@ -559,7 +559,9 @@ def split_by_worker( if not nrows: return {} # assert len(df) == nrows # Not true if some outputs aren't wanted - t = pa.Table.from_pandas(df) + # FIXME: If we do not preserve the index something is corrupting the + # bytestream such that it cannot be deserialized anymore + t = pa.Table.from_pandas(df, preserve_index=True) t = t.sort_by("_worker") codes = np.asarray(t.select(["_worker"]))[0] t = t.drop(["_worker"]) diff --git a/distributed/shuffle/tests/test_graph.py b/distributed/shuffle/tests/test_graph.py index 028b955aa5e..ffa84553110 100644 --- a/distributed/shuffle/tests/test_graph.py +++ b/distributed/shuffle/tests/test_graph.py @@ -13,7 +13,7 @@ from dask.blockwise import Blockwise from dask.utils_test import hlg_layer_topological -from distributed.shuffle.shuffle_extension import ShuffleWorkerExtension +from distributed.shuffle._shuffle_extension import ShuffleWorkerExtension from distributed.utils_test import gen_cluster diff --git a/distributed/shuffle/tests/test_multi_comm.py b/distributed/shuffle/tests/test_multi_comm.py index d45a4d18b93..90919511701 100644 --- a/distributed/shuffle/tests/test_multi_comm.py +++ b/distributed/shuffle/tests/test_multi_comm.py @@ -5,7 +5,7 @@ import pytest -from distributed.shuffle.multi_comm import MultiComm +from distributed.shuffle._multi_comm import MultiComm from distributed.utils_test import gen_test diff --git a/distributed/shuffle/tests/test_multi_file.py b/distributed/shuffle/tests/test_multi_file.py index 16ebc7562c9..9dddd3cd3a7 100644 --- a/distributed/shuffle/tests/test_multi_file.py +++ b/distributed/shuffle/tests/test_multi_file.py @@ -5,7 +5,7 @@ import pytest -from distributed.shuffle.multi_file import MultiFile +from distributed.shuffle._multi_file import MultiFile from distributed.utils_test import gen_test diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index 63e245bb23d..5d15f68ffc6 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -17,7 +17,7 @@ from dask.distributed import Worker from dask.utils import stringify -from distributed.shuffle.shuffle_extension import ( +from distributed.shuffle._shuffle_extension import ( dump_batch, list_of_buffers_to_table, load_arrow, diff --git a/distributed/shuffle/tests/test_shuffle_extension.py b/distributed/shuffle/tests/test_shuffle_extension.py index b3ffbcc6988..78b6e3b64fd 100644 --- a/distributed/shuffle/tests/test_shuffle_extension.py +++ b/distributed/shuffle/tests/test_shuffle_extension.py @@ -5,7 +5,7 @@ pd = pytest.importorskip("pandas") dd = pytest.importorskip("dask.dataframe") -from distributed.shuffle.shuffle_extension import ( +from distributed.shuffle._shuffle_extension import ( ShuffleWorkerExtension, get_worker_for, split_by_partition, @@ -22,7 +22,6 @@ async def test_installation(s, a): assert a.handlers["shuffle_inputs_done"] == ext.shuffle_inputs_done -@pytest.mark.skip def test_split_by_worker(): df = pd.DataFrame( { @@ -30,17 +29,32 @@ def test_split_by_worker(): "_partition": [0, 1, 2, 0, 1], } ) + workers = ["alice", "bob"] + worker_for_mapping = {} npartitions = 3 - - out = split_by_worker(df, "_partition", npartitions, workers) + for part in range(npartitions): + worker_for_mapping[part] = get_worker_for(part, workers, npartitions) + worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category") + out = split_by_worker(df, "_partition", worker_for) assert set(out) == {"alice", "bob"} - assert out["alice"].column_names == list(df.columns) + assert list(out["alice"].to_pandas().columns) == list(df.columns) assert sum(map(len, out.values())) == len(df) -@pytest.mark.skip +def test_split_by_worker_empty(): + df = pd.DataFrame( + { + "x": [1, 2, 3, 4, 5], + "_partition": [0, 1, 2, 0, 1], + } + ) + worker_for = pd.Series({5: "chuck"}, name="_workers").astype("category") + out = split_by_worker(df, "_partition", worker_for) + assert out == {} + + def test_split_by_worker_many_workers(): df = pd.DataFrame( { @@ -50,8 +64,11 @@ def test_split_by_worker_many_workers(): ) workers = ["a", "b", "c", "d", "e", "f", "g", "h"] npartitions = 10 - - out = split_by_worker(df, "_partition", npartitions, workers) + worker_for_mapping = {} + for part in range(npartitions): + worker_for_mapping[part] = get_worker_for(part, workers, npartitions) + worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category") + out = split_by_worker(df, "_partition", worker_for) assert get_worker_for(5, workers, npartitions) in out assert get_worker_for(0, workers, npartitions) in out assert get_worker_for(7, workers, npartitions) in out