Skip to content

Commit 24b9890

Browse files
colin2328pytorchmergebot
authored andcommitted
[torchrec] [composable] update ShardedEmbeddingBagCollection to be use registered EBCs with shardedTensors as registered modules (#758) (pytorch#88026)
Summary: X-link: meta-pytorch/torchrec#758 This PR fixes a bug in FSDP/DDP, where ShardedTensors are not supported even if passed in as params to ignore. this is important for composability because TorchRec named_parameters() will return FQN of shardedTensors (as defined in goals) It defines device of ShardedTensor to be None when local_tensor() does not exist on rank update ShardedEmbeddingBagCollection to be composable according to https://docs.google.com/document/d/1TBJSd5zgEg6cRcXv3Okuj7bBkqQwGS2IPh4TLWNNzFI/edit Differential Revision: D40458625 Pull Request resolved: pytorch#88026 Approved by: https://github.com/wanchaol, https://github.com/rohan-varma
1 parent 1cd6ebe commit 24b9890

File tree

4 files changed

+82
-23
lines changed

4 files changed

+82
-23
lines changed

test/distributed/test_c10d_gloo.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,28 +23,35 @@
2323
import torch.nn.functional as F
2424
import torch.testing._internal.common_utils as common
2525
from test_c10d_common import (
26-
LOOPBACK,
2726
gpus_for_rank,
28-
Task,
27+
LOOPBACK,
2928
ModuleForDdpCommHook,
3029
SparseGradientModule,
30+
Task,
3131
)
3232
from torch import nn
33+
from torch.distributed._shard.sharded_tensor import (
34+
init_from_local_shards,
35+
Shard,
36+
ShardedTensor,
37+
ShardMetadata,
38+
)
3339
from torch.nn.parallel import DistributedDataParallel
40+
from torch.nn.parallel._replicated_tensor_ddp_utils import _ddp_replicated_tensor
3441
from torch.testing._internal.common_distributed import (
42+
create_device,
3543
MultiProcessTestCase,
3644
requires_gloo,
37-
skip_if_lt_x_gpu,
3845
simple_sparse_reduce_tests,
46+
skip_if_lt_x_gpu,
3947
skip_if_win32,
40-
create_device,
4148
verify_ddp_error_logged,
4249
)
4350
from torch.testing._internal.common_utils import (
44-
TestCase,
45-
run_tests,
4651
retry_on_connect_failures,
52+
run_tests,
4753
sandcastle_skip,
54+
TestCase,
4855
)
4956

5057

@@ -1754,6 +1761,49 @@ def forward(self, x):
17541761
loss = criterion(output, target)
17551762
loss.backward()
17561763

1764+
@requires_gloo()
1765+
@skip_if_lt_x_gpu(2)
1766+
def test_ignored_sharded_tensor(self):
1767+
class MyModule(nn.Module):
1768+
def __init__(self, shard_tensor: ShardedTensor) -> None:
1769+
super().__init__()
1770+
self.fc1 = nn.Linear(2, 10, bias=False)
1771+
self.st = nn.Parameter(shard_tensor)
1772+
self.relu = nn.ReLU()
1773+
1774+
def forward(self, x):
1775+
x = self.relu(self.fc1(x))
1776+
return F.softmax(x, dim=1)
1777+
pg = dist.init_process_group(
1778+
"gloo",
1779+
init_method=f"file://{self.file_name}",
1780+
world_size=self.world_size,
1781+
rank=self.rank,
1782+
)
1783+
device = torch.device(f"cuda:{self.rank}")
1784+
local_shard_metadata = ShardMetadata(
1785+
shard_offsets=[(self.rank % 2) * 5, 0],
1786+
shard_sizes=[5, 10],
1787+
placement=f"rank:{self.rank}/cuda:{self.rank}"
1788+
)
1789+
local_shards = [Shard(torch.randn(5, 10, device=device), local_shard_metadata)]
1790+
st = init_from_local_shards(local_shards, [10, 10])
1791+
m = MyModule(st)
1792+
with _ddp_replicated_tensor(False):
1793+
DistributedDataParallel._set_params_and_buffers_to_ignore_for_model(
1794+
module=m,
1795+
params_and_buffers_to_ignore={'st'}
1796+
)
1797+
# test to make DDP constructor will not fail when module includes a ShardedTensor when ignored
1798+
DistributedDataParallel(
1799+
m,
1800+
device_ids=[device] if device.type == "gpu" else None,
1801+
process_group=pg,
1802+
gradient_as_bucket_view=True,
1803+
broadcast_buffers=False,
1804+
static_graph=True,
1805+
)
1806+
17571807
def _run_and_verify_sparse_gradients(self, vanilla_model, ddp_model):
17581808
mult = 2
17591809
batch_size = mult * self.world_size

torch/distributed/_shard/sharded_tensor/_ops/tensor_ops.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,14 @@ def tensor_device(types, args=(), kwargs=None, pg=None):
4242
# Validate types
4343
if not isinstance(self_st, ShardedTensor):
4444
raise TypeError("input needs to be a ShardedTensor")
45-
46-
return self_st.local_shards()[0].tensor.device
47-
45+
dev: torch.device
46+
if self_st._local_shards:
47+
dev = self_st._local_shards[0].tensor.device
48+
elif pg and pg._get_backend_name() == "gloo":
49+
dev = torch.device("cpu")
50+
else:
51+
dev = torch.device(torch.cuda.current_device())
52+
return dev
4853

4954
@_sharded_op_impl(torch.Tensor.is_meta.__get__) # type: ignore[attr-defined]
5055
def st_is_meta(types, args=(), kwargs=None, pg=None):

torch/distributed/_shard/sharded_tensor/api.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,13 @@ def cuda(
630630
return st_cuda
631631

632632
def to(self, *args, **kwargs) -> ShardedTensor:
633-
current_device = self._local_shards[0].tensor.device
633+
current_device: torch.device
634+
if self._local_shards:
635+
current_device = self._local_shards[0].tensor.device
636+
elif self._process_group._get_backend_name() == "gloo":
637+
current_device = torch.device("cpu")
638+
else:
639+
current_device = torch.device(torch.cuda.current_device())
634640
current_dtype = self.dtype
635641
device_to = current_device
636642
dtype_to = current_dtype

torch/nn/parallel/distributed.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -553,11 +553,15 @@ def __init__(
553553
gradient_as_bucket_view=False,
554554
static_graph=False,
555555
):
556-
557556
super(DistributedDataParallel, self).__init__()
558557
Joinable.__init__(self)
559558
self.logger = None
560-
if not any((p.requires_grad for p in module.parameters())):
559+
if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
560+
self.parameters_to_ignore = set(module._ddp_params_and_buffers_to_ignore)
561+
else:
562+
self.parameters_to_ignore = set()
563+
self._module_parameters = [p for n, p in module.named_parameters() if n not in self.parameters_to_ignore]
564+
if not any((p.requires_grad for p in self._module_parameters)):
561565
self._log_and_throw(
562566
RuntimeError,
563567
"DistributedDataParallel is not needed when a module "
@@ -570,10 +574,8 @@ def __init__(
570574
"device_ids can only be None or contain a single element.",
571575
)
572576

573-
self.is_multi_device_module = (
574-
len({p.device for p in module.parameters()}) > 1
575-
)
576-
distinct_device_types = {p.device.type for p in module.parameters()}
577+
self.is_multi_device_module = len({p.device for p in self._module_parameters}) > 1
578+
distinct_device_types = {p.device.type for p in self._module_parameters if p.device is not None}
577579
if len(distinct_device_types) != 1:
578580
self._log_and_throw(
579581
ValueError,
@@ -599,7 +601,7 @@ def __init__(
599601
"but got device_ids {}, output_device {}, and module parameters {}.".format(
600602
device_ids,
601603
output_device,
602-
{p.device for p in module.parameters()},
604+
{p.device for p in self._module_parameters},
603605
),
604606
)
605607

@@ -621,16 +623,12 @@ def __init__(
621623
self.static_graph = False
622624
self.dim = dim
623625
self.module = module
624-
self.device = list(self.module.parameters())[0].device
626+
self.device = list(self._module_parameters)[0].device
625627
self.broadcast_buffers = broadcast_buffers
626628
self.find_unused_parameters = find_unused_parameters
627629
self.require_backward_grad_sync = True
628630
self.require_forward_param_sync = True
629631
self.gradient_as_bucket_view = gradient_as_bucket_view
630-
if hasattr(module, "_ddp_params_and_buffers_to_ignore"):
631-
self.parameters_to_ignore = module._ddp_params_and_buffers_to_ignore
632-
else:
633-
self.parameters_to_ignore = []
634632

635633
self._use_replicated_tensor_module = (
636634
_ddp_with_replicated_tensor_enabled()
@@ -647,7 +645,7 @@ def __init__(
647645
)
648646

649647
# Check that a module does not have Uninitialized parameters
650-
for param in module.parameters():
648+
for param in self._module_parameters:
651649
if isinstance(param, torch.nn.parameter.UninitializedParameter):
652650
self._log_and_throw(
653651
RuntimeError,

0 commit comments

Comments
 (0)