Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix RecurrentGemma device_map #30273

Merged
merged 3 commits into from
Apr 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _update_cache(self, key_states, value_states, **cache_kwargs):
to_shift = cache_position >= self.config.attention_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size

k_out, v_out = self.key_states, self.value_states
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
Copy link
Member Author

@SunMarc SunMarc Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Due to _setup_cache, self.key_states and self.value_states are initialized on the device of the hidden state that we pass to the model in generate (e.g. cuda:0). However, this layer might not be on the same device as the hidden state if we use multi-gpu. Hence, we need to make sure that self.key_states is on the same device as key_states. Same for value_states.

k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]

Expand Down Expand Up @@ -376,7 +376,9 @@ def _rnn_scan(
return hidden_states, hidden_states[:, 0].type(acc_dtype)

else:
contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None]
contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to(
recurrent_gate.device
)
Comment on lines +379 to +381
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same issue with recurrent_gate which is initialized in _setup_cache.

contextualized_states += hidden_states.type(acc_dtype)
return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1]

Expand All @@ -387,7 +389,7 @@ def _rnn_scan(

contextualized_states = torch.zeros_like(hidden_states)
for t in range(hidden_states.shape[1]):
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here also !

recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype)
contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype)

Expand Down Expand Up @@ -658,7 +660,9 @@ def __init__(self, config: RecurrentGemmaConfig):
self.final_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False

self.register_buffer("normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16))
self.register_buffer(
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False
)
Comment on lines +663 to +665
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need this to be persistant. This fixes an issue that we get with accelerate too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch

# Initialize weights and apply final processing
self.post_init()

Expand Down
Loading