Skip to content

Commit

Permalink
Fix RecurrentGemma device_map (#30273)
Browse files Browse the repository at this point in the history
* Switch to non persistant buffer

* fix device mismatch issue due to cache

* style
  • Loading branch information
SunMarc authored and ydshieh committed Apr 23, 2024
1 parent 5befc89 commit 005b9ec
Showing 1 changed file with 8 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def _update_cache(self, key_states, value_states, **cache_kwargs):
to_shift = cache_position >= self.config.attention_window_size - 1
indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size

k_out, v_out = self.key_states, self.value_states
k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device)
k_out = k_out[:, :, indices]
v_out = v_out[:, :, indices]

Expand Down Expand Up @@ -376,7 +376,9 @@ def _rnn_scan(
return hidden_states, hidden_states[:, 0].type(acc_dtype)

else:
contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None]
contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to(
recurrent_gate.device
)
contextualized_states += hidden_states.type(acc_dtype)
return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1]

Expand All @@ -387,7 +389,7 @@ def _rnn_scan(

contextualized_states = torch.zeros_like(hidden_states)
for t in range(hidden_states.shape[1]):
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states
recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device)
recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype)
contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype)

Expand Down Expand Up @@ -658,7 +660,9 @@ def __init__(self, config: RecurrentGemmaConfig):
self.final_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False

self.register_buffer("normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16))
self.register_buffer(
"normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False
)
# Initialize weights and apply final processing
self.post_init()

Expand Down

0 comments on commit 005b9ec

Please sign in to comment.