From 3be2afb4e970406d82396497899ab0036e583022 Mon Sep 17 00:00:00 2001 From: Lucain Date: Mon, 26 Aug 2024 15:06:01 +0200 Subject: [PATCH] [Inference] Support `stop` parameter in `text-generation` instead of `stop_sequences` (#2473) * Support stop parameter in text-generation instead of stop_sequences * no need for a special case for stop_sequences * comment * deprecate properly --- src/huggingface_hub/inference/_client.py | 35 ++++++++++++++----- .../inference/_generated/_async_client.py | 35 ++++++++++++++----- 2 files changed, 52 insertions(+), 18 deletions(-) diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index ef29242007..84131d298e 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -1655,7 +1655,8 @@ def text_generation( # type: ignore repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, # Same as `stop` + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1684,7 +1685,8 @@ def text_generation( # type: ignore repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, # Same as `stop` + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1713,7 +1715,8 @@ def text_generation( # type: ignore repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, # Same as `stop` + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1742,7 +1745,8 @@ def text_generation( # type: ignore repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, # Same as `stop` + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1771,7 +1775,8 @@ def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, # Same as `stop` + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1799,7 +1804,8 @@ def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, # Same as `stop` + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1864,8 +1870,10 @@ def text_generation( Whether to prepend the prompt to the generated text seed (`int`, *optional*): Random sampling seed + stop (`List[str]`, *optional*): + Stop generating tokens if a member of `stop` is generated. stop_sequences (`List[str]`, *optional*): - Stop generating tokens if a member of `stop_sequences` is generated + Deprecated argument. Use `stop` instead. temperature (`float`, *optional*): The value used to module the logits distribution. top_n_tokens (`int`, *optional*): @@ -2009,6 +2017,15 @@ def text_generation( ) decoder_input_details = False + if stop_sequences is not None: + warnings.warn( + "`stop_sequences` is a deprecated argument for `text_generation` task" + " and will be removed in version '0.28.0'. Use `stop` instead.", + FutureWarning, + ) + if stop is None: + stop = stop_sequences # use deprecated arg if provided + # Build payload parameters = { "adapter_id": adapter_id, @@ -2022,7 +2039,7 @@ def text_generation( "repetition_penalty": repetition_penalty, "return_full_text": return_full_text, "seed": seed, - "stop": stop_sequences if stop_sequences is not None else [], + "stop": stop if stop is not None else [], "temperature": temperature, "top_k": top_k, "top_n_tokens": top_n_tokens, @@ -2092,7 +2109,7 @@ def text_generation( repetition_penalty=repetition_penalty, return_full_text=return_full_text, seed=seed, - stop_sequences=stop_sequences, + stop=stop, temperature=temperature, top_k=top_k, top_n_tokens=top_n_tokens, diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index b7d6bc20ab..f34b4e33fd 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -1688,7 +1688,8 @@ async def text_generation( # type: ignore repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, # Same as `stop` + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1717,7 +1718,8 @@ async def text_generation( # type: ignore repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, # Same as `stop` + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1746,7 +1748,8 @@ async def text_generation( # type: ignore repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, # Same as `stop` + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1775,7 +1778,8 @@ async def text_generation( # type: ignore repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, # Same as `stop` + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1804,7 +1808,8 @@ async def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, # Same as `stop` + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1832,7 +1837,8 @@ async def text_generation( repetition_penalty: Optional[float] = None, return_full_text: Optional[bool] = False, # Manual default value seed: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, # Same as `stop` + stop: Optional[List[str]] = None, + stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead temperature: Optional[float] = None, top_k: Optional[int] = None, top_n_tokens: Optional[int] = None, @@ -1897,8 +1903,10 @@ async def text_generation( Whether to prepend the prompt to the generated text seed (`int`, *optional*): Random sampling seed + stop (`List[str]`, *optional*): + Stop generating tokens if a member of `stop` is generated. stop_sequences (`List[str]`, *optional*): - Stop generating tokens if a member of `stop_sequences` is generated + Deprecated argument. Use `stop` instead. temperature (`float`, *optional*): The value used to module the logits distribution. top_n_tokens (`int`, *optional*): @@ -2043,6 +2051,15 @@ async def text_generation( ) decoder_input_details = False + if stop_sequences is not None: + warnings.warn( + "`stop_sequences` is a deprecated argument for `text_generation` task" + " and will be removed in version '0.28.0'. Use `stop` instead.", + FutureWarning, + ) + if stop is None: + stop = stop_sequences # use deprecated arg if provided + # Build payload parameters = { "adapter_id": adapter_id, @@ -2056,7 +2073,7 @@ async def text_generation( "repetition_penalty": repetition_penalty, "return_full_text": return_full_text, "seed": seed, - "stop": stop_sequences if stop_sequences is not None else [], + "stop": stop if stop is not None else [], "temperature": temperature, "top_k": top_k, "top_n_tokens": top_n_tokens, @@ -2126,7 +2143,7 @@ async def text_generation( repetition_penalty=repetition_penalty, return_full_text=return_full_text, seed=seed, - stop_sequences=stop_sequences, + stop=stop, temperature=temperature, top_k=top_k, top_n_tokens=top_n_tokens,