Skip to content

Commit

Permalink
LoRA for MoE Layer (NVIDIA#9396)
Browse files Browse the repository at this point in the history
* initial moe lora impl

Signed-off-by: Chen Cui <chcui@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>

* fix dangling adapter

Signed-off-by: Chen Cui <chcui@nvidia.com>

* update to newest mcore code

Signed-off-by: Chen Cui <chcui@nvidia.com>

---------

Signed-off-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: cuichenx <cuichenx@users.noreply.github.com>
Co-authored-by: cuichenx <cuichenx@users.noreply.github.com>
  • Loading branch information
2 people authored and JesusPaz committed Jun 18, 2024
1 parent a7f15f5 commit d86d7fa
Show file tree
Hide file tree
Showing 3 changed files with 173 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,16 @@

import torch
import torch.nn.functional as F
from megatron.core.fusions.fused_bias_dropout import get_bias_dropout_add
from megatron.core.fusions.fused_bias_geglu import bias_geglu_impl
from megatron.core.fusions.fused_bias_gelu import bias_gelu_impl
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.tensor_parallel import ColumnParallelLinear
from megatron.core.transformer.attention import SelfAttention
from megatron.core.transformer.custom_layers.transformer_engine import (
SplitAlongDim,
TEColumnParallelLinear,
TELayerNormColumnParallelLinear,
)
from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim
from megatron.core.transformer.mlp import MLP
from megatron.core.transformer.moe.experts import SequentialMLP
from megatron.core.transformer.transformer_layer import TransformerLayer
from megatron.core.utils import make_viewless_tensor

Expand All @@ -37,6 +34,8 @@
LoraDenseAttentionAdapterConfig,
LoraHto4HAdapterConfig,
LoraKQVAdapterConfig,
LoraMoe4HtoHAdapterConfig,
LoraMoeHto4HAdapterConfig,
LoraUnfusedHto4HAdapterConfig,
LoraUnfusedKQVAdapterConfig,
MLPInfusedAdapterConfig,
Expand Down Expand Up @@ -281,13 +280,15 @@ def forward(
class MCoreMLPMixin(MLP, MCoreAdapterModuleMixin):
def mcore_register_adapters(self):
"""
Setup NeMo IA3 adapter to this MCore layer.
Setup NeMo IA3 and LoRA adapter to this MCore layer.
"""
self.set_accepted_adapter_types(
[
LoraUnfusedHto4HAdapterConfig._target_,
LoraHto4HAdapterConfig._target_,
Lora4HtoHAdapterConfig._target_,
LoraMoeHto4HAdapterConfig._target_,
LoraMoe4HtoHAdapterConfig._target_,
MLPInfusedAdapterConfig._target_,
]
) # only self attn (packed qkv) for now
Expand All @@ -302,9 +303,12 @@ def mcore_register_adapters(self):
# overlap is used.
self.linear_fc1.return_layernorm_output_gathered = True

def forward(self, hidden_states):
def forward(self, hidden_states, expert_idx=None):
# [s, b, 4 * h/p]
if self.linear_fc1.te_return_bias:
if isinstance(self.linear_fc1, ColumnParallelLinear):
layernorm_output = hidden_states
intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)
elif self.linear_fc1.te_return_bias:
intermediate_parallel, bias_parallel, layernorm_output = self.linear_fc1(hidden_states)
else:
# bias_parallel is None
Expand All @@ -315,15 +319,19 @@ def forward(self, hidden_states):
lora_adapter = None
lora_fc1_adapter = self.get_adapter_module(AdapterName.LORA_Hto4H_ADAPTER)
lora_unfused_fc1_adapter = self.get_adapter_module(AdapterName.LORA_UNFUSED_Hto4H_ADAPTER)
lora_moe_fc1_adapter = self.get_adapter_module(AdapterName.LORA_MOE_Hto4H_ADAPTER)
if lora_fc1_adapter and self.adapter_cfg[AdapterName.LORA_Hto4H_ADAPTER]['enabled']:
lora_adapter = lora_fc1_adapter
if lora_unfused_fc1_adapter and self.adapter_cfg[AdapterName.LORA_UNFUSED_Hto4H_ADAPTER]['enabled']:
assert lora_adapter is None, "Expected only one of LORA_Hto4H_ADAPTER or LORA_UNFUSED_Hto4H_ADAPTER"
lora_adapter = lora_unfused_fc1_adapter

lora_output = 0
if lora_adapter:
lora_output = lora_adapter(layernorm_output)
intermediate_parallel = intermediate_parallel + lora_output
elif lora_moe_fc1_adapter and self.adapter_cfg[AdapterName.LORA_MOE_Hto4H_ADAPTER]['enabled']:
lora_output = lora_moe_fc1_adapter(layernorm_output, expert_idx)
intermediate_parallel = intermediate_parallel + lora_output

if self.config.bias_activation_fusion:
if self.activation_func == F.gelu:
Expand Down Expand Up @@ -363,14 +371,51 @@ def glu(x):

# LoRA logic
if self.is_adapter_available():
lora_linear_fc2_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER)
if lora_linear_fc2_adapter and self.adapter_cfg[AdapterName.LORA_4HtoH_ADAPTER]['enabled']:
lora_output = lora_linear_fc2_adapter(intermediate_parallel)
output = output + lora_output
lora_fc2_adapter = self.get_adapter_module(AdapterName.LORA_4HtoH_ADAPTER)
lora_moe_fc2_adapter = self.get_adapter_module(AdapterName.LORA_MOE_4HtoH_ADAPTER)

lora_output = 0
if lora_fc2_adapter and self.adapter_cfg[AdapterName.LORA_4HtoH_ADAPTER]['enabled']:
lora_output = lora_fc2_adapter(intermediate_parallel)
elif lora_moe_fc2_adapter and self.adapter_cfg[AdapterName.LORA_MOE_4HtoH_ADAPTER]['enabled']:
lora_output = lora_moe_fc2_adapter(intermediate_parallel, expert_idx)

output = output + lora_output

return output, output_bias


class MCoreSequentialMLPMixin(SequentialMLP, MCoreAdapterModuleMixin):
def mcore_register_adapters(self):
"""
We don't want the SequentialMLP layer to take any adapters. We only want to override the forward() behavior
"""
pass

def forward(self, permuted_local_hidden_states, tokens_per_expert):
output_local = torch.zeros_like(permuted_local_hidden_states)
output_bias_local = None
if self.add_bias:
output_bias_local = torch.zeros_like(permuted_local_hidden_states)

cumsum_num_tokens = torch.cumsum(tokens_per_expert, dim=0)
# Insert zero at the begining for offset index's convenience
zero_tensor = torch.zeros(1, dtype=torch.long, device=cumsum_num_tokens.device)
cumsum_num_tokens = torch.cat((zero_tensor, cumsum_num_tokens))
for expert_num, expert in enumerate(self.local_experts):
start = cumsum_num_tokens[expert_num]
end = cumsum_num_tokens[expert_num + 1]
hidden = permuted_local_hidden_states[start:end]
output, output_bias = expert(hidden, expert_num) # expert: MLP

output_local[start:end] = output
if self.add_bias:
output_bias = output_bias.expand_as(output)
output_bias_local[start:end, :] = output_bias

return output_local, output_bias_local


class MCoreGPTEmbeddingMixin(LanguageModelEmbedding, MCoreAdapterModuleMixin):
def mcore_register_adapters(self):
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ class AdapterName(str, enum.Enum):
LORA_Hto4H_ADAPTER = "lora_hto4h_adapter"
LORA_UNFUSED_Hto4H_ADAPTER = "lora_unfused_hto4h_adapter"
LORA_4HtoH_ADAPTER = "lora_4htoh_adapter"
LORA_MOE_Hto4H_ADAPTER = "lora_moe_hto4h_adapter"
LORA_MOE_4HtoH_ADAPTER = "lora_moe_4htoh_adapter"
MULTIMODAL_PROJECTOR_ADAPTER = "mm_projector_adapter"
PARALLEL_LINEAR_ADAPTER = "parallel_linear_adapter"

Expand Down Expand Up @@ -611,6 +613,80 @@ class LoraUnfusedKQVAdapterConfig(AdapterConfig):
_target_: str = "{0}.{1}".format(LoraUnfusedKQVAdapter.__module__, LoraUnfusedKQVAdapter.__name__)


class LoraMoeAdapter(nn.Module, AdapterModuleUtil):
def __init__(
self,
num_moe_experts: int,
in_features: int,
out_features: int,
dim: int,
activation: str = 'identity',
norm_position: Optional[str] = None,
norm_type: Optional[str] = None,
column_init_method: str = 'xavier',
row_init_method: str = 'zero',
gather_output: bool = False,
input_is_parallel: bool = False,
dropout: float = 0.0,
model_parallel_config: Optional[ModelParallelConfig] = None,
alpha: float | None = None,
dropout_position: str = 'post',
a2a_experimental: bool = False,
**kwargs,
):
super().__init__()

self.num_moe_experts = num_moe_experts
adapter_args = {
"in_features": in_features,
"out_features": out_features,
"dim": dim,
"activation": activation,
"norm_position": norm_position,
"norm_type": norm_type,
"column_init_method": column_init_method,
"row_init_method": row_init_method,
"gather_output": gather_output,
"input_is_parallel": input_is_parallel,
"dropout": dropout,
"model_parallel_config": model_parallel_config,
"alpha": alpha,
"dropout_position": dropout_position,
"a2a_experimental": a2a_experimental,
}
self.expert_adapters = nn.ModuleList()
for i in range(num_moe_experts):
self.expert_adapters.append(ParallelLinearAdapter(**adapter_args))

def forward(self, x, expert_idx):
return self.expert_adapters[expert_idx](x)


@dataclass
class LoraMoeHto4HAdapterConfig(AdapterConfig):
num_moe_experts: int
in_features: int
out_features: int
dim: int
activation: str = 'identity'
norm_position: Optional[str] = None
norm_type: Optional[str] = None
column_init_method: str = 'xavier'
row_init_method: str = 'zero'
gather_output: bool = False
input_is_parallel: bool = False
dropout: float = 0.0
dropout_position: str = 'post'
alpha: float | None = None
a2a_experimental: bool = False
_target_: str = "{0}.{1}".format(LoraMoeAdapter.__module__, LoraMoeAdapter.__name__)


@dataclass
class LoraMoe4HtoHAdapterConfig(LoraMoeHto4HAdapterConfig):
input_is_parallel: bool = True


class PromptEncoderAdapter(nn.Module, AdapterModuleUtil):
"""
The Tensor Parallel MLP prompt encoder network that is used to generate the virtual
Expand Down Expand Up @@ -690,20 +766,14 @@ def set_inference_table(self, prompt_representation: torch.Tensor):
self.is_inference_ready = True
return True

def clear_inference_table(
self,
):
def clear_inference_table(self):
self.inference_table.fill_(0.0)
self.is_inference_ready = False

def get_inference_table(
self,
):
def get_inference_table(self):
return self.inference_table.data

def inner_forward(
self,
):
def inner_forward(self):
input_embeds = self.embedding(self.indices).unsqueeze(0)
intermediate_parallel, bias_parallel = self.first(input_embeds)
intermediate_parallel = fused_bias_gelu(intermediate_parallel, bias_parallel)
Expand Down
40 changes: 35 additions & 5 deletions nemo/collections/nlp/parts/peft_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
MCoreGPTEmbeddingMixin,
MCoreMLPMixin,
MCoreSelfAttentionMixin,
MCoreSequentialMLPMixin,
MCoreTransformerLayerMixin,
)
except (ImportError, ModuleNotFoundError):
Expand All @@ -36,6 +37,8 @@
LoraHto4HAdapterConfig,
LoraKQVAdapterConfig,
LoraKQVAdapterWeightTyingConfig,
LoraMoe4HtoHAdapterConfig,
LoraMoeHto4HAdapterConfig,
LoraUnfusedHto4HAdapterConfig,
LoraUnfusedKQVAdapterConfig,
MLPInfusedAdapterConfig,
Expand Down Expand Up @@ -176,7 +179,10 @@ def __init__(self, cfg):

elif module == PEFT_MODULE_MAP["hto4h_module"]:
hto4h_projection_size = cfg.ffn_hidden_size * 2 if fast_glu_activation else cfg.ffn_hidden_size
if lora_cfg.get("variant", "nemo") == "canonical":
if cfg.get('num_moe_experts', None):
_adapter_name = AdapterName.LORA_MOE_Hto4H_ADAPTER
_adapter_cfg_cls = LoraMoeHto4HAdapterConfig
elif lora_cfg.get("variant", "nemo") == "canonical":
_adapter_name = AdapterName.LORA_UNFUSED_Hto4H_ADAPTER
_adapter_cfg_cls = LoraUnfusedHto4HAdapterConfig
else:
Expand All @@ -187,13 +193,35 @@ def __init__(self, cfg):
cfg, lora_cfg, cfg.hidden_size, hto4h_projection_size, _adapter_cfg_cls
)
name_key_to_cfg[_adapter_name] = adapter_cfg
name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)]
if _adapter_name == AdapterName.LORA_MOE_Hto4H_ADAPTER:
name_key_to_mcore_mixins[_adapter_name] = [("mlp.experts", MCoreSequentialMLPMixin)]
for i in range(int(cfg.num_moe_experts)):
name_key_to_mcore_mixins[_adapter_name].append(
(f"mlp.experts.local_experts.{i}", MCoreMLPMixin)
)
else:
name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)]

elif module == PEFT_MODULE_MAP["4htoh_module"]:
if cfg.get('num_moe_experts', None):
_adapter_name = AdapterName.LORA_MOE_4HtoH_ADAPTER
_adapter_cfg_cls = LoraMoe4HtoHAdapterConfig
else:
_adapter_name = AdapterName.LORA_4HtoH_ADAPTER
_adapter_cfg_cls = Lora4HtoHAdapterConfig

adapter_cfg = self._create_lora_config(
cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, Lora4HtoHAdapterConfig
cfg, lora_cfg, cfg.ffn_hidden_size, cfg.hidden_size, _adapter_cfg_cls
)
name_key_to_cfg[AdapterName.LORA_4HtoH_ADAPTER] = adapter_cfg
name_key_to_mcore_mixins[AdapterName.LORA_4HtoH_ADAPTER] = [("mlp", MCoreMLPMixin)]
name_key_to_cfg[_adapter_name] = adapter_cfg
if _adapter_name == AdapterName.LORA_MOE_4HtoH_ADAPTER:
name_key_to_mcore_mixins[_adapter_name] = [("mlp.experts", MCoreSequentialMLPMixin)]
for i in range(int(cfg.num_moe_experts)):
name_key_to_mcore_mixins[_adapter_name].append(
(f"mlp.experts.local_experts.{i}", MCoreMLPMixin)
)
else:
name_key_to_mcore_mixins[_adapter_name] = [("mlp", MCoreMLPMixin)]
else:
logging.error(
f"Unrecognized target_module string: {module}.\n"
Expand Down Expand Up @@ -228,6 +256,8 @@ def _create_lora_config(
assert kv_channels is not None, "kv_channels must be provided for canonical Lora"
config_args.update({"num_query_groups": num_query_groups, "kv_channels": kv_channels})
config_args.pop("out_features")
elif adapter_cfg_cls in (LoraMoeHto4HAdapterConfig, LoraMoe4HtoHAdapterConfig):
config_args.update({'num_moe_experts': cfg.num_moe_experts})

if lora_cfg.weight_tying:
position_embedding_strategy = lora_cfg.get("position_embedding_strategy", None)
Expand Down

0 comments on commit d86d7fa

Please sign in to comment.