Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix resolve chat completion URL #2540

Merged
merged 2 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/huggingface_hub/hub_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -829,7 +829,7 @@ def _load_as_pickle(cls, model: T, model_file: str, map_location: str, strict: b

@classmethod
def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, strict: bool) -> T:
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"):
if packaging.version.parse(safetensors.__version__) < packaging.version.parse("0.4.3"): # type: ignore [attr-defined]
load_model_as_safetensor(model, model_file, strict=strict) # type: ignore [arg-type]
if map_location != "cpu":
logger.warning(
Expand All @@ -840,7 +840,7 @@ def _load_as_safetensor(cls, model: T, model_file: str, map_location: str, stric
)
model.to(map_location) # type: ignore [attr-defined]
else:
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location)
safetensors.torch.load_model(model, model_file, strict=strict, device=map_location) # type: ignore [arg-type]
return model


Expand Down
46 changes: 26 additions & 20 deletions src/huggingface_hub/inference/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,26 +810,7 @@ def chat_completion(
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
```
"""
# Determine model
# `self.xxx` takes precedence over the method argument only in `chat_completion`
# since `chat_completion(..., model=xxx)` is also a payload parameter for the
# server, we need to handle it differently
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
is_url = model_id_or_url.startswith(("http://", "https://"))

# First, resolve the model chat completions URL
if model_id_or_url == self.base_url:
# base_url passed => add server route
model_url = model_id_or_url.rstrip("/")
if not model_url.endswith("/v1"):
model_url += "/v1"
model_url += "/chat/completions"
elif is_url:
# model is a URL => use it directly
model_url = model_id_or_url
else:
# model is a model ID => resolve it + add server route
model_url = self._resolve_url(model_id_or_url).rstrip("/") + "/v1/chat/completions"
model_url = self._resolve_chat_completion_url(model)

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
Expand Down Expand Up @@ -865,6 +846,31 @@ def chat_completion(

return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]

def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
# `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")

# Resolve URL if it's a model ID
model_url = (
model_id_or_url
if model_id_or_url.startswith(("http://", "https://"))
else self._resolve_url(model_id_or_url, task="text-generation")
)

# Strip trailing /
model_url = model_url.rstrip("/")

# Append /chat/completions if not already present
if model_url.endswith("/v1"):
model_url += "/chat/completions"

# Append /v1/chat/completions if not already present
if not model_url.endswith("/chat/completions"):
model_url += "/v1/chat/completions"

return model_url

def document_question_answering(
self,
image: ContentT,
Expand Down
46 changes: 26 additions & 20 deletions src/huggingface_hub/inference/_generated/_async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -850,26 +850,7 @@ async def chat_completion(
'{\n\n"activity": "bike ride",\n"animals": ["puppy", "cat", "raccoon"],\n"animals_seen": 3,\n"location": "park"}'
```
"""
# Determine model
# `self.xxx` takes precedence over the method argument only in `chat_completion`
# since `chat_completion(..., model=xxx)` is also a payload parameter for the
# server, we need to handle it differently
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")
is_url = model_id_or_url.startswith(("http://", "https://"))

# First, resolve the model chat completions URL
if model_id_or_url == self.base_url:
# base_url passed => add server route
model_url = model_id_or_url.rstrip("/")
if not model_url.endswith("/v1"):
model_url += "/v1"
model_url += "/chat/completions"
elif is_url:
# model is a URL => use it directly
model_url = model_id_or_url
else:
# model is a model ID => resolve it + add server route
model_url = self._resolve_url(model_id_or_url).rstrip("/") + "/v1/chat/completions"
model_url = self._resolve_chat_completion_url(model)

# `model` is sent in the payload. Not used by the server but can be useful for debugging/routing.
# If it's a ID on the Hub => use it. Otherwise, we use a random string.
Expand Down Expand Up @@ -905,6 +886,31 @@ async def chat_completion(

return ChatCompletionOutput.parse_obj_as_instance(data) # type: ignore[arg-type]

def _resolve_chat_completion_url(self, model: Optional[str] = None) -> str:
# Since `chat_completion(..., model=xxx)` is also a payload parameter for the server, we need to handle 'model' differently.
# `self.base_url` and `self.model` takes precedence over 'model' argument only in `chat_completion`.
model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation")

# Resolve URL if it's a model ID
model_url = (
model_id_or_url
if model_id_or_url.startswith(("http://", "https://"))
else self._resolve_url(model_id_or_url, task="text-generation")
)

# Strip trailing /
model_url = model_url.rstrip("/")

# Append /chat/completions if not already present
if model_url.endswith("/v1"):
model_url += "/chat/completions"

# Append /v1/chat/completions if not already present
if not model_url.endswith("/chat/completions"):
model_url += "/v1/chat/completions"

return model_url

async def document_question_answering(
self,
image: ContentT,
Expand Down
2 changes: 1 addition & 1 deletion src/huggingface_hub/utils/_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,7 +473,7 @@ def hf_raise_for_status(response: Response, endpoint_name: Optional[str] = None)

# Convert `HTTPError` into a `HfHubHTTPError` to display request information
# as well (request id and/or server error message)
raise _format(HfHubHTTPError, "", response) from e
raise _format(HfHubHTTPError, str(e), response) from e


def _format(error_type: Type[HfHubHTTPError], custom_message: str, response: Response) -> HfHubHTTPError:
Expand Down
127 changes: 106 additions & 21 deletions tests/test_inference_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import time
import unittest
from pathlib import Path
from typing import Optional
from unittest.mock import MagicMock, patch

import numpy as np
Expand Down Expand Up @@ -918,27 +919,6 @@ def test_model_and_base_url_mutually_exclusive(self):
InferenceClient(model="meta-llama/Meta-Llama-3-8B-Instruct", base_url="http://127.0.0.1:8000")


@pytest.mark.parametrize(
"base_url",
[
"http://0.0.0.0:8080/v1", # expected from OpenAI client
"http://0.0.0.0:8080", # but not mandatory
"http://0.0.0.0:8080/v1/", # ok with trailing '/' as well
"http://0.0.0.0:8080/", # ok with trailing '/' as well
],
)
def test_chat_completion_base_url_works_with_v1(base_url: str):
"""Test that `/v1/chat/completions` is correctly appended to the base URL.

This is a regression test for https://github.com/huggingface/huggingface_hub/issues/2414
"""
with patch("huggingface_hub.inference._client.InferenceClient.post") as post_mock:
client = InferenceClient(base_url=base_url)
post_mock.return_value = "{}"
client.chat_completion(messages=CHAT_COMPLETION_MESSAGES, stream=False)
assert post_mock.call_args_list[0].kwargs["model"] == "http://0.0.0.0:8080/v1/chat/completions"


@pytest.mark.parametrize("stop_signal", [b"data: [DONE]", b"data: [DONE]\n", b"data: [DONE] "])
def test_stream_text_generation_response(stop_signal: bytes):
data = [
Expand Down Expand Up @@ -970,3 +950,108 @@ def test_stream_chat_completion_response(stop_signal: bytes):
assert len(output) == 2
assert output[0].choices[0].delta.content == "Both"
assert output[1].choices[0].delta.content == " Rust"


INFERENCE_API_URL = "https://api-inference.huggingface.co/models"
INFERENCE_ENDPOINT_URL = "https://rur2d6yoccusjxgn.us-east-1.aws.endpoints.huggingface.cloud" # example
LOCAL_TGI_URL = "http://0.0.0.0:8080"


@pytest.mark.parametrize(
("client_model", "client_base_url", "model", "expected_url"),
[
(
# Inference API - model global to client
"username/repo_name",
None,
None,
f"{INFERENCE_API_URL}/username/repo_name/v1/chat/completions",
),
(
# Inference API - model specific to request
None,
None,
"username/repo_name",
f"{INFERENCE_API_URL}/username/repo_name/v1/chat/completions",
),
(
# Inference Endpoint - url global to client as 'model'
INFERENCE_ENDPOINT_URL,
None,
None,
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
),
(
# Inference Endpoint - url global to client as 'base_url'
None,
INFERENCE_ENDPOINT_URL,
None,
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
),
(
# Inference Endpoint - url specific to request
None,
None,
INFERENCE_ENDPOINT_URL,
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
),
(
# Inference Endpoint - url global to client as 'base_url' - explicit model id
None,
INFERENCE_ENDPOINT_URL,
"username/repo_name",
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
),
(
# Inference Endpoint - full url global to client as 'model'
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
None,
None,
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
),
(
# Inference Endpoint - full url global to client as 'base_url'
None,
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
None,
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
),
(
# Inference Endpoint - full url specific to request
None,
None,
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
),
(
# Inference Endpoint - url with '/v1' (OpenAI compatibility)
# Regression test for https://github.com/huggingface/huggingface_hub/pull/2418
None,
None,
f"{INFERENCE_ENDPOINT_URL}/v1",
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
),
(
# Inference Endpoint - url with '/v1/' (OpenAI compatibility)
# Regression test for https://github.com/huggingface/huggingface_hub/pull/2418
None,
None,
f"{INFERENCE_ENDPOINT_URL}/v1/",
f"{INFERENCE_ENDPOINT_URL}/v1/chat/completions",
),
(
# Local TGI with trailing '/v1'
# Regression test for https://github.com/huggingface/huggingface_hub/issues/2414
f"{LOCAL_TGI_URL}/v1", # expected from OpenAI client
None,
None,
f"{LOCAL_TGI_URL}/v1/chat/completions",
),
],
)
def test_resolve_chat_completion_url(
client_model: Optional[str], client_base_url: Optional[str], model: Optional[str], expected_url: str
):
client = InferenceClient(model=client_model, base_url=client_base_url)
url = client._resolve_chat_completion_url(model)
assert url == expected_url
Loading