Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hot-fix] Handle [DONE] signal from TGI + remove logic for "non-TGI servers" #2410

Merged
merged 8 commits into from
Jul 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 45 additions & 120 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,9 @@
_fetch_recommended_models,
_get_unsupported_text_generation_kwargs,
_import_numpy,
_is_chat_completion_server,
_open_as_binary,
_set_as_non_chat_completion_server,
_set_unsupported_text_generation_kwargs,
_stream_chat_completion_response_from_bytes,
_stream_chat_completion_response,
_stream_text_generation_response,
raise_text_generation_error,
)
Expand All @@ -82,8 +80,6 @@
ChatCompletionInputTool,
ChatCompletionInputToolTypeClass,
ChatCompletionOutput,
ChatCompletionOutputComplete,
ChatCompletionOutputMessage,
ChatCompletionStreamOutput,
DocumentQuestionAnsweringOutputElement,
FillMaskOutputElement,
Expand Down Expand Up @@ -189,7 +185,7 @@ def __init__(
)

self.model: Optional[str] = model
self.token: Union[str, bool, None] = token or api_key
self.token: Union[str, bool, None] = token if token is not None else api_key
self.headers = CaseInsensitiveDict(build_hf_headers(token=self.token)) # 'authorization' + 'user-agent'
if headers is not None:
self.headers.update(headers)
Expand Down Expand Up @@ -818,123 +814,52 @@ def chat_completion(
# since `chat_completion(..., model=xxx)` is also a payload parameter for the
# server, we need to handle it differently
model = self.base_url or self.model or model or self.get_recommended_model("text-generation")
is_url = model.startswith(("http://", "https://"))

# First, resolve the model chat completions URL
if model == self.base_url:
# base_url passed => add server route
model_url = model + "/v1/chat/completions"
elif is_url:
# model is a URL => use it directly
model_url = model
else:
# model is a model ID => resolve it + add server route
model_url = self._resolve_url(model) + "/v1/chat/completions"

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
model_id = model if not is_url and model.count("/") == 1 else "tgi"

data = self.post(
model=model_url,
json=dict(
model=model_id,
messages=messages,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
temperature=temperature,
tool_choice=tool_choice,
tool_prompt=tool_prompt,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
stream=stream,
),
stream=stream,
)

if _is_chat_completion_server(model):
# First, let's consider the server has a `/v1/chat/completions` endpoint.
# If that's the case, we don't have to render the chat template client-side.
model_url = self._resolve_url(model)
if not model_url.endswith("/chat/completions"):
model_url += "/v1/chat/completions"

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
if not model.startswith("http") and model.count("/") == 1:
# If it's a ID on the Hub => use it
model_id = model
else:
# Otherwise, we use a random string
model_id = "tgi"

try:
data = self.post(
model=model_url,
json=dict(
model=model_id,
messages=messages,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
temperature=temperature,
tool_choice=tool_choice,
tool_prompt=tool_prompt,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
stream=stream,
),
stream=stream,
)
except HTTPError as e:
if e.response.status_code in (400, 404, 500):
# Let's consider the server is not a chat completion server.
# Then we call again `chat_completion` which will render the chat template client side.
# (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
_set_as_non_chat_completion_server(model)
logger.warning(
f"Server {model_url} does not seem to support chat completion. Falling back to text generation. Error: {e}"
)
return self.chat_completion(
messages=messages,
model=model,
stream=stream,
max_tokens=max_tokens,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
raise

if stream:
return _stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]

return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]

# At this point, we know the server is not a chat completion server.
# It means it's a transformers-backed server for which we can send a list of messages directly to the
# `text-generation` pipeline. We won't receive a detailed response but only the generated text.
if stream:
raise ValueError(
"Streaming token is not supported by the model. This is due to the model not been served by a "
"Text-Generation-Inference server. Please pass `stream=False` as input."
)
if tool_choice is not None or tool_prompt is not None or tools is not None:
warnings.warn(
"Tools are not supported by the model. This is due to the model not been served by a "
"Text-Generation-Inference server. The provided tool parameters will be ignored."
)
if response_format is not None:
warnings.warn(
"Response format is not supported by the model. This is due to the model not been served by a "
"Text-Generation-Inference server. The provided response format will be ignored."
)

# generate response
text_generation_output = self.text_generation(
prompt=messages, # type: ignore # Not correct type but works implicitly
model=model,
stream=False,
details=False,
max_new_tokens=max_tokens,
seed=seed,
stop_sequences=stop,
temperature=temperature,
top_p=top_p,
)
return _stream_chat_completion_response(data) # type: ignore[arg-type]

# Format as a ChatCompletionOutput with dummy values for fields we can't provide
return ChatCompletionOutput(
id="dummy",
model="dummy",
system_fingerprint="dummy",
usage=None, # type: ignore # set to `None` as we don't want to provide false information
created=int(time.time()),
choices=[
ChatCompletionOutputComplete(
finish_reason="unk", # type: ignore # set to `unk` as we don't want to provide false information
index=0,
message=ChatCompletionOutputMessage(
content=text_generation_output,
role="assistant",
),
)
],
)
return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]

def conversational(
self,
Expand Down
86 changes: 25 additions & 61 deletions src/huggingface_hub/inference/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
Literal,
NoReturn,
Optional,
Set,
Union,
overload,
)
Expand All @@ -61,8 +60,6 @@
)
from ._generated.types import (
ChatCompletionStreamOutput,
ChatCompletionStreamOutputChoice,
ChatCompletionStreamOutputDelta,
TextGenerationStreamOutput,
)

Expand Down Expand Up @@ -271,7 +268,10 @@ def _stream_text_generation_response(
"""Used in `InferenceClient.text_generation`."""
# Parse ServerSentEvents
for byte_payload in bytes_output_as_lines:
output = _format_text_generation_stream_output(byte_payload, details)
try:
output = _format_text_generation_stream_output(byte_payload, details)
except StopIteration:
break
if output is not None:
yield output

Expand All @@ -282,7 +282,10 @@ async def _async_stream_text_generation_response(
"""Used in `AsyncInferenceClient.text_generation`."""
# Parse ServerSentEvents
async for byte_payload in bytes_output_as_lines:
output = _format_text_generation_stream_output(byte_payload, details)
try:
output = _format_text_generation_stream_output(byte_payload, details)
except StopIteration:
break
if output is not None:
yield output

Expand All @@ -293,6 +296,9 @@ def _format_text_generation_stream_output(
if not byte_payload.startswith(b"data:"):
return None # empty line

if byte_payload == b"data: [DONE]":
raise StopIteration("[DONE] signal received.")
Wauplin marked this conversation as resolved.
Show resolved Hide resolved

# Decode payload
payload = byte_payload.decode("utf-8")
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
Expand All @@ -306,72 +312,41 @@ def _format_text_generation_stream_output(
return output.token.text if not details else output


def _format_chat_completion_stream_output_from_text_generation(
item: TextGenerationStreamOutput, created: int
) -> ChatCompletionStreamOutput:
if item.details is None:
# new token generated => return delta
return ChatCompletionStreamOutput(
# explicitly set 'dummy' values to reduce expectations from users
id="dummy",
model="dummy",
system_fingerprint="dummy",
choices=[
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(
role="assistant",
content=item.token.text,
),
finish_reason=None,
index=0,
)
],
created=created,
)
else:
# generation is completed => return finish reason
return ChatCompletionStreamOutput(
# explicitly set 'dummy' values to reduce expectations from users
id="dummy",
model="dummy",
system_fingerprint="dummy",
choices=[
ChatCompletionStreamOutputChoice(
delta=ChatCompletionStreamOutputDelta(role="assistant"),
finish_reason=item.details.finish_reason,
index=0,
)
],
created=created,
)


def _stream_chat_completion_response_from_bytes(
def _stream_chat_completion_response(
bytes_lines: Iterable[bytes],
) -> Iterable[ChatCompletionStreamOutput]:
"""Used in `InferenceClient.chat_completion` if model is served with TGI."""
for item in bytes_lines:
output = _format_chat_completion_stream_output_from_text_generation_from_bytes(item)
try:
output = _format_chat_completion_stream_output(item)
except StopIteration:
break
if output is not None:
yield output


async def _async_stream_chat_completion_response_from_bytes(
async def _async_stream_chat_completion_response(
bytes_lines: AsyncIterable[bytes],
) -> AsyncIterable[ChatCompletionStreamOutput]:
"""Used in `AsyncInferenceClient.chat_completion`."""
async for item in bytes_lines:
output = _format_chat_completion_stream_output_from_text_generation_from_bytes(item)
try:
output = _format_chat_completion_stream_output(item)
except StopIteration:
break
if output is not None:
yield output


def _format_chat_completion_stream_output_from_text_generation_from_bytes(
def _format_chat_completion_stream_output(
byte_payload: bytes,
) -> Optional[ChatCompletionStreamOutput]:
if not byte_payload.startswith(b"data:"):
return None # empty line

if byte_payload == b"data: [DONE]":
raise StopIteration("[DONE] signal received.")

# Decode payload
payload = byte_payload.decode("utf-8")
json_payload = json.loads(payload.lstrip("data:").rstrip("/n"))
Expand Down Expand Up @@ -413,17 +388,6 @@ def _get_unsupported_text_generation_kwargs(model: Optional[str]) -> List[str]:
return _UNSUPPORTED_TEXT_GENERATION_KWARGS.get(model, [])


_NON_CHAT_COMPLETION_SERVER: Set[str] = set()


def _set_as_non_chat_completion_server(model: str) -> None:
_NON_CHAT_COMPLETION_SERVER.add(model)


def _is_chat_completion_server(model: str) -> bool:
return model not in _NON_CHAT_COMPLETION_SERVER


# TEXT GENERATION ERRORS
# ----------------------
# Text-generation errors are parsed separately to handle as much as possible the errors returned by the text generation
Expand Down
Loading
Loading