Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use asyncio to parallelize notebook testing #1201

Merged
merged 5 commits into from
Apr 22, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 23 additions & 26 deletions scripts/nb-tester/test-notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# been altered from the originals.

import argparse
import asyncio
import sys
import textwrap
from dataclasses import dataclass
Expand Down Expand Up @@ -81,26 +82,19 @@ class NotebookWarning:
cell_index: int
msg: str

def report(self):
def format(self) -> str:
"""
Format warning and print it
Format warning to pretty string
"""
message = f"Warning detected in cell {self.cell_index}:\n"
message = f"Cell {self.cell_index}:\n"
for line in self.msg.splitlines():
message += (
textwrap.fill(
line, width=77, initial_indent=" │ ", subsequent_indent=" │ "
)
+ "\n"
)
print_yellow(message, flush=True)


def print_yellow(s: str, **kwargs):
"""
Use ANSI escape codes to print yellow text
"""
print(f"\033[0;33m{s}\033[0m", **kwargs)
return message


def extract_warnings(notebook: nbformat.NotebookNode) -> list[NotebookWarning]:
Expand All @@ -121,33 +115,34 @@ def extract_warnings(notebook: nbformat.NotebookNode) -> list[NotebookWarning]:
return notebook_warnings


def execute_notebook(path: Path, args: argparse.Namespace) -> bool:
async def execute_notebook(path: Path, args: argparse.Namespace) -> bool:
"""
Wrapper function for `_execute_notebook` to print status
"""
print(f"▶️ {path}", end="", flush=True)
print(f"▶️ Executing {path}")
possible_exceptions = (
nbconvert.preprocessors.CellExecutionError,
nbclient.exceptions.CellTimeoutError,
)
try:
nb = _execute_notebook(path, args)
nb = await _execute_notebook(path, args)
except possible_exceptions as err:
print("\r❌\n")
print(err)
print(f"❌ Problem in {path}:\n{err}")
return False

notebook_warnings = extract_warnings(nb)
if notebook_warnings:
print("\r⚠️")
[w.report() for w in notebook_warnings]
print(
f"⚠️ Warnings in {path}:\n"
+ "\n".join((w.format() for w in notebook_warnings))
)
return False

print("\r✅")
print(f"✅ No problems in {path}")
return True


def _execute_notebook(filepath: Path, args: argparse.Namespace) -> nbformat.NotebookNode:
async def _execute_notebook(filepath: Path, args: argparse.Namespace) -> nbformat.NotebookNode:
"""
Use nbconvert to execute notebook.
"""
Expand All @@ -156,12 +151,14 @@ def _execute_notebook(filepath: Path, args: argparse.Namespace) -> nbformat.Note

processor = nbconvert.preprocessors.ExecutePreprocessor(
# If submitting jobs, we want to wait forever (-1 means no timeout)
timeout=-1 if submit_jobs else 100,
timeout=-1 if submit_jobs else 300,
kernel_name="python3",
extra_arguments=["--InlineBackend.figure_format='svg'"]
)

processor.preprocess(nb)
# This runs the notebook, including possibly submitting jobs. We run it in a
# new thread to avoid blocking other notebooks from submitting jobs.
await asyncio.to_thread(processor.preprocess, nb)

if not args.write:
return nb
Expand Down Expand Up @@ -191,7 +188,7 @@ def cancel_trailing_jobs(start_time: datetime) -> bool:
"""
Cancel any runtime jobs created after `start_time`.

Return True if non exist, False otherwise.
Return True if none exist, False otherwise.

Notebooks should not submit jobs during a normal test run. If they do, the
cell will time out and this function will cancel the job to avoid wasting
Expand Down Expand Up @@ -255,15 +252,15 @@ def create_argument_parser() -> argparse.ArgumentParser:
return parser


def main() -> None:
async def main() -> None:
args = create_argument_parser().parse_args()
paths = map(Path, args.filenames or find_notebooks())
filtered_paths = filter_paths(paths, args)

# Execute notebooks
start_time = datetime.now()
print("Executing notebooks:")
results = [execute_notebook(path, args) for path in filtered_paths]
results = await asyncio.gather(*(execute_notebook(path, args) for path in filtered_paths))
print("Checking for trailing jobs...")
results.append(cancel_trailing_jobs(start_time))
if not all(results):
Expand All @@ -272,4 +269,4 @@ def main() -> None:


if __name__ == "__main__":
main()
asyncio.run(main())
Loading