-
Notifications
You must be signed in to change notification settings - Fork 281
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FSDP][feature] optimizer state dict save and load #537
Changes from 24 commits
ee088bb
ad7df24
ed7526a
ed75c59
44158f7
f82f3b6
1022e1e
75119c2
89947a4
8dcf0a8
2caf928
0b888fd
a2aacd0
0fc045d
d859734
3635277
dbb426f
f537632
a04b406
e5e91df
ea9d4b5
47e7cba
6cebcec
93c0857
c93d1db
13b0537
9d3dfb7
9f619b2
a4778b7
c77a9f7
aeefe69
d645337
75bdd3f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
# Copyright (c) Facebook, Inc. and its affiliates. | ||
# | ||
# This source code is licensed under the BSD license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
"""These files are used by fsdp to help consolidate and shard optimizer states.""" | ||
import copy | ||
from typing import Dict, Generator, List, Tuple | ||
|
||
import torch | ||
|
||
|
||
# This function helps shard an | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def flatten_optim_state_dict(sd: Dict) -> Dict: | ||
"""Called by FSDP.get_shard_from_optim_state_dict""" | ||
param_id_map = sd["param_id_map"] | ||
num_local_params = len(set(param_id_map.values())) | ||
if sd["state"]: | ||
new_state: Dict = {local_id: {} for local_id in range(num_local_params)} | ||
else: | ||
new_state = {} | ||
constant_state = {} | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# assumes sd sorted | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for expanded_pid, buffers in sd["state"].items(): | ||
consolidated_pid = param_id_map[expanded_pid] | ||
for buffer_name, p in buffers.items(): | ||
if torch.is_tensor(p): | ||
if buffer_name not in new_state[consolidated_pid]: | ||
new_state[consolidated_pid][buffer_name] = [] | ||
new_state[consolidated_pid][buffer_name].append(p.reshape(-1)) | ||
else: | ||
assert isinstance(p, (float, int)), f"unexpected type {type(p)} in optimizer state[{buffer_name}]" | ||
constant_state[buffer_name] = p | ||
# TODO(SS): THIS COULD BE WRONG. What if step is different for different params... At least check | ||
|
||
for consolidated_pid, state in new_state.items(): | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for buffer_name, tensors in state.items(): | ||
new_state[consolidated_pid][buffer_name] = torch.cat(tensors) | ||
new_state[consolidated_pid].update(constant_state) | ||
new_sd = {"state": new_state, "param_groups": sd["param_groups"]} | ||
|
||
for pg_id, _ in enumerate(sd["param_groups"]): | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# TODO: this list could be huge. Can we avoid materializing? | ||
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params)) | ||
|
||
return new_sd | ||
|
||
|
||
# All functions help saving the list of optimizer states, one from each rank | ||
# build_unflat_state_dict is the interface used by FSDP | ||
def _extract_constant_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict: | ||
constant_state = {} # This state is like step in Adam, not a tensor so we dont unpad or cat it. | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for k, v in combined_state[param_id].items(): | ||
|
||
if torch.is_tensor(v[0]): | ||
continue | ||
elif len(set(v)) == 1: | ||
constant_state[k] = v[0] | ||
else: | ||
raise TypeError(f"Dont know how to expand optimizer param {k} with values {v}") | ||
return constant_state | ||
|
||
|
||
def _combine_tensor_optim_state(states: List[Dict]) -> Dict[int, Dict]: | ||
combined_state = states[0] | ||
for param_id in combined_state: | ||
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()} | ||
if len(states) == 1: | ||
return combined_state | ||
|
||
for rank, s in enumerate(states[1:]): | ||
for param_id, param_state in s.items(): | ||
for k, tensor in param_state.items(): | ||
combined_state[param_id][k].append(tensor) | ||
return combined_state | ||
|
||
|
||
def _unflatten_optim_state( | ||
combined_state: Dict[int, Dict], instance_list: List[torch.nn.Module], world_pad_info: List[List[List[int]]], | ||
) -> Tuple[Dict[int, Dict], Dict[int, int]]: | ||
local_to_global_param_id: Dict[int, List[int]] = {} | ||
next_global_id = 0 # gets incremented | ||
unflat_state = {} | ||
pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state} | ||
|
||
# constant_state refers to entries in sd[state][param_id] that are not tensors, like "step" | ||
# we check that these are identical across workers and then take the first | ||
constant_state = [_extract_constant_state(combined_state, id) for id in combined_state] | ||
|
||
# loop over parameters in state. | ||
# Tensor state will be padded, concatenated, and then restored to their original | ||
# shape with FlattenParamsWrapper.get_views | ||
# get_views multiple tensors, each of which is a new parameter with a new "global" id. | ||
for local_id in combined_state: | ||
local_to_global_param_id[local_id] = [] | ||
# undo the work of shard_parameters | ||
for k, v in combined_state[local_id].items(): | ||
if k in constant_state[local_id]: | ||
continue | ||
assert isinstance(v, list), f"got {k}: {v} for {local_id}" | ||
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])] | ||
flat_buffer = torch.cat(v_unpad) | ||
param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore | ||
for i, param_view in enumerate(param_views): | ||
if i == len(local_to_global_param_id[local_id]): # make a new ID | ||
local_to_global_param_id[local_id].append(next_global_id) | ||
next_global_id += 1 | ||
global_id = local_to_global_param_id[local_id][i] | ||
if global_id not in unflat_state: | ||
unflat_state[global_id] = copy.deepcopy(constant_state[local_id]) | ||
|
||
assert k not in unflat_state[global_id], f"already added {k} to new[{global_id}]" | ||
unflat_state[global_id][k] = param_view | ||
|
||
global_to_local_id = { | ||
new_id: old_pid for old_pid, global_ids in local_to_global_param_id.items() for new_id in global_ids | ||
} | ||
|
||
return unflat_state, global_to_local_id | ||
|
||
|
||
def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add a docstring? |
||
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states] | ||
assert all(len(s) == len(instance_list) for s in world_pad_info) | ||
assert all(len(s[0]) == 1 for s in world_pad_info) | ||
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"]) | ||
assert len(param_groups) == 1 | ||
# combined_state refers to tensor values in sd[state][param_id]. | ||
# Here we just aggregate them into a dictionary of lists (from a list of dictionaries) | ||
combined_state = _combine_tensor_optim_state([x["state"] for x in world_optim_states]) | ||
# cleanup all_optimizer_states_list | ||
del world_optim_states | ||
new_state_dict = {"state": {}, "param_groups": param_groups} | ||
# local ids are in the current state, global_ids will be in returned state. | ||
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info) | ||
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore | ||
new_state_dict["param_groups"][0]["params"] = list(range(num_params)) | ||
new_state_dict["param_id_map"] = global_to_local_id | ||
# Make sure that the parameters are sorted in the state, as expected for a pytorch dict | ||
new_state_dict["state"] = dict(sorted(unflat_state.items())) | ||
return new_state_dict | ||
|
||
|
||
def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None: | ||
n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values())) | ||
msg = ( | ||
f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved " | ||
f"there were {n_local_params_in_opt}" | ||
) | ||
stateless = len(full_optim_state_dict["state"]) == 0 | ||
assert stateless or (n_instances == n_local_params_in_opt), msg |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,9 +19,10 @@ | |
from torch.nn import Parameter | ||
import torch.nn.functional as F | ||
|
||
import fairscale.nn.data_parallel.fsdp_optim_utils as ou | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. relative import like `import .fsdp_optim_utils as ou" is more portable? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. got it. perhaps from . import fsdp_optim_utils as ou? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That works! |
||
from fairscale.nn.misc import FlattenParamsWrapper | ||
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap | ||
from fairscale.optim.utils import calc_grad_norm | ||
from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device | ||
from fairscale.utils.containers import apply_to_tensors | ||
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group | ||
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer | ||
|
@@ -88,8 +89,8 @@ class FullyShardedDataParallel(nn.Module): | |
import torch | ||
from fairscale.nn.auto_wrap import enable_wrap, auto_wrap | ||
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP | ||
fsdp_params = dict(mixed_precision=True, flatten_parameters=True) | ||
with enable_wrap(wrapper_cls=FSDP, **fsdp_params): | ||
fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for fixing the doc here! |
||
with enable_wrap(**fsdp_params): | ||
# Wraps layer in FSDP by default if within context | ||
self.l1 = wrap(torch.nn.Linear(5, 5)) | ||
assert isinstance(self.l1, FSDP) | ||
|
@@ -185,6 +186,8 @@ def __init__( | |
self.buffer_dtype = buffer_dtype or self.compute_dtype | ||
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu | ||
self.bucket_cap_mb = bucket_cap_mb | ||
|
||
self.num_padded: List[int] = [] | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.compute_device = compute_device | ||
|
||
if self.fp32_reduce_scatter and not self.mixed_precision: | ||
|
@@ -412,6 +415,7 @@ def _shard_parameters_(self) -> None: | |
allocate less memory for optimizer state, avoiding redundancy across | ||
data parallel workers. | ||
""" | ||
self.num_padded = [] | ||
for p in self.params: | ||
assert not hasattr(p, "_is_sharded") | ||
assert p.is_floating_point() | ||
|
@@ -423,16 +427,19 @@ def _shard_parameters_(self) -> None: | |
p._orig_size = p.data.size() | ||
|
||
if not p._is_sharded: | ||
self.num_padded.append(0) | ||
continue | ||
p._is_sharded = True | ||
|
||
# Replace p.data with the relevant shard. | ||
orig_data = p.data | ||
p.data = self._get_shard(p.data) | ||
p.data, num_padded = self._get_shard(p.data) | ||
self.num_padded.append(num_padded) | ||
free_storage_(orig_data) | ||
assert len(self.num_padded) == len(self.params) | ||
|
||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor: | ||
"""Return the local shard of a given full tensor.""" | ||
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]: | ||
"""Return the local shard of a full tensor.""" | ||
# Shard using torch.chunk to match all-gather/reduce-scatter. | ||
chunks = list(torch.flatten(tensor).chunk(self.world_size)) | ||
while len(chunks) < self.world_size: | ||
|
@@ -445,7 +452,7 @@ def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor: | |
shard = chunks[self.rank].clone() | ||
if num_to_pad > 0: | ||
shard = F.pad(shard, [0, num_to_pad]) | ||
return shard | ||
return shard, num_to_pad | ||
|
||
def extra_repr(self) -> str: | ||
return ( | ||
|
@@ -684,7 +691,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge | |
if not volatile: | ||
# Copy any changes made to the full params back into | ||
# the corresponding local shards. | ||
local_shard = self._get_shard(full_tensor) | ||
local_shard, _ = self._get_shard(full_tensor) | ||
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard)) | ||
if safe_to_free: | ||
free_storage_(full_tensor) | ||
|
@@ -1346,6 +1353,83 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None | |
traceback.print_stack() | ||
raise ValueError(msg) | ||
|
||
def _consolidate_optim_state_dict( | ||
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None | ||
) -> List[Dict]: | ||
"""Update the consolidated state_dict list, one per rank. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. should this be called only on the root FSDP instance? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, more specifically it should be called on the instance that was the argument to |
||
|
||
Args: | ||
recipient_rank (int): on which rank to materialize the full state dict. | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
None is a special value, which means that all ranks should have the state | ||
min-xu-ai marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
.. warning: This needs to be called on all replicas""" | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self._lazy_init() | ||
# NOTE(SS): we do not support param groups yet, as they seem to break FSDP | ||
# Pull the sharded state from all the other replicas | ||
# Store all the states in order, rank by rank | ||
should_collect_state = recipient_rank is None or (self.rank == recipient_rank) | ||
all_states: List[Dict[str, Any]] = [] | ||
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device) | ||
for rank in range(self.world_size): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there might be complications here when nested FSDP instance have different world_size, right? For example, if BN layers are in their own world_size == 1 process groups, then we collect duplicated states for them? add a TODO? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added TODO in the caller |
||
if rank == self.rank: | ||
sd = optim.state_dict() | ||
sd["num_padded"] = [m.num_padded for m in self._fsdp_instances] | ||
else: | ||
sd = dummy_tensor # type: ignore | ||
sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device) # type: ignore | ||
if should_collect_state: | ||
assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict" | ||
all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu"))) | ||
|
||
return all_states | ||
|
||
def gather_full_optim_state_dict( | ||
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = 0 | ||
) -> Optional[Dict[str, Any]]: | ||
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the | ||
sharded properties are not exposed. Multiple parameter groups are not yet supported. | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
Returns: | ||
a dict with two entries | ||
* state - a dict holding current optimization state. Its content | ||
differs between optimizer classes. | ||
* param_groups - a dict containing all parameter groups | ||
|
||
""" | ||
if not self.flatten_parameters: | ||
raise NotImplementedError("optim state dict requires flatten_parameters=True") | ||
world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank) | ||
if self.rank != recipient_rank and recipient_rank is not None: | ||
return None | ||
# Unify the shard states by concatenating tensors and unflattening params | ||
new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states) | ||
return new_state_dict | ||
|
||
@property | ||
def _fsdp_instances(self) -> List[nn.Module]: | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""Returns all fsdp modules in self.modules() including self.""" | ||
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)] | ||
|
||
def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]: | ||
"""Get the portion of the optimizer state dict associated with the shard""" | ||
sshleifer marked this conversation as resolved.
Show resolved
Hide resolved
|
||
# Assert nesting is the same as it was at save time | ||
instance_list = self._fsdp_instances | ||
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list)) | ||
if self.flatten_parameters: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does this assume all inner FSDP instances also have flatten == True? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, will assert |
||
full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict) | ||
assert len(full_optim_state_dict["state"]) in (0, len(instance_list)) | ||
|
||
# get the portion of dict associated with the shard, in place | ||
for id, s in full_optim_state_dict["state"].items(): | ||
for k, v in s.items(): | ||
if torch.is_tensor(v): | ||
v_shard, _ = self._get_shard(v) | ||
else: | ||
v_shard = v # dont shard entries that are not tensors | ||
full_optim_state_dict["state"][id][k] = v_shard | ||
|
||
return full_optim_state_dict | ||
|
||
|
||
@torch.no_grad() | ||
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,15 +122,15 @@ def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None: | |
# register the views as plain attributes | ||
self._unflatten_params_as_views() | ||
|
||
def _get_param_views(self, flat_param: Tensor) -> Generator: | ||
def get_param_views(self, flat_param: Tensor) -> Generator: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. since this is becoming an public method, can you please:
|
||
return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes)) | ||
|
||
def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None: | ||
assert self.is_flattened or flat_param is not None | ||
self.is_flattened = False | ||
flat_param = flat_param if flat_param is not None else self.flat_param | ||
|
||
ps = self._get_param_views(flat_param) | ||
ps = self.get_param_views(flat_param) | ||
for (m, n), p in zip(self._param_infos, ps): | ||
if hasattr(m, n): | ||
delattr(m, n) | ||
|
@@ -144,7 +144,7 @@ def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None: | |
|
||
def _unflatten_params_as_views(self) -> None: | ||
assert self.is_flattened | ||
ps = self._get_param_views(self.flat_param) | ||
ps = self.get_param_views(self.flat_param) | ||
for (m, n), p in zip(self._param_infos, ps): | ||
setattr(m, n, p) # This will set as plain attr | ||
for (m, n, shared_m, shared_n) in self._shared_param_infos: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
❤️ this