From e8873d958728a71e5bd3fad2eb0a90a44f70b323 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 26 Jul 2023 16:15:18 +0530 Subject: [PATCH 01/14] add util for ram efficient loading of model when using fsdp --- src/transformers/__init__.py | 2 +- src/transformers/trainer_pt_utils.py | 19 +++++++++++++++++++ 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 5b4a75215aa72f..002d1d15fd53f2 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -6598,7 +6598,7 @@ # Trainer from .trainer import Trainer - from .trainer_pt_utils import torch_distributed_zero_first + from .trainer_pt_utils import load_model_from_pretrained_only_on_rank0, torch_distributed_zero_first from .trainer_seq2seq import Seq2SeqTrainer # TensorFlow diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index b57770f33b29b3..c7a524cffdea18 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -1125,3 +1125,22 @@ def smp_nested_concat(tensor): # It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step` # which is also the name of the decorator so Python is confused. return tensor.concat().detach().cpu() + + +def load_model_from_pretrained_only_on_rank0(model_cls, config_cls, model_name_or_path): + from accelerate.state import PartialState + + state = PartialState() + if state.is_main_process: + model = model_cls.from_pretrained(model_name_or_path, return_dict=True) + param_init_fn = None + else: + with torch.device("meta"): + config = config_cls.from_pretrained(model_name_or_path) + model = model_cls.from_config(config) + + def param_init_fn(x): + return x.to_empty(device=state.device, recurse=False) + + model.train() + return model, param_init_fn From 6f342297df54c5bd71e42ef61cf4694fa4c04a1c Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 26 Jul 2023 16:22:50 +0530 Subject: [PATCH 02/14] make fix-copies --- src/transformers/utils/dummy_pt_objects.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 65fedf02d8c9cf..50431bffbf0c3a 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8493,6 +8493,10 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) +def load_model_from_pretrained_only_on_rank0(*args, **kwargs): + requires_backends(load_model_from_pretrained_only_on_rank0, ["torch"]) + + def torch_distributed_zero_first(*args, **kwargs): requires_backends(torch_distributed_zero_first, ["torch"]) From 67d13dfbdb4c4743d12397b78a70a02db90484fa Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 26 Jul 2023 16:31:26 +0530 Subject: [PATCH 03/14] =?UTF-8?q?fixes=20=F0=9F=98=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/transformers/__init__.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 002d1d15fd53f2..1d7aeeea25dc3a 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3059,7 +3059,10 @@ _import_structure["sagemaker"] = [] _import_structure["time_series_utils"] = [] _import_structure["trainer"] = ["Trainer"] - _import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"] + _import_structure["trainer_pt_utils"] = [ + "load_model_from_pretrained_only_on_rank0", + "torch_distributed_zero_first", + ] _import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"] # TensorFlow-backed objects From fb996334e6e2aa5e5781d00de2d2e3a138542eef Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 26 Jul 2023 16:44:18 +0530 Subject: [PATCH 04/14] docs --- docs/source/en/internal/trainer_utils.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/en/internal/trainer_utils.md b/docs/source/en/internal/trainer_utils.md index e3f8a9b04536fa..47e0b9598b8d44 100644 --- a/docs/source/en/internal/trainer_utils.md +++ b/docs/source/en/internal/trainer_utils.md @@ -32,6 +32,8 @@ Most of those are only useful if you are studying the code of the Trainer in the [[autodoc]] torch_distributed_zero_first +[[autodoc]] load_model_from_pretrained_only_on_rank0 + ## Callbacks internals [[autodoc]] trainer_callback.CallbackHandler From a48cc1aa4e4389421ba1fb67ef08f6b696d11e17 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 26 Jul 2023 17:00:05 +0530 Subject: [PATCH 05/14] making it further easier to use --- src/transformers/trainer_pt_utils.py | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index c7a524cffdea18..482a898943c9fd 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -1133,14 +1133,8 @@ def load_model_from_pretrained_only_on_rank0(model_cls, config_cls, model_name_o state = PartialState() if state.is_main_process: model = model_cls.from_pretrained(model_name_or_path, return_dict=True) - param_init_fn = None else: with torch.device("meta"): config = config_cls.from_pretrained(model_name_or_path) model = model_cls.from_config(config) - - def param_init_fn(x): - return x.to_empty(device=state.device, recurse=False) - - model.train() - return model, param_init_fn + return model From 63885961aba61c6a0505ade2c976b78039c17331 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 26 Jul 2023 17:21:47 +0530 Subject: [PATCH 06/14] rename the function --- docs/source/en/internal/trainer_utils.md | 2 +- src/transformers/__init__.py | 4 ++-- src/transformers/trainer_pt_utils.py | 2 +- src/transformers/utils/dummy_pt_objects.py | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/docs/source/en/internal/trainer_utils.md b/docs/source/en/internal/trainer_utils.md index 47e0b9598b8d44..8035a6f8bf703e 100644 --- a/docs/source/en/internal/trainer_utils.md +++ b/docs/source/en/internal/trainer_utils.md @@ -32,7 +32,7 @@ Most of those are only useful if you are studying the code of the Trainer in the [[autodoc]] torch_distributed_zero_first -[[autodoc]] load_model_from_pretrained_only_on_rank0 +[[autodoc]] load_pretrained_model_only_on_rank0 ## Callbacks internals diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index 1d7aeeea25dc3a..d6c60531526836 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3060,7 +3060,7 @@ _import_structure["time_series_utils"] = [] _import_structure["trainer"] = ["Trainer"] _import_structure["trainer_pt_utils"] = [ - "load_model_from_pretrained_only_on_rank0", + "load_pretrained_model_only_on_rank0", "torch_distributed_zero_first", ] _import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"] @@ -6601,7 +6601,7 @@ # Trainer from .trainer import Trainer - from .trainer_pt_utils import load_model_from_pretrained_only_on_rank0, torch_distributed_zero_first + from .trainer_pt_utils import load_pretrained_model_only_on_rank0, torch_distributed_zero_first from .trainer_seq2seq import Seq2SeqTrainer # TensorFlow diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 482a898943c9fd..5ed096a4fe479c 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -1127,7 +1127,7 @@ def smp_nested_concat(tensor): return tensor.concat().detach().cpu() -def load_model_from_pretrained_only_on_rank0(model_cls, config_cls, model_name_or_path): +def load_pretrained_model_only_on_rank0(model_cls, config_cls, model_name_or_path): from accelerate.state import PartialState state = PartialState() diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 50431bffbf0c3a..5521bd74b95002 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8493,8 +8493,8 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -def load_model_from_pretrained_only_on_rank0(*args, **kwargs): - requires_backends(load_model_from_pretrained_only_on_rank0, ["torch"]) +def load_pretrained_model_only_on_rank0(*args, **kwargs): + requires_backends(load_pretrained_model_only_on_rank0, ["torch"]) def torch_distributed_zero_first(*args, **kwargs): From b18a1038c78fb7dad02b08fcde58a5e38b13b262 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 26 Jul 2023 19:40:13 +0530 Subject: [PATCH 07/14] refactor to handle fsdp ram efficiency in `from_pretrained` --- src/transformers/modeling_utils.py | 58 ++++++++++++++++++----------- src/transformers/trainer.py | 11 ++---- src/transformers/training_args.py | 59 +++++++++++++++++++----------- 3 files changed, 78 insertions(+), 50 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4d41ff49841c01..f902ae98e4cd75 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -105,6 +105,14 @@ _init_weights = True +def is_fsdp_enabled(): + return os.environ["ACCELERATE_USE_FSDP"] + + +def is_fsdp_enabled_and_dist_rank_0(): + return is_fsdp_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() == 0 + + if is_sagemaker_mp_enabled(): import smdistributed.modelparallel.torch as smp from smdistributed.modelparallel import __version__ as SMP_VERSION @@ -457,7 +465,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ) return safe_load_file(checkpoint_file) try: - if is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0: + if ( + (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) + and torch.distributed.is_initialized() + and torch.distributed.get_rank() > 0 + ): map_location = "meta" else: map_location = "cpu" @@ -541,7 +553,7 @@ def load(module: nn.Module, state_dict, prefix=""): with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): if torch.distributed.get_rank() == 0: module._load_from_state_dict(*args) - else: + elif not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0(): module._load_from_state_dict(*args) for name, child in module._modules.items(): @@ -1481,7 +1493,7 @@ def _get_resized_embeddings( with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0): if torch.distributed.get_rank() == 0: new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] - else: + elif not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0(): new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] return new_embeddings @@ -1565,7 +1577,7 @@ def _get_resized_lm_head( # Copy bias weights to new lm head if has_new_lm_head_bias: new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] - else: + elif not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0(): # Copy old lm head weights to new lm head if not transposed: new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] @@ -2193,6 +2205,9 @@ def from_pretrained( commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) + if is_fsdp_enabled(): + low_cpu_mem_usage = True + if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning @@ -3265,23 +3280,24 @@ def _find_mismatched_keys( ) if low_cpu_mem_usage: - new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( - model_to_load, - state_dict, - loaded_keys, - start_prefix, - expected_keys, - device_map=device_map, - offload_folder=offload_folder, - offload_index=offload_index, - state_dict_folder=state_dict_folder, - state_dict_index=state_dict_index, - dtype=dtype, - is_quantized=is_quantized, - is_safetensors=is_safetensors, - keep_in_fp32_modules=keep_in_fp32_modules, - ) - error_msgs += new_error_msgs + if not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0(): + new_error_msgs, offload_index, state_dict_index = _load_state_dict_into_meta_model( + model_to_load, + state_dict, + loaded_keys, + start_prefix, + expected_keys, + device_map=device_map, + offload_folder=offload_folder, + offload_index=offload_index, + state_dict_folder=state_dict_folder, + state_dict_index=state_dict_index, + dtype=dtype, + is_quantized=is_quantized, + is_safetensors=is_safetensors, + keep_in_fp32_modules=keep_in_fp32_modules, + ) + error_msgs += new_error_msgs else: error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 4ccad5b276d2f7..43b8f03148335a 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -461,10 +461,6 @@ def __init__( ): self.backward_prefetch = BackwardPrefetch.BACKWARD_POST - self.forward_prefetch = False - if self.args.fsdp_config.get("forward_prefect", False): - self.forward_prefetch = True - self.limit_all_gathers = False if self.args.fsdp_config.get("limit_all_gathers", False): self.limit_all_gathers = True @@ -1379,12 +1375,12 @@ def _wrap_model(self, model, training=True, dataloader=None): auto_wrapper_callable = None default_transformer_cls_names_to_wrap = getattr(model, "_no_split_modules", None) fsdp_transformer_layer_cls_to_wrap = self.args.fsdp_config.get( - "fsdp_transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap + "transformer_layer_cls_to_wrap", default_transformer_cls_names_to_wrap ) - if self.args.fsdp_config["fsdp_min_num_params"] > 0: + if self.args.fsdp_config["min_num_params"] > 0: auto_wrap_policy = functools.partial( - size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["fsdp_min_num_params"] + size_based_auto_wrap_policy, min_num_params=self.args.fsdp_config["min_num_params"] ) elif fsdp_transformer_layer_cls_to_wrap is not None: transformer_cls_to_wrap = set() @@ -3825,7 +3821,6 @@ def create_accelerator_and_postprocess(self): fsdp_plugin.limit_all_gathers = self.args.fsdp_config.get( "limit_all_gathers", fsdp_plugin.limit_all_gathers ) - fsdp_plugin.use_orig_params = self.args.fsdp_config.get("use_orig_params", fsdp_plugin.use_orig_params) if self.is_deepspeed_enabled: if getattr(self.args, "hf_deepspeed_config", None) is None: diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index b409a84bed8b8d..ce5b016d790c5d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -458,10 +458,17 @@ class TrainingArguments: FSDP's forward prefetch mode (useful only when `fsdp` field is passed). If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. - - limit_all_gathers (`bool`, *optional*, defaults to `False`) + - fsdp_limit_all_gathers (`bool`, *optional*, defaults to `False`) FSDP's limit_all_gathers (useful only when `fsdp` field is passed). If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. + - fsdp_use_orig_params (`bool`, *optional*, defaults to `False`) + If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. + Useful in cases such as parameter-efficient fine-tuning. + Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 + - fsdp_sync_module_states (`bool`, *optional*, defaults to `False`) + If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 + to ensure they are the same across all ranks after initialization - xla (`bool`, *optional*, defaults to `False`): Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature and its API may evolve in the future. @@ -1511,40 +1518,44 @@ def __post_init__(self): self.fsdp_config = {} if isinstance(self.fsdp_config, str): + if len(self.fsdp) == 0: + warnings.warn("`--fsdp_config` is useful only when `--fsdp` is specified.") with io.open(self.fsdp_config, "r", encoding="utf-8") as f: self.fsdp_config = json.load(f) + for k, v in self.fsdp_config.items(): + if k.startswith("fsdp_"): + self.fsdp_config[k.replace("fsdp", "")] = v + del self.fsdp_config[k] if self.fsdp_min_num_params > 0: warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning) self.fsdp_config["fsdp_min_num_params"] = max( - self.fsdp_config.get("fsdp_min_num_params", 0), self.fsdp_min_num_params + self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params ) - # if fsdp_config["fsdp_transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object - if isinstance(self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None), str): - self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = [ - self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] - ] + # if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object + if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str): + self.fsdp_config["transformer_layer_cls_to_wrap"] = [self.fsdp_config["transformer_layer_cls_to_wrap"]] if self.fsdp_transformer_layer_cls_to_wrap is not None: warnings.warn( "using `--fsdp_transformer_layer_cls_to_wrap` is deprecated. Use fsdp_config instead ", FutureWarning ) - self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] = self.fsdp_config.get( - "fsdp_transformer_layer_cls_to_wrap", [] + self.fsdp_config["transformer_layer_cls_to_wrap"] = self.fsdp_config.get( + "transformer_layer_cls_to_wrap", [] ) + [self.fsdp_transformer_layer_cls_to_wrap] - if len(self.fsdp) == 0 and self.fsdp_config["fsdp_min_num_params"] > 0: + if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0: warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.") - if len(self.fsdp) == 0 and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None: warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.") if ( len(self.fsdp) > 0 - and self.fsdp_config["fsdp_min_num_params"] > 0 - and self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None + and self.fsdp_config["min_num_params"] > 0 + and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None ): raise ValueError( "`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive." @@ -1574,23 +1585,29 @@ def __post_init__(self): FSDP_SHARDING_STRATEGY, ) + prefix = "FSDP_" for fsdp_option in self.fsdp: if fsdp_option.upper() in FSDP_SHARDING_STRATEGY: # set environment variable for FSDP sharding strategy - os.environ["FSDP_SHARDING_STRATEGY"] = str(FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1) + os.environ[f"{prefix}SHARDING_STRATEGY"] = str( + FSDP_SHARDING_STRATEGY.index(fsdp_option.upper()) + 1 + ) elif fsdp_option == FSDPOption.OFFLOAD: - os.environ["FSDP_OFFLOAD_PARAMS"] = "true" + os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true" elif fsdp_option == FSDPOption.AUTO_WRAP: - os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0] + os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0] if self.fsdp_config["fsdp_min_num_params"] > 0: - os.environ["FSDP_MIN_NUM_PARAMS"] = str(self.fsdp_config["fsdp_min_num_params"]) - os.environ["FSDP_AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] + os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"]) + os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: - os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = ",".join( - self.fsdp_config["fsdp_transformer_layer_cls_to_wrap"] + os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join( + self.fsdp_config["transformer_layer_cls_to_wrap"] ) prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH") - os.environ["FSDP_BACKWARD_PREFETCH"] = prefetch_policy.upper() + os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() + os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefect", "false") + os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "false") + os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false") if self.tpu_metrics_debug: warnings.warn( From 353fa8ac5e9443a535fb96339f0a0fb2957cf4ef Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Wed, 26 Jul 2023 21:49:41 +0530 Subject: [PATCH 08/14] fixes --- src/transformers/modeling_utils.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index f902ae98e4cd75..aecab420fc9981 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -72,6 +72,7 @@ is_torch_tpu_available, logging, replace_return_docstrings, + strtobool, ) from .utils.hub import convert_file_size_to_int, get_checkpoint_shard_files from .utils.import_utils import ENV_VARS_TRUE_VALUES, is_sagemaker_mp_enabled @@ -106,7 +107,7 @@ def is_fsdp_enabled(): - return os.environ["ACCELERATE_USE_FSDP"] + return strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1 def is_fsdp_enabled_and_dist_rank_0(): @@ -1154,6 +1155,12 @@ def _from_config(cls, config, **kwargs): # and memory copying it on CPU or each GPU first with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()): model = cls(config, **kwargs) + elif is_fsdp_enabled(): + if is_fsdp_enabled_and_dist_rank_0(): + model = cls(config, **kwargs) + else: + with torch.device("meta"): + model = cls(config, **kwargs) else: model = cls(config, **kwargs) @@ -2205,9 +2212,6 @@ def from_pretrained( commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) - if is_fsdp_enabled(): - low_cpu_mem_usage = True - if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning @@ -2262,7 +2266,7 @@ def from_pretrained( # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. require_version_core("torch>=1.10") - if is_deepspeed_zero3_enabled(): + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): raise ValueError( "DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`." ) @@ -2714,6 +2718,9 @@ def from_pretrained( init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts elif load_in_8bit or load_in_4bit or low_cpu_mem_usage: init_contexts.append(init_empty_weights()) + elif is_fsdp_enabled(): + if not is_fsdp_enabled_and_dist_rank_0(): + init_contexts.append(torch.device("meta")) with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) From 34421001a39f91607e2d436e224b45b3b32d82ba Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 27 Jul 2023 03:18:45 +0530 Subject: [PATCH 09/14] fixes --- src/transformers/modeling_utils.py | 40 ++++++++++++++++-------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index aecab420fc9981..c293a38d0fdd95 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -466,11 +466,7 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ) return safe_load_file(checkpoint_file) try: - if ( - (is_deepspeed_zero3_enabled() or is_fsdp_enabled()) - and torch.distributed.is_initialized() - and torch.distributed.get_rank() > 0 - ): + if is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0: map_location = "meta" else: map_location = "cpu" @@ -554,7 +550,7 @@ def load(module: nn.Module, state_dict, prefix=""): with deepspeed.zero.GatheredParameters(params_to_gather, modifier_rank=0): if torch.distributed.get_rank() == 0: module._load_from_state_dict(*args) - elif not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0(): + else: module._load_from_state_dict(*args) for name, child in module._modules.items(): @@ -1155,12 +1151,6 @@ def _from_config(cls, config, **kwargs): # and memory copying it on CPU or each GPU first with deepspeed.zero.Init(config_dict_or_path=deepspeed_config()): model = cls(config, **kwargs) - elif is_fsdp_enabled(): - if is_fsdp_enabled_and_dist_rank_0(): - model = cls(config, **kwargs) - else: - with torch.device("meta"): - model = cls(config, **kwargs) else: model = cls(config, **kwargs) @@ -1500,7 +1490,7 @@ def _get_resized_embeddings( with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=0): if torch.distributed.get_rank() == 0: new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] - elif not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0(): + else: new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :] return new_embeddings @@ -1584,7 +1574,7 @@ def _get_resized_lm_head( # Copy bias weights to new lm head if has_new_lm_head_bias: new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] - elif not is_fsdp_enabled() or is_fsdp_enabled_and_dist_rank_0(): + else: # Copy old lm head weights to new lm head if not transposed: new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] @@ -2212,6 +2202,9 @@ def from_pretrained( commit_hash = kwargs.pop("_commit_hash", None) variant = kwargs.pop("variant", None) + if is_fsdp_enabled(): + low_cpu_mem_usage = True + if use_auth_token is not None: warnings.warn( "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning @@ -2266,7 +2259,7 @@ def from_pretrained( # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info. require_version_core("torch>=1.10") - if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + if is_deepspeed_zero3_enabled(): raise ValueError( "DeepSpeed Zero-3 is not compatible with `low_cpu_mem_usage=True` or with passing a `device_map`." ) @@ -2718,9 +2711,6 @@ def from_pretrained( init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config())] + init_contexts elif load_in_8bit or load_in_4bit or low_cpu_mem_usage: init_contexts.append(init_empty_weights()) - elif is_fsdp_enabled(): - if not is_fsdp_enabled_and_dist_rank_0(): - init_contexts.append(torch.device("meta")) with ContextManagers(init_contexts): model = cls(config, *model_args, **model_kwargs) @@ -3082,7 +3072,8 @@ def _fix_key(key): unexpected_keys = list(unexpected_keys - model_buffers) model.tie_weights() - if device_map is None: + if device_map is None and not is_fsdp_enabled(): + print("no device map") ptrs = collections.defaultdict(list) for name, tensor in model.state_dict().items(): id_tensor = id_tensor_storage(tensor) @@ -3305,6 +3296,17 @@ def _find_mismatched_keys( keep_in_fp32_modules=keep_in_fp32_modules, ) error_msgs += new_error_msgs + else: + for key, param in model_to_load.state_dict().items(): + if param.device == torch.device("meta"): + if not (is_quantized): + set_module_tensor_to_device( + model, key, "cpu", torch.empty(*param.size(), dtype=dtype) + ) + else: + set_module_quantized_tensor_to_device( + model, key, "cpu", torch.empty(*param.size(), dtype=dtype) + ) else: error_msgs += _load_state_dict_into_model(model_to_load, state_dict, start_prefix) From e75646da2b1eb051c9a5231aedb11519c2c8cc15 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 27 Jul 2023 03:45:54 +0530 Subject: [PATCH 10/14] fixes --- src/transformers/modeling_utils.py | 6 ++++- src/transformers/training_args.py | 39 ++++++++++++++---------------- 2 files changed, 23 insertions(+), 22 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index c293a38d0fdd95..8d7a9c19fb1d15 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -466,7 +466,11 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike]): ) return safe_load_file(checkpoint_file) try: - if is_deepspeed_zero3_enabled() and torch.distributed.is_initialized() and torch.distributed.get_rank() > 0: + if ( + (is_deepspeed_zero3_enabled() or is_fsdp_enabled) + and torch.distributed.is_initialized() + and torch.distributed.get_rank() > 0 + ): map_location = "meta" else: map_location = "cpu" diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index ce5b016d790c5d..5567bd5696574f 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -436,13 +436,13 @@ class TrainingArguments: deepspeed json config file (e.g., `ds_config.json`) or an already loaded json file as `dict`. A List of config and its options: - - fsdp_min_num_params (`int`, *optional*, defaults to `0`): + - min_num_params (`int`, *optional*, defaults to `0`): FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `fsdp` field is passed). - - fsdp_transformer_layer_cls_to_wrap (`List[str]`, *optional*): + - transformer_layer_cls_to_wrap (`List[str]`, *optional*): List of transformer layer class names (case-sensitive) to wrap, e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... (useful only when `fsdp` flag is passed). - - fsdp_backward_prefetch (`str`, *optional*) + - backward_prefetch (`str`, *optional*) FSDP's backward prefetch mode. Controls when to prefetch next set of parameters (useful only when `fsdp` field is passed). @@ -454,21 +454,22 @@ class TrainingArguments: - `"backward_post"` : This prefetches the next set of parameters after the current set of parameter’s gradient computation. - - fsdp_forward_prefetch (`bool`, *optional*, defaults to `False`) + - forward_prefetch (`bool`, *optional*, defaults to `False`) FSDP's forward prefetch mode (useful only when `fsdp` field is passed). If `"True"`, then FSDP explicitly prefetches the next upcoming all-gather while executing in the forward pass. - - fsdp_limit_all_gathers (`bool`, *optional*, defaults to `False`) + - limit_all_gathers (`bool`, *optional*, defaults to `False`) FSDP's limit_all_gathers (useful only when `fsdp` field is passed). If `"True"`, FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. - - fsdp_use_orig_params (`bool`, *optional*, defaults to `False`) - If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres. - Useful in cases such as parameter-efficient fine-tuning. - Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 - - fsdp_sync_module_states (`bool`, *optional*, defaults to `False`) - If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 - to ensure they are the same across all ranks after initialization + - use_orig_params (`bool`, *optional*, defaults to `False`) + If `"True"`, allows non-uniform `requires_grad` during init, which means support for interspersed + frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please + refer this + [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 + - sync_module_states (`bool`, *optional*, defaults to `False`) + If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to + ensure they are the same across all ranks after initialization - xla (`bool`, *optional*, defaults to `False`): Whether to use PyTorch/XLA Fully Sharded Data Parallel Training. This is an experimental feature and its API may evolve in the future. @@ -1524,15 +1525,13 @@ def __post_init__(self): self.fsdp_config = json.load(f) for k, v in self.fsdp_config.items(): if k.startswith("fsdp_"): - self.fsdp_config[k.replace("fsdp", "")] = v + self.fsdp_config[k.replace("fsdp_", "")] = v del self.fsdp_config[k] if self.fsdp_min_num_params > 0: warnings.warn("using `--fsdp_min_num_params` is deprecated. Use fsdp_config instead ", FutureWarning) - self.fsdp_config["fsdp_min_num_params"] = max( - self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params - ) + self.fsdp_config["min_num_params"] = max(self.fsdp_config.get("min_num_params", 0), self.fsdp_min_num_params) # if fsdp_config["transformer_layer_cls_to_wrap"] is specified as a string, convert it to a list with a single object if isinstance(self.fsdp_config.get("transformer_layer_cls_to_wrap", None), str): @@ -1547,19 +1546,17 @@ def __post_init__(self): ) + [self.fsdp_transformer_layer_cls_to_wrap] if len(self.fsdp) == 0 and self.fsdp_config["min_num_params"] > 0: - warnings.warn("`--fsdp_min_num_params` is useful only when `--fsdp` is specified.") + warnings.warn("`min_num_params` is useful only when `--fsdp` is specified.") if len(self.fsdp) == 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None: - warnings.warn("`--fsdp_transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.") + warnings.warn("`transformer_layer_cls_to_wrap` is useful only when `--fsdp` is specified.") if ( len(self.fsdp) > 0 and self.fsdp_config["min_num_params"] > 0 and self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None ): - raise ValueError( - "`--fsdp_min_num_params` and `--fsdp_transformer_layer_cls_to_wrap` are mutually exclusive." - ) + raise ValueError("`min_num_params` and `transformer_layer_cls_to_wrap` are mutually exclusive.") self.fsdp_config["xla"] = self.fsdp_config.get("xla", False) self.fsdp_config["xla_fsdp_grad_ckpt"] = self.fsdp_config.get("xla_fsdp_grad_ckpt", False) if self.fsdp_config["xla"]: From 79351ae19fa8498f76bd4736ec67dab570e8a085 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 27 Jul 2023 04:29:37 +0530 Subject: [PATCH 11/14] update --- src/transformers/training_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 5567bd5696574f..abea5cb33e37c1 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -467,7 +467,7 @@ class TrainingArguments: frozen and trainable paramteres. Useful in cases such as parameter-efficient fine-tuning. Please refer this [blog](https://dev-discuss.pytorch.org/t/rethinking-pytorch-fully-sharded-data-parallel-fsdp-from-first-principles/1019 - - sync_module_states (`bool`, *optional*, defaults to `False`) + - sync_module_states (`bool`, *optional*, defaults to `True`) If `"True"`, each individually wrapped FSDP unit will broadcast module parameters from rank 0 to ensure they are the same across all ranks after initialization - xla (`bool`, *optional*, defaults to `False`): @@ -1603,7 +1603,7 @@ def __post_init__(self): prefetch_policy = self.fsdp_config.get("fsdp_backward_prefetch", "NO_PREFETCH") os.environ[f"{prefix}BACKWARD_PREFETCH"] = prefetch_policy.upper() os.environ[f"{prefix}FORWARD_PREFETCH"] = self.fsdp_config.get("forward_prefect", "false") - os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "false") + os.environ[f"{prefix}SYNC_MODULE_STATES"] = self.fsdp_config.get("sync_module_states", "true") os.environ[f"{prefix}USE_ORIG_PARAMS"] = self.fsdp_config.get("use_orig_params", "false") if self.tpu_metrics_debug: From 2a5acd184b7dd8e73b6c4dd1a935cc2dca4f07c0 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 27 Jul 2023 04:44:57 +0530 Subject: [PATCH 12/14] fixes --- src/transformers/training_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index abea5cb33e37c1..72ccf969638011 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1593,10 +1593,10 @@ def __post_init__(self): os.environ[f"{prefix}OFFLOAD_PARAMS"] = "true" elif fsdp_option == FSDPOption.AUTO_WRAP: os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[0] - if self.fsdp_config["fsdp_min_num_params"] > 0: + if self.fsdp_config["min_num_params"] > 0: os.environ[f"{prefix}MIN_NUM_PARAMS"] = str(self.fsdp_config["min_num_params"]) os.environ[f"{prefix}AUTO_WRAP_POLICY"] = FSDP_AUTO_WRAP_POLICY[1] - elif self.fsdp_config.get("fsdp_transformer_layer_cls_to_wrap", None) is not None: + elif self.fsdp_config.get("transformer_layer_cls_to_wrap", None) is not None: os.environ[f"{prefix}TRANSFORMER_CLS_TO_WRAP"] = ",".join( self.fsdp_config["transformer_layer_cls_to_wrap"] ) From 449afd4df2998644ec9b9f2d0cc669525a48d1c2 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 27 Jul 2023 11:56:23 +0530 Subject: [PATCH 13/14] revert `load_pretrained_model_only_on_rank0` --- docs/source/en/internal/trainer_utils.md | 2 -- src/transformers/__init__.py | 7 ++----- src/transformers/trainer_pt_utils.py | 13 ------------- src/transformers/utils/dummy_pt_objects.py | 4 ---- 4 files changed, 2 insertions(+), 24 deletions(-) diff --git a/docs/source/en/internal/trainer_utils.md b/docs/source/en/internal/trainer_utils.md index 8035a6f8bf703e..e3f8a9b04536fa 100644 --- a/docs/source/en/internal/trainer_utils.md +++ b/docs/source/en/internal/trainer_utils.md @@ -32,8 +32,6 @@ Most of those are only useful if you are studying the code of the Trainer in the [[autodoc]] torch_distributed_zero_first -[[autodoc]] load_pretrained_model_only_on_rank0 - ## Callbacks internals [[autodoc]] trainer_callback.CallbackHandler diff --git a/src/transformers/__init__.py b/src/transformers/__init__.py index d6c60531526836..5b4a75215aa72f 100644 --- a/src/transformers/__init__.py +++ b/src/transformers/__init__.py @@ -3059,10 +3059,7 @@ _import_structure["sagemaker"] = [] _import_structure["time_series_utils"] = [] _import_structure["trainer"] = ["Trainer"] - _import_structure["trainer_pt_utils"] = [ - "load_pretrained_model_only_on_rank0", - "torch_distributed_zero_first", - ] + _import_structure["trainer_pt_utils"] = ["torch_distributed_zero_first"] _import_structure["trainer_seq2seq"] = ["Seq2SeqTrainer"] # TensorFlow-backed objects @@ -6601,7 +6598,7 @@ # Trainer from .trainer import Trainer - from .trainer_pt_utils import load_pretrained_model_only_on_rank0, torch_distributed_zero_first + from .trainer_pt_utils import torch_distributed_zero_first from .trainer_seq2seq import Seq2SeqTrainer # TensorFlow diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 5ed096a4fe479c..b57770f33b29b3 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -1125,16 +1125,3 @@ def smp_nested_concat(tensor): # It doesn't seem possible to check here if `tensor` is a StepOutput because StepOutput lives in `smp.step` # which is also the name of the decorator so Python is confused. return tensor.concat().detach().cpu() - - -def load_pretrained_model_only_on_rank0(model_cls, config_cls, model_name_or_path): - from accelerate.state import PartialState - - state = PartialState() - if state.is_main_process: - model = model_cls.from_pretrained(model_name_or_path, return_dict=True) - else: - with torch.device("meta"): - config = config_cls.from_pretrained(model_name_or_path) - model = model_cls.from_config(config) - return model diff --git a/src/transformers/utils/dummy_pt_objects.py b/src/transformers/utils/dummy_pt_objects.py index 5521bd74b95002..65fedf02d8c9cf 100644 --- a/src/transformers/utils/dummy_pt_objects.py +++ b/src/transformers/utils/dummy_pt_objects.py @@ -8493,10 +8493,6 @@ def __init__(self, *args, **kwargs): requires_backends(self, ["torch"]) -def load_pretrained_model_only_on_rank0(*args, **kwargs): - requires_backends(load_pretrained_model_only_on_rank0, ["torch"]) - - def torch_distributed_zero_first(*args, **kwargs): requires_backends(torch_distributed_zero_first, ["torch"]) From 3adfddd30356c87308bd60d17dc05cd2a3abb4f0 Mon Sep 17 00:00:00 2001 From: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com> Date: Thu, 27 Jul 2023 11:59:23 +0530 Subject: [PATCH 14/14] resolve `load_from_checkpoint` --- src/transformers/trainer.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 43b8f03148335a..762906228133da 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1513,7 +1513,12 @@ def train( if resume_from_checkpoint is None: raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") - if resume_from_checkpoint is not None and not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled: + if ( + resume_from_checkpoint is not None + and not is_sagemaker_mp_enabled() + and not self.is_deepspeed_enabled + and not self.is_fsdp_enabled + ): self._load_from_checkpoint(resume_from_checkpoint) # If model was re-initialized, put it on the right device and update self.model_wrapped @@ -1625,7 +1630,7 @@ def _inner_training_loop( model = self._wrap_model(self.model_wrapped) - if is_sagemaker_mp_enabled() and resume_from_checkpoint is not None: + if (is_sagemaker_mp_enabled() or self.is_fsdp_enabled) and resume_from_checkpoint is not None: self._load_from_checkpoint(resume_from_checkpoint, model) # as the model is wrapped, don't use `accelerator.prepare`