diff --git a/src/huggingface_hub/inference/_client.py b/src/huggingface_hub/inference/_client.py index 4dfe6e9de8..ef29242007 100644 --- a/src/huggingface_hub/inference/_client.py +++ b/src/huggingface_hub/inference/_client.py @@ -814,51 +814,51 @@ def chat_completion( # `self.xxx` takes precedence over the method argument only in `chat_completion` # since `chat_completion(..., model=xxx)` is also a payload parameter for the # server, we need to handle it differently - model = self.base_url or self.model or model or self.get_recommended_model("text-generation") - is_url = model.startswith(("http://", "https://")) + model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation") + is_url = model_id_or_url.startswith(("http://", "https://")) # First, resolve the model chat completions URL - if model == self.base_url: + if model_id_or_url == self.base_url: # base_url passed => add server route - model_url = model.rstrip("/") + model_url = model_id_or_url.rstrip("/") if not model_url.endswith("/v1"): model_url += "/v1" model_url += "/chat/completions" elif is_url: # model is a URL => use it directly - model_url = model + model_url = model_id_or_url else: # model is a model ID => resolve it + add server route - model_url = self._resolve_url(model).rstrip("/") + "/v1/chat/completions" + model_url = self._resolve_url(model_id_or_url).rstrip("/") + "/v1/chat/completions" # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing. # If it's a ID on the Hub => use it. Otherwise, we use a random string. - model_id = model if not is_url and model.count("/") == 1 else "tgi" - - data = self.post( - model=model_url, - json=dict( - model=model_id, - messages=messages, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - temperature=temperature, - tool_choice=tool_choice, - tool_prompt=tool_prompt, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - stream=stream, - ), + model_id = model or self.model or "tgi" + if model_id.startswith(("http://", "https://")): + model_id = "tgi" # dummy value + + payload = dict( + model=model_id, + messages=messages, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + temperature=temperature, + tool_choice=tool_choice, + tool_prompt=tool_prompt, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, stream=stream, ) + payload = {key: value for key, value in payload.items() if value is not None} + data = self.post(model=model_url, json=payload, stream=stream) if stream: return _stream_chat_completion_response(data) # type: ignore[arg-type] diff --git a/src/huggingface_hub/inference/_generated/_async_client.py b/src/huggingface_hub/inference/_generated/_async_client.py index 0301bd42d7..b7d6bc20ab 100644 --- a/src/huggingface_hub/inference/_generated/_async_client.py +++ b/src/huggingface_hub/inference/_generated/_async_client.py @@ -824,51 +824,51 @@ async def chat_completion( # `self.xxx` takes precedence over the method argument only in `chat_completion` # since `chat_completion(..., model=xxx)` is also a payload parameter for the # server, we need to handle it differently - model = self.base_url or self.model or model or self.get_recommended_model("text-generation") - is_url = model.startswith(("http://", "https://")) + model_id_or_url = self.base_url or self.model or model or self.get_recommended_model("text-generation") + is_url = model_id_or_url.startswith(("http://", "https://")) # First, resolve the model chat completions URL - if model == self.base_url: + if model_id_or_url == self.base_url: # base_url passed => add server route - model_url = model.rstrip("/") + model_url = model_id_or_url.rstrip("/") if not model_url.endswith("/v1"): model_url += "/v1" model_url += "/chat/completions" elif is_url: # model is a URL => use it directly - model_url = model + model_url = model_id_or_url else: # model is a model ID => resolve it + add server route - model_url = self._resolve_url(model).rstrip("/") + "/v1/chat/completions" + model_url = self._resolve_url(model_id_or_url).rstrip("/") + "/v1/chat/completions" # `model` is sent in the payload. Not used by the server but can be useful for debugging/routing. # If it's a ID on the Hub => use it. Otherwise, we use a random string. - model_id = model if not is_url and model.count("/") == 1 else "tgi" - - data = await self.post( - model=model_url, - json=dict( - model=model_id, - messages=messages, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - logprobs=logprobs, - max_tokens=max_tokens, - n=n, - presence_penalty=presence_penalty, - response_format=response_format, - seed=seed, - stop=stop, - temperature=temperature, - tool_choice=tool_choice, - tool_prompt=tool_prompt, - tools=tools, - top_logprobs=top_logprobs, - top_p=top_p, - stream=stream, - ), + model_id = model or self.model or "tgi" + if model_id.startswith(("http://", "https://")): + model_id = "tgi" # dummy value + + payload = dict( + model=model_id, + messages=messages, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + temperature=temperature, + tool_choice=tool_choice, + tool_prompt=tool_prompt, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, stream=stream, ) + payload = {key: value for key, value in payload.items() if value is not None} + data = await self.post(model=model_url, json=payload, stream=stream) if stream: return _async_stream_chat_completion_response(data) # type: ignore[arg-type]