Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transformations: Support constant inits in memref_stream.generic lowering [2/3] #2764

Merged
merged 22 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
8896dab
transformations: do not insert affine.apply ops when streaming
superlopuh Jun 20, 2024
f89d526
transformations: fix yielding of values in memref_stream.generic lowe…
superlopuh Jun 20, 2024
299de80
tests: move constant initialisation around in bottom-up tests
superlopuh Jun 21, 2024
2e5a024
dialects: (memref_stream) add an inits field to memref_stream.generic
superlopuh Jun 14, 2024
b2acbca
transformations: support constant inits in memref_stream.generic lowe…
superlopuh Jun 18, 2024
2fb3274
Merge remote-tracking branch 'origin/main' into sasha/memref_stream/y…
superlopuh Jun 21, 2024
c2a03e6
Merge branch 'sasha/memref_stream/yields' into sasha/memref_stream/bo…
superlopuh Jun 21, 2024
7f12d54
Merge branch 'sasha/memref_stream/bottom-up-constants' into sasha/mem…
superlopuh Jun 21, 2024
7b7f027
Merge branch 'sasha/memref_stream/const-init' into sasha/memref_strea…
superlopuh Jun 21, 2024
7378f6d
Merge branch 'main' into sasha/memref_stream/const-init
superlopuh Jun 21, 2024
3f70240
Merge branch 'main' into sasha/memref_stream/const-init
superlopuh Jun 23, 2024
333ce98
Merge branch 'sasha/memref_stream/const-init' into sasha/memref_strea…
superlopuh Jun 23, 2024
a207627
inits now values not attributes
superlopuh Jun 24, 2024
f9ae2aa
inits now values not attributes
superlopuh Jun 24, 2024
aed0cd4
straggling inits
superlopuh Jun 24, 2024
182f2c0
Merge branch 'sasha/memref_stream/const-init' into sasha/memref_strea…
superlopuh Jun 24, 2024
968c53d
fix printing
superlopuh Jun 24, 2024
81018b5
Merge branch 'sasha/memref_stream/const-init' into sasha/memref_strea…
superlopuh Jun 24, 2024
f114076
fix fix
superlopuh Jun 24, 2024
9ecbc32
fix fix fix
superlopuh Jun 24, 2024
73b0ec9
Merge branch 'sasha/memref_stream/const-init' into sasha/memref_strea…
superlopuh Jun 24, 2024
a80c473
Merge branch 'main' into sasha/memref_stream/const-init-lowering
superlopuh Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 35 additions & 70 deletions tests/filecheck/projects/riscv-backend-paper/bottom_up.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,19 @@ func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
]
} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) outs(%Z : memref<1x1x6x6xf64>) {
^0(%x_stream : !stream.readable<f64>, %y_stream : !stream.readable<f64>, %z_stream : !stream.writable<f64>):
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c6 = arith.constant 6 : index

scf.for %i0 = %c0 to %c1 step %c1 {
scf.for %i1 = %c0 to %c1 step %c1 {
scf.for %i2 = %c0 to %c6 step %c1 {
scf.for %i3 = %c0 to %c6 step %c1 {
%z = scf.for %i = %c0 to %c3 step %c1 iter_args(%acc0 = %zero_float) -> (f64) {
%z3 = scf.for %j = %c0 to %c3 step %c1 iter_args(%acc1 = %acc0) -> (f64) {
%x = memref_stream.read from %x_stream : f64
%y = memref_stream.read from %y_stream : f64
%prod = arith.mulf %x, %y fastmath<fast> : f64
%res = arith.addf %prod, %acc1 fastmath<fast> : f64
scf.yield %res : f64
}
scf.yield %z3 : f64
}

memref_stream.write %z to %z_stream : f64
}
}
}
memref_stream.generic {
bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<6>, #builtin.int<6>, #builtin.int<1>, #builtin.int<3>, #builtin.int<3>],
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d4, d2 + d5, d3 + d6)>,
affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1, d4, d5, d6)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]
} ins(%x_stream, %y_stream : !stream.readable<f64>, !stream.readable<f64>) outs(%z_stream : !stream.writable<f64>) inits(%zero_float : f64) {
^0(%x : f64, %y : f64, %acc : f64):
%prod = arith.mulf %x, %y fastmath<fast> : f64
%res = arith.addf %prod, %acc fastmath<fast> : f64
memref_stream.yield %res : f64
Comment on lines -17 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😱 Amazing ❤️

}
}

Expand Down Expand Up @@ -413,29 +402,17 @@ func.func public @pooling_nchw_max_d1_s2_3x3(
]
} ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) {
^0(%x_stream : !stream.readable<f64>, %y_stream : !stream.writable<f64>):
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c7 = arith.constant 7 : index
%c512 = arith.constant 512 : index

scf.for %i0 = %c0 to %c1 step %c1 {
scf.for %i1 = %c0 to %c1 step %c1 {
scf.for %i2 = %c0 to %c7 step %c1 {
scf.for %i3 = %c0 to %c7 step %c1 {
%y = scf.for %i = %c0 to %c3 step %c1 iter_args(%acc0 = %min_val) -> (f64) {
%y3 = scf.for %j = %c0 to %c3 step %c1 iter_args(%acc1 = %acc0) -> (f64) {
%x = memref_stream.read from %x_stream : f64
%res = arith.maximumf %x, %acc1 : f64
scf.yield %res : f64
}
scf.yield %y3 : f64
}

memref_stream.write %y to %y_stream : f64
}
}
}
memref_stream.generic {
bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<7>, #builtin.int<7>, #builtin.int<3>, #builtin.int<3>],
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]
} ins(%x_stream : !stream.readable<f64>) outs(%y_stream : !stream.writable<f64>) inits(%min_val : f64) {
^0(%x : f64, %acc : f64):
Comment on lines +412 to +413
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick question on the semantics of this:

You have two reduction dimensions, but only one reduction parameter int he loop, and only one initial value. Is there a place I can read up on how this works exactly?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this follows the linalg.generic quite closely. The dimensions and operands are related by the shapes of the operands. In this case, there are two reduction dimensions because the pooling happens over a 3x3 mini tile. The reduction dimensions are not present in the shapes of the ins or outs. I'd say the documentation for the linalg dialect and the linalg.generic op specifically is probably the best place to look.

%res = arith.maximumf %x, %acc : f64
memref_stream.yield %res : f64
}
}

Expand Down Expand Up @@ -545,29 +522,17 @@ func.func public @pooling_nchw_sum_d1_s2_3x3(
]
} ins(%X : memref<1x1x16x16xf64>) outs(%Y : memref<1x1x7x7xf64>) {
^0(%x_stream : !stream.readable<f64>, %y_stream : !stream.writable<f64>):
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c3 = arith.constant 3 : index
%c7 = arith.constant 7 : index
%c512 = arith.constant 512 : index

scf.for %i0 = %c0 to %c1 step %c1 {
scf.for %i1 = %c0 to %c1 step %c1 {
scf.for %i2 = %c0 to %c7 step %c1 {
scf.for %i3 = %c0 to %c7 step %c1 {
%y = scf.for %i = %c0 to %c3 step %c1 iter_args(%acc0 = %zero_float) -> (f64) {
%y3 = scf.for %j = %c0 to %c3 step %c1 iter_args(%acc1 = %acc0) -> (f64) {
%x = memref_stream.read from %x_stream : f64
%res = arith.addf %x, %acc1 : f64
scf.yield %res : f64
}
scf.yield %y3 : f64
}

memref_stream.write %y to %y_stream : f64
}
}
}
memref_stream.generic {
bounds = [#builtin.int<1>, #builtin.int<1>, #builtin.int<7>, #builtin.int<7>, #builtin.int<3>, #builtin.int<3>],
indexing_maps = [
affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2 * 2 + d4, d3 * 2 + d5)>,
affine_map<(d0, d1, d2, d3) -> (d0, d1, d2, d3)>
],
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]
} ins(%x_stream : !stream.readable<f64>) outs(%y_stream : !stream.writable<f64>) inits(%zero_float : f64) {
^0(%x : f64, %acc : f64):
%res = arith.addf %x, %acc : f64
memref_stream.yield %res : f64
}
}

Expand Down
51 changes: 51 additions & 0 deletions tests/filecheck/transforms/convert_memref_stream_to_loops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,55 @@ func.func @nested_imperfect(%A : memref<2x3x4xf64>, %B : memref<f64>) -> memref<
// CHECK-NEXT: func.return %{{.*}} : memref<f64>
// CHECK-NEXT: }

func.func @main_inits(%A : memref<4x2xf64>, %B : memref<2x3xf64>, %C : memref<4x3xf64>) -> memref<4x3xf64> {
%zero_float = arith.constant 0.000000e+00 : f64
memref_stream.streaming_region {
patterns = [
#memref_stream.stride_pattern<ub = [4, 3, 2], index_map = (d0, d1, d2) -> (d0, d2)>,
#memref_stream.stride_pattern<ub = [4, 3, 2], index_map = (d0, d1, d2) -> (d2, d1)>
]
} ins(%A, %B : memref<4x2xf64>, memref<2x3xf64>) {
^0(%0 : !stream.readable<f64>, %1 : !stream.readable<f64>):
memref_stream.generic {
bounds = [#builtin.int<4>, #builtin.int<3>, #builtin.int<2>],
indexing_maps = [
affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d2, d1)>,
affine_map<(d0, d1) -> (d0, d1)>
],
iterator_types = ["parallel", "parallel", "reduction"]
} ins(%0, %1 : !stream.readable<f64>, !stream.readable<f64>) outs(%C : memref<4x3xf64>) inits(%zero_float : f64) {
^1(%a : f64, %b : f64, %acc_old : f64):
%prod = arith.mulf %a, %b : f64
%acc_new = arith.addf %acc_old, %prod : f64
memref_stream.yield %acc_new : f64
}
}
func.return %C : memref<4x3xf64>
}
// CHECK-NEXT: func.func @main_inits(%{{.*}} : memref<4x2xf64>, %{{.*}} : memref<2x3xf64>, %{{.*}} : memref<4x3xf64>) -> memref<4x3xf64> {
// CHECK-NEXT: %zero_float = arith.constant 0.000000e+00 : f64
// CHECK-NEXT: memref_stream.streaming_region {patterns = [#memref_stream.stride_pattern<ub = [4, 3, 2], index_map = (d0, d1, d2) -> (d0, d2)>, #memref_stream.stride_pattern<ub = [4, 3, 2], index_map = (d0, d1, d2) -> (d2, d1)>]} ins(%{{.*}}, %{{.*}} : memref<4x2xf64>, memref<2x3xf64>) {
// CHECK-NEXT: ^{{.*}}(%{{.*}} : !stream.readable<f64>, %{{.*}} : !stream.readable<f64>):
// CHECK-NEXT: %{{.*}} = arith.constant 4 : index
// CHECK-NEXT: %{{.*}} = arith.constant 3 : index
// CHECK-NEXT: %{{.*}} = arith.constant 2 : index
// CHECK-NEXT: %{{.*}} = arith.constant 0 : index
// CHECK-NEXT: %{{.*}} = arith.constant 1 : index
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} {
// CHECK-NEXT: %{{.*}} = scf.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} iter_args(%{{.*}} = %zero_float) -> (f64) {
// CHECK-NEXT: %{{.*}} = memref_stream.read from %{{.*}} : f64
// CHECK-NEXT: %{{.*}} = memref_stream.read from %{{.*}} : f64
// CHECK-NEXT: %{{.*}} = arith.mulf %{{.*}}, %{{.*}} : f64
// CHECK-NEXT: %{{.*}} = arith.addf %{{.*}}, %{{.*}} : f64
// CHECK-NEXT: scf.yield %{{.*}} : f64
// CHECK-NEXT: }
// CHECK-NEXT: memref.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}] : memref<4x3xf64>
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: }
// CHECK-NEXT: func.return %{{.*}} : memref<4x3xf64>
// CHECK-NEXT: }

// CHECK-NEXT: }
1 change: 1 addition & 0 deletions xdsl/transforms/convert_linalg_to_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


def insert_load(
value_index: int,
value: SSAValue,
affine_map_attr: AffineMapAttr,
ind_vars: Sequence[SSAValue],
Expand Down
34 changes: 31 additions & 3 deletions xdsl/transforms/convert_memref_stream_to_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
)


def insert_load(
def _insert_load(
source_index: int,
source: SSAValue,
affine_map_attr: AffineMapAttr,
ind_vars: Sequence[SSAValue],
Expand Down Expand Up @@ -64,13 +65,40 @@ class LowerGenericOpPattern(RewritePattern):
def match_and_rewrite(
self, op: memref_stream.GenericOp, rewriter: PatternRewriter
) -> None:
ins_count = len(op.inputs)
if any(not isinstance(init, UnitAttr) for init in op.inits):
raise NotImplementedError("Operation has inits that are not UnitAttr")
constant_vals: list[SSAValue | None] = [None] * len(op.outputs)
for index, val in zip(op.init_indices, op.inits, strict=True):
constant_vals[index.data] = val

def insert_load(
source_index: int,
source: SSAValue,
affine_map_attr: AffineMapAttr,
ind_vars: Sequence[SSAValue],
rewriter: PatternRewriter,
insertion_point: InsertPoint,
) -> SSAValue:
if source_index >= ins_count:
constant_val = constant_vals[source_index - ins_count]
if constant_val is not None:
return constant_val

return _insert_load(
source_index,
source,
affine_map_attr,
ind_vars,
rewriter,
insertion_point,
)

else:
insert_load = _insert_load

outer_ubs, inner_ubs = op.get_static_loop_ranges()
if inner_ubs:
# Imperfectly nested
ins_count = len(op.inputs)
rewrite_generic_to_imperfect_loops(
rewriter,
InsertPoint.before(op),
Expand Down
9 changes: 7 additions & 2 deletions xdsl/transforms/loop_nest_lowering_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def indices_for_map(

INSERT_LOAD: TypeAlias = Callable[
[
int,
SSAValue,
AffineMapAttr,
Sequence[SSAValue],
Expand Down Expand Up @@ -161,6 +162,7 @@ def _insert_load_ops(
operands: Sequence[SSAValue],
args: Sequence[BlockArgument],
insert_load: INSERT_LOAD,
index_increment: int = 0,
) -> Sequence[tuple[int, SSAValue]]:
"""
Inserts the load operations at the specified insertion point.
Expand All @@ -172,6 +174,7 @@ def _insert_load_ops(
The `affine_map_attrs`, `operands`, and `args` must have the same length.
Returns a tuple of integers indicating the locations of the returned values, and
the values themselves.
The integer values are incremented by `index_increment`.
"""
res: list[tuple[int, SSAValue]] = []
for i, (affine_map_attr, operand, arg) in enumerate(
Expand All @@ -180,13 +183,14 @@ def _insert_load_ops(
if not arg.uses:
continue
res_val = insert_load(
i + index_increment,
operand,
affine_map_attr,
ind_vars,
rewriter,
insertion_point,
)
res.append((i, res_val))
res.append((i + index_increment, res_val))
return res


Expand Down Expand Up @@ -352,6 +356,7 @@ def outer_make_body(
outer_load_operands,
outer_load_block_args,
insert_load,
index_increment=len(inner_load_block_args),
)

def inner_make_body(
Expand All @@ -377,7 +382,7 @@ def inner_make_body(
inner_iter_args,
strict=True,
):
block.args[i + len(inner_loaded_values)].replace_by(arg)
block.args[i].replace_by(arg)

# Replace block argument use with load op results
for i, val in inner_loaded_values:
Expand Down
Loading