diff --git a/tests/test_manifest.py b/tests/test_manifest.py index dc30169..b8cc334 100644 --- a/tests/test_manifest.py +++ b/tests/test_manifest.py @@ -861,3 +861,9 @@ def _update_local_manifest_with_merged_entries( ) # pyre-fixme[7]: Expected `None` but got `Dict[typing.Any, typing.Any]`. return merged_local_manifest + + +def test_get_tensor_shape() -> None: + # pyre-ignore Undefined attribute [16]: `Entry` has no attribute `shards`. + shards = [_MANIFEST_0[f"{i}/foo/qux"].shards[0] for i in range(4)] + assert ShardedTensorEntry(shards=shards).get_tensor_shape() == [4, 8] diff --git a/tests/test_read_object.py b/tests/test_read_object.py index 00fecee..a188430 100644 --- a/tests/test_read_object.py +++ b/tests/test_read_object.py @@ -67,10 +67,20 @@ def _test_read_sharded_tensor() -> None: path=path, app_state={"state": torchsnapshot.StateDict(foo=foo)} ) snapshot.read_object("0/state/foo", obj_out=bar) + baz = snapshot.read_object("0/state/foo") for foo_shard, bar_shard in zip(foo.local_shards(), bar.local_shards()): tc.assertTrue(torch.allclose(foo_shard.tensor, bar_shard.tensor)) + tc.assertEqual(baz.shape, torch.Size([20_000, 128])) + + gathered_foo_tensor = torch.empty(20_000, 128) + if dist.get_rank() == 0: + foo.gather(dst=0, out=gathered_foo_tensor) + tc.assertTrue(torch.allclose(baz, gathered_foo_tensor)) + else: + foo.gather(dst=0, out=None) + def test_read_sharded_tensor(self) -> None: lc = get_pet_launch_config(nproc=4) pet.elastic_launch(lc, entrypoint=self._test_read_sharded_tensor)() diff --git a/torchsnapshot/manifest.py b/torchsnapshot/manifest.py index f615373..6593ce0 100644 --- a/torchsnapshot/manifest.py +++ b/torchsnapshot/manifest.py @@ -137,6 +137,34 @@ def from_yaml_obj(cls, yaml_obj: Any) -> "ShardedTensorEntry": ] return cls(**yaml_obj) + def get_tensor_shape(self) -> List[int]: + """ + Computes the shape of the entire tensor. + + Returns: + List[int]: shape of the entire tensor + + .. note:: + The shape can be computed by finding the maximum (size + offset sum) + tuple in all of the shards. The shard's size/offset are equal or + increasing in each dimension as the shards progress in the list. + """ + assert len(self.shards) > 0, "No shards found." + + first_shard = self.shards[0] + shape = [ + size + offset + for size, offset in zip(first_shard.sizes, first_shard.offsets) + ] + for shard in self.shards[1:]: + sizes = shard.sizes + offsets = shard.offsets + # sum element-wise + candidate_shape = [size + offset for size, offset in zip(sizes, offsets)] + if all(x >= y for x, y in zip(candidate_shape, shape)): + shape = candidate_shape + return shape + @dataclass class ChunkedTensorEntry(Entry): diff --git a/torchsnapshot/snapshot.py b/torchsnapshot/snapshot.py index 0bf73e7..ce9a8ce 100644 --- a/torchsnapshot/snapshot.py +++ b/torchsnapshot/snapshot.py @@ -28,6 +28,8 @@ from torch.distributed._tensor import DTensor from torch.nn.parallel import DistributedDataParallel as DDP from torchsnapshot.dtensor_utils import is_sharded +from torchsnapshot.manifest import ShardedTensorEntry +from torchsnapshot.serialization import string_to_dtype from .batcher import batch_read_requests, batch_write_requests @@ -384,8 +386,9 @@ def read_object( type. Otherwise, ``obj_out`` is ignored. .. note:: - When the target object is a ``ShardedTensor``, ``obj_out`` - must be specified. + When the target object is a ``ShardedTensor``, and ``obj_out`` + is None, will return cpu, full tensor version of the sharded + tensor. memory_budget_bytes (int, optional): When specified, the read operation will keep the temporary memory buffer size below this @@ -435,6 +438,14 @@ def read_object( entry = merged_sd_entries.get(unranked_path) or manifest[unranked_path] if isinstance(entry, PrimitiveEntry): return cast(T, entry.get_value()) + elif obj_out is None and isinstance(entry, ShardedTensorEntry): + # construct tensor for `obj_out` to fill in-place + # by reading shard metadata + shape = entry.get_tensor_shape() + dtype = entry.shards[0].tensor.dtype + tensor = torch.empty(shape, dtype=string_to_dtype(dtype)) + obj_out = tensor + read_reqs, fut = prepare_read( entry=entry, obj_out=obj_out,