diff --git a/torchrec/distributed/shards_wrapper.py b/torchrec/distributed/shards_wrapper.py index e7fc1e52b..6475452f7 100644 --- a/torchrec/distributed/shards_wrapper.py +++ b/torchrec/distributed/shards_wrapper.py @@ -116,6 +116,9 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): aten.detach.default: cls.handle_detach, aten.clone.default: cls.handle_clone, aten.new_empty.default: cls.handle_new_empty, + aten.copy_.default: cls.handle_copy_, + aten.zeros_like.default: cls.handle_zeros_like, + aten.empty_like.default: cls.handle_empty_like, } if func in dispatcher: @@ -125,6 +128,36 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): f"{func} is not supported for LocalShardsWrapper!" ) + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_zeros_like(args, kwargs): + return LocalShardsWrapper( + [torch.zeros_like(shard, **kwargs) for shard in args[0].local_shards()], + args[0].local_offsets(), + ) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_empty_like(args, kwargs): + return LocalShardsWrapper( + [torch.empty_like(shard, **kwargs) for shard in args[0].local_shards()], + args[0].local_offsets(), + ) + + @staticmethod + # pyre-fixme[3]: Return type must be annotated. + # pyre-fixme[2]: Parameter must be annotated. + def handle_copy_(args, kwargs): + src = args[1] + dst = args[0] + + for i, shard in enumerate(src.local_shards()): + dst.local_shards()[i].copy_(shard, **kwargs) + + return args[0] + @staticmethod # pyre-fixme[3]: Return type must be annotated. # pyre-fixme[2]: Parameter must be annotated.