Skip to content

Commit

Permalink
qlora support more models (NVIDIA#9488)
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Cui <chcui@nvidia.com>
Signed-off-by: Alex Cui <alexcui1994@gmail.com>
  • Loading branch information
cuichenx authored and BuyuanCui committed Jul 12, 2024
1 parent 9880716 commit 98cd6f1
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from megatron.core.fusions.fused_bias_swiglu import bias_swiglu_impl
from megatron.core.models.common.embeddings.language_model_embedding import LanguageModelEmbedding
from megatron.core.models.common.embeddings.rotary_pos_embedding import apply_rotary_pos_emb
from megatron.core.tensor_parallel import ColumnParallelLinear
from megatron.core.transformer.attention import SelfAttention
from megatron.core.transformer.custom_layers.transformer_engine import SplitAlongDim
from megatron.core.transformer.mlp import MLP
Expand Down Expand Up @@ -305,14 +304,16 @@ def mcore_register_adapters(self):

def forward(self, hidden_states, expert_idx=None):
# [s, b, 4 * h/p]
if isinstance(self.linear_fc1, ColumnParallelLinear):
layernorm_output = hidden_states
intermediate_parallel, bias_parallel = self.linear_fc1(hidden_states)
elif self.linear_fc1.te_return_bias:
intermediate_parallel, bias_parallel, layernorm_output = self.linear_fc1(hidden_states)
output = self.linear_fc1(hidden_states)
if isinstance(output, tuple) and len(output) == 2:
intermediate_parallel, bias_parallel = output
if isinstance(intermediate_parallel, tuple) and len(intermediate_parallel) == 2:
intermediate_parallel, layernorm_output = intermediate_parallel
else:
layernorm_output = hidden_states
else:
# bias_parallel is None
(intermediate_parallel, layernorm_output), bias_parallel = self.linear_fc1(hidden_states)
# self.linear_fc1.te_return_bias == True
intermediate_parallel, bias_parallel, layernorm_output = output

# LoRA logic
if self.is_adapter_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -228,12 +228,12 @@ def qlora_load_model(model: 'MCoreGPTModel', model_cfg: 'DictConfig', checkpoint
def replace_linear(module: nn.Module, prefix=""):
for name, child in module.named_children():
if name in qlora_targets:
bf16_weight = checkpoint[f"{prefix}.{name}.weight"]
bf16_weight = checkpoint[f"{prefix}.{name}.weight"].to(torch.bfloat16)
logging.info(f'QLoRA: Quantizing linear layer: {prefix}.{name}')
if name in ['linear_proj', 'linear_fc2']:
layer_norm_weight = checkpoint.get(f"{prefix}.{name}.layer_norm_weight", None)
if layer_norm_weight is None:
setattr(module, name, NF4LinearWrapper(bf16_weight))
else: # name in ['linear_qkv', 'linear_fc1']
layer_norm_weight = checkpoint[f"{prefix}.{name}.layer_norm_weight"]
else:
layer_norm_bias = checkpoint.get(f"{prefix}.{name}.layer_norm_bias", None)
normalization = module.config.normalization
zero_centered_gamma = module.config.layernorm_zero_centered_gamma
Expand Down

0 comments on commit 98cd6f1

Please sign in to comment.