Skip to content

Commit

Permalink
Use asyncio to parallelize notebook testing (Qiskit#1201)
Browse files Browse the repository at this point in the history
An alternative to Qiskit#1181.
This should help when executing many notebooks that submit jobs as
they'll all submit their jobs immediately and begin queueing in
parallel.

---------

Co-authored-by: Frank Harkins <frankharkins@hotmail.co.uk>
  • Loading branch information
Eric-Arellano and frankharkins authored Apr 22, 2024
1 parent 2b0ae5f commit aca0d17
Showing 1 changed file with 23 additions and 26 deletions.
49 changes: 23 additions & 26 deletions scripts/nb-tester/test-notebook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# been altered from the originals.

import argparse
import asyncio
import sys
import textwrap
from dataclasses import dataclass
Expand Down Expand Up @@ -81,26 +82,19 @@ class NotebookWarning:
cell_index: int
msg: str

def report(self):
def format(self) -> str:
"""
Format warning and print it
Format warning to pretty string
"""
message = f"Warning detected in cell {self.cell_index}:\n"
message = f"Cell {self.cell_index}:\n"
for line in self.msg.splitlines():
message += (
textwrap.fill(
line, width=77, initial_indent=" │ ", subsequent_indent=" │ "
)
+ "\n"
)
print_yellow(message, flush=True)


def print_yellow(s: str, **kwargs):
"""
Use ANSI escape codes to print yellow text
"""
print(f"\033[0;33m{s}\033[0m", **kwargs)
return message


def extract_warnings(notebook: nbformat.NotebookNode) -> list[NotebookWarning]:
Expand All @@ -121,33 +115,34 @@ def extract_warnings(notebook: nbformat.NotebookNode) -> list[NotebookWarning]:
return notebook_warnings


def execute_notebook(path: Path, args: argparse.Namespace) -> bool:
async def execute_notebook(path: Path, args: argparse.Namespace) -> bool:
"""
Wrapper function for `_execute_notebook` to print status
"""
print(f"▶️ {path}", end="", flush=True)
print(f"▶️ Executing {path}")
possible_exceptions = (
nbconvert.preprocessors.CellExecutionError,
nbclient.exceptions.CellTimeoutError,
)
try:
nb = _execute_notebook(path, args)
nb = await _execute_notebook(path, args)
except possible_exceptions as err:
print("\r\n")
print(err)
print(f"❌ Problem in {path}:\n{err}")
return False

notebook_warnings = extract_warnings(nb)
if notebook_warnings:
print("\r⚠️")
[w.report() for w in notebook_warnings]
print(
f"⚠️ Warnings in {path}:\n"
+ "\n".join((w.format() for w in notebook_warnings))
)
return False

print("\r")
print(f"✅ No problems in {path}")
return True


def _execute_notebook(filepath: Path, args: argparse.Namespace) -> nbformat.NotebookNode:
async def _execute_notebook(filepath: Path, args: argparse.Namespace) -> nbformat.NotebookNode:
"""
Use nbconvert to execute notebook.
"""
Expand All @@ -156,12 +151,14 @@ def _execute_notebook(filepath: Path, args: argparse.Namespace) -> nbformat.Note

processor = nbconvert.preprocessors.ExecutePreprocessor(
# If submitting jobs, we want to wait forever (-1 means no timeout)
timeout=-1 if submit_jobs else 100,
timeout=-1 if submit_jobs else 300,
kernel_name="python3",
extra_arguments=["--InlineBackend.figure_format='svg'"]
)

processor.preprocess(nb)
# This runs the notebook, including possibly submitting jobs. We run it in a
# new thread to avoid blocking other notebooks from submitting jobs.
await asyncio.to_thread(processor.preprocess, nb)

if not args.write:
return nb
Expand Down Expand Up @@ -191,7 +188,7 @@ def cancel_trailing_jobs(start_time: datetime) -> bool:
"""
Cancel any runtime jobs created after `start_time`.
Return True if non exist, False otherwise.
Return True if none exist, False otherwise.
Notebooks should not submit jobs during a normal test run. If they do, the
cell will time out and this function will cancel the job to avoid wasting
Expand Down Expand Up @@ -255,15 +252,15 @@ def create_argument_parser() -> argparse.ArgumentParser:
return parser


def main() -> None:
async def main() -> None:
args = create_argument_parser().parse_args()
paths = map(Path, args.filenames or find_notebooks())
filtered_paths = filter_paths(paths, args)

# Execute notebooks
start_time = datetime.now()
print("Executing notebooks:")
results = [execute_notebook(path, args) for path in filtered_paths]
results = await asyncio.gather(*(execute_notebook(path, args) for path in filtered_paths))
print("Checking for trailing jobs...")
results.append(cancel_trailing_jobs(start_time))
if not all(results):
Expand All @@ -272,4 +269,4 @@ def main() -> None:


if __name__ == "__main__":
main()
asyncio.run(main())

0 comments on commit aca0d17

Please sign in to comment.