Skip to content

Commit 3b850e5

Browse files
authored
Fix fusion ordering for partial rotary embedding (#2402)
The partial-rotary-embedding fusion depends on the cos-sin-cache fusion. Fix the fusion ordering. This is necessary for GQA fusion in models like Phi4 (with partial-rotary-embedding). TODO: Add test-case. The one I have is huge. Need to create a smaller test-case. Signed-off-by: Ganesan Ramalingam <grama@microsoft.com>
1 parent e71c889 commit 3b850e5

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

onnxscript/rewriter/ort_fusions/_core.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,8 +80,9 @@ def fuse(func, **kwargs):
8080
fusion_count["skip_layer_normalization"] = fuse(fuse_skip_layer_normalization)
8181
fusion_count["skip_rms_normalization"] = fuse(fuse_skip_rms_normalization)
8282
fusion_count["rotary_embedding"] = fuse(fuse_rotary_embedding)
83-
fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding)
8483
fusion_count["cos_sin_cache"] = fuse(fuse_cos_sin_cache)
84+
fusion_count["partial_rotary_embedding"] = fuse(fuse_partial_rotary_embedding)
85+
8586
# We apply shape inference after the SDPA fusion as new nodes are added
8687
# in the rewrite rule for certain patterns of SDPA.
8788
fusion_count["sdpa"] = fuse(fuse_sdpa, apply_shape_inference=True)

0 commit comments

Comments
 (0)