Skip to content

Commit

Permalink
[FSDP][feature] optimizer state dict save and load (#537)
Browse files Browse the repository at this point in the history
Co-authored-by: Min Xu <24926999+min-xu-ai@users.noreply.github.com>
  • Loading branch information
sshleifer and min-xu-ai authored Mar 25, 2021
1 parent df493a2 commit 9474d75
Show file tree
Hide file tree
Showing 5 changed files with 401 additions and 11 deletions.
157 changes: 157 additions & 0 deletions fairscale/nn/data_parallel/fsdp_optim_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
"""These functions are used by FullyShardedDataParallel to help consolidate and shard optimizer states."""
import copy
from typing import Dict, Generator, List, Tuple

import torch


# This function helps shard a full optimizer state dict
def flatten_optim_state_dict(sd: Dict) -> Dict:
"""Shard a full optimizer state dict (called by FSDP.get_shard_from_optim_state_dict)"""
param_id_map = sd["param_id_map"]
num_local_params = len(set(param_id_map.values()))
if sd["state"]:
new_state: Dict = {local_id: {} for local_id in range(num_local_params)}
else:
new_state = {}
non_tensor_state = {}

# Populate `new_state["state"]`. (Assuming sd is sorted)
for expanded_pid, buffers in sd["state"].items():
consolidated_pid = param_id_map[expanded_pid]
for buffer_name, p in buffers.items():
if torch.is_tensor(p):
if buffer_name not in new_state[consolidated_pid]:
new_state[consolidated_pid][buffer_name] = []
new_state[consolidated_pid][buffer_name].append(p.reshape(-1))
else:
non_tensor_state[buffer_name] = p

# Now combine all tensors in each buffer using torch.cat().
for consolidated_pid, state in new_state.items():
for buffer_name, tensors in state.items():
new_state[consolidated_pid][buffer_name] = torch.cat(tensors)
new_state[consolidated_pid].update(non_tensor_state)
new_sd = {"state": new_state, "param_groups": sd["param_groups"]}

# add pointers from the `params` dict.
for pg_id, _ in enumerate(sd["param_groups"]):
# TODO: this list could be huge. Can we avoid materializing?
new_sd["param_groups"][pg_id]["params"] = list(range(num_local_params))

return new_sd


def check_param_counts_before_sharding(full_optim_state_dict: Dict, n_instances: int) -> None:
n_local_params_in_opt = len(set(full_optim_state_dict["param_id_map"].values()))
msg = (
f"Including itself, this model has {n_instances} nested instances. When the optimizer state was saved "
f"there were {n_local_params_in_opt}"
)
stateless = len(full_optim_state_dict["state"]) == 0
assert stateless or (n_instances == n_local_params_in_opt), msg


# All functions below here help saving the list of optimizer states, one from each rank
# build_unflat_state_dict is the interface used by FSDP
def _extract_non_tensor_state(combined_state: Dict[int, Dict[str, List]], param_id: int) -> Dict:
non_tensor_state = {} # This state is like the `step` count in Adam, not a tensor so we dont unpad or cat it.
for k, v in combined_state[param_id].items():
if torch.is_tensor(v[0]):
continue
elif len(set(v)) == 1:
non_tensor_state[k] = v[0]
else:
raise TypeError(f"Dont know how to consolidate optimizer param {k} with values {v}")
return non_tensor_state


def _combine_state(states: List[Dict]) -> Dict[int, Dict]:
combined_state = states[0]
for param_id in combined_state:
combined_state[param_id] = {k: [v] for k, v in combined_state[param_id].items()}
if len(states) == 1:
return combined_state

for rank, s in enumerate(states[1:]):
for param_id, param_state in s.items():
for k, tensor in param_state.items():
combined_state[param_id][k].append(tensor)
return combined_state


def _unflatten_optim_state(
combined_state: Dict[int, Dict], instance_list: List[torch.nn.Module], world_pad_info: List[List[List[int]]],
) -> Tuple[Dict[int, Dict], Dict[int, int]]:
# local ids are the keys in the current state (combined_state), (usually fewer)
# global ids will be the keys in the unflattened state
next_global_id = 0 # gets incremented
pad_info = {id: [s[id][0] for s in world_pad_info] for id in combined_state}
local_ids = [id for id in sorted(combined_state.keys())]

# non_tensor_state refers to entries in sd[state][param_id] that are not tensors, like "step".
# we check that these are identical across workers and then take the first
non_tensor_state = [_extract_non_tensor_state(combined_state, id) for id in combined_state]

# local corresponds to flattened, global corresponds to unflattened
num_unflat_params = [len(m._param_numels) for m in instance_list] # type: ignore
global_to_local_id = {}
for local_id, num_unflat in enumerate(num_unflat_params):
for _ in range(num_unflat):
global_to_local_id[next_global_id] = local_id
next_global_id += 1
if not combined_state:
return {}, global_to_local_id

# If the constant state is the same as the combined state, copy it N times, no unflattening needed.
unflat_state = {i: copy.deepcopy(non_tensor_state[0]) for i in range(sum(num_unflat_params))}
if non_tensor_state[0].keys() == combined_state[0].keys():
return unflat_state, global_to_local_id

local_to_global: Dict[int, List] = {i: [] for i in local_ids}
for g, l in global_to_local_id.items():
local_to_global[l].append(g)
# loop over parameters in state.
# Tensor state will be padded, concatenated, and restored to original shape with FlattenParamsWrapper.get_views
# get_views returns multiple tensors, each of which is a new parameter with a new "global" id.
for local_id in local_ids:
# undo the work of shard_parameters
for k, v in combined_state[local_id].items():
if k in non_tensor_state[local_id]:
continue
assert isinstance(v, list), f"got {k}: {v} for {local_id}"
v_unpad = [t[:-np] if np > 0 else t for t, np in zip(v, pad_info[local_id])]
flat_buffer = torch.cat(v_unpad)
param_views: Generator = instance_list[local_id].get_param_views(flat_buffer) # type: ignore
for global_id, param_view in zip(sorted(local_to_global[local_id]), param_views):
assert k not in unflat_state[global_id], f"already added {k} to {global_id} {local_id}"
unflat_state[global_id][k] = param_view

return unflat_state, global_to_local_id


def build_unflat_state_dict(instance_list: List[torch.nn.Module], world_optim_states: List[Dict]) -> Dict:
"""Build an unflattened optimizer state dict given a list of flattened optimizer state dicts from each rank."""
world_pad_info: List[List[List[int]]] = [s.pop("num_padded") for s in world_optim_states]
assert all(len(s) == len(instance_list) for s in world_pad_info)
assert all(len(s[0]) == 1 for s in world_pad_info)
param_groups = copy.deepcopy(world_optim_states[0]["param_groups"])
assert len(param_groups) == 1

# Aggregate from a list of dictionaries to a dictionary of lists
combined_state = _combine_state([x["state"] for x in world_optim_states])
del world_optim_states

# local ids are in the current state, global_ids will be in returned state.
unflat_state, global_to_local_id = _unflatten_optim_state(combined_state, instance_list, world_pad_info)
num_params = sum([len(m._param_numels) for m in instance_list]) # type: ignore
param_groups[0]["params"] = list(range(num_params)) # This could be a large list. #TODO: is it essential
return {
"state": dict(sorted(unflat_state.items())), # NOTE: this is probably already sorted
"param_id_map": global_to_local_id,
"param_groups": param_groups,
}
129 changes: 121 additions & 8 deletions fairscale/nn/data_parallel/fully_sharded_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,14 @@

from fairscale.nn.misc import FlattenParamsWrapper
from fairscale.nn.wrap import auto_wrap, default_auto_wrap_policy, enable_wrap
from fairscale.optim.utils import calc_grad_norm
from fairscale.optim.utils import broadcast_object, calc_grad_norm, recursive_copy_to_device
from fairscale.utils.containers import apply_to_tensors
from fairscale.utils.parallel import chunk_and_pad, enable_pytorch_sync_bn, validate_process_group
from fairscale.utils.reduce_scatter_bucketer import ReduceScatterBucketer
from fairscale.utils.state_dict import replace_by_prefix_

from . import fsdp_optim_utils as ou

if TYPE_CHECKING:
from collections import OrderedDict # noqa: F401

Expand Down Expand Up @@ -88,8 +90,8 @@ class FullyShardedDataParallel(nn.Module):
import torch
from fairscale.nn.auto_wrap import enable_wrap, auto_wrap
from fairscale.nn.data_parallel import FullyShardedDataParallel as FSDP
fsdp_params = dict(mixed_precision=True, flatten_parameters=True)
with enable_wrap(wrapper_cls=FSDP, **fsdp_params):
fsdp_params = dict(wrapper_cls=FSDP, mixed_precision=True, flatten_parameters=True)
with enable_wrap(**fsdp_params):
# Wraps layer in FSDP by default if within context
self.l1 = wrap(torch.nn.Linear(5, 5))
assert isinstance(self.l1, FSDP)
Expand Down Expand Up @@ -185,6 +187,8 @@ def __init__(
self.buffer_dtype = buffer_dtype or self.compute_dtype
self.move_grads_to_cpu = cpu_offload if move_grads_to_cpu is None else move_grads_to_cpu
self.bucket_cap_mb = bucket_cap_mb

self.numel_padded_per_param: List[int] = []
self.compute_device = compute_device

if self.fp32_reduce_scatter and not self.mixed_precision:
Expand Down Expand Up @@ -412,6 +416,7 @@ def _shard_parameters_(self) -> None:
allocate less memory for optimizer state, avoiding redundancy across
data parallel workers.
"""
self.numel_padded_per_param = []
for p in self.params:
assert not hasattr(p, "_is_sharded")
assert p.is_floating_point()
Expand All @@ -423,16 +428,19 @@ def _shard_parameters_(self) -> None:
p._orig_size = p.data.size()

if not p._is_sharded:
self.numel_padded_per_param.append(0)
continue
p._is_sharded = True

# Replace p.data with the relevant shard.
orig_data = p.data
p.data = self._get_shard(p.data)
p.data, num_padded = self._get_shard(p.data)
self.numel_padded_per_param.append(num_padded)
free_storage_(orig_data)
assert len(self.numel_padded_per_param) == len(self.params)

def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor:
"""Return the local shard of a given full tensor."""
def _get_shard(self, tensor: torch.Tensor) -> Tuple[torch.Tensor, int]:
"""Return the local shard of a full tensor."""
# Shard using torch.chunk to match all-gather/reduce-scatter.
chunks = list(torch.flatten(tensor).chunk(self.world_size))
while len(chunks) < self.world_size:
Expand All @@ -445,7 +453,7 @@ def _get_shard(self, tensor: torch.Tensor) -> torch.Tensor:
shard = chunks[self.rank].clone()
if num_to_pad > 0:
shard = F.pad(shard, [0, num_to_pad])
return shard
return shard, num_to_pad

def extra_repr(self) -> str:
return (
Expand Down Expand Up @@ -684,7 +692,7 @@ def summon_full_params(self, recurse: bool = True, volatile: bool = False) -> Ge
if not volatile:
# Copy any changes made to the full params back into
# the corresponding local shards.
local_shard = self._get_shard(full_tensor)
local_shard, _ = self._get_shard(full_tensor)
p._fp32_shard.copy_(local_shard.view_as(p._fp32_shard))
if safe_to_free:
free_storage_(full_tensor)
Expand Down Expand Up @@ -1346,6 +1354,111 @@ def assert_state(self, state: Union[TrainingState, List[TrainingState]]) -> None
traceback.print_stack()
raise ValueError(msg)

def _consolidate_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = None
) -> List[Dict]:
"""Update the consolidated state_dict list, one per rank.
Args:
optim (Optimizer): an optimizer instance for this FSDP rank. Its state is
used in the consolidation. However, its state is not modified.
recipient_rank (int): on which rank to materialize the full state dict.
None is a special value, which means that all ranks should have the state
Returns:
all_states (list[dict]) the optimizer state from each rank
.. warning: This needs to be called on all replicas"""
self._lazy_init()
# NOTE(SS): we do not support param groups yet, as they seem to break FSDP
# Pull the sharded state from all the other replicas
# Store all the states in order, rank by rank
should_collect_state = recipient_rank is None or (self.rank == recipient_rank)
all_states: List[Dict[str, Any]] = []
dummy_tensor = torch.tensor([0], dtype=torch.uint8, device=self.compute_device)
for rank in range(self.world_size):
if rank == self.rank:
sd = optim.state_dict()
sd["num_padded"] = [m.numel_padded_per_param for m in self._fsdp_instances]
else:
sd = dummy_tensor # type: ignore
sd = broadcast_object(sd, src_rank=rank, group=self.process_group, dist_device=self.compute_device) # type: ignore
if should_collect_state:
assert isinstance(sd, dict), f"{self.rank} received {type(sd)} from {rank}, expected dict"
all_states.append(recursive_copy_to_device(sd, non_blocking=False, device=torch.device("cpu")))

return all_states

def gather_full_optim_state_dict(
self, optim: torch.optim.Optimizer, recipient_rank: Optional[int] = 0
) -> Optional[Dict[str, Any]]:
"""Return the last known global optimizer state. The returned state is compatible with Pytorch, in that the
sharded properties are not exposed. Multiple parameter groups are not yet supported.
This should be called only on the root FSDP instance.
Different world_size groups in nested FSDP instances is not supported.
Args:
optim (Optimizer): an optimizer instance for this FSDP rank. Its state is
used in the consolidation. However, its state is not modified.
recipient_rank (int): on which rank to materialize the full state dict.
Returns:
a dict with two entries
* state - a dict holding gathered optimization state, 1 entry per unflat parameter
* param_groups - a dict containing the 1 parameter group
"""
if not self.flatten_parameters:
raise NotImplementedError("optim state dict requires flatten_parameters=True")
world_optim_states = self._consolidate_optim_state_dict(optim, recipient_rank)
if self.rank != recipient_rank and recipient_rank is not None:
return None
# Unify the shard states by concatenating tensors and unflattening params
new_state_dict = ou.build_unflat_state_dict(self._fsdp_instances, world_optim_states)
# TODO: check if this code supports nested instances with different world size
return new_state_dict

@property
def _fsdp_instances(self) -> List[nn.Module]:
"""Returns all fsdp modules in self.modules() including self."""
return [m for m in self.modules() if isinstance(m, FullyShardedDataParallel)]

def get_shard_from_optim_state_dict(self, full_optim_state_dict: Dict[str, Any]) -> Dict[str, Any]:
"""Get the portion of the optimizer state dict associated with the shard
This can be used to get the right sharded optimizer state to be loaded
into the sharded optimizer for this FSDP rank.
Args:
full_optim_state_dict (dict): consolidated optimizer state returned by ``gather_full_optim_state``, or loaded from a checkpoint.
Returns:
(dict): a shard of the optimizer state.
"""
# Assert nesting is the same as it was at save time
instance_list = self._fsdp_instances
assert all(
x.world_size == self.world_size for x in instance_list
), "all nested instances must have same world size"
ou.check_param_counts_before_sharding(full_optim_state_dict, len(instance_list))
if self.flatten_parameters:
full_optim_state_dict = ou.flatten_optim_state_dict(full_optim_state_dict)
assert len(full_optim_state_dict["state"]) in (0, len(instance_list))

# get the portion of dict associated with the shard, in place
for id, s in full_optim_state_dict["state"].items():
for k, v in s.items():
if torch.is_tensor(v):
v_shard, _ = self._get_shard(v)
else:
v_shard = v # dont shard entries that are not tensors
full_optim_state_dict["state"][id][k] = v_shard

return full_optim_state_dict


@torch.no_grad()
def cast_inputs_to_fp16(*args: Any, **kwargs: Any) -> Tuple[Any, Any]:
Expand Down
6 changes: 3 additions & 3 deletions fairscale/nn/misc/flatten_params_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,15 +122,15 @@ def _flatten_params(self, flat_param: Optional[nn.Parameter] = None) -> None:
# register the views as plain attributes
self._unflatten_params_as_views()

def _get_param_views(self, flat_param: Tensor) -> Generator:
def get_param_views(self, flat_param: Tensor) -> Generator:
return (t.view(s) for (t, s) in zip(flat_param.split(self._param_numels), self._param_shapes))

def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None:
assert self.is_flattened or flat_param is not None
self.is_flattened = False
flat_param = flat_param if flat_param is not None else self.flat_param

ps = self._get_param_views(flat_param)
ps = self.get_param_views(flat_param)
for (m, n), p in zip(self._param_infos, ps):
if hasattr(m, n):
delattr(m, n)
Expand All @@ -144,7 +144,7 @@ def _unflatten_params(self, flat_param: Optional[Tensor] = None) -> None:

def _unflatten_params_as_views(self) -> None:
assert self.is_flattened
ps = self._get_param_views(self.flat_param)
ps = self.get_param_views(self.flat_param)
for (m, n), p in zip(self._param_infos, ps):
setattr(m, n, p) # This will set as plain attr
for (m, n, shared_m, shared_n) in self._shared_param_infos:
Expand Down
1 change: 1 addition & 0 deletions tests/ci_test_list_3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ tests/nn/data_parallel/test_fsdp_summon_full_params.py
tests/nn/data_parallel/test_fsdp_input.py
tests/nn/data_parallel/test_fsdp_multiple_forward.py
tests/nn/data_parallel/test_fsdp_regnet.py
tests/nn/data_parallel/test_fsdp_optimizer_utils.py
tests/nn/data_parallel/test_sharded_ddp_features.py
tests/nn/data_parallel/test_sharded_ddp_pytorch_parity.py
tests/nn/pipe/skip/test_gpipe.py
Expand Down
Loading

0 comments on commit 9474d75

Please sign in to comment.