Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add util for ram efficient loading of model when using fsdp #25107

Merged
merged 16 commits into from
Aug 17, 2023
Merged
67 changes: 48 additions & 19 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@
is_torch_tpu_available,
logging,
replace_return_docstrings,
strtobool,
)
from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files
from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled, is_torch_fx_proxy
Expand Down Expand Up @@ -106,6 +107,14 @@
_init_weights = True


def is_fsdp_enabled():
return strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1


def is_fsdp_enabled_and_dist_rank_0():
return is_fsdp_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0


if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
Expand Down Expand Up @@ -458,7 +467,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
)
return safe_load_file(checkpoint_file)
try:
if is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0:
if (
(is_deepspeed_zero3_enabled() or is_fsdp_enabled)
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > 0
):
map_location = "meta"
else:
map_location = "cpu"
Expand Down Expand Up @@ -2225,6 +2238,9 @@ def from_pretrained(
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)

if is_fsdp_enabled():
low_cpu_mem_usage = True

if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
Expand Down Expand Up @@ -3180,7 +3196,8 @@ def _fix_key(key):
model_buffers = {".".join([prefix, key]) for key in model_buffers}
unexpected_keys = list(unexpected_keys - model_buffers)

if device_map is None:
model.tie_weights()
if device_map is None and not is_fsdp_enabled():
ptrs = collections.defaultdict(list)
for name, tensor in model.state_dict().items():
id_tensor = id_tensor_storage(tensor)
Expand Down Expand Up @@ -3385,23 +3402,35 @@ def _find_mismatched_keys(
)

if low_cpu_mem_usage:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
is_quantized=is_quantized,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
is_quantized=is_quantized,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
else:
for key, param in model_to_load.state_dict().items():
if param.device == torch.device("meta"):
if not (is_quantized):
set_module_tensor_to_device(
model, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
set_module_quantized_tensor_to_device(
model, key, "cpu", torch.empty(*param.size(), dtype=dtype)
)
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

Expand Down
20 changes: 10 additions & 10 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -465,10 +465,6 @@ def __init__(
):
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST

self.forward_prefetch = False
if self.args.fsdp_config.get("forward_prefetch", False):
self.forward_prefetch = True

self.limit_all_gathers = False
if self.args.fsdp_config.get("limit_all_gathers", False):
self.limit_all_gathers = True
Expand Down Expand Up @@ -1379,12 +1375,12 @@ def _wrap_model(self, model, training=True, dataloader=None):
auto_wrapper_callable = None
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
"fsdp_transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
"transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
)

if self.args.fsdp_config["fsdp_min_num_params"] > 0:
if self.args.fsdp_config["min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"]
)
elif fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = set()
Expand Down Expand Up @@ -1517,7 +1513,12 @@ def train(
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})")

if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled:
if (
resume_from_checkpoint is not None
and not is_sagemaker_mp_enabled()
and not self.is_deepspeed_enabled
and not self.is_fsdp_enabled
):
self._load_from_checkpoint(resume_from_checkpoint)

# If model was re-initialized, put it on the right device and update self.model_wrapped
Expand Down Expand Up @@ -1651,7 +1652,7 @@ def _inner_training_loop(

model = self._wrap_model(self.model_wrapped)

if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None:
if (is_sagemaker_mp_enabled() or self.is_fsdp_enabled) and resume_from_checkpoint is not None:
self._load_from_checkpoint(resume_from_checkpoint, model)

# as the model is wrapped, don't use `accelerator.prepare`
Expand Down Expand Up @@ -3886,7 +3887,6 @@ def create_accelerator_and_postprocess(self):
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params)

if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
Expand Down
80 changes: 47 additions & 33 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,13 +436,13 @@ class TrainingArguments:
deepspeed json config file (e.g., `ds_config.json`) or an already loaded json file as `dict`.

A List of config and its options:
- fsdp_min_num_params (`int`, *optional*, defaults to `0`):
- min_num_params (`int`, *optional*, defaults to `0`):
FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is
passed).
- fsdp_transformer_layer_cls_to_wrap (`List[str]`, *optional*):
- transformer_layer_cls_to_wrap (`List[str]`, *optional*):
List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`,
`T5Block` .... (useful only when `fsdp` flag is passed).
- fsdp_backward_prefetch (`str`, *optional*)
- backward_prefetch (`str`, *optional*)
FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when
`fsdp` field is passed).

Expand All @@ -454,14 +454,22 @@ class TrainingArguments:
- `"backward_post"` : This prefetches the next set of parameters after the current set of
parameter’s
gradient computation.
- fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`)
- forward_prefetch (`bool`, *optional*, defaults to `False`)
FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
forward pass.
- limit_all_gathers (`bool`, *optional*, defaults to `False`)
FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
all-gathers.
- use_orig_params (`bool`, *optional*, defaults to `False`)
If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed
frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please
refer this
[blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
- sync_module_states (`bool`, *optional*, defaults to `True`)
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to
ensure they are the same across all ranks after initialization
- xla (`bool`, *optional*, defaults to `False`):
Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
and its API may evolve in the future.
Expand Down Expand Up @@ -1520,44 +1528,44 @@ def __post_init__(self):
self.fsdp_config = {}

if isinstance(self.fsdp_config, str):
if len(self.fsdp) == 0:
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
with io.open(self.fsdp_config, "r", encoding="utf-8") as f:
self.fsdp_config = json.load(f)
for k, v in self.fsdp_config.items():
if k.startswith("fsdp_"):
self.fsdp_config[k.replace("fsdp_", "")] = v
del self.fsdp_config[k]

if self.fsdp_min_num_params > 0:
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)

self.fsdp_config["fsdp_min_num_params"] = max(
self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params
)
self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params)

# if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str):
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
]
# if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str):
self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]]

if self.fsdp_transformer_layer_cls_to_wrap is not None:
warnings.warn(
"using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning
)
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
"fsdp_transformer_layer_cls_to_wrap", []
self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
"transformer_layer_cls_to_wrap", []
) + [self.fsdp_transformer_layer_cls_to_wrap]

if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0:
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")
if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0:
warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.")

if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")
if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")

if (
len(self.fsdp) > 0
and self.fsdp_config["fsdp_min_num_params"] > 0
and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None
and self.fsdp_config["min_num_params"] > 0
and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None
):
raise ValueError(
"`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
)
raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.")
self.fsdp_config["xla"] = self.fsdp_config.get("xla", False)
self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False)
if self.fsdp_config["xla"]:
Expand All @@ -1583,23 +1591,29 @@ def __post_init__(self):
FSDP_SHARDING_STRATEGY,
)

prefix = "FSDP_"
for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
# set environment variable for FSDP sharding strategy
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
os.environ[f"{prefix}SHARDING_STRATEGY"] = str(
FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1
)
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP:
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
if self.fsdp_config["fsdp_min_num_params"] > 0:
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
if self.fsdp_config["min_num_params"] > 0:
os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"])
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["transformer_layer_cls_to_wrap"]
)
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefect", "false")
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true")
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false")

if self.tpu_metrics_debug:
warnings.warn(
Expand Down
Loading