Skip to content

Commit

Permalink
rewriting: Add convert-ml-program-to-memref (#2580)
Browse files Browse the repository at this point in the history
I'm not 100% sure that this is the intended way that ml_program should
be lowered but it seems to work for our ONNX kernels.
  • Loading branch information
superlopuh authored May 29, 2024
1 parent 14b4f9c commit 8fbb3b2
Show file tree
Hide file tree
Showing 5 changed files with 111 additions and 9 deletions.
11 changes: 11 additions & 0 deletions tests/filecheck/transforms/convert_ml_program_to_memref.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: xdsl-opt %s -p convert-ml-program-to-memref | filecheck %s

ml_program.global private @global_same_type(dense<4> : tensor<4xi32>) : tensor<4xi32>

%0 = ml_program.global_load_const @global_same_type : tensor<4xi32>

// CHECK: builtin.module {
// CHECK-NEXT: "memref.global"() <{"sym_name" = "global_same_type", "type" = memref<4xi32>, "initial_value" = dense<4> : tensor<4xi32>, "sym_visibility" = "private", "constant"}> : () -> ()
// CHECK-NEXT: %0 = memref.get_global @global_same_type : memref<4xi32>
// CHECK-NEXT: %1 = "bufferization.to_tensor"(%0) : (memref<4xi32>) -> tensor<4xi32>
// CHECK-NEXT: }
2 changes: 1 addition & 1 deletion tests/interpreters/test_memref_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_memref_get_global():
sym_visibility=StringAttr("public"),
)
with ImplicitBuilder(func.FuncOp("main", ((), ())).body):
fetch = memref.GetGlobal.get("my_global", memref_type)
fetch = memref.GetGlobal("my_global", memref_type)

interpreter = Interpreter(module, index_bitwidth=32)
interpreter.register_implementations(MemrefFunctions())
Expand Down
13 changes: 9 additions & 4 deletions xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
SymbolOpInterface,
)
from xdsl.utils.bitwise_casts import is_power_of_two
from xdsl.utils.deprecation import deprecated_constructor
from xdsl.utils.exceptions import VerifyException
from xdsl.utils.hints import isa

Expand Down Expand Up @@ -358,11 +359,15 @@ class GetGlobal(IRDLOperation):
memref: OpResult = result_def(MemRefType[Attribute])
name_: SymbolRefAttr = prop_def(SymbolRefAttr, prop_name="name")

def __init__(self, name: str | SymbolRefAttr, return_type: Attribute):
if isinstance(name, str):
name = SymbolRefAttr(name)
super().__init__(result_types=[return_type], properties={"name": name})

@deprecated_constructor
@staticmethod
def get(name: str, return_type: Attribute) -> GetGlobal:
return GetGlobal.build(
result_types=[return_type], properties={"name": SymbolRefAttr(name)}
)
def get(name: str | SymbolRefAttr, return_type: Attribute) -> GetGlobal:
return GetGlobal(name, return_type)

assembly_format = "$name `:` type($memref) attr-dict"

Expand Down
14 changes: 10 additions & 4 deletions xdsl/tools/command_line_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,10 +352,10 @@ def get_convert_linalg_to_loops():

return convert_linalg_to_loops.ConvertLinalgToLoopsPass

def get_stencil_tensorize_z_dimension():
from xdsl.transforms.experimental import stencil_tensorize_z_dimension
def get_convert_ml_program_to_memref():
from xdsl.transforms import convert_ml_program_to_memref

return stencil_tensorize_z_dimension.StencilTensorizeZDimension
return convert_ml_program_to_memref.ConvertMlProgramToMemrefPass

def get_convert_riscv_scf_for_to_frep():
from xdsl.transforms import convert_riscv_scf_for_to_frep
Expand Down Expand Up @@ -577,6 +577,11 @@ def get_replace_incompatible_fpga():

return replace_incompatible_fpga.ReplaceIncompatibleFPGA

def get_stencil_tensorize_z_dimension():
from xdsl.transforms.experimental import stencil_tensorize_z_dimension

return stencil_tensorize_z_dimension.StencilTensorizeZDimension

def get_stencil_unroll():
from xdsl.transforms import stencil_unroll

Expand All @@ -597,9 +602,9 @@ def get_test_lower_snitch_stream_to_asm():
"convert-func-to-riscv-func": get_convert_func_to_riscv_func,
"convert-linalg-to-memref-stream": get_convert_linalg_to_memref_stream,
"convert-linalg-to-loops": get_convert_linalg_to_loops,
"stencil-tensorize-z-dimension": get_stencil_tensorize_z_dimension,
"convert-memref-stream-to-loops": get_convert_memref_stream_to_loops,
"convert-memref-to-riscv": get_convert_memref_to_riscv,
"convert-ml-program-to-memref": get_convert_ml_program_to_memref,
"convert-onnx-to-linalg": get_convert_onnx_to_linalg,
"convert-memref-stream-to-snitch": get_convert_memref_stream_to_snitch,
"convert-print-format-to-riscv-debug": get_convert_print_format_to_riscv_debug,
Expand Down Expand Up @@ -639,6 +644,7 @@ def get_test_lower_snitch_stream_to_asm():
"snitch-allocate-registers": get_snitch_register_allocation,
"stencil-shape-inference": get_stencil_shape_inference,
"stencil-storage-materialization": get_stencil_storage_materialization,
"stencil-tensorize-z-dimension": get_stencil_tensorize_z_dimension,
"stencil-unroll": get_stencil_unroll,
"test-lower-snitch-stream-to-asm": get_test_lower_snitch_stream_to_asm,
}
Expand Down
80 changes: 80 additions & 0 deletions xdsl/transforms/convert_ml_program_to_memref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from typing import Any, cast

from xdsl.dialects import bufferization, memref, ml_program
from xdsl.dialects.builtin import (
ModuleOp,
TensorType,
UnitAttr,
)
from xdsl.ir import MLContext
from xdsl.passes import ModulePass
from xdsl.pattern_rewriter import (
GreedyRewritePatternApplier,
PatternRewriter,
PatternRewriteWalker,
RewritePattern,
op_type_rewrite_pattern,
)


class ConvertGlobalPattern(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: ml_program.Global, rewriter: PatternRewriter
) -> None:
if op.value is None:
raise NotImplementedError(
"Converting ml_program.global with no value not implemented"
)
assert isinstance(op_type := op.type, TensorType)
op_type = cast(TensorType[Any], op_type)
new_type = memref.MemRefType(op_type.element_type, op_type.shape)
rewriter.replace_matched_op(
(
memref.Global.get(
op.sym_name,
new_type,
op.value,
op.sym_visibility,
UnitAttr() if op.is_mutable is None else None,
),
)
)


class ConvertGlobalLoadConst(RewritePattern):
@op_type_rewrite_pattern
def match_and_rewrite(
self, op: ml_program.GlobalLoadConstant, rewriter: PatternRewriter
) -> None:
assert isinstance(op_type := op.result.type, TensorType)
op_type = cast(TensorType[Any], op_type)
new_type = memref.MemRefType(op_type.element_type, op_type.shape)
rewriter.replace_matched_op(
(
mem := memref.GetGlobal.get(op.global_attr, new_type),
bufferization.ToTensorOp(mem.memref),
)
)


class ConvertMlProgramToMemrefPass(ModulePass):
"""
Converts operations in the `ml_program` dialect to `memref`.
`ml_program` operations are at the `tensor` level of abstraction, so some of the
rewrites insert `bufferization` ops to bridge the gap to existing consumers of global
`tensor`s.
"""

name = "convert-ml-program-to-memref"

def apply(self, ctx: MLContext, op: ModuleOp) -> None:
PatternRewriteWalker(
GreedyRewritePatternApplier(
[
ConvertGlobalPattern(),
ConvertGlobalLoadConst(),
]
),
apply_recursively=False,
).rewrite_module(op)

0 comments on commit 8fbb3b2

Please sign in to comment.