Skip to content

Commit 647754f

Browse files
authored
Extend GQA fusion for Qwen (#2662)
A couple of extensions to the GQA fusion pattern: * Support the case where there is no past key/value cache, and * Normalization and Transpose occur in the opposite order in Qwen (which has the same behavior). Support this pattern variation. TODO: add test-cases to cover and validate this --------- Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent fe50b83 commit 647754f

File tree

1 file changed

+41
-7
lines changed
  • onnxscript/rewriter/ort_fusions

1 file changed

+41
-7
lines changed

onnxscript/rewriter/ort_fusions/gqa.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,13 @@ def pattern(
163163
):
164164
# Reshape query from (B, S, D) to (B, S, H, D/H)
165165
query_BSHDh = op.Reshape(query_BSD, pattern.ANY_VALUE, _outputs=["query_BSHDh"])
166+
# Qwen variant uses normalization of query/key before rotary embedding:
167+
# The normalization can happen before (eg., Qwen) or after the Transpose (eg., Gemma).
168+
query_BSHDh_normalized = op.SimplifiedLayerNormalization(
169+
query_BSHDh, pattern.ANY_VALUE, axis=-1, _outputs=["query_BSHDh_normalized"]
170+
)
171+
query_BSHDh = pattern.OrValue([query_BSHDh, query_BSHDh_normalized])
172+
166173
# Transpose from (B, S, H, D/H) to (B, H, S, D/H)
167174
query_BHSDh = op.Transpose(query_BSHDh, perm=[0, 2, 1, 3])
168175

@@ -174,6 +181,11 @@ def pattern(
174181

175182
# Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H)
176183
key_BSHkvDh = op.Reshape(key_BSDkv, pattern.ANY_VALUE, _outputs=["key_BSHkvDh"])
184+
key_BSHkvDh_normalized = op.SimplifiedLayerNormalization(
185+
key_BSHkvDh, pattern.ANY_VALUE, axis=-1, _outputs=["key_BSHkvDh_normalized"]
186+
)
187+
key_BSHkvDh = pattern.OrValue([key_BSHkvDh, key_BSHkvDh_normalized])
188+
177189
# Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
178190
key_BHkvSDh = op.Transpose(key_BSHkvDh, perm=[0, 2, 1, 3])
179191

@@ -209,6 +221,8 @@ def pattern(
209221
# that share key/value.
210222

211223
key_seq_BHkvTDh = op.Concat(past_key, key_BHkvSDh_rope, axis=-2)
224+
# Concat with past_key is optional:
225+
key_seq_BHkvTDh = pattern.OrValue([key_seq_BHkvTDh, key_BHkvSDh_rope])
212226
key_seq_BHkv1TDh = op.Unsqueeze(key_seq_BHkvTDh, 2)
213227
key_seq_BHkvGTDh = op.Expand(key_seq_BHkv1TDh, pattern.ANY_VALUE)
214228
key_seq_BHTDh = op.Reshape(
@@ -218,6 +232,8 @@ def pattern(
218232
# Concatenate past_value cache and current value, expand across heads
219233
# that share key/value.
220234
value_seq_BHkvTDh = op.Concat(past_value, value_BHkvSDh, axis=-2)
235+
# Concat with past_value is optional:
236+
value_seq_BHkvTDh = pattern.OrValue([value_seq_BHkvTDh, value_BHkvSDh])
221237
value_seq_BHkv1TDh = op.Unsqueeze(value_seq_BHkvTDh, 2)
222238
value_seq_BHkvGTDh = op.Expand(value_seq_BHkv1TDh, pattern.ANY_VALUE)
223239
value_seq_BHTDh = op.Reshape(
@@ -254,8 +270,23 @@ def check(
254270
query_BSHDh,
255271
key_BSHkvDh,
256272
mask,
273+
query_BSHDh_normalized=None,
274+
query_BHSDh_normalized=None,
275+
key_BSHkvDh_normalized=None,
276+
key_BHkvSDh_normalized=None,
257277
**_,
258278
):
279+
result = pattern.MatchResult()
280+
if query_BSHDh_normalized is not None and query_BHSDh_normalized is not None:
281+
return result.fail(
282+
"Query normalized twice",
283+
[query_BSHDh_normalized, query_BHSDh_normalized],
284+
)
285+
if key_BSHkvDh_normalized is not None and key_BHkvSDh_normalized is not None:
286+
return result.fail(
287+
"Key normalized twice",
288+
[key_BSHkvDh_normalized, key_BHkvSDh_normalized],
289+
)
259290
bindings: dict[str, Dim] = {}
260291

261292
def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
@@ -268,17 +299,16 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
268299
if no_match(value_BSDkv, ["B", "S", "Dkv"]):
269300
return False
270301

271-
if no_match(past_key, ["B", "Hkv", "P", "Dh"]):
302+
if past_key is not None and no_match(past_key, ["B", "Hkv", "P", "Dh"]):
272303
return False
273-
if no_match(past_value, ["B", "Hkv", "P", "Dv"]):
304+
if past_value is not None and no_match(past_value, ["B", "Hkv", "P", "Dv"]):
274305
return False
275306

276307
# TODO: verify Reshapes:
277308
# eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]:
278309
# and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]:
279310
# or check Reshape's shape-input value
280311

281-
result = pattern.MatchResult()
282312
num_heads = _ir_utils.get_dim(query_BSHDh, 2)
283313
kv_num_heads = _ir_utils.get_dim(key_BSHkvDh, 2)
284314
if not isinstance(num_heads, int):
@@ -330,7 +360,9 @@ def rewrite(
330360
mask,
331361
query_BSHDh,
332362
key_BSHkvDh,
363+
query_BSHDh_normalized=None,
333364
query_BHSDh_normalized=None,
365+
key_BSHkvDh_normalized=None,
334366
key_BHkvSDh_normalized=None,
335367
**_,
336368
):
@@ -352,9 +384,10 @@ def rewrite(
352384
max_seq_length = op.ReduceMax(seqlens_k, zero_int64_1d, keepdims=0)
353385
total_seq_length_int32 = op.Add(max_seq_length, one_int32_0d)
354386

355-
if query_BHSDh_normalized is not None:
387+
normalized_query = query_BHSDh_normalized or query_BSHDh_normalized
388+
if normalized_query is not None:
356389
# We apply normalization without the transpose, which is fused into GQA
357-
norm_node = query_BHSDh_normalized.producer()
390+
norm_node = normalized_query.producer()
358391
norm_attrs = norm_node.attributes
359392
norm_scale = norm_node.inputs[1]
360393
query_BSHDh_normalized = op.SimplifiedLayerNormalization(
@@ -363,9 +396,10 @@ def rewrite(
363396
reshape_BSHDh_to_BSD = op.Constant(value_ints=[0, 0, -1])
364397
query_BSD = op.Reshape(query_BSHDh_normalized, reshape_BSHDh_to_BSD)
365398

366-
if key_BHkvSDh_normalized is not None:
399+
normalized_key = key_BHkvSDh_normalized or key_BSHkvDh_normalized
400+
if normalized_key is not None:
367401
# We apply normalization without the transpose, which is fused into GQA
368-
norm_node = key_BHkvSDh_normalized.producer()
402+
norm_node = normalized_key.producer()
369403
norm_attrs = norm_node.attributes
370404
norm_scale = norm_node.inputs[1]
371405
key_BSHkvDh_normalized = op.SimplifiedLayerNormalization(

0 commit comments

Comments
 (0)