From fa3ff92b07ea5aaa633a2039818c310744f84d07 Mon Sep 17 00:00:00 2001 From: kcz358 Date: Mon, 6 May 2024 08:32:57 +0000 Subject: [PATCH] Fix llava conv template for llama3 --- lmms_eval/models/llava.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/lmms_eval/models/llava.py b/lmms_eval/models/llava.py index 632859d7..66f9ed0f 100644 --- a/lmms_eval/models/llava.py +++ b/lmms_eval/models/llava.py @@ -223,7 +223,11 @@ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]: image_tokens = " ".join(image_tokens) prompts_input = image_tokens + "\n" + (contexts[0] if isinstance(contexts, list) else contexts) - conv = conv_templates[self.conv_template].copy() + # This is much safer for llama3, as we now have some object type in it + if "llama_3" in self.conv_template: + conv = copy.deepcopy(conv_templates[self.conv_template]) + else: + conv = conv_templates[self.conv_template].copy() conv.append_message(conv.roles[0], prompts_input) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() @@ -331,7 +335,11 @@ def _collate(x): else: question = context - conv = conv_templates[self.conv_template].copy() + # This is much safer for llama3, as we now have some object type in it + if "llama_3" in self.conv_template: + conv = copy.deepcopy(conv_templates[self.conv_template]) + else: + conv = conv_templates[self.conv_template].copy() conv.append_message(conv.roles[0], question) conv.append_message(conv.roles[1], None) prompt_question = conv.get_prompt()