Skip to content

Commit

Permalink
Merge pull request #90 from Visual-Behavior/75-bar-progress-samples-d…
Browse files Browse the repository at this point in the history
…ownload

sample download progress bar and skip user prompt
  • Loading branch information
thibo73800 authored Oct 1, 2021
2 parents c925712 + e92f8a6 commit c3b7032
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
30 changes: 21 additions & 9 deletions alodataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import shutil
import torch
import json
from tqdm import tqdm
from typing import List, Callable, Dict
from enum import Enum

Expand Down Expand Up @@ -248,13 +249,14 @@ def set_dataset_dir(self, dataset_dir: str):
content = json.loads(f.read())

if dataset_dir is None:
dataset_dir = _user_prompt(
f"{self.name} does not exist in config file. "
+ "Do you want to download and use a sample?: (Y)es or (N)o: "
)
if dataset_dir.lower() in ["y", "yes"]: # Download sample and change root directory
self.sample = True
return os.path.join(self.vb_folder, "samples")
if self.name in DATASETS_DOWNLOAD_PATHS:
dataset_dir = _user_prompt(
f"{self.name} does not exist in config file. "
+ "Do you want to download and use a sample?: (Y)es or (N)o: "
)
if dataset_dir.lower() in ["y", "yes"]: # Download sample and change root directory
self.sample = True
return os.path.join(self.vb_folder, "samples")
dataset_dir = _user_prompt(f"Please write a new root directory for {self.name} dataset: ")
dataset_dir = os.path.expanduser(dataset_dir)

Expand Down Expand Up @@ -344,8 +346,18 @@ def download_sample(self) -> str:
if not os.path.exists(os.path.join(dest)):
print(f"Download {self.name} sample...")
if "http" in src:
r = requests.get(src, allow_redirects=True)
open(dest, "wb").write(r.content)
with open(dest, "wb") as f:
response = requests.get(src, stream=True)
total_length = response.headers.get("content-length")

if total_length is None: # no content length header
f.write(response.content)
else:
pbar = tqdm()
pbar.reset(total=int(total_length)) # initialise with new `total`
for data in response.iter_content(chunk_size=4096):
f.write(data)
pbar.update(len(data))
else:
shutil.copy2(src, dest)

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ pytorch_lightning==1.4.1
opencv-python==4.5.3.56
python-dateutil==2.8.2
urllib3==1.26.6
tqdm==4.62.3
#onnx_graphsurgeon==0.0.1

0 comments on commit c3b7032

Please sign in to comment.