-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rewriting: Add convert-ml-program-to-memref #2580
Conversation
…nd use bufferization to_tensor instead of cast
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2580 +/- ##
=======================================
Coverage 89.61% 89.62%
=======================================
Files 360 361 +1
Lines 46209 46243 +34
Branches 6986 6987 +1
=======================================
+ Hits 41410 41444 +34
Misses 3724 3724
Partials 1075 1075 ☔ View full report in Codecov by Sentry. |
class ConvertGlobalLoadConst(RewritePattern): | ||
@op_type_rewrite_pattern | ||
def match_and_rewrite( | ||
self, op: ml_program.GlobalLoadConstant, rewriter: PatternRewriter | ||
) -> None: | ||
assert isinstance(op_type := op.result.type, TensorType) | ||
op_type = cast(TensorType[Any], op_type) | ||
new_type = memref.MemRefType(op_type.element_type, op_type.shape) | ||
rewriter.replace_matched_op( | ||
( | ||
mem := memref.GetGlobal.get(op.global_attr, new_type), | ||
bufferization.ToTensorOp(mem.memref), | ||
) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is dark magic 🧙
But necessary dark magic when you have to deal with linking and tensor I guess.
Can you simply comment on that here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NB: As said, this is necessary in general as you want it to work with external tensor symbols.
What about special casing it for when the symbol actually is in the same module?
I'm not 100% sure that this is the intended way that ml_program should be lowered but it seems to work for our ONNX kernels.