Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral + Mixtral Support for NeVa #9459

Merged
merged 15 commits into from
Jul 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions nemo/collections/multimodal/data/neva/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class SeparatorStyle(Enum):
PLAIN = auto()
LLAMA_2 = auto()
LLAMA_3 = auto()
MISTRAL = auto()
NVGPT = auto()


Expand Down Expand Up @@ -94,11 +95,15 @@ def get_prompt(self):
ret += " "
else:
ret += role + ":"
elif self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
elif self.sep_style == SeparatorStyle.LLAMA_2 or self.sep_style == SeparatorStyle.MISTRAL:
if self.sep_style == SeparatorStyle.LLAMA_2:
wrap_sys = lambda msg: f"<<SYS>>\n{msg}\n<</SYS>>\n\n"
else:
wrap_sys = lambda msg: f"{msg}" + ("\n" if msg else "")
wrap_inst = lambda msg: f"[INST] {msg} [/INST]"
ret = ""

if self.sep_style == SeparatorStyle.MISTRAL:
ret += DEFAULT_BOS_TOKEN
for i, (role, message) in enumerate(messages):
if i == 0:
assert message, "first message should not be none"
Expand All @@ -112,7 +117,10 @@ def get_prompt(self):
message = wrap_inst(message)
ret += self.sep + " " + message
else:
ret += " " + message + " " + self.sep2
if self.sep_style == SeparatorStyle.LLAMA_2:
ret += " " + message + " " + self.sep2
else:
ret += message + self.sep2
else:
ret += ""
ret = ret.lstrip(self.sep)
Expand Down Expand Up @@ -449,6 +457,17 @@ def dict(self):
version="v1_mmtag",
)

conv_mistral = Conversation(
system="",
roles=("USER", "ASSISTANT"),
version="mistral",
messages=(),
offset=0,
sep_style=SeparatorStyle.MISTRAL,
sep="",
sep2=DEFAULT_EOS_TOKEN,
)

default_conversation = conv_vicuna_v1
conv_templates = {
"default": conv_vicuna_v0,
Expand All @@ -466,6 +485,7 @@ def dict(self):
"nvgpt": conv_nvgpt,
"nv_steerlm": conv_nvgpt,
"nv_dpo": conv_nv_dpo,
"mistral": conv_mistral,
}

if __name__ == "__main__":
Expand Down
34 changes: 28 additions & 6 deletions nemo/collections/multimodal/data/neva/neva_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def preprocess_llama_2(
sources: dict,
tokenizer,
cfg,
is_mistral: bool = False,
) -> Dict:
"""
Preprocesses sources for the LLaMA 2 model configuration.
Expand All @@ -442,7 +443,10 @@ def preprocess_llama_2(
- Dict: A dictionary containing tokenized and labeled data suitable for the LLaMA 2 model.
This includes tokens, labels, and any special processing as defined in the configuration.
"""
conv = conversation_lib.conv_llava_llama_2.copy()
if is_mistral:
conv = conversation_lib.conv_mistral.copy()
else:
conv = conversation_lib.conv_llava_llama_2.copy()
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}

# Apply prompt templates
Expand Down Expand Up @@ -477,7 +481,10 @@ def preprocess_llama_2(
labels = tokens.clone().detach()

# Mask labels
sep = "[/INST] "
if is_mistral:
sep = "[/INST]"
else:
sep = "[/INST] "
for conversation, target in zip(conversations, labels):
rounds = conversation.split(conv.sep2)
cur_len = 0
Expand All @@ -492,18 +499,23 @@ def preprocess_llama_2(
parts[0] += sep

round_len = len(tokenizer.text_to_ids(rou + conv.sep2))
instruction_len = len(tokenizer.text_to_ids(parts[0])) - 2

if is_mistral:
instruction_len = len(tokenizer.text_to_ids(parts[0])) - 1
else:
instruction_len = len(tokenizer.text_to_ids(parts[0])) - 2

if i > 0:
round_len -= 1 # Remove extra token added by sp tokenizer
else:
instruction_len += 1
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX

cur_len += round_len
target[cur_len:] = IGNORE_INDEX

# Check if masking working correctly
# print([x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())])
# masking_test =[x for x in zip(tokens[0].numpy().tolist(), labels[0].numpy().tolist())]
# print(masking_test)

if add_extra_token:
tokens = tokens[:, :-1].contiguous()
Expand Down Expand Up @@ -990,7 +1002,10 @@ def expand2square(pil_img, background_color):
result.paste(pil_img, ((height - width) // 2, 0))
return result

frames = expand2square(frames, tuple(int(x * 255) for x in self.processor.image_mean))
frames = [
expand2square(frame, tuple(int(x * 255) for x in self.processor.image_mean))
for frame in frames
]
frames = self.processor.preprocess(frames, return_tensors='pt')['pixel_values']
else:
frames = self.processor.preprocess(frames, return_tensors='pt')['pixel_values']
Expand Down Expand Up @@ -1057,6 +1072,13 @@ def expand2square(pil_img, background_color):
self.tokenizer,
self.multimodal_cfg,
)
elif self.conv_template == "mistral":
data_dict = preprocess_llama_2(
sources,
self.tokenizer,
self.multimodal_cfg,
is_mistral=True,
)
elif self.conv_template == "plain":
data_dict = preprocess_plain(
sources,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
HAVE_APEX = False

try:
from megatron.core import InferenceParams, dist_checkpointing, parallel_state
from megatron.core import InferenceParams, dist_checkpointing, parallel_state, tensor_parallel

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'dist_checkpointing' is not used.
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
from megatron.core.pipeline_parallel.schedules import get_forward_backward_func
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
Expand Down Expand Up @@ -154,10 +154,34 @@
self.media = media

def forward(self, input_ids, **kwargs):
media = self.media # avoid change the signature of embedding forward function
media = self.media # avoid changing the signature of embedding forward function

# TODO: Refactor replace_media_embedding to account for MCore's embedding communication optimization
# https://github.com/NVIDIA/Megatron-LM/commit/ee423e7 changes the way we handle embeddings with sequence parallelism
# When using reduce_scatter_embeddings, word_embedding_tensor is now in the following shape: [sequence/tp, batch_size, hidden_size]
# replace_media_embedding currently expects [batch_size, sequence, hidden_size]

# Check if reduce_scatter_embeddings is enabled in the embedding forward function
apply_reduce_scatter = getattr(self, 'reduce_scatter_embeddings', False)

# Set reduce_scatter_embeddings to false to keep words_embedding's
# tensor dimesion the same for replace_media_embedding
if apply_reduce_scatter:
self.reduce_scatter_embeddings = False

words_embeddings = super().forward(input_ids, **kwargs)
words_embeddings = self.replace_media_embeddings(input_ids, words_embeddings, media)

return self.replace_media_embeddings(input_ids, words_embeddings, media)
# Scatter embeddings back to each TP rank if reduce_scatter_embeddings is enabled
if apply_reduce_scatter:
words_embeddings = self._apply_reduce_scatter(words_embeddings)
self.reduce_scatter_embeddings = True

return words_embeddings

def _apply_reduce_scatter(self, embeddings):
embeddings = embeddings.transpose(0, 1).contiguous()
return tensor_parallel.mappings.scatter_to_sequence_parallel_region(embeddings)

def encode_vision_x(self, vision_x: torch.Tensor):
"""
Expand Down Expand Up @@ -193,7 +217,6 @@
def replace_media_embeddings(self, input_ids, inputs_embeds, media):
if media is None:
return inputs_embeds

batch_size, sequence_length, hidden_size = inputs_embeds.shape

# calculate media features without gradients
Expand Down Expand Up @@ -550,7 +573,12 @@
media_end_id=media_end_id,
mcore_gpt=self.mcore_gpt,
config=self.transformer_config,
transformer_layer_spec=get_specs(self.spec_name),
transformer_layer_spec=get_specs(
self.spec_name,
self.transformer_config.num_moe_experts,
self.transformer_config.moe_grouped_gemm,
self.transformer_engine,
),
vocab_size=self.cfg.get('override_vocab_size', self.padded_vocab_size),
max_sequence_length=self.cfg.get('encoder_seq_length', 512),
pre_process=pre_process,
Expand Down
4 changes: 3 additions & 1 deletion nemo/collections/multimodal/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,10 @@ def load_nemo_model_weights(nemo_path, sharded_state_dict=None):

# distributed checkpointing
if state_dict is None and sharded_state_dict is not None:

is_dist_ckpt = True
checkpoint = dict(state_dict=sharded_state_dict)

tmp_model_weights_ckpt = os.path.join(tmpdir, save_restore_connector.model_weights_ckpt)
tmp_model_weights_dir = os.path.splitext(tmp_model_weights_ckpt)[0]
assert os.path.isdir(tmp_model_weights_dir), f'Expected {tmp_model_weights_dir} to be a directory.'
Expand Down Expand Up @@ -501,7 +503,7 @@ def expand2square(pil_img, background_color):
result.paste(pil_img, ((height - width) // 2, 0))
return result

frames = expand2square(frames, tuple(int(x * 255) for x in processor.image_mean))
frames = [expand2square(frame, tuple(int(x * 255) for x in self.processor.image_mean)) for frame in frames]
frames = processor.preprocess(frames, return_tensors='pt')['pixel_values']
else:
frames = processor.preprocess(frames, return_tensors='pt')['pixel_values']
Expand Down
21 changes: 21 additions & 0 deletions nemo/collections/nlp/modules/common/text_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,27 @@ def neva_process_prompts(prompt, tokenizer, multimodal_cfg, num_media_latents, c
copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents
) # HARDCODED FOR NOW
data_dict = preprocess_llama_3(sources, tokenizer, multimodal_cfg)
elif multimodal_cfg["conv_template"] == "mistral":
record = {
'conversations': [
{
'from': 'human',
'value': prompt,
},
{
'from': 'gpt',
'value': '',
},
],
}
for turn in record['conversations']:
if turn.get('value') is not None:
turn['value'] = re.sub('<image>', f'{DEFAULT_IMAGE_TOKEN}\n', turn['value'])
list_data_dict.append(record)
sources = preprocess_multimodal(
copy.deepcopy(list_data_dict), multimodal_cfg, num_media_latents
) # HARDCODED FOR NOW
data_dict = preprocess_llama_2(sources, tokenizer, multimodal_cfg, is_mistral=True)
elif multimodal_cfg["conv_template"] == "v1":
record = {
'conversations': [
Expand Down
Loading