You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardexpand all lines: 04-fully-sharded-data-parallel/README.md
+7-20
Original file line number
Diff line number
Diff line change
@@ -81,7 +81,6 @@ Here is our [FSDP constructor](https://pytorch.org/docs/stable/fsdp.html#torch.d
81
81
model = FullyShardedDataParallel(
82
82
model,
83
83
device_id=local_rank,
84
-
param_init_fn=safe_param_init_fn,
85
84
sync_module_states=True,
86
85
auto_wrap_policy=wrap_policy,
87
86
sharding_strategy=ShardingStrategy.FULL_SHARD,
@@ -93,28 +92,16 @@ model = FullyShardedDataParallel(
93
92
94
93
##### reset_parameters()
95
94
96
-
In most cases, if you just want to apply `reset_parameters()` - you actually don't have to specify this parameter. However some models (e.g. Llama 2/3.1) have modules that do not implement `reset_parameters()`. In this chapter we show how to implement a simple version of param_init_fn that is identical to the default FSDP, but just checks for the existence of reset_parameters.
95
+
In most cases, if you just want to apply `reset_parameters()` - you actually don't have to specify this parameter. However some models (e.g. Llama 2/3.1) have modules that do not implement `reset_parameters()`.
97
96
98
-
From pytorch documentation:
99
-
100
-
> As of v1.12, FSDP detects modules with parameters or buffers on meta device via is_meta and either applies `param_init_fn` if specified or calls nn.Module.reset_parameters() otherwise.
101
-
102
-
You can see how the default behavior is specified in the pytorch source code [torch/distributed/fsdp/_init_utils.py#L889-L890](https://github.com/pytorch/pytorch/blob/v2.4.0/torch/distributed/fsdp/_init_utils.py#L889-L890)
97
+
It is suggested that you implement them manually. Here is what we do for our llama models:
103
98
104
99
```python
105
-
defsafe_param_init_fn(module: torch.nn.Module):
106
-
"""
107
-
For use in FSDP constructor. This is identical to default behavior of FSDP when dealing with meta device,
108
-
except pytorch code doesn't check for existence of `reset_parameters()` before calling it. Some modules
109
-
don't have this implemented, so this is our "fix" for it.
110
-
"""
111
-
#NOTE: according to FSDP.__init__.param_init_fn documnetaiton, we should set recurse=False
112
-
module.to_empty(device=device, recurse=False)
113
-
#NOTE: Since we are training from scratch here, we just reset the parameters,
114
-
# otherwise we may want to load in weights directly here, or load
115
-
# parameters on rank 0 and use sync_module_states=True in FSDP constructor.
116
-
ifhasattr(module, "reset_parameters"):
117
-
module.reset_parameters()
100
+
from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaRotaryEmbedding
0 commit comments