diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index a33cffedf8..01efac8676 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -155,6 +155,7 @@ "get_space_runtime", "like", "list_datasets", + "list_files_info", "list_liked_repos", "list_metrics", "list_models", @@ -434,6 +435,7 @@ def __dir__(): get_space_runtime, # noqa: F401 like, # noqa: F401 list_datasets, # noqa: F401 + list_files_info, # noqa: F401 list_liked_repos, # noqa: F401 list_metrics, # noqa: F401 list_models, # noqa: F401 diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 36bf6adaba..fe464a5d86 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -189,6 +189,7 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tu class BlobLfsInfo(TypedDict, total=False): size: int sha256: str + pointer_size: int @dataclass @@ -309,23 +310,21 @@ def __repr__(self) -> str: class RepoFile(ReprMixin): """ - Data structure that represents a public file inside a repo, accessible from - huggingface.co + Data structure that represents a public file inside a repo, accessible from huggingface.co Args: rfilename (str): - file name, relative to the repo root. This is the only attribute - that's guaranteed to be here, but under certain conditions there can - certain other stuff. + file name, relative to the repo root. This is the only attribute that's guaranteed to be here, but under + certain conditions there can certain other stuff. size (`int`, *optional*): - The file's size, in bytes. This attribute is present when `files_metadata` argument - of [`repo_info`] is set to `True`. It's `None` otherwise. + The file's size, in bytes. This attribute is present when `files_metadata` argument of [`repo_info`] is set + to `True`. It's `None` otherwise. blob_id (`str`, *optional*): - The file's git OID. This attribute is present when `files_metadata` argument - of [`repo_info`] is set to `True`. It's `None` otherwise. + The file's git OID. This attribute is present when `files_metadata` argument of [`repo_info`] is set to + `True`. It's `None` otherwise. lfs (`BlobLfsInfo`, *optional*): - The file's LFS metadata. This attribute is present when`files_metadata` argument - of [`repo_info`] is set to `True` and the file is stored with Git LFS. It's `None` otherwise. + The file's LFS metadata. This attribute is present when`files_metadata` argument of [`repo_info`] is set to + `True` and the file is stored with Git LFS. It's `None` otherwise. """ def __init__( @@ -1781,6 +1780,141 @@ def repo_info( files_metadata=files_metadata, ) + @validate_hf_hub_args + def list_files_info( + self, + repo_id: str, + paths: Union[List[str], str, None] = None, + *, + revision: Optional[str] = None, + repo_type: Optional[str] = None, + token: Optional[Union[bool, str]] = None, + ) -> Iterable[RepoFile]: + """ + List files on a repo and get information about them. + + Takes as input a list of paths. Those paths can be either files or folders. Two server endpoints are called: + 1. POST "/paths-info" to get information about the provided paths. Called once. + 2. GET "/tree?recursive=True" to paginate over the input folders. Called only if a folder path is provided as + input. Will be called multiple times to follow pagination. + If no path is provided as input, step 1. is ignored and all files from the repo are listed. + + Args: + repo_id (`str`): + A namespace (user or an organization) and a repo name separated by a `/`. + paths (`Union[List[str], str, None]`, *optional*): + The paths to get information about. Paths to files are directly resolved. Paths to folders are resolved + recursively which means that information is returned about all files in the folder and its subfolders. + If `None`, all files are returned (the default). If a path do not exist, it is ignored without raising + an exception. + revision (`str`, *optional*): + The revision of the repository from which to get the information. Defaults to `"main"` branch. + repo_type (`str`, *optional*): + The type of the repository from which to get the information (`"model"`, `"dataset"` or `"space"`. + Defaults to `"model"`. + token (`bool` or `str`, *optional*): + A valid authentication token (see https://huggingface.co/settings/token). If `None` or `True` and + machine is logged in (through `huggingface-cli login` or [`~huggingface_hub.login`]), token will be + retrieved from the cache. If `False`, token is not sent in the request header. + + Returns: + `Iterable[RepoFile]`: + The information about the files, as an iterable of [`RepoFile`] objects. The order of the files is + not guaranteed. + + Raises: + [`~utils.RepositoryNotFoundError`]: + If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo + does not exist. + [`~utils.RevisionNotFoundError`]: + If revision is not found (error 404) on the repo. + + Examples: + + Get information about files on a repo. + ```py + >>> from huggingface_hub import list_files_info + >>> files_info = list_files_info("lysandre/arxiv-nlp", ["README.md", "config.json"]) + >>> files_info + + >>> list(files_info) + [ + RepoFile: {"blob_id": "43bd404b159de6fba7c2f4d3264347668d43af25", "lfs": None, "rfilename": "README.md", "size": 391}, + RepoFile: {"blob_id": "2f9618c3a19b9a61add74f70bfb121335aeef666", "lfs": None, "rfilename": "config.json", "size": 554}, + ] + ``` + + List LFS files from the "vae/" folder in "stabilityai/stable-diffusion-2" repository. + + ```py + >>> from huggingface_hub import list_files_info + >>> [info.rfilename for info in list_files_info("stabilityai/stable-diffusion-2", "vae") if info.lfs is not None] + ['vae/diffusion_pytorch_model.bin', 'vae/diffusion_pytorch_model.safetensors'] + ``` + + List all files on a repo. + ```py + >>> from huggingface_hub import list_files_info + >>> [info.rfilename for info in list_files_info("glue", repo_type="dataset")] + ['.gitattributes', 'README.md', 'dataset_infos.json', 'glue.py'] + ``` + """ + repo_type = repo_type or REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION + headers = self._build_hf_headers(token=token) + + def _format_as_repo_file(info: Dict) -> RepoFile: + # Quick alias very specific to the server return type of /paths-info and /tree endpoints. Let's keep this + # logic here. + rfilename = info.pop("path") + size = info.pop("size") + blobId = info.pop("oid") + lfs = info.pop("lfs", None) + info.pop("type", None) # "file" or "folder" -> not needed in practice since we know it's a file + # "lastCommit": behavior might change server-side in the near future (it might become optional) + # In the meantime, let's remove it so that users don't expect it + # TODO: set it back when https://github.com/huggingface/moon-landing/issues/5993 is settled + info.pop("lastCommit", None) + if lfs is not None: + lfs = BlobLfsInfo(size=lfs["size"], sha256=lfs["oid"], pointer_size=lfs["pointerSize"]) + return RepoFile(rfilename=rfilename, size=size, blobId=blobId, lfs=lfs, **info) + + folder_paths = [] + if paths is None: + # `paths` is not provided => list all files from the repo + folder_paths.append("") + elif paths == []: + # corner case: server would return a 400 error if `paths` is an empty list. Let's return early. + return + else: + # `paths` is provided => get info about those + response = get_session().post( + f"{self.endpoint}/api/{repo_type}s/{repo_id}/paths-info/{revision}", + data={ + "paths": paths if isinstance(paths, list) else [paths], + # "expand": True, # TODO: related to "lastCommit" (see above). Do not return it for now. + }, + headers=headers, + ) + hf_raise_for_status(response) + paths_info = response.json() + + # List top-level files first + for path_info in paths_info: + if path_info["type"] == "file": + yield _format_as_repo_file(path_info) + else: + folder_paths.append(path_info["path"]) + + # List files in subdirectories + for path in folder_paths: + encoded_path = "/" + quote(path, safe="") if path else "" + tree_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tree/{revision}{encoded_path}" + for subpath_info in paginate(path=tree_url, headers=headers, params={"recursive": True}): + if subpath_info["type"] == "file": + yield _format_as_repo_file(subpath_info) + + @_deprecate_arguments(version="0.17", deprecated_args=["timeout"], custom_message="timeout is not used anymore.") @validate_hf_hub_args def list_repo_files( self, @@ -1796,35 +1930,26 @@ def list_repo_files( Args: repo_id (`str`): - A namespace (user or an organization) and a repo name separated - by a `/`. + A namespace (user or an organization) and a repo name separated by a `/`. revision (`str`, *optional*): - The revision of the model repository from which to get the - information. + The revision of the model repository from which to get the information. repo_type (`str`, *optional*): - Set to `"dataset"` or `"space"` if uploading to a dataset or - space, `None` or `"model"` if uploading to a model. Default is - `None`. - timeout (`float`, *optional*): - Whether to set a timeout for the request to the Hub. + Set to `"dataset"` or `"space"` if uploading to a dataset or space, `None` or `"model"` if uploading to + a model. Default is `None`. token (`bool` or `str`, *optional*): - A valid authentication token (see https://huggingface.co/settings/token). - If `None` or `True` and machine is logged in (through `huggingface-cli login` - or [`~huggingface_hub.login`]), token will be retrieved from the cache. - If `False`, token is not sent in the request header. + A valid authentication token (see https://huggingface.co/settings/token). If `None` or `True` and + machine is logged in (through `huggingface-cli login` or [`~huggingface_hub.login`]), token will be + retrieved from the cache. If `False`, token is not sent in the request header. Returns: `List[str]`: the list of files in a given repository. """ - # TODO: use https://huggingface.co/api/{repo_type}/{repo_id}/tree/{revision}/{subfolder} - repo_info = self.repo_info( - repo_id, - revision=revision, - repo_type=repo_type, - token=token, - timeout=timeout, - ) - return [f.rfilename for f in repo_info.siblings] + return [ + f.rfilename + for f in self.list_files_info( + repo_id=repo_id, paths=None, revision=revision, repo_type=repo_type, token=token + ) + ] @validate_hf_hub_args def list_repo_refs( @@ -4296,6 +4421,7 @@ def _parse_revision_from_pr_url(pr_url: str) -> str: list_repo_files = api.list_repo_files list_repo_refs = api.list_repo_refs list_repo_commits = api.list_repo_commits +list_files_info = api.list_files_info list_metrics = api.list_metrics diff --git a/src/huggingface_hub/hf_file_system.py b/src/huggingface_hub/hf_file_system.py index 25235f6183..c8d7d43b3b 100644 --- a/src/huggingface_hub/hf_file_system.py +++ b/src/huggingface_hub/hf_file_system.py @@ -291,12 +291,14 @@ def ls( return out if detail else [o["name"] for o in out] def _iter_tree(self, path: str, revision: Optional[str] = None): + # TODO: use HfApi.list_files_info instead when it supports "lastCommit" and "expand=True" + # See https://github.com/huggingface/moon-landing/issues/5993 resolved_path = self.resolve_path(path, revision=revision) path = f"{self._api.endpoint}/api/{resolved_path.repo_type}s/{resolved_path.repo_id}/tree/{safe_quote(resolved_path.revision)}/{resolved_path.path_in_repo}".rstrip( "/" ) headers = self._api._build_hf_headers() - yield from paginate(path, params={}, headers=headers) + yield from paginate(path, params={"expand": True}, headers=headers) def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None: resolved_path1 = self.resolve_path(path1, revision=revision) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 131a5f1c72..a9d1e9eacf 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -1031,6 +1031,116 @@ def test_create_commit_failing_implicit_delete_folder(self): ) +class HfApiListFilesInfoTest(HfApiCommonTest): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.repo_id = cls._api.create_repo(repo_id=repo_name()).repo_id + + cls._api.create_commit( + repo_id=cls.repo_id, + commit_message="A first repo", + operations=[ + CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="file.md"), + CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="lfs.bin"), + CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="1/file_1.md"), + CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="1/2/file_1_2.md"), + CommitOperationAdd(path_or_fileobj=b"data", path_in_repo="2/file_2.md"), + ], + ) + + cls._api.create_commit( + repo_id=cls.repo_id, + commit_message="Another commit", + operations=[ + CommitOperationAdd(path_or_fileobj=b"data2", path_in_repo="3/file_3.md"), + ], + ) + + @classmethod + def tearDownClass(cls): + cls._api.delete_repo(repo_id=cls.repo_id) + + def test_get_regular_file_info(self): + files = list(self._api.list_files_info(repo_id=self.repo_id, paths="file.md")) + self.assertEqual(len(files), 1) + file = files[0] + + self.assertEqual(file.rfilename, "file.md") + self.assertIsNone(file.lfs) + self.assertEqual(file.size, 4) + self.assertEqual(file.blob_id, "6320cd248dd8aeaab759d5871f8781b5c0505172") + + def test_get_lfs_file_info(self): + files = list(self._api.list_files_info(repo_id=self.repo_id, paths="lfs.bin")) + self.assertEqual(len(files), 1) + file = files[0] + + self.assertEqual(file.rfilename, "lfs.bin") + self.assertEqual( + file.lfs, + { + "size": 4, + "sha256": "3a6eb0790f39ac87c94f3856b2dd2c5d110e6811602261a9a923d3bb23adc8b7", + "pointer_size": 126, + }, + ) + self.assertEqual(file.size, 4) + self.assertEqual(file.blob_id, "0a828055346279420bd02a4221c177bbcdc045d8") + + def test_list_files(self): + files = list(self._api.list_files_info(repo_id=self.repo_id, paths=["file.md", "lfs.bin", "2/file_2.md"])) + self.assertEqual(len(files), 3) + self.assertEqual({f.rfilename for f in files}, {"file.md", "lfs.bin", "2/file_2.md"}) + + def test_list_files_and_folder(self): + files = list(self._api.list_files_info(repo_id=self.repo_id, paths=["file.md", "lfs.bin", "2"])) + self.assertEqual(len(files), 3) + self.assertEqual({f.rfilename for f in files}, {"file.md", "lfs.bin", "2/file_2.md"}) + + def test_list_unknown_path_among_other(self): + files = list(self._api.list_files_info(repo_id=self.repo_id, paths=["file.md", "unknown"])) + self.assertEqual(len(files), 1) + + def test_list_unknown_path_alone(self): + files = list(self._api.list_files_info(repo_id=self.repo_id, paths="unknown")) + self.assertEqual(len(files), 0) + + def test_list_folder_flat(self): + files = list(self._api.list_files_info(repo_id=self.repo_id, paths=["2"])) + self.assertEqual(len(files), 1) + self.assertEqual(files[0].rfilename, "2/file_2.md") + + def test_list_folder_recursively(self): + files = list(self._api.list_files_info(repo_id=self.repo_id, paths=["1"])) + self.assertEqual(len(files), 2) + self.assertEqual({f.rfilename for f in files}, {"1/2/file_1_2.md", "1/file_1.md"}) + + def test_list_repo_files_manually(self): + files = list(self._api.list_files_info(repo_id=self.repo_id)) + self.assertEqual(len(files), 7) + self.assertEqual( + {f.rfilename for f in files}, + {".gitattributes", "1/2/file_1_2.md", "1/file_1.md", "2/file_2.md", "3/file_3.md", "file.md", "lfs.bin"}, + ) + + def test_list_repo_files_alias(self): + self.assertEqual( + set(self._api.list_repo_files(repo_id=self.repo_id)), + {".gitattributes", "1/2/file_1_2.md", "1/file_1.md", "2/file_2.md", "3/file_3.md", "file.md", "lfs.bin"}, + ) + + def test_list_with_root_path_is_ignored(self): + # must use `paths=None` + files = list(self._api.list_files_info(repo_id=self.repo_id, paths="/")) + self.assertEqual(len(files), 0) + + def test_list_with_empty_path_is_invalid(self): + # must use `paths=None` + with self.assertRaises(BadRequestError): + list(self._api.list_files_info(repo_id=self.repo_id, paths="")) + + class HfApiTagEndpointTest(HfApiCommonTest): @retry_endpoint @use_tmp_repo("model")