-
Notifications
You must be signed in to change notification settings - Fork 251
/
Copy pathgemma_backbone.py
316 lines (296 loc) · 13.1 KB
/
gemma_backbone.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
import keras
from keras import ops
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.modeling.reversible_embedding import (
ReversibleEmbedding,
)
from keras_hub.src.models.backbone import Backbone
from keras_hub.src.models.gemma.gemma_decoder_block import GemmaDecoderBlock
from keras_hub.src.models.gemma.rms_normalization import RMSNormalization
@keras_hub_export("keras_hub.models.GemmaBackbone")
class GemmaBackbone(Backbone):
"""Gemma core network with hyperparameters.
This backbone implements the base Transformer network for the Gemma model.
It includes the embedding lookups and transformer layers. This backbone
will output the final hidden states for each token, not generative
predictions over the vocabulary space. For a higher-level object for text
generation, see `keras_hub.models.GemmaCausalLM`.
The default constructor gives a fully customizable, randomly initialized
Gemma model with any number of layers, heads, and embedding dimensions. To
load preset architectures and weights, use the `from_preset` constructor.
Args:
vocabulary_size: int. The size of the token vocabulary.
num_layers: int. The number of transformer layers.
num_query_heads: int. The number of heads for the query projections in
the attention layer.
num_key_value_heads: int. The number of heads for the key and value
projections in the attention layer.
hidden_dim: int. The size of the transformer hidden state at the end
of each transformer layer.
intermediate_dim: int. The output dimension of the first Dense layer in
a two-layer feedforward network for each transformer.
head_dim: int. The size of each attention head.
layer_norm_epsilon: float. The epsilon value user for every layer norm
in the transformer model.
dropout: float. Dropout probability for the Transformer encoder.
query_head_dim_normalize: boolean. If `True` normalize the query before
attention with `head_dim`. If `False`, normalize the query with
`hidden_dim / num_query_heads`. Defaults to True.
use_post_ffw_norm: boolean. Whether to normalize after the feedforward
block. Defaults to False.
use_post_attention_norm: boolean. Whether to normalize after the
attention block. Defaults to False.
attention_logit_soft_cap: None or int. Soft cap for the attention
logits. Defaults to None.
final_logit_soft_cap: None or int. Soft cap for the final logits.
Defaults to None.
use_sliding_window_attention boolean. Whether to use sliding local
window attention. Defaults to False.
sliding_window_size: int. Size of the sliding local window. Defaults to
4096.
dtype: string or `keras.mixed_precision.DTypePolicy`. The dtype to use
for the models computations and weights. Note that some
computations, such as softmax and layer normalization will always
be done a float32 precision regardless of dtype.
Example:
```python
input_data = {
"token_ids": np.ones(shape=(1, 12), dtype="int32"),
"padding_mask": np.array([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]),
}
# Pretrained Gemma decoder.
model = keras_hub.models.GemmaBackbone.from_preset("gemma_2b_en")
model(input_data)
# Randomly initialized Gemma decoder with custom config.
model = keras_hub.models.GemmaBackbone(
vocabulary_size=50257,
num_layers=12,
num_query_heads=12,
num_key_value_heads=1,
hidden_dim=768,
intermediate_dim=3072,
head_dim=64,
)
model(input_data)
```
"""
def __init__(
self,
vocabulary_size,
num_layers,
num_query_heads,
num_key_value_heads,
hidden_dim,
intermediate_dim,
head_dim,
query_head_dim_normalize=True,
use_post_ffw_norm=False,
use_post_attention_norm=False,
attention_logit_soft_cap=None,
final_logit_soft_cap=None,
use_sliding_window_attention=False,
sliding_window_size=4096,
layer_norm_epsilon=1e-6,
dropout=0,
dtype=None,
**kwargs,
):
# === Layers ===
self.token_embedding = ReversibleEmbedding(
input_dim=vocabulary_size,
output_dim=hidden_dim,
tie_weights=True,
embeddings_initializer=keras.initializers.VarianceScaling(
scale=1.0,
mode="fan_in",
distribution="untruncated_normal",
seed=None,
),
dtype=dtype,
logit_soft_cap=final_logit_soft_cap,
name="token_embedding",
)
self.transformer_layers = []
for i in range(num_layers):
sliding_window = use_sliding_window_attention and (i % 2 == 0)
layer = GemmaDecoderBlock(
intermediate_dim=intermediate_dim,
hidden_dim=hidden_dim,
num_query_heads=num_query_heads,
head_dim=head_dim,
num_key_value_heads=num_key_value_heads,
query_head_dim_normalize=query_head_dim_normalize,
use_post_ffw_norm=use_post_ffw_norm,
use_post_attention_norm=use_post_attention_norm,
logit_soft_cap=attention_logit_soft_cap,
use_sliding_window_attention=sliding_window,
sliding_window_size=sliding_window_size,
dropout=dropout,
dtype=dtype,
name=f"decoder_block_{i}",
)
self.transformer_layers.append(layer)
self.layer_norm = RMSNormalization(
epsilon=layer_norm_epsilon,
dtype=dtype,
name="final_normalization",
)
# === Functional Model ===
token_id_input = keras.Input(
shape=(None,), dtype="float32", name="token_ids"
)
padding_mask_input = keras.Input(
shape=(None,), dtype="float32", name="padding_mask"
)
x = self.token_embedding(token_id_input)
x = x * ops.cast(ops.sqrt(hidden_dim), x.dtype)
for transformer_layer in self.transformer_layers:
x = transformer_layer(x, padding_mask=padding_mask_input)
sequence_output = self.layer_norm(x)
super().__init__(
inputs={
"token_ids": token_id_input,
"padding_mask": padding_mask_input,
},
outputs=sequence_output,
dtype=dtype,
**kwargs,
)
# === Config ===
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_query_heads = num_query_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.head_dim = head_dim
self.layer_norm_epsilon = layer_norm_epsilon
self.dropout = dropout
self.query_head_dim_normalize = query_head_dim_normalize
self.use_post_ffw_norm = use_post_ffw_norm
self.use_post_attention_norm = use_post_attention_norm
self.attention_logit_soft_cap = attention_logit_soft_cap
self.final_logit_soft_cap = final_logit_soft_cap
self.sliding_window_size = sliding_window_size
self.use_sliding_window_attention = use_sliding_window_attention
def get_config(self):
config = super().get_config()
config.update(
{
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_query_heads": self.num_query_heads,
"num_key_value_heads": self.num_key_value_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"head_dim": self.head_dim,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dropout": self.dropout,
"query_head_dim_normalize": self.query_head_dim_normalize,
"use_post_ffw_norm": self.use_post_ffw_norm,
"use_post_attention_norm": self.use_post_attention_norm,
"final_logit_soft_cap": self.final_logit_soft_cap,
"attention_logit_soft_cap": self.attention_logit_soft_cap,
"sliding_window_size": self.sliding_window_size,
"use_sliding_window_attention": (
self.use_sliding_window_attention
),
}
)
return config
@staticmethod
def get_layout_map(
device_mesh,
model_parallel_dim_name="model",
data_parallel_dim_name="batch",
):
"""Get a `keras.distribution.LayoutMap` for model parallel distribution.
The returned `LayoutMap` contains the sharding spec for the gemma
backbone weights, so that you can use it to distribute weights across
the accelerators.
Example:
```
# Feel free to change the mesh shape to balance data and model
# parallelism
mesh = keras.distribution.DeviceMesh(
shape=(1, 8), axis_names=('batch', 'model'),
devices=keras.distribution.list_devices())
layout_map = GemmaBackbone.get_layout_map(
mesh, model_parallel_dim_name="model")
distribution = keras.distribution.ModelParallel(
layout_map=layout_map, batch_dim_name='batch')
with distribution.scope():
gemma_model = keras_hub.models.GemmaCausalLM.from_preset()
```
To see how the layout map was applied, load the model then run (for one
decoder block):
```
embedding_layer = gemma_model.backbone.get_layer("token_embedding")
decoder_block_1 = gemma_model.backbone.get_layer('decoder_block_1')
for variable in embedding_layer.weights + decoder_block_1.weights:
print(
f'{variable.path:<58} {str(variable.shape):<16} '
f'{str(variable.value.sharding.spec)}'
)
```
Args:
device_mesh: The `keras.distribution.DeviceMesh` instance for
distribution.
model_parallel_dim_name: The axis name of the device mesh, where
the weights should be partition on.
data_parallel_dim_name: The axis name of the device mesh, where
the data should be partition on.
Return:
`keras.distribution.LayoutMap` that contains the sharding spec
for all the model weights.
"""
# The weight path and shape of the Gemma backbone is like below (for 2G)
# token_embedding/embeddings, (256128, 2048)
# repeat block for decoder
# ...
# decoder_block_17/pre_attention_norm/scale, (2048,)
# decoder_block_17/attention/query/kernel, (8, 2048, 256)
# decoder_block_17/attention/key/kernel, (8, 2048, 256)
# decoder_block_17/attention/value/kernel, (8, 2048, 256)
# decoder_block_17/attention/attention_output/kernel, (8, 256, 2048)
# decoder_block_17/pre_ffw_norm/scale, (2048,)
# decoder_block_17/ffw_gating/kernel, (2048, 16384)
# decoder_block_17/ffw_gating_2/kernel, (2048, 16384)
# decoder_block_17/ffw_linear/kernel, (16384, 2048)
if not isinstance(device_mesh, keras.distribution.DeviceMesh):
raise ValueError(
"Invalid device_mesh type. Expected "
f"`keras.distribution.Device`, got {type(device_mesh)}"
)
if model_parallel_dim_name not in device_mesh.axis_names:
raise ValueError(
f"{model_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
if data_parallel_dim_name not in device_mesh.axis_names:
raise ValueError(
f"{data_parallel_dim_name} is not found in the "
f"device_mesh.axis_names. {device_mesh.axis_name=}"
)
# Note that it is possible to further config the mesh to be 3D, eg
# (data, seq, model). We leave it as 2D for now for simplicity.
data_dim = data_parallel_dim_name
model_dim = model_parallel_dim_name
# The sharding config is based on the Gemma team training config.
# See https://arxiv.org/abs/2403.08295
layout_map = keras.distribution.LayoutMap(device_mesh)
layout_map["token_embedding/embeddings"] = (model_dim, data_dim)
layout_map["decoder_block.*attention.*(query|key|value).kernel"] = (
model_dim,
data_dim,
None,
)
layout_map["decoder_block.*attention_output.kernel"] = (
model_dim,
None,
data_dim,
)
layout_map["decoder_block.*ffw_gating.kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_gating_2.kernel"] = (data_dim, model_dim)
layout_map["decoder_block.*ffw_linear.kernel"] = (model_dim, data_dim)
return layout_map