diff --git a/src/agentlab/llm/chat_api.py b/src/agentlab/llm/chat_api.py index f8f02766..afc6d158 100644 --- a/src/agentlab/llm/chat_api.py +++ b/src/agentlab/llm/chat_api.py @@ -261,7 +261,7 @@ def __init__( **client_args, ) - def __call__(self, messages: list[dict], n_samples: int = 1) -> dict: + def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict: # Initialize retry tracking attributes self.retries = 0 self.success = False @@ -271,12 +271,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1) -> dict: e = None for itr in range(self.max_retry): self.retries += 1 + temperature = temperature if temperature is not None else self.temperature try: completion = self.client.chat.completions.create( model=self.model_name, messages=messages, n=n_samples, - temperature=self.temperature, + temperature=temperature, max_tokens=self.max_tokens, ) @@ -414,11 +415,10 @@ def __init__( super().__init__(model_name, n_retry_server) if temperature < 1e-3: logging.warning("Models might behave weirdly when temperature is too low.") + self.temperature = temperature if token is None: token = os.environ["TGI_TOKEN"] client = InferenceClient(model=model_url, token=token) - self.llm = partial( - client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens - ) + self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens) diff --git a/src/agentlab/llm/huggingface_utils.py b/src/agentlab/llm/huggingface_utils.py index 9bb2d7ab..364221b5 100644 --- a/src/agentlab/llm/huggingface_utils.py +++ b/src/agentlab/llm/huggingface_utils.py @@ -1,6 +1,6 @@ import logging import time -from typing import Any, List, Optional +from typing import Any, List, Optional, Union from pydantic import Field from transformers import AutoTokenizer, GPT2TokenizerFast @@ -12,7 +12,7 @@ class HFBaseChatModel(AbstractChatModel): """ - Custom LLM Chatbot that can interface with HuggingFace models. + Custom LLM Chatbot that can interface with HuggingFace models with support for multiple samples. This class allows for the creation of a custom chatbot using models hosted on HuggingFace Hub or a local checkpoint. It provides flexibility in defining @@ -22,6 +22,8 @@ class HFBaseChatModel(AbstractChatModel): Attributes: llm (Any): The HuggingFaceHub model instance. prompt_template (Any): Template for the prompt to be used for the model's input sequence. + tokenizer (Any): The tokenizer to use for the model. + n_retry_server (int): Number of times to retry on server failure. """ llm: Any = Field(description="The HuggingFaceHub model instance") @@ -53,12 +55,25 @@ def __init__(self, model_name, n_retry_server): def __call__( self, messages: list[dict], - ) -> dict: - - # NOTE: The `stop`, `run_manager`, and `kwargs` arguments are ignored in this implementation. - + n_samples: int = 1, + temperature: float = None, + ) -> Union[AIMessage, List[AIMessage]]: + """ + Generate one or more responses for the given messages. + + Args: + messages: List of message dictionaries containing the conversation history. + n_samples: Number of independent responses to generate. Defaults to 1. + temperature: The temperature for response sampling. Defaults to None. + + Returns: + If n_samples=1, returns a single AIMessage. + If n_samples>1, returns a list of AIMessages. + + Raises: + Exception: If the server fails to respond after n_retry_server attempts or if the chat template fails. + """ if self.tokenizer: - # messages_formated = _convert_messages_to_dict(messages) ## ? try: if isinstance(messages, Discussion): messages.merge() @@ -66,31 +81,36 @@ def __call__( except Exception as e: if "Conversation roles must alternate" in str(e): logging.warning( - f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role" + f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. " "Retrying with the 'system' role appended to the 'user' role." ) messages = _prepend_system_to_first_user(messages) prompt = self.tokenizer.apply_chat_template(messages, tokenize=False) else: raise e - elif self.prompt_template: prompt = self.prompt_template.construct_prompt(messages) - itr = 0 - while True: - try: - response = AIMessage(self.llm(prompt)) - return response - except Exception as e: - if itr == self.n_retry_server - 1: - raise e - logging.warning( - f"Failed to get a response from the server: \n{e}\n" - f"Retrying... ({itr+1}/{self.n_retry_server})" - ) - time.sleep(5) - itr += 1 + responses = [] + for _ in range(n_samples): + itr = 0 + while True: + try: + temperature = temperature if temperature is not None else self.temperature + response = AIMessage(self.llm(prompt, temperature=temperature)) + responses.append(response) + break + except Exception as e: + if itr == self.n_retry_server - 1: + raise e + logging.warning( + f"Failed to get a response from the server: \n{e}\n" + f"Retrying... ({itr+1}/{self.n_retry_server})" + ) + time.sleep(5) + itr += 1 + + return responses[0] if n_samples == 1 else responses def _llm_type(self): return "huggingface"