Skip to content

Commit

Permalink
Ensure data is not lost if deserialization is slow
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter committed Oct 26, 2022
1 parent d91556a commit ffb3d08
Show file tree
Hide file tree
Showing 2 changed files with 172 additions and 25 deletions.
46 changes: 24 additions & 22 deletions distributed/shuffle/_shuffle_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,31 +214,31 @@ async def _receive(self, data: list[pa.Buffer]) -> None:
if self._exception:
raise self._exception

self.total_recvd += sum(map(len, data))
# TODO: Is it actually a good idea to dispatch multiple times instead of
# onyl once?
# An ugly way of turning these batches back into an arrow table
with self.time("cpu"):
data = await self.offload(
list_of_buffers_to_table,
data,
self.schema,
)
try:
self.total_recvd += sum(map(len, data))
# TODO: Is it actually a good idea to dispatch multiple times instead of
# onyl once?
# An ugly way of turning these batches back into an arrow table
with self.time("cpu"):
data = await self.offload(
list_of_buffers_to_table,
data,
self.schema,
)

groups = await self.offload(split_by_partition, data, self.column)
groups = await self.offload(split_by_partition, data, self.column)

assert len(data) == sum(map(len, groups.values()))
del data
assert len(data) == sum(map(len, groups.values()))
del data

with self.time("cpu"):
groups = await self.offload(
lambda: {
k: [batch.serialize() for batch in v.to_batches()]
for k, v in groups.items()
}
)
try:
await self.multi_file.put(groups)
with self.time("cpu"):
groups = await self.offload(
lambda: {
k: [batch.serialize() for batch in v.to_batches()]
for k, v in groups.items()
}
)
await self.multi_file.put(groups)
except Exception as e:
self._exception = e

Expand Down Expand Up @@ -290,6 +290,8 @@ def done(self) -> bool:

async def flush_receive(self) -> None:
await asyncio.gather(*self._tasks)
if self._exception:
raise self._exception
await self.multi_file.flush()

async def close(self) -> None:
Expand Down
151 changes: 148 additions & 3 deletions distributed/shuffle/tests/test_shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,7 +497,7 @@ async def _(**kwargs):
return _


class TestShufflePool:
class ShuffleTestPool:
def __init__(self, *args, **kwargs):
self.shuffles = {}
super().__init__(*args, **kwargs)
Expand All @@ -513,7 +513,9 @@ async def fake_broadcast(self, msg):
out[addr] = await getattr(s, op)()
return out

def new_shuffle(self, name, worker_for_mapping, schema, directory, loop):
def new_shuffle(
self, name, worker_for_mapping, schema, directory, loop, Shuffle=Shuffle
):
s = Shuffle(
column="_partition",
worker_for=worker_for_mapping,
Expand Down Expand Up @@ -563,7 +565,7 @@ async def test_basic_lowlevel_shuffle(
assert len(set(worker_for_mapping.values())) == min(n_workers, npartitions)
schema = pa.Schema.from_pandas(dfs[0])

local_shuffle_pool = TestShufflePool()
local_shuffle_pool = ShuffleTestPool()
shuffles = []
for ix in range(n_workers):
shuffles.append(
Expand Down Expand Up @@ -606,3 +608,146 @@ def _done():
finally:
await asyncio.gather(*[s.close() for s in shuffles])
assert len(df_after) == len(pd.concat(dfs))


@gen_test()
async def test_slow_offload(tmpdir, loop_in_thread):
# Ensure that all data is
dfs = []
rows_per_df = 10
n_input_partitions = 2
npartitions = 2
for ix in range(n_input_partitions):
df = pd.DataFrame({"x": range(rows_per_df * ix, rows_per_df * (ix + 1))})
df["_partition"] = df.x % npartitions
dfs.append(df)

workers = ["A", "B"]

worker_for_mapping = {}
partitions_for_worker = defaultdict(list)

for part in range(npartitions):
worker_for_mapping[part] = w = get_worker_for(part, workers, npartitions)
partitions_for_worker[w].append(part)
schema = pa.Schema.from_pandas(dfs[0])

local_shuffle_pool = ShuffleTestPool()

block_offload = asyncio.Event()

class SlowOffload(Shuffle):
async def offload(self, func, *args):
await block_offload.wait()
return await super().offload(func, *args)

sA = local_shuffle_pool.new_shuffle(
name="A",
worker_for_mapping=worker_for_mapping,
schema=schema,
directory=tmpdir,
loop=loop_in_thread,
Shuffle=SlowOffload,
)
sB = local_shuffle_pool.new_shuffle(
name="B",
worker_for_mapping=worker_for_mapping,
schema=schema,
directory=tmpdir,
loop=loop_in_thread,
)
try:
sA.add_partition(dfs[0])
sB.add_partition(dfs[1])

await sB.barrier()

assert len(partitions_for_worker["A"]) == 1

get_partition_A = asyncio.create_task(
sA.get_output_partition(partitions_for_worker["A"][0])
)
partition_available = asyncio.Event()
get_partition_A.add_done_callback(lambda _: partition_available.set())

with pytest.raises(asyncio.TimeoutError):
await asyncio.wait_for(partition_available.wait(), 0.2)

# Fetching from B is not a problem
assert len(partitions_for_worker["B"]) == 1
df1 = await sB.get_output_partition(partitions_for_worker["B"][0])

# After unblocking we should receive the data as usual w/out data loss
block_offload.set()
df2 = await get_partition_A

df_after = pd.concat([df1, df2])
assert len(df_after) == len(pd.concat(dfs))
finally:
await asyncio.gather(*[s.close() for s in [sA, sB]])


@gen_test()
async def test_error_offload(tmpdir, loop_in_thread):
# Ensure that all data is
dfs = []
rows_per_df = 10
n_input_partitions = 2
npartitions = 2
for ix in range(n_input_partitions):
df = pd.DataFrame({"x": range(rows_per_df * ix, rows_per_df * (ix + 1))})
df["_partition"] = df.x % npartitions
dfs.append(df)

workers = ["A", "B"]

worker_for_mapping = {}
partitions_for_worker = defaultdict(list)

for part in range(npartitions):
worker_for_mapping[part] = w = get_worker_for(part, workers, npartitions)
partitions_for_worker[w].append(part)
schema = pa.Schema.from_pandas(dfs[0])

local_shuffle_pool = ShuffleTestPool()

block_offload = asyncio.Event()

class ErrorOffload(Shuffle):
async def offload(self, func, *args):
raise RuntimeError("Error during deserialization")

sA = local_shuffle_pool.new_shuffle(
name="A",
worker_for_mapping=worker_for_mapping,
schema=schema,
directory=tmpdir,
loop=loop_in_thread,
Shuffle=ErrorOffload,
)
sB = local_shuffle_pool.new_shuffle(
name="B",
worker_for_mapping=worker_for_mapping,
schema=schema,
directory=tmpdir,
loop=loop_in_thread,
)
try:
with pytest.raises(RuntimeError, match="Error during deserialization"):
sA.add_partition(dfs[0])
sB.add_partition(dfs[1])

await sB.barrier()

assert len(partitions_for_worker["A"]) == 1

# The error should be raised here. Functionally speaking, we're fine
# as long as it is raised before we collect the last shard.
await sA.get_output_partition(partitions_for_worker["A"][0])

# Fetching from B is not a problem
assert len(partitions_for_worker["B"]) == 1
await sB.get_output_partition(partitions_for_worker["B"][0])

finally:
await asyncio.gather(*[s.close() for s in [sA, sB]])

0 comments on commit ffb3d08

Please sign in to comment.