-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix RecurrentGemma device_map #30273
Fix RecurrentGemma device_map #30273
Conversation
@@ -252,7 +252,7 @@ def _update_cache(self, key_states, value_states, **cache_kwargs): | |||
to_shift = cache_position >= self.config.attention_window_size - 1 | |||
indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size | |||
|
|||
k_out, v_out = self.key_states, self.value_states | |||
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to _setup_cache
, self.key_states
and self.value_states
are initialized on the device of the hidden state that we pass to the model in generate (e.g. cuda:0). However, this layer might not be on the same device as the hidden state if we use multi-gpu. Hence, we need to make sure that self.key_states
is on the same device as key_states
. Same for value_states
.
contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to( | ||
recurrent_gate.device | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same issue with recurrent_gate
which is initialized in _setup_cache
.
@@ -387,7 +389,7 @@ def _rnn_scan( | |||
|
|||
contextualized_states = torch.zeros_like(hidden_states) | |||
for t in range(hidden_states.shape[1]): | |||
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states | |||
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here also !
self.register_buffer( | ||
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need this to be persistant. This fixes an issue that we get with accelerate too.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, the device thing could be fixed by placing them on the same device as self.key_states? rather than the device passed?
Also tad bit scared of the slow down of doing it there? But LGTM otherwise
self.register_buffer( | ||
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch
I think it will slow down if why place them on the same device as self.key_states for example. Let's say |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks
* Switch to non persistant buffer * fix device mismatch issue due to cache * style
* Switch to non persistant buffer * fix device mismatch issue due to cache * style
* Switch to non persistant buffer * fix device mismatch issue due to cache * style
* Switch to non persistant buffer * fix device mismatch issue due to cache * style
What does this PR do ?
This PR makes gemma compatible with multi-gpu device_map. To try out:
I get the same output in the single gpu or multi gpu setup.