@@ -163,6 +163,13 @@ def pattern(
163163 ):
164164 # Reshape query from (B, S, D) to (B, S, H, D/H)
165165 query_BSHDh = op .Reshape (query_BSD , pattern .ANY_VALUE , _outputs = ["query_BSHDh" ])
166+ # Qwen variant uses normalization of query/key before rotary embedding:
167+ # The normalization can happen before (eg., Qwen) or after the Transpose (eg., Gemma).
168+ query_BSHDh_normalized = op .SimplifiedLayerNormalization (
169+ query_BSHDh , pattern .ANY_VALUE , axis = - 1 , _outputs = ["query_BSHDh_normalized" ]
170+ )
171+ query_BSHDh = pattern .OrValue ([query_BSHDh , query_BSHDh_normalized ])
172+
166173 # Transpose from (B, S, H, D/H) to (B, H, S, D/H)
167174 query_BHSDh = op .Transpose (query_BSHDh , perm = [0 , 2 , 1 , 3 ])
168175
@@ -174,6 +181,11 @@ def pattern(
174181
175182 # Reshape key from (B, S, Dkv) to (B, S, Hkv, D/H)
176183 key_BSHkvDh = op .Reshape (key_BSDkv , pattern .ANY_VALUE , _outputs = ["key_BSHkvDh" ])
184+ key_BSHkvDh_normalized = op .SimplifiedLayerNormalization (
185+ key_BSHkvDh , pattern .ANY_VALUE , axis = - 1 , _outputs = ["key_BSHkvDh_normalized" ]
186+ )
187+ key_BSHkvDh = pattern .OrValue ([key_BSHkvDh , key_BSHkvDh_normalized ])
188+
177189 # Transpose from (B, S, Hkv, D/H) to (B, Hkv, S, D/H)
178190 key_BHkvSDh = op .Transpose (key_BSHkvDh , perm = [0 , 2 , 1 , 3 ])
179191
@@ -209,6 +221,8 @@ def pattern(
209221 # that share key/value.
210222
211223 key_seq_BHkvTDh = op .Concat (past_key , key_BHkvSDh_rope , axis = - 2 )
224+ # Concat with past_key is optional:
225+ key_seq_BHkvTDh = pattern .OrValue ([key_seq_BHkvTDh , key_BHkvSDh_rope ])
212226 key_seq_BHkv1TDh = op .Unsqueeze (key_seq_BHkvTDh , 2 )
213227 key_seq_BHkvGTDh = op .Expand (key_seq_BHkv1TDh , pattern .ANY_VALUE )
214228 key_seq_BHTDh = op .Reshape (
@@ -218,6 +232,8 @@ def pattern(
218232 # Concatenate past_value cache and current value, expand across heads
219233 # that share key/value.
220234 value_seq_BHkvTDh = op .Concat (past_value , value_BHkvSDh , axis = - 2 )
235+ # Concat with past_value is optional:
236+ value_seq_BHkvTDh = pattern .OrValue ([value_seq_BHkvTDh , value_BHkvSDh ])
221237 value_seq_BHkv1TDh = op .Unsqueeze (value_seq_BHkvTDh , 2 )
222238 value_seq_BHkvGTDh = op .Expand (value_seq_BHkv1TDh , pattern .ANY_VALUE )
223239 value_seq_BHTDh = op .Reshape (
@@ -254,8 +270,23 @@ def check(
254270 query_BSHDh ,
255271 key_BSHkvDh ,
256272 mask ,
273+ query_BSHDh_normalized = None ,
274+ query_BHSDh_normalized = None ,
275+ key_BSHkvDh_normalized = None ,
276+ key_BHkvSDh_normalized = None ,
257277 ** _ ,
258278 ):
279+ result = pattern .MatchResult ()
280+ if query_BSHDh_normalized is not None and query_BHSDh_normalized is not None :
281+ return result .fail (
282+ "Query normalized twice" ,
283+ [query_BSHDh_normalized , query_BHSDh_normalized ],
284+ )
285+ if key_BSHkvDh_normalized is not None and key_BHkvSDh_normalized is not None :
286+ return result .fail (
287+ "Key normalized twice" ,
288+ [key_BSHkvDh_normalized , key_BHkvSDh_normalized ],
289+ )
259290 bindings : dict [str , Dim ] = {}
260291
261292 def no_match (val : ir .Value , dims : Sequence [str ]) -> bool :
@@ -268,17 +299,16 @@ def no_match(val: ir.Value, dims: Sequence[str]) -> bool:
268299 if no_match (value_BSDkv , ["B" , "S" , "Dkv" ]):
269300 return False
270301
271- if no_match (past_key , ["B" , "Hkv" , "P" , "Dh" ]):
302+ if past_key is not None and no_match (past_key , ["B" , "Hkv" , "P" , "Dh" ]):
272303 return False
273- if no_match (past_value , ["B" , "Hkv" , "P" , "Dv" ]):
304+ if past_value is not None and no_match (past_value , ["B" , "Hkv" , "P" , "Dv" ]):
274305 return False
275306
276307 # TODO: verify Reshapes:
277308 # eg.: verify bindings["B"] * bindings["H"] == bindings["B*H"]:
278309 # and bindings["H"] * bindings["Dh"] == bindings["H*Dh"]:
279310 # or check Reshape's shape-input value
280311
281- result = pattern .MatchResult ()
282312 num_heads = _ir_utils .get_dim (query_BSHDh , 2 )
283313 kv_num_heads = _ir_utils .get_dim (key_BSHkvDh , 2 )
284314 if not isinstance (num_heads , int ):
@@ -330,7 +360,9 @@ def rewrite(
330360 mask ,
331361 query_BSHDh ,
332362 key_BSHkvDh ,
363+ query_BSHDh_normalized = None ,
333364 query_BHSDh_normalized = None ,
365+ key_BSHkvDh_normalized = None ,
334366 key_BHkvSDh_normalized = None ,
335367 ** _ ,
336368 ):
@@ -352,9 +384,10 @@ def rewrite(
352384 max_seq_length = op .ReduceMax (seqlens_k , zero_int64_1d , keepdims = 0 )
353385 total_seq_length_int32 = op .Add (max_seq_length , one_int32_0d )
354386
355- if query_BHSDh_normalized is not None :
387+ normalized_query = query_BHSDh_normalized or query_BSHDh_normalized
388+ if normalized_query is not None :
356389 # We apply normalization without the transpose, which is fused into GQA
357- norm_node = query_BHSDh_normalized .producer ()
390+ norm_node = normalized_query .producer ()
358391 norm_attrs = norm_node .attributes
359392 norm_scale = norm_node .inputs [1 ]
360393 query_BSHDh_normalized = op .SimplifiedLayerNormalization (
@@ -363,9 +396,10 @@ def rewrite(
363396 reshape_BSHDh_to_BSD = op .Constant (value_ints = [0 , 0 , - 1 ])
364397 query_BSD = op .Reshape (query_BSHDh_normalized , reshape_BSHDh_to_BSD )
365398
366- if key_BHkvSDh_normalized is not None :
399+ normalized_key = key_BHkvSDh_normalized or key_BSHkvDh_normalized
400+ if normalized_key is not None :
367401 # We apply normalization without the transpose, which is fused into GQA
368- norm_node = key_BHkvSDh_normalized .producer ()
402+ norm_node = normalized_key .producer ()
369403 norm_attrs = norm_node .attributes
370404 norm_scale = norm_node .inputs [1 ]
371405 key_BSHkvDh_normalized = op .SimplifiedLayerNormalization (
0 commit comments