-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transformations: (csl-stencil) Add pass to handle async ops and enclosing cf #3192
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #3192 +/- ##
==========================================
+ Coverage 90.01% 90.03% +0.02%
==========================================
Files 431 432 +1
Lines 54282 54435 +153
Branches 8410 8443 +33
==========================================
+ Hits 48863 49012 +149
+ Misses 4061 4056 -5
- Partials 1358 1367 +9 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice progress! I have some comments here and there
) | ||
|
||
no_params = FunctionType.from_lists([], []) | ||
cond_task_id = self.counter + 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are you sure that this is a robust enough solution? I remember us thinking about having task-id allocation and referring to them by symbols in the meantime. What happened to that plan? (jsut curious, not necessarily criticism)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes and no, with our current comms library we have ids 1-3 free for local task ids. We could prob put an assert not to exceed that, but I think it'll simply not compile at a later stage. The question is how much knowledge this pass should have of things downstream? That said, it is the first thing in the pipeline to handle tasks and task ids, so it has a bit of freedom.
terminator = op.body.block.last_op | ||
assert isinstance(terminator, scf.Yield) | ||
assert all( | ||
arg in op.body.block.args for arg in terminator.arguments | ||
), "Can only yield unmodified iter_args (in any order)" | ||
|
||
# limitation: currently only loops built from arith.constant are supported | ||
assert isinstance(op.lb, OpResult) | ||
assert isinstance(op.ub, OpResult) | ||
assert isinstance(op.step, OpResult) | ||
assert isinstance(op.lb.op, arith.Constant) | ||
assert isinstance(op.ub.op, arith.Constant) | ||
assert isinstance(op.step.op, arith.Constant) | ||
assert isa(op.lb.op.value, IntegerAttr[IndexType]) | ||
assert isa(op.ub.op.value, IntegerAttr[IndexType]) | ||
assert isa(op.step.op.value, IntegerAttr[IndexType]) | ||
|
||
# limitation: all iter_args must be memrefs (stencil buffers) and have the same data type | ||
assert isa(op.iter_args[0].type, MemRefType[csl.ZerosOp.T]) | ||
element_type = op.iter_args[0].type.get_element_type() | ||
assert all( | ||
isa(a.type, MemRefType[csl.ZerosOp.T]) | ||
and element_type == a.type.get_element_type() | ||
for a in op.iter_args | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'm generally not a fan of asserts in rewrite patterns, but I can see the appeal here. When the pass just aborts you may end up with silent failures, which are a pain to debug... I guess this shows that xDSL is lacking some sort of warning system for not-applied passes...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in MLIR there are separate match and apply functions, which separates out code that checks if a pattern should be applied. We also need quite a few of these asserts to make pyright happy, and the new make precommit
forces us to split assert isinstance(op.lb, OpResult) and isinstance(op.lb.op, arith.Constant)
into separate asserts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a couple of things from me
@@ -400,7 +400,7 @@ def get_element_type(self): | |||
return self.res.type.get_element_type() | |||
|
|||
@staticmethod | |||
def from_type(child_type: TypeAttribute) -> VariableOp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this not be a TypeAttribute
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then pyright complains in this line: [csl.VariableOp.from_type(arg_t) for arg_t in op.iter_args.types]
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like the wrong place to fix this then..?
Although, @AntonLydike, is there a reason some types do not inherit TypeAttribute
? In which case this should be an Attribute
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, good question. All types should inherit from TypeAttribute
, as that will have them be printed as !dialect.name
instead of #dialect.name
(which is attribute specific and can't be used as a type according to MLIR).
I've added some functionality (and a test) to support non-loop, sequential stencil apply ops. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good AFAICT
Adds a pass to handle the async control flow of
csl_stencil.apply
and any enclosing loops by translating control flow into a csl.func call graph.