Skip to content

Commit

Permalink
Fix progress bar not always closed in file_download.py (#2308)
Browse files Browse the repository at this point in the history
* Fix progress bar not always closed in file_download.py

* Fix feedback from charles
  • Loading branch information
Wauplin authored Jun 3, 2024
1 parent c8dc5f5 commit 919ce7d
Showing 1 changed file with 69 additions and 64 deletions.
133 changes: 69 additions & 64 deletions src/huggingface_hub/file_download.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import copy
import errno
import fnmatch
Expand Down Expand Up @@ -487,9 +488,8 @@ def http_get(
)

# Stream file to buffer
progress = _tqdm_bar
if progress is None:
progress = tqdm(
progress_cm: tqdm = (
tqdm( # type: ignore[assignment]
unit="B",
unit_scale=True,
total=total,
Expand All @@ -500,71 +500,76 @@ def http_get(
# see https://github.com/huggingface/huggingface_hub/pull/2000
name="huggingface_hub.http_get",
)
if _tqdm_bar is None
else contextlib.nullcontext(_tqdm_bar)
# ^ `contextlib.nullcontext` mimics a context manager that does nothing
# Makes it easier to use the same code path for both cases but in the later
# case, the progress bar is not closed when exiting the context manager.
)

if hf_transfer and total is not None and total > 5 * DOWNLOAD_CHUNK_SIZE:
supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters
if not supports_callback:
warnings.warn(
"You are using an outdated version of `hf_transfer`. "
"Consider upgrading to latest version to enable progress bars "
"using `pip install -U hf_transfer`."
)
with progress_cm as progress:
if hf_transfer and total is not None and total > 5 * DOWNLOAD_CHUNK_SIZE:
supports_callback = "callback" in inspect.signature(hf_transfer.download).parameters
if not supports_callback:
warnings.warn(
"You are using an outdated version of `hf_transfer`. "
"Consider upgrading to latest version to enable progress bars "
"using `pip install -U hf_transfer`."
)
try:
hf_transfer.download(
url=url,
filename=temp_file.name,
max_files=HF_TRANSFER_CONCURRENCY,
chunk_size=DOWNLOAD_CHUNK_SIZE,
headers=headers,
parallel_failures=3,
max_retries=5,
**({"callback": progress.update} if supports_callback else {}),
)
except Exception as e:
raise RuntimeError(
"An error occurred while downloading using `hf_transfer`. Consider"
" disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling."
) from e
if not supports_callback:
progress.update(total)
if expected_size is not None and expected_size != os.path.getsize(temp_file.name):
raise EnvironmentError(
consistency_error_message.format(
actual_size=os.path.getsize(temp_file.name),
)
)
return
new_resume_size = resume_size
try:
hf_transfer.download(
for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
new_resume_size += len(chunk)
# Some data has been downloaded from the server so we reset the number of retries.
_nb_retries = 5
except (requests.ConnectionError, requests.ReadTimeout) as e:
# If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely
# a transient error (network outage?). We log a warning message and try to resume the download a few times
# before giving up. Tre retry mechanism is basic but should be enough in most cases.
if _nb_retries <= 0:
logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e))
raise
logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e))
time.sleep(1)
reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects
return http_get(
url=url,
filename=temp_file.name,
max_files=HF_TRANSFER_CONCURRENCY,
chunk_size=DOWNLOAD_CHUNK_SIZE,
headers=headers,
parallel_failures=3,
max_retries=5,
**({"callback": progress.update} if supports_callback else {}),
)
except Exception as e:
raise RuntimeError(
"An error occurred while downloading using `hf_transfer`. Consider"
" disabling HF_HUB_ENABLE_HF_TRANSFER for better error handling."
) from e
if not supports_callback:
progress.update(total)
if expected_size is not None and expected_size != os.path.getsize(temp_file.name):
raise EnvironmentError(
consistency_error_message.format(
actual_size=os.path.getsize(temp_file.name),
)
temp_file=temp_file,
proxies=proxies,
resume_size=new_resume_size,
headers=initial_headers,
expected_size=expected_size,
_nb_retries=_nb_retries - 1,
_tqdm_bar=_tqdm_bar,
)
return
new_resume_size = resume_size
try:
for chunk in r.iter_content(chunk_size=DOWNLOAD_CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
progress.update(len(chunk))
temp_file.write(chunk)
new_resume_size += len(chunk)
# Some data has been downloaded from the server so we reset the number of retries.
_nb_retries = 5
except (requests.ConnectionError, requests.ReadTimeout) as e:
# If ConnectionError (SSLError) or ReadTimeout happen while streaming data from the server, it is most likely
# a transient error (network outage?). We log a warning message and try to resume the download a few times
# before giving up. Tre retry mechanism is basic but should be enough in most cases.
if _nb_retries <= 0:
logger.warning("Error while downloading from %s: %s\nMax retries exceeded.", url, str(e))
raise
logger.warning("Error while downloading from %s: %s\nTrying to resume download...", url, str(e))
time.sleep(1)
reset_sessions() # In case of SSLError it's best to reset the shared requests.Session objects
return http_get(
url=url,
temp_file=temp_file,
proxies=proxies,
resume_size=new_resume_size,
headers=initial_headers,
expected_size=expected_size,
_nb_retries=_nb_retries - 1,
_tqdm_bar=_tqdm_bar,
)

progress.close()

if expected_size is not None and expected_size != temp_file.tell():
raise EnvironmentError(
Expand Down

0 comments on commit 919ce7d

Please sign in to comment.