Skip to content

Fix llama meta tensor loading in AutoTP and kernel injected inference#3608

Merged
lekurile merged 16 commits intodeepspeedai:masterfrom
zeyugao:master
Sep 20, 2023
Merged

Fix llama meta tensor loading in AutoTP and kernel injected inference#3608
lekurile merged 16 commits intodeepspeedai:masterfrom
zeyugao:master

Conversation

@zeyugao
Copy link
Contributor

@zeyugao zeyugao commented May 25, 2023

The Llama uses the non-conventional LayerNorm and it will not be loaded when using meta tensor. So it will result in NotImplementedError: Cannot copy out of meta tensor; no data!

@zeyugao
Copy link
Contributor Author

zeyugao commented May 30, 2023

@microsoft-github-policy-service agree

@zeyugao zeyugao changed the title Adapte to Llama when using meta tensor to load Fix llama meta tensor loading, model tensor parallelism inference May 30, 2023
@zeyugao
Copy link
Contributor Author

zeyugao commented May 30, 2023

The second commit should fix #3452

@Yard1
Copy link

Yard1 commented May 31, 2023

I can still reproduce the error in #3452 (comment) with this PR.

@zeyugao
Copy link
Contributor Author

zeyugao commented Jun 1, 2023

I can still reproduce the error in #3452 (comment) with this PR.

I made some modifications based on your code to run on my machine. I am also using v100. It works fine when mp_size=2 or 4. Can you try this code?

if True:
    import sys
    import os
    # New deepspeed path
    sys.path.insert(0, '/Code/DeepSpeed')
    import torch
    import deepspeed
    from transformers import LlamaForCausalLM, LlamaTokenizer
    import argparse

# here
deepspeed.init_distributed()

local_rank = int(os.environ.get('LOCAL_RANK', 0))
world_size = int(os.environ.get('WORLD_SIZE', 1))

print(f'local_rank: {local_rank}, world_size: {world_size}')

tokenizer = LlamaTokenizer.from_pretrained('./llama_7B_hf/')
model = LlamaForCausalLM.from_pretrained(./llama_7B_hf/')
model = deepspeed.init_inference(
    model,
    mp_size=world_size,
    dtype=torch.half,
    replace_with_kernel_inject=True
)

batch = tokenizer(
    "The primary use of LLaMA is research on large language models, including",
    return_tensors="pt",
    add_special_tokens=False
)
# here
batch = {k: v.cuda(local_rank) for k, v in batch.items()}
generated = model.generate(batch["input_ids"], max_length=100)
print(tokenizer.decode(generated[0]))

It yields

The primary use of LLaMA is research on large language models, including the BERT model.

\subsection{Learning Language Models}

LLaMA is a tool for training large language models. It is designed to be used with the BERT model, but it can also be used with other large language models.

(Edit: Some issues still exist when mp_size>2 or when padding='right')

@Yard1
Copy link

Yard1 commented Jun 1, 2023

@zeyugao Thank you, it works! I must have messed something up with the installation.

@lyy1994
Copy link

lyy1994 commented Jun 8, 2023

Hi, I am trying to use your PR to run LLaMA-65B. How should I do this? Directly using LlamaForCausalLM.from_pretrained and launching with deepspeed --num_gpus 8 seems to consume a lot of RAM yet meta tensors are not supported for LLaMA.

@zeyugao
Copy link
Contributor Author

zeyugao commented Jun 8, 2023

@lyy1994 You can refer to https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/text-generation/inference-test.py#L24 for how to use meta tensor by checking how this example uses variable use_meta_tensor. I doubt that whether it is compatible with the kernel injection

@lyy1994
Copy link

lyy1994 commented Jun 8, 2023

@lyy1994 You can refer to https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/text-generation/inference-test.py#L24 for how to use meta tensor by checking how this example uses variable use_meta_tensor. I doubt that whether it is compatible with the kernel injection

Thanks for your suggestions! I have tried the code you point to and it raises the following error:

AssertionError: Meta tensors are not supported for this model currently.

@lyy1994
Copy link

lyy1994 commented Jun 8, 2023

@lyy1994 You can refer to https://github.com/microsoft/DeepSpeedExamples/blob/master/inference/huggingface/text-generation/inference-test.py#L24 for how to use meta tensor by checking how this example uses variable use_meta_tensor. I doubt that whether it is compatible with the kernel injection

Thanks for your suggestions! I have tried the code you point to and it raises the following error:

AssertionError: Meta tensors are not supported for this model currently.

Sorry, I made a mistake in running this script before. The command I used is:

deepspeed --num_gpus 2 inference-test.py --checkpoint_path ./LLaMA-7B-Official --batch_size 2 --name ./LLaMA-7B-Official --use_meta_tensor

and it still gives the following error even using your PR:

NotImplementedError: Cannot copy out of meta tensor; no data!

@davidthomas426
Copy link

(Edit: Some issues still exist when mp_size>2 or when padding='right')

I tried running this with mp_size 4, and it worked for me. As far as I can tell, this fixes #3628 as intended. Great work!

@RezaYazdaniAminabadi Could you look at merging this? The gated MLP intermediate weight sharding is broken in 0.9.3 / master, and this fixes it (it needs strided copy like QKV gemm weight since the two weight matrices are glued into one).

@zeyugao
Copy link
Contributor Author

zeyugao commented Jun 10, 2023

I think that this pr is not complete for fixing the LLaMA inference and it does not involve tensor parallelism. During testing, it was found that when the input context exceeds a certain length (e.g., larger than 768, 1024, or 1536), the kernel-injected LLaMA produces incorrect results. In particular, while the model correctly predicts the first output tokens (with do_sample=False) and aligns with the model without kernel injection, subsequent tokens are completely inaccurate. This problem persists even when tensor parallelism is disabled (tp_size=1 or without specifying relevant parameters).

The exact cause remains uncertain, with possibilities including KV cache errors or computation precision loss. However, since shorter inputs generate accurate results, this complicates the issue, and I lack extensive debugging experience with CUDA kernels so I have no much idea about how to debug it.

@zeyugao
Copy link
Contributor Author

zeyugao commented Jun 12, 2023

@lyy1994 I don't observe this issue in my pr (without kernel injection enabled). Did you correctly install the fixed deepspeed? You can use PYTHONPATH to force the python to use the fixed code. For example, I use the following script:

PYTHONPATH=/Code/DeepSpeed deepspeed --include=localhost:1,2 inference-test.py --name /mnt/data/llama_7B_hf --checkpoint_path /mnt/data/llama_7B_hf --ds_inference --use_meta_tensor

@zeyugao
Copy link
Contributor Author

zeyugao commented Jun 12, 2023

I think that it is beacuse of the max_out_tokens parameters in init_inference. Enlarge it when needed. So I think that this pr is complete.

Also, I added some commits that fix the meta tensor loading when kernel injection is enabled.

@thies1006
Copy link

I tried it out using vicuna-13b and 8xT4.
With one query (bs=1) it looks good, but when I try to send a batch it crashes (sometimes).

@chhzh123
Copy link

Any updates on this PR? I saw #3788 had fixed the gated MLP but still did not work for meta tensors.

@ganyk
Copy link

ganyk commented Jul 22, 2023

@zeyugao Hi, i am trying to use your PR to run inference for LLAMA-7B and 65B, it worked well for LLAMA-7B. However, when i used llama-65b, i got this error "KeyError: 'model.layers.53.mlp.gate_proj.weight'".

I have checked the checkpoint file, the key 'model.layers.53.mlp.gate_proj.weight' does exist.

Have you encountered similar issues?

here is my code snippet:

model_name = 'huggyllama/llama-65b'
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False)
config = AutoConfig.from_pretrained(model_name)

with deepspeed.OnDevice(dtype=dtype, device="meta"):
    model = AutoModelForCausalLM.from_config(config, torch_dtype=torch.float16)

model = model.eval()
model = deepspeed.init_inference(
    model,
    mp_size=world_size,
    base_dir=repo_root,
    dtype=torch.float16,
    checkpoint=checkpoints_json,
    replace_with_kernel_inject=True,
)

@xs1997zju
Copy link

Thanks for the reply and experiments. But it seems weird the result of "no meta tensor without kernel injected", it should be the same of huggingface. Did you have tried on llama7b? For llama7b, this two result is same.

@zeyugao
Copy link
Contributor Author

zeyugao commented Aug 1, 2023

@xs1997zju In "no meta tensor without kernel injected", the auto tensor parallel is enabled. Meanwhile the huggingface method is a bare pipeline.

And in first few lines, the result from deepspeed aligns with huggingface, maybe it is due to the precision issue in different implementations.

@xs1997zju
Copy link

Anyway, now, I just using mp=4 and don't use meta tensor to load model 65b, but have to load model on cpu RAM in every rank, and then use deepspeed.init_inference, the final result is same as huggingface. The ram of my machine is not enough for loading 8 model in each rank QAQ.

@xs1997zju
Copy link

So I think the problem should be in meta tensor load method, the llama model adapt has some issue.

@xs1997zju
Copy link

When the transformers is version 4.28.1, use meta-tensor and no kernel-inject, the result is the same as huggingface.

@zeyugao
Copy link
Contributor Author

zeyugao commented Aug 2, 2023

@xs1997zju Without kernel injection, the model loading is kind of different, so it doesn't prove that the meta tensor loading is totally correct.

But I compare the w/ and w/o meta tensor (with kernel injection enabled), the final model parameters are the same. But I only check the rank == 0, cause with mp_size = 8, it takes almost 30 minutes to load the model w/o meta tensor

@xs1997zju
Copy link

any update?

@zeyugao
Copy link
Contributor Author

zeyugao commented Aug 16, 2023

Fixed by a monkey patch, but it seems that I accidentally push the wrong branch and close this pr

@zeyugao zeyugao reopened this Aug 16, 2023
@zeyugao
Copy link
Contributor Author

zeyugao commented Aug 16, 2023

I don't know why this bias variable matters. But considering that no one reviews this pr these three months, I think that I will not response to the following questions rapidly, and I will just occasionally take a look at it.

But the patch is so simple that you can easily merge it upon the master branch.

@lekurile
Copy link
Contributor

Hi @zeyugao,

Thank you for the continued debug efforts here. We'd like to merge the PR, but here are a few conflicts in auto_tp.py and replace_module.py due to a recent PR refactoring a few things into a class called Loading in deepspeed/module_inject/auto_tp.py.

Can you please resolve these conflicts?

Thanks,
Lev

@lekurile lekurile enabled auto-merge September 20, 2023 19:09
@lekurile lekurile added this pull request to the merge queue Sep 20, 2023
Merged via the queue into deepspeedai:master with commit 4fc2c8e Sep 20, 2023
CurryRice233 added a commit to CurryRice233/DeepSpeed that referenced this pull request Sep 28, 2023
* origin/master:
  Allow multiple inference engines in single script (deepspeedai#4384)
  adds triton flash attention2 kernel (deepspeedai#4337)
  Fix llama meta tensor loading in AutoTP and kernel injected inference (deepspeedai#3608)
  Fix min torch version (deepspeedai#4375)
  Fix multinode runner to properly append to PDSH_SSH_ARGS_APPEND (deepspeedai#4373)
  add the missing method (deepspeedai#4363)
  Openfold fix (deepspeedai#4368)
  deepspeed4science japanese blog (deepspeedai#4369)
  deepspeed4science chinese blog (deepspeedai#4366)
  Enable workflow dispatch on Torch 1.10 CI tests (deepspeedai#4361)
  Update conda env to have max pydantic version (deepspeedai#4362)
  add deepspeed4science blog link (deepspeedai#4364)
  added check to avoid undefined behavior when the input_id length is greater than max_tokens (deepspeedai#4349)
  Add the policy to run llama model from the official repo (deepspeedai#4313)
  fix deepspeed4science links (deepspeedai#4358)
  DeepSpeed4Science (deepspeedai#4357)
  Support InternLM (deepspeedai#4137)
  Pass base_dir to model files can be loaded for auto-tp/meta-tensor. (deepspeedai#4348)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants