Skip to content

Commit

Permalink
[hot-fix] Handle [DONE] signal from TGI + remove logic for "non-TGI s…
Browse files Browse the repository at this point in the history
…ervers" (#2410)

* Handle [DONE] signal from TGI

* fix text_generation as well

* Handle error 404 correctly

* consistency + stop treating transformers-backed models differently

* fix test

* fix broken test on main

* fix test

* cleaner
  • Loading branch information
Wauplin authored Jul 23, 2024
1 parent 78542a4 commit 91fe78e
Show file tree
Hide file tree
Showing 8 changed files with 226 additions and 612 deletions.
165 changes: 45 additions & 120 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@
_fetch_recommended_models,
_get_unsupported_text_generation_kwargs,
_import_numpy,
_is_chat_completion_server,
_open_as_binary,
_set_as_non_chat_completion_server,
_set_unsupported_text_generation_kwargs,
_stream_chat_completion_response_from_bytes,
_stream_chat_completion_response,
_stream_text_generation_response,
raise_text_generation_error,
)
Expand All @@ -82,8 +80,6 @@
ChatCompletionInputTool,
ChatCompletionInputToolTypeClass,
ChatCompletionOutput,
ChatCompletionOutputComplete,
ChatCompletionOutputMessage,
ChatCompletionStreamOutput,
DocumentQuestionAnsweringOutputElement,
FillMaskOutputElement,
Expand Down Expand Up @@ -189,7 +185,7 @@ def __init__(
)

self.model: Optional[str] = model
self.token: Union[str, bool, None] = token or api_key
self.token: Union[str, bool, None] = token if token is not None else api_key
self.headers = CaseInsensitiveDict(build_hf_headers(token=self.token)) # 'authorization' + 'user-agent'
if headers is not None:
self.headers.update(headers)
Expand Down Expand Up @@ -818,123 +814,52 @@ def chat_completion(
# since `chat_completion(..., model=xxx)` is also a payload parameter for the
# server, we need to handle it differently
model = self.base_url or self.model or model or self.get_recommended_model("text-generation")
is_url = model.startswith(("http://", "https://"))

# First, resolve the model chat completions URL
if model == self.base_url:
# base_url passed => add server route
model_url = model + "/v1/chat/completions"
elif is_url:
# model is a URL => use it directly
model_url = model
else:
# model is a model ID => resolve it + add server route
model_url = self._resolve_url(model) + "/v1/chat/completions"

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
model_id = model if not is_url and model.count("/") == 1 else "tgi"

data = self.post(
model=model_url,
json=dict(
model=model_id,
messages=messages,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
temperature=temperature,
tool_choice=tool_choice,
tool_prompt=tool_prompt,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
stream=stream,
),
stream=stream,
)

if _is_chat_completion_server(model):
# First, let's consider the server has a `/v1/chat/completions` endpoint.
# If that's the case, we don't have to render the chat template client-side.
model_url = self._resolve_url(model)
if not model_url.endswith("/chat/completions"):
model_url += "/v1/chat/completions"

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
if not model.startswith("http") and model.count("/") == 1:
# If it's a ID on the Hub => use it
model_id = model
else:
# Otherwise, we use a random string
model_id = "tgi"

try:
data = self.post(
model=model_url,
json=dict(
model=model_id,
messages=messages,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
temperature=temperature,
tool_choice=tool_choice,
tool_prompt=tool_prompt,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
stream=stream,
),
stream=stream,
)
except HTTPError as e:
if e.response.status_code in (400, 404, 500):
# Let's consider the server is not a chat completion server.
# Then we call again `chat_completion` which will render the chat template client side.
# (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
_set_as_non_chat_completion_server(model)
logger.warning(
f"Server {model_url} does not seem to support chat completion. Falling back to text generation. Error: {e}"
)
return self.chat_completion(
messages=messages,
model=model,
stream=stream,
max_tokens=max_tokens,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
raise

if stream:
return _stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]

return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]

# At this point, we know the server is not a chat completion server.
# It means it's a transformers-backed server for which we can send a list of messages directly to the
# `text-generation` pipeline. We won't receive a detailed response but only the generated text.
if stream:
raise ValueError(
"Streaming token is not supported by the model. This is due to the model not been served by a "
"Text-Generation-Inference server. Please pass `stream=False` as input."
)
if tool_choice is not None or tool_prompt is not None or tools is not None:
warnings.warn(
"Tools are not supported by the model. This is due to the model not been served by a "
"Text-Generation-Inference server. The provided tool parameters will be ignored."
)
if response_format is not None:
warnings.warn(
"Response format is not supported by the model. This is due to the model not been served by a "
"Text-Generation-Inference server. The provided response format will be ignored."
)

# generate response
text_generation_output = self.text_generation(
prompt=messages, # type: ignore # Not correct type but works implicitly
model=model,
stream=False,
details=False,
max_new_tokens=max_tokens,
seed=seed,
stop_sequences=stop,
temperature=temperature,
top_p=top_p,
)
return _stream_chat_completion_response(data) # type: ignore[arg-type]

# Format as a ChatCompletionOutput with dummy values for fields we can't provide
return ChatCompletionOutput(
id="dummy",
model="dummy",
system_fingerprint="dummy",
usage=None, # type: ignore # set to `None` as we don't want to provide false information
created=int(time.time()),
choices=[
ChatCompletionOutputComplete(
finish_reason="unk", # type: ignore # set to `unk` as we don't want to provide false information
index=0,
message=ChatCompletionOutputMessage(
content=text_generation_output,
role="assistant",
),
)
],
)
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]

def conversational(
self,
Expand Down
86 changes: 25 additions & 61 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
Literal,
NoReturn,
Optional,
Set,
Union,
overload,
)
Expand All @@ -61,8 +60,6 @@
)
from ._generated.types import (
ChatCompletionStreamOutput,
ChatCompletionStreamOutputChoice,
ChatCompletionStreamOutputDelta,
TextGenerationStreamOutput,
)

Expand Down Expand Up @@ -271,7 +268,10 @@ def _stream_text_generation_response(
"""Used in `InferenceClient.text_generation`."""
# Parse ServerSentEvents
for byte_payload in bytes_output_as_lines:
output = _format_text_generation_stream_output(byte_payload, details)
try:
output = _format_text_generation_stream_output(byte_payload, details)
except StopIteration:
break
if output is not None:
yield output

Expand All @@ -282,7 +282,10 @@ async def _async_stream_text_generation_response(
"""Used in `AsyncInferenceClient.text_generation`."""
# Parse ServerSentEvents
async for byte_payload in bytes_output_as_lines:
output = _format_text_generation_stream_output(byte_payload, details)
try:
output = _format_text_generation_stream_output(byte_payload, details)
except StopIteration:
break
if output is not None:
yield output

Expand All @@ -293,6 +296,9 @@ def _format_text_generation_stream_output(
if not byte_payload.startswith(b"data:"):
return None # empty line

if byte_payload == b"data: [DONE]":
raise StopIteration("[DONE] signal received.")

# Decode payload
payload = byte_payload.decode("utf-8")
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
Expand All @@ -306,72 +312,41 @@ def _format_text_generation_stream_output(
return output.token.text if not details else output


def _format_chat_completion_stream_output_from_text_generation(
item: TextGenerationStreamOutput, created: int
) -> ChatCompletionStreamOutput:
if item.details is None:
# new token generated => return delta
return ChatCompletionStreamOutput(
# explicitly set 'dummy' values to reduce expectations from users
id="dummy",
model="dummy",
system_fingerprint="dummy",
choices=[
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(
role="assistant",
content=item.token.text,
),
finish_reason=None,
index=0,
)
],
created=created,
)
else:
# generation is completed => return finish reason
return ChatCompletionStreamOutput(
# explicitly set 'dummy' values to reduce expectations from users
id="dummy",
model="dummy",
system_fingerprint="dummy",
choices=[
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(role="assistant"),
finish_reason=item.details.finish_reason,
index=0,
)
],
created=created,
)


def _stream_chat_completion_response_from_bytes(
def _stream_chat_completion_response(
bytes_lines: Iterable[bytes],
) -> Iterable[ChatCompletionStreamOutput]:
"""Used in `InferenceClient.chat_completion` if model is served with TGI."""
for item in bytes_lines:
output = _format_chat_completion_stream_output_from_text_generation_from_bytes(item)
try:
output = _format_chat_completion_stream_output(item)
except StopIteration:
break
if output is not None:
yield output


async def _async_stream_chat_completion_response_from_bytes(
async def _async_stream_chat_completion_response(
bytes_lines: AsyncIterable[bytes],
) -> AsyncIterable[ChatCompletionStreamOutput]:
"""Used in `AsyncInferenceClient.chat_completion`."""
async for item in bytes_lines:
output = _format_chat_completion_stream_output_from_text_generation_from_bytes(item)
try:
output = _format_chat_completion_stream_output(item)
except StopIteration:
break
if output is not None:
yield output


def _format_chat_completion_stream_output_from_text_generation_from_bytes(
def _format_chat_completion_stream_output(
byte_payload: bytes,
) -> Optional[ChatCompletionStreamOutput]:
if not byte_payload.startswith(b"data:"):
return None # empty line

if byte_payload == b"data: [DONE]":
raise StopIteration("[DONE] signal received.")

# Decode payload
payload = byte_payload.decode("utf-8")
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
Expand Down Expand Up @@ -413,17 +388,6 @@ def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]:
return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, [])


_NON_CHAT_COMPLETION_SERVER: Set[str] = set()


def _set_as_non_chat_completion_server(model: str) -> None:
_NON_CHAT_COMPLETION_SERVER.add(model)


def _is_chat_completion_server(model: str) -> bool:
return model not in _NON_CHAT_COMPLETION_SERVER


# TEXT GENERATION ERRORS
# ----------------------
# Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation
Expand Down
Loading

0 comments on commit 91fe78e

Please sign in to comment.