Skip to content

Commit

Permalink
add sharded tensor support to read_object when obj_out=None (#167)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #167

# Context
It is currently not possible to read a sharded tensor directly from snapshot without constructing a sharded tensor manually and passing via `obj_out` in `read_object`.

# This Diff
Adds `get_tensor_shape` method to `ShardedTensorEntry`, which calculates
 the sharded tensor's size (the largest  offsets + sizes tuple). `read_object` calls this and sets `obj_out` to a cpu, regular tensor of that size.

Reviewed By: galrotem, RdoubleA, yifuwang

Differential Revision: D52258262

fbshipit-source-id: 01231e4132af1b0459c16e6bb8523c62d6217858
  • Loading branch information
JKSenthil authored and facebook-github-bot committed Dec 19, 2023
1 parent 1f42a31 commit 0e60109
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 2 deletions.
6 changes: 6 additions & 0 deletions tests/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -861,3 +861,9 @@ def _update_local_manifest_with_merged_entries(
)
# pyre-fixme[7]: Expected `None` but got `Dict[typing.Any, typing.Any]`.
return merged_local_manifest


def test_get_tensor_shape() -> None:
# pyre-ignore Undefined attribute [16]: `Entry` has no attribute `shards`.
shards = [_MANIFEST_0[f"{i}/foo/qux"].shards[0] for i in range(4)]
assert ShardedTensorEntry(shards=shards).get_tensor_shape() == [4, 8]
10 changes: 10 additions & 0 deletions tests/test_read_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,20 @@ def _test_read_sharded_tensor() -> None:
path=path, app_state={"state": torchsnapshot.StateDict(foo=foo)}
)
snapshot.read_object("0/state/foo", obj_out=bar)
baz = snapshot.read_object("0/state/foo")

for foo_shard, bar_shard in zip(foo.local_shards(), bar.local_shards()):
tc.assertTrue(torch.allclose(foo_shard.tensor, bar_shard.tensor))

tc.assertEqual(baz.shape, torch.Size([20_000, 128]))

gathered_foo_tensor = torch.empty(20_000, 128)
if dist.get_rank() == 0:
foo.gather(dst=0, out=gathered_foo_tensor)
tc.assertTrue(torch.allclose(baz, gathered_foo_tensor))
else:
foo.gather(dst=0, out=None)

def test_read_sharded_tensor(self) -> None:
lc = get_pet_launch_config(nproc=4)
pet.elastic_launch(lc, entrypoint=self._test_read_sharded_tensor)()
Expand Down
28 changes: 28 additions & 0 deletions torchsnapshot/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,34 @@ def from_yaml_obj(cls, yaml_obj: Any) -> "ShardedTensorEntry":
]
return cls(**yaml_obj)

def get_tensor_shape(self) -> List[int]:
"""
Computes the shape of the entire tensor.
Returns:
List[int]: shape of the entire tensor
.. note::
The shape can be computed by finding the maximum (size + offset sum)
tuple in all of the shards. The shard's size/offset are equal or
increasing in each dimension as the shards progress in the list.
"""
assert len(self.shards) > 0, "No shards found."

first_shard = self.shards[0]
shape = [
size + offset
for size, offset in zip(first_shard.sizes, first_shard.offsets)
]
for shard in self.shards[1:]:
sizes = shard.sizes
offsets = shard.offsets
# sum element-wise
candidate_shape = [size + offset for size, offset in zip(sizes, offsets)]
if all(x >= y for x, y in zip(candidate_shape, shape)):
shape = candidate_shape
return shape


@dataclass
class ChunkedTensorEntry(Entry):
Expand Down
15 changes: 13 additions & 2 deletions torchsnapshot/snapshot.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from torch.distributed._tensor import DTensor
from torch.nn.parallel import DistributedDataParallel as DDP
from torchsnapshot.dtensor_utils import is_sharded
from torchsnapshot.manifest import ShardedTensorEntry
from torchsnapshot.serialization import string_to_dtype

from .batcher import batch_read_requests, batch_write_requests

Expand Down Expand Up @@ -384,8 +386,9 @@ def read_object(
type. Otherwise, ``obj_out`` is ignored.
.. note::
When the target object is a ``ShardedTensor``, ``obj_out``
must be specified.
When the target object is a ``ShardedTensor``, and ``obj_out``
is None, will return cpu, full tensor version of the sharded
tensor.
memory_budget_bytes (int, optional): When specified, the read
operation will keep the temporary memory buffer size below this
Expand Down Expand Up @@ -435,6 +438,14 @@ def read_object(
entry = merged_sd_entries.get(unranked_path) or manifest[unranked_path]
if isinstance(entry, PrimitiveEntry):
return cast(T, entry.get_value())
elif obj_out is None and isinstance(entry, ShardedTensorEntry):
# construct tensor for `obj_out` to fill in-place
# by reading shard metadata
shape = entry.get_tensor_shape()
dtype = entry.shards[0].tensor.dtype
tensor = torch.empty(shape, dtype=string_to_dtype(dtype))
obj_out = tensor

read_reqs, fut = prepare_read(
entry=entry,
obj_out=obj_out,
Expand Down

0 comments on commit 0e60109

Please sign in to comment.