From 16569d32ace74013ffe740e56deaf2004cb62f1c Mon Sep 17 00:00:00 2001 From: Xiang Zhen Gan Date: Wed, 10 Jul 2024 16:39:05 +0800 Subject: [PATCH 1/3] support download dataset --- README.md | 15 +++ README_EN.md | 14 +++ pycsghub/constants.py | 4 + pycsghub/file_download.py | 33 ++--- pycsghub/snapshot_download.py | 29 +++-- pycsghub/test/snapshot_download_test.py | 13 +- pycsghub/utils.py | 157 ++++++++++++++++++++++-- 7 files changed, 226 insertions(+), 39 deletions(-) diff --git a/README.md b/README.md index e837240..a3e34e5 100644 --- a/README.md +++ b/README.md @@ -126,6 +126,21 @@ result = snapshot_download(repo_id, token=token) ``` +### 数据集下载 +```python +from pycsghub.snapshot_download import snapshot_download +token="xxxx" +endpoint = "https://hub.opencsg.com" +repo_id = 'AIWizards/tmmluplus' +repo_type="dataset" +cache_dir = '/Users/xiangzhen/Downloads/' +result = snapshot_download(repo_id, + repo_type=repo_type, + cache_dir=cache_dir, + endpoint=endpoint, + token=token) +``` + ### 兼容huggingface的模型加载 huggingface的transformers库支持直接输入huggingface上的repo_id以下载并读取相关模型,如下列所示: diff --git a/README_EN.md b/README_EN.md index 024d180..b4b7607 100644 --- a/README_EN.md +++ b/README_EN.md @@ -128,6 +128,20 @@ result = snapshot_download(repo_id, endpoint=endpoint, token=token) ``` +### Download dataset +```python +from pycsghub.snapshot_download import snapshot_download +token="xxxx" +endpoint = "https://hub.opencsg.com" +repo_id = 'AIWizards/tmmluplus' +repo_type="dataset" +cache_dir = '/Users/xiangzhen/Downloads/' +result = snapshot_download(repo_id, + repo_type=repo_type, + cache_dir=cache_dir, + endpoint=endpoint, + token=token) +``` ### Model loading compatible with huggingface diff --git a/pycsghub/constants.py b/pycsghub/constants.py index 83c77c7..0cccfd8 100644 --- a/pycsghub/constants.py +++ b/pycsghub/constants.py @@ -3,6 +3,10 @@ API_FILE_DOWNLOAD_TIMEOUT = 5 API_FILE_DOWNLOAD_RETRY_TIMES = 5 +REPO_TYPE_DATASET = "dataset" +REPO_TYPE_MODEL = "model" +REPO_TYPE_SPACE = "space" +REPO_TYPES = [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET] CSGHUB_HOME = os.environ.get('CSGHUB_HOME', '/home') CSGHUB_TOKEN_PATH = os.environ.get("CSGHUB_TOKEN_PATH", os.path.join(CSGHUB_HOME, "token")) diff --git a/pycsghub/file_download.py b/pycsghub/file_download.py index ba6ce90..c3dc1bc 100644 --- a/pycsghub/file_download.py +++ b/pycsghub/file_download.py @@ -12,7 +12,7 @@ from pycsghub.utils import (build_csg_headers, get_cache_dir, model_id_to_group_owner_name, - pack_model_file_info, + pack_repo_file_info, get_file_download_url, get_endpoint) from pycsghub.constants import (API_FILE_DOWNLOAD_RETRY_TIMES, @@ -39,6 +39,7 @@ def csg_hub_download(): def get_csg_hub_url(): pass + def file_download( repo_id: str, *, @@ -98,9 +99,9 @@ def file_download( with tempfile.TemporaryDirectory( dir=temporary_cache_dir) as temp_cache_dir: - model_file_info = pack_model_file_info(file_name, revision) - if not cache.exists(model_file_info): - file_name = os.path.basename(model_file_info['Path']) + repo_file_info = pack_repo_file_info(file_name, revision) + if not cache.exists(repo_file_info): + file_name = os.path.basename(repo_file_info['Path']) # get download url url = get_file_download_url( model_id=repo_id, @@ -117,7 +118,7 @@ def file_download( # todo using hash to check file integrity temp_file = os.path.join(temp_cache_dir, file_name) - cache.put_file(model_file_info, temp_file) + cache.put_file(repo_file_info, temp_file) else: print( f'File {file_name} already in cache, skip downloading!' @@ -125,6 +126,7 @@ def file_download( cache.save_model_version(revision_info={'Revision': revision}) return os.path.join(cache.get_root_location(), file_name) + def http_get(*, url: str, local_dir: str, @@ -153,19 +155,20 @@ def http_get(*, downloaded_size = temp_file.tell() if downloaded_size > 0: get_headers['Range'] = 'bytes=%d-' % downloaded_size - r = requests.get(url, headers=get_headers, stream=True, cookies=cookies, timeout=API_FILE_DOWNLOAD_TIMEOUT) + r = requests.get(url, headers=get_headers, stream=True, + cookies=cookies, timeout=API_FILE_DOWNLOAD_TIMEOUT) r.raise_for_status() accept_ranges = r.headers.get('Accept-Ranges') content_length = r.headers.get('Content-Length') if accept_ranges == 'bytes': if downloaded_size == 0: - total_content_length = int(content_length) if content_length is not None else None + total_content_length = int(content_length) if content_length is not None else None else: if downloaded_size > 0: temp_file.truncate(0) downloaded_size = temp_file.tell() - total_content_length = int(content_length) if content_length is not None else None - + total_content_length = int(content_length) if content_length is not None else None + progress = tqdm( unit='B', unit_scale=True, @@ -187,8 +190,12 @@ def http_get(*, downloaded_length = os.path.getsize(temp_file.name) if total_content_length != downloaded_length: os.remove(temp_file.name) - msg = 'File %s download incomplete, content_length: %s but the file downloaded length: %s, please download again' % (file_name, total_content_length, downloaded_length) + msg = 'File %s download incomplete, content_length: %s but the file downloaded length: %s, please download again' % ( + file_name, total_content_length, downloaded_length) raise FileDownloadError(msg) + print(f"{temp_file.name}, {local_dir}, {file_name}") + # fix folder recursive issue + os.makedirs(os.path.dirname(os.path.join(local_dir, file_name)), exist_ok=True) os.replace(temp_file.name, os.path.join(local_dir, file_name)) return @@ -208,9 +215,3 @@ def http_get(*, file_name=file_name, headers=headers, cookies=cookies) - - - - - - diff --git a/pycsghub/snapshot_download.py b/pycsghub/snapshot_download.py index dac87cb..47e0b1d 100644 --- a/pycsghub/snapshot_download.py +++ b/pycsghub/snapshot_download.py @@ -9,17 +9,18 @@ model_id_to_group_owner_name) from pycsghub.cache import ModelFileSystemCache from pycsghub.utils import (get_cache_dir, - pack_model_file_info, + pack_repo_file_info, get_endpoint) from huggingface_hub.utils import filter_repo_objects from pycsghub.file_download import http_get -from pycsghub.constants import DEFAULT_REVISION +from pycsghub.constants import DEFAULT_REVISION, REPO_TYPES from pycsghub import utils def snapshot_download( repo_id: str, *, + repo_type: Optional[str] = None, revision: Optional[str] = DEFAULT_REVISION, cache_dir: Union[str, Path, None] = None, local_files_only: Optional[bool] = False, @@ -30,6 +31,10 @@ def snapshot_download( endpoint: Optional[str] = None, token: Optional[str] = None ) -> str: + if repo_type is None: + repo_type = "model" + if repo_type not in REPO_TYPES: + raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") if cache_dir is None: cache_dir = get_cache_dir() if isinstance(cache_dir, Path): @@ -53,13 +58,14 @@ def snapshot_download( # make headers # todo need to add cookies? repo_info = utils.get_repo_info(repo_id, + repo_type=repo_type, revision=revision, token=token, endpoint=endpoint if endpoint is not None else get_endpoint()) assert repo_info.sha is not None, "Repo info returned from server must have a revision sha." assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list." - model_files = list( + repo_files = list( filter_repo_objects( items=[f.rfilename for f in repo_info.siblings], allow_patterns=allow_patterns, @@ -69,10 +75,10 @@ def snapshot_download( with tempfile.TemporaryDirectory( dir=temporary_cache_dir) as temp_cache_dir: - for model_file in model_files: - model_file_info = pack_model_file_info(model_file, revision) - if cache.exists(model_file_info): - file_name = os.path.basename(model_file_info['Path']) + for repo_file in repo_files: + repo_file_info = pack_repo_file_info(repo_file, revision) + if cache.exists(repo_file_info): + file_name = os.path.basename(repo_file_info['Path']) print( f'File {file_name} already in cache, skip downloading!' ) @@ -81,20 +87,21 @@ def snapshot_download( # get download url url = get_file_download_url( model_id=repo_id, - file_path=model_file, + file_path=repo_file, + repo_type=repo_type, revision=revision) # todo support parallel download api http_get( url=url, local_dir=temp_cache_dir, - file_name=model_file, + file_name=repo_file, headers=headers, cookies=cookies, token=token) # todo using hash to check file integrity - temp_file = os.path.join(temp_cache_dir, model_file) - cache.put_file(model_file_info, temp_file) + temp_file = os.path.join(temp_cache_dir, repo_file) + cache.put_file(repo_file_info, temp_file) cache.save_model_version(revision_info={'Revision': revision}) return os.path.join(cache.get_root_location()) diff --git a/pycsghub/test/snapshot_download_test.py b/pycsghub/test/snapshot_download_test.py index a2ddf8f..76411e5 100644 --- a/pycsghub/test/snapshot_download_test.py +++ b/pycsghub/test/snapshot_download_test.py @@ -19,7 +19,6 @@ def test_snapshot_download(self): token=token) print(result) - def test_singlefile_download(self): token = ("f3a7b9c1d6e5f8e2a1b5d4f9e6a2b8d7c3a4e2b1d9f6e7a8d2c5a7b4c1e3f5b8a1d4f" "9b7d6e2f8a5d3b1e7f9c6a8b2d1e4f7d5b6e9f2a4b3c8e1d7f995hd82hf") @@ -47,7 +46,17 @@ def test_singlefile_download_not_exist(self): except InvalidParameter as e: self.assertEqual(str(e), "file wolegequ.hehe not in repo wayne0019/lwfmodel") - + def test_snapshot_download(self): + token = ("4e5b97a59c1f8a954954971bf1cdbf3ce61a35sd5") + endpoint = "https://hub.opencsg.com" + repo_id = 'AIWizards/tmmluplus' + cache_dir = '~/Downloads/' + result = snapshot_download(repo_id, + repo_type="dataset", + cache_dir=cache_dir, + endpoint=endpoint, + token=token) + print(result) if __name__ == '__main__': diff --git a/pycsghub/utils.py b/pycsghub/utils.py index 9f506ec..5e1e356 100644 --- a/pycsghub/utils.py +++ b/pycsghub/utils.py @@ -2,9 +2,9 @@ from pathlib import Path import os -from pycsghub.constants import MODEL_ID_SEPARATOR, DEFAULT_CSG_GROUP, DEFAULT_CSGHUB_DOMAIN +from pycsghub.constants import MODEL_ID_SEPARATOR, DEFAULT_CSG_GROUP, DEFAULT_CSGHUB_DOMAIN, REPO_TYPE_DATASET import requests -from huggingface_hub.hf_api import ModelInfo +from huggingface_hub.hf_api import ModelInfo, DatasetInfo, SpaceInfo import urllib import hashlib from pycsghub.errors import FileIntegrityError @@ -94,7 +94,7 @@ def get_repo_info( files_metadata: bool = False, token: Union[bool, str, None] = None, endpoint: Optional[str] = None -) -> ModelInfo: +) -> Union[ModelInfo, DatasetInfo, SpaceInfo]: """ Get the info object for a given repo of a given type. @@ -136,7 +136,10 @@ def get_repo_info( """ if repo_type is None or repo_type == "model": method = model_info - # todo dataset and spaceset are now not supported + elif repo_type == "dataset": + method = dataset_info + elif repo_type == "space": + method = space_info else: raise ValueError("Unsupported repo type.") return method( @@ -149,6 +152,136 @@ def get_repo_info( ) +def dataset_info( + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + token: Union[bool, str, None] = None, + endpoint: Optional[str] = None, +) -> DatasetInfo: + """ + Get info on one specific dataset on huggingface.co. + + Dataset can be private if you pass an acceptable token. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the dataset repository from which to get the + information. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`hf_api.DatasetInfo`]: The dataset repository information. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + + """ + headers = build_csg_headers(token=token) + path = ( + f"{endpoint}/hf/api/datasets/{repo_id}" + if revision is None + else (f"{endpoint}/hf/api/datasets/{repo_id}/revision/{quote(revision, safe='')}") + ) + params = {} + if files_metadata: + params["blobs"] = True + r = requests.get(path, + headers=headers, + timeout=timeout, + params=params) + r.raise_for_status() + data = r.json() + return DatasetInfo(**data) + + +def space_info( + repo_id: str, + *, + revision: Optional[str] = None, + timeout: Optional[float] = None, + files_metadata: bool = False, + token: Union[bool, str, None] = None, + endpoint: Optional[str] = None, +) -> SpaceInfo: + """ + Get info on one specific Space on huggingface.co. + + Space can be private if you pass an acceptable token. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated + by a `/`. + revision (`str`, *optional*): + The revision of the space repository from which to get the + information. + timeout (`float`, *optional*): + Whether to set a timeout for the request to the Hub. + files_metadata (`bool`, *optional*): + Whether or not to retrieve metadata for files in the repository + (size, LFS metadata, etc). Defaults to `False`. + token (Union[bool, str, None], optional): + A valid user access token (string). Defaults to the locally saved + token, which is the recommended method for authentication (see + https://huggingface.co/docs/huggingface_hub/quick-start#authentication). + To disable authentication, pass `False`. + + Returns: + [`~hf_api.SpaceInfo`]: The space repository information. + + + + Raises the following errors: + + - [`~utils.RepositoryNotFoundError`] + If the repository to download from cannot be found. This may be because it doesn't exist, + or because it is set to `private` and you do not have access. + - [`~utils.RevisionNotFoundError`] + If the revision to download from cannot be found. + + + """ + headers = build_csg_headers(token=token) + path = ( + f"{endpoint}/hf/api/spaces/{repo_id}" + if revision is None + else (f"{endpoint}/hf/api/spaces/{repo_id}/revision/{quote(revision, safe='')}") + ) + params = {} + if files_metadata: + params["blobs"] = True + r = requests.get(path, + headers=headers, + timeout=timeout, + params=params) + r.raise_for_status() + data = r.json() + return SpaceInfo(**data) + + def model_info( repo_id: str, *, @@ -228,7 +361,9 @@ def get_endpoint(): def get_file_download_url(model_id: str, file_path: str, - revision: str) -> str: + revision: str, + repo_type: Optional[str] = None, + ) -> str: """Format file download url according to `model_id`, `revision` and `file_path`. Args: model_id (str): The model_id. @@ -241,6 +376,8 @@ def get_file_download_url(model_id: str, file_path = urllib.parse.quote_plus(file_path) revision = urllib.parse.quote_plus(revision) download_url_template = '{endpoint}/hf/{model_id}/resolve/{revision}/{file_path}' + if repo_type == REPO_TYPE_DATASET: + download_url_template = '{endpoint}/hf/datasets//{model_id}/resolve/{revision}/{file_path}' return download_url_template.format( endpoint=get_endpoint(), model_id=model_id, @@ -280,8 +417,8 @@ def compute_hash(file_path) -> str: return sha256_hash.hexdigest() -def pack_model_file_info(model_file_path, - revision) -> Dict[str, str]: - model_file_info = {'Path': model_file_path, - 'Revision': revision} - return model_file_info +def pack_repo_file_info(repo_file_path, + revision) -> Dict[str, str]: + repo_file_info = {'Path': repo_file_path, + 'Revision': revision} + return repo_file_info From 6573ba5eafee3d95f8a127bf776a21b212faba7d Mon Sep 17 00:00:00 2001 From: Xiang Zhen Gan Date: Wed, 10 Jul 2024 17:29:13 +0800 Subject: [PATCH 2/3] remove debug print --- pycsghub/file_download.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pycsghub/file_download.py b/pycsghub/file_download.py index c3dc1bc..0d7e5fc 100644 --- a/pycsghub/file_download.py +++ b/pycsghub/file_download.py @@ -193,7 +193,6 @@ def http_get(*, msg = 'File %s download incomplete, content_length: %s but the file downloaded length: %s, please download again' % ( file_name, total_content_length, downloaded_length) raise FileDownloadError(msg) - print(f"{temp_file.name}, {local_dir}, {file_name}") # fix folder recursive issue os.makedirs(os.path.dirname(os.path.join(local_dir, file_name)), exist_ok=True) os.replace(temp_file.name, os.path.join(local_dir, file_name)) From 6fc671bd67872702e8e4473228e8795300273c79 Mon Sep 17 00:00:00 2001 From: Xiang Zhen Gan Date: Wed, 10 Jul 2024 22:34:16 +0800 Subject: [PATCH 3/3] set dataset dir for dataset --- pycsghub/file_download.py | 5 +++-- pycsghub/snapshot_download.py | 2 +- pycsghub/utils.py | 8 ++++++-- 3 files changed, 10 insertions(+), 5 deletions(-) diff --git a/pycsghub/file_download.py b/pycsghub/file_download.py index 0d7e5fc..bbcf209 100644 --- a/pycsghub/file_download.py +++ b/pycsghub/file_download.py @@ -52,10 +52,11 @@ def file_download( ignore_patterns: Optional[Union[List[str], str]] = None, headers: Optional[Dict[str, str]] = None, endpoint: Optional[str] = None, - token: Optional[str] = None + token: Optional[str] = None, + repo_type: Optional[str] = None ) -> str: if cache_dir is None: - cache_dir = get_cache_dir() + cache_dir = get_cache_dir(repo_type=repo_type) if isinstance(cache_dir, Path): cache_dir = str(cache_dir) temporary_cache_dir = os.path.join(cache_dir, 'temp') diff --git a/pycsghub/snapshot_download.py b/pycsghub/snapshot_download.py index 47e0b1d..389388e 100644 --- a/pycsghub/snapshot_download.py +++ b/pycsghub/snapshot_download.py @@ -36,7 +36,7 @@ def snapshot_download( if repo_type not in REPO_TYPES: raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}") if cache_dir is None: - cache_dir = get_cache_dir() + cache_dir = get_cache_dir(repo_type=repo_type) if isinstance(cache_dir, Path): cache_dir = str(cache_dir) temporary_cache_dir = os.path.join(cache_dir, 'temp') diff --git a/pycsghub/utils.py b/pycsghub/utils.py index 5e1e356..d348573 100644 --- a/pycsghub/utils.py +++ b/pycsghub/utils.py @@ -60,19 +60,23 @@ def model_id_to_group_owner_name(model_id: str) -> (str, str): return group_or_owner, name -def get_cache_dir(model_id: Optional[str] = None) -> Union[str, Path]: +def get_cache_dir(model_id: Optional[str] = None, repo_type: Optional[str] = None) -> Union[str, Path]: """cache dir precedence: function parameter > environment > ~/.cache/csg/hub Args: model_id (str, optional): The model id. + repo_type (str, optional): The repo type Returns: str: the model_id dir if model_id not None, otherwise cache root dir. """ default_cache_dir = get_default_cache_dir() + sub_dir = 'hub' + if repo_type == "dataset": + sub_dir = 'dataset' base_path = os.getenv('CSGHUB_CACHE', - os.path.join(default_cache_dir, 'hub')) + os.path.join(default_cache_dir, sub_dir)) return base_path if model_id is None else os.path.join( base_path, model_id + '/')