Skip to content

Commit

Permalink
Make p2p shuffle submodules private (#7186)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Oct 26, 2022
1 parent 675cfed commit 5dccad4
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 25 deletions.
12 changes: 9 additions & 3 deletions distributed/shuffle/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
from __future__ import annotations

from distributed.shuffle.shuffle import rearrange_by_column_p2p
from distributed.shuffle.shuffle_extension import (
ShuffleId,
from distributed.shuffle._shuffle import P2PShuffleLayer, rearrange_by_column_p2p
from distributed.shuffle._shuffle_extension import (
ShuffleSchedulerExtension,
ShuffleWorkerExtension,
)

__all__ = [
"P2PShuffleLayer",
"rearrange_by_column_p2p",
"ShuffleSchedulerExtension",
"ShuffleWorkerExtension",
]
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
from dask.highlevelgraph import HighLevelGraph
from dask.layers import SimpleShuffleLayer

from distributed.shuffle.shuffle_extension import ShuffleId, ShuffleWorkerExtension
from distributed.shuffle._shuffle_extension import ShuffleId, ShuffleWorkerExtension

if TYPE_CHECKING:
import pandas as pd

from dask.dataframe import DataFrame


def get_ext() -> ShuffleWorkerExtension:
def _get_worker_extension() -> ShuffleWorkerExtension:
from distributed import get_worker

try:
Expand All @@ -25,7 +25,7 @@ def get_ext() -> ShuffleWorkerExtension:
"please confirm that you've created a distributed Client and are submitting this computation through it."
) from e
extension: ShuffleWorkerExtension | None = worker.extensions.get("shuffle")
if not extension:
if extension is None:
raise RuntimeError(
f"The worker {worker.address} does not have a ShuffleExtension. "
"Is pandas installed on the worker?"
Expand All @@ -39,17 +39,19 @@ def shuffle_transfer(
npartitions: int,
column: str,
) -> None:
get_ext().add_partition(input, id, npartitions=npartitions, column=column)
_get_worker_extension().add_partition(
input, id, npartitions=npartitions, column=column
)


def shuffle_unpack(
id: ShuffleId, output_partition: int, barrier: object
) -> pd.DataFrame:
return get_ext().get_output_partition(id, output_partition)
return _get_worker_extension().get_output_partition(id, output_partition)


def shuffle_barrier(id: ShuffleId, transfers: list[None]) -> None:
get_ext().barrier(id)
return _get_worker_extension().barrier(id)


def rearrange_by_column_p2p(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,14 @@
import toolz

from distributed.protocol import to_serialize
from distributed.shuffle.arrow import (
from distributed.shuffle._arrow import (
deserialize_schema,
dump_batch,
list_of_buffers_to_table,
load_arrow,
)
from distributed.shuffle.multi_comm import MultiComm
from distributed.shuffle.multi_file import MultiFile
from distributed.shuffle._multi_comm import MultiComm
from distributed.shuffle._multi_file import MultiFile
from distributed.utils import log_errors, sync

if TYPE_CHECKING:
Expand Down Expand Up @@ -559,7 +559,9 @@ def split_by_worker(
if not nrows:
return {}
# assert len(df) == nrows # Not true if some outputs aren't wanted
t = pa.Table.from_pandas(df)
# FIXME: If we do not preserve the index something is corrupting the
# bytestream such that it cannot be deserialized anymore
t = pa.Table.from_pandas(df, preserve_index=True)
t = t.sort_by("_worker")
codes = np.asarray(t.select(["_worker"]))[0]
t = t.drop(["_worker"])
Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dask.blockwise import Blockwise
from dask.utils_test import hlg_layer_topological

from distributed.shuffle.shuffle_extension import ShuffleWorkerExtension
from distributed.shuffle._shuffle_extension import ShuffleWorkerExtension
from distributed.utils_test import gen_cluster


Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/tests/test_multi_comm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from distributed.shuffle.multi_comm import MultiComm
from distributed.shuffle._multi_comm import MultiComm
from distributed.utils_test import gen_test


Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/tests/test_multi_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytest

from distributed.shuffle.multi_file import MultiFile
from distributed.shuffle._multi_file import MultiFile
from distributed.utils_test import gen_test


Expand Down
2 changes: 1 addition & 1 deletion distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dask.distributed import Worker
from dask.utils import stringify

from distributed.shuffle.shuffle_extension import (
from distributed.shuffle._shuffle_extension import (
dump_batch,
list_of_buffers_to_table,
load_arrow,
Expand Down
33 changes: 25 additions & 8 deletions distributed/shuffle/tests/test_shuffle_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
pd = pytest.importorskip("pandas")
dd = pytest.importorskip("dask.dataframe")

from distributed.shuffle.shuffle_extension import (
from distributed.shuffle._shuffle_extension import (
ShuffleWorkerExtension,
get_worker_for,
split_by_partition,
Expand All @@ -22,25 +22,39 @@ async def test_installation(s, a):
assert a.handlers["shuffle_inputs_done"] == ext.shuffle_inputs_done


@pytest.mark.skip
def test_split_by_worker():
df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [0, 1, 2, 0, 1],
}
)

workers = ["alice", "bob"]
worker_for_mapping = {}
npartitions = 3

out = split_by_worker(df, "_partition", npartitions, workers)
for part in range(npartitions):
worker_for_mapping[part] = get_worker_for(part, workers, npartitions)
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
out = split_by_worker(df, "_partition", worker_for)
assert set(out) == {"alice", "bob"}
assert out["alice"].column_names == list(df.columns)
assert list(out["alice"].to_pandas().columns) == list(df.columns)

assert sum(map(len, out.values())) == len(df)


@pytest.mark.skip
def test_split_by_worker_empty():
df = pd.DataFrame(
{
"x": [1, 2, 3, 4, 5],
"_partition": [0, 1, 2, 0, 1],
}
)
worker_for = pd.Series({5: "chuck"}, name="_workers").astype("category")
out = split_by_worker(df, "_partition", worker_for)
assert out == {}


def test_split_by_worker_many_workers():
df = pd.DataFrame(
{
Expand All @@ -50,8 +64,11 @@ def test_split_by_worker_many_workers():
)
workers = ["a", "b", "c", "d", "e", "f", "g", "h"]
npartitions = 10

out = split_by_worker(df, "_partition", npartitions, workers)
worker_for_mapping = {}
for part in range(npartitions):
worker_for_mapping[part] = get_worker_for(part, workers, npartitions)
worker_for = pd.Series(worker_for_mapping, name="_workers").astype("category")
out = split_by_worker(df, "_partition", worker_for)
assert get_worker_for(5, workers, npartitions) in out
assert get_worker_for(0, workers, npartitions) in out
assert get_worker_for(7, workers, npartitions) in out
Expand Down

0 comments on commit 5dccad4

Please sign in to comment.