-
Notifications
You must be signed in to change notification settings - Fork 22.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch.serialization.skip_data context manager #134504
Add torch.serialization.skip_data context manager #134504
Conversation
[ghstack-poisoned]
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/134504
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 0cc1afb with merge base 5a0e7a4 (): This comment was automatically generated by Dr. CI and updates every 15 minutes. |
ghstack-source-id: c7ec2503006744fc967a26d7623f107f1952e1b1 Pull Request resolved: #134504
[ghstack-poisoned]
ghstack-source-id: d877e269960a0984896931ad6b61ae05688eaa9a Pull Request resolved: #134504
The semantic is (1) If real tensors are passed, storages bytes will not be written but zipfile metadata + sufficient space will be saved in the checkpoint for ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() torch.save(sd, 'foo.pt', metadata_only=True) print(torch.load('foo.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]])), ('bias', tensor([0., 0., 0., 0., 0.]))]) ``` (2) If FakeTensor is passed, space will be saved in the checkpoint for the associated storage ```python from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() torch.save(sd, 'bla.pt', metadata_only=True) print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
ghstack-source-id: f462eb65e4620006092e4c8446ed2b99b3325f2f Pull Request resolved: #134504
@@ -978,7 +1004,7 @@ def persistent_id(obj): | |||
# If storage is allocated, ensure that any other saved storages | |||
# pointing to the same data all have the same dtype. If storage is | |||
# not allocated, don't perform this check | |||
if storage.data_ptr() != 0: | |||
if str(storage.device) != "meta" and storage.data_ptr() != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using str(storage.device) != "meta"
to avoid the deprecation message mentioned above when accessing FakeTensor data_ptr
...are there other cases when device is not meta but storage data_ptr is 0?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
xla IIRC
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok, think this fix is good then to avoid accessing data_ptr specifically for FakeTensor storage
The semantic is (1) If real tensors are passed, storages bytes will not be written but zipfile metadata + sufficient space will be saved in the checkpoint for ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() torch.save(sd, 'foo.pt', metadata_only=True) print(torch.load('foo.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]])), ('bias', tensor([0., 0., 0., 0., 0.]))]) ``` (2) If FakeTensor is passed to `torch.save(metadata_only=True)` space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() torch.save(sd, 'bla.pt', metadata_only=True) print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
ghstack-source-id: 2ded66d41b110f0c9bf44ad0b36688201a9d85b4 Pull Request resolved: #134504
The semantic is (1) If real tensors are passed, storages bytes will not be written but zipfile metadata + sufficient space will be saved in the checkpoint for ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() torch.save(sd, 'foo.pt', metadata_only=True) print(torch.load('foo.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]])), ('bias', tensor([0., 0., 0., 0., 0.]))]) ``` (2) If FakeTensor is passed to `torch.save(metadata_only=True)` space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() torch.save(sd, 'bla.pt', metadata_only=True) print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
ghstack-source-id: 4590955bb49646f3b461245fbc88c2051297bbd3 Pull Request resolved: #134504
The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make torch.save skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save()` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
ghstack-source-id: 04f6268af51de244f7c7f1a267ae85902d9c490a Pull Request resolved: #134504
The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make torch.save skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save()` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
ghstack-source-id: 353853b426919cac893a28a8a83968a92cf36aa7 Pull Request resolved: #134504
The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make torch.save skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save()` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
ghstack-source-id: 508e473d845f34db9c7c0fd7c3f1fa1f8a30b46a Pull Request resolved: #134504
The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
ghstack-source-id: 5257374670569afd9696e23dfa7c100e00e63399 Pull Request resolved: #134504
state = torch._utils._get_obj_state(self) | ||
if type(self) is Tensor and not state: | ||
# Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is sub-optimal in the long term. We should have a clearer contract on how our FakeTensor -> real Tensor with no data.
Dropping every single field from it might not be acceptable for everyone and might be interesting if there was a method to extract such Tensor from it easily.
This sounds ok as a temporary solution here I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, technically what we want is to drop every attribute that is in the dict of a "normal" FakeTensor, but keep anything else, is that right?
But agreed that for now it seems tricky to handle this because I'm not sure what is in a "normal" FakeTensor's dict is stable
torch/_tensor.py
Outdated
), | ||
( | ||
isinstance( | ||
self, torch._subclasses.functional_tensor.FunctionalTensor |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we just split functional + fake Tensor to another condition just above? I have to admit I can't read this condition anymore :D
torch/serialization.py
Outdated
# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) | ||
# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) | ||
# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) | ||
_serialization_tls = threading.local() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ho I would have expected you just did torch._C._stash_obj_in_tls("_serialization_tls", _serialization_tls)
here and nothing else below?
Or if we don't know how to do it. It's ok to remote it all and open an issue as a TODO for later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed and opened issue here #134680
The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
ghstack-source-id: ddcd957cc6258a1d5eb7bf54910deee2ee02527d Pull Request resolved: #134504
The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` [ghstack-poisoned]
ghstack-source-id: 507667c8e1f8956f09c236d486d66f744fc89233 Pull Request resolved: #134504
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) [ghstack-poisoned]
ghstack-source-id: f0ebde9fd35937451885a562b05b7c02d4f91dac Pull Request resolved: #134504
@mikaylagawarecki has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610) [ghstack-poisoned]
ghstack-source-id: edd5529c1a59d84dcc63562874a73c7927f524d2 Pull Request resolved: #134504
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610) [ghstack-poisoned]
ghstack-source-id: 9258c783fbb455fa0caa507c00e1a74e834d7cf0 Pull Request resolved: #134504
@mikaylagawarecki has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Repro given no longer fails, see D62238610 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM!
@pytorchbot merge |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
@pytorchbot merge |
The merge job was canceled or timed out. This most often happen if two merge requests were issued for the same PR, or if merge job was waiting for more than 6 hours for tests to finish. In later case, please do not hesitate to reissue the merge command |
Merge startedYour change will be merged once all checks pass (ETA 0-4 Hours). Learn more about merging in the wiki. Questions? Feedback? Please reach out to the PyTorch DevX Team |
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Pull Request resolved: pytorch#134504 Approved by: https://github.com/albanD
…4504)" This reverts commit 202600b. Reverted pytorch#134504 on behalf of https://github.com/mikaylagawarecki due to This is breaking Windows docs tests due to NamedTemporaryFile on Windows not working well ([comment](pytorch#134504 (comment)))
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Pull Request resolved: pytorch#134504 Approved by: https://github.com/albanD
…4504)" This reverts commit 94db935. Reverted pytorch#134504 on behalf of https://github.com/kit1980 due to See D62082697 ([comment](pytorch#134504 (comment)))
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610) Pull Request resolved: pytorch#134504 Approved by: https://github.com/albanD
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Pull Request resolved: pytorch#134504 Approved by: https://github.com/albanD
…4504)" This reverts commit 202600b. Reverted pytorch#134504 on behalf of https://github.com/mikaylagawarecki due to This is breaking Windows docs tests due to NamedTemporaryFile on Windows not working well ([comment](pytorch#134504 (comment)))
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Pull Request resolved: pytorch#134504 Approved by: https://github.com/albanD
…4504)" This reverts commit 94db935. Reverted pytorch#134504 on behalf of https://github.com/kit1980 due to See D62082697 ([comment](pytorch#134504 (comment)))
## Semantic The semantic is (1) By default `torch.serialization.skip_data(materialize_fake_tensors=False)` will make `torch.save` skip writing storages (but reserve space for them in the checkpoint). ```python import torch import torch.nn as nn sd = nn.Linear(3, 5).state_dict() with torch.serialization.skip_data(): torch.save(sd, 'foo.pt') print(torch.load('foo.pt', weights_only=True)) ``` (2) With `torch.serialization.skip_data(materialize_fake_tensors=True)`If FakeTensor is passed to `torch.save` the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor) ```python import torch import torch.nn as nn from torch._subclasses.fake_tensor import FakeTensorMode with FakeTensorMode(): m = nn.Linear(3, 5, dtype=torch.float16, device='cuda') sd = m.state_dict() with torch.serialization.skip_data(materialize_fake_tensors=True): torch.save(sd, 'bla.pt') print(torch.load('bla.pt', weights_only=True)) # OrderedDict([('weight', tensor([[0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.], # [0., 0., 0.]], device='cuda:0', dtype=torch.float16)), ('bias', tensor([0., 0., 0., 0., 0.], device='cuda:0', dtype=torch.float16))]) ``` ## Follow Ups - [ ] `torch.load` semantic for skip_data context manager - [ ] Mechanism for getting offsets of storages saved via this method (for writing in a separate pass) Differential Revision: [D62238610](https://our.internmc.facebook.com/intern/diff/D62238610) Pull Request resolved: pytorch#134504 Approved by: https://github.com/albanD
Semantic
The semantic is
(1) By default
torch.serialization.skip_data(materialize_fake_tensors=False)
will maketorch.save
skip writing storages (but reserve space for them in the checkpoint).(2) With
torch.serialization.skip_data(materialize_fake_tensors=True)
If FakeTensor is passed totorch.save
the pickler will treat these FakeTensors as being "materialized" space will be reserved in the checkpoint for the associated storage bytes, and when loading the type will be Tensor instead of FakeTensor)Follow Ups
torch.load
semantic for skip_data context managerStack from ghstack (oldest at bottom):
Differential Revision: D62238610