diff --git a/docs/changes.rst b/docs/changes.rst index 3be2a186..cb63da52 100644 --- a/docs/changes.rst +++ b/docs/changes.rst @@ -16,6 +16,8 @@ v0.9 estimators. :pr:`384` by :user:`Reid Johnson `. - Fix an issue with visualizing Skops files for `scikit-learn` tree estimators. :pr:`386` by :user:`Reid Johnson `. +- :func:`skops.hug_utils.get_model_output` is deprecated and will be removed in version + 0.10. :pr:`396` by `Adrin Jalali`_. v0.8 ---- diff --git a/skops/hub_utils/_hf_hub.py b/skops/hub_utils/_hf_hub.py index 41eaa160..fb161388 100644 --- a/skops/hub_utils/_hf_hub.py +++ b/skops/hub_utils/_hf_hub.py @@ -9,6 +9,7 @@ import json import os import shutil +import warnings from pathlib import Path from typing import Any, List, Literal, MutableMapping, Optional, Sequence, Union @@ -702,11 +703,16 @@ def download( shutil.rmtree(path=cached_folder) +# TODO(v0.10): remove this function def get_model_output(repo_id: str, data: Any, token: Optional[str] = None) -> Any: """Returns the output of the model using Hugging Face Hub's inference API. See the :ref:`User Guide ` for more details. + .. deprecated:: 0.9 + Will be removed in version 0.10. Use ``huggingface_hub.InferenceClient`` + instead. + Parameters ---------- repo_id: str @@ -737,8 +743,12 @@ def get_model_output(repo_id: str, data: Any, token: Optional[str] = None) -> An Also note that if the model repo is private, the inference API would not be available. """ - # TODO: the "type: ignore" should eventually become unncessary when hf_hub - # is updated + warnings.warn( + "This feature is no longer free on hf.co and therefore this function will" + " be removed in the next release. Use `huggingface_hub.InferenceClient`" + " instead.", + FutureWarning, + ) model_info = HfApi().model_info(repo_id=repo_id, use_auth_token=token) # type: ignore if not model_info.pipeline_tag: raise ValueError( diff --git a/skops/hub_utils/tests/test_hf_hub.py b/skops/hub_utils/tests/test_hf_hub.py index d2ee3931..59b51ebb 100644 --- a/skops/hub_utils/tests/test_hf_hub.py +++ b/skops/hub_utils/tests/test_hf_hub.py @@ -503,7 +503,8 @@ def test_inference( X_test = data.data.head(5) y_pred = model.predict(X_test) - output = get_model_output(repo_id, data=X_test, token=HF_HUB_TOKEN) + with pytest.warns(FutureWarning): + output = get_model_output(repo_id, data=X_test, token=HF_HUB_TOKEN) # cleanup client.delete_repo(repo_id=repo_id, token=HF_HUB_TOKEN) @@ -512,6 +513,12 @@ def test_inference( assert np.allclose(output, y_pred) +def test_get_model_output_deprecated(): + with pytest.raises(Exception): + with pytest.warns(FutureWarning, match="This feature is no longer free"): + get_model_output("dummy", data=iris.data) + + def test_get_config(repo_path, config_json): config_path, file_format = config_json config = get_config(repo_path)