Skip to content

Commit

Permalink
Revert "Add torch.serialization.skip_data context manager (pytorch#13…
Browse files Browse the repository at this point in the history
…4504)"

This reverts commit 202600b.

Reverted pytorch#134504 on behalf of https://github.com/mikaylagawarecki due to This is breaking Windows docs tests due to NamedTemporaryFile on Windows not working well ([comment](pytorch#134504 (comment)))
  • Loading branch information
pytorchmergebot authored and Chao1Han committed Sep 20, 2024
1 parent b375e67 commit 86adf0c
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 260 deletions.
1 change: 0 additions & 1 deletion docs/source/notes/serialization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -398,4 +398,3 @@ The following utility functions are related to serialization:
.. autofunction:: clear_safe_globals
.. autofunction:: get_safe_globals
.. autoclass:: safe_globals
.. autoclass:: skip_data
12 changes: 1 addition & 11 deletions test/test_cpp_extensions_open_device_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def test_open_device_tensorlist_type_fallback(self):
# call _fused_adamw_ with undefined tensor.
self.module.fallback_with_undefined_tensor()

def test_open_device_numpy_serialization(self):
def test_open_device_numpy_serialization_map_location(self):
torch.utils.rename_privateuse1_backend("foo")
device = self.module.custom_device()
default_protocol = torch.serialization.DEFAULT_PROTOCOL
Expand All @@ -553,7 +553,6 @@ def test_open_device_numpy_serialization(self):
self.assertTrue(
rebuild_func is torch._utils._rebuild_device_tensor_from_numpy
)
# Test map_location
with TemporaryFileName() as f:
torch.save(sd, f)
with safe_globals(
Expand All @@ -570,15 +569,6 @@ def test_open_device_numpy_serialization(self):
sd_loaded = torch.load(f, map_location="cpu")
self.assertTrue(sd_loaded["x"].is_cpu)

# Test metadata_only
with TemporaryFileName() as f:
with self.assertRaisesRegex(
RuntimeError,
"Cannot serialize tensors on backends with no storage under skip_data context manager",
):
with torch.serialization.skip_data():
torch.save(sd, f)


if __name__ == "__main__":
common.run_tests()
88 changes: 0 additions & 88 deletions test/test_serialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Owner(s): ["module: serialization"]

import contextlib
import copy
import gc
import gzip
Expand All @@ -20,7 +19,6 @@
from pathlib import Path

import torch
from torch._subclasses.fake_tensor import FakeTensorMode, FakeTensorConverter
from torch._utils import _rebuild_tensor
from torch._utils_internal import get_file_path_2
from torch.serialization import (
Expand All @@ -29,7 +27,6 @@
LoadEndianness,
safe_globals,
set_default_load_endianness,
skip_data,
SourceChangeWarning,
)
from torch.testing._internal.common_device_type import instantiate_device_type_tests
Expand Down Expand Up @@ -4215,91 +4212,6 @@ def test_filewriter_metadata_writing(self, filename):
sd_loaded_ref = torch.load(f)
self.assertEqual(sd_loaded, sd_loaded_ref)

@parametrize("materialize_fake", (True, False))
def test_skip_data_serialization(self, materialize_fake):
# Create one tensor that uses each of the paths in __reduce_ex__ that should work
t_device = "cuda" if torch.cuda.is_available() else "cpu"
t_v2 = torch.randn(2, 3, device=t_device)
t_v3 = torch.randn(2, 3, dtype=torch.complex32, device=t_device)
i = torch.tensor([[0, 1, 1],
[2, 0, 2]])
v = torch.tensor([3, 4, 5], dtype=torch.float32)
if not materialize_fake:
# FakeTensorConverter messes up sizes of i and v for the sparse tensor
st = torch.sparse_coo_tensor(i, v, (2, 4))
tt = TwoTensor(torch.randn(2, device=t_device), torch.randn(2, device=t_device))

mode, converter = FakeTensorMode(), FakeTensorConverter()

def fn(t):
return converter.from_real_tensor(mode, t) if materialize_fake else t

sd = {'t_v2': fn(t_v2), 't_v3': fn(t_v3), 'tt': fn(tt)}
sd_expected = {
't_v2': torch.zeros(2, 3, device=t_device),
't_v3': torch.zeros(2, 3, dtype=torch.complex32, device=t_device),
'tt': TwoTensor(torch.zeros(2, device=t_device), torch.zeros(2, device=t_device)),
}

if not materialize_fake:
sd['st'] = st
sd_expected['st'] = torch.sparse_coo_tensor(torch.zeros(2, 3), torch.zeros(3), (2, 4))

with BytesIOContext() as f:
with skip_data(materialize_fake_tensors=materialize_fake):
torch.save(sd, f)
f.seek(0)
with safe_globals([TwoTensor]):
sd_loaded = torch.load(f, weights_only=True)
self.assertEqual(sd_loaded, sd_expected, exact_device=True)
self.assertFalse(getattr(torch.serialization._serialization_tls, "materialize_fake_tensors", False))
self.assertFalse(getattr(torch.serialization._serialization_tls, "skip_data", False))

# Test that without materialize_fake_tensor, behavior for fake_tensors is not altered by ctx
if not materialize_fake:
ft = converter.from_real_tensor(mode, torch.randn(2, device=t_device))
with self.assertRaisesRegex(AttributeError, "Can't pickle local object 'WeakValueDictionary.__init__.<locals>.remove'"):
with skip_data(), BytesIOContext() as f:
torch.save(ft, f)

@parametrize("materialize_fake", (True, False))
def test_skip_data_serialization_preserves_views(self, materialize_fake):
ctx = FakeTensorMode if materialize_fake else contextlib.nullcontext
with ctx():
t = torch.randn(2, 3)
t_view = t.view(-1)
t_slice = t[1]
sd = {'t': t, 't_view': t_view, 't_slice': t_slice}
with BytesIOContext() as f:
with skip_data(materialize_fake_tensors=materialize_fake):
torch.save(sd, f)
f.seek(0)
sd_loaded = torch.load(f, weights_only=True)
self.assertTrue(id(sd_loaded['t_view'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))
self.assertTrue(id(sd_loaded['t_slice'].untyped_storage()) == id(sd_loaded['t'].untyped_storage()))

def test_skip_data_serialization_error_cases(self):
def _save_load(t):
with BytesIOContext() as f:
with skip_data():
torch.save(t, f)
f.seek(0)
torch.load(f, weights_only=True)

nt = torch.nested.nested_tensor([torch.randn(2), torch.randn(3)])
t = torch.randn(2, 3, device="meta")
with self.assertRaisesRegex(RuntimeError, "Cannot serialize nested tensor under skip_data context manager"):
_save_load(nt)

with self.assertWarnsRegex(UserWarning, "meta device under skip_data context manager is a no-op"):
_save_load(t)

with self.assertRaisesRegex(RuntimeError, "Please call torch.load outside the skip_data context manager"):
with skip_data(), BytesIOContext() as f:
torch.save(torch.randn(2, 3), f)
f.seek(0)
torch.load(f, weights_only=True)

def run(self, *args, **kwargs):
with serialization_method(use_zip=True):
return super().run(*args, **kwargs)
Expand Down
66 changes: 8 additions & 58 deletions torch/_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,16 +209,8 @@ def __deepcopy__(self, memo):
return new_tensor

def __reduce_ex__(self, proto):
materialize_fake_tensors = (
torch.serialization._serialization_tls.materialize_fake_tensors
)
state = torch._utils._get_obj_state(self)
# Ignore all state when using FakeTensor with skip_data(materialize_fake_tensors) because FakeTensor has
# some state that cannot be pickled
if (
type(self) is torch._subclasses.fake_tensor.FakeTensor
and materialize_fake_tensors
) or (type(self) is Tensor and not state):
if type(self) is Tensor and not state:
# Fast path for regular tensor without Python state.
return self._reduce_ex_internal(proto)
if has_torch_function_unary(self):
Expand Down Expand Up @@ -259,12 +251,6 @@ def _reduce_ex_internal(self, proto):
# See Note [Don't serialize hooks]
warn_if_has_hooks(self)
backward_hooks: Dict[Any, Any] = OrderedDict()

skip_data = torch.serialization._serialization_tls.skip_data
materialize_fake_tensors = (
torch.serialization._serialization_tls.materialize_fake_tensors
)

# Note: Numpy array is chosen to be the rebuild component for XLA, MTIA, MAIA Tensors.
# We considered a few options:
# 1. CPU tensor can't be used here.
Expand All @@ -282,10 +268,6 @@ def _reduce_ex_internal(self, proto):
# Convert BFloat16 tesors to Float32 before conversion to numpy, as numpy doesn't
# support BFloat16. The rebuild tensor from numpy takes in the original self.dtype,
# this would reconstruct the BFloat16 tensor from numpy.
if skip_data:
raise RuntimeError(
"Cannot serialize tensors on backends with no storage under skip_data context manager"
)
numpy_tensor = (
self.cpu().numpy()
if self.dtype != torch.bfloat16
Expand All @@ -298,10 +280,6 @@ def _reduce_ex_internal(self, proto):
if self.device.type == "meta":
# NB: This implementation BREAKS storage sharing. Current
# hypothesis is that no one cares for meta tensors.
if skip_data:
warnings.warn(
"Serializing tensors on the meta device under skip_data context manager is a no-op"
)
arg_meta = (
self.dtype,
tuple(self.size()),
Expand All @@ -310,10 +288,6 @@ def _reduce_ex_internal(self, proto):
)
return (torch._utils._rebuild_meta_tensor_no_storage, arg_meta)
if self.is_quantized:
if skip_data:
raise RuntimeError(
"Cannot serialize qtensor under skip_data context manager, file an issue if you need this feature"
)
# quantizer_params can be different type based on torch attribute
quantizer_params: Union[
Tuple[torch.qscheme, float, int], Tuple[Any, Tensor, Tensor, int]
Expand Down Expand Up @@ -395,10 +369,6 @@ def _reduce_ex_internal(self, proto):
)
return (torch._utils._rebuild_sparse_tensor, args_sparse_compressed)
elif self.is_nested:
if skip_data:
raise RuntimeError(
"Cannot serialize nested tensor under skip_data context manager, file an issue if you need this feature"
)
args_nested = (
# NB: values() currently returns the storage as a buffer in an unsafe way.
# Ideally, we'd use a private API for this instead. TODO: Switch to this if
Expand All @@ -413,30 +383,14 @@ def _reduce_ex_internal(self, proto):
type(self) is not torch.Tensor
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
and (
isinstance(self, torch._subclasses.functional_tensor.FunctionalTensor)
or (
not isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
and self.data_ptr() == 0
isinstance(
self,
(
torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor,
),
)
)
):
arg_wrapper_subclass = (
type(self),
self.dtype,
tuple(self.size()),
self.stride(),
self.storage_offset(),
self.layout,
self.device,
self.requires_grad,
)
return (torch._utils._rebuild_wrapper_subclass, arg_wrapper_subclass)
elif (
type(self) is not torch.Tensor
and type(self).__torch_dispatch__ is not torch.Tensor.__torch_dispatch__
and (
isinstance(self, torch._subclasses.fake_tensor.FakeTensor)
and not (skip_data and materialize_fake_tensors)
or self.data_ptr() == 0
)
):
arg_wrapper_subclass = (
Expand Down Expand Up @@ -464,10 +418,6 @@ def _reduce_ex_internal(self, proto):
dtype=self.dtype,
_internal=True,
) # type: ignore[assignment]

if isinstance(self, torch._subclasses.fake_tensor.FakeTensor) and skip_data:
storage._fake_device = self.device

args = (
storage,
self.storage_offset(),
Expand Down
6 changes: 5 additions & 1 deletion torch/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import functools
import logging
import sys
import threading
import traceback
import warnings
from collections import defaultdict
Expand Down Expand Up @@ -108,13 +109,16 @@ def _get_async_or_non_blocking(function_name, non_blocking, kwargs):
return kwargs["async"]


_thread_local_state = threading.local()


def _get_restore_location(device):
"""Return the map_location location.
Used for rebuild functions where the tensor device is distinct from the storage
"""

map_location = torch.serialization._serialization_tls.map_location
map_location = getattr(_thread_local_state, "map_location", None)
if map_location is None:
return device
else:
Expand Down
Loading

0 comments on commit 86adf0c

Please sign in to comment.