Skip to content

Commit

Permalink
Merge pull request #173 from ServiceNow/multiple-samples-hf-model
Browse files Browse the repository at this point in the history
Adapt multiple samples for HF models
  • Loading branch information
recursix authored Dec 5, 2024
2 parents 6defb41 + 6a2c783 commit c52b7cd
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 28 deletions.
10 changes: 5 additions & 5 deletions src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def __init__(
**client_args,
)

def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
# Initialize retry tracking attributes
self.retries = 0
self.success = False
Expand All @@ -271,12 +271,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
e = None
for itr in range(self.max_retry):
self.retries += 1
temperature = temperature if temperature is not None else self.temperature
try:
completion = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
n=n_samples,
temperature=self.temperature,
temperature=temperature,
max_tokens=self.max_tokens,
)

Expand Down Expand Up @@ -414,11 +415,10 @@ def __init__(
super().__init__(model_name, n_retry_server)
if temperature < 1e-3:
logging.warning("Models might behave weirdly when temperature is too low.")
self.temperature = temperature

if token is None:
token = os.environ["TGI_TOKEN"]

client = InferenceClient(model=model_url, token=token)
self.llm = partial(
client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens
)
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens)
66 changes: 43 additions & 23 deletions src/agentlab/llm/huggingface_utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
import time
from typing import Any, List, Optional
from typing import Any, List, Optional, Union

from pydantic import Field
from transformers import AutoTokenizer, GPT2TokenizerFast
Expand All @@ -12,7 +12,7 @@

class HFBaseChatModel(AbstractChatModel):
"""
Custom LLM Chatbot that can interface with HuggingFace models.
Custom LLM Chatbot that can interface with HuggingFace models with support for multiple samples.
This class allows for the creation of a custom chatbot using models hosted
on HuggingFace Hub or a local checkpoint. It provides flexibility in defining
Expand All @@ -22,6 +22,8 @@ class HFBaseChatModel(AbstractChatModel):
Attributes:
llm (Any): The HuggingFaceHub model instance.
prompt_template (Any): Template for the prompt to be used for the model's input sequence.
tokenizer (Any): The tokenizer to use for the model.
n_retry_server (int): Number of times to retry on server failure.
"""

llm: Any = Field(description="The HuggingFaceHub model instance")
Expand Down Expand Up @@ -53,44 +55,62 @@ def __init__(self, model_name, n_retry_server):
def __call__(
self,
messages: list[dict],
) -> dict:

# NOTE: The `stop`, `run_manager`, and `kwargs` arguments are ignored in this implementation.

n_samples: int = 1,
temperature: float = None,
) -> Union[AIMessage, List[AIMessage]]:
"""
Generate one or more responses for the given messages.
Args:
messages: List of message dictionaries containing the conversation history.
n_samples: Number of independent responses to generate. Defaults to 1.
temperature: The temperature for response sampling. Defaults to None.
Returns:
If n_samples=1, returns a single AIMessage.
If n_samples>1, returns a list of AIMessages.
Raises:
Exception: If the server fails to respond after n_retry_server attempts or if the chat template fails.
"""
if self.tokenizer:
# messages_formated = _convert_messages_to_dict(messages) ## ?
try:
if isinstance(messages, Discussion):
messages.merge()
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
except Exception as e:
if "Conversation roles must alternate" in str(e):
logging.warning(
f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role"
f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. "
"Retrying with the 'system' role appended to the 'user' role."
)
messages = _prepend_system_to_first_user(messages)
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
else:
raise e

elif self.prompt_template:
prompt = self.prompt_template.construct_prompt(messages)

itr = 0
while True:
try:
response = AIMessage(self.llm(prompt))
return response
except Exception as e:
if itr == self.n_retry_server - 1:
raise e
logging.warning(
f"Failed to get a response from the server: \n{e}\n"
f"Retrying... ({itr+1}/{self.n_retry_server})"
)
time.sleep(5)
itr += 1
responses = []
for _ in range(n_samples):
itr = 0
while True:
try:
temperature = temperature if temperature is not None else self.temperature
response = AIMessage(self.llm(prompt, temperature=temperature))
responses.append(response)
break
except Exception as e:
if itr == self.n_retry_server - 1:
raise e
logging.warning(
f"Failed to get a response from the server: \n{e}\n"
f"Retrying... ({itr+1}/{self.n_retry_server})"
)
time.sleep(5)
itr += 1

return responses[0] if n_samples == 1 else responses

def _llm_type(self):
return "huggingface"
Expand Down

0 comments on commit c52b7cd

Please sign in to comment.