From ec5812dc6e85c32539e70695983e0dc07ecc7741 Mon Sep 17 00:00:00 2001 From: Swapnil Jikar <112884653+WizKnight@users.noreply.github.com> Date: Wed, 21 Aug 2024 18:23:58 +0530 Subject: [PATCH] Update constants imports with module level access #1172 (#2469) * Update constants imports with module level access and backward compatibility #1172 * Use constants.ENDPOINT and fix patch * remove useless import --------- Co-authored-by: Lucain Pouget Co-authored-by: Lucain --- src/huggingface_hub/hf_api.py | 267 +++++++++++++++++----------------- tests/test_hf_api.py | 46 +++--- tests/testing_utils.py | 2 +- 3 files changed, 157 insertions(+), 158 deletions(-) diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 2cebc0b187..78619f1707 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -47,6 +47,7 @@ from tqdm.auto import tqdm as base_tqdm from tqdm.contrib.concurrent import thread_map +from . import constants from ._commit_api import ( CommitOperation, CommitOperationAdd, @@ -249,7 +250,7 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tu """ input_hf_id = hf_id - hub_url = re.sub(r"https?://", "", hub_url if hub_url is not None else ENDPOINT) + hub_url = re.sub(r"https?://", "", hub_url if hub_url is not None else constants.ENDPOINT) is_hf_url = hub_url in hf_id and "@" not in hf_id HFFS_PREFIX = "hf://" @@ -266,9 +267,9 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tu namespace = None if len(url_segments) > 2 and hub_url not in url_segments[-3]: repo_type = url_segments[-3] - elif namespace in REPO_TYPES_MAPPING: + elif namespace in constants.REPO_TYPES_MAPPING: # Mean canonical dataset or model - repo_type = REPO_TYPES_MAPPING[namespace] + repo_type = constants.REPO_TYPES_MAPPING[namespace] namespace = None else: repo_type = None @@ -277,9 +278,9 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tu # Passed // or // repo_type, namespace, repo_id = url_segments[-3:] elif len(url_segments) == 2: - if url_segments[0] in REPO_TYPES_MAPPING: + if url_segments[0] in constants.REPO_TYPES_MAPPING: # Passed '' or 'datasets/' for a canonical model or dataset - repo_type = REPO_TYPES_MAPPING[url_segments[0]] + repo_type = constants.REPO_TYPES_MAPPING[url_segments[0]] namespace = None repo_id = hf_id.split("/")[-1] else: @@ -294,11 +295,11 @@ def repo_type_and_id_from_hf_id(hf_id: str, hub_url: Optional[str] = None) -> Tu raise ValueError(f"Unable to retrieve user and repo ID from the passed HF ID: {hf_id}") # Check if repo type is known (mapping "spaces" => "space" + empty value => `None`) - if repo_type in REPO_TYPES_MAPPING: - repo_type = REPO_TYPES_MAPPING[repo_type] + if repo_type in constants.REPO_TYPES_MAPPING: + repo_type = constants.REPO_TYPES_MAPPING[repo_type] if repo_type == "": repo_type = None - if repo_type not in REPO_TYPES: + if repo_type not in constants.REPO_TYPES: raise ValueError(f"Unknown `repo_type`: '{repo_type}' ('{input_hf_id}')") return repo_type, namespace, repo_id @@ -492,7 +493,7 @@ class WebhookInfo: id: str url: str watched: List[WebhookWatchedItem] - domains: List[WEBHOOK_DOMAIN_T] + domains: List[constants.WEBHOOK_DOMAIN_T] secret: Optional[str] disabled: bool @@ -544,14 +545,14 @@ def __new__(cls, url: Any, endpoint: Optional[str] = None): def __init__(self, url: Any, endpoint: Optional[str] = None) -> None: super().__init__() # Parse URL - self.endpoint = endpoint or ENDPOINT + self.endpoint = endpoint or constants.ENDPOINT repo_type, namespace, repo_name = repo_type_and_id_from_hf_id(self, hub_url=self.endpoint) # Populate fields self.namespace = namespace self.repo_name = repo_name self.repo_id = repo_name if namespace is None else f"{namespace}/{repo_name}" - self.repo_type = repo_type or REPO_TYPE_MODEL + self.repo_type = repo_type or constants.REPO_TYPE_MODEL self.url = str(self) # just in case it's needed def __repr__(self) -> str: @@ -1188,7 +1189,7 @@ def __init__(self, **kwargs) -> None: self.description = kwargs.pop("description", None) endpoint = kwargs.pop("endpoint", None) if endpoint is None: - endpoint = ENDPOINT + endpoint = constants.ENDPOINT self._url = f"{endpoint}/collections/{self.slug}" @property @@ -1472,7 +1473,7 @@ def __init__( Additional headers to be sent with each request. Example: `{"X-My-Header": "value"}`. Headers passed here are taking precedence over the default headers. """ - self.endpoint = endpoint if endpoint is not None else ENDPOINT + self.endpoint = endpoint if endpoint is not None else constants.ENDPOINT self.token = token self.library_name = library_name self.library_version = library_version @@ -2159,7 +2160,7 @@ def like( ``` """ if repo_type is None: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL response = get_session().post( url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/like", headers=self._build_hf_headers(token=token), @@ -2210,7 +2211,7 @@ def unlike( ``` """ if repo_type is None: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL response = get_session().delete( url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/like", headers=self._build_hf_headers(token=token) ) @@ -2326,7 +2327,7 @@ def list_repo_likers( # Construct the API endpoint if repo_type is None: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/likers" headers = self._build_hf_headers(token=token) @@ -2952,8 +2953,8 @@ def list_repo_tree( ] ``` """ - repo_type = repo_type or REPO_TYPE_MODEL - revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION + repo_type = repo_type or constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION headers = self._build_hf_headers(token=token) encoded_path_in_repo = "/" + quote(path_in_repo, safe="") if path_in_repo else "" @@ -3012,7 +3013,7 @@ def list_repo_refs( [`GitRefs`]: object containing all information about branches and tags for a repo on the Hub. """ - repo_type = repo_type or REPO_TYPE_MODEL + repo_type = repo_type or constants.REPO_TYPE_MODEL response = get_session().get( f"{self.endpoint}/api/{repo_type}s/{repo_id}/refs", headers=self._build_hf_headers(token=token), @@ -3098,8 +3099,8 @@ def list_repo_commits( [`~utils.RevisionNotFoundError`]: If revision is not found (error 404) on the repo. """ - repo_type = repo_type or REPO_TYPE_MODEL - revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION + repo_type = repo_type or constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION # Paginate over results and return the list of commits. return [ @@ -3177,8 +3178,8 @@ def get_paths_info( ] ``` """ - repo_type = repo_type or REPO_TYPE_MODEL - revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION + repo_type = repo_type or constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION headers = self._build_hf_headers(token=token) response = get_session().post( @@ -3267,11 +3268,11 @@ def super_squash_history( ``` """ if repo_type is None: - repo_type = REPO_TYPE_MODEL - if repo_type not in REPO_TYPES: + repo_type = constants.REPO_TYPE_MODEL + if repo_type not in constants.REPO_TYPES: raise ValueError("Invalid repo type") if branch is None: - branch = DEFAULT_REVISION + branch = constants.DEFAULT_REVISION # Prepare request url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/super-squash/{branch}" @@ -3349,7 +3350,7 @@ def create_repo( path = f"{self.endpoint}/api/repos/create" - if repo_type not in REPO_TYPES: + if repo_type not in constants.REPO_TYPES: raise ValueError("Invalid repo type") json: Dict[str, Any] = {"name": name, "organization": organization, "private": private} @@ -3359,10 +3360,10 @@ def create_repo( if space_sdk is None: raise ValueError( "No space_sdk provided. `create_repo` expects space_sdk to be one" - f" of {SPACES_SDK_TYPES} when repo_type is 'space'`" + f" of {constants.SPACES_SDK_TYPES} when repo_type is 'space'`" ) - if space_sdk not in SPACES_SDK_TYPES: - raise ValueError(f"Invalid space_sdk. Please choose one of {SPACES_SDK_TYPES}.") + if space_sdk not in constants.SPACES_SDK_TYPES: + raise ValueError(f"Invalid space_sdk. Please choose one of {constants.SPACES_SDK_TYPES}.") json["sdk"] = space_sdk if space_sdk is not None and repo_type != "space": @@ -3418,7 +3419,7 @@ def create_repo( # No write permission on the namespace but repo might already exist try: self.repo_info(repo_id=repo_id, repo_type=repo_type, token=token) - if repo_type is None or repo_type == REPO_TYPE_MODEL: + if repo_type is None or repo_type == constants.REPO_TYPE_MODEL: return RepoUrl(f"{self.endpoint}/{repo_id}") return RepoUrl(f"{self.endpoint}/{repo_type}/{repo_id}") except HfHubHTTPError: @@ -3464,7 +3465,7 @@ def delete_repo( path = f"{self.endpoint}/api/repos/delete" - if repo_type not in REPO_TYPES: + if repo_type not in constants.REPO_TYPES: raise ValueError("Invalid repo type") json = {"name": name, "organization": organization} @@ -3518,10 +3519,10 @@ def update_repo_visibility( """ - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: - repo_type = REPO_TYPE_MODEL # default repo type + repo_type = constants.REPO_TYPE_MODEL # default repo type r = get_session().put( url=f"{self.endpoint}/api/{repo_type}s/{repo_id}/settings", @@ -3580,7 +3581,7 @@ def move_repo( raise ValueError(f"Invalid repo_id: {to_id}. It should have a namespace (:namespace:/:repo_name:)") if repo_type is None: - repo_type = REPO_TYPE_MODEL # Hub won't accept `None`. + repo_type = constants.REPO_TYPE_MODEL # Hub won't accept `None`. json = {"fromRepo": from_id, "toRepo": to_id, "type": repo_type} @@ -3750,19 +3751,19 @@ def create_commit( If repository is not found (error 404): wrong repo_id/repo_type, private but not authenticated or repo does not exist. """ - if parent_commit is not None and not REGEX_COMMIT_OID.fullmatch(parent_commit): + if parent_commit is not None and not constants.REGEX_COMMIT_OID.fullmatch(parent_commit): raise ValueError( - f"`parent_commit` is not a valid commit OID. It must match the following regex: {REGEX_COMMIT_OID}" + f"`parent_commit` is not a valid commit OID. It must match the following regex: {constants.REGEX_COMMIT_OID}" ) if commit_message is None or len(commit_message) == 0: raise ValueError("`commit_message` can't be empty, please pass a value.") commit_description = commit_description if commit_description is not None else "" - repo_type = repo_type if repo_type is not None else REPO_TYPE_MODEL - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") - unquoted_revision = revision or DEFAULT_REVISION + repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + unquoted_revision = revision or constants.DEFAULT_REVISION revision = quote(unquoted_revision, safe="") create_pr = create_pr if create_pr is not None else False @@ -3850,7 +3851,7 @@ def create_commit( # Return commit info based on latest commit url_prefix = self.endpoint - if repo_type is not None and repo_type != REPO_TYPE_MODEL: + if repo_type is not None and repo_type != constants.REPO_TYPE_MODEL: url_prefix = f"{url_prefix}/{repo_type}s" return CommitInfo( commit_url=f"{url_prefix}/{repo_id}/commit/{info.sha}", @@ -4085,7 +4086,7 @@ def create_commits_on_pr( commits_on_main_branch = { commit.commit_id for commit in self.list_repo_commits( - repo_id=repo_id, repo_type=repo_type, token=token, revision=DEFAULT_REVISION + repo_id=repo_id, repo_type=repo_type, token=token, revision=constants.DEFAULT_REVISION ) } pr_commits = [ @@ -4304,10 +4305,10 @@ def preupload_lfs_files( >>> create_commit(repo_id, operations=operations, commit_message="Commit all shards") ``` """ - repo_type = repo_type if repo_type is not None else REPO_TYPE_MODEL - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") - revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION + repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION create_pr = create_pr if create_pr is not None else False headers = self._build_hf_headers(token=token) @@ -4535,8 +4536,8 @@ def upload_file( "https://huggingface.co/username/my-model/blob/refs%2Fpr%2F1/remote/file/path.h5" ``` """ - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") commit_message = ( commit_message if commit_message is not None else f"Upload {path_in_repo} with huggingface_hub" @@ -4560,9 +4561,9 @@ def upload_file( if commit_info.pr_url is not None: revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") - if repo_type in REPO_TYPES_URL_PREFIXES: - repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id - revision = revision if revision is not None else DEFAULT_REVISION + if repo_type in constants.REPO_TYPES_URL_PREFIXES: + repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id + revision = revision if revision is not None else constants.DEFAULT_REVISION return CommitInfo( commit_url=commit_info.commit_url, @@ -4837,11 +4838,11 @@ def upload_folder( ``` """ - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if multi_commits: - if revision is not None and revision != DEFAULT_REVISION: + if revision is not None and revision != constants.DEFAULT_REVISION: raise ValueError("Cannot use `multi_commit` to commit changes other than the main branch.") # By default, upload folder to the root directory in repo. @@ -4858,7 +4859,7 @@ def upload_folder( delete_operations = self._prepare_folder_deletions( repo_id=repo_id, repo_type=repo_type, - revision=DEFAULT_REVISION if create_pr else revision, + revision=constants.DEFAULT_REVISION if create_pr else revision, token=token, path_in_repo=path_in_repo, delete_patterns=delete_patterns, @@ -4913,9 +4914,9 @@ def upload_folder( # Create url to uploaded folder (for legacy return value) if create_pr and commit_info.pr_url is not None: revision = quote(_parse_revision_from_pr_url(commit_info.pr_url), safe="") - if repo_type in REPO_TYPES_URL_PREFIXES: - repo_id = REPO_TYPES_URL_PREFIXES[repo_type] + repo_id - revision = revision if revision is not None else DEFAULT_REVISION + if repo_type in constants.REPO_TYPES_URL_PREFIXES: + repo_id = constants.REPO_TYPES_URL_PREFIXES[repo_type] + repo_id + revision = revision if revision is not None else constants.DEFAULT_REVISION return CommitInfo( commit_url=commit_info.commit_url, @@ -5169,7 +5170,7 @@ def get_hf_file_metadata( url: str, token: Union[bool, str, None] = None, proxies: Optional[Dict] = None, - timeout: Optional[float] = DEFAULT_REQUEST_TIMEOUT, + timeout: Optional[float] = constants.DEFAULT_REQUEST_TIMEOUT, ) -> HfFileMetadata: """Fetch metadata of a file versioned on the Hub for a given url. @@ -5216,7 +5217,7 @@ def hf_hub_download( local_dir: Union[str, Path, None] = None, force_download: bool = False, proxies: Optional[Dict] = None, - etag_timeout: float = DEFAULT_ETAG_TIMEOUT, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, token: Union[bool, str, None] = None, local_files_only: bool = False, # Deprecated args @@ -5355,7 +5356,7 @@ def snapshot_download( cache_dir: Union[str, Path, None] = None, local_dir: Union[str, Path, None] = None, proxies: Optional[Dict] = None, - etag_timeout: float = DEFAULT_ETAG_TIMEOUT, + etag_timeout: float = constants.DEFAULT_ETAG_TIMEOUT, force_download: bool = False, token: Union[bool, str, None] = None, local_files_only: bool = False, @@ -5545,14 +5546,14 @@ def get_safetensors_metadata( """ if self.file_exists( # Single safetensors file => non-sharded model repo_id=repo_id, - filename=SAFETENSORS_SINGLE_FILE, + filename=constants.SAFETENSORS_SINGLE_FILE, repo_type=repo_type, revision=revision, token=token, ): file_metadata = self.parse_safetensors_file_metadata( repo_id=repo_id, - filename=SAFETENSORS_SINGLE_FILE, + filename=constants.SAFETENSORS_SINGLE_FILE, repo_type=repo_type, revision=revision, token=token, @@ -5560,12 +5561,14 @@ def get_safetensors_metadata( return SafetensorsRepoMetadata( metadata=None, sharded=False, - weight_map={tensor_name: SAFETENSORS_SINGLE_FILE for tensor_name in file_metadata.tensors.keys()}, - files_metadata={SAFETENSORS_SINGLE_FILE: file_metadata}, + weight_map={ + tensor_name: constants.SAFETENSORS_SINGLE_FILE for tensor_name in file_metadata.tensors.keys() + }, + files_metadata={constants.SAFETENSORS_SINGLE_FILE: file_metadata}, ) elif self.file_exists( # Multiple safetensors files => sharded with index repo_id=repo_id, - filename=SAFETENSORS_INDEX_FILE, + filename=constants.SAFETENSORS_INDEX_FILE, repo_type=repo_type, revision=revision, token=token, @@ -5573,7 +5576,7 @@ def get_safetensors_metadata( # Fetch index index_file = self.hf_hub_download( repo_id=repo_id, - filename=SAFETENSORS_INDEX_FILE, + filename=constants.SAFETENSORS_INDEX_FILE, repo_type=repo_type, revision=revision, token=token, @@ -5607,7 +5610,7 @@ def _parse(filename: str) -> None: else: # Not a safetensors repo raise NotASafetensorsRepoError( - f"'{repo_id}' is not a safetensors repo. Couldn't find '{SAFETENSORS_INDEX_FILE}' or '{SAFETENSORS_SINGLE_FILE}' files." + f"'{repo_id}' is not a safetensors repo. Couldn't find '{constants.SAFETENSORS_INDEX_FILE}' or '{constants.SAFETENSORS_SINGLE_FILE}' files." ) def parse_safetensors_file_metadata( @@ -5668,11 +5671,11 @@ def parse_safetensors_file_metadata( # 2. Parse metadata size metadata_size = struct.unpack(" SAFETENSORS_MAX_HEADER_LENGTH: + if metadata_size > constants.SAFETENSORS_MAX_HEADER_LENGTH: raise SafetensorsParsingError( f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " - f"'{revision or DEFAULT_REVISION}'): safetensors header is too big. Maximum supported size is " - f"{SAFETENSORS_MAX_HEADER_LENGTH} bytes (got {metadata_size})." + f"'{revision or constants.DEFAULT_REVISION}'): safetensors header is too big. Maximum supported size is " + f"{constants.SAFETENSORS_MAX_HEADER_LENGTH} bytes (got {metadata_size})." ) # 3.a. Get metadata from payload @@ -5689,7 +5692,7 @@ def parse_safetensors_file_metadata( except json.JSONDecodeError as e: raise SafetensorsParsingError( f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " - f"'{revision or DEFAULT_REVISION}'): header is not json-encoded string. Please make sure this is a " + f"'{revision or constants.DEFAULT_REVISION}'): header is not json-encoded string. Please make sure this is a " "correctly formatted safetensors file." ) from e @@ -5709,7 +5712,7 @@ def parse_safetensors_file_metadata( except (KeyError, IndexError) as e: raise SafetensorsParsingError( f"Failed to parse safetensors header for '{filename}' (repo '{repo_id}', revision " - f"'{revision or DEFAULT_REVISION}'): header format not recognized. Please make sure this is a correctly" + f"'{revision or constants.DEFAULT_REVISION}'): header format not recognized. Please make sure this is a correctly" " formatted safetensors file." ) from e @@ -5765,7 +5768,7 @@ def create_branch( set to `False`. """ if repo_type is None: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL branch = quote(branch, safe="") # Prepare request @@ -5834,7 +5837,7 @@ def delete_branch( """ if repo_type is None: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL branch = quote(branch, safe="") # Prepare request @@ -5901,8 +5904,8 @@ def create_tag( set to `False`. """ if repo_type is None: - repo_type = REPO_TYPE_MODEL - revision = quote(revision, safe="") if revision is not None else DEFAULT_REVISION + repo_type = constants.REPO_TYPE_MODEL + revision = quote(revision, safe="") if revision is not None else constants.DEFAULT_REVISION # Prepare request tag_url = f"{self.endpoint}/api/{repo_type}s/{repo_id}/tag/{revision}" @@ -5957,7 +5960,7 @@ def delete_tag( If tag is not found. """ if repo_type is None: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL tag = quote(tag, safe="") # Prepare request @@ -6012,8 +6015,8 @@ def get_repo_discussions( repo_id: str, *, author: Optional[str] = None, - discussion_type: Optional[DiscussionTypeFilter] = None, - discussion_status: Optional[DiscussionStatusFilter] = None, + discussion_type: Optional[constants.DiscussionTypeFilter] = None, + discussion_status: Optional[constants.DiscussionStatusFilter] = None, repo_type: Optional[str] = None, token: Union[bool, str, None] = None, ) -> Iterator[Discussion]: @@ -6065,14 +6068,14 @@ def get_repo_discussions( ... print(discussion.num, discussion.title) ``` """ - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: - repo_type = REPO_TYPE_MODEL - if discussion_type is not None and discussion_type not in DISCUSSION_TYPES: - raise ValueError(f"Invalid discussion_type, must be one of {DISCUSSION_TYPES}") - if discussion_status is not None and discussion_status not in DISCUSSION_STATUS: - raise ValueError(f"Invalid discussion_status, must be one of {DISCUSSION_STATUS}") + repo_type = constants.REPO_TYPE_MODEL + if discussion_type is not None and discussion_type not in constants.DISCUSSION_TYPES: + raise ValueError(f"Invalid discussion_type, must be one of {constants.DISCUSSION_TYPES}") + if discussion_status is not None and discussion_status not in constants.DISCUSSION_STATUS: + raise ValueError(f"Invalid discussion_status, must be one of {constants.DISCUSSION_STATUS}") headers = self._build_hf_headers(token=token) path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions" @@ -6159,10 +6162,10 @@ def get_discussion_details( """ if not isinstance(discussion_num, int) or discussion_num <= 0: raise ValueError("Invalid discussion_num, must be a positive integer") - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL path = f"{self.endpoint}/api/{repo_type}s/{repo_id}/discussions/{discussion_num}" headers = self._build_hf_headers(token=token) @@ -6249,10 +6252,10 @@ def create_discussion( or because it is set to `private` and you do not have access. """ - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL if description is not None: description = description.strip() @@ -6359,10 +6362,10 @@ def _post_discussion_changes( """Internal utility to POST changes to a Discussion or Pull Request""" if not isinstance(discussion_num, int) or discussion_num <= 0: raise ValueError("Invalid discussion_num, must be a positive integer") - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL repo_id = f"{repo_type}s/{repo_id}" path = f"{self.endpoint}/api/{repo_id}/discussions/{discussion_num}/{resource}" @@ -7386,7 +7389,7 @@ def list_inference_endpoints( namespace = namespace or self._get_namespace(token=token) response = get_session().get( - f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) @@ -7549,7 +7552,7 @@ def create_inference_endpoint( } response = get_session().post( - f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}", headers=self._build_hf_headers(token=token), json=payload, ) @@ -7597,7 +7600,7 @@ def get_inference_endpoint( namespace = namespace or self._get_namespace(token=token) response = get_session().get( - f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) @@ -7700,7 +7703,7 @@ def update_inference_endpoint( } response = get_session().put( - f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", headers=self._build_hf_headers(token=token), json=payload, ) @@ -7731,7 +7734,7 @@ def delete_inference_endpoint( """ namespace = namespace or self._get_namespace(token=token) response = get_session().delete( - f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) @@ -7764,7 +7767,7 @@ def pause_inference_endpoint( namespace = namespace or self._get_namespace(token=token) response = get_session().post( - f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/pause", + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/pause", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) @@ -7803,7 +7806,7 @@ def resume_inference_endpoint( namespace = namespace or self._get_namespace(token=token) response = get_session().post( - f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/resume", + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/resume", headers=self._build_hf_headers(token=token), ) try: @@ -7845,7 +7848,7 @@ def scale_to_zero_inference_endpoint( namespace = namespace or self._get_namespace(token=token) response = get_session().post( - f"{INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/scale-to-zero", + f"{constants.INFERENCE_ENDPOINTS_ENDPOINT}/endpoint/{namespace}/{name}/scale-to-zero", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) @@ -8527,13 +8530,13 @@ def _list_access_requests( repo_type: Optional[str] = None, token: Union[bool, str, None] = None, ) -> List[AccessRequest]: - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL response = get_session().get( - f"{ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/{status}", + f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/{status}", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) @@ -8682,13 +8685,13 @@ def _handle_access_request( repo_type: Optional[str] = None, token: Union[bool, str, None] = None, ) -> None: - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL response = get_session().post( - f"{ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/handle", + f"{constants.ENDPOINT}/api/{repo_type}s/{repo_id}/user-access-request/handle", headers=self._build_hf_headers(token=token), json={"user": user, "status": status}, ) @@ -8732,13 +8735,13 @@ def grant_access( [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 if the user does not exist on the Hub. """ - if repo_type not in REPO_TYPES: - raise ValueError(f"Invalid repo type, must be one of {REPO_TYPES}") + if repo_type not in constants.REPO_TYPES: + raise ValueError(f"Invalid repo type, must be one of {constants.REPO_TYPES}") if repo_type is None: - repo_type = REPO_TYPE_MODEL + repo_type = constants.REPO_TYPE_MODEL response = get_session().post( - f"{ENDPOINT}/api/models/{repo_id}/user-access-request/grant", + f"{constants.ENDPOINT}/api/models/{repo_id}/user-access-request/grant", headers=self._build_hf_headers(token=token), json={"user": user}, ) @@ -8781,7 +8784,7 @@ def get_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = None) ``` """ response = get_session().get( - f"{ENDPOINT}/api/settings/webhooks/{webhook_id}", + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) @@ -8832,7 +8835,7 @@ def list_webhooks(self, *, token: Union[bool, str, None] = None) -> List[Webhook ``` """ response = get_session().get( - f"{ENDPOINT}/api/settings/webhooks", + f"{constants.ENDPOINT}/api/settings/webhooks", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) @@ -8856,7 +8859,7 @@ def create_webhook( *, url: str, watched: List[Union[Dict, WebhookWatchedItem]], - domains: Optional[List[WEBHOOK_DOMAIN_T]] = None, + domains: Optional[List[constants.WEBHOOK_DOMAIN_T]] = None, secret: Optional[str] = None, token: Union[bool, str, None] = None, ) -> WebhookInfo: @@ -8904,7 +8907,7 @@ def create_webhook( watched_dicts = [asdict(item) if isinstance(item, WebhookWatchedItem) else item for item in watched] response = get_session().post( - f"{ENDPOINT}/api/settings/webhooks", + f"{constants.ENDPOINT}/api/settings/webhooks", json={"watched": watched_dicts, "url": url, "domains": domains, "secret": secret}, headers=self._build_hf_headers(token=token), ) @@ -8930,7 +8933,7 @@ def update_webhook( *, url: Optional[str] = None, watched: Optional[List[Union[Dict, WebhookWatchedItem]]] = None, - domains: Optional[List[WEBHOOK_DOMAIN_T]] = None, + domains: Optional[List[constants.WEBHOOK_DOMAIN_T]] = None, secret: Optional[str] = None, token: Union[bool, str, None] = None, ) -> WebhookInfo: @@ -8982,7 +8985,7 @@ def update_webhook( watched_dicts = [asdict(item) if isinstance(item, WebhookWatchedItem) else item for item in watched] response = get_session().post( - f"{ENDPOINT}/api/settings/webhooks/{webhook_id}", + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", json={"watched": watched_dicts, "url": url, "domains": domains, "secret": secret}, headers=self._build_hf_headers(token=token), ) @@ -9034,7 +9037,7 @@ def enable_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = Non ``` """ response = get_session().post( - f"{ENDPOINT}/api/settings/webhooks/{webhook_id}/enable", + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}/enable", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) @@ -9085,7 +9088,7 @@ def disable_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = No ``` """ response = get_session().post( - f"{ENDPOINT}/api/settings/webhooks/{webhook_id}/disable", + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}/disable", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) @@ -9126,7 +9129,7 @@ def delete_webhook(self, webhook_id: str, *, token: Union[bool, str, None] = Non ``` """ response = get_session().delete( - f"{ENDPOINT}/api/settings/webhooks/{webhook_id}", + f"{constants.ENDPOINT}/api/settings/webhooks/{webhook_id}", headers=self._build_hf_headers(token=token), ) hf_raise_for_status(response) @@ -9272,7 +9275,7 @@ def _validate_yaml(self, content: str, *, repo_type: Optional[str] = None, token - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) if YAML is invalid """ - repo_type = repo_type if repo_type is not None else REPO_TYPE_MODEL + repo_type = repo_type if repo_type is not None else constants.REPO_TYPE_MODEL headers = self._build_hf_headers(token=token) response = get_session().post( @@ -9309,7 +9312,7 @@ def get_user_overview(self, username: str) -> User: [`HTTPError`](https://requests.readthedocs.io/en/latest/api/#requests.HTTPError): HTTP 404 If the user does not exist on the Hub. """ - r = get_session().get(f"{ENDPOINT}/api/users/{username}/overview") + r = get_session().get(f"{constants.ENDPOINT}/api/users/{username}/overview") hf_raise_for_status(r) return User(**r.json()) @@ -9331,7 +9334,7 @@ def list_organization_members(self, organization: str) -> Iterable[User]: """ - r = get_session().get(f"{ENDPOINT}/api/organizations/{organization}/members") + r = get_session().get(f"{constants.ENDPOINT}/api/organizations/{organization}/members") hf_raise_for_status(r) @@ -9355,7 +9358,7 @@ def list_user_followers(self, username: str) -> Iterable[User]: """ - r = get_session().get(f"{ENDPOINT}/api/users/{username}/followers") + r = get_session().get(f"{constants.ENDPOINT}/api/users/{username}/followers") hf_raise_for_status(r) @@ -9379,7 +9382,7 @@ def list_user_following(self, username: str) -> Iterable[User]: """ - r = get_session().get(f"{ENDPOINT}/api/users/{username}/following") + r = get_session().get(f"{constants.ENDPOINT}/api/users/{username}/following") hf_raise_for_status(r) diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index 233b2677e3..8a794f30d0 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -35,7 +35,7 @@ from requests.exceptions import HTTPError import huggingface_hub.lfs -from huggingface_hub import HfApi, SpaceHardware, SpaceStage, SpaceStorage +from huggingface_hub import HfApi, SpaceHardware, SpaceStage, SpaceStorage, constants from huggingface_hub._commit_api import ( CommitOperationAdd, CommitOperationCopy, @@ -43,12 +43,6 @@ _fetch_upload_modes, ) from huggingface_hub.community import DiscussionComment, DiscussionWithDetails -from huggingface_hub.constants import ( - REPO_TYPE_DATASET, - REPO_TYPE_MODEL, - REPO_TYPE_SPACE, - SPACES_SDK_TYPES, -) from huggingface_hub.errors import ( BadRequestError, EntryNotFoundError, @@ -231,41 +225,43 @@ def test_create_update_and_delete_repo(self): self._api.delete_repo(repo_id=repo_id) def test_create_update_and_delete_model_repo(self): - repo_id = self._api.create_repo(repo_id=repo_name(), repo_type=REPO_TYPE_MODEL).repo_id - res = self._api.update_repo_visibility(repo_id=repo_id, private=True, repo_type=REPO_TYPE_MODEL) + repo_id = self._api.create_repo(repo_id=repo_name(), repo_type=constants.REPO_TYPE_MODEL).repo_id + res = self._api.update_repo_visibility(repo_id=repo_id, private=True, repo_type=constants.REPO_TYPE_MODEL) assert res["private"] - res = self._api.update_repo_visibility(repo_id=repo_id, private=False, repo_type=REPO_TYPE_MODEL) + res = self._api.update_repo_visibility(repo_id=repo_id, private=False, repo_type=constants.REPO_TYPE_MODEL) assert not res["private"] - self._api.delete_repo(repo_id=repo_id, repo_type=REPO_TYPE_MODEL) + self._api.delete_repo(repo_id=repo_id, repo_type=constants.REPO_TYPE_MODEL) def test_create_update_and_delete_dataset_repo(self): - repo_id = self._api.create_repo(repo_id=repo_name(), repo_type=REPO_TYPE_DATASET).repo_id - res = self._api.update_repo_visibility(repo_id=repo_id, private=True, repo_type=REPO_TYPE_DATASET) + repo_id = self._api.create_repo(repo_id=repo_name(), repo_type=constants.REPO_TYPE_DATASET).repo_id + res = self._api.update_repo_visibility(repo_id=repo_id, private=True, repo_type=constants.REPO_TYPE_DATASET) assert res["private"] - res = self._api.update_repo_visibility(repo_id=repo_id, private=False, repo_type=REPO_TYPE_DATASET) + res = self._api.update_repo_visibility(repo_id=repo_id, private=False, repo_type=constants.REPO_TYPE_DATASET) assert not res["private"] - self._api.delete_repo(repo_id=repo_id, repo_type=REPO_TYPE_DATASET) + self._api.delete_repo(repo_id=repo_id, repo_type=constants.REPO_TYPE_DATASET) def test_create_update_and_delete_space_repo(self): with pytest.raises(ValueError, match=r"No space_sdk provided.*"): - self._api.create_repo(repo_id=repo_name(), repo_type=REPO_TYPE_SPACE, space_sdk=None) + self._api.create_repo(repo_id=repo_name(), repo_type=constants.REPO_TYPE_SPACE, space_sdk=None) with pytest.raises(ValueError, match=r"Invalid space_sdk.*"): - self._api.create_repo(repo_id=repo_name(), repo_type=REPO_TYPE_SPACE, space_sdk="something") + self._api.create_repo(repo_id=repo_name(), repo_type=constants.REPO_TYPE_SPACE, space_sdk="something") - for sdk in SPACES_SDK_TYPES: - repo_id = self._api.create_repo(repo_id=repo_name(), repo_type=REPO_TYPE_SPACE, space_sdk=sdk).repo_id - res = self._api.update_repo_visibility(repo_id=repo_id, private=True, repo_type=REPO_TYPE_SPACE) + for sdk in constants.SPACES_SDK_TYPES: + repo_id = self._api.create_repo( + repo_id=repo_name(), repo_type=constants.REPO_TYPE_SPACE, space_sdk=sdk + ).repo_id + res = self._api.update_repo_visibility(repo_id=repo_id, private=True, repo_type=constants.REPO_TYPE_SPACE) assert res["private"] - res = self._api.update_repo_visibility(repo_id=repo_id, private=False, repo_type=REPO_TYPE_SPACE) + res = self._api.update_repo_visibility(repo_id=repo_id, private=False, repo_type=constants.REPO_TYPE_SPACE) assert not res["private"] - self._api.delete_repo(repo_id=repo_id, repo_type=REPO_TYPE_SPACE) + self._api.delete_repo(repo_id=repo_id, repo_type=constants.REPO_TYPE_SPACE) def test_move_repo_normal_usage(self): repo_id = f"{USER}/{repo_name()}" new_repo_id = f"{USER}/{repo_name()}" # Spaces not tested on staging (error 500) - for repo_type in [None, REPO_TYPE_MODEL, REPO_TYPE_DATASET]: + for repo_type in [None, constants.REPO_TYPE_MODEL, constants.REPO_TYPE_DATASET]: self._api.create_repo(repo_id=repo_id, repo_type=repo_type) self._api.move_repo(from_id=repo_id, to_id=new_repo_id, repo_type=repo_type) self._api.delete_repo(repo_id=new_repo_id, repo_type=repo_type) @@ -278,7 +274,7 @@ def test_move_repo_target_already_exists(self) -> None: self._api.create_repo(repo_id=repo_id_2) with pytest.raises(HfHubHTTPError, match=r"A model repository called .* already exists"): - self._api.move_repo(from_id=repo_id_1, to_id=repo_id_2, repo_type=REPO_TYPE_MODEL) + self._api.move_repo(from_id=repo_id_1, to_id=repo_id_2, repo_type=constants.REPO_TYPE_MODEL) self._api.delete_repo(repo_id=repo_id_1) self._api.delete_repo(repo_id=repo_id_2) @@ -3675,7 +3671,7 @@ def test_user_agent_is_overwritten(self, mock_build_hf_headers: Mock) -> None: self.assertEqual(mock_build_hf_headers.call_args[1]["user_agent"], {"A": "B"}) -@patch("huggingface_hub.hf_api.ENDPOINT", "https://huggingface.co") +@patch("huggingface_hub.constants.ENDPOINT", "https://huggingface.co") class RepoUrlTest(unittest.TestCase): def test_repo_url_class(self): url = RepoUrl("https://huggingface.co/gpt2") diff --git a/tests/testing_utils.py b/tests/testing_utils.py index e952dc5192..d5e88ba080 100644 --- a/tests/testing_utils.py +++ b/tests/testing_utils.py @@ -217,7 +217,7 @@ def rmtree_with_retry(path: Union[str, Path]) -> None: def with_production_testing(func): file_download = patch("huggingface_hub.file_download.HUGGINGFACE_CO_URL_TEMPLATE", ENDPOINT_PRODUCTION_URL_SCHEME) - hf_api = patch("huggingface_hub.hf_api.ENDPOINT", ENDPOINT_PRODUCTION) + hf_api = patch("huggingface_hub.constants.ENDPOINT", ENDPOINT_PRODUCTION) return hf_api(file_download(func))