Skip to content

Commit

Permalink
Create new eventloop to run coroutines (#1055)
Browse files Browse the repository at this point in the history
Summary:
Prevent calling `asyncio.run` within `__iter__` function. New event loop would be created rather than reusing the global event loop.
Using policy for the case of custom policy is used as suggested in https://docs.python.org/3.8/library/asyncio-eventloop.html#event-loop

Address the [comment](https://docs.python.org/3/library/asyncio-runner.html#:~:text=It%20should%20be%20used%20as,ideally%20only%20be%20called%20once.&text=New%20in%20version%203.7.)

Pull Request resolved: #1055

Reviewed By: wenleix, dracifer

Differential Revision: D43630185

Pulled By: ejguan

fbshipit-source-id: 2dbaffb379971433e29dfcbaa936064c4ce0a44b
  • Loading branch information
ejguan authored and facebook-github-bot committed Feb 28, 2023
1 parent d7835c0 commit f30281b
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
9 changes: 9 additions & 0 deletions test/test_iterdatapipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -1606,6 +1606,15 @@ def _helper(input_data, exp_res, async_fn, input_col=None, output_col=None, max_
output_col=1,
)

# Test multiple asyncio eventloops
dp1 = IterableWrapper(range(50))
dp1 = dp1.async_map_batches(_async_mul_ten, 16)
dp2 = IterableWrapper(range(50))
dp2 = dp2.async_map_batches(_async_mul_ten, 16)
for v1, v2, exp in zip(dp1, dp2, [i * 10 for i in range(50)]):
self.assertEqual(v1, exp)
self.assertEqual(v2, exp)


if __name__ == "__main__":
unittest.main()
13 changes: 10 additions & 3 deletions torchdata/datapipes/iter/transform/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,9 +447,16 @@ def __init__(
self.max_concurrency = max_concurrency

def __iter__(self):
for batch in self.source_datapipe:
new_batch = asyncio.run(self.processbatch(batch))
yield new_batch
policy = asyncio.get_event_loop_policy()
loop = policy.new_event_loop()
try:
for batch in self.source_datapipe:
policy.set_event_loop(loop)
new_batch = loop.run_until_complete(self.processbatch(batch))
yield new_batch
finally:
loop.run_until_complete(loop.shutdown_asyncgens())
loop.close()

async def processbatch(self, batch):
sem = asyncio.Semaphore(self.max_concurrency)
Expand Down

0 comments on commit f30281b

Please sign in to comment.