-
Notifications
You must be signed in to change notification settings - Fork 531
/
hf_fsdp.py
257 lines (212 loc) · 9.72 KB
/
hf_fsdp.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
# helper functions from https://github.com/CarperAI/trlx/blob/main/trlx/utils/modeling.py
# which is MIT licensed
import functools
from typing import Any, Iterable, List, Optional
import torch
from transformers import PreTrainedModel
from transformers.models.opt.modeling_opt import OPTDecoder
# helper functions
def rhasattr(obj: Any, attr: str) -> bool:
"""A chain-able attribute version of hasattr.
For example, to check if
`obj` has the attribute `foo.bar.baz`, you can use:
`rhasattr(obj, "foo.bar.baz")`
Reference: https://stackoverflow.com/a/67303315
"""
_nested_attrs = attr.split('.')
_curr_obj = obj
for _a in _nested_attrs[:-1]:
if hasattr(_curr_obj, _a):
_curr_obj = getattr(_curr_obj, _a)
else:
return False
return hasattr(_curr_obj, _nested_attrs[-1])
def rgetattr(obj: Any, attr: str, *args: List[Any]) -> Any:
"""A chain-able attribute version of getattr.
For example, to get the attribute `foo.bar.baz` from `obj`, you can use:
`rgetattr(obj, "foo.bar.baz")`
Reference: https://stackoverflow.com/a/31174427
"""
def _getattr(obj: Any, attr: str):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
def findattr(obj: Any, attrs: Iterable[str]) -> Optional[Any]:
for attr in attrs:
if rhasattr(obj, attr):
return rgetattr(obj, attr)
return None
def hf_get_causal_base_model(model: PreTrainedModel) -> Any:
"""Returns the causal decoder backbone of the specified HuggingFace model.
Newer HF models have a `self.get_decoder()` method. Older models do not.
NOTE: Different model configurations have different causal decoder attribute
names.
- transformer: (GPT2LMHeadModel, GPTJConfig)
- model.decoder: (OPTConfig, BloomConfig)
- gpt_neox: (GPTNeoXConfig)
"""
if hasattr(model, 'get_decoder'):
return model.get_decoder()
decoder_attrs = ('transformer', 'model.decoder', 'gpt_neox')
causal_base_model = findattr(model, decoder_attrs)
if causal_base_model is None:
raise ValueError(
f'Unable to FSDP-wrap model {model}. Please open a github issue to add support.'
)
return causal_base_model
def hf_get_hidden_layers(model: PreTrainedModel) -> Any:
"""Returns the hidden layers of the specified model.
NOTE: Different model configurations have different hidden layer attribute names.
- transformer.h: (BloomForCausalLM, GPT2LMHeadModel, GPTJForCausalLM)
- model.decoder.layers: (OPTForCausalLM)
- gpt_neox.layers: (GPTNeoXForCausalLM)
- model.layers: (LlaMaForCausalLM)
- transformer.blocks: (MPTForCausalLM)
"""
hidden_layers_attrs = (
'transformer.h', # BLOOM, GPT2, GPTJ
'model.decoder.layers', # OPT
'gpt_neox.layers', # GPTNeoX
'block', # T5, BART, Pegasus (from encoder)
'layers', # ProphetNet, Marian (from encoder)
'model.layers', # LLaMa
'transformer.blocks', # MPT
)
layers = findattr(model, hidden_layers_attrs)
if layers is None:
raise ValueError(
f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}'
)
return layers
def hf_get_init_device(init_device: Optional[str]) -> Optional[str]:
"""Returns the appropriate device to initialize models."""
from composer.utils import dist
if init_device == 'mixed':
if dist.get_local_rank() == 0:
return 'cpu'
return 'meta'
return init_device
# /end helper functions
def prepare_hf_model_for_fsdp(model: PreTrainedModel,
init_device: Optional[str]) -> None:
"""FSDP wrap a HuggingFace model.
Call specific functions
"""
if model.config.is_encoder_decoder:
prepare_hf_enc_dec_model_for_fsdp(model, init_device)
else:
# many common decoder-only model do not set the flag
# model.config.is_decoder, so we can't trust it
prepare_hf_causal_lm_model_for_fsdp(model, init_device)
def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel,
init_device: Optional[str]) -> None:
"""FSDP wrap a HuggingFace decoder.
Wrap any model for FSDP which follows one of the 3 existing conventions from
HuggingFace for decoder-only LLMs.
"""
causal_base_model = hf_get_causal_base_model(model)
# OPT has an extra layer of wrapping, so special case here
if isinstance(causal_base_model, OPTDecoder):
model.model._fsdp_wrap = False
model_block = hf_get_hidden_layers(model)
lm_head = model.get_output_embeddings()
# some models (OPT) implement .get_input_embeddings for the causal subclass
# but all of them implement it for the base model
tied_embeddings = causal_base_model.get_input_embeddings()
modules = {
'base_model': causal_base_model,
'model_block': model_block,
'lm_head': lm_head,
'tied_embeddings': tied_embeddings
}
for mod_name, module in modules.items():
if module is None:
raise ValueError(
f'Unable to FSDP-wrap this model! `{mod_name}` does not ' +
'follow common layer/weight naming conventions.')
block_type = type(model_block[0])
if init_device == 'mixed':
# For FSDP with models with different device initializations, `mixed`, which
# initializes the model on rank 0 on `cpu` and on all other ranks on `meta,``
# we need to tag all child modules that are torch.nn.Modules with `_fsdp_wrap`.
for child in model.children():
if isinstance(child, type(causal_base_model)):
continue
if isinstance(child, torch.nn.Module):
child._fsdp_wrap = True
for child in causal_base_model.children():
if isinstance(child, torch.nn.ModuleList):
continue
if isinstance(child, torch.nn.Module):
child._fsdp_wrap = True
if model.config.tie_word_embeddings and not model.config.model_type == 'mpt':
raise ValueError(
'The passed in HuggingFaceModel has tied word embeddings ' +
'and the passed in initialization device is `mixed.` ' +
'In order to support this initialization scheme, we would need to break '
+
'the weight tying. As a result, either use a different initialization scheme '
+ 'or in the model config set `tie_word_embeddings=False.`')
else:
# When using the HF LM models,
# the weights of the self.lm_head and self.transformer.wte are tied.
# This tying occurs inside the `self.post_init()` function.
# This is a hurdle for FSDP because they need to be in the same FSDP block
# These lines ensures that both modules stay together in the top-most block when
# the model has this tying enabled (almost all do; this property defaults to True)
if model.config.tie_word_embeddings:
causal_base_model._fsdp_wrap = False
tied_embeddings._fsdp_wrap = False
lm_head._fsdp_wrap = False
# FSDP Wrap and Activation Checkpoint every model block
model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
model.activation_checkpointing_fn = lambda module: isinstance(
module, block_type)
def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel,
init_device: Optional[str]) -> None:
"""Wrap an encoder/decoder HF model.
This works for T5, BART, Pegasus, PegasusX, but not all enc/dec (ProphetNet)
You have model.shared, model.encoder, model.decoder and model.lm_head, where
model.shared are the embeddings which are tied to model.lm_head, and
model.shared == model.encoder.embed_tokens and model.shared ==
model.decoder.embed_tokens
"""
tied_embeddings = model.get_input_embeddings()
encoder = model.get_encoder()
decoder = model.get_decoder()
lm_head = model.get_output_embeddings()
# some encoder/decoders have different layers for encoder vs decoder
encoder_block = hf_get_hidden_layers(encoder)
decoder_block = hf_get_hidden_layers(decoder)
modules = {
'encoder': encoder,
'decoder': decoder,
'encoder_block': encoder_block,
'decoder_block': decoder_block,
'lm_head': lm_head,
'tied_embeddings': tied_embeddings
}
for mod_name, module in modules.items():
if module is None:
raise ValueError(
f'Unable to FSDP-wrap this model! `{mod_name}` does not ' +
'follow common layer/weight naming conventions.')
decoder_block_type = type(decoder_block[0])
encoder_block_type = type(encoder_block[0])
if model.config.tie_word_embeddings:
# it is possible to train an enc/dec without tied embeddings, hence the check
tied_embeddings._fsdp_wrap = False
encoder._fsdp_wrap = False
decoder._fsdp_wrap = False
lm_head._fsdp_wrap = False
# FSDP Wrap and Activation Checkpoint every decoder block
model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type)
model.activation_checkpointing_fn = lambda module: isinstance(
module, decoder_block_type)
if encoder_block_type == decoder_block_type:
return
# need to wrap encoder blocks separately for ProhpetNet and Marian
model.fsdp_wrap_fn = lambda module: isinstance(module, encoder_block_type)
model.activation_checkpointing_fn = lambda module: isinstance(
module, encoder_block_type)