Skip to content

Commit

Permalink
[Inference] Support stop parameter in text-generation instead of …
Browse files Browse the repository at this point in the history
…`stop_sequences` (#2473)

* Support stop parameter in text-generation instead of stop_sequences

* no need for a special case for stop_sequences

* comment

* deprecate properly
  • Loading branch information
Wauplin authored Aug 26, 2024
1 parent 6e9e4e4 commit 3be2afb
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 18 deletions.
35 changes: 26 additions & 9 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1655,7 +1655,8 @@ def text_generation( # type: ignore
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = False, # Manual default value
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, # Same as `stop`
stop: Optional[List[str]] = None,
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
Expand Down Expand Up @@ -1684,7 +1685,8 @@ def text_generation( # type: ignore
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = False, # Manual default value
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, # Same as `stop`
stop: Optional[List[str]] = None,
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
Expand Down Expand Up @@ -1713,7 +1715,8 @@ def text_generation( # type: ignore
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = False, # Manual default value
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, # Same as `stop`
stop: Optional[List[str]] = None,
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
Expand Down Expand Up @@ -1742,7 +1745,8 @@ def text_generation( # type: ignore
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = False, # Manual default value
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, # Same as `stop`
stop: Optional[List[str]] = None,
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
Expand Down Expand Up @@ -1771,7 +1775,8 @@ def text_generation(
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = False, # Manual default value
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, # Same as `stop`
stop: Optional[List[str]] = None,
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
Expand Down Expand Up @@ -1799,7 +1804,8 @@ def text_generation(
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = False, # Manual default value
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, # Same as `stop`
stop: Optional[List[str]] = None,
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
Expand Down Expand Up @@ -1864,8 +1870,10 @@ def text_generation(
Whether to prepend the prompt to the generated text
seed (`int`, *optional*):
Random sampling seed
stop (`List[str]`, *optional*):
Stop generating tokens if a member of `stop` is generated.
stop_sequences (`List[str]`, *optional*):
Stop generating tokens if a member of `stop_sequences` is generated
Deprecated argument. Use `stop` instead.
temperature (`float`, *optional*):
The value used to module the logits distribution.
top_n_tokens (`int`, *optional*):
Expand Down Expand Up @@ -2009,6 +2017,15 @@ def text_generation(
)
decoder_input_details = False

if stop_sequences is not None:
warnings.warn(
"`stop_sequences` is a deprecated argument for `text_generation` task"
" and will be removed in version '0.28.0'. Use `stop` instead.",
FutureWarning,
)
if stop is None:
stop = stop_sequences # use deprecated arg if provided

# Build payload
parameters = {
"adapter_id": adapter_id,
Expand All @@ -2022,7 +2039,7 @@ def text_generation(
"repetition_penalty": repetition_penalty,
"return_full_text": return_full_text,
"seed": seed,
"stop": stop_sequences if stop_sequences is not None else [],
"stop": stop if stop is not None else [],
"temperature": temperature,
"top_k": top_k,
"top_n_tokens": top_n_tokens,
Expand Down Expand Up @@ -2092,7 +2109,7 @@ def text_generation(
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop_sequences=stop_sequences,
stop=stop,
temperature=temperature,
top_k=top_k,
top_n_tokens=top_n_tokens,
Expand Down
35 changes: 26 additions & 9 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,8 @@ async def text_generation( # type: ignore
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = False, # Manual default value
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, # Same as `stop`
stop: Optional[List[str]] = None,
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
Expand Down Expand Up @@ -1717,7 +1718,8 @@ async def text_generation( # type: ignore
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = False, # Manual default value
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, # Same as `stop`
stop: Optional[List[str]] = None,
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
Expand Down Expand Up @@ -1746,7 +1748,8 @@ async def text_generation( # type: ignore
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = False, # Manual default value
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, # Same as `stop`
stop: Optional[List[str]] = None,
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
Expand Down Expand Up @@ -1775,7 +1778,8 @@ async def text_generation( # type: ignore
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = False, # Manual default value
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, # Same as `stop`
stop: Optional[List[str]] = None,
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
Expand Down Expand Up @@ -1804,7 +1808,8 @@ async def text_generation(
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = False, # Manual default value
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, # Same as `stop`
stop: Optional[List[str]] = None,
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
Expand Down Expand Up @@ -1832,7 +1837,8 @@ async def text_generation(
repetition_penalty: Optional[float] = None,
return_full_text: Optional[bool] = False, # Manual default value
seed: Optional[int] = None,
stop_sequences: Optional[List[str]] = None, # Same as `stop`
stop: Optional[List[str]] = None,
stop_sequences: Optional[List[str]] = None, # Deprecated, use `stop` instead
temperature: Optional[float] = None,
top_k: Optional[int] = None,
top_n_tokens: Optional[int] = None,
Expand Down Expand Up @@ -1897,8 +1903,10 @@ async def text_generation(
Whether to prepend the prompt to the generated text
seed (`int`, *optional*):
Random sampling seed
stop (`List[str]`, *optional*):
Stop generating tokens if a member of `stop` is generated.
stop_sequences (`List[str]`, *optional*):
Stop generating tokens if a member of `stop_sequences` is generated
Deprecated argument. Use `stop` instead.
temperature (`float`, *optional*):
The value used to module the logits distribution.
top_n_tokens (`int`, *optional*):
Expand Down Expand Up @@ -2043,6 +2051,15 @@ async def text_generation(
)
decoder_input_details = False

if stop_sequences is not None:
warnings.warn(
"`stop_sequences` is a deprecated argument for `text_generation` task"
" and will be removed in version '0.28.0'. Use `stop` instead.",
FutureWarning,
)
if stop is None:
stop = stop_sequences # use deprecated arg if provided

# Build payload
parameters = {
"adapter_id": adapter_id,
Expand All @@ -2056,7 +2073,7 @@ async def text_generation(
"repetition_penalty": repetition_penalty,
"return_full_text": return_full_text,
"seed": seed,
"stop": stop_sequences if stop_sequences is not None else [],
"stop": stop if stop is not None else [],
"temperature": temperature,
"top_k": top_k,
"top_n_tokens": top_n_tokens,
Expand Down Expand Up @@ -2126,7 +2143,7 @@ async def text_generation(
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
stop_sequences=stop_sequences,
stop=stop,
temperature=temperature,
top_k=top_k,
top_n_tokens=top_n_tokens,
Expand Down

0 comments on commit 3be2afb

Please sign in to comment.