diff --git a/docs/source/_toctree.yml b/docs/source/_toctree.yml index de564fc48a..7ffd9e02c6 100644 --- a/docs/source/_toctree.yml +++ b/docs/source/_toctree.yml @@ -18,6 +18,8 @@ title: Repository - local: guides/search title: Search + - local: guides/hf_file_system + title: HfFileSystem - local: guides/inference title: Inference - local: guides/community @@ -52,6 +54,8 @@ title: Mixins & serialization methods - local: package_reference/inference_api title: Inference API + - local: package_reference/hf_file_system + title: HfFileSystem - local: package_reference/utilities title: Utilities - local: package_reference/community diff --git a/docs/source/guides/hf_file_system.mdx b/docs/source/guides/hf_file_system.mdx new file mode 100644 index 0000000000..7d0d5581a3 --- /dev/null +++ b/docs/source/guides/hf_file_system.mdx @@ -0,0 +1,105 @@ +# Interact with the Hub through the Filesystem API + +In addition to the [`HfApi`], the `huggingface_hub` library provides [`HfFileSystem`], a pythonic [fsspec-compatible](https://filesystem-spec.readthedocs.io/en/latest/) file interface to the Hugging Face Hub. The [`HfFileSystem`] builds of top of the [`HfApi`] and offers typical filesystem style operations like `cp`, `mv`, `ls`, `du`, `glob`, `get_file`, and `put_file`. + +## Usage + +```python +>>> from huggingface_hub import HfFileSystem +>>> fs = HfFileSystem() + +>>> # List all files in a directory +>>> fs.ls("datasets/my-username/my-dataset-repo/data", detail=False) +['datasets/my-username/my-dataset-repo/data/train.csv', 'datasets/my-username/my-dataset-repo/data/test.csv'] + +>>> # List all ".csv" files in a repo +>>> fs.glob("datasets/my-username/my-dataset-repo/**.csv") +['datasets/my-username/my-dataset-repo/data/train.csv', 'datasets/my-username/my-dataset-repo/data/test.csv'] + +>>> # Read a remote file +>>> with fs.open("datasets/my-username/my-dataset-repo/data/train.csv", "r") as f: +... train_data = f.readlines() + +>>> # Read the content of a remote file as a string +>>> train_data = fs.read_text("datasets/my-username/my-dataset-repo/data/train.csv", revision="dev") + +>>> # Write a remote file +>>> with fs.open("datasets/my-username/my-dataset-repo/data/validation.csv", "w") as f: +... f.write("text,label") +... f.write("Fantastic movie!,good") +``` + +The optional `revision` argument can be passed to run an operation from a specific commit such as a branch, tag name, or a commit hash. + +Unlike Python's built-in `open`, `fsspec`'s `open` defaults to binary mode, `"rb"`. This means you must explicitly set mode as `"r"` for reading and `"w"` for writing in text mode. Appending to a file (modes `"a"` and `"ab"`) is not supported yet. + +## Integrations + +The [`HfFileSystem`] can be used with any library that integrates `fsspec`, provided the URL follows the scheme: + +``` +hf://[][@]/ +``` + +The `repo_type_prefix` is `datasets/` for datasets, `spaces/` for spaces, and models don't need a prefix in the URL. + +Some interesting integrations where [`HfFileSystem`] simplifies interacting with the Hub are listed below: + +* Reading/writing a [Pandas](https://pandas.pydata.org/pandas-docs/stable/user_guide/io.html#reading-writing-remote-files) DataFrame from/to a Hub repository: + + ```python + >>> import pandas as pd + + >>> # Read a remote CSV file into a dataframe + >>> df = pd.read_csv("hf://datasets/my-username/my-dataset-repo/train.csv") + + >>> # Write a dataframe to a remote CSV file + >>> df.to_csv("hf://datasets/my-username/my-dataset-repo/test.csv") + ``` + +The same workflow can also be used for [Dask](https://docs.dask.org/en/stable/how-to/connect-to-remote-data.html) and [Polars](https://pola-rs.github.io/polars/py-polars/html/reference/io.html) DataFrames. + +* Querying (remote) Hub files with [DuckDB](https://duckdb.org/docs/guides/python/filesystems): + + ```python + >>> from huggingface_hub import HfFileSystem + >>> import duckdb + + >>> fs = HfFileSystem() + >>> duckdb.register_filesystem(fs) + >>> # Query a remote file and get the result back as a dataframe + >>> fs_query_file = "hf://datasets/my-username/my-dataset-repo/data_dir/data.parquet" + >>> df = duckdb.query(f"SELECT * FROM '{fs_query_file}' LIMIT 10").df() + ``` + +* Using the Hub as an array store with [Zarr](https://zarr.readthedocs.io/en/stable/tutorial.html#io-with-fsspec): + + ```python + >>> import numpy as np + >>> import zarr + + >>> embeddings = np.random.randn(50000, 1000).astype("float32") + + >>> # Write an array to a repo + >>> with zarr.open_group("hf://my-username/my-model-repo/array-store", mode="w") as root: + ... foo = root.create_group("embeddings") + ... foobar = foo.zeros('experiment_0', shape=(50000, 1000), chunks=(10000, 1000), dtype='f4') + ... foobar[:] = embeddings + + >>> # Read an array from a repo + >>> with zarr.open_group("hf://my-username/my-model-repo/array-store", mode="r") as root: + ... first_row = root["embeddings/experiment_0"][0] + ``` + +## Authentication + +In many cases, you must be logged in with a Hugging Face account to interact with the Hub. Refer to the [Login](../quick-start#login) section of the documentation to learn more about authentication methods on the Hub. + +It is also possible to login programmatically by passing your `token` as an argument to [`HfFileSystem`]: + +```python +>>> from huggingface_hub import HfFileSystem +>>> fs = HfFileSystem(token=token) +``` + +If you login this way, be careful not to accidentally leak the token when sharing your source code! diff --git a/docs/source/guides/overview.mdx b/docs/source/guides/overview.mdx index 96820925a5..6551c839d2 100644 --- a/docs/source/guides/overview.mdx +++ b/docs/source/guides/overview.mdx @@ -42,6 +42,15 @@ Take a look at these guides to learn how to use huggingface_hub to solve real-wo

+ +
+ HfFileSystem +

+ How to interact with the Hub through a convenient interface that mimics Python's file interface? +

+
+
diff --git a/docs/source/package_reference/hf_file_system.mdx b/docs/source/package_reference/hf_file_system.mdx new file mode 100644 index 0000000000..17c9258d75 --- /dev/null +++ b/docs/source/package_reference/hf_file_system.mdx @@ -0,0 +1,12 @@ +# Filesystem API + +The `HfFileSystem` class provides a pythonic file interface to the Hugging Face Hub based on [`fssepc`](https://filesystem-spec.readthedocs.io/en/latest/). + +## HfFileSystem + +`HfFileSystem` is based on [fsspec](https://filesystem-spec.readthedocs.io/en/latest/), so it is compatible with most of the APIs that it offers. For more details, check out [our guide](../guides/filesystem) and the fsspec's [API Reference](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.spec.AbstractFileSystem). + +[[autodoc]] HfFileSystem + - __init__ + - resolve_path + - ls diff --git a/setup.cfg b/setup.cfg index 5d4938d997..9cc27b091c 100644 --- a/setup.cfg +++ b/setup.cfg @@ -13,6 +13,7 @@ known_third_party = faiss-cpu fastprogress fire + fsspec fugashi git graphviz diff --git a/setup.py b/setup.py index 60cf5afbeb..947e5eb4ac 100644 --- a/setup.py +++ b/setup.py @@ -13,6 +13,7 @@ def get_version() -> str: install_requires = [ "filelock", + "fsspec", "requests", "tqdm>=4.42.1", "pyyaml>=5.1", diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index eb6dc7d1c1..5994e800c1 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -162,6 +162,11 @@ "upload_folder", "whoami", ], + "hf_file_system": [ + "HfFileSystem", + "HfFileSystemFile", + "HfFileSystemResolvedPath", + ], "hub_mixin": [ "ModelHubMixin", "PyTorchModelHubMixin", @@ -421,6 +426,11 @@ def __dir__(): upload_folder, # noqa: F401 whoami, # noqa: F401 ) + from .hf_file_system import ( + HfFileSystem, # noqa: F401 + HfFileSystemFile, # noqa: F401 + HfFileSystemResolvedPath, # noqa: F401 + ) from .hub_mixin import ( ModelHubMixin, # noqa: F401 PyTorchModelHubMixin, # noqa: F401 diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 349f6c2f5d..36bf6adaba 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -64,6 +64,7 @@ filter_repo_objects, hf_raise_for_status, logging, + paginate, parse_datetime, validate_hf_hub_args, ) @@ -71,7 +72,6 @@ _deprecate_arguments, _deprecate_list_output, ) -from .utils._pagination import paginate from .utils._typing import Literal, TypedDict from .utils.endpoint_helpers import ( AttributeDictionary, diff --git a/src/huggingface_hub/hf_file_system.py b/src/huggingface_hub/hf_file_system.py new file mode 100644 index 0000000000..5d68740cb4 --- /dev/null +++ b/src/huggingface_hub/hf_file_system.py @@ -0,0 +1,439 @@ +import itertools +import os +import tempfile +from dataclasses import dataclass +from datetime import datetime +from glob import has_magic +from typing import Any, Dict, List, Optional, Tuple, Union +from urllib.parse import quote, unquote + +import fsspec +import requests + +from ._commit_api import CommitOperationDelete +from .constants import DEFAULT_REVISION, ENDPOINT, REPO_TYPE_MODEL, REPO_TYPES_MAPPING, REPO_TYPES_URL_PREFIXES +from .hf_api import HfApi +from .utils import ( + EntryNotFoundError, + HFValidationError, + RepositoryNotFoundError, + RevisionNotFoundError, + hf_raise_for_status, + http_backoff, + paginate, + parse_datetime, +) + + +@dataclass +class HfFileSystemResolvedPath: + """Data structure containing information about a resolved hffs path.""" + + repo_type: str + repo_id: str + revision: str + path_in_repo: str + + def unresolve(self) -> str: + return ( + f"{REPO_TYPES_URL_PREFIXES.get(self.repo_type, '') + self.repo_id}@{safe_quote(self.revision)}/{self.path_in_repo}" + .rstrip("/") + ) + + +class HfFileSystem(fsspec.AbstractFileSystem): + """ + Access a remote Hugging Face Hub repository as if were a local file system. + + Args: + endpoint (`str`, *optional*): + The endpoint to use. If not provided, the default one (https://huggingface.co) is used. + token (`str`, *optional*): + Authentication token, obtained with [`HfApi.login`] method. Will default to the stored token. + + Usage: + + ```python + >>> import hffs + + >>> fs = hffs.HfFileSystem() + + >>> # List files + >>> fs.glob("my-username/my-model/*.bin") + ['my-username/my-model/pytorch_model.bin'] + >>> fs.ls("datasets/my-username/my-dataset", detail=False) + ['datasets/my-username/my-dataset/.gitattributes', 'datasets/my-username/my-dataset/README.md', 'datasets/my-username/my-dataset/data.json'] + + >>> # Read/write files + >>> with fs.open("my-username/my-model/pytorch_model.bin") as f: + ... data = f.read() + >>> with fs.open("my-username/my-model/pytorch_model.bin", "wb") as f: + ... f.write(data) + ``` + """ + + root_marker = "" + protocol = "hf" + + def __init__( + self, + *args, + endpoint: Optional[str] = None, + token: Optional[str] = None, + **storage_options, + ): + super().__init__(*args, **storage_options) + self.endpoint = endpoint or ENDPOINT + self.token = token + self._api = HfApi(endpoint=endpoint, token=token) + # Maps (repo_type, repo_id, revision) to a 2-tuple with: + # * the 1st element indicating whether the repositoy and the revision exist + # * the 2nd element being the exception raised if the repository or revision doesn't exist + self._repo_and_revision_exists_cache: Dict[ + Tuple[str, str, Optional[str]], Tuple[bool, Optional[Exception]] + ] = {} + + def _repo_and_revision_exist( + self, repo_type: str, repo_id: str, revision: Optional[str] + ) -> Tuple[bool, Optional[Exception]]: + if (repo_type, repo_id, revision) not in self._repo_and_revision_exists_cache: + try: + self._api.repo_info(repo_id, revision=revision, repo_type=repo_type) + except (RepositoryNotFoundError, HFValidationError) as e: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = False, e + except RevisionNotFoundError as e: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = False, e + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None + else: + self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] = True, None + self._repo_and_revision_exists_cache[(repo_type, repo_id, None)] = True, None + return self._repo_and_revision_exists_cache[(repo_type, repo_id, revision)] + + def resolve_path(self, path: str, revision: Optional[str] = None) -> HfFileSystemResolvedPath: + def _align_revision_in_path_with_revision( + revision_in_path: Optional[str], revision: Optional[str] + ) -> Optional[str]: + if revision is not None: + if revision_in_path is not None and revision_in_path != revision: + raise ValueError( + f'Revision specified in path ("{revision_in_path}") and in `revision` argument ("{revision}")' + " are not the same." + ) + else: + revision = revision_in_path + return revision + + path = self._strip_protocol(path) + if not path: + # can't list repositories at root + raise NotImplementedError("Access to repositories lists is not implemented.") + elif path.split("/")[0] + "/" in REPO_TYPES_URL_PREFIXES.values(): + if "/" not in path: + # can't list repositories at the repository type level + raise NotImplementedError("Acces to repositories lists is not implemented.") + repo_type, path = path.split("/", 1) + repo_type = REPO_TYPES_MAPPING[repo_type] + else: + repo_type = REPO_TYPE_MODEL + if path.count("/") > 0: + if "@" in path: + repo_id, revision_in_path = path.split("@", 1) + if "/" in revision_in_path: + revision_in_path, path_in_repo = revision_in_path.split("/", 1) + else: + path_in_repo = "" + revision_in_path = unquote(revision_in_path) + revision = _align_revision_in_path_with_revision(revision_in_path, revision) + repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + raise FileNotFoundError(path) from err + else: + repo_id_with_namespace = "/".join(path.split("/")[:2]) + path_in_repo_with_namespace = "/".join(path.split("/")[2:]) + repo_id_without_namespace = path.split("/")[0] + path_in_repo_without_namespace = "/".join(path.split("/")[1:]) + repo_id = repo_id_with_namespace + path_in_repo = path_in_repo_with_namespace + repo_and_revision_exist, err = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + if isinstance(err, (RepositoryNotFoundError, HFValidationError)): + repo_id = repo_id_without_namespace + path_in_repo = path_in_repo_without_namespace + repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + raise FileNotFoundError(path) from err + else: + raise FileNotFoundError(path) from err + else: + repo_id = path + path_in_repo = "" + if "@" in path: + repo_id, revision_in_path = path.split("@", 1) + revision_in_path = unquote(revision_in_path) + revision = _align_revision_in_path_with_revision(revision_in_path, revision) + repo_and_revision_exist, _ = self._repo_and_revision_exist(repo_type, repo_id, revision) + if not repo_and_revision_exist: + raise NotImplementedError("Acces to repositories lists is not implemented.") + + revision = revision if revision is not None else DEFAULT_REVISION + return HfFileSystemResolvedPath(repo_type, repo_id, revision, path_in_repo) + + def invalidate_cache(self, path: Optional[str] = None) -> None: + if not path: + self.dircache.clear() + self._repository_type_and_id_exists_cache.clear() + else: + path = self.resolve_path(path).unresolve() + while path: + self.dircache.pop(path, None) + path = self._parent(path) + + def _open( + self, + path: str, + mode: str = "rb", + revision: Optional[str] = None, + **kwargs, + ) -> "HfFileSystemFile": + if mode == "ab": + raise NotImplementedError("Appending to remote files is not yet supported.") + return HfFileSystemFile(self, path, mode=mode, revision=revision, **kwargs) + + def _rm(self, path: str, revision: Optional[str] = None, **kwargs) -> None: + resolved_path = self.resolve_path(path, revision=revision) + self._api.delete_file( + path_in_repo=resolved_path.path_in_repo, + repo_id=resolved_path.repo_id, + token=self.token, + repo_type=resolved_path.repo_type, + revision=resolved_path.revision, + commit_message=kwargs.get("commit_message"), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path.unresolve()) + + def rm( + self, + path: str, + recursive: bool = False, + maxdepth: Optional[int] = None, + revision: Optional[str] = None, + **kwargs, + ) -> None: + resolved_path = self.resolve_path(path, revision=revision) + root_path = REPO_TYPES_URL_PREFIXES.get(resolved_path.repo_type, "") + resolved_path.repo_id + paths = self.expand_path(path, recursive=recursive, maxdepth=maxdepth, revision=resolved_path.revision) + paths_in_repo = [path[len(root_path) + 1 :] for path in paths if not self.isdir(path)] + operations = [CommitOperationDelete(path_in_repo=path_in_repo) for path_in_repo in paths_in_repo] + commit_message = f"Delete {path} " + commit_message += "recursively " if recursive else "" + commit_message += f"up to depth {maxdepth} " if maxdepth is not None else "" + # TODO: use `commit_description` to list all the deleted paths? + self._api.create_commit( + repo_id=resolved_path.repo_id, + repo_type=resolved_path.repo_type, + token=self.token, + operations=operations, + revision=resolved_path.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path.unresolve()) + + def ls( + self, path: str, detail: bool = True, refresh: bool = False, revision: Optional[str] = None, **kwargs + ) -> List[Union[str, Dict[str, Any]]]: + """List the contents of a directory.""" + resolved_path = self.resolve_path(path, revision=revision) + revision_in_path = "@" + safe_quote(resolved_path.revision) + has_revision_in_path = revision_in_path in path + path = resolved_path.unresolve() + if path not in self.dircache or refresh: + path_prefix = ( + HfFileSystemResolvedPath( + resolved_path.repo_type, resolved_path.repo_id, resolved_path.revision, "" + ).unresolve() + + "/" + ) + tree_path = path + tree_iter = self._iter_tree(tree_path, revision=resolved_path.revision) + try: + tree_item = next(tree_iter) + except EntryNotFoundError: + if "/" in resolved_path.path_in_repo: + tree_path = self._parent(path) + tree_iter = self._iter_tree(tree_path, revision=resolved_path.revision) + else: + raise + else: + tree_iter = itertools.chain([tree_item], tree_iter) + child_infos = [] + for tree_item in tree_iter: + child_info = { + "name": path_prefix + tree_item["path"], + "size": tree_item["size"], + "type": tree_item["type"], + } + if tree_item["type"] == "file": + child_info.update( + { + "blob_id": tree_item["oid"], + "lfs": tree_item.get("lfs"), + "last_modified": parse_datetime(tree_item["lastCommit"]["date"]), + }, + ) + child_infos.append(child_info) + self.dircache[tree_path] = child_infos + out = self._ls_from_cache(path) + if not has_revision_in_path: + out = [{**o, "name": o["name"].replace(revision_in_path, "", 1)} for o in out] + return out if detail else [o["name"] for o in out] + + def _iter_tree(self, path: str, revision: Optional[str] = None): + resolved_path = self.resolve_path(path, revision=revision) + path = f"{self._api.endpoint}/api/{resolved_path.repo_type}s/{resolved_path.repo_id}/tree/{safe_quote(resolved_path.revision)}/{resolved_path.path_in_repo}".rstrip( + "/" + ) + headers = self._api._build_hf_headers() + yield from paginate(path, params={}, headers=headers) + + def cp_file(self, path1: str, path2: str, revision: Optional[str] = None, **kwargs) -> None: + resolved_path1 = self.resolve_path(path1, revision=revision) + resolved_path2 = self.resolve_path(path2, revision=revision) + + same_repo = ( + resolved_path1.repo_type == resolved_path2.repo_type and resolved_path1.repo_id == resolved_path2.repo_id + ) + + # TODO: Wait for https://github.com/huggingface/huggingface_hub/issues/1083 to be resolved to simplify this logic + if same_repo and self.info(path1, revision=resolved_path1.revision)["lfs"] is not None: + headers = self._api._build_hf_headers(is_write_action=True) + commit_message = f"Copy {path1} to {path2}" + payload = { + "summary": kwargs.get("commit_message", commit_message), + "description": kwargs.get("commit_description", ""), + "files": [], + "lfsFiles": [ + { + "path": resolved_path2.path_in_repo, + "algo": "sha256", + "oid": self.info(path1, revision=resolved_path1.revision)["lfs"]["oid"], + } + ], + "deletedFiles": [], + } + r = requests.post( + f"{self.endpoint}/api/{resolved_path1.repo_type}s/{resolved_path1.repo_id}/commit/{safe_quote(resolved_path2.revision)}", + json=payload, + headers=headers, + ) + hf_raise_for_status(r) + else: + with self.open(path1, "rb", revision=resolved_path1.revision) as f: + content = f.read() + commit_message = f"Copy {path1} to {path2}" + self._api.upload_file( + path_or_fileobj=content, + path_in_repo=resolved_path2.path_in_repo, + repo_id=resolved_path2.repo_id, + token=self.token, + repo_type=resolved_path2.repo_type, + revision=resolved_path2.revision, + commit_message=kwargs.get("commit_message", commit_message), + commit_description=kwargs.get("commit_description"), + ) + self.invalidate_cache(path=resolved_path1.unresolve()) + self.invalidate_cache(path=resolved_path2.unresolve()) + + def modified(self, path: str, **kwargs) -> datetime: + info = self.info(path, **kwargs) + if "last_modified" not in info: + raise IsADirectoryError(path) + return info["last_modified"] + + def info(self, path: str, **kwargs) -> Dict[str, Any]: + resolved_path = self.resolve_path(path) + if not resolved_path.path_in_repo: + revision_in_path = "@" + safe_quote(resolved_path.revision) + has_revision_in_path = revision_in_path in path + name = resolved_path.unresolve() + name = name.replace(revision_in_path, "", 1) if not has_revision_in_path else name + return {"name": name, "size": 0, "type": "directory"} + return super().info(path, **kwargs) + + def expand_path( + self, path: Union[str, List[str]], recursive: bool = False, maxdepth: Optional[int] = None, **kwargs + ) -> List[str]: + # The default implementation does not allow passing custom kwargs (e.g., we use these kwargs to propagate the `revision`) + if maxdepth is not None and maxdepth < 1: + raise ValueError("maxdepth must be at least 1") + + if isinstance(path, str): + return self.expand_path([path], recursive, maxdepth) + + out = set() + path = [self._strip_protocol(p) for p in path] + for p in path: + if has_magic(p): + bit = set(self.glob(p)) + out |= bit + if recursive: + out |= set(self.expand_path(list(bit), recursive=recursive, maxdepth=maxdepth, **kwargs)) + continue + elif recursive: + rec = set(self.find(p, maxdepth=maxdepth, withdirs=True, detail=False, **kwargs)) + out |= rec + if p not in out and (recursive is False or self.exists(p)): + # should only check once, for the root + out.add(p) + if not out: + raise FileNotFoundError(path) + return list(sorted(out)) + + +class HfFileSystemFile(fsspec.spec.AbstractBufferedFile): + def __init__(self, fs: HfFileSystem, path: str, revision: Optional[str] = None, **kwargs): + super().__init__(fs, path, **kwargs) + self.fs: HfFileSystem + self.resolved_path = fs.resolve_path(path, revision=revision) + + def _fetch_range(self, start: int, end: int) -> bytes: + headers = { + "range": f"bytes={start}-{end - 1}", + **self.fs._api._build_hf_headers(), + } + url = ( + f"{self.fs.endpoint}/{REPO_TYPES_URL_PREFIXES.get(self.resolved_path.repo_type, '') + self.resolved_path.repo_id}/resolve/{safe_quote(self.resolved_path.revision)}/{safe_quote(self.resolved_path.path_in_repo)}" + ) + r = http_backoff("GET", url, headers=headers) + hf_raise_for_status(r) + return r.content + + def _initiate_upload(self) -> None: + self.temp_file = tempfile.NamedTemporaryFile(prefix="hffs-", delete=False) + + def _upload_chunk(self, final: bool = False) -> None: + self.buffer.seek(0) + block = self.buffer.read() + self.temp_file.write(block) + if final: + self.temp_file.close() + self.fs._api.upload_file( + path_or_fileobj=self.temp_file.name, + path_in_repo=self.resolved_path.path_in_repo, + repo_id=self.resolved_path.repo_id, + token=self.fs.token, + repo_type=self.resolved_path.repo_type, + revision=self.resolved_path.revision, + commit_message=self.kwargs.get("commit_message"), + commit_description=self.kwargs.get("commit_description"), + ) + os.remove(self.temp_file.name) + self.fs.invalidate_cache( + path=self.resolved_path.unresolve(), + ) + + +def safe_quote(s: str) -> str: + return quote(s, safe="") diff --git a/src/huggingface_hub/utils/__init__.py b/src/huggingface_hub/utils/__init__.py index f3f545d250..db69f357ea 100644 --- a/src/huggingface_hub/utils/__init__.py +++ b/src/huggingface_hub/utils/__init__.py @@ -44,6 +44,7 @@ from ._headers import build_hf_headers, get_token_to_send from ._hf_folder import HfFolder from ._http import configure_http_backend, get_session, http_backoff +from ._pagination import paginate from ._paths import filter_repo_objects, IGNORE_GIT_FOLDER_PATTERNS from ._runtime import ( dump_environment_info, diff --git a/tests/test_hf_file_system.py b/tests/test_hf_file_system.py new file mode 100644 index 0000000000..7a40e1402a --- /dev/null +++ b/tests/test_hf_file_system.py @@ -0,0 +1,284 @@ +import datetime +import unittest +from typing import Optional +from unittest.mock import patch + +import fsspec +import pytest + +from huggingface_hub.constants import REPO_TYPES_URL_PREFIXES +from huggingface_hub.hf_file_system import HfFileSystem +from huggingface_hub.utils import RepositoryNotFoundError, RevisionNotFoundError + +from .testing_constants import ENDPOINT_STAGING, TOKEN, USER +from .testing_utils import repo_name, retry_endpoint + + +class HfFileSystemTests(unittest.TestCase): + @classmethod + def setUpClass(cls): + """Register `HfFileSystem` as a `fsspec` filesystem if not already registered.""" + if HfFileSystem.protocol not in fsspec.available_protocols(): + fsspec.register_implementation(HfFileSystem.protocol, HfFileSystem) + + def setUp(self): + self.repo_id = f"{USER}/{repo_name()}" + self.repo_type = "dataset" + self.hf_path = REPO_TYPES_URL_PREFIXES.get(self.repo_type, "") + self.repo_id + self.hffs = HfFileSystem(endpoint=ENDPOINT_STAGING, token=TOKEN) + self.api = self.hffs._api + + # Create dummy repo + self.api.create_repo(self.repo_id, repo_type=self.repo_type) + self.api.upload_file( + path_or_fileobj=b"dummy binary data on pr", + path_in_repo="data/binary_data_for_pr.bin", + repo_id=self.repo_id, + repo_type=self.repo_type, + create_pr=True, + ) + self.api.upload_file( + path_or_fileobj="dummy text data".encode("utf-8"), + path_in_repo="data/text_data.txt", + repo_id=self.repo_id, + repo_type=self.repo_type, + ) + self.api.upload_file( + path_or_fileobj=b"dummy binary data", + path_in_repo="data/binary_data.bin", + repo_id=self.repo_id, + repo_type=self.repo_type, + ) + + def tearDown(self): + self.api.delete_repo(self.repo_id, repo_type=self.repo_type) + + @retry_endpoint + def test_glob(self): + self.assertEqual( + sorted(self.hffs.glob(self.hf_path + "/*")), + sorted([self.hf_path + "/.gitattributes", self.hf_path + "/data"]), + ) + + self.assertEqual( + sorted(self.hffs.glob(self.hf_path + "/*", revision="main")), + sorted([self.hf_path + "/.gitattributes", self.hf_path + "/data"]), + ) + self.assertEqual( + sorted(self.hffs.glob(self.hf_path + "@main" + "/*")), + sorted([self.hf_path + "@main" + "/.gitattributes", self.hf_path + "@main" + "/data"]), + ) + + @retry_endpoint + def test_file_type(self): + self.assertTrue( + self.hffs.isdir(self.hf_path + "/data") and not self.hffs.isdir(self.hf_path + "/.gitattributes") + ) + self.assertTrue( + self.hffs.isfile(self.hf_path + "/data/text_data.txt") and not self.hffs.isfile(self.hf_path + "/data") + ) + + @retry_endpoint + def test_remove_file(self): + self.hffs.rm_file(self.hf_path + "/data/text_data.txt") + self.assertEqual(self.hffs.glob(self.hf_path + "/data/*"), [self.hf_path + "/data/binary_data.bin"]) + + @retry_endpoint + def test_remove_directory(self): + self.hffs.rm(self.hf_path + "/data", recursive=True) + self.assertNotIn(self.hf_path + "/data", self.hffs.ls(self.hf_path)) + + @retry_endpoint + def test_read_file(self): + with self.hffs.open(self.hf_path + "/data/text_data.txt", "r") as f: + self.assertEqual(f.read(), "dummy text data") + + @retry_endpoint + def test_write_file(self): + data = "new text data" + with self.hffs.open(self.hf_path + "/data/new_text_data.txt", "w") as f: + f.write(data) + self.assertIn(self.hf_path + "/data/new_text_data.txt", self.hffs.glob(self.hf_path + "/data/*")) + with self.hffs.open(self.hf_path + "/data/new_text_data.txt", "r") as f: + self.assertEqual(f.read(), data) + + @retry_endpoint + def test_write_file_multiple_chunks(self): + # TODO: try with files between 10 and 50MB (as of 16 March 2023 I was getting 504 errors on hub-ci) + data = "a" * (4 << 20) # 4MB + with self.hffs.open(self.hf_path + "/data/new_text_data_big.txt", "w") as f: + for _ in range(2): # 8MB in total + f.write(data) + + self.assertIn(self.hf_path + "/data/new_text_data_big.txt", self.hffs.glob(self.hf_path + "/data/*")) + with self.hffs.open(self.hf_path + "/data/new_text_data_big.txt", "r") as f: + for _ in range(2): + self.assertEqual(f.read(len(data)), data) + + @unittest.skip("Not implemented yet") + @retry_endpoint + def test_append_file(self): + with self.hffs.open(self.hf_path + "/data/text_data.txt", "a") as f: + f.write(" appended text") + + with self.hffs.open(self.hf_path + "/data/text_data.txt", "r") as f: + self.assertEqual(f.read(), "dummy text data appended text") + + @retry_endpoint + def test_copy_file(self): + # Non-LFS file + self.assertIsNone(self.hffs.info(self.hf_path + "/data/text_data.txt")["lfs"]) + self.hffs.cp_file(self.hf_path + "/data/text_data.txt", self.hf_path + "/data/text_data_copy.txt") + with self.hffs.open(self.hf_path + "/data/text_data_copy.txt", "r") as f: + self.assertEqual(f.read(), "dummy text data") + self.assertIsNone(self.hffs.info(self.hf_path + "/data/text_data_copy.txt")["lfs"]) + # LFS file + self.assertIsNotNone(self.hffs.info(self.hf_path + "/data/binary_data.bin")["lfs"]) + self.hffs.cp_file(self.hf_path + "/data/binary_data.bin", self.hf_path + "/data/binary_data_copy.bin") + with self.hffs.open(self.hf_path + "/data/binary_data_copy.bin", "rb") as f: + self.assertEqual(f.read(), b"dummy binary data") + self.assertIsNotNone(self.hffs.info(self.hf_path + "/data/binary_data_copy.bin")["lfs"]) + + @retry_endpoint + def test_modified_time(self): + self.assertIsInstance(self.hffs.modified(self.hf_path + "/data/text_data.txt"), datetime.datetime) + # should fail on a non-existing file + with self.assertRaises(FileNotFoundError): + self.hffs.modified(self.hf_path + "/data/not_existing_file.txt") + # should fail on a directory + with self.assertRaises(IsADirectoryError): + self.hffs.modified(self.hf_path + "/data") + + @retry_endpoint + def test_initialize_from_fsspec(self): + fs, _, paths = fsspec.get_fs_token_paths( + f"hf://{self.repo_type}s/{self.repo_id}/data/text_data.txt", + storage_options={ + "endpoint": ENDPOINT_STAGING, + "token": TOKEN, + }, + ) + self.assertIsInstance(fs, HfFileSystem) + self.assertEqual(fs._api.endpoint, ENDPOINT_STAGING) + self.assertEqual(fs.token, TOKEN) + self.assertEqual(paths, [self.hf_path + "/data/text_data.txt"]) + + fs, _, paths = fsspec.get_fs_token_paths(f"hf://{self.repo_id}/data/text_data.txt") + self.assertIsInstance(fs, HfFileSystem) + self.assertEqual(paths, [f"{self.repo_id}/data/text_data.txt"]) + + @retry_endpoint + def test_list_root_directory_no_revision(self): + files = self.hffs.ls(self.hf_path) + self.assertEqual(len(files), 2) + + self.assertEqual(files[0]["type"], "directory") + self.assertEqual(files[0]["size"], 0) + self.assertTrue(files[0]["name"].endswith("/data")) + + self.assertEqual(files[1]["type"], "file") + self.assertGreater(files[1]["size"], 0) # not empty + self.assertTrue(files[1]["name"].endswith("/.gitattributes")) + + @retry_endpoint + def test_list_data_directory_no_revision(self): + files = self.hffs.ls(self.hf_path + "/data") + self.assertEqual(len(files), 2) + + self.assertEqual(files[0]["type"], "file") + self.assertGreater(files[0]["size"], 0) # not empty + self.assertTrue(files[0]["name"].endswith("/data/binary_data.bin")) + self.assertIsNotNone(files[0]["lfs"]) + self.assertIn("oid", files[0]["lfs"]) + self.assertIn("size", files[0]["lfs"]) + self.assertIn("pointerSize", files[0]["lfs"]) + + self.assertEqual(files[1]["type"], "file") + self.assertGreater(files[1]["size"], 0) # not empty + self.assertTrue(files[1]["name"].endswith("/data/text_data.txt")) + self.assertIsNone(files[1]["lfs"]) + + @retry_endpoint + def test_list_data_directory_with_revision(self): + files = self.hffs.ls(self.hf_path + "@refs%2Fpr%2F1" + "/data") + + for test_name, files in { + "rev_in_path": self.hffs.ls(self.hf_path + "@refs%2Fpr%2F1" + "/data"), + "rev_as_arg": self.hffs.ls(self.hf_path + "/data", revision="refs/pr/1"), + "rev_in_path_and_as_arg": self.hffs.ls(self.hf_path + "@refs%2Fpr%2F1" + "/data", revision="refs/pr/1"), + }.items(): + with self.subTest(test_name): + self.assertEqual(len(files), 1) # only one file in PR + self.assertEqual(files[0]["type"], "file") + self.assertTrue(files[0]["name"].endswith("/data/binary_data_for_pr.bin")) # PR file + + +@pytest.mark.parametrize("path_in_repo", ["", "foo"]) +@pytest.mark.parametrize( + "root_path,revision,repo_type,repo_id,resolved_revision", + [ + # Parse without namespace + ("gpt2", None, "model", "gpt2", "main"), + ("gpt2", "dev", "model", "gpt2", "dev"), + ("gpt2@dev", None, "model", "gpt2", "dev"), + ("datasets/squad", None, "dataset", "squad", "main"), + ("datasets/squad", "dev", "dataset", "squad", "dev"), + ("datasets/squad@dev", None, "dataset", "squad", "dev"), + # Parse with namespace + ("username/my_model", None, "model", "username/my_model", "main"), + ("username/my_model", "dev", "model", "username/my_model", "dev"), + ("username/my_model@dev", None, "model", "username/my_model", "dev"), + ("datasets/username/my_dataset", None, "dataset", "username/my_dataset", "main"), + ("datasets/username/my_dataset", "dev", "dataset", "username/my_dataset", "dev"), + ("datasets/username/my_dataset@dev", None, "dataset", "username/my_dataset", "dev"), + # Parse with hf:// protocol + ("hf://gpt2", None, "model", "gpt2", "main"), + ("hf://gpt2", "dev", "model", "gpt2", "dev"), + ("hf://gpt2@dev", None, "model", "gpt2", "dev"), + ("hf://datasets/squad", None, "dataset", "squad", "main"), + ("hf://datasets/squad", "dev", "dataset", "squad", "dev"), + ("hf://datasets/squad@dev", None, "dataset", "squad", "dev"), + ], +) +def test_resolve_path( + root_path: str, + revision: Optional[str], + repo_type: str, + repo_id: str, + resolved_revision: str, + path_in_repo: str, +): + fs = HfFileSystem() + path = root_path + "/" + path_in_repo if path_in_repo else root_path + + def mock_repo_info(repo_id: str, *, revision: str, repo_type: str, **kwargs): + if repo_id not in ["gpt2", "squad", "username/my_dataset", "username/my_model"]: + raise RepositoryNotFoundError(repo_id) + if revision is not None and revision not in ["main", "dev"]: + raise RevisionNotFoundError(revision) + + with patch.object(fs._api, "repo_info", mock_repo_info): + resolved_path = fs.resolve_path(path, revision=revision) + assert ( + resolved_path.repo_type, + resolved_path.repo_id, + resolved_path.revision, + resolved_path.path_in_repo, + ) == (repo_type, repo_id, resolved_revision, path_in_repo) + + +def test_resolve_path_with_non_matching_revisions(): + fs = HfFileSystem() + with pytest.raises(ValueError): + fs.resolve_path("gpt2@dev", revision="main") + + +@pytest.mark.parametrize("not_supported_path", ["", "foo", "datasets", "datasets/foo"]) +def test_access_repositories_lists(not_supported_path): + fs = HfFileSystem() + with pytest.raises(NotImplementedError): + fs.ls(not_supported_path) + with pytest.raises(NotImplementedError): + fs.glob(not_supported_path + "/") + with pytest.raises(NotImplementedError): + fs.open(not_supported_path)