Skip to content

Commit

Permalink
broadcast the whole optimizer state to each rank
Browse files Browse the repository at this point in the history
Summary:
OSS removed the 'partition' key in their state dict to accommodate for changing partition size. This requires an update on the fairseq side to not look into the parameter partition, just broadcast everything, and let the optimizer on each rank decides which parameters are relevant.

This diff also needs D26419095 to function completely, and blefaudeux has made fixes upstream in facebookresearch/fairscale#383

Reviewed By: myleott

Differential Revision: D26382917

fbshipit-source-id: 95af1022be59e88814748acaee36a1a350f7dc5b
  • Loading branch information
Weiyi Zheng authored and harkash committed Feb 23, 2021
1 parent d7d045d commit dcfae95
Showing 1 changed file with 10 additions and 48 deletions.
58 changes: 10 additions & 48 deletions fairseq/optim/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@

from typing import Any, Dict

import torch
from fairseq.distributed import utils


try:
from fairscale.optim import OSS, utils
from fairscale.optim import OSS

_has_fairscale = True
except ImportError:
Expand Down Expand Up @@ -38,53 +38,15 @@ def broadcast_global_state_dict(
self, state_dict: Dict[str, Any]
) -> Dict[str, Any]:
"""
Broadcasts the relevant parts of a global state dict from rank 0 to
all other ranks.
Broadcasts the entire state_dict to all other ranks
each rank is responsible to load their own partition of data
"""
if self.rank == 0:

# Create template state dict for all other keys not related to sharding
template_state_dict = {
key: state_dict[key]
for key in state_dict
if key not in ("param_groups", "state")
}
template_state_dict["local_state_dict"] = True

for dst_rank in range(self.world_size):
# Get the dst_rank's param_groups shard
send_state = {
"param_groups": state_dict["param_groups"][
state_dict["partition"][dst_rank][0] : state_dict[
"partition"
][dst_rank][1]
],
"state": state_dict["state"][dst_rank],
}
send_state.update(template_state_dict)

if dst_rank == 0:
recv_state = send_state
else:
utils.broadcast_object(
send_state,
src_rank=0,
group=self.group,
dist_device=self._device,
)
else:
empty_buffer = torch.tensor([0], dtype=torch.uint8, device=self._device)
for dst_rank in range(1, self.world_size):
state = utils.broadcast_object(
empty_buffer,
src_rank=0,
group=self.group,
dist_device=self._device,
)
if dst_rank == self.rank:
recv_state = state

return recv_state
return utils.broadcast_object(
state_dict,
src_rank=0,
group=self.group,
dist_device=self._device,
)

torch_optimizer = optimizer.optimizer
optim_cls = type(torch_optimizer)
Expand Down

0 comments on commit dcfae95

Please sign in to comment.