diff --git a/torchvision/datasets/utils.py b/torchvision/datasets/utils.py index 4675013ac33..fef84cbeea7 100644 --- a/torchvision/datasets/utils.py +++ b/torchvision/datasets/utils.py @@ -5,10 +5,12 @@ from tqdm import tqdm -def gen_bar_updator(pbar): +def gen_bar_updater(pbar): def bar_update(count, block_size, total_size): - pbar.total = total_size / block_size - pbar.update(count) + if pbar.total is None and total_size: + pbar.total = total_size + progress_bytes = count * block_size + pbar.update(progress_bytes - pbar.n) return bar_update @@ -47,13 +49,19 @@ def download_url(url, root, filename, md5): else: try: print('Downloading ' + url + ' to ' + fpath) - urllib.request.urlretrieve(url, fpath, reporthook=gen_bar_updator(tqdm())) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) + ) except: if url[:5] == 'https': url = url.replace('https:', 'http:') print('Failed download. Trying https -> http instead.' ' Downloading ' + url + ' to ' + fpath) - urllib.request.urlretrieve(url, fpath) + urllib.request.urlretrieve( + url, fpath, + reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True)) + ) def list_dir(root, prefix=False):