From d86d7faa7e8815a20ef4158c6d3a2d082d33be71 Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Tue, 11 Jun 2024 15:55:01 -0400 Subject: [PATCH] LoRA for MoE Layer (#9396) * initial moe lora impl Signed-off-by: Chen Cui * Apply isort and black reformatting Signed-off-by: cuichenx * fix dangling adapter Signed-off-by: Chen Cui * update to newest mcore code Signed-off-by: Chen Cui --------- Signed-off-by: Chen Cui Signed-off-by: cuichenx Co-authored-by: cuichenx --- .../common/megatron/adapters/mcore_mixins.py | 73 ++++++++++++--- .../megatron/adapters/parallel_adapters.py | 88 +++++++++++++++++-- nemo/collections/nlp/parts/peft_config.py | 40 +++++++-- 3 files changed, 173 insertions(+), 28 deletions(-) diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py index a85c155cc0a8..bcfe07f702a0 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/mcore_mixins.py @@ -14,19 +14,16 @@ import torch import torch.nn.functional as F -from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb +from megatron.core.tensor_parallel import ColumnParallelLinear from megatron.core.transformer.attention import SelfAttention -from megatron.core.transformer.custom_layers.transformer_engine import ( - SplitAlongDim, - TEColumnParallelLinear, - TELayerNormColumnParallelLinear, -) +from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim from megatron.core.transformer.mlp import MLP +from megatron.core.transformer.moe.experts import SequentialMLP from megatron.core.transformer.transformer_layer import TransformerLayer from megatron.core.utils import make_viewless_tensor @@ -37,6 +34,8 @@ LoraDenseAttentionAdapterConfig, LoraHto4HAdapterConfig, LoraKQVAdapterConfig, + LoraMoe4HtoHAdapterConfig, + LoraMoeHto4HAdapterConfig, LoraUnfusedHto4HAdapterConfig, LoraUnfusedKQVAdapterConfig, MLPInfusedAdapterConfig, @@ -281,13 +280,15 @@ def forward( class MCoreMLPMixin(MLP, MCoreAdapterModuleMixin): def mcore_register_adapters(self): """ - Setup NeMo IA3 adapter to this MCore layer. + Setup NeMo IA3 and LoRA adapter to this MCore layer. """ self.set_accepted_adapter_types( [ LoraUnfusedHto4HAdapterConfig._target_, LoraHto4HAdapterConfig._target_, Lora4HtoHAdapterConfig._target_, + LoraMoeHto4HAdapterConfig._target_, + LoraMoe4HtoHAdapterConfig._target_, MLPInfusedAdapterConfig._target_, ] ) # only self attn (packed qkv) for now @@ -302,9 +303,12 @@ def mcore_register_adapters(self): # overlap is used. self.linear_fc1.return_layernorm_output_gathered = True - def forward(self, hidden_states): + def forward(self, hidden_states, expert_idx=None): # [s, b, 4 * h/p] - if self.linear_fc1.te_return_bias: + if isinstance(self.linear_fc1, ColumnParallelLinear): + layernorm_output = hidden_states + intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states) + elif self.linear_fc1.te_return_bias: intermediate_parallel, bias_parallel, layernorm_output = self.linear_fc1(hidden_states) else: # bias_parallel is None @@ -315,15 +319,19 @@ def forward(self, hidden_states): lora_adapter = None lora_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER) lora_unfused_fc1_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_Hto4H_ADAPTER) + lora_moe_fc1_adapter = self.get_adapter_module(AdapterName.LORA_MOE_Hto4H_ADAPTER) if lora_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']: lora_adapter = lora_fc1_adapter if lora_unfused_fc1_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_Hto4H_ADAPTER]['enabled']: assert lora_adapter is None, "Expected only one of LORA_Hto4H_ADAPTER or LORA_UNFUSED_Hto4H_ADAPTER" lora_adapter = lora_unfused_fc1_adapter + lora_output = 0 if lora_adapter: lora_output = lora_adapter(layernorm_output) - intermediate_parallel = intermediate_parallel + lora_output + elif lora_moe_fc1_adapter and self.adapter_cfg[AdapterName.LORA_MOE_Hto4H_ADAPTER]['enabled']: + lora_output = lora_moe_fc1_adapter(layernorm_output, expert_idx) + intermediate_parallel = intermediate_parallel + lora_output if self.config.bias_activation_fusion: if self.activation_func == F.gelu: @@ -363,14 +371,51 @@ def glu(x): # LoRA logic if self.is_adapter_available(): - lora_linear_fc2_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER) - if lora_linear_fc2_adapter and self.adapter_cfg[AdapterName.LORA_4HtoH_ADAPTER]['enabled']: - lora_output = lora_linear_fc2_adapter(intermediate_parallel) - output = output + lora_output + lora_fc2_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER) + lora_moe_fc2_adapter = self.get_adapter_module(AdapterName.LORA_MOE_4HtoH_ADAPTER) + + lora_output = 0 + if lora_fc2_adapter and self.adapter_cfg[AdapterName.LORA_4HtoH_ADAPTER]['enabled']: + lora_output = lora_fc2_adapter(intermediate_parallel) + elif lora_moe_fc2_adapter and self.adapter_cfg[AdapterName.LORA_MOE_4HtoH_ADAPTER]['enabled']: + lora_output = lora_moe_fc2_adapter(intermediate_parallel, expert_idx) + + output = output + lora_output return output, output_bias +class MCoreSequentialMLPMixin(SequentialMLP, MCoreAdapterModuleMixin): + def mcore_register_adapters(self): + """ + We don't want the SequentialMLP layer to take any adapters. We only want to override the forward() behavior + """ + pass + + def forward(self, permuted_local_hidden_states, tokens_per_expert): + output_local = torch.zeros_like(permuted_local_hidden_states) + output_bias_local = None + if self.add_bias: + output_bias_local = torch.zeros_like(permuted_local_hidden_states) + + cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0) + # Insert zero at the begining for offset index's convenience + zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device) + cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens)) + for expert_num, expert in enumerate(self.local_experts): + start = cumsum_num_tokens[expert_num] + end = cumsum_num_tokens[expert_num + 1] + hidden = permuted_local_hidden_states[start:end] + output, output_bias = expert(hidden, expert_num) # expert: MLP + + output_local[start:end] = output + if self.add_bias: + output_bias = output_bias.expand_as(output) + output_bias_local[start:end, :] = output_bias + + return output_local, output_bias_local + + class MCoreGPTEmbeddingMixin(LanguageModelEmbedding, MCoreAdapterModuleMixin): def mcore_register_adapters(self): """ diff --git a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py index 61903e6b3673..21dace008877 100644 --- a/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py +++ b/nemo/collections/nlp/modules/common/megatron/adapters/parallel_adapters.py @@ -83,6 +83,8 @@ class AdapterName(str, enum.Enum): LORA_Hto4H_ADAPTER = "lora_hto4h_adapter" LORA_UNFUSED_Hto4H_ADAPTER = "lora_unfused_hto4h_adapter" LORA_4HtoH_ADAPTER = "lora_4htoh_adapter" + LORA_MOE_Hto4H_ADAPTER = "lora_moe_hto4h_adapter" + LORA_MOE_4HtoH_ADAPTER = "lora_moe_4htoh_adapter" MULTIMODAL_PROJECTOR_ADAPTER = "mm_projector_adapter" PARALLEL_LINEAR_ADAPTER = "parallel_linear_adapter" @@ -611,6 +613,80 @@ class LoraUnfusedKQVAdapterConfig(AdapterConfig): _target_: str = "{0}.{1}".format(LoraUnfusedKQVAdapter.__module__, LoraUnfusedKQVAdapter.__name__) +class LoraMoeAdapter(nn.Module, AdapterModuleUtil): + def __init__( + self, + num_moe_experts: int, + in_features: int, + out_features: int, + dim: int, + activation: str = 'identity', + norm_position: Optional[str] = None, + norm_type: Optional[str] = None, + column_init_method: str = 'xavier', + row_init_method: str = 'zero', + gather_output: bool = False, + input_is_parallel: bool = False, + dropout: float = 0.0, + model_parallel_config: Optional[ModelParallelConfig] = None, + alpha: float | None = None, + dropout_position: str = 'post', + a2a_experimental: bool = False, + **kwargs, + ): + super().__init__() + + self.num_moe_experts = num_moe_experts + adapter_args = { + "in_features": in_features, + "out_features": out_features, + "dim": dim, + "activation": activation, + "norm_position": norm_position, + "norm_type": norm_type, + "column_init_method": column_init_method, + "row_init_method": row_init_method, + "gather_output": gather_output, + "input_is_parallel": input_is_parallel, + "dropout": dropout, + "model_parallel_config": model_parallel_config, + "alpha": alpha, + "dropout_position": dropout_position, + "a2a_experimental": a2a_experimental, + } + self.expert_adapters = nn.ModuleList() + for i in range(num_moe_experts): + self.expert_adapters.append(ParallelLinearAdapter(**adapter_args)) + + def forward(self, x, expert_idx): + return self.expert_adapters[expert_idx](x) + + +@dataclass +class LoraMoeHto4HAdapterConfig(AdapterConfig): + num_moe_experts: int + in_features: int + out_features: int + dim: int + activation: str = 'identity' + norm_position: Optional[str] = None + norm_type: Optional[str] = None + column_init_method: str = 'xavier' + row_init_method: str = 'zero' + gather_output: bool = False + input_is_parallel: bool = False + dropout: float = 0.0 + dropout_position: str = 'post' + alpha: float | None = None + a2a_experimental: bool = False + _target_: str = "{0}.{1}".format(LoraMoeAdapter.__module__, LoraMoeAdapter.__name__) + + +@dataclass +class LoraMoe4HtoHAdapterConfig(LoraMoeHto4HAdapterConfig): + input_is_parallel: bool = True + + class PromptEncoderAdapter(nn.Module, AdapterModuleUtil): """ The Tensor Parallel MLP prompt encoder network that is used to generate the virtual @@ -690,20 +766,14 @@ def set_inference_table(self, prompt_representation: torch.Tensor): self.is_inference_ready = True return True - def clear_inference_table( - self, - ): + def clear_inference_table(self): self.inference_table.fill_(0.0) self.is_inference_ready = False - def get_inference_table( - self, - ): + def get_inference_table(self): return self.inference_table.data - def inner_forward( - self, - ): + def inner_forward(self): input_embeds = self.embedding(self.indices).unsqueeze(0) intermediate_parallel, bias_parallel = self.first(input_embeds) intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel) diff --git a/nemo/collections/nlp/parts/peft_config.py b/nemo/collections/nlp/parts/peft_config.py index 4d558ce00114..50c97e349885 100644 --- a/nemo/collections/nlp/parts/peft_config.py +++ b/nemo/collections/nlp/parts/peft_config.py @@ -23,6 +23,7 @@ MCoreGPTEmbeddingMixin, MCoreMLPMixin, MCoreSelfAttentionMixin, + MCoreSequentialMLPMixin, MCoreTransformerLayerMixin, ) except (ImportError, ModuleNotFoundError): @@ -36,6 +37,8 @@ LoraHto4HAdapterConfig, LoraKQVAdapterConfig, LoraKQVAdapterWeightTyingConfig, + LoraMoe4HtoHAdapterConfig, + LoraMoeHto4HAdapterConfig, LoraUnfusedHto4HAdapterConfig, LoraUnfusedKQVAdapterConfig, MLPInfusedAdapterConfig, @@ -176,7 +179,10 @@ def __init__(self, cfg): elif module == PEFT_MODULE_MAP["hto4h_module"]: hto4h_projection_size = cfg.ffn_hidden_size * 2 if fast_glu_activation else cfg.ffn_hidden_size - if lora_cfg.get("variant", "nemo") == "canonical": + if cfg.get('num_moe_experts', None): + _adapter_name = AdapterName.LORA_MOE_Hto4H_ADAPTER + _adapter_cfg_cls = LoraMoeHto4HAdapterConfig + elif lora_cfg.get("variant", "nemo") == "canonical": _adapter_name = AdapterName.LORA_UNFUSED_Hto4H_ADAPTER _adapter_cfg_cls = LoraUnfusedHto4HAdapterConfig else: @@ -187,13 +193,35 @@ def __init__(self, cfg): cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, _adapter_cfg_cls ) name_key_to_cfg[_adapter_name] = adapter_cfg - name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)] + if _adapter_name == AdapterName.LORA_MOE_Hto4H_ADAPTER: + name_key_to_mcore_mixins[_adapter_name] = [("mlp.experts", MCoreSequentialMLPMixin)] + for i in range(int(cfg.num_moe_experts)): + name_key_to_mcore_mixins[_adapter_name].append( + (f"mlp.experts.local_experts.{i}", MCoreMLPMixin) + ) + else: + name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)] + elif module == PEFT_MODULE_MAP["4htoh_module"]: + if cfg.get('num_moe_experts', None): + _adapter_name = AdapterName.LORA_MOE_4HtoH_ADAPTER + _adapter_cfg_cls = LoraMoe4HtoHAdapterConfig + else: + _adapter_name = AdapterName.LORA_4HtoH_ADAPTER + _adapter_cfg_cls = Lora4HtoHAdapterConfig + adapter_cfg = self._create_lora_config( - cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, Lora4HtoHAdapterConfig + cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, _adapter_cfg_cls ) - name_key_to_cfg[AdapterName.LORA_4HtoH_ADAPTER] = adapter_cfg - name_key_to_mcore_mixins[AdapterName.LORA_4HtoH_ADAPTER] = [("mlp", MCoreMLPMixin)] + name_key_to_cfg[_adapter_name] = adapter_cfg + if _adapter_name == AdapterName.LORA_MOE_4HtoH_ADAPTER: + name_key_to_mcore_mixins[_adapter_name] = [("mlp.experts", MCoreSequentialMLPMixin)] + for i in range(int(cfg.num_moe_experts)): + name_key_to_mcore_mixins[_adapter_name].append( + (f"mlp.experts.local_experts.{i}", MCoreMLPMixin) + ) + else: + name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)] else: logging.error( f"Unrecognized target_module string: {module}.\n" @@ -228,6 +256,8 @@ def _create_lora_config( assert kv_channels is not None, "kv_channels must be provided for canonical Lora" config_args.update({"num_query_groups": num_query_groups, "kv_channels": kv_channels}) config_args.pop("out_features") + elif adapter_cfg_cls in (LoraMoeHto4HAdapterConfig, LoraMoe4HtoHAdapterConfig): + config_args.update({'num_moe_experts': cfg.num_moe_experts}) if lora_cfg.weight_tying: position_embedding_strategy = lora_cfg.get("position_embedding_strategy", None)