Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TKW] Rework vector mask generation #172

Merged
merged 6 commits into from
Sep 30, 2024

Conversation

Hardcode84
Copy link
Contributor

Instead of generating individual element comparisons and doing vector.insertelement generate the whole mask using vector ops.

Add support for vector codegen when generating MLIR IR from sympy expressions. Add method IndexingContext.iota to generate special symbols which map to (1,2 ... n-1) vec expressions. gen_sympy_index will start to generate vector ops when encountering such symbols, inserting proper splat's between scalar vals when necessary.

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
@Hardcode84 Hardcode84 marked this pull request as ready for review September 27, 2024 14:55
Copy link
Contributor

@harsh-nod harsh-nod left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, modulo some minor comments

# CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>>
# CHECK: vector.maskedstore %[[D27]][%[[D5]], %[[D8]]], %[[D25]], %[[D26]] :
# CHECK-SAME: memref<1x3xf16, strided<[3, 1], offset: ?>>, vector<4xi1>, vector<4xf16>
# CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4xf16>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice!

) -> Optional[list[IndexExpr]]:
bounds = []
for constraint in constraints:
if not isinstance(constraint, (WorkgroupConstraint, TilingConstraint)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we ignore the WaveConstraints here? Does that mean that we assuming that the workgroup tile size is a multiple of the wave tile size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I implicitly assumed WG size is divisible by wave size, do we have any potential examples when it's false?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good question. I was just thinking in terms of generality (so workgroup tile size = 27 and wave tile size = 17) but maybe its not such a common use case. We can ignore for now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I am reading through some of the history, allow me to share a thought.
It is generally good practice to be as exhaustive and general as possible and failing loudly.
This way, in the future we can come back and immediately understand the problem and that it is NYI.
In the the current form, it seems to me the case proposed by Harsh would silently pass but generate wrong code?


pos = arith_d.ConstantOp(IndexType.get(), i)
mask = vector_d.insertelement(cond, mask, position=pos)
mask_expr = functools.reduce(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very cool!

@@ -171,6 +172,32 @@ def get_type_or_element_type(operand_type: IrType):
def gen_sympy_index(emitter: WaveEmitter, expr: sympy.Expr) -> OpResult:
stack: list[OpResult] = []

def _broadcast(a, b):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like you can refactor this to avoid the duplication for a and b.

mask_expr = functools.reduce(
lambda a, b: sympy.And(a, b), (new_index[dim] < dim for dim in bounds)
)
mask = gen_sympy_index(emitter, mask_expr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No action needed, but just putting this down for the record that at some point we should evaluate generating these masks prior to codegen and evaluate the performance impact.

@Hardcode84 Hardcode84 merged commit 92ad900 into iree-org:main Sep 30, 2024
8 checks passed
@Hardcode84 Hardcode84 deleted the masking-vectorize branch September 30, 2024 14:46
IanNod pushed a commit to IanNod/iree-turbine that referenced this pull request Sep 30, 2024
Instead of generating individual element comparisons and doing
`vector.insertelement` generate the whole mask using vector ops.

Add support for vector codegen when generating MLIR IR from sympy
expressions. Add method `IndexingContext.iota` to generate special
symbols which map to `(1,2 ... n-1)` vec expressions. `gen_sympy_index`
will start to generate vector ops when encountering such symbols,
inserting proper `splat`'s between scalar vals when necessary.

---------

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
IanNod pushed a commit to IanNod/iree-turbine that referenced this pull request Sep 30, 2024
Instead of generating individual element comparisons and doing
`vector.insertelement` generate the whole mask using vector ops.

Add support for vector codegen when generating MLIR IR from sympy
expressions. Add method `IndexingContext.iota` to generate special
symbols which map to `(1,2 ... n-1)` vec expressions. `gen_sympy_index`
will start to generate vector ops when encountering such symbols,
inserting proper `splat`'s between scalar vals when necessary.

---------

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Signed-off-by: Ian <ian.nordeng@amd.com>
stellaraccident pushed a commit that referenced this pull request Oct 13, 2024
Instead of generating individual element comparisons and doing
`vector.insertelement` generate the whole mask using vector ops.

Add support for vector codegen when generating MLIR IR from sympy
expressions. Add method `IndexingContext.iota` to generate special
symbols which map to `(1,2 ... n-1)` vec expressions. `gen_sympy_index`
will start to generate vector ops when encountering such symbols,
inserting proper `splat`'s between scalar vals when necessary.

---------

Signed-off-by: Ivan Butygin <ivan.butygin@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants