diff --git a/vlmeval/config.py b/vlmeval/config.py
index 340310f4a..9e033a89f 100644
--- a/vlmeval/config.py
+++ b/vlmeval/config.py
@@ -46,6 +46,7 @@
'emu2':partial(Emu, name='emu2'),
'emu2_chat':partial(Emu, name='emu2_chat'),
'monkey':partial(Monkey, model_path='echo840/Monkey'),
+ 'monkey-chat':partial(MonkeyChat, model_path='echo840/Monkey-Chat'),
}
api_models = {
diff --git a/vlmeval/vlm/__init__.py b/vlmeval/vlm/__init__.py
index dc0266b73..25a0050cd 100644
--- a/vlmeval/vlm/__init__.py
+++ b/vlmeval/vlm/__init__.py
@@ -16,4 +16,4 @@
from .cogvlm import CogVlm
from .sharedcaptioner import SharedCaptioner
from .emu import Emu
-from .monkey import Monkey
+from .monkey import Monkey, MonkeyChat
diff --git a/vlmeval/vlm/monkey.py b/vlmeval/vlm/monkey.py
index c83d6e19a..d1da2dc37 100644
--- a/vlmeval/vlm/monkey.py
+++ b/vlmeval/vlm/monkey.py
@@ -22,7 +22,7 @@ def __init__(self, model_path='echo840/Monkey', **kwargs):
torch.cuda.empty_cache()
def generate(self, image_path, prompt, dataset=None):
- cur_prompt = f'{image_path}\n{prompt} Answer:'
+ cur_prompt = f'{image_path} {prompt} Answer:'
input_ids = self.tokenizer(cur_prompt, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
@@ -44,3 +44,47 @@ def generate(self, image_path, prompt, dataset=None):
response = self.tokenizer.decode(output_ids[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
return response
+
+class MonkeyChat:
+
+ INSTALL_REQ = False
+
+ def __init__(self, model_path='echo840/Monkey-Chat', **kwargs):
+ assert model_path is not None
+ self.model_path = model_path
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval()
+ self.kwargs = kwargs
+
+ self.tokenizer.padding_side = 'left'
+ self.tokenizer.pad_token_id = self.tokenizer.eod_id
+
+ warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
+ torch.cuda.empty_cache()
+
+ def generate(self, image_path, prompt, dataset=None):
+ if dataset == 'MMVet':
+ cur_prompt = f'{image_path} {prompt} Answer: '
+ else:
+ cur_prompt = f'{image_path} \n {prompt} Answer: '
+ input_ids = self.tokenizer(cur_prompt, return_tensors='pt', padding='longest')
+ attention_mask = input_ids.attention_mask
+ input_ids = input_ids.input_ids
+
+ output_ids = self.model.generate(
+ input_ids=input_ids.cuda(),
+ attention_mask=attention_mask.cuda(),
+ do_sample=False,
+ num_beams=1,
+ max_new_tokens=10,
+ min_new_tokens=1,
+ length_penalty=1,
+ num_return_sequences=1,
+ output_hidden_states=True,
+ use_cache=True,
+ pad_token_id=self.tokenizer.eod_id,
+ eos_token_id=self.tokenizer.eod_id,
+ )
+ response = self.tokenizer.decode(output_ids[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()
+
+ return response