From 9c956e7fd578e25404d4d6d31402c445c73e49f0 Mon Sep 17 00:00:00 2001 From: Holger Kohr Date: Mon, 4 Jun 2018 00:11:19 +0200 Subject: [PATCH] Fix broken progress bar - Fix broken update calculation - Make progress bar use the neat `unit_scale` feature of tqdm --- torchvision/datasets/utils.py | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 4675013ac33..fef84cbeea7 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -5,10 +5,12 @@ from tqdm import tqdm -def gen_bar_updator(pbar): +def gen_bar_updater(pbar): def bar_update(count, block_size, total_size): - pbar.total = total_size / block_size - pbar.update(count) + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) return bar_update @@ -47,13 +49,19 @@ def download_url(url, root, filename, md5): else: try: print('Downloading ' + url + ' to ' + fpath) - urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updator(tqdm())) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) + ) except: if url[:5] == 'https': url = url.replace('https:', 'http:') print('Failed download. Trying https -> http instead.' ' Downloading ' + url + ' to ' + fpath) - urllib.request.urlretrieve(url, fpath) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) + ) def list_dir(root, prefix=False):