Skip to content

Commit

Permalink
[fix][minor] Change empty shard handling for OSS, do not rely on asse…
Browse files Browse the repository at this point in the history
…rts (#460)

* change empty shard handling for OSS, do not rely on asserts
* code review
  • Loading branch information
blefaudeux authored Mar 5, 2021
1 parent f565d44 commit d1fab39
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 17 deletions.
17 changes: 7 additions & 10 deletions fairscale/optim/oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,13 +140,6 @@ def partition_parameters(self) -> List[List[dict]]:
param_group_rank["params"] = params
self._partition_parameters[rank].append(param_group_rank)

assert min(sum(len(pg["params"]) for pg in partition) for partition in self._partition_parameters) > 0, (
"One or more empty shards detected, the world size is too big or the model too small.\n"
+ "Please reduce your world size if this is the model you would like to train\n"
+ f"Current world size: {self.world_size}\n"
+ "Current number of parameters: {}".format(sum(len(pg["params"]) for pg in self.param_groups))
)

return self._partition_parameters

@property
Expand Down Expand Up @@ -552,8 +545,11 @@ def _broadcast_params(self) -> None:

for device in self.buckets.keys():
for src_rank, bucket in enumerate(self.buckets[device]):
global_src_rank = self.get_global_rank(self.group, src_rank)
last_work_handle = dist.broadcast(tensor=bucket, src=global_src_rank, group=self.group, async_op=True)
if bucket.numel() > 0:
global_src_rank = self.get_global_rank(self.group, src_rank)
last_work_handle = dist.broadcast(
tensor=bucket, src=global_src_rank, group=self.group, async_op=True
)

# Only check on the last handle, they're all inlined on the same CUDA stream
if last_work_handle:
Expand Down Expand Up @@ -597,4 +593,5 @@ def _setup_flat_buffers(self) -> None:
else:
self.buckets[device][dst_rank] = bucket
else:
self.buckets[device].append(torch.zeros(1, device=device))
# This rank has an empty shard, that's fine
self.buckets[device].append(torch.zeros(0, device=device))
26 changes: 19 additions & 7 deletions tests/optim/test_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,19 +262,31 @@ def test_zero_grad():
mp.spawn(run_test_zero_grad, args=(world_size, temp_file_name), nprocs=world_size, join=True)


def run_test_catch_empty_shardd(rank, world_size, tempfile_name):
dist_init(rank, world_size, tempfile_name, backend="gloo")
def run_test_empty_shard(rank, world_size, tempfile_name, backend):
dist_init(rank, world_size, tempfile_name, backend=backend)
m = torch.nn.Linear(1, 1)
with pytest.raises(AssertionError):
_ = optim.OSS(m.parameters(), lr=0.1)
x = torch.rand(20, 1)

if torch.cuda.is_available():
m = m.to(rank)
x = x.to(rank)

o = optim.OSS(m.parameters(), lr=0.1)
y = m(x).sum()
y.backward()
o.step()

dist.destroy_process_group()


def test_empty_shard():
@pytest.mark.parametrize("backend", ["gloo", "nccl"])
def test_empty_shard(backend):
world_size = 4

mp.spawn(run_test_catch_empty_shardd, args=(world_size, tempfile.mkstemp()[1]), nprocs=world_size, join=True)
if torch.cuda.is_available() and torch.cuda.device_count() < world_size:
world_size = min(world_size, torch.cuda.device_count())
if world_size == 1 or (backend == "nccl" and not torch.cuda.is_available()):
pytest.skip("Not enough GPUs to test with NCCL, or CUDA not present")
mp.spawn(run_test_empty_shard, args=(world_size, tempfile.mkstemp()[1], backend), nprocs=world_size, join=True)


def run_test_step(rank, world_size, tempfile_name):
Expand Down

0 comments on commit d1fab39

Please sign in to comment.