Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

trainer save_model ValueError You are trying to save a non contiguous tensor #28293

Closed
2 of 4 tasks
siebeniris opened this issue Dec 31, 2023 · 15 comments · Fixed by #29906
Closed
2 of 4 tasks

trainer save_model ValueError You are trying to save a non contiguous tensor #28293

siebeniris opened this issue Dec 31, 2023 · 15 comments · Fixed by #29906

Comments

@siebeniris
Copy link

siebeniris commented Dec 31, 2023

System Info

Transformers version: 4.36.2
pytorch version: 2.1.1
Python version: 3.10.13

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Fine-tuning mt5 model on a task using transformers trainer, and try to save the model, then the following error occurs.

  File "/home/xxx/xxx/xxx/run.py", line 17, in main
    experiment.run()
  File "/home/xxx/xxx/xxx/experiments.py", line 162, in run
    self.train()
  File "/home/xxx/xxx/xxx/experiments.py", line 207, in train
    trainer.save_model()  # Saves the tokenizer too for easy upload
  File "/home/xxx/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2849, in save_model
    self._save(output_dir)
  File "/home/xxx/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2909, in _save
    self.model.save_pretrained(
  File "/home/xxx/.local/lib/python3.10/site-packages/transformers/modeling_utils.py", line 2376, in save_pretrained
    safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
  File "/home/xxx/.local/lib/python3.10/site-packages/safetensors/torch.py", line 281, in save_file
    serialize_file(_flatten(tensors), filename, metadata=metadata)
  File "/home/xxx/.local/lib/python3.10/site-packages/safetensors/torch.py", line 475, in _flatten
    return {
  File "/home/xxx/.local/lib/python3.10/site-packages/safetensors/torch.py", line 479, in <dictcomp>
    "data": _tobytes(v, k),
  File "/home/xxx/.local/lib/python3.10/site-packages/safetensors/torch.py", line 396, in _tobytes
    raise ValueError(
ValueError: You are trying to save a non contiguous tensor: `encoder_decoder.encoder.block.0.layer.0.SelfAttention.q.weight` which is not allowed. It either means you are trying to save tensors which are reference of each other in which case it's recommended to save only the full tensors, and reslice at load time, or simply call `.contiguous()` on your tensor to pack it before saving.

Expected behavior

Fine-tune mt5 model, and try to save the fine-tuned model, it renders the above error, and modifying transformers/modeling_utils.py file with state_dict= {k:v.contiguous() for k,v in state_dict.items()} solves the problem.

@LysandreJik
Copy link
Member

hmmm do you know what might be happening here @Narsil ? With mt5

@ArthurZucker
Copy link
Collaborator

Might be fixed by #28414 ?

@Narsil
Copy link
Contributor

Narsil commented Jan 12, 2024

Happy to take a look if I can have acces either to the finetune (even dummy I just need to look at those tensors) or a reproducer.

I have no idea what makes some tensors non contiguous and what kind of non contiguous those are

@fxmarty
Copy link
Contributor

fxmarty commented Jan 16, 2024

Non-contiguous parameters/buffers can be saved with safe_serialization=False but not with safe_serialization=True.

@Narsil
Copy link
Contributor

Narsil commented Jan 17, 2024

I was try to ask more, what lib is actually creating non contiguous tensors ? Seems odd to me that we need to create non contiguous tensors for training.

Deepspeed for isntant it's not non contiguous it' s more that they abuse the storage system to force several matmul locality (which I think it to optimize network transport), therefore it was easy to fix once identified (because that's a condition where it's easy to rework the tensors on behalf of users since the non contiguity is not really important for the model).

@jaketae
Copy link
Contributor

jaketae commented Jan 26, 2024

I ran into this issue due to a custom weight tying scheme (output layer is a transpose of the vocabulary embedding, so the former is not contiguous). I got around the error by turning off safe serialization as noted above.

@siebeniris
Copy link
Author

Hi, thanks all for the comments. I have no idea why there are even non-contiguous tensors. I think make them contiguous makes more sense? And it solves the problem and the model training seems to be well. I found it odd that the error doesn't occur for trains T5 models, only for MT5 models, since MT5 is built upon T5 in transformers scripts.

@huggingface huggingface deleted a comment from github-actions bot Feb 27, 2024
@ArthurZucker
Copy link
Collaborator

This is related to the PR that was reverted #28898 I believe!

Copy link

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ZhanGHanG9991
Copy link

ZhanGHanG9991 commented Aug 7, 2024

I also encounter this problem when I fine-tune google-t5/t5-base model and try to save the fine-tuned model.

System Info:
Transformers version: 4.43.4
pytorch version: 2.4.0
Python version: 3.8.19

image

@odunola499
Copy link

odunola499 commented Aug 12, 2024

@ZhanGHanG9991 This is crude but you can add this to your code just when you initialise your model.
for param in model.parameters(): param.data = param.data.contiguous()

@OrrZwebner
Copy link

@ZhanGHanG9991 This is crude but you can add this to your code just when you initialise your model. for param in model.parameters(): param.data = param.data.contiguous()

Amazing it work me! I add this in the init() just after the init of the model
(I used Byt5 model):
image

@duanfa
Copy link

duanfa commented Sep 25, 2024

@ZhanGHanG9991 This is crude but you can add this to your code just when you initialise your model. for param in model.parameters(): param.data = param.data.contiguous()

It worked ! thanks

@StarSapph1re
Copy link

@ZhanGHanG9991 This is crude but you can add this to your code just when you initialise your model. for param in model.parameters(): param.data = param.data.contiguous()

It works! Thank you very much!

@ZhanGHanG9991
Copy link

@ZhanGHanG9991 This is crude but you can add this to your code just when you initialise your model. for param in model.parameters(): param.data = param.data.contiguous()

Thanks!!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.