Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement list_files_info + tests #1435

Merged
merged 8 commits into from
Apr 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/huggingface_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
"get_space_runtime",
"like",
"list_datasets",
"list_files_info",
"list_liked_repos",
"list_metrics",
"list_models",
Expand Down Expand Up @@ -434,6 +435,7 @@ def __dir__():
get_space_runtime, # noqa: F401
like, # noqa: F401
list_datasets, # noqa: F401
list_files_info, # noqa: F401
list_liked_repos, # noqa: F401
list_metrics, # noqa: F401
list_models, # noqa: F401
Expand Down
192 changes: 159 additions & 33 deletions src/huggingface_hub/hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tu
class BlobLfsInfo(TypedDict, total=False):
size: int
sha256: str
pointer_size: int


@dataclass
Expand Down Expand Up @@ -309,23 +310,21 @@ def __repr__(self) -> str:

class RepoFile(ReprMixin):
"""
Data structure that represents a public file inside a repo, accessible from
huggingface.co
Data structure that represents a public file inside a repo, accessible from huggingface.co

Args:
rfilename (str):
file name, relative to the repo root. This is the only attribute
that's guaranteed to be here, but under certain conditions there can
certain other stuff.
file name, relative to the repo root. This is the only attribute that's guaranteed to be here, but under
certain conditions there can certain other stuff.
size (`int`, *optional*):
The file's size, in bytes. This attribute is present when `files_metadata` argument
of [`repo_info`] is set to `True`. It's `None` otherwise.
The file's size, in bytes. This attribute is present when `files_metadata` argument of [`repo_info`] is set
to `True`. It's `None` otherwise.
blob_id (`str`, *optional*):
The file's git OID. This attribute is present when `files_metadata` argument
of [`repo_info`] is set to `True`. It's `None` otherwise.
The file's git OID. This attribute is present when `files_metadata` argument of [`repo_info`] is set to
`True`. It's `None` otherwise.
lfs (`BlobLfsInfo`, *optional*):
The file's LFS metadata. This attribute is present when`files_metadata` argument
of [`repo_info`] is set to `True` and the file is stored with Git LFS. It's `None` otherwise.
The file's LFS metadata. This attribute is present when`files_metadata` argument of [`repo_info`] is set to
`True` and the file is stored with Git LFS. It's `None` otherwise.
"""

def __init__(
Expand Down Expand Up @@ -1781,6 +1780,141 @@ def repo_info(
files_metadata=files_metadata,
)

@validate_hf_hub_args
def list_files_info(
self,
repo_id: str,
paths: Union[List[str], str, None] = None,
*,
revision: Optional[str] = None,
repo_type: Optional[str] = None,
token: Optional[Union[bool, str]] = None,
) -> Iterable[RepoFile]:
"""
List files on a repo and get information about them.

Takes as input a list of paths. Those paths can be either files or folders. Two server endpoints are called:
1. POST "/paths-info" to get information about the provided paths. Called once.
2. GET "/tree?recursive=True" to paginate over the input folders. Called only if a folder path is provided as
input. Will be called multiple times to follow pagination.
If no path is provided as input, step 1. is ignored and all files from the repo are listed.

Args:
repo_id (`str`):
A namespace (user or an organization) and a repo name separated by a `/`.
paths (`Union[List[str], str, None]`, *optional*):
The paths to get information about. Paths to files are directly resolved. Paths to folders are resolved
recursively which means that information is returned about all files in the folder and its subfolders.
If `None`, all files are returned (the default). If a path do not exist, it is ignored without raising
an exception.
revision (`str`, *optional*):
The revision of the repository from which to get the information. Defaults to `"main"` branch.
repo_type (`str`, *optional*):
The type of the repository from which to get the information (`"model"`, `"dataset"` or `"space"`.
Defaults to `"model"`.
token (`bool` or `str`, *optional*):
A valid authentication token (see https://huggingface.co/settings/token). If `None` or `True` and
machine is logged in (through `huggingface-cli login` or [`~huggingface_hub.login`]), token will be
retrieved from the cache. If `False`, token is not sent in the request header.

Returns:
`Iterable[RepoFile]`:
The information about the files, as an iterable of [`RepoFile`] objects. The order of the files is
not guaranteed.

Raises:
[`~utils.RepositoryNotFoundError`]:
If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo
does not exist.
[`~utils.RevisionNotFoundError`]:
If revision is not found (error 404) on the repo.

Examples:

Get information about files on a repo.
```py
>>> from huggingface_hub import list_files_info
>>> files_info = list_files_info("lysandre/arxiv-nlp", ["README.md", "config.json"])
>>> files_info
<generator object HfApi.list_files_info at 0x7f93b848e730>
>>> list(files_info)
[
RepoFile: {"blob_id": "43bd404b159de6fba7c2f4d3264347668d43af25", "lfs": None, "rfilename": "README.md", "size": 391},
RepoFile: {"blob_id": "2f9618c3a19b9a61add74f70bfb121335aeef666", "lfs": None, "rfilename": "config.json", "size": 554},
]
```

List LFS files from the "vae/" folder in "stabilityai/stable-diffusion-2" repository.

```py
>>> from huggingface_hub import list_files_info
>>> [info.rfilename for info in list_files_info("stabilityai/stable-diffusion-2", "vae") if info.lfs is not None]
['vae/diffusion_pytorch_model.bin', 'vae/diffusion_pytorch_model.safetensors']
```

List all files on a repo.
```py
>>> from huggingface_hub import list_files_info
>>> [info.rfilename for info in list_files_info("glue", repo_type="dataset")]
['.gitattributes', 'README.md', 'dataset_infos.json', 'glue.py']
```
"""
repo_type = repo_type or REPO_TYPE_MODEL
revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION
headers = self._build_hf_headers(token=token)

def _format_as_repo_file(info: Dict) -> RepoFile:
# Quick alias very specific to the server return type of /paths-info and /tree endpoints. Let's keep this
# logic here.
rfilename = info.pop("path")
size = info.pop("size")
blobId = info.pop("oid")
lfs = info.pop("lfs", None)
info.pop("type", None) # "file" or "folder" -> not needed in practice since we know it's a file
# "lastCommit": behavior might change server-side in the near future (it might become optional)
# In the meantime, let's remove it so that users don't expect it
# TODO: set it back when https://github.com/huggingface/moon-landing/issues/5993 is settled
info.pop("lastCommit", None)
if lfs is not None:
lfs = BlobLfsInfo(size=lfs["size"], sha256=lfs["oid"], pointer_size=lfs["pointerSize"])
return RepoFile(rfilename=rfilename, size=size, blobId=blobId, lfs=lfs, **info)

folder_paths = []
if paths is None:
# `paths` is not provided => list all files from the repo
folder_paths.append("")
elif paths == []:
# corner case: server would return a 400 error if `paths` is an empty list. Let's return early.
return
else:
# `paths` is provided => get info about those
response = get_session().post(
f"{self.endpoint}/api/{repo_type}s/{repo_id}/paths-info/{revision}",
data={
"paths": paths if isinstance(paths, list) else [paths],
# "expand": True, # TODO: related to "lastCommit" (see above). Do not return it for now.
},
headers=headers,
)
hf_raise_for_status(response)
paths_info = response.json()

# List top-level files first
for path_info in paths_info:
if path_info["type"] == "file":
yield _format_as_repo_file(path_info)
else:
folder_paths.append(path_info["path"])

# List files in subdirectories
for path in folder_paths:
encoded_path = "/" + quote(path, safe="") if path else ""
tree_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tree/{revision}{encoded_path}"
for subpath_info in paginate(path=tree_url, headers=headers, params={"recursive": True}):
if subpath_info["type"] == "file":
yield _format_as_repo_file(subpath_info)

@_deprecate_arguments(version="0.17", deprecated_args=["timeout"], custom_message="timeout is not used anymore.")
@validate_hf_hub_args
def list_repo_files(
self,
Expand All @@ -1796,35 +1930,26 @@ def list_repo_files(

Args:
repo_id (`str`):
A namespace (user or an organization) and a repo name separated
by a `/`.
A namespace (user or an organization) and a repo name separated by a `/`.
revision (`str`, *optional*):
The revision of the model repository from which to get the
information.
The revision of the model repository from which to get the information.
repo_type (`str`, *optional*):
Set to `"dataset"` or `"space"` if uploading to a dataset or
space, `None` or `"model"` if uploading to a model. Default is
`None`.
timeout (`float`, *optional*):
Whether to set a timeout for the request to the Hub.
Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to
a model. Default is `None`.
token (`bool` or `str`, *optional*):
A valid authentication token (see https://huggingface.co/settings/token).
If `None` or `True` and machine is logged in (through `huggingface-cli login`
or [`~huggingface_hub.login`]), token will be retrieved from the cache.
If `False`, token is not sent in the request header.
A valid authentication token (see https://huggingface.co/settings/token). If `None` or `True` and
machine is logged in (through `huggingface-cli login` or [`~huggingface_hub.login`]), token will be
retrieved from the cache. If `False`, token is not sent in the request header.

Returns:
`List[str]`: the list of files in a given repository.
"""
# TODO: use https://huggingface.co/api/{repo_type}/{repo_id}/tree/{revision}/{subfolder}
repo_info = self.repo_info(
repo_id,
revision=revision,
repo_type=repo_type,
token=token,
timeout=timeout,
)
return [f.rfilename for f in repo_info.siblings]
return [
f.rfilename
for f in self.list_files_info(
repo_id=repo_id, paths=None, revision=revision, repo_type=repo_type, token=token
)
]

@validate_hf_hub_args
def list_repo_refs(
Expand Down Expand Up @@ -4296,6 +4421,7 @@ def _parse_revision_from_pr_url(pr_url: str) -> str:
list_repo_files = api.list_repo_files
list_repo_refs = api.list_repo_refs
list_repo_commits = api.list_repo_commits
list_files_info = api.list_files_info

list_metrics = api.list_metrics

Expand Down
4 changes: 3 additions & 1 deletion src/huggingface_hub/hf_file_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,12 +291,14 @@ def ls(
return out if detail else [o["name"] for o in out]

def _iter_tree(self, path: str, revision: Optional[str] = None):
# TODO: use HfApi.list_files_info instead when it supports "lastCommit" and "expand=True"
# See https://github.com/huggingface/moon-landing/issues/5993
resolved_path = self.resolve_path(path, revision=revision)
path = f"{self._api.endpoint}/api/{resolved_path.repo_type}s/{resolved_path.repo_id}/tree/{safe_quote(resolved_path.revision)}/{resolved_path.path_in_repo}".rstrip(
"/"
)
headers = self._api._build_hf_headers()
yield from paginate(path, params={}, headers=headers)
yield from paginate(path, params={"expand": True}, headers=headers)

def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None:
resolved_path1 = self.resolve_path(path1, revision=revision)
Expand Down
110 changes: 110 additions & 0 deletions tests/test_hf_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1031,6 +1031,116 @@ def test_create_commit_failing_implicit_delete_folder(self):
)


class HfApiListFilesInfoTest(HfApiCommonTest):
@classmethod
def setUpClass(cls):
super().setUpClass()
cls.repo_id = cls._api.create_repo(repo_id=repo_name()).repo_id

cls._api.create_commit(
repo_id=cls.repo_id,
commit_message="A first repo",
operations=[
CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="file.md"),
CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="lfs.bin"),
CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="1/file_1.md"),
CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="1/2/file_1_2.md"),
CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="2/file_2.md"),
],
)

cls._api.create_commit(
repo_id=cls.repo_id,
commit_message="Another commit",
operations=[
CommitOperationAdd(path_or_fileobj=b"data2", path_in_repo="3/file_3.md"),
],
)

@classmethod
def tearDownClass(cls):
cls._api.delete_repo(repo_id=cls.repo_id)

def test_get_regular_file_info(self):
files = list(self._api.list_files_info(repo_id=self.repo_id, paths="file.md"))
self.assertEqual(len(files), 1)
file = files[0]

self.assertEqual(file.rfilename, "file.md")
self.assertIsNone(file.lfs)
self.assertEqual(file.size, 4)
self.assertEqual(file.blob_id, "6320cd248dd8aeaab759d5871f8781b5c0505172")

def test_get_lfs_file_info(self):
files = list(self._api.list_files_info(repo_id=self.repo_id, paths="lfs.bin"))
self.assertEqual(len(files), 1)
file = files[0]

self.assertEqual(file.rfilename, "lfs.bin")
self.assertEqual(
file.lfs,
{
"size": 4,
"sha256": "3a6eb0790f39ac87c94f3856b2dd2c5d110e6811602261a9a923d3bb23adc8b7",
"pointer_size": 126,
},
)
self.assertEqual(file.size, 4)
self.assertEqual(file.blob_id, "0a828055346279420bd02a4221c177bbcdc045d8")

def test_list_files(self):
files = list(self._api.list_files_info(repo_id=self.repo_id, paths=["file.md", "lfs.bin", "2/file_2.md"]))
self.assertEqual(len(files), 3)
self.assertEqual({f.rfilename for f in files}, {"file.md", "lfs.bin", "2/file_2.md"})

def test_list_files_and_folder(self):
files = list(self._api.list_files_info(repo_id=self.repo_id, paths=["file.md", "lfs.bin", "2"]))
self.assertEqual(len(files), 3)
self.assertEqual({f.rfilename for f in files}, {"file.md", "lfs.bin", "2/file_2.md"})

def test_list_unknown_path_among_other(self):
files = list(self._api.list_files_info(repo_id=self.repo_id, paths=["file.md", "unknown"]))
self.assertEqual(len(files), 1)

def test_list_unknown_path_alone(self):
files = list(self._api.list_files_info(repo_id=self.repo_id, paths="unknown"))
self.assertEqual(len(files), 0)

def test_list_folder_flat(self):
files = list(self._api.list_files_info(repo_id=self.repo_id, paths=["2"]))
self.assertEqual(len(files), 1)
self.assertEqual(files[0].rfilename, "2/file_2.md")

def test_list_folder_recursively(self):
files = list(self._api.list_files_info(repo_id=self.repo_id, paths=["1"]))
self.assertEqual(len(files), 2)
self.assertEqual({f.rfilename for f in files}, {"1/2/file_1_2.md", "1/file_1.md"})

def test_list_repo_files_manually(self):
files = list(self._api.list_files_info(repo_id=self.repo_id))
self.assertEqual(len(files), 7)
self.assertEqual(
{f.rfilename for f in files},
{".gitattributes", "1/2/file_1_2.md", "1/file_1.md", "2/file_2.md", "3/file_3.md", "file.md", "lfs.bin"},
)

def test_list_repo_files_alias(self):
self.assertEqual(
set(self._api.list_repo_files(repo_id=self.repo_id)),
{".gitattributes", "1/2/file_1_2.md", "1/file_1.md", "2/file_2.md", "3/file_3.md", "file.md", "lfs.bin"},
)

def test_list_with_root_path_is_ignored(self):
# must use `paths=None`
files = list(self._api.list_files_info(repo_id=self.repo_id, paths="/"))
self.assertEqual(len(files), 0)

def test_list_with_empty_path_is_invalid(self):
# must use `paths=None`
with self.assertRaises(BadRequestError):
list(self._api.list_files_info(repo_id=self.repo_id, paths=""))


class HfApiTagEndpointTest(HfApiCommonTest):
@retry_endpoint
@use_tmp_repo("model")
Expand Down