Skip to content

fix llama meta data error and LLaMa lm_head wrongly loading error#3914

Closed
baodii wants to merge 7 commits intodeepspeedai:masterfrom
baodii:baodi/fix_llama
Closed

fix llama meta data error and LLaMa lm_head wrongly loading error#3914
baodii wants to merge 7 commits intodeepspeedai:masterfrom
baodii:baodi/fix_llama

Conversation

@baodii
Copy link
Contributor

@baodii baodii commented Jul 10, 2023

  • fix 'NotImplementedError: Cannot copy out of meta tensor; no data!', when loading LLaMa from device meta

@baodii
Copy link
Contributor Author

baodii commented Jul 10, 2023

@mpjlu
Copy link
Contributor

mpjlu commented Jul 12, 2023

master llama container doesn't support meta tensor, this this PR still cannot work

delete lm_head weight load part
@baodii
Copy link
Contributor Author

baodii commented Jul 21, 2023

master llama container doesn't support meta tensor, this this PR still cannot work

This PR is used with autoTP, not kernel injection. It will work when the model you put in init_inference API is meta device.

delete debug code
@xs1997zju
Copy link

appreciate for this pr

@puneeshkhanna
Copy link
Contributor

When will this merge ?

Copy link
Contributor

@cmikeh2 cmikeh2 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!

@jeffra jeffra enabled auto-merge July 27, 2023 21:15
@jeffra jeffra added this pull request to the merge queue Jul 27, 2023
@github-merge-queue github-merge-queue bot removed this pull request from the merge queue due to a conflict with the base branch Jul 27, 2023
@baodii
Copy link
Contributor Author

baodii commented Jul 28, 2023

resolved conflict.

@zeyugao
Copy link
Contributor

zeyugao commented Aug 2, 2023

@baodii Sorry to bother you. Could you please append a new commit or create a new pr for the meta tensor loading when using injected kernel? As the following lines in this pr https://github.com/microsoft/DeepSpeed/pull/3608/files#diff-ad3c4426f1e24b0f6abe2a5b01757eb9d621f67917d46aec05f5e8bc8d757553L88-L89 and https://github.com/microsoft/DeepSpeed/pull/3608/files#diff-ad3c4426f1e24b0f6abe2a5b01757eb9d621f67917d46aec05f5e8bc8d757553R23

In a nut shell, the problem occurs in these two lines https://github.com/microsoft/DeepSpeed/blob/94c7233a8bb51e068ff8dd5d3e03f2e9b5ab248e/deepspeed/module_inject/containers/llama.py#L105-L106

It tries to:

  • load prefix + param_names[7] into transformer_param_names[8] (attn_nw)
  • load prefix + param_names[8] into transformer_param_names[10] (norm_w).

Originally, param_names[7] is input_layernorm.weight and param_names[8] is post_attention_layernorm.weight. However it seems that they are in wrong order.

That is

  • norm_w should be loaded from input_layernorm.weight
  • attn_nw should be loaded from post_attention_layernorm

It is said that the kernel injection with meta tensor loading is partially fixed according to the feedback and test. When testing with 7B llama, the output equals to the huggingface one. But in 65B llama, the output differs, but it doesn't output the garbage, some truly meaningful words but complete different from huggingface.

Some extra words:

I fixed meta tensor loading (pr by you), kernel injected inference these months before the individual PRs because I did come across these issues and debug out them. But no one reviewed my pr even the feedback from others are positive. So I turn to you, hoping that what I found can be fixed, reducing the duplicate effects in debuging this codebase.

analysis **norm_w**

That is, as for norm_w, it is used in https://github.com/microsoft/DeepSpeed/blob/94c7233a8bb51e068ff8dd5d3e03f2e9b5ab248e/deepspeed/model_implementations/transformers/ds_transformer.py#L171-L181 and passed as parameter gamma in https://github.com/microsoft/DeepSpeed/blob/94c7233a8bb51e068ff8dd5d3e03f2e9b5ab248e/deepspeed/ops/transformer/inference/ds_attention.py#L155 (llama is pre_layer_norm) then into one of the cuda function rms_qkv_gemm_* (llama is rms_norm), which implementation is

https://github.com/microsoft/DeepSpeed/blob/94c7233a8bb51e068ff8dd5d3e03f2e9b5ab248e/csrc/transformer/inference/csrc/pt_binding.cpp#L945

The gamma parameter is used in

https://github.com/microsoft/DeepSpeed/blob/94c7233a8bb51e068ff8dd5d3e03f2e9b5ab248e/csrc/transformer/inference/csrc/pt_binding.cpp#L966-L974

Then take the python implement of llama as reference https://github.com/huggingface/transformers/blob/904e7e0f3cee944bffc54e2a084dfcab47ef2036/src/transformers/models/llama/modeling_llama.py#L410

So, the norm_w is equal to the input_layernorm.

attn_nw

attn_nw is used at ds_mlp as a parameter gamma in
https://github.com/microsoft/DeepSpeed/blob/94c7233a8bb51e068ff8dd5d3e03f2e9b5ab248e/deepspeed/ops/transformer/inference/ds_mlp.py#L105-L112

and

https://github.com/microsoft/DeepSpeed/blob/94c7233a8bb51e068ff8dd5d3e03f2e9b5ab248e/deepspeed/ops/transformer/inference/op_binding/mlp_gemm.py#L88-L99

Which implementation is at

https://github.com/microsoft/DeepSpeed/blob/94c7233a8bb51e068ff8dd5d3e03f2e9b5ab248e/csrc/transformer/inference/csrc/pt_binding.cpp#L1540

And gamma parameter is used at

https://github.com/microsoft/DeepSpeed/blob/94c7233a8bb51e068ff8dd5d3e03f2e9b5ab248e/csrc/transformer/inference/csrc/pt_binding.cpp#L1576-L1584

Considering that it is a MLP module following the attention, and the rms_norm is right before the MLP operation, taking the python llama as reference, it is a post attention layer norm

That is attn_nw should be equal to post_attention_layernorm

@baodii baodii closed this Aug 8, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants