From ff0465df39dfa7ebc9545cf9d172e68cbbe247c4 Mon Sep 17 00:00:00 2001 From: Lucain Date: Mon, 25 Sep 2023 10:30:07 +0200 Subject: [PATCH] Add Collection API (#1687) * Add Collection API * add unit tests * collections * add guide (wip) * completed guide * fix collection tests * uncomment * document 500 limit * delete * tpyo * better formatting * docs * fix test * Apply suggestions from code review Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * requested changes --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/_toctree.yml | 4 + docs/source/en/guides/collections.md | 185 +++++++ .../en/package_reference/collections.md | 24 + src/huggingface_hub/__init__.py | 18 + src/huggingface_hub/hf_api.py | 501 +++++++++++++++++- src/huggingface_hub/utils/_errors.py | 9 + tests/test_hf_api.py | 153 +++++- tests/test_utils_errors.py | 3 +- 8 files changed, 883 insertions(+), 14 deletions(-) create mode 100644 docs/source/en/guides/collections.md create mode 100644 docs/source/en/package_reference/collections.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index e8ab4a1241..632bbf5086 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -24,6 +24,8 @@ title: Inference - local: guides/community title: Community Tab + - local: guides/collections + title: Collections - local: guides/manage-cache title: Cache - local: guides/model-cards @@ -68,6 +70,8 @@ title: Repo Cards and Repo Card Data - local: package_reference/space_runtime title: Space runtime + - local: package_reference/collections + title: Collections - local: package_reference/tensorboard title: TensorBoard logger - local: package_reference/webhooks_server diff --git a/docs/source/en/guides/collections.md b/docs/source/en/guides/collections.md new file mode 100644 index 0000000000..9421b7a50f --- /dev/null +++ b/docs/source/en/guides/collections.md @@ -0,0 +1,185 @@ + + +# Collections + +A collection is a group of related items on the Hub (models, datasets, Spaces, papers) that are organized together on the same page. Collections are useful for creating your own portfolio, bookmarking content in categories, or presenting a curated list of items you want to share. Check out this [guide](https://huggingface.co/docs/hub/collections) to understand in more detail what collections are and how they look on the Hub. + +You can directly manage collections in the browser, but in this guide, we will focus on how to manage it programmatically. + +## Fetch a collection + +Use [`get_collection`] to fetch your collections or any public ones. You must have the collection's *slug* to retrieve a collection. A slug is an identifier for a collection based on the title and a unique ID. You can find the slug in the URL of the collection page. + +
+ +
+ +Let's fetch the collection with, `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`: + +```py +>>> from huggingface_hub import get_collection +>>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") +>>> collection +Collection: { + {'description': "Models I've recently quantized.', + 'items': [...], + 'last_updated': datetime.datetime(2023, 9, 21, 7, 26, 28, 57000, tzinfo=datetime.timezone.utc), + 'owner': 'TheBloke', + 'position': 1, + 'private': False, + 'slug': 'TheBloke/recent-models-64f9a55bb3115b4f513ec026', + 'theme': 'green', + 'title': 'Recent models'} +} +>>> collection.items[0] +CollectionItem: { + {'item_object_id': '6507f6d5423b46492ee1413e', + 'author': 'TheBloke', + 'item_id': 'TheBloke/TigerBot-70B-Chat-GPTQ', + 'item_type': 'model', + 'lastModified': '2023-09-19T12:55:21.000Z', + 'position': 0, + 'private': False, + 'repoType': 'model' + (...) + } +} +``` + +The [`Collection`] object returned by [`get_collection`] contains: +- high-level metadata: `slug`, `owner`, `title`, `description`, etc. +- a list of [`CollectionItem`] objects; each item represents a model, a dataset, a Space, or a paper. + +All collection items are guaranteed to have: +- a unique `item_object_id`: this is the id of the collection item in the database +- an `item_id`: this is the id on the Hub of the underlying item (model, dataset, Space, paper); it is not necessarily unique, and only the `item_id`/`item_type` pair is unique +- an `item_type`: model, dataset, Space, paper +- the `position` of the item in the collection, which can be updated to reorganize your collection (see [`update_collection_item`] below) + +A `note` can also be attached to the item. This is useful to add additional information about the item (a comment, a link to a blog post, etc.). The attribute still has a `None` value if an item doesn't have a note. + +In addition to these base attributes, returned items can have additional attributes depending on their type: `author`, `private`, `lastModified`, `gated`, `title`, `likes`, `upvotes`, etc. None of these attributes are guaranteed to be returned. + +## Create a new collection + +Now that we know how to get a [`Collection`], let's create our own! Use [`create_collection`] with a title and description. To create a collection on an organization page, pass `namespace="my-cool-org"` when creating the collection. Finally, you can also create private collections by passing `private=True`. + +```py +>>> from huggingface_hub import create_collection + +>>> collection = create_collection( +... title="ICCV 2023", +... description="Portfolio of models, papers and demos I presented at ICCV 2023", +... ) +``` + +It will return a [`Collection`] object with the high-level metadata (title, description, owner, etc.) and an empty list of items. You will now be able to refer to this collection using it's `slug`. + +```py +>>> collection.slug +'iccv-2023-15e23b46cb98efca45' +>>> collection.title +"ICCV 2023" +>>> collection.owner +"username" +``` + +## Manage items in a collection + +Now that we have a [`Collection`], we want to add items to it and organize them. + +### Add items + +Items have to be added one by one using [`add_collection_item`]. You only need to know the `collection_slug`, `item_id` and `item_type`. Optionally, you can also add a `note` to the item (500 characters maximum). + +```py +>>> from huggingface_hub import create_collection, add_collection_item + +>>> collection = create_collection(title="OS Week Highlights - Sept 18 - 24", namespace="osanseviero") +>>> collection.slug +"osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" + +>>> add_collection_item(collection.slug, item_id="coqui/xtts", item_type="space") +>>> add_collection_item( +... collection.slug, +... item_id="warp-ai/wuerstchen", +... item_type="model", +... note="Würstchen is a new fast and efficient high resolution text-to-image architecture and model" +... ) +>>> add_collection_item(collection.slug, item_id="lmsys/lmsys-chat-1m", item_type="dataset") +>>> add_collection_item(collection.slug, item_id="warp-ai/wuerstchen", item_type="space") # same item_id, different item_type +``` + +If an item already exists in a collection (same `item_id`/`item_type` pair), an HTTP 409 error will be raised. You can choose to ignore this error by setting `exists_ok=True`. + +### Add a note to an existing item + +You can modify an existing item to add or modify the note attached to it using [`update_collection_item`]. Let's reuse the example above: + +```py +>>> from huggingface_hub import get_collection, update_collection_item + +# Fetch collection with newly added items +>>> collection_slug = "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" +>>> collection = get_collection(collection_slug) + +# Add note the `lmsys-chat-1m` dataset +>>> update_collection_item( +... collection_slug=collection_slug, +... item_object_id=collection.items[2].item_object_id, +... note="This dataset contains one million real-world conversations with 25 state-of-the-art LLMs.", +... ) +``` + +### Reorder items + +Items in a collection are ordered. The order is determined by the `position` attribute of each item. By default, items are ordered by appending new items at the end of the collection. You can update the order using [`update_collection_item`] the same way you would add a note. + +Let's reuse our example above: + +```py +>>> from huggingface_hub import get_collection, update_collection_item + +# Fetch collection +>>> collection_slug = "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" +>>> collection = get_collection(collection_slug) + +# Reorder to place the two `Wuerstchen` items together +>>> update_collection_item( +... collection_slug=collection_slug, +... item_object_id=collection.items[3].item_object_id, +... position=2, +... ) +``` + +### Remove items + +Finally, you can also remove an item using [`delete_collection_item`]. + +```py +>>> from huggingface_hub import get_collection, update_collection_item + +# Fetch collection +>>> collection_slug = "osanseviero/os-week-highlights-sept-18-24-650bfed7f795a59f491afb80" +>>> collection = get_collection(collection_slug) + +# Remove `coqui/xtts` Space from the list +>>> delete_collection_item(collection_slug=collection_slug, item_object_id=collection.items[0].item_object_id) +``` + +## Delete collection + +A collection can be deleted using [`delete_collection`]. + + + +This is a non-revertible action. A deleted collection cannot be restored. + + + +```py +>>> from huggingface_hub import delete_collection +>>> collection = delete_collection("username/useless-collection-64f9a55bb3115b4f513ec026", missing_ok=True) +``` \ No newline at end of file diff --git a/docs/source/en/package_reference/collections.md b/docs/source/en/package_reference/collections.md new file mode 100644 index 0000000000..37fe25039c --- /dev/null +++ b/docs/source/en/package_reference/collections.md @@ -0,0 +1,24 @@ + + +# Managing collections + +Check out the [`HfApi`] documentation page for the reference of methods to manage your Space on the Hub. + +- Get collection content: [`get_collection`] +- Create new collection: [`create_collection`] +- Update a collection: [`update_collection_metadata`] +- Delete a collection: [`delete_collection`] +- Add an item to a collection: [`add_collection_item`] +- Update an item in a collection: [`update_collection_item`] +- Remove an item from a collection: [`delete_collection_item`] + + +### Collection + +[[autodoc]] Collection + +### CollectionItem + +[[autodoc]] CollectionItem \ No newline at end of file diff --git a/src/huggingface_hub/__init__.py b/src/huggingface_hub/__init__.py index 9d75889f3b..bf137fcf13 100644 --- a/src/huggingface_hub/__init__.py +++ b/src/huggingface_hub/__init__.py @@ -129,6 +129,8 @@ "try_to_load_from_cache", ], "hf_api": [ + "Collection", + "CollectionItem", "CommitInfo", "CommitOperation", "CommitOperationAdd", @@ -142,11 +144,13 @@ "ModelSearchArguments", "RepoUrl", "UserLikes", + "add_collection_item", "add_space_secret", "add_space_variable", "change_discussion_status", "comment_discussion", "create_branch", + "create_collection", "create_commit", "create_commits_on_pr", "create_discussion", @@ -155,6 +159,8 @@ "create_tag", "dataset_info", "delete_branch", + "delete_collection", + "delete_collection_item", "delete_file", "delete_folder", "delete_repo", @@ -165,6 +171,7 @@ "duplicate_space", "edit_discussion_comment", "file_exists", + "get_collection", "get_dataset_tags", "get_discussion_details", "get_full_repo_name", @@ -199,6 +206,8 @@ "space_info", "super_squash_history", "unlike", + "update_collection_item", + "update_collection_metadata", "update_repo_visibility", "upload_file", "upload_folder", @@ -438,6 +447,8 @@ def __dir__(): try_to_load_from_cache, # noqa: F401 ) from .hf_api import ( + Collection, # noqa: F401 + CollectionItem, # noqa: F401 CommitInfo, # noqa: F401 CommitOperation, # noqa: F401 CommitOperationAdd, # noqa: F401 @@ -451,11 +462,13 @@ def __dir__(): ModelSearchArguments, # noqa: F401 RepoUrl, # noqa: F401 UserLikes, # noqa: F401 + add_collection_item, # noqa: F401 add_space_secret, # noqa: F401 add_space_variable, # noqa: F401 change_discussion_status, # noqa: F401 comment_discussion, # noqa: F401 create_branch, # noqa: F401 + create_collection, # noqa: F401 create_commit, # noqa: F401 create_commits_on_pr, # noqa: F401 create_discussion, # noqa: F401 @@ -464,6 +477,8 @@ def __dir__(): create_tag, # noqa: F401 dataset_info, # noqa: F401 delete_branch, # noqa: F401 + delete_collection, # noqa: F401 + delete_collection_item, # noqa: F401 delete_file, # noqa: F401 delete_folder, # noqa: F401 delete_repo, # noqa: F401 @@ -474,6 +489,7 @@ def __dir__(): duplicate_space, # noqa: F401 edit_discussion_comment, # noqa: F401 file_exists, # noqa: F401 + get_collection, # noqa: F401 get_dataset_tags, # noqa: F401 get_discussion_details, # noqa: F401 get_full_repo_name, # noqa: F401 @@ -508,6 +524,8 @@ def __dir__(): space_info, # noqa: F401 super_squash_history, # noqa: F401 unlike, # noqa: F401 + update_collection_item, # noqa: F401 + update_collection_metadata, # noqa: F401 update_repo_visibility, # noqa: F401 upload_file, # noqa: F401 upload_folder, # noqa: F401 diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index f4eb4b5961..a357cf1b61 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -132,6 +132,7 @@ R = TypeVar("R") # Return type +CollectionItemType_T = Literal["model", "dataset", "space", "paper"] USERNAME_PLACEHOLDER = "hf_user" _REGEX_DISCUSSION_URL = re.compile(r".*/discussions/(\d+)$") @@ -143,6 +144,13 @@ class ReprMixin: """Mixin to create the __repr__ for a class""" + def __init__(self, **kwargs) -> None: + # Store all the other fields returned by the API + # Hack to ensure backward compatibility with future versions of the API. + # See discussion in https://github.com/huggingface/huggingface_hub/pull/951#discussion_r926460408 + for k, v in kwargs.items(): + setattr(self, k, v) + def __repr__(self): formatted_value = pprint.pformat(self.__dict__, width=119, compact=True) if "\n" in formatted_value: @@ -392,10 +400,8 @@ def __init__( self.blob_id = blobId self.lfs = lfs - # Hack to ensure backward compatibility with future versions of the API. - # See discussion in https://github.com/huggingface/huggingface_hub/pull/951#discussion_r926460408 - for k, v in kwargs.items(): - setattr(self, k, v) + # Store all the other fields returned by the API + super().__init__(**kwargs) class ModelInfo(ReprMixin): @@ -453,8 +459,9 @@ def __init__( self.author = author self.config = config self.securityStatus = securityStatus - for k, v in kwargs.items(): - setattr(self, k, v) + + # Store all the other fields returned by the API + super().__init__(**kwargs) def __str__(self): r = f"Model Name: {self.modelId}, Tags: {self.tags}" @@ -521,8 +528,7 @@ def __init__( # because of old versions of the datasets lib that need this field kwargs.pop("key", None) # Store all the other fields returned by the API - for k, v in kwargs.items(): - setattr(self, k, v) + super().__init__(**kwargs) def __str__(self): r = f"Dataset Name: {self.id}, Tags: {self.tags}" @@ -570,8 +576,8 @@ def __init__( self.siblings = [RepoFile(**x) for x in siblings] if siblings is not None else [] self.private = private self.author = author - for k, v in kwargs.items(): - setattr(self, k, v) + # Store all the other fields returned by the API + super().__init__(**kwargs) class MetricInfo(ReprMixin): @@ -594,14 +600,97 @@ def __init__( # because of old versions of the datasets lib that need this field kwargs.pop("key", None) # Store all the other fields returned by the API - for k, v in kwargs.items(): - setattr(self, k, v) + super().__init__(**kwargs) def __str__(self): r = f"Metric Name: {self.id}" return r +class CollectionItem(ReprMixin): + """Contains information about an item of a Collection (model, dataset, Space or paper). + + Args: + item_object_id (`str`): + Unique ID of the item in the collection. + item_id (`str`): + ID of the underlying object on the Hub. Can be either a repo_id or a paper id + e.g. `"jbilcke-hf/ai-comic-factory"`, `"2307.09288"`. + item_type (`str`): + Type of the underlying object. Can be one of `"model"`, `"dataset"`, `"space"` or `"paper"`. + position (`int`): + Position of the item in the collection. + note (`str`, *optional*): + Note associated with the item, as plain text. + kwargs (`Dict`, *optional*): + Any other attribute returned by the server. Those attributes depend on the `item_type`: "author", "private", + "lastModified", "gated", "title", "likes", "upvotes", etc. + """ + + def __init__( + self, _id: str, id: str, type: CollectionItemType_T, position: int, note: Optional[Dict] = None, **kwargs + ) -> None: + self.item_object_id: str = _id # id in database + self.item_id: str = id # repo_id or paper id + self.item_type: CollectionItemType_T = type + self.position: int = position + self.note: str = note["text"] if note is not None else None + + # Store all the other fields returned by the API + super().__init__(**kwargs) + + +class Collection(ReprMixin): + """ + Contains information about a Collection on the Hub. + + Args: + slug (`str`): + Slug of the collection. E.g. `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + title (`str`): + Title of the collection. E.g. `"Recent models"`. + owner (`str`): + Owner of the collection. E.g. `"TheBloke"`. + description (`str`, *optional*): + Description of the collection, as plain text. + items (`List[CollectionItem]`): + List of items in the collection. + last_updated (`datetime`): + Date of the last update of the collection. + position (`int`): + Position of the collection in the list of collections of the owner. + private (`bool`): + Whether the collection is private or not. + theme (`str`): + Theme of the collection. E.g. `"green"`. + """ + + slug: str + title: str + owner: str + description: Optional[str] + items: List[CollectionItem] + + last_updated: datetime + position: int + private: bool + theme: str + + def __init__(self, data: Dict) -> None: + # Collection info + self.slug = data["slug"] + self.title = data["title"] + self.owner = data["owner"]["name"] + self.description = data.get("description") + self.items = [CollectionItem(**item) for item in data["items"]] + + # Metadata + self.last_updated = parse_datetime(data["lastUpdated"]) + self.private = data["private"] + self.position = data["position"] + self.theme = data["theme"] + + class ModelSearchArguments(AttributeDictionary): """ A nested namespace object holding all possible values for properties of @@ -5710,6 +5799,384 @@ def delete_space_storage( hf_raise_for_status(r) return SpaceRuntime(r.json()) + ######################## + # Collection Endpoints # + ######################## + + def get_collection(self, collection_slug: str, *, token: Optional[str] = None) -> Collection: + """Gets information about a Collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection of the Hub. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import get_collection + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + >>> collection.title + 'Recent models' + >>> len(collection.items) + 37 + >>> collection.items[0] + CollectionItem: { + {'item_object_id': '6507f6d5423b46492ee1413e', + 'item_id': 'TheBloke/TigerBot-70B-Chat-GPTQ', + 'author': 'TheBloke', + 'item_type': 'model', + 'lastModified': '2023-09-19T12:55:21.000Z', + (...) + }} + ``` + """ + r = get_session().get( + f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token) + ) + hf_raise_for_status(r) + return Collection(r.json()) + + def create_collection( + self, + title: str, + *, + namespace: Optional[str] = None, + description: Optional[str] = None, + private: bool = False, + exists_ok: bool = False, + token: Optional[str] = None, + ) -> Collection: + """Create a new Collection on the Hub. + + Args: + title (`str`): + Title of the collection to create. Example: `"Recent models"`. + namespace (`str`, *optional*): + Namespace of the collection to create (username or org). Will default to the owner name. + description (`str`, *optional*): + Description of the collection to create. + private (`bool`, *optional*): + Whether the collection should be private or not. Defaults to `False` (i.e. public collection). + exists_ok (`bool`, *optional*): + If `True`, do not raise an error if collection already exists. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import create_collection + >>> collection = create_collection( + ... title="ICCV 2023", + ... description="Portfolio of models, papers and demos I presented at ICCV 2023", + ... ) + >>> collection.slug + "username/iccv-2023-64f9a55bb3115b4f513ec026" + ``` + """ + if namespace is None: + namespace = self.whoami(token)["name"] + + payload = { + "title": title, + "namespace": namespace, + "private": private, + } + if description is not None: + payload["description"] = description + + r = get_session().post( + f"{self.endpoint}/api/collections", headers=self._build_hf_headers(token=token), json=payload + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if exists_ok and err.response.status_code == 409: + # Collection already exists and `exists_ok=True` + slug = r.json()["slug"] + return self.get_collection(slug, token=token) + else: + raise + return Collection(r.json()) + + def update_collection_metadata( + self, + collection_slug: str, + *, + title: Optional[str] = None, + description: Optional[str] = None, + position: Optional[int] = None, + private: Optional[bool] = None, + theme: Optional[str] = None, + token: Optional[str] = None, + ) -> Collection: + """Update metadata of a collection on the Hub. + + All arguments are optional. Only provided metadata will be updated. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + title (`str`): + Title of the collection to update. + description (`str`, *optional*): + Description of the collection to update. + position (`int`, *optional*): + New position of the collection in the list of collections of the user. + private (`bool`, *optional*): + Whether the collection should be private or not. + theme (`str`, *optional*): + Theme of the collection on the Hub. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import update_collection_metadata + >>> collection = update_collection_metadata( + ... collection_slug="username/iccv-2023-64f9a55bb3115b4f513ec026", + ... title="ICCV Oct. 2023" + ... description="Portfolio of models, datasets, papers and demos I presented at ICCV Oct. 2023", + ... private=False, + ... theme="pink", + ... ) + >>> collection.slug + "username/iccv-oct-2023-64f9a55bb3115b4f513ec026" + # ^collection slug got updated but not the trailing ID + ``` + """ + payload = { + "position": position, + "private": private, + "theme": theme, + "title": title, + "description": description, + } + r = get_session().patch( + f"{self.endpoint}/api/collections/{collection_slug}", + headers=self._build_hf_headers(token=token), + # Only send not-none values to the API + json={key: value for key, value in payload.items() if value is not None}, + ) + hf_raise_for_status(r) + return Collection(r.json()["data"]) + + def delete_collection( + self, collection_slug: str, *, missing_ok: bool = False, token: Optional[str] = None + ) -> None: + """Delete a collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection to delete. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + missing_ok (`bool`, *optional*): + If `True`, do not raise an error if collection doesn't exists. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Example: + + ```py + >>> from huggingface_hub import delete_collection + >>> collection = delete_collection("username/useless-collection-64f9a55bb3115b4f513ec026", missing_ok=True) + ``` + + + + This is a non-revertible action. A deleted collection cannot be restored. + + + """ + r = get_session().delete( + f"{self.endpoint}/api/collections/{collection_slug}", headers=self._build_hf_headers(token=token) + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if missing_ok and err.response.status_code == 404: + # Collection doesn't exists and `missing_ok=True` + return + else: + raise + + def add_collection_item( + self, + collection_slug: str, + item_id: str, + item_type: CollectionItemType_T, + *, + note: Optional[str] = None, + exists_ok: bool = False, + token: Optional[str] = None, + ) -> Collection: + """Add an item to a collection on the Hub. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_id (`str`): + ID of the item to add to the collection. It can be the ID of a repo on the Hub (e.g. `"facebook/bart-large-mnli"`) + or a paper id (e.g. `"2307.09288"`). + item_type (`str`): + Type of the item to add. Can be one of `"model"`, `"dataset"`, `"space"` or `"paper"`. + note (`str`, *optional*): + A note to attach to the item in the collection. The maximum size for a note is 500 characters. + exists_ok (`bool`, *optional*): + If `True`, do not raise an error if item already exists. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Returns: [`Collection`] + + Example: + + ```py + >>> from huggingface_hub import add_collection_item + >>> collection = add_collection_item( + ... collection_slug="davanstrien/climate-64f99dc2a5067f6b65531bab", + ... item_id="pierre-loic/climate-news-articles", + ... item_type="dataset" + ... ) + >>> collection.items[-1].item_id + "pierre-loic/climate-news-articles" + # ^item got added to the collection on last position + + # Add collection with a note + >>> add_collection_item( + ... collection_slug="davanstrien/climate-64f99dc2a5067f6b65531bab", + ... item_id="datasets/climate_fever", + ... item_type="dataset" + ... note="This dataset adopts the FEVER methodology that consists of 1,535 real-world claims regarding climate-change collected on the internet." + ... ) + (...) + ``` + """ + payload: Dict[str, Any] = {"item": {"id": item_id, "type": item_type}} + if note is not None: + payload["note"] = note + r = get_session().post( + f"{self.endpoint}/api/collections/{collection_slug}/items", + headers=self._build_hf_headers(token=token), + json=payload, + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if exists_ok and err.response.status_code == 409: + # Item already exists and `exists_ok=True` + return self.get_collection(collection_slug, token=token) + else: + raise + return Collection(r.json()) + + def update_collection_item( + self, + collection_slug: str, + item_object_id: str, + *, + note: Optional[str] = None, + position: Optional[int] = None, + token: Optional[str] = None, + ) -> None: + """Update an item in a collection. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_object_id (`str`): + ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id). + It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0]._id`. + note (`str`, *optional*): + A note to attach to the item in the collection. The maximum size for a note is 500 characters. + position (`int`, *optional*): + New position of the item in the collection. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Example: + + ```py + >>> from huggingface_hub import get_collection, update_collection_item + + # Get collection first + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + + # Update item based on its ID (add note + update position) + >>> update_collection_item( + ... collection_slug="TheBloke/recent-models-64f9a55bb3115b4f513ec026", + ... item_object_id=collection.items[-1].item_object_id, + ... note="Newly updated model!" + ... position=0, + ... ) + ``` + """ + payload = {"position": position, "note": note} + r = get_session().patch( + f"{self.endpoint}/api/collections/{collection_slug}/items/{item_object_id}", + headers=self._build_hf_headers(token=token), + # Only send not-none values to the API + json={key: value for key, value in payload.items() if value is not None}, + ) + hf_raise_for_status(r) + + def delete_collection_item( + self, + collection_slug: str, + item_object_id: str, + *, + missing_ok: bool = False, + token: Optional[str] = None, + ) -> None: + """Delete an item from a collection. + + Args: + collection_slug (`str`): + Slug of the collection to update. Example: `"TheBloke/recent-models-64f9a55bb3115b4f513ec026"`. + item_object_id (`str`): + ID of the item in the collection. This is not the id of the item on the Hub (repo_id or paper id). + It must be retrieved from a [`CollectionItem`] object. Example: `collection.items[0]._id`. + missing_ok (`bool`, *optional*): + If `True`, do not raise an error if item doesn't exists. + token (`str`, *optional*): + Hugging Face token. Will default to the locally saved token if not provided. + + Example: + + ```py + >>> from huggingface_hub import get_collection, delete_collection_item + + # Get collection first + >>> collection = get_collection("TheBloke/recent-models-64f9a55bb3115b4f513ec026") + + # Delete item based on its ID + >>> delete_collection_item( + ... collection_slug="TheBloke/recent-models-64f9a55bb3115b4f513ec026", + ... item_object_id=collection.items[-1].item_object_id, + ... ) + ``` + """ + r = get_session().delete( + f"{self.endpoint}/api/collections/{collection_slug}/items/{item_object_id}", + headers=self._build_hf_headers(token=token), + ) + try: + hf_raise_for_status(r) + except HTTPError as err: + if missing_ok and err.response.status_code == 404: + # Item already deleted and `missing_ok=True` + return + else: + raise + def _build_hf_headers( self, token: Optional[Union[bool, str]] = None, @@ -5902,3 +6369,13 @@ def _parse_revision_from_pr_url(pr_url: str) -> str: duplicate_space = api.duplicate_space request_space_storage = api.request_space_storage delete_space_storage = api.delete_space_storage + +# Collections API +get_collection = api.get_collection +create_collection = api.create_collection +update_collection_metadata = api.update_collection_metadata +delete_collection = api.delete_collection +add_collection_item = api.add_collection_item +update_collection_item = api.update_collection_item +delete_collection_item = api.delete_collection_item +delete_collection_item = api.delete_collection_item diff --git a/src/huggingface_hub/utils/_errors.py b/src/huggingface_hub/utils/_errors.py index 47f66d8b4c..586aa2acea 100644 --- a/src/huggingface_hub/utils/_errors.py +++ b/src/huggingface_hub/utils/_errors.py @@ -283,6 +283,15 @@ def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None) ) raise GatedRepoError(message, response) from e + elif ( + response.status_code == 401 + and response.request.url is not None + and "/api/collections" in response.request.url + ): + # Collection not found. We don't raise a custom error for this. + # This prevent from raising a misleading `RepositoryNotFoundError` (see below). + pass + elif error_code == "RepoNotFound" or response.status_code == 401: # 401 is misleading as it is returned for: # - private and gated repos if user is not authenticated diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index f37e992170..6fb7110fd0 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -19,12 +19,13 @@ import time import types import unittest +import uuid import warnings from concurrent.futures import Future from functools import partial from io import BytesIO from pathlib import Path -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Optional, Union from unittest.mock import Mock, patch from urllib.parse import quote @@ -49,6 +50,7 @@ ) from huggingface_hub.file_download import hf_hub_download from huggingface_hub.hf_api import ( + Collection, CommitInfo, DatasetInfo, HfApi, @@ -3199,3 +3201,152 @@ def __init__(self, **kwargs: Dict[str, Any]) -> None: repr(MyClass(foo="foo", bar="bar")), "MyClass: {'bar': 'bar', 'foo': 'foo'}", # keys are sorted ) + + +class CollectionAPITest(HfApiCommonTest): + def setUp(self) -> None: + id = uuid.uuid4() + self.title = f"My cool stuff {id}" + self.slug_prefix = f"{USER}/my-cool-stuff-{id}" + self.slug: Optional[str] = None # Populated by the tests => use to delete in tearDown + return super().setUp() + + def tearDown(self) -> None: + if self.slug is not None: # Delete collection even if test failed + self._api.delete_collection(self.slug, missing_ok=True) + return super().tearDown() + + def test_create_collection_with_description(self) -> None: + collection = self._api.create_collection(self.title, description="Contains a lot of cool stuff") + self.slug = collection.slug + + self.assertIsInstance(collection, Collection) + self.assertEqual(collection.title, self.title) + self.assertEqual(collection.description, "Contains a lot of cool stuff") + self.assertEqual(collection.items, []) + self.assertTrue(collection.slug.startswith(self.slug_prefix)) + + def test_create_collection_exists_ok(self) -> None: + # Create collection once without description + collection_1 = self._api.create_collection(self.title) + self.slug = collection_1.slug + + # Cannot create twice with same title + with self.assertRaises(HTTPError): # already exists + self._api.create_collection(self.title) + + # Can ignore error + collection_2 = self._api.create_collection(self.title, description="description", exists_ok=True) + + self.assertEqual(collection_1.slug, collection_2.slug) + self.assertIsNone(collection_1.description) + self.assertIsNone(collection_2.description) # Did not got updated! + + def test_create_private_collection(self) -> None: + collection = self._api.create_collection(self.title, private=True) + self.slug = collection.slug + + # Get private collection + self._api.get_collection(collection.slug) # no error + with self.assertRaises(HTTPError): + self._api.get_collection(collection.slug, token=OTHER_TOKEN) # not authorized + + # Get public collection + self._api.update_collection_metadata(collection.slug, private=False) + self._api.get_collection(collection.slug) # no error + self._api.get_collection(collection.slug, token=OTHER_TOKEN) # no error + + def test_update_collection(self) -> None: + # Create collection + collection_1 = self._api.create_collection(self.title) + self.slug = collection_1.slug + + # Update metadata + new_title = f"New title {uuid.uuid4()}" + collection_2 = self._api.update_collection_metadata( + collection_slug=collection_1.slug, + title=new_title, + description="New description", + private=True, + theme="pink", + ) + + self.assertEqual(collection_2.title, new_title) + self.assertEqual(collection_2.description, "New description") + self.assertEqual(collection_2.private, True) + self.assertEqual(collection_2.theme, "pink") + self.assertNotEqual(collection_1.slug, collection_2.slug) + + # Different slug, same id + self.assertEqual(collection_1.slug.split("-")[-1], collection_2.slug.split("-")[-1]) + + # Works with both slugs, same collection returned + self.assertEqual(self._api.get_collection(collection_1.slug).slug, collection_2.slug) + self.assertEqual(self._api.get_collection(collection_2.slug).slug, collection_2.slug) + + def test_delete_collection(self) -> None: + collection = self._api.create_collection(self.title) + + self._api.delete_collection(collection.slug) + + # Cannot delete twice the same collection + with self.assertRaises(HTTPError): # already exists + self._api.delete_collection(collection.slug) + + # Possible to ignore error + self._api.delete_collection(collection.slug, missing_ok=True) + + def test_collection_items(self) -> None: + # Create some repos + model_id = self._api.create_repo(repo_name()).repo_id + dataset_id = self._api.create_repo(repo_name(), repo_type="dataset").repo_id + + # Create collection + add items to it + collection = self._api.create_collection(self.title) + self._api.add_collection_item(collection.slug, model_id, "model", note="This is my model") + self._api.add_collection_item(collection.slug, dataset_id, "dataset") # note is optional + + # Check consistency + collection = self._api.get_collection(collection.slug) + self.assertEqual(len(collection.items), 2) + self.assertEqual(collection.items[0].item_id, model_id) + self.assertEqual(collection.items[0].item_type, "model") + self.assertEqual(collection.items[0].note, "This is my model") + + self.assertEqual(collection.items[1].item_id, dataset_id) + self.assertEqual(collection.items[1].item_type, "dataset") + self.assertIsNone(collection.items[1].note) + + # Add existing item fails (except if ignore error) + with self.assertRaises(HTTPError): + self._api.add_collection_item(collection.slug, model_id, "model") + self._api.add_collection_item(collection.slug, model_id, "model", exists_ok=True) + + # Add inexistent item fails + with self.assertRaises(HTTPError): + self._api.add_collection_item(collection.slug, model_id, "dataset") + + # Update first item + self._api.update_collection_item( + collection.slug, collection.items[0].item_object_id, note="New note", position=1 + ) + + # Check consistency + collection = self._api.get_collection(collection.slug) + self.assertEqual(collection.items[0].item_id, dataset_id) # position got updated + self.assertEqual(collection.items[1].item_id, model_id) + self.assertEqual(collection.items[1].note, "New note") # note got updated + + # Delete last item + self._api.delete_collection_item(collection.slug, collection.items[1].item_object_id) + self._api.delete_collection_item(collection.slug, collection.items[1].item_object_id, missing_ok=True) + + # Check consistency + collection = self._api.get_collection(collection.slug) + self.assertEqual(len(collection.items), 1) # only 1 item remaining + self.assertEqual(collection.items[0].item_id, dataset_id) # position got updated + + # Delete everything + self._api.delete_repo(model_id) + self._api.delete_repo(dataset_id, repo_type="dataset") + self._api.delete_collection(collection.slug) diff --git a/tests/test_utils_errors.py b/tests/test_utils_errors.py index af469f9c29..b0362a1998 100644 --- a/tests/test_utils_errors.py +++ b/tests/test_utils_errors.py @@ -1,6 +1,6 @@ import unittest -from requests.models import Response +from requests.models import PreparedRequest, Response from huggingface_hub.utils._errors import ( BadRequestError, @@ -27,6 +27,7 @@ def test_hf_raise_for_status_repo_not_found_without_error_code(self) -> None: response = Response() response.headers = {"X-Request-Id": 123} response.status_code = 401 + response.request = PreparedRequest() with self.assertRaisesRegex(RepositoryNotFoundError, "Repository Not Found") as context: hf_raise_for_status(response)