Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support download dataset #14

Merged
merged 3 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,21 @@ result = snapshot_download(repo_id,
token=token)
```

### 数据集下载
```python
from pycsghub.snapshot_download import snapshot_download
token="xxxx"
endpoint = "https://hub.opencsg.com"
repo_id = 'AIWizards/tmmluplus'
repo_type="dataset"
cache_dir = '/Users/xiangzhen/Downloads/'
result = snapshot_download(repo_id,
repo_type=repo_type,
cache_dir=cache_dir,
endpoint=endpoint,
token=token)
```

### 兼容huggingface的模型加载

huggingface的transformers库支持直接输入huggingface上的repo_id以下载并读取相关模型,如下列所示:
Expand Down
14 changes: 14 additions & 0 deletions README_EN.md
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,20 @@ result = snapshot_download(repo_id,
endpoint=endpoint,
token=token)
```
### Download dataset
```python
from pycsghub.snapshot_download import snapshot_download
token="xxxx"
endpoint = "https://hub.opencsg.com"
repo_id = 'AIWizards/tmmluplus'
repo_type="dataset"
cache_dir = '/Users/xiangzhen/Downloads/'
result = snapshot_download(repo_id,
repo_type=repo_type,
cache_dir=cache_dir,
endpoint=endpoint,
token=token)
```

### Model loading compatible with huggingface

Expand Down
4 changes: 4 additions & 0 deletions pycsghub/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
API_FILE_DOWNLOAD_TIMEOUT = 5
API_FILE_DOWNLOAD_RETRY_TIMES = 5

REPO_TYPE_DATASET = "dataset"
REPO_TYPE_MODEL = "model"
REPO_TYPE_SPACE = "space"
REPO_TYPES = [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET]
ganisback marked this conversation as resolved.
Show resolved Hide resolved

CSGHUB_HOME = os.environ.get('CSGHUB_HOME', '/home')
CSGHUB_TOKEN_PATH = os.environ.get("CSGHUB_TOKEN_PATH", os.path.join(CSGHUB_HOME, "token"))
Expand Down
37 changes: 19 additions & 18 deletions pycsghub/file_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pycsghub.utils import (build_csg_headers,
get_cache_dir,
model_id_to_group_owner_name,
pack_model_file_info,
pack_repo_file_info,
get_file_download_url,
get_endpoint)
from pycsghub.constants import (API_FILE_DOWNLOAD_RETRY_TIMES,
Expand All @@ -39,6 +39,7 @@ def csg_hub_download():
def get_csg_hub_url():
pass


def file_download(
repo_id: str,
*,
Expand All @@ -51,10 +52,11 @@ def file_download(
ignore_patterns: Optional[Union[List[str], str]] = None,
headers: Optional[Dict[str, str]] = None,
endpoint: Optional[str] = None,
token: Optional[str] = None
token: Optional[str] = None,
repo_type: Optional[str] = None
) -> str:
if cache_dir is None:
cache_dir = get_cache_dir()
cache_dir = get_cache_dir(repo_type=repo_type)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
temporary_cache_dir = os.path.join(cache_dir, 'temp')
Expand Down Expand Up @@ -98,9 +100,9 @@ def file_download(

with tempfile.TemporaryDirectory(
dir=temporary_cache_dir) as temp_cache_dir:
model_file_info = pack_model_file_info(file_name, revision)
if not cache.exists(model_file_info):
file_name = os.path.basename(model_file_info['Path'])
repo_file_info = pack_repo_file_info(file_name, revision)
if not cache.exists(repo_file_info):
file_name = os.path.basename(repo_file_info['Path'])
# get download url
url = get_file_download_url(
model_id=repo_id,
Expand All @@ -117,14 +119,15 @@ def file_download(

# todo using hash to check file integrity
temp_file = os.path.join(temp_cache_dir, file_name)
cache.put_file(model_file_info, temp_file)
cache.put_file(repo_file_info, temp_file)
else:
print(
f'File {file_name} already in cache, skip downloading!'
)
cache.save_model_version(revision_info={'Revision': revision})
return os.path.join(cache.get_root_location(), file_name)


def http_get(*,
url: str,
local_dir: str,
Expand Down Expand Up @@ -153,19 +156,20 @@ def http_get(*,
downloaded_size = temp_file.tell()
if downloaded_size > 0:
get_headers['Range'] = 'bytes=%d-' % downloaded_size
r = requests.get(url, headers=get_headers, stream=True, cookies=cookies, timeout=API_FILE_DOWNLOAD_TIMEOUT)
r = requests.get(url, headers=get_headers, stream=True,
cookies=cookies, timeout=API_FILE_DOWNLOAD_TIMEOUT)
r.raise_for_status()
accept_ranges = r.headers.get('Accept-Ranges')
content_length = r.headers.get('Content-Length')
if accept_ranges == 'bytes':
if downloaded_size == 0:
total_content_length = int(content_length) if content_length is not None else None
total_content_length = int(content_length) if content_length is not None else None
else:
if downloaded_size > 0:
temp_file.truncate(0)
downloaded_size = temp_file.tell()
total_content_length = int(content_length) if content_length is not None else None
total_content_length = int(content_length) if content_length is not None else None

progress = tqdm(
unit='B',
unit_scale=True,
Expand All @@ -187,8 +191,11 @@ def http_get(*,
downloaded_length = os.path.getsize(temp_file.name)
if total_content_length != downloaded_length:
os.remove(temp_file.name)
msg = 'File %s download incomplete, content_length: %s but the file downloaded length: %s, please download again' % (file_name, total_content_length, downloaded_length)
msg = 'File %s download incomplete, content_length: %s but the file downloaded length: %s, please download again' % (
file_name, total_content_length, downloaded_length)
raise FileDownloadError(msg)
# fix folder recursive issue
os.makedirs(os.path.dirname(os.path.join(local_dir, file_name)), exist_ok=True)
os.replace(temp_file.name, os.path.join(local_dir, file_name))
return

Expand All @@ -208,9 +215,3 @@ def http_get(*,
file_name=file_name,
headers=headers,
cookies=cookies)






31 changes: 19 additions & 12 deletions pycsghub/snapshot_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@
model_id_to_group_owner_name)
from pycsghub.cache import ModelFileSystemCache
from pycsghub.utils import (get_cache_dir,
pack_model_file_info,
pack_repo_file_info,
get_endpoint)
from huggingface_hub.utils import filter_repo_objects
from pycsghub.file_download import http_get
from pycsghub.constants import DEFAULT_REVISION
from pycsghub.constants import DEFAULT_REVISION, REPO_TYPES
from pycsghub import utils


def snapshot_download(
repo_id: str,
*,
repo_type: Optional[str] = None,
revision: Optional[str] = DEFAULT_REVISION,
cache_dir: Union[str, Path, None] = None,
local_files_only: Optional[bool] = False,
Expand All @@ -30,8 +31,12 @@ def snapshot_download(
endpoint: Optional[str] = None,
token: Optional[str] = None
) -> str:
if repo_type is None:
repo_type = "model"
if repo_type not in REPO_TYPES:
raise ValueError(f"Invalid repo type: {repo_type}. Accepted repo types are: {str(REPO_TYPES)}")
if cache_dir is None:
cache_dir = get_cache_dir()
cache_dir = get_cache_dir(repo_type=repo_type)
if isinstance(cache_dir, Path):
cache_dir = str(cache_dir)
temporary_cache_dir = os.path.join(cache_dir, 'temp')
Expand All @@ -53,13 +58,14 @@ def snapshot_download(
# make headers
# todo need to add cookies?
repo_info = utils.get_repo_info(repo_id,
repo_type=repo_type,
revision=revision,
token=token,
endpoint=endpoint if endpoint is not None else get_endpoint())

assert repo_info.sha is not None, "Repo info returned from server must have a revision sha."
assert repo_info.siblings is not None, "Repo info returned from server must have a siblings list."
model_files = list(
repo_files = list(
filter_repo_objects(
items=[f.rfilename for f in repo_info.siblings],
allow_patterns=allow_patterns,
Expand All @@ -69,10 +75,10 @@ def snapshot_download(

with tempfile.TemporaryDirectory(
dir=temporary_cache_dir) as temp_cache_dir:
for model_file in model_files:
model_file_info = pack_model_file_info(model_file, revision)
if cache.exists(model_file_info):
file_name = os.path.basename(model_file_info['Path'])
for repo_file in repo_files:
repo_file_info = pack_repo_file_info(repo_file, revision)
if cache.exists(repo_file_info):
file_name = os.path.basename(repo_file_info['Path'])
print(
f'File {file_name} already in cache, skip downloading!'
)
Expand All @@ -81,20 +87,21 @@ def snapshot_download(
# get download url
url = get_file_download_url(
model_id=repo_id,
file_path=model_file,
file_path=repo_file,
repo_type=repo_type,
revision=revision)
# todo support parallel download api
http_get(
url=url,
local_dir=temp_cache_dir,
file_name=model_file,
file_name=repo_file,
headers=headers,
cookies=cookies,
token=token)

# todo using hash to check file integrity
temp_file = os.path.join(temp_cache_dir, model_file)
cache.put_file(model_file_info, temp_file)
temp_file = os.path.join(temp_cache_dir, repo_file)
cache.put_file(repo_file_info, temp_file)

cache.save_model_version(revision_info={'Revision': revision})
return os.path.join(cache.get_root_location())
13 changes: 11 additions & 2 deletions pycsghub/test/snapshot_download_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ def test_snapshot_download(self):
token=token)
print(result)


def test_singlefile_download(self):
token = ("f3a7b9c1d6e5f8e2a1b5d4f9e6a2b8d7c3a4e2b1d9f6e7a8d2c5a7b4c1e3f5b8a1d4f"
"9b7d6e2f8a5d3b1e7f9c6a8b2d1e4f7d5b6e9f2a4b3c8e1d7f995hd82hf")
Expand Down Expand Up @@ -47,7 +46,17 @@ def test_singlefile_download_not_exist(self):
except InvalidParameter as e:
self.assertEqual(str(e), "file wolegequ.hehe not in repo wayne0019/lwfmodel")


def test_snapshot_download(self):
token = ("4e5b97a59c1f8a954954971bf1cdbf3ce61a35sd5")
endpoint = "https://hub.opencsg.com"
repo_id = 'AIWizards/tmmluplus'
cache_dir = '~/Downloads/'
result = snapshot_download(repo_id,
repo_type="dataset",
cache_dir=cache_dir,
endpoint=endpoint,
token=token)
print(result)


if __name__ == '__main__':
Expand Down
Loading