-
Notifications
You must be signed in to change notification settings - Fork 83
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
transformations: Support constant inits in memref_stream.generic lowering [2/3] #2764
Changes from all commits
8896dab
f89d526
299de80
2e5a024
b2acbca
2fb3274
c2a03e6
7f12d54
7b7f027
7378f6d
3f70240
333ce98
a207627
f9ae2aa
aed0cd4
182f2c0
968c53d
81018b5
f114076
9ecbc32
73b0ec9
a80c473
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,30 +14,19 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3( | |
] | ||
} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) { | ||
^0(%x_stream : !stream.readable<f64>, %y_stream : !stream.readable<f64>, %z_stream : !stream.writable<f64>): | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c3 = arith.constant 3 : index | ||
%c6 = arith.constant 6 : index | ||
|
||
scf.for %i0 = %c0 to %c1 step %c1 { | ||
scf.for %i1 = %c0 to %c1 step %c1 { | ||
scf.for %i2 = %c0 to %c6 step %c1 { | ||
scf.for %i3 = %c0 to %c6 step %c1 { | ||
%z = scf.for %i = %c0 to %c3 step %c1 iter_args(%acc0 = %zero_float) -> (f64) { | ||
%z3 = scf.for %j = %c0 to %c3 step %c1 iter_args(%acc1 = %acc0) -> (f64) { | ||
%x = memref_stream.read from %x_stream : f64 | ||
%y = memref_stream.read from %y_stream : f64 | ||
%prod = arith.mulf %x, %y fastmath<fast> : f64 | ||
%res = arith.addf %prod, %acc1 fastmath<fast> : f64 | ||
scf.yield %res : f64 | ||
} | ||
scf.yield %z3 : f64 | ||
} | ||
|
||
memref_stream.write %z to %z_stream : f64 | ||
} | ||
} | ||
} | ||
memref_stream.generic { | ||
bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>], | ||
indexing_maps = [ | ||
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>, | ||
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>, | ||
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | ||
], | ||
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"] | ||
} ins(%x_stream, %y_stream : !stream.readable<f64>, !stream.readable<f64>) outs(%z_stream : !stream.writable<f64>) inits(%zero_float : f64) { | ||
^0(%x : f64, %y : f64, %acc : f64): | ||
%prod = arith.mulf %x, %y fastmath<fast> : f64 | ||
%res = arith.addf %prod, %acc fastmath<fast> : f64 | ||
memref_stream.yield %res : f64 | ||
} | ||
} | ||
|
||
|
@@ -413,29 +402,17 @@ func.func public @pooling_nchw_max_d1_s2_3x3( | |
] | ||
} ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) { | ||
^0(%x_stream : !stream.readable<f64>, %y_stream : !stream.writable<f64>): | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c3 = arith.constant 3 : index | ||
%c7 = arith.constant 7 : index | ||
%c512 = arith.constant 512 : index | ||
|
||
scf.for %i0 = %c0 to %c1 step %c1 { | ||
scf.for %i1 = %c0 to %c1 step %c1 { | ||
scf.for %i2 = %c0 to %c7 step %c1 { | ||
scf.for %i3 = %c0 to %c7 step %c1 { | ||
%y = scf.for %i = %c0 to %c3 step %c1 iter_args(%acc0 = %min_val) -> (f64) { | ||
%y3 = scf.for %j = %c0 to %c3 step %c1 iter_args(%acc1 = %acc0) -> (f64) { | ||
%x = memref_stream.read from %x_stream : f64 | ||
%res = arith.maximumf %x, %acc1 : f64 | ||
scf.yield %res : f64 | ||
} | ||
scf.yield %y3 : f64 | ||
} | ||
|
||
memref_stream.write %y to %y_stream : f64 | ||
} | ||
} | ||
} | ||
memref_stream.generic { | ||
bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<7>, #builtin.int<7>, #builtin.int<3>, #builtin.int<3>], | ||
indexing_maps = [ | ||
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, | ||
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | ||
], | ||
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"] | ||
} ins(%x_stream : !stream.readable<f64>) outs(%y_stream : !stream.writable<f64>) inits(%min_val : f64) { | ||
^0(%x : f64, %acc : f64): | ||
Comment on lines
+412
to
+413
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quick question on the semantics of this: You have two reduction dimensions, but only one reduction parameter int he loop, and only one initial value. Is there a place I can read up on how this works exactly? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this follows the linalg.generic quite closely. The dimensions and operands are related by the shapes of the operands. In this case, there are two reduction dimensions because the pooling happens over a 3x3 mini tile. The reduction dimensions are not present in the shapes of the ins or outs. I'd say the documentation for the linalg dialect and the linalg.generic op specifically is probably the best place to look. |
||
%res = arith.maximumf %x, %acc : f64 | ||
memref_stream.yield %res : f64 | ||
} | ||
} | ||
|
||
|
@@ -545,29 +522,17 @@ func.func public @pooling_nchw_sum_d1_s2_3x3( | |
] | ||
} ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) { | ||
^0(%x_stream : !stream.readable<f64>, %y_stream : !stream.writable<f64>): | ||
%c0 = arith.constant 0 : index | ||
%c1 = arith.constant 1 : index | ||
%c3 = arith.constant 3 : index | ||
%c7 = arith.constant 7 : index | ||
%c512 = arith.constant 512 : index | ||
|
||
scf.for %i0 = %c0 to %c1 step %c1 { | ||
scf.for %i1 = %c0 to %c1 step %c1 { | ||
scf.for %i2 = %c0 to %c7 step %c1 { | ||
scf.for %i3 = %c0 to %c7 step %c1 { | ||
%y = scf.for %i = %c0 to %c3 step %c1 iter_args(%acc0 = %zero_float) -> (f64) { | ||
%y3 = scf.for %j = %c0 to %c3 step %c1 iter_args(%acc1 = %acc0) -> (f64) { | ||
%x = memref_stream.read from %x_stream : f64 | ||
%res = arith.addf %x, %acc1 : f64 | ||
scf.yield %res : f64 | ||
} | ||
scf.yield %y3 : f64 | ||
} | ||
|
||
memref_stream.write %y to %y_stream : f64 | ||
} | ||
} | ||
} | ||
memref_stream.generic { | ||
bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<7>, #builtin.int<7>, #builtin.int<3>, #builtin.int<3>], | ||
indexing_maps = [ | ||
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>, | ||
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)> | ||
], | ||
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"] | ||
} ins(%x_stream : !stream.readable<f64>) outs(%y_stream : !stream.writable<f64>) inits(%zero_float : f64) { | ||
^0(%x : f64, %acc : f64): | ||
%res = arith.addf %x, %acc : f64 | ||
memref_stream.yield %res : f64 | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
😱 Amazing ❤️