Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FSDP] Fix for optim state dict #102901

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion test/distributed/fsdp/test_fsdp_hybrid_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
import torch.nn as nn

from torch.distributed.distributed_c10d import _rank_not_in_group
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
ShardingStrategy,
StateDictType,
)
from torch.distributed.fsdp._init_utils import HYBRID_SHARDING_STRATEGIES
from torch.distributed.fsdp.wrap import ModuleWrapPolicy
from torch.nn import TransformerDecoderLayer, TransformerEncoderLayer
Expand Down Expand Up @@ -76,6 +80,9 @@ def __init__(self):
self.lin2 = nn.Linear(10, 10)
self.lin3 = nn.Linear(10, 10)

def forward(self, x):
return self.lin3(self.lin2(self.lin1(x)))


class ShardingStrategyMode(Enum):
ALL_HYBRID_SHARD = auto()
Expand Down Expand Up @@ -144,6 +151,56 @@ def test_hybrid_shard_pg_mismatch_raises(self):
):
model(inp)

@skip_if_lt_x_gpu(4)
def test_hsdp_save_load_state_dict(self):
model = MyModel().cuda()
num_node_devices = torch.cuda.device_count()
shard_rank_lists = list(range(0, num_node_devices // 2)), list(
range(num_node_devices // 2, num_node_devices)
)
shard_groups = (
dist.new_group(shard_rank_lists[0]),
dist.new_group(shard_rank_lists[1]),
)
my_shard_group = (
shard_groups[0] if self.rank in shard_rank_lists[0] else shard_groups[1]
)
my_replicate_group = None
my_rank = self.rank
# Create groups like (0, 4), (1, 5), (2, 6) etc and assign appropriately
shard_factor = len(shard_rank_lists[0])
for i in range(num_node_devices // 2):
replicate_group_ranks = list(range(i, num_node_devices, shard_factor))
replicate_group = dist.new_group(replicate_group_ranks)
if my_rank in replicate_group_ranks:
my_replicate_group = replicate_group

fsdp_ctor = partial(
FSDP,
sharding_strategy=ShardingStrategy.HYBRID_SHARD,
use_orig_params=True,
process_group=(my_shard_group, my_replicate_group),
)
model = fsdp_ctor(model)
optim = torch.optim.AdamW(model.parameters())
# Initialize optimizer states
model(torch.randn(2, 10)).sum().backward()
optim.step()
shard_g = model.process_group
replicate_g = model._inter_node_state.process_group
assert shard_g == my_shard_group
assert replicate_g == my_replicate_group
with FSDP.state_dict_type(model, StateDictType.SHARDED_STATE_DICT):
msd = model.state_dict()
osd = FSDP.optim_state_dict(model, optim)

load_model = fsdp_ctor(MyModel().cuda())
load_optim = torch.optim.AdamW(load_model.parameters())
with FSDP.state_dict_type(load_model, StateDictType.SHARDED_STATE_DICT):
load_model.load_state_dict(msd)
FSDP.optim_state_dict_to_load(load_model, load_optim, osd)
load_optim.load_state_dict(osd)

@skip_if_lt_x_gpu(2)
def test_invalid_pg_specification_raises(self):
pol = ModuleWrapPolicy({nn.Linear})
Expand Down
2 changes: 1 addition & 1 deletion torch/distributed/fsdp/_optim_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,7 +1497,7 @@ def _all_gather_optim_state(
object_list: List[StateInfo] = [
processed_state for _ in range(fsdp_state.world_size)
]
dist.all_gather_object(object_list, processed_state)
dist.all_gather_object(object_list, processed_state, group=fsdp_state.process_group)

# Convert the gathered, pre-processed state of each rank to the original one.
gathered_state: Dict[str, Any] = {}
Expand Down