Skip to content

Commit

Permalink
[Model] Add monkey-chat (#54)
Browse files Browse the repository at this point in the history
* add monkey-chat

* add monkey-chat

---------

Co-authored-by: yuluoyun <1731396519@qq.com>
  • Loading branch information
ShuoZhang2003 and echo840 authored Jan 17, 2024
1 parent 7cd26fc commit 0c4239c
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 2 deletions.
1 change: 1 addition & 0 deletions vlmeval/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
'emu2':partial(Emu, name='emu2'),
'emu2_chat':partial(Emu, name='emu2_chat'),
'monkey':partial(Monkey, model_path='echo840/Monkey'),
'monkey-chat':partial(MonkeyChat, model_path='echo840/Monkey-Chat'),
}

api_models = {
Expand Down
2 changes: 1 addition & 1 deletion vlmeval/vlm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
from .cogvlm import CogVlm
from .sharedcaptioner import SharedCaptioner
from .emu import Emu
from .monkey import Monkey
from .monkey import Monkey, MonkeyChat
46 changes: 45 additions & 1 deletion vlmeval/vlm/monkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, model_path='echo840/Monkey', **kwargs):
torch.cuda.empty_cache()

def generate(self, image_path, prompt, dataset=None):
cur_prompt = f'<img>{image_path}</img>\n{prompt} Answer:'
cur_prompt = f'<img>{image_path}</img> {prompt} Answer:'
input_ids = self.tokenizer(cur_prompt, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids
Expand All @@ -44,3 +44,47 @@ def generate(self, image_path, prompt, dataset=None):
response = self.tokenizer.decode(output_ids[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()

return response

class MonkeyChat:

INSTALL_REQ = False

def __init__(self, model_path='echo840/Monkey-Chat', **kwargs):
assert model_path is not None
self.model_path = model_path
self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
self.model = AutoModelForCausalLM.from_pretrained(model_path, device_map='cuda', trust_remote_code=True).eval()
self.kwargs = kwargs

self.tokenizer.padding_side = 'left'
self.tokenizer.pad_token_id = self.tokenizer.eod_id

warnings.warn(f"Following kwargs received: {self.kwargs}, will use as generation config. ")
torch.cuda.empty_cache()

def generate(self, image_path, prompt, dataset=None):
if dataset == 'MMVet':
cur_prompt = f'<img>{image_path}</img> {prompt} Answer: '
else:
cur_prompt = f'<img>{image_path}</img> \n {prompt} Answer: '
input_ids = self.tokenizer(cur_prompt, return_tensors='pt', padding='longest')
attention_mask = input_ids.attention_mask
input_ids = input_ids.input_ids

output_ids = self.model.generate(
input_ids=input_ids.cuda(),
attention_mask=attention_mask.cuda(),
do_sample=False,
num_beams=1,
max_new_tokens=10,
min_new_tokens=1,
length_penalty=1,
num_return_sequences=1,
output_hidden_states=True,
use_cache=True,
pad_token_id=self.tokenizer.eod_id,
eos_token_id=self.tokenizer.eod_id,
)
response = self.tokenizer.decode(output_ids[0][input_ids.size(1):].cpu(), skip_special_tokens=True).strip()

return response

0 comments on commit 0c4239c

Please sign in to comment.