diff --git a/src/transformers/cache_utils.py b/src/transformers/cache_utils.py index 3c157018ecd214..f6791ead8bda43 100644 --- a/src/transformers/cache_utils.py +++ b/src/transformers/cache_utils.py @@ -1090,8 +1090,15 @@ def update( A tuple containing the updated key and value states. """ cache_position = cache_kwargs.get("cache_position") - self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) - self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) + if key_states.device != self.key_cache[layer_idx].device: + # Note: The static cache is also used by `torch.export`. Attempting to call `.to()` on it + # will raise an error: "cannot mutate tensors with frozen storage" when using `torch.export`. + # Since `torch.export` currently works on "cpu", the workaround is to ensure that any device + # change is scoped within this conditional check to avoid tracing the device change code. + # See: https://github.com/pytorch/pytorch/issues/131679 + self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device=key_states.device) + if value_states.device != self.value_cache[layer_idx].device: + self.value_cache[layer_idx] = self.value_cache[layer_idx].to(device=value_states.device) k_out = self.key_cache[layer_idx] v_out = self.value_cache[layer_idx]