-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TKW] Rework vector mask generation #172
Conversation
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, modulo some minor comments
# CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>> | ||
# CHECK: vector.maskedstore %[[D27]][%[[D5]], %[[D8]]], %[[D25]], %[[D26]] : | ||
# CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16> | ||
# CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf16> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice!
) -> Optional[list[IndexExpr]]: | ||
bounds = [] | ||
for constraint in constraints: | ||
if not isinstance(constraint, (WorkgroupConstraint, TilingConstraint)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we ignore the WaveConstraints here? Does that mean that we assuming that the workgroup tile size is a multiple of the wave tile size?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, I implicitly assumed WG size is divisible by wave size, do we have any potential examples when it's false?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question. I was just thinking in terms of generality (so workgroup tile size = 27 and wave tile size = 17) but maybe its not such a common use case. We can ignore for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As I am reading through some of the history, allow me to share a thought.
It is generally good practice to be as exhaustive and general as possible and failing loudly.
This way, in the future we can come back and immediately understand the problem and that it is NYI.
In the the current form, it seems to me the case proposed by Harsh would silently pass but generate wrong code?
|
||
pos = arith_d.ConstantOp(IndexType.get(), i) | ||
mask = vector_d.insertelement(cond, mask, position=pos) | ||
mask_expr = functools.reduce( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
very cool!
@@ -171,6 +172,32 @@ def get_type_or_element_type(operand_type: IrType): | |||
def gen_sympy_index(emitter: WaveEmitter, expr: sympy.Expr) -> OpResult: | |||
stack: list[OpResult] = [] | |||
|
|||
def _broadcast(a, b): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like you can refactor this to avoid the duplication for a and b.
mask_expr = functools.reduce( | ||
lambda a, b: sympy.And(a, b), (new_index[dim] < dim for dim in bounds) | ||
) | ||
mask = gen_sympy_index(emitter, mask_expr) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No action needed, but just putting this down for the record that at some point we should evaluate generating these masks prior to codegen and evaluate the performance impact.
Instead of generating individual element comparisons and doing `vector.insertelement` generate the whole mask using vector ops. Add support for vector codegen when generating MLIR IR from sympy expressions. Add method `IndexingContext.iota` to generate special symbols which map to `(1,2 ... n-1)` vec expressions. `gen_sympy_index` will start to generate vector ops when encountering such symbols, inserting proper `splat`'s between scalar vals when necessary. --------- Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Instead of generating individual element comparisons and doing `vector.insertelement` generate the whole mask using vector ops. Add support for vector codegen when generating MLIR IR from sympy expressions. Add method `IndexingContext.iota` to generate special symbols which map to `(1,2 ... n-1)` vec expressions. `gen_sympy_index` will start to generate vector ops when encountering such symbols, inserting proper `splat`'s between scalar vals when necessary. --------- Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com> Signed-off-by: Ian <ian.nordeng@amd.com>
Instead of generating individual element comparisons and doing `vector.insertelement` generate the whole mask using vector ops. Add support for vector codegen when generating MLIR IR from sympy expressions. Add method `IndexingContext.iota` to generate special symbols which map to `(1,2 ... n-1)` vec expressions. `gen_sympy_index` will start to generate vector ops when encountering such symbols, inserting proper `splat`'s between scalar vals when necessary. --------- Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Instead of generating individual element comparisons and doing
vector.insertelement
generate the whole mask using vector ops.Add support for vector codegen when generating MLIR IR from sympy expressions. Add method
IndexingContext.iota
to generate special symbols which map to(1,2 ... n-1)
vec expressions.gen_sympy_index
will start to generate vector ops when encountering such symbols, inserting propersplat
's between scalar vals when necessary.