From 012d1fc269b5d0a087ee2b31bf277dab46b699f7 Mon Sep 17 00:00:00 2001 From: Meet Vadakkanchery Date: Wed, 18 Jun 2025 10:48:15 -0700 Subject: [PATCH] Implement torch.Tensor APIs for TorchRec wrappers (#3096) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/3096 ### Diff Context Sometime trainer `state_dict` input to checkpointing can contain `LocalShardsWrapper` from TorchRec, which is a `torch.Tensor`. However, it doesn't implement some `torch.Tensor` operations like `copy_`, `zeros_like`, `empty_like`. This diff aims to implement those. Reviewed By: iamzainhuda, pradeepfn Differential Revision: D75553113 --- torchrec/distributed/shards_wrapper.py | 33 ++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/torchrec/distributed/shards_wrapper.py b/torchrec/distributed/shards_wrapper.py index e7fc1e52b..6475452f7 100644 --- a/torchrec/distributed/shards_wrapper.py +++ b/torchrec/distributed/shards_wrapper.py @@ -116,6 +116,9 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): aten.detach.default: cls.handle_detach, aten.clone.default: cls.handle_clone, aten.new_empty.default: cls.handle_new_empty, + aten.copy_.default: cls.handle_copy_, + aten.zeros_like.default: cls.handle_zeros_like, + aten.empty_like.default: cls.handle_empty_like, } if func in dispatcher: @@ -125,6 +128,36 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): f"{func} is not supported for LocalShardsWrapper!" ) + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_zeros_like(args, kwargs): + return LocalShardsWrapper( + [torch.zeros_like(shard, **kwargs) for shard in args[0].local_shards()], + args[0].local_offsets(), + ) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_empty_like(args, kwargs): + return LocalShardsWrapper( + [torch.empty_like(shard, **kwargs) for shard in args[0].local_shards()], + args[0].local_offsets(), + ) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_copy_(args, kwargs): + src = args[1] + dst = args[0] + + for i, shard in enumerate(src.local_shards()): + dst.local_shards()[i].copy_(shard, **kwargs) + + return args[0] + @staticmethod # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated.