Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewriting: Add convert-ml-program-to-memref #2580

Merged
merged 9 commits into from
May 29, 2024

Conversation

superlopuh
Copy link
Member

I'm not 100% sure that this is the intended way that ml_program should be lowered but it seems to work for our ONNX kernels.

@superlopuh superlopuh added the rewriting changes on xDSL rewrite engine label May 13, 2024
@superlopuh superlopuh self-assigned this May 13, 2024
Copy link

codecov bot commented May 13, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 89.62%. Comparing base (fcaed29) to head (d719f69).

Additional details and impacted files
@@           Coverage Diff           @@
##             main    #2580   +/-   ##
=======================================
  Coverage   89.61%   89.62%           
=======================================
  Files         360      361    +1     
  Lines       46209    46243   +34     
  Branches     6986     6987    +1     
=======================================
+ Hits        41410    41444   +34     
  Misses       3724     3724           
  Partials     1075     1075           

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

xdsl/dialects/memref.py Outdated Show resolved Hide resolved
Comment on lines +45 to +58
class ConvertGlobalLoadConst(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: ml_program.GlobalLoadConstant, rewriter: PatternRewriter
) -> None:
assert isinstance(op_type := op.result.type, TensorType)
op_type = cast(TensorType[Any], op_type)
new_type = memref.MemRefType(op_type.element_type, op_type.shape)
rewriter.replace_matched_op(
(
mem := memref.GetGlobal.get(op.global_attr, new_type),
bufferization.ToTensorOp(mem.memref),
)
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is dark magic 🧙
But necessary dark magic when you have to deal with linking and tensor I guess.
Can you simply comment on that here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NB: As said, this is necessary in general as you want it to work with external tensor symbols.
What about special casing it for when the symbol actually is in the same module?

@AntonLydike AntonLydike changed the title rewriting: add convert-ml-program-to-memref rewriting: Add convert-ml-program-to-memref May 22, 2024
@superlopuh superlopuh merged commit 8fbb3b2 into main May 29, 2024
10 checks passed
@superlopuh superlopuh deleted the sasha/ml_program/to-tensor-const branch May 29, 2024 08:33
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
rewriting changes on xDSL rewrite engine
Projects
No open projects
Status: Done
Development

Successfully merging this pull request may close these issues.

2 participants