Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral + Mixtral Support for NeVa #9459

Merged
merged 15 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions nemo/collections/multimodal/data/neva/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class SeparatorStyle(Enum):
PLAIN = auto()
LLAMA_2 = auto()
LLAMA_3 = auto()
MISTRAL = auto()
NVGPT = auto()


Expand Down Expand Up @@ -94,11 +95,15 @@ def get_prompt(self):
ret += " "
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
elif self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL:
if self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
else:
wrap_sys = lambda msg: f"{msg}" + ("\n" if msg else "")
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""

if self.sep_style == SeparatorStyle.MISTRAL:
ret += DEFAULT_BOS_TOKEN
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
Expand All @@ -112,7 +117,10 @@ def get_prompt(self):
message = wrap_inst(message)
ret += self.sep + " " + message
else:
ret += " " + message + " " + self.sep2
if self.sep_style == SeparatorStyle.LLAMA_2:
ret += " " + message + " " + self.sep2
else:
ret += message + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
Expand Down Expand Up @@ -449,6 +457,17 @@ def dict(self):
version="v1_mmtag",
)

conv_mistral = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="mistral",
messages=(),
offset=0,
sep_style=SeparatorStyle.MISTRAL,
sep="",
sep2=DEFAULT_EOS_TOKEN,
)

default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
Expand All @@ -466,6 +485,7 @@ def dict(self):
"nvgpt": conv_nvgpt,
"nv_steerlm": conv_nvgpt,
"nv_dpo": conv_nv_dpo,
"mistral": conv_mistral,
}

if __name__ == "__main__":
Expand Down
34 changes: 28 additions & 6 deletions nemo/collections/multimodal/data/neva/neva_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def preprocess_llama_2(
sources: dict,
tokenizer,
cfg,
is_mistral: bool = False,
) -> Dict:
"""
Preprocesses sources for the LLaMA 2 model configuration.
Expand All @@ -442,7 +443,10 @@ def preprocess_llama_2(
- Dict: A dictionary containing tokenized and labeled data suitable for the LLaMA 2 model.
This includes tokens, labels, and any special processing as defined in the configuration.
"""
conv = conversation_lib.conv_llava_llama_2.copy()
if is_mistral:
conv = conversation_lib.conv_mistral.copy()
else:
conv = conversation_lib.conv_llava_llama_2.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

# Apply prompt templates
Expand Down Expand Up @@ -477,7 +481,10 @@ def preprocess_llama_2(
labels = tokens.clone().detach()

# Mask labels
sep = "[/INST] "
if is_mistral:
sep = "[/INST]"
else:
sep = "[/INST] "
for conversation, target in zip(conversations, labels):
rounds = conversation.split(conv.sep2)
cur_len = 0
Expand All @@ -492,18 +499,23 @@ def preprocess_llama_2(
parts[0] += sep

round_len = len(tokenizer.text_to_ids(rou + conv.sep2))
instruction_len = len(tokenizer.text_to_ids(parts[0])) - 2

if is_mistral:
instruction_len = len(tokenizer.text_to_ids(parts[0])) - 1
else:
instruction_len = len(tokenizer.text_to_ids(parts[0])) - 2

if i > 0:
round_len -= 1 # Remove extra token added by sp tokenizer
else:
instruction_len += 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

cur_len += round_len
target[cur_len:] = IGNORE_INDEX

# Check if masking working correctly
# print([x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())])
# masking_test =[x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())]
# print(masking_test)

if add_extra_token:
tokens = tokens[:, :-1].contiguous()
Expand Down Expand Up @@ -990,7 +1002,10 @@ def expand2square(pil_img, background_color):
result.paste(pil_img, ((height - width) // 2, 0))
return result

frames = expand2square(frames, tuple(int(x * 255) for x in self.processor.image_mean))
frames = [
expand2square(frame, tuple(int(x * 255) for x in self.processor.image_mean))
for frame in frames
]
frames = self.processor.preprocess(frames, return_tensors='pt')['pixel_values']
else:
frames = self.processor.preprocess(frames, return_tensors='pt')['pixel_values']
Expand Down Expand Up @@ -1057,6 +1072,13 @@ def expand2square(pil_img, background_color):
self.tokenizer,
self.multimodal_cfg,
)
elif self.conv_template == "mistral":
data_dict = preprocess_llama_2(
sources,
self.tokenizer,
self.multimodal_cfg,
is_mistral=True,
)
elif self.conv_template == "plain":
data_dict = preprocess_plain(
sources,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,12 @@ def dummy():
media_end_id=media_end_id,
mcore_gpt=self.mcore_gpt,
config=self.transformer_config,
transformer_layer_spec=get_specs(self.spec_name),
transformer_layer_spec=get_specs(
self.spec_name,
self.transformer_config.num_moe_experts,
self.transformer_config.moe_grouped_gemm,
self.transformer_engine,
),
vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size),
max_sequence_length=self.cfg.get('encoder_seq_length', 512),
pre_process=pre_process,
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,10 @@ def load_nemo_model_weights(nemo_path, sharded_state_dict=None):

# distributed checkpointing
if state_dict is None and sharded_state_dict is not None:

is_dist_ckpt = True
checkpoint = dict(state_dict=sharded_state_dict)

tmp_model_weights_ckpt = os.path.join(tmpdir, save_restore_connector.model_weights_ckpt)
tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0]
assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.'
Expand Down Expand Up @@ -501,7 +503,7 @@ def expand2square(pil_img, background_color):
result.paste(pil_img, ((height - width) // 2, 0))
return result

frames = expand2square(frames, tuple(int(x * 255) for x in processor.image_mean))
frames = [expand2square(frame, tuple(int(x * 255) for x in self.processor.image_mean)) for frame in frames]
frames = processor.preprocess(frames, return_tensors='pt')['pixel_values']
else:
frames = processor.preprocess(frames, return_tensors='pt')['pixel_values']
Expand Down
21 changes: 21 additions & 0 deletions nemo/collections/nlp/modules/common/text_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,27 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c
copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents
) # HARDCODED FOR NOW
data_dict = preprocess_llama_3(sources, tokenizer, multimodal_cfg)
elif multimodal_cfg["conv_template"] == "mistral":
record = {
'conversations': [
{
'from': 'human',
'value': prompt,
},
{
'from': 'gpt',
'value': '',
},
],
}
for turn in record['conversations']:
if turn.get('value') is not None:
turn['value'] = re.sub('<image>', f'{DEFAULT_IMAGE_TOKEN}\n', turn['value'])
list_data_dict.append(record)
sources = preprocess_multimodal(
copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents
) # HARDCODED FOR NOW
data_dict = preprocess_llama_2(sources, tokenizer, multimodal_cfg, is_mistral=True)
elif multimodal_cfg["conv_template"] == "v1":
record = {
'conversations': [
Expand Down
Loading