-
Notifications
You must be signed in to change notification settings - Fork 71
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]: Jax tensor donation #3224
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #3224 +/- ##
==========================================
+ Coverage 90.01% 90.03% +0.02%
==========================================
Files 445 446 +1
Lines 55850 55953 +103
Branches 5351 5357 +6
==========================================
+ Hits 50274 50378 +104
Misses 4169 4169
+ Partials 1407 1406 -1 ☔ View full report in Codecov by Sentry. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for this, it's a great first step! A few bigger comments before the nitty-gritty:
- I'm a little confused by the name of the pass, it's not really converting JAX to linalg, just making sure that the operands marked as donating are actually reused as destinations, so another name feels more appropriate. (I would recommend making the prefix of the pass something that's not convert, as that's usually used for converting one dialect to another, and in this case it's more of an optimisation.) Maybe
jax-use-donated-arguments
? - The test is quite specific, and uses operations from dialects that aren't strictly involved. I would recommend using the "test" dialect as much as possible to generate values that are necessary for the test.
- It's quite a bit more difficult to review PRs with lots of Pyright errors, could you please fix the errors locally and ping me again to take a look?
- It would be good to be very clear about the limitations of the proposed approach, and to raise helpful messages if used in an unexpected context, could you please add
DiagnosticException
s if the assumptions are not met and tests for those?
xdsl/tools/command_line_tool.py
Outdated
@@ -86,6 +86,11 @@ def get_convert_stencil_to_ll_mlir(): | |||
|
|||
return convert_stencil_to_ll_mlir.ConvertStencilToLLMLIRPass | |||
|
|||
def get_jax_use_donated_arguments(): | |||
from xdsl.transforms.experimental import jax_use_donated_arguments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would recommend against experimental, it's a place for untested code, that we will probably remove at some point, let's just put it in the main transforms folder
donated_inputs: list[BlockArgument] = [] | ||
for inp, attr in zip(op.args, op.arg_attrs): | ||
if type(inp.type) is TensorType and "tf.aliasing_output" in attr.data: | ||
donated_inputs.append(inp) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this work?
donated_inputs: list[BlockArgument] = [] | |
for inp, attr in zip(op.args, op.arg_attrs): | |
if type(inp.type) is TensorType and "tf.aliasing_output" in attr.data: | |
donated_inputs.append(inp) | |
donated_inputs = [ | |
inp | |
for inp, attr in zip(op.args, op.arg_attrs) | |
if isinstance(inp.type, TensorType) and "tf.aliasing_output" in attr.data | |
] |
walk_regions_first=True, | ||
) | ||
the_one_pass.rewrite_module(op) | ||
MLIROptPass(arguments=("--linalg-fuse-elementwise-ops",)).apply(ctx, op) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be done separately by the user, and not included in this pass
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there's no need to put the test in the with-mlir folder, it's better to separate it into a pure xDSL thing, and let the user call into mlir-opt if they want to
arg_attrs = getattr(func_op, "arg_attrs") | ||
args = getattr(func_op, "args") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we try to avoid getattr when possible, could you please check that the parent is a func, and then access properties directly?
xdsl/dialects/tensor.py
Outdated
@irdl_op_definition | ||
class CollapseShapeOp(IRDLOperation): | ||
name = "tensor.collapse_shape" | ||
|
||
src = operand_def(TensorType[Attribute]) | ||
result = result_def(TensorType[Attribute]) | ||
reassociation = prop_def(ReassociationAttr) | ||
assembly_format = ( | ||
"$src $reassociation attr-dict `:` type($src) `into` type($result)" | ||
) | ||
|
||
traits = frozenset([NoMemoryEffect()]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this as part of this PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would be great to add in a separate PR, though
xdsl/dialects/bufferization.py
Outdated
@@ -168,12 +168,33 @@ class ToMemrefOp(IRDLOperation): | |||
assembly_format = "$tensor (`read_only` $read_only^)? `:` attr-dict type($memref)" | |||
|
|||
|
|||
# now only works for (tensor, tensor) arguments. need to add memref support as well. | |||
@irdl_op_definition | |||
class MaterializeInDestination(IRDLOperation): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add this in a separate PR?
// CHECK-NEXT: } | ||
// CHECK-NEXT: } | ||
|
||
builtin.module { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
builtin.module { |
And in other tests
// RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file --verify-diagnostics | filecheck %s | ||
|
||
builtin.module { | ||
func.func public @main(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think all the functions have to be called "main", and we can use the function name to document the purpose of the test, like so:
func.func public @main(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) { | |
func.func public @one_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) { |
xdsl/dialects/bufferization.py
Outdated
name = "bufferization.materialize_in_destination" | ||
|
||
source = operand_def( | ||
TensorMemrefInferenceConstraint("T", AnyOf([TensorType, UnrankedTensorType])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TensorMemrefInferenceConstraint("T", AnyOf([TensorType, UnrankedTensorType])) | |
TensorMemrefInferenceConstraint("T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr) |
(will need to add the second constraint definition in builtin.py)
|
||
donated_inputs = [ | ||
inp | ||
for inp, attr in zip(func_op.args, func_op.arg_attrs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for inp, attr in zip(func_op.args, func_op.arg_attrs) | |
for inp, attr in zip(func_op.args, func_op.arg_attrs, strict=True) |
@op_type_rewrite_pattern | ||
def match_and_rewrite(self, op: Return, rewriter: PatternRewriter, /): | ||
func_op = op.parent_op() | ||
if func_op is None or type(func_op) is not FuncOp: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if func_op is None or type(func_op) is not FuncOp: | |
assert isinstance(func_op, FuncOp) |
The input can be assumed to be verified. There's no way to handle a VerifyException here anyway, so we might as well assert.
93a8703
to
7869d09
Compare
This is still work in progress, just wanted to get some feedback for this version as well.
The main area of improvement: donation logic should be more in line with the jax documentation.
There are currently two problems here: