Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP]: Jax tensor donation #3224

Draft
wants to merge 22 commits into
base: main
Choose a base branch
from

Conversation

mamanain
Copy link
Collaborator

This is still work in progress, just wanted to get some feedback for this version as well.

The main area of improvement: donation logic should be more in line with the jax documentation.
There are currently two problems here:

  1. Donated buffers can still be used in the function but here it can be used as a buffer straight away. This can lead to data being overwritten and then used for some computations which will lead to errors. So we need to change the logic so that the buffer becomes available only after its last usage.
  2. Right now buffer can be used only once. Should we keep it in the dictionary and reuse in other places where it fits?

Copy link

codecov bot commented Sep 27, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 90.03%. Comparing base (70fa878) to head (8fc4d95).
Report is 1 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main    #3224      +/-   ##
==========================================
+ Coverage   90.01%   90.03%   +0.02%     
==========================================
  Files         445      446       +1     
  Lines       55850    55953     +103     
  Branches     5351     5357       +6     
==========================================
+ Hits        50274    50378     +104     
  Misses       4169     4169              
+ Partials     1407     1406       -1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@superlopuh superlopuh marked this pull request as draft September 28, 2024 22:50
Copy link
Member

@superlopuh superlopuh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for this, it's a great first step! A few bigger comments before the nitty-gritty:

  • I'm a little confused by the name of the pass, it's not really converting JAX to linalg, just making sure that the operands marked as donating are actually reused as destinations, so another name feels more appropriate. (I would recommend making the prefix of the pass something that's not convert, as that's usually used for converting one dialect to another, and in this case it's more of an optimisation.) Maybe jax-use-donated-arguments?
  • The test is quite specific, and uses operations from dialects that aren't strictly involved. I would recommend using the "test" dialect as much as possible to generate values that are necessary for the test.
  • It's quite a bit more difficult to review PRs with lots of Pyright errors, could you please fix the errors locally and ping me again to take a look?
  • It would be good to be very clear about the limitations of the proposed approach, and to raise helpful messages if used in an unexpected context, could you please add DiagnosticExceptions if the assumptions are not met and tests for those?

@@ -86,6 +86,11 @@ def get_convert_stencil_to_ll_mlir():

return convert_stencil_to_ll_mlir.ConvertStencilToLLMLIRPass

def get_jax_use_donated_arguments():
from xdsl.transforms.experimental import jax_use_donated_arguments
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would recommend against experimental, it's a place for untested code, that we will probably remove at some point, let's just put it in the main transforms folder

Comment on lines 27 to 30
donated_inputs: list[BlockArgument] = []
for inp, attr in zip(op.args, op.arg_attrs):
if type(inp.type) is TensorType and "tf.aliasing_output" in attr.data:
donated_inputs.append(inp)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work?

Suggested change
donated_inputs: list[BlockArgument] = []
for inp, attr in zip(op.args, op.arg_attrs):
if type(inp.type) is TensorType and "tf.aliasing_output" in attr.data:
donated_inputs.append(inp)
donated_inputs = [
inp
for inp, attr in zip(op.args, op.arg_attrs)
if isinstance(inp.type, TensorType) and "tf.aliasing_output" in attr.data
]

walk_regions_first=True,
)
the_one_pass.rewrite_module(op)
MLIROptPass(arguments=("--linalg-fuse-elementwise-ops",)).apply(ctx, op)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be done separately by the user, and not included in this pass

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's no need to put the test in the with-mlir folder, it's better to separate it into a pure xDSL thing, and let the user call into mlir-opt if they want to

Comment on lines 32 to 33
arg_attrs = getattr(func_op, "arg_attrs")
args = getattr(func_op, "args")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we try to avoid getattr when possible, could you please check that the parent is a func, and then access properties directly?

Comment on lines 188 to 199
@irdl_op_definition
class CollapseShapeOp(IRDLOperation):
name = "tensor.collapse_shape"

src = operand_def(TensorType[Attribute])
result = result_def(TensorType[Attribute])
reassociation = prop_def(ReassociationAttr)
assembly_format = (
"$src $reassociation attr-dict `:` type($src) `into` type($result)"
)

traits = frozenset([NoMemoryEffect()])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this as part of this PR

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be great to add in a separate PR, though

@@ -168,12 +168,33 @@ class ToMemrefOp(IRDLOperation):
assembly_format = "$tensor (`read_only` $read_only^)? `:` attr-dict type($memref)"


# now only works for (tensor, tensor) arguments. need to add memref support as well.
@irdl_op_definition
class MaterializeInDestination(IRDLOperation):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you please add this in a separate PR?

// CHECK-NEXT: }
// CHECK-NEXT: }

builtin.module {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
builtin.module {

And in other tests

// RUN: xdsl-opt %s -p jax-use-donated-arguments --split-input-file --verify-diagnostics | filecheck %s

builtin.module {
func.func public @main(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think all the functions have to be called "main", and we can use the function name to document the purpose of the test, like so:

Suggested change
func.func public @main(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) {
func.func public @one_arg(%arg0: tensor<2x3xf32>, %arg1: tensor<3x4xf32>, %arg2: tensor<2x4xf32> {tf.aliasing_output = 0 : i32}) -> (tensor<2x4xf32>) {

name = "bufferization.materialize_in_destination"

source = operand_def(
TensorMemrefInferenceConstraint("T", AnyOf([TensorType, UnrankedTensorType]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
TensorMemrefInferenceConstraint("T", AnyOf([TensorType, UnrankedTensorType]))
TensorMemrefInferenceConstraint("T", AnyTensorTypeConstr | AnyUnrankedTensorTypeConstr)

(will need to add the second constraint definition in builtin.py)


donated_inputs = [
inp
for inp, attr in zip(func_op.args, func_op.arg_attrs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for inp, attr in zip(func_op.args, func_op.arg_attrs)
for inp, attr in zip(func_op.args, func_op.arg_attrs, strict=True)

@op_type_rewrite_pattern
def match_and_rewrite(self, op: Return, rewriter: PatternRewriter, /):
func_op = op.parent_op()
if func_op is None or type(func_op) is not FuncOp:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
if func_op is None or type(func_op) is not FuncOp:
assert isinstance(func_op, FuncOp)

The input can be assumed to be verified. There's no way to handle a VerifyException here anyway, so we might as well assert.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants