diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 08b0b822cad4..ef3cf9898ede 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -30,10 +30,11 @@ import wrapt from tensorrt_llm._utils import numpy_to_torch +from nemo.collections.nlp.parts.utils_funcs import torch_dtype_from_precision from nemo.deploy import ITritonDeployable from nemo.export.tarutils import TarPath, unpack_tarball from nemo.export.trt_llm.converter.model_converter import model_to_trtllm_ckpt -from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import dist_model_to_trt_llm_ckpt +from nemo.export.trt_llm.converter.model_to_trt_llm_ckpt import dist_model_to_trt_llm_ckpt, get_layer_prefix from nemo.export.trt_llm.converter.utils import init_model_parallel_from_nemo from nemo.export.trt_llm.nemo_ckpt_loader.nemo_file import ( build_tokenizer, @@ -65,6 +66,8 @@ @wrapt.decorator def noop_decorator(func): + """No op decorator""" + def wrapper(*args, **kwargs): return func(*args, **kwargs) @@ -80,6 +83,7 @@ def wrapper(*args, **kwargs): use_pytriton = False +# pylint: disable=line-too-long class TensorRTLLM(ITritonDeployable): """ Exports nemo checkpoints to TensorRT-LLM and run fast inference. @@ -343,43 +347,14 @@ def export( DEFAULT_CONVERSION_DICT, ) from megatron.core.export.trtllm.trtllm_helper import TRTLLMHelper - from megatron.core.transformer.transformer_config import TransformerConfig from tensorrt_llm.layers import MoeConfig - def get_transformer_config(nemo_model_config): - normalization = nemo_model_config.get('normalization', 'layernorm') - transformer_config_normalization = 'LayerNorm' - layernorm_zero_centered_gamma = False - if normalization == 'layernorm1p': - layernorm_zero_centered_gamma = True - elif normalization == 'rmsnorm': - transformer_config_normalization = 'RMSNorm' - - conf = TransformerConfig( - num_layers=nemo_model_config.get('num_layers'), - moe_router_topk=nemo_model_config.get('moe_router_topk', 0), - num_attention_heads=nemo_model_config.get('num_attention_heads'), - num_query_groups=nemo_model_config.get( - 'num_query_groups', nemo_model_config['num_attention_heads'] - ), - kv_channels=nemo_model_config.get("kv_channels", None), - hidden_size=nemo_model_config.get('hidden_size'), - ffn_hidden_size=nemo_model_config.get('ffn_hidden_size'), - layernorm_epsilon=nemo_model_config.get('layernorm_epsilon'), - add_bias_linear=nemo_model_config.get('bias'), - num_moe_experts=nemo_model_config.get('num_moe_experts', None), - normalization=transformer_config_normalization, - layernorm_zero_centered_gamma=layernorm_zero_centered_gamma, - ) - - return conf - # We build the transformer config using the nemo model config. - transformer_config = get_transformer_config(model_configs) + transformer_config = self.get_transformer_config(model_configs) input_model_type = getattr(ModelType, model_type) # MCore export supports some default conversion dictionaries - mcore_model_conversion_dict = DEFAULT_CONVERSION_DICT[input_model_type] + mcore_model_conversion_dict = DEFAULT_CONVERSION_DICT # All Mcore conversion dicts start with "decoder.layers.4.blah.blah" , while nemo models start with "model.decoder.layers.4.blahblah". so we append model. to the keys nemo_model_conversion_dict = { f'model.{key}': value for key, value in mcore_model_conversion_dict.items() @@ -518,6 +493,34 @@ def get_transformer_config(nemo_model_config): if load_model: self._load() + def get_transformer_config(self, nemo_model_config): + """Given nemo model config get transformer config""" + from megatron.core.transformer.transformer_config import TransformerConfig + + normalization = nemo_model_config.get('normalization', 'layernorm') + transformer_config_normalization = 'LayerNorm' + layernorm_zero_centered_gamma = False + if normalization == 'layernorm1p': + layernorm_zero_centered_gamma = True + elif normalization == 'rmsnorm': + transformer_config_normalization = 'RMSNorm' + + conf = TransformerConfig( + num_layers=nemo_model_config.get('num_layers'), + moe_router_topk=nemo_model_config.get('moe_router_topk', 0), + num_attention_heads=nemo_model_config.get('num_attention_heads'), + num_query_groups=nemo_model_config.get('num_query_groups', nemo_model_config['num_attention_heads']), + kv_channels=nemo_model_config.get("kv_channels", None), + hidden_size=nemo_model_config.get('hidden_size'), + ffn_hidden_size=nemo_model_config.get('ffn_hidden_size'), + layernorm_epsilon=nemo_model_config.get('layernorm_epsilon'), + add_bias_linear=nemo_model_config.get('bias'), + num_moe_experts=nemo_model_config.get('num_moe_experts', None), + normalization=transformer_config_normalization, + layernorm_zero_centered_gamma=layernorm_zero_centered_gamma, + ) + return conf + def convert_to_safe_tensors( self, nemo_checkpoint_path: str, @@ -530,6 +533,7 @@ def convert_to_safe_tensors( use_embedding_sharing: bool = False, dtype: str = "bfloat16", ): + """Convert to safe tensor""" gpus_per_node = tensor_parallelism_size if gpus_per_node is None else gpus_per_node if Path(self.model_dir).exists(): @@ -595,6 +599,167 @@ def convert_to_safe_tensors( if tensorrt_llm.mpi_world_size() > 1: tensorrt_llm.mpi_barrier() + def gather_and_reshard_model(self, model_config, model, storage_dtype): + """ + Accumulate all vp model chunks together, and reshard model (i.e) gather all pp ranks + if required and return the final model state dict + """ + + def _get_layer_index(split_key): + for index, key in enumerate(split_key): + if key == "layers": + return index + 1 + raise ValueError(f"Unknown layer name format: {split_key}") + + def rename_layer_num(param_name, layer_num): + split_key = param_name.split(".") + layer_index = int(_get_layer_index(split_key)) + split_key[layer_index] = str(layer_num) + return ".".join(split_key) + + def get_layer_num(param_name): + split_key = param_name.split(".") + layer_index = int(_get_layer_index(split_key)) + return int(split_key[layer_index]) + + from megatron.core import parallel_state + + tp_size = parallel_state.get_tensor_model_parallel_world_size() + pp_rank = parallel_state.get_pipeline_model_parallel_rank() + pp_first_rank = parallel_state.get_pipeline_model_parallel_first_rank() + pp_last_rank = parallel_state.get_pipeline_model_parallel_last_rank() + pp_size = parallel_state.get_pipeline_model_parallel_world_size() + pp_group = parallel_state.get_pipeline_model_parallel_group() + vp_size = parallel_state.get_virtual_pipeline_model_parallel_world_size() + if not vp_size: + vp_size = 1 + + inference_tp_size = self.tp_size + inference_pp_size = self.pp_size + reshard_model = False + if inference_tp_size != tp_size or inference_pp_size != pp_size: + LOGGER.info("Training/Generation model parallelism resharding enabled") + if inference_pp_size == 1 and pp_size > 1 and inference_tp_size == tp_size: + reshard_model = True + else: + raise NotImplementedError( + f"NeMo currently only supports PP>1 -> PP=1 resharding, other types of resharding will come in future releases." + ) + + num_layers = model_config["num_layers"] + layers_per_pp = num_layers // pp_size + layers_per_chunk = layers_per_pp // vp_size + + tl_params = {} + model_level_params = {} + if vp_size > 1: # consolidate params across model chunks + for idx, model_chunk in enumerate(model): + for key, val in model_chunk.state_dict().items(): + if torch.is_tensor(val): + if 'layers' in key: + key2 = rename_layer_num(key, get_layer_num(key) + idx * pp_size * layers_per_chunk) + tl_params[key2] = val + else: + model_level_params[key] = val + else: + for key, val in model.state_dict().items(): + if torch.is_tensor(val): + if 'decoder.layers' in key: + tl_params[key] = val + else: + model_level_params[key] = val + + if vp_size > 1 or reshard_model: + # gather layers across pp ranks + gathered_params = {} + for key, val in tl_params.items(): + weight_list = [torch.zeros_like(val) for _ in range(pp_size)] + torch.distributed.all_gather(weight_list, val, group=pp_group) + for idx in range(pp_size): + layer_num = get_layer_num(key) + idx * layers_per_chunk + key2 = rename_layer_num(key, layer_num) + if not reshard_model: # Save only layers of 1 single PP stage + layers_start = layers_per_pp * pp_rank + layers_end = layers_per_pp * (pp_rank + 1) - 1 + if layer_num >= layers_start and layer_num <= layers_end: + key2 = rename_layer_num(key, layer_num % layers_per_pp) + gathered_params[key2] = weight_list[idx] + else: + gathered_params[key2] = weight_list[idx] + tl_params = gathered_params + + model_state_dict = model_level_params + model_state_dict.update(tl_params) + + def get_tensor_if_available(key, pp_src_idx, group): + tensor = model_state_dict.get(key) + if tensor is not None: + tensor_shape = [tensor.shape] + else: + tensor_shape = [None] + + torch.distributed.broadcast_object_list(tensor_shape, pp_src_idx, group=group) + + if tensor_shape[0] is None: + return None + if torch.distributed.get_rank() != pp_src_idx: + tensor = torch.empty(tensor_shape[0], dtype=storage_dtype).cuda() + + torch.distributed.broadcast(tensor.contiguous(), pp_src_idx, group=pp_group) + return tensor + + if reshard_model: + key = 'decoder.final_layernorm.weight' + tensor = get_tensor_if_available(key, pp_last_rank, pp_group) + if tensor is not None: + model_state_dict[key] = tensor + + key = 'decoder.final_layernorm.bias' + tensor = get_tensor_if_available(key, pp_last_rank, pp_group) + if tensor is not None: + model_state_dict[key] = tensor + + key = 'embedding.word_embeddings.weight' + tensor = get_tensor_if_available(key, pp_first_rank, pp_group) + if tensor is not None: + model_state_dict[key] = tensor + + key = 'output_layer.weight' + tensor = get_tensor_if_available(key, pp_last_rank, pp_group) + if tensor is not None: + model_state_dict[key] = tensor + + return model_state_dict + + def get_input_dtype(self, storage_dtype): + """ + Return mcore export dtype given torch dtype + """ + from megatron.core.export.data_type import DataType + + if storage_dtype == torch.bfloat16: + return DataType.bfloat16 + elif storage_dtype == torch.float32: + return DataType.float32 + elif storage_dtype == torch.float16: + return DataType.float16 + + def get_nemo_to_trtllm_conversion_dict(self, model_state_dict): + """MCore export supports some default conversion dictionaries + All Mcore conversion dicts start with "decoder.layers.4.blah.blah" , while nemo models sometimes start with "model.decoder.layers.4.blahblah". so we append model prefix. to the keys + """ + from megatron.core.export.trtllm.model_to_trllm_mapping.default_conversion_dict import DEFAULT_CONVERSION_DICT + + model_prefix, _ = get_layer_prefix(layer_names=model_state_dict.keys(), is_mcore=True) + + nemo_model_conversion_dict = {} + for key, value in DEFAULT_CONVERSION_DICT.items(): + if 'layers' in key: + nemo_model_conversion_dict[f'{model_prefix}.{key}'] = value + else: + nemo_model_conversion_dict[key] = value + return DEFAULT_CONVERSION_DICT + def build( self, model, @@ -607,6 +772,7 @@ def build( max_batch_size: int = 4, use_refit: bool = True, reshard_model: bool = False, + use_mcore_path: bool = True, ): """ Convert a model parallel nemo model to TensorRT-LLM. @@ -621,31 +787,103 @@ def build( if self.dp_size > 1: self.model_dir = os.path.join(self.model_dir, f"dp_rank{self.dp_rank}") - weights, model_config = model_to_trtllm_ckpt( - model=model, - nemo_model_config=model_config, - nemo_export_dir=self.model_dir, - decoder_type=model_type, - tensor_parallel_size=self.tp_size, - pipeline_parallel_size=self.pp_size, - gpus_per_node=gpus_per_node, - use_parallel_embedding=True, - use_distributed_convert=True, - model_parallel_rank=self.mp_rank, - vocab_size=self.tokenizer.vocab_size, - ) + if use_mcore_path: + from megatron.core.export.model_type import ModelType + from megatron.core.export.trtllm.trtllm_helper import TRTLLMHelper + from tensorrt_llm.layers import MoeConfig + + storage_dtype = torch_dtype_from_precision(model_config.precision) + model_state_dict = self.gather_and_reshard_model(model_config, model, storage_dtype) + # We build the transformer config using the nemo model config. + transformer_config = self.get_transformer_config(model_config) + input_model_type = getattr(ModelType, model_type) + + nemo_model_conversion_dict = self.get_nemo_to_trtllm_conversion_dict(model_state_dict) + + trtllm_helper = TRTLLMHelper( + transformer_config=transformer_config, + model_type=input_model_type, + trtllm_conversion_dict=nemo_model_conversion_dict, + position_embedding_type=model_config.get('position_embedding_type'), + max_position_embeddings=model_config.get('max_position_embeddings'), + rotary_percentage=model_config.get('rotary_percentage', 1.0), + rotary_base=model_config.get('rotary_base', 10000), + moe_tp_mode=model_config.get('moe_tp_mode', 2), + multi_query_mode=model_config.get("multi_query_mode", False), + activation=model_config.get('activation', "gelu"), + seq_len_interpolation_factor=model_config.get("seq_len_interpolation_factor"), + moe_renorm_mode=model_config.get( + 'moe_renorm_mode', MoeConfig.ExpertScaleNormalizationMode.RENORMALIZE + ), + share_embeddings_and_output_weights=model_config.get("share_embeddings_and_output_weights", False), + ) + + input_dtype = self.get_input_dtype(storage_dtype) + + trtllm_model_weights_list, trtllm_model_config_list = ( + trtllm_helper.get_trtllm_pretrained_config_and_model_weights( + model_state_dict=model_state_dict, + dtype=input_dtype, + state_dict_split_by_layer_numbers=True, + on_device_distributed_conversion=True, + vocab_size=self.tokenizer.vocab_size, + gpus_per_node=gpus_per_node, + ) + ) + trtllm_model_config = trtllm_model_config_list[0] + trtllm_model_weights = trtllm_model_weights_list[0] + + if reshard_model: + assert self.pp_size == 1, 'Reshard is true, but pp size is not one' + # MCORE Export will use parallel_state to determine pp . + # Since we reshard to pp = 1, we need to modify the config and mapping + world_size = self.tp_size * self.pp_size + trtllm_model_config.pp_size = self.pp_size + trtllm_model_config.world_size = world_size + trtllm_model_config.mapping = tensorrt_llm.Mapping( + world_size=world_size, + rank=self.mp_rank, + tp_size=self.tp_size, + pp_size=self.pp_size, + ) + + engine = trtllm_helper.build_and_save_engine( + max_input_len=max_input_len, + max_output_len=max_output_len, + max_seq_len=max_input_len + max_output_len, + max_batch_size=max_batch_size, + trtllm_model_config=trtllm_model_config, + trtllm_model_weights=trtllm_model_weights, + engine_dir=self.model_dir, + use_refit=use_refit, + ) + else: + weights, model_config = model_to_trtllm_ckpt( + model=model, + nemo_model_config=model_config, + nemo_export_dir=self.model_dir, + decoder_type=model_type, + tensor_parallel_size=self.tp_size, + pipeline_parallel_size=self.pp_size, + gpus_per_node=gpus_per_node, + use_parallel_embedding=True, + use_distributed_convert=True, + model_parallel_rank=self.mp_rank, + vocab_size=self.tokenizer.vocab_size, + ) + + engine = build_and_save_engine( + max_input_len=max_input_len, + max_output_len=max_output_len, + max_seq_len=max_input_len + max_output_len, + max_batch_size=max_batch_size, + model_config=model_config[0], + model_weights=weights[0], + model_dir=self.model_dir, + model_type=model_type, + use_refit=use_refit, + ) - engine = build_and_save_engine( - max_input_len=max_input_len, - max_output_len=max_output_len, - max_seq_len=max_input_len + max_output_len, - max_batch_size=max_batch_size, - model_config=model_config[0], - model_weights=weights[0], - model_dir=self.model_dir, - model_type=model_type, - use_refit=use_refit, - ) torch.distributed.barrier() cfg_path = Path(os.path.join(self.model_dir, f'config_{torch.distributed.get_rank()}.json')) @@ -654,18 +892,44 @@ def build( load_distributed(self.model_dir, self.mp_rank, gpus_per_node) - def refit(self, model, model_config): + def refit(self, model, model_config, use_mcore_path=True): """ Refits an TensorRT engine using an instantiated nemo model. This function should only be used after calling build() """ - weights_dict = dist_model_to_trt_llm_ckpt( - model=model, - nemo_model_config=model_config, - inference_tp_size=self.tp_size, - inference_pp_size=self.pp_size, - tokenizer_vocab_size=self.tokenizer.vocab_size, - ) + weights_dict = None + if use_mcore_path: + from megatron.core.export.trtllm.trtllm_weights_converter.distributed_trtllm_model_weights_converter import ( + DistributedTRTLLMModelWeightsConverter, + ) + + transformer_config = self.get_transformer_config(model_config) + storage_dtype = torch_dtype_from_precision(model_config.precision) + dtype = self.get_input_dtype(storage_dtype) + + dist_trtllm_model_weights_converter = DistributedTRTLLMModelWeightsConverter( + transformer_config=transformer_config, + dtype=dtype, + multi_query_mode=model_config.get("multi_query_mode", False), + activation=model_config.get('activation', "gelu"), + ) + + model_state_dict = self.gather_and_reshard_model(model_config, model, storage_dtype) + nemo_model_conversion_dict = self.get_nemo_to_trtllm_conversion_dict(model_state_dict) + dist_trtllm_model_weights_converter.convert( + model_state_dict=model_state_dict, + tokenizer_vocab_size=self.tokenizer.vocab_size, + trtllm_conversion_dict=nemo_model_conversion_dict, + ) + weights_dict = dist_trtllm_model_weights_converter.trtllm_model_weights + else: + weights_dict = dist_model_to_trt_llm_ckpt( + model=model, + nemo_model_config=model_config, + inference_tp_size=self.tp_size, + inference_pp_size=self.pp_size, + tokenizer_vocab_size=self.tokenizer.vocab_size, + ) load_distributed(self.model_dir, self.mp_rank, self.gpus_per_node) gc.collect() torch.cuda.empty_cache() @@ -806,6 +1070,7 @@ def forward( ) def add_prompt_table(self, task_name: str, prompt_embeddings_checkpoint_path: str): + """Add prompt table""" if self.model is None: raise Exception( "A nemo checkpoint should be exported to TensorRT-LLM and " @@ -827,6 +1092,7 @@ def add_prompt_table(self, task_name: str, prompt_embeddings_checkpoint_path: st self._prep_ptuning_table() def remove_prompt_table(self, task_name: str): + """Remove prompt table""" if self.ptuning_tables is not None: for i in range(len(self.ptuning_tables)): if self.ptuning_tables[i]["task_name"] == task_name: @@ -838,11 +1104,13 @@ def remove_prompt_table(self, task_name: str): @property def get_supported_models_list(self): + """Supported model list""" # gpt and gptnext are the same. Keeping the gptnext due to backward compatibility. return ["gpt", "gptnext", "llama", "falcon", "starcoder", "mixtral", "gemma"] @property def get_hidden_size(self): + """Get hidden size""" if self.config is None: return None else: @@ -850,6 +1118,7 @@ def get_hidden_size(self): @property def get_triton_input(self): + """Get triton input""" inputs = ( Tensor(name="prompts", shape=(-1,), dtype=bytes), Tensor(name="max_output_len", shape=(-1,), dtype=np.int_, optional=True), @@ -867,11 +1136,13 @@ def get_triton_input(self): @property def get_triton_output(self): + """Get Triton Output""" outputs = (Tensor(name="outputs", shape=(-1,), dtype=bytes),) return outputs @batch def triton_infer_fn(self, **inputs: np.ndarray): + """Triton infer function for streaming""" try: infer_input = {"input_texts": str_ndarray2list(inputs.pop("prompts"))} if "max_output_len" in inputs: @@ -909,6 +1180,7 @@ def triton_infer_fn(self, **inputs: np.ndarray): @batch def triton_infer_fn_streaming(self, **inputs: np.ndarray): + """Triton infer function for streaming""" try: infer_input = {"input_texts": str_ndarray2list(inputs.pop("prompts"))} if "max_output_len" in inputs: @@ -1118,4 +1390,5 @@ def _load(self): ) from error def unload_engine(self): + """Unload engine""" unload_engine()