Skip to content

Commit

Permalink
Raise error in chat completion when unprocessable (#2257)
Browse files Browse the repository at this point in the history
* Fix chat completion should throw error if not unprocessable

* add docstring

* forgot cassette
  • Loading branch information
Wauplin authored Apr 29, 2024
1 parent 84c0fd2 commit 7e9196c
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 32 deletions.
32 changes: 17 additions & 15 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,21 +728,23 @@ def chat_completion(
),
stream=stream,
)
except HTTPError:
# Let's consider the server is not a chat completion server.
# Then we call again `chat_completion` which will render the chat template client side.
# (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
_set_as_non_chat_completion_server(model)
return self.chat_completion(
messages=messages,
model=model,
stream=stream,
max_tokens=max_tokens,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
except HTTPError as e:
if e.response.status_code in (400, 404, 500):
# Let's consider the server is not a chat completion server.
# Then we call again `chat_completion` which will render the chat template client side.
# (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
_set_as_non_chat_completion_server(model)
return self.chat_completion(
messages=messages,
model=model,
stream=stream,
max_tokens=max_tokens,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
raise

if stream:
return _stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
Expand Down
32 changes: 17 additions & 15 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,21 +728,23 @@ async def chat_completion(
),
stream=stream,
)
except _import_aiohttp().ClientResponseError:
# Let's consider the server is not a chat completion server.
# Then we call again `chat_completion` which will render the chat template client side.
# (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
_set_as_non_chat_completion_server(model)
return await self.chat_completion(
messages=messages,
model=model,
stream=stream,
max_tokens=max_tokens,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
except _import_aiohttp().ClientResponseError as e:
if e.status in (400, 404, 500):
# Let's consider the server is not a chat completion server.
# Then we call again `chat_completion` which will render the chat template client side.
# (can be HTTP 500, HTTP 400, HTTP 404 depending on the server)
_set_as_non_chat_completion_server(model)
return await self.chat_completion(
messages=messages,
model=model,
stream=stream,
max_tokens=max_tokens,
seed=seed,
stop=stop,
temperature=temperature,
top_p=top_p,
)
raise

if stream:
return _async_stream_chat_completion_response_from_bytes(data) # type: ignore[arg-type]
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
interactions:
- request:
body: '{"model": "tgi", "messages": "please output ''Observation''", "frequency_penalty":
null, "logit_bias": null, "logprobs": null, "max_tokens": 200, "n": null, "presence_penalty":
null, "seed": null, "stop": ["Observation", "Final Answer"], "temperature":
null, "tool_choice": null, "tool_prompt": null, "tools": null, "top_logprobs":
null, "top_p": null, "stream": false}'
headers:
Accept:
- '*/*'
Accept-Encoding:
- gzip, deflate, br
Connection:
- keep-alive
Content-Length:
- '367'
Content-Type:
- application/json
X-Amzn-Trace-Id:
- 52f7edcd-57c0-4258-8ff1-e3c520a43ef7
user-agent:
- unknown/None; hf_hub/0.23.0.dev0; python/3.10.12; torch/2.2.1; tensorflow/2.15.0.post1;
fastcore/1.5.23
method: POST
uri: https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct/v1/chat/completions
response:
body:
string: 'Failed to deserialize the JSON body into the target type: messages:
invalid type: string "please output ''Observation''", expected a sequence
at line 1 column 58'
headers:
Connection:
- keep-alive
Content-Type:
- text/plain; charset=utf-8
Date:
- Mon, 29 Apr 2024 16:18:50 GMT
Transfer-Encoding:
- chunked
access-control-allow-credentials:
- 'true'
access-control-allow-origin:
- '*'
vary:
- origin, Origin, Access-Control-Request-Method, Access-Control-Request-Headers
x-request-id:
- yECOxGeXLet5-Ae5rw808
x-sha:
- e8cf5276ae3e97cfde8a058e64a636f2cde47820
status:
code: 422
message: Unprocessable Entity
version: 1
13 changes: 13 additions & 0 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,19 @@ def test_chat_completion_with_tool(self) -> None:
"location": "San Francisco, CA",
}

def test_chat_completion_unprocessable_entity(self) -> None:
"""Regression test for #2225.
See https://github.com/huggingface/huggingface_hub/issues/2225.
"""
with self.assertRaises(HfHubHTTPError):
self.client.chat_completion(
"please output 'Observation'", # Not a list of messages
stop=["Observation", "Final Answer"],
max_tokens=200,
model="meta-llama/Meta-Llama-3-70B-Instruct",
)

@expect_deprecation("InferenceClient.conversational")
def test_conversational(self) -> None:
output = self.client.conversational("Hi, who are you?")
Expand Down
6 changes: 4 additions & 2 deletions utils/generate_async_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,8 +335,10 @@ def _adapt_text_generation_to_async(code: str) -> str:
def _adapt_chat_completion_to_async(code: str) -> str:
# Catch `aiohttp` error instead of `requests` error
code = code.replace(
"except HTTPError:",
"except _import_aiohttp().ClientResponseError:",
""" except HTTPError as e:
if e.response.status_code in (400, 404, 500):""",
""" except _import_aiohttp().ClientResponseError as e:
if e.status in (400, 404, 500):""",
)

# Await text-generation call
Expand Down

0 comments on commit 7e9196c

Please sign in to comment.