-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][feature] optimizer state dict save and load #537
Conversation
@@ -1346,6 +1352,244 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None | |||
traceback.print_stack() | |||
raise ValueError(msg) | |||
|
|||
# Optim State dict functions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I considered moving these to a separate FSDPOptimizerMixin
in fsdp_optimizer_utils.py
, but decided it wasn't really a mixin since it depends heavily on FSDP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
About half-way through, leaving initial comments and will post rest in second batch
if rank == self.rank: | ||
sd = optim.state_dict() | ||
sd["num_padded"] = [m.num_padded for m in self.modules() if isinstance(m, FullyShardedDataParallel)] | ||
if should_collect_state: | ||
_all_optimizer_states.append( | ||
recursive_copy_to_device(sd, non_blocking=True, device=torch.device("cpu")) | ||
) | ||
|
||
# Sync with other replicas | ||
state_to_share = ( | ||
sd if should_send_state else torch.tensor([0], dtype=torch.uint8, device=_default_device) | ||
) | ||
broadcast_object( | ||
state_to_share, src_rank=self.rank, group=self.process_group, dist_device=_default_device, | ||
) | ||
else: | ||
# Fetch the optim state from the other replicas | ||
replica_state = broadcast_object( | ||
torch.tensor([0], dtype=torch.uint8, device=_default_device), | ||
src_rank=rank, | ||
group=self.process_group, | ||
dist_device=_default_device, | ||
) | ||
|
||
if should_collect_state: | ||
_all_optimizer_states.append( | ||
recursive_copy_to_device(replica_state, non_blocking=True, device=torch.device("cpu")) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can this be rearranged to remove some duplication? Something like:
for rank in range(self.world_size):
if rank == self.rank:
state = optim.state_dict()
sd["num_padded"] = ...
state = broadcast_object(state, src_rank=rank, ...)
else:
state = broadcast_object(None, src_rank=rank, ...)
if should_collect_state:
_all_optimizer_states.append(recursive_copy_to_device(state, device=torch.device("cpu"))
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just copy pasted this func from OSS. I think the reason for the extra append is to save useless communication from recipient_rank
to recipient_rank
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have the simplified implem working with torch.distributed.broadcast_object_list
.
I no longer need compute_device. Still calling lazy_init_
for safety.
Looks like really cool stuff. Some high level questions about the context:
It seems that this is only needed when we change world size between save/restore? If the world size not changed, normal save/restore with the only the sharded data is OK? |
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we envision a use case for calling consolidate_optim_state_dict
without calling gather_full_optim_state_dict
?
If not, perhaps simplify interface to:
fsdp = FSDP(world_size=4)
optim = Adam(fsdp.parameters())
full_state_dict = fsdp.gather_full_optim_state_dict(optim, recipient_rank=-1)
# combined_state refers to tensor values in sd[state][param_id]. | ||
# Here we just aggregate them into a list inside the dictionary from a list of dictionaries. | ||
combined_state = self._combine_tensor_optim_state( | ||
[x["state"] for x in self._all_optimizer_states], self.world_size | ||
) | ||
|
||
# constant_state refers to entries in sd[state][param_id] that are not tensors, like "step" | ||
# we check that these are identical across workers and then take the first | ||
constant_state = [self._extract_constant_state(combined_state, id) for id in combined_state] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
these comments/helper methods are very nice 😄
|
||
if next_global_param_id == 0: # stateless optimizer | ||
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore | ||
new_state_dict["param_groups"][pg_id]["params"] = list(range(num_params)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this list could be quite large, right? I guess this only affects SGD w/o momentum, but I wonder if there's a more compact way. Let's not worry about it for now, but perhaps put a note or TODO to make it more efficient
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you talking about list(range(num_params))
? If so, it affects both cases.
I'll leave a TODO
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
"""These files are used by fsdp to help consolidate and shard optimizer states.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️ this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like super solid work. I added some minor comments. I didn't check the logic in detail mainly because I have two high level questions:
- should we consider some optimizer wrapper that work together with fsdp to get the full state? It seems right now everything is in fsdp. Will an optimizer wrapper help more? I haven't thought through this.
- I have been thinking that fsdp should support a "streaming" mode for full state so that no single rank's work need to hold all state (non-shard state) in memory. Should this PR try to do streaming to avoid overly big state?
Both 1 and 2 above are kind of independent of PR. Just wanted to put them out there in case they are helpful. If not, just let me know and I will dive deep into this version of the code and give it a more detailed review. Thanks!
@@ -19,9 +19,10 @@ | |||
from torch.nn import Parameter | |||
import torch.nn.functional as F | |||
|
|||
import fairscale.nn.data_parallel.fsdp_optim_utils as ou |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
relative import like `import .fsdp_optim_utils as ou" is more portable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SyntaxError: invalid syntax
:(
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
got it. perhaps from . import fsdp_optim_utils as ou?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That works!
@@ -88,8 +89,8 @@ class FullyShardedDataParallel(nn.Module): | |||
import torch | |||
from fairscale.nn.auto_wrap import enable_wrap, auto_wrap | |||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP | |||
fsdp_params = dict(mixed_precision=True, flatten_parameters=True) | |||
with enable_wrap(wrapper_cls=FSDP, **fsdp_params): | |||
fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing the doc here!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed the test file first. It looks very good. Minor comments.
|
||
from parameterized import parameterized | ||
import torch | ||
from torch.optim import SGD, Adadelta, Adam # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we usually don't have typing in test files. so "type: ignore" is not needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was getting "torch.optim has no Attribute Adadelta" from mypy without this, using
mypy --ignore-missing-imports --scripts-are-modules --pretty .
from fs_test
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see. magic mypy. I thought it would skip the whole file since there isn't any type annotation in it.
try: | ||
fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) | ||
optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) | ||
except TypeError: # AdaScale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do you actually mean "AdaScale" here? I don't see AdaScale being used here in this test.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, nice catch
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise | ||
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is interesting. thanks for the comment. what the usual value for duration
? I am surprised that it is somehow connected with world_size
, which is not in the unit of seconds even.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It takes longer to gather from 8 nodes than 4 than 2.
This actually takes 4 ms, but I accidentally regressed it during development and caused it to take 8 seconds for world size 2, 13 for world size 4.
Now that it's fixed I want to prevent it happening again, agreed that the units are arbitrary.
sum([first_tensor_shape(v) for k, v in sd["state"].items()]), | ||
sum([first_tensor_shape(v) for k, v in unwrapped_sd["state"].items()]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps norm
will be slightly better than sum
for comparison in case both tensors sum to the same values? same with line 110, 111.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This just checks that we have the same num elements as the base model after unflattening.
I renamed first_tensor_shape
-> first_tensor_numel
to make it clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reviewed fsdp changes. I am not sure if nested FSDP cases are well supported by this change.
- the APIs are really only intended for the root instance?
- root and all inner instances should have flatten == True?
- all instance needs to have world_size == default world_size?
If so, can you assert those are the cases in the APIs so that we don't accidentally produce incorrect optim states or crash with non-obvious errors?
def _consolidate_optim_state_dict( | ||
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None | ||
) -> List[Dict]: | ||
"""Update the consolidated state_dict list, one per rank. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be called only on the root FSDP instance?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, more specifically it should be called on the instance that was the argument to optimizer(model.parameters()
. Are there other cases?
should_collect_state = recipient_rank is None or (self.rank == recipient_rank) | ||
all_states: List[Dict[str, Any]] = [] | ||
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device) | ||
for rank in range(self.world_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there might be complications here when nested FSDP instance have different world_size, right? For example, if BN layers are in their own world_size == 1 process groups, then we collect duplicated states for them? add a TODO?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added TODO in the caller
# Assert nesting is the same as it was at save time | ||
instance_list = self._fsdp_instances | ||
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list)) | ||
if self.flatten_parameters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does this assume all inner FSDP instances also have flatten == True?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will assert
@@ -122,15 +122,15 @@ def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None: | |||
# register the views as plain attributes | |||
self._unflatten_params_as_views() | |||
|
|||
def _get_param_views(self, flat_param: Tensor) -> Generator: | |||
def get_param_views(self, flat_param: Tensor) -> Generator: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
since this is becoming an public method, can you please:
- add docstring with proper doc
- assert flat_param is valid before using it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the comments!
def _consolidate_optim_state_dict( | ||
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None | ||
) -> List[Dict]: | ||
"""Update the consolidated state_dict list, one per rank. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, more specifically it should be called on the instance that was the argument to optimizer(model.parameters()
. Are there other cases?
should_collect_state = recipient_rank is None or (self.rank == recipient_rank) | ||
all_states: List[Dict[str, Any]] = [] | ||
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device) | ||
for rank in range(self.world_size): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added TODO in the caller
# Assert nesting is the same as it was at save time | ||
instance_list = self._fsdp_instances | ||
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list)) | ||
if self.flatten_parameters: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, will assert
|
||
from parameterized import parameterized | ||
import torch | ||
from torch.optim import SGD, Adadelta, Adam # type: ignore |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was getting "torch.optim has no Attribute Adadelta" from mypy without this, using
mypy --ignore-missing-imports --scripts-are-modules --pretty .
from fs_test
.
try: | ||
fsdp_optim = optim_fn(fsdp.parameters(), lr=0.01,) | ||
optim_unwrapped = optim_fn(unwrapped_model.parameters(), lr=0.01) | ||
except TypeError: # AdaScale |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, nice catch
# Switching from fairscale.optim.utils.broadcast_object to torch.broadcast_object_list will cause this to raise | ||
assert duration < fsdp.world_size, f"gather optim state took {duration} seconds, suspect change in _consolidate" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It takes longer to gather from 8 nodes than 4 than 2.
This actually takes 4 ms, but I accidentally regressed it during development and caused it to take 8 seconds for world size 2, 13 for world size 4.
Now that it's fixed I want to prevent it happening again, agreed that the units are arbitrary.
sum([first_tensor_shape(v) for k, v in sd["state"].items()]), | ||
sum([first_tensor_shape(v) for k, v in unwrapped_sd["state"].items()]), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This just checks that we have the same num elements as the base model after unflattening.
I renamed first_tensor_shape
-> first_tensor_numel
to make it clearer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finished reviewing. Great step forward. I wished there were more comments in the fsdp_optim_utils.py for me follow along better. I tried my best and it seems to make sense. It might be able to be simplified and individually tested. But we can iterated on them later as we learn more.
return unflat_state, global_to_local_id | ||
|
||
|
||
def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add a docstring?
Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com>
I'm gunna merge this tomorrow AM unless further comments or CI failure @myleott |
Overview
Future Work
flatten_parameters=False
On the fairseq side, I tested running with 4 gpus and loading with 2 and this worked.
Assumptions
(0)
flatten_parameters=True
(1) if there is a tensor in optimizer state, it is the same size and corresponds to a tensor in model state. If there are singleton tensors in the optimizer, or tensors that correspond to the average update for a column of params (so shaped differently), things will break.
(2) We assume that these two lists are the same if we account for padding:
we use this assumption to call
mlist[i].get_params_view(flat_param=params_unpadded[i])
.New overhead introduced
_get_shard
now returns how many padding elements it introduced.