Skip to content

Commit 8f36b0d

Browse files
committed
feat: added improved progress reporting with Rich
1 parent 895a806 commit 8f36b0d

File tree

4 files changed

+91
-26
lines changed

4 files changed

+91
-26
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ authors = [
99
requires-python = ">=3.10"
1010
dependencies = [
1111
"papermill>=2.6.0",
12+
"rich>=14.0.0",
1213
"typer>=0.16.0",
1314
]
1415

src/millrun/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def run(
7878
recursive,
7979
exclude_glob_pattern,
8080
include_glob_pattern,
81-
multiprocessing=True
81+
use_multiprocessing=True
8282
# **kwargs
8383
)
8484

src/millrun/millrun.py

Lines changed: 87 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
from typing import Optional, Any
44
import papermill as pm
55
import functools as ft
6+
import multiprocessing
67
from concurrent.futures import ProcessPoolExecutor
8+
from rich.progress import Progress, BarColumn, TimeRemainingColumn, TimeElapsedColumn
9+
710

811

912
def execute_batch(
@@ -15,7 +18,7 @@ def execute_batch(
1518
recursive: bool = False,
1619
exclude_glob_pattern: Optional[str] = None,
1720
include_glob_pattern: Optional[str] = None,
18-
multiprocessing: bool = False,
21+
use_multiprocessing: bool = False,
1922
**kwargs,
2023
) -> list[pathlib.Path] | None:
2124
"""
@@ -95,7 +98,7 @@ def execute_batch(
9598
output_prepend_components,
9699
output_append_components,
97100
output_dir,
98-
multiprocessing
101+
use_multiprocessing
99102
)
100103
else:
101104
glob_method = notebook_dir.glob
@@ -111,7 +114,7 @@ def execute_batch(
111114
glob_pattern = "*.ipynb"
112115
included_paths = set(glob_method(glob_pattern))
113116

114-
notebook_paths = included_paths - excluded_paths
117+
notebook_paths = sorted(included_paths - excluded_paths)
115118

116119
for notebook_path in notebook_paths:
117120
execute_notebooks(
@@ -120,8 +123,13 @@ def execute_batch(
120123
output_prepend_components,
121124
output_append_components,
122125
output_dir,
123-
multiprocessing
126+
use_multiprocessing,
124127
)
128+
# Multiprocessing approach inspired by
129+
# https://www.deanmontgomery.com/2022/03/24/rich-progress-and-multiprocessing/
130+
131+
132+
125133

126134
def check_unequal_value_lengths(bulk_params: dict[str, list]) -> bool | dict:
127135
"""
@@ -175,43 +183,94 @@ def get_output_name(
175183
notebook_filename = pathlib.Path(notebook_filename)
176184
return "-".join([elem for elem in [prepend_str, notebook_filename.stem, append_str] if elem]) + notebook_filename.suffix
177185

186+
# notebook_path,
187+
# bulk_params_list,
188+
# output_prepend_components,
189+
# output_append_components,
190+
# output_dir,
191+
# _progress,
192+
# task_id
193+
178194

179195
def execute_notebooks(
180-
notebook_filename: pathlib.Path,
196+
notebook_path: pathlib.Path,
181197
bulk_params_list: dict[str, Any],
182198
output_prepend_components: list[str],
183199
output_append_components: list[str],
184200
output_dir: pathlib.Path,
185-
multiprocessing: bool = False,
186-
**kwargs,
201+
use_multiprocessing: bool,
202+
**kwargs
187203
):
188-
mp_execute_notebook = ft.partial(
189-
execute_notebook,
190-
notebook_filename=notebook_filename,
191-
output_prepend_components=output_prepend_components,
192-
output_append_components=output_append_components,
193-
output_dir=output_dir,
194-
)
195-
# print(mp_execute_notebook(notebook_params=bulk_params_list))
196-
if multiprocessing:
197-
with ProcessPoolExecutor() as executor:
198-
for result in executor.map(mp_execute_notebook, bulk_params_list):
199-
pass
204+
total_variations = len(bulk_params_list)
205+
if not use_multiprocessing:
206+
with Progress(
207+
"[progress.description]{task.description}",
208+
BarColumn(),
209+
"[progress.percentage]{task.percentage:>3.0f}%",
210+
TimeElapsedColumn(),
211+
refresh_per_second=1, # bit slower updates
212+
) as progress:
213+
task_id = progress.add_task(notebook_path.name, total=total_variations)
214+
for idx, notebook_params in (list(enumerate(bulk_params_list))):
215+
execute_notebook(
216+
notebook_filename=notebook_path,
217+
notebook_params=notebook_params,
218+
output_prepend_components=output_prepend_components,
219+
output_append_components=output_append_components,
220+
output_dir=output_dir,
221+
**kwargs,
222+
)
223+
progress.update(task_id, completed=idx + 2)
200224
else:
201-
for result in map(mp_execute_notebook, bulk_params_list):
202-
pass
225+
with Progress(
226+
"[progress.description]{task.description}",
227+
BarColumn(),
228+
"[progress.percentage]{task.percentage:>3.0f}%",
229+
TimeElapsedColumn(),
230+
refresh_per_second=1, # bit slower updates
231+
) as progress:
232+
futures = [] # keep track of the jobs
233+
with multiprocessing.Manager() as manager:
234+
_progress = manager.dict()
235+
overall_progress_task = progress.add_task(f"{notebook_path.name}", visible=True, total=total_variations)
236+
with ProcessPoolExecutor() as executor:
237+
for idx, notebook_params in enumerate(bulk_params_list):
238+
futures.append(
239+
executor.submit(
240+
execute_notebook,
241+
notebook_path,
242+
notebook_params,
243+
output_prepend_components,
244+
output_append_components,
245+
output_dir,
246+
total_variations,
247+
idx,
248+
_progress,
249+
overall_progress_task
250+
)
251+
)
203252

253+
# monitor the progress:
254+
while (n_finished := sum([future.done() for future in futures])) < len(
255+
futures
256+
):
257+
progress.update(
258+
overall_progress_task, completed=n_finished + 1
259+
)
204260

205261

206262
def execute_notebook(
207-
notebook_params: dict,
208263
notebook_filename: pathlib.Path,
264+
notebook_params: dict,
209265
output_prepend_components: list[str],
210266
output_append_components: list[str],
211267
output_dir: pathlib.Path,
268+
total_variations: Optional[int] = None,
269+
current_iteration: Optional[int] = None,
270+
_progress: Optional[dict] = None,
271+
_task_id: Optional[str] = None,
212272
**kwargs,
213273
):
214-
print(notebook_filename)
215274
output_name = get_output_name(
216275
notebook_filename,
217276
output_prepend_components,
@@ -222,6 +281,9 @@ def execute_notebook(
222281
notebook_filename,
223282
output_path=output_dir / output_name,
224283
parameters=notebook_params,
225-
progress_bar=True,
284+
progress_bar=False,
285+
cwd=str(notebook_filename.parent),
226286
**kwargs
227-
)
287+
)
288+
if _progress is not None:
289+
_progress[_task_id] = {"progress": current_iteration / total_variations, "total": total_variations}

uv.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)