Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support torchrun for emu2&emu2_chat and fix bug #52

Merged
merged 18 commits into from
Jan 17, 2024
2 changes: 1 addition & 1 deletion vlmeval/utils/dataset_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
'ScienceQA_VAL': '96320d05e142e585e7204e72affd29f3',
'ScienceQA_TEST': 'e42e9e00f9c59a80d8a5db35bc32b71f',
'HallusionBench': '0c23ac0dc9ef46832d7a24504f2a0c7c',
"DocVQA_VAL": '3744f5df4aaf2781c85fe7677ae0a411',
"DocVQA_VAL": 'c911fdc5f4974513c112cc83a25c99d9',
"AI2D": "53db8397adbe73e9cc0b4861227004d4",
"LLaVABench": "d382a093f749a697820d3dadd61c8428"
}
Expand Down
16 changes: 13 additions & 3 deletions vlmeval/vlm/emu.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,16 @@ def __init__(self,
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import init_empty_weights, infer_auto_device_map, dispatch_model

local_rank,world_size = get_rank_and_world_size()

device_num = torch.cuda.device_count()
assert world_size * 2 <= device_num, 'The number of devices does not match the world size'

device_1 = local_rank
device_2 = local_rank+world_size
torch.cuda.set_device(device_1)
torch.cuda.set_device(device_2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some check / assertion to make sure such arrangement is viable.
For example, check if you have 2 * world_size GPUs


tokenizer = AutoTokenizer.from_pretrained(model_path) # "BAAI/Emu2-Chat"
self.tokenizer = tokenizer
with init_empty_weights():
Expand All @@ -37,9 +47,9 @@ def __init__(self,
low_cpu_mem_usage=True,
trust_remote_code=True)

device_map = infer_auto_device_map(model, max_memory={0:'38GiB',1:'38GiB',}, no_split_module_classes=['Block','LlamaDecoderLayer'])
device_map = infer_auto_device_map(model, max_memory={device_1:'38GiB',device_2:'38GiB',}, no_split_module_classes=['Block','LlamaDecoderLayer'])
# input and output logits should be on same device
device_map["model.decoder.lm.lm_head"] = 0
device_map["model.decoder.lm.lm_head"] = device_1

model = dispatch_model(
model,
Expand Down Expand Up @@ -79,4 +89,4 @@ def interleave_generate(self, ti_list, dataset=None):
def generate(self, image_path, prompt, dataset=None):
tl_list = [image_path,prompt]
output = self.interleave_generate(tl_list, dataset)
return output
return output