-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transformations: Support constant inits in memref_stream.generic lowering [2/3] #2764
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #2764 +/- ##
=======================================
Coverage 89.79% 89.80%
=======================================
Files 381 381
Lines 48331 48341 +10
Branches 7404 7407 +3
=======================================
+ Hits 43400 43411 +11
+ Misses 3773 3772 -1
Partials 1158 1158 ☔ View full report in Codecov by Sentry. |
…ttom-up-constants
…ref_stream/const-init
…m/const-init-lowering
…m/const-init-lowering
…m/const-init-lowering
…m/const-init-lowering
…m/const-init-lowering
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Neato!
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c3 = arith.constant 3 : index | ||
%c6 = arith.constant 6 : index | ||
|
||
scf.for %i0 = %c0 to %c1 step %c1 { | ||
scf.for %i1 = %c0 to %c1 step %c1 { | ||
scf.for %i2 = %c0 to %c6 step %c1 { | ||
scf.for %i3 = %c0 to %c6 step %c1 { | ||
%z = scf.for %i = %c0 to %c3 step %c1 iter_args(%acc0 = %zero_float) -> (f64) { | ||
%z3 = scf.for %j = %c0 to %c3 step %c1 iter_args(%acc1 = %acc0) -> (f64) { | ||
%x = memref_stream.read from %x_stream : f64 | ||
%y = memref_stream.read from %y_stream : f64 | ||
%prod = arith.mulf %x, %y fastmath<fast> : f64 | ||
%res = arith.addf %prod, %acc1 fastmath<fast> : f64 | ||
scf.yield %res : f64 | ||
} | ||
scf.yield %z3 : f64 | ||
} | ||
|
||
memref_stream.write %z to %z_stream : f64 | ||
} | ||
} | ||
} | ||
memref_stream.generic { | ||
bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], | ||
indexing_maps = [ | ||
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, | ||
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, | ||
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | ||
], | ||
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] | ||
} ins(%x_stream, %y_stream : !stream.readable<f64>, !stream.readable<f64>) outs(%z_stream : !stream.writable<f64>) inits(%zero_float : f64) { | ||
^0(%x : f64, %y : f64, %acc : f64): | ||
%prod = arith.mulf %x, %y fastmath<fast> : f64 | ||
%res = arith.addf %prod, %acc fastmath<fast> : f64 | ||
memref_stream.yield %res : f64 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😱 Amazing ❤️
} ins(%x_stream : !stream.readable<f64>) outs(%y_stream : !stream.writable<f64>) inits(%min_val : f64) { | ||
^0(%x : f64, %acc : f64): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Quick question on the semantics of this:
You have two reduction dimensions, but only one reduction parameter int he loop, and only one initial value. Is there a place I can read up on how this works exactly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this follows the linalg.generic quite closely. The dimensions and operands are related by the shapes of the operands. In this case, there are two reduction dimensions because the pooling happens over a 3x3 mini tile. The reduction dimensions are not present in the shapes of the ins or outs. I'd say the documentation for the linalg dialect and the linalg.generic op specifically is probably the best place to look.
This PR adds support for lowering
memref_stream.generic
s with inits to loops. It also updates the bottom-up tests to leverage this feature. As you can see, there is no functional change in the final assembly, modulo a single register difference.I updated the
insert_load
interface in the loop lowering helpers to communicate the index of the operand that for which the load is being inserted, this allows us to use the const value directly.Note stacked PR.