diff --git a/docs/source/notes/serialization.rst b/docs/source/notes/serialization.rst index c05dc028a471c..5541d28bdcafa 100644 --- a/docs/source/notes/serialization.rst +++ b/docs/source/notes/serialization.rst @@ -398,4 +398,3 @@ The following utility functions are related to serialization: .. autofunction:: clear_safe_globals .. autofunction:: get_safe_globals .. autoclass:: safe_globals -.. autoclass:: skip_data diff --git a/test/test_cpp_extensions_open_device_registration.py b/test/test_cpp_extensions_open_device_registration.py index 4e86ed458b078..616a6e0f4b551 100644 --- a/test/test_cpp_extensions_open_device_registration.py +++ b/test/test_cpp_extensions_open_device_registration.py @@ -540,7 +540,7 @@ def test_open_device_tensorlist_type_fallback(self): # call _fused_adamw_ with undefined tensor. self.module.fallback_with_undefined_tensor() - def test_open_device_numpy_serialization(self): + def test_open_device_numpy_serialization_map_location(self): torch.utils.rename_privateuse1_backend("foo") device = self.module.custom_device() default_protocol = torch.serialization.DEFAULT_PROTOCOL @@ -553,7 +553,6 @@ def test_open_device_numpy_serialization(self): self.assertTrue( rebuild_func is torch._utils._rebuild_device_tensor_from_numpy ) - # Test map_location with TemporaryFileName() as f: torch.save(sd, f) with safe_globals( @@ -570,15 +569,6 @@ def test_open_device_numpy_serialization(self): sd_loaded = torch.load(f, map_location="cpu") self.assertTrue(sd_loaded["x"].is_cpu) - # Test metadata_only - with TemporaryFileName() as f: - with self.assertRaisesRegex( - RuntimeError, - "Cannot serialize tensors on backends with no storage under skip_data context manager", - ): - with torch.serialization.skip_data(): - torch.save(sd, f) - if __name__ == "__main__": common.run_tests() diff --git a/test/test_serialization.py b/test/test_serialization.py index 3ba96b80541d8..a041473d195b1 100644 --- a/test/test_serialization.py +++ b/test/test_serialization.py @@ -1,6 +1,5 @@ # Owner(s): ["module: serialization"] -import contextlib import copy import gc import gzip @@ -20,7 +19,6 @@ from pathlib import Path import torch -from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter from torch._utils import _rebuild_tensor from torch._utils_internal import get_file_path_2 from torch.serialization import ( @@ -29,7 +27,6 @@ LoadEndianness, safe_globals, set_default_load_endianness, - skip_data, SourceChangeWarning, ) from torch.testing._internal.common_device_type import instantiate_device_type_tests @@ -4215,91 +4212,6 @@ def test_filewriter_metadata_writing(self, filename): sd_loaded_ref = torch.load(f) self.assertEqual(sd_loaded, sd_loaded_ref) - @parametrize("materialize_fake", (True, False)) - def test_skip_data_serialization(self, materialize_fake): - # Create one tensor that uses each of the paths in __reduce_ex__ that should work - t_device = "cuda" if torch.cuda.is_available() else "cpu" - t_v2 = torch.randn(2, 3, device=t_device) - t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device) - i = torch.tensor([[0, 1, 1], - [2, 0, 2]]) - v = torch.tensor([3, 4, 5], dtype=torch.float32) - if not materialize_fake: - # FakeTensorConverter messes up sizes of i and v for the sparse tensor - st = torch.sparse_coo_tensor(i, v, (2, 4)) - tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device)) - - mode, converter = FakeTensorMode(), FakeTensorConverter() - - def fn(t): - return converter.from_real_tensor(mode, t) if materialize_fake else t - - sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)} - sd_expected = { - 't_v2': torch.zeros(2, 3, device=t_device), - 't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device), - 'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)), - } - - if not materialize_fake: - sd['st'] = st - sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4)) - - with BytesIOContext() as f: - with skip_data(materialize_fake_tensors=materialize_fake): - torch.save(sd, f) - f.seek(0) - with safe_globals([TwoTensor]): - sd_loaded = torch.load(f, weights_only=True) - self.assertEqual(sd_loaded, sd_expected, exact_device=True) - self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False)) - self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False)) - - # Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx - if not materialize_fake: - ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device)) - with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__..remove'"): - with skip_data(), BytesIOContext() as f: - torch.save(ft, f) - - @parametrize("materialize_fake", (True, False)) - def test_skip_data_serialization_preserves_views(self, materialize_fake): - ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext - with ctx(): - t = torch.randn(2, 3) - t_view = t.view(-1) - t_slice = t[1] - sd = {'t': t, 't_view': t_view, 't_slice': t_slice} - with BytesIOContext() as f: - with skip_data(materialize_fake_tensors=materialize_fake): - torch.save(sd, f) - f.seek(0) - sd_loaded = torch.load(f, weights_only=True) - self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) - self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage())) - - def test_skip_data_serialization_error_cases(self): - def _save_load(t): - with BytesIOContext() as f: - with skip_data(): - torch.save(t, f) - f.seek(0) - torch.load(f, weights_only=True) - - nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)]) - t = torch.randn(2, 3, device="meta") - with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"): - _save_load(nt) - - with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"): - _save_load(t) - - with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"): - with skip_data(), BytesIOContext() as f: - torch.save(torch.randn(2, 3), f) - f.seek(0) - torch.load(f, weights_only=True) - def run(self, *args, **kwargs): with serialization_method(use_zip=True): return super().run(*args, **kwargs) diff --git a/torch/_tensor.py b/torch/_tensor.py index 61d5e3891b9e5..98563aebae9aa 100644 --- a/torch/_tensor.py +++ b/torch/_tensor.py @@ -209,16 +209,8 @@ def __deepcopy__(self, memo): return new_tensor def __reduce_ex__(self, proto): - materialize_fake_tensors = ( - torch.serialization._serialization_tls.materialize_fake_tensors - ) state = torch._utils._get_obj_state(self) - # Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has - # some state that cannot be pickled - if ( - type(self) is torch._subclasses.fake_tensor.FakeTensor - and materialize_fake_tensors - ) or (type(self) is Tensor and not state): + if type(self) is Tensor and not state: # Fast path for regular tensor without Python state. return self._reduce_ex_internal(proto) if has_torch_function_unary(self): @@ -259,12 +251,6 @@ def _reduce_ex_internal(self, proto): # See Note [Don't serialize hooks] warn_if_has_hooks(self) backward_hooks: Dict[Any, Any] = OrderedDict() - - skip_data = torch.serialization._serialization_tls.skip_data - materialize_fake_tensors = ( - torch.serialization._serialization_tls.materialize_fake_tensors - ) - # Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors. # We considered a few options: # 1. CPU tensor can't be used here. @@ -282,10 +268,6 @@ def _reduce_ex_internal(self, proto): # Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't # support BFloat16. The rebuild tensor from numpy takes in the original self.dtype, # this would reconstruct the BFloat16 tensor from numpy. - if skip_data: - raise RuntimeError( - "Cannot serialize tensors on backends with no storage under skip_data context manager" - ) numpy_tensor = ( self.cpu().numpy() if self.dtype != torch.bfloat16 @@ -298,10 +280,6 @@ def _reduce_ex_internal(self, proto): if self.device.type == "meta": # NB: This implementation BREAKS storage sharing. Current # hypothesis is that no one cares for meta tensors. - if skip_data: - warnings.warn( - "Serializing tensors on the meta device under skip_data context manager is a no-op" - ) arg_meta = ( self.dtype, tuple(self.size()), @@ -310,10 +288,6 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta) if self.is_quantized: - if skip_data: - raise RuntimeError( - "Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature" - ) # quantizer_params can be different type based on torch attribute quantizer_params: Union[ Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int] @@ -395,10 +369,6 @@ def _reduce_ex_internal(self, proto): ) return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed) elif self.is_nested: - if skip_data: - raise RuntimeError( - "Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature" - ) args_nested = ( # NB: values() currently returns the storage as a buffer in an unsafe way. # Ideally, we'd use a private API for this instead. TODO: Switch to this if @@ -413,30 +383,14 @@ def _reduce_ex_internal(self, proto): type(self) is not torch.Tensor and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ and ( - isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor) - or ( - not isinstance(self, torch._subclasses.fake_tensor.FakeTensor) - and self.data_ptr() == 0 + isinstance( + self, + ( + torch._subclasses.fake_tensor.FakeTensor, + torch._subclasses.functional_tensor.FunctionalTensor, + ), ) - ) - ): - arg_wrapper_subclass = ( - type(self), - self.dtype, - tuple(self.size()), - self.stride(), - self.storage_offset(), - self.layout, - self.device, - self.requires_grad, - ) - return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass) - elif ( - type(self) is not torch.Tensor - and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__ - and ( - isinstance(self, torch._subclasses.fake_tensor.FakeTensor) - and not (skip_data and materialize_fake_tensors) + or self.data_ptr() == 0 ) ): arg_wrapper_subclass = ( @@ -464,10 +418,6 @@ def _reduce_ex_internal(self, proto): dtype=self.dtype, _internal=True, ) # type: ignore[assignment] - - if isinstance(self, torch._subclasses.fake_tensor.FakeTensor) and skip_data: - storage._fake_device = self.device - args = ( storage, self.storage_offset(), diff --git a/torch/_utils.py b/torch/_utils.py index f0d38daa81149..938392fa97159 100644 --- a/torch/_utils.py +++ b/torch/_utils.py @@ -3,6 +3,7 @@ import functools import logging import sys +import threading import traceback import warnings from collections import defaultdict @@ -108,13 +109,16 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs): return kwargs["async"] +_thread_local_state = threading.local() + + def _get_restore_location(device): """Return the map_location location. Used for rebuild functions where the tensor device is distinct from the storage """ - map_location = torch.serialization._serialization_tls.map_location + map_location = getattr(_thread_local_state, "map_location", None) if map_location is None: return device else: diff --git a/torch/serialization.py b/torch/serialization.py index 954596ade61f5..2ac36ec371fe5 100644 --- a/torch/serialization.py +++ b/torch/serialization.py @@ -11,7 +11,6 @@ import sys import tarfile import tempfile -import threading import warnings from contextlib import closing, contextmanager from enum import Enum @@ -61,7 +60,6 @@ "get_safe_globals", "add_safe_globals", "safe_globals", - "skip_data", ] @@ -89,22 +87,6 @@ MAP_SHARED, MAP_PRIVATE = None, None # type: ignore[assignment] -# _serialization_tls is used to store thread local state specific to serialization -# that needs to be propagated to other files, in particular we use this for -# (1) map_location (needed for wrapper subclasses/third party devices to torch._utils) -# (2) skip_data (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) -# (3) materialize_fake_tensors (needed for torch.Tensor.__reduce_ex__ for skip_data ctx) -class _SerializationLocal(threading.local): - def __init__(self): - super().__init__() - self.map_location: Optional[MAP_LOCATION] = None - self.skip_data: bool = False - self.materialize_fake_tensors: bool = False - - -_serialization_tls = _SerializationLocal() - - class SourceChangeWarning(Warning): pass @@ -286,46 +268,6 @@ class safe_globals(_weights_only_unpickler._safe_globals): """ -class skip_data: - """ - Context-manager that skips writing storage bytes for ``torch.save`` calls. - - Storages will still be saved, but the space that their bytes would usually be written to - will be empty space. The storage bytes can then be populated in a separate pass. - - .. warning:: - The ``skip_data`` context manager is an early prototype and is subject to change. - - Args: - materialize_fake_tensors: Whether to materialize FakeTensors. - - Example: - >>> import tempfile - >>> t = torch.randn(2, 3) - >>> with tempfile.NamedTemporaryFile() as f: - ... with torch.serialization.skip_data(): - ... torch.save(t, f.name) - ... torch.load(f.name, weights_only=True) - tensor([[0., 0., 0.], - [0., 0., 0.]]) - """ - - def __init__(self, materialize_fake_tensors: bool = False): - self.materialize_fake_tensors = materialize_fake_tensors - - def __enter__(self): - global _serialization_tls - self._old_skip_data = _serialization_tls.skip_data - self._old_materialize_fake_tensors = _serialization_tls.materialize_fake_tensors - _serialization_tls.skip_data = True - _serialization_tls.materialize_fake_tensors = self.materialize_fake_tensors - - def __exit__(self, type, value, tb): - global _serialization_tls - _serialization_tls.skip_data = self._old_skip_data - _serialization_tls.materialize_fake_tensors = self._old_materialize_fake_tensors - - def _is_zipfile(f) -> bool: # This is a stricter implementation than zipfile.is_zipfile(). # zipfile.is_zipfile() is True if the magic number appears anywhere in the @@ -855,11 +797,6 @@ def save( ) return else: - global _serialization_tls - if _serialization_tls.skip_data: - raise RuntimeError( - "Cannot use skip_data=True with _use_new_zipfile_serialization=False" - ) with _open_file_like(f, "wb") as opened_file: _legacy_save(obj, opened_file, pickle_module, pickle_protocol) @@ -1018,13 +955,7 @@ def persistent_id(obj: Any) -> Optional[Tuple]: ) -def _save( - obj, - zip_file, - pickle_module, - pickle_protocol, - _disable_byteorder_record, -): +def _save(obj, zip_file, pickle_module, pickle_protocol, _disable_byteorder_record): serialized_storages = {} id_map: Dict[int, str] = {} @@ -1059,7 +990,7 @@ def persistent_id(obj): # If storage is allocated, ensure that any other saved storages # pointing to the same data all have the same dtype. If storage is # not allocated, don't perform this check - if str(storage.device) != "meta" and storage.data_ptr() != 0: + if storage.data_ptr() != 0: if storage.data_ptr() in storage_dtypes: if storage_dtype != storage_dtypes[storage.data_ptr()]: raise RuntimeError( @@ -1070,10 +1001,7 @@ def persistent_id(obj): storage_dtypes[storage.data_ptr()] = storage_dtype storage_key = id_map.setdefault(storage._cdata, str(len(id_map))) - if hasattr(obj, "_fake_device") and obj._fake_device is not None: - location = str(obj._fake_device) - else: - location = location_tag(storage) + location = location_tag(storage) serialized_storages[storage_key] = storage return ("storage", storage_type, storage_key, location, storage_numel) @@ -1099,18 +1027,14 @@ def persistent_id(obj): for key in sorted(serialized_storages.keys()): name = f"data/{key}" storage = serialized_storages[key] + # given that we copy things around anyway, we might use storage.cpu() + # this means to that to get tensors serialized, you need to implement + # .cpu() on the underlying Storage + if storage.device.type != "cpu": + storage = storage.cpu() + # Now that it is on the CPU we can directly copy it into the zip file num_bytes = storage.nbytes() - global _serialization_tls - if _serialization_tls.skip_data: - zip_file.write_record_metadata(name, num_bytes) - else: - # given that we copy things around anyway, we might use storage.cpu() - # this means to that to get tensors serialized, you need to implement - # .cpu() on the underlying Storage - if storage.device.type != "cpu": - storage = storage.cpu() - # Now that it is on the CPU we can directly copy it into the zip file - zip_file.write_record(name, storage, num_bytes) + zip_file.write_record(name, storage, num_bytes) def load( @@ -1260,14 +1184,6 @@ def _get_wo_message(message: str) -> str: updated_message += message return updated_message + DOCS_MESSAGE - global _serialization_tls - skip_data = _serialization_tls.skip_data - if skip_data: - raise RuntimeError( - "`torch.load` called within a torch.serialization.skip_data context manager " - "is not supported yet. Please call torch.load outside the skip_data context manager." - ) - if weights_only is None: weights_only, warn_weights_only = False, True else: @@ -1842,10 +1758,9 @@ def find_class(self, mod_name, name): unpickler.persistent_load = persistent_load # Needed for tensors where storage device and rebuild tensor device are # not connected (wrapper subclasses and tensors rebuilt using numpy) - global _serialization_tls - _serialization_tls.map_location = map_location + torch._utils._thread_local_state.map_location = map_location result = unpickler.load() - _serialization_tls.map_location = None + del torch._utils._thread_local_state.map_location torch._utils._validate_loaded_sparse_tensors() torch._C._log_api_usage_metadata( diff --git a/torch/storage.py b/torch/storage.py index 8848649905f93..b6ba608c16e5c 100644 --- a/torch/storage.py +++ b/torch/storage.py @@ -39,8 +39,6 @@ class _StorageBase: is_sparse: _bool = False is_sparse_csr: _bool = False device: torch.device - # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) - _fake_device: _Optional[torch.device] = None def __init__(self, *args, **kwargs): pass @@ -651,8 +649,6 @@ def _get_device_from_module(module: str): class TypedStorage: is_sparse: _bool = False - # Used when stashing FakeTensor device onto storage in torch.save(metadata_only=True) - _fake_device: _Optional[torch.device] = None dtype: torch.dtype