diff --git a/backends/huggingface_multimodal_api.py b/backends/huggingface_multimodal_api.py index e62af6a5ad..28c3f87229 100644 --- a/backends/huggingface_multimodal_api.py +++ b/backends/huggingface_multimodal_api.py @@ -106,7 +106,11 @@ def get_images(messages: list[Dict]) -> list: return loaded_images -def generate_idefics_output(messages: list[Dict], model: IdeficsForVisionText2Text, processor: AutoProcessor, device) -> list[str]: +def generate_idefics_output(messages: list[Dict], + model: IdeficsForVisionText2Text, + processor: AutoProcessor, + max_tokens: int, + device) -> list[str]: ''' Return generated text from Idefics model @@ -136,8 +140,8 @@ def generate_idefics_output(messages: list[Dict], model: IdeficsForVisionText2Te # Generation args for Idefics exit_condition = processor.tokenizer("", add_special_tokens=False).input_ids bad_words_ids = processor.tokenizer(["", ""], add_special_tokens=False).input_ids - generated_ids = model.generate(**inputs, eos_token_id=exit_condition, bad_words_ids=bad_words_ids, max_length=100) - generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True) + generated_ids = model.generate(**inputs, eos_token_id=exit_condition, bad_words_ids=bad_words_ids, max_length=max_tokens) + generated_text = processor.batch_decode(generated_ids) return generated_text @@ -167,8 +171,7 @@ def __init__(self, model_spec: backends.ModelSpec): if model_spec["padding"]: self.padding = True - def generate_response(self, messages: List[Dict], - log_messages: bool = False) -> Tuple[Any, Any, str]: + def generate_response(self, messages: List[Dict]) -> Tuple[Any, Any, str]: """ :param messages: for example [ @@ -203,6 +206,7 @@ def generate_response(self, messages: List[Dict], generated_text = generate_idefics_output(messages=messages, model=self.multimodal_model, processor=self.processor, + max_tokens=self.get_max_tokens(), device=self.device)