Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

相比qwen第一版,显存占用为什么增加了很多? #240

Closed
charliedream1 opened this issue Apr 2, 2024 · 11 comments
Closed

相比qwen第一版,显存占用为什么增加了很多? #240

charliedream1 opened this issue Apr 2, 2024 · 11 comments

Comments

@charliedream1
Copy link

通过缩小context window和pos embedding大小,好像都没有用。是什么导致显存占用增加了,相比qwen第一版显存占用增加了很多

@jklj077
Copy link
Collaborator

jklj077 commented Apr 2, 2024

Please first check the pinned issue and see if your memory profiling matches ours.

@charliedream1
Copy link
Author

charliedream1 commented Apr 2, 2024 via email

@baisechundu
Copy link

For inference it is normal, but training takes more memory

---原始邮件--- 发件人: "Ren @.> 发送时间: 2024年4月2日(周二) 晚上7:21 收件人: @.>; 抄送: "Optimus @.@.>; 主题: Re: [QwenLM/Qwen1.5] 相比qwen第一版,显存占用为什么增加了很多? (Issue #240) Please first check the pinned issue and see if your memory profiling matches ours. — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

Maybe you choose the eager (default) mode for Attention.
Here is the part of source code:
QWEN2_ATTENTION_CLASSES = {
"eager": Qwen2Attention,
"flash_attention_2": Qwen2FlashAttention2,
"sdpa": Qwen2SdpaAttention,
}
self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)

@charliedream1
Copy link
Author

charliedream1 commented Apr 8, 2024 via email

@baisechundu
Copy link

I didn't see this in config.json. So what should I do to change this eager mode? what does this mode do? If removing, will it give impact on performance?

---原始邮件--- 发件人: @.> 发送时间: 2024年4月8日(周一) 下午2:11 收件人: @.>; 抄送: "Optimus @.@.>; 主题: Re: [QwenLM/Qwen1.5] 相比qwen第一版,显存占用为什么增加了很多? (Issue #240) For inference it is normal, but training takes more memory … ---原始邮件--- 发件人: "Ren @.> 发送时间: 2024年4月2日(周二) 晚上7:21 收件人: @.>; 抄送: "Optimus @.@.>; 主题: Re: [QwenLM/Qwen1.5] 相比qwen第一版,显存占用为什么增加了很多? (Issue #240) Please first check the pinned issue and see if your memory profiling matches ours. — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.> Maybe you choose the eager (default) mode for Attention. Here is the part of source code: QWEN2_ATTENTION_CLASSES = { "eager": Qwen2Attention, "flash_attention_2": Qwen2FlashAttention2, "sdpa": Qwen2SdpaAttention, } self.self_attn = QWEN2_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) — Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.>

Two methods for you, it works for me:

  1. Add this parameters in config.json : "_attn_implementation": "sdpa"
  2. Add this parameters in sft code:
    model = AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    config=config,
    cache_dir=training_args.cache_dir,
    device_map=device_map,
    attn_implementation="sdpa", # Add attn_implementation
    quantization_config=BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    )
    if training_args.use_lora and lora_args.q_lora
    else None,
    **model_load_kwargs,
    )

@charliedream1
Copy link
Author

charliedream1 commented Apr 8, 2024 via email

@WeixuanXiong
Copy link

这边使用2.5.7版本的flashattn后也观察到显存占用远远高于Qwen1。有没有什么优化方案呢。

@jklj077
Copy link
Collaborator

jklj077 commented May 22, 2024

Qwen(1.0) will automatically enable flash attention if it is installed, which is no longer the case for Qwen1.5.

To enable flash attention in Qwen1.5, please follow the instructions provided in the transformers' official documentation at https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2. In short, please ensure that attn_implementation is set to "flash_attention_2" and torch_dtype is set to "auto" or torch.bfloat16 or torch.float16 when calling from_pretrained for it to take effect.

We don't recommend bitsandbytes as you may suffer from substantial accuracy loss. If you must use quantization, try loading the GPTQ or the AWQ version and then use QLoRA.

@jklj077 jklj077 closed this as completed May 22, 2024
@LuJunru
Copy link

LuJunru commented May 27, 2024

I think this may be related to the transformers issue: huggingface/transformers#30860. Since many models are influenced. In Qwen codes, there's no logits = logits.float().

@philipgao518
Copy link

philipgao518 commented Aug 8, 2024

测试过了,用dp3、flash_attention_2,qwen1.5的72B在16张A10显卡下微调可以开到2048tokens,相同参数下qwen2只能跑到1024tokens,显存消耗增加了不少,是模型结构变化了吗?

Copy link

This issue has been automatically locked since there has not been any recent activity after it was closed. Please open a new issue for related bugs.

@github-actions github-actions bot locked as resolved and limited conversation to collaborators Feb 26, 2025
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants