Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rotary embedding fusion rule (part 1) #1981
base: main
Are you sure you want to change the base?
Add rotary embedding fusion rule (part 1) #1981
Changes from 12 commits
8de7231
a20b903
b8f7a08
315c94e
2219fd3
f77f0e7
5ec9d1e
90f0b7b
1fdc19b
eb916b8
d874dbc
a745039
17c06c3
c7c7c79
834815b
9a4a58e
766791d
2b5309a
4c0e5f9
File filter
Filter by extension
Conversations
Jump to
There are no files selected for viewing
Check warning on line 314 in onnxscript/optimizer/_constant_folding.py
Codecov / codecov/patch
onnxscript/optimizer/_constant_folding.py#L314
Check warning on line 317 in onnxscript/optimizer/_constant_folding.py
Codecov / codecov/patch
onnxscript/optimizer/_constant_folding.py#L317
Check warning on line 320 in onnxscript/optimizer/_constant_folding.py
Codecov / codecov/patch
onnxscript/optimizer/_constant_folding.py#L320
Check warning on line 323 in onnxscript/optimizer/_constant_folding.py
Codecov / codecov/patch
onnxscript/optimizer/_constant_folding.py#L323
Check warning on line 91 in onnxscript/rewriter/_ir_utils.py
Codecov / codecov/patch
onnxscript/rewriter/_ir_utils.py#L91
Check warning on line 98 in onnxscript/rewriter/_ir_utils.py
Codecov / codecov/patch
onnxscript/rewriter/_ir_utils.py#L97-L98
Check warning on line 104 in onnxscript/rewriter/_ir_utils.py
Codecov / codecov/patch
onnxscript/rewriter/_ir_utils.py#L104
Check warning on line 50 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
Codecov / codecov/patch
onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L50
Check warning on line 52 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
Codecov / codecov/patch
onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L52
Check warning on line 55 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py
Codecov / codecov/patch
onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache.py#L55
Check warning on line 29 in onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py
Codecov / codecov/patch
onnxscript/rewriter/onnxruntime/xformers/cos_sin_cache_test.py#L29
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Although this logic is correct and makes sense, this doesn't match the function logic in the op definition. Is it correct to assume that the pattern logic should mimic the onnx function in the op schema?
Currently in the op schema, this pattern would look like after x1, x2 (which uses split instead of slice for non-interleaved case):
So the concat happens after the multiplication
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It does not have to match the function logic in the op definition. But it has to match the function graph produced by the ONNX exporter from the logic defined in the source (eg., the transformers implementation).
But what we have to guarantee or ensure is that replacing this logic by the pattern in
rewrite
is fine: that they will both produce the same values.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Specifically, it is more important to match the source logic like this transformer code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Putting all this together, there are 3 parts to these rewrite-rules:
pattern
should typically be aligned with the subgraph pattern we see in the ONNX graphs produced by the exporter (which itself depends on the source pytorch code).rewrite
part is aligned with the (fused) op definition (existing in ORT or being introduced to ONNX).check
condition has to be strong enough to guarantee that the replacement is sound. So, that we can be sure we will produce the same outputs with or without the optimization.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to schema, x can be a 3D tensor as well. And num_heads are necessary to be known in cases with 3D tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That means the optimization is safe and correct (in this regard). To generalize and allow a 3D also here, we would need to guarantee that the entire-fusion is guaranteed to be semantically correct ... it is not enough to know that the RotaryEmbedding op permits 3D inputs.
What do you think about the correctness of this fusion optimization? Do you think it is fine to generalize and allow 3D here?
Check warning on line 33 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py
Codecov / codecov/patch
onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L33
Check warning on line 35 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py
Codecov / codecov/patch
onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L35
Check warning on line 38 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py
Codecov / codecov/patch
onnxscript/rewriter/onnxruntime/xformers/rotary_embedding.py#L38
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just curious about the domain here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. I am currently using these to split the fusion optimization into multiple stages. We may need to clean this up finally. For now, we also need to target the existing RotaryEmbedding op in ORT (which is what we can test against also). Eventually, we can target the new proposed RotaryEmbedding op ... so we may also need to support some variations in the fusion optimization (depending on target ORT/ONNX versions).
Check warning on line 23 in onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py
Codecov / codecov/patch
onnxscript/rewriter/onnxruntime/xformers/rotary_embedding_test.py#L23