Skip to content

Commit

Permalink
Make async executors work in Jupyter notebooks (#661)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomwhite authored Jan 11, 2025
1 parent 4fce79a commit 848a5e7
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 4 deletions.
3 changes: 2 additions & 1 deletion cubed/runtime/executors/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
from cubed.runtime.types import Callback, CubedPipeline, DagExecutor
from cubed.runtime.utils import (
asyncio_run,
execution_stats,
gensym,
handle_callbacks,
Expand Down Expand Up @@ -170,7 +171,7 @@ def execute_dag(
**kwargs,
) -> None:
merged_kwargs = {**self.kwargs, **kwargs}
asyncio.run(
asyncio_run(
async_execute_dag(
dag,
callbacks=callbacks,
Expand Down
5 changes: 3 additions & 2 deletions cubed/runtime/executors/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
from cubed.runtime.types import Callback, CubedPipeline, DagExecutor, TaskEndEvent
from cubed.runtime.utils import (
asyncio_run,
execution_stats,
execution_timing,
handle_callbacks,
Expand Down Expand Up @@ -271,7 +272,7 @@ def execute_dag(
**kwargs,
) -> None:
merged_kwargs = {**self.kwargs, **kwargs}
asyncio.run(
asyncio_run(
async_execute_dag(
dag,
callbacks=callbacks,
Expand Down Expand Up @@ -310,7 +311,7 @@ def execute_dag(
**kwargs,
) -> None:
merged_kwargs = {**self.kwargs, **kwargs}
asyncio.run(
asyncio_run(
async_execute_dag(
dag,
callbacks=callbacks,
Expand Down
3 changes: 2 additions & 1 deletion cubed/runtime/executors/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from cubed.runtime.pipeline import visit_node_generations, visit_nodes
from cubed.runtime.types import Callback, DagExecutor
from cubed.runtime.utils import (
asyncio_run,
execute_with_stats,
handle_callbacks,
handle_operation_start_callbacks,
Expand Down Expand Up @@ -261,7 +262,7 @@ def execute_dag(
**kwargs,
) -> None:
merged_kwargs = {**self.kwargs, **kwargs}
asyncio.run(
asyncio_run(
async_execute_dag(
dag,
callbacks=callbacks,
Expand Down
15 changes: 15 additions & 0 deletions cubed/runtime/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import asyncio
import time
from concurrent.futures import ThreadPoolExecutor
from contextlib import nullcontext
from functools import partial
from itertools import islice
Expand Down Expand Up @@ -115,6 +117,19 @@ def handle_callbacks(callbacks, result, stats):
[callback.on_task_end(event) for callback in callbacks]


# Like asyncio.run(), but works in a Jupyter notebook
# Based on https://stackoverflow.com/a/75341431
def asyncio_run(coro):
try:
asyncio.get_running_loop() # Triggers RuntimeError if no running event loop
except RuntimeError:
return asyncio.run(coro)
else:
# Create a separate thread so we can block before returning
with ThreadPoolExecutor(1) as pool:
return pool.submit(lambda: asyncio.run(coro)).result()


# this will be in Python 3.12 https://docs.python.org/3.12/library/itertools.html#itertools.batched
def batched(iterable, n):
# batched('ABCDEFG', 3) --> ABC DEF G
Expand Down

0 comments on commit 848a5e7

Please sign in to comment.