Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Serialization: take into account meta tensor when splitting the state_dict #2591

Merged
merged 9 commits into from
Oct 10, 2024

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Oct 7, 2024

What does this PR do ?

Fixes huggingface/transformers#33209 cc @xenova

When a meta tensor is in the state dict, it will be assigned to the same shard file as the other meta tensor since they all share the same storage_id. Hence this creates a big file when using transformers cpu offload saving functionality.
If we have meta tensor in the state_dict, we should consider that they do not share the same storage. Right now, we are putting the meta tensor that have the same size in the same bucket.

Not sure what's the best way to deal with that, I considered:

  • modifying split_state_dict_into_shards_factory function but that would require to add code specific to torch. See this commit.
  • modify get_torch_storage_id and return None when we have meta tensor as returning none is the default behavior of get_storage_id -> latest commit

Failing tests are not related to this PR.

Example

I get the right number of shards + expected total_size

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

model = AutoModelForCausalLM.from_pretrained(
    "google/gemma-2-27b-it",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    max_memory={0:"40GB","cpu":"100GB"}
)
model.save_pretrained('output')
del model

new_model = AutoModelForCausalLM.from_pretrained(
    "output",
    device_map="auto",
    torch_dtype=torch.bfloat16,
    max_memory={0:"60GB"}
)
tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b-it")

input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = new_model.generate(**input_ids)
print(tokenizer.decode(outputs[0]))
Screenshot 2024-10-07 at 6 47 15 PM

@SunMarc SunMarc requested a review from Wauplin October 7, 2024 16:48
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines +382 to +385
if tensor.device.type == "meta":
return None
else:
return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm fine with the changes you've made 👍
Could you just update the docstring just above to reflect this change? Currently it says

Multiple different tensors can share the same underlying storage. For example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with non-overlapping lifetimes may have the same id.

which is not true anymore (will always return None)

And thanks for looking into it in the first place!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done !

@SunMarc SunMarc requested a review from Wauplin October 9, 2024 16:20
Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

src/huggingface_hub/serialization/_torch.py Outdated Show resolved Hide resolved
SunMarc and others added 2 commits October 9, 2024 18:26
Co-authored-by: Lucain <lucain@huggingface.co>
@Wauplin
Copy link
Contributor

Wauplin commented Oct 10, 2024

Thanks @SunMarc !

@Wauplin Wauplin merged commit 8cb81ac into main Oct 10, 2024
19 checks passed
@Wauplin Wauplin deleted the serialization-meta-device branch October 10, 2024 12:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
3 participants