From ffb3d08e7c523f46bc51f1a1a37ce0fabf8b0832 Mon Sep 17 00:00:00 2001 From: fjetter Date: Wed, 26 Oct 2022 14:07:00 +0200 Subject: [PATCH] Ensure data is not lost if deserialization is slow --- distributed/shuffle/_shuffle_extension.py | 46 +++---- distributed/shuffle/tests/test_shuffle.py | 151 +++++++++++++++++++++- 2 files changed, 172 insertions(+), 25 deletions(-) diff --git a/distributed/shuffle/_shuffle_extension.py b/distributed/shuffle/_shuffle_extension.py index f29b131dff0..72763107b31 100644 --- a/distributed/shuffle/_shuffle_extension.py +++ b/distributed/shuffle/_shuffle_extension.py @@ -214,31 +214,31 @@ async def _receive(self, data: list[pa.Buffer]) -> None: if self._exception: raise self._exception - self.total_recvd += sum(map(len, data)) - # TODO: Is it actually a good idea to dispatch multiple times instead of - # onyl once? - # An ugly way of turning these batches back into an arrow table - with self.time("cpu"): - data = await self.offload( - list_of_buffers_to_table, - data, - self.schema, - ) + try: + self.total_recvd += sum(map(len, data)) + # TODO: Is it actually a good idea to dispatch multiple times instead of + # onyl once? + # An ugly way of turning these batches back into an arrow table + with self.time("cpu"): + data = await self.offload( + list_of_buffers_to_table, + data, + self.schema, + ) - groups = await self.offload(split_by_partition, data, self.column) + groups = await self.offload(split_by_partition, data, self.column) - assert len(data) == sum(map(len, groups.values())) - del data + assert len(data) == sum(map(len, groups.values())) + del data - with self.time("cpu"): - groups = await self.offload( - lambda: { - k: [batch.serialize() for batch in v.to_batches()] - for k, v in groups.items() - } - ) - try: - await self.multi_file.put(groups) + with self.time("cpu"): + groups = await self.offload( + lambda: { + k: [batch.serialize() for batch in v.to_batches()] + for k, v in groups.items() + } + ) + await self.multi_file.put(groups) except Exception as e: self._exception = e @@ -290,6 +290,8 @@ def done(self) -> bool: async def flush_receive(self) -> None: await asyncio.gather(*self._tasks) + if self._exception: + raise self._exception await self.multi_file.flush() async def close(self) -> None: diff --git a/distributed/shuffle/tests/test_shuffle.py b/distributed/shuffle/tests/test_shuffle.py index fcf328821c7..db94600b7e3 100644 --- a/distributed/shuffle/tests/test_shuffle.py +++ b/distributed/shuffle/tests/test_shuffle.py @@ -497,7 +497,7 @@ async def _(**kwargs): return _ -class TestShufflePool: +class ShuffleTestPool: def __init__(self, *args, **kwargs): self.shuffles = {} super().__init__(*args, **kwargs) @@ -513,7 +513,9 @@ async def fake_broadcast(self, msg): out[addr] = await getattr(s, op)() return out - def new_shuffle(self, name, worker_for_mapping, schema, directory, loop): + def new_shuffle( + self, name, worker_for_mapping, schema, directory, loop, Shuffle=Shuffle + ): s = Shuffle( column="_partition", worker_for=worker_for_mapping, @@ -563,7 +565,7 @@ async def test_basic_lowlevel_shuffle( assert len(set(worker_for_mapping.values())) == min(n_workers, npartitions) schema = pa.Schema.from_pandas(dfs[0]) - local_shuffle_pool = TestShufflePool() + local_shuffle_pool = ShuffleTestPool() shuffles = [] for ix in range(n_workers): shuffles.append( @@ -606,3 +608,146 @@ def _done(): finally: await asyncio.gather(*[s.close() for s in shuffles]) assert len(df_after) == len(pd.concat(dfs)) + + +@gen_test() +async def test_slow_offload(tmpdir, loop_in_thread): + # Ensure that all data is + dfs = [] + rows_per_df = 10 + n_input_partitions = 2 + npartitions = 2 + for ix in range(n_input_partitions): + df = pd.DataFrame({"x": range(rows_per_df * ix, rows_per_df * (ix + 1))}) + df["_partition"] = df.x % npartitions + dfs.append(df) + + workers = ["A", "B"] + + worker_for_mapping = {} + partitions_for_worker = defaultdict(list) + + for part in range(npartitions): + worker_for_mapping[part] = w = get_worker_for(part, workers, npartitions) + partitions_for_worker[w].append(part) + schema = pa.Schema.from_pandas(dfs[0]) + + local_shuffle_pool = ShuffleTestPool() + + block_offload = asyncio.Event() + + class SlowOffload(Shuffle): + async def offload(self, func, *args): + await block_offload.wait() + return await super().offload(func, *args) + + sA = local_shuffle_pool.new_shuffle( + name="A", + worker_for_mapping=worker_for_mapping, + schema=schema, + directory=tmpdir, + loop=loop_in_thread, + Shuffle=SlowOffload, + ) + sB = local_shuffle_pool.new_shuffle( + name="B", + worker_for_mapping=worker_for_mapping, + schema=schema, + directory=tmpdir, + loop=loop_in_thread, + ) + try: + sA.add_partition(dfs[0]) + sB.add_partition(dfs[1]) + + await sB.barrier() + + assert len(partitions_for_worker["A"]) == 1 + + get_partition_A = asyncio.create_task( + sA.get_output_partition(partitions_for_worker["A"][0]) + ) + partition_available = asyncio.Event() + get_partition_A.add_done_callback(lambda _: partition_available.set()) + + with pytest.raises(asyncio.TimeoutError): + await asyncio.wait_for(partition_available.wait(), 0.2) + + # Fetching from B is not a problem + assert len(partitions_for_worker["B"]) == 1 + df1 = await sB.get_output_partition(partitions_for_worker["B"][0]) + + # After unblocking we should receive the data as usual w/out data loss + block_offload.set() + df2 = await get_partition_A + + df_after = pd.concat([df1, df2]) + assert len(df_after) == len(pd.concat(dfs)) + finally: + await asyncio.gather(*[s.close() for s in [sA, sB]]) + + +@gen_test() +async def test_error_offload(tmpdir, loop_in_thread): + # Ensure that all data is + dfs = [] + rows_per_df = 10 + n_input_partitions = 2 + npartitions = 2 + for ix in range(n_input_partitions): + df = pd.DataFrame({"x": range(rows_per_df * ix, rows_per_df * (ix + 1))}) + df["_partition"] = df.x % npartitions + dfs.append(df) + + workers = ["A", "B"] + + worker_for_mapping = {} + partitions_for_worker = defaultdict(list) + + for part in range(npartitions): + worker_for_mapping[part] = w = get_worker_for(part, workers, npartitions) + partitions_for_worker[w].append(part) + schema = pa.Schema.from_pandas(dfs[0]) + + local_shuffle_pool = ShuffleTestPool() + + block_offload = asyncio.Event() + + class ErrorOffload(Shuffle): + async def offload(self, func, *args): + raise RuntimeError("Error during deserialization") + + sA = local_shuffle_pool.new_shuffle( + name="A", + worker_for_mapping=worker_for_mapping, + schema=schema, + directory=tmpdir, + loop=loop_in_thread, + Shuffle=ErrorOffload, + ) + sB = local_shuffle_pool.new_shuffle( + name="B", + worker_for_mapping=worker_for_mapping, + schema=schema, + directory=tmpdir, + loop=loop_in_thread, + ) + try: + with pytest.raises(RuntimeError, match="Error during deserialization"): + sA.add_partition(dfs[0]) + sB.add_partition(dfs[1]) + + await sB.barrier() + + assert len(partitions_for_worker["A"]) == 1 + + # The error should be raised here. Functionally speaking, we're fine + # as long as it is raised before we collect the last shard. + await sA.get_output_partition(partitions_for_worker["A"][0]) + + # Fetching from B is not a problem + assert len(partitions_for_worker["B"]) == 1 + await sB.get_output_partition(partitions_for_worker["B"][0]) + + finally: + await asyncio.gather(*[s.close() for s in [sA, sB]])