Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

"Attempted to access the data pointer on an invalid python storage" when saving model in TPU mode (Kaggle) #27578

Closed
2 of 4 tasks
Zaphat opened this issue Nov 18, 2023 · 8 comments

Comments

@Zaphat
Copy link

Zaphat commented Nov 18, 2023

System Info

It keeps happening whenever I try to use TPU mode to fine-tune BERT model for sentiment analysis. Everything works fine in GPU mode. I even tried to downgrade/upgrade TensorFlow & safetensors, but it didn't work either. Can you give me any suggestion?

Link to that notebook: https://www.kaggle.com/code/phttrnnguyngia/final

trainer.save_model('final-result')

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
File /kaggle/working/env/safetensors/torch.py:13, in storage_ptr(tensor)
     12 try:
---> 13     return tensor.untyped_storage().data_ptr()
     14 except Exception:
     15     # Fallback for torch==1.10

RuntimeError: Attempted to access the data pointer on an invalid python storage.

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
Cell In[21], line 2
      1 # save the model
----> 2 trainer.save_model('final-result')

File /kaggle/working/env/transformers/trainer.py:2804, in Trainer.save_model(self, output_dir, _internal_call)
   2801     output_dir = self.args.output_dir
   2803 if is_torch_tpu_available():
-> 2804     self._save_tpu(output_dir)
   2805 elif is_sagemaker_mp_enabled():
   2806     # Calling the state_dict needs to be done on the wrapped model and on all processes.
   2807     os.makedirs(output_dir, exist_ok=True)

File /kaggle/working/env/transformers/trainer.py:2873, in Trainer._save_tpu(self, output_dir)
   2871         xm.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
   2872 else:
-> 2873     self.model.save_pretrained(output_dir, is_main_process=self.args.should_save, save_function=xm.save)
   2874 if self.tokenizer is not None and self.args.should_save:
   2875     self.tokenizer.save_pretrained(output_dir)

File /kaggle/working/env/transformers/modeling_utils.py:2187, in PreTrainedModel.save_pretrained(self, save_directory, is_main_process, state_dict, save_function, push_to_hub, max_shard_size, safe_serialization, variant, token, save_peft_format, **kwargs)
   2183 for shard_file, shard in shards.items():
   2184     if safe_serialization:
   2185         # At some point we will need to deal better with save_function (used for TPU and other distributed
   2186         # joyfulness), but for now this enough.
-> 2187         safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
   2188     else:
   2189         save_function(shard, os.path.join(save_directory, shard_file))

File /kaggle/working/env/safetensors/torch.py:281, in save_file(tensors, filename, metadata)
    250 def save_file(
    251     tensors: Dict[str, torch.Tensor],
    252     filename: Union[str, os.PathLike],
    253     metadata: Optional[Dict[str, str]] = None,
    254 ):
    255     """
    256     Saves a dictionary of tensors into raw bytes in safetensors format.
    257 
   (...)
    279     ```
    280     """
--> 281     serialize_file(_flatten(tensors), filename, metadata=metadata)

File /kaggle/working/env/safetensors/torch.py:460, in _flatten(tensors)
    453 if invalid_tensors:
    454     raise ValueError(
    455         f"You are trying to save a sparse tensors: `{invalid_tensors}` which this library does not support."
    456         " You can make it a dense tensor before saving with `.to_dense()` but be aware this might"
    457         " make a much larger file than needed."
    458     )
--> 460 shared_pointers = _find_shared_tensors(tensors)
    461 failing = []
    462 for names in shared_pointers:

File /kaggle/working/env/safetensors/torch.py:72, in _find_shared_tensors(state_dict)
     70 tensors = defaultdict(set)
     71 for k, v in state_dict.items():
---> 72     if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
     73         # Need to add device as key because of multiple GPU.
     74         tensors[(v.device, storage_ptr(v), storage_size(v))].add(k)
     75 tensors = list(sorted(tensors.values()))

File /kaggle/working/env/safetensors/torch.py:17, in storage_ptr(tensor)
     14 except Exception:
     15     # Fallback for torch==1.10
     16     try:
---> 17         return tensor.storage().data_ptr()
     18     except NotImplementedError:
     19         # Fallback for meta storage
     20         return 0

File /kaggle/working/env/torch/storage.py:909, in TypedStorage.data_ptr(self)
    907 def data_ptr(self):
    908     _warn_typed_storage_removal()
--> 909     return self._data_ptr()

File /kaggle/working/env/torch/storage.py:913, in TypedStorage._data_ptr(self)
    912 def _data_ptr(self):
--> 913     return self._untyped_storage.data_ptr()

RuntimeError: Attempted to access the data pointer on an invalid python storage.

Who can help?

No response

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • An officially supported task in the examples folder (such as GLUE/SQuAD, ...)
  • My own task or dataset (give details below)

Reproduction

Run in Kaggle TPU, Environment: Always use latest environment. Input data is included in the notebook

Expected behavior

Expected to save successfully like when using GPU.

@ArthurZucker
Copy link
Collaborator

ArthurZucker commented Nov 20, 2023

I know you provided a link to the notebook but a minimal reproducer would still be welcomed here! 🤗

cc @LysandreJik this might be related to the latest changes? Do you want to have a look?

@LysandreJik
Copy link
Member

Would be eager to hear your thoughts on it @Narsil

@robocan
Copy link

robocan commented Dec 3, 2023

Hellow
I'm facing the same issue in Kaggle TPUs too, did anybody find a solution for it?
Thanks

@Narsil
Copy link
Contributor

Narsil commented Dec 4, 2023

Seems there's a bug in torch itself there since safetensors is only using public API.

import torch_xla
import torch_xla.core.xla_model as xm

dev = xm.xla_device()

A = torch.zeros((2, 2), device=dev)
A.untyped_storage() # <--- Crashes with Attempted to set the storage of a tensor on device "cpu" to a storage on different device "xla:0".  This is no longer allowed; the devices must match.

The CPU fix #27799 will work, but only by moving everything to CPU which isn't desirable imo.

Do we have XLA/torch experts that could shed some light on how to detect a xla tensor specifically ? (I would implement the same in to cpu in safetensors if the tensor is on an XLA device).

Although this could be easily brought up as a bug too to pytorch, no ? @LysandreJik

Minimal repro : https://colab.research.google.com/drive/1O9EqLD-Vfp7PGGldNeJtRtq3oUpLnOJV?usp=sharing (Choose TPU runtime)

@qihqi
Copy link
Contributor

qihqi commented Dec 8, 2023

Do we have XLA/torch experts that could shed some light on how to detect a xla tensor specifically ?

Can check if tensor.device.type == 'xla'

@yeounoh
Copy link
Contributor

yeounoh commented Dec 14, 2023

Also, #27993 could someone help land this? This could resolve this issue.

Copy link

github-actions bot commented Jan 8, 2024

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

@ArthurZucker
Copy link
Collaborator

As the pr was merged, let's close this

jeffhataws added a commit to jeffhataws/transformers that referenced this issue Mar 17, 2024
save_safetensor=True is default as of release 4.35.0, which then
required TPU hotfix huggingface#27799
(issue huggingface#27578).
However, when the flag save_safetensor is set to False (compatibility mode),
moving the model to CPU causes generation of too many graphs
during checkpoint huggingface#28438.
This PR disable moving of model to CPU when save_safetensor=False.
jeffhataws added a commit to jeffhataws/transformers that referenced this issue Mar 17, 2024
save_safetensor=True is default as of release 4.35.0, which then
required TPU hotfix huggingface#27799
(issue huggingface#27578).
However, when the flag save_safetensor is set to False (compatibility mode),
moving the model to CPU causes generation of too many graphs
during checkpoint huggingface#28438.
This PR disable moving of model to CPU when save_safetensor=False.
amyeroberts pushed a commit that referenced this issue Apr 24, 2024
)

save_safetensor=True is default as of release 4.35.0, which then
required TPU hotfix #27799
(issue #27578).
However, when the flag save_safetensor is set to False (compatibility mode),
moving the model to CPU causes generation of too many graphs
during checkpoint #28438.
This PR disable moving of model to CPU when save_safetensor=False.
itazap pushed a commit that referenced this issue May 14, 2024
)

save_safetensor=True is default as of release 4.35.0, which then
required TPU hotfix #27799
(issue #27578).
However, when the flag save_safetensor is set to False (compatibility mode),
moving the model to CPU causes generation of too many graphs
during checkpoint #28438.
This PR disable moving of model to CPU when save_safetensor=False.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants