From 9b7bfbc1ac283b3fd2b7ce907d64df5785f1e312 Mon Sep 17 00:00:00 2001 From: erjia Date: Mon, 27 Feb 2023 18:32:22 +0000 Subject: [PATCH 1/2] Create new eventloop to run coroutines --- torchdata/datapipes/iter/transform/callable.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/torchdata/datapipes/iter/transform/callable.py b/torchdata/datapipes/iter/transform/callable.py index 815f8a8db..643e36190 100644 --- a/torchdata/datapipes/iter/transform/callable.py +++ b/torchdata/datapipes/iter/transform/callable.py @@ -447,9 +447,16 @@ def __init__( self.max_concurrency = max_concurrency def __iter__(self): - for batch in self.source_datapipe: - new_batch = asyncio.run(self.processbatch(batch)) - yield new_batch + policy = asyncio.get_event_loop_policy() + loop = policy.new_event_loop() + try: + for batch in self.source_datapipe: + policy.set_event_loop(loop) + new_batch = loop.run_until_complete(self.processbatch(batch)) + yield new_batch + finally: + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.close() async def processbatch(self, batch): sem = asyncio.Semaphore(self.max_concurrency) From 3260d3194420420df8e5a2136e01f349dc788971 Mon Sep 17 00:00:00 2001 From: erjia Date: Mon, 27 Feb 2023 20:52:05 +0000 Subject: [PATCH 2/2] Add unit tests for multiple coroutines --- test/test_iterdatapipe.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/test/test_iterdatapipe.py b/test/test_iterdatapipe.py index 8ad84298d..2f105615d 100644 --- a/test/test_iterdatapipe.py +++ b/test/test_iterdatapipe.py @@ -1601,6 +1601,15 @@ def _helper(input_data, exp_res, async_fn, input_col=None, output_col=None, max_ output_col=1, ) + # Test multiple asyncio eventloops + dp1 = IterableWrapper(range(50)) + dp1 = dp1.async_map_batches(_async_mul_ten, 16) + dp2 = IterableWrapper(range(50)) + dp2 = dp2.async_map_batches(_async_mul_ten, 16) + for v1, v2, exp in zip(dp1, dp2, [i * 10 for i in range(50)]): + self.assertEqual(v1, exp) + self.assertEqual(v2, exp) + if __name__ == "__main__": unittest.main()