Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

device_map not work #1840

Closed
2 of 4 tasks
bxrjmfh opened this issue Jun 10, 2024 · 2 comments
Closed
2 of 4 tasks

device_map not work #1840

bxrjmfh opened this issue Jun 10, 2024 · 2 comments

Comments

@bxrjmfh
Copy link

bxrjmfh commented Jun 10, 2024

System Info

peft 0.11

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder
  • My own task or dataset (give details below)

Reproduction

I currently have the original LLM weights (llama3-8B) and the corresponding LoRA weights. When loading them, I use the following script:

from typing import List
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
import json
class load_test:
    def __init__(self) -> None:
        pass
    def _load_model_tokenizer(self,args):
        self.model = AutoModelForCausalLM.from_pretrained(args.rp_path, torch_dtype=torch.bfloat16).to(self.device)
        if args.rp_lora_path != '':
            self.model = PeftModel.from_pretrained(self.model,args.rp_lora_path, adapter_name="default11").to(self.device)
            self.model.set_adapter("default11")
        torch.cuda.empty_cache()
lt = load_test()
lt.device = 'cuda:1'
lt._load_model_tokenizer(args)

The expected behavior is that all weights are loaded onto the same device (cuda:1), but there is still some occupation on cuda:0.

image

Expected behavior

All weights are loaded onto the specified device.

Additionally, I have checked the previous issues, but none of them provided a solution.

@BenjaminBossan
Copy link
Member

BenjaminBossan commented Jun 10, 2024

Could you pass the torch_device argument, i.e.:

PeftModel.from_pretrained(self.model,args.rp_lora_path, adapter_name="default11", torch_device=self.device)

For me, that solved the issue, otherwise, PEFT guesses what device to load the PEFT weights on.

This is hard to find as the argument is not documented. I created a PR (#1843) to do that.

Strangely, I also had to remove the torch.cuda.empty_cache() in my tests, otherwise memory would be assigned to cuda:0, no idea how that's possible.

@bxrjmfh
Copy link
Author

bxrjmfh commented Jun 10, 2024

您能否传递该torch_device参数,即:

PeftModel.from_pretrained(self.model,args.rp_lora_path, adapter_name="default11", torch_device=self.device)

对我来说,这解决了这个问题,否则,PEFT 会猜测在哪个设备上加载 PEFT 重量。

由于没有记录该参数,因此很难找到。我创建了一个 PR ( #1843 ) 来做到这一点。

奇怪的是,在我的测试中我还必须删除torch.cuda.empty_cache(),否则内存将被分配给cuda:0,不知道这是怎么可能的。

Thanks for your help! It's actually work. The torch.cuda.empty_cache() still cause allocate bad.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants