Skip to content

Commit 87faff1

Browse files
committed
Fixes for reset_parameters()
1 parent 0feb245 commit 87faff1

File tree

3 files changed

+17
-35
lines changed

3 files changed

+17
-35
lines changed

04-fully-sharded-data-parallel/README.md

+7-20
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ Here is our [FSDP constructor](https://pytorch.org/docs/stable/fsdp.html#torch.d
8181
model = FullyShardedDataParallel(
8282
model,
8383
device_id=local_rank,
84-
param_init_fn=safe_param_init_fn,
8584
sync_module_states=True,
8685
auto_wrap_policy=wrap_policy,
8786
sharding_strategy=ShardingStrategy.FULL_SHARD,
@@ -93,28 +92,16 @@ model = FullyShardedDataParallel(
9392

9493
##### reset_parameters()
9594

96-
In most cases, if you just want to apply `reset_parameters()` - you actually don't have to specify this parameter. However some models (e.g. Llama 2/3.1) have modules that do not implement `reset_parameters()`. In this chapter we show how to implement a simple version of param_init_fn that is identical to the default FSDP, but just checks for the existence of reset_parameters.
95+
In most cases, if you just want to apply `reset_parameters()` - you actually don't have to specify this parameter. However some models (e.g. Llama 2/3.1) have modules that do not implement `reset_parameters()`.
9796

98-
From pytorch documentation:
99-
100-
> As of v1.12, FSDP detects modules with parameters or buffers on meta device via is_meta and either applies `param_init_fn` if specified or calls nn.Module.reset_parameters() otherwise.
101-
102-
You can see how the default behavior is specified in the pytorch source code [torch/distributed/fsdp/_init_utils.py#L889-L890](https://github.com/pytorch/pytorch/blob/v2.4.0/torch/distributed/fsdp/_init_utils.py#L889-L890)
97+
It is suggested that you implement them manually. Here is what we do for our llama models:
10398

10499
```python
105-
def safe_param_init_fn(module: torch.nn.Module):
106-
"""
107-
For use in FSDP constructor. This is identical to default behavior of FSDP when dealing with meta device,
108-
except pytorch code doesn't check for existence of `reset_parameters()` before calling it. Some modules
109-
don't have this implemented, so this is our "fix" for it.
110-
"""
111-
# NOTE: according to FSDP.__init__.param_init_fn documnetaiton, we should set recurse=False
112-
module.to_empty(device=device, recurse=False)
113-
# NOTE: Since we are training from scratch here, we just reset the parameters,
114-
# otherwise we may want to load in weights directly here, or load
115-
# parameters on rank 0 and use sync_module_states=True in FSDP constructor.
116-
if hasattr(module, "reset_parameters"):
117-
module.reset_parameters()
100+
from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding
101+
102+
# fixes for reset_parameters not existing
103+
LlamaRMSNorm.reset_parameters = lambda self: torch.nn.init.ones_(self.weight)
104+
LlamaRotaryEmbedding.reset_parameters = lambda _: None
118105
```
119106

120107
##### Loading a checkpoint

04-fully-sharded-data-parallel/train_llm.py

+5-15
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,11 @@
3737
AutoTokenizer,
3838
default_data_collator,
3939
)
40+
from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding
41+
42+
# fixes for reset_parameters not existing
43+
LlamaRMSNorm.reset_parameters = lambda self: torch.nn.init.ones_(self.weight)
44+
LlamaRotaryEmbedding.reset_parameters = lambda _: None
4045

4146
LOGGER = logging.getLogger(__name__)
4247

@@ -76,27 +81,12 @@ def main():
7681

7782
LOGGER.info(f"Before FSDP: {get_mem_stats(device)}")
7883

79-
def safe_param_init_fn(module: torch.nn.Module):
80-
"""
81-
For use in FSDP constructor. This is identical to default behavior of FSDP when dealing with meta device,
82-
except pytorch code doesn't check for existence of `reset_parameters()` before calling it. Some modules
83-
don't have this implemented, so this is our "fix" for it.
84-
"""
85-
# NOTE: according to FSDP.__init__.param_init_fn documnetaiton, we should set recurse=False
86-
module.to_empty(device=device, recurse=False)
87-
# NOTE: Since we are training from scratch here, we just reset the parameters,
88-
# otherwise we may want to load in weights directly here, or load
89-
# parameters on rank 0 and use sync_module_states=True in FSDP constructor.
90-
if hasattr(module, "reset_parameters"):
91-
module.reset_parameters()
92-
9384
wrap_policy = functools.partial(
9485
size_based_auto_wrap_policy, min_num_params=int(args.numel_to_wrap)
9586
)
9687
model = FullyShardedDataParallel(
9788
model,
9889
device_id=local_rank,
99-
param_init_fn=safe_param_init_fn,
10090
sync_module_states=True,
10191
# NOTE: FULL_SHARD is equivalent to deepspeed ZeRO stage 3
10292
auto_wrap_policy=wrap_policy,

05-training-llama-405b/train_llm.py

+5
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@
4141
AutoTokenizer,
4242
default_data_collator,
4343
)
44+
from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding
45+
46+
# fixes for reset_parameters not existing
47+
LlamaRMSNorm.reset_parameters = lambda self: torch.nn.init.ones_(self.weight)
48+
LlamaRotaryEmbedding.reset_parameters = lambda _: None
4449

4550
LOGGER = logging.getLogger(__name__)
4651

0 commit comments

Comments
 (0)