diff --git a/vlmeval/config.py b/vlmeval/config.py index 340310f4a..9e033a89f 100644 --- a/vlmeval/config.py +++ b/vlmeval/config.py @@ -46,6 +46,7 @@ 'emu2':partial(Emu, name='emu2'), 'emu2_chat':partial(Emu, name='emu2_chat'), 'monkey':partial(Monkey, model_path='echo840/Monkey'), + 'monkey-chat':partial(MonkeyChat, model_path='echo840/Monkey-Chat'), } api_models = { diff --git a/vlmeval/vlm/__init__.py b/vlmeval/vlm/__init__.py index dc0266b73..25a0050cd 100644 --- a/vlmeval/vlm/__init__.py +++ b/vlmeval/vlm/__init__.py @@ -16,4 +16,4 @@ from .cogvlm import CogVlm from .sharedcaptioner import SharedCaptioner from .emu import Emu -from .monkey import Monkey +from .monkey import Monkey, MonkeyChat diff --git a/vlmeval/vlm/monkey.py b/vlmeval/vlm/monkey.py index c83d6e19a..d1da2dc37 100644 --- a/vlmeval/vlm/monkey.py +++ b/vlmeval/vlm/monkey.py @@ -22,7 +22,7 @@ def __init__(self, model_path='echo840/Monkey', **kwargs): torch.cuda.empty_cache() def generate(self, image_path, prompt, dataset=None): - cur_prompt = f'{image_path}\n{prompt} Answer:' + cur_prompt = f'{image_path} {prompt} Answer:' input_ids = self.tokenizer(cur_prompt, return_tensors='pt', padding='longest') attention_mask = input_ids.attention_mask input_ids = input_ids.input_ids @@ -44,3 +44,47 @@ def generate(self, image_path, prompt, dataset=None): response = self.tokenizer.decode(output_ids[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip() return response + +class MonkeyChat: + + INSTALL_REQ = False + + def __init__(self, model_path='echo840/Monkey-Chat', **kwargs): + assert model_path is not None + self.model_path = model_path + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval() + self.kwargs = kwargs + + self.tokenizer.padding_side = 'left' + self.tokenizer.pad_token_id = self.tokenizer.eod_id + + warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ") + torch.cuda.empty_cache() + + def generate(self, image_path, prompt, dataset=None): + if dataset == 'MMVet': + cur_prompt = f'{image_path} {prompt} Answer: ' + else: + cur_prompt = f'{image_path} \n {prompt} Answer: ' + input_ids = self.tokenizer(cur_prompt, return_tensors='pt', padding='longest') + attention_mask = input_ids.attention_mask + input_ids = input_ids.input_ids + + output_ids = self.model.generate( + input_ids=input_ids.cuda(), + attention_mask=attention_mask.cuda(), + do_sample=False, + num_beams=1, + max_new_tokens=10, + min_new_tokens=1, + length_penalty=1, + num_return_sequences=1, + output_hidden_states=True, + use_cache=True, + pad_token_id=self.tokenizer.eod_id, + eos_token_id=self.tokenizer.eod_id, + ) + response = self.tokenizer.decode(output_ids[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip() + + return response