-
Notifications
You must be signed in to change notification settings - Fork 43
/
Copy pathmodeling_utils.py
376 lines (354 loc) · 13.5 KB
/
modeling_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
# -----------------------------------------------------------------------------
#
# Copyright (c) 2024 Qualcomm Innovation Center, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# -----------------------------------------------------------------------------
from collections import namedtuple
from typing import Dict, Optional, Tuple, Type
import torch
import torch.nn as nn
from transformers.models.codegen.modeling_codegen import (
CodeGenAttention,
CodeGenBlock,
CodeGenForCausalLM,
CodeGenModel,
)
from transformers.models.falcon.modeling_falcon import (
FalconAttention,
FalconForCausalLM,
FalconModel,
)
from transformers.models.gemma.modeling_gemma import (
GemmaAttention,
GemmaDecoderLayer,
GemmaForCausalLM,
GemmaModel,
GemmaRMSNorm,
)
from transformers.models.gemma2.modeling_gemma2 import (
Gemma2Attention,
Gemma2DecoderLayer,
Gemma2ForCausalLM,
Gemma2Model,
Gemma2RMSNorm,
)
from transformers.models.gpt2.modeling_gpt2 import GPT2Attention, GPT2Block, GPT2LMHeadModel, GPT2Model
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
GPTBigCodeAttention,
GPTBigCodeBlock,
GPTBigCodeForCausalLM,
GPTBigCodeModel,
)
from transformers.models.gptj.modeling_gptj import GPTJAttention, GPTJForCausalLM, GPTJModel
from transformers.models.llama.modeling_llama import (
LlamaAttention,
LlamaDecoderLayer,
LlamaForCausalLM,
LlamaModel,
LlamaRMSNorm,
)
from transformers.models.mistral.modeling_mistral import (
MistralAttention,
MistralDecoderLayer,
MistralForCausalLM,
MistralModel,
MistralRMSNorm,
)
from transformers.models.mistral3.modeling_mistral3 import (
Mistral3ForConditionalGeneration,
Mistral3RMSNorm,
)
from transformers.models.mixtral.modeling_mixtral import (
MixtralAttention,
MixtralDecoderLayer,
MixtralForCausalLM,
MixtralModel,
MixtralRMSNorm,
MixtralSparseMoeBlock,
)
from transformers.models.mllama.modeling_mllama import MllamaForCausalLM
from transformers.models.mpt.modeling_mpt import MptAttention, MptBlock, MptForCausalLM, MptModel
from transformers.models.phi.modeling_phi import PhiAttention, PhiForCausalLM, PhiModel
from transformers.models.phi3.modeling_phi3 import Phi3Attention, Phi3ForCausalLM, Phi3Model, Phi3RMSNorm
from transformers.models.pixtral.modeling_pixtral import PixtralRMSNorm
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, Qwen2ForCausalLM, Qwen2Model, Qwen2RMSNorm
from transformers.models.starcoder2.modeling_starcoder2 import (
Starcoder2Attention,
Starcoder2DecoderLayer,
Starcoder2ForCausalLM,
Starcoder2Model,
)
from transformers.models.whisper.modeling_whisper import (
WhisperAttention,
WhisperDecoder,
WhisperDecoderLayer,
WhisperEncoder,
WhisperForConditionalGeneration,
WhisperModel,
WhisperPositionalEmbedding,
)
from QEfficient.customop import CustomRMSNormAIC
from QEfficient.transformers.models.mistral3.modeling_mistral3 import QEffMistral3ForConditionalGeneration
from .models.codegen.modeling_codegen import (
QEffCodeGenAttention,
QeffCodeGenBlock,
QEffCodeGenForCausalLM,
QEffCodeGenModel,
)
from .models.falcon.modeling_falcon import (
QEffFalconAttention,
QEffFalconForCausalLM,
QEffFalconModel,
)
from .models.gemma.modeling_gemma import QEffGemmaAttention, QEffGemmaDecoderLayer, QEffGemmaForCausalLM, QEffGemmaModel
from .models.gemma2.modeling_gemma2 import (
QEffGemma2Attention,
QEffGemma2DecoderLayer,
QEffGemma2ForCausalLM,
QEffGemma2Model,
)
from .models.gpt2.modeling_gpt2 import QEffGPT2Attention, QEffGPT2Block, QEffGPT2LMHeadModel, QEffGPT2Model
from .models.gpt_bigcode.modeling_gpt_bigcode import (
QEffGPTBigCodeAttention,
QEffGPTBigCodeBlock,
QEffGPTBigCodeForCausalLM,
QEffGPTBigCodeModel,
)
from .models.gptj.modeling_gptj import QEffGPTJAttention, QEffGPTJForCausalLM, QEffGPTJModel
from .models.llama.modeling_llama import (
QEffLlamaAttention,
QEffLlamaDecoderLayer,
QEffLlamaForCausalLM,
QEffLlamaModel,
)
from .models.mistral.modeling_mistral import (
QEffMistralAttention,
QEffMistralDecoderLayer,
QEffMistralForCausalLM,
QEffMistralModel,
)
from .models.mixtral_moe.modeling_mixtral import (
QEffMixtralAttention,
QeffMixtralDecoderLayer,
QEffMixtralForCausalLM,
QEffMixtralModel,
QEffMixtralSparseMoeBlock,
)
from .models.mpt.modeling_mpt import QEffMptAttention, QEffMptBlock, QEffMptForCausalLM, QEFfMptModel
from .models.phi.modeling_phi import QEffPhiAttention, QEffPhiForCausalLM, QEffPhiModel
from .models.phi3.modeling_phi3 import QEffPhi3Attention, QEffPhi3ForCausalLM, QEffPhi3Model
from .models.qwen2.modeling_qwen2 import QEffQwen2Attention, QEffQwen2ForCausalLM, QEffQwen2Model
from .models.starcoder2.modeling_starcoder2 import (
QEffStarcoder2Attention,
QEFFStarcoder2DecoderLayer,
QEffStarcoder2ForCausalLM,
QEffStarcoder2Model,
)
from .models.whisper.modeling_whisper import (
QEffWhisperAttention,
QEffWhisperDecoder,
QEffWhisperDecoderLayer,
QEffWhisperEncoder,
QEffWhisperForConditionalGeneration,
QEffWhisperModel,
QEffWhisperPositionalEmbedding,
)
# Define a named tuple for ModelArchitectures
# Required for the Automation tool
ModelArchitectures = namedtuple("ModelArchitectures", ["architectures"])
# Create an instance of the named tuple
qeff_supported_architectures = ModelArchitectures(
[
GPT2LMHeadModel.__name__,
GPTJForCausalLM.__name__,
MptForCausalLM.__name__,
CodeGenForCausalLM.__name__,
LlamaForCausalLM.__name__,
GemmaForCausalLM.__name__,
Gemma2ForCausalLM.__name__,
MistralForCausalLM.__name__,
MixtralForCausalLM.__name__,
Phi3ForCausalLM.__name__,
PhiForCausalLM.__name__,
FalconForCausalLM.__name__,
Qwen2ForCausalLM.__name__,
Starcoder2ForCausalLM.__name__,
GPTBigCodeForCausalLM.__name__,
MllamaForCausalLM.__name__,
WhisperForConditionalGeneration.__name__,
Mistral3ForConditionalGeneration.__name__,
]
)
# Define a transformers layers to QEff layers dictionary
# While onboarding new models make sure to add the new layer maps to this dictionary.
TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = {
# GPT model layers
GPT2Model: QEffGPT2Model,
GPT2Block: QEffGPT2Block,
GPT2Attention: QEffGPT2Attention,
GPT2LMHeadModel: QEffGPT2LMHeadModel,
# GPTJ model layers
GPTJModel: QEffGPTJModel,
GPTJAttention: QEffGPTJAttention,
GPTJForCausalLM: QEffGPTJForCausalLM,
# Llama model layers
LlamaModel: QEffLlamaModel,
LlamaAttention: QEffLlamaAttention,
LlamaForCausalLM: QEffLlamaForCausalLM,
LlamaDecoderLayer: QEffLlamaDecoderLayer,
LlamaRMSNorm: CustomRMSNormAIC,
# Gemma model layers
GemmaModel: QEffGemmaModel,
GemmaAttention: QEffGemmaAttention,
GemmaForCausalLM: QEffGemmaForCausalLM,
GemmaDecoderLayer: QEffGemmaDecoderLayer,
GemmaRMSNorm: CustomRMSNormAIC,
# Gemma2 model layers
Gemma2Model: QEffGemma2Model,
Gemma2Attention: QEffGemma2Attention,
Gemma2ForCausalLM: QEffGemma2ForCausalLM,
Gemma2DecoderLayer: QEffGemma2DecoderLayer,
Gemma2RMSNorm: CustomRMSNormAIC,
# MPT model layers
MptAttention: QEffMptAttention,
MptBlock: QEffMptBlock,
MptModel: QEFfMptModel,
MptForCausalLM: QEffMptForCausalLM,
# CodeGen model layers
CodeGenAttention: QEffCodeGenAttention,
CodeGenModel: QEffCodeGenModel,
CodeGenForCausalLM: QEffCodeGenForCausalLM,
CodeGenBlock: QeffCodeGenBlock,
# Mistral model layers
MistralAttention: QEffMistralAttention,
MistralDecoderLayer: QEffMistralDecoderLayer,
MistralModel: QEffMistralModel,
MistralForCausalLM: QEffMistralForCausalLM,
MistralRMSNorm: CustomRMSNormAIC,
# Mistral3 model layers
Mistral3ForConditionalGeneration: QEffMistral3ForConditionalGeneration,
Mistral3RMSNorm: CustomRMSNormAIC,
# Mixtral model layers
MixtralAttention: QEffMixtralAttention,
MixtralDecoderLayer: QeffMixtralDecoderLayer,
MixtralModel: QEffMixtralModel,
MixtralForCausalLM: QEffMixtralForCausalLM,
MixtralRMSNorm: CustomRMSNormAIC,
MixtralSparseMoeBlock: QEffMixtralSparseMoeBlock,
# Phi3 model layers
Phi3Attention: QEffPhi3Attention,
Phi3Model: QEffPhi3Model,
Phi3ForCausalLM: QEffPhi3ForCausalLM,
Phi3RMSNorm: CustomRMSNormAIC,
# Phi model layers
PhiAttention: QEffPhiAttention,
PhiModel: QEffPhiModel,
PhiForCausalLM: QEffPhiForCausalLM,
# Pixtral model layers
PixtralRMSNorm: CustomRMSNormAIC,
# Falcon model layers
FalconAttention: QEffFalconAttention,
FalconForCausalLM: QEffFalconForCausalLM,
FalconModel: QEffFalconModel,
# Qwen2 model layers
Qwen2Attention: QEffQwen2Attention,
Qwen2ForCausalLM: QEffQwen2ForCausalLM,
Qwen2Model: QEffQwen2Model,
Qwen2RMSNorm: CustomRMSNormAIC,
# Starcoder2 model layers
Starcoder2Attention: QEffStarcoder2Attention,
Starcoder2ForCausalLM: QEffStarcoder2ForCausalLM,
Starcoder2Model: QEffStarcoder2Model,
Starcoder2DecoderLayer: QEFFStarcoder2DecoderLayer,
# Gpt_bigcode model layers
GPTBigCodeForCausalLM: QEffGPTBigCodeForCausalLM,
GPTBigCodeAttention: QEffGPTBigCodeAttention,
GPTBigCodeBlock: QEffGPTBigCodeBlock,
GPTBigCodeModel: QEffGPTBigCodeModel,
# Whisper encoder and decoder layers
WhisperAttention: QEffWhisperAttention,
WhisperDecoderLayer: QEffWhisperDecoderLayer,
WhisperEncoder: QEffWhisperEncoder,
WhisperDecoder: QEffWhisperDecoder,
WhisperPositionalEmbedding: QEffWhisperPositionalEmbedding,
WhisperModel: QEffWhisperModel,
WhisperForConditionalGeneration: QEffWhisperForConditionalGeneration,
}
def _prepare_cross_attention_mask(
cross_attention_mask: torch.Tensor,
num_vision_tokens: int,
dtype: str,
) -> Tuple[torch.Tensor, torch.Tensor]:
# reshape so it can be used by attn module
batch_size, text_total_length, *_ = cross_attention_mask.shape
cross_attention_mask = cross_attention_mask.repeat_interleave(num_vision_tokens, dim=3)
cross_attention_mask = cross_attention_mask.view(batch_size, text_total_length, -1)
cross_attention_mask = cross_attention_mask.unsqueeze(1)
# invert the mask
inverted_cross_attn_mask = (1.0 - cross_attention_mask).to(dtype)
cross_attention_mask = inverted_cross_attn_mask.masked_fill(
inverted_cross_attn_mask.to(torch.bool), torch.tensor(-10000.0, dtype=torch.float32)
)
# apply full-row bias, which return 4D tensor of shape [B, H, S1, 1] where value is 0 if the a full row in cross attn mask's
# last dimension contains negative infinity values, otherwise it's 1
negative_inf_value = torch.tensor(-10000.0, dtype=torch.float32)
full_text_row_masked_out_mask = (
(cross_attention_mask != negative_inf_value).any(dim=-1).type_as(cross_attention_mask)[..., None]
)
cross_attention_mask *= full_text_row_masked_out_mask
return cross_attention_mask, full_text_row_masked_out_mask
def _prepare_aspect_ratio_attention_mask(
aspect_ratio_mask: torch.Tensor,
num_patches: int,
target_length: int,
dtype: torch.dtype,
) -> torch.Tensor:
# Expand aspect ratio mask to target_length
batch_size, max_num_tiles = aspect_ratio_mask.shape
attention_mask = aspect_ratio_mask.view(batch_size, max_num_tiles, 1, 1).to(dtype)
attention_mask = attention_mask.repeat(1, 1, target_length, 1)
# Mask padding patches
pad_patches = target_length - num_patches
attention_mask[:, :, -pad_patches:] = 0
# Invert the mask (0 -> 1, 1 -> 0)
attention_mask = 1 - attention_mask
# Reshape to 2D and create 4D attention mask
# (batch_size, 1, max_num_tiles * target_length, max_num_tiles * target_length)
attention_mask = attention_mask.reshape(batch_size, max_num_tiles * target_length, 1)
attention_mask = attention_mask @ attention_mask.transpose(-1, -2) * torch.tensor(-10000.0, dtype=torch.float32)
attention_mask = attention_mask.unsqueeze(1)
return attention_mask
def _create_causal_mask(
position_ids,
target_length,
sliding_window: Optional[int] = None,
):
"""
A utility attention mask class that allows one to:
- Create a causal 4d mask
- Create a causal 4d mask with sliding window
"""
if sliding_window is not None:
query_indices = position_ids.unsqueeze(-1)
kv_indices = torch.arange(target_length).view(1, -1)
# --- Rolling buffer ---
pos_max = position_ids.max(1, keepdim=True).values
kv_start = (pos_max // target_length) * target_length
kv_indices_high = kv_indices + kv_start
kv_indices_low = torch.where(kv_indices_high < target_length, kv_indices, kv_indices_high - target_length)
kv_indices = torch.where(kv_indices_high > pos_max, kv_indices_low, kv_indices_high)
kv_indices = kv_indices.unsqueeze(1)
# ------
causal_mask = kv_indices > query_indices
attention_mask = causal_mask
window_indices = query_indices - sliding_window + 1
window_mask = kv_indices < window_indices
attention_mask = attention_mask | window_mask
attention_mask = attention_mask.unsqueeze(1)
else:
query_indices = position_ids.unsqueeze(-1)
kv_indices = torch.arange(target_length).view(1, 1, -1)
attention_mask = kv_indices > query_indices
attention_mask = attention_mask.unsqueeze(1)
return attention_mask