Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add util for ram efficient loading of model when using fsdp #25107

Merged
merged 16 commits into from
Aug 17, 2023
Merged
2 changes: 2 additions & 0 deletions docs/source/en/internal/trainer_utils.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ Most of those are only useful if you are studying the code of the Trainer in the

[[autodoc]] torch_distributed_zero_first

[[autodoc]] load_pretrained_model_only_on_rank0

## Callbacks internals

[[autodoc]] trainer_callback.CallbackHandler
Expand Down
7 changes: 5 additions & 2 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3059,7 +3059,10 @@
_import_structure["sagemaker"] = []
_import_structure["time_series_utils"] = []
_import_structure["trainer"] = ["Trainer"]
_import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"]
_import_structure["trainer_pt_utils"] = [
"load_pretrained_model_only_on_rank0",
"torch_distributed_zero_first",
]
_import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"]

# TensorFlow-backed objects
Expand Down Expand Up @@ -6598,7 +6601,7 @@

# Trainer
from .trainer import Trainer
from .trainer_pt_utils import torch_distributed_zero_first
from .trainer_pt_utils import load_pretrained_model_only_on_rank0, torch_distributed_zero_first
from .trainer_seq2seq import Seq2SeqTrainer

# TensorFlow
Expand Down
58 changes: 37 additions & 21 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@
_init_weights = True


def is_fsdp_enabled():
return os.environ["ACCELERATE_USE_FSDP"]


def is_fsdp_enabled_and_dist_rank_0():
return is_fsdp_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0


if is_sagemaker_mp_enabled():
import smdistributed.modelparallel.torch as smp
from smdistributed.modelparallel import __version__ as SMP_VERSION
Expand Down Expand Up @@ -457,7 +465,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
)
return safe_load_file(checkpoint_file)
try:
if is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0:
if (
(is_deepspeed_zero3_enabled() or is_fsdp_enabled())
and torch.distributed.is_initialized()
and torch.distributed.get_rank() > 0
):
map_location = "meta"
else:
map_location = "cpu"
Expand Down Expand Up @@ -541,7 +553,7 @@ def load(module: nn.Module, state_dict, prefix=""):
with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0):
if torch.distributed.get_rank() == 0:
module._load_from_state_dict(*args)
else:
elif not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
module._load_from_state_dict(*args)

for name, child in module._modules.items():
Expand Down Expand Up @@ -1481,7 +1493,7 @@ def _get_resized_embeddings(
with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0):
if torch.distributed.get_rank() == 0:
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
else:
elif not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]

return new_embeddings
Expand Down Expand Up @@ -1565,7 +1577,7 @@ def _get_resized_lm_head(
# Copy bias weights to new lm head
if has_new_lm_head_bias:
new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
else:
elif not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
# Copy old lm head weights to new lm head
if not transposed:
new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
Expand Down Expand Up @@ -2193,6 +2205,9 @@ def from_pretrained(
commit_hash = kwargs.pop("_commit_hash", None)
variant = kwargs.pop("variant", None)

if is_fsdp_enabled():
low_cpu_mem_usage = True

if use_auth_token is not None:
warnings.warn(
"The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
Expand Down Expand Up @@ -3265,23 +3280,24 @@ def _find_mismatched_keys(
)

if low_cpu_mem_usage:
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
is_quantized=is_quantized,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0():
new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model(
model_to_load,
state_dict,
loaded_keys,
start_prefix,
expected_keys,
device_map=device_map,
offload_folder=offload_folder,
offload_index=offload_index,
state_dict_folder=state_dict_folder,
state_dict_index=state_dict_index,
dtype=dtype,
is_quantized=is_quantized,
is_safetensors=is_safetensors,
keep_in_fp32_modules=keep_in_fp32_modules,
)
error_msgs += new_error_msgs
else:
error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix)

Expand Down
11 changes: 3 additions & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,10 +461,6 @@ def __init__(
):
self.backward_prefetch = BackwardPrefetch.BACKWARD_POST

self.forward_prefetch = False
if self.args.fsdp_config.get("forward_prefect", False):
self.forward_prefetch = True
pacman100 marked this conversation as resolved.
Show resolved Hide resolved

self.limit_all_gathers = False
if self.args.fsdp_config.get("limit_all_gathers", False):
self.limit_all_gathers = True
Expand Down Expand Up @@ -1379,12 +1375,12 @@ def _wrap_model(self, model, training=True, dataloader=None):
auto_wrapper_callable = None
default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None)
fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get(
"fsdp_transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
"transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap
)

if self.args.fsdp_config["fsdp_min_num_params"] > 0:
if self.args.fsdp_config["min_num_params"] > 0:
auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"]
size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"]
)
elif fsdp_transformer_layer_cls_to_wrap is not None:
transformer_cls_to_wrap = set()
Expand Down Expand Up @@ -3825,7 +3821,6 @@ def create_accelerator_and_postprocess(self):
fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get(
"limit_all_gathers", fsdp_plugin.limit_all_gathers
)
fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params)

if self.is_deepspeed_enabled:
if getattr(self.args, "hf_deepspeed_config", None) is None:
Expand Down
13 changes: 13 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1125,3 +1125,16 @@ def smp_nested_concat(tensor):
# It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step`
# which is also the name of the decorator so Python is confused.
return tensor.concat().detach().cpu()


def load_pretrained_model_only_on_rank0(model_cls, config_cls, model_name_or_path):
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
from accelerate.state import PartialState

state = PartialState()
if state.is_main_process:
model = model_cls.from_pretrained(model_name_or_path, return_dict=True)
else:
with torch.device("meta"):
config = config_cls.from_pretrained(model_name_or_path)
model = model_cls.from_config(config)
return model
59 changes: 38 additions & 21 deletions src/transformers/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,10 +458,17 @@ class TrainingArguments:
FSDP's forward prefetch mode (useful only when `fsdp` field is passed).
If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the
forward pass.
- limit_all_gathers (`bool`, *optional*, defaults to `False`)
- fsdp_limit_all_gathers (`bool`, *optional*, defaults to `False`)
FSDP's limit_all_gathers (useful only when `fsdp` field is passed).
If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight
all-gathers.
- fsdp_use_orig_params (`bool`, *optional*, defaults to `False`)
If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres.
Useful in cases such as parameter-efficient fine-tuning.
Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019
- fsdp_sync_module_states (`bool`, *optional*, defaults to `False`)
pacman100 marked this conversation as resolved.
Show resolved Hide resolved
If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0
to ensure they are the same across all ranks after initialization
- xla (`bool`, *optional*, defaults to `False`):
Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature
and its API may evolve in the future.
Expand Down Expand Up @@ -1511,40 +1518,44 @@ def __post_init__(self):
self.fsdp_config = {}

if isinstance(self.fsdp_config, str):
if len(self.fsdp) == 0:
warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.")
with io.open(self.fsdp_config, "r", encoding="utf-8") as f:
self.fsdp_config = json.load(f)
for k, v in self.fsdp_config.items():
if k.startswith("fsdp_"):
self.fsdp_config[k.replace("fsdp", "")] = v
del self.fsdp_config[k]

if self.fsdp_min_num_params > 0:
warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning)

self.fsdp_config["fsdp_min_num_params"] = max(
self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params
self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params
)

# if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str):
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
]
# if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object
if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str):
self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]]

if self.fsdp_transformer_layer_cls_to_wrap is not None:
warnings.warn(
"using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning
)
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
"fsdp_transformer_layer_cls_to_wrap", []
self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get(
"transformer_layer_cls_to_wrap", []
) + [self.fsdp_transformer_layer_cls_to_wrap]

if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0:
if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0:
warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.")

if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None:
warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.")

if (
len(self.fsdp) > 0
and self.fsdp_config["fsdp_min_num_params"] > 0
and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None
and self.fsdp_config["min_num_params"] > 0
and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None
):
raise ValueError(
"`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive."
Expand Down Expand Up @@ -1574,23 +1585,29 @@ def __post_init__(self):
FSDP_SHARDING_STRATEGY,
)

prefix = "FSDP_"
for fsdp_option in self.fsdp:
if fsdp_option.upper() in FSDP_SHARDING_STRATEGY:
# set environment variable for FSDP sharding strategy
os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1)
os.environ[f"{prefix}SHARDING_STRATEGY"] = str(
FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1
)
elif fsdp_option == FSDPOption.OFFLOAD:
os.environ["FSDP_OFFLOAD_PARAMS"] = "true"
os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true"
elif fsdp_option == FSDPOption.AUTO_WRAP:
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0]
if self.fsdp_config["fsdp_min_num_params"] > 0:
os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"])
os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"])
os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1]
elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None:
os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"]
os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join(
self.fsdp_config["transformer_layer_cls_to_wrap"]
)
prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH")
os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper()
os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefect", "false")
os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "false")
os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false")

if self.tpu_metrics_debug:
warnings.warn(
Expand Down
4 changes: 4 additions & 0 deletions src/transformers/utils/dummy_pt_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -8493,6 +8493,10 @@ def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])


def load_pretrained_model_only_on_rank0(*args, **kwargs):
requires_backends(load_pretrained_model_only_on_rank0, ["torch"])


def torch_distributed_zero_first(*args, **kwargs):
requires_backends(torch_distributed_zero_first, ["torch"])

Expand Down
Loading