From 55bb9ec9663b24ab4c14ada38523c62fc8a4b282 Mon Sep 17 00:00:00 2001 From: Simon Brandeis <33657802+SBrandeis@users.noreply.github.com> Date: Tue, 13 Sep 2022 17:49:41 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=97=91=20Deprecate=20`token`=20in=20read-?= =?UTF-8?q?only=20methods=20of=20`HfApi`=20in=20favor=20of=20`use=5Fauth?= =?UTF-8?q?=5Ftoken`=20(#928)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * 🚚 Proposal: `use_auth_token` for read only `HfApi` methods This PR deprecates the `token` argument in read-only methods of `HfApi` in favor of `use_auth_token`. `use_auth_token` is more flexible as it allows not passing a token, whereas `token=None` fetches the token from `HfFolder`. I personally prefer the semantics of `use_auth_token`, that I find very legible: - `use_auth_token=False` - `use_auth_token="token"` - `use_auth_token=True` cc @lhoestq * 🔊 Deprecate in 0.12 * ♻ factorization * ✅ Update unit tests to ensure a warning is raised * 🎨 Smol typing fix * 💄 Code format * 🩹 --- src/huggingface_hub/_snapshot_download.py | 2 +- src/huggingface_hub/hf_api.py | 145 +++++++++++++++------- src/huggingface_hub/inference_api.py | 2 +- src/huggingface_hub/keras_mixin.py | 4 +- src/huggingface_hub/repository.py | 4 +- tests/test_hf_api.py | 35 +++++- tests/test_hubmixin.py | 8 +- tests/test_inference_api.py | 16 ++- tests/test_keras_integration.py | 2 +- 9 files changed, 157 insertions(+), 61 deletions(-) diff --git a/src/huggingface_hub/_snapshot_download.py b/src/huggingface_hub/_snapshot_download.py index b8978a17f3..c7e72580bb 100644 --- a/src/huggingface_hub/_snapshot_download.py +++ b/src/huggingface_hub/_snapshot_download.py @@ -167,7 +167,7 @@ def snapshot_download( # if we have internet connection we retrieve the correct folder name from the huggingface api _api = HfApi() repo_info = _api.repo_info( - repo_id=repo_id, repo_type=repo_type, revision=revision, token=token + repo_id=repo_id, repo_type=repo_type, revision=revision, use_auth_token=token ) filtered_repo_files = list( filter_repo_objects( diff --git a/src/huggingface_hub/hf_api.py b/src/huggingface_hub/hf_api.py index 6bcb88cd6b..e6ed7088da 100644 --- a/src/huggingface_hub/hf_api.py +++ b/src/huggingface_hub/hf_api.py @@ -92,8 +92,8 @@ def _validate_repo_id_deprecation(repo_id, name, organization): if repo_id and (name or organization): raise ValueError( - "Only pass `repo_id` and leave deprecated `name` and " - "`organization` to be None." + "Only pass `repo_id` and leave deprecated `name` and `organization` to be" + " None." ) elif name or organization: warnings.warn( @@ -617,8 +617,8 @@ def whoami(self, token: Optional[str] = None) -> Dict: token = HfFolder.get_token() if token is None: raise ValueError( - "You need to pass a valid `token` or login by using `huggingface-cli " - "login`" + "You need to pass a valid `token` or login by using `huggingface-cli" + " login`" ) path = f"{self.endpoint}/api/whoami-v2" r = requests.get(path, headers={"authorization": f"Bearer {token}"}) @@ -651,7 +651,7 @@ def _is_valid_token(self, token: str): def _validate_or_retrieve_token( self, - token: Optional[str] = None, + token: Optional[Union[str, bool]] = None, name: Optional[str] = None, function_name: Optional[str] = None, ): @@ -676,8 +676,8 @@ def _validate_or_retrieve_token( token = HfFolder.get_token() if token is None: raise EnvironmentError( - "You need to provide a `token` or be logged in to Hugging " - "Face with `huggingface-cli login`." + "You need to provide a `token` or be logged in to Hugging Face with" + " `huggingface-cli login`." ) if name is not None: if self._is_valid_token(name): @@ -696,6 +696,27 @@ def _validate_or_retrieve_token( return token, name + def _build_auth_headers( + self, *, token: Optional[str], use_auth_token: Optional[Union[str, bool]] + ) -> Dict[str, str]: + """Helper to build Authorization header from kwargs. To be removed in 0.12.0 when `token` is deprecated.""" + if token is not None: + warnings.warn( + "`token` is deprecated and will be removed in 0.12.0. Use" + " `use_auth_token` instead.", + FutureWarning, + ) + + auth_token = None + if use_auth_token is None and token is None: + # To maintain backwards-compatibility. To be removed in 0.12.0 + auth_token = HfFolder.get_token() + elif use_auth_token: + auth_token, _ = self._validate_or_retrieve_token(use_auth_token) + else: + auth_token = token + return {"authorization": f"Bearer {auth_token}"} if auth_token else {} + @staticmethod def set_access_token(access_token: str): """ @@ -840,9 +861,10 @@ def list_models( ``` """ path = f"{self.endpoint}/api/models" + headers = {} if use_auth_token: token, name = self._validate_or_retrieve_token(use_auth_token) - headers = {"authorization": f"Bearer {token}"} if use_auth_token else None + headers["authorization"] = f"Bearer {token}" params = {} if filter is not None: if isinstance(filter, ModelFilter): @@ -1037,9 +1059,10 @@ def list_datasets( ``` """ path = f"{self.endpoint}/api/datasets" + headers = {} if use_auth_token: token, name = self._validate_or_retrieve_token(use_auth_token) - headers = {"authorization": f"Bearer {token}"} if use_auth_token else None + headers["authorization"] = f"Bearer {token}" params = {} if filter is not None: if isinstance(filter, DatasetFilter): @@ -1176,9 +1199,10 @@ def list_spaces( `List[SpaceInfo]`: a list of [`huggingface_hub.hf_api.SpaceInfo`] objects """ path = f"{self.endpoint}/api/spaces" + headers = {} if use_auth_token: token, name = self._validate_or_retrieve_token(use_auth_token) - headers = {"authorization": f"Bearer {token}"} if use_auth_token else None + headers["authorization"] = f"Bearer {token}" params = {} if filter is not None: params.update({"filter": filter}) @@ -1217,6 +1241,7 @@ def model_info( timeout: Optional[float] = None, securityStatus: Optional[bool] = None, files_metadata: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, ) -> ModelInfo: """ Get info on one specific model on huggingface.co @@ -1231,6 +1256,7 @@ def model_info( The revision of the model repository from which to get the information. token (`str`, *optional*): + Deprecated in favor of `use_auth_token`. Will be removed in 0.12.0. An authentication token (See https://huggingface.co/settings/token) timeout (`float`, *optional*): Whether to set a timeout for the request to the Hub. @@ -1240,6 +1266,10 @@ def model_info( files_metadata (`bool`, *optional*): Whether or not to retrieve metadata for files in the repository (size, LFS metadata, etc). Defaults to `False`. + use_auth_token (`bool` or `str`, *optional*): + Whether to use the `auth_token` provided from the + `huggingface_hub` cli. If not logged in, a valid `auth_token` + can be passed in as a string. Returns: [`huggingface_hub.hf_api.ModelInfo`]: The model repository information. @@ -1256,9 +1286,7 @@ def model_info( """ - if token is None: - token = HfFolder.get_token() - + headers = self._build_auth_headers(token=token, use_auth_token=use_auth_token) path = ( f"{self.endpoint}/api/models/{repo_id}" if revision is None @@ -1266,7 +1294,6 @@ def model_info( f"{self.endpoint}/api/models/{repo_id}/revision/{quote(revision, safe='')}" ) ) - headers = {"authorization": f"Bearer {token}"} if token is not None else None params = {} if securityStatus: params["securityStatus"] = True @@ -1291,6 +1318,7 @@ def dataset_info( token: Optional[str] = None, timeout: Optional[float] = None, files_metadata: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, ) -> DatasetInfo: """ Get info on one specific dataset on huggingface.co. @@ -1305,12 +1333,17 @@ def dataset_info( The revision of the dataset repository from which to get the information. token (`str`, *optional*): + Deprecated in favor of `use_auth_token`. Will be removed in 0.12.0. An authentication token (See https://huggingface.co/settings/token) timeout (`float`, *optional*): Whether to set a timeout for the request to the Hub. files_metadata (`bool`, *optional*): Whether or not to retrieve metadata for files in the repository (size, LFS metadata, etc). Defaults to `False`. + use_auth_token (`bool` or `str`, *optional*): + Whether to use the `auth_token` provided from the + `huggingface_hub` cli. If not logged in, a valid `auth_token` + can be passed in as a string. Returns: [`hf_api.DatasetInfo`]: The dataset repository information. @@ -1327,8 +1360,7 @@ def dataset_info( """ - if token is None: - token = HfFolder.get_token() + headers = self._build_auth_headers(token=token, use_auth_token=use_auth_token) path = ( f"{self.endpoint}/api/datasets/{repo_id}" @@ -1337,7 +1369,6 @@ def dataset_info( f"{self.endpoint}/api/datasets/{repo_id}/revision/{quote(revision, safe='')}" ) ) - headers = {"authorization": f"Bearer {token}"} if token is not None else None params = {} if files_metadata: params["blobs"] = True @@ -1356,6 +1387,7 @@ def space_info( token: Optional[str] = None, timeout: Optional[float] = None, files_metadata: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, ) -> SpaceInfo: """ Get info on one specific Space on huggingface.co. @@ -1370,12 +1402,17 @@ def space_info( The revision of the space repository from which to get the information. token (`str`, *optional*): + Deprecated in favor of `use_auth_token`. Will be removed in 0.12.0. An authentication token (See https://huggingface.co/settings/token) timeout (`float`, *optional*): Whether to set a timeout for the request to the Hub. files_metadata (`bool`, *optional*): Whether or not to retrieve metadata for files in the repository (size, LFS metadata, etc). Defaults to `False`. + use_auth_token (`bool` or `str`, *optional*): + Whether to use the `auth_token` provided from the + `huggingface_hub` cli. If not logged in, a valid `auth_token` + can be passed in as a string. Returns: [`~hf_api.SpaceInfo`]: The space repository information. @@ -1392,9 +1429,7 @@ def space_info( """ - if token is None: - token = HfFolder.get_token() - + headers = self._build_auth_headers(token=token, use_auth_token=use_auth_token) path = ( f"{self.endpoint}/api/spaces/{repo_id}" if revision is None @@ -1402,7 +1437,6 @@ def space_info( f"{self.endpoint}/api/spaces/{repo_id}/revision/{quote(revision, safe='')}" ) ) - headers = {"authorization": f"Bearer {token}"} if token is not None else None params = {} if files_metadata: params["blobs"] = True @@ -1422,6 +1456,7 @@ def repo_info( token: Optional[str] = None, timeout: Optional[float] = None, files_metadata: bool = False, + use_auth_token: Optional[Union[bool, str]] = None, ) -> Union[ModelInfo, DatasetInfo, SpaceInfo]: """ Get the info object for a given repo of a given type. @@ -1434,12 +1469,17 @@ def repo_info( The revision of the repository from which to get the information. token (`str`, *optional*): + Deprecated in favor of `use_auth_token`. Will be removed in 0.12.0. An authentication token (See https://huggingface.co/settings/token) timeout (`float`, *optional*): Whether to set a timeout for the request to the Hub. files_metadata (`bool`, *optional*): Whether or not to retrieve metadata for files in the repository (size, LFS metadata, etc). Defaults to `False`. + use_auth_token (`bool` or `str`, *optional*): + Whether to use the `auth_token` provided from the + `huggingface_hub` cli. If not logged in, a valid `auth_token` + can be passed in as a string. Returns: `Union[SpaceInfo, DatasetInfo, ModelInfo]`: The repository information, as a @@ -1459,31 +1499,21 @@ def repo_info( """ if repo_type is None or repo_type == "model": - return self.model_info( - repo_id, - revision=revision, - token=token, - timeout=timeout, - files_metadata=files_metadata, - ) + method = self.model_info elif repo_type == "dataset": - return self.dataset_info( - repo_id, - revision=revision, - token=token, - timeout=timeout, - files_metadata=files_metadata, - ) + method = self.dataset_info elif repo_type == "space": - return self.space_info( - repo_id, - revision=revision, - token=token, - timeout=timeout, - files_metadata=files_metadata, - ) + method = self.space_info else: raise ValueError("Unsupported repo type.") + return method( + repo_id, + revision=revision, + token=token, + timeout=timeout, + files_metadata=files_metadata, + use_auth_token=use_auth_token, + ) @validate_hf_hub_args def list_repo_files( @@ -1494,6 +1524,7 @@ def list_repo_files( repo_type: Optional[str] = None, token: Optional[str] = None, timeout: Optional[float] = None, + use_auth_token: Optional[Union[bool, str]] = None, ) -> List[str]: """ Get the list of files in a given repo. @@ -1510,9 +1541,14 @@ def list_repo_files( space, `None` or `"model"` if uploading to a model. Default is `None`. token (`str`, *optional*): + Deprecated in favor of `use_auth_token`. Will be removed in 0.12.0. An authentication token (See https://huggingface.co/settings/token) timeout (`float`, *optional*): Whether to set a timeout for the request to the Hub. + use_auth_token (`bool` or `str`, *optional*): + Whether to use the `auth_token` provided from the + `huggingface_hub` cli. If not logged in, a valid `auth_token` + can be passed in as a string. Returns: `List[str]`: the list of files in a given repository. @@ -1523,6 +1559,7 @@ def list_repo_files( repo_type=repo_type, token=token, timeout=timeout, + use_auth_token=use_auth_token, ) return [f.rfilename for f in repo_info.siblings] @@ -1997,8 +2034,8 @@ def create_commit( """ if parent_commit is not None and not REGEX_COMMIT_OID.fullmatch(parent_commit): raise ValueError( - "`parent_commit` is not a valid commit OID. It must match the following" - f" regex: {REGEX_COMMIT_OID}" + "`parent_commit` is not a valid commit OID. It must match the" + f" following regex: {REGEX_COMMIT_OID}" ) if commit_message is None or len(commit_message) == 0: @@ -2498,6 +2535,7 @@ def get_full_repo_name( *, organization: Optional[str] = None, token: Optional[str] = None, + use_auth_token: Optional[Union[bool, str]] = None, ): """ Returns the repository name for a given model ID and optional @@ -2510,13 +2548,28 @@ def get_full_repo_name( If passed, the repository name will be in the organization namespace instead of the user namespace. token (`str`, *optional*): + Deprecated in favor of `use_auth_token`. Will be removed in 0.12.0. The Hugging Face authentication token + use_auth_token (`bool` or `str`, *optional*): + Whether to use the `auth_token` provided from the + `huggingface_hub` cli. If not logged in, a valid `auth_token` + can be passed in as a string. Returns: `str`: The repository name in the user's namespace ({username}/{model_id}) if no organization is passed, and under the organization namespace ({organization}/{model_id}) otherwise. """ + if token is not None: + warnings.warn( + "`token` is deprecated and will be removed in 0.12.0. Use" + " `use_auth_token` instead.", + FutureWarning, + ) + + if token is None and use_auth_token: + token, name = self._validate_or_retrieve_token(use_auth_token) + if organization is None: if "/" in model_id: username = model_id.split("/")[0] @@ -3363,8 +3416,8 @@ def _parse_revision_from_pr_url(pr_url: str) -> str: re_match = re.match(_REGEX_DISCUSSION_URL, pr_url) if re_match is None: raise RuntimeError( - "Unexpected response from the hub, expected a Pull Request URL but" - f" got: '{pr_url}'" + "Unexpected response from the hub, expected a Pull Request URL but got:" + f" '{pr_url}'" ) return f"refs/pr/{re_match[1]}" diff --git a/src/huggingface_hub/inference_api.py b/src/huggingface_hub/inference_api.py index f55af74251..9b5dd9afc4 100644 --- a/src/huggingface_hub/inference_api.py +++ b/src/huggingface_hub/inference_api.py @@ -112,7 +112,7 @@ def __init__( self.headers["Authorization"] = f"Bearer {token}" # Configure task - model_info = HfApi().model_info(repo_id=repo_id, token=token) + model_info = HfApi().model_info(repo_id=repo_id, use_auth_token=token) if not model_info.pipeline_tag and not task: raise ValueError( "Task not specified in the repository. Please add it to the model card" diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py index b834497895..dfcd90758b 100644 --- a/src/huggingface_hub/keras_mixin.py +++ b/src/huggingface_hub/keras_mixin.py @@ -439,7 +439,9 @@ def push_to_hub_keras( # Delete previous log files from Hub operations += [ CommitOperationDelete(path_in_repo=file) - for file in api.list_repo_files(repo_id=repo_id, token=token) + for file in api.list_repo_files( + repo_id=repo_id, use_auth_token=token + ) if file.startswith("logs/") ] diff --git a/src/huggingface_hub/repository.py b/src/huggingface_hub/repository.py index 86562e948c..7f23b9e34d 100644 --- a/src/huggingface_hub/repository.py +++ b/src/huggingface_hub/repository.py @@ -658,7 +658,9 @@ def clone_from(self, repo_url: str, use_auth_token: Union[bool, str, None] = Non if namespace == user or namespace in valid_organisations: try: _ = HfApi().repo_info( - f"{repo_id}", repo_type=self.repo_type, token=token + f"{repo_id}", + repo_type=self.repo_type, + use_auth_token=token, ) except HTTPError: if self.repo_type == "space": diff --git a/tests/test_hf_api.py b/tests/test_hf_api.py index bd8140ebd3..3bd41e2a99 100644 --- a/tests/test_hf_api.py +++ b/tests/test_hf_api.py @@ -609,11 +609,13 @@ def test_delete_file(self): self._api.delete_repo(repo_id=REPO_NAME, token=self._token) def test_get_full_repo_name(self): - repo_name_with_no_org = self._api.get_full_repo_name("model", token=self._token) + repo_name_with_no_org = self._api.get_full_repo_name( + "model", use_auth_token=self._token + ) self.assertEqual(repo_name_with_no_org, f"{USER}/model") repo_name_with_no_org = self._api.get_full_repo_name( - "model", organization="org", token=self._token + "model", organization="org", use_auth_token=self._token ) self.assertEqual(repo_name_with_no_org, "org/model") @@ -1468,11 +1470,38 @@ def test_model_info(self): ): _ = self._api.model_info(repo_id=f"{USER}/{self.REPO_NAME}") # Test we can access model info with a token + with self.assertWarns(FutureWarning): + model_info = self._api.model_info( + repo_id=f"{USER}/{self.REPO_NAME}", token=self._token + ) + self.assertIsInstance(model_info, ModelInfo) model_info = self._api.model_info( - repo_id=f"{USER}/{self.REPO_NAME}", token=self._token + repo_id=f"{USER}/{self.REPO_NAME}", use_auth_token=self._token ) self.assertIsInstance(model_info, ModelInfo) + def test_dataset_info(self): + shutil.rmtree(os.path.dirname(HfFolder.path_token), ignore_errors=True) + # Test we cannot access model info without a token + with self.assertRaisesRegex( + requests.exceptions.HTTPError, + re.compile( + r"401 Client Error(.+)\(Request ID: .+\)(.*)Repository Not Found", + flags=re.DOTALL, + ), + ): + _ = self._api.dataset_info(repo_id=f"{USER}/{self.REPO_NAME}") + # Test we can access model info with a token + with self.assertWarns(FutureWarning): + dataset_info = self._api.dataset_info( + repo_id=f"{USER}/{self.REPO_NAME}", token=self._token + ) + self.assertIsInstance(dataset_info, DatasetInfo) + dataset_info = self._api.dataset_info( + repo_id=f"{USER}/{self.REPO_NAME}", use_auth_token=self._token + ) + self.assertIsInstance(dataset_info, DatasetInfo) + def test_list_private_datasets(self): orig = len(self._api.list_datasets()) new = len(self._api.list_datasets(use_auth_token=self._token)) diff --git a/tests/test_hubmixin.py b/tests/test_hubmixin.py index 28b4f3ce8c..d78f6e2ef4 100644 --- a/tests/test_hubmixin.py +++ b/tests/test_hubmixin.py @@ -178,7 +178,7 @@ def test_push_to_hub_via_http_basic(self): ) # Test model id exists - model_info = self._api.model_info(repo_id, token=self._token) + model_info = self._api.model_info(repo_id, use_auth_token=self._token) self.assertEqual(model_info.modelId, repo_id) # Test config has been pushed to hub @@ -204,7 +204,7 @@ def test_push_to_hub_via_git_deprecated(self): use_auth_token=self._token, ) - model_info = self._api.model_info(repo_id, token=self._token) + model_info = self._api.model_info(repo_id, use_auth_token=self._token) self.assertEqual(model_info.modelId, repo_id) self._api.delete_repo(repo_id=repo_id, token=self._token) @@ -236,7 +236,9 @@ def test_push_to_hub_via_git_use_lfs_by_default(self): git_email="ci@dummy.com", ) - model_info = self._api.model_info(f"{USER}/{REPO_NAME}", token=self._token) + model_info = self._api.model_info( + f"{USER}/{REPO_NAME}", use_auth_token=self._token + ) self.assertTrue("large_file.txt" in [f.rfilename for f in model_info.siblings]) self._api.delete_repo(repo_id=f"{REPO_NAME}", token=self._token) diff --git a/tests/test_inference_api.py b/tests/test_inference_api.py index ede0ca53e1..d92aacbd92 100644 --- a/tests/test_inference_api.py +++ b/tests/test_inference_api.py @@ -67,9 +67,12 @@ def test_inference_with_dict_inputs(self): @with_production_testing def test_inference_with_audio(self): api = InferenceApi("facebook/wav2vec2-base-960h") - dataset = datasets.load_dataset( - "patrickvonplaten/librispeech_asr_dummy", "clean", split="validation" - ) + with self.assertWarns(FutureWarning): + dataset = datasets.load_dataset( + "patrickvonplaten/librispeech_asr_dummy", + "clean", + split="validation", + ) data = self.read(dataset["file"][0]) result = api(data=data) self.assertIsInstance(result, dict) @@ -78,7 +81,12 @@ def test_inference_with_audio(self): @with_production_testing def test_inference_with_image(self): api = InferenceApi("google/vit-base-patch16-224") - dataset = datasets.load_dataset("Narsil/image_dummy", "image", split="test") + with self.assertWarns(FutureWarning): + dataset = datasets.load_dataset( + "Narsil/image_dummy", + "image", + split="test", + ) data = self.read(dataset["file"][0]) result = api(data=data) self.assertIsInstance(result, list) diff --git a/tests/test_keras_integration.py b/tests/test_keras_integration.py index 5628e02e54..becd60658e 100644 --- a/tests/test_keras_integration.py +++ b/tests/test_keras_integration.py @@ -174,7 +174,7 @@ def test_push_to_hub_keras_mixin_via_http_basic(self): ) # Test model id exists - model_info = self._api.model_info(repo_id, token=self._token) + model_info = self._api.model_info(repo_id, use_auth_token=self._token) self.assertEqual(model_info.modelId, repo_id) # Test config has been pushed to hub