Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RecurrentGemma device_map #30273

Merged
merged 3 commits into from
Apr 18, 2024

Conversation

SunMarc
Copy link
Member

@SunMarc SunMarc commented Apr 16, 2024

What does this PR do ?

This PR makes gemma compatible with multi-gpu device_map. To try out:

from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("google/recurrentgemma-2b-it")
model = AutoModelForCausalLM.from_pretrained(
    "google/recurrentgemma-2b-it", device_map="auto"
)
input_text = "Write me a poem about Machine Learning."
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids,use_cache=True)
print(tokenizer.decode(outputs[0]))

I get the same output in the single gpu or multi gpu setup.

@SunMarc SunMarc requested a review from ArthurZucker April 16, 2024 14:48
@SunMarc SunMarc changed the title Fix recurrent gemma device_map Fix RecurrentGemma device_map Apr 16, 2024
@@ -252,7 +252,7 @@ def _update_cache(self, key_states, value_states, **cache_kwargs):
to_shift = cache_position >= self.config.attention_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size

k_out, v_out = self.key_states, self.value_states
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
Copy link
Member Author

@SunMarc SunMarc Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to _setup_cache, self.key_states and self.value_states are initialized on the device of the hidden state that we pass to the model in generate (e.g. cuda:0). However, this layer might not be on the same device as the hidden state if we use multi-gpu. Hence, we need to make sure that self.key_states is on the same device as key_states. Same for value_states.

Comment on lines +379 to +381
contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to(
recurrent_gate.device
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue with recurrent_gate which is initialized in _setup_cache.

@@ -387,7 +389,7 @@ def _rnn_scan(

contextualized_states = torch.zeros_like(hidden_states)
for t in range(hidden_states.shape[1]):
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here also !

Comment on lines +663 to +665
self.register_buffer(
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False
)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this to be persistant. This fixes an issue that we get with accelerate too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, the device thing could be fixed by placing them on the same device as self.key_states? rather than the device passed?
Also tad bit scared of the slow down of doing it there? But LGTM otherwise

Comment on lines +663 to +665
self.register_buffer(
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

@SunMarc
Copy link
Member Author

SunMarc commented Apr 17, 2024

Thanks, the device thing could be fixed by placing them on the same device as self.key_states? rather than the device passed? Also tad bit scared of the slow down of doing it there? But LGTM otherwise

I think it will slow down if why place them on the same device as self.key_states for example. Let's say self.key_states is initialized on cuda:0 and we have 2 gpus. The problem is that the computed key_states can be on cuda:0 or cuda:1 depending on where the layer is. Hence, it is better to move self.key_states to the device of key_states to limit data transfert between gpus. Otherwise, we will need to move the data each time we have a layer in cuda:1.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks

@ArthurZucker ArthurZucker merged commit 7509a0a into huggingface:main Apr 18, 2024
19 checks passed
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request Apr 18, 2024
* Switch to non persistant buffer

* fix device mismatch issue due to cache

* style
ArthurZucker pushed a commit that referenced this pull request Apr 22, 2024
* Switch to non persistant buffer

* fix device mismatch issue due to cache

* style
ydshieh pushed a commit that referenced this pull request Apr 23, 2024
* Switch to non persistant buffer

* fix device mismatch issue due to cache

* style
itazap pushed a commit that referenced this pull request May 14, 2024
* Switch to non persistant buffer

* fix device mismatch issue due to cache

* style
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants