From 5fdbbd292536ceee61945c2c0eaa197a84c447e2 Mon Sep 17 00:00:00 2001 From: Patrick Date: Mon, 4 Nov 2024 15:31:34 +0000 Subject: [PATCH 1/7] minor update to train scripts --- examples/multimodal/pretrain_mistral_clip.sh | 11 +++++------ examples/multimodal/run_text_generation.py | 18 ++++++++++++------ examples/multimodal/sft_mistral_clip.sh | 17 ++++++----------- .../multimodal/text_generation_mistral_clip.sh | 13 ++++++++++--- 4 files changed, 33 insertions(+), 26 deletions(-) diff --git a/examples/multimodal/pretrain_mistral_clip.sh b/examples/multimodal/pretrain_mistral_clip.sh index b06dbfe53c..eefdff0673 100755 --- a/examples/multimodal/pretrain_mistral_clip.sh +++ b/examples/multimodal/pretrain_mistral_clip.sh @@ -29,7 +29,7 @@ if [[ -z $TOKENIZER_MODEL ]]; then exit 1 fi -CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}/checkpoints" +CHECKPOINT_DIR="$LOAD_NAME" DATA_TRAIN="${SOURCE}/examples/multimodal/pretrain_dataset.yaml" @@ -62,7 +62,6 @@ OPTIONS=" \ --num-query-groups 8 \ --no-masked-softmax-fusion \ --num-workers ${NW} \ - --exit-duration-in-mins 230 \ --use-flash-attn \ --untie-embeddings-and-output-weights \ --disable-bias-linear \ @@ -82,7 +81,7 @@ OPTIONS=" \ --max-position-embeddings 4096 \ --ffn-hidden-size 14336 \ --train-iters 20000 \ - --micro-batch-size 1 \ + --micro-batch-size 16 \ --global-batch-size ${BZ} \ --lr-decay-iters 20000 \ --lr-warmup-fraction .01 \ @@ -93,10 +92,10 @@ OPTIONS=" \ --eval-iters 10 \ --eval-interval 1000 \ --tokenizer-type HuggingFaceTokenizer \ - --tokenizer-model ${WORKSPACE}/${TOKENIZER_MODEL} \ + --tokenizer-model ${TOKENIZER_MODEL} \ --data-path ${DATA_TRAIN} \ --prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \ - --save-interval 1000 \ + --save-interval 2000 \ --save ${FINETUNE_DIR} \ --load ${FINETUNE_DIR} \ --dataloader-save ${FINETUNE_DIR}/dataloader \ @@ -128,4 +127,4 @@ OPTIONS=" \ export NVTE_APPLY_QK_LAYER_SCALING=0 export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${NONDETERMINISTIC_ATTN} -torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} \ No newline at end of file +torchrun --nproc_per_node 4 examples/multimodal/train.py ${OPTIONS} \ No newline at end of file diff --git a/examples/multimodal/run_text_generation.py b/examples/multimodal/run_text_generation.py index 37d9072f0a..b781b77ff1 100644 --- a/examples/multimodal/run_text_generation.py +++ b/examples/multimodal/run_text_generation.py @@ -186,7 +186,17 @@ def __init__( max_num_tiles, use_thumbnail, ): - image_files = sorted(glob.glob(input_image_path + "/*")) + #image_files = sorted(glob.glob(input_image_path + "/*")) + + gts = json.load(open(gt_path)) + answers = defaultdict(list) + image_files = list() + + for gt in gts: + image_files.append(input_image_path + "/" + gt["image"]) + answers[gt["image"]] = gt['caption'] + + image_files = sorted(image_files) # Optionally, process only a subset of the input files. if num_partitions > 0: @@ -194,11 +204,7 @@ def __init__( len(image_files), num_samples_per_partition, num_partitions, partition_id ) image_files = image_files[lb:ub] - - gts = json.load(open(gt_path)) - answers = defaultdict(list) - for gt in gts["annotations"]: - answers[gt["image_id"]].append(gt['caption']) + self._image_files = image_files self._answers = answers diff --git a/examples/multimodal/sft_mistral_clip.sh b/examples/multimodal/sft_mistral_clip.sh index 46fc996055..38cf676738 100755 --- a/examples/multimodal/sft_mistral_clip.sh +++ b/examples/multimodal/sft_mistral_clip.sh @@ -24,17 +24,12 @@ if [[ -z $LOAD_NAME ]]; then exit 1 fi -if [[ -z $LOAD_ITER ]]; then - echo "Please set LOAD_ITER for pre-trained input model iteration." - exit 1 -fi - if [[ -z $TOKENIZER_MODEL ]]; then echo "Please set TOKENIZER_MODEL for tokenizer model name." exit 1 fi -CHECKPOINT_DIR="${WORKSPACE}/${LOAD_NAME}/checkpoints" +CHECKPOINT_DIR="$LOAD_NAME" DATA_TRAIN="${SOURCE}/examples/multimodal/sft_dataset.yaml" @@ -67,7 +62,6 @@ OPTIONS=" \ --num-query-groups 8 \ --no-masked-softmax-fusion \ --num-workers ${NW} \ - --exit-duration-in-mins 230 \ --use-flash-attn \ --untie-embeddings-and-output-weights \ --disable-bias-linear \ @@ -87,7 +81,7 @@ OPTIONS=" \ --max-position-embeddings 4096 \ --ffn-hidden-size 14336 \ --train-iters 20000 \ - --micro-batch-size 1 \ + --micro-batch-size 8 \ --global-batch-size ${BZ} \ --lr-decay-iters 20000 \ --lr-warmup-fraction .01 \ @@ -98,7 +92,7 @@ OPTIONS=" \ --eval-iters 10 \ --eval-interval 500 \ --tokenizer-type HuggingFaceTokenizer \ - --tokenizer-model ${WORKSPACE}/${TOKENIZER_MODEL} \ + --tokenizer-model ${TOKENIZER_MODEL} \ --data-path ${DATA_TRAIN} \ --prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \ --save-interval 500 \ @@ -115,6 +109,7 @@ OPTIONS=" \ --log-params-norm \ --log-num-zeros-in-grad \ --eod-mask-loss \ + --bf16 \ --freeze-ViT \ --patch-dim 14 \ --img-h 336 \ @@ -124,10 +119,10 @@ OPTIONS=" \ --language-model-type=mistral_7b \ --disable-vision-class-token \ ${EXTRA_ARGS} \ - --distributed-timeout-minutes 60 \ + --distributed-timeout-minutes 60 " export NVTE_APPLY_QK_LAYER_SCALING=0 export NVTE_ALLOW_NONDETERMINISTIC_ALGO=${NONDETERMINISTIC_ATTN} -torchrun --nproc_per_node 8 examples/multimodal/train.py ${OPTIONS} +torchrun --nproc_per_node 4 --master-port 29700 examples/multimodal/train.py ${OPTIONS} diff --git a/examples/multimodal/text_generation_mistral_clip.sh b/examples/multimodal/text_generation_mistral_clip.sh index 6423464e6d..3edad3d592 100755 --- a/examples/multimodal/text_generation_mistral_clip.sh +++ b/examples/multimodal/text_generation_mistral_clip.sh @@ -6,6 +6,7 @@ export NVTE_APPLY_QK_LAYER_SCALING=0 GROUNDTRUTH_PATH="placeholder" NUM_FRAMES=1 +NUM_GPUS=4 while [[ $# -gt 0 ]]; do case $1 in @@ -49,6 +50,11 @@ while [[ $# -gt 0 ]]; do shift shift ;; + --num-gpus) + NUM_GPUS="$2" + shift + shift + ;; -*|--*) echo "Invalid option $1" exit 1 @@ -63,7 +69,7 @@ END=0 for PARTITION_ID in $( eval echo {$START..$END} ) do - torchrun --nproc_per_node 8 examples/multimodal/run_text_generation.py \ + torchrun --nproc_per_node ${NUM_GPUS} examples/multimodal/run_text_generation.py \ --apply-layernorm-1p \ --attention-softmax-in-fp32 \ --use-flash-attn \ @@ -94,8 +100,9 @@ do --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model ${TOKENIZER_PATH} \ --bf16 \ - --micro-batch-size 1 \ - --seq-length 2048 \ + --micro-batch-size 32 \ + --seq-length 576 \ + --decoder-seq-length 2048 \ --out-seq-length 12 \ --temperature 1.0 \ --img-h 336 \ From 30e02a299602f2c4963c3038a8d7f73290a76288 Mon Sep 17 00:00:00 2001 From: Patrick Date: Wed, 20 Nov 2024 19:15:41 +0000 Subject: [PATCH 2/7] Add convertion script to HF LlaVA --- examples/multimodal/config.py | 5 +- examples/multimodal/convert_to_hf.py | 501 +++++++++++++++++++ examples/multimodal/model.py | 8 +- examples/multimodal/pretrain_dataset.yaml | 4 +- examples/multimodal/pretrain_mistral_clip.sh | 25 +- examples/multimodal/sft_dataset.yaml | 4 +- examples/multimodal/sft_mistral_clip.sh | 5 +- 7 files changed, 535 insertions(+), 17 deletions(-) create mode 100644 examples/multimodal/convert_to_hf.py diff --git a/examples/multimodal/config.py b/examples/multimodal/config.py index cf48b131a7..e9f35df7fe 100644 --- a/examples/multimodal/config.py +++ b/examples/multimodal/config.py @@ -114,7 +114,7 @@ def get_vision_model_config(config, apply_query_key_layer_scaling): def get_vision_projection_config(config, hidden_size): config.gated_linear_unit = False config.bias_activation_fusion = False - config.add_bias_linear = False + config.add_bias_linear = True # This was changed to make it compatible with HF's LLava config.hidden_size = hidden_size # Used as the vision projection output size, i.e., the input to the language model. if config.language_model_type == "2b": config.ffn_hidden_size = 5440 @@ -126,7 +126,8 @@ def get_vision_projection_config(config, hidden_size): config.ffn_hidden_size = 14336 config.activation_func = torch.nn.functional.gelu elif config.language_model_type == "mistral_7b": - config.ffn_hidden_size = 14336 + # TODO: check what needs to be done for other models + config.ffn_hidden_size = hidden_size # This was changed to make it compatible with HF's LLava config.activation_func = torch.nn.functional.gelu return config diff --git a/examples/multimodal/convert_to_hf.py b/examples/multimodal/convert_to_hf.py new file mode 100644 index 0000000000..ff49af53e5 --- /dev/null +++ b/examples/multimodal/convert_to_hf.py @@ -0,0 +1,501 @@ +import argparse +import os +import torch + +from megatron.core import parallel_state +from megatron.core import dist_checkpointing +from megatron.core.tensor_parallel.random import model_parallel_cuda_manual_seed + +from transformers import ( + AutoConfig, + LlavaForConditionalGeneration, + AutoTokenizer, + CLIPVisionConfig, + AddedToken, + AutoImageProcessor, + LlavaProcessor, +) +from transformers import LlavaConfig + +from model import model_provider + + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--mcore-load-dir", required=True) + parser.add_argument("--hf-save-dir", required=True) + parser.add_argument("--original-text-model-id", required=True) + parser.add_argument("--original-vision-model-id", required=True) + parser.add_argument("--target-params-dtype", type=str, default="float16") + return parser.parse_args() + + +def main(): + initialize_distributed(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) + model_parallel_cuda_manual_seed(123) + args = parse_args() + convert_mcore2hf(args) + + +def initialize_distributed(tensor_model_parallel_size=1, pipeline_model_parallel_size=1): + parallel_state.destroy_model_parallel() + + rank = int(os.environ['LOCAL_RANK']) + world_size = torch.cuda.device_count() + torch.cuda.set_device(rank) + torch.distributed.init_process_group(world_size=world_size, rank=rank) + # Megatron core distributed training initialization + parallel_state.initialize_model_parallel( + tensor_model_parallel_size=tensor_model_parallel_size, + pipeline_model_parallel_size=pipeline_model_parallel_size, + ) + +def convert_mcore2hf(args): + """Main function to convert MCore checkpoint to HF format""" + # TODO: add support for casting explicitly to dtype + dtype = getattr(torch, args.target_params_dtype) + + print(f"> Loading MCore checkpoints") + assert os.path.exists(f"{args.mcore_load_dir}/latest_checkpointed_iteration.txt") + assert os.path.isfile(f"{args.mcore_load_dir}/latest_checkpointed_iteration.txt") + + with open(f"{args.mcore_load_dir}/latest_checkpointed_iteration.txt", "r") as f: + iteration = int(f.read().strip()) + iter_dir = f"{args.mcore_load_dir}/iter_{iteration:07d}" + + # start by loading the args from the checkpoint + margs = dist_checkpointing.load_common_state_dict(iter_dir)['args'] + print(f"> Loaded args from checkpoint: {margs}") + args.tensor_model_parallel_size = 1 + + # load the model checkpoint itself + model = model_provider(args=margs) + sharded_state_dict = model.sharded_state_dict(prefix='') + checkpoint = dist_checkpointing.load( + sharded_state_dict=sharded_state_dict, checkpoint_dir=iter_dir + ) + + # import pdb; pdb.set_trace() + # create the HF config + hf_config = create_hf_config(args.original_text_model_id, args.original_vision_model_id, margs) + + # create the tokenizer and processor + processor = create_hf_processor(hf_config, args.original_text_model_id, args.original_vision_model_id) + processor.save_pretrained(args.hf_save_dir) + + # Convert the state dict + print(f"> Converting weights from MCore to HF format") + hf_state_dict = {} + + # Convert vision model weights + vision_state_dict = convert_mcore2hf_vision_model(checkpoint) + + # Convert language model weights + language_state_dict = convert_mcore2hf_language_model(checkpoint) + + # Convert projection weights + projection_state_dict = convert_mcore2hf_vision_projection(checkpoint) + + # Combine all state dicts + hf_state_dict.update(vision_state_dict) + hf_state_dict.update(language_state_dict) + hf_state_dict.update(projection_state_dict) + + # create the HF model + print(f"> Loading HF model and converted weights") + hf_model = LlavaForConditionalGeneration(config=hf_config) + hf_model.load_state_dict(hf_state_dict, strict=True) + + # extend the embeddings + extend_embeddings(hf_model, hf_config) + + print(f"> Saving HF model to {args.hf_save_dir}") + hf_model.save_pretrained(args.hf_save_dir) + + +def create_hf_config(original_text_model_id, original_vision_model_id, margs): + """Create HF config from Megatron checkpoint""" + # Extract model args from checkpoint + assert margs.transformer_impl == "transformer_engine" + assert margs.position_embedding_type == "rope" + assert margs.normalization == "RMSNorm" + assert margs.swiglu + # assert margs.untie_embeddings_and_output_weights + + # TODO: atm, both vision and language towers + # assume that there was an initial HF model + # that we can use to get the config + # however, in the ideal world, we would directly + # use the Megatron config to create the HF config + # but we leave this for later + + # Create CLIP vision config + # get config for openai/clip-vit-large-patch14-336 + vision_config = CLIPVisionConfig.from_pretrained(original_vision_model_id) + + # Create language model config (using LlamaConfig as base) + text_config = AutoConfig.from_pretrained(original_text_model_id) + + # Create final LLaVA config combining both + hf_config = LlavaConfig( + vision_config=vision_config, + text_config=text_config, + # Add any other LLaVA specific configs here + ) + return hf_config + + +def create_hf_processor(hf_config, text_model_id, vision_model_id): + tokenizer = AutoTokenizer.from_pretrained(text_model_id) + tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) + tokenizer.add_special_tokens({"pad_token": ""}) + hf_config.image_token_index = tokenizer.convert_tokens_to_ids("") + hf_config.pad_token_id = tokenizer.pad_token_id + + try: + from transformers.models.llava.image_processing_llava import LlavaImageProcessor + image_processor = LlavaImageProcessor( + do_megatron_pp=True, + ) + except ImportError: + print("> WARNING: could not import LlavaImageProcessor, using AutoImageProcessor instead") + print("> This might lead to performance degradation due to slightly different image pre-processing") + image_processor = AutoImageProcessor.from_pretrained(vision_model_id) + + processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) + return processor + + +def convert_mcore2hf_vision_model(mcore_sd): + """Convert vision model weights from Megatron to HF format""" + state_dict = {} + + # Vision embedding layers + state_dict.update( + { + "vision_tower.vision_model.embeddings.class_embedding": mcore_sd[ + "vision_model.class_token" + ].squeeze(), + "vision_tower.vision_model.embeddings.position_embedding.weight": mcore_sd[ + "vision_model.position_embeddings.weight" + ], + "vision_tower.vision_model.embeddings.patch_embedding.weight": mcore_sd[ + "vision_model.conv1.weight" + ], + "vision_tower.vision_model.pre_layrnorm.weight": mcore_sd["vision_model.ln_pre.weight"], + "vision_tower.vision_model.pre_layrnorm.bias": mcore_sd["vision_model.ln_pre.bias"], + } + ) + + # Vision transformer layers + # TODO: for some reason, this is not in the args?? + clip_num_layers = 24 + for layer_i in range(clip_num_layers): + hf_layer_prefix = f"vision_tower.vision_model.encoder.layers.{layer_i}" + mcore_layer_prefix = f"vision_model.decoder.layers.{layer_i}" + + # Get QKV weights and biases + qkv_weight = mcore_sd[f"{mcore_layer_prefix}.self_attention.linear_qkv.weight"] + qkv_bias = mcore_sd[f"{mcore_layer_prefix}.self_attention.linear_qkv.bias"] + + # Split into Q, K, V following CLIP's original ordering + hidden_size = qkv_weight.shape[1] + num_heads = 16 # CLIP ViT-L/14 uses 16 heads + head_dim = hidden_size // num_heads + + # Reshape and split QKV similar to language model approach + qkv = qkv_weight.reshape(num_heads, 3 * head_dim, -1) + + # Split into Q, K, V components + q_proj = qkv[:, :head_dim, :] + k_proj = qkv[:, head_dim : 2 * head_dim, :] + v_proj = qkv[:, 2 * head_dim :, :] + + # Reshape back to original dimensions + q_proj = q_proj.reshape(num_heads * head_dim, -1) + k_proj = k_proj.reshape(num_heads * head_dim, -1) + v_proj = v_proj.reshape(num_heads * head_dim, -1) + + # Do the same for biases + qkv_bias = qkv_bias.reshape(num_heads, 3 * head_dim) + q_bias = qkv_bias[:, :head_dim].reshape(-1) + k_bias = qkv_bias[:, head_dim : 2 * head_dim].reshape(-1) + v_bias = qkv_bias[:, 2 * head_dim :].reshape(-1) + + state_dict.update( + { + # Attention weights + f"{hf_layer_prefix}.self_attn.q_proj.weight": q_proj, + f"{hf_layer_prefix}.self_attn.k_proj.weight": k_proj, + f"{hf_layer_prefix}.self_attn.v_proj.weight": v_proj, + f"{hf_layer_prefix}.self_attn.q_proj.bias": q_bias, + f"{hf_layer_prefix}.self_attn.k_proj.bias": k_bias, + f"{hf_layer_prefix}.self_attn.v_proj.bias": v_bias, + # Output projection + f"{hf_layer_prefix}.self_attn.out_proj.weight": mcore_sd[ + f"{mcore_layer_prefix}.self_attention.linear_proj.weight" + ], + f"{hf_layer_prefix}.self_attn.out_proj.bias": mcore_sd[ + f"{mcore_layer_prefix}.self_attention.linear_proj.bias" + ], + # Layer norms + f"{hf_layer_prefix}.layer_norm1.weight": mcore_sd[ + f"{mcore_layer_prefix}.self_attention.linear_qkv.layer_norm_weight" + ], + f"{hf_layer_prefix}.layer_norm1.bias": mcore_sd[ + f"{mcore_layer_prefix}.self_attention.linear_qkv.layer_norm_bias" + ], + f"{hf_layer_prefix}.layer_norm2.weight": mcore_sd[ + f"{mcore_layer_prefix}.mlp.linear_fc1.layer_norm_weight" + ], + f"{hf_layer_prefix}.layer_norm2.bias": mcore_sd[ + f"{mcore_layer_prefix}.mlp.linear_fc1.layer_norm_bias" + ], + # MLP weights + f"{hf_layer_prefix}.mlp.fc1.weight": mcore_sd[ + f"{mcore_layer_prefix}.mlp.linear_fc1.weight" + ], + f"{hf_layer_prefix}.mlp.fc1.bias": mcore_sd[ + f"{mcore_layer_prefix}.mlp.linear_fc1.bias" + ], + f"{hf_layer_prefix}.mlp.fc2.weight": mcore_sd[ + f"{mcore_layer_prefix}.mlp.linear_fc2.weight" + ], + f"{hf_layer_prefix}.mlp.fc2.bias": mcore_sd[ + f"{mcore_layer_prefix}.mlp.linear_fc2.bias" + ], + } + ) + + # NOTE: for some reason, Megatron removes the post_layernorm weights and biases + # so we need to add them back in for the HF model, + # ensuring they perform the identity mapping + state_dict["vision_tower.vision_model.post_layernorm.weight"] = torch.ones(1024) + state_dict["vision_tower.vision_model.post_layernorm.bias"] = torch.zeros(1024) + + return state_dict + + +def convert_mcore2hf_vision_model_new(mcore_sd): + """Convert vision model weights from Megatron to HF format""" + state_dict = {} + + # Vision embedding layers + state_dict.update( + { + "vision_tower.vision_model.embeddings.class_embedding": mcore_sd[ + "vision_model.class_token" + ].squeeze(), + "vision_tower.vision_model.embeddings.position_embedding.weight": mcore_sd[ + "vision_model.position_embeddings.weight" + ], + "vision_tower.vision_model.embeddings.patch_embedding.weight": mcore_sd[ + "vision_model.conv1.weight" + ], + "vision_tower.vision_model.pre_layrnorm.weight": mcore_sd["vision_model.ln_pre.weight"], + "vision_tower.vision_model.pre_layrnorm.bias": mcore_sd["vision_model.ln_pre.bias"], + } + ) + + # Vision transformer layers + clip_num_layers = 24 + for layer_i in range(clip_num_layers): + hf_layer_prefix = f"vision_tower.vision_model.encoder.layers.{layer_i}" + mcore_layer_prefix = f"vision_model.decoder.layers.{layer_i}" + + # Get QKV weights and biases + qkv_weight = mcore_sd[f"{mcore_layer_prefix}.self_attention.linear_qkv.weight"] + qkv_bias = mcore_sd[f"{mcore_layer_prefix}.self_attention.linear_qkv.bias"] + + # Calculate dimensions + hidden_dim = qkv_weight.shape[1] + num_heads = mcore_sd["args"].num_attention_heads + head_dim = hidden_dim // num_heads + + # Split QKV weights and biases + q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) + q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0) + + # Ensure these are correctly assigned in the state_dict + state_dict.update( + { + f"{hf_layer_prefix}.self_attn.q_proj.weight": q_weight, + f"{hf_layer_prefix}.self_attn.k_proj.weight": k_weight, + f"{hf_layer_prefix}.self_attn.v_proj.weight": v_weight, + f"{hf_layer_prefix}.self_attn.q_proj.bias": q_bias, + f"{hf_layer_prefix}.self_attn.k_proj.bias": k_bias, + f"{hf_layer_prefix}.self_attn.v_proj.bias": v_bias, + } + ) + + # NOTE: for some reason, Megatron removes the post_layernorm weights and biases + # so we need to add them back in for the HF model, + # ensuring they perform the identity mapping + state_dict["vision_tower.vision_model.post_layernorm.weight"] = torch.ones(1024) + state_dict["vision_tower.vision_model.post_layernorm.bias"] = torch.zeros(1024) + return state_dict + + +def convert_mcore2hf_language_model(mcore_sd): + """Convert language model weights from Megatron to HF format""" + state_dict = {} + + # Embeddings + state_dict["language_model.model.embed_tokens.weight"] = mcore_sd[ + "language_model.embedding.word_embeddings.weight" + ] + + # Final layer norm and output + state_dict["language_model.model.norm.weight"] = mcore_sd[ + "language_model.decoder.final_layernorm.weight" + ] + state_dict["language_model.lm_head.weight"] = mcore_sd["language_model.output_layer.weight"] + + # Transformer layers + for layer_i in range(mcore_sd["args"].num_layers): + mcore_prefix = f"language_model.decoder.layers.{layer_i}" + hf_prefix = f"language_model.model.layers.{layer_i}" + + # Layer norms + state_dict.update( + { + f"{hf_prefix}.input_layernorm.weight": mcore_sd[ + f"{mcore_prefix}.self_attention.linear_qkv.layer_norm_weight" + ], + f"{hf_prefix}.post_attention_layernorm.weight": mcore_sd[ + f"{mcore_prefix}.mlp.linear_fc1.layer_norm_weight" + ], + } + ) + + # Attention weights + qkv_weight = mcore_sd[f"{mcore_prefix}.self_attention.linear_qkv.weight"] + # Ensure the shape is divisible by 3 + + # load transformer llava and do the same + # from transformers import LlavaForConditionalGeneration + # llava_model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") + + hidden_size = qkv_weight.shape[1] + num_kv_heads = mcore_sd["args"].num_query_groups + num_heads = mcore_sd["args"].num_attention_heads + num_queries_per_group = num_heads // num_kv_heads + head_dim = hidden_size // num_heads + + qkv_size, _ = qkv_weight.size() + expected_qkv_size = num_kv_heads * (num_queries_per_group + 2) * head_dim + if qkv_size != expected_qkv_size: + raise ValueError("qkv_size does not match expected size") + + qkv = qkv_weight.reshape(num_kv_heads, (num_queries_per_group + 2) * head_dim, -1) + + # Split qkv into q_proj, k_proj, v_proj + q_proj = qkv[:, : num_queries_per_group * head_dim, :] + k_proj = qkv[ + :, num_queries_per_group * head_dim : (num_queries_per_group + 1) * head_dim, : + ] + v_proj = qkv[ + :, (num_queries_per_group + 1) * head_dim : (num_queries_per_group + 2) * head_dim, : + ] + + # Reshape projections to match HuggingFace format + q_proj = q_proj.reshape(num_kv_heads * num_queries_per_group * head_dim, -1) + k_proj = k_proj.reshape(num_kv_heads * head_dim, -1) + v_proj = v_proj.reshape(num_kv_heads * head_dim, -1) + + # import pdb; pdb.set_trace() + + state_dict.update( + { + f"{hf_prefix}.self_attn.q_proj.weight": q_proj, + f"{hf_prefix}.self_attn.k_proj.weight": k_proj, + f"{hf_prefix}.self_attn.v_proj.weight": v_proj, + f"{hf_prefix}.self_attn.o_proj.weight": mcore_sd[ + f"{mcore_prefix}.self_attention.linear_proj.weight" + ], + } + ) + + # MLP weights + # Note: In LLaMA, gate_proj and up_proj together form what was fc1 in the original architecture + fc1_weight = mcore_sd[f"{mcore_prefix}.mlp.linear_fc1.weight"] + gate_size = fc1_weight.shape[0] // 2 + state_dict.update( + { + f"{hf_prefix}.mlp.gate_proj.weight": fc1_weight[:gate_size], + f"{hf_prefix}.mlp.up_proj.weight": fc1_weight[gate_size:], + f"{hf_prefix}.mlp.down_proj.weight": mcore_sd[ + f"{mcore_prefix}.mlp.linear_fc2.weight" + ], + } + ) + + return state_dict + + +def convert_mcore2hf_vision_projection(mcore_sd): + """Convert vision projection weights from Megatron to HF format""" + state_dict = {} + + # Map the weights from Megatron to Hugging Face format + state_dict["multi_modal_projector.linear_1.weight"] = mcore_sd[ + "vision_projection.encoder.linear_fc1.weight" + ] + state_dict["multi_modal_projector.linear_1.bias"] = mcore_sd[ + "vision_projection.encoder.linear_fc1.bias" + ] + state_dict["multi_modal_projector.linear_2.weight"] = mcore_sd[ + "vision_projection.encoder.linear_fc2.weight" + ] + state_dict["multi_modal_projector.linear_2.bias"] = mcore_sd[ + "vision_projection.encoder.linear_fc2.bias" + ] + + return state_dict + +def extend_embeddings(hf_model, hf_config): + # Initialize new embeddings for additional tokens + # We use the average of the pre-expansion embeddings as the mean + # and a small covariance matrix to ensure the new embeddings are close to the old ones + # adapted from + # https://github.com/huggingface/transformers/blob/bf42c3bd4b088fd9df1086e63d47a8e33048e5e1/src/transformers/models/llava/convert_llava_weights_to_hf.py#L100 + # TODO: it seems this might not be needed anymore in the new versions of HF?? + # double check + pre_expansion_embeddings = hf_model.language_model.model.embed_tokens.weight.data + mu = torch.mean(pre_expansion_embeddings, dim=0).float() + n = pre_expansion_embeddings.size()[0] + sigma = ((pre_expansion_embeddings - mu).T @ (pre_expansion_embeddings - mu)) / n + dist = torch.distributions.multivariate_normal.MultivariateNormal( + mu, covariance_matrix=1e-5 * sigma + ) + + # We add an image token so we resize the model and pad to 64 for performance reasons + pad_shape = 64 + vocab_size = hf_config.text_config.vocab_size + hf_model.resize_token_embeddings(vocab_size + 2, pad_shape) + hf_model.language_model.model.embed_tokens.weight.data[vocab_size:] = torch.stack( + tuple( + ( + dist.sample() + for _ in range( + hf_model.language_model.model.embed_tokens.weight.data[vocab_size:].shape[0] + ) + ) + ), + dim=0, + ) + hf_model.language_model.lm_head.weight.data[vocab_size:] = torch.stack( + tuple( + ( + dist.sample() + for _ in range(hf_model.language_model.lm_head.weight.data[vocab_size:].shape[0]) + ) + ), + dim=0, + ) + + + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/model.py b/examples/multimodal/model.py index b4bab73cfb..f35a31b328 100644 --- a/examples/multimodal/model.py +++ b/examples/multimodal/model.py @@ -13,7 +13,7 @@ def model_provider( - pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True + pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True, args=None ) -> LLaVAModel: """Builds the model. @@ -29,7 +29,7 @@ def model_provider( Returns: model: A multimodal model. """ - args = get_args() + args = get_args() if args is None else args use_te = args.use_te @@ -41,7 +41,7 @@ def model_provider( ) old_seq_length = args.seq_length args.seq_length = args.encoder_seq_length = num_image_embeddings - if torch.distributed.get_rank() == 0 and old_seq_length != args.seq_length: + if (not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0) and old_seq_length != args.seq_length: warnings.warn( f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})" ) @@ -60,7 +60,7 @@ def model_provider( f"Expanded max_position_embeddings to {args.max_position_embeddings} to accommodate the maximum language model sequence length" ) - base_config = core_transformer_config_from_args(get_args()) + base_config = core_transformer_config_from_args(args) base_config.language_model_type = args.language_model_type base_config.vision_model_type = args.vision_model_type base_config.calculate_per_token_loss = True diff --git a/examples/multimodal/pretrain_dataset.yaml b/examples/multimodal/pretrain_dataset.yaml index f27bccba30..3581d5a67f 100644 --- a/examples/multimodal/pretrain_dataset.yaml +++ b/examples/multimodal/pretrain_dataset.yaml @@ -4,12 +4,12 @@ splits: train: datasets: - weight: 1. - path: + path: /mnt/cephfs-nvme/nunomg/Megatron-LM/LLaVA-Pretrain/wds subflavors: augmentation: false val: datasets: - weight: 1. - path: + path: /mnt/cephfs-nvme/nunomg/Megatron-LM/LLaVA-Pretrain/wds subflavors: augmentation: false diff --git a/examples/multimodal/pretrain_mistral_clip.sh b/examples/multimodal/pretrain_mistral_clip.sh index eefdff0673..1cca88255b 100755 --- a/examples/multimodal/pretrain_mistral_clip.sh +++ b/examples/multimodal/pretrain_mistral_clip.sh @@ -42,6 +42,21 @@ if [[ $DEBUG -eq 1 ]]; then EXTRA_ARGS="" NONDETERMINISTIC_ATTN=1 else + # Original hparams + # TRAIN_ITERS=20000 + # SAVE_INTERVAL=2000 + # EVAL_INTERVAL=1000 + # LR=0.00015 + # LR_WARMUP_FRACTION=0.01 + + # LlaVA hparams + TRAIN_ITERS=2000 + SAVE_INTERVAL=500 + EVAL_INTERVAL=500 + LR=0.001 + LR_WARMUP_FRACTION=0.03 + + BZ=256 NW=2 HD=0.1 @@ -80,22 +95,22 @@ OPTIONS=" \ --decoder-seq-length 1024 \ --max-position-embeddings 4096 \ --ffn-hidden-size 14336 \ - --train-iters 20000 \ + --train-iters ${TRAIN_ITERS} \ --micro-batch-size 16 \ --global-batch-size ${BZ} \ - --lr-decay-iters 20000 \ + --lr-decay-iters ${TRAIN_ITERS} \ --lr-warmup-fraction .01 \ - --lr 0.00015 \ + --lr ${LR} \ --min-lr 1.0e-5 \ --lr-decay-style cosine \ --log-interval ${LI} \ --eval-iters 10 \ - --eval-interval 1000 \ + --eval-interval ${EVAL_INTERVAL} \ --tokenizer-type HuggingFaceTokenizer \ --tokenizer-model ${TOKENIZER_MODEL} \ --data-path ${DATA_TRAIN} \ --prompt-path ${SOURCE}/examples/multimodal/manual_prompts.json \ - --save-interval 2000 \ + --save-interval ${SAVE_INTERVAL} \ --save ${FINETUNE_DIR} \ --load ${FINETUNE_DIR} \ --dataloader-save ${FINETUNE_DIR}/dataloader \ diff --git a/examples/multimodal/sft_dataset.yaml b/examples/multimodal/sft_dataset.yaml index c9f0257ae7..1e992974f1 100644 --- a/examples/multimodal/sft_dataset.yaml +++ b/examples/multimodal/sft_dataset.yaml @@ -4,12 +4,12 @@ splits: train: datasets: - weight: 1. - path: + path: /mnt/cephfs-nvme/jalves/tower_vision/data/LLaVA-Instruct-150K/wds subflavors: augmentation: false val: datasets: - weight: 1. - path: + path: /mnt/cephfs-nvme/jalves/tower_vision/data/LLaVA-Instruct-150K/wds subflavors: augmentation: false diff --git a/examples/multimodal/sft_mistral_clip.sh b/examples/multimodal/sft_mistral_clip.sh index 38cf676738..97c9ad12ce 100755 --- a/examples/multimodal/sft_mistral_clip.sh +++ b/examples/multimodal/sft_mistral_clip.sh @@ -48,6 +48,7 @@ else LI=10 EXTRA_ARGS="" NONDETERMINISTIC_ATTN=1 + TRAIN_ITERS=5000 fi OPTIONS=" \ @@ -80,10 +81,10 @@ OPTIONS=" \ --decoder-seq-length 2048 \ --max-position-embeddings 4096 \ --ffn-hidden-size 14336 \ - --train-iters 20000 \ + --train-iters ${TRAIN_ITERS} \ --micro-batch-size 8 \ --global-batch-size ${BZ} \ - --lr-decay-iters 20000 \ + --lr-decay-iters ${TRAIN_ITERS} \ --lr-warmup-fraction .01 \ --lr 1e-6 \ --min-lr 1e-7 \ From 575aaced7ec071e6ee7cd7a9e952accdca6ef053 Mon Sep 17 00:00:00 2001 From: Patrick Date: Mon, 25 Nov 2024 15:47:32 +0000 Subject: [PATCH 3/7] update code --- conda_install.sh | 3 +- examples/multimodal/combine_mistral_clip.sh | 5 +- examples/multimodal/download_hf_model.py | 17 ++ .../model_converter/clip_converter.py | 2 +- examples/multimodal/pretrain_dataset.yaml | 4 +- examples/multimodal/sft_dataset.yaml | 4 +- tapes/main.tape | 214 ++++++++++++++++++ tapes/scslurm.tape | 86 +++++++ 8 files changed, 328 insertions(+), 7 deletions(-) create mode 100644 examples/multimodal/download_hf_model.py create mode 100644 tapes/main.tape create mode 100644 tapes/scslurm.tape diff --git a/conda_install.sh b/conda_install.sh index edfe26940c..aaa026176e 100644 --- a/conda_install.sh +++ b/conda_install.sh @@ -29,13 +29,14 @@ echo "Megatron-LM dir: $DIR" source ${CONDA_HOME}/etc/profile.d/conda.sh # python can't handle this dependency madness, switch to C++ -conda create -y -n ${ENV_NAME} python=3.10 +# conda create -y -n ${ENV_NAME} python=3.10 conda activate ${ENV_NAME} pip install ninja # install our own copy of CUDA and set environment variables +conda install -y openldap conda install -y -c "nvidia/label/cuda-12.4.0" cuda-toolkit cuda-nvcc cudnn export PATH=${CONDA_ENVS}/${ENV_NAME}/bin:$PATH diff --git a/examples/multimodal/combine_mistral_clip.sh b/examples/multimodal/combine_mistral_clip.sh index ff866c7f72..9871004c8d 100755 --- a/examples/multimodal/combine_mistral_clip.sh +++ b/examples/multimodal/combine_mistral_clip.sh @@ -1,9 +1,12 @@ #/bin/bash +# get dir of this script +SCRIPT_DIR=$(dirname $(readlink -f $0)) + MCORE_MISTRAL=$1 # MCORE_CLIP=$2 # OUTPUT_DIR=$3 # -python examples/multimodal/combine_state_dicts.py \ +python ${SCRIPT_DIR}/combine_state_dicts.py \ --input \ ${MCORE_MISTRAL}/iter_0000001/mp_rank_00/model_optim_rng.pt \ ${MCORE_CLIP}/iter_0000001/mp_rank_00/model_optim_rng.pt \ diff --git a/examples/multimodal/download_hf_model.py b/examples/multimodal/download_hf_model.py new file mode 100644 index 0000000000..ba61902885 --- /dev/null +++ b/examples/multimodal/download_hf_model.py @@ -0,0 +1,17 @@ +import argparse + +from transformers import AutoModelForCausalLM + +def read_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--model", type=str, required=True) + parser.add_argument("--output-dir", type=str, required=True) + return parser.parse_args() + +def main(): + args = read_args() + model = AutoModelForCausalLM.from_pretrained(args.model) + model.save_pretrained(args.output_dir) + +if __name__ == "__main__": + main() diff --git a/examples/multimodal/model_converter/clip_converter.py b/examples/multimodal/model_converter/clip_converter.py index 696c810890..e9a0aa7d98 100644 --- a/examples/multimodal/model_converter/clip_converter.py +++ b/examples/multimodal/model_converter/clip_converter.py @@ -8,7 +8,7 @@ def convert(download_root, output_path, tensor_parallel_size, use_te): - device = "cuda" + device = "cpu" model, _ = clip.load("ViT-L/14@336px", device=device, download_root=download_root) diff --git a/examples/multimodal/pretrain_dataset.yaml b/examples/multimodal/pretrain_dataset.yaml index 3581d5a67f..4fd9a10a7a 100644 --- a/examples/multimodal/pretrain_dataset.yaml +++ b/examples/multimodal/pretrain_dataset.yaml @@ -4,12 +4,12 @@ splits: train: datasets: - weight: 1. - path: /mnt/cephfs-nvme/nunomg/Megatron-LM/LLaVA-Pretrain/wds + path: /lustre/fswork/projects/rech/qjm/ued79zb/prt-wds subflavors: augmentation: false val: datasets: - weight: 1. - path: /mnt/cephfs-nvme/nunomg/Megatron-LM/LLaVA-Pretrain/wds + path: /lustre/fswork/projects/rech/qjm/ued79zb/prt-wds subflavors: augmentation: false diff --git a/examples/multimodal/sft_dataset.yaml b/examples/multimodal/sft_dataset.yaml index 1e992974f1..835caedf63 100644 --- a/examples/multimodal/sft_dataset.yaml +++ b/examples/multimodal/sft_dataset.yaml @@ -4,12 +4,12 @@ splits: train: datasets: - weight: 1. - path: /mnt/cephfs-nvme/jalves/tower_vision/data/LLaVA-Instruct-150K/wds + path: /lustre/fswork/projects/rech/qjm/ued79zb/sft-wds subflavors: augmentation: false val: datasets: - weight: 1. - path: /mnt/cephfs-nvme/jalves/tower_vision/data/LLaVA-Instruct-150K/wds + path: /lustre/fswork/projects/rech/qjm/ued79zb/sft-wds subflavors: augmentation: false diff --git a/tapes/main.tape b/tapes/main.tape new file mode 100644 index 0000000000..45b6a739de --- /dev/null +++ b/tapes/main.tape @@ -0,0 +1,214 @@ +import "scslurm.tape" + +global { + ducttape_experimental_imports=true + ducttape_experimental_submitters=true + ducttape_experimental_multiproc=true + + ducttape_output=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_outputs + repo=/linkhome/rech/genrce01/ued79zb/repos/Megatron-LM-pretrain + + # multimodal model parameters + # (base lm, tp, etc...) + clip_original_dir=/lustre/fswork/projects/rech/qjm/ued79zb/clip_model_og/ + mistral_model="mistralai/Mistral-7B-Instruct-v0.3" + tp=4 + pp=1 + + # pre-training arguments + external_model_dir=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/mistral_7b_instruct_prt + external_tensorboard=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/mistral_7b_instruct_prt + pretrain_iters=2000 + pretrain_bsz=256 + pretrain_lr=0.001 + pretrain_lr_warmup=0.03 + pretrain_save_interval=500 + pretrain_eval_interval=500 + + + # -- submitter arguments -- + submitter=scslurm + + prepare_account="qjm@cpu" + prepare_time="1:00:00" + prepare_cpus=4 + prepare_partition="prepost" + + pretrain_C="h100" + pretrain_account="qjm@h100" + pretrain_time="10:00:00" + pretrain_cpus=80 + pretrain_gres="gpu:4" +} + +task PrepareModel + > initial_model + :: repo=@ + :: clip_original_dir=@ + :: mistral_model=@ + :: tp=@ + :: pp=@ + :: .submitter=@ + :: .account=$prepare_account + :: .time=$prepare_time + :: .cpus=$prepare_cpus + :: .partition=$prepare_partition +{ + # Download & convert CLIP model + echo "Downloading & converting CLIP model..." + python $repo/examples/multimodal/model_converter/clip_converter.py \ + --download-root $clip_original_dir \ + --output clip_mcore_dir \ + --tensor-parallel-size ${tp} \ + --use-te + + # Download & convert language model + echo "Downloading & converting language model..." + python $repo/examples/multimodal/download_hf_model.py \ + --model ${mistral_model} \ + --output-dir mistral_original_dir + python $repo/tools/checkpoint/convert.py --model-type GPT \ + --loader llama_mistral \ + --saver mcore \ + --checkpoint-type hf \ + --model-size mistral-7B \ + --load-dir mistral_original_dir \ + --save-dir mistral_mcore_dir \ + --tokenizer-model ${mistral_model} \ + --target-tensor-parallel-size ${tp} \ + --target-pipeline-parallel-size ${pp} \ + --bf16 + + # Combine models + echo "Combining language and vision models..." + bash $repo/examples/multimodal/combine_mistral_clip.sh \ + mistral_mcore_dir \ + clip_mcore_dir \ + $initial_model + + # remove original and intermediate converted models to save space + rm -rf mistral_original_dir + rm -rf clip_mcore_dir mistral_mcore_dir +} + +task PretrainModel + < initial_model=@PrepareModel + > model_dir + :: repo=@ + :: tokenizer_model=$mistral_model + :: train_iters=$pretrain_iters + :: batch_size=$pretrain_bsz + :: lr=$pretrain_lr + :: lr_warmup_fraction=$pretrain_lr_warmup + :: save_interval=$pretrain_save_interval + :: eval_interval=$pretrain_eval_interval + :: external_resume=true + :: external_model_dir=@ + :: external_tensorboard=@ + :: .submitter=@ + :: .C=$pretrain_C + :: .account=$pretrain_account + :: .time=$pretrain_time + :: .cpus=$pretrain_cpus + :: .gres=$pretrain_gres +{ + export NCCL_IB_SL=1 + export CUDA_DEVICE_MAX_CONNECTIONS=1 + + # if `save_external` is set, symlink it to the `model_dir` + # and copy the config file to the `model_dir` + if [ "$external_model_dir" != "" ]; then + if [ "$external_resume" == false ]; then + rm -rf $external_model_dir + fi + mkdir -p $external_model_dir + ln -sf $external_model_dir $model_dir + fi + + if [ "$external_tensorboard" != "" ]; then + mkdir -p $external_tensorboard + tensorboard=$external_tensorboard + else + mkdir -p tensorboard + tensorboard=tensorboard + fi + + export NVTE_APPLY_QK_LAYER_SCALING=0 + export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 + + torchrun --nproc_per_node 4 $repo/examples/multimodal/train.py \ + --apply-layernorm-1p \ + --attention-softmax-in-fp32 \ + --use-checkpoint-args \ + --use-distributed-optimizer \ + --transformer-impl transformer_engine \ + --use-te \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --num-workers 2 \ + --use-flash-attn \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout 0.1 \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --seq-length 576 \ + --decoder-seq-length 1024 \ + --max-position-embeddings 4096 \ + --ffn-hidden-size 14336 \ + --train-iters ${train_iters} \ + --micro-batch-size 16 \ + --global-batch-size ${batch_size} \ + --lr-decay-iters ${train_iters} \ + --lr-warmup-fraction ${lr_warmup_fraction} \ + --lr ${lr} \ + --min-lr 1.0e-5 \ + --lr-decay-style cosine \ + --log-interval 10 \ + --eval-iters 10 \ + --eval-interval ${eval_interval} \ + --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-model ${tokenizer_model} \ + --data-path ${repo}/examples/multimodal/pretrain_dataset.yaml \ + --prompt-path ${repo}/examples/multimodal/manual_prompts.json \ + --save-interval ${save_interval} \ + --save ${model_dir} \ + --load ${model_dir} \ + --dataloader-save ${model_dir}/dataloader \ + --pretrained-checkpoint ${initial_model}/mistral_instruct_clip336_tp4_combined_mcore \ + --split 100,0,0 \ + --clip-grad 1.0 \ + --weight-decay 1e-2 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --bf16 \ + --eod-mask-loss \ + --freeze-LM \ + --freeze-ViT \ + --patch-dim 14 \ + --img-h 336 \ + --img-w 336 \ + --dataloader-type external \ + --tensorboard-dir ${tensorboard} \ + --language-model-type=mistral_7b \ + --disable-vision-class-token \ + --distributed-timeout-minutes 60 \ + --allow-missing-vision-projection-checkpoint +} + +plan TrainLLaVA { + reach PretrainModel +} \ No newline at end of file diff --git a/tapes/scslurm.tape b/tapes/scslurm.tape new file mode 100644 index 0000000000..b3dc3c68c5 --- /dev/null +++ b/tapes/scslurm.tape @@ -0,0 +1,86 @@ +submitter scslurm :: COMMANDS + :: TASK REALIZATION CONFIGURATION { + action run > exit_code { + # Returns true iff the first parameter + # is the name of a defined variable + function var_defined { + eval '[[ -n ${'$1'+set} && ${'$1'} != "none" ]]' + } + + # define script + wrapper="ducttape_job.sh" + + echo "#!/bin/bash" > $wrapper + echo "set -e # stop on errors" >> $wrapper + echo "set -o pipefail # stop on pipeline errors" >> $wrapper + echo "set -u # stop on undeclared variables" >> $wrapper + + # print actual jobs + echo "$COMMANDS" >> $wrapper + + SLURM_ARGS="--job-name=$TASK" + if (var_defined C); then + SLURM_ARGS+=" -C $C" + fi + if (var_defined account); then + SLURM_ARGS+=" -A $account" + fi + if (var_defined config); then + SLURM_ARGS+=" -C ${config}" + fi + if (var_defined reservation); then + SLURM_ARGS+=" --reservation=$reservation" + fi + if (var_defined nodes); then + SLURM_ARGS+=" --nodes=$nodes --ntasks-per-node=1" + fi + if (var_defined cpus); then + SLURM_ARGS+=" --cpus-per-task=$cpus" + fi + if (var_defined mem); then + SLURM_ARGS+=" --mem=$mem" + fi + if (var_defined gres); then + SLURM_ARGS+=" --gres=$gres" + fi + if (var_defined time); then + SLURM_ARGS+=" --time=$time" + fi + if (var_defined partition); then + SLURM_ARGS+=" --partition=$partition" + fi + if (var_defined qos); then + SLURM_ARGS+=" --qos=$qos" + fi + + SLURM_ARGS+=" --hint=nomultithread" + + echo $SLURM_ARGS + if (var_defined restart_on_timeout) && [ $restart_on_timeout == "true" ]; then + echo "Will restart on timeout!" + set +e # don't stop on errors + # loop until the job completes successfully + while true; do + # submit the job and capture its exit code + srun $SLURM_ARGS bash $wrapper + exit_code=$? + + # if the job completed successfully, break the loop + if [ $exit_code -eq 0 ]; then + break + # else if code is not 143 (timeout) propagate error + elif [ $exit_code -ne 143 ]; then + echo "Slurm job failed with exit code $exit_code" + exit $exit_code + else + echo "Job timed out, resubmitting..." + fi + + # otherwise, wait a bit and then continue the loop to resubmit the job + sleep 60 + done + else + srun $SLURM_ARGS --hint=nomultithread bash $wrapper + fi + } +} \ No newline at end of file From 01eab326c7e4bbf734c1ad0652e92a3004ab29a9 Mon Sep 17 00:00:00 2001 From: Patrick Date: Tue, 26 Nov 2024 20:39:53 +0000 Subject: [PATCH 4/7] Update convert script for new version --- examples/multimodal/convert_to_hf.py | 229 +++++++++++------- examples/multimodal/dataloader_provider.py | 2 +- examples/multimodal/evaluation_datasets.py | 15 +- examples/multimodal/run_text_generation.py | 2 +- .../text_generation_mistral_clip.sh | 11 +- tapes/main.tape | 220 ++++++++++++++++- 6 files changed, 368 insertions(+), 111 deletions(-) diff --git a/examples/multimodal/convert_to_hf.py b/examples/multimodal/convert_to_hf.py index ff49af53e5..1f6c2b7747 100644 --- a/examples/multimodal/convert_to_hf.py +++ b/examples/multimodal/convert_to_hf.py @@ -27,34 +27,118 @@ def parse_args(): parser.add_argument("--hf-save-dir", required=True) parser.add_argument("--original-text-model-id", required=True) parser.add_argument("--original-vision-model-id", required=True) + parser.add_argument("--source-tp-size", type=int, default=4) parser.add_argument("--target-params-dtype", type=str, default="float16") return parser.parse_args() def main(): - initialize_distributed(tensor_model_parallel_size=1, pipeline_model_parallel_size=1) - model_parallel_cuda_manual_seed(123) args = parse_args() convert_mcore2hf(args) +def get_checkpoint_sub_dir_name(*, tp_rank): + """Get the checkpoint subdirectory name based on parallel ranks.""" + sub_dir_name = f"mp_rank_{tp_rank:02d}" + return sub_dir_name -def initialize_distributed(tensor_model_parallel_size=1, pipeline_model_parallel_size=1): - parallel_state.destroy_model_parallel() - rank = int(os.environ['LOCAL_RANK']) - world_size = torch.cuda.device_count() - torch.cuda.set_device(rank) - torch.distributed.init_process_group(world_size=world_size, rank=rank) - # Megatron core distributed training initialization - parallel_state.initialize_model_parallel( - tensor_model_parallel_size=tensor_model_parallel_size, - pipeline_model_parallel_size=pipeline_model_parallel_size, - ) +def gather_state_dict(mcore_sds, *, tp_size): + """ + Gather all tensor parallel shards into a single state dict. + Only concatenates parameters that are actually sharded. + + Args: + mcore_sds: List of state dicts [tp_rank] + tp_size: Number of tensor parallel shards + + Returns: + Consolidated state dict with all shards merged + """ + consolidated_sd = {} + + # Get all unique keys from all shards + all_keys = set() + for tp_rank in range(tp_size): + all_keys.update(mcore_sds[tp_rank].keys()) + + # Define patterns for sharded parameters + sharded_output_dim_patterns = [ + "self_attention.linear_qkv.weight", + "self_attention.linear_qkv.bias", + "mlp.linear_fc1.weight", + "mlp.linear_fc1.bias", + "encoder.linear_fc1.weight", + "encoder.linear_fc1.bias", + "word_embeddings.weight", + "output_layer.weight" + ] + + sharded_input_dim_patterns = [ + "self_attention.linear_proj.weight", + "mlp.linear_fc2.weight", + "encoder.linear_fc2.weight", + ] + + # For each key, try to gather the shards + for key in all_keys: + # Skip args and other non-tensor keys + if not isinstance(mcore_sds[0].get(key, None), torch.Tensor): + consolidated_sd[key] = mcore_sds[0][key] + continue + + # Collect shards if they exist + shards = [] + for tp_rank in range(tp_size): + if key in mcore_sds[tp_rank]: + shards.append(mcore_sds[tp_rank][key]) + + # If only one shard exists, use it directly + if len(shards) == 1: + consolidated_sd[key] = shards[0] + continue + + # Check if this parameter should be sharded + is_sharded_output = any(pattern in key for pattern in sharded_output_dim_patterns) + is_sharded_input = any(pattern in key for pattern in sharded_input_dim_patterns) + # special case for language model mlp fc1 weights due to swiglu activation + if "language_model" in key and "mlp.linear_fc1.weight" in key: + print(key) + # Get all shards + shards = [] + for tp_rank in range(tp_size): + if key in mcore_sds[tp_rank]: + shard = mcore_sds[tp_rank][key] + # Split each shard into gate and up parts + shard_size = shard.shape[0] // 2 + gate_shard = shard[:shard_size] + up_shard = shard[shard_size:] + shards.append((gate_shard, up_shard)) + + # Concatenate gate and up parts separately + gate_shards = [s[0] for s in shards] + up_shards = [s[1] for s in shards] + + # Then concatenate them in the correct order + consolidated_sd[key] = torch.cat([ + torch.cat(gate_shards, dim=0), + torch.cat(up_shards, dim=0) + ], dim=0) + elif is_sharded_output: + consolidated_sd[key] = torch.cat(shards, dim=0) + elif is_sharded_input: + consolidated_sd[key] = torch.cat(shards, dim=1) + else: + # For non-sharded parameters, just take the first shard + consolidated_sd[key] = shards[0] + + return consolidated_sd + def convert_mcore2hf(args): """Main function to convert MCore checkpoint to HF format""" # TODO: add support for casting explicitly to dtype dtype = getattr(torch, args.target_params_dtype) + tp_size = args.source_tp_size print(f"> Loading MCore checkpoints") assert os.path.exists(f"{args.mcore_load_dir}/latest_checkpointed_iteration.txt") @@ -64,17 +148,31 @@ def convert_mcore2hf(args): iteration = int(f.read().strip()) iter_dir = f"{args.mcore_load_dir}/iter_{iteration:07d}" - # start by loading the args from the checkpoint - margs = dist_checkpointing.load_common_state_dict(iter_dir)['args'] + # Initialize nested list to store state dicts for each parallel rank + mcore_sds = [{} for _ in range(tp_size)] + + # Load all checkpoint shards + mcore_args = [] + for tp_rank in range(tp_size): + print(f" > Loading tp_rank={tp_rank}") + sub_dir_name = get_checkpoint_sub_dir_name(tp_rank=tp_rank,) + checkpoint_path = f"{iter_dir}/{sub_dir_name}/model_optim_rng.pt" + checkpoint = torch.load(checkpoint_path, map_location='cpu') + mcore_args.append(checkpoint['args']) + mcore_sds[tp_rank] = checkpoint['model'] + + # Verify all args are the same + # TODO: it fails even tho they are the same? + # assert all([mcore_args[0] == mcore_arg for mcore_arg in mcore_args]) + margs = mcore_args[0] print(f"> Loaded args from checkpoint: {margs}") - args.tensor_model_parallel_size = 1 - # load the model checkpoint itself - model = model_provider(args=margs) - sharded_state_dict = model.sharded_state_dict(prefix='') - checkpoint = dist_checkpointing.load( - sharded_state_dict=sharded_state_dict, checkpoint_dir=iter_dir - ) + #import pdb; pdb.set_trace() + + # Consolidate sharded state dict + print(f"> Consolidating sharded checkpoints") + checkpoint = gather_state_dict(mcore_sds, tp_size=tp_size) + # import pdb; pdb.set_trace() # create the HF config @@ -89,13 +187,13 @@ def convert_mcore2hf(args): hf_state_dict = {} # Convert vision model weights - vision_state_dict = convert_mcore2hf_vision_model(checkpoint) + vision_state_dict = convert_mcore2hf_vision_model(checkpoint, margs) # Convert language model weights - language_state_dict = convert_mcore2hf_language_model(checkpoint) + language_state_dict = convert_mcore2hf_language_model(checkpoint, margs) # Convert projection weights - projection_state_dict = convert_mcore2hf_vision_projection(checkpoint) + projection_state_dict = convert_mcore2hf_vision_projection(checkpoint, margs) # Combine all state dicts hf_state_dict.update(vision_state_dict) @@ -141,6 +239,7 @@ def create_hf_config(original_text_model_id, original_vision_model_id, margs): hf_config = LlavaConfig( vision_config=vision_config, text_config=text_config, + vision_feature_layer=-1, # megatron uses the last layer # Add any other LLaVA specific configs here ) return hf_config @@ -154,10 +253,12 @@ def create_hf_processor(hf_config, text_model_id, vision_model_id): hf_config.pad_token_id = tokenizer.pad_token_id try: + # WARNING: this was used for custom version transformer + # that implemented that llave image processing pipeline + # see https://github.com/huggingface/transformers/pull/33191 + # for a status on it being merged to the official repo from transformers.models.llava.image_processing_llava import LlavaImageProcessor - image_processor = LlavaImageProcessor( - do_megatron_pp=True, - ) + image_processor = LlavaImageProcessor() except ImportError: print("> WARNING: could not import LlavaImageProcessor, using AutoImageProcessor instead") print("> This might lead to performance degradation due to slightly different image pre-processing") @@ -167,7 +268,7 @@ def create_hf_processor(hf_config, text_model_id, vision_model_id): return processor -def convert_mcore2hf_vision_model(mcore_sd): +def convert_mcore2hf_vision_model(mcore_sd, margs): """Convert vision model weights from Megatron to HF format""" state_dict = {} @@ -277,67 +378,7 @@ def convert_mcore2hf_vision_model(mcore_sd): return state_dict -def convert_mcore2hf_vision_model_new(mcore_sd): - """Convert vision model weights from Megatron to HF format""" - state_dict = {} - - # Vision embedding layers - state_dict.update( - { - "vision_tower.vision_model.embeddings.class_embedding": mcore_sd[ - "vision_model.class_token" - ].squeeze(), - "vision_tower.vision_model.embeddings.position_embedding.weight": mcore_sd[ - "vision_model.position_embeddings.weight" - ], - "vision_tower.vision_model.embeddings.patch_embedding.weight": mcore_sd[ - "vision_model.conv1.weight" - ], - "vision_tower.vision_model.pre_layrnorm.weight": mcore_sd["vision_model.ln_pre.weight"], - "vision_tower.vision_model.pre_layrnorm.bias": mcore_sd["vision_model.ln_pre.bias"], - } - ) - - # Vision transformer layers - clip_num_layers = 24 - for layer_i in range(clip_num_layers): - hf_layer_prefix = f"vision_tower.vision_model.encoder.layers.{layer_i}" - mcore_layer_prefix = f"vision_model.decoder.layers.{layer_i}" - - # Get QKV weights and biases - qkv_weight = mcore_sd[f"{mcore_layer_prefix}.self_attention.linear_qkv.weight"] - qkv_bias = mcore_sd[f"{mcore_layer_prefix}.self_attention.linear_qkv.bias"] - - # Calculate dimensions - hidden_dim = qkv_weight.shape[1] - num_heads = mcore_sd["args"].num_attention_heads - head_dim = hidden_dim // num_heads - - # Split QKV weights and biases - q_weight, k_weight, v_weight = qkv_weight.chunk(3, dim=0) - q_bias, k_bias, v_bias = qkv_bias.chunk(3, dim=0) - - # Ensure these are correctly assigned in the state_dict - state_dict.update( - { - f"{hf_layer_prefix}.self_attn.q_proj.weight": q_weight, - f"{hf_layer_prefix}.self_attn.k_proj.weight": k_weight, - f"{hf_layer_prefix}.self_attn.v_proj.weight": v_weight, - f"{hf_layer_prefix}.self_attn.q_proj.bias": q_bias, - f"{hf_layer_prefix}.self_attn.k_proj.bias": k_bias, - f"{hf_layer_prefix}.self_attn.v_proj.bias": v_bias, - } - ) - - # NOTE: for some reason, Megatron removes the post_layernorm weights and biases - # so we need to add them back in for the HF model, - # ensuring they perform the identity mapping - state_dict["vision_tower.vision_model.post_layernorm.weight"] = torch.ones(1024) - state_dict["vision_tower.vision_model.post_layernorm.bias"] = torch.zeros(1024) - return state_dict - - -def convert_mcore2hf_language_model(mcore_sd): +def convert_mcore2hf_language_model(mcore_sd, margs): """Convert language model weights from Megatron to HF format""" state_dict = {} @@ -353,7 +394,7 @@ def convert_mcore2hf_language_model(mcore_sd): state_dict["language_model.lm_head.weight"] = mcore_sd["language_model.output_layer.weight"] # Transformer layers - for layer_i in range(mcore_sd["args"].num_layers): + for layer_i in range(margs.num_layers): mcore_prefix = f"language_model.decoder.layers.{layer_i}" hf_prefix = f"language_model.model.layers.{layer_i}" @@ -378,8 +419,8 @@ def convert_mcore2hf_language_model(mcore_sd): # llava_model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") hidden_size = qkv_weight.shape[1] - num_kv_heads = mcore_sd["args"].num_query_groups - num_heads = mcore_sd["args"].num_attention_heads + num_kv_heads = margs.num_query_groups + num_heads = margs.num_attention_heads num_queries_per_group = num_heads // num_kv_heads head_dim = hidden_size // num_heads @@ -434,7 +475,7 @@ def convert_mcore2hf_language_model(mcore_sd): return state_dict -def convert_mcore2hf_vision_projection(mcore_sd): +def convert_mcore2hf_vision_projection(mcore_sd, margs): """Convert vision projection weights from Megatron to HF format""" state_dict = {} diff --git a/examples/multimodal/dataloader_provider.py b/examples/multimodal/dataloader_provider.py index 923b518643..230e7c953d 100644 --- a/examples/multimodal/dataloader_provider.py +++ b/examples/multimodal/dataloader_provider.py @@ -143,7 +143,7 @@ def train_valid_test_dataloaders_provider(train_val_test_num_samples): ] test_dataloader = None - return EnergonDataloader(train_dataloader), valid_dataloader, EnergonDataloader(test_dataloader) + return EnergonDataloader(train_dataloader), valid_dataloader, test_dataloader # EnergonDataloader(test_dataloader) class EnergonDataloader: diff --git a/examples/multimodal/evaluation_datasets.py b/examples/multimodal/evaluation_datasets.py index 97f9ba926f..82e781ec0d 100644 --- a/examples/multimodal/evaluation_datasets.py +++ b/examples/multimodal/evaluation_datasets.py @@ -124,7 +124,15 @@ def __init__( use_thumbnail, vision_model_type, ): - image_files = sorted(glob.glob(input_image_path + "/*")) + gts = json.load(open(gt_path)) + answers = defaultdict(list) + image_files = list() + + for gt in gts: + image_files.append(input_image_path + "/" + gt["image"]) + answers[gt["image"]] = gt['caption'] + + image_files = sorted(image_files) # Optionally, process only a subset of the input files. if num_partitions > 0: @@ -133,11 +141,6 @@ def __init__( ) image_files = image_files[lb:ub] - gts = json.load(open(gt_path)) - answers = defaultdict(list) - for gt in gts["annotations"]: - answers[gt["image_id"]].append(gt['caption']) - self._image_files = image_files self._answers = answers self._img_h = img_h diff --git a/examples/multimodal/run_text_generation.py b/examples/multimodal/run_text_generation.py index f4bb5025ff..cadc1120bc 100644 --- a/examples/multimodal/run_text_generation.py +++ b/examples/multimodal/run_text_generation.py @@ -354,7 +354,7 @@ def _forward(self, tokens, position_ids, attention_mask): ) def __call__(self, tokens, position_ids, attention_mask): - num_image_tokens = (tokens == self.model.image_token_index).sum().item() + num_image_tokens = (tokens == self.model.module.image_token_index).sum().item() num_tokens = tokens.size(1) recv_buffer_seq_length = None if num_image_tokens > 0: diff --git a/examples/multimodal/text_generation_mistral_clip.sh b/examples/multimodal/text_generation_mistral_clip.sh index 1fa4bb4063..51710ca237 100755 --- a/examples/multimodal/text_generation_mistral_clip.sh +++ b/examples/multimodal/text_generation_mistral_clip.sh @@ -1,4 +1,5 @@ #!/bin/bash +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) export NCCL_IB_SL=1 export CUDA_DEVICE_MAX_CONNECTIONS=1 @@ -7,6 +8,7 @@ export NVTE_APPLY_QK_LAYER_SCALING=0 GROUNDTRUTH_PATH="placeholder" NUM_FRAMES=1 NUM_GPUS=4 +BATCH_SIZE=32 while [[ $# -gt 0 ]]; do case $1 in @@ -50,6 +52,11 @@ while [[ $# -gt 0 ]]; do shift shift ;; + -b|--batch-size) + BATCH_SIZE="$2" + shift + shift + ;; --num-gpus) NUM_GPUS="$2" shift @@ -69,7 +76,7 @@ END=0 for PARTITION_ID in $( eval echo {$START..$END} ) do - torchrun --nproc_per_node ${NUM_GPUS} examples/multimodal/run_text_generation.py \ + torchrun --nproc_per_node ${NUM_GPUS} ${SCRIPT_DIR}/run_text_generation.py \ --apply-layernorm-1p \ --attention-softmax-in-fp32 \ --use-flash-attn \ @@ -101,7 +108,7 @@ do --tokenizer-model ${TOKENIZER_PATH} \ --tokenizer-prompt-format mistral \ --bf16 \ - --micro-batch-size 32 \ + --micro-batch-size ${BATCH_SIZE} \ --seq-length 576 \ --decoder-seq-length 2048 \ --out-seq-length 12 \ diff --git a/tapes/main.tape b/tapes/main.tape index 45b6a739de..8b083fc245 100644 --- a/tapes/main.tape +++ b/tapes/main.tape @@ -25,6 +25,19 @@ global { pretrain_save_interval=500 pretrain_eval_interval=500 + # fine-tuning arguments + finetune_model_dir=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/mistral_7b_instruct_sft + finetune_tensorboard=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/mistral_7b_instruct_sft + finetune_iters=5000 + finetune_bsz=128 + finetune_lr=1e-6 + finetune_lr_warmup=0.01 + finetune_save_interval=1000 + finetune_eval_interval=1000 + + # eval arguments + coco_dir=/lustre/fswork/projects/rech/qjm/ued79zb/coco/ + eval_bsz=32 # -- submitter arguments -- submitter=scslurm @@ -39,6 +52,18 @@ global { pretrain_time="10:00:00" pretrain_cpus=80 pretrain_gres="gpu:4" + + finetune_C="h100" + finetune_account="qjm@h100" + finetune_time="10:00:00" + finetune_cpus=80 + finetune_gres="gpu:4" + + eval_C="h100" + eval_account="qjm@h100" + eval_time="1:00:00" + eval_cpus=80 + eval_gres="gpu:4" } task PrepareModel @@ -81,7 +106,7 @@ task PrepareModel # Combine models echo "Combining language and vision models..." - bash $repo/examples/multimodal/combine_mistral_clip.sh \ + bash $repo/examples/multimodal/combine_lm_vision_checkpoints.sh \ mistral_mcore_dir \ clip_mcore_dir \ $initial_model @@ -177,15 +202,16 @@ task PretrainModel --log-interval 10 \ --eval-iters 10 \ --eval-interval ${eval_interval} \ - --tokenizer-type HuggingFaceTokenizer \ + --tokenizer-type MultimodalTokenizer \ --tokenizer-model ${tokenizer_model} \ + --tokenizer-prompt-format mistral \ --data-path ${repo}/examples/multimodal/pretrain_dataset.yaml \ --prompt-path ${repo}/examples/multimodal/manual_prompts.json \ --save-interval ${save_interval} \ --save ${model_dir} \ --load ${model_dir} \ --dataloader-save ${model_dir}/dataloader \ - --pretrained-checkpoint ${initial_model}/mistral_instruct_clip336_tp4_combined_mcore \ + --pretrained-checkpoint ${initial_model} \ --split 100,0,0 \ --clip-grad 1.0 \ --weight-decay 1e-2 \ @@ -206,9 +232,189 @@ task PretrainModel --language-model-type=mistral_7b \ --disable-vision-class-token \ --distributed-timeout-minutes 60 \ - --allow-missing-vision-projection-checkpoint + --allow-missing-vision-projection-checkpoint \ + --ckpt-format torch } -plan TrainLLaVA { - reach PretrainModel -} \ No newline at end of file + +task EvaluatePretrainedModel + < coco_dir=@ + < model_dir=@PretrainModel + > coco_results + :: repo=@ + :: eval_bsz=@ + :: tokenizer_model=$mistral_model + :: .submitter=@ + :: .C=$eval_C + :: .account=$eval_account + :: .time=$eval_time + :: .cpus=$eval_cpus + :: .gres=$eval_gres +{ + echo "Evaluating pretrained model..." + # TODO: inline the bash script + bash $repo/examples/multimodal/text_generation_mistral_clip.sh \ + --input-image-path $coco_dir \ + --gt-path $coco_dir/coco_karpathy_test.json \ + --output-path coco_outputs \ + --model-path $model_dir \ + --tokenizer-path $tokenizer_model \ + --batch-size $eval_bsz \ + --task captioning + python $repo/examples/multimodal/evaluate_coco.py \ + --input-path coco_outputs \ + --groundtruth-path $coco_dir/coco_karpathy_test_gt.json \ + > $coco_results +} + +task FineTuneModel + < pretrained_dir=$model_dir@PretrainModel + > finetuned_dir + :: repo=@ + :: tokenizer_model=$mistral_model + :: train_iters=$finetune_iters + :: batch_size=$finetune_bsz + :: lr=$finetune_lr + :: lr_warmup_fraction=$finetune_lr_warmup + :: save_interval=$finetune_save_interval + :: eval_interval=$finetune_eval_interval + :: external_resume=true + :: external_model_dir=$finetune_model_dir + :: external_tensorboard=$finetune_tensorboard + :: .submitter=@ + :: .C=$finetune_C + :: .account=$finetune_account + :: .time=$finetune_time + :: .cpus=$finetune_cpus + :: .gres=$finetune_gres +{ + export NCCL_IB_SL=1 + export CUDA_DEVICE_MAX_CONNECTIONS=1 + + # Handle external directories similar to PretrainModel + if [ "$external_model_dir" != "" ]; then + if [ "$external_resume" == false ]; then + rm -rf $external_model_dir + fi + mkdir -p $external_model_dir + ln -sf $external_model_dir $finetuned_dir + fi + + if [ "$external_tensorboard" != "" ]; then + mkdir -p $external_tensorboard + tensorboard=$external_tensorboard + else + mkdir -p tensorboard + tensorboard=tensorboard + fi + + export NVTE_APPLY_QK_LAYER_SCALING=0 + export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 + + torchrun --nproc_per_node 4 $repo/examples/multimodal/train.py \ + --apply-layernorm-1p \ + --attention-softmax-in-fp32 \ + --use-checkpoint-args \ + --use-distributed-optimizer \ + --transformer-impl transformer_engine \ + --use-te \ + --normalization RMSNorm \ + --group-query-attention \ + --num-query-groups 8 \ + --no-masked-softmax-fusion \ + --num-workers 2 \ + --use-flash-attn \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-percent 1.0 \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout 0.1 \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --num-layers 32 \ + --hidden-size 4096 \ + --num-attention-heads 32 \ + --seq-length 576 \ + --decoder-seq-length 2048 \ + --max-position-embeddings 4096 \ + --ffn-hidden-size 14336 \ + --train-iters ${train_iters} \ + --micro-batch-size 8 \ + --global-batch-size ${batch_size} \ + --lr-decay-iters ${train_iters} \ + --lr-warmup-fraction ${lr_warmup_fraction} \ + --lr ${lr} \ + --min-lr 1.0e-7 \ + --lr-decay-style cosine \ + --log-interval 10 \ + --eval-iters 10 \ + --eval-interval ${eval_interval} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model ${tokenizer_model} \ + --tokenizer-prompt-format mistral \ + --data-path ${repo}/examples/multimodal/sft_dataset.yaml \ + --prompt-path ${repo}/examples/multimodal/manual_prompts.json \ + --save-interval ${save_interval} \ + --save ${finetuned_dir} \ + --load ${finetuned_dir} \ + --dataloader-save ${finetuned_dir}/dataloader \ + --pretrained-checkpoint ${pretrained_dir} \ + --split 100,0,0 \ + --clip-grad 0.5 \ + --weight-decay 0.1 \ + --adam-beta1 0.9 \ + --adam-beta2 0.95 \ + --init-method-std 0.014 \ + --log-params-norm \ + --log-num-zeros-in-grad \ + --bf16 \ + --eod-mask-loss \ + --freeze-ViT \ + --patch-dim 14 \ + --img-h 336 \ + --img-w 336 \ + --dataloader-type external \ + --tensorboard-dir ${tensorboard} \ + --language-model-type=mistral_7b \ + --disable-vision-class-token \ + --distributed-timeout-minutes 60 \ + --ckpt-format torch +} + +#task EvaluateFinetunedModel +# < mmmu_dir=@ +# < finetuned_dir=@FineTuneModel +# > mmmu_results +# :: repo=@ +# :: tokenizer_model=$mistral_model +# :: .submitter=@ +# :: .C=$eval_C +# :: .account=$eval_account +# :: .time=$eval_time +# :: .cpus=$eval_cpus +# :: .gres=$eval_gres +#{ +# bash $repo/examples/multimodal/text_generation_mistral_clip.sh \ +# --input-image-path $mmmu_dir \ +# --gt-path $mmmu_dir/mmmu_test_gt.json \ +# --output-path mmmu_outputs \ +# --model-path $finetuned_dir \ +# --tokenizer-path $tokenizer_model \ +# --batch-size $eval_bsz \ +# --task mmmu +# python $repo/examples/multimodal/evaluate_mmmu.py \ +# --input-path "mmmu_outputs-MMMU-dprank\=0-partition\=0.jsonl" +#} + +#func ConvertToHF +# < model_dir=@ +# > hf_model_dir +# :: repo +# :: submitter +#{ +# python $repo/examples/multimodal/convert_to_hf.py \ +#} +# \ No newline at end of file From 46ea312ee09f69cf54ecd6fae6d29bdb433a58c9 Mon Sep 17 00:00:00 2001 From: Patrick Date: Wed, 27 Nov 2024 11:00:05 +0000 Subject: [PATCH 5/7] Update tapes and scripts to be compatible with tapes --- examples/multimodal/convert_to_hf.py | 4 +- examples/multimodal/evaluate_coco.py | 48 +++++- examples/multimodal/evaluate_mmmu.py | 4 +- examples/multimodal/evaluation_datasets.py | 3 +- .../text_generation_mistral_clip.sh | 9 +- tapes/main.tape | 146 +++++++++++++----- 6 files changed, 169 insertions(+), 45 deletions(-) diff --git a/examples/multimodal/convert_to_hf.py b/examples/multimodal/convert_to_hf.py index 1f6c2b7747..b7d90811de 100644 --- a/examples/multimodal/convert_to_hf.py +++ b/examples/multimodal/convert_to_hf.py @@ -29,6 +29,7 @@ def parse_args(): parser.add_argument("--original-vision-model-id", required=True) parser.add_argument("--source-tp-size", type=int, default=4) parser.add_argument("--target-params-dtype", type=str, default="float16") + parser.add_argument("--upload-to-hub", default=None) return parser.parse_args() @@ -210,7 +211,8 @@ def convert_mcore2hf(args): print(f"> Saving HF model to {args.hf_save_dir}") hf_model.save_pretrained(args.hf_save_dir) - + if args.upload_to_hub is not None: + hf_model.push_to_hub(args.upload_to_hub) def create_hf_config(original_text_model_id, original_vision_model_id, margs): """Create HF config from Megatron checkpoint""" diff --git a/examples/multimodal/evaluate_coco.py b/examples/multimodal/evaluate_coco.py index a717090c92..e8072ae01c 100644 --- a/examples/multimodal/evaluate_coco.py +++ b/examples/multimodal/evaluate_coco.py @@ -1,11 +1,57 @@ # Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. import argparse import json +import glob +import os +from dataclasses import dataclass -from evaluate_mmmu import get_input_output_paths +#from evaluate_mmmu import get_input_output_paths from pycocoevalcap.eval import COCOEvalCap from pycocotools.coco import COCO +@dataclass +class EvaluationConfig: + """Evaluation related configuration.""" + task: str + + temperature: float = 1.0 + top_p: float = 0.0 + top_k: int = 0 + + out_seq_length: int = 32 + + output_path: str = "" + + input_image_path: str = "" + gt_path: str = "" + + num_partitions: int = 1 + partition_id: int = 0 + num_samples_per_partition: int = 0 + + +def get_output_path(config, dp_rank): + """Generation output path.""" + return ( + f"{config.output_path}-{config.task}-dprank={dp_rank}-partition={config.partition_id}.jsonl" + ) + +def get_input_output_paths(input_path, task): + """Get all input files and an output path for a merged file.""" + # Single input file. + if os.path.exists(input_path): + input_file_paths = [input_path] + output_file_path = input_path.replace(".jsonl", "-merged.json") + # Select multiple partitions and dp ranks. + else: + cfg = EvaluationConfig(task=task, output_path=input_path, partition_id="*") + pattern = get_output_path(cfg, dp_rank="*") + input_file_paths = glob.glob(pattern) + + output_file_path = input_path + f"-{task}-merged.json" + + return input_file_paths, output_file_path + def convert_to_coco_format(input_path): """Convert input files to COCO compatible format.""" diff --git a/examples/multimodal/evaluate_mmmu.py b/examples/multimodal/evaluate_mmmu.py index 66118fa905..49ef9a4968 100644 --- a/examples/multimodal/evaluate_mmmu.py +++ b/examples/multimodal/evaluate_mmmu.py @@ -64,7 +64,7 @@ def mmmu_eval(input_path, groundtruth_path): output = subprocess.run( [ "python", - "examples/multimodal/MMMU/mmmu/main_eval_only.py", + "/linkhome/rech/genrce01/ued79zb/repos/Megatron-LM-pretrain/examples/multimodal/MMMU/mmmu/main_eval_only.py", "--output_path", result_file, "--answer_path", @@ -85,7 +85,7 @@ def mmmu_eval(input_path, groundtruth_path): def main(): """Run MMMU evaluation.""" # Using the validation groundtruth file from the MMMU repo by default. This assumes you have cloned the MMMU github repo here. - default_groundtruth_path = "examples/multimodal/MMMU/mmmu/answer_dict_val.json" + default_groundtruth_path = "/linkhome/rech/genrce01/ued79zb/repos/Megatron-LM-pretrain/examples/multimodal/MMMU/mmmu/answer_dict_val.json" parser = argparse.ArgumentParser() parser.add_argument("--input-path", type=str, required=True, help="Path to input file(s)") diff --git a/examples/multimodal/evaluation_datasets.py b/examples/multimodal/evaluation_datasets.py index 82e781ec0d..66fbcdc1bd 100644 --- a/examples/multimodal/evaluation_datasets.py +++ b/examples/multimodal/evaluation_datasets.py @@ -234,7 +234,8 @@ def __init__( dataset = dataset[lb:ub] # Using the LLaVA config from the MMMU repo. - config = load_yaml("examples/multimodal/MMMU/mmmu/configs/llava1.5.yaml") + # TODO: remove the hardcoded path. + config = load_yaml("/linkhome/rech/genrce01/ued79zb/repos/Megatron-LM-pretrain/examples/multimodal/MMMU/mmmu/configs/llava1.5.yaml") for k, v in config.items(): if isinstance(v, list): assert len(v) == 1, "only one value supported." diff --git a/examples/multimodal/text_generation_mistral_clip.sh b/examples/multimodal/text_generation_mistral_clip.sh index 51710ca237..6add85c92b 100755 --- a/examples/multimodal/text_generation_mistral_clip.sh +++ b/examples/multimodal/text_generation_mistral_clip.sh @@ -9,7 +9,7 @@ GROUNDTRUTH_PATH="placeholder" NUM_FRAMES=1 NUM_GPUS=4 BATCH_SIZE=32 - +PROMPT_FORMAT="mistral" while [[ $# -gt 0 ]]; do case $1 in --input-image-path) @@ -42,6 +42,11 @@ while [[ $# -gt 0 ]]; do shift shift ;; + --tokenizer-prompt-format) + PROMPT_FORMAT="$2" + shift + shift + ;; --task) TASK="$2" shift @@ -106,7 +111,7 @@ do --load ${MODEL_PATH} \ --tokenizer-type MultimodalTokenizer \ --tokenizer-model ${TOKENIZER_PATH} \ - --tokenizer-prompt-format mistral \ + --tokenizer-prompt-format ${PROMPT_FORMAT} \ --bf16 \ --micro-batch-size ${BATCH_SIZE} \ --seq-length 576 \ diff --git a/tapes/main.tape b/tapes/main.tape index 8b083fc245..7456d4257e 100644 --- a/tapes/main.tape +++ b/tapes/main.tape @@ -11,7 +11,16 @@ global { # multimodal model parameters # (base lm, tp, etc...) clip_original_dir=/lustre/fswork/projects/rech/qjm/ued79zb/clip_model_og/ - mistral_model="mistralai/Mistral-7B-Instruct-v0.3" + mistral_model=( + TextModel: + mistral="mistralai/Mistral-7B-Instruct-v0.3" + tower="Unbabel/TowerInstruct-Mistral-7B-v0.2" + ) + prompt_format=( + TextModel: + mistral="mistral" + tower="chatml" + ) tp=4 pp=1 @@ -37,8 +46,12 @@ global { # eval arguments coco_dir=/lustre/fswork/projects/rech/qjm/ued79zb/coco/ + mmmu_dir=/dev/null eval_bsz=32 + # convert arguments + upload_id="patricksf/mistral-7b-clip-prt" + # -- submitter arguments -- submitter=scslurm @@ -64,6 +77,11 @@ global { eval_time="1:00:00" eval_cpus=80 eval_gres="gpu:4" + + convert_account="qjm@cpu" + convert_time="1:00:00" + convert_cpus=4 + convert_partition="prepost" } task PrepareModel @@ -121,6 +139,7 @@ task PretrainModel > model_dir :: repo=@ :: tokenizer_model=$mistral_model + :: prompt_format=@ :: train_iters=$pretrain_iters :: batch_size=$pretrain_bsz :: lr=$pretrain_lr @@ -204,7 +223,7 @@ task PretrainModel --eval-interval ${eval_interval} \ --tokenizer-type MultimodalTokenizer \ --tokenizer-model ${tokenizer_model} \ - --tokenizer-prompt-format mistral \ + --tokenizer-prompt-format ${prompt_format} \ --data-path ${repo}/examples/multimodal/pretrain_dataset.yaml \ --prompt-path ${repo}/examples/multimodal/manual_prompts.json \ --save-interval ${save_interval} \ @@ -244,6 +263,7 @@ task EvaluatePretrainedModel :: repo=@ :: eval_bsz=@ :: tokenizer_model=$mistral_model + :: prompt_format=@ :: .submitter=@ :: .C=$eval_C :: .account=$eval_account @@ -259,12 +279,13 @@ task EvaluatePretrainedModel --output-path coco_outputs \ --model-path $model_dir \ --tokenizer-path $tokenizer_model \ + --tokenizer-prompt-format ${prompt_format} \ --batch-size $eval_bsz \ --task captioning python $repo/examples/multimodal/evaluate_coco.py \ --input-path coco_outputs \ --groundtruth-path $coco_dir/coco_karpathy_test_gt.json \ - > $coco_results + | tee $coco_results } task FineTuneModel @@ -272,6 +293,7 @@ task FineTuneModel > finetuned_dir :: repo=@ :: tokenizer_model=$mistral_model + :: prompt_format=@ :: train_iters=$finetune_iters :: batch_size=$finetune_bsz :: lr=$finetune_lr @@ -354,7 +376,7 @@ task FineTuneModel --eval-interval ${eval_interval} \ --tokenizer-type MultimodalTokenizer \ --tokenizer-model ${tokenizer_model} \ - --tokenizer-prompt-format mistral \ + --tokenizer-prompt-format ${prompt_format} \ --data-path ${repo}/examples/multimodal/sft_dataset.yaml \ --prompt-path ${repo}/examples/multimodal/manual_prompts.json \ --save-interval ${save_interval} \ @@ -384,37 +406,85 @@ task FineTuneModel --ckpt-format torch } -#task EvaluateFinetunedModel -# < mmmu_dir=@ -# < finetuned_dir=@FineTuneModel -# > mmmu_results -# :: repo=@ -# :: tokenizer_model=$mistral_model -# :: .submitter=@ -# :: .C=$eval_C -# :: .account=$eval_account -# :: .time=$eval_time -# :: .cpus=$eval_cpus -# :: .gres=$eval_gres -#{ -# bash $repo/examples/multimodal/text_generation_mistral_clip.sh \ -# --input-image-path $mmmu_dir \ -# --gt-path $mmmu_dir/mmmu_test_gt.json \ -# --output-path mmmu_outputs \ -# --model-path $finetuned_dir \ -# --tokenizer-path $tokenizer_model \ -# --batch-size $eval_bsz \ -# --task mmmu -# python $repo/examples/multimodal/evaluate_mmmu.py \ -# --input-path "mmmu_outputs-MMMU-dprank\=0-partition\=0.jsonl" -#} - -#func ConvertToHF -# < model_dir=@ -# > hf_model_dir -# :: repo -# :: submitter -#{ -# python $repo/examples/multimodal/convert_to_hf.py \ -#} -# \ No newline at end of file +task EvaluateFinetunedModel + < mmmu_dir=@ + < finetuned_dir=@FineTuneModel + > mmmu_results + :: repo=@ + :: eval_bsz=@ + :: tokenizer_model=$mistral_model + :: prompt_format=@ + :: .submitter=@ + :: .C=$eval_C + :: .account=$eval_account + :: .time=$eval_time + :: .cpus=$eval_cpus + :: .gres=$eval_gres +{ + export PYTHONPATH="$PYTHONPATH:$repo/examples/multimodal" + bash $repo/examples/multimodal/text_generation_mistral_clip.sh \ + --input-image-path none \ + --output-path mmmu_outputs \ + --model-path $finetuned_dir \ + --tokenizer-path $tokenizer_model \ + --tokenizer-prompt-format ${prompt_format} \ + --batch-size $eval_bsz \ + --task MMMU + python $repo/examples/multimodal/evaluate_mmmu.py \ + --input-path mmmu_outputs \ + | tee $mmmu_results +} + +func ConvertToHF + < model_dir + > hf_model_dir + :: repo + :: mistral_model + :: upload_id +{ + python $repo/examples/multimodal/convert_to_hf.py \ + --mcore-load-dir $model_dir \ + --hf-save-dir $hf_model_dir \ + --original-text-model-id $mistral_model \ + --original-vision-model-id openai/clip-vit-large-patch14-336 \ + --upload-to-hub $upload_id +} + +task ConvertPretrainedModel calls ConvertToHF + < model_dir=@PretrainModel + > hf_model_dir + :: repo=@ + :: mistral_model=@ + :: upload_id=@ + :: .submitter=@ + :: .account=$convert_account + :: .time=$convert_time + :: .cpus=$convert_cpus + :: .partition=$convert_partition + +task ConvertFinetunedModel calls ConvertToHF + < model_dir=$finetuned_dir@FineTuneModel + > hf_model_dir + :: repo=@ + :: mistral_model=@ + :: .submitter=@ + :: .account=$convert_account + :: .time=$convert_time + :: .cpus=$convert_cpus + :: .partition=$convert_partition + + +plan TrainPipelineMistral { + reach EvaluatePretrainedModel + reach EvaluateFinetunedModel +} + +plan TrainPipelineTower { + reach EvaluatePretrainedModel via (TextModel: tower) + reach EvaluateFinetunedModel via (TextModel: tower) +} + +plan ConvertModels { + reach ConvertPretrainedModel + #reach ConvertFinetunedModel +} From 04c54e788f15cc5d437864cafa18b7f60113348d Mon Sep 17 00:00:00 2001 From: Patrick Date: Mon, 6 Jan 2025 11:08:25 +0000 Subject: [PATCH 6/7] update overall codebase to support nvlm-style models --- examples/multimodal/check_num_samples.py | 23 + examples/multimodal/config.py | 42 +- examples/multimodal/convert_to_hf.py | 192 +++-- examples/multimodal/evaluation_datasets.py | 4 +- .../hf_modeling_files/configuration_nvlm_d.py | 96 +++ .../image_processing_nvlm_d.py | 209 ++++++ .../hf_modeling_files/modeling_nvlm_d.py | 260 +++++++ .../hf_modeling_files/processing_nvlm_d.py | 168 +++++ examples/multimodal/model.py | 9 +- examples/multimodal/pangea_instruct.yaml | 15 + examples/multimodal/pixmo_caps.yaml | 15 + examples/multimodal/run_text_generation.py | 8 + .../core/models/multimodal/llava_model.py | 47 +- megatron/core/models/vision/clip_vit_model.py | 4 +- tapes/main.tape | 685 +++++++++++++----- tapes/og_datasets.tconf | 210 ++++++ tapes/pixcaps_pangea.tconf | 145 ++++ tools/checkpoint/loader_llama_mistral.py | 5 +- 18 files changed, 1896 insertions(+), 241 deletions(-) create mode 100644 examples/multimodal/check_num_samples.py create mode 100644 examples/multimodal/hf_modeling_files/configuration_nvlm_d.py create mode 100644 examples/multimodal/hf_modeling_files/image_processing_nvlm_d.py create mode 100644 examples/multimodal/hf_modeling_files/modeling_nvlm_d.py create mode 100644 examples/multimodal/hf_modeling_files/processing_nvlm_d.py create mode 100644 examples/multimodal/pangea_instruct.yaml create mode 100644 examples/multimodal/pixmo_caps.yaml create mode 100644 tapes/og_datasets.tconf create mode 100644 tapes/pixcaps_pangea.tconf diff --git a/examples/multimodal/check_num_samples.py b/examples/multimodal/check_num_samples.py new file mode 100644 index 0000000000..6e9286ea16 --- /dev/null +++ b/examples/multimodal/check_num_samples.py @@ -0,0 +1,23 @@ +import argparse + +from megatron.energon import get_train_dataset, get_loader, WorkerConfig + +def read_args(): + parser = argparse.ArgumentParser() + parser.add_argument("dataset_path", type=str, required=True) + return parser.parse_args() + +if __name__ == "__main__": + args = read_args() + simple_worker_config = WorkerConfig(rank=0, world_size=1, num_workers=1) + + train_ds = get_train_dataset( + args.dataset_path, + batch_size=1, + shuffle_buffer_size=None, + max_samples_per_sequence=None, + worker_config=simple_worker_config, + ) + + # print number of samples + print(f"Number of samples: {len(train_ds)}") \ No newline at end of file diff --git a/examples/multimodal/config.py b/examples/multimodal/config.py index f737fd31f7..511a63950b 100644 --- a/examples/multimodal/config.py +++ b/examples/multimodal/config.py @@ -87,6 +87,33 @@ def get_language_model_config(config): config.apply_rope_fusion = False config.attention_softmax_in_fp32 = True config.ffn_hidden_size = 29568 + elif config.language_model_type == "qwen2.5_7b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.add_qkv_bias = True + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 18944 + elif config.language_model_type == "eurollm_9b": + config.activation_func = torch.nn.functional.silu + config.add_bias_linear = False + config.bias_activation_fusion = False + config.gated_linear_unit = True + config.apply_query_key_layer_scaling = False + config.layernorm_zero_centered_gamma = ( + False # Zero centered gamma not supported for RMSNorm + ) + config.bias_dropout_fusion = False + config.apply_rope_fusion = False + config.attention_softmax_in_fp32 = True + config.ffn_hidden_size = 12288 else: raise ValueError(f"unknown language model type {config.language_model_type}") @@ -177,10 +204,6 @@ def get_vision_projection_config(config, hidden_size): elif config.language_model_type == "llama3_8b": config.ffn_hidden_size = 14336 config.activation_func = torch.nn.functional.gelu - elif config.language_model_type == "mistral_7b": - # TODO: check what needs to be done for other models - config.ffn_hidden_size = hidden_size # This was changed to make it compatible with HF's LLava - config.activation_func = torch.nn.functional.gelu elif config.language_model_type == "yi-34b": config.ffn_hidden_size = 20480 config.normalization = 'LayerNorm' @@ -189,6 +212,17 @@ def get_vision_projection_config(config, hidden_size): config.ffn_hidden_size = 29568 config.normalization = 'LayerNorm' config.activation_func = torch.nn.functional.gelu + # The following two have been changed to make it compatible with HF's LLava + elif config.language_model_type == "mistral_7b": + # TODO: check what needs to be done for other models + config.ffn_hidden_size = hidden_size # This was changed to make it compatible with HF's LLava + config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "qwen2.5_7b": + config.ffn_hidden_size = hidden_size # This was changed to make it compatible with HF's LLava + config.activation_func = torch.nn.functional.gelu + elif config.language_model_type == "eurollm_9b": + config.ffn_hidden_size = hidden_size # This was changed to make it compatible with HF's LLava + config.activation_func = torch.nn.functional.gelu else: raise ValueError(f"unknown language model type {config.language_model_type}") diff --git a/examples/multimodal/convert_to_hf.py b/examples/multimodal/convert_to_hf.py index b7d90811de..f9c89d77a5 100644 --- a/examples/multimodal/convert_to_hf.py +++ b/examples/multimodal/convert_to_hf.py @@ -17,8 +17,9 @@ ) from transformers import LlavaConfig -from model import model_provider +from huggingface_hub import HfApi +from model import model_provider def parse_args(): @@ -28,7 +29,8 @@ def parse_args(): parser.add_argument("--original-text-model-id", required=True) parser.add_argument("--original-vision-model-id", required=True) parser.add_argument("--source-tp-size", type=int, default=4) - parser.add_argument("--target-params-dtype", type=str, default="float16") + parser.add_argument("--target-params-dtype", type=str, default="bfloat16") + parser.add_argument("--hf-model-type", type=str, default="llava", choices=["llava", "nvlm_d"]) parser.add_argument("--upload-to-hub", default=None) return parser.parse_args() @@ -37,6 +39,7 @@ def main(): args = parse_args() convert_mcore2hf(args) + def get_checkpoint_sub_dir_name(*, tp_rank): """Get the checkpoint subdirectory name based on parallel ranks.""" sub_dir_name = f"mp_rank_{tp_rank:02d}" @@ -47,21 +50,21 @@ def gather_state_dict(mcore_sds, *, tp_size): """ Gather all tensor parallel shards into a single state dict. Only concatenates parameters that are actually sharded. - + Args: mcore_sds: List of state dicts [tp_rank] tp_size: Number of tensor parallel shards - + Returns: Consolidated state dict with all shards merged """ consolidated_sd = {} - + # Get all unique keys from all shards all_keys = set() for tp_rank in range(tp_size): all_keys.update(mcore_sds[tp_rank].keys()) - + # Define patterns for sharded parameters sharded_output_dim_patterns = [ "self_attention.linear_qkv.weight", @@ -71,28 +74,28 @@ def gather_state_dict(mcore_sds, *, tp_size): "encoder.linear_fc1.weight", "encoder.linear_fc1.bias", "word_embeddings.weight", - "output_layer.weight" - ] - + "output_layer.weight", + ] + sharded_input_dim_patterns = [ "self_attention.linear_proj.weight", "mlp.linear_fc2.weight", "encoder.linear_fc2.weight", ] - + # For each key, try to gather the shards for key in all_keys: # Skip args and other non-tensor keys if not isinstance(mcore_sds[0].get(key, None), torch.Tensor): consolidated_sd[key] = mcore_sds[0][key] continue - + # Collect shards if they exist shards = [] for tp_rank in range(tp_size): if key in mcore_sds[tp_rank]: shards.append(mcore_sds[tp_rank][key]) - + # If only one shard exists, use it directly if len(shards) == 1: consolidated_sd[key] = shards[0] @@ -114,16 +117,15 @@ def gather_state_dict(mcore_sds, *, tp_size): gate_shard = shard[:shard_size] up_shard = shard[shard_size:] shards.append((gate_shard, up_shard)) - + # Concatenate gate and up parts separately gate_shards = [s[0] for s in shards] up_shards = [s[1] for s in shards] - + # Then concatenate them in the correct order - consolidated_sd[key] = torch.cat([ - torch.cat(gate_shards, dim=0), - torch.cat(up_shards, dim=0) - ], dim=0) + consolidated_sd[key] = torch.cat( + [torch.cat(gate_shards, dim=0), torch.cat(up_shards, dim=0)], dim=0 + ) elif is_sharded_output: consolidated_sd[key] = torch.cat(shards, dim=0) elif is_sharded_input: @@ -131,7 +133,7 @@ def gather_state_dict(mcore_sds, *, tp_size): else: # For non-sharded parameters, just take the first shard consolidated_sd[key] = shards[0] - + return consolidated_sd @@ -156,7 +158,7 @@ def convert_mcore2hf(args): mcore_args = [] for tp_rank in range(tp_size): print(f" > Loading tp_rank={tp_rank}") - sub_dir_name = get_checkpoint_sub_dir_name(tp_rank=tp_rank,) + sub_dir_name = get_checkpoint_sub_dir_name(tp_rank=tp_rank) checkpoint_path = f"{iter_dir}/{sub_dir_name}/model_optim_rng.pt" checkpoint = torch.load(checkpoint_path, map_location='cpu') mcore_args.append(checkpoint['args']) @@ -168,19 +170,22 @@ def convert_mcore2hf(args): margs = mcore_args[0] print(f"> Loaded args from checkpoint: {margs}") - #import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() # Consolidate sharded state dict print(f"> Consolidating sharded checkpoints") checkpoint = gather_state_dict(mcore_sds, tp_size=tp_size) - # import pdb; pdb.set_trace() # create the HF config - hf_config = create_hf_config(args.original_text_model_id, args.original_vision_model_id, margs) + hf_config = create_hf_config( + args.original_text_model_id, args.original_vision_model_id, margs, args.hf_model_type + ) # create the tokenizer and processor - processor = create_hf_processor(hf_config, args.original_text_model_id, args.original_vision_model_id) + processor = create_hf_processor( + hf_config, args.original_text_model_id, args.original_vision_model_id, args.hf_model_type + ) processor.save_pretrained(args.hf_save_dir) # Convert the state dict @@ -203,18 +208,70 @@ def convert_mcore2hf(args): # create the HF model print(f"> Loading HF model and converted weights") - hf_model = LlavaForConditionalGeneration(config=hf_config) + if args.hf_model_type == "llava": + hf_model = LlavaForConditionalGeneration(config=hf_config) + elif args.hf_model_type == "nvlm_d": + from hf_modeling_files.modeling_nvlm_d import NVLM_D_Model + + hf_model = NVLM_D_Model(config=hf_config) + else: + raise ValueError(f"Unsupported model type: {args.hf_model_type}") + + # cast the model to the target dtype + hf_model.to(dtype=dtype) + + # TODO: for now, if megatron already expanded the embeddings, + # we re-shorten them to the original vocab size + # and then extend them again after loading the weights + if ( + args.hf_model_type == "llava" + and hf_state_dict["language_model.model.embed_tokens.weight"].size(0) + != hf_config.text_config.vocab_size + ): + # shorten the embeddings and output layer + hf_state_dict["language_model.model.embed_tokens.weight"] = hf_state_dict[ + "language_model.model.embed_tokens.weight" + ][: hf_config.text_config.vocab_size] + hf_state_dict["language_model.lm_head.weight"] = hf_state_dict[ + "language_model.lm_head.weight" + ][: hf_config.text_config.vocab_size] + hf_model.load_state_dict(hf_state_dict, strict=True) # extend the embeddings - extend_embeddings(hf_model, hf_config) + # TODO: double check why only llava model has this + if args.hf_model_type == "llava": + extend_embeddings(hf_model, hf_config) print(f"> Saving HF model to {args.hf_save_dir}") hf_model.save_pretrained(args.hf_save_dir) if args.upload_to_hub is not None: - hf_model.push_to_hub(args.upload_to_hub) + # TODO: still need to add the auto-map to the config files + # push everything to the hub + hf_model.push_to_hub(args.upload_to_hub, private=True) + processor.push_to_hub(args.upload_to_hub) + if args.hf_model_type == "nvlm_d": + # push the hf_modeling_files folder to the hub + pi = HfApi() + # get directory of the current script + script_dir = os.path.dirname(os.path.abspath(__file__)) + modeling_files_dir = f"{script_dir}/hf_modeling_files" + for file in os.listdir(modeling_files_dir): + # only send .py files + if not file.endswith(".py"): + continue + + pi.upload_file( + path_or_fileobj=f"{modeling_files_dir}/{file}", + repo_id=args.upload_to_hub, + path_in_repo=file, + repo_type="model", + ) + -def create_hf_config(original_text_model_id, original_vision_model_id, margs): +def create_hf_config( + original_text_model_id, original_vision_model_id, margs, hf_model_type="llava" +): """Create HF config from Megatron checkpoint""" # Extract model args from checkpoint assert margs.transformer_impl == "transformer_engine" @@ -238,16 +295,37 @@ def create_hf_config(original_text_model_id, original_vision_model_id, margs): text_config = AutoConfig.from_pretrained(original_text_model_id) # Create final LLaVA config combining both - hf_config = LlavaConfig( - vision_config=vision_config, - text_config=text_config, - vision_feature_layer=-1, # megatron uses the last layer - # Add any other LLaVA specific configs here - ) + if hf_model_type == "llava": + hf_config = LlavaConfig( + vision_config=vision_config, + text_config=text_config, + vision_feature_layer=-1, # megatron uses the last layer + # Add any other LLaVA specific configs here + ) + elif hf_model_type == "nvlm_d": + from hf_modeling_files.configuration_nvlm_d import NVLM_D_Config + + hf_config = NVLM_D_Config( + vision_config=vision_config, + text_config=text_config, + # Add any other NVLM-D specific configs here + ) + else: + raise ValueError(f"Unsupported model type: {hf_model_type}") + return hf_config -def create_hf_processor(hf_config, text_model_id, vision_model_id): +def create_hf_processor(hf_config, text_model_id, vision_model_id, hf_model_type="llava"): + if hf_model_type == "nvlm_d": + from hf_modeling_files.image_processing_nvlm_d import NVLM_D_ImageProcessor + from hf_modeling_files.processing_nvlm_d import NVLM_D_Processor + + image_processor = NVLM_D_ImageProcessor() + tokenizer = AutoTokenizer.from_pretrained(text_model_id) + processor = NVLM_D_Processor(image_processor=image_processor, tokenizer=tokenizer) + return processor + tokenizer = AutoTokenizer.from_pretrained(text_model_id) tokenizer.add_tokens(AddedToken("", special=True, normalized=False), special_tokens=True) tokenizer.add_special_tokens({"pad_token": ""}) @@ -256,16 +334,19 @@ def create_hf_processor(hf_config, text_model_id, vision_model_id): try: # WARNING: this was used for custom version transformer - # that implemented that llave image processing pipeline + # that implemented that llave image processing pipeline # see https://github.com/huggingface/transformers/pull/33191 # for a status on it being merged to the official repo from transformers.models.llava.image_processing_llava import LlavaImageProcessor + image_processor = LlavaImageProcessor() except ImportError: print("> WARNING: could not import LlavaImageProcessor, using AutoImageProcessor instead") - print("> This might lead to performance degradation due to slightly different image pre-processing") + print( + "> This might lead to performance degradation due to slightly different image pre-processing" + ) image_processor = AutoImageProcessor.from_pretrained(vision_model_id) - + processor = LlavaProcessor(tokenizer=tokenizer, image_processor=image_processor) return processor @@ -414,12 +495,6 @@ def convert_mcore2hf_language_model(mcore_sd, margs): # Attention weights qkv_weight = mcore_sd[f"{mcore_prefix}.self_attention.linear_qkv.weight"] - # Ensure the shape is divisible by 3 - - # load transformer llava and do the same - # from transformers import LlavaForConditionalGeneration - # llava_model = LlavaForConditionalGeneration.from_pretrained("llava-hf/llava-1.5-7b-hf") - hidden_size = qkv_weight.shape[1] num_kv_heads = margs.num_query_groups num_heads = margs.num_attention_heads @@ -460,6 +535,31 @@ def convert_mcore2hf_language_model(mcore_sd, margs): } ) + # some models have bias in the attention + if f"{mcore_prefix}.self_attention.linear_qkv.bias" in mcore_sd: + # split the bias into q, k, v, similar to the weights + qkv_bias = mcore_sd[f"{mcore_prefix}.self_attention.linear_qkv.bias"] + qkv_bias = qkv_bias.reshape(num_kv_heads, (num_queries_per_group + 2) * head_dim) + + # Split into q, k, v components + q_bias = qkv_bias[:, : num_queries_per_group * head_dim] + k_bias = qkv_bias[ + :, num_queries_per_group * head_dim : (num_queries_per_group + 1) * head_dim + ] + v_bias = qkv_bias[:, (num_queries_per_group + 1) * head_dim :] + + # Reshape to match HuggingFace format + q_bias = q_bias.reshape(-1) # Flatten to 1D + k_bias = k_bias.reshape(-1) + v_bias = v_bias.reshape(-1) + state_dict.update( + { + f"{hf_prefix}.self_attn.q_proj.bias": q_bias, + f"{hf_prefix}.self_attn.k_proj.bias": k_bias, + f"{hf_prefix}.self_attn.v_proj.bias": v_bias, + } + ) + # MLP weights # Note: In LLaMA, gate_proj and up_proj together form what was fc1 in the original architecture fc1_weight = mcore_sd[f"{mcore_prefix}.mlp.linear_fc1.weight"] @@ -497,11 +597,12 @@ def convert_mcore2hf_vision_projection(mcore_sd, margs): return state_dict + def extend_embeddings(hf_model, hf_config): # Initialize new embeddings for additional tokens # We use the average of the pre-expansion embeddings as the mean # and a small covariance matrix to ensure the new embeddings are close to the old ones - # adapted from + # adapted from # https://github.com/huggingface/transformers/blob/bf42c3bd4b088fd9df1086e63d47a8e33048e5e1/src/transformers/models/llava/convert_llava_weights_to_hf.py#L100 # TODO: it seems this might not be needed anymore in the new versions of HF?? # double check @@ -539,6 +640,5 @@ def extend_embeddings(hf_model, hf_config): ) - if __name__ == "__main__": main() diff --git a/examples/multimodal/evaluation_datasets.py b/examples/multimodal/evaluation_datasets.py index 66fbcdc1bd..32c112e18e 100644 --- a/examples/multimodal/evaluation_datasets.py +++ b/examples/multimodal/evaluation_datasets.py @@ -662,7 +662,9 @@ def __len__(self): return len(self._gt) def __getitem__(self, idx): - img_path = os.path.join(self._input_image_path, self._gt[idx]['image']) + # HACK: img_path assumes a certain structure, we want more flexibility + suffix_path = self._gt[idx]['image'].replace("data/ai2diagram/AI2D_TEST/", "") + img_path = os.path.join(self._input_image_path, suffix_path) if self._no_mask: img_path.replace("AI2D_TEST", "AI2D_TEST_NO_MASK_IMAGES") diff --git a/examples/multimodal/hf_modeling_files/configuration_nvlm_d.py b/examples/multimodal/hf_modeling_files/configuration_nvlm_d.py new file mode 100644 index 0000000000..b8b6f5f7b3 --- /dev/null +++ b/examples/multimodal/hf_modeling_files/configuration_nvlm_d.py @@ -0,0 +1,96 @@ +# -------------------------------------------------------- +# Adapted from https://huggingface.co/nvidia/NVLM-D-72B under MIT License +# -------------------------------------------------------- + +import copy + +from transformers import CONFIG_MAPPING +from transformers.configuration_utils import PretrainedConfig +from transformers.utils import logging + +logger = logging.get_logger(__name__) + + +class NVLM_D_Config(PretrainedConfig): + model_type = 'NVLM_D' + is_composition = True + auto_map = { + "AutoModel": "modeling_nvlm_d.NVLM_D_Model", + "AutoModelForCausalLM": "modeling_nvlm_d.NVLM_D_Model", + "AutoModelForConditionalGeneration": "modeling_nvlm_d.NVLM_D_Model", + } + + def __init__( + self, + vision_config=None, + text_config=None, + use_backbone_lora=0, + use_llm_lora=0, + select_layer=-1, + projector_hidden_act='gelu', + force_image_size=None, + downsample_ratio=0.5, + template="chatml", + dynamic_image_size=True, + use_thumbnail=True, + ps_version='v2', + min_dynamic_patch=1, + max_dynamic_patch=6, + **kwargs + ): + super().__init__(**kwargs) + + # if configs are dicts, convert them to config objects + if isinstance(vision_config, dict): + vision_config = CONFIG_MAPPING[vision_config['model_type']](**vision_config) + if isinstance(text_config, dict): + text_config = CONFIG_MAPPING[text_config['model_type']](**text_config) + + # then use the provided vision and text configs + self.vision_config = vision_config + self.text_config = text_config + + # Assign configuration values + self.use_backbone_lora = use_backbone_lora + self.use_llm_lora = use_llm_lora + self.select_layer = select_layer + self.projector_hidden_act = projector_hidden_act + self.force_image_size = force_image_size + self.downsample_ratio = downsample_ratio + self.template = template + self.dynamic_image_size = dynamic_image_size + self.use_thumbnail = use_thumbnail + self.ps_version = ps_version # Pixel shuffle version + self.min_dynamic_patch = min_dynamic_patch + self.max_dynamic_patch = max_dynamic_patch + + # Log important parameters + logger.info(f'vision_select_layer: {self.select_layer}') + logger.info(f'ps_version: {self.ps_version}') + logger.info(f'min_dynamic_patch: {self.min_dynamic_patch}') + logger.info(f'max_dynamic_patch: {self.max_dynamic_patch}') + + def to_dict(self): + """ + Serializes this instance to a Python dictionary. Overrides the default `PretrainedConfig.to_dict`. + + Returns: + Dict[str, Any]: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + output['vision_config'] = self.vision_config.to_dict() + output['text_config'] = self.text_config.to_dict() + output['model_type'] = self.model_type + output['use_backbone_lora'] = self.use_backbone_lora + output['use_llm_lora'] = self.use_llm_lora + output['select_layer'] = self.select_layer + output['force_image_size'] = self.force_image_size + output['downsample_ratio'] = self.downsample_ratio + output['template'] = self.template + output['dynamic_image_size'] = self.dynamic_image_size + output['use_thumbnail'] = self.use_thumbnail + output['ps_version'] = self.ps_version + output['min_dynamic_patch'] = self.min_dynamic_patch + output['max_dynamic_patch'] = self.max_dynamic_patch + + return output \ No newline at end of file diff --git a/examples/multimodal/hf_modeling_files/image_processing_nvlm_d.py b/examples/multimodal/hf_modeling_files/image_processing_nvlm_d.py new file mode 100644 index 0000000000..8939c355ef --- /dev/null +++ b/examples/multimodal/hf_modeling_files/image_processing_nvlm_d.py @@ -0,0 +1,209 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for NVLM-D.""" + +from typing import Dict, List, Optional, Union, Tuple, Set +import torch +import numpy as np +from PIL import Image + +import torchvision.transforms as T +from torchvision.transforms.functional import InterpolationMode + +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers.image_utils import ImageInput, valid_images +from transformers.utils import TensorType, logging + +logger = logging.get_logger(__name__) + +CLIP_MEAN = [0.48145466, 0.4578275, 0.40821073] +CLIP_STD = [0.26862954, 0.26130258, 0.27577711] + +class NVLM_D_ImageProcessor(BaseImageProcessor): + """ + Image processor for NVLM-D. Handles dynamic high-resolution image processing with tile-tagging. + """ + + model_input_names = ["pixel_values"] + auto_map = { + "AutoImageProcessor": "image_processing_nvlm_d.NVLM_D_ImageProcessor", + } + + def __init__( + self, + image_size: int = 336, + max_num: int = 6, + min_num: int = 1, + use_thumbnail: bool = True, + mean: List[float] = CLIP_MEAN, + std: List[float] = CLIP_STD, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.image_size = image_size + self.max_num = max_num + self.min_num = min_num + self.use_thumbnail = use_thumbnail + self.mean = mean + self.std = std + + def build_transform(self, input_size: int) -> T.Compose: + """Build the transformation pipeline.""" + # TODO: only CLIP ordering for now + return T.Compose([ + T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC), + T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img), + T.ToTensor(), + T.Normalize(mean=self.mean, std=self.std) + ]) + + def find_closest_aspect_ratio( + self, + aspect_ratio: float, + target_ratios: Set[Tuple[int, int]], + width: int, + height: int, + image_size: int + ) -> Tuple[int, int]: + """Find the closest aspect ratio from target ratios.""" + best_ratio_diff = float('inf') + best_ratio = (1, 1) + area = width * height + + for ratio in target_ratios: + target_aspect_ratio = ratio[0] / ratio[1] + ratio_diff = abs(aspect_ratio - target_aspect_ratio) + if ratio_diff < best_ratio_diff: + best_ratio_diff = ratio_diff + best_ratio = ratio + elif ratio_diff == best_ratio_diff: + if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]: + best_ratio = ratio + return best_ratio + + def dynamic_preprocess( + self, + image: Image.Image, + min_num: int = None, + max_num: int = None, + image_size: int = None, + use_thumbnail: bool = None + ) -> List[Image.Image]: + """Process image into tiles based on aspect ratio.""" + min_num = min_num if min_num is not None else self.min_num + max_num = max_num if max_num is not None else self.max_num + image_size = image_size if image_size is not None else self.image_size + use_thumbnail = use_thumbnail if use_thumbnail is not None else self.use_thumbnail + + orig_width, orig_height = image.size + aspect_ratio = orig_width / orig_height + + # Calculate target ratios + target_ratios = set( + (i, j) for n in range(min_num, max_num + 1) + for i in range(1, n + 1) + for j in range(1, n + 1) + if i * j <= max_num and i * j >= min_num + ) + target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1]) + + # Find closest aspect ratio + target_aspect_ratio = self.find_closest_aspect_ratio( + aspect_ratio, target_ratios, orig_width, orig_height, image_size + ) + + # Calculate target dimensions + target_width = image_size * target_aspect_ratio[0] + target_height = image_size * target_aspect_ratio[1] + blocks = target_aspect_ratio[0] * target_aspect_ratio[1] + + # Resize and split image + resized_img = image.resize((target_width, target_height)) + processed_images = [] + + for i in range(blocks): + box = ( + (i % (target_width // image_size)) * image_size, + (i // (target_width // image_size)) * image_size, + ((i % (target_width // image_size)) + 1) * image_size, + ((i // (target_width // image_size)) + 1) * image_size + ) + split_img = resized_img.crop(box) + processed_images.append(split_img) + + assert len(processed_images) == blocks + + if use_thumbnail and len(processed_images) != 1: + thumbnail_img = image.resize((image_size, image_size)) + processed_images.append(thumbnail_img) + + return processed_images + + def preprocess( + self, + images: Union[ImageInput, List[ImageInput]], + input_size: Optional[int] = None, + max_num: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + ) -> BatchFeature: + """ + Preprocess an image or batch of images for the NVLM-D model. + + Args: + images: A single image or list of images + input_size: Size to resize image patches to + max_num: Maximum number of patches + return_tensors: Type of tensors to return ("pt" for PyTorch) + + Returns: + BatchFeature containing preprocessed pixel values + """ + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + input_size = input_size if input_size is not None else self.image_size + max_num = max_num if max_num is not None else self.max_num + + if not isinstance(images, (list, tuple)): + images = [images] + + transform = self.build_transform(input_size=input_size) + + all_pixel_values = [] + for image in images: + if not isinstance(image, Image.Image): + image = Image.fromarray(image) + + processed_images = self.dynamic_preprocess( + image, + image_size=input_size, + max_num=max_num + ) + pixel_values = [transform(img) for img in processed_images] + pixel_values = torch.stack(pixel_values) + all_pixel_values.append(pixel_values) + + if len(all_pixel_values) == 1: + all_pixel_values = all_pixel_values[0] + else: + all_pixel_values = torch.stack(all_pixel_values) + + return BatchFeature( + data={"pixel_values": all_pixel_values}, + tensor_type=return_tensors + ) \ No newline at end of file diff --git a/examples/multimodal/hf_modeling_files/modeling_nvlm_d.py b/examples/multimodal/hf_modeling_files/modeling_nvlm_d.py new file mode 100644 index 0000000000..561cd12b38 --- /dev/null +++ b/examples/multimodal/hf_modeling_files/modeling_nvlm_d.py @@ -0,0 +1,260 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import AutoModel, AutoModelForCausalLM, PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_outputs import ModelOutput +from transformers.utils import logging + +from .configuration_nvlm_d import NVLM_D_Config + +@dataclass +class NVLM_D_CausalLMOutputWithPast(ModelOutput): + """ + Output class for NVLM-D causal language model outputs. + + Args: + loss (`torch.FloatTensor` of shape `(1,)`, *optional*): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*): + Contains pre-computed hidden-states for faster sequential decoding. + hidden_states (`tuple(torch.FloatTensor)`, *optional*): + Hidden-states of the model. + attentions (`tuple(torch.FloatTensor)`, *optional*): + Attention weights. + image_hidden_states (`torch.FloatTensor`, *optional*): + Image features after projection. + """ + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + past_key_values: Optional[List[torch.FloatTensor]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + image_hidden_states: Optional[torch.FloatTensor] = None + + +class NVLM_D_MultiModalProjector(nn.Module): + def __init__(self, config: NVLM_D_Config): + super().__init__() + projector_input_size = config.vision_config.hidden_size * int(1 / config.downsample_ratio) ** 2 + self.linear_1 = nn.Linear(projector_input_size, config.text_config.hidden_size, bias=True) + self.act = ACT2FN[config.projector_hidden_act] + self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size, bias=True) + + def forward(self, image_features): + hidden_states = self.linear_1(image_features) + hidden_states = self.act(hidden_states) + hidden_states = self.linear_2(hidden_states) + return hidden_states + + +class NVLM_D_PreTrainedModel(PreTrainedModel): + """Base class for NVLM-D model.""" + config_class = NVLM_D_Config + base_model_prefix = "nvlm_d" + supports_gradient_checkpointing = True + _no_split_modules = ["CLIPVisionModel", "Qwen2DecoderLayer"] + + def _init_weights(self, module): + std = self.config.text_config.initializer_range + if isinstance(module, (nn.Linear, nn.Conv2d)): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + +# TODO: what to do to support flash attention? +class NVLM_D_Model(NVLM_D_PreTrainedModel): + """NVLM-D model for dynamic high-resolution image understanding.""" + + def __init__(self, config: NVLM_D_Config): + super().__init__(config) + + self.config = config + self.vision_tower = AutoModel.from_config(config.vision_config) + self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.multi_modal_projector = NVLM_D_MultiModalProjector(config) + + # Model attributes + self.patch_size = config.vision_config.patch_size + self.select_layer = config.select_layer + self.downsample_ratio = config.downsample_ratio + self.ps_version = config.ps_version + + image_size = config.force_image_size or config.vision_config.image_size + self.num_image_tokens = int((image_size // self.patch_size) ** 2 * (config.downsample_ratio ** 2)) + + # Generation attributes + # TODO: hard-coded for now + self.img_context_token_id = 151654 + self.main_input_name = "input_ids" + + def pixel_shuffle(self, x: torch.Tensor, scale_factor: float = 0.5) -> torch.Tensor: + """Perform pixel shuffling for downsampling.""" + n, w, h, c = x.size() + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + x = x.permute(0, 2, 1, 3).contiguous() + x = x.view(n, int(h * scale_factor), int(w * scale_factor), + int(c / (scale_factor * scale_factor))) + if self.ps_version != 'v1': + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def get_image_features(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor: + """Extract and project image features.""" + if self.select_layer == -1: + vit_embeds = self.vision_tower( + pixel_values=pixel_values, + output_hidden_states=False, + return_dict=True + ).last_hidden_state + else: + vit_embeds = self.vision_tower( + pixel_values=pixel_values, + output_hidden_states=True, + return_dict=True + ).hidden_states[self.select_layer] + + vit_embeds = vit_embeds[:, 1:, :] + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + return self.multi_modal_projector(vit_embeds) + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + pixel_values: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, NVLM_D_CausalLMOutputWithPast]: + """Forward pass of the model.""" + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds") + + if inputs_embeds is None: + inputs_embeds = self.language_model.get_input_embeddings()(input_ids) + + image_features = None + if pixel_values is not None: + image_features = self.get_image_features(pixel_values) + + # Replace image context tokens with image features + B, N, C = inputs_embeds.shape + inputs_embeds = inputs_embeds.reshape(B * N, C) + input_ids = input_ids.reshape(B * N) + selected = (input_ids == self.img_context_token_id) + + try: + inputs_embeds[selected] = image_features.reshape(-1, C) + except Exception as e: + image_features = image_features.reshape(-1, C) + n_token = selected.sum() + inputs_embeds[selected] = image_features[:n_token] + + inputs_embeds = inputs_embeds.reshape(B, N, C) + + outputs = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + logits = outputs[0] if not return_dict else outputs.logits + + loss = None + if labels is not None: + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.text_config.vocab_size) + shift_labels = shift_labels.view(-1) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return NVLM_D_CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values if return_dict else outputs[1], + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + image_hidden_states=image_features, + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + pixel_values=None, + **kwargs + ): + """Prepare inputs for text generation.""" + if past_key_values: + input_ids = input_ids[:, -1:] + + # Only forward pixel_values on the first call + if past_key_values is None: + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "inputs_embeds": inputs_embeds, + "pixel_values": pixel_values, + } + else: + model_inputs = { + "input_ids": input_ids, + "attention_mask": attention_mask, + "past_key_values": past_key_values, + "inputs_embeds": inputs_embeds, + } + + return model_inputs + diff --git a/examples/multimodal/hf_modeling_files/processing_nvlm_d.py b/examples/multimodal/hf_modeling_files/processing_nvlm_d.py new file mode 100644 index 0000000000..153423ff4d --- /dev/null +++ b/examples/multimodal/hf_modeling_files/processing_nvlm_d.py @@ -0,0 +1,168 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for NVLM-D. +""" + +from typing import List, Union + +from transformers.feature_extraction_utils import BatchFeature +from transformers.processing_utils import ProcessingKwargs, ProcessorMixin, Unpack +from transformers.tokenization_utils_base import PreTokenizedInput, TextInput +from transformers.utils import logging +from transformers.image_utils import ImageInput + +logger = logging.get_logger(__name__) + + +class NVLM_D_ProcessorKwargs(ProcessingKwargs, total=False): + _defaults = { + "text_kwargs": { + "padding": False, + }, + "images_kwargs": {}, + } + + +class NVLM_D_Processor(ProcessorMixin): + r""" + Constructs a NVLM-D processor which wraps a NVLM-D image processor and a tokenizer into a single processor. + + Args: + image_processor ([`NVLM_D_ImageProcessor`]): + The image processor for handling dynamic high-resolution image processing. + tokenizer ([`PreTrainedTokenizer`]): + The tokenizer for processing text. + image_token (`str`, *optional*, defaults to ""): + Special token used to denote image location in text. + tile_token_format (`str`, *optional*, defaults to ""): + Format string for tile position tokens. + global_token (`str`, *optional*, defaults to ""): + Token used for the global thumbnail image. + """ + + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + auto_map = { + "AutoProcessor": "processing_nvlm_d.NVLM_D_Processor", + } + + def __init__( + self, + image_processor=None, + tokenizer=None, + image_token="", + tile_token_format="", + global_token="", + image_context_token="<|vision_pad|>", + num_image_tokens=144, + **kwargs, + ): + super().__init__(image_processor, tokenizer) + self.image_token = image_token + self.tile_token_format = tile_token_format + self.global_token = global_token + self.image_context_token = image_context_token + self.num_image_tokens = num_image_tokens + def __call__( + self, + images: ImageInput = None, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None, + **kwargs: Unpack[NVLM_D_ProcessorKwargs], + ) -> BatchFeature: + """ + Main method to prepare inputs for the NVLM-D model. Processes both images and text. + + Args: + images: The image or batch of images to be processed. + text: The text or batch of texts to be processed. + + Returns: + BatchFeature: A BatchFeature with the following fields: + - input_ids: Token ids for the text + - attention_mask: Attention mask for text tokens + - pixel_values: Processed image patches + - image_sizes: Original sizes of the images + """ + if images is None and text is None: + raise ValueError("You have to specify at least images or text.") + + output_kwargs = self._merge_kwargs( + NVLM_D_ProcessorKwargs, + tokenizer_init_kwargs=self.tokenizer.init_kwargs, + **kwargs, + ) + + # Process images if provided + if images is not None: + image_inputs = self.image_processor(images, **output_kwargs["images_kwargs"]) + num_patches = image_inputs.pixel_values.shape[0] # Get number of patches including thumbnail + else: + image_inputs = {} + num_patches = 0 + + # Process text + if text is not None: + if isinstance(text, str): + text = [text] + + # Replace image tokens with appropriate tile tokens + processed_texts = [] + for txt in text: + if self.image_token in txt and num_patches > 0: + # Generate tile tokens + tile_tokens = [] + for i in range(1, num_patches): # Start from 1 as per original code + tile_tokens.append(self.tile_token_format.format(i)) + if num_patches > 1: # Add global thumbnail token if we have multiple patches + tile_tokens.append(self.global_token) + + # Create image token sequence + image_token_sequence = "" + for tile_token in tile_tokens: + image_token_sequence += tile_token + self.image_context_token * self.num_image_tokens + + # Replace with the full sequence + txt = txt.replace(self.image_token, f"{image_token_sequence}") + + processed_texts.append(txt) + + text_inputs = self.tokenizer(processed_texts, **output_kwargs["text_kwargs"]) + else: + text_inputs = {} + + return BatchFeature(data={**text_inputs, **image_inputs}) + + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to the tokenizer's batch_decode. + """ + return self.tokenizer.batch_decode(*args, **kwargs) + + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to the tokenizer's decode. + """ + return self.tokenizer.decode(*args, **kwargs) + + @property + def model_input_names(self): + """ + Get the model input names from both tokenizer and image processor. + """ + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) \ No newline at end of file diff --git a/examples/multimodal/model.py b/examples/multimodal/model.py index 74898e73a3..991454f01a 100644 --- a/examples/multimodal/model.py +++ b/examples/multimodal/model.py @@ -37,6 +37,8 @@ def model_provider( print_rank_0('building a multimodal model ...') + tokenizer = get_tokenizer() + tile_tag_length = len(tokenizer.tokenize(f"")) num_image_embeddings = get_num_image_embeddings( args.img_h, args.img_w, @@ -46,6 +48,7 @@ def model_provider( 1, args.pixel_shuffle, args.use_tile_tags, + tile_tag_length, ) old_seq_length = args.seq_length args.seq_length = args.encoder_seq_length = num_image_embeddings @@ -136,7 +139,6 @@ def model_provider( else: vision_projection_layer_spec = get_mlp_module_spec(use_te=use_te).submodules - tokenizer = get_tokenizer() image_token_index = tokenizer.convert_tokens_to_ids(IMAGE_TOKEN) tile_tags = _get_tile_tags(args, tokenizer) @@ -168,6 +170,7 @@ def model_provider( image_token_index=image_token_index, pixel_shuffle=args.pixel_shuffle, tile_tags=tile_tags, + tile_tag_length=tile_tag_length, ) model.freeze( @@ -188,7 +191,9 @@ def _get_tile_tags(args, tokenizer): thumbnail_tag_text = "" if args.tokenizer_prompt_format == "nvlm-yi-34b": thumbnail_tag_text = "" - + if args.tokenizer_model == "utter-project/EuroLLM-9B-Instruct": + thumbnail_tag_text = "" + assert args.max_num_tiles <= 6, "Up to 6 tile tags used" tile_tags_text = [f"" for i in range(1, args.max_num_tiles + 1)] + [thumbnail_tag_text] diff --git a/examples/multimodal/pangea_instruct.yaml b/examples/multimodal/pangea_instruct.yaml new file mode 100644 index 0000000000..ddefcc102a --- /dev/null +++ b/examples/multimodal/pangea_instruct.yaml @@ -0,0 +1,15 @@ +__module__: megatron.energon +__class__: Metadataset +splits: + train: + datasets: + - weight: 1. + path: /lustre/fswork/projects/rech/qjm/ued79zb/visionblocks_datasets/Unbabel-VisionBlocks-pangea-instruct/ + subflavors: + augmentation: false + val: + datasets: + - weight: 1. + path: /lustre/fswork/projects/rech/qjm/ued79zb/visionblocks_datasets/Unbabel-VisionBlocks-pangea-instruct/ + subflavors: + augmentation: false \ No newline at end of file diff --git a/examples/multimodal/pixmo_caps.yaml b/examples/multimodal/pixmo_caps.yaml new file mode 100644 index 0000000000..53b85d294f --- /dev/null +++ b/examples/multimodal/pixmo_caps.yaml @@ -0,0 +1,15 @@ +__module__: megatron.energon +__class__: Metadataset +splits: + train: + datasets: + - weight: 1. + path: /lustre/fswork/projects/rech/qjm/ued79zb/visionblocks_datasets/Unbabel-VisionBlocks-pixmo-cap/ + subflavors: + augmentation: false + val: + datasets: + - weight: 1. + path: /lustre/fswork/projects/rech/qjm/ued79zb/visionblocks_datasets/Unbabel-VisionBlocks-pixmo-cap/ + subflavors: + augmentation: false \ No newline at end of file diff --git a/examples/multimodal/run_text_generation.py b/examples/multimodal/run_text_generation.py index cadc1120bc..89f4244fe5 100644 --- a/examples/multimodal/run_text_generation.py +++ b/examples/multimodal/run_text_generation.py @@ -140,6 +140,9 @@ def generate_samples(model, config: EvaluationConfig, print_output): args.vision_model_type, ) + tokenizer = get_tokenizer() + tile_tag_length = len(tokenizer.tokenize(f"")) + num_img_embeddings_per_tile = get_num_image_embeddings( args.img_h, args.img_w, @@ -149,6 +152,7 @@ def generate_samples(model, config: EvaluationConfig, print_output): 1, args.pixel_shuffle, args.use_tile_tags, + tile_tag_length, ) for idx, (imgs, num_tiles, sample_id, question, answers, metadata) in enumerate(dataloader): @@ -461,6 +465,10 @@ def get_prompt_and_generated(prompt_and_generation, prompt_format): generated = generated.split("")[0] elif prompt_format == "chatml": splitted = prompt_and_generation.split("<|im_start|> assistant\n") + if len(splitted) == 1: + # TODO: better handling of this case + splitted = prompt_and_generation.split("<|im_start|>assistant\n") + prompt = splitted[0] generated = splitted[1] generated = generated.split("<|im_end|>")[0] diff --git a/megatron/core/models/multimodal/llava_model.py b/megatron/core/models/multimodal/llava_model.py index 3b46487f87..210678d604 100644 --- a/megatron/core/models/multimodal/llava_model.py +++ b/megatron/core/models/multimodal/llava_model.py @@ -99,6 +99,7 @@ def __init__( image_token_index: int = DEFAULT_IMAGE_TOKEN_INDEX, pixel_shuffle: bool = False, tile_tags: Optional[list] = None, + tile_tag_length: int = 5, ) -> None: super().__init__(config=language_transformer_config) @@ -210,6 +211,7 @@ def __init__( class_token_len, pixel_shuffle, tile_tags is not None, # Tile tags enabled/disabled. + tile_tag_length, ) self.image_token_index = image_token_index @@ -520,8 +522,9 @@ def _preprocess_data( return final_embedding, final_labels, final_loss_mask, attention_mask - def _apply_tile_tagging(self, image_embeddings, num_image_tiles): - """Apply tile tagging. + def _apply_tile_tagging_1bsz(self, image_embeddings, num_image_tiles): + """WARNING: This is deprecated, in favour of the new batched tile tagging. Kept for reference. + Apply tile tagging. The image embeddings of multiple tiles are prepended with tile tags such as . This implements the method used in NVLM https://arxiv.org/pdf/2409.11402. @@ -536,7 +539,7 @@ def _apply_tile_tagging(self, image_embeddings, num_image_tiles): """ assert ( num_image_tiles.shape[0] == 1 and len(num_image_tiles) == 1 - ), "multiple input images are not supported yet." + ), "multiple input images are not supported yet, got {}".format(num_image_tiles) num_tiles = num_image_tiles[0].item() tile_tags = self._tile_tags[: num_tiles - 1] + [self._tile_tags[-1]] @@ -556,6 +559,44 @@ def _apply_tile_tagging(self, image_embeddings, num_image_tiles): return image_embeddings # [tile_seq_len + img_seq_len, num_tiles, h_language] + + def _apply_tile_tagging(self, image_embeddings, num_image_tiles): + """Apply tile tagging for batched processing. + + Args: + image_embeddings (torch.Tensor): [img_seq_len, total_num_tiles, h_language] + where total_num_tiles is the sum of tiles across all images + num_image_tiles (torch.Tensor): Number of tiles for each input image [num_images] + + Returns: + torch.Tensor: Tile tags prepended to image embeddings. + [tile_seq_len + img_seq_len, total_num_tiles, h_language] + """ + # Split embeddings by number of tiles per image + tile_splits = torch.split(image_embeddings, num_image_tiles.tolist(), dim=1) + + processed_embeddings = [] + for tiles, num_tiles in zip(tile_splits, num_image_tiles): + # Get the tile tags for current image + tile_tags = self._tile_tags[: num_tiles - 1] + [self._tile_tags[-1]] + + # [num_tiles, tile_seq_len] + tile_tag_input_ids = torch.tensor( + tile_tags, dtype=torch.int64, device=num_image_tiles.device + ) + + # [tile_seq_len, num_tiles, h_language] + tile_tag_embeds = self.language_model.embedding(tile_tag_input_ids, position_ids=None) + + # Concatenate tile tags with image embeddings + # tiles shape: [img_seq_len, num_tiles, h_language] + # tile_tag_embeds shape: [tile_seq_len, num_tiles, h_language] + processed = torch.cat([tile_tag_embeds, tiles], dim=0) + processed_embeddings.append(processed) + + # Concatenate all processed embeddings along the tiles dimension + return torch.cat(processed_embeddings, dim=1) + def forward( self, images: torch.Tensor, diff --git a/megatron/core/models/vision/clip_vit_model.py b/megatron/core/models/vision/clip_vit_model.py index 2fdc77a4f7..1728989713 100644 --- a/megatron/core/models/vision/clip_vit_model.py +++ b/megatron/core/models/vision/clip_vit_model.py @@ -195,6 +195,7 @@ def get_num_image_embeddings( class_token_len, pixel_shuffle=False, use_tile_tags=False, + tile_tag_length=5, ): """Get the number of image embeddings per image tile.""" if vision_model_type == "siglip": @@ -213,7 +214,6 @@ def get_num_image_embeddings( num_image_embeddings_per_tile = int(num_image_embeddings_per_tile * (0.5**2)) if use_tile_tags: - # The length of tile tags tokenized. Currently, the same across tokenizers used. - num_image_embeddings_per_tile += 5 + num_image_embeddings_per_tile += tile_tag_length return num_image_embeddings_per_tile diff --git a/tapes/main.tape b/tapes/main.tape index 7456d4257e..849a7f251e 100644 --- a/tapes/main.tape +++ b/tapes/main.tape @@ -4,98 +4,23 @@ global { ducttape_experimental_imports=true ducttape_experimental_submitters=true ducttape_experimental_multiproc=true - - ducttape_output=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_outputs - repo=/linkhome/rech/genrce01/ued79zb/repos/Megatron-LM-pretrain - - # multimodal model parameters - # (base lm, tp, etc...) - clip_original_dir=/lustre/fswork/projects/rech/qjm/ued79zb/clip_model_og/ - mistral_model=( - TextModel: - mistral="mistralai/Mistral-7B-Instruct-v0.3" - tower="Unbabel/TowerInstruct-Mistral-7B-v0.2" - ) - prompt_format=( - TextModel: - mistral="mistral" - tower="chatml" - ) - tp=4 - pp=1 - - # pre-training arguments - external_model_dir=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/mistral_7b_instruct_prt - external_tensorboard=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/mistral_7b_instruct_prt - pretrain_iters=2000 - pretrain_bsz=256 - pretrain_lr=0.001 - pretrain_lr_warmup=0.03 - pretrain_save_interval=500 - pretrain_eval_interval=500 - - # fine-tuning arguments - finetune_model_dir=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/mistral_7b_instruct_sft - finetune_tensorboard=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/mistral_7b_instruct_sft - finetune_iters=5000 - finetune_bsz=128 - finetune_lr=1e-6 - finetune_lr_warmup=0.01 - finetune_save_interval=1000 - finetune_eval_interval=1000 - - # eval arguments - coco_dir=/lustre/fswork/projects/rech/qjm/ued79zb/coco/ - mmmu_dir=/dev/null - eval_bsz=32 - - # convert arguments - upload_id="patricksf/mistral-7b-clip-prt" - - # -- submitter arguments -- - submitter=scslurm - - prepare_account="qjm@cpu" - prepare_time="1:00:00" - prepare_cpus=4 - prepare_partition="prepost" - - pretrain_C="h100" - pretrain_account="qjm@h100" - pretrain_time="10:00:00" - pretrain_cpus=80 - pretrain_gres="gpu:4" - - finetune_C="h100" - finetune_account="qjm@h100" - finetune_time="10:00:00" - finetune_cpus=80 - finetune_gres="gpu:4" - - eval_C="h100" - eval_account="qjm@h100" - eval_time="1:00:00" - eval_cpus=80 - eval_gres="gpu:4" - - convert_account="qjm@cpu" - convert_time="1:00:00" - convert_cpus=4 - convert_partition="prepost" } task PrepareModel > initial_model :: repo=@ :: clip_original_dir=@ - :: mistral_model=@ + :: model_name=@ + :: model_type=@ :: tp=@ :: pp=@ :: .submitter=@ + :: .C=$prepare_C :: .account=$prepare_account :: .time=$prepare_time :: .cpus=$prepare_cpus :: .partition=$prepare_partition + :: .gres=$prepare_gres { # Download & convert CLIP model echo "Downloading & converting CLIP model..." @@ -108,16 +33,16 @@ task PrepareModel # Download & convert language model echo "Downloading & converting language model..." python $repo/examples/multimodal/download_hf_model.py \ - --model ${mistral_model} \ + --model ${model_name} \ --output-dir mistral_original_dir python $repo/tools/checkpoint/convert.py --model-type GPT \ --loader llama_mistral \ --saver mcore \ --checkpoint-type hf \ - --model-size mistral-7B \ + --model-size ${model_type} \ --load-dir mistral_original_dir \ --save-dir mistral_mcore_dir \ - --tokenizer-model ${mistral_model} \ + --tokenizer-model ${model_name} \ --target-tensor-parallel-size ${tp} \ --target-pipeline-parallel-size ${pp} \ --bf16 @@ -138,17 +63,24 @@ task PretrainModel < initial_model=@PrepareModel > model_dir :: repo=@ - :: tokenizer_model=$mistral_model + :: tokenizer_model=$model_name + :: model_type=@ :: prompt_format=@ + :: image_preproc=@ + :: pixel_shuffle=@ + :: pretrain_dataset=@ :: train_iters=$pretrain_iters :: batch_size=$pretrain_bsz :: lr=$pretrain_lr :: lr_warmup_fraction=$pretrain_lr_warmup + :: unfreeze_lm=$pretrain_unfreeze_lm + :: unfreeze_vit=$pretrain_unfreeze_vit :: save_interval=$pretrain_save_interval :: eval_interval=$pretrain_eval_interval :: external_resume=true :: external_model_dir=@ :: external_tensorboard=@ + :: num_workers=@ :: .submitter=@ :: .C=$pretrain_C :: .account=$pretrain_account @@ -180,6 +112,97 @@ task PretrainModel export NVTE_APPLY_QK_LAYER_SCALING=0 export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 + # define custom arguments based on model type + if [ "$model_type" == "mistral-7B" ]; then + # mistral-7B is the only model that requires the --disable-vision-class-token flag + MODEL_ARGS="--language-model-type=mistral_7b" + MODEL_ARGS="$MODEL_ARGS --disable-vision-class-token" + MODEL_ARGS="$MODEL_ARGS --normalization RMSNorm" + MODEL_ARGS="$MODEL_ARGS --group-query-attention" + MODEL_ARGS="$MODEL_ARGS --num-query-groups 8" + MODEL_ARGS="$MODEL_ARGS --no-masked-softmax-fusion" + MODEL_ARGS="$MODEL_ARGS --use-flash-attn" + MODEL_ARGS="$MODEL_ARGS --untie-embeddings-and-output-weights" + MODEL_ARGS="$MODEL_ARGS --disable-bias-linear" + MODEL_ARGS="$MODEL_ARGS --position-embedding-type rope" + MODEL_ARGS="$MODEL_ARGS --rotary-percent 1.0" + MODEL_ARGS="$MODEL_ARGS --rotary-base 1000000" + MODEL_ARGS="$MODEL_ARGS --swiglu" + MODEL_ARGS="$MODEL_ARGS --attention-dropout 0.0" + MODEL_ARGS="$MODEL_ARGS --hidden-dropout 0.1" + MODEL_ARGS="$MODEL_ARGS --tensor-model-parallel-size 4" + MODEL_ARGS="$MODEL_ARGS --pipeline-model-parallel-size 1" + MODEL_ARGS="$MODEL_ARGS --num-layers 32" + MODEL_ARGS="$MODEL_ARGS --hidden-size 4096" + MODEL_ARGS="$MODEL_ARGS --num-attention-heads 32" + MODEL_ARGS="$MODEL_ARGS --decoder-seq-length 1024" + MODEL_ARGS="$MODEL_ARGS --max-position-embeddings 4096" + MODEL_ARGS="$MODEL_ARGS --ffn-hidden-size 14336" + elif [ "$model_type" == "qwen2.5-7B" ]; then + MODEL_ARGS="--language-model-type=qwen2.5_7b" + MODEL_ARGS="$MODEL_ARGS --normalization RMSNorm" + MODEL_ARGS="$MODEL_ARGS --group-query-attention" + MODEL_ARGS="$MODEL_ARGS --num-query-groups 4" + MODEL_ARGS="$MODEL_ARGS --no-masked-softmax-fusion" + MODEL_ARGS="$MODEL_ARGS --use-flash-attn" + MODEL_ARGS="$MODEL_ARGS --untie-embeddings-and-output-weights" + MODEL_ARGS="$MODEL_ARGS --disable-bias-linear" + MODEL_ARGS="$MODEL_ARGS --add-qkv-bias" + MODEL_ARGS="$MODEL_ARGS --position-embedding-type rope" + MODEL_ARGS="$MODEL_ARGS --rotary-base 1000000" + MODEL_ARGS="$MODEL_ARGS --swiglu" + MODEL_ARGS="$MODEL_ARGS --no-bias-swiglu-fusion" + MODEL_ARGS="$MODEL_ARGS --no-rope-fusion" + MODEL_ARGS="$MODEL_ARGS --attention-dropout 0.0" + MODEL_ARGS="$MODEL_ARGS --hidden-dropout 0.0" + MODEL_ARGS="$MODEL_ARGS --tensor-model-parallel-size 4" + MODEL_ARGS="$MODEL_ARGS --pipeline-model-parallel-size 1" + MODEL_ARGS="$MODEL_ARGS --num-layers 28" + MODEL_ARGS="$MODEL_ARGS --hidden-size 3584" + MODEL_ARGS="$MODEL_ARGS --num-attention-heads 28" + MODEL_ARGS="$MODEL_ARGS --decoder-seq-length 1024" + MODEL_ARGS="$MODEL_ARGS --max-position-embeddings 131072" + MODEL_ARGS="$MODEL_ARGS --ffn-hidden-size 18944" + MODEL_ARGS="$MODEL_ARGS --norm-epsilon 1e-6" + elif [ "$model_type" == "eurollm-9B" ]; then + # TODO: check if rotary-base and norm-epsilon are correct + MODEL_ARGS="--language-model-type=eurollm_9b" + MODEL_ARGS="$MODEL_ARGS --normalization RMSNorm" + MODEL_ARGS="$MODEL_ARGS --group-query-attention" + MODEL_ARGS="$MODEL_ARGS --num-query-groups 8" + MODEL_ARGS="$MODEL_ARGS --no-masked-softmax-fusion" + MODEL_ARGS="$MODEL_ARGS --use-flash-attn" + MODEL_ARGS="$MODEL_ARGS --untie-embeddings-and-output-weights" + MODEL_ARGS="$MODEL_ARGS --disable-bias-linear" + MODEL_ARGS="$MODEL_ARGS --position-embedding-type rope" + MODEL_ARGS="$MODEL_ARGS --rotary-base 10000" + MODEL_ARGS="$MODEL_ARGS --swiglu" + MODEL_ARGS="$MODEL_ARGS --attention-dropout 0.0" + MODEL_ARGS="$MODEL_ARGS --hidden-dropout 0.0" + MODEL_ARGS="$MODEL_ARGS --tensor-model-parallel-size 4" + MODEL_ARGS="$MODEL_ARGS --pipeline-model-parallel-size 1" + MODEL_ARGS="$MODEL_ARGS --num-layers 42" + MODEL_ARGS="$MODEL_ARGS --hidden-size 4096" + MODEL_ARGS="$MODEL_ARGS --num-attention-heads 32" + MODEL_ARGS="$MODEL_ARGS --decoder-seq-length 1024" + MODEL_ARGS="$MODEL_ARGS --max-position-embeddings 4096" + MODEL_ARGS="$MODEL_ARGS --ffn-hidden-size 12288" + fi + + if [ "$image_preproc" == "basic" ]; then + MODEL_ARGS="$MODEL_ARGS --seq-length 576" + elif [ "$image_preproc" == "nvlm" ]; then + MODEL_ARGS="$MODEL_ARGS --seq-length 256" # Image embeddings sequence length + MODEL_ARGS="$MODEL_ARGS --image-tag-type nvlm" + else + echo "Invalid image preproc: $image_preproc" + exit 1 + fi + + if [ "$pixel_shuffle" == true ]; then + MODEL_ARGS="$MODEL_ARGS --pixel-shuffle" + fi + torchrun --nproc_per_node 4 $repo/examples/multimodal/train.py \ --apply-layernorm-1p \ --attention-softmax-in-fp32 \ @@ -187,29 +210,7 @@ task PretrainModel --use-distributed-optimizer \ --transformer-impl transformer_engine \ --use-te \ - --normalization RMSNorm \ - --group-query-attention \ - --num-query-groups 8 \ - --no-masked-softmax-fusion \ - --num-workers 2 \ - --use-flash-attn \ - --untie-embeddings-and-output-weights \ - --disable-bias-linear \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 1000000 \ - --swiglu \ - --attention-dropout 0.0 \ - --hidden-dropout 0.1 \ - --tensor-model-parallel-size 4 \ - --pipeline-model-parallel-size 1 \ - --num-layers 32 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --seq-length 576 \ - --decoder-seq-length 1024 \ - --max-position-embeddings 4096 \ - --ffn-hidden-size 14336 \ + $MODEL_ARGS \ --train-iters ${train_iters} \ --micro-batch-size 16 \ --global-batch-size ${batch_size} \ @@ -224,13 +225,14 @@ task PretrainModel --tokenizer-type MultimodalTokenizer \ --tokenizer-model ${tokenizer_model} \ --tokenizer-prompt-format ${prompt_format} \ - --data-path ${repo}/examples/multimodal/pretrain_dataset.yaml \ + --data-path $pretrain_dataset \ --prompt-path ${repo}/examples/multimodal/manual_prompts.json \ --save-interval ${save_interval} \ --save ${model_dir} \ --load ${model_dir} \ --dataloader-save ${model_dir}/dataloader \ --pretrained-checkpoint ${initial_model} \ + --num-workers $num_workers \ --split 100,0,0 \ --clip-grad 1.0 \ --weight-decay 1e-2 \ @@ -241,29 +243,172 @@ task PretrainModel --log-num-zeros-in-grad \ --bf16 \ --eod-mask-loss \ - --freeze-LM \ - --freeze-ViT \ + $([ "$unfreeze_lm" == false ] && echo "--freeze-LM" || echo "") \ + $([ "$unfreeze_vit" == false ] && echo "--freeze-ViT" || echo "") \ --patch-dim 14 \ --img-h 336 \ --img-w 336 \ --dataloader-type external \ --tensorboard-dir ${tensorboard} \ - --language-model-type=mistral_7b \ --disable-vision-class-token \ --distributed-timeout-minutes 60 \ --allow-missing-vision-projection-checkpoint \ --ckpt-format torch } +func GenerateTestTask + < model_dir + > outputs + :: repo + :: tokenizer_model + :: model_type + :: prompt_format + :: image_preproc + :: pixel_shuffle + :: input_image_path + :: gt_path + :: eval_bsz + :: task +{ + export NCCL_IB_SL=1 + export CUDA_DEVICE_MAX_CONNECTIONS=1 + export NVTE_APPLY_QK_LAYER_SCALING=0 + export PYTHONPATH="$PYTHONPATH:$repo/examples/multimodal" + + # define custom arguments based on model type + if [ "$model_type" == "mistral-7B" ]; then + MODEL_ARGS="--language-model-type=mistral_7b" + MODEL_ARGS="$MODEL_ARGS --normalization RMSNorm" + MODEL_ARGS="$MODEL_ARGS --group-query-attention" + MODEL_ARGS="$MODEL_ARGS --num-query-groups 8" + MODEL_ARGS="$MODEL_ARGS --num-layers 32" + MODEL_ARGS="$MODEL_ARGS --hidden-size 4096" + MODEL_ARGS="$MODEL_ARGS --ffn-hidden-size 14336" + MODEL_ARGS="$MODEL_ARGS --num-attention-heads 32" + MODEL_ARGS="$MODEL_ARGS --max-position-embeddings 4096" + MODEL_ARGS="$MODEL_ARGS --decoder-seq-length 2048" + elif [ "$model_type" == "qwen2.5-7B" ]; then + MODEL_ARGS="--language-model-type=qwen2.5_7b" + MODEL_ARGS="$MODEL_ARGS --normalization RMSNorm" + MODEL_ARGS="$MODEL_ARGS --group-query-attention" + MODEL_ARGS="$MODEL_ARGS --num-query-groups 4" + MODEL_ARGS="$MODEL_ARGS --add-qkv-bias" + MODEL_ARGS="$MODEL_ARGS --no-bias-swiglu-fusion" + MODEL_ARGS="$MODEL_ARGS --no-rope-fusion" + MODEL_ARGS="$MODEL_ARGS --num-layers 28" + MODEL_ARGS="$MODEL_ARGS --hidden-size 3584" + MODEL_ARGS="$MODEL_ARGS --num-attention-heads 28" + MODEL_ARGS="$MODEL_ARGS --max-position-embeddings 131072" + MODEL_ARGS="$MODEL_ARGS --ffn-hidden-size 18944" + MODEL_ARGS="$MODEL_ARGS --norm-epsilon 1e-6" + MODEL_ARGS="$MODEL_ARGS --decoder-seq-length 2048" + elif [ "$model_type" == "eurollm-9B" ]; then + # TODO: check if rotary-base and norm-epsilon are correct + MODEL_ARGS="--language-model-type=eurollm_9b" + MODEL_ARGS="$MODEL_ARGS --normalization RMSNorm" + MODEL_ARGS="$MODEL_ARGS --group-query-attention" + MODEL_ARGS="$MODEL_ARGS --num-query-groups 8" + MODEL_ARGS="$MODEL_ARGS --num-layers 42" + MODEL_ARGS="$MODEL_ARGS --hidden-size 4096" + MODEL_ARGS="$MODEL_ARGS --num-attention-heads 32" + MODEL_ARGS="$MODEL_ARGS --max-position-embeddings 4096" + MODEL_ARGS="$MODEL_ARGS --ffn-hidden-size 12288" + MODEL_ARGS="$MODEL_ARGS --decoder-seq-length 2048" + fi + + if [ "$image_preproc" == "basic" ]; then + MODEL_ARGS="$MODEL_ARGS --seq-length 576" + elif [ "$image_preproc" == "nvlm" ]; then + MODEL_ARGS="$MODEL_ARGS --seq-length 261" # 256 image embeddings + 5 tile tag embeddings + MODEL_ARGS="$MODEL_ARGS --use-tiling" + MODEL_ARGS="$MODEL_ARGS --max-num-tiles 6" + MODEL_ARGS="$MODEL_ARGS --use-thumbnail" + MODEL_ARGS="$MODEL_ARGS --use-tile-tags" + MODEL_ARGS="$MODEL_ARGS --image-tag-type nvlm" + else + echo "Invalid image preproc: $image_preproc" + exit 1 + fi -task EvaluatePretrainedModel - < coco_dir=@ + if [ "$pixel_shuffle" == true ]; then + MODEL_ARGS="$MODEL_ARGS --pixel-shuffle" + fi + + mkdir -p $outputs + + torchrun --nproc_per_node 4 $repo/examples/multimodal/run_text_generation.py \ + --apply-layernorm-1p \ + --attention-softmax-in-fp32 \ + --use-flash-attn \ + --transformer-impl transformer_engine \ + --use-te \ + --use-checkpoint-args \ + $MODEL_ARGS \ + --no-masked-softmax-fusion \ + --untie-embeddings-and-output-weights \ + --disable-bias-linear \ + --position-embedding-type rope \ + --rotary-base 1000000 \ + --swiglu \ + --attention-dropout 0.0 \ + --hidden-dropout 0.0 \ + --tensor-model-parallel-size 4 \ + --pipeline-model-parallel-size 1 \ + --load ${model_dir} \ + --tokenizer-type MultimodalTokenizer \ + --tokenizer-model ${tokenizer_model} \ + --tokenizer-prompt-format ${prompt_format} \ + --bf16 \ + --micro-batch-size ${eval_bsz} \ + --out-seq-length 12 \ + --temperature 1.0 \ + --img-h 336 \ + --img-w 336 \ + --patch-dim 14 \ + --disable-vision-class-token \ + --seed 153 \ + --top_k 1 \ + --no-load-rng \ + --no-load-optim \ + --input-image-path ${input_image_path} \ + --num-partitions 0 \ + --partition-id 0 \ + --output-path $outputs/${task}_outputs \ + --gt-path $gt_path \ + --task ${task} \ + --num-frames 1 \ + --ckpt-format torch +} + + +task GenerateTestCoco calls GenerateTestTask < model_dir=@PretrainModel + > outputs + :: repo=@ + :: tokenizer_model=$model_name + :: model_type=@ + :: prompt_format=@ + :: image_preproc=@ + :: pixel_shuffle=@ + :: eval_bsz=@ + :: task=captioning + :: input_image_path=$coco_dir + :: gt_path=$coco_gt_path + :: .submitter=@ + :: .C=$eval_C + :: .account=$eval_account + :: .time=$eval_time + :: .cpus=$eval_cpus + :: .gres=$eval_gres + +task EvaluatePretrainedModel + < outputs=$outputs@GenerateTestCoco > coco_results :: repo=@ :: eval_bsz=@ - :: tokenizer_model=$mistral_model + :: tokenizer_model=$model_name :: prompt_format=@ + :: coco_dir=@ :: .submitter=@ :: .C=$eval_C :: .account=$eval_account @@ -271,19 +416,9 @@ task EvaluatePretrainedModel :: .cpus=$eval_cpus :: .gres=$eval_gres { - echo "Evaluating pretrained model..." - # TODO: inline the bash script - bash $repo/examples/multimodal/text_generation_mistral_clip.sh \ - --input-image-path $coco_dir \ - --gt-path $coco_dir/coco_karpathy_test.json \ - --output-path coco_outputs \ - --model-path $model_dir \ - --tokenizer-path $tokenizer_model \ - --tokenizer-prompt-format ${prompt_format} \ - --batch-size $eval_bsz \ - --task captioning + # Evaluate results python $repo/examples/multimodal/evaluate_coco.py \ - --input-path coco_outputs \ + --input-path $outputs/captioning_outputs \ --groundtruth-path $coco_dir/coco_karpathy_test_gt.json \ | tee $coco_results } @@ -292,23 +427,36 @@ task FineTuneModel < pretrained_dir=$model_dir@PretrainModel > finetuned_dir :: repo=@ - :: tokenizer_model=$mistral_model + :: tokenizer_model=$model_name + :: model_type=@ :: prompt_format=@ + :: image_preproc=@ + :: pixel_shuffle=@ + :: sft_dataset=@ :: train_iters=$finetune_iters :: batch_size=$finetune_bsz + :: micro_batch_size=$finetune_micro_bsz :: lr=$finetune_lr :: lr_warmup_fraction=$finetune_lr_warmup + :: unfreeze_vit=$finetune_unfreeze_vit :: save_interval=$finetune_save_interval :: eval_interval=$finetune_eval_interval :: external_resume=true :: external_model_dir=$finetune_model_dir :: external_tensorboard=$finetune_tensorboard + :: nnodes=$finetune_nnodes + :: gpus=$finetune_gpus + :: num_workers=@ + :: master_addr=@ + :: master_port=@ :: .submitter=@ :: .C=$finetune_C :: .account=$finetune_account :: .time=$finetune_time :: .cpus=$finetune_cpus + :: .nodes=$finetune_nodes :: .gres=$finetune_gres + :: .qos=$finetune_qos { export NCCL_IB_SL=1 export CUDA_DEVICE_MAX_CONNECTIONS=1 @@ -333,38 +481,117 @@ task FineTuneModel export NVTE_APPLY_QK_LAYER_SCALING=0 export NVTE_ALLOW_NONDETERMINISTIC_ALGO=1 - torchrun --nproc_per_node 4 $repo/examples/multimodal/train.py \ + # define custom arguments based on model type + if [ "$model_type" == "mistral-7B" ]; then + # mistral-7B is the only model that requires the --disable-vision-class-token flag + MODEL_ARGS="--language-model-type=mistral_7b" + MODEL_ARGS="$MODEL_ARGS --normalization RMSNorm" + MODEL_ARGS="$MODEL_ARGS --group-query-attention" + MODEL_ARGS="$MODEL_ARGS --num-query-groups 8" + MODEL_ARGS="$MODEL_ARGS --no-masked-softmax-fusion" + MODEL_ARGS="$MODEL_ARGS --use-flash-attn" + MODEL_ARGS="$MODEL_ARGS --untie-embeddings-and-output-weights" + MODEL_ARGS="$MODEL_ARGS --disable-bias-linear" + MODEL_ARGS="$MODEL_ARGS --position-embedding-type rope" + MODEL_ARGS="$MODEL_ARGS --rotary-percent 1.0" + MODEL_ARGS="$MODEL_ARGS --rotary-base 1000000" + MODEL_ARGS="$MODEL_ARGS --swiglu" + MODEL_ARGS="$MODEL_ARGS --attention-dropout 0.0" + MODEL_ARGS="$MODEL_ARGS --hidden-dropout 0.1" + MODEL_ARGS="$MODEL_ARGS --tensor-model-parallel-size 4" + MODEL_ARGS="$MODEL_ARGS --pipeline-model-parallel-size 1" + MODEL_ARGS="$MODEL_ARGS --num-layers 32" + MODEL_ARGS="$MODEL_ARGS --hidden-size 4096" + MODEL_ARGS="$MODEL_ARGS --num-attention-heads 32" + MODEL_ARGS="$MODEL_ARGS --decoder-seq-length 2048" + MODEL_ARGS="$MODEL_ARGS --max-position-embeddings 4096" + MODEL_ARGS="$MODEL_ARGS --ffn-hidden-size 14336" + elif [ "$model_type" == "qwen2.5-7B" ]; then + MODEL_ARGS="--language-model-type=qwen2.5_7b" + MODEL_ARGS="$MODEL_ARGS --normalization RMSNorm" + MODEL_ARGS="$MODEL_ARGS --group-query-attention" + MODEL_ARGS="$MODEL_ARGS --num-query-groups 4" + MODEL_ARGS="$MODEL_ARGS --no-masked-softmax-fusion" + MODEL_ARGS="$MODEL_ARGS --use-flash-attn" + MODEL_ARGS="$MODEL_ARGS --untie-embeddings-and-output-weights" + MODEL_ARGS="$MODEL_ARGS --disable-bias-linear" + MODEL_ARGS="$MODEL_ARGS --add-qkv-bias" + MODEL_ARGS="$MODEL_ARGS --position-embedding-type rope" + MODEL_ARGS="$MODEL_ARGS --rotary-base 1000000" + MODEL_ARGS="$MODEL_ARGS --swiglu" + MODEL_ARGS="$MODEL_ARGS --no-bias-swiglu-fusion" + MODEL_ARGS="$MODEL_ARGS --no-rope-fusion" + MODEL_ARGS="$MODEL_ARGS --attention-dropout 0.0" + MODEL_ARGS="$MODEL_ARGS --hidden-dropout 0.0" + MODEL_ARGS="$MODEL_ARGS --tensor-model-parallel-size 4" + MODEL_ARGS="$MODEL_ARGS --pipeline-model-parallel-size 1" + MODEL_ARGS="$MODEL_ARGS --num-layers 28" + MODEL_ARGS="$MODEL_ARGS --hidden-size 3584" + MODEL_ARGS="$MODEL_ARGS --num-attention-heads 28" + MODEL_ARGS="$MODEL_ARGS --decoder-seq-length 2048" + MODEL_ARGS="$MODEL_ARGS --max-position-embeddings 131072" + MODEL_ARGS="$MODEL_ARGS --ffn-hidden-size 18944" + MODEL_ARGS="$MODEL_ARGS --norm-epsilon 1e-6" + elif [ "$model_type" == "eurollm-9B" ]; then + # TODO: check if rotary-base and norm-epsilon are correct + MODEL_ARGS="--language-model-type=eurollm_9b" + MODEL_ARGS="$MODEL_ARGS --normalization RMSNorm" + MODEL_ARGS="$MODEL_ARGS --group-query-attention" + MODEL_ARGS="$MODEL_ARGS --num-query-groups 8" + MODEL_ARGS="$MODEL_ARGS --no-masked-softmax-fusion" + MODEL_ARGS="$MODEL_ARGS --use-flash-attn" + MODEL_ARGS="$MODEL_ARGS --untie-embeddings-and-output-weights" + MODEL_ARGS="$MODEL_ARGS --disable-bias-linear" + MODEL_ARGS="$MODEL_ARGS --position-embedding-type rope" + MODEL_ARGS="$MODEL_ARGS --rotary-base 10000" + MODEL_ARGS="$MODEL_ARGS --swiglu" + MODEL_ARGS="$MODEL_ARGS --attention-dropout 0.0" + MODEL_ARGS="$MODEL_ARGS --hidden-dropout 0.0" + MODEL_ARGS="$MODEL_ARGS --tensor-model-parallel-size 4" + MODEL_ARGS="$MODEL_ARGS --pipeline-model-parallel-size 1" + MODEL_ARGS="$MODEL_ARGS --num-layers 42" + MODEL_ARGS="$MODEL_ARGS --hidden-size 4096" + MODEL_ARGS="$MODEL_ARGS --num-attention-heads 32" + MODEL_ARGS="$MODEL_ARGS --decoder-seq-length 2048" + MODEL_ARGS="$MODEL_ARGS --max-position-embeddings 4096" + MODEL_ARGS="$MODEL_ARGS --ffn-hidden-size 12288" + fi + + if [ "$image_preproc" == "basic" ]; then + MODEL_ARGS="$MODEL_ARGS --seq-length 576" + elif [ "$image_preproc" == "nvlm" ]; then + # NVLM-specific settings based on Qwen 72B SFT script + MODEL_ARGS="$MODEL_ARGS --seq-length 261" # 256 image embeddings + 5 tile tag embeddings + MODEL_ARGS="$MODEL_ARGS --image-tag-type nvlm" + # Tiling-specific arguments + MODEL_ARGS="$MODEL_ARGS --use-tiling" + MODEL_ARGS="$MODEL_ARGS --max-num-tiles 6" + MODEL_ARGS="$MODEL_ARGS --use-thumbnail" + MODEL_ARGS="$MODEL_ARGS --use-tile-tags" + else + echo "Invalid image preproc: $image_preproc" + exit 1 + fi + + if [ "$pixel_shuffle" == true ]; then + MODEL_ARGS="$MODEL_ARGS --pixel-shuffle" + fi + + #export NCCL_ASYNC_ERROR_HANDLING=1 + distributed_args="--nnodes=$nnodes --nproc_per_node=$gpus" + #distributed_args="${distributed_args} --rdzv_backend c10d --rdzv_endpoint $master_addr:$master_port" + #distributed_args="${distributed_args} --node_rank=$SLURM_PROCID" + + torchrun $distributed_args $repo/examples/multimodal/train.py \ --apply-layernorm-1p \ --attention-softmax-in-fp32 \ --use-checkpoint-args \ --use-distributed-optimizer \ --transformer-impl transformer_engine \ --use-te \ - --normalization RMSNorm \ - --group-query-attention \ - --num-query-groups 8 \ - --no-masked-softmax-fusion \ - --num-workers 2 \ - --use-flash-attn \ - --untie-embeddings-and-output-weights \ - --disable-bias-linear \ - --position-embedding-type rope \ - --rotary-percent 1.0 \ - --rotary-base 1000000 \ - --swiglu \ - --attention-dropout 0.0 \ - --hidden-dropout 0.1 \ - --tensor-model-parallel-size 4 \ - --pipeline-model-parallel-size 1 \ - --num-layers 32 \ - --hidden-size 4096 \ - --num-attention-heads 32 \ - --seq-length 576 \ - --decoder-seq-length 2048 \ - --max-position-embeddings 4096 \ - --ffn-hidden-size 14336 \ + $MODEL_ARGS \ --train-iters ${train_iters} \ - --micro-batch-size 8 \ + --micro-batch-size ${micro_batch_size} \ --global-batch-size ${batch_size} \ --lr-decay-iters ${train_iters} \ --lr-warmup-fraction ${lr_warmup_fraction} \ @@ -377,13 +604,14 @@ task FineTuneModel --tokenizer-type MultimodalTokenizer \ --tokenizer-model ${tokenizer_model} \ --tokenizer-prompt-format ${prompt_format} \ - --data-path ${repo}/examples/multimodal/sft_dataset.yaml \ + --data-path $sft_dataset \ --prompt-path ${repo}/examples/multimodal/manual_prompts.json \ --save-interval ${save_interval} \ --save ${finetuned_dir} \ --load ${finetuned_dir} \ --dataloader-save ${finetuned_dir}/dataloader \ --pretrained-checkpoint ${pretrained_dir} \ + --num-workers $num_workers \ --split 100,0,0 \ --clip-grad 0.5 \ --weight-decay 0.1 \ @@ -394,25 +622,87 @@ task FineTuneModel --log-num-zeros-in-grad \ --bf16 \ --eod-mask-loss \ - --freeze-ViT \ + $([ "$unfreeze_vit" == false ] && echo "--freeze-ViT" || echo "") \ --patch-dim 14 \ --img-h 336 \ --img-w 336 \ --dataloader-type external \ --tensorboard-dir ${tensorboard} \ - --language-model-type=mistral_7b \ --disable-vision-class-token \ --distributed-timeout-minutes 60 \ --ckpt-format torch } +task GenerateTestMMMU calls GenerateTestTask + < model_dir=$finetuned_dir@FineTuneModel + > outputs + :: repo=@ + :: tokenizer_model=$model_name + :: model_type=@ + :: prompt_format=@ + :: image_preproc=@ + :: pixel_shuffle=@ + :: eval_bsz=@ + :: task=MMMU + :: input_image_path=none + :: gt_path=none + :: .submitter=@ + :: .C=$eval_C + :: .account=$eval_account + :: .time=$eval_time + :: .cpus=$eval_cpus + :: .gres=$eval_gres + +task GenerateTestTextVQA calls GenerateTestTask + < model_dir=$finetuned_dir@FineTuneModel + > outputs + :: repo=@ + :: tokenizer_model=$model_name + :: model_type=@ + :: prompt_format=@ + :: image_preproc=@ + :: pixel_shuffle=@ + :: eval_bsz=@ + :: task=TextVQA + :: input_image_path=$textvqa_dir + :: gt_path=$textvqa_gt_path + :: .submitter=@ + :: .C=$eval_C + :: .account=$eval_account + :: .time=$eval_time + :: .cpus=$eval_cpus + :: .gres=$eval_gres + +task GenerateTestAI2D calls GenerateTestTask + < model_dir=$finetuned_dir@FineTuneModel + > outputs + :: repo=@ + :: tokenizer_model=$model_name + :: model_type=@ + :: prompt_format=@ + :: image_preproc=@ + :: pixel_shuffle=@ + :: eval_bsz=@ + :: task=AI2D + :: input_image_path=$ai2d_dir + :: gt_path=$ai2d_gt_path + :: .submitter=@ + :: .C=$eval_C + :: .account=$eval_account + :: .time=$eval_time + :: .cpus=$eval_cpus + :: .gres=$eval_gres + task EvaluateFinetunedModel - < mmmu_dir=@ - < finetuned_dir=@FineTuneModel + < ai2d_outputs=$outputs@GenerateTestAI2D + < textvqa_outputs=$outputs@GenerateTestTextVQA + < mmmu_outputs=$outputs@GenerateTestMMMU + > ai2d_results + > textvqa_results > mmmu_results :: repo=@ :: eval_bsz=@ - :: tokenizer_model=$mistral_model + :: tokenizer_model=$model_name :: prompt_format=@ :: .submitter=@ :: .C=$eval_C @@ -421,32 +711,33 @@ task EvaluateFinetunedModel :: .cpus=$eval_cpus :: .gres=$eval_gres { - export PYTHONPATH="$PYTHONPATH:$repo/examples/multimodal" - bash $repo/examples/multimodal/text_generation_mistral_clip.sh \ - --input-image-path none \ - --output-path mmmu_outputs \ - --model-path $finetuned_dir \ - --tokenizer-path $tokenizer_model \ - --tokenizer-prompt-format ${prompt_format} \ - --batch-size $eval_bsz \ - --task MMMU python $repo/examples/multimodal/evaluate_mmmu.py \ - --input-path mmmu_outputs \ + --input-path $mmmu_outputs/MMMU_outputs \ | tee $mmmu_results + + python $repo/examples/multimodal/evaluate_textvqa.py \ + --input-path $textvqa_outputs/TextVQA_outputs \ + | tee $textvqa_results + + python $repo/examples/multimodal/evaluate_ai2d.py \ + --input-path $ai2d_outputs/AI2D_outputs \ + | tee $ai2d_results } func ConvertToHF < model_dir > hf_model_dir :: repo - :: mistral_model + :: model_name :: upload_id + :: hf_model_type { python $repo/examples/multimodal/convert_to_hf.py \ --mcore-load-dir $model_dir \ --hf-save-dir $hf_model_dir \ - --original-text-model-id $mistral_model \ + --original-text-model-id $model_name \ --original-vision-model-id openai/clip-vit-large-patch14-336 \ + --hf-model-type $hf_model_type \ --upload-to-hub $upload_id } @@ -454,8 +745,9 @@ task ConvertPretrainedModel calls ConvertToHF < model_dir=@PretrainModel > hf_model_dir :: repo=@ - :: mistral_model=@ - :: upload_id=@ + :: model_name=@ + :: upload_id=$prt_upload_id + :: hf_model_type=@ :: .submitter=@ :: .account=$convert_account :: .time=$convert_time @@ -466,25 +758,54 @@ task ConvertFinetunedModel calls ConvertToHF < model_dir=$finetuned_dir@FineTuneModel > hf_model_dir :: repo=@ - :: mistral_model=@ + :: model_name=@ + :: upload_id=$sft_upload_id + :: hf_model_type=@ :: .submitter=@ :: .account=$convert_account :: .time=$convert_time :: .cpus=$convert_cpus :: .partition=$convert_partition - -plan TrainPipelineMistral { - reach EvaluatePretrainedModel - reach EvaluateFinetunedModel +summary Evaluation { + of EvaluatePretrainedModel > COCOCider { + cat $coco_results | grep -o "CIDEr: [0-9.]\+" | sed "s/CIDEr: //" > $COCOCider + } + of EvaluateFinetunedModel > MMMUAccuracy TextVQAAccuracy AI2DAccuracy { + cat $mmmu_results | grep -o "MMMU average accuracy: [0-9.]\+" | sed "s/MMMU average accuracy: //" > $MMMUAccuracy + cat $textvqa_results | grep -o "TextVQA Accuracy [0-9.]\+" | sed "s/TextVQA Accuracy //" > $TextVQAAccuracy + cat $ai2d_results | grep -o "AI2D Accuracy [0-9.]\+" | sed "s/AI2D Accuracy //" > $AI2DAccuracy + } } -plan TrainPipelineTower { - reach EvaluatePretrainedModel via (TextModel: tower) - reach EvaluateFinetunedModel via (TextModel: tower) +plan TrainPipelineQwen25 { + reach EvaluateFinetunedModel via (TextModel: qwen2p5_7b) } -plan ConvertModels { - reach ConvertPretrainedModel - #reach ConvertFinetunedModel +plan TrainPipelineEuroLLM { + reach EvaluateFinetunedModel via (TextModel: eurollm_9b) } + +# plan SweepTrainQwen25 { +# reach EvaluatePretrainedModel via (TextModel: qwen2p5_7b) * (PretrainIters: 2000 5000) * (PretrainLR: 0p001) +# reach EvaluatePretrainedModel via (TextModel: qwen2p5_7b) * (PretrainIters: 5000) * (PretrainLR: 0p0005) +# reach EvaluatePretrainedModel via (TextModel: qwen2p5_7b) * (FullUnfreeze: true) * (PretrainIters: 2000 5000) * (PretrainLR: 0p001 0p0005) +# } +# +# plan TrainPipelineEuroLLM9B { +# reach EvaluatePretrainedModel, EvaluateFinetunedModel via (TextModel: eurollm_9b) +# reach EvaluateFinetunedModel via (TextModel: eurollm_9b) * (ImagePreproc: nvlm) * (PixelShuffle: true) +# } +# +# plan BackboneComparison { +# reach EvaluateFinetunedModel via (TextModel: mistral qwen2p5_7b eurollm_9b) +# reach EvaluateFinetunedModel via (TextModel: qwen2p5_7b) * (ImagePreproc: nvlm) * (PixelShuffle: true) +# } +# +# plan ConvertModels { +# reach ConvertFinetunedModel via (TextModel: qwen2p5_7b) * (ImagePreproc: nvlm) * (PixelShuffle: true) +# } +# +# plan TestPipeline { +# reach PretrainModel via (TextModel: eurollm_9b) +# } diff --git a/tapes/og_datasets.tconf b/tapes/og_datasets.tconf new file mode 100644 index 0000000000..a3c9e9418b --- /dev/null +++ b/tapes/og_datasets.tconf @@ -0,0 +1,210 @@ +global { + ducttape_experimental_imports=true + ducttape_experimental_submitters=true + ducttape_experimental_multiproc=true + + ducttape_output=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_outputs + repo=/linkhome/rech/genrce01/ued79zb/repos/Megatron-LM-pretrain + + # multimodal model parameters + # (base lm, tp, etc...) + clip_original_dir=/lustre/fswork/projects/rech/qjm/ued79zb/clip_model_og/ + model_name=( + TextModel: + mistral="mistralai/Mistral-7B-Instruct-v0.3" + tower="Unbabel/TowerInstruct-Mistral-7B-v0.2" + qwen2p5_7b="Qwen/Qwen2.5-7B-Instruct" + eurollm_9b="utter-project/EuroLLM-9B-Instruct" + ) + model_type=( + TextModel: + mistral="mistral-7B" + tower="mistral-7B" + qwen2p5_7b="qwen2.5-7B" + eurollm_9b="eurollm-9B" + ) + prompt_format=( + TextModel: + mistral="mistral" + tower="chatml" + qwen2p5_7b="qwen2p0" + eurollm_9b="chatml" + ) + image_preproc=(ImagePreproc: basic nvlm) + pixel_shuffle=(PixelShuffle: false true) + tp=4 + pp=1 + + # pre-training arguments + external_model_dir=( + TextModel: + mistral=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/mistral_7b_instruct_prt + tower=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/tower_7b_instruct_prt + qwen2p5_7b=(ImagePreproc: + basic=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/qwen2p5_7b_instruct_prt + nvlm=( + PixelShuffle: + false=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/qwen2p5_7b_instruct_prt_nvlm + true=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/qwen2p5_7b_instruct_prt_nvlm_ps + )) + eurollm_9b=(ImagePreproc: + basic=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/eurollm_9b_instruct_prt + nvlm=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/eurollm_9b_instruct_prt_nvlm + ) + ) + external_tensorboard=( + TextModel: + mistral=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/mistral_7b_instruct_prt + tower=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/tower_7b_instruct_prt + qwen2p5_7b=(ImagePreproc: + basic=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/qwen2p5_7b_instruct_prt + nvlm=( + PixelShuffle: + false=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/qwen2p5_7b_instruct_prt_nvlm + true=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/qwen2p5_7b_instruct_prt_nvlm_ps + )) + eurollm_9b=(ImagePreproc: + basic=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/eurollm_9b_instruct_prt + nvlm=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/eurollm_9b_instruct_prt_nvlm + ) + ) + pretrain_dataset=/linkhome/rech/genrce01/ued79zb/repos/Megatron-LM-pretrain/examples/multimodal/pixmo_caps.yaml + pretrain_iters=(PretrainIters: 2000 5000) + pretrain_bsz=256 + pretrain_lr=(PretrainLR: 0p001=0.001 0p0005=0.0005) + pretrain_lr_warmup=0.03 + pretrain_unfreeze_lm=(FullUnfreeze: false true) + pretrain_unfreeze_vit=false + #pretrain_save_interval=500 + pretrain_save_interval=2000 + pretrain_eval_interval=500 + + # fine-tuning arguments + sft_dataset=/linkhome/rech/genrce01/ued79zb/repos/Megatron-LM-pretrain/examples/multimodal/sft_dataset.yaml + finetune_model_dir=( + TextModel: + mistral=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/mistral_7b_instruct_sft + tower=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/tower_7b_instruct_sft + qwen2p5_7b=(ImagePreproc: + basic=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/qwen2p5_7b_instruct_sft + nvlm=( + PixelShuffle: + false=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/qwen2p5_7b_instruct_sft_nvlm + true=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/qwen2p5_7b_instruct_sft_nvlm_ps + )) + eurollm_9b=(ImagePreproc: + basic=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/eurollm_9b_instruct_sft + nvlm=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/eurollm_9b_instruct_sft_nvlm + ) + ) + finetune_tensorboard=( + TextModel: + mistral=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/mistral_7b_instruct_sft + tower=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/tower_7b_instruct_sft + qwen2p5_7b=(ImagePreproc: + basic=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/qwen2p5_7b_instruct_sft + nvlm=( + PixelShuffle: + false=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/qwen2p5_7b_instruct_sft_nvlm + true=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/qwen2p5_7b_instruct_sft_nvlm_ps + )) + eurollm_9b=(ImagePreproc: + basic=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/eurollm_9b_instruct_sft + nvlm=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/eurollm_9b_instruct_sft_nvlm + ) + ) + finetune_iters=5000 + finetune_bsz=( + TextModel: + mistral=128 + tower=128 + qwen2p5_7b=128 + eurollm_9b=120 + ) + finetune_micro_bsz=( + TextModel: + mistral=8 + tower=8 + qwen2p5_7b=8 + eurollm_9b=6 + ) + finetune_lr=1e-6 + finetune_lr_warmup=0.01 + #finetune_unfreeze_vit=(FullUnfreeze: false true) + finetune_unfreeze_vit=(SFTUnfreezeViT: false true) + #finetune_save_interval=1000 + finetune_save_interval=2500 + finetune_eval_interval=1000 + + num_workers=16 + + # eval arguments + coco_dir=/lustre/fswork/projects/rech/qjm/ued79zb/coco/ + coco_gt_path=/lustre/fswork/projects/rech/qjm/ued79zb/coco/coco_karpathy_test.json + textvqa_dir=/lustre/fswork/projects/rech/qjm/ued79zb/text_vqa/train_images + textvqa_gt_path=/lustre/fswork/projects/rech/qjm/ued79zb/text_vqa/TextVQA_0.5.1_val.json + ai2d_dir=/lustre/fswork/projects/rech/qjm/ued79zb/ai2diagram/AI2D_TEST + ai2d_gt_path=/lustre/fswork/projects/rech/qjm/ued79zb/ai2diagram/test_vlmevalkit.jsonl + eval_bsz=1 + + # convert arguments + # upload_id="patricksf/mistral-7b-clip-prt" + prt_upload_id=( + TextModel: + mistral="Unbabel/mistral-7b-clip-prt-v1" + tower="Unbabel/tower-7b-clip-prt-v1" + qwen2p5_7b="Unbabel/qwen2p5-7b-clip-prt-v1" + eurollm_9b="Unbabel/eurollm-9b-clip-prt-v1" + ) + sft_upload_id=( + TextModel: + mistral="Unbabel/mistral-7b-clip-sft-v1" + tower="Unbabel/tower-7b-clip-sft-v1" + qwen2p5_7b=( + ImagePreproc: + basic="Unbabel/qwen2.5-7b-clip-sft-v1" + nvlm="Unbabel/qwen2p5-7b-clip-hdr-sft-v4" + ) + eurollm_9b="Unbabel/eurollm-9b-clip-sft-v1" + ) + hf_model_type=(ImagePreproc: + basic="llava" + nvlm="nvlm_d" + ) + + # -- submitter arguments -- + submitter=scslurm + + prepare_C="h100" + #prepare_account="qjm@cpu" + prepare_account="qjm@h100" + prepare_time="1:00:00" + prepare_cpus=24 + prepare_gres="gpu:1" + #prepare_partition="prepost" + #prepare_cpus=32 + prepare_partition=none + + pretrain_C="h100" + pretrain_account="qjm@h100" + pretrain_time="10:00:00" + pretrain_cpus=80 + pretrain_gres="gpu:4" + + finetune_C="h100" + finetune_account="qjm@h100" + finetune_time="10:00:00" + finetune_cpus=80 + finetune_gres="gpu:4" + + eval_C="h100" + eval_account="qjm@h100" + eval_time="2:00:00" + eval_cpus=80 + eval_gres="gpu:4" + + convert_account="qjm@cpu" + convert_time="1:00:00" + convert_cpus=4 + convert_partition="prepost" +} \ No newline at end of file diff --git a/tapes/pixcaps_pangea.tconf b/tapes/pixcaps_pangea.tconf new file mode 100644 index 0000000000..978e3812f3 --- /dev/null +++ b/tapes/pixcaps_pangea.tconf @@ -0,0 +1,145 @@ +global { + ducttape_output=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_outputs_pixcaps_pangea + repo=/linkhome/rech/genrce01/ued79zb/repos/Megatron-LM-pretrain + + # multimodal model parameters + # (base lm, tp, etc...) + clip_original_dir=/lustre/fswork/projects/rech/qjm/ued79zb/clip_model_og/ + model_name=( + TextModel: + qwen2p5_7b="Qwen/Qwen2.5-7B-Instruct" + eurollm_9b="utter-project/EuroLLM-9B-Instruct" + ) + model_type=( + TextModel: + qwen2p5_7b="qwen2.5-7B" + eurollm_9b="eurollm-9B" + ) + prompt_format=( + TextModel: + qwen2p5_7b="qwen2p0" + eurollm_9b="chatml" + ) + image_preproc=nvlm + pixel_shuffle=true + tp=4 + pp=1 + + # pre-training arguments + external_model_dir=( + TextModel: + qwen2p5_7b=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/qwen2p5_7b_instruct_prt + eurollm_9b=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/eurollm_9b_instruct_prt + ) + external_tensorboard=( + TextModel: + qwen2p5_7b=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/qwen2p5_7b_instruct_prt + eurollm_9b=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/eurollm_9b_instruct_prt + ) + pretrain_dataset=/linkhome/rech/genrce01/ued79zb/repos/Megatron-LM-pretrain/examples/multimodal/pixmo_caps.yaml + pretrain_iters=5000 + pretrain_bsz=256 + pretrain_lr=0.001 + pretrain_lr_warmup=0.03 + pretrain_unfreeze_lm=false + pretrain_unfreeze_vit=false + pretrain_save_interval=5000 + pretrain_eval_interval=1000 + + # fine-tuning arguments + sft_dataset=/linkhome/rech/genrce01/ued79zb/repos/Megatron-LM-pretrain/examples/multimodal/pangea_instruct.yaml + finetune_model_dir=( + TextModel: + qwen2p5_7b=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/qwen2p5_7b_instruct_sft + eurollm_9b=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_ckpts/eurollm_9b_instruct_sft + ) + finetune_tensorboard=( + TextModel: + qwen2p5_7b=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/qwen2p5_7b_instruct_sft + eurollm_9b=/lustre/fswork/projects/rech/qjm/ued79zb/towervision_tbs/eurollm_9b_instruct_sft + ) + finetune_iters=40000 + finetune_bsz=( + TextModel: + qwen2p5_7b=120 + eurollm_9b=128 + ) + finetune_micro_bsz=( + TextModel: + qwen2p5_7b=6 + eurollm_9b=4 + ) + finetune_lr=1e-6 + finetune_lr_warmup=0.01 + finetune_unfreeze_vit=false + finetune_save_interval=10000 + finetune_eval_interval=1000 + finetune_nnodes=1 + finetune_gpus=4 + master_addr=localhost + master_port=29800 + + num_workers=16 + + # eval arguments + coco_dir=/lustre/fswork/projects/rech/qjm/ued79zb/coco/ + coco_gt_path=/lustre/fswork/projects/rech/qjm/ued79zb/coco/coco_karpathy_test.json + textvqa_dir=/lustre/fswork/projects/rech/qjm/ued79zb/text_vqa/train_images + textvqa_gt_path=/lustre/fswork/projects/rech/qjm/ued79zb/text_vqa/TextVQA_0.5.1_val.json + ai2d_dir=/lustre/fswork/projects/rech/qjm/ued79zb/ai2diagram/AI2D_TEST + ai2d_gt_path=/lustre/fswork/projects/rech/qjm/ued79zb/ai2diagram/test_vlmevalkit.jsonl + eval_bsz=1 + + # convert arguments + # upload_id="patricksf/mistral-7b-clip-prt" + prt_upload_id=( + TextModel: + qwen2p5_7b="Unbabel/qwen2p5-7b-hdr-prt-pixcaps" + eurollm_9b="Unbabel/eurollm-9b-hdr-prt-pixcaps" + ) + sft_upload_id=( + TextModel: + qwen2p5_7b="Unbabel/qwen2p5-7b-hdr-sft-pangea" + eurollm_9b="Unbabel/eurollm-9b-hdr-sft-pangea" + ) + hf_model_type=nvlm_d + + # -- submitter arguments -- + submitter=scslurm + + prepare_C="h100" + #prepare_account="qjm@cpu" + prepare_account="qjm@h100" + prepare_time="1:00:00" + #prepare_gres="none" + prepare_gres="gpu:1" + prepare_partition=none + prepare_cpus=24 + #prepare_partition="prepost" + #prepare_cpus=32 + + pretrain_C="h100" + pretrain_account="qjm@h100" + pretrain_time="10:00:00" + pretrain_cpus=80 + pretrain_gres="gpu:4" + + finetune_C="h100" + finetune_account="qjm@h100" + finetune_time="99:00:00" + finetune_cpus=80 + finetune_nodes=1 + finetune_gres="gpu:4" + finetune_qos=qos_gpu_h100-t4 + + eval_C="h100" + eval_account="qjm@h100" + eval_time="2:00:00" + eval_cpus=80 + eval_gres="gpu:4" + + convert_account="qjm@cpu" + convert_time="1:00:00" + convert_cpus=4 + convert_partition="prepost" +} \ No newline at end of file diff --git a/tools/checkpoint/loader_llama_mistral.py b/tools/checkpoint/loader_llama_mistral.py index 87062fe079..08466713c3 100644 --- a/tools/checkpoint/loader_llama_mistral.py +++ b/tools/checkpoint/loader_llama_mistral.py @@ -19,7 +19,7 @@ def add_arguments(parser): # TODO(jbarker): Need assertion to make sure *exactly* one of these is used parser.add_argument('--model-size', type=str, required=True, - choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3-8B', 'llama3-70B', 'llama3-8Bf', 'llama3-70Bf', 'mistral-7B', 'mistral-7Bf', 'yi-34B', 'qwen2.5-7B', 'qwen2.5-72B', 'qwen2.5-7Bf', 'qwen2.5-72Bf'], + choices=['llama2-7B', 'llama2-13B', 'llama2-70B', 'llama2-7Bf', 'llama2-13Bf', 'llama2-70Bf', 'llama3-8B', 'llama3-70B', 'llama3-8Bf', 'llama3-70Bf', 'mistral-7B', 'mistral-7Bf', 'yi-34B', 'qwen2.5-7B', 'qwen2.5-72B', 'qwen2.5-7Bf', 'qwen2.5-72Bf', 'eurollm-9B'], help='Model size can be `llama2-7B`, `llama2-13B`, `llama2-70B`, `llama3-8B`, `llama3-70B`, `mistral-7B`, `qwen2.5-7B`, `qwen2.5-72B` (for pretrained models), ' 'and `llama2-7Bf`, `llama2-13Bf`, `llama2-70Bf`, `llama3-8Bf`, `llama3-70bf`, `mistral-7Bf`, `qwen2.5-7Bf`, and `qwen2.5-72Bf` (for chat-finetuned models).') parser.add_argument('--checkpoint-type', type=str, required=True, @@ -64,6 +64,7 @@ def verify_transformers_version(): "qwen2.5-7Bf": 1, "qwen2.5-72B": 8, "qwen2.5-72Bf": 8, + "eurollm-9B": 1, } @@ -478,6 +479,8 @@ def _load_checkpoint(queue, args): elif "qwen2.5" in args.model_size: margs.tokenizer_type = "HuggingFaceTokenizer" margs.add_qkv_bias = True + elif "eurollm-9B" in args.model_size: + margs.tokenizer_type = "HuggingFaceTokenizer" # Arguments do sanity checks on the world size, but we don't care, # so trick it into thinking we are plenty of processes. From f865125d425fa3772f8047e739c24ff988c9d759 Mon Sep 17 00:00:00 2001 From: Patrick Fernandes Date: Mon, 6 Jan 2025 11:26:14 +0000 Subject: [PATCH 7/7] Update conda_install.sh --- conda_install.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/conda_install.sh b/conda_install.sh index aaa026176e..3effa2928e 100644 --- a/conda_install.sh +++ b/conda_install.sh @@ -29,7 +29,7 @@ echo "Megatron-LM dir: $DIR" source ${CONDA_HOME}/etc/profile.d/conda.sh # python can't handle this dependency madness, switch to C++ -# conda create -y -n ${ENV_NAME} python=3.10 +conda create -y -n ${ENV_NAME} python=3.10 conda activate ${ENV_NAME}