diff --git a/.github/workflows/cicd-main.yml b/.github/workflows/cicd-main.yml index d67bf4c6d381..77d97fd6e061 100644 --- a/.github/workflows/cicd-main.yml +++ b/.github/workflows/cicd-main.yml @@ -4185,6 +4185,55 @@ jobs: AFTER_SCRIPT: | rm -f examples/asr/evaluation_transcripts.json + L2_Stable_Diffusion_Training: + needs: [cicd-test-container-setup] + uses: ./.github/workflows/_test_template.yml + with: + RUNNER: self-hosted-azure-gpus-1 + SCRIPT: | + rm -rf examples/multimodal/text_to_image/sd_train_results + + python examples/multimodal/text_to_image/stable_diffusion/sd_train.py \ + trainer.devices=1 \ + trainer.max_steps=3 \ + +trainer.val_check_interval=10 \ + trainer.limit_val_batches=2 \ + trainer.gradient_clip_val=0 \ + exp_manager.exp_dir=examples/multimodal/text_to_image/sd_train_results \ + exp_manager.create_checkpoint_callback=False \ + exp_manager.resume_if_exists=False \ + model.resume_from_checkpoint=null \ + model.precision=16 \ + model.micro_batch_size=1 \ + model.global_batch_size=1 \ + model.first_stage_key=moments \ + model.cond_stage_key=encoded \ + +model.load_vae=False \ + +model.load_unet=False \ + +model.load_encoder=False \ + model.parameterization=v \ + model.load_only_unet=False \ + model.text_embedding_dropout_rate=0.0 \ + model.inductor=True \ + model.inductor_cudagraphs=False \ + model.capture_cudagraph_iters=15 \ + +model.unet_config.num_head_channels=64 \ + +model.unet_config.use_linear_in_transformer=True \ + model.unet_config.context_dim=1024 \ + model.unet_config.use_flash_attention=null \ + model.unet_config.resblock_gn_groups=16 \ + model.unet_config.unet_precision=fp16 \ + +model.unet_config.timesteps=1000 \ + model.optim.name=megatron_fused_adam \ + +model.optim.capturable=True \ + +model.optim.master_weights=True \ + model.optim.weight_decay=0.01 \ + model.first_stage_config.from_pretrained=null \ + model.data.num_workers=16 \ + model.data.synthetic_data=True + AFTER_SCRIPT: | + rm -rf examples/multimodal/text_to_image/sd_train_results + Nemo_CICD_Test: needs: #- OPTIONAL_L0_Unit_Tests_GPU @@ -4279,6 +4328,7 @@ jobs: - L2_TTS_Fast_dev_runs_1_Mixer-TTS - L2_TTS_Fast_dev_runs_1_Hifigan - Speech_Checkpoints_tests + - L2_Stable_Diffusion_Training if: always() runs-on: ubuntu-latest steps: diff --git a/Dockerfile b/Dockerfile index c27048784244..b03c3414e505 100644 --- a/Dockerfile +++ b/Dockerfile @@ -66,8 +66,7 @@ WORKDIR /workspace/ # We leave it here in case we need to work off of a specific commit in main RUN git clone https://github.com/NVIDIA/Megatron-LM.git && \ cd Megatron-LM && \ - git checkout 36e9b6bf3d8034b10c9bbd9fc357c2df2bd1515c && \ - git cherry-pick -n e69187bc3679ea5841030a165d587bb48b56ee77 && \ + git checkout 02871b4df8c69fac687ab6676c4246e936ce92d0 && \ pip install . # Performance optimizations for distributed optimizer: https://github.com/NVIDIA/apex/pull/1771 diff --git a/Dockerfile.ci b/Dockerfile.ci index a1bc61dece62..6d59d300b26f 100644 --- a/Dockerfile.ci +++ b/Dockerfile.ci @@ -34,7 +34,7 @@ WORKDIR /workspace # Install NeMo requirements ARG TE_TAG=bfe21c3d68b0a9951e5716fb520045db53419c5e ARG MODELOPT_VERSION=0.11.0 -ARG MCORE_TAG=c90aa1671fc0b97f80fa6c3bb892ce6f8e88e7c9 +ARG MCORE_TAG=02871b4df8c69fac687ab6676c4246e936ce92d0 ARG APEX_TAG=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c RUN \ --mount=type=bind,source=requirements,target=requirements \ diff --git a/README.rst b/README.rst index ab3a4b6b06c9..e24ce6f05a36 100644 --- a/README.rst +++ b/README.rst @@ -45,6 +45,20 @@ Latest News
Large Language Models and Multimodal +
+ + + NVIDIA releases 340B base, instruct, and reward models pretrained on a total of 9T tokens. + (2024-06-18) + + See documentation and tutorials for SFT, PEFT, and PTQ with + + Nemotron 340B + + in the NeMo Framework User Guide. +

+
+
@@ -417,7 +431,7 @@ The most recent working versions of these dependencies are here: export apex_commit=810ffae374a2b9cb4b5c5e28eaeca7d7998fca0c export te_commit=bfe21c3d68b0a9951e5716fb520045db53419c5e - export mcore_commit=fbb375d4b5e88ce52f5f7125053068caff47f93f + export mcore_commit=02871b4df8c69fac687ab6676c4246e936ce92d0 export nv_pytorch_tag=24.02-py3 When using a released version of NeMo, please refer to the `Software Component Versions `_ for the correct versions. diff --git a/docs/source/nlp/quantization.rst b/docs/source/nlp/quantization.rst index 747938bebedd..500c37dcfb26 100644 --- a/docs/source/nlp/quantization.rst +++ b/docs/source/nlp/quantization.rst @@ -103,7 +103,7 @@ The TensorRT-LLM engine can be conveniently built and run using ``TensorRTLLM`` .. code-block:: python - from nemo.export import TensorRTLLM + from nemo.export.tensorrt_llm import TensorRTLLM trt_llm_exporter = TensorRTLLM(model_dir="/path/to/trt_llm_engine_folder") diff --git a/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml b/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml index b47c719fef1d..3ec90b2d1b53 100644 --- a/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml +++ b/examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml @@ -86,7 +86,7 @@ model: # LLM configs # use GPTModel from megatron.core - mcore_gpt: False + mcore_gpt: True # model architecture encoder_seq_length: 4096 @@ -149,7 +149,7 @@ model: bias_activation_fusion: False megatron_legacy: False - transformer_engine: False + transformer_engine: True fp8: False # enables fp8 in TransformerLayer forward fp8_e4m3: False # sets fp8_format = recipe.Format.E4M3 fp8_hybrid: False # sets fp8_format = recipe.Format.HYBRID diff --git a/examples/nlp/language_modeling/megatron_gpt_continue_training.py b/examples/nlp/language_modeling/megatron_gpt_continue_training.py index 73cbb2abcce8..fd02414f6478 100755 --- a/examples/nlp/language_modeling/megatron_gpt_continue_training.py +++ b/examples/nlp/language_modeling/megatron_gpt_continue_training.py @@ -115,7 +115,11 @@ def load_from_checkpoint_dir(cls, cfg, trainer, modify_confg_fn): gpt_cfg = modify_confg_fn(hparams_file.cfg, cfg, add_cfg_to_tree=True) with tempfile.NamedTemporaryFile(suffix='.yaml') as f: OmegaConf.save(config=gpt_cfg, f=f.name) - model = cls.load_from_checkpoint(checkpoint_path=checkpoint_path, trainer=trainer, hparams_file=f.name,) + model = cls.load_from_checkpoint( + checkpoint_path=checkpoint_path, + trainer=trainer, + hparams_file=f.name, + ) return model @@ -141,11 +145,12 @@ def main(cfg) -> None: gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, find_unused_parameters=False, ) + precision = cfg.trainer.precision if cfg.trainer.precision in [16, '16', 'bf16', '16-mixed', 'bf16-mixed']: scaler = None if cfg.trainer.precision in [16, '16', '16-mixed']: scaler = GradScaler( - init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + init_scale=cfg.model.get('native_amp_init_scale', 2**32), growth_interval=cfg.model.get('native_amp_growth_interval', 1000), hysteresis=cfg.model.get('hysteresis', 2), ) @@ -156,7 +161,7 @@ def main(cfg) -> None: plugins.append(MegatronHalfPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) else: plugins.append(PipelineMixedPrecisionPlugin(precision=plugin_precision, device='cuda', scaler=scaler)) - + cfg.trainer.precision = None if cfg.get('cluster_type', None) == 'BCP': plugins.append(TorchElasticEnvironment()) @@ -165,6 +170,7 @@ def main(cfg) -> None: if 'enable_progress_bar' not in cfg.trainer or cfg.trainer.enable_progress_bar: callbacks.append(CustomProgressBar()) trainer = Trainer(plugins=plugins, strategy=strategy, **cfg.trainer, callbacks=callbacks) + cfg.trainer.precision = precision exp_manager(trainer, cfg.exp_manager) diff --git a/nemo/collections/llm/__init__.py b/nemo/collections/llm/__init__.py index 0f60fd7438b9..19911b544f43 100644 --- a/nemo/collections/llm/__init__.py +++ b/nemo/collections/llm/__init__.py @@ -13,11 +13,30 @@ SquadDataModule, ) from nemo.collections.llm.gpt.model import ( + CodeGemmaConfig2B, + CodeGemmaConfig7B, + CodeLlamaConfig7B, + CodeLlamaConfig13B, + CodeLlamaConfig34B, + CodeLlamaConfig70B, + GemmaConfig, + GemmaConfig2B, + GemmaConfig7B, + GemmaModel, GPTConfig, GPTModel, + Llama2Config7B, + Llama2Config13B, + Llama2Config70B, + Llama3Config8B, + Llama3Config70B, + LlamaConfig, + LlamaModel, MaskedTokenLossReduction, Mistral7BConfig, Mistral7BModel, + MixtralConfig, + MixtralModel, gpt_data_step, gpt_forward_step, ) @@ -31,6 +50,25 @@ "MaskedTokenLossReduction", "Mistral7BConfig", "Mistral7BModel", + "MixtralConfig", + "MixtralModel", + "LlamaConfig", + "Llama2Config7B", + "Llama2Config13B", + "Llama2Config70B", + "Llama3Config8B", + "Llama3Config70B", + "CodeLlamaConfig7B", + "CodeLlamaConfig13B", + "CodeLlamaConfig34B", + "CodeLlamaConfig70B", + "LlamaModel", + "GemmaConfig", + "GemmaConfig2B", + "GemmaConfig7B", + "CodeGemmaConfig2B", + "CodeGemmaConfig7B", + "GemmaModel", "PreTrainingDataModule", "FineTuningDataModule", "SquadDataModule", diff --git a/nemo/collections/llm/api.py b/nemo/collections/llm/api.py index 035f9d448bce..30b1bccdcb26 100644 --- a/nemo/collections/llm/api.py +++ b/nemo/collections/llm/api.py @@ -15,7 +15,7 @@ def train( trainer: Trainer, log: Annotated[Optional[NeMoLogger], Config[NeMoLogger]] = None, resume: Annotated[Optional[AutoResume], Config[AutoResume]] = None, - opt: Optional[OptimizerModule] = None, + optim: Optional[OptimizerModule] = None, tokenizer: Optional[str] = None, # TODO: Fix export export: Optional[str] = None, ) -> Path: @@ -28,7 +28,7 @@ def train( trainer (Trainer): The trainer instance configured with a MegatronStrategy. log (NeMoLogger): A nemologger instance. resume (Optional[Union[AutoResume, Resume]]): Resume training from a checkpoint. - opt (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer + optim (Optional[OptimizerModule]): The optimizer module to be used. If not provided, the default optimizer from the model will be used. tokenizer (Optional[str]): Tokenizer setting to be applied. Can be 'data' or 'model'. export (Optional[str]): Filename to save the exported checkpoint after training. @@ -49,27 +49,18 @@ def train( >>> train(model, data, trainer, tokenizer='data', source='path/to/ckpt.ckpt', export='final.ckpt') PosixPath('/path/to/log_dir') """ - if not isinstance(trainer.strategy, MegatronStrategy): - raise ValueError("Only MegatronStrategy is supported") - _log = log or NeMoLogger() - - if tokenizer: # TODO: Improve this - _use_tokenizer(model, data, tokenizer) - app_state = _log.setup( trainer, resume_if_exists=getattr(resume, "resume_if_exists", False), + task_config=getattr(train, "__io__", None), ) if resume is not None: resume.setup(model, trainer) - if opt: - opt.connect(model) - - trainer.fit(model, data) - - if hasattr(train, "__io__"): - _save_config_img(app_state.exp_dir, train.__io__) + if optim: + optim.connect(model) + if tokenizer: # TODO: Improve this + _use_tokenizer(model, data, tokenizer) trainer.fit(model, data) diff --git a/nemo/collections/llm/fn/activation.py b/nemo/collections/llm/fn/activation.py new file mode 100644 index 000000000000..89b5ba93f0f6 --- /dev/null +++ b/nemo/collections/llm/fn/activation.py @@ -0,0 +1,11 @@ +import torch + + +@torch.jit.script +def gelu_impl(x): + """OpenAI's gelu implementation.""" + return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x * (1.0 + 0.044715 * x * x))) + + +def openai_gelu(x): + return gelu_impl(x) diff --git a/nemo/collections/llm/gpt/data/pre_training.py b/nemo/collections/llm/gpt/data/pre_training.py index 80e099290b1d..a659823b085e 100644 --- a/nemo/collections/llm/gpt/data/pre_training.py +++ b/nemo/collections/llm/gpt/data/pre_training.py @@ -3,6 +3,7 @@ import pytorch_lightning as pl from pytorch_lightning.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS +from torch.utils import data from torch.utils.data import DataLoader from nemo.lightning.pytorch.plugins import MegatronDataSampler @@ -121,7 +122,7 @@ def _create_dataloader(self, dataset, **kwargs) -> DataLoader: num_workers=self.num_workers, pin_memory=self.pin_memory, persistent_workers=self.persistent_workers, - collate_fn=dataset.collate_fn, + collate_fn=getattr(dataset, 'collate_fn', data.dataloader.default_collate), **kwargs, ) diff --git a/nemo/collections/llm/gpt/model/__init__.py b/nemo/collections/llm/gpt/model/__init__.py index fcb78d6cd397..4f2de2df690e 100644 --- a/nemo/collections/llm/gpt/model/__init__.py +++ b/nemo/collections/llm/gpt/model/__init__.py @@ -5,13 +5,54 @@ gpt_data_step, gpt_forward_step, ) +from nemo.collections.llm.gpt.model.gemma import ( + CodeGemmaConfig2B, + CodeGemmaConfig7B, + GemmaConfig, + GemmaConfig2B, + GemmaConfig7B, + GemmaModel, +) +from nemo.collections.llm.gpt.model.llama import ( + CodeLlamaConfig7B, + CodeLlamaConfig13B, + CodeLlamaConfig34B, + CodeLlamaConfig70B, + Llama2Config7B, + Llama2Config13B, + Llama2Config70B, + Llama3Config8B, + Llama3Config70B, + LlamaConfig, + LlamaModel, +) from nemo.collections.llm.gpt.model.mistral_7b import Mistral7BConfig, Mistral7BModel +from nemo.collections.llm.gpt.model.mixtral import MixtralConfig, MixtralModel __all__ = [ "GPTConfig", "GPTModel", "Mistral7BConfig", "Mistral7BModel", + "MixtralConfig", + "MixtralModel", + "LlamaConfig", + "Llama2Config7B", + "Llama2Config13B", + "Llama2Config70B", + "Llama3Config8B", + "Llama3Config70B", + "CodeLlamaConfig7B", + "CodeLlamaConfig13B", + "CodeLlamaConfig34B", + "CodeLlamaConfig70B", + "GemmaConfig", + "GemmaConfig2B", + "GemmaConfig7B", + "CodeGemmaConfig2B", + "CodeGemmaConfig7B", + "GemmaModel", + "LlamaModel", "MaskedTokenLossReduction", "gpt_data_step", "gpt_forward_step", diff --git a/nemo/collections/llm/gpt/model/base.py b/nemo/collections/llm/gpt/model/base.py index e577ddb63d26..f5823fa9acd6 100644 --- a/nemo/collections/llm/gpt/model/base.py +++ b/nemo/collections/llm/gpt/model/base.py @@ -10,7 +10,7 @@ from nemo.collections.llm import fn from nemo.lightning import get_vocab_size, io from nemo.lightning.megatron_parallel import MaskedTokenLossReduction -from nemo.lightning.pytorch.opt import MegatronOptimizerModule, OptimizerModule +from nemo.lightning.pytorch.optim import MegatronOptimizerModule, OptimizerModule if TYPE_CHECKING: from megatron.core.models.gpt.gpt_model import GPTModel as MCoreGPTModel @@ -19,11 +19,11 @@ @dataclass -class GPTConfig(TransformerConfig): +class GPTConfig(TransformerConfig, io.IOMixin): # From megatron.core.models.gpt.gpt_model.GPTModel fp16_lm_cross_entropy: bool = False parallel_output: bool = True - share_embeddings_and_output_weights: bool = False + share_embeddings_and_output_weights: bool = True make_vocab_size_divisible_by: int = 128 position_embedding_type: Literal["learned_absolute", "rope"] = "learned_absolute" rotary_base: int = 10000 @@ -48,7 +48,7 @@ def configure_model(self, tokenizer) -> "MCoreGPTModel": return MCoreGPTModel( self, - transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(), + transformer_layer_spec=get_gpt_layer_with_transformer_engine_spec(self.num_moe_experts), vocab_size=get_vocab_size(self, tokenizer.vocab_size, self.make_vocab_size_divisible_by), max_sequence_length=self.seq_length, fp16_lm_cross_entropy=self.fp16_lm_cross_entropy, @@ -78,7 +78,8 @@ def __init__( self.optim.connect(self) # This will bind the `configure_optimizers` method def configure_model(self) -> None: - self.module = self.config.configure_model(self.tokenizer) + if not hasattr(self, "module"): + self.module = self.config.configure_model(self.tokenizer) def forward( self, @@ -170,7 +171,7 @@ def gpt_forward_step(model, batch) -> torch.Tensor: def get_batch_on_this_context_parallel_rank(batch): from megatron.core import parallel_state - if cp_size := parallel_state.get_context_parallel_world_size() > 1: + if (cp_size := parallel_state.get_context_parallel_world_size()) > 1: num_valid_tokens_in_ub = None if 'loss_mask' in batch and batch['loss_mask'] is not None: num_valid_tokens_in_ub = batch['loss_mask'].sum() @@ -200,7 +201,7 @@ def get_packed_seq_params(batch): cu_seqlens = batch['cu_seqlens'].squeeze() # remove batch size dimension (mbs=1) # remove -1 "paddings" added in collate_fn - if cu_seqlens_argmin := batch.get('cu_seqlens_argmin', None) is not None: + if (cu_seqlens_argmin := batch.get('cu_seqlens_argmin', None)) is not None: # pre-compute cu_seqlens_argmin in dataset class for perf cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()] else: diff --git a/nemo/collections/llm/gpt/model/gemma.py b/nemo/collections/llm/gpt/model/gemma.py new file mode 100644 index 000000000000..e58c9152d098 --- /dev/null +++ b/nemo/collections/llm/gpt/model/gemma.py @@ -0,0 +1,299 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Callable, Optional + +import torch + +from nemo.collections.llm.fn.activation import openai_gelu +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel +from nemo.collections.llm.utils import Config +from nemo.lightning import OptimizerModule, io, teardown + +if TYPE_CHECKING: + from transformers import GemmaForCausalLM + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +# Note: Gemma requires huggingface transformers >= 4.38 +# Note: these Gemma configs are copied from the corresponding HF model. You may need to modify the parameter for +# your own needs, in particular: seq_length and rotary_base. +@dataclass +class GemmaConfig(GPTConfig): + # configs that are common across model sizes + normalization: str = "RMSNorm" + activation_func: Callable = openai_gelu + gated_linear_unit: bool = True + position_embedding_type: str = "rope" + add_bias_linear: bool = False + seq_length: int = 8192 + kv_channels: int = 256 + share_embeddings_and_output_weights: bool = True + # Note: different behavior compared to Legacy NeMo + # Legacy NeMo does not set layernorm_zero_centered_gamma and instead adds 1 in the HF -> NeMo conversion script + # The present implementation is more in line with the official implementation + layernorm_zero_centered_gamma: bool = True + + +@dataclass +class GemmaConfig2B(GemmaConfig): + num_layers: int = 18 + hidden_size: int = 2048 + num_attention_heads: int = 8 + num_query_groups: int = 1 + ffn_hidden_size: int = 16384 + + +@dataclass +class GemmaConfig7B(GemmaConfig): + num_layers: int = 28 + hidden_size: int = 3072 + num_attention_heads: int = 16 + num_query_groups: int = 16 + ffn_hidden_size: int = 24576 + + +class CodeGemmaConfig2B(GemmaConfig2B): + pass + + +class CodeGemmaConfig7B(GemmaConfig7B): + pass + + +class GemmaModel(GPTModel): + def __init__( + self, + config: Annotated[Optional[GemmaConfig], Config[GemmaConfig]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + ): + super().__init__(config or GemmaConfig(), optim=optim, tokenizer=tokenizer) + + +@io.model_importer(GemmaModel, "hf") +class HFGemmaImporter(io.ModelConnector["GemmaForCausalLM", GemmaModel]): + def init(self) -> GemmaModel: + return GemmaModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import GemmaForCausalLM + + source = GemmaForCausalLM.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted Gemma model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.norm.weight": "decoder.final_layernorm.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1]) + + @property + def tokenizer(self) -> "AutoTokenizer": + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(str(self)) + + @property + def config(self) -> GemmaConfig: + from transformers import GemmaConfig as HFGemmaConfig + + source = HFGemmaConfig.from_pretrained(str(self)) + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + output = GemmaConfig( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + num_attention_heads=source.num_attention_heads, + init_method_std=source.initializer_range, + layernorm_epsilon=source.rms_norm_eps, + num_query_groups=source.num_key_value_heads, + rotary_base=source.rope_theta, + gated_linear_unit=True, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), + share_embeddings_and_output_weights=False, + ) + + return output + + +@io.model_exporter(GemmaModel, "hf") +class HFGemmaExporter(io.ModelConnector[GemmaModel, "GemmaForCausalLM"]): + def init(self) -> "GemmaForCausalLM": + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_config(self.config) + + def apply(self, output_path: Path) -> Path: + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_linear_fc1]) + + @property + def tokenizer(self): + return io.load_ckpt(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "GemmaConfig": + source: GemmaConfig = io.load_ckpt(str(self)).model.config + + from transformers import GemmaConfig as HFGemmaConfig + + return HFGemmaConfig( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + max_position_embeddings=source.seq_length, + initializer_range=source.init_method_std, + rms_norm_eps=source.layernorm_epsilon, + num_key_value_heads=source.num_query_groups, + vocab_size=self.tokenizer.vocab_size, + ) + + +@io.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), +) +def _export_qkv(ctx: io.TransformCTX, linear_qkv): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +@io.state_transform( + source_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), + target_key="decoder.layers.*.mlp.linear_fc1.weight", +) +def _import_linear_fc1(down, gate): + return torch.cat((down, gate), axis=0).float() + + +@io.state_transform( + source_key="decoder.layers.*.mlp.linear_fc1.weight", + target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), +) +def _export_linear_fc1(linear_fc1): + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + + return gate_proj, up_proj + + +__all__ = [ + "GemmaConfig", + "GemmaConfig2B", + "GemmaConfig7B", + "CodeGemmaConfig2B", + "CodeGemmaConfig7B", + "GemmaModel", +] diff --git a/nemo/collections/llm/gpt/model/llama.py b/nemo/collections/llm/gpt/model/llama.py new file mode 100644 index 000000000000..aa089b077041 --- /dev/null +++ b/nemo/collections/llm/gpt/model/llama.py @@ -0,0 +1,342 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Annotated, Callable, Optional + +import torch +import torch.nn.functional as F + +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel +from nemo.collections.llm.utils import Config +from nemo.lightning import OptimizerModule, io, teardown + +if TYPE_CHECKING: + from transformers import LlamaConfig as HFLlamaConfig + from transformers import LlamaForCausalLM + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + + +# Note: these Llama configs are copied from the corresponding HF model. You may need to modify the parameter for +# your own needs, in particular: seq_length and rotary_base. +@dataclass +class LlamaConfig(GPTConfig): + # configs that are common across model sizes + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + gated_linear_unit: bool = True + position_embedding_type: str = "rope" + add_bias_linear: bool = False + seq_length: int = 4096 + + +@dataclass +class Llama2Config7B(LlamaConfig): + num_layers: int = 32 + hidden_size: int = 4096 + num_attention_heads: int = 32 + num_query_groups: int = 32 + ffn_hidden_size: int = 11008 + + +@dataclass +class Llama2Config13B(LlamaConfig): + num_layers: int = 40 + hidden_size: int = 5120 + num_attention_heads: int = 40 + num_query_groups: int = 40 + ffn_hidden_size: int = 13824 + + +@dataclass +class Llama2Config70B(LlamaConfig): + num_layers: int = 80 + hidden_size: int = 8192 + num_attention_heads: int = 64 + num_query_groups: int = 8 + ffn_hidden_size: int = 28672 + + +@dataclass +class Llama3Config8B(Llama2Config7B): + seq_length: int = 8192 + num_query_groups: int = 8 + ffn_hidden_size: int = 14336 + + +@dataclass +class Llama3Config70B(Llama2Config70B): + seq_length: int = 8192 + + +@dataclass +class CodeLlamaConfig7B(Llama2Config7B): + rotary_base: int = 1_000_000 + seq_length: int = 16384 + + +@dataclass +class CodeLlamaConfig13B(Llama2Config13B): + rotary_base: int = 1_000_000 + seq_length: int = 16384 + + +@dataclass +class CodeLlamaConfig34B(LlamaConfig): + num_layers: int = 48 + hidden_size: int = 8192 + num_attention_heads: int = 64 + num_query_groups: int = 8 + ffn_hidden_size: int = 22016 + rotary_base: int = 1_000_000 + seq_length: int = 16384 + + +@dataclass +class CodeLlamaConfig70B(Llama2Config70B): + pass + + +class LlamaModel(GPTModel): + def __init__( + self, + config: Annotated[Optional[LlamaConfig], Config[LlamaConfig]] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + ): + super().__init__(config or LlamaConfig(), optim=optim, tokenizer=tokenizer) + + +@io.model_importer(LlamaModel, "hf") +class HFLlamaImporter(io.ModelConnector["LlamaForCausalLM", LlamaModel]): + def init(self) -> LlamaModel: + return LlamaModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import LlamaForCausalLM + + source = LlamaForCausalLM.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + print(f"Converted Llama model to Nemo, model saved to {output_path}") + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.mlp.down_proj.weight": "decoder.layers.*.mlp.linear_fc2.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.mlp.linear_fc1.layer_norm_weight", + "model.norm.weight": "decoder.final_layernorm.weight", + "lm_head.weight": "output_layer.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_linear_fc1]) + + @property + def tokenizer(self) -> "AutoTokenizer": + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(str(self)) + + @property + def config(self) -> LlamaConfig: + from transformers import LlamaConfig as HFLlamaConfig + + source = HFLlamaConfig.from_pretrained(str(self)) + + def make_vocab_size_divisible_by(vocab_size): + base = 128 + while vocab_size % base != 0: + base //= 2 + return base + + output = LlamaConfig( + num_layers=source.num_hidden_layers, + hidden_size=source.hidden_size, + ffn_hidden_size=source.intermediate_size, + num_attention_heads=source.num_attention_heads, + init_method_std=source.initializer_range, + layernorm_epsilon=source.rms_norm_eps, + num_query_groups=source.num_key_value_heads, + rotary_base=source.rope_theta, + gated_linear_unit=True, + make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), + share_embeddings_and_output_weights=False, + ) + + return output + + +@io.model_exporter(LlamaModel, "hf") +class HFLlamaExporter(io.ModelConnector[LlamaModel, "LlamaForCausalLM"]): + def init(self) -> "LlamaForCausalLM": + from transformers import AutoModelForCausalLM + + return AutoModelForCausalLM.from_config(self.config) + + def apply(self, output_path: Path) -> Path: + target = self.init() + source, _ = self.nemo_load(str(self)) + target = self.convert_state(source, target) + + target = target.cpu() + target.save_pretrained(output_path) + self.tokenizer.save_pretrained(output_path) + + return output_path + + def convert_state(self, source, target): + mapping = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.layers.*.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "decoder.layers.*.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight": "model.layers.*.input_layernorm.weight", + "decoder.layers.*.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_export_qkv, _export_linear_fc1]) + + @property + def tokenizer(self): + return io.load_ckpt(str(self)).model.tokenizer.tokenizer + + @property + def config(self) -> "HFLlamaConfig": + source: LlamaConfig = io.load_ckpt(str(self)).model.config + + from transformers import LlamaConfig as HFLlamaConfig + + return HFLlamaConfig( + num_hidden_layers=source.num_layers, + hidden_size=source.hidden_size, + intermediate_size=source.ffn_hidden_size, + num_attention_heads=source.num_attention_heads, + max_position_embeddings=source.seq_length, + initializer_range=source.init_method_std, + rms_norm_eps=source.layernorm_epsilon, + num_key_value_heads=source.num_query_groups, + rope_theta=source.rotary_base, + vocab_size=self.tokenizer.vocab_size, + ) + + +@io.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.state_transform( + source_key="decoder.layers.*.self_attention.linear_qkv.weight", + target_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), +) +def _export_qkv(ctx: io.TransformCTX, linear_qkv): + megatron_config = ctx.source.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + qkv_total_dim = head_num + 2 * num_query_groups + + linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, hidden_size]) + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() + k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() + v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() + + return q_proj, k_proj, v_proj + + +@io.state_transform( + source_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), + target_key="decoder.layers.*.mlp.linear_fc1.weight", +) +def _import_linear_fc1(down, gate): + return torch.cat((down, gate), axis=0).float() + + +@io.state_transform( + source_key="decoder.layers.*.mlp.linear_fc1.weight", + target_key=("model.layers.*.mlp.gate_proj.weight", "model.layers.*.mlp.up_proj.weight"), +) +def _export_linear_fc1(linear_fc1): + gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) + + return gate_proj, up_proj + + +__all__ = [ + "LlamaConfig", + "Llama2Config7B", + "Llama2Config13B", + "Llama2Config70B", + "Llama3Config8B", + "Llama3Config70B", + "CodeLlamaConfig7B", + "CodeLlamaConfig13B", + "CodeLlamaConfig34B", + "CodeLlamaConfig70B", + "LlamaModel", +] diff --git a/nemo/collections/llm/gpt/model/mistral_7b.py b/nemo/collections/llm/gpt/model/mistral_7b.py index 56dd0090346b..619cbb40526e 100644 --- a/nemo/collections/llm/gpt/model/mistral_7b.py +++ b/nemo/collections/llm/gpt/model/mistral_7b.py @@ -10,7 +10,7 @@ from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel from nemo.collections.llm.utils import Config from nemo.lightning import io, teardown -from nemo.lightning.pytorch.opt import OptimizerModule +from nemo.lightning.pytorch.optim import OptimizerModule if TYPE_CHECKING: from transformers import MistralConfig, MistralForCausalLM @@ -71,9 +71,6 @@ def apply(self, output_path: Path) -> Path: return output_path - def on_import_ckpt(self, model: pl.LightningModule): - model.tokenizer = self.tokenizer - def convert_state(self, source, target): mapping = { "model.embed_tokens.weight": "embedding.word_embeddings.weight", @@ -111,7 +108,7 @@ def make_vocab_size_divisible_by(mistral_vocab_size): hidden_size=source.hidden_size, ffn_hidden_size=source.intermediate_size, num_attention_heads=source.num_attention_heads, - max_position_embeddings=source.max_position_embeddings, + # max_position_embeddings=source.max_position_embeddings, init_method_std=source.initializer_range, layernorm_epsilon=source.rms_norm_eps, num_query_groups=source.num_key_value_heads, @@ -119,6 +116,7 @@ def make_vocab_size_divisible_by(mistral_vocab_size): gated_linear_unit=True, make_vocab_size_divisible_by=make_vocab_size_divisible_by(source.vocab_size), window_size=[source.sliding_window, 0], + share_embeddings_and_output_weights=False, ) return output diff --git a/nemo/collections/llm/gpt/model/mixtral.py b/nemo/collections/llm/gpt/model/mixtral.py new file mode 100644 index 000000000000..bd0b79f1137a --- /dev/null +++ b/nemo/collections/llm/gpt/model/mixtral.py @@ -0,0 +1,183 @@ +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Optional + +import torch +import torch.nn.functional as F + +from nemo.collections.llm.gpt.model.base import GPTConfig, GPTModel +from nemo.lightning import io, teardown +from nemo.lightning.pytorch.optim import OptimizerModule + +if TYPE_CHECKING: + from transformers import MistralConfig, MistralForCausalLM + + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + +@dataclass +class MixtralConfig(GPTConfig): + """ + Config for Mixtral-8x7B model + Official announcement: https://mistral.ai/news/mixtral-of-experts/ + """ + + normalization: str = "RMSNorm" + activation_func: Callable = F.silu + position_embedding_type: str = "rope" + add_bias_linear: bool = False + gated_linear_unit: bool = True + apply_query_key_layer_scaling: bool = False # TODO: Should this be True? + + num_layers: int = 32 + hidden_size: int = 4096 + num_attention_heads: int = 32 + num_query_groups: int = 8 + ffn_hidden_size: int = 14336 + max_position_embeddings: int = 4096 # 32768 + seq_length: int = 4096 # 32768 + # MoE + num_moe_experts: int = 8 + moe_router_topk: int = 1 + + init_method_std: float = 0.02 + layernorm_epsilon: float = 1e-5 + # rotary + rotary_percent: float = 0.5 + rotary_base: float = 10000 + + +class MixtralModel(GPTModel): + def __init__( + self, + config: Optional[MixtralConfig] = None, + optim: Optional[OptimizerModule] = None, + tokenizer: Optional["TokenizerSpec"] = None, + ): + super().__init__(config or MixtralConfig(), optim=optim, tokenizer=tokenizer) + + +@io.model_importer(MixtralModel, ext="hf") +class HFMixtralImporter(io.ModelConnector["MixtralForCausalLM", MixtralModel]): + def init(self) -> MixtralModel: + return MixtralModel(self.config, tokenizer=self.tokenizer) + + def apply(self, output_path: Path) -> Path: + from transformers import MixtralForCausalLM + + source = MixtralForCausalLM.from_pretrained(str(self)) + target = self.init() + trainer = self.nemo_setup(target) + self.convert_state(source, target) + self.nemo_save(output_path, trainer) + + teardown(trainer, target) + del trainer, target + + return output_path + + def convert_state(self, source, target): + mapping = { + "model.embed_tokens.weight": "embedding.word_embeddings.weight", + "model.layers.*.self_attn.o_proj.weight": "decoder.layers.*.self_attention.linear_proj.weight", + "model.layers.*.input_layernorm.weight": "decoder.layers.*.self_attention.linear_qkv.layer_norm_weight", + "model.layers.*.post_attention_layernorm.weight": "decoder.layers.*.pre_mlp_layernorm.weight", + # MoE + "model.layers.*.block_sparse_moe.experts.*.w2.weight": "decoder.layers.*.mlp.experts.local_experts.*.linear_fc2.weight", + "model.layers.*.block_sparse_moe.gate.weight": "decoder.layers.*.mlp.router.weight", + # lm-head + "model.norm.weight": "decoder.final_layernorm.weight", + "lm_head.weight": "output_layer.weight", + } + + return io.apply_transforms(source, target, mapping=mapping, transforms=[_import_qkv, _import_moe_w1_w3]) + + @property + def tokenizer(self) -> "AutoTokenizer": + from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer + + return AutoTokenizer(str(self)) + + @property + def config(self) -> MixtralConfig: + from transformers import MixtralConfig as HfMixtralConfig + + config = HfMixtralConfig.from_pretrained(str(self)) + return MixtralConfig( + activation_func=F.silu, + # network + num_layers=config.num_hidden_layers, + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + max_position_embeddings=config.max_position_embeddings, # TODO + seq_length=config.max_position_embeddings, + # RoPE + position_embedding_type='rope', + rotary_base=config.rope_theta, + # Transformer config + num_attention_heads=config.num_attention_heads, + num_query_groups=config.num_key_value_heads, + num_moe_experts=config.num_local_experts, + moe_router_topk=config.num_experts_per_tok, + # norm + normalization='RMSNorm', + layernorm_epsilon=config.rms_norm_eps, + # Init + init_method_std=config.initializer_range, + gated_linear_unit=True, + # Vocab + make_vocab_size_divisible_by=128, + ) + + +@io.state_transform( + source_key=( + "model.layers.*.self_attn.q_proj.weight", + "model.layers.*.self_attn.k_proj.weight", + "model.layers.*.self_attn.v_proj.weight", + ), + target_key="decoder.layers.*.self_attention.linear_qkv.weight", +) +def _import_qkv(ctx: io.TransformCTX, q, k, v): + megatron_config = ctx.target.config + + head_num = megatron_config.num_attention_heads + num_query_groups = megatron_config.num_query_groups + heads_per_group = head_num // num_query_groups + hidden_size = megatron_config.hidden_size + head_num = megatron_config.num_attention_heads + head_size = hidden_size // head_num + + old_tensor_shape = q.size() + new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] + new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] + + q = q.view(*new_q_tensor_shape) + k = k.view(*new_kv_tensor_shape) + v = v.view(*new_kv_tensor_shape) + + qkv_weights_l = [] + for i in range(num_query_groups): + qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) + qkv_weights_l.append(k[i : i + 1, :, :]) + qkv_weights_l.append(v[i : i + 1, :, :]) + qkv_weights = torch.cat(qkv_weights_l) + assert qkv_weights.ndim == 3, qkv_weights.shape + assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape + assert qkv_weights.shape[1] == head_size, qkv_weights.shape + assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape + + qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + + return qkv_weights + + +@io.state_transform( + source_key=( + "model.layers.*.block_sparse_moe.experts.*.w1.weight", + "model.layers.*.block_sparse_moe.experts.*.w3.weight", + ), + target_key="decoder.layers.*.mlp.experts.local_experts.*.linear_fc1.weight", +) +def _import_moe_w1_w3(gate_proj, up_proj): + return torch.cat((gate_proj, up_proj), axis=0) diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py index f75bd37f91f8..f001e8f58d25 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_layer_modelopt_spec.py @@ -23,6 +23,7 @@ from megatron.core.transformer.enums import AttnMaskType from megatron.core.transformer.identity_op import IdentityOp from megatron.core.transformer.mlp import MLP, MLPSubmodules + from megatron.core.transformer.moe.moe_layer import MoELayer from megatron.core.transformer.spec_utils import ModuleSpec from megatron.core.transformer.transformer_layer import TransformerLayer, TransformerLayerSubmodules @@ -40,7 +41,7 @@ # Use this spec for Model Optimizer PTQ and TensorRT-LLM export -def get_gpt_layer_modelopt_spec() -> ModuleSpec: +def get_gpt_layer_modelopt_spec(num_experts: int = None) -> ModuleSpec: """Mix the native spec with TENorm. This is essentially the native local spec except for the layernorm implementation @@ -67,18 +68,38 @@ def get_gpt_layer_modelopt_spec() -> ModuleSpec: ), self_attn_bda=get_bias_dropout_add, pre_mlp_layernorm=TENorm, - mlp=ModuleSpec( - module=MLP, - submodules=MLPSubmodules( - linear_fc1=ColumnParallelLinear, - linear_fc2=RowParallelLinear, - ), - ), + mlp=_get_mlp_module_spec(num_experts=num_experts), mlp_bda=get_bias_dropout_add, # Map TE-layernorm-fusion keys back sharded_state_dict_keys_map={ 'input_layernorm.': 'self_attention.linear_qkv.layer_norm_', - 'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_', + **({'pre_mlp_layernorm.': 'mlp.linear_fc1.layer_norm_'} if num_experts is None else {}), }, ), ) + + +# Helper function to get module spec for MLP/MoE +def _get_mlp_module_spec(num_experts: int = None, moe_grouped_gemm: bool = False) -> ModuleSpec: + if num_experts is None: + # Dense MLP w/ or w/o TE modules. + return ModuleSpec( + module=MLP, + submodules=MLPSubmodules( + linear_fc1=ColumnParallelLinear, + linear_fc2=RowParallelLinear, + ), + ) + else: + # Mixture of experts with modules in megatron core. + return ModuleSpec( + module=MoELayer, + submodules=( + MLPSubmodules( + linear_fc1=ColumnParallelLinear, + linear_fc2=RowParallelLinear, + ) + if not moe_grouped_gemm + else None + ), + ) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index eb7d7b694e2f..ae409b1b72bf 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -155,7 +155,7 @@ def get_specs(spec_name, num_experts=None, moe_grouped_gemm=False, use_te=True, "te_gpt": get_gpt_layer_with_transformer_engine_spec(num_experts, moe_grouped_gemm), "megatron_falcon_gpt": get_falcon_layer_spec(), "megatron_gpt_full_te_layer_autocast": get_gpt_full_te_layer_autocast_spec(), - "modelopt": get_gpt_layer_modelopt_spec(), + "modelopt": get_gpt_layer_modelopt_spec(num_experts), "te_gpt_hyena": get_gpt_layer_with_te_and_hyena_spec(hyena_cfg), } if spec_name not in name_spec_dict: @@ -535,8 +535,6 @@ def setup_mcore_distributed_parallel(self): config, ddp_config, model_chunk, - data_parallel_group=parallel_state.get_data_parallel_group(with_context_parallel=True), - expert_data_parallel_group=parallel_state.get_data_modulo_expert_parallel_group(), # Turn off bucketing for model_chunk 2 onwards, since communication for these # model chunks is overlapped with compute anyway. disable_bucketing=(model_chunk_idx > 0), @@ -1474,15 +1472,16 @@ def build_train_valid_test_datasets(self): # E = argmin_e e * N_d >= N, or equivalently E = ceildiv(N, N_d) # Where N_d is the total number of samples in a dataset (files), and N is the requested number of samples (provided for every split in the list below). # Setting N = 1 we force E to be 1 as well + legacy_dataset = self.cfg.data.get("legacy_dataset", False) if self.trainer.limit_val_batches <= 1.0 and isinstance(self.trainer.limit_val_batches, float): - train_valid_test_num_samples[1] = None + train_valid_test_num_samples[1] = 1 if legacy_dataset else None # Add extra FIM tokens to tokenizer if self.cfg.data.get('add_fim', False) and self.cfg.tokenizer.library == 'megatron': fim_tokens = self.cfg.data.fim.extra_tokens fim_tokens = [fim_tokens.prefix, fim_tokens.middle, fim_tokens.suffix, fim_tokens.pad, fim_tokens.eod] self.tokenizer.add_special_tokens({'additional_special_tokens': fim_tokens}) - if self.cfg.data.get("legacy_dataset", False): + if legacy_dataset: self._train_ds, self._validation_ds, self._test_ds = build_train_valid_test_datasets( cfg=self.cfg, trainer=self.trainer, diff --git a/nemo/collections/nlp/parts/nlp_overrides.py b/nemo/collections/nlp/parts/nlp_overrides.py index 0555776457a5..2fdb1906c31f 100644 --- a/nemo/collections/nlp/parts/nlp_overrides.py +++ b/nemo/collections/nlp/parts/nlp_overrides.py @@ -444,6 +444,9 @@ def _check_param_groups_mismatch(self, checkpoint_path: Union[str, Path], sharde bool: True if the number of param groups does not match """ common_state_dict = dist_checkpointing.load_common_state_dict(checkpoint_path) + # @akoumparouli: check if it contains an mcore dist opt + if common_state_dict.get('optimizer_states', [{}])[0].get('param_groups', None) is None: + return False model_param_groups = self._get_param_group(common_state_dict) checkpoint_param_groups = self._get_param_group(sharded_state_dict) return len(model_param_groups) != len(checkpoint_param_groups) diff --git a/nemo/deploy/deploy_pytriton.py b/nemo/deploy/deploy_pytriton.py index 25e09cf3eacc..1e1333f03b55 100644 --- a/nemo/deploy/deploy_pytriton.py +++ b/nemo/deploy/deploy_pytriton.py @@ -29,7 +29,7 @@ class DeployPyTriton(DeployBase): Example: from nemo.deploy import DeployPyTriton, NemoQueryLLM - from nemo.export import TensorRTLLM + from nemo.export.tensorrt_llm import TensorRTLLM trt_llm_exporter = TensorRTLLM(model_dir="/path/for/model/files") trt_llm_exporter.export( diff --git a/nemo/deploy/nlp/__init__.py b/nemo/deploy/nlp/__init__.py index ae4db1ce6f2a..a2110931c6df 100644 --- a/nemo/deploy/nlp/__init__.py +++ b/nemo/deploy/nlp/__init__.py @@ -19,4 +19,8 @@ except Exception: use_query_llm = False -from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable +use_megatron_llm = True +try: + from nemo.deploy.nlp.megatronllm_deployable import MegatronLLMDeployable +except Exception: + use_megatron_llm = False diff --git a/nemo/export/__init__.py b/nemo/export/__init__.py index 55712d98852c..d9155f923f18 100644 --- a/nemo/export/__init__.py +++ b/nemo/export/__init__.py @@ -11,15 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -import logging - -LOGGER = logging.getLogger("NeMo") - - -use_TensorRTLLM = True -try: - from nemo.export.tensorrt_llm import TensorRTLLM -except Exception as e: - LOGGER.warning("TensorRTLLM could not be imported.") diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/sentencepiece_tokenizer.py b/nemo/export/sentencepiece_tokenizer.py similarity index 93% rename from nemo/export/trt_llm/nemo_ckpt_loader/sentencepiece_tokenizer.py rename to nemo/export/sentencepiece_tokenizer.py index 1f86c5887a5e..e47b1c665af5 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/sentencepiece_tokenizer.py +++ b/nemo/export/sentencepiece_tokenizer.py @@ -22,7 +22,7 @@ class SentencePieceTokenizer: """ - Sentencepiecetokenizer https://github.com/google/sentencepiece + SentencePieceTokenizer https://github.com/google/sentencepiece Args: model_path: path to sentence piece tokenizer model. @@ -247,3 +247,21 @@ def vocab(self): for i in range(self.vocab_size - self.original_vocab_size) ] return main_vocab + special_tokens + + ### Below are a few methods that mimic transformers.PreTrainedTokenizer for vLLM + + def convert_ids_to_tokens(self, ids, skip_special_tokens: bool = False): + return self.ids_to_tokens(ids) # TODO: support skip_special_tokens + + def convert_tokens_to_string(self, tokens: List[str]): + return self.tokens_to_text(tokens) + + def __len__(self): + return self.vocab_size + + @property + def is_fast(self): + return True + + def get_added_vocab(self): + return None diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 7cc92f0ca588..8016c352d4b1 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -68,7 +68,7 @@ class TensorRTLLM(ITritonDeployable): Exports nemo checkpoints to TensorRT-LLM and run fast inference. Example: - from nemo.export import TensorRTLLM + from nemo.export.tensorrt_llm import TensorRTLLM trt_llm_exporter = TensorRTLLM(model_dir="/path/for/model/files") trt_llm_exporter.export( @@ -132,6 +132,7 @@ def export( use_embedding_sharing: bool = False, paged_kv_cache: bool = True, remove_input_padding: bool = True, + paged_context_fmha: bool = False, dtype: str = "bfloat16", load_model: bool = True, enable_multi_block_mode: bool = False, @@ -162,6 +163,7 @@ def export( use_parallel_embedding (bool): whether to use parallel embedding feature of TRT-LLM or not use_embedding_sharing (bool): paged_kv_cache (bool): if True, uses kv cache feature of the TensorRT-LLM. + paged_context_fmha (bool): whether to use paged context fmha feature of TRT-LLM or not remove_input_padding (bool): enables removing input padding or not. dtype (str): Floating point type for model weights (Supports BFloat16/Float16). load_model (bool): load TensorRT-LLM model after the export. @@ -295,6 +297,7 @@ def export( enable_multi_block_mode=enable_multi_block_mode, paged_kv_cache=paged_kv_cache, remove_input_padding=remove_input_padding, + paged_context_fmha=paged_context_fmha, max_num_tokens=max_num_tokens, opt_num_tokens=opt_num_tokens, ) diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/__init__.py b/nemo/export/trt_llm/nemo_ckpt_loader/__init__.py index c9c6f65d27e0..d9155f923f18 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/__init__.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/__init__.py @@ -11,6 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - - -from nemo.export.trt_llm.nemo_ckpt_loader.sentencepiece_tokenizer import SentencePieceTokenizer diff --git a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py index 09eae628999a..1d473f497f51 100644 --- a/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py +++ b/nemo/export/trt_llm/nemo_ckpt_loader/nemo_file.py @@ -28,8 +28,8 @@ from torch.distributed.checkpoint import FileSystemReader from transformers import AutoTokenizer, PreTrainedTokenizer +from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer from nemo.export.tarutils import TarPath, ZarrPathStore -from nemo.export.trt_llm.nemo_ckpt_loader.sentencepiece_tokenizer import SentencePieceTokenizer LOGGER = logging.getLogger("NeMo") diff --git a/nemo/export/trt_llm/qnemo/tokenizer_utils.py b/nemo/export/trt_llm/qnemo/tokenizer_utils.py index 4b0775a0aa2a..c3dd5c2befc9 100644 --- a/nemo/export/trt_llm/qnemo/tokenizer_utils.py +++ b/nemo/export/trt_llm/qnemo/tokenizer_utils.py @@ -17,7 +17,7 @@ from omegaconf import OmegaConf from transformers import AutoTokenizer -from nemo.export.trt_llm.nemo_ckpt_loader.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer # TODO: use get_nmt_tokenizer helper below to instantiate tokenizer once environment / dependencies get stable # from nemo.collections.nlp.modules.common.tokenizer_utils import get_nmt_tokenizer diff --git a/nemo/export/trt_llm/tensorrt_llm_build.py b/nemo/export/trt_llm/tensorrt_llm_build.py index ef9a14c1d582..f73ac309a475 100644 --- a/nemo/export/trt_llm/tensorrt_llm_build.py +++ b/nemo/export/trt_llm/tensorrt_llm_build.py @@ -44,6 +44,7 @@ def build_and_save_engine( enable_multi_block_mode: bool = False, paged_kv_cache: bool = True, remove_input_padding: bool = True, + paged_context_fmha: bool = False, max_num_tokens: int = None, opt_num_tokens: int = None, max_beam_width: int = 1, @@ -65,6 +66,7 @@ def build_and_save_engine( else: plugin_config.paged_kv_cache = False plugin_config.remove_input_padding = remove_input_padding + plugin_config.use_paged_context_fmha = paged_context_fmha max_num_tokens, opt_num_tokens = check_max_num_tokens( max_num_tokens=max_num_tokens, diff --git a/nemo/export/vllm/__init__.py b/nemo/export/vllm/__init__.py new file mode 100644 index 000000000000..d9155f923f18 --- /dev/null +++ b/nemo/export/vllm/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/nemo/export/vllm/engine.py b/nemo/export/vllm/engine.py new file mode 100644 index 000000000000..0a3600e7b1eb --- /dev/null +++ b/nemo/export/vllm/engine.py @@ -0,0 +1,101 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +from pathlib import Path + +from vllm import LLMEngine +from vllm.transformers_utils.tokenizer_group.tokenizer_group import TokenizerGroup + +from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.export.tarutils import TarPath +from nemo.export.vllm.tokenizer_group import NemoTokenizerGroup + +LOGGER = logging.getLogger("NeMo") + + +class NemoLLMEngine(LLMEngine): + """ + Overrides some functionality from vllm.LLMEngine to use our custom tokenizer + instead of one from Transformers. + """ + + def _init_tokenizer(self, **tokenizer_init_kwargs): + # Find the tokenizer file name in the Nemo checkpoint config + tokenizer_config = self.model_config.nemo_model_config.get('tokenizer', {}) + tokenizer_model = tokenizer_config.get('model', tokenizer_config.get('tokenizer_model', None)) + + # If there is no tokenizer file specified but there's a reference to an HF tokenizer, use that + if tokenizer_model is None and tokenizer_config.get('library') == 'huggingface': + tokenizer_type = tokenizer_config.get('type') + if tokenizer_type is not None: + tokenizer_group = TokenizerGroup( + tokenizer_id=tokenizer_type, + enable_lora=bool(self.lora_config), + max_num_seqs=self.scheduler_config.max_num_seqs, + max_input_length=None, + ) + + # Update the HF config fields that come from the tokenizer in NeMo + self.model_config.hf_config.vocab_size = tokenizer_group.tokenizer.vocab_size + self.model_config.hf_config.bos_token_id = tokenizer_group.tokenizer.bos_token_id + self.model_config.hf_config.eos_token_id = tokenizer_group.tokenizer.eos_token_id + self.model_config.hf_config.pad_token_id = tokenizer_group.tokenizer.pad_token_id + + return tokenizer_group + + # Open the checkpoint archive + with TarPath(self.model_config.nemo_checkpoint) as archive: + tokenizer_model_file = None + if isinstance(tokenizer_model, str) and tokenizer_model.startswith('nemo:'): + tokenizer_model = tokenizer_model[len('nemo:') :] + tokenizer_model_file = archive / tokenizer_model + if not tokenizer_model_file.exists(): + LOGGER.warn( + f'Tokenizer model file {tokenizer_model} specified in the model_config does not ' + + 'exist in the checkpoint.' + ) + tokenizer_model_file = None + + if tokenizer_model_file is None: + for path in archive.glob('*tokenizer*.model'): + LOGGER.info(f'Found tokenizer model file {path}.') + tokenizer_model_file = path + break + + if tokenizer_model_file is None: + raise RuntimeError('No tokenizer model file found, aborting.') + + # Extract the tokenizer model file into the model directory, + # because sentencepiece cannot load it directly from TarPath. + extracted_tokenizer_model = Path(self.model_config.model) / 'tokenizer.model' + with tokenizer_model_file.open('rb') as infile: + with extracted_tokenizer_model.open('wb') as outfile: + outfile.write(infile.read()) + + # Construct the tokenizer object and wrapper + tokenizer = SentencePieceTokenizer(str(extracted_tokenizer_model)) + + # Determine if the model needs a bos token (which is not stored in Nemo checkpoints) + add_bos_token = self.model_config.model_converter.requires_bos_token() + + tokenizer_group = NemoTokenizerGroup(tokenizer, add_bos_token=add_bos_token) + + # Update the HF config fields that come from the tokenizer in NeMo + self.model_config.hf_config.vocab_size = tokenizer.vocab_size + self.model_config.hf_config.bos_token_id = tokenizer.bos_token_id + self.model_config.hf_config.eos_token_id = tokenizer.eos_token_id + self.model_config.hf_config.pad_token_id = tokenizer.pad_id + + return tokenizer_group diff --git a/nemo/export/vllm/model_config.py b/nemo/export/vllm/model_config.py new file mode 100644 index 000000000000..0a98a9180c1d --- /dev/null +++ b/nemo/export/vllm/model_config.py @@ -0,0 +1,135 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Union + +import torch +import yaml +from transformers import AutoConfig +from vllm.config import ModelConfig, _get_and_verify_dtype, _get_and_verify_max_len +from vllm.transformers_utils.config import get_hf_text_config + +from nemo.export.tarutils import TarPath +from nemo.export.vllm.model_converters import get_model_converter + + +class NemoModelConfig(ModelConfig): + """ + This class pretents to be a vllm.config.ModelConfig (with extra fields) but skips + some of its initialization code, and initializes the configuration from a Nemo checkpoint instead. + """ + + def __init__( + self, + nemo_checkpoint: str, + model_dir: str, + model_type: str, + tokenizer_mode: str, + dtype: Union[str, torch.dtype], + seed: int, + revision: Optional[str] = None, + code_revision: Optional[str] = None, + rope_scaling: Optional[dict] = None, + rope_theta: Optional[float] = None, + tokenizer_revision: Optional[str] = None, + max_model_len: Optional[int] = None, + quantization: Optional[str] = None, + quantization_param_path: Optional[str] = None, + enforce_eager: bool = False, + max_seq_len_to_capture: Optional[int] = None, + max_logprobs: int = 5, + disable_sliding_window: bool = False, + ) -> None: + # Don't call ModelConfig.__init__ because we don't want it to call + # transformers.AutoConfig.from_pretrained(...) + + # TODO: Do something about vLLM's call to _load_generation_config_dict in LLMEngine.__init__ + # because it calls transformers.GenerationConfig.from_pretrained(...), which tries to download things + + self.nemo_checkpoint = nemo_checkpoint + self.model = model_dir + self.model_type = model_type + self.tokenizer = None + self.tokenizer_mode = tokenizer_mode + self.skip_tokenizer_init = False + self.trust_remote_code = False + self.seed = seed + self.revision = revision + self.code_revision = code_revision + self.rope_scaling = rope_scaling + self.rope_theta = rope_theta + self.tokenizer_revision = tokenizer_revision + self.quantization = quantization + self.quantization_param_path = quantization_param_path + self.enforce_eager = enforce_eager + self.max_seq_len_to_capture = max_seq_len_to_capture + self.max_logprobs = max_logprobs + self.disable_sliding_window = disable_sliding_window + self.served_model_name = nemo_checkpoint + + self.model_converter = get_model_converter(model_type) + if self.model_converter is None: + raise RuntimeError(f'Unknown model type "{model_type}"') + + hf_to_nemo_dict = { + 'hidden_size': 'hidden_size', + 'intermediate_size': 'ffn_hidden_size', + 'num_hidden_layers': 'num_layers', + 'num_attention_heads': 'num_attention_heads', + 'num_key_value_heads': 'num_query_groups', + # 'hidden_act': 'activation', ## <- vLLM has good defaults for the models, nemo values are wrong + 'max_position_embeddings': ['max_position_embeddings', 'encoder_seq_length'], + 'rms_norm_eps': 'layernorm_epsilon', + 'attention_dropout': 'attention_dropout', + 'initializer_range': 'init_method_std', + 'norm_epsilon': 'layernorm_epsilon', + 'rope_theta': 'rotary_base', + 'use_bias': 'bias', + } + + with TarPath(nemo_checkpoint) as archive: + with (archive / "model_config.yaml").open("r") as model_config_file: + self.nemo_model_config = yaml.load(model_config_file, Loader=yaml.SafeLoader) + + hf_args = {} + for hf_arg, nemo_arg in hf_to_nemo_dict.items(): + if not isinstance(nemo_arg, list): + nemo_arg = [nemo_arg] + + for nemo_arg_option in nemo_arg: + value = self.nemo_model_config.get(nemo_arg_option) + if value is not None: + hf_args[hf_arg] = value + break + + self.model_converter.convert_config(self.nemo_model_config, hf_args) + + self.hf_config = AutoConfig.for_model(model_type, **hf_args) + + self.hf_config.architectures = [self.model_converter.get_architecture()] + if self.rope_scaling is not None: + self.hf_config['rope_scaling'] = rope_scaling + + self.hf_text_config = get_hf_text_config(self.hf_config) + self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) + self.max_model_len = _get_and_verify_max_len( + hf_config=self.hf_text_config, + max_model_len=max_model_len, + disable_sliding_window=self.disable_sliding_window, + sliding_window_len=self.get_hf_config_sliding_window(), + ) + self._verify_tokenizer_mode() + self._verify_embedding_mode() + self._verify_quantization() + self._verify_cuda_graph() diff --git a/nemo/export/vllm/model_converters.py b/nemo/export/vllm/model_converters.py new file mode 100644 index 000000000000..595ceecf0b18 --- /dev/null +++ b/nemo/export/vllm/model_converters.py @@ -0,0 +1,410 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from typing import Optional, Sequence, Tuple + +import torch + + +class ModelConverter(ABC): + """ + Abstract class that defines the interface for a converter that implements model-specific conversion functions + for deploying NeMo checkpoints on vLLM. + """ + + def __init__(self, model_type: str): + self.model_type = model_type + + @abstractmethod + def get_architecture(self) -> Optional[str]: + """ + Returns the HF architecture name for the current model, such as 'LlamaForCausalLM'. + """ + pass + + def convert_config(self, nemo_model_config: dict, hf_config: dict) -> None: + """ + Implements any custom HF configuration adjustments in the 'hf_config' dict that are necessary + for this model after the common translation takes place in NemoModelConfig's constructor. + """ + pass + + @abstractmethod + def convert_weights(self, nemo_model_config: dict, state_dict: dict) -> Sequence[Tuple[str, torch.tensor]]: + """ + Returns or yields a sequence of (name, tensor) tuples that contain model weights in the HF format. + """ + pass + + def requires_bos_token(self) -> bool: + """ + Returns True if the model requires a 'bos' token to be used at the beginning of the input sequence. + NeMo checkpoints do not store this information. + """ + return False + + +class LlamaConverter(ModelConverter): + + def get_architecture(self): + if self.model_type == 'llama': + return 'LlamaForCausalLM' + if self.model_type == 'mistral': + return 'MistralForCausalLM' + return None + + def convert_weights(self, nemo_model_config, state_dict): + hidden_size = nemo_model_config["hidden_size"] + head_num = nemo_model_config["num_attention_heads"] + num_query_groups = nemo_model_config["num_query_groups"] + num_layers = nemo_model_config["num_layers"] + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + + yield ('model.embed_tokens.weight', state_dict['model.embedding.word_embeddings.weight']) + yield ('model.norm.weight', state_dict['model.decoder.final_layernorm.weight']) + yield ('lm_head.weight', state_dict['model.output_layer.weight']) + + for layer in range(int(num_layers)): + qkv_weights = state_dict['model.decoder.layers.self_attention.linear_qkv.weight'][layer] + qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + for name, slice in [('q_proj', q_slice), ('k_proj', k_slice), ('v_proj', v_slice)]: + weight_name = f'model.layers.{layer}.self_attn.{name}.weight' + yield (weight_name, qkv_weights[slice].reshape(-1, hidden_size)) + + linear_proj_weight = state_dict['model.decoder.layers.self_attention.linear_proj.weight'][layer] + yield (f'model.layers.{layer}.self_attn.o_proj.weight', linear_proj_weight) + + gate_proj_weight, up_proj_weight = torch.chunk( + state_dict['model.decoder.layers.mlp.linear_fc1.weight'][layer], 2, dim=0 + ) + yield (f'model.layers.{layer}.mlp.gate_proj.weight', gate_proj_weight) + yield (f'model.layers.{layer}.mlp.up_proj.weight', up_proj_weight) + + mlp_up_weight = state_dict['model.decoder.layers.mlp.linear_fc2.weight'][layer] + yield (f'model.layers.{layer}.mlp.down_proj.weight', mlp_up_weight) + + input_layernorm_weight = state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_weight'][ + layer + ] + yield (f'model.layers.{layer}.input_layernorm.weight', input_layernorm_weight) + + post_attn_layernorm_weight = state_dict['model.decoder.layers.mlp.linear_fc1.layer_norm_weight'][layer] + yield (f'model.layers.{layer}.post_attention_layernorm.weight', post_attn_layernorm_weight) + + def requires_bos_token(self): + return True + + +class MixtralConverter(ModelConverter): + + def get_architecture(self): + if self.model_type == 'mixtral': + return 'MixtralForCausalLM' + return None + + def convert_weights(self, nemo_model_config, state_dict): + hidden_size = nemo_model_config["hidden_size"] + head_num = nemo_model_config["num_attention_heads"] + num_query_groups = nemo_model_config["num_query_groups"] + num_layers = nemo_model_config["num_layers"] + num_moe_experts = nemo_model_config["num_moe_experts"] + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + + yield ('model.embed_tokens.weight', state_dict['model.embedding.word_embeddings.weight']) + yield ('model.norm.weight', state_dict['model.decoder.final_layernorm.weight']) + yield ('lm_head.weight', state_dict['model.output_layer.weight']) + + for layer in range(int(num_layers)): + qkv_weights = state_dict['model.decoder.layers.self_attention.linear_qkv.weight'][layer] + qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + for name, slice in [('q_proj', q_slice), ('k_proj', k_slice), ('v_proj', v_slice)]: + weight_name = f'model.layers.{layer}.self_attn.{name}.weight' + yield (weight_name, qkv_weights[slice].reshape(-1, hidden_size)) + + linear_proj_weight = state_dict['model.decoder.layers.self_attention.linear_proj.weight'][layer] + yield (f'model.layers.{layer}.self_attn.o_proj.weight', linear_proj_weight) + + mlp_router_weight = state_dict['model.decoder.layers.mlp.router.weight'][layer] + yield (f'model.layers.{layer}.block_sparse_moe.gate.weight', mlp_router_weight) + + for expert in range(num_moe_experts): + linear_fc1_weight = state_dict['model.decoder.layers.mlp.experts.experts.linear_fc1.weight'][layer][ + expert + ] + gate_proj_weight, up_proj_weight = torch.chunk(linear_fc1_weight, 2, dim=0) + yield (f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w1.weight', gate_proj_weight) + yield (f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w3.weight', up_proj_weight) + + linear_fc2_weight = state_dict['model.decoder.layers.mlp.experts.experts.linear_fc2.weight'][layer][ + expert + ] + yield (f'model.layers.{layer}.block_sparse_moe.experts.{expert}.w2.weight', linear_fc2_weight) + + input_layernorm_weight = state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_weight'][ + layer + ] + yield (f'model.layers.{layer}.input_layernorm.weight', input_layernorm_weight) + + post_attn_layernorm_weight = state_dict['model.decoder.layers.pre_mlp_layernorm.weight'][layer] + yield (f'model.layers.{layer}.post_attention_layernorm.weight', post_attn_layernorm_weight) + + def requires_bos_token(self): + return True + + +class GemmaConverter(ModelConverter): + + def get_architecture(self): + if self.model_type == 'gemma': + return 'GemmaForCausalLM' + return None + + def convert_weights(self, nemo_model_config, state_dict): + num_layers = nemo_model_config["num_layers"] + num_query_groups = nemo_model_config["num_query_groups"] + head_num = nemo_model_config["num_attention_heads"] + head_size = nemo_model_config["kv_channels"] + hidden_size = nemo_model_config["hidden_size"] + heads_per_group = head_num // num_query_groups + + yield ('model.embed_tokens.weight', state_dict['model.embedding.word_embeddings.weight']) + + final_layernorm_weight = state_dict['model.decoder.final_layernorm.weight'] + final_layernorm_weight -= 1.0 + yield ('model.norm.weight', final_layernorm_weight) + + for layer in range(int(num_layers)): + input_layernorm_weight = state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_weight'][ + layer + ] + input_layernorm_weight -= 1.0 + yield (f'model.layers.{layer}.input_layernorm.weight', input_layernorm_weight) + + post_attention_layernorm_weight = state_dict['model.decoder.layers.mlp.linear_fc1.layer_norm_weight'][ + layer + ] + post_attention_layernorm_weight -= 1.0 + yield (f'model.layers.{layer}.post_attention_layernorm.weight', post_attention_layernorm_weight) + + gate_up_combined_weight = state_dict['model.decoder.layers.mlp.linear_fc1.weight'][layer] + gate_size = gate_up_combined_weight.shape[0] // 2 + yield (f'model.layers.{layer}.mlp.gate_proj.weight', gate_up_combined_weight[:gate_size, :]) + yield (f'model.layers.{layer}.mlp.up_proj.weight', gate_up_combined_weight[gate_size:, :]) + + down_proj_weight = state_dict['model.decoder.layers.mlp.linear_fc2.weight'][layer] + yield (f'model.layers.{layer}.mlp.down_proj.weight', down_proj_weight) + + self_attn_o_proj_weight = state_dict['model.decoder.layers.self_attention.linear_proj.weight'][layer] + yield (f'model.layers.{layer}.self_attn.o_proj.weight', self_attn_o_proj_weight) + + qkv_weight = state_dict['model.decoder.layers.self_attention.linear_qkv.weight'][layer] + qkv_intermediate_size = head_num + 2 * num_query_groups + qkv_weight = qkv_weight.reshape(qkv_intermediate_size, head_size, hidden_size) + + q_weight = torch.empty((head_num, head_size, hidden_size), dtype=qkv_weight.dtype) + k_weight = torch.empty((num_query_groups, head_size, hidden_size), dtype=qkv_weight.dtype) + v_weight = torch.empty((num_query_groups, head_size, hidden_size), dtype=qkv_weight.dtype) + + ptr = 0 + for i in range(num_query_groups): + q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :] = qkv_weight[ + ptr : ptr + heads_per_group, :: + ] + ptr += heads_per_group + k_weight[i : i + 1, :, :] = qkv_weight[ptr : ptr + 1, :, :] + ptr += 1 + v_weight[i : i + 1, :, :] = qkv_weight[ptr : ptr + 1, :, :] + ptr += 1 + assert ptr == qkv_intermediate_size + + q_weight = q_weight.reshape(head_num * head_size, hidden_size) + k_weight = k_weight.reshape(num_query_groups * head_size, hidden_size) + v_weight = v_weight.reshape(num_query_groups * head_size, hidden_size) + + yield (f'model.layers.{layer}.self_attn.q_proj.weight', q_weight) + yield (f'model.layers.{layer}.self_attn.k_proj.weight', k_weight) + yield (f'model.layers.{layer}.self_attn.v_proj.weight', v_weight) + + def requires_bos_token(self): + return True + + +class Starcoder2Converter(ModelConverter): + + def get_architecture(self): + if self.model_type == 'starcoder2': + return 'Starcoder2ForCausalLM' + return None + + def convert_config(self, nemo_model_config, hf_config): + window_sizes = nemo_model_config.get('window_size') + if window_sizes is not None: + hf_config['sliding_window'] = window_sizes[0] + + # 'tie_word_embeddings = False' means that there is a 'lm_head.weight' tensor. + # This converter assumes that it's always there. + # If there is a version of starcoder2 where it's not there, we'll need to copy + # 'model.embed_tokens.weight' into 'lm_head.weight' and still set 'tie_word_embeddings = False' + # because at this point we don't know if the weight is there or not, and this configuration + # is not stored in NeMo checkpoints. + hf_config['tie_word_embeddings'] = False + + def convert_weights(self, nemo_model_config, state_dict): + num_layers = nemo_model_config["num_layers"] + num_query_groups = nemo_model_config["num_query_groups"] + head_num = nemo_model_config["num_attention_heads"] + hidden_size = nemo_model_config["hidden_size"] + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + qkv_total_dim = head_num + 2 * num_query_groups + has_bias = nemo_model_config["bias"] + + yield ('model.embed_tokens.weight', state_dict['model.embedding.word_embeddings.weight']) + + yield ('model.norm.weight', state_dict['model.decoder.final_layernorm.weight']) + if has_bias: + yield ('model.norm.bias', state_dict['model.decoder.final_layernorm.bias']) + + yield ('lm_head.weight', state_dict['model.output_layer.weight']) + + for layer in range(int(num_layers)): + # q,k,v + qkv_weights = state_dict['model.decoder.layers.self_attention.linear_qkv.weight'][layer] + qkv_weights = qkv_weights.reshape([qkv_total_dim, head_size, hidden_size]) + if has_bias: + qkv_bias = state_dict['model.decoder.layers.self_attention.linear_qkv.bias'][layer] + qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size]) + + q_slice = torch.cat( + [ + torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) + for i in range(num_query_groups) + ] + ) + k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) + v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) + + for name, slice in [('q_proj', q_slice), ('k_proj', k_slice), ('v_proj', v_slice)]: + qkv_weights_slice = qkv_weights[slice].reshape(-1, hidden_size) + yield (f'model.layers.{layer}.self_attn.{name}.weight', qkv_weights_slice) + if has_bias: + qkv_bias_slice = qkv_bias[slice].reshape(-1) + yield (f'model.layers.{layer}.self_attn.{name}.bias', qkv_bias_slice) + + # Attention dense + yield ( + f'model.layers.{layer}.self_attn.o_proj.weight', + state_dict[f'model.decoder.layers.self_attention.linear_proj.weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.self_attn.o_proj.bias', + state_dict['model.decoder.layers.self_attention.linear_proj.bias'][layer], + ) + + # MLP FC1 + yield ( + f'model.layers.{layer}.mlp.c_fc.weight', + state_dict['model.decoder.layers.mlp.linear_fc1.weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.mlp.c_fc.bias', + state_dict['model.decoder.layers.mlp.linear_fc1.bias'][layer], + ) + + # MLP FC2 + yield ( + f'model.layers.{layer}.mlp.c_proj.weight', + state_dict['model.decoder.layers.mlp.linear_fc2.weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.mlp.c_proj.bias', + state_dict['model.decoder.layers.mlp.linear_fc2.bias'][layer], + ) + + # Input LayerNorm + yield ( + f'model.layers.{layer}.input_layernorm.weight', + state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.input_layernorm.bias', + state_dict['model.decoder.layers.self_attention.linear_qkv.layer_norm_bias'][layer], + ) + + # Post-attention LayerNorm + yield ( + f'model.layers.{layer}.post_attention_layernorm.weight', + state_dict['model.decoder.layers.mlp.linear_fc1.layer_norm_weight'][layer], + ) + if has_bias: + yield ( + f'model.layers.{layer}.post_attention_layernorm.bias', + state_dict['model.decoder.layers.mlp.linear_fc1.layer_norm_bias'][layer], + ) + + +_MODEL_CONVERTERS = { + 'llama': LlamaConverter, + 'mistral': LlamaConverter, + 'mixtral': MixtralConverter, + 'gemma': GemmaConverter, + 'starcoder2': Starcoder2Converter, +} + + +def register_model_converter(model_type, cls): + """ + Establishes a mapping from short model type to a class that converts the model from Nemo format + to a vLLM compatible format. + """ + _MODEL_CONVERTERS[model_type] = cls + + +def get_model_converter(model_type) -> ModelConverter: + """ + Returns an instance of the the model conversion class for the given model type, or None. + """ + cls = _MODEL_CONVERTERS.get(model_type, None) + if cls is None: + return None + return cls(model_type) diff --git a/nemo/export/vllm/model_loader.py b/nemo/export/vllm/model_loader.py new file mode 100644 index 000000000000..e7f3f1d1569f --- /dev/null +++ b/nemo/export/vllm/model_loader.py @@ -0,0 +1,120 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import logging +import os.path +from typing import Optional + +import numpy +import safetensors.torch +import tensorstore # needed to register 'bfloat16' dtype with numpy for zarr compatibility +import torch +import zarr +from vllm.config import CacheConfig, DeviceConfig, LoRAConfig, ParallelConfig, SchedulerConfig, VisionLanguageConfig +from vllm.model_executor.model_loader.loader import BaseModelLoader, _initialize_model +from vllm.model_executor.model_loader.utils import set_default_torch_dtype + +from nemo.export.tarutils import TarPath, ZarrPathStore +from nemo.export.vllm.model_config import NemoModelConfig + +LOGGER = logging.getLogger("NeMo") + + +class NemoModelLoader(BaseModelLoader): + """ + Implements a custom ModelLoader for vLLM that reads the weights from a Nemo checkpoint + and converts them to a vLLM compatible format at load time. + + Also supports an ahead-of-time conversion that stores new weights in a Safetensors file, + see convert_and_store_nemo_weights(...) + """ + + @staticmethod + def _load_nemo_checkpoint_state(nemo_file: str): + sharded_state_dict = {} + + LOGGER.info(f'Loading weights from {nemo_file}...') + + with TarPath(nemo_file) as archive: + for subdir in archive.iterdir(): + if not subdir.is_dir() or not (subdir / '.zarray').exists(): + continue + key = subdir.name + + zstore = ZarrPathStore(subdir) + arr = zarr.open(zstore, 'r') + + if arr.dtype.name == "bfloat16": + sharded_state_dict[key] = torch.from_numpy(arr[:].view(numpy.int16)).view(torch.bfloat16) + else: + sharded_state_dict[key] = torch.from_numpy(arr[:]) + + arr = None + gc.collect() + + LOGGER.debug(f'Loaded tensor "{key}": {sharded_state_dict[key].shape}') + + return sharded_state_dict + + def load_model( + self, + *, + model_config: NemoModelConfig, + device_config: DeviceConfig, + lora_config: Optional[LoRAConfig], + vision_language_config: Optional[VisionLanguageConfig], + parallel_config: ParallelConfig, + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + ) -> torch.nn.Module: + """ + Overrides the load_model function from BaseModelLoader to convert Nemo weights at load time. + """ + + assert isinstance(model_config, NemoModelConfig) + state_dict = NemoModelLoader._load_nemo_checkpoint_state(model_config.nemo_checkpoint) + + with set_default_torch_dtype(model_config.dtype): + with torch.device(device_config.device): + model = _initialize_model( + model_config, self.load_config, lora_config, vision_language_config, cache_config + ) + + weights_iterator = model_config.model_converter.convert_weights(model_config.nemo_model_config, state_dict) + + model.load_weights(weights_iterator) + + return model.eval() + + @staticmethod + def convert_and_store_nemo_weights(model_config: NemoModelConfig, safetensors_file: str): + """ + Converts Nemo weights and stores the converted weights in a Safetensors file. + """ + + assert isinstance(model_config, NemoModelConfig) + assert os.path.exists(model_config.model) + + state_dict = NemoModelLoader._load_nemo_checkpoint_state(model_config.nemo_checkpoint) + + tensors = { + name: tensor + for name, tensor in model_config.model_converter.convert_weights( + model_config.nemo_model_config, state_dict + ) + } + + LOGGER.info(f'Saving weights to {safetensors_file}...') + safetensors.torch.save_file(tensors, safetensors_file) diff --git a/nemo/export/vllm/tokenizer_group.py b/nemo/export/vllm/tokenizer_group.py new file mode 100644 index 000000000000..6e4aedc14acb --- /dev/null +++ b/nemo/export/vllm/tokenizer_group.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from vllm.lora.request import LoRARequest +from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import BaseTokenizerGroup + +from nemo.export.sentencepiece_tokenizer import SentencePieceTokenizer + + +class NemoTokenizerGroup(BaseTokenizerGroup): + """ + Implements a custom tokenizer for vLLM, based on SentencePieceTokenizer. + """ + + def __init__(self, tokenizer: SentencePieceTokenizer, add_bos_token: bool = False): + self.tokenizer = tokenizer + self.add_bos_token = add_bos_token + + def ping(self) -> bool: + return True + + def get_max_input_len(self, lora_request: Optional[LoRARequest] = None) -> Optional[int]: + return None + + def encode( + self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None + ) -> List[int]: + ids = self.tokenizer.encode(prompt) + if self.add_bos_token: + ids = [self.tokenizer.bos_token_id] + ids + return ids + + async def encode_async( + self, prompt: str, request_id: Optional[str] = None, lora_request: Optional[LoRARequest] = None + ) -> List[int]: + return self.tokenizer.encode(prompt) # TODO: not sure how this is supposed to work + + def get_lora_tokenizer(self, lora_request: Optional[LoRARequest] = None) -> SentencePieceTokenizer: + return self.tokenizer + + async def get_lora_tokenizer_async(self, lora_request: Optional[LoRARequest] = None) -> SentencePieceTokenizer: + return self.tokenizer diff --git a/nemo/export/vllm_exporter.py b/nemo/export/vllm_exporter.py new file mode 100644 index 000000000000..f3dd6c8a248b --- /dev/null +++ b/nemo/export/vllm_exporter.py @@ -0,0 +1,417 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os.path +from typing import Iterable, List, Optional, Union + +import numpy +import wrapt +from vllm import RequestOutput, SamplingParams +from vllm.config import CacheConfig, DeviceConfig, LoadConfig, LoadFormat, ParallelConfig, SchedulerConfig +from vllm.executor.ray_utils import initialize_ray_cluster + +from nemo.deploy import ITritonDeployable +from nemo.deploy.utils import cast_output +from nemo.export.vllm.engine import NemoLLMEngine +from nemo.export.vllm.model_config import NemoModelConfig +from nemo.export.vllm.model_loader import NemoModelLoader + +LOGGER = logging.getLogger("NeMo") + + +@wrapt.decorator +def noop_decorator(func): + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + + return wrapper + + +use_pytriton = True +try: + from pytriton.decorators import batch + from pytriton.model_config import Tensor +except Exception: + use_pytriton = False + + +class vLLMExporter(ITritonDeployable): + """ + The Exporter class implements conversion from a Nemo checkpoint format to something compatible with vLLM, + loading the model in vLLM, and binding that model to a Triton server. + + Example: + from nemo.export.vllm import Exporter + from nemo.deploy import DeployPyTriton + + exporter = Exporter() + exporter.export( + nemo_checkpoint='/path/to/checkpoint.nemo', + model_dir='/path/to/temp_dir', + model_type='llama') + + server = DeployPyTriton( + model=exporter, + triton_model_name='LLAMA') + + server.deploy() + server.serve() + server.stop() + """ + + def __init__(self): + self.request_id = 0 + + def export( + self, + nemo_checkpoint: str, + model_dir: str, + model_type: str, + device: str = 'auto', + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + max_model_len: int = None, + dtype: str = 'auto', + seed: int = 0, + log_stats: bool = True, + weight_storage: str = 'auto', + gpu_memory_utilization: float = 0.9, + ): + """ + Exports the Nemo checkpoint to vLLM and initializes the engine. + + Args: + nemo_checkpoint (str): path to the nemo checkpoint. + model_dir (str): path to a temporary directory to store weights and the tokenizer model. + The temp dir may persist between subsequent export operations, in which case + converted weights may be reused to speed up the export. + model_type (str): type of the model, such as "llama", "mistral", "mixtral". + Needs to be compatible with transformers.AutoConfig. + device (str): type of the device to use by the vLLM engine. + Supported values are "auto", "cuda", "cpu", "neuron". + tensor_parallel_size (int): tensor parallelism. + pipeline_parallel_size (int): pipeline parallelism. + Values over 1 are not currently supported by vLLM. + max_model_len (int): model context length. + dtype (str): data type for model weights and activations. + Possible choices: auto, half, float16, bfloat16, float, float32 + "auto" will use FP16 precision for FP32 and FP16 models, + and BF16 precision for BF16 models. + seed (int): random seed value. + log_stats (bool): enables logging inference performance statistics by vLLM. + weight_storage (str): controls how converted weights are stored: + "file" - always write weights into a file inside 'model_dir', + "memory" - always do an in-memory conversion, + "cache" - reuse existing files if they are newer than the nemo checkpoint, + "auto" - use "cache" for multi-GPU runs and "memory" for single-GPU runs. + gpu_memory_utilization (float): The fraction of GPU memory to be used for the model + executor, which can range from 0 to 1. + """ + + # Pouplate the basic configuration structures + device_config = DeviceConfig(device) + + model_config = NemoModelConfig( + nemo_checkpoint, + model_dir, + model_type, + tokenizer_mode='auto', + dtype=dtype, + seed=seed, + revision=None, + code_revision=None, + tokenizer_revision=None, + max_model_len=max_model_len, + quantization=None, # TODO ??? + quantization_param_path=None, + enforce_eager=False, + max_seq_len_to_capture=None, + ) + + parallel_config = ParallelConfig( + pipeline_parallel_size=pipeline_parallel_size, tensor_parallel_size=tensor_parallel_size + ) + + # See if we have an up-to-date safetensors file + safetensors_file = os.path.join(model_config.model, 'model.safetensors') + safetensors_file_valid = os.path.exists(safetensors_file) and os.path.getmtime( + safetensors_file + ) > os.path.getmtime(nemo_checkpoint) + + # Decide how we're going to convert the weights + if weight_storage == 'auto': + if parallel_config.distributed_executor_backend is not None: + save_weights = not safetensors_file_valid + inmemory_weight_conversion = False + else: + save_weights = False + inmemory_weight_conversion = True + + elif weight_storage == 'cache': + save_weights = not safetensors_file_valid + inmemory_weight_conversion = False + + elif weight_storage == 'file': + save_weights = True + inmemory_weight_conversion = False + + elif weight_storage == 'memory': + save_weights = False + inmemory_weight_conversion = True + + else: + raise ValueError(f'Unsupported value for weight_storage: "{weight_storage}"') + + # Convert the weights ahead-of-time, if needed + if save_weights: + NemoModelLoader.convert_and_store_nemo_weights(model_config, safetensors_file) + elif not inmemory_weight_conversion: + LOGGER.info(f'Using cached weights in {safetensors_file}') + + # TODO: these values are the defaults from vllm.EngineArgs. + cache_config = CacheConfig( + block_size=16, + gpu_memory_utilization=gpu_memory_utilization, + swap_space=4, + cache_dtype='auto', + sliding_window=model_config.get_sliding_window(), + ) + + # TODO: these values are the defaults from vllm.EngineArgs. + scheduler_config = SchedulerConfig( + max_num_batched_tokens=None, + max_num_seqs=256, + # Note: max_model_len can be derived by model_config if the input value is None + max_model_len=model_config.max_model_len, + use_v2_block_manager=False, + num_lookahead_slots=0, + delay_factor=0.0, + enable_chunked_prefill=False, + ) + + load_config = LoadConfig( + load_format=NemoModelLoader if inmemory_weight_conversion else LoadFormat.SAFETENSORS, + download_dir=None, + model_loader_extra_config=None, + ) + + # Initialize the cluster and specify the executor class. + if device_config.device_type == "neuron": + from vllm.executor.neuron_executor import NeuronExecutor + + executor_class = NeuronExecutor + elif device_config.device_type == "cpu": + from vllm.executor.cpu_executor import CPUExecutor + + executor_class = CPUExecutor + elif parallel_config.distributed_executor_backend == "ray": + initialize_ray_cluster(parallel_config) + from vllm.executor.ray_gpu_executor import RayGPUExecutor + + executor_class = RayGPUExecutor + elif parallel_config.distributed_executor_backend == "mp": + from vllm.executor.multiproc_gpu_executor import MultiprocessingGPUExecutor + + executor_class = MultiprocessingGPUExecutor + else: + assert parallel_config.world_size == 1, "Ray is required if parallel_config.world_size > 1." + from vllm.executor.gpu_executor import GPUExecutor + + executor_class = GPUExecutor + + # Initialize the engine + self.engine = NemoLLMEngine( + model_config=model_config, + cache_config=cache_config, + parallel_config=parallel_config, + scheduler_config=scheduler_config, + device_config=device_config, + load_config=load_config, + lora_config=None, + vision_language_config=None, + speculative_config=None, + decoding_config=None, + executor_class=executor_class, + log_stats=log_stats, + ) + + def _add_request_to_engine( + self, prompt: str, max_output_len: int, temperature: float = 1.0, top_k: int = 1, top_p: float = 0.0 + ) -> str: + if top_p <= 0.0: + top_p = 1.0 + + sampling_params = SamplingParams(max_tokens=max_output_len, temperature=temperature, top_k=top_k, top_p=top_p) + + request_id = str(self.request_id) + self.request_id += 1 + + self.engine.add_request(request_id, prompt, sampling_params) + + return request_id + + def _forward_regular(self, request_ids: List[str]): + responses = [None] * len(request_ids) + finished = [False] * len(request_ids) + + while not all(finished): + request_outputs: List[RequestOutput] = self.engine.step() + + for request_output in request_outputs: + if not request_output.finished: + continue + + try: + request_index = request_ids.index(request_output.request_id) + except ValueError: + continue + + finished[request_index] = request_output.finished + output_text = request_output.outputs[-1].text + responses[request_index] = output_text + + return [[response] for response in responses] + + def _forward_streaming(self, request_ids: List[str]): + responses = [None] * len(request_ids) + finished = [False] * len(request_ids) + + while not all(finished): + request_outputs: List[RequestOutput] = self.engine.step() + + for request_output in request_outputs: + try: + request_index = request_ids.index(request_output.request_id) + except ValueError: + continue + + finished[request_index] = request_output.finished + output_text = request_output.outputs[-1].text + responses[request_index] = output_text + + yield [[response] for response in responses] + + def _add_triton_request_to_engine(self, inputs: numpy.ndarray, index: int) -> str: + return self._add_request_to_engine( + prompt=inputs['prompts'][index][0].decode('UTF-8'), + max_output_len=inputs['max_output_len'][index][0], + temperature=inputs['temperature'][index][0], + top_k=inputs['top_k'][index][0], + top_p=inputs['top_p'][index][0], + ) + + @property + def get_triton_input(self): + inputs = ( + Tensor(name="prompts", shape=(-1,), dtype=bytes), + Tensor(name="max_output_len", shape=(-1,), dtype=numpy.int_, optional=True), + Tensor(name="top_k", shape=(-1,), dtype=numpy.int_, optional=True), + Tensor(name="top_p", shape=(-1,), dtype=numpy.single, optional=True), + Tensor(name="temperature", shape=(-1,), dtype=numpy.single, optional=True), + ) + return inputs + + @property + def get_triton_output(self): + outputs = (Tensor(name="outputs", shape=(-1,), dtype=bytes),) + return outputs + + @batch + def triton_infer_fn(self, **inputs: numpy.ndarray): + request_ids = [] + num_requests = len(inputs["prompts"]) + for index in range(num_requests): + request_id = self._add_triton_request_to_engine(inputs, index) + request_ids.append(request_id) + + responses = self._forward_regular(request_ids) + responses = [r[0] for r in responses] + + output_tensor = cast_output(responses, numpy.bytes_) + return {'outputs': output_tensor} + + @batch + def triton_infer_fn_streaming(self, **inputs: numpy.ndarray): + request_ids = [] + num_requests = len(inputs["prompts"]) + for index in range(num_requests): + request_id = self._add_triton_request_to_engine(inputs, index) + request_ids.append(request_id) + + for responses in self._forward_streaming(request_ids): + responses = [r[0] for r in responses] + output_tensor = cast_output(responses, numpy.bytes_) + yield {'outputs': output_tensor} + + # Mimic the TensorRTLLM exporter's forward function, even though we don't support many of its features. + def forward( + self, + input_texts: List[str], + max_output_len: int = 64, + top_k: int = 1, + top_p: float = 0.0, + temperature: float = 1.0, + stop_words_list: Optional[List[str]] = None, + bad_words_list: Optional[List[str]] = None, + no_repeat_ngram_size: Optional[int] = None, + task_ids: Optional[List[str]] = None, + lora_uids: Optional[List[str]] = None, + prompt_embeddings_table=None, + prompt_embeddings_checkpoint_path: Optional[str] = None, + streaming: bool = False, + output_log_probs: bool = False, + ) -> Union[List[List[str]], Iterable[List[List[str]]]]: + """ + The forward function performs LLM evaluation on the provided array of prompts with other parameters shared, + and returns the generated texts. If 'streaming' is True, the output texts are returned incrementally + with a generator: one token appended to each output at a time. If 'streaming' is false, the final output texts + are returned as a single list of responses. + """ + + if stop_words_list is not None and stop_words_list != []: + raise NotImplementedError("stop_words_list is not supported") + + if bad_words_list is not None and bad_words_list != []: + raise NotImplementedError("bad_words_list is not supported") + + if no_repeat_ngram_size is not None: + raise NotImplementedError("no_repeat_ngram_size is not supported") + + if task_ids is not None and task_ids != []: + raise NotImplementedError("task_ids is not supported") + + if lora_uids is not None and lora_uids != []: + raise NotImplementedError("lora_uids is not supported") + + if prompt_embeddings_table is not None: + raise NotImplementedError("prompt_embeddings_table is not supported") + + if prompt_embeddings_checkpoint_path is not None: + raise NotImplementedError("prompt_embeddings_checkpoint_path is not supported") + + if output_log_probs: + raise NotImplementedError("output_log_probs is not supported") + + request_ids = [] + for prompt in input_texts: + request_id = self._add_request_to_engine( + prompt=prompt, max_output_len=max_output_len, temperature=temperature, top_k=top_k, top_p=top_p + ) + request_ids.append(request_id) + + if streaming: + return self._forward_streaming(request_ids) + else: + return self._forward_regular(request_ids) diff --git a/nemo/lightning/__init__.py b/nemo/lightning/__init__.py index 0c5379fb6e82..9484a1dcbd13 100644 --- a/nemo/lightning/__init__.py +++ b/nemo/lightning/__init__.py @@ -12,7 +12,7 @@ from nemo.lightning.base import get_vocab_size, teardown from nemo.lightning.nemo_logger import NeMoLogger from nemo.lightning.pytorch.callbacks.megatron_model_checkpoint import ModelCheckpoint -from nemo.lightning.pytorch.opt import LRSchedulerModule, MegatronOptimizerModule, OptimizerModule +from nemo.lightning.pytorch.optim import LRSchedulerModule, MegatronOptimizerModule, OptimizerModule from nemo.lightning.pytorch.plugins import MegatronDataSampler, MegatronMixedPrecision from nemo.lightning.pytorch.plugins import data_sampler as _data_sampler from nemo.lightning.pytorch.strategies import MegatronStrategy diff --git a/nemo/lightning/io/__init__.py b/nemo/lightning/io/__init__.py index d1a193c5e728..1bf17786cf56 100644 --- a/nemo/lightning/io/__init__.py +++ b/nemo/lightning/io/__init__.py @@ -2,9 +2,10 @@ from nemo.lightning.io.capture import reinit from nemo.lightning.io.connector import Connector, ModelConnector from nemo.lightning.io.mixin import ConnectorMixin, IOMixin -from nemo.lightning.io.pl import TrainerCheckpoint, is_distributed_ckpt +from nemo.lightning.io.pl import TrainerContext, is_distributed_ckpt from nemo.lightning.io.state import TransformCTX, apply_transforms, state_transform + __all__ = [ "apply_transforms", "Connector", @@ -20,6 +21,6 @@ "model_exporter", 'reinit', "state_transform", - "TrainerCheckpoint", + "TrainerContext", "TransformCTX", ] diff --git a/nemo/lightning/io/api.py b/nemo/lightning/io/api.py index fbe764d67e3d..a99e0b8d8a92 100644 --- a/nemo/lightning/io/api.py +++ b/nemo/lightning/io/api.py @@ -1,12 +1,12 @@ -import pickle from pathlib import Path from typing import Any, Callable, Optional, Type, TypeVar import fiddle as fdl import pytorch_lightning as pl +from fiddle._src.experimental import serialization from nemo.lightning.io.mixin import ConnectorMixin, ConnT, ModelConnector -from nemo.lightning.io.pl import TrainerCheckpoint +from nemo.lightning.io.pl import TrainerContext CkptType = TypeVar("CkptType") @@ -34,34 +34,34 @@ def load(path: Path, output_type: Type[CkptType] = Any) -> CkptType: _path = Path(path) if hasattr(_path, 'is_dir') and _path.is_dir(): - _path = Path(_path) / "io.pkl" + _path = Path(_path) / "io.json" elif hasattr(_path, 'isdir') and _path.isdir: - _path = Path(_path) / "io.pkl" + _path = Path(_path) / "io.json" if not _path.is_file(): raise FileNotFoundError(f"No such file: '{_path}'") with open(_path, "rb") as f: - config = pickle.load(f) + config = serialization.load_json(f.read()) return fdl.build(config) -def load_ckpt(path: Path) -> TrainerCheckpoint: +def load_ckpt(path: Path) -> TrainerContext: """ - Loads a TrainerCheckpoint from a pickle file or directory. + Loads a TrainerContext from a json-file or directory. Args: - path (Path): The path to the pickle file or directory containing 'io.pkl'. + path (Path): The path to the json-file or directory containing 'io.json'. Returns ------- - TrainerCheckpoint: The loaded TrainerCheckpoint instance. + TrainerContext: The loaded TrainerContext instance. Example: - checkpoint: TrainerCheckpoint = load_ckpt("/path/to/checkpoint") + checkpoint: TrainerContext = load_ckpt("/path/to/checkpoint") """ - return load(path, output_type=TrainerCheckpoint) + return load(path, output_type=TrainerContext) def model_importer(target: Type[ConnectorMixin], ext: str) -> Callable[[Type[ConnT]], Type[ConnT]]: diff --git a/nemo/lightning/io/connector.py b/nemo/lightning/io/connector.py index a6ab4afd6d1b..41c81582bb63 100644 --- a/nemo/lightning/io/connector.py +++ b/nemo/lightning/io/connector.py @@ -217,4 +217,5 @@ def local_path(self, base_path: Optional[Path] = None) -> Path: return _base / str(self).replace("://", "/") - def on_import_ckpt(self, model: pl.LightningModule): ... + def on_import_ckpt(self, model: pl.LightningModule): + model.tokenizer = self.tokenizer diff --git a/nemo/lightning/io/fdl_torch.py b/nemo/lightning/io/fdl_torch.py new file mode 100644 index 000000000000..c74e48e1c411 --- /dev/null +++ b/nemo/lightning/io/fdl_torch.py @@ -0,0 +1,116 @@ +"""Fiddle extensions to handle PyTorch code more elegantly. + +This module provides extensions for better handling of PyTorch types and functions +in codegen, graphviz, and other debugging functions. +""" + +import types + +import libcst as cst +import torch +import torch.nn as nn +from fiddle._src import daglish_extensions +from fiddle._src.codegen import import_manager, py_val_to_cst_converter, special_value_codegen +from fiddle._src.experimental import serialization + + +def _make_torch_importable(name: str) -> special_value_codegen.Importable: + return special_value_codegen.SingleImportable("torch", lambda torch_name: f"{torch_name}.{name}") + + +_torch_type_importables = ( + (torch.bool, _make_torch_importable("bool")), + (torch.uint8, _make_torch_importable("uint8")), + (torch.int8, _make_torch_importable("int8")), + (torch.int16, _make_torch_importable("int16")), + (torch.int32, _make_torch_importable("int32")), + (torch.int64, _make_torch_importable("int64")), + (torch.float16, _make_torch_importable("float16")), + (torch.bfloat16, _make_torch_importable("bfloat16")), + (torch.float32, _make_torch_importable("float32")), + (torch.float64, _make_torch_importable("float64")), + (torch.complex64, _make_torch_importable("complex64")), + (torch.complex128, _make_torch_importable("complex128")), +) + +_torch_initializers = ( + nn.init.constant_, + nn.init.dirac_, + nn.init.xavier_normal_, + nn.init.xavier_uniform_, + nn.init.kaiming_normal_, + nn.init.kaiming_uniform_, + nn.init.normal_, + nn.init.ones_, + nn.init.orthogonal_, + nn.init.uniform_, + nn.init.zeros_, +) + +_import_aliases = (("torch.nn.init", "from torch.nn import init"),) + + +def _make_torch_nn_importable(name: str) -> special_value_codegen.Importable: + return special_value_codegen.SingleImportable("torch", lambda torch_mod_name: f"{torch_mod_name}.nn.{name}") + + +_nn_type_importables = ( + (nn.ReLU, _make_torch_nn_importable("ReLU")), + (nn.GELU, _make_torch_nn_importable("GELU")), + (nn.ReLU6, _make_torch_nn_importable("ReLU6")), + (nn.SiLU, _make_torch_nn_importable("SiLU")), + (nn.Sigmoid, _make_torch_nn_importable("Sigmoid")), + (nn.SELU, _make_torch_nn_importable("SELU")), + (nn.Hardtanh, _make_torch_nn_importable("Hardtanh")), + (nn.Tanh, _make_torch_nn_importable("Tanh")), +) + + +def is_torch_tensor(value): + """Returns true if `value` is a PyTorch Tensor.""" + return isinstance(value, torch.Tensor) + + +def convert_torch_tensor_to_cst(value, convert_child): + return cst.Call( + func=cst.Attribute(value=convert_child(torch), attr=cst.Name("tensor")), + args=[ + cst.Arg(convert_child(value.tolist())), + py_val_to_cst_converter.kwarg_to_cst("dtype", convert_child(value.dtype)), + ], + ) + + +def enable(): + """Registers PyTorch fiddle extensions. + + This allows for things like nicer handling of torch dtypes. + """ + for value, importable in _torch_type_importables: + special_value_codegen.register_exact_value(value, importable) + + for value, importable in _nn_type_importables: + special_value_codegen.register_exact_value(value, importable) + + for module_str, import_stmt in _import_aliases: + import_manager.register_import_alias(module_str, import_stmt) + + py_val_to_cst_converter.register_py_val_to_cst_converter(is_torch_tensor)(convert_torch_tensor_to_cst) + + for dtype, _ in _torch_type_importables: + daglish_extensions.register_immutable(dtype) + lib, symbol = str(dtype).split(".") + serialization.register_constant(lib, symbol, compare_by_identity=True) + + for init in _torch_initializers: + daglish_extensions.register_immutable(init) + daglish_extensions.register_function_with_immutable_return_value(init) + + # Monkey-patch the Serialization class to handle things like activation-functions + def _modified_serialize(self, value, current_path, all_paths=None): + if isinstance(value, types.BuiltinFunctionType): + return self._pyref(value, current_path) + return self._original_serialize(value, current_path, all_paths) + + serialization.Serialization._original_serialize = serialization.Serialization._serialize + serialization.Serialization._serialize = _modified_serialize diff --git a/nemo/lightning/io/mixin.py b/nemo/lightning/io/mixin.py index 62b9a165c542..2e0867cbe39e 100644 --- a/nemo/lightning/io/mixin.py +++ b/nemo/lightning/io/mixin.py @@ -1,3 +1,4 @@ +import base64 import functools import inspect from dataclasses import is_dataclass @@ -5,13 +6,17 @@ from typing import Any, Callable, Dict, Optional, Type, TypeVar, Union import fiddle as fdl -from cloudpickle import dump +import fiddle._src.experimental.dataclasses as fdl_dc +from cloudpickle import dumps, loads +from fiddle._src.experimental import serialization from typing_extensions import Self from nemo.lightning.io.capture import IOProtocol from nemo.lightning.io.connector import ModelConnector +from nemo.lightning.io.fdl_torch import enable as _enable_ext ConnT = TypeVar('ConnT', bound=ModelConnector) +_enable_ext() class IOMixin: @@ -54,7 +59,7 @@ def __init__(self, param1, param2): """ - __io__ = fdl.Config[Self] + __io__: fdl.Config[Self] def __new__(cls, *args, **kwargs): """ @@ -82,6 +87,14 @@ def wrapped_init(self, *args, **kwargs): return output + def __init_subclass__(cls): + serialization.register_node_traverser( + cls, + flatten_fn=_io_flatten_object, + unflatten_fn=_io_unflatten_object, + path_elements_fn=_io_path_elements_fn, + ) + def io_transform_args(self, init_fn, *args, **kwargs) -> Dict[str, Any]: """ Transforms and captures the arguments passed to the `__init__` method, filtering out @@ -106,10 +119,11 @@ def io_transform_args(self, init_fn, *args, **kwargs) -> Dict[str, Any]: for key in config_kwargs: if isinstance(config_kwargs[key], IOProtocol): config_kwargs[key] = config_kwargs[key].__io__ - if is_dataclass(self): + if is_dataclass(config_kwargs[key]): + config_kwargs[key] = fdl_dc.convert_dataclasses_to_configs(config_kwargs[key], allow_post_init=True) # Check if the arg is a factory (dataclasses.field) - if config_kwargs[key].__class__.__name__ == "_HAS_DEFAULT_FACTORY_CLASS": - to_del.append(key) + if config_kwargs[key].__class__.__name__ == "_HAS_DEFAULT_FACTORY_CLASS": + to_del.append(key) for key in to_del: del config_kwargs[key] @@ -137,9 +151,10 @@ def io_dump(self, output: Path): Args: output (Path): The path to the file where the configuration object will be serialized. """ - config_path = Path(output) / "io.pkl" - with open(config_path, "wb") as f: - dump(self.__io__, f) + config_path = Path(output) / "io.json" + with open(config_path, "w") as f: + json = serialization.dump_json(self.__io__) + f.write(json) class ConnectorMixin: @@ -198,7 +213,7 @@ def register_importer(cls, ext: str, default_path: Optional[str] = None) -> Call """ def decorator(connector: Type[ConnT]) -> Type[ConnT]: - cls._IMPORTERS[ext] = connector + cls._IMPORTERS[str(cls) + ext] = connector if default_path: connector.default_path = default_path return connector @@ -221,7 +236,7 @@ def register_exporter(cls, ext: str, default_path: Optional[str] = None) -> Call """ def decorator(connector: Type[ConnT]) -> Type[ConnT]: - cls._EXPORTERS[ext] = connector + cls._EXPORTERS[str(cls) + ext] = connector if default_path: connector.default_path = default_path return connector @@ -310,7 +325,7 @@ def _get_connector(cls, ext, path=None, importer=True) -> ModelConnector: else: _path = path - connector = cls._IMPORTERS.get(ext) if importer else cls._EXPORTERS.get(ext) + connector = cls._IMPORTERS.get(str(cls) + ext) if importer else cls._EXPORTERS.get(str(cls) + ext) if not connector: raise ValueError(f"No connector found for extension '{ext}'") @@ -321,3 +336,32 @@ def _get_connector(cls, ext, path=None, importer=True) -> ModelConnector: return connector() return connector(_path) + + +def _io_flatten_object(instance): + try: + serialization.dump_json(instance.__io__) + except serialization.UnserializableValueError as e: + pickled_data = dumps(instance.__io__) + encoded_data = base64.b64encode(pickled_data).decode('utf-8') + return (encoded_data,), None + + return instance.__io__.__flatten__() + + +def _io_unflatten_object(values, metadata): + if len(values) == 1: + encoded_data = values[0] + pickled_data = base64.b64decode(encoded_data.encode('utf-8')) + return loads(pickled_data) + + return fdl.Config.__unflatten__(values, metadata) + + +def _io_path_elements_fn(x): + try: + serialization.dump_json(x.__io__) + except serialization.UnserializableValueError: + return (serialization.IdentityElement(),) + + return x.__io__.__path_elements__() diff --git a/nemo/lightning/io/pl.py b/nemo/lightning/io/pl.py index 35dfb077bb9e..cf81cc847444 100644 --- a/nemo/lightning/io/pl.py +++ b/nemo/lightning/io/pl.py @@ -1,22 +1,19 @@ import logging from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Dict, Generic, Optional, Protocol, TypeVar, Union +from typing import Any, Callable, Dict, Generic, Optional, TypeVar, Union import pytorch_lightning as pl import torch from lightning_fabric.plugins.io.checkpoint_io import CheckpointIO from lightning_fabric.utilities.cloud_io import get_filesystem from lightning_fabric.utilities.types import _PATH -from megatron.core.dist_checkpointing.strategies import tensorstore from torch import nn from typing_extensions import Self, override from nemo.lightning.io.capture import IOProtocol from nemo.lightning.io.mixin import IOMixin -if TYPE_CHECKING: - from nemo.lightning.pytorch.strategies import MegatronStrategy log = logging.getLogger(__name__) @@ -26,39 +23,29 @@ @dataclass -class TrainerCheckpoint(IOMixin, Generic[LightningModuleT]): +class TrainerContext(IOMixin, Generic[LightningModuleT]): model: LightningModuleT trainer: pl.Trainer extra: Dict[str, Any] = field(default_factory=dict) @classmethod - def from_strategy(cls, strategy: "MegatronStrategy") -> Self: - if not isinstance(strategy.trainer, IOProtocol): + def from_trainer(cls, trainer: pl.Trainer) -> Self: + if not hasattr(trainer, "__io__"): raise ValueError(f"Trainer must be an instance of {IOProtocol}. Please use the Trainer from nemo.") - - if not isinstance(strategy.lightning_module, IOProtocol): + if not hasattr(trainer.lightning_module, "__io__"): raise ValueError("LightningModule must extend IOMixin.") - return cls(trainer=strategy.trainer, model=strategy.lightning_module, extra=cls.construct_extra(strategy)) + return cls(trainer=trainer, model=trainer.lightning_module, extra=cls.construct_extra(trainer)) @classmethod - def construct_extra(cls, strategy: "MegatronStrategy") -> Dict[str, Any]: + def construct_extra(cls, trainer: pl.Trainer) -> Dict[str, Any]: extra = {} - if hasattr(strategy.trainer, "datamodule") and isinstance(strategy.trainer.datamodule, IOProtocol): - extra["datamodule"] = strategy.trainer.datamodule.__io__ - - # TODO: Add optimizer to extra + if hasattr(trainer, "datamodule") and hasattr(trainer.datamodule, "__io__"): + extra["datamodule"] = trainer.datamodule.__io__ return extra -class TrainerCkptProtocol(Protocol): - @classmethod - def from_strategy(cls, strategy: "MegatronStrategy") -> Self: ... - - def io_dump(self, output: Path): ... - - class MegatronCheckpointIO(CheckpointIO): """CheckpointIO that utilizes :func:`torch.save` and :func:`torch.load` to save and load checkpoints respectively, common for most use cases. diff --git a/nemo/lightning/io/state.py b/nemo/lightning/io/state.py index ed481cfcfe08..b69fed9d0f4f 100644 --- a/nemo/lightning/io/state.py +++ b/nemo/lightning/io/state.py @@ -217,15 +217,15 @@ def __call__(self, ctx: TransformCTX) -> TransformCTX: source_key_dict = source_key source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()} target_matches = _match_keys(list(target_dict.keys()), target_key) - - for target_index, target_match in np.ndenumerate(target_matches): - kwargs = {} - for param in fn_params: - if param in source_matches_dict: - source_match = source_matches_dict[param][target_index[:-1]] - kwargs[param] = source_dict[source_match[target_index]] - - target_dict[target_match] = self.call_transform(ctx, **kwargs) + param_names = list(filter(lambda x: x in source_matches_dict, fn_params)) + for layer_names_group in zip(*([source_matches_dict[v] for v in param_names] + [target_matches])): + # Wrap in a list if it's a single layer (ie non-expert) + if isinstance(layer_names_group[0], str): + layer_names_group = [[x] for x in layer_names_group] + for layer_names in zip(*layer_names_group): + target_dict[layer_names[-1]] = self.call_transform( + ctx, **dict(zip(param_names, [source_dict[x] for x in layer_names[:-1]])) + ) else: source_keys = list(source_dict.keys()) target_keys = list(target_dict.keys()) diff --git a/nemo/lightning/megatron_parallel.py b/nemo/lightning/megatron_parallel.py index 44556a15c13a..4eab2fc4ea38 100644 --- a/nemo/lightning/megatron_parallel.py +++ b/nemo/lightning/megatron_parallel.py @@ -110,6 +110,7 @@ def __init__( vp_size: Optional[int] = None, ddp_config: Optional[DistributedDataParallelConfig] = None, cpu: bool = False, + convert_module_fn: Optional[Callable[[nn.Module], nn.Module]] = None, ) -> None: from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes from megatron.core import parallel_state @@ -134,6 +135,10 @@ def __init__( _model.configure_model() _pipeline.append(_model) + if convert_module_fn: + for i in range(len(_pipeline)): + _pipeline[i] = convert_module_fn(_pipeline[i]) + if isinstance(ddp_config, DistributedDataParallelConfig): for model_chunk_idx, model_chunk in enumerate(_pipeline): module = model_chunk.module @@ -280,12 +285,6 @@ def forward( if loss_mean == []: loss_mean = None - ## TODO: is this where logging should go? - model = pipeline - if isinstance(pipeline, list): - model = pipeline[0] - pipeline.log('train_loss', loss_mean) - return loss_mean def wrapped_forward_step( diff --git a/nemo/lightning/nemo_logger.py b/nemo/lightning/nemo_logger.py index 2ad0753d04c5..093e4f2ed589 100644 --- a/nemo/lightning/nemo_logger.py +++ b/nemo/lightning/nemo_logger.py @@ -7,6 +7,7 @@ import lightning_fabric as fl import pytorch_lightning as pl +from fiddle._src.experimental import serialization from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint as PTLModelCheckpoint from nemo.lightning.pytorch.callbacks import ModelCheckpoint @@ -48,11 +49,7 @@ def __post_init__(self): f"Cannot set both log_local_rank_0_only and log_global_rank_0_only to True. Please set either one or neither." ) - def setup( - self, - trainer: Union[pl.Trainer, fl.Fabric], - resume_if_exists: bool = False, - ): + def setup(self, trainer: Union[pl.Trainer, fl.Fabric], resume_if_exists: bool = False, task_config=None): """Setup the logger for the experiment. Args: @@ -100,6 +97,7 @@ def setup( "No version folders would be created under the log folder as 'resume_if_exists' is enabled." ) version = None + trainer.logger._version = version or "" if version: if is_global_rank_zero(): os.environ[NEMO_ENV_VARNAME_VERSION] = version @@ -115,6 +113,12 @@ def setup( os.makedirs(log_dir, exist_ok=True) # Cannot limit creation to global zero as all ranks write to own log file logging.info(f'Experiments will be logged at {log_dir}') + if task_config and is_global_rank_zero(): + task_config.save_config_img(log_dir / "task.png") + task_json = serialization.dump_json(task_config) + with open(log_dir / "task.json", "w") as f: + f.write(task_json) + if isinstance(trainer, pl.Trainer): if self.ckpt: _overwrite_i = None @@ -160,6 +164,12 @@ def setup( # This is set if the env var NEMO_TESTING is set to True. nemo_testing = get_envbool(NEMO_ENV_VARNAME_TESTING, False) + files_to_move = [] + if Path(log_dir).exists(): + for child in Path(log_dir).iterdir(): + if child.is_file(): + files_to_move.append(child) + # Handle logging to file log_file = log_dir / f'nemo_log_globalrank-{global_rank}_localrank-{local_rank}.txt' if self.log_local_rank_0_only is True and not nemo_testing: @@ -174,6 +184,7 @@ def setup( add_handlers_to_mcore_logger() + app_state.files_to_move = files_to_move app_state.files_to_copy = self.files_to_copy app_state.cmd_args = sys.argv diff --git a/nemo/lightning/pytorch/callbacks/megatron_model_checkpoint.py b/nemo/lightning/pytorch/callbacks/megatron_model_checkpoint.py index fb10ad3a218b..63164513c901 100644 --- a/nemo/lightning/pytorch/callbacks/megatron_model_checkpoint.py +++ b/nemo/lightning/pytorch/callbacks/megatron_model_checkpoint.py @@ -26,6 +26,7 @@ from pytorch_lightning.callbacks.model_checkpoint import _is_local_file_protocol from pytorch_lightning.utilities import rank_zero_info +from nemo.lightning.io.pl import TrainerContext from nemo.utils import logging from nemo.utils.app_state import AppState from nemo.utils.model_utils import ckpt_to_dir @@ -48,10 +49,12 @@ def __init__( train_time_interval: Optional[timedelta] = None, save_best_model: bool = False, save_on_train_epoch_end: Optional[bool] = False, # Save after training, not after validation + enable_nemo_ckpt_io: bool = True, **kwargs, ): self.save_best_model = save_best_model self.previous_best_path = "" + self.enable_nemo_ckpt_io = enable_nemo_ckpt_io # Call the parent class constructor with the remaining kwargs. super().__init__( @@ -82,11 +85,7 @@ def on_train_start(self, trainer, pl_module): log_dir = app_state.log_dir # Check to see if any files exist that need to be moved - files_to_move = [] - if Path(log_dir).exists(): - for child in Path(log_dir).iterdir(): - if child.is_file(): - files_to_move.append(child) + files_to_move = app_state.files_to_move if len(files_to_move) > 0: # Move old files to a new folder @@ -106,8 +105,9 @@ def on_train_start(self, trainer, pl_module): shutil.copy(Path(_file), log_dir) # Create files for cmd args and git info - with open(log_dir / 'cmd-args.log', 'w', encoding='utf-8') as _file: - _file.write(" ".join(app_state.cmd_args)) + if app_state.cmd_args: + with open(log_dir / 'cmd-args.log', 'w', encoding='utf-8') as _file: + _file.write(" ".join(app_state.cmd_args)) # Try to get git hash git_repo, git_hash = get_git_hash() @@ -366,6 +366,7 @@ def _save_checkpoint(self, trainer: 'pytorch_lightning.Trainer', filepath: str) # if anything goes wrong during checkpointing, we should be able to detect that data is incomplete. self.set_checkpoint_unfinished_marker(filepath, barrier_after=True) ema_callback = self._ema_callback(trainer) + if ema_callback is not None: with ema_callback.save_original_optimizer_state(trainer): super()._save_checkpoint(trainer, filepath) @@ -394,6 +395,11 @@ def _cb(): self._last_global_step_saved = global_step self._last_checkpoint_saved = filepath + from nemo.utils.get_rank import is_global_rank_zero + + if self.enable_nemo_ckpt_io and is_global_rank_zero(): + TrainerContext.from_trainer(trainer).io_dump(ckpt_to_dir(filepath)) + # notify loggers if trainer.is_global_zero: for logger in trainer.loggers: diff --git a/nemo/lightning/pytorch/opt/__init__.py b/nemo/lightning/pytorch/optim/__init__.py similarity index 81% rename from nemo/lightning/pytorch/opt/__init__.py rename to nemo/lightning/pytorch/optim/__init__.py index ded886bf1e6c..d23494a96a5f 100644 --- a/nemo/lightning/pytorch/opt/__init__.py +++ b/nemo/lightning/pytorch/optim/__init__.py @@ -1,5 +1,5 @@ -from nemo.lightning.pytorch.opt.base import LRSchedulerModule, OptimizerModule -from nemo.lightning.pytorch.opt.lr_scheduler import ( +from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule +from nemo.lightning.pytorch.optim.lr_scheduler import ( CosineAnnealingScheduler, InverseSquareRootAnnealingScheduler, NoamAnnealingScheduler, @@ -13,7 +13,7 @@ WarmupHoldPolicyScheduler, WarmupPolicyScheduler, ) -from nemo.lightning.pytorch.opt.megatron import MegatronOptimizerModule +from nemo.lightning.pytorch.optim.megatron import MegatronOptimizerModule __all__ = [ "OptimizerModule", diff --git a/nemo/lightning/pytorch/opt/base.py b/nemo/lightning/pytorch/optim/base.py similarity index 93% rename from nemo/lightning/pytorch/opt/base.py rename to nemo/lightning/pytorch/optim/base.py index fda3b9defb9e..0d8c1f2dcaf9 100644 --- a/nemo/lightning/pytorch/opt/base.py +++ b/nemo/lightning/pytorch/optim/base.py @@ -129,6 +129,11 @@ def custom_configure_optimizers(lightning_module_self, megatron_parallel=None): return opt model.configure_optimizers = types.MethodType(custom_configure_optimizers, model) + model.optim = self + + if hasattr(self, "__io__") and hasattr(model, "__io__"): + if hasattr(model.__io__, "optim"): + model.__io__.optim = self.__io__ @abstractmethod def optimizers(self, model) -> List[Optimizer]: @@ -142,6 +147,11 @@ def optimizers(self, model) -> List[Optimizer]: """ raise NotImplementedError("The optimizers method should be implemented by subclasses.") + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx) -> None: + if self._optimizers is not None: + lr = self._optimizers[0].param_groups[0]['lr'] + pl_module.log('lr', lr, rank_zero_only=True, batch_size=1) + def __call__(self, model: L.LightningModule, megatron_parallel=None) -> OptimizerLRScheduler: """Calls the setup and optimizers methods. diff --git a/nemo/lightning/pytorch/opt/lr_scheduler.py b/nemo/lightning/pytorch/optim/lr_scheduler.py similarity index 99% rename from nemo/lightning/pytorch/opt/lr_scheduler.py rename to nemo/lightning/pytorch/optim/lr_scheduler.py index 689eb2faa839..1c602d8111de 100644 --- a/nemo/lightning/pytorch/opt/lr_scheduler.py +++ b/nemo/lightning/pytorch/optim/lr_scheduler.py @@ -13,7 +13,7 @@ WarmupHoldPolicy, WarmupPolicy, ) -from nemo.lightning.pytorch.opt.base import LRSchedulerModule +from nemo.lightning.pytorch.optim.base import LRSchedulerModule class WarmupPolicyScheduler(LRSchedulerModule): diff --git a/nemo/lightning/pytorch/opt/megatron.py b/nemo/lightning/pytorch/optim/megatron.py similarity index 86% rename from nemo/lightning/pytorch/opt/megatron.py rename to nemo/lightning/pytorch/optim/megatron.py index 697e2010d1b4..814f58f2c195 100644 --- a/nemo/lightning/pytorch/opt/megatron.py +++ b/nemo/lightning/pytorch/optim/megatron.py @@ -7,7 +7,7 @@ from torch.optim import Optimizer from nemo.lightning.megatron_parallel import MegatronParallel -from nemo.lightning.pytorch.opt.base import LRSchedulerModule, OptimizerModule +from nemo.lightning.pytorch.optim.base import LRSchedulerModule, OptimizerModule class MegatronOptimizerModule(OptimizerModule): @@ -84,6 +84,16 @@ def optimizers(self, model: MegatronParallel) -> List[Optimizer]: from nemo.core.optim import McoreDistributedOptimizer + class McoreOpt(McoreDistributedOptimizer): + def sharded_state_dict( + self, + model_sharded_state_dict, + optimizer_state_dict=None, + is_loading=False, + dist_ckpt_parallel_save=False, + ): + return self.mcore_optimizer.sharded_state_dict(model_sharded_state_dict, is_loading=is_loading) + mcore_opt = get_megatron_optimizer( self.config, list(model), @@ -92,7 +102,7 @@ def optimizers(self, model: MegatronParallel) -> List[Optimizer]: lr_mult=self.lr_mult, ) - return [McoreDistributedOptimizer(mcore_opt)] + return [McoreOpt(mcore_opt)] def finalize_model_grads(self, *args, **kwargs): return finalize_model_grads(*args, **kwargs) diff --git a/nemo/lightning/pytorch/plugins/data_sampler.py b/nemo/lightning/pytorch/plugins/data_sampler.py index 470b7f3984f2..c6ff3b7ccaaa 100644 --- a/nemo/lightning/pytorch/plugins/data_sampler.py +++ b/nemo/lightning/pytorch/plugins/data_sampler.py @@ -94,6 +94,14 @@ def on_megatron_step_end(self, trainer: pl.Trainer, pl_module: pl.LightningModul # TODO: Add consumed samples consumed_samples = self.compute_consumed_samples(trainer.global_step + 1 - self.init_global_step) + pl_module.log( + 'consumed_samples', + consumed_samples, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + self.prev_consumed_samples = consumed_samples num_microbatch_calculator = ( diff --git a/nemo/lightning/pytorch/plugins/mixed_precision.py b/nemo/lightning/pytorch/plugins/mixed_precision.py index 6c3d556816d2..923bd625da62 100644 --- a/nemo/lightning/pytorch/plugins/mixed_precision.py +++ b/nemo/lightning/pytorch/plugins/mixed_precision.py @@ -13,6 +13,7 @@ # limitations under the License. from contextlib import contextmanager +from types import SimpleNamespace from typing import Any, Callable, Generator, List, Literal, Tuple, TypeVar, Union import pytorch_lightning as pl @@ -57,7 +58,7 @@ def float16_convertor(val): raise ValueError("precision must be '16-mixed' or 'bf16-mixed'") self.dtype = dtype - torch.set_autocast_gpu_dtype(dtype) + # torch.set_autocast_gpu_dtype(dtype) self.float16_convertor = float16_convertor self.amp_O2 = amp_O2 @@ -81,10 +82,15 @@ def convert_module(self, module: Module) -> Module: This is optional and depends on the precision limitations during optimization. """ - if self.precision == "bf16-mixed": - return module.bfloat16() - if self.precision == "16-mixed": - return module.half() + from megatron.core.distributed import DistributedDataParallel + from megatron.core.transformer.module import Float16Module + from megatron.core.utils import get_model_config + + if self.precision in ["16-mixed", "bf16-mixed"]: + config = get_model_config(module.module) + config.fp16 = self.precision == "16-mixed" + config.bf16 = self.precision == "bf16-mixed" + module.module = Float16Module(config, module.module) return module @@ -112,6 +118,8 @@ def convert_input(self, data: AnyT) -> AnyT: parallel_state.is_pipeline_first_stage() """ + return data + from megatron.core.transformer.module import fp32_to_float16 return fp32_to_float16(data, self.float16_convertor) @@ -123,6 +131,8 @@ def convert_output(self, data: AnyT) -> AnyT: parallel_state.is_pipeline_last_stage() """ + return data + from megatron.core.transformer.module import float16_to_fp32 return float16_to_fp32(data) diff --git a/nemo/lightning/pytorch/strategies.py b/nemo/lightning/pytorch/strategies.py index 833a1be3905a..9bffbf374183 100644 --- a/nemo/lightning/pytorch/strategies.py +++ b/nemo/lightning/pytorch/strategies.py @@ -1,6 +1,7 @@ import functools import inspect import logging +import os import shutil from collections import OrderedDict from contextlib import ExitStack @@ -13,6 +14,7 @@ from lightning_fabric.plugins import CheckpointIO, ClusterEnvironment from lightning_fabric.utilities.optimizer import _optimizers_to_device from megatron.core.distributed import DistributedDataParallelConfig +from megatron.core.optimizer import OptimizerConfig from pytorch_lightning.accelerators import CPUAccelerator from pytorch_lightning.callbacks.progress import TQDMProgressBar from pytorch_lightning.loops import _AutomaticOptimization, evaluation_loop, fit_loop, prediction_loop @@ -30,7 +32,7 @@ from typing_extensions import override from nemo.lightning import _strategy_lib, io -from nemo.lightning.io.pl import MegatronCheckpointIO, TrainerCheckpoint, TrainerCkptProtocol +from nemo.lightning.io.pl import MegatronCheckpointIO from nemo.lightning.megatron_parallel import CallbackConnector, MegatronParallel, _ModuleStepFunction from nemo.lightning.pytorch.callbacks import MegatronProgressBar @@ -46,27 +48,58 @@ class MegatronStrategy(DDPStrategy, io.IOMixin): """Megatron plugin for Pytorch Lightning. + This strategy implements model parallelism using NVIDIA's Megatron-LM framework. It supports + various forms of parallelism including tensor model parallelism, pipeline model parallelism, + sequence parallelism, and expert parallelism for efficient training of large language models. + Args: - no_ddp_communication_hook: Disable DDP communication hook when using AMP-O2 - with FP32 gradient accumulation. + tensor_model_parallel_size (int): Intra-layer model parallelism. Splits tensors across GPU ranks. + Defaults to 1. + pipeline_model_parallel_size (int): Inter-layer model parallelism. Splits transformer layers + across GPU ranks. Defaults to 1. + virtual_pipeline_model_parallel_size (Optional[int]): Interleaved pipeline parallelism used to + improve performance by reducing the pipeline bubble. Defaults to None. + context_parallel_size (int): Splits network input along sequence dimension across GPU ranks. + Defaults to 1. + sequence_parallel (bool): Makes tensor parallelism more memory efficient for LLMs (20B+) by + parallelizing layer norms and dropout sequentially. Defaults to False. + expert_model_parallel_size (int): Distributes MoE Experts across sub data parallel dimension. + Defaults to 1. + moe_extended_tp (bool): Alternative parallelization strategy for expert parallelism. Defaults to False. + data_sampler (Optional['DataSampler']): Custom data sampler for distributed training. Defaults to None. + parallel_devices (Optional[List[torch.device]]): List of devices to use for parallelism. Defaults to None. + cluster_environment: Cluster environment for distributed training. Defaults to None. + checkpoint_io: Checkpoint I/O handler. Defaults to None. + find_unused_parameters (bool): Find unused parameters in DDP. Defaults to False. + enable_nemo_ckpt_io (bool): Enable NeMo checkpoint I/O. Defaults to True. + ckpt_type (TrainerCkptProtocol): Checkpoint type. Defaults to TrainerCheckpoint. + ckpt_include_optimizer (bool): Include optimizer state in checkpoint. Defaults to False. + ddp (Union[DDPLiteral, DistributedDataParallelConfig]): DDP configuration. Defaults to "megatron". + lazy_init (bool): Use lazy initialization for model parallel parameters. Defaults to False. + pipeline_dtype (Optional[torch.dtype]): Data type for pipeline parallelism. Defaults to None. + **kwargs: Additional keyword arguments. + + Note: + This strategy is designed to work with NVIDIA's Megatron-LM framework and requires + specific model implementations that are compatible with Megatron's parallelism techniques. """ trainer: pl.Trainer - ## TODO: support context parallel def __init__( self, tensor_model_parallel_size: int = 1, pipeline_model_parallel_size: int = 1, virtual_pipeline_model_parallel_size: Optional[int] = None, + context_parallel_size: int = 1, sequence_parallel: bool = False, + expert_model_parallel_size: int = 1, + moe_extended_tp: bool = False, data_sampler: Optional['DataSampler'] = None, parallel_devices: Optional[List[torch.device]] = None, cluster_environment=None, # TODO: Add type-hint checkpoint_io=None, # TODO: Add type-hint find_unused_parameters: bool = False, - enable_nemo_ckpt_io: bool = True, - ckpt_type: TrainerCkptProtocol = TrainerCheckpoint, ckpt_include_optimizer: bool = False, ddp: Union[DDPLiteral, DistributedDataParallelConfig] = "megatron", lazy_init: bool = False, @@ -85,16 +118,19 @@ def __init__( self.data_sampler: Optional['DataSampler'] = data_sampler self.tensor_model_parallel_size = tensor_model_parallel_size self.pipeline_model_parallel_size = pipeline_model_parallel_size + self.context_parallel_size = context_parallel_size + self.expert_model_parallel_size = expert_model_parallel_size + self.moe_extended_tp = moe_extended_tp self.virtual_pipeline_model_parallel_size = virtual_pipeline_model_parallel_size self.sequence_parallel = sequence_parallel - self.enable_nemo_ckpt_io = enable_nemo_ckpt_io - self.ckpt_type = ckpt_type self.lazy_init = lazy_init self.ckpt_include_optimizer = ckpt_include_optimizer self.pipeline_dtype = pipeline_dtype + self.log_train_loss = bool(int(os.getenv("NEMO_LOG_TRAIN_LOSS", 1))) + self.log_memory_usage = bool(int(os.getenv("NEMO_LOG_MEMORY_USAGE", 0))) if ddp == "megatron": - self.ddp_config = DistributedDataParallelConfig(use_distributed_optimizer=True) + self.ddp_config = DistributedDataParallelConfig() elif isinstance(ddp, DistributedDataParallelConfig): self.ddp_config = ddp elif ddp == "pytorch": @@ -122,9 +158,27 @@ def connect(self, model: pl.LightningModule) -> None: config.tensor_model_parallel_size = self.tensor_model_parallel_size config.pipeline_model_parallel_size = self.pipeline_model_parallel_size config.virtual_pipeline_model_parallel_size = self.virtual_pipeline_model_parallel_size + config.context_parallel_size = self.context_parallel_size + config.expert_model_parallel_size = self.expert_model_parallel_size + config.moe_extended_tp = self.moe_extended_tp config.sequence_parallel = self.sequence_parallel self._mcore_config = config + has_optim = getattr(model, "optim", None) + if has_optim: + opt_config = getattr(model.optim, "config", None) + if isinstance(opt_config, OptimizerConfig): + mcore_opt_config: OptimizerConfig = cast(OptimizerConfig, opt_config) + if not self.ddp_config: + raise ValueError("PyTorch DDP is not enabled for mcore optimizer") + ddp_config = cast(DistributedDataParallelConfig, self.ddp_config) + + if mcore_opt_config.use_distributed_optimizer != ddp_config.use_distributed_optimizer: + from nemo.utils import logging + + logging.info("Fixing mis-match between ddp-config & mcore-optimizer config") + ddp_config.use_distributed_optimizer = mcore_opt_config.use_distributed_optimizer + @override def setup(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None: assert self.accelerator is not None @@ -208,12 +262,17 @@ def process_dataloader(self, dataloader: DataLoader) -> DataLoader: def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool = True) -> None: assert self.model is not None, "Model is not set" + convert_module_fn = None + if hasattr(self.precision_plugin, "convert_module"): + convert_module_fn = self.precision_plugin.convert_module + self.megatron_parallel = MegatronParallel( self.model, precision_plugin=self.precision_plugin, vp_size=self.virtual_pipeline_model_parallel_size, cpu=isinstance(trainer.accelerator, CPUAccelerator), ddp_config=self.ddp_config, + convert_module_fn=convert_module_fn, ) self.megatron_parallel.trainer = trainer @@ -227,18 +286,16 @@ def setup_megatron_parallel(self, trainer: pl.Trainer, setup_optimizers: bool = if setup_optimizers: self.setup_optimizers(trainer) - # TODO: Throw an execption if we have a mcore optimizer and no ddp_config - if hasattr(self.precision_plugin, "convert_optimizer"): - _optimizers = [*self.optimizers] - _optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0]) - self.optimizers = _optimizers + # TODO: Throw an execption if we have a mcore optimizer and no ddp_config - _optimizers_to_device(self.optimizers, self.root_device) + if hasattr(self.precision_plugin, "convert_optimizer"): + _optimizers = [*self.optimizers] + _optimizers[0] = self.precision_plugin.convert_optimizer(self.optimizers[0]) + self.optimizers = _optimizers - self.model = self.megatron_parallel + _optimizers_to_device(self.optimizers, self.root_device) - if hasattr(self.precision_plugin, "convert_module"): - self.model = self.precision_plugin.convert_module(self.model) + self.model = self.megatron_parallel self.model.callbacks.add(getattr(trainer, "callbacks")) if self.data_sampler: @@ -299,7 +356,50 @@ def training_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTP for opt in self.optimizers: opt.zero_grad() - return self.model(dataloader_iter, forward_only=False, *args, **kwargs) + out = self.model(dataloader_iter, forward_only=False, *args, **kwargs) + + self.lightning_module.log( + 'global_step', + self.trainer.global_step, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + + if self.log_memory_usage: + max_memory_reserved = torch.cuda.max_memory_reserved() + memory_allocated = torch.cuda.memory_allocated() + self.lightning_module.log( + "peak_memory_usage", + max_memory_reserved, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + self.lightning_module.log( + "memory_allocated", + memory_allocated, + prog_bar=True, + rank_zero_only=True, + batch_size=1, + ) + + if self.log_train_loss: + from megatron.core import parallel_state + + from nemo.collections.nlp.parts.utils_funcs import get_last_rank + + # When using pipeline parallelism, loss is calculated only in the last pipeline stage and + # it should be casted to other pipeline stages for logging. + # we can avoid this broadcast by updating the PTL log function to accept specific ranks + if parallel_state.get_pipeline_model_parallel_world_size() > 1: + if torch.distributed.get_rank() == get_last_rank(): + torch.distributed.send(out, 0) + elif torch.distributed.get_rank() == 0: + torch.distributed.recv(out, get_last_rank()) + self.lightning_module.log('reduced_train_loss', out, prog_bar=True, rank_zero_only=True, batch_size=1) + + return out @override def validation_step(self, dataloader_iter, *args: Any, **kwargs: Any) -> STEP_OUTPUT: @@ -389,12 +489,10 @@ def save_checkpoint( ) -> None: checkpoint["state_dict"] = OrderedDict([]) # remove device state_dict checkpoint["sharded_state_dict"] = self.megatron_parallel.sharded_state_dict() - if self.trainer.state.fn == TrainerFn.FITTING: + if self.trainer.state.fn == TrainerFn.FITTING and self.ckpt_include_optimizer: checkpoint["optimizer"] = [self.optimizer_sharded_state_dict()] self.checkpoint_io.save_checkpoint(checkpoint, filepath, storage_options=storage_options) - if self.enable_nemo_ckpt_io and self.is_global_zero and self.ckpt_type: - self.ckpt_type.from_strategy(self).io_dump(ckpt_to_dir(filepath)) @override def load_checkpoint(self, checkpoint_path: Union[str, Path]) -> Dict[str, Any]: @@ -430,16 +528,36 @@ def load_model_state_dict(self, checkpoint: Mapping[str, Any], strict: bool = Tr checkpoint_state_dict = checkpoint['state_dict'][f'model_{index}'] else: checkpoint_state_dict = checkpoint['state_dict'] - # checkpoint_state_dict has "model." but module does not so we need to remove it when loading - checkpoint_state_dict = { - key.replace('model.', ''): checkpoint_state_dict.pop(key) for key in list(checkpoint_state_dict.keys()) - } + + mcore_model = self.lightning_module.module + current = self.model[0] + n_nesting = 2 + while current != mcore_model: + current = current.module + n_nesting += 1 + + _state_dict = {} + for key, value in checkpoint_state_dict.items(): + # Count the number of "module." at the start of the key + count, _key = 0, key + while _key.startswith("module."): + _key = _key[len("module.") :] + count += 1 + + # Adjust the number of "module." prefixes + if count < n_nesting: + to_add = "module." * (n_nesting - count) + _state_dict[f"{to_add}{key}"] = value + elif count > n_nesting: + to_remove = "module." * (count - n_nesting) + _state_dict[key[len(to_remove) :]] = value + checkpoint_state_dict = _state_dict + module.load_state_dict(checkpoint_state_dict, strict=strict) @property @override def checkpoint_io(self) -> CheckpointIO: - if self._checkpoint_io is None: self._checkpoint_io = MegatronCheckpointIO() elif isinstance(self._checkpoint_io, _WrappingCheckpointIO): diff --git a/nemo/utils/app_state.py b/nemo/utils/app_state.py index 4d1d7387ba90..7a60c3969df3 100644 --- a/nemo/utils/app_state.py +++ b/nemo/utils/app_state.py @@ -81,8 +81,10 @@ def __init__(self): self._model_guid_map = {} # type: Dict[str, ModelMetadataRegistry] self._restore = False # TODO: are this and _is_model_being_restored both needed? + # files from a previous run to move into a new directory + self.files_to_move = [] # files to copy into log dir - self._files_to_copy = None + self._files_to_copy = [] # command-ling arguments for run self._cmd_args = None @@ -560,6 +562,20 @@ def checkpoint_callback_params(self, params): """ self._checkpoint_callback_params = params + @property + def files_to_move(self): + """Returns the list of files to move into a separate directory.""" + return self._files_to_move + + @files_to_move.setter + def files_to_move(self, files): + """Sets the files_to_move property. + + Args: + files (list[str]): list of filenames to move. + """ + self._files_to_move = files + @property def files_to_copy(self): """Returns the list of files to copy into the log dir.""" diff --git a/requirements/requirements_vllm.txt b/requirements/requirements_vllm.txt new file mode 100644 index 000000000000..a603b3c4ec53 --- /dev/null +++ b/requirements/requirements_vllm.txt @@ -0,0 +1 @@ +vllm==0.5.0 diff --git a/scripts/checkpoint_converters/convert_gemma_hf_to_nemo.py b/scripts/checkpoint_converters/convert_gemma_hf_to_nemo.py index de12aefd1844..9ce51e544661 100644 --- a/scripts/checkpoint_converters/convert_gemma_hf_to_nemo.py +++ b/scripts/checkpoint_converters/convert_gemma_hf_to_nemo.py @@ -127,8 +127,8 @@ def adjust_tensor_shapes(model, nemo_state_dict): model_config = model.cfg num_query_groups = model_config["num_query_groups"] head_num = model_config["num_attention_heads"] - head_size = model_config["kv_channels"] hidden_size = model_config["hidden_size"] + head_size = model_config["kv_channels"] heads_per_group = head_num // num_query_groups # Note: For 'key' and 'value' weight and biases, NeMo uses a consolidated tensor 'query_key_value'. diff --git a/scripts/checkpoint_converters/convert_gemma_pyt_to_nemo.py b/scripts/checkpoint_converters/convert_gemma_pyt_to_nemo.py index d14e5f7de551..3cf3ed021527 100644 --- a/scripts/checkpoint_converters/convert_gemma_pyt_to_nemo.py +++ b/scripts/checkpoint_converters/convert_gemma_pyt_to_nemo.py @@ -133,8 +133,8 @@ def adjust_tensor_shapes(model, nemo_state_dict): model_config = model.cfg num_query_groups = model_config["num_query_groups"] head_num = model_config["num_attention_heads"] - head_size = model_config["kv_channels"] hidden_size = model_config["hidden_size"] + head_size = model_config["kv_channels"] heads_per_group = head_num // num_query_groups # Note: For 'key' and 'value' weight and biases, NeMo uses a consolidated tensor 'query_key_value'. diff --git a/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py b/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py new file mode 100644 index 000000000000..d91899348e8c --- /dev/null +++ b/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py @@ -0,0 +1,331 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + python3 /opt/NeMo/scripts/checkpoint_converters/convert_llava_hf_to_nemo.py \ + --input_name_or_path llava-hf/llava-1.5-7b-hf \ + --output_path /path/to/llava-7b.nemo \ + --tokenizer_path /path/to/tokenizer.model +""" + +import os +from argparse import ArgumentParser + +import torch +from omegaconf import OmegaConf +from transformers import LlamaTokenizer, LlavaForConditionalGeneration + +from nemo.collections.multimodal.models.multimodal_llm.neva.neva_model import MegatronNevaModel +from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision +from nemo.utils import logging + + +def create_rename_keys(num_hidden_layers): + rename_keys = [] + for i in range(num_hidden_layers): + # Attention layers + rename_keys.extend( + [ + ( + f"language_model.model.layers.{i}.self_attn.o_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_proj.weight", + ), + ( + f"language_model.model.layers.{i}.self_attn.q_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_q.weight", + ), + ( + f"language_model.model.layers.{i}.self_attn.k_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_k.weight", + ), + ( + f"language_model.model.layers.{i}.self_attn.v_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_v.weight", + ), + # MLP and LayerNorm + ( + f"language_model.model.layers.{i}.mlp.gate_proj.weight", + f"model.decoder.layers.{i}.mlp.linear_fc1_gate.weight", + ), + ( + f"language_model.model.layers.{i}.mlp.up_proj.weight", + f"model.decoder.layers.{i}.mlp.linear_fc1_proj.weight", + ), + ( + f"language_model.model.layers.{i}.mlp.down_proj.weight", + f"model.decoder.layers.{i}.mlp.linear_fc2.weight", + ), + ( + f"language_model.model.layers.{i}.input_layernorm.weight", + f"model.decoder.layers.{i}.self_attention.linear_qkv.layer_norm_weight", + ), + ( + f"language_model.model.layers.{i}.post_attention_layernorm.weight", + f"model.decoder.layers.{i}.mlp.linear_fc1.layer_norm_weight", + ), + ] + ) + + rename_keys.extend( + [ + ( + "multi_modal_projector.linear_1.weight", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.0.weight", + ), + ( + "multi_modal_projector.linear_1.bias", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.0.bias", + ), + ( + "multi_modal_projector.linear_2.weight", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.2.weight", + ), + ( + "multi_modal_projector.linear_2.bias", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.2.bias", + ), + ("language_model.model.embed_tokens.weight", "model.embedding.word_embeddings.weight"), + ("language_model.model.norm.weight", "model.decoder.final_layernorm.weight"), + ("language_model.lm_head.weight", "model.output_layer.weight"), + ] + ) + + return rename_keys + + +def rename_model_keys(model_state_dict, rename_keys): + """ + Rename keys in the model's state dictionary based on the provided mappings. + + Parameters: + model_state_dict (dict): The state dictionary of the model. + rename_keys (list): A list of tuples with the mapping (old_key, new_key). + + Returns: + dict: A new state dictionary with updated key names. + """ + + # Create a new state dictionary with updated key names + new_state_dict = {} + + # Track keys from the original state dict to ensure all are processed + remaining_keys = set(model_state_dict.keys()) + + # Iterate over the rename mappings + for old_key, new_key in rename_keys: + if old_key in model_state_dict: + # Rename the key and remove it from the tracking set + new_state_dict[new_key] = model_state_dict[old_key] + remaining_keys.remove(old_key) + + # Check if any keys were not converted from old to new + for old_key in remaining_keys: + print(f"Warning: Key '{old_key}' was not converted.") + + return new_state_dict + + +def adjust_tensor_shapes(model, nemo_state_dict): + """ + Adapt tensor shapes in the state dictionary to ensure compatibility with a different model structure. + + Parameters: + nemo_state_dict (dict): The state dictionary of the model. + + Returns: + dict: The updated state dictionary with modified tensor shapes for compatibility. + """ + model_config = model.cfg + num_query_groups = model_config["num_query_groups"] + head_num = model_config["num_attention_heads"] + hidden_size = model_config["hidden_size"] + head_size = model_config["kv_channels"] + heads_per_group = head_num // num_query_groups + + # Note: For 'key' and 'value' weight and biases, NeMo uses a consolidated tensor 'query_key_value'. + for key_ in list(nemo_state_dict.keys()): + if 'vision_towel' in key_: + del nemo_state_dict[key_] + + if 'word_embeddings.weight' in key_ or 'output_layer.weight' in key_: + # padding + loaded_weight = nemo_state_dict[key_] + new_weight = model.state_dict()[key_] + new_weight[: loaded_weight.shape[0], : loaded_weight.shape[1]] = loaded_weight + nemo_state_dict[key_] = new_weight + + if 'mlp.linear_fc1_gate.weight' in key_: + key_gate = key_ + key_proj = key_.replace('mlp.linear_fc1_gate.weight', 'mlp.linear_fc1_proj.weight') + new_key = key_.replace('mlp.linear_fc1_gate.weight', 'mlp.linear_fc1.weight') + gate_weight = nemo_state_dict[key_gate] + proj_weight = nemo_state_dict[key_proj] + nemo_state_dict[new_key] = torch.cat((gate_weight, proj_weight)) + del nemo_state_dict[key_gate], nemo_state_dict[key_proj] + + if 'self_attention.linear_q.weight' in key_: + key_q = key_ + key_k = key_.replace('linear_q', 'linear_k') + key_v = key_.replace('linear_q', 'linear_v') + key_qkv = key_.replace('linear_q', 'linear_qkv') + + # [(head_num + 2 * num_query_groups) * head_size, hidden_size] + # -> [head_num, head_size, hidden_size], 2 * [num_query_groups, head_size, hidden_size] + q_weight, k_weight, v_weight = nemo_state_dict[key_q], nemo_state_dict[key_k], nemo_state_dict[key_v] + q_weight = q_weight.reshape(head_num, head_size, hidden_size) + k_weight = k_weight.reshape(num_query_groups, head_size, hidden_size) + v_weight = v_weight.reshape(num_query_groups, head_size, hidden_size) + + qkv_weight = torch.empty((0, head_size, hidden_size), device=q_weight.device) + for i in range(num_query_groups): + qkv_weight = torch.cat((qkv_weight, q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :])) + qkv_weight = torch.cat((qkv_weight, k_weight[i : i + 1, :, :])) + qkv_weight = torch.cat((qkv_weight, v_weight[i : i + 1, :, :])) + qkv_weight = qkv_weight.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) + nemo_state_dict[key_qkv] = qkv_weight + del nemo_state_dict[key_q], nemo_state_dict[key_k], nemo_state_dict[key_v] + + return nemo_state_dict + + +def adjust_nemo_config(model_config, ref_config): + model_config.mm_cfg.mm_mlp_adapter_type = "mlp2x_gelu" + if ref_config["vision_config"].image_size == 336: + model_config.mm_cfg.vision_encoder.from_pretrained = "openai/clip-vit-large-patch14-336" + model_config.data.image_token_len = 576 + else: + model_config.mm_cfg.vision_encoder.from_pretrained = "openai/clip-vit-large-patch14" + model_config.data.image_token_len = 256 + + ref_config = ref_config['text_config'].__dict__ + model_config["encoder_seq_length"] = ref_config["max_position_embeddings"] + model_config["num_layers"] = ref_config["num_hidden_layers"] + model_config["ffn_hidden_size"] = ref_config["intermediate_size"] + model_config["hidden_size"] = ref_config["hidden_size"] + model_config["num_attention_heads"] = ref_config["num_attention_heads"] + model_config["num_query_groups"] = ref_config["num_key_value_heads"] + model_config["layernorm_epsilon"] = ref_config["rms_norm_eps"] + model_config["init_method_std"] = ref_config["initializer_range"] + model_config["kv_channels"] = ref_config.get( + "head_dim", model_config["hidden_size"] // model_config["num_attention_heads"] + ) + if ref_config.get("rope_scaling") is not None: + if ref_config["rope_scaling"]["type"] == "linear": + model_config["seq_len_interpolation_factor"] = ref_config["rope_scaling"]["factor"] + else: + raise ValueError("Only linear rope scaling type is supported now") + model_config["use_cpu_initialization"] = True + + return model_config + + +def get_args(): + parser = ArgumentParser() + parser.add_argument("--input_name_or_path", type=str) + parser.add_argument("--tokenizer_path", type=str) + parser.add_argument("--conv_template", default="v1", type=str) + parser.add_argument( + "--hparams_file", + type=str, + default=os.path.join( + os.path.dirname(__file__), '../../examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml' + ), + required=False, + help="Path config for restoring. It's created during training and may need to be modified during restore if restore environment is different than training. Ex: /raid/nemo_experiments/megatron_gpt/hparams.yaml", + ) + parser.add_argument("--output_path", type=str, default=None, help="Path to output .nemo file.") + parser.add_argument( + "--precision", type=str, default="bf16", choices=["bf16", "32"], help="Precision for checkpoint weight saved" + ) + parser.add_argument("--skip_verification", action="store_true") + + args = parser.parse_args() + return args + + +def convert(args): + logging.info(f"Loading checkpoint from HF Llava: `{args.input_name_or_path}`") + hf_tokenizer = LlamaTokenizer.from_pretrained(args.input_name_or_path) + hf_model = LlavaForConditionalGeneration.from_pretrained(args.input_name_or_path) + logging.info("HF Model loading done.") + + nemo_config = OmegaConf.load(args.hparams_file) + nemo_config.model = adjust_nemo_config(nemo_config.model, hf_model.config.__dict__) + nemo_config.model.data["conv_template"] = args.conv_template + nemo_config.model.mm_cfg.llm["model_type"] = args.conv_template + nemo_config.model.tokenizer["model"] = args.tokenizer_path + + nemo_config.trainer["precision"] = args.precision + trainer = MegatronTrainerBuilder(nemo_config).create_trainer() + model = MegatronNevaModel(nemo_config.model, trainer) + + rename_keys = create_rename_keys(nemo_config.model.num_layers) + old_state_dict = hf_model.state_dict() + new_state_dict = rename_model_keys(model_state_dict=old_state_dict, rename_keys=rename_keys) + + nemo_state_dict = adjust_tensor_shapes(model, new_state_dict) + model.load_state_dict(nemo_state_dict, strict=False) + + logging.info(f'=' * 100) + if not args.skip_verification: + # Verifications + input_texts = [ + 'query: how much protein should a female eat', + ] + logging.info(f"Running verifications {input_texts} ...") + + # Tokenize the input texts + hf_tokenizer.pad_token = hf_tokenizer.eos_token + batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') + batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()} + hf_model = hf_model.cuda().eval() + model = model.eval() + + hf_outputs = hf_model(**batch_dict_cuda, output_hidden_states=True) + ids = batch_dict_cuda['input_ids'] + + id_tensors = [torch.unsqueeze(torch.LongTensor(id_list), dim=0) for id_list in ids.cpu()] + + masks_and_position_ids = [ + get_ltor_masks_and_position_ids(id_tensor, hf_tokenizer.eos_token, False, False, False) + for id_tensor in id_tensors + ] + for tokens, attn_mask_and_pos_ids in zip(id_tensors, masks_and_position_ids): + attn_mask, _, pos_ids = attn_mask_and_pos_ids + + outputs = model( + tokens=tokens, text_position_ids=pos_ids.cuda(), attention_mask=attn_mask.cuda(), labels=None + ) + + hf_next_token = hf_outputs.logits[0, -1].argmax() + next_token = outputs.squeeze()[-1].argmax() + + logging.info(f"HF predicted next token is: '{hf_tokenizer._convert_id_to_token(int(hf_next_token))}'.") + logging.info(f"NeMo predicted next token is: '{hf_tokenizer._convert_id_to_token(int(next_token))}'.") + assert ( + hf_next_token == next_token + ), f'prediction mismatch: {hf_tokenizer.decode(hf_next_token)} != {hf_tokenizer.decode(next_token)}' + logging.info(f'=' * 100) + + dtype = torch_dtype_from_precision(args.precision) + model = model.to(dtype=dtype) + model.save_to(args.output_path) + logging.info(f'NeMo model saved to: {args.output_path}') + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/scripts/checkpoint_converters/convert_llava_nemo_to_hf.py b/scripts/checkpoint_converters/convert_llava_nemo_to_hf.py new file mode 100644 index 000000000000..430a74567ec2 --- /dev/null +++ b/scripts/checkpoint_converters/convert_llava_nemo_to_hf.py @@ -0,0 +1,337 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" + python3 /opt/NeMo/scripts/nlp_language_modeling/convert_gemma_hf_to_nemo.py \ + --input_name_or_path /path/to/llava-v1.5-7b.nemo \ + --hf_input_path llava-hf/llava-1.5-7b-hf \ + --hf_output_path=/path/to/hf_updated_checkpoint +""" + +import os +from argparse import ArgumentParser + +import torch +from omegaconf import OmegaConf +from transformers import LlamaTokenizer, LlavaForConditionalGeneration + +from nemo.collections.multimodal.models.multimodal_llm.neva.neva_model import MegatronNevaModel +from nemo.collections.nlp.modules.common.megatron.utils import get_ltor_masks_and_position_ids +from nemo.collections.nlp.parts.megatron_trainer_builder import MegatronTrainerBuilder +from nemo.collections.nlp.parts.nlp_overrides import NLPSaveRestoreConnector +from nemo.utils import logging + + +def create_rename_keys(num_hidden_layers): + rename_keys = [] + for i in range(num_hidden_layers): + # Attention layers + rename_keys.extend( + [ + ( + f"language_model.model.layers.{i}.self_attn.o_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_proj.weight", + ), + ( + f"language_model.model.layers.{i}.self_attn.q_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_q.weight", + ), + ( + f"language_model.model.layers.{i}.self_attn.k_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_k.weight", + ), + ( + f"language_model.model.layers.{i}.self_attn.v_proj.weight", + f"model.decoder.layers.{i}.self_attention.linear_v.weight", + ), + # MLP and LayerNorm + ( + f"language_model.model.layers.{i}.mlp.gate_proj.weight", + f"model.decoder.layers.{i}.mlp.linear_fc1_gate.weight", + ), + ( + f"language_model.model.layers.{i}.mlp.up_proj.weight", + f"model.decoder.layers.{i}.mlp.linear_fc1_proj.weight", + ), + ( + f"language_model.model.layers.{i}.mlp.down_proj.weight", + f"model.decoder.layers.{i}.mlp.linear_fc2.weight", + ), + ( + f"language_model.model.layers.{i}.input_layernorm.weight", + f"model.decoder.layers.{i}.self_attention.linear_qkv.layer_norm_weight", + ), + ( + f"language_model.model.layers.{i}.post_attention_layernorm.weight", + f"model.decoder.layers.{i}.mlp.linear_fc1.layer_norm_weight", + ), + ] + ) + + rename_keys.extend( + [ + ( + "multi_modal_projector.linear_1.weight", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.0.weight", + ), + ( + "multi_modal_projector.linear_1.bias", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.0.bias", + ), + ( + "multi_modal_projector.linear_2.weight", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.2.weight", + ), + ( + "multi_modal_projector.linear_2.bias", + "model.embedding.word_embeddings.adapter_layer.mm_projector_adapter.mm_projector.2.bias", + ), + ("language_model.model.embed_tokens.weight", "model.embedding.word_embeddings.weight"), + ("language_model.model.norm.weight", "model.decoder.final_layernorm.weight"), + ("language_model.lm_head.weight", "model.output_layer.weight"), + ] + ) + + return rename_keys + + +def rename_model_keys(model_state_dict, rename_keys): + """ + Rename keys in the model's state dictionary based on the provided mappings. + + Parameters: + model_state_dict (dict): The state dictionary of the model. + rename_keys (list): A list of tuples with the mapping (old_key, new_key). + + Returns: + dict: A new state dictionary with updated key names. + """ + + # Create a new state dictionary with updated key names + new_state_dict = {} + + # Track keys from the original state dict to ensure all are processed + remaining_keys = set(model_state_dict.keys()) + + # Iterate over the rename mappings + for new_key, old_key in rename_keys: + if old_key in model_state_dict: + # Rename the key and remove it from the tracking set + new_state_dict[new_key] = model_state_dict[old_key] + remaining_keys.remove(old_key) + + # Check if any keys were not converted from old to new + for old_key in remaining_keys: + print(f"Warning: Key '{old_key}' was not converted.") + + return new_state_dict + + +def reverse_adjust_tensor_shapes(model, hf_model, nemo_state_dict): + """ + Reverse the tensor adjustments made in the state dictionary to retrieve the original model structure. + + Parameters: + model (torch.nn.Module): The model instance to reference the state dictionary. + nemo_state_dict (dict): The state dictionary containing the adjusted tensors. + + Returns: + dict: The updated state dictionary with original tensor shapes and structures. + """ + model_config = model.cfg + num_query_groups = model_config["num_query_groups"] + head_num = model_config["num_attention_heads"] + hidden_size = model_config["hidden_size"] + head_size = model_config["kv_channels"] + if head_size is None: + head_size = hidden_size // head_num + heads_per_group = head_num // num_query_groups + vocab_size = hf_model.config.vocab_size + + for key_ in list(nemo_state_dict.keys()): + if 'word_embeddings.weight' in key_ or 'output_layer.weight' in key_: + # Reverse padding + loaded_weight = model.state_dict()[key_] + nemo_state_dict[key_] = loaded_weight[:vocab_size] + + if 'mlp.linear_fc1.weight' in key_: + new_key_gate = key_.replace('mlp.linear_fc1.weight', 'mlp.linear_fc1_gate.weight') + new_key_proj = key_.replace('mlp.linear_fc1.weight', 'mlp.linear_fc1_proj.weight') + + # Split concatenated gate and projection weights + combined_weight = nemo_state_dict[key_] + gate_weight, proj_weight = torch.chunk(combined_weight, 2, dim=0) + nemo_state_dict[new_key_gate] = gate_weight + nemo_state_dict[new_key_proj] = proj_weight + del nemo_state_dict[key_] + + if 'self_attention.linear_qkv.weight' in key_: + key_qkv = key_ + key_q = key_qkv.replace('linear_qkv', 'linear_q') + key_k = key_qkv.replace('linear_qkv', 'linear_k') + key_v = key_qkv.replace('linear_qkv', 'linear_v') + qkv_weight = nemo_state_dict[key_qkv].reshape(-1, head_size, hidden_size) + q_weight = torch.empty((head_num, head_size, hidden_size), device=qkv_weight.device) + k_weight = torch.empty((num_query_groups, head_size, hidden_size), device=qkv_weight.device) + v_weight = torch.empty((num_query_groups, head_size, hidden_size), device=qkv_weight.device) + + qkv_index = 0 + for i in range(num_query_groups): + q_weight[i * heads_per_group : (i + 1) * heads_per_group, :, :] = qkv_weight[ + qkv_index : qkv_index + heads_per_group, :, : + ] + qkv_index += heads_per_group + k_weight[i, :, :] = qkv_weight[qkv_index, :, :] + qkv_index += 1 + v_weight[i, :, :] = qkv_weight[qkv_index, :, :] + qkv_index += 1 + + nemo_state_dict[key_q] = q_weight.reshape(head_num * head_size, hidden_size) + nemo_state_dict[key_k] = k_weight.reshape(num_query_groups * head_size, hidden_size) + nemo_state_dict[key_v] = v_weight.reshape(num_query_groups * head_size, hidden_size) + + del nemo_state_dict[key_qkv] + + return nemo_state_dict + + +def adjust_nemo_config(model_config, ref_config): + model_config.mm_cfg.mm_mlp_adapter_type = "mlp2x_gelu" + if ref_config["vision_config"].image_size == 336: + model_config.mm_cfg.vision_encoder.from_pretrained = "openai/clip-vit-large-patch14-336" + model_config.data.image_token_len = 576 + else: + model_config.mm_cfg.vision_encoder.from_pretrained = "openai/clip-vit-large-patch14" + model_config.data.image_token_len = 256 + + ref_config = ref_config['text_config'].__dict__ + model_config["encoder_seq_length"] = ref_config["max_position_embeddings"] + model_config["num_layers"] = ref_config["num_hidden_layers"] + model_config["ffn_hidden_size"] = ref_config["intermediate_size"] + model_config["hidden_size"] = ref_config["hidden_size"] + model_config["num_attention_heads"] = ref_config["num_attention_heads"] + model_config["num_query_groups"] = ref_config["num_key_value_heads"] + model_config["layernorm_epsilon"] = ref_config["rms_norm_eps"] + model_config["init_method_std"] = ref_config["initializer_range"] + model_config["kv_channels"] = ref_config.get( + "head_dim", model_config["hidden_size"] // model_config["num_attention_heads"] + ) + if ref_config.get("rope_scaling") is not None: + if ref_config["rope_scaling"]["type"] == "linear": + model_config["seq_len_interpolation_factor"] = ref_config["rope_scaling"]["factor"] + else: + raise ValueError("Only linear rope scaling type is supported now") + model_config["use_cpu_initialization"] = True + + return model_config + + +def get_args(): + parser = ArgumentParser() + parser.add_argument( + "--input_name_or_path", + type=str, + default=None, + required=True, + help="Path to .nemo file or extracted folder", + ) + parser.add_argument( + "--hf_input_path", + type=str, + default=None, + help="A HF model path, " "e.g. a folder containing https://huggingface.co/meta-llama/Llama-2-7b-hf/tree/main", + ) + parser.add_argument( + "--hf_output_path", + type=str, + default=None, + help="Output HF model path, " "with the same format as above but user's own weights", + ) + parser.add_argument("--skip_verification", action="store_true") + + args = parser.parse_args() + return args + + +def convert(args): + logging.info(f"Loading checkpoint from HF Llava: `{args.hf_input_path}`") + hf_tokenizer = LlamaTokenizer.from_pretrained(args.hf_input_path) + hf_model = LlavaForConditionalGeneration.from_pretrained(args.hf_input_path) + logging.info("HF Model loading done.") + + nemo_config = OmegaConf.load( + os.path.join(os.path.dirname(__file__), '../../examples/multimodal/multimodal_llm/neva/conf/llava_config.yaml') + ) + trainer = MegatronTrainerBuilder(nemo_config).create_trainer() + model = MegatronNevaModel.restore_from( + restore_path=args.input_name_or_path, + trainer=trainer, + save_restore_connector=NLPSaveRestoreConnector(), + ) + + rename_keys = create_rename_keys(model.cfg.num_layers) + old_state_dict = model.state_dict() + nemo_state_dict = reverse_adjust_tensor_shapes(model, hf_model, old_state_dict) + hf_state_dict = rename_model_keys(model_state_dict=nemo_state_dict, rename_keys=rename_keys) + + hf_model.load_state_dict(hf_state_dict, strict=False) + + logging.info(f'=' * 100) + if not args.skip_verification: + # Verifications + input_texts = [ + 'query: how much protein should a female eat', + ] + logging.info(f"Running verifications {input_texts} ...") + + # Tokenize the input texts + hf_tokenizer.pad_token = hf_tokenizer.eos_token + batch_dict = hf_tokenizer(input_texts, max_length=512, padding=True, truncation=True, return_tensors='pt') + batch_dict_cuda = {k: v.cuda() for k, v in batch_dict.items()} + hf_model = hf_model.cuda().eval() + model = model.eval() + + hf_outputs = hf_model(**batch_dict_cuda, output_hidden_states=True) + ids = batch_dict_cuda['input_ids'] + + id_tensors = [torch.unsqueeze(torch.LongTensor(id_list), dim=0) for id_list in ids.cpu()] + + masks_and_position_ids = [ + get_ltor_masks_and_position_ids(id_tensor, hf_tokenizer.eos_token, False, False, False) + for id_tensor in id_tensors + ] + for tokens, attn_mask_and_pos_ids in zip(id_tensors, masks_and_position_ids): + attn_mask, _, pos_ids = attn_mask_and_pos_ids + + outputs = model( + tokens=tokens, text_position_ids=pos_ids.cuda(), attention_mask=attn_mask.cuda(), labels=None + ) + + hf_next_token = hf_outputs.logits[0, -1].argmax() + next_token = outputs.squeeze()[-1].argmax() + + logging.info(f"HF predicted next token is: '{hf_tokenizer._convert_id_to_token(int(hf_next_token))}'.") + logging.info(f"NeMo predicted next token is: '{hf_tokenizer._convert_id_to_token(int(next_token))}'.") + assert ( + hf_next_token == next_token + ), f'prediction mismatch: {hf_tokenizer.decode(hf_next_token)} != {hf_tokenizer.decode(next_token)}' + logging.info(f'=' * 100) + + hf_model.save_pretrained(args.hf_output_path) + logging.info(f"Full HF model saved to {args.hf_output_path}") + + +if __name__ == '__main__': + args = get_args() + convert(args) diff --git a/scripts/deploy/nlp/deploy_triton.py b/scripts/deploy/nlp/deploy_triton.py index d0854916cd38..8916fec0b1dd 100755 --- a/scripts/deploy/nlp/deploy_triton.py +++ b/scripts/deploy/nlp/deploy_triton.py @@ -16,14 +16,34 @@ import logging import os import sys +import tempfile from pathlib import Path from nemo.deploy import DeployPyTriton -from nemo.deploy.nlp import MegatronLLMDeployable -from nemo.export import TensorRTLLM LOGGER = logging.getLogger("NeMo") +megatron_llm_supported = True +try: + from nemo.deploy.nlp import MegatronLLMDeployable +except Exception as e: + LOGGER.warning(f"Cannot import MegatronLLMDeployable, it will not be available. {type(e).__name__}: {e}") + megatron_llm_supported = False + +trt_llm_supported = True +try: + from nemo.export.tensorrt_llm import TensorRTLLM +except Exception as e: + LOGGER.warning(f"Cannot import the TensorRTLLM exporter, it will not be available. {type(e).__name__}: {e}") + trt_llm_supported = False + +vllm_supported = True +try: + from nemo.export.vllm_exporter import vLLMExporter +except Exception as e: + LOGGER.warning(f"Cannot import the vLLM exporter, it will not be available. {type(e).__name__}: {e}") + vllm_supported = False + def get_args(argv): parser = argparse.ArgumentParser( @@ -69,7 +89,7 @@ def get_args(argv): choices=["bfloat16", "float16", "fp8", "int8"], default="bfloat16", type=str, - help="dtype of the model on TensorRT-LLM", + help="dtype of the model on TensorRT-LLM or vLLM", ) parser.add_argument("-mil", "--max_input_len", default=256, type=int, help="Max input length of the model") parser.add_argument("-mol", "--max_output_len", default=256, type=int, help="Max output length of the model") @@ -150,7 +170,23 @@ def get_args(argv): help="Different options to deploy nemo model.", ) parser.add_argument("-dm", "--debug_mode", default=False, action='store_true', help="Enable debug mode") - + parser.add_argument( + '-ws', + '--weight_storage', + default='auto', + choices=['auto', 'cache', 'file', 'memory'], + help='Strategy for storing converted weights for vLLM: "file" - always write weights into a file, ' + '"memory" - always do an in-memory conversion, "cache" - reuse existing files if they are ' + 'newer than the nemo checkpoint, "auto" - use "cache" for multi-GPU runs and "memory" ' + 'for single-GPU runs.', + ) + parser.add_argument( + "-gmu", + '--gpu_memory_utilization', + default=0.9, + type=float, + help="GPU memory utilization percentage for vLLM.", + ) args = parser.parse_args(argv) return args @@ -160,8 +196,8 @@ def get_trtllm_deployable(args): trt_llm_path = "/tmp/trt_llm_model_dir/" LOGGER.info( "/tmp/trt_llm_model_dir/ path will be used as the TensorRT LLM folder. " - "Please set this parameter if you'd like to use a path that has already " - "included the TensorRT LLM model files." + "Please set the --triton_model_repository parameter if you'd like to use a path that already " + "includes the TensorRT LLM model files." ) Path(trt_llm_path).mkdir(parents=True, exist_ok=True) else: @@ -261,6 +297,45 @@ def get_trtllm_deployable(args): return trt_llm_exporter +def get_vllm_deployable(args): + if args.ptuning_nemo_checkpoint is not None: + raise ValueError("vLLM backend doesn't support P-tuning at this time.") + if args.lora_ckpt is not None: + raise ValueError("vLLM backend doesn't support LoRA at this time.") + + tempdir = None + model_dir = args.triton_model_repository + if model_dir is None: + tempdir = tempfile.TemporaryDirectory() + model_dir = tempdir.name + LOGGER.info( + f"{model_dir} path will be used as the vLLM intermediate folder. " + + "Please set the --triton_model_repository parameter if you'd like to use a path that already " + + "includes the vLLM model files." + ) + elif not os.path.exists(model_dir): + os.makedirs(model_dir) + + try: + exporter = vLLMExporter() + exporter.export( + nemo_checkpoint=args.nemo_checkpoint, + model_dir=model_dir, + model_type=args.model_type, + tensor_parallel_size=args.num_gpus, + max_model_len=args.max_input_len + args.max_output_len, + dtype=args.dtype, + weight_storage=args.weight_storage, + gpu_memory_utilization=args.gpu_memory_utilization, + ) + return exporter + except Exception as error: + raise RuntimeError("An error has occurred during the model export. Error message: " + str(error)) + finally: + if tempdir is not None: + tempdir.cleanup() + + def get_nemo_deployable(args): if args.nemo_checkpoint is None: raise ValueError("In-Framework deployment requires a .nemo checkpoint") @@ -282,11 +357,17 @@ def nemo_deploy(argv): backend = args.backend.lower() if backend == 'tensorrt-llm': + if not trt_llm_supported: + raise ValueError("TensorRT-LLM engine is not supported in this environment.") triton_deployable = get_trtllm_deployable(args) elif backend == 'in-framework': + if not megatron_llm_supported: + raise ValueError("MegatronLLMDeployable is not supported in this environment.") triton_deployable = get_nemo_deployable(args) elif backend == 'vllm': - raise ValueError("vLLM will be supported in the next release.") + if not vllm_supported: + raise ValueError("vLLM engine is not supported in this environment.") + triton_deployable = get_vllm_deployable(args) else: raise ValueError("Backend: {0} is not supported.".format(backend)) diff --git a/scripts/export/export_to_trt_llm.py b/scripts/export/export_to_trt_llm.py index a0c70c8bbd85..49fefd40561b 100644 --- a/scripts/export/export_to_trt_llm.py +++ b/scripts/export/export_to_trt_llm.py @@ -16,7 +16,7 @@ import logging import sys -from nemo.export import TensorRTLLM +from nemo.export.tensorrt_llm import TensorRTLLM LOGGER = logging.getLogger("NeMo") diff --git a/tests/export/nemo_export.py b/tests/export/nemo_export.py index 5541cc0f8673..013a22deee3b 100644 --- a/tests/export/nemo_export.py +++ b/tests/export/nemo_export.py @@ -14,46 +14,85 @@ import argparse import json +import logging import shutil +import sys import time +from dataclasses import dataclass from pathlib import Path +from typing import Dict, List, Optional, Tuple + import torch -from tests.infer_data_path import get_infer_test_data +# Import infer_data_path from the parent folder assuming that the 'tests' package is not installed. +sys.path.append(str(Path(__file__).parent.parent)) +from infer_data_path import get_infer_test_data + +LOGGER = logging.getLogger("NeMo") -run_export_tests = True +triton_supported = True try: from nemo.deploy import DeployPyTriton from nemo.deploy.nlp import NemoQueryLLM - from nemo.export import TensorRTLLM except Exception as e: - run_export_tests = False + LOGGER.warning(f"Cannot import Triton, deployment will not be available. {type(e).__name__}: {e}") + triton_supported = False + +trt_llm_supported = True +try: + from nemo.export.tensorrt_llm import TensorRTLLM +except Exception as e: + LOGGER.warning(f"Cannot import the TensorRTLLM exporter, it will not be available. {type(e).__name__}: {e}") + trt_llm_supported = False + +vllm_supported = True +try: + from nemo.export.vllm_exporter import vLLMExporter +except Exception as e: + LOGGER.warning(f"Cannot import the vLLM exporter, it will not be available. {type(e).__name__}: {e}") + vllm_supported = False -def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path=None): +class UsageError(Exception): + pass + + +@dataclass +class FunctionalResult: + regular_pass: Optional[bool] = None + deployed_pass: Optional[bool] = None + + +@dataclass +class AccuracyResult: + accuracy: float + accuracy_relaxed: float + deployed_accuracy: float + deployed_accuracy_relaxed: float + evaluation_time: float + + +def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path): # lambada dataset based accuracy test, which includes more than 5000 sentences. # Use generated last token with original text's last token for accuracy comparison. # If the generated last token start with the original token, trtllm_correct make an increment. # It generates a CSV file for text comparison detail. - if test_data_path is None: - raise Exception("test_data_path cannot be None.") - - trtllm_correct = 0 - trtllm_deployed_correct = 0 - trtllm_correct_relaxed = 0 - trtllm_deployed_correct_relaxed = 0 + correct_answers = 0 + correct_answers_deployed = 0 + correct_answers_relaxed = 0 + correct_answers_deployed_relaxed = 0 all_expected_outputs = [] - all_trtllm_outputs = [] + all_actual_outputs = [] with open(test_data_path, 'r') as file: records = json.load(file) - eval_start = time.perf_counter() + eval_start = time.monotonic() for record in records: prompt = record["text_before_last_word"] expected_output = record["last_word"].strip().lower() - trtllm_output = model.forward( + model_output = model.forward( input_texts=[prompt], max_output_len=1, top_k=1, @@ -62,22 +101,22 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path=Non task_ids=task_ids, lora_uids=lora_uids, ) - trtllm_output = trtllm_output[0][0].strip().lower() + model_output = model_output[0][0].strip().lower() all_expected_outputs.append(expected_output) - all_trtllm_outputs.append(trtllm_output) + all_actual_outputs.append(model_output) - if expected_output == trtllm_output: - trtllm_correct += 1 + if expected_output == model_output: + correct_answers += 1 if ( - expected_output == trtllm_output - or trtllm_output.startswith(expected_output) - or expected_output.startswith(trtllm_output) + expected_output == model_output + or model_output.startswith(expected_output) + or expected_output.startswith(model_output) ): - if len(trtllm_output) == 1 and len(expected_output) > 1: + if len(model_output) == 1 and len(expected_output) > 1: continue - trtllm_correct_relaxed += 1 + correct_answers_relaxed += 1 if nq is not None: trtllm_deployed_output = nq.query_llm( @@ -91,7 +130,7 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path=Non trtllm_deployed_output = trtllm_deployed_output[0][0].strip().lower() if expected_output == trtllm_deployed_output: - trtllm_deployed_correct += 1 + correct_answers_deployed += 1 if ( expected_output == trtllm_deployed_output @@ -100,32 +139,47 @@ def get_accuracy_with_lambada(model, nq, task_ids, lora_uids, test_data_path=Non ): if len(trtllm_deployed_output) == 1 and len(expected_output) > 1: continue - trtllm_deployed_correct_relaxed += 1 - eval_end = time.perf_counter() + correct_answers_deployed_relaxed += 1 + eval_end = time.monotonic() + + return AccuracyResult( + accuracy=correct_answers / len(all_expected_outputs), + accuracy_relaxed=correct_answers_relaxed / len(all_expected_outputs), + deployed_accuracy=correct_answers_deployed / len(all_expected_outputs), + deployed_accuracy_relaxed=correct_answers_deployed_relaxed / len(all_expected_outputs), + evaluation_time=eval_end - eval_start, + ) - trtllm_accuracy = trtllm_correct / len(all_expected_outputs) - trtllm_accuracy_relaxed = trtllm_correct_relaxed / len(all_expected_outputs) - trtllm_deployed_accuracy = trtllm_deployed_correct / len(all_expected_outputs) - trtllm_deployed_accuracy_relaxed = trtllm_deployed_correct_relaxed / len(all_expected_outputs) +# Tests if the model outputs contain the expected keywords. +def check_model_outputs(streaming: bool, model_outputs, expected_outputs: List[str]) -> bool: - evaluation_time = eval_end - eval_start + # In streaming mode, we get a list of lists of lists, and we only care about the last item in that list + if streaming: + if len(model_outputs) == 0: + return False + model_outputs = model_outputs[-1] - return ( - trtllm_accuracy, - trtllm_accuracy_relaxed, - trtllm_deployed_accuracy, - trtllm_deployed_accuracy_relaxed, - evaluation_time, - ) + # See if we have the right number of final answers. + if len(model_outputs) != len(expected_outputs): + return False + + # Check the presence of keywords in the final answers. + for i in range(len(model_outputs)): + if expected_outputs[i] not in model_outputs[i][0]: + return False + return True -def run_trt_llm_inference( + +def run_inference( model_name, model_type, - prompt, + prompts, + expected_outputs, checkpoint_path, - trt_llm_model_dir, + model_dir, + use_vllm, n_gpu=1, max_batch_size=8, use_embedding_sharing=False, @@ -135,8 +189,8 @@ def run_trt_llm_inference( p_tuning_checkpoint=None, lora=False, lora_checkpoint=None, - tp_size=None, - pp_size=None, + tp_size=1, + pp_size=1, top_k=1, top_p=0.0, temperature=1.0, @@ -147,7 +201,7 @@ def run_trt_llm_inference( test_deployment=False, test_data_path=None, save_trt_engine=False, -): +) -> Tuple[Optional[FunctionalResult], Optional[AccuracyResult]]: if Path(checkpoint_path).exists(): if n_gpu > torch.cuda.device_count(): print( @@ -155,9 +209,9 @@ def run_trt_llm_inference( checkpoint_path, model_name, n_gpu, torch.cuda.device_count() ) ) - return None, None, None, None, None + return (None, None) - Path(trt_llm_model_dir).mkdir(parents=True, exist_ok=True) + Path(model_dir).mkdir(parents=True, exist_ok=True) if debug: print("") @@ -182,7 +236,7 @@ def run_trt_llm_inference( print("---- PTuning enabled.") else: print("---- PTuning could not be enabled and skipping the test.") - return None, None, None, None, None + return (None, None) lora_ckpt_list = None lora_uids = None @@ -199,36 +253,48 @@ def run_trt_llm_inference( print("---- LoRA enabled.") else: print("---- LoRA could not be enabled and skipping the test.") - return None, None, None, None, None - - trt_llm_exporter = TensorRTLLM(trt_llm_model_dir, lora_ckpt_list, load_model=False) - - trt_llm_exporter.export( - nemo_checkpoint_path=checkpoint_path, - model_type=model_type, - n_gpus=n_gpu, - tensor_parallel_size=tp_size, - pipeline_parallel_size=pp_size, - max_input_len=max_input_len, - max_output_len=max_output_len, - max_batch_size=max_batch_size, - max_prompt_embedding_table_size=max_prompt_embedding_table_size, - use_lora_plugin=use_lora_plugin, - lora_target_modules=lora_target_modules, - max_num_tokens=int(max_input_len * max_batch_size * 0.2), - opt_num_tokens=60, - use_embedding_sharing=use_embedding_sharing, - save_nemo_model_config=True, - ) + return (None, None) + + if use_vllm: + exporter = vLLMExporter() + + exporter.export( + nemo_checkpoint=checkpoint_path, + model_dir=model_dir, + model_type=model_type, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + max_model_len=max_input_len + max_output_len, + ) + else: + exporter = TensorRTLLM(model_dir, lora_ckpt_list, load_model=False) + + exporter.export( + nemo_checkpoint_path=checkpoint_path, + model_type=model_type, + n_gpus=n_gpu, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + max_input_len=max_input_len, + max_output_len=max_output_len, + max_batch_size=max_batch_size, + max_prompt_embedding_table_size=max_prompt_embedding_table_size, + use_lora_plugin=use_lora_plugin, + lora_target_modules=lora_target_modules, + max_num_tokens=int(max_input_len * max_batch_size * 0.2), + opt_num_tokens=60, + use_embedding_sharing=use_embedding_sharing, + save_nemo_model_config=True, + ) if ptuning: - trt_llm_exporter.add_prompt_table( + exporter.add_prompt_table( task_name="0", prompt_embeddings_checkpoint_path=prompt_embeddings_checkpoint_path, ) - output = trt_llm_exporter.forward( - input_texts=prompt, + output = exporter.forward( + input_texts=prompts, max_output_len=max_output_len, top_k=top_k, top_p=top_p, @@ -239,10 +305,21 @@ def run_trt_llm_inference( stop_words_list=stop_words_list, ) - if not use_lora_plugin and not ptuning: + # Unwrap the generator if needed + output = list(output) + + functional_result = FunctionalResult() + + # Check non-deployed funcitonal correctness + functional_result.regular_pass = True + if not check_model_outputs(streaming, output, expected_outputs): + LOGGER.warning("Model outputs don't match the expected result.") + functional_result.regular_pass = False + + if not use_lora_plugin and not ptuning and not use_vllm: test_cpp_runtime( - engine_path=trt_llm_model_dir, - prompt=prompt, + engine_path=model_dir, + prompt=prompts, max_output_len=max_output_len, debug=True, ) @@ -252,7 +329,7 @@ def run_trt_llm_inference( output_deployed = "" if test_deployment: nm = DeployPyTriton( - model=trt_llm_exporter, + model=exporter, triton_model_name=model_name, port=8000, ) @@ -261,7 +338,7 @@ def run_trt_llm_inference( nq = NemoQueryLLM(url="localhost:8000", model_name=model_name) output_deployed = nq.query_llm( - prompts=prompt, + prompts=prompts, max_output_len=max_output_len, top_k=1, top_p=0.0, @@ -269,33 +346,38 @@ def run_trt_llm_inference( lora_uids=lora_uids, ) - if debug: + # Unwrap the generator if needed + output_deployed = list(output_deployed) + + # Check deployed funcitonal correctness + functional_result.deployed_pass = True + if not check_model_outputs(streaming, output_deployed, expected_outputs): + LOGGER.warning("Deployed model outputs don't match the expected result.") + functional_result.deployed_pass = False + + if debug or functional_result.regular_pass == False or functional_result.deployed_pass == False: print("") - print("--- Prompt: ", prompt) + print("--- Prompt: ", prompts) print("") - print("--- Output: ", output) + print("--- Expected keywords: ", expected_outputs) print("") + print("--- Output: ", output) print("") print("--- Output deployed: ", output_deployed) print("") + accuracy_result = None if run_accuracy: print("Start model accuracy testing ...") - result = get_accuracy_with_lambada(trt_llm_exporter, nq, task_ids, lora_uids, test_data_path) - if test_deployment: - nm.stop() - - if not save_trt_engine: - shutil.rmtree(trt_llm_model_dir) - return result + accuracy_result = get_accuracy_with_lambada(exporter, nq, task_ids, lora_uids, test_data_path) if test_deployment: nm.stop() if not save_trt_engine: - shutil.rmtree(trt_llm_model_dir) + shutil.rmtree(model_dir) - return None, None, None, None, None + return (functional_result, accuracy_result) else: raise Exception("Checkpoint {0} could not be found.".format(checkpoint_path)) @@ -323,6 +405,7 @@ def test_cpp_runtime( def run_existing_checkpoints( model_name, + use_vllm, n_gpus, tp_size=None, pp_size=None, @@ -334,10 +417,10 @@ def run_existing_checkpoints( stop_words_list=None, test_data_path=None, save_trt_engine=False, -): +) -> Tuple[Optional[FunctionalResult], Optional[AccuracyResult]]: if n_gpus > torch.cuda.device_count(): print("Skipping the test due to not enough number of GPUs") - return None, None, None, None, None + return (None, None) test_data = get_infer_test_data() if not (model_name in test_data.keys()): @@ -347,7 +430,7 @@ def run_existing_checkpoints( if n_gpus < model_info["min_gpus"]: print("Min n_gpus for this model is {0}".format(n_gpus)) - return None, None, None, None, None + return (None, None) p_tuning_checkpoint = None if ptuning: @@ -369,12 +452,13 @@ def run_existing_checkpoints( else: use_embedding_sharing = False - return run_trt_llm_inference( + return run_inference( model_name=model_name, model_type=model_info["model_type"], - prompt=model_info["prompt_template"], + prompts=model_info["prompt_template"], checkpoint_path=model_info["checkpoint"], - trt_llm_model_dir=model_info["trt_llm_model_dir"], + model_dir=model_info["model_dir"], + use_vllm=use_vllm, n_gpu=n_gpus, max_batch_size=model_info["max_batch_size"], use_embedding_sharing=use_embedding_sharing, @@ -437,7 +521,7 @@ def get_args(): required=False, ) parser.add_argument( - "--trt_llm_model_dir", + "--model_dir", type=str, ) parser.add_argument( @@ -475,10 +559,12 @@ def get_args(): ) parser.add_argument( "--tp_size", + default=1, type=int, ) parser.add_argument( "--pp_size", + default=1, type=int, ) parser.add_argument( @@ -527,31 +613,48 @@ def get_args(): type=str, default="False", ) + parser.add_argument( + "--use_vllm", + type=str, + default="False", + ) + + args = parser.parse_args() + + def str_to_bool(name: str, s: str) -> bool: + true_strings = ["true", "1"] + false_strings = ["false", "0"] + if s.lower() in true_strings: + return True + if s.lower() in false_strings: + return False + raise UsageError(f"Invalid boolean value for argument --{name}: '{s}'") + + args.test_deployment = str_to_bool("test_deployment", args.test_deployment) + args.save_trt_engine = str_to_bool("save_trt_engin", args.save_trt_engine) + args.run_accuracy = str_to_bool("run_accuracy", args.run_accuracy) + args.use_vllm = str_to_bool("use_vllm", args.use_vllm) - return parser.parse_args() + return args def run_inference_tests(args): - if args.test_deployment == "True": - args.test_deployment = True - else: - args.test_deployment = False + if not args.use_vllm and not trt_llm_supported: + raise UsageError("TensorRT-LLM engine is not supported in this environment.") - if args.save_trt_engine == "True": - args.save_trt_engine = True - else: - args.save_trt_engine = False + if args.use_vllm and not vllm_supported: + raise UsageError("vLLM engine is not supported in this environment.") - if args.run_accuracy == "True": - args.run_accuracy = True - else: - args.run_accuracy = False + if args.use_vllm and (args.ptuning or args.lora): + raise UsageError("The vLLM integration currently does not support P-tuning or LoRA.") - if args.run_accuracy: - if args.test_data_path is None: - raise Exception("test_data_path param cannot be None.") + if args.test_deployment and not triton_supported: + raise UsageError("Deployment tests are not available because Triton is not supported in this environment.") - result_dic = {} + if args.run_accuracy and args.test_data_path is None: + raise UsageError("Accuracy testing requires the --test_data_path argument.") + + result_dic: Dict[int, Tuple[FunctionalResult, Optional[AccuracyResult]]] = {} if args.existing_test_models: n_gpus = args.min_gpus @@ -561,6 +664,7 @@ def run_inference_tests(args): while n_gpus <= args.max_gpus: result_dic[n_gpus] = run_existing_checkpoints( model_name=args.model_name, + use_vllm=args.use_vllm, n_gpus=n_gpus, ptuning=args.ptuning, lora=args.lora, @@ -575,18 +679,24 @@ def run_inference_tests(args): n_gpus = n_gpus * 2 else: - prompt_template = ["The capital of France is", "Largest animal in the sea is"] + if args.model_dir is None: + raise Exception("When using custom checkpoints, --model_dir is required.") + + prompts = ["The capital of France is", "Largest animal in the sea is"] + expected_outputs = ["Paris", "blue whale"] n_gpus = args.min_gpus if args.max_gpus is None: args.max_gpus = args.min_gpus while n_gpus <= args.max_gpus: - result_dic[n_gpus] = run_trt_llm_inference( + result_dic[n_gpus] = run_inference( model_name=args.model_name, model_type=args.model_type, - prompt=prompt_template, + prompts=prompts, + expected_outputs=expected_outputs, checkpoint_path=args.checkpoint_dir, - trt_llm_model_dir=args.trt_llm_model_dir, + model_dir=args.model_dir, + use_vllm=args.use_vllm, n_gpu=n_gpus, max_batch_size=args.max_batch_size, max_input_len=args.max_input_len, @@ -610,31 +720,59 @@ def run_inference_tests(args): n_gpus = n_gpus * 2 - test_result = "PASS" + functional_test_result = "PASS" + accuracy_test_result = "PASS" print_separator = False print("============= Test Summary ============") - for i, results in result_dic.items(): - if not results[0] is None and not results[1] is None: - if print_separator: - print("---------------------------------------") - print( - "Number of GPUS: {}\n" - "Model Accuracy: {:.4f}\n" - "Relaxed Model Accuracy: {:.4f}\n" - "Deployed Model Accuracy: {:.4f}\n" - "Deployed Relaxed Model Accuracy: {:.4f}\n" - "Evaluation Time [s]: {:.2f}".format(i, *results) - ) - print_separator = True - if results[1] < 0.5: - test_result = "FAIL" + for num_gpus, results in result_dic.items(): + functional_result, accuracy_result = results + + if print_separator: + print("---------------------------------------") + print_separator = True + + def optional_bool_to_pass_fail(b: Optional[bool]): + if b is None: + return "N/A" + return "PASS" if b else "FAIL" + + print(f"Number of GPUS: {num_gpus}") + + if functional_result is not None: + print(f"Functional Test: {optional_bool_to_pass_fail(functional_result.regular_pass)}") + print(f"Deployed Functional Test: {optional_bool_to_pass_fail(functional_result.deployed_pass)}") + + if functional_result.regular_pass == False: + functional_test_result = "FAIL" + if functional_result.deployed_pass == False: + functional_test_result = "FAIL" + + if accuracy_result is not None: + print(f"Model Accuracy: {accuracy_result.accuracy:.4f}") + print(f"Relaxed Model Accuracy: {accuracy_result.accuracy_relaxed:.4f}") + print(f"Deployed Model Accuracy: {accuracy_result.deployed_accuracy:.4f}") + print(f"Deployed Relaxed Model Accuracy: {accuracy_result.deployed_accuracy_relaxed:.4f}") + print(f"Evaluation Time [s]: {accuracy_result.evaluation_time:.2f}") + if accuracy_result.accuracy_relaxed < 0.5: + accuracy_test_result = "FAIL" print("=======================================") - print("TEST: " + test_result) - if test_result == "FAIL": + print(f"Functional: {functional_test_result}") + if args.run_accuracy: + print(f"Acccuracy: {accuracy_test_result}") + + if functional_test_result == "FAIL": + raise Exception("Functional test failed") + + if accuracy_test_result == "FAIL": raise Exception("Model accuracy is below 0.5") if __name__ == '__main__': - args = get_args() - run_inference_tests(args) + try: + args = get_args() + run_inference_tests(args) + except UsageError as e: + LOGGER.error(f"{e}") + except argparse.ArgumentError as e: + LOGGER.error(f"{e}") diff --git a/tests/lightning/io/test_api.py b/tests/lightning/io/test_api.py index 9872d0860193..d13573de180f 100644 --- a/tests/lightning/io/test_api.py +++ b/tests/lightning/io/test_api.py @@ -16,7 +16,7 @@ def test_reload_ckpt(self, tmpdir): ) ) - ckpt = io.TrainerCheckpoint(model, trainer) + ckpt = io.TrainerContext(model, trainer) ckpt.io_dump(tmpdir) loaded = io.load_ckpt(tmpdir)