Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add list_deployed_models to inference client #1622

Merged
38 changes: 38 additions & 0 deletions src/huggingface_hub/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,3 +129,41 @@ def _as_int(value: Optional[str]) -> Optional[int]:
HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD: int = (
_as_int(os.environ.get("HF_HUB_LOCAL_DIR_AUTO_SYMLINK_THRESHOLD")) or 5 * 1024 * 1024
)

# List frameworks that are handled by the InferenceAPI service. Useful to scan endpoints and check which models are
# deployed and running. Since 95% of the models are using the top 4 frameworks listed below, we scan only those by
# default. We still keep the full list of supported frameworks in case we want to scan all of them.
MAIN_INFERENCE_API_FRAMEWORKS = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Current distribution of libraries (a bit outdated but good proxy)

  • transformers 108k repos
  • SB3 8.6k repos (no API)
  • diffusers 8k repos
  • ml-agents 4k repos (no API)
  • ST 3k repos
  • cleanRL 1.5k repos (no API)
  • timm 1.2k repos (we have API, but quite limited)

I think going with transformers, ST and diffusers support is good for now!

cc @LysandreJik and @julien-c for vis

Copy link
Contributor

@Wauplin Wauplin Sep 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we have a confusion between what we consider a "framework" on the Hub and what is considered as "framework" under the hood for InferenceAPI. As I see it, list_deployed_models should be a helper to help discoverability of models already deployed so that the user don't have to wait for the model to load. I don't think we should choose the MAIN_INFERENCE_API_FRAMEWORKS list based on the Hub distribution but rather on the InferenceAPI distribution.

I know it is only a proxy but if I check every possible inference framework currently deployed, I get this:

diffusers 112
sentence-transformers 13
text-generation-inference 20
transformers 298
adapter-transformers 0
allennlp 0
asteroid 1
bertopic 0
doctr 0
espnet 2
fairseq 4
fastai 0
fasttext 0
flair 2
generic 2
k2 0
keras 0
mindspore 0
nemo 1
open_clip 0
paddlenlp 0
peft 0
pyannote-audio 1
sklearn 0
spacy 0
span-marker 0
speechbrain 3
stanza 0
timm 1

Which is why I chose those 4 frameworks. Also AFAIK, TGI models deployed on InferenceAPI is in fact a curated list of models that we consider as worthy to run which makes it even more important to list by default IMO (since users cannot spin-up LLMs by themselves the same way as any small-enough diffusers/transformers model)

Copy link
Contributor

@Wauplin Wauplin Sep 8, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline.

The current solution is not ideal (might lead to confusion between Hub frameworks vs InferenceAPI frameworks) but hopefully it shouldn't be too bad as we (I) expect most users to not use the frameworks parameters and just take the default output. I added two Tip sections in the docstring to remind that list_deployed_models and get_model_status are meant to be complementary (the first one for discoverability, the second one for users that know what they want).

"diffusers",
"sentence-transformers",
"text-generation-inference",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

text-generation-inference is not a framework. It's a other type of tag, and is always transformers models and based in the architecture. See internal PR https://github.com/huggingface/moon-landing/pull/6894

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I consider it as a framework because that's how the InferenceAPI route call it: https://api-inference.huggingface.co/framework/text-generation-inference . I don't think it's worth changing the naming here (arguably, TGI is not a library like transformers but still a framework to power inference on the server isn't it?).

"transformers",
]

ALL_INFERENCE_API_FRAMEWORKS = MAIN_INFERENCE_API_FRAMEWORKS + [
"adapter-transformers",
"allennlp",
"asteroid",
"bertopic",
"doctr",
"espnet",
"fairseq",
"fastai",
"fasttext",
"flair",
"generic",
"k2",
"keras",
"mindspore",
"nemo",
"open_clip",
"paddlenlp",
"peft",
"pyannote-audio",
"sklearn",
"spacy",
"span-marker",
"speechbrain",
"stanza",
"timm",
]
86 changes: 84 additions & 2 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from requests import HTTPError
from requests.structures import CaseInsensitiveDict

from huggingface_hub.constants import INFERENCE_ENDPOINT
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
from huggingface_hub.inference._common import (
TASKS_EXPECTING_IMAGES,
ContentT,
Expand Down Expand Up @@ -756,6 +756,81 @@ def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -> str:
response = self.post(data=image, model=model, task="image-to-text")
return _bytes_to_dict(response)[0]["generated_text"]

def list_deployed_models(
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
) -> Dict[str, List[str]]:
"""
List models currently deployed on the Inference API service.

This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
frameworks are checked, the more time it will take.

<Tip>

This endpoint is mostly useful for discoverability. If you already know which model you want to use and want to
check its availability, you can directly use [`~InferenceClient.get_model_status`].

</Tip>

Args:
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
custom set of frameworks to check.

Returns:
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.

Example:
```python
>>> from huggingface_hub import InferenceClient
>>> client = InferenceClient()

# Discover zero-shot-classification models currently deployed
>>> models = client.list_deployed_models()
>>> models["zero-shot-classification"]
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]

# List from only 1 framework
>>> client.list_deployed_models("text-generation-inference")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
```
"""
# Resolve which frameworks to check
if frameworks is None:
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
elif frameworks == "all":
frameworks = ALL_INFERENCE_API_FRAMEWORKS
elif isinstance(frameworks, str):
frameworks = [frameworks]
frameworks = list(set(frameworks))

# Fetch them iteratively
models_by_task: Dict[str, List[str]] = {}

def _unpack_response(framework: str, items: List[Dict]) -> None:
for model in items:
if framework == "sentence-transformers":
# Model running with the `sentence-transformers` framework can work with both tasks even if not
# branded as such in the API response
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
else:
models_by_task.setdefault(model["task"], []).append(model["model_id"])
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

for framework in frameworks:
response = get_session().get(f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=self.headers)
hf_raise_for_status(response)
_unpack_response(framework, response.json())

# Sort alphabetically for discoverability and return
for task, models in models_by_task.items():
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
return models_by_task

def object_detection(
self,
image: ContentT,
Expand Down Expand Up @@ -1800,7 +1875,14 @@ def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None)

def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
"""
A function which returns the status of a specific model, from the Inference API.
Get the status of a model hosted on the Inference API.

<Tip>

This endpoint is mostly useful when you already know which model you want to use and want to check its
availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].

</Tip>

Args:
model (`str`, *optional*):
Expand Down
92 changes: 90 additions & 2 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from requests.structures import CaseInsensitiveDict

from huggingface_hub.constants import INFERENCE_ENDPOINT
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, INFERENCE_ENDPOINT, MAIN_INFERENCE_API_FRAMEWORKS
from huggingface_hub.inference._common import (
TASKS_EXPECTING_IMAGES,
ContentT,
Expand Down Expand Up @@ -763,6 +763,87 @@ async def image_to_text(self, image: ContentT, *, model: Optional[str] = None) -
response = await self.post(data=image, model=model, task="image-to-text")
return _bytes_to_dict(response)[0]["generated_text"]

async def list_deployed_models(
self, frameworks: Union[None, str, Literal["all"], List[str]] = None
) -> Dict[str, List[str]]:
"""
List models currently deployed on the Inference API service.

This helper checks deployed models framework by framework. By default, it will check the 4 main frameworks that
are supported and account for 95% of the hosted models. However, if you want a complete list of models you can
specify `frameworks="all"` as input. Alternatively, if you know before-hand which framework you are interested
in, you can also restrict to search to this one (e.g. `frameworks="text-generation-inference"`). The more
frameworks are checked, the more time it will take.

<Tip>

This endpoint is mostly useful for discoverability. If you already know which model you want to use and want to
check its availability, you can directly use [`~InferenceClient.get_model_status`].

</Tip>

Args:
frameworks (`Literal["all"]` or `List[str]` or `str`, *optional*):
The frameworks to filter on. By default only a subset of the available frameworks are tested. If set to
"all", all available frameworks will be tested. It is also possible to provide a single framework or a
custom set of frameworks to check.

Returns:
`Dict[str, List[str]]`: A dictionary mapping task names to a sorted list of model IDs.

Example:
```py
# Must be run in an async contextthon
>>> from huggingface_hub import AsyncInferenceClient
>>> client = AsyncInferenceClient()

# Discover zero-shot-classification models currently deployed
>>> models = await client.list_deployed_models()
>>> models["zero-shot-classification"]
['Narsil/deberta-large-mnli-zero-cls', 'facebook/bart-large-mnli', ...]

# List from only 1 framework
>>> await client.list_deployed_models("text-generation-inference")
{'text-generation': ['bigcode/starcoder', 'meta-llama/Llama-2-70b-chat-hf', ...], ...}
```
"""
# Resolve which frameworks to check
if frameworks is None:
frameworks = MAIN_INFERENCE_API_FRAMEWORKS
elif frameworks == "all":
frameworks = ALL_INFERENCE_API_FRAMEWORKS
elif isinstance(frameworks, str):
frameworks = [frameworks]
frameworks = list(set(frameworks))

# Fetch them iteratively
models_by_task: Dict[str, List[str]] = {}

def _unpack_response(framework: str, items: List[Dict]) -> None:
for model in items:
if framework == "sentence-transformers":
# Model running with the `sentence-transformers` framework can work with both tasks even if not
# branded as such in the API response
models_by_task.setdefault("feature-extraction", []).append(model["model_id"])
models_by_task.setdefault("sentence-similarity", []).append(model["model_id"])
else:
models_by_task.setdefault(model["task"], []).append(model["model_id"])

async def _fetch_framework(framework: str) -> None:
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}")
response.raise_for_status()
_unpack_response(framework, await response.json())

import asyncio

await asyncio.gather(*[_fetch_framework(framework) for framework in frameworks])

# Sort alphabetically for discoverability and return
for task, models in models_by_task.items():
models_by_task[task] = sorted(set(models), key=lambda x: x.lower())
return models_by_task

async def object_detection(
self,
image: ContentT,
Expand Down Expand Up @@ -1823,7 +1904,14 @@ def _resolve_url(self, model: Optional[str] = None, task: Optional[str] = None)

async def get_model_status(self, model: Optional[str] = None) -> ModelStatus:
"""
A function which returns the status of a specific model, from the Inference API.
Get the status of a model hosted on the Inference API.

<Tip>

This endpoint is mostly useful when you already know which model you want to use and want to check its
availability. If you want to discover already deployed models, you should rather use [`~InferenceClient.list_deployed_models`].

</Tip>

Args:
model (`str`, *optional*):
Expand Down
14 changes: 14 additions & 0 deletions tests/test_inference_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,17 @@ async def test_get_status_unknown_model() -> None:
async def test_get_status_model_as_url() -> None:
with pytest.raises(NotImplementedError):
await AsyncInferenceClient().get_model_status("https://unkown/model")


@pytest.mark.asyncio
async def test_list_deployed_models_single_frameworks() -> None:
models_by_task = await AsyncInferenceClient().list_deployed_models("text-generation-inference")
assert isinstance(models_by_task, dict)
for task, models in models_by_task.items():
assert isinstance(task, str)
assert isinstance(models, list)
for model in models:
assert isinstance(model, str)

assert "text-generation" in models_by_task
assert "bigscience/bloom" in models_by_task["text-generation"]
31 changes: 31 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from PIL import Image

from huggingface_hub import InferenceClient, hf_hub_download
from huggingface_hub.constants import ALL_INFERENCE_API_FRAMEWORKS, MAIN_INFERENCE_API_FRAMEWORKS
from huggingface_hub.inference._client import _open_as_binary
from huggingface_hub.utils import HfHubHTTPError, build_hf_headers

Expand Down Expand Up @@ -528,3 +529,33 @@ def test_model_as_url(self) -> None:
client = InferenceClient()
with self.assertRaises(NotImplementedError):
client.get_model_status("https://unkown/model")


class TestListDeployedModels(unittest.TestCase):
@patch("huggingface_hub.inference._client.get_session")
def test_list_deployed_models_main_frameworks_mock(self, get_session_mock: MagicMock) -> None:
InferenceClient().list_deployed_models()
self.assertEqual(
len(get_session_mock.return_value.get.call_args_list),
len(MAIN_INFERENCE_API_FRAMEWORKS),
)

@patch("huggingface_hub.inference._client.get_session")
def test_list_deployed_models_all_frameworks_mock(self, get_session_mock: MagicMock) -> None:
InferenceClient().list_deployed_models("all")
self.assertEqual(
len(get_session_mock.return_value.get.call_args_list),
len(ALL_INFERENCE_API_FRAMEWORKS),
)

def test_list_deployed_models_single_frameworks(self) -> None:
models_by_task = InferenceClient().list_deployed_models("text-generation-inference")
self.assertIsInstance(models_by_task, dict)
for task, models in models_by_task.items():
self.assertIsInstance(task, str)
self.assertIsInstance(models, list)
for model in models:
self.assertIsInstance(model, str)

self.assertIn("text-generation", models_by_task)
self.assertIn("bigscience/bloom", models_by_task["text-generation"])
25 changes: 25 additions & 0 deletions utils/generate_async_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ def generate_async_client_code(code: str) -> str:

# Adapt get_model_status
code = _adapt_get_model_status(code)

# Adapt list_deployed_models
code = _adapt_list_deployed_models(code)

return code


Expand Down Expand Up @@ -359,6 +363,27 @@ def _adapt_get_model_status(code: str) -> str:
return code.replace(sync_snippet, async_snippet)


def _adapt_list_deployed_models(code: str) -> str:
sync_snippet = """
for framework in frameworks:
response = get_session().get(f"{INFERENCE_ENDPOINT}/framework/{framework}", headers=self.headers)
hf_raise_for_status(response)
_unpack_response(framework, response.json())""".strip()

async_snippet = """
async def _fetch_framework(framework: str) -> None:
async with _import_aiohttp().ClientSession(headers=self.headers) as client:
response = await client.get(f"{INFERENCE_ENDPOINT}/framework/{framework}")
response.raise_for_status()
_unpack_response(framework, await response.json())

import asyncio

await asyncio.gather(*[_fetch_framework(framework) for framework in frameworks])""".strip()

return code.replace(sync_snippet, async_snippet)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
Expand Down