Skip to content

Commit

Permalink
add pyre-fixme annotations to problematic lines (#166)
Browse files Browse the repository at this point in the history
Summary:

Adding pyre-fixme annotations for violating lines.

Differential Revision: D51680177
  • Loading branch information
galrotem authored and facebook-github-bot committed Nov 29, 2023
1 parent e2000bc commit 739e48e
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 1 deletion.
4 changes: 4 additions & 0 deletions tests/gpu_tests/test_dtensor_io_preparer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from torchsnapshot.test_utils import tensor_eq

WORLD_SIZE = 4
# pyre-fixme[5]: Global expression must be annotated.
_DEVICE_MESH = [
list(range(WORLD_SIZE)),
np.arange(WORLD_SIZE).reshape(2, 2).tolist(),
Expand All @@ -58,6 +59,7 @@ class TestDTensorIOPreparer(DTensorTestBase):
@parametrize("mesh", _DEVICE_MESH)
@parametrize("placements", _PLACEMENTS)
@skip_if_lt_x_gpu(WORLD_SIZE)
# pyre-fixme[56]: While applying decorator `torch.testing._internal.distributed._...
@with_comms
async def test_dtensor_io_preparer(
self,
Expand Down Expand Up @@ -92,6 +94,8 @@ async def test_dtensor_io_preparer(
# When subdivision is enabled, we have more write requests than local
# shards, and each write request corresponds to a subview of a local
# shard.
# pyre-fixme[6]: For 1st argument expected `pyre_extensions.ReadOnly[Sized]`
# but got `int`.
assert len(src._spec.num_shards) < len(write_reqs)
entry_total_size = 0
for shard_entry in entry.shards:
Expand Down
1 change: 1 addition & 0 deletions tests/gpu_tests/test_dtensor_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
class TestDTensorUtils(DTensorTestBase):
@with_comms
@skip_if_lt_x_gpu(WORLD_SIZE)
# pyre-fixme[3]: Return type must be annotated.
def test_is_sharded_is_replicated(self):
mesh = DeviceMesh("cuda", mesh=[[0, 1], [2, 3]])
placements = [Replicate(), Shard(0)]
Expand Down
12 changes: 12 additions & 0 deletions tests/gpu_tests/test_manifest_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@
class TestManifestUtils(DTensorTestBase):
@parametrize("dtype", NCCL_SUPPORTED_DTYPES)
@skip_if_lt_x_gpu(WORLD_SIZE)
# pyre-fixme[56]: While applying decorator
# `torch.testing._internal.distributed._tensor.common_dtensor.with_comms`: For 1st
# argument expected `(object) -> object` but got `(self: TestManifestUtils, dtype:
# dtype) -> Any`.
@with_comms
# pyre-fixme[3]: Return type must be annotated.
def test_get_replicated_ranks(self, dtype: torch.dtype):
logical_path = "foo"
tensor, entry, wrs = _dtensor_test_case(
Expand All @@ -41,13 +46,19 @@ def test_get_replicated_ranks(self, dtype: torch.dtype):
rank=dist.get_rank(),
replicated=True,
)
# pyre-fixme[6]: For 1st argument expected `DTensorEntry` but got `Entry`.
actual_repranks = _get_replicated_ranks(entry=entry)
expected_repranks = [[0, 2], [1, 3]]
assert actual_repranks == expected_repranks

@parametrize("dtype", NCCL_SUPPORTED_DTYPES)
@skip_if_lt_x_gpu(WORLD_SIZE)
# pyre-fixme[56]: While applying decorator
# `torch.testing._internal.distributed._tensor.common_dtensor.with_comms`: For 1st
# argument expected `(object) -> object` but got `(self: TestManifestUtils, dtype:
# dtype) -> Any`.
@with_comms
# pyre-fixme[3]: Return type must be annotated.
def test_is_partially_replicated(self, dtype: torch.dtype):
logical_path = "foo"
tensor, entry, wrs = _dtensor_test_case(
Expand All @@ -60,6 +71,7 @@ def test_is_partially_replicated(self, dtype: torch.dtype):
assert is_partially_replicated_entry(entry=entry)

# Only replicated
# pyre-fixme[16]: `Entry` has no attribute `dim_map`.
entry.dim_map = [-1, -1]
assert not is_partially_replicated_entry(entry=entry)

Expand Down
4 changes: 4 additions & 0 deletions tests/gpu_tests/test_partitioner_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class TestPartitioner(DTensorTestBase):
@parametrize("dtype", NCCL_SUPPORTED_DTYPES)
@parametrize("enable_batcher", [True, False])
@skip_if_lt_x_gpu(WORLD_SIZE)
# pyre-fixme[56]: While applying decorator
# `torch.testing._internal.distributed._tensor.common_dtensor.with_comms`: For 1st
# argument expected `(object) -> object` but got `(self: TestPartitioner, dtype:
# dtype, enable_batcher: bool) -> Coroutine[typing.Any, typing.Any, None]`.
@with_comms
async def test_partitioner(
self,
Expand Down
10 changes: 10 additions & 0 deletions tests/gpu_tests/test_snapshot_dtensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,23 +38,28 @@


class DummyModel(torch.nn.Module):
# pyre-fixme[3]: Return type must be annotated.
def __init__(self):
super().__init__()
self.net1 = nn.Sequential(nn.Linear(8, 16), nn.ReLU())
self.net2 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
self.net3 = nn.Linear(32, 64)
self.net4 = nn.Sequential(nn.ReLU(), nn.Linear(64, 8))

# pyre-fixme[3]: Return type must be annotated.
# pyre-fixme[2]: Parameter must be annotated.
def forward(self, x):
return self.net4(self.net3(self.net2(self.net1(x))))

# pyre-fixme[3]: Return type must be annotated.
def get_input(self):
return torch.rand(4, 8, device="cuda")


# TODO: Test different world sizes (may require not using DTensorTestBase)
# TODO: Test FSDP + TP once dim_map is updated for [Shard(0), Shard(0)] cases
class TestSnapshotWithDTensor(DTensorTestBase):
# pyre-fixme[3]: Return type must be annotated.
def _create_model(
self, seed: int, optim_lr: float, device_mesh: Optional[DeviceMesh] = None
):
Expand All @@ -74,6 +79,10 @@ def _create_model(
inter_node_pg = mesh_2d.get_dim_groups(mesh_dim=0)
model = FSDP(
DummyModel().cuda(),
# pyre-fixme[6]: For 2nd argument expected `Union[None,
# Tuple[ProcessGroup, ProcessGroup], ProcessGroup]` but got
# `Tuple[Union[List[ProcessGroup], ProcessGroup],
# Union[List[ProcessGroup], ProcessGroup]]`.
process_group=(intra_node_pg, inter_node_pg),
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
)
Expand All @@ -96,6 +105,7 @@ def _create_model(

@with_comms
@skip_if_lt_x_gpu(WORLD_SIZE)
# pyre-fixme[3]: Return type must be annotated.
def test_save_and_load_same_world_size(self):
mesh_2d = init_device_mesh("cuda", (2, WORLD_SIZE // 2))
src_model, src_optim = self._create_model(
Expand Down
6 changes: 6 additions & 0 deletions tests/gpu_tests/test_snapshot_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def _create_fsdp_model(
@pytest.mark.skipif(
not torch.cuda.is_available(), reason="The test requires GPUs to run."
)
# pyre-fixme[56]: Pyre was not able to infer the type of the decorator
# `pytest.mark.gpu_only`.
@pytest.mark.gpu_only
@pytest.mark.usefixtures("toggle_batching")
# Sharded state dict will test ShardedTensors, full tests Tensors
Expand Down Expand Up @@ -76,7 +78,11 @@ def test_model_and_optim_fsdp(tmp_path: Path, state_dict_type: StateDictType) ->
bar_optim.step(closure=None)
bar_optim.zero_grad(set_to_none=True)

# pyre-fixme[6]: For 1st argument expected `FullyShardedDataParallel` but got
# `Module`.
foo_fsdp_optim = FSDPOptimizerAdapter(foo_fsdp, foo_optim)
# pyre-fixme[6]: For 1st argument expected `FullyShardedDataParallel` but got
# `Module`.
bar_fsdp_optim = FSDPOptimizerAdapter(bar_fsdp, bar_optim)

assert not check_state_dict_eq(
Expand Down
19 changes: 18 additions & 1 deletion tests/test_batcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,11 @@ async def test_batcher(
Verify the behavior of the batcher.
"""
src_tensors, entries, write_reqs, dst_tensors = (
# pyre-fixme[6]: For 1st argument expected `List[Tensor]` but got
# `Union[List[Entry], List[WriteReq], List[Tensor]]`.
# pyre-fixme[58]: `+` is not supported for operand types
# `List[torch._tensor.Tensor]` and `Union[List[Entry], List[WriteReq],
# List[torch._tensor.Tensor]]`.
a + b + c
for a, b, c in zip(
tensor_test_cases, sharded_tensor_test_cases, dtensor_test_cases
Expand All @@ -326,7 +331,10 @@ async def test_batcher(
entry, wrs = ObjectIOPreparer.prepare_write(
storage_path=f"object_{idx}", obj=object()
)
# pyre-fixme[6]: For 1st argument expected `Tensor` but got `ObjectEntry`.
entries.append(entry)
# pyre-fixme[6]: For 1st argument expected `Iterable[Tensor]` but got
# `List[WriteReq]`.
write_reqs.extend(wrs)

# The order of the write request shouldn't matter so we shuffle it
Expand All @@ -339,16 +347,25 @@ async def test_batcher(
# Expect the batcher to fail if it did not receive all affected entries
# NOTE: don't shuffle the entries as they have to be aligned with
# src_tensors and dst_tensors.
# pyre-fixme[6]: For 1st argument expected `BufferStager` but got `Tensor`.
if is_batchable(entries[0]):
with pytest.raises(RuntimeError):
batch_write_requests(
# pyre-fixme[6]: For 1st argument expected `List[Entry]` but got
# `List[Tensor]`.
entries=entries[1:],
# pyre-fixme[6]: For 2nd argument expected `List[WriteReq]` but got
# `List[Tensor]`.
write_reqs=write_reqs,
)

# Batch the write requests
entries, batched_write_reqs = batch_write_requests(
entries=entries, write_reqs=copy.deepcopy(write_reqs)
# pyre-fixme[6]: For 1st argument expected `List[Entry]` but got `List[Tensor]`.
entries=entries,
# pyre-fixme[6]: For 2nd argument expected `List[WriteReq]` but got
# `List[Tensor]`.
write_reqs=copy.deepcopy(write_reqs),
)
assert len(batched_write_reqs) < len(write_reqs)
write_reqs = batched_write_reqs
Expand Down
4 changes: 4 additions & 0 deletions tests/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,8 @@ def test_ddp_save_load_non_ddp(tmp_path: Path) -> None:
# The utility consume_prefix_in_state_dict_if_present re-inserts keys into the state dict
# which changes the order they appear in the state dict, as it is an OrderedDict.
# to test for equality, explicitly sort the state dicts by key before comparison
# pyre-fixme[6]: For 1st argument expected `Dict[typing.Any, typing.Any]` but
# got `List[str]`.
# pyre-fixme[6]: For 2nd argument expected `Dict[typing.Any, typing.Any]` but
# got `List[str]`.
assert check_state_dict_eq(sorted(restored_state_dict), sorted(consumed_state_dict))
3 changes: 3 additions & 0 deletions tests/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,8 @@ def test_get_local_manifest(manifest: Dict[str, Entry], rank: int) -> None:
expected_local_manifest[local_path] = entry

merged_local_manifest = _update_local_manifest_with_merged_entries(local_manifest)
# pyre-fixme[6]: For 1st argument expected `SupportsKeysAndGetItem[typing.Any,
# typing.Any]` but got `None`.
expected_local_manifest.update(merged_local_manifest)

if rank >= _WORLD_SIZE:
Expand Down Expand Up @@ -857,4 +859,5 @@ def _update_local_manifest_with_merged_entries(
mesh=[[0, 1], [2, 3]],
dim_map=[[-1], [0]],
)
# pyre-fixme[7]: Expected `None` but got `Dict[typing.Any, typing.Any]`.
return merged_local_manifest

0 comments on commit 739e48e

Please sign in to comment.