Skip to content

Commit 526994d

Browse files
committed
add a progress_bar for loading a BalancingLearner
1 parent d79512f commit 526994d

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

adaptive/learner/balancing_learner.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def save(self, fname, compress=True):
377377
for l in self.learners:
378378
l.save(fname(l), compress=compress)
379379

380-
def load(self, fname, compress=True):
380+
def load(self, fname, compress=True, with_progress_bar=False):
381381
"""Load the data of the child learners from pickle files
382382
in a directory.
383383
@@ -389,16 +389,30 @@ def load(self, fname, compress=True):
389389
compress : bool, default True
390390
If the data is compressed when saved, one must load it
391391
with compression too.
392+
with_progress_bar : bool, default False
393+
Display a progress bar using `tqdm`.
392394
393395
Example
394396
-------
395397
See the example in the `BalancingLearner.save` doc-string.
396398
"""
399+
def progress(seq):
400+
if not with_progress_bar:
401+
return seq
402+
else:
403+
from adaptive.notebook_integration import in_ipynb
404+
if in_ipynb():
405+
from tqdm import tqdm_notebook
406+
return tqdm_notebook(list(seq))
407+
else:
408+
from tqdm import tqdm
409+
return tqdm(list(seq))
410+
397411
if isinstance(fname, Iterable):
398-
for l, _fname in zip(self.learners, fname):
412+
for l, _fname in progress(zip(self.learners, fname)):
399413
l.load(_fname, compress=compress)
400414
else:
401-
for l in self.learners:
415+
for l in progress(self.learners):
402416
l.load(fname(l), compress=compress)
403417

404418
def _get_data(self):

environment.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ dependencies:
1616
- ipywidgets
1717
- scikit-optimize
1818
- plotly
19+
- tqdm

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ def get_version_and_cmdclass(package_name):
3636
"bokeh",
3737
"matplotlib",
3838
"plotly",
39+
"tqdm",
3940
]
4041
}
4142

0 commit comments

Comments
 (0)