diff --git a/docs/source/guides/inference.md b/docs/source/guides/inference.md index a6ea1c0800..043b156da7 100644 --- a/docs/source/guides/inference.md +++ b/docs/source/guides/inference.md @@ -139,7 +139,7 @@ has a simple API that supports the most common tasks. Here is a list of the curr | | [Summarization](https://huggingface.co/tasks/summarization) | ✅ | [`~InferenceClient.summarization`] | | | [Table Question Answering](https://huggingface.co/tasks/table-question-answering) | | | | | [Text Classification](https://huggingface.co/tasks/text-classification) | | | -| | [Text Generation](https://huggingface.co/tasks/text-generation) | | | +| | [Text Generation](https://huggingface.co/tasks/text-generation) | ✅ | [`~InferenceClient.text_generation`] | | | [Token Classification](https://huggingface.co/tasks/token-classification) | | | | | [Translation](https://huggingface.co/tasks/translation) | | | | | [Zero Shot Classification](https://huggingface.co/tasks/zero-shot-image-classification) | | | diff --git a/docs/source/package_reference/inference_client.md b/docs/source/package_reference/inference_client.md index ded8be8d84..9ce0fa7be6 100644 --- a/docs/source/package_reference/inference_client.md +++ b/docs/source/package_reference/inference_client.md @@ -36,6 +36,35 @@ For most tasks, the return value has a built-in type (string, list, image...). H [[autodoc]] huggingface_hub.inference._types.ImageSegmentationOutput +### Text generation types + +[`~InferenceClient.text_generation`] task has a greater support than other tasks in `InferenceClient`. In +particular, user inputs and server outputs are validated using [Pydantic](https://docs.pydantic.dev/latest/) +if this package is installed. Therefore, we recommend installing it (`pip install pydantic`) +for a better user experience. + +You can find below the dataclasses used to validate data and in particular [`~huggingface_hub.inference._text_generation.TextGenerationParameters`] (input), +[`~huggingface_hub.inference._text_generation.TextGenerationResponse`] (output) and +[`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`] (streaming output). + +[[autodoc]] huggingface_hub.inference._text_generation.TextGenerationParameters + +[[autodoc]] huggingface_hub.inference._text_generation.TextGenerationResponse + +[[autodoc]] huggingface_hub.inference._text_generation.TextGenerationStreamResponse + +[[autodoc]] huggingface_hub.inference._text_generation.InputToken + +[[autodoc]] huggingface_hub.inference._text_generation.Token + +[[autodoc]] huggingface_hub.inference._text_generation.FinishReason + +[[autodoc]] huggingface_hub.inference._text_generation.BestOfSequence + +[[autodoc]] huggingface_hub.inference._text_generation.Details + +[[autodoc]] huggingface_hub.inference._text_generation.StreamDetails + ## InferenceAPI [`InferenceAPI`] is the legacy way to call the Inference API. The interface is more simplistic and requires knowing diff --git a/pyproject.toml b/pyproject.toml index c7e8c4c3ed..e21bd2d4cf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,6 +7,9 @@ preview = true ignore_missing_imports = true no_implicit_optional = true scripts_are_modules = true +plugins = [ + "pydantic.mypy" +] [tool.ruff] # Ignored rules: diff --git a/setup.py b/setup.py index c852d528ce..e4c91f36e7 100644 --- a/setup.py +++ b/setup.py @@ -55,6 +55,7 @@ def get_version() -> str: "Pillow", "gradio", # to test webhooks "numpy", # for embeddings + "pydantic", # for text-generation-inference ] # Typing extra dependencies list is duplicated in `.pre-commit-config.yaml` @@ -66,6 +67,7 @@ def get_version() -> str: "types-toml", "types-tqdm", "types-urllib3", + "pydantic", # for text-generation dataclasses ] extras["quality"] = [ diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 798c842e57..568a06d5de 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -36,19 +36,48 @@ # - Only the main parameters are publicly exposed. Power users can always read the docs for more options. import base64 import io +import json import logging import time import warnings from contextlib import contextmanager +from dataclasses import asdict from pathlib import Path -from typing import TYPE_CHECKING, Any, BinaryIO, ContextManager, Dict, Generator, List, Optional, Union, overload +from typing import ( + TYPE_CHECKING, + Any, + BinaryIO, + ContextManager, + Dict, + Generator, + Iterable, + List, + Optional, + Set, + Union, + overload, +) from requests import HTTPError, Response from requests.structures import CaseInsensitiveDict from ..constants import ENDPOINT, INFERENCE_ENDPOINT -from ..utils import build_hf_headers, get_session, hf_raise_for_status, is_numpy_available, is_pillow_available +from ..utils import ( + BadRequestError, + build_hf_headers, + get_session, + hf_raise_for_status, + is_numpy_available, + is_pillow_available, +) from ..utils._typing import Literal +from ._text_generation import ( + TextGenerationParameters, + TextGenerationRequest, + TextGenerationResponse, + TextGenerationStreamResponse, + raise_text_generation_error, +) from ._types import ClassificationOutput, ConversationalOutput, ImageSegmentationOutput @@ -121,6 +150,7 @@ def post( data: Optional[ContentT] = None, model: Optional[str] = None, task: Optional[str] = None, + stream: bool = False, ) -> Response: """ Make a POST request to the inference server. @@ -138,6 +168,8 @@ def post( task (`str`, *optional*): The task to perform on the inference. Used only to default to a recommended model if `model` is not provided. At least `model` or `task` must be provided. Defaults to None. + stream (`bool`, *optional*): + Whether to iterate over streaming APIs. Returns: Response: The `requests` HTTP response. @@ -165,6 +197,7 @@ def post( headers=self.headers, cookies=self.cookies, timeout=self.timeout, + stream=stream, ) except TimeoutError as error: # Convert any `TimeoutError` to a `InferenceTimeoutError` @@ -659,6 +692,374 @@ def summarization( response = self.post(json=payload, model=model, task="summarization") return response.json()[0]["summary_text"] + @overload + def text_generation( # type: ignore + self, + prompt: str, + *, + details: Literal[False] = ..., + stream: Literal[False] = ..., + model: Optional[str] = None, + do_sample: bool = False, + max_new_tokens: int = 20, + best_of: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: bool = False, + ) -> str: + ... + + @overload + def text_generation( # type: ignore + self, + prompt: str, + *, + details: Literal[True] = ..., + stream: Literal[False] = ..., + model: Optional[str] = None, + do_sample: bool = False, + max_new_tokens: int = 20, + best_of: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: bool = False, + ) -> TextGenerationResponse: + ... + + @overload + def text_generation( # type: ignore + self, + prompt: str, + *, + details: Literal[False] = ..., + stream: Literal[True] = ..., + model: Optional[str] = None, + do_sample: bool = False, + max_new_tokens: int = 20, + best_of: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: bool = False, + ) -> Iterable[str]: + ... + + @overload + def text_generation( + self, + prompt: str, + *, + details: Literal[True] = ..., + stream: Literal[True] = ..., + model: Optional[str] = None, + do_sample: bool = False, + max_new_tokens: int = 20, + best_of: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: bool = False, + ) -> Iterable[TextGenerationStreamResponse]: + ... + + def text_generation( + self, + prompt: str, + *, + details: bool = False, + stream: bool = False, + model: Optional[str] = None, + do_sample: bool = False, + max_new_tokens: int = 20, + best_of: Optional[int] = None, + repetition_penalty: Optional[float] = None, + return_full_text: bool = False, + seed: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + temperature: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + truncate: Optional[int] = None, + typical_p: Optional[float] = None, + watermark: bool = False, + decoder_input_details: bool = False, + ) -> Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]: + """ + Given a prompt, generate the following text. + + It is recommended to have Pydantic installed in order to get inputs validated. This is preferable as it allow + early failures. + + API endpoint is supposed to run with the `text-generation-inference` backend (TGI). This backend is the + go-to solution to run large language models at scale. However, for some smaller models (e.g. "gpt2") the + default `transformers` + `api-inference` solution is still in use. Both approaches have very similar APIs, but + not exactly the same. This method is compatible with both approaches but some parameters are only available for + `text-generation-inference`. If some parameters are ignored, a warning message is triggered but the process + continues correctly. + + To learn more about the TGI project, please refer to https://github.com/huggingface/text-generation-inference. + + Args: + prompt (`str`): + Input text. + details (`bool`, *optional*): + By default, text_generation returns a string. Pass `details=True` if you want a detailed output (tokens, + probabilities, seed, finish reason, etc.). Only available for models running on with the + `text-generation-inference` backend. + stream (`bool`, *optional*): + By default, text_generation returns the full generated text. Pass `stream=True` if you want a stream of + tokens to be returned. Only available for models running on with the `text-generation-inference` + backend. + model (`str`, *optional*): + The model to use for inference. Can be a model ID hosted on the Hugging Face Hub or a URL to a deployed + Inference Endpoint. This parameter overrides the model defined at the instance level. Defaults to None. + do_sample (`bool`): + Activate logits sampling + max_new_tokens (`int`): + Maximum number of generated tokens + best_of (`int`): + Generate best_of sequences and return the one if the highest token logprobs + repetition_penalty (`float`): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + return_full_text (`bool`): + Whether to prepend the prompt to the generated text + seed (`int`): + Random sampling seed + stop_sequences (`List[str]`): + Stop generating tokens if a member of `stop_sequences` is generated + temperature (`float`): + The value used to module the logits distribution. + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + truncate (`int`): + Truncate inputs tokens to the given size + typical_p (`float`): + Typical Decoding mass + See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + watermark (`bool`): + Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + decoder_input_details (`bool`): + Return the decoder input token logprobs and ids. You must set `details=True` as well for it to be taken + into account. Defaults to `False`. + + Returns: + `Union[str, TextGenerationResponse, Iterable[str], Iterable[TextGenerationStreamResponse]]`: + Generated text returned from the server: + - if `stream=False` and `details=False`, the generated text is returned as a `str` (default) + - if `stream=True` and `details=False`, the generated text is returned token by token as a `Iterable[str]` + - if `stream=False` and `details=True`, the generated text is returned with more details as a [`~huggingface_hub.inference._text_generation.TextGenerationResponse`] + - if `details=True` and `stream=True`, the generated text is returned token by token as a iterable of [`~huggingface_hub.inference._text_generation.TextGenerationStreamResponse`] + + Raises: + `ValidationError`: + If input values are not valid. No HTTP call is made to the server. + [`InferenceTimeoutError`]: + If the model is unavailable or the request times out. + `HTTPError`: + If the request fails with an HTTP error status code other than HTTP 503. + + Example: + ```py + >>> from huggingface_hub import InferenceClient + >>> client = InferenceClient() + + # Case 1: generate text + >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12) + '100% open source and built to be easy to use.' + + # Case 2: iterate over the generated tokens. Useful for large generation. + >>> for token in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, stream=True): + ... print(token) + 100 + % + open + source + and + built + to + be + easy + to + use + . + + # Case 3: get more details about the generation process. + >>> client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True) + TextGenerationResponse( + generated_text='100% open source and built to be easy to use.', + details=Details( + finish_reason=, + generated_tokens=12, + seed=None, + prefill=[ + InputToken(id=487, text='The', logprob=None), + InputToken(id=53789, text=' hugging', logprob=-13.171875), + (...) + InputToken(id=204, text=' ', logprob=-7.0390625) + ], + tokens=[ + Token(id=1425, text='100', logprob=-1.0175781, special=False), + Token(id=16, text='%', logprob=-0.0463562, special=False), + (...) + Token(id=25, text='.', logprob=-0.5703125, special=False) + ], + best_of_sequences=None + ) + ) + + # Case 4: iterate over the generated tokens with more details. + # Last object is more complete, containing the full generated text and the finish reason. + >>> for details in client.text_generation("The huggingface_hub library is ", max_new_tokens=12, details=True, stream=True): + ... print(details) + ... + TextGenerationStreamResponse(token=Token(id=1425, text='100', logprob=-1.0175781, special=False), generated_text=None, details=None) + TextGenerationStreamResponse(token=Token(id=16, text='%', logprob=-0.0463562, special=False), generated_text=None, details=None) + TextGenerationStreamResponse(token=Token(id=1314, text=' open', logprob=-1.3359375, special=False), generated_text=None, details=None) + TextGenerationStreamResponse(token=Token(id=3178, text=' source', logprob=-0.28100586, special=False), generated_text=None, details=None) + TextGenerationStreamResponse(token=Token(id=273, text=' and', logprob=-0.5961914, special=False), generated_text=None, details=None) + TextGenerationStreamResponse(token=Token(id=3426, text=' built', logprob=-1.9423828, special=False), generated_text=None, details=None) + TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-1.4121094, special=False), generated_text=None, details=None) + TextGenerationStreamResponse(token=Token(id=314, text=' be', logprob=-1.5224609, special=False), generated_text=None, details=None) + TextGenerationStreamResponse(token=Token(id=1833, text=' easy', logprob=-2.1132812, special=False), generated_text=None, details=None) + TextGenerationStreamResponse(token=Token(id=271, text=' to', logprob=-0.08520508, special=False), generated_text=None, details=None) + TextGenerationStreamResponse(token=Token(id=745, text=' use', logprob=-0.39453125, special=False), generated_text=None, details=None) + TextGenerationStreamResponse(token=Token( + id=25, + text='.', + logprob=-0.5703125, + special=False), + generated_text='100% open source and built to be easy to use.', + details=StreamDetails(finish_reason=, generated_tokens=12, seed=None) + ) + ``` + """ + # NOTE: Text-generation integration is taken from the text-generation-inference project. It has more features + # like input/output validation (if Pydantic is installed). See `_text_generation.py` header for more details. + + if decoder_input_details and not details: + warnings.warn( + "`decoder_input_details=True` has been passed to the server but `details=False` is set meaning that" + " the output from the server will be truncated." + ) + decoder_input_details = False + + # Validate parameters + parameters = TextGenerationParameters( + best_of=best_of, + details=details, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + repetition_penalty=repetition_penalty, + return_full_text=return_full_text, + seed=seed, + stop=stop_sequences if stop_sequences is not None else [], + temperature=temperature, + top_k=top_k, + top_p=top_p, + truncate=truncate, + typical_p=typical_p, + watermark=watermark, + decoder_input_details=decoder_input_details, + ) + request = TextGenerationRequest(inputs=prompt, stream=stream, parameters=parameters) + payload = asdict(request) + + # Remove some parameters if not a TGI server + if not _is_tgi_server(model): + ignored_parameters = [] + for key in "watermark", "stop", "details", "decoder_input_details": + if payload["parameters"][key] is not None: + ignored_parameters.append(key) + del payload["parameters"][key] + if len(ignored_parameters) > 0: + warnings.warn( + ( + "API endpoint/model for text-generation is not served via TGI. Ignoring parameters" + f" {ignored_parameters}." + ), + UserWarning, + ) + if details: + warnings.warn( + ( + "API endpoint/model for text-generation is not served via TGI. Parameter `details=True` will" + " be ignored meaning only the generated text will be returned." + ), + UserWarning, + ) + details = False + if stream: + raise ValueError( + "API endpoint/model for text-generation is not served via TGI. Cannot return output as a stream." + " Please pass `stream=False` as input." + ) + + # Handle errors separately for more precise error messages + try: + response = self.post(json=payload, model=model, task="text-generation", stream=stream) + except HTTPError as e: + if isinstance(e, BadRequestError) and "The following `model_kwargs` are not used by the model" in str(e): + _set_as_non_tgi(model) + return self.text_generation( # type: ignore + prompt=prompt, + details=details, + stream=stream, + model=model, + do_sample=do_sample, + max_new_tokens=max_new_tokens, + best_of=best_of, + repetition_penalty=repetition_penalty, + return_full_text=return_full_text, + seed=seed, + stop_sequences=stop_sequences, + temperature=temperature, + top_k=top_k, + top_p=top_p, + truncate=truncate, + typical_p=typical_p, + watermark=watermark, + decoder_input_details=decoder_input_details, + ) + raise_text_generation_error(e) + + # Parse output + if stream: + return _stream_text_generation_response(response, details) # type: ignore + elif details: + return TextGenerationResponse(**response.json()[0]) + else: + return response.json()[0]["generated_text"] + def text_to_image( self, prompt: str, @@ -888,6 +1289,26 @@ def _response_to_image(response: Response) -> "Image": return Image.open(io.BytesIO(response.content)) +def _stream_text_generation_response( + response: Response, details: bool +) -> Union[Iterable[str], Iterable[TextGenerationStreamResponse]]: + # Parse ServerSentEvents + for byte_payload in response.iter_lines(): + # Skip line + if byte_payload == b"\n": + continue + + payload = byte_payload.decode("utf-8") + + # Event data + if payload.startswith("data:"): + # Decode payload + json_payload = json.loads(payload.lstrip("data:").rstrip("/n")) + # Parse payload + output = TextGenerationStreamResponse(**json_payload) + yield output.token.text if not details else output + + def _import_pil_image(): """Make sure `PIL` is installed on the machine.""" if not is_pillow_available(): @@ -914,3 +1335,28 @@ def _first_or_none(items: List[Any]) -> Optional[Any]: return items[0] or None except IndexError: return None + + +# "TGI servers" are servers running with the `text-generation-inference` backend. +# This backend is the go-to solution to run large language models at scale. However, +# for some smaller models (e.g. "gpt2") the default `transformers` + `api-inference` +# solution is still in use. +# +# Both approaches have very similar APIs, but not exactly the same. What we do first in +# the `text_generation` method is to assume the model is served via TGI. If we realize +# it's not the case (i.e. we receive an HTTP 400 Bad Request), we fallback to the +# default API with a warning message. We remember for each model if it's a TGI server +# or not using `_NON_TGI_SERVERS` global variable. +# +# For more details, see https://github.com/huggingface/text-generation-inference and +# https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task. + +_NON_TGI_SERVERS: Set[Optional[str]] = set() + + +def _set_as_non_tgi(model: Optional[str]) -> None: + _NON_TGI_SERVERS.add(model) + + +def _is_tgi_server(model: Optional[str]) -> bool: + return model not in _NON_TGI_SERVERS diff --git a/src/huggingface_hub/inference/_text_generation.py b/src/huggingface_hub/inference/_text_generation.py new file mode 100644 index 0000000000..b7a26c7e60 --- /dev/null +++ b/src/huggingface_hub/inference/_text_generation.py @@ -0,0 +1,477 @@ +# coding=utf-8 +# Copyright 2023-present, the HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Original implementation taken from the `text-generation` Python client (see https://pypi.org/project/text-generation/ +# and https://github.com/huggingface/text-generation-inference/tree/main/clients/python) +# +# Changes compared to original implementation: +# - use pydantic.dataclasses instead of BaseModel +# - default to Python's dataclasses if Pydantic is not installed (same implementation but no validation) +# - added default values for all parameters (not needed in BaseModel but dataclasses yes) +# - integrated in `huggingface_hub.InferenceClient`` +# - added `stream: bool` and `details: bool` in the `text_generation` method instead of having different methods for each use case +# - NO asyncio support yet => TODO soon + +from dataclasses import field +from enum import Enum +from typing import List, NoReturn, Optional + +from requests import HTTPError + +from ..utils import is_pydantic_available + + +if is_pydantic_available(): + from pydantic import validator + from pydantic.dataclasses import dataclass +else: + # No validation if Pydantic is not installed + from dataclasses import dataclass # type: ignore + + def validator(x): # type: ignore + return lambda y: y + + +@dataclass +class TextGenerationParameters: + """ + Parameters for text generation. + + Args: + do_sample (`bool`, *optional*): + Activate logits sampling. Defaults to False. + max_new_tokens (`int`, *optional*): + Maximum number of generated tokens. Defaults to 20. + repetition_penalty (`Optional[float]`, *optional*): + The parameter for repetition penalty. A value of 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) + for more details. Defaults to None. + return_full_text (`bool`, *optional*): + Whether to prepend the prompt to the generated text. Defaults to False. + stop (`List[str]`, *optional*): + Stop generating tokens if a member of `stop_sequences` is generated. Defaults to an empty list. + seed (`Optional[int]`, *optional*): + Random sampling seed. Defaults to None. + temperature (`Optional[float]`, *optional*): + The value used to modulate the logits distribution. Defaults to None. + top_k (`Optional[int]`, *optional*): + The number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None. + top_p (`Optional[float]`, *optional*): + If set to a value less than 1, only the smallest set of most probable tokens with probabilities that add up + to `top_p` or higher are kept for generation. Defaults to None. + truncate (`Optional[int]`, *optional*): + Truncate input tokens to the given size. Defaults to None. + typical_p (`Optional[float]`, *optional*): + Typical Decoding mass. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) + for more information. Defaults to None. + best_of (`Optional[int]`, *optional*): + Generate `best_of` sequences and return the one with the highest token logprobs. Defaults to None. + watermark (`bool`, *optional*): + Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226). Defaults to False. + details (`bool`, *optional*): + Get generation details. Defaults to False. + decoder_input_details (`bool`, *optional*): + Get decoder input token logprobs and ids. Defaults to False. + """ + + # Activate logits sampling + do_sample: bool = False + # Maximum number of generated tokens + max_new_tokens: int = 20 + # The parameter for repetition penalty. 1.0 means no penalty. + # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + repetition_penalty: Optional[float] = None + # Whether to prepend the prompt to the generated text + return_full_text: bool = False + # Stop generating tokens if a member of `stop_sequences` is generated + stop: List[str] = field(default_factory=lambda: []) + # Random sampling seed + seed: Optional[int] = None + # The value used to module the logits distribution. + temperature: Optional[float] = None + # The number of highest probability vocabulary tokens to keep for top-k-filtering. + top_k: Optional[int] = None + # If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + # higher are kept for generation. + top_p: Optional[float] = None + # truncate inputs tokens to the given size + truncate: Optional[int] = None + # Typical Decoding mass + # See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information + typical_p: Optional[float] = None + # Generate best_of sequences and return the one if the highest token logprobs + best_of: Optional[int] = None + # Watermarking with [A Watermark for Large Language Models](https://arxiv.org/abs/2301.10226) + watermark: bool = False + # Get generation details + details: bool = False + # Get decoder input token logprobs and ids + decoder_input_details: bool = False + + @validator("best_of") + def valid_best_of(cls, field_value, values): + if field_value is not None: + if field_value <= 0: + raise ValueError("`best_of` must be strictly positive") + if field_value > 1 and values["seed"] is not None: + raise ValueError("`seed` must not be set when `best_of` is > 1") + sampling = ( + values["do_sample"] + | (values["temperature"] is not None) + | (values["top_k"] is not None) + | (values["top_p"] is not None) + | (values["typical_p"] is not None) + ) + if field_value > 1 and not sampling: + raise ValueError("you must use sampling when `best_of` is > 1") + + return field_value + + @validator("repetition_penalty") + def valid_repetition_penalty(cls, v): + if v is not None and v <= 0: + raise ValueError("`repetition_penalty` must be strictly positive") + return v + + @validator("seed") + def valid_seed(cls, v): + if v is not None and v < 0: + raise ValueError("`seed` must be positive") + return v + + @validator("temperature") + def valid_temp(cls, v): + if v is not None and v <= 0: + raise ValueError("`temperature` must be strictly positive") + return v + + @validator("top_k") + def valid_top_k(cls, v): + if v is not None and v <= 0: + raise ValueError("`top_k` must be strictly positive") + return v + + @validator("top_p") + def valid_top_p(cls, v): + if v is not None and (v <= 0 or v >= 1.0): + raise ValueError("`top_p` must be > 0.0 and < 1.0") + return v + + @validator("truncate") + def valid_truncate(cls, v): + if v is not None and v <= 0: + raise ValueError("`truncate` must be strictly positive") + return v + + @validator("typical_p") + def valid_typical_p(cls, v): + if v is not None and (v <= 0 or v >= 1.0): + raise ValueError("`typical_p` must be > 0.0 and < 1.0") + return v + + +@dataclass +class TextGenerationRequest: + """ + Request object for text generation (only for internal use). + + Args: + inputs (`str`): + The prompt for text generation. + parameters (`Optional[TextGenerationParameters]`, *optional*): + Generation parameters. + stream (`bool`, *optional*): + Whether to stream output tokens. Defaults to False. + """ + + # Prompt + inputs: str + # Generation parameters + parameters: Optional[TextGenerationParameters] = None + # Whether to stream output tokens + stream: bool = False + + @validator("inputs") + def valid_input(cls, v): + if not v: + raise ValueError("`inputs` cannot be empty") + return v + + @validator("stream") + def valid_best_of_stream(cls, field_value, values): + parameters = values["parameters"] + if parameters is not None and parameters.best_of is not None and parameters.best_of > 1 and field_value: + raise ValueError("`best_of` != 1 is not supported when `stream` == True") + return field_value + + +# Decoder input tokens +@dataclass +class InputToken: + """ + Represents an input token. + + Args: + id (`int`): + Token ID from the model tokenizer. + text (`str`): + Token text. + logprob (`float` or `None`): + Log probability of the token. Optional since the logprob of the first token cannot be computed. + """ + + # Token ID from the model tokenizer + id: int + # Token text + text: str + # Logprob + # Optional since the logprob of the first token cannot be computed + logprob: Optional[float] = None + + +# Generated tokens +@dataclass +class Token: + """ + Represents a token. + + Args: + id (`int`): + Token ID from the model tokenizer. + text (`str`): + Token text. + logprob (`float`): + Log probability of the token. + special (`bool`): + Indicates whether the token is a special token. It can be used to ignore + tokens when concatenating. + """ + + # Token ID from the model tokenizer + id: int + # Token text + text: str + # Logprob + logprob: float + # Is the token a special token + # Can be used to ignore tokens when concatenating + special: bool + + +# Generation finish reason +class FinishReason(str, Enum): + # number of generated tokens == `max_new_tokens` + Length = "length" + # the model generated its end of sequence token + EndOfSequenceToken = "eos_token" + # the model generated a text included in `stop_sequences` + StopSequence = "stop_sequence" + + +# Additional sequences when using the `best_of` parameter +@dataclass +class BestOfSequence: + """ + Represents a best-of sequence generated during text generation. + + Args: + generated_text (`str`): + The generated text. + finish_reason (`FinishReason`): + The reason for the generation to finish, represented by a `FinishReason` value. + generated_tokens (`int`): + The number of generated tokens in the sequence. + seed (`Optional[int]`): + The sampling seed if sampling was activated. + prefill (`List[InputToken]`): + The decoder input tokens. Empty if `decoder_input_details` is False. Defaults to an empty list. + tokens (`List[Token]`): + The generated tokens. Defaults to an empty list. + """ + + # Generated text + generated_text: str + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] = None + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] = field(default_factory=lambda: []) + # Generated tokens + tokens: List[Token] = field(default_factory=lambda: []) + + +# `generate` details +@dataclass +class Details: + """ + Represents details of a text generation. + + Args: + finish_reason (`FinishReason`): + The reason for the generation to finish, represented by a `FinishReason` value. + generated_tokens (`int`): + The number of generated tokens. + seed (`Optional[int]`): + The sampling seed if sampling was activated. + prefill (`List[InputToken]`, *optional*): + The decoder input tokens. Empty if `decoder_input_details` is False. Defaults to an empty list. + tokens (`List[Token]`): + The generated tokens. Defaults to an empty list. + best_of_sequences (`Optional[List[BestOfSequence]]`): + Additional sequences when using the `best_of` parameter. + """ + + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] = None + # Decoder input tokens, empty if decoder_input_details is False + prefill: List[InputToken] = field(default_factory=lambda: []) + # Generated tokens + tokens: List[Token] = field(default_factory=lambda: []) + # Additional sequences when using the `best_of` parameter + best_of_sequences: Optional[List[BestOfSequence]] = None + + +# `generate` return value +@dataclass +class TextGenerationResponse: + """ + Represents a response for text generation. + + In practice, if `details=False` is passed (default), only the generated text is returned. + + Args: + generated_text (`str`): + The generated text. + details (`Optional[Details]`): + Generation details. Returned only if `details=True` is sent to the server. + """ + + # Generated text + generated_text: str + # Generation details + details: Optional[Details] = None + + +# `generate_stream` details +@dataclass +class StreamDetails: + """ + Represents details of a text generation stream. + + Args: + finish_reason (`FinishReason`): + The reason for the generation to finish, represented by a `FinishReason` value. + generated_tokens (`int`): + The number of generated tokens. + seed (`Optional[int]`): + The sampling seed if sampling was activated. + """ + + # Generation finish reason + finish_reason: FinishReason + # Number of generated tokens + generated_tokens: int + # Sampling seed if sampling was activated + seed: Optional[int] = None + + +# `generate_stream` return value +@dataclass +class TextGenerationStreamResponse: + """ + Represents a response for text generation when `stream=True` is passed + + Args: + token (`Token`): + The generated token. + generated_text (`Optional[str]`, *optional*): + The complete generated text. Only available when the generation is finished. + details (`Optional[StreamDetails]`, *optional*): + Generation details. Only available when the generation is finished. + """ + + # Generated token + token: Token + # Complete generated text + # Only available when the generation is finished + generated_text: Optional[str] = None + # Generation details + # Only available when the generation is finished + details: Optional[StreamDetails] = None + + +# TEXT GENERATION ERRORS +# ---------------------- +# Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation +# inference project (https://github.com/huggingface/text-generation-inference). +# ---------------------- + + +class TextGenerationError(HTTPError): + """Generic error raised if text-generation went wrong.""" + + +# Text Generation Inference Errors +class ValidationError(TextGenerationError): + """Server-side validation error.""" + + +class GenerationError(TextGenerationError): + pass + + +class OverloadedError(TextGenerationError): + pass + + +class IncompleteGenerationError(TextGenerationError): + pass + + +def raise_text_generation_error(http_error: HTTPError) -> NoReturn: + """ + Try to parse text-generation-inference error message and raise HTTPError in any case. + + Args: + error (`HTTPError`): + The HTTPError that have been raised. + """ + # Try to parse a Text Generation Inference error + try: + payload = http_error.response.json() + message = payload.get("error") + error_type = payload.get("error_type") + except Exception: # no payload + raise http_error + + # If error_type => more information than `hf_raise_for_status` + if error_type is not None: + if error_type == "generation": + raise GenerationError(message) from http_error + if error_type == "incomplete_generation": + raise IncompleteGenerationError(message) from http_error + if error_type == "overloaded": + raise OverloadedError(message) from http_error + if error_type == "validation": + raise ValidationError(message) from http_error + + # Otherwise, fallback to default error + raise http_error diff --git a/src/huggingface_hub/utils/__init__.py b/src/huggingface_hub/utils/__init__.py index ac33c14049..06fca95c0a 100644 --- a/src/huggingface_hub/utils/__init__.py +++ b/src/huggingface_hub/utils/__init__.py @@ -58,6 +58,7 @@ get_jinja_version, get_numpy_version, get_pillow_version, + get_pydantic_version, get_pydot_version, get_python_version, get_tensorboard_version, @@ -73,6 +74,7 @@ is_jinja_available, is_notebook, is_pillow_available, + is_pydantic_available, is_pydot_available, is_tensorboard_available, is_tf_available, diff --git a/src/huggingface_hub/utils/_runtime.py b/src/huggingface_hub/utils/_runtime.py index 52a975762e..6bd2501239 100644 --- a/src/huggingface_hub/utils/_runtime.py +++ b/src/huggingface_hub/utils/_runtime.py @@ -41,6 +41,7 @@ "jinja": {"Jinja2"}, "numpy": {"numpy"}, "pillow": {"Pillow"}, + "pydantic": {"pydantic"}, "pydot": {"pydot"}, "tensorboard": {"tensorboardX"}, "tensorflow": ( @@ -159,6 +160,15 @@ def get_pillow_version() -> str: return _get_version("pillow") +# Pydantic +def is_pydantic_available() -> bool: + return _is_available("pydantic") + + +def get_pydantic_version() -> str: + return _get_version("pydantic") + + # Pydot def is_pydot_available() -> bool: return _is_available("pydot") @@ -287,6 +297,7 @@ def dump_environment_info() -> Dict[str, Any]: info["gradio"] = get_gradio_version() info["tensorboard"] = get_tensorboard_version() info["numpy"] = get_numpy_version() + info["pydantic"] = get_pydantic_version() # Environment variables info["ENDPOINT"] = constants.ENDPOINT diff --git a/tests/cassettes/TestTextGenerationClientVCR.test_generate_best_of.yaml b/tests/cassettes/TestTextGenerationClientVCR.test_generate_best_of.yaml new file mode 100644 index 0000000000..d4a7f5d521 --- /dev/null +++ b/tests/cassettes/TestTextGenerationClientVCR.test_generate_best_of.yaml @@ -0,0 +1,53 @@ +interactions: +- request: + body: '{"inputs": "test", "parameters": {"do_sample": true, "max_new_tokens": + 1, "repetition_penalty": null, "return_full_text": false, "stop": [], "seed": + null, "temperature": null, "top_k": null, "top_p": null, "truncate": null, "typical_p": + null, "best_of": 2, "watermark": false, "details": true, "decoder_input_details": + true}, "stream": false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '342' + Content-Type: + - application/json + user-agent: + - unknown/None; hf_hub/0.16.0.dev0; python/3.10.6; torch/1.12.1; tensorflow/2.11.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/google/flan-t5-xxl + response: + body: + string: '[{"generated_text":"5","details":{"finish_reason":"length","generated_tokens":1,"seed":2835381823020158940,"prefill":[{"id":0,"text":"","logprob":null}],"tokens":[{"id":305,"text":" + 5","logprob":-6.0976562,"special":false}],"best_of_sequences":[{"generated_text":"10-","finish_reason":"length","generated_tokens":1,"seed":1621834340256795699,"prefill":[{"id":0,"text":"","logprob":null}],"tokens":[{"id":9445,"text":" + 10-","logprob":-10.6875,"special":false}]}]}}]' + headers: + Connection: + - keep-alive + Content-Length: + - '474' + Content-Type: + - application/json + Date: + - Mon, 19 Jun 2023 15:19:38 GMT + access-control-allow-credentials: + - 'true' + vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-time: + - '110' + x-compute-type: + - cache + x-request-id: + - YrO0tIq-p_KLVqDmvlPjF + x-sha: + - ad196ce8c46191d6a52592960835ff96d30152b5 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/cassettes/TestTextGenerationClientVCR.test_generate_no_details.yaml b/tests/cassettes/TestTextGenerationClientVCR.test_generate_no_details.yaml new file mode 100644 index 0000000000..9bbf0a76bf --- /dev/null +++ b/tests/cassettes/TestTextGenerationClientVCR.test_generate_no_details.yaml @@ -0,0 +1,51 @@ +interactions: +- request: + body: '{"inputs": "test", "parameters": {"do_sample": false, "max_new_tokens": + 1, "repetition_penalty": null, "return_full_text": false, "stop": [], "seed": + null, "temperature": null, "top_k": null, "top_p": null, "truncate": null, "typical_p": + null, "best_of": null, "watermark": false, "details": false, "decoder_input_details": + false}, "stream": false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '348' + Content-Type: + - application/json + user-agent: + - unknown/None; hf_hub/0.16.0.dev0; python/3.10.6; torch/1.12.1; tensorflow/2.11.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/google/flan-t5-xxl + response: + body: + string: '[{"generated_text":""}]' + headers: + Connection: + - keep-alive + Content-Length: + - '23' + Content-Type: + - application/json + Date: + - Mon, 19 Jun 2023 15:19:39 GMT + access-control-allow-credentials: + - 'true' + vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-time: + - '57' + x-compute-type: + - cache + x-request-id: + - P3ynVdl0lrN0nMWY4Di62 + x-sha: + - ad196ce8c46191d6a52592960835ff96d30152b5 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/cassettes/TestTextGenerationClientVCR.test_generate_non_tgi_endpoint.yaml b/tests/cassettes/TestTextGenerationClientVCR.test_generate_non_tgi_endpoint.yaml new file mode 100644 index 0000000000..dccb99f5da --- /dev/null +++ b/tests/cassettes/TestTextGenerationClientVCR.test_generate_non_tgi_endpoint.yaml @@ -0,0 +1,210 @@ +interactions: +- request: + body: '{"inputs": "0 1 2", "parameters": {"do_sample": false, "max_new_tokens": + 10, "repetition_penalty": null, "return_full_text": false, "stop": [], "seed": + null, "temperature": null, "top_k": null, "top_p": null, "truncate": null, "typical_p": + null, "best_of": null, "watermark": false, "details": false, "decoder_input_details": + false}, "stream": false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '350' + Content-Type: + - application/json + user-agent: + - unknown/None; hf_hub/0.16.0.dev0; python/3.10.6; torch/1.12.1; tensorflow/2.11.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/gpt2 + response: + body: + string: '{"error":"The following `model_kwargs` are not used by the model: [''watermark'', + ''stop'', ''details'', ''decoder_input_details''] (note: typos in the generate + arguments will also show up in this list)","warnings":["There was an inference + error: The following `model_kwargs` are not used by the model: [''watermark'', + ''stop'', ''details'', ''decoder_input_details''] (note: typos in the generate + arguments will also show up in this list)"]}' + headers: + Connection: + - keep-alive + Content-Type: + - application/json + Date: + - Tue, 20 Jun 2023 14:33:02 GMT + Transfer-Encoding: + - chunked + access-control-allow-credentials: + - 'true' + access-control-expose-headers: + - x-compute-type, x-compute-time + server: + - uvicorn + vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-time: + - '0.003' + x-compute-type: + - cpu + x-request-id: + - XMk4gR6bC7lYiuNKaa7dP + x-sha: + - e7da7f221d5bf496a48136c0cd264e630fe9fcc8 + status: + code: 400 + message: Bad Request +- request: + body: '{"inputs": "0 1 2", "parameters": {"do_sample": false, "max_new_tokens": + 10, "repetition_penalty": null, "return_full_text": false, "seed": null, "temperature": + null, "top_k": null, "top_p": null, "truncate": null, "typical_p": null, "best_of": + null}, "stream": false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '268' + Content-Type: + - application/json + user-agent: + - unknown/None; hf_hub/0.16.0.dev0; python/3.10.6; torch/1.12.1; tensorflow/2.11.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/gpt2 + response: + body: + string: '[{"generated_text":" 3 4 5 6 7 8 9 10 11 12"}]' + headers: + Connection: + - keep-alive + Content-Length: + - '46' + Content-Type: + - application/json + Date: + - Tue, 20 Jun 2023 14:33:02 GMT + access-control-allow-credentials: + - 'true' + vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-time: + - '0.215' + x-compute-type: + - cache + x-request-id: + - Xu4BHkuFIXmDdmVaAw6ZL + x-sha: + - e7da7f221d5bf496a48136c0cd264e630fe9fcc8 + status: + code: 200 + message: OK +- request: + body: '{"inputs": "4 5 6", "parameters": {"do_sample": false, "max_new_tokens": + 10, "repetition_penalty": null, "return_full_text": false, "seed": null, "temperature": + null, "top_k": null, "top_p": null, "truncate": null, "typical_p": null, "best_of": + null}, "stream": false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '268' + Content-Type: + - application/json + user-agent: + - unknown/None; hf_hub/0.16.0.dev0; python/3.10.6; torch/1.12.1; tensorflow/2.11.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/gpt2 + response: + body: + string: '[{"generated_text":" 7 8 9 10 11 12 13 14 15 16"}]' + headers: + Connection: + - keep-alive + Content-Type: + - application/json + Date: + - Tue, 20 Jun 2023 14:33:03 GMT + Transfer-Encoding: + - chunked + access-control-allow-credentials: + - 'true' + access-control-expose-headers: + - x-compute-type, x-compute-time + server: + - uvicorn + vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-characters: + - '5' + x-compute-time: + - '0.284' + x-compute-type: + - cpu + x-request-id: + - KD-x8pfpYvdw7n5zT4InB + x-sha: + - e7da7f221d5bf496a48136c0cd264e630fe9fcc8 + status: + code: 200 + message: OK +- request: + body: '{"inputs": "0 1 2", "parameters": {"do_sample": false, "max_new_tokens": + 10, "repetition_penalty": null, "return_full_text": false, "seed": null, "temperature": + null, "top_k": null, "top_p": null, "truncate": null, "typical_p": null, "best_of": + null}, "stream": false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '268' + Content-Type: + - application/json + user-agent: + - unknown/None; hf_hub/0.16.0.dev0; python/3.10.6; torch/1.12.1; tensorflow/2.11.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/gpt2 + response: + body: + string: '[{"generated_text":" 3 4 5 6 7 8 9 10 11 12"}]' + headers: + Connection: + - keep-alive + Content-Length: + - '46' + Content-Type: + - application/json + Date: + - Tue, 20 Jun 2023 14:33:03 GMT + access-control-allow-credentials: + - 'true' + vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-time: + - '0.215' + x-compute-type: + - cache + x-request-id: + - DXrY2BH9DQPbdWmrH3-gM + x-sha: + - e7da7f221d5bf496a48136c0cd264e630fe9fcc8 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/cassettes/TestTextGenerationClientVCR.test_generate_stream_no_details.yaml b/tests/cassettes/TestTextGenerationClientVCR.test_generate_stream_no_details.yaml new file mode 100644 index 0000000000..1f11b0980b --- /dev/null +++ b/tests/cassettes/TestTextGenerationClientVCR.test_generate_stream_no_details.yaml @@ -0,0 +1,52 @@ +interactions: +- request: + body: '{"inputs": "test", "parameters": {"do_sample": false, "max_new_tokens": + 1, "repetition_penalty": null, "return_full_text": false, "stop": [], "seed": + null, "temperature": null, "top_k": null, "top_p": null, "truncate": null, "typical_p": + null, "best_of": null, "watermark": false, "details": true, "decoder_input_details": + false}, "stream": true}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '346' + Content-Type: + - application/json + user-agent: + - unknown/None; hf_hub/0.16.0.dev0; python/3.10.6; torch/1.12.1; tensorflow/2.11.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/google/flan-t5-xxl + response: + body: + string: 'data:{"token":{"id":3,"text":" ","logprob":-2.0078125,"special":false},"generated_text":"","details":{"finish_reason":"length","generated_tokens":1,"seed":null}} + + + ' + headers: + Connection: + - keep-alive + Content-Length: + - '163' + Content-Type: + - text/event-stream + Date: + - Mon, 19 Jun 2023 15:22:08 GMT + access-control-allow-credentials: + - 'true' + vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-type: + - cache + x-request-id: + - Cclm6imN3ko6EnvbACvvt + x-sha: + - ad196ce8c46191d6a52592960835ff96d30152b5 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/cassettes/TestTextGenerationClientVCR.test_generate_stream_with_details.yaml b/tests/cassettes/TestTextGenerationClientVCR.test_generate_stream_with_details.yaml new file mode 100644 index 0000000000..0001ff406f --- /dev/null +++ b/tests/cassettes/TestTextGenerationClientVCR.test_generate_stream_with_details.yaml @@ -0,0 +1,52 @@ +interactions: +- request: + body: '{"inputs": "test", "parameters": {"do_sample": false, "max_new_tokens": + 1, "repetition_penalty": null, "return_full_text": false, "stop": [], "seed": + null, "temperature": null, "top_k": null, "top_p": null, "truncate": null, "typical_p": + null, "best_of": null, "watermark": false, "details": true, "decoder_input_details": + false}, "stream": true}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '346' + Content-Type: + - application/json + user-agent: + - unknown/None; hf_hub/0.16.0.dev0; python/3.10.6; torch/1.12.1; tensorflow/2.11.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/google/flan-t5-xxl + response: + body: + string: 'data:{"token":{"id":3,"text":" ","logprob":-2.0078125,"special":false},"generated_text":"","details":{"finish_reason":"length","generated_tokens":1,"seed":null}} + + + ' + headers: + Connection: + - keep-alive + Content-Length: + - '163' + Content-Type: + - text/event-stream + Date: + - Mon, 19 Jun 2023 15:22:09 GMT + access-control-allow-credentials: + - 'true' + vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-type: + - cache + x-request-id: + - _72TkULJsztkhCE-ss7vX + x-sha: + - ad196ce8c46191d6a52592960835ff96d30152b5 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/cassettes/TestTextGenerationClientVCR.test_generate_validation_error.yaml b/tests/cassettes/TestTextGenerationClientVCR.test_generate_validation_error.yaml new file mode 100644 index 0000000000..1fdefdde28 --- /dev/null +++ b/tests/cassettes/TestTextGenerationClientVCR.test_generate_validation_error.yaml @@ -0,0 +1,50 @@ +interactions: +- request: + body: '{"inputs": "test", "parameters": {"do_sample": false, "max_new_tokens": + 10000, "repetition_penalty": null, "return_full_text": false, "stop": [], "seed": + null, "temperature": null, "top_k": null, "top_p": null, "truncate": null, "typical_p": + null, "best_of": null, "watermark": false, "details": false, "decoder_input_details": + false}, "stream": false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '352' + Content-Type: + - application/json + user-agent: + - unknown/None; hf_hub/0.16.0.dev0; python/3.10.6; torch/1.12.1; tensorflow/2.11.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/google/flan-t5-xxl + response: + body: + string: '{"error":"Input validation error: `inputs` tokens + `max_new_tokens` + must be <= 1512. Given: 2 `inputs` tokens and 10000 `max_new_tokens`","error_type":"validation"}' + headers: + Connection: + - keep-alive + Content-Type: + - application/json + Date: + - Mon, 19 Jun 2023 15:22:09 GMT + Transfer-Encoding: + - chunked + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + vary: + - origin, Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-request-id: + - JXgh_Lwucgb8AbjAKAXMn + x-sha: + - ad196ce8c46191d6a52592960835ff96d30152b5 + status: + code: 422 + message: Unprocessable Entity +version: 1 diff --git a/tests/cassettes/TestTextGenerationClientVCR.test_generate_with_details.yaml b/tests/cassettes/TestTextGenerationClientVCR.test_generate_with_details.yaml new file mode 100644 index 0000000000..cb1cee4a2c --- /dev/null +++ b/tests/cassettes/TestTextGenerationClientVCR.test_generate_with_details.yaml @@ -0,0 +1,52 @@ +interactions: +- request: + body: '{"inputs": "test", "parameters": {"do_sample": false, "max_new_tokens": + 1, "repetition_penalty": null, "return_full_text": false, "stop": [], "seed": + null, "temperature": null, "top_k": null, "top_p": null, "truncate": null, "typical_p": + null, "best_of": null, "watermark": false, "details": true, "decoder_input_details": + true}, "stream": false}' + headers: + Accept: + - '*/*' + Accept-Encoding: + - gzip, deflate + Connection: + - keep-alive + Content-Length: + - '346' + Content-Type: + - application/json + user-agent: + - unknown/None; hf_hub/0.16.0.dev0; python/3.10.6; torch/1.12.1; tensorflow/2.11.0; + fastcore/1.5.23 + method: POST + uri: https://api-inference.huggingface.co/models/google/flan-t5-xxl + response: + body: + string: '[{"generated_text":"","details":{"finish_reason":"length","generated_tokens":1,"seed":null,"prefill":[{"id":0,"text":"","logprob":null}],"tokens":[{"id":3,"text":" + ","logprob":-2.0078125,"special":false}]}}]' + headers: + Connection: + - keep-alive + Content-Length: + - '212' + Content-Type: + - application/json + Date: + - Mon, 19 Jun 2023 15:22:10 GMT + access-control-allow-credentials: + - 'true' + vary: + - Origin, Access-Control-Request-Method, Access-Control-Request-Headers + x-compute-time: + - '56' + x-compute-type: + - cache + x-request-id: + - mYbIoIndwB6eIqjB5CNM5 + x-sha: + - ad196ce8c46191d6a52592960835ff96d30152b5 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/test_inference_client.py b/tests/test_inference_client.py index bb0d7745a8..58bfe6e10d 100644 --- a/tests/test_inference_client.py +++ b/tests/test_inference_client.py @@ -304,4 +304,5 @@ def test_mocked_post(self, get_session_mock: MagicMock) -> None: headers={"user-agent": expected_user_agent, "X-My-Header": "foo"}, cookies={"my-cookie": "bar"}, timeout=None, + stream=False, ) diff --git a/tests/test_inference_text_generation.py b/tests/test_inference_text_generation.py new file mode 100644 index 0000000000..a426fb8b3d --- /dev/null +++ b/tests/test_inference_text_generation.py @@ -0,0 +1,222 @@ +# Original implementation taken from the `text-generation` Python client (see https://pypi.org/project/text-generation/ +# and https://github.com/huggingface/text-generation-inference/tree/main/clients/python) +# +# See './src/huggingface_hub/inference/_text_generation.py' for details. +import unittest +from typing import Dict +from unittest.mock import MagicMock, patch + +import pytest +from pydantic import ValidationError +from requests import HTTPError + +from huggingface_hub import InferenceClient +from huggingface_hub.inference._client import _NON_TGI_SERVERS +from huggingface_hub.inference._text_generation import ( + FinishReason, + GenerationError, + IncompleteGenerationError, + InputToken, + OverloadedError, + TextGenerationParameters, + TextGenerationRequest, + raise_text_generation_error, +) +from huggingface_hub.inference._text_generation import ( + ValidationError as TextGenerationValidationError, +) + + +class TestTextGenerationTypes(unittest.TestCase): + def test_parameters_validation(self): + # Test best_of + TextGenerationParameters(best_of=1) + with self.assertRaises(ValidationError): + TextGenerationParameters(best_of=0) + with self.assertRaises(ValidationError): + TextGenerationParameters(best_of=-1) + TextGenerationParameters(best_of=2, do_sample=True) + with self.assertRaises(ValidationError): + TextGenerationParameters(best_of=2) + with self.assertRaises(ValidationError): + TextGenerationParameters(best_of=2, seed=1) + + # Test repetition_penalty + TextGenerationParameters(repetition_penalty=1) + with self.assertRaises(ValidationError): + TextGenerationParameters(repetition_penalty=0) + with self.assertRaises(ValidationError): + TextGenerationParameters(repetition_penalty=-1) + + # Test seed + TextGenerationParameters(seed=1) + with self.assertRaises(ValidationError): + TextGenerationParameters(seed=-1) + + # Test temperature + TextGenerationParameters(temperature=1) + with self.assertRaises(ValidationError): + TextGenerationParameters(temperature=0) + with self.assertRaises(ValidationError): + TextGenerationParameters(temperature=-1) + + # Test top_k + TextGenerationParameters(top_k=1) + with self.assertRaises(ValidationError): + TextGenerationParameters(top_k=0) + with self.assertRaises(ValidationError): + TextGenerationParameters(top_k=-1) + + # Test top_p + TextGenerationParameters(top_p=0.5) + with self.assertRaises(ValidationError): + TextGenerationParameters(top_p=0) + with self.assertRaises(ValidationError): + TextGenerationParameters(top_p=-1) + with self.assertRaises(ValidationError): + TextGenerationParameters(top_p=1) + + # Test truncate + TextGenerationParameters(truncate=1) + with self.assertRaises(ValidationError): + TextGenerationParameters(truncate=0) + with self.assertRaises(ValidationError): + TextGenerationParameters(truncate=-1) + + # Test typical_p + TextGenerationParameters(typical_p=0.5) + with self.assertRaises(ValidationError): + TextGenerationParameters(typical_p=0) + with self.assertRaises(ValidationError): + TextGenerationParameters(typical_p=-1) + with self.assertRaises(ValidationError): + TextGenerationParameters(typical_p=1) + + def test_request_validation(self): + TextGenerationRequest(inputs="test") + + with self.assertRaises(ValidationError): + TextGenerationRequest(inputs="") + + TextGenerationRequest(inputs="test", stream=True) + TextGenerationRequest(inputs="test", parameters=TextGenerationParameters(best_of=2, do_sample=True)) + + with self.assertRaises(ValidationError): + TextGenerationRequest( + inputs="test", parameters=TextGenerationParameters(best_of=2, do_sample=True), stream=True + ) + + +class TestTextGenerationErrors(unittest.TestCase): + def test_generation_error(self): + error = _mocked_error({"error_type": "generation", "error": "test"}) + with self.assertRaises(GenerationError): + raise_text_generation_error(error) + + def test_incomplete_generation_error(self): + error = _mocked_error({"error_type": "incomplete_generation", "error": "test"}) + with self.assertRaises(IncompleteGenerationError): + raise_text_generation_error(error) + + def test_overloaded_error(self): + error = _mocked_error({"error_type": "overloaded", "error": "test"}) + with self.assertRaises(OverloadedError): + raise_text_generation_error(error) + + def test_validation_error(self): + error = _mocked_error({"error_type": "validation", "error": "test"}) + with self.assertRaises(TextGenerationValidationError): + raise_text_generation_error(error) + + +def _mocked_error(payload: Dict) -> MagicMock: + error = HTTPError(response=MagicMock()) + error.response.json.return_value = payload + return error + + +@pytest.mark.vcr +@patch.dict("huggingface_hub.inference._client._NON_TGI_SERVERS", {}) +class TestTextGenerationClientVCR(unittest.TestCase): + """Use VCR test to avoid making requests to the prod infra.""" + + def setUp(self) -> None: + self.client = InferenceClient(model="google/flan-t5-xxl") + return super().setUp() + + def test_generate_no_details(self): + response = self.client.text_generation("test", details=False, max_new_tokens=1) + + assert response == "" + + def test_generate_with_details(self): + response = self.client.text_generation("test", details=True, max_new_tokens=1, decoder_input_details=True) + + assert response.generated_text == "" + assert response.details.finish_reason == FinishReason.Length + assert response.details.generated_tokens == 1 + assert response.details.seed is None + assert len(response.details.prefill) == 1 + assert response.details.prefill[0] == InputToken(id=0, text="", logprob=None) + assert len(response.details.tokens) == 1 + assert response.details.tokens[0].id == 3 + assert response.details.tokens[0].text == " " + assert not response.details.tokens[0].special + + def test_generate_best_of(self): + response = self.client.text_generation( + "test", max_new_tokens=1, best_of=2, do_sample=True, decoder_input_details=True, details=True + ) + + assert response.details.seed is not None + assert response.details.best_of_sequences is not None + assert len(response.details.best_of_sequences) == 1 + assert response.details.best_of_sequences[0].seed is not None + + def test_generate_validation_error(self): + with self.assertRaises(TextGenerationValidationError): + self.client.text_generation("test", max_new_tokens=10_000) + + def test_generate_stream_no_details(self): + responses = [ + response for response in self.client.text_generation("test", max_new_tokens=1, stream=True, details=True) + ] + + assert len(responses) == 1 + response = responses[0] + + assert response.generated_text == "" + assert response.details.finish_reason == FinishReason.Length + assert response.details.generated_tokens == 1 + assert response.details.seed is None + + def test_generate_stream_with_details(self): + responses = [ + response for response in self.client.text_generation("test", max_new_tokens=1, stream=True, details=True) + ] + + assert len(responses) == 1 + response = responses[0] + + assert response.generated_text == "" + assert response.details.finish_reason == FinishReason.Length + assert response.details.generated_tokens == 1 + assert response.details.seed is None + + def test_generate_non_tgi_endpoint(self): + text = self.client.text_generation("0 1 2", model="gpt2", max_new_tokens=10) + self.assertEqual(text, " 3 4 5 6 7 8 9 10 11 12") + self.assertIn("gpt2", _NON_TGI_SERVERS) + + # Watermark is ignored (+ warning) + with self.assertWarns(UserWarning): + self.client.text_generation("4 5 6", model="gpt2", max_new_tokens=10, watermark=True) + + # Return as detail even if details=True (+ warning) + with self.assertWarns(UserWarning): + text = self.client.text_generation("0 1 2", model="gpt2", max_new_tokens=10, details=True) + self.assertIsInstance(text, str) + + # Return as stream raises error + with self.assertRaises(ValueError): + self.client.text_generation("0 1 2", model="gpt2", max_new_tokens=10, stream=True)