Skip to content

Commit

Permalink
Add type error suppressions for upcoming upgrade
Browse files Browse the repository at this point in the history
Reviewed By: MaggieMoss

Differential Revision: D64502856

fbshipit-source-id: 82c7bc157cfed0539319785987d56946b1aac94c
  • Loading branch information
generatedunixname89002005307016 authored and facebook-github-bot committed Oct 17, 2024
1 parent 1ff027d commit 26eb163
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 1 deletion.
1 change: 1 addition & 0 deletions tests/test_memoryview_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
class MemoryviewStreamTest(unittest.TestCase):
def test_memoryview_stream(self) -> None:
tensor = torch.rand(1000)
# pyre-fixme[6]: For 1st argument expected `Buffer` but got `ndarray[Any, Any]`.
mv = memoryview(tensor.numpy()).cast("b")
self.assertEqual(len(mv), 4000)

Expand Down
1 change: 1 addition & 0 deletions tests/test_snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def test_app_state_with_primitive_types(tmp_path: Path) -> None:
bytes_key=None,
)

# pyre-fixme[53]: Captured variable `snapshot` is not annotated.
def _assert_primitive_entry_with_type(
location_key: str, expected_type_name: str
) -> None:
Expand Down
4 changes: 4 additions & 0 deletions torchsnapshot/memoryview_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,13 @@

# pyre-fixme[13]: Attribute `write` is never initialized.
class MemoryviewStream(io.IOBase):
# pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter.
def __init__(self, mv: memoryview) -> None:
# pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter.
self._mv: memoryview = mv.cast("b")
self._pos = 0

# pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter.
def read(self, size: Optional[int] = -1) -> memoryview:
if self.closed:
raise ValueError("read from closed file")
Expand All @@ -38,6 +41,7 @@ def read(self, size: Optional[int] = -1) -> memoryview:
self._pos = newpos
return b

# pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter.
def read1(self, size: int = -1) -> memoryview:
"""This is the same as read."""
return self.read(size)
Expand Down
9 changes: 8 additions & 1 deletion torchsnapshot/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ class Serializer(Enum):
]


# pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter.
def tensor_as_memoryview(tensor: torch.Tensor) -> memoryview:
"""
Obtain the class::`memoryview` of a class::`torch.Tensor`.
Expand All @@ -199,9 +200,11 @@ def tensor_as_memoryview(tensor: torch.Tensor) -> memoryview:
tensor = tensor.contiguous()
if tensor.dtype == torch.bfloat16:
return _tensor_as_memoryview_via_untyped_storage(tensor)
# pyre-fixme[6]: For 1st argument expected `Buffer` but got `ndarray[Any, Any]`.
return memoryview(tensor.numpy()).cast("b")


# pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter.
def _tensor_as_memoryview_via_untyped_storage(tensor: torch.Tensor) -> memoryview:
"""
Obtain the class::`memoryview` of a class::`torch.Tensor` via untyped storage.
Expand All @@ -223,6 +226,7 @@ def _tensor_as_memoryview_via_untyped_storage(tensor: torch.Tensor) -> memoryvie
untyped_storage = contiguous_view_as_untyped_storage(tensor)
tensor = torch.empty((0))
tensor.set_(untyped_storage)
# pyre-fixme[6]: For 1st argument expected `Buffer` but got `ndarray[Any, Any]`.
return memoryview(tensor.numpy()).cast("b")


Expand All @@ -249,7 +253,10 @@ def contiguous_view_as_untyped_storage(tensor: torch.Tensor) -> UntypedStorage:


def tensor_from_memoryview(
mv: memoryview, dtype: torch.dtype, shape: List[int]
# pyre-fixme[24]: Generic type `memoryview` expects 1 type parameter.
mv: memoryview,
dtype: torch.dtype,
shape: List[int],
) -> torch.Tensor:
# PyTorch issues a warning if the given memoryview is non-writable. This is
# not a concern for torchsnapshot, as tensors created from non-writable
Expand Down

0 comments on commit 26eb163

Please sign in to comment.