Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Typing][B-58] Add type annotations for python/paddle/utils/download.py #65824

Merged
merged 2 commits into from
Jul 9, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions python/paddle/utils/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import hashlib
import os
import os.path as osp
Expand All @@ -20,6 +22,7 @@
import tarfile
import time
import zipfile
from typing import Literal

import httpx

Expand Down Expand Up @@ -58,7 +61,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
DOWNLOAD_RETRY_LIMIT = 3


def is_url(path):
def is_url(path: str) -> bool:
"""
Whether path is URL.
Args:
Expand All @@ -67,7 +70,7 @@ def is_url(path):
return path.startswith('http://') or path.startswith('https://')


def get_weights_path_from_url(url, md5sum=None):
def get_weights_path_from_url(url: str, md5sum: str = None) -> str:
SigureMo marked this conversation as resolved.
Show resolved Hide resolved
"""Get weights path from WEIGHT_HOME, if not exists,
download it from url.

Expand Down Expand Up @@ -114,8 +117,13 @@ def _get_unique_endpoints(trainer_endpoints):


def get_path_from_url(
url, root_dir, md5sum=None, check_exist=True, decompress=True, method='get'
):
url: str,
root_dir: str,
md5sum: str | None = None,
check_exist: bool = True,
decompress: bool = True,
method: Literal['wget', 'get'] = 'get',
) -> str:
"""Download from given url to root_dir.
if file or directory specified by url is exists under
root_dir, return the path directly, otherwise download
Expand All @@ -125,9 +133,9 @@ def get_path_from_url(
url (str): download url
root_dir (str): root dir for downloading, it should be
WEIGHTS_HOME or DATASET_HOME
md5sum (str): md5 sum of download package
decompress (bool): decompress zip or tar file. Default is `True`
method (str): which download method to use. Support `wget` and `get`. Default is `get`.
md5sum (str|None, optional): md5 sum of download package
decompress (bool, optional): decompress zip or tar file. Default is `True`
method (str, optional): which download method to use. Support `wget` and `get`. Default is `get`.

Returns:
str: a local path to save downloaded models & weights & datasets.
Expand Down