Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

NotImplementedError: Cannot copy out of meta tensor; no data! when moving LLaVa from meta device to CUDA #28972

Closed
1 of 2 tasks
NielsRogge opened this issue Feb 12, 2024 · 3 comments

Comments

@NielsRogge
Copy link
Contributor

NielsRogge commented Feb 12, 2024

System Info

Transformers 4.37.0.dev0

Who can help?

@ArthurZucker @younesbelkada

Information

  • The official example scripts
  • My own modified scripts

Reproduction

Getting this error:

Traceback (most recent call last):
  File "src/transformers/models/llava/test_meta_device.py", line 10, in <module>
    model.to(device)
  File "/home/niels/python_projects/transformers/src/transformers/modeling_utils.py", line 2556, in to
    return super().to(*args, **kwargs)
  File "/home/niels/python_projects/transformers/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1160, in to
    return self._apply(convert)
  File "/home/niels/python_projects/transformers/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/home/niels/python_projects/transformers/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  File "/home/niels/python_projects/transformers/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 810, in _apply
    module._apply(fn)
  [Previous line repeated 1 more time]
  File "/home/niels/python_projects/transformers/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 833, in _apply
    param_applied = fn(param)
  File "/home/niels/python_projects/transformers/env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1158, in convert
    return t.to(device, dtype if t.is_floating_point() or t.is_complex() else None, non_blocking)
NotImplementedError: Cannot copy out of meta tensor; no data!

when running this:

from transformers import LlavaConfig, LlavaForConditionalGeneration
import torch

config = LlavaConfig()

with torch.device("meta"):
    model = LlavaForConditionalGeneration(config)

model.load_state_dict(original_state_dict, assign=True)

device = "cuda:0"
model.to(device)

Taken from this script. Weird enough, the same thing works for CogVLM as seen here, but not for LLaVa.

Based on this PR: #26849 (which fixed a similar issue for DETR), this may have to do with some modules that cannot be splitted

Expected behavior

I'd like to first load LLaVa on the "meta" device, and then load the weights as I'm converting from the original repository. The "meta" device allows to load the HF model much faster as I've observed with CogVLM (thanks @ArthurZucker as I didn't know about this!)

@NielsRogge
Copy link
Contributor Author

NielsRogge commented Feb 12, 2024

Ok I'm seeing the same error with BERT:

from transformers import BertConfig, BertModel
import torch

config = BertConfig()

with torch.device("meta"):
    model = BertModel(config)

pretrained_model = BertModel.from_pretrained("bert-base-uncased")
model.load_state_dict(pretrained_model.state_dict(), assign=True)

device = "cuda:0"
model.to(device)

Trying to figure out why this works for CogVLM but not for BERT or LLaVa.. maybe @muellerzr has some insights given that he knows a lot about big model inference

Update: it also doesn't work for CogVLM if I use the same rotary embedding class as the one of llama

@muellerzr
Copy link
Contributor

muellerzr commented Feb 12, 2024

@NielsRogge the issue lies in the parameters being initialized. Instead of using with torch.device("meta") use init_empty_weights from accelerate instead and it will work just fine: (basically some buffers and other things causing problems)

from transformers import BertConfig, BertModel
from accelerate import init_empty_weights

config = BertConfig.from_pretrained("bert-base-uncased")

with init_empty_weights():
    model = BertModel(config)

pretrained_model = BertModel.from_pretrained("bert-base-uncased")
model.load_state_dict(pretrained_model.state_dict(), assign=True)

model.to("cuda")

@NielsRogge
Copy link
Contributor Author

Thanks, I indeed noticed that it had to do something with buffers, great, thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants