diff --git a/examples/multimodal/multimodal_llm/neva/conf/neva_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/neva_config.yaml index b41f15c384a8..0caf4beb6a12 100644 --- a/examples/multimodal/multimodal_llm/neva/conf/neva_config.yaml +++ b/examples/multimodal/multimodal_llm/neva/conf/neva_config.yaml @@ -181,6 +181,7 @@ model: additional_special_tokens: null # ["", "", "", "", "", ""] data: + packed_sequence: False num_workers: 8 dataloader_type: cyclic data_path: diff --git a/examples/multimodal/multimodal_llm/neva/eval/vqa_science.py b/examples/multimodal/multimodal_llm/neva/eval/vqa_science.py index 8ea267ac8116..62d8788067bb 100644 --- a/examples/multimodal/multimodal_llm/neva/eval/vqa_science.py +++ b/examples/multimodal/multimodal_llm/neva/eval/vqa_science.py @@ -79,7 +79,8 @@ def eval_model(args): cfg.base_model_file = args.model_base cfg.inference.images_base_path = args.image_folder cfg.tensor_model_parallel_size = args.tp - cfg.trainer.devices = args.tp + cfg.pipeline_model_parallel_size = args.pp + cfg.trainer.devices = args.tp * args.pp model, image_processor = create_neva_model_and_processor(cfg) length_params: LengthParam = { @@ -102,7 +103,8 @@ def eval_model(args): questions = get_chunk(questions, args.num_chunks, args.chunk_idx) answers_file = os.path.expanduser(args.answers_file) os.makedirs(os.path.dirname(answers_file), exist_ok=True) - ans_file = open(answers_file, "w") + if is_global_rank_zero(): + ans_file = open(answers_file, "w") for i, line in enumerate(tqdm(questions, disable=(not is_global_rank_zero()))): idx = line["id"] question = line['conversations'][0] @@ -123,7 +125,8 @@ def eval_model(args): sampling_params=sampling_params, inference_config=cfg, ) - # import pdb; pdb.set_trace() + if responses is None: + continue outputs = responses[0]["clean_response"] # prompt for answer @@ -139,22 +142,24 @@ def eval_model(args): outputs = responses[0]["clean_response"] outputs = outputs_reasoning + '\n The answer is ' + outputs - ans_id = shortuuid.uuid() - ans_file.write( - json.dumps( - { - "question_id": idx, - "prompt": cur_prompt, - "text": outputs, - "answer_id": ans_id, - "model_id": args.model_path, - "metadata": {}, - } + if is_global_rank_zero(): + ans_id = shortuuid.uuid() + ans_file.write( + json.dumps( + { + "question_id": idx, + "prompt": cur_prompt, + "text": outputs, + "answer_id": ans_id, + "model_id": args.model_path, + "metadata": {}, + } + ) + + "\n" ) - + "\n" - ) - ans_file.flush() - ans_file.close() + ans_file.flush() + if is_global_rank_zero(): + ans_file.close() if __name__ == "__main__": @@ -166,6 +171,7 @@ def eval_model(args): parser.add_argument("--answers-file", type=str, default="answer.jsonl") parser.add_argument("--conv-mode", type=str, default="llava_v0") parser.add_argument("--tp", type=int, default=1) + parser.add_argument("--pp", type=int, default=1) parser.add_argument("--num-chunks", type=int, default=1) parser.add_argument("--chunk-idx", type=int, default=0) parser.add_argument("--temperature", type=float, default=0.2) diff --git a/examples/multimodal/multimodal_llm/neva/neva_evaluation.py b/examples/multimodal/multimodal_llm/neva/neva_evaluation.py index bd3f975e4d54..d9d9a71db757 100644 --- a/examples/multimodal/multimodal_llm/neva/neva_evaluation.py +++ b/examples/multimodal/multimodal_llm/neva/neva_evaluation.py @@ -20,6 +20,7 @@ from nemo.collections.multimodal.parts.utils import create_neva_model_and_processor from nemo.collections.nlp.modules.common.transformer.text_generation import LengthParam, SamplingParam from nemo.core.config import hydra_runner +from nemo.utils.get_rank import is_global_rank_zero try: @@ -121,22 +122,27 @@ def forward_loop(): ) # ============== Quantization End ========================= - results = [] - for response, prompt in zip(responses, final_prompts): - prompt['full_text'] = response["clean_text"] - prompt['text'] = response["clean_response"] - prompt['model_id'] = cfg.neva_model_file - if 'image_path' in prompt: - prompt['image'] = prompt.pop('image_path') - if 'answer_id' not in prompt: - prompt['answer_id'] = 0 - if 'metadata' not in prompt: - prompt['metadata'] = {} - results.append(prompt) - - with open(cfg.output_file, 'w') as f: - for result in results: - f.write(json.dumps(result) + '\n') + # PP middle stages do not yield any responses + if responses is None: + return + + if is_global_rank_zero(): + results = [] + for response, prompt in zip(responses, final_prompts): + prompt['full_text'] = response["clean_text"] + prompt['text'] = response["clean_response"] + prompt['model_id'] = cfg.neva_model_file + if 'image_path' in prompt: + prompt['image'] = prompt.pop('image_path') + if 'answer_id' not in prompt: + prompt['answer_id'] = 0 + if 'metadata' not in prompt: + prompt['metadata'] = {} + results.append(prompt) + + with open(cfg.output_file, 'w') as f: + for result in results: + f.write(json.dumps(result) + '\n') if __name__ == '__main__': diff --git a/examples/multimodal/multimodal_llm/neva/sequence_packing/preprocess_dataset.py b/examples/multimodal/multimodal_llm/neva/sequence_packing/preprocess_dataset.py new file mode 100644 index 000000000000..ee96ff6489d3 --- /dev/null +++ b/examples/multimodal/multimodal_llm/neva/sequence_packing/preprocess_dataset.py @@ -0,0 +1,354 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Example Usage: +-------------- +This script preprocesses a dataset for the NeMo Multimodal Learning framework. It requires specifying paths for data, images, and the tokenizer model, among other parameters. + +Command: +python examples/multimodal/multimodal_llm/neva/sequence_packing/preprocess_dataset.py \ + --data_path=/path/to/LLaVA-Instruct-150K/llava_v1_5_mix665k_filtered.json \ + --image_folder=/path/to/LLaVA-Instruct-150K/images \ + --tokenizer_path=/path/to/checkpoints/tokenizer_add_special.model \ + --output_dir=/path/to/LLaVA-Instruct-150K/packed_seq_4096_336_v1 \ + --max_seq_length=12288 \ + --packing_algorithm=first_fit_shuffle \ + --hf_vision_encoder=openai/clip-vit-large-patch14-336 \ + --conv_template=v1 \ + --image_aspect_ratio=pad \ + --seed=42 + +Parameters: +----------- +--data_path: Path to the dataset file in JSON format. +--image_folder: Directory containing the images referenced in the dataset. +--tokenizer_path: Path to the tokenizer model. +--output_dir: Directory where the processed dataset will be stored. +--max_seq_length: The maximum sequence length of the model. +--packing_algorithm: Algorithm used for packing sequences. Defaults to 'first_fit_shuffle'. +--hf_vision_encoder: The Hugging Face vision encoder to use. Default is 'openai/clip-vit-large-patch14-336'. +--conv_template: Template for data conversion. Default is 'plain', with 'v1' as an alternative. +--image_aspect_ratio: The aspect ratio for processing images. Defaults to 'square', 'pad' for padding to maintain aspect ratio. +--seed: Seed for random operations in 'first_fit_shuffle'. +--hparams_file: Optional path to a YAML file containing additional hyperparameters. +""" + +import collections +import os +import random +import re +from argparse import ArgumentParser +from concurrent.futures import ThreadPoolExecutor, as_completed + +import numpy as np +import torch +from megatron.core.datasets.indexed_dataset import IndexedDataset, IndexedDatasetBuilder, get_bin_path, get_idx_path +from omegaconf import OmegaConf +from torch.utils.data import DataLoader +from tqdm import tqdm + +from nemo.collections.multimodal.data.neva.neva_dataset import make_supervised_data_module +from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer +from nemo.utils import logging + +PACKING_ALGOS = ['first_fit_decreasing', 'first_fit_shuffle', 'shuffle_and_pack'] + + +def first_fit(seq_lens, max_seq_length): + """ + Assigns sequences to bins using the First Fit algorithm, by integrating the search + and assignment within the same function. It moves bins that can no longer fit the minimum sequence length + to a completed bins list, avoiding direct modification of the bins list during iteration. + + Parameters: + - seq_lens: List of sequence lengths. + - max_seq_length: Maximum capacity of each bin. + + Returns: + - List of bins with assigned sequence lengths. + """ + min_seq_len = min(seq_lens) # Find the minimum sequence length + completed_bins = [] # Initialize the completed bins list + bins = [] # Initialize the bins list to store active bins + + for s in tqdm(seq_lens): # Iterate through each sequence length + found_bin = False + for i, abin in enumerate(bins[:]): # Iterate over a shallow copy of bins + if sum(abin) + min_seq_len > max_seq_length: + completed_bins.append(abin) # Add to completed bins + bins[i] = 'TO_REMOVE' # Mark this bin for removal + continue + if sum(abin) + s <= max_seq_length: # Check if the bin can fit the sequence + bins[i].append(s) # If so, add the sequence to this bin + found_bin = True + break + + if not found_bin: # If no existing bin can fit the sequence + bins.append([s]) # Open a new bin for this sequence + + # Clean up bins marked 'TO_REMOVE' + bins = [bin for bin in bins if bin != 'TO_REMOVE'] + + # Combine completed bins with any remaining active bins + all_bins = completed_bins + bins + return all_bins + + +def chunkify(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i : i + n] + + +def parallel_first_fit(seq_lens, max_seq_length, chunk_size, num_workers): + """ + Assigns sequences to bins in parallel using the First Fit algorithm. + + Parameters: + - seq_lens: List of sequence lengths. + - max_seq_length: Maximum capacity of each bin. + - chunk_size: Size of chunks to divide seq_lens into for parallel processing. + - num_workers: Number of worker threads to use in the ThreadPoolExecutor. + + Returns: + - List of bins with assigned sequence lengths. + """ + # Split the sequence lengths into chunks + chunks = list(chunkify(seq_lens, chunk_size)) + + # Function to process each chunk + def process_chunk(chunk): + return first_fit(chunk, max_seq_length) + + bins = [] # This will hold the final bins + with ThreadPoolExecutor(max_workers=num_workers) as executor: + # Submit each chunk to the executor + futures = [executor.submit(process_chunk, chunk) for chunk in chunks] + + # As each future completes, combine its bins with the final bins + for future in as_completed(futures): + bins.extend(future.result()) + + return bins + + +def first_fit_decreasing(seq_lens, max_seq_length): + """ + Assigns sequences to bins using the First Fit Decreasing algorithm. + + Parameters: + - seq_lens: List of sequence lengths. + - max_seq_length: Maximum capacity of each bin. + + Returns: + - List of bins with assigned sequence lengths. + """ + sorted_seq_lens = sorted(seq_lens, reverse=True) + return first_fit(sorted_seq_lens, max_seq_length) + + +def first_fit_shuffle(seq_lens, max_seq_length): + """ + Assigns sequences to bins using a shuffled version of the First Fit algorithm. + + Parameters: + - seq_lens: List of sequence lengths. + - max_seq_length: Maximum capacity of each bin. + + Returns: + - List of bins with assigned sequence lengths. + """ + shuffled_seq_lens = seq_lens[:] + np.random.shuffle(shuffled_seq_lens) + return parallel_first_fit(shuffled_seq_lens, max_seq_length, 20000, 32) + + +def shuffle_and_pack(seq_lens, max_seq_length): + """ + Assigns sequences to bins with shuffling, trying to maximize the packing efficiency. + After shuffling the sequences, they will be added to one bin in order. Once the bin cannot + take more sequences, we will move on to the next bin. + + Parameters: + - seq_lens: List of sequence lengths. + - max_seq_length: Maximum capacity of each bin. + + Returns: + - List of bins with assigned sequence lengths. + """ + shuffled_seq_lens = np.array(seq_lens) + np.random.shuffle(shuffled_seq_lens) + bins = [[]] + cur_bin_total = 0 + for s in tqdm(shuffled_seq_lens): + if cur_bin_total + s <= max_seq_length: + bins[-1].append(s) + cur_bin_total += s + else: + bins.append([s]) + cur_bin_total = s + return bins + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--data_path", type=str) + parser.add_argument("--image_folder", type=str) + parser.add_argument("--tokenizer_path", type=str) + parser.add_argument('--output_dir', required=True, type=str) + parser.add_argument("--max_seq_length", default=4096, type=int) + parser.add_argument('--packing_algorithm', default='first_fit_shuffle', choices=PACKING_ALGOS, type=str) + parser.add_argument("--hf_vision_encoder", default='openai/clip-vit-large-patch14-336', type=str) + parser.add_argument("--conv_template", default='plain', type=str) + parser.add_argument("--image_aspect_ratio", default='square', type=str) + parser.add_argument('--seed', default=0, type=int, help="Seed for shuffling, used with first_fit_shuffle.") + parser.add_argument( + "--hparams_file", + type=str, + default=os.path.join(os.path.dirname(__file__), '../conf/llava_config.yaml'), + required=False, + help="Path to the hparams file.", + ) + return parser.parse_args() + + +def pack_sequence(args, seq_lens): + """ + Packs sequences according to the specified algorithm in args. + + Parameters: + - args: Command line arguments. + - seq_lens: List of sequence lengths. + + Returns: + - List of bins with assigned sequence lengths. + """ + np.random.seed(args.seed) + random.seed(args.seed) + + packing_fn = globals()[args.packing_algorithm] + bins = packing_fn(seq_lens, args.max_seq_length) + return bins + + +def main(): + torch.multiprocessing.set_sharing_strategy('file_system') + + args = get_args() + nemo_config = OmegaConf.load(args.hparams_file) + nemo_config.model.mm_cfg.vision_encoder.from_pretrained = args.hf_vision_encoder + nemo_config.model.data.data_path = args.data_path + nemo_config.model.data.image_folder = args.image_folder + nemo_config.model.data.conv_template = args.conv_template + nemo_config.model.data.image_aspect_ratio = args.image_aspect_ratio + + tokenizer = get_nmt_tokenizer(library="sentencepiece", tokenizer_model=args.tokenizer_path,) + train_ds = make_supervised_data_module(tokenizer=tokenizer, model_cfg=nemo_config.model)["train_dataset"] + train_dl = DataLoader(train_ds, num_workers=32, collate_fn=None, shuffle=False) + # Example shape: {'tokens': torch.Size([1, 344]), 'labels': torch.Size([1, 344]), 'image': torch.Size([1, 1, 3, 224, 224])} + + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + logging.info(f"Output directory: {output_dir}") + + prefix_path = f"{output_dir}/packed_seq_dataset" + # Original Datasets to Sequence Lengths Files + builders = {} + for item_dict in tqdm(train_dl, desc="Building indexed datasets"): + item_dict = {k: v[0] for k, v in item_dict.items()} + seq_len = len(item_dict['tokens']) + if seq_len in builders: + builder = builders[seq_len] + else: + builder_path = get_bin_path(f"{prefix_path}/seqlen_{seq_len}") + logging.info(f"Creating builder for sequence length {seq_len} at {builder_path}") + builder = IndexedDatasetBuilder(builder_path, dtype=np.float32, multimodal=True) + builders[seq_len] = builder + builder.add_item(item_dict['tokens']) + builder.add_item(item_dict['labels']) + builder.add_item(item_dict['image'], 1) + builder.end_document() + del item_dict + + for seq_len, builder in builders.items(): + idx_path = get_idx_path(f"{prefix_path}/seqlen_{seq_len}") + logging.info(f"Finalizing builder for sequence length {seq_len} at {idx_path}") + builder.finalize(idx_path) + + # Packing Sequences into Bins + files = os.listdir(f"{output_dir}/packed_seq_dataset") + pattern = rf"seqlen_(\d+).bin" + seq_len_list = [] + for file in files: + match = re.match(pattern, file) + if match: + seq_len = int(match.group(1)) + seq_len_list.append(seq_len) + + aggregated_seq_lens = [] + doc_pop_order = {} + indexed_datasets = {} + for seq_len in seq_len_list: + dataset_path = f"{prefix_path}/seqlen_{seq_len}" + dataset = IndexedDataset(dataset_path, multimodal=True) + aggregated_seq_lens.extend([seq_len] * (len(dataset.document_indices) - 1)) + doc_pop_order[seq_len] = list(np.random.permutation(len(dataset.document_indices) - 1)) + indexed_datasets[seq_len] = dataset + + logging.info("Getting bins") + bins = pack_sequence(args, aggregated_seq_lens) + logging.info("Finished getting bins") + + num_bins = len(bins) + avg_bins_len = sum([len(x) for x in bins]) / num_bins + avg_bins_sum = sum([sum(x) for x in bins]) / num_bins + logging.info(f"Number of bins: {num_bins}, Average bin length: {avg_bins_len}, Average bin sum: {avg_bins_sum}") + + # Reading Sequence Lengths and Packing into New Files + final_builder_path = get_bin_path(f"{prefix_path}") + logging.info(f"Creating final builder at {final_builder_path}") + final_builder = IndexedDatasetBuilder(final_builder_path, dtype=np.float32, multimodal=True) + + for assignment in tqdm(bins, desc="Building final dataset"): + packed_items = collections.defaultdict(list) + packed_items["seq_indices"] = [0] + for seq_len in assignment: + doc_index = doc_pop_order[seq_len].pop() + doc_start = indexed_datasets[seq_len].document_indices[doc_index] + doc_end = indexed_datasets[seq_len].document_indices[doc_index + 1] + item_dict = { + "tokens": torch.tensor((indexed_datasets[seq_len][doc_start:doc_end][0])[0]), + "labels": torch.tensor((indexed_datasets[seq_len][doc_start:doc_end][0])[1]), + "image": torch.tensor((indexed_datasets[seq_len][doc_start:doc_end][0])[2]), + } + for key in ["tokens", "labels", "image"]: + packed_items[key].append(item_dict[key]) + packed_items["seq_indices"].append(packed_items["seq_indices"][-1] + seq_len) + + for key in ["seq_indices", "tokens", "labels", "image"]: + final_builder.add_item( + torch.tensor(packed_items[key]) if key == "seq_indices" else torch.cat(packed_items[key], dim=0), + 1 if key == "image" else 0, + ) + final_builder.end_document() + + idx_path = get_idx_path(f"{prefix_path}") + logging.info(f"Finalizing final builder at {idx_path}") + final_builder.finalize(idx_path) + logging.info(f"Number of bins: {num_bins}, Average bin length: {avg_bins_len}, Average bin sum: {avg_bins_sum}") + + +if __name__ == "__main__": + main() diff --git a/nemo/collections/multimodal/data/neva/neva_dataset.py b/nemo/collections/multimodal/data/neva/neva_dataset.py index 71d9bda12de1..ddd409e928b2 100644 --- a/nemo/collections/multimodal/data/neva/neva_dataset.py +++ b/nemo/collections/multimodal/data/neva/neva_dataset.py @@ -18,7 +18,7 @@ import re import tarfile from dataclasses import dataclass -from typing import Any, Dict, List, Sequence, Union +from typing import Any, Dict, List, Sequence, Tuple, Union import torch import torch.nn.functional as F @@ -49,6 +49,15 @@ MAX_NUM_IMAGES = 1 IGNORE_INDEX = -1 +try: + from megatron.core.datasets.indexed_dataset import IndexedDataset + + HAVE_MEGATRON_CORE = True + +except (ImportError, ModuleNotFoundError): + + HAVE_MEGATRON_CORE = False + class TarOrFolderImageLoader: """ @@ -781,12 +790,27 @@ class DataCollatorForSupervisedDataset(object): tokenizer: transformers.PreTrainedTokenizer def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: + packed_sequence = "cu_seqlens" in instances[0] max_len = max(instance['tokens'].shape[0] for instance in instances) max_len = (max_len - 1) // 64 * 64 + 64 for instance in instances: pad_len = max_len - instance['tokens'].shape[0] instance['tokens'] = F.pad(instance['tokens'], (0, pad_len), 'constant', 0) instance['labels'] = F.pad(instance['labels'], (0, pad_len), 'constant', -1) + if packed_sequence and instance["cu_seqlens"][-1] != max_len: + instance["cu_seqlens"] = torch.cat((instance["cu_seqlens"], torch.IntTensor([max_len])), 0) + + if packed_sequence: + max_len_cu = max(instance['cu_seqlens'].shape[0] for instance in instances) + max_len_image = max(instance['image'].shape[0] for instance in instances) + for instance in instances: + pad_len_cu = max_len_cu - instance['cu_seqlens'].shape[0] + instance['cu_seqlens'] = F.pad(instance['cu_seqlens'], (0, pad_len_cu), 'constant', max_len) + + x = instance['image'] + num_pad = max_len_image - x.shape[0] + pad_tensor = torch.zeros(num_pad, *x.shape[1:], dtype=x.dtype, device=x.device) + instance['image'] = torch.cat((x, pad_tensor), dim=0) batch = default_collate(instances) tokenizer = self.tokenizer @@ -796,13 +820,25 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: labels = batch['labels'] media = batch.get('image') - attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( - data=tokens, - eod_token=tokenizer.eos_id, - eod_mask_loss=model_cfg.data.get("eod_mask_loss", False), - reset_attention_mask=False, - reset_position_ids=False, - ) + if packed_sequence: + cu_seqlens = batch["cu_seqlens"] + position_ids = [] + for cu_seqlen in cu_seqlens: + position_ids.append([]) + for ind in range(0, len(cu_seqlen) - 1): + seqlen = cu_seqlen[ind + 1] - cu_seqlen[ind] + position_ids[-1].extend(list(range(seqlen))) + position_ids = torch.LongTensor(position_ids) + loss_mask = torch.ones(tokens.size(), dtype=torch.float, device=tokens.device) + attention_mask = torch.ones(tokens.size(), dtype=torch.long, device=tokens.device) + else: + attention_mask, loss_mask, position_ids = get_ltor_masks_and_position_ids( + data=tokens, + eod_token=tokenizer.eos_id, + eod_mask_loss=model_cfg.data.get("eod_mask_loss", False), + reset_attention_mask=False, + reset_position_ids=False, + ) loss_mask[labels == -1] = 0.0 tokens[tokens == -1] = 0 @@ -821,6 +857,8 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: 'position_ids': position_ids, 'media': media, } + if packed_sequence: + batch["cu_seqlens"] = cu_seqlens return batch @@ -859,3 +897,23 @@ def make_supervised_data_module(tokenizer, model_cfg) -> Dict: ) return dict(train_dataset=train_dataset, eval_dataset=train_dataset) + + +class NevaPackedSeqDatatset(Dataset): + def __init__(self, data_path: str, crop_size: Tuple[int, int] = (224, 224)): + self.ds = IndexedDataset(data_path) + self.crop_size = crop_size + + def __len__(self): + return len(self.ds.document_indices) - 1 + + def __getitem__(self, i) -> Dict[str, torch.Tensor]: + doc_start = self.ds.document_indices[i] + batch = { + "cu_seqlens": torch.IntTensor(self.ds[doc_start]), + "tokens": torch.LongTensor(self.ds[doc_start + 1]), + "labels": torch.LongTensor(self.ds[doc_start + 2]), + "image": torch.FloatTensor(self.ds[doc_start + 3]).reshape(-1, 3, *self.crop_size), + } + + return batch diff --git a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py index cff8ab1a7b5f..5b50a8340b06 100644 --- a/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py +++ b/nemo/collections/multimodal/models/multimodal_llm/neva/neva_model.py @@ -21,6 +21,7 @@ import torch.nn.functional as F from einops import rearrange, repeat from omegaconf.dictconfig import DictConfig +from pkg_resources import packaging from pytorch_lightning.trainer.trainer import Trainer from transformers import CLIPVisionModel @@ -28,6 +29,7 @@ from nemo.collections.multimodal.data.neva.conversation import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN from nemo.collections.multimodal.data.neva.neva_dataset import ( DataCollatorForSupervisedDataset, + NevaPackedSeqDatatset, make_supervised_data_module, ) from nemo.collections.multimodal.models.vision_language_foundation.clip.megatron_clip_models import ( @@ -43,7 +45,10 @@ AdapterName, MultimodalProjectorAdapterConfig, ) -from nemo.collections.nlp.modules.common.megatron.utils import average_losses_across_data_parallel_group +from nemo.collections.nlp.modules.common.megatron.utils import ( + average_losses_across_data_parallel_group, + get_iterator_k_split, +) from nemo.collections.nlp.modules.common.text_generation_utils import ( generate, get_computeprob_response, @@ -61,6 +66,7 @@ try: import apex.transformer.pipeline_parallel.utils + from apex.transformer.pipeline_parallel.utils import get_num_microbatches HAVE_APEX = True @@ -71,6 +77,7 @@ try: from megatron.core import InferenceParams, dist_checkpointing, parallel_state from megatron.core.models.gpt import GPTModel as MCoreGPTModel + from megatron.core.pipeline_parallel.schedules import get_forward_backward_func HAVE_MEGATRON_CORE = True @@ -385,14 +392,24 @@ def __init__( NevaBaseModel.__init__(self, mm_cfg, media_start_id, media_end_id, mcore_gpt, **kwargs) def freeze_llm(self, mm_cfg): - for param in chain(self.embedding.parameters(), self.decoder.parameters(), self.output_layer.parameters(),): + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + embedding_parameters = self.embedding.parameters() + else: + embedding_parameters = {} + if parallel_state.is_pipeline_last_stage(ignore_virtual=True): + output_layer_parameters = self.output_layer.parameters() + else: + output_layer_parameters = {} + + for param in chain(embedding_parameters, self.decoder.parameters(), output_layer_parameters,): param.requires_grad = False def forward( self, *args, **kwargs, ): media = kwargs.pop('media', None) - self.embedding.word_embeddings.set_media(media) + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + self.embedding.word_embeddings.set_media(media) return MCoreGPTModel.forward(self, *args, **kwargs) @@ -418,7 +435,8 @@ def forward( self, *args, **kwargs, ): media = kwargs.pop('media', None) - self.embedding.word_embeddings.set_media(media) + if parallel_state.is_pipeline_first_stage(ignore_virtual=True): + self.embedding.word_embeddings.set_media(media) return GPTModel.forward(self, *args, **kwargs) @@ -611,7 +629,73 @@ def forward(self, tokens, text_position_ids, attention_mask, labels, media=None) return output_tensor def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): - return MegatronGPTModel.fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step) + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + return MegatronGPTModel.fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step) + else: + batch, _, _ = next(dataloader_iter) + _, seq_length = batch['tokens'].shape + batch_iter = get_iterator_k_split(batch, get_num_microbatches()) + + # handle asynchronous grad reduction + no_sync_func = None + grad_sync_func = None + param_sync_func = None + if not forward_only and self.with_distributed_adam: + no_sync_func = partial(self._optimizer.no_sync, greedy_grad_copy=self.megatron_amp_O2,) + grad_sync_func = self.reduce_overlap_gradients + param_sync_func = self.sync_overlap_parameters + + # pipeline schedules will get these from self.model.config + for module in self.get_model_module_list(): + module.config.no_sync_func = no_sync_func + module.config.grad_sync_func = grad_sync_func + module.config.param_sync_func = param_sync_func + + # run forward and backwards passes for an entire global batch + # we do this inside training_step to support pipeline parallelism + fwd_bwd_function = get_forward_backward_func() + # print(f"{torch.distributed.get_rank()}: {parallel_state.is_pipeline_last_stage()} {fwd_bwd_function}") + + # TODO @akhattar: add num_micro_batches_with_partial_activation_checkpoints when ready + losses_reduced_per_micro_batch = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(forward_only), + data_iterator=self._make_data_iterator_list(batch_iter), + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=seq_length, + micro_batch_size=self.cfg.micro_batch_size, + first_val_step=first_val_step, + ) + + # only the last stages of the pipeline return losses + if losses_reduced_per_micro_batch: + if (not forward_only) or self.cfg.data.get('validation_drop_last', True): + # average loss across micro batches + loss_tensors_list = [loss_reduced['avg'] for loss_reduced in losses_reduced_per_micro_batch] + loss_tensor = torch.concat(loss_tensors_list) + loss_mean = loss_tensor.mean() + else: + # Get the total loss since micro batches sizes are not uniform + loss_sum_tensors_list = [ + loss_sum['loss_sum_and_ub_size'] + for loss_sum in losses_reduced_per_micro_batch + if loss_sum['loss_sum_and_ub_size'][1] > 0 + ] + loss_sum = ( + torch.vstack(loss_sum_tensors_list).sum(axis=0) + if len(loss_sum_tensors_list) > 0 + else torch.tensor([0.0, 0.0]).cuda() + ) + return loss_sum + else: + # we're not on the last pipeline stage so no losses + if forward_only: + loss_mean = [] + else: + loss_mean = torch.tensor(0.0).cuda() + + return loss_mean def training_step(self, dataloader_iter): """ @@ -631,7 +715,9 @@ def loss_func(output_tensor, loss_mask): return loss_for_ub, dict(avg=reduced_loss[0].unsqueeze(0)) def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): - batch, _, _ = next(dataloader_iter) + batch = next(dataloader_iter) + if isinstance(batch, tuple): + batch = batch[0] if parallel_state.get_pipeline_model_parallel_world_size() == 1: for k in batch.keys(): if self.get_attention_mask_from_fusion: @@ -644,28 +730,36 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ for k in batch.keys(): if self.get_attention_mask_from_fusion: batch[k] = ( - batch[k].cuda(non_blocking=True) if k in ['tokens', 'position_ids', 'media'] else None + batch[k].cuda(non_blocking=True) + if k in ['tokens', 'position_ids', 'media', 'cu_seqlens'] + else None ) else: batch[k] = ( batch[k].cuda(non_blocking=True) - if k in ['tokens', 'position_ids', 'attention_mask', 'media'] + if k in ['tokens', 'position_ids', 'attention_mask', 'media', 'cu_seqlens'] else None ) elif parallel_state.is_pipeline_last_stage(): # Last pipeline stage needs the labels, loss_mask, and attention_mask for k in batch.keys(): if self.get_attention_mask_from_fusion: - batch[k] = batch[k].cuda(non_blocking=True) if k in ['labels', 'loss_mask'] else None + batch[k] = ( + batch[k].cuda(non_blocking=True) + if k in ['labels', 'loss_mask', 'cu_seqlens'] + else None + ) else: batch[k] = ( batch[k].cuda(non_blocking=True) - if k in ['labels', 'loss_mask', 'attention_mask'] + if k in ['labels', 'loss_mask', 'attention_mask', 'cu_seqlens'] else None ) else: # Intermediate pipeline stage doesn't need any inputs - batch = {k: None for k in ['tokens', 'position_ids', 'attention_mask', 'labels', 'media']} + batch = { + k: None for k in ['tokens', 'position_ids', 'attention_mask', 'labels', 'media', 'loss_mask'] + } forward_args = { 'input_ids': batch['tokens'], @@ -678,16 +772,40 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_ if self.use_loss_mask: forward_args['loss_mask'] = batch['loss_mask'] forward_args['checkpoint_activations_all_layers'] = checkpoint_activations_all_layers + else: + if 'cu_seqlens' in batch: # packed sequence + # these args are passed eventually into TEDotProductAttention.forward() + cu_seqlens = batch['cu_seqlens'].squeeze() # remove batch size dimension (mbs=1) + max_seqlen = batch['max_seqlen'].squeeze() if 'max_seqlen' in batch else None + + try: + from megatron.core.packed_seq_params import PackedSeqParams + except (ImportError, ModuleNotFoundError) as e: + mcore_version = packaging.version.Version(version('megatron-core')) + logging.error( + f"megatron-core v{mcore_version} does not support training with packed sequence. " + "Please use megatron-core >= 0.5.0, or set model.data.train_ds.packed_sequence=False" + ) + raise e + forward_args['packed_seq_params'] = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + qkv_format='thd', + ) output_tensor = model(**forward_args) - return output_tensor, partial(loss_func, loss_mask=batch['loss_mask']) + return output_tensor, partial(loss_func, loss_mask=batch.get('loss_mask')) return fwd_output_and_loss_func def get_forward_output_only_func(self): def fwd_output_only_func(dataloader_iter, model): - batch, _, _ = next(dataloader_iter) + batch = next(dataloader_iter) + if isinstance(batch, tuple): + batch = batch[0] extra_arg = {} ( tokens, @@ -859,9 +977,14 @@ def setup(self, stage=None): def build_train_valid_test_datasets(self): logging.info('Building Neva datasets.') - ds_dict = make_supervised_data_module(tokenizer=self.tokenizer, model_cfg=self.cfg,) - self._train_ds = ds_dict["train_dataset"] - self._validation_ds = ds_dict["eval_dataset"] + if self.cfg.data.get("packed_sequence", False): + assert self.cfg.micro_batch_size == 1, "Micro batch size must be 1 if using packed sequence" + self._train_ds = NevaPackedSeqDatatset(self.cfg.data.data_prefix, self.cfg.data.get("crop_size")) + self._validation_ds = NevaPackedSeqDatatset(self.cfg.data.data_prefix, self.cfg.data.get("crop_size")) + else: + ds_dict = make_supervised_data_module(tokenizer=self.tokenizer, model_cfg=self.cfg,) + self._train_ds = ds_dict["train_dataset"] + self._validation_ds = ds_dict["eval_dataset"] return self._train_ds, self._validation_ds @@ -872,12 +995,17 @@ def build_pretraining_data_loader( logging.info(f'Building dataloader with consumed samples: {consumed_samples}') # Megatron sampler + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + micro_batch_size = self.cfg.micro_batch_size + else: + micro_batch_size = self.cfg.global_batch_size // parallel_state.get_data_parallel_world_size() + if hasattr(self.cfg.data, 'dataloader_type') and self.cfg.data.dataloader_type is not None: if self.cfg.data.dataloader_type == 'single': batch_sampler = MegatronPretrainingSampler( total_samples=len(dataset), consumed_samples=consumed_samples, - micro_batch_size=self.cfg.micro_batch_size, + micro_batch_size=micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), drop_last=drop_last, @@ -889,7 +1017,7 @@ def build_pretraining_data_loader( dataset=dataset, total_samples=len(dataset), consumed_samples=consumed_samples, - micro_batch_size=self.cfg.micro_batch_size, + micro_batch_size=micro_batch_size, data_parallel_rank=parallel_state.get_data_parallel_rank(), data_parallel_size=parallel_state.get_data_parallel_world_size(), drop_last=self.cfg.get('drop_last', True), @@ -953,14 +1081,9 @@ def load_state_dict(self, state_dict, strict=False): def on_load_checkpoint(self, checkpoint) -> None: pass - # if self.mcore_gpt: - # state_dict = checkpoint["state_dict"] - # self.load_state_dict(state_dict) def sharded_state_dict(self, prefix: str = ''): return None - # sharded_state_dict = MegatronGPTModel.sharded_state_dict(self, prefix) - # return sharded_state_dict def predict_step(self, batch: Any, batch_idx: int, dataloader_idx: Optional[int] = None) -> Any: inference_config = self.get_inference_config() diff --git a/nemo/collections/multimodal/parts/utils.py b/nemo/collections/multimodal/parts/utils.py index 723e965eb8a8..71c28cf00855 100644 --- a/nemo/collections/multimodal/parts/utils.py +++ b/nemo/collections/multimodal/parts/utils.py @@ -320,7 +320,7 @@ def dummy(): def create_neva_model_and_processor(cfg): - from nemo.collections.multimodal.models.neva.neva_model import MegatronNevaModel + from nemo.collections.multimodal.models.multimodal_llm.neva.neva_model import MegatronNevaModel plugins = [] if cfg.get('cluster_type', None) == 'BCP': @@ -366,6 +366,7 @@ def create_neva_model_and_processor(cfg): neva_cfg.precision = trainer.precision neva_cfg.mm_cfg.llm.from_pretrained = cfg.get('base_model_file', None) neva_cfg.apply_rope_fusion = False + neva_cfg.fp8 = False # neva_cfg.mm_cfg.vision_encoder.from_pretrained = None model = MegatronNevaModel.restore_from( diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index c2e1f0ed48b7..7a2f3459470c 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -784,7 +784,11 @@ def training_step(self, dataloader_iter): self._optimizer._finish_bucket_grad_sync() elif self.megatron_amp_O2: # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously) - if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False): + if ( + self.cfg.get('pipeline_model_parallel_size', 1) > 1 + or self.cfg.get('sequence_parallel', False) + or not self.cfg.get('async_grad_allreduce', True) + ): # main grads are stored in the MainParamsOptimizer wrapper self._optimizer.allreduce_main_grads() else: diff --git a/nemo/collections/nlp/modules/common/text_generation_utils.py b/nemo/collections/nlp/modules/common/text_generation_utils.py index d130322404b6..b50c9de682f7 100644 --- a/nemo/collections/nlp/modules/common/text_generation_utils.py +++ b/nemo/collections/nlp/modules/common/text_generation_utils.py @@ -173,6 +173,10 @@ def megatron_neva_generate(model, prompt_dict_list, length_params, sampling_para **strategy_args, ) + # Middle stages of PP will return None + if response is None: + continue + # Regular expression pattern to match the sequence pattern = re.compile(rf'{DEFAULT_IM_START_TOKEN}( ⁇ )+{DEFAULT_IM_END_TOKEN}') pattern_nvgpt = re.compile(rf'{DEFAULT_IM_START_TOKEN}({DEFAULT_IMAGE_PATCH_TOKEN})+{DEFAULT_IM_END_TOKEN}') diff --git a/nemo/collections/vision/data/megatron/data_samplers.py b/nemo/collections/vision/data/megatron/data_samplers.py index 82fc49990c49..2f63e675731b 100644 --- a/nemo/collections/vision/data/megatron/data_samplers.py +++ b/nemo/collections/vision/data/megatron/data_samplers.py @@ -67,7 +67,9 @@ def __iter__(self): random_idx = torch.randperm(bucket_size, generator=g).tolist() idx_range = [start_idx + x for x in random_idx[bucket_offset:]] else: - full_bucket_size = (self.total_samples // self.micro_batch_size) * self.micro_batch_size + full_bucket_size = ( + self.total_samples // self.micro_batch_times_data_parallel_size + ) * self.micro_batch_times_data_parallel_size full_bucket_offset = current_epoch_samples g = torch.Generator() g.manual_seed(self.epoch)