From 26eb163c64e6b0947fcedc2376f01d27b2345149 Mon Sep 17 00:00:00 2001 From: generatedunixname89002005307016 Date: Wed, 16 Oct 2024 18:34:25 -0700 Subject: [PATCH] Add type error suppressions for upcoming upgrade Reviewed By: MaggieMoss Differential Revision: D64502856 fbshipit-source-id: 82c7bc157cfed0539319785987d56946b1aac94c --- tests/test_memoryview_stream.py | 1 + tests/test_snapshot.py | 1 + torchsnapshot/memoryview_stream.py | 4 ++++ torchsnapshot/serialization.py | 9 ++++++++- 4 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/test_memoryview_stream.py b/tests/test_memoryview_stream.py index 641b057..35d47e3 100644 --- a/tests/test_memoryview_stream.py +++ b/tests/test_memoryview_stream.py @@ -17,6 +17,7 @@ class MemoryviewStreamTest(unittest.TestCase): def test_memoryview_stream(self) -> None: tensor = torch.rand(1000) + # pyre-fixme[6]: For 1st argument expected `Buffer` but got `ndarray[Any, Any]`. mv = memoryview(tensor.numpy()).cast("b") self.assertEqual(len(mv), 4000) diff --git a/tests/test_snapshot.py b/tests/test_snapshot.py index 40481ef..399fedc 100644 --- a/tests/test_snapshot.py +++ b/tests/test_snapshot.py @@ -180,6 +180,7 @@ def test_app_state_with_primitive_types(tmp_path: Path) -> None: bytes_key=None, ) + # pyre-fixme[53]: Captured variable `snapshot` is not annotated. def _assert_primitive_entry_with_type( location_key: str, expected_type_name: str ) -> None: diff --git a/torchsnapshot/memoryview_stream.py b/torchsnapshot/memoryview_stream.py index 31f5716..d383e24 100644 --- a/torchsnapshot/memoryview_stream.py +++ b/torchsnapshot/memoryview_stream.py @@ -13,10 +13,13 @@ # pyre-fixme[13]: Attribute `write` is never initialized. class MemoryviewStream(io.IOBase): + # pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter. def __init__(self, mv: memoryview) -> None: + # pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter. self._mv: memoryview = mv.cast("b") self._pos = 0 + # pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter. def read(self, size: Optional[int] = -1) -> memoryview: if self.closed: raise ValueError("read from closed file") @@ -38,6 +41,7 @@ def read(self, size: Optional[int] = -1) -> memoryview: self._pos = newpos return b + # pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter. def read1(self, size: int = -1) -> memoryview: """This is the same as read.""" return self.read(size) diff --git a/torchsnapshot/serialization.py b/torchsnapshot/serialization.py index b48be71..f991466 100644 --- a/torchsnapshot/serialization.py +++ b/torchsnapshot/serialization.py @@ -173,6 +173,7 @@ class Serializer(Enum): ] +# pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter. def tensor_as_memoryview(tensor: torch.Tensor) -> memoryview: """ Obtain the class::`memoryview` of a class::`torch.Tensor`. @@ -199,9 +200,11 @@ def tensor_as_memoryview(tensor: torch.Tensor) -> memoryview: tensor = tensor.contiguous() if tensor.dtype == torch.bfloat16: return _tensor_as_memoryview_via_untyped_storage(tensor) + # pyre-fixme[6]: For 1st argument expected `Buffer` but got `ndarray[Any, Any]`. return memoryview(tensor.numpy()).cast("b") +# pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter. def _tensor_as_memoryview_via_untyped_storage(tensor: torch.Tensor) -> memoryview: """ Obtain the class::`memoryview` of a class::`torch.Tensor` via untyped storage. @@ -223,6 +226,7 @@ def _tensor_as_memoryview_via_untyped_storage(tensor: torch.Tensor) -> memoryvie untyped_storage = contiguous_view_as_untyped_storage(tensor) tensor = torch.empty((0)) tensor.set_(untyped_storage) + # pyre-fixme[6]: For 1st argument expected `Buffer` but got `ndarray[Any, Any]`. return memoryview(tensor.numpy()).cast("b") @@ -249,7 +253,10 @@ def contiguous_view_as_untyped_storage(tensor: torch.Tensor) -> UntypedStorage: def tensor_from_memoryview( - mv: memoryview, dtype: torch.dtype, shape: List[int] + # pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter. + mv: memoryview, + dtype: torch.dtype, + shape: List[int], ) -> torch.Tensor: # PyTorch issues a warning if the given memoryview is non-writable. This is # not a concern for torchsnapshot, as tensors created from non-writable