Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

static cache: RuntimeError: cannot mutate tensors with frozen storage #33178

Closed
dvrogozh opened this issue Aug 28, 2024 · 6 comments
Closed

static cache: RuntimeError: cannot mutate tensors with frozen storage #33178

dvrogozh opened this issue Aug 28, 2024 · 6 comments

Comments

@dvrogozh
Copy link
Contributor

dvrogozh commented Aug 28, 2024

With:

On:

  • CPU
  • NVidia A10

Test that "static cache works with torch.export()" fails with:

# RUN_SLOW=1 python3 -m pytest --pspec -vv -k CacheTest tests/utils/test_cache_utils.py

RuntimeError: cannot mutate tensors with frozen storage

While executing %index_copy_ : [num_users=0] = call_method[target=index_copy_](args = (%k_out, 2, %l_input_pos_, %k_embed), kwargs = {})
Original traceback:
  File "/home/dvrogozh/git/huggingface/transformers/tests/utils/test_cache_utils.py", line 210, in forward
    outs = self.model(
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/models/gemma/modeling_gemma.py", line 1076, in forward
    outputs = self.model(
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/models/gemma/modeling_gemma.py", line 889, in forward
    layer_outputs = decoder_layer(
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/models/gemma/modeling_gemma.py", line 611, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/models/gemma/modeling_gemma.py", line 521, in forward
    key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
  File "/home/dvrogozh/git/huggingface/transformers/src/transformers/cache_utils.py", line 1101, in update
    k_out.index_copy_(2, cache_position, key_states)

I observe that adding a .clone() to the following 2 tensors does fix the issue. Such solution was suggested in pytorch/pytorch#127571 (comment). However I am not sure whether that's the correct fix. See #33178 draft PR with this change.

k_out = self.key_cache[layer_idx]
v_out = self.value_cache[layer_idx]

CC: @gante @SunMarc

dvrogozh added a commit to dvrogozh/transformers that referenced this issue Aug 28, 2024
For huggingface#33178

Signed-off-by: Dmitry Rogozhkin <dmitry.v.rogozhkin@intel.com>
@dvrogozh
Copy link
Contributor Author

Another observation, this issue seem to appear after this commit: 1c36db6, #32543, @SunMarc

@dvrogozh
Copy link
Contributor Author

Found also this issue on pytorch side which seem relevant:

@guangy10
Copy link
Contributor

guangy10 commented Sep 4, 2024

Workaround in #33287

@SunMarc
Copy link
Member

SunMarc commented Sep 4, 2024

The workaround proposed by @guangy10 sounds better as it doesn't involve copying !

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@dvrogozh
Copy link
Contributor Author

dvrogozh commented Oct 1, 2024

This issue was addressed by:

@dvrogozh dvrogozh closed this as completed Oct 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants