diff --git a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py index 26cdc437d0f99e..c21f99ce48bd32 100644 --- a/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +++ b/src/transformers/models/recurrent_gemma/modeling_recurrent_gemma.py @@ -252,7 +252,7 @@ def _update_cache(self, key_states, value_states, **cache_kwargs): to_shift = cache_position >= self.config.attention_window_size - 1 indices = (slicing + to_shift[-1].int() - 1) % self.config.attention_window_size - k_out, v_out = self.key_states, self.value_states + k_out, v_out = self.key_states.to(key_states.device), self.value_states.to(value_states.device) k_out = k_out[:, :, indices] v_out = v_out[:, :, indices] @@ -376,7 +376,9 @@ def _rnn_scan( return hidden_states, hidden_states[:, 0].type(acc_dtype) else: - contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None] + contextualized_states = recurrent_gate.type(acc_dtype) * recurrent_states[:, None].to( + recurrent_gate.device + ) contextualized_states += hidden_states.type(acc_dtype) return contextualized_states.type(hidden_states.dtype), contextualized_states[:, -1] @@ -387,7 +389,7 @@ def _rnn_scan( contextualized_states = torch.zeros_like(hidden_states) for t in range(hidden_states.shape[1]): - recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states + recurrent_states = recurrent_gate[:, t].type(acc_dtype) * recurrent_states.to(recurrent_gate.device) recurrent_states = recurrent_states + hidden_states[:, t].type(acc_dtype) contextualized_states[:, t] = recurrent_states.type(hidden_states.dtype) @@ -658,7 +660,9 @@ def __init__(self, config: RecurrentGemmaConfig): self.final_norm = RecurrentGemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False - self.register_buffer("normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16)) + self.register_buffer( + "normalizer", torch.tensor(self.config.hidden_size**0.5, dtype=torch.bfloat16), persistent=False + ) # Initialize weights and apply final processing self.post_init()