From 8c50582303b1f9d9ad11ad8c981c9a92ae0d1efd Mon Sep 17 00:00:00 2001 From: Niansong Zhang Date: Sun, 24 Nov 2024 22:42:33 -0500 Subject: [PATCH 1/9] Struct build func init --- allo/ir/types.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/allo/ir/types.py b/allo/ir/types.py index 444d7e7b..d0871ad2 100644 --- a/allo/ir/types.py +++ b/allo/ir/types.py @@ -233,7 +233,12 @@ def __getitem__(self, key): return self.__getattr__(key) def build(self): - raise NotImplementedError("TODO") + fields = [] + types = [] + for name, dtype in self.dtype_dict.items(): + fields.append(StringAttr.get(name)) + types.append(dtype.build()) + return StructType.get(fields, types) class Stream(AlloType): From 1cc1088da1b8967c1ac94b59cb955c72dab5eaca Mon Sep 17 00:00:00 2001 From: Niansong Zhang Date: Mon, 25 Nov 2024 00:00:08 -0500 Subject: [PATCH 2/9] Support dict in usedef --- allo/ir/use_def.py | 7 +++++++ tests/test_types.py | 36 +++++++++++++++++++++++++++++++++++- 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/allo/ir/use_def.py b/allo/ir/use_def.py index 8d23da64..26feb8d5 100644 --- a/allo/ir/use_def.py +++ b/allo/ir/use_def.py @@ -151,6 +151,13 @@ def visit_For(self, node): res.append(self.visit(stmt)) return res raise RuntimeError("Unsupported for loop") + + + def visit_Dict(self, node): + res = set() + for value in node.values: + res = res.union(self.visit(value)) + return res def visit_Call(self, node): original_func_id = self.func_id diff --git a/tests/test_types.py b/tests/test_types.py index 6341fc31..5bbe1fe0 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -506,6 +506,39 @@ def kernel(A: bool[16]) -> bool[16]: np.testing.assert_array_equal(np_A, np_B) +def test_struct(): + struct_type = T.Struct({"x": int32, "y": float32}) + def kernel(x: int32[16], y: float32[16]) -> int32: + sum_val: int32 = 0 + for i in range(16): + # Create struct inside function + point: struct_type = {"x": x[i], "y": y[i]} + sum_val += point.x + sum_val += int(point.y) + return sum_val + + s = allo.customize(kernel) + print(s.module) + mod = s.build() + + # Create separate arrays for x and y + np_x = np.zeros(16, dtype=np.int32) + np_y = np.zeros(16, dtype=np.float32) + + # Fill with test data + for i in range(16): + np_x[i] = i + np_y[i] = float(i) + + allo_result = mod(np_x, np_y) + + # Calculate expected result + expected = sum(x + int(y) for x, y in zip(np_x, np_y)) + + assert allo_result == expected + + + ###################################################################### # Legacy tests ###################################################################### @@ -534,4 +567,5 @@ def test_type_comparison(): if __name__ == "__main__": - pytest.main([__file__]) + # pytest.main([__file__]) + test_struct() From b69054217206dbda594139e68037c17913e66616 Mon Sep 17 00:00:00 2001 From: Niansong Zhang Date: Mon, 25 Nov 2024 04:16:34 -0500 Subject: [PATCH 3/9] Add symbol resolver and type infer --- allo/ir/infer.py | 28 +++++++++++++++++++++++++++- allo/ir/symbol_resolver.py | 13 +++++++++++++ tests/test_types.py | 4 ++-- 3 files changed, 42 insertions(+), 3 deletions(-) diff --git a/allo/ir/infer.py b/allo/ir/infer.py index 62a0b80f..406448cf 100644 --- a/allo/ir/infer.py +++ b/allo/ir/infer.py @@ -21,6 +21,7 @@ uint1, int32, float32, + Struct, Stream, ) from .typing_rule import get_typing_rule @@ -122,6 +123,8 @@ def visit_Constant(ctx, node): node.dtype = int32 elif isinstance(node.value, float): node.dtype = float32 + elif isinstance(node.value, str): + node.dtype = str elif node.value is None: return ASTResolver.resolve_constant(node.value, ctx) else: @@ -135,6 +138,17 @@ def visit_Tuple(ctx, node): node.dtype = [elt.dtype for elt in node.elts] return node + @staticmethod + def visit_Dict(ctx, node): + # Visit all keys and values + visit_stmts(ctx, node.keys) + visit_stmts(ctx, node.values) + + # Dictionary type is a mapping of keys to value types + node.dtype = {k.value: v.dtype for k, v in zip(node.keys, node.values)} + node.shape = () # one dict is considered as one Struct-type scalar + return node + @staticmethod def visit_Index(ctx, node): value = visit_stmt(ctx, node.value) @@ -396,9 +410,22 @@ def visit_symbol(ctx, node): # pylint: disable=raising-bad-type raise None + @staticmethod @staticmethod def visit_Subscript(ctx, node): value = visit_stmt(ctx, node.value) + # Handle struct field access + if len(value.shape) == 0 and isinstance(value.dtype, Struct): + if not isinstance(node.slice, ast.Constant) or not isinstance(node.slice.value, str): + raise RuntimeError("Struct field access must use string literal") + field = node.slice.value + if field not in value.dtype.dtype_dict: + raise RuntimeError(f"Field {field} not found in struct type") + node.dtype = value.dtype.dtype_dict[field] + node.shape = tuple() + return node + + # Handle tensor subscript if len(value.shape) > 0: visit_stmt(ctx, node.slice) # calculate tensor slicing @@ -465,7 +492,6 @@ def visit_Subscript(ctx, node): else: raise RuntimeError("Can only access bit (slice) for integers") return node - @staticmethod def visit_ExtSlice(ctx, node): stmts = visit_stmts(ctx, node.dims) diff --git a/allo/ir/symbol_resolver.py b/allo/ir/symbol_resolver.py index 79a393ba..bd1eb26a 100644 --- a/allo/ir/symbol_resolver.py +++ b/allo/ir/symbol_resolver.py @@ -31,6 +31,19 @@ def resolve(node, scope): # pylint: disable=eval-used return eval(compile(ast.Expression(node), "", "eval"), scope) + if isinstance(node, ast.Dict): + # Resolve dictionary literals to struct types + from .types import Struct + keys = [k.value if isinstance(k, ast.Constant) else None for k in node.keys] + # If any key is not a string constant, this isn't a valid struct type + if any(not isinstance(k, str) for k in keys): + return None + values = [ASTResolver.resolve(v, scope) for v in node.values] + # If any value type couldn't be resolved, return None + if any(v is None for v in values): + return None + return Struct(dict(zip(keys, values))) + if isinstance(node, ast.Name): return scope.get(node.id) diff --git a/tests/test_types.py b/tests/test_types.py index 5bbe1fe0..1619b75d 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -513,8 +513,8 @@ def kernel(x: int32[16], y: float32[16]) -> int32: for i in range(16): # Create struct inside function point: struct_type = {"x": x[i], "y": y[i]} - sum_val += point.x - sum_val += int(point.y) + sum_val += point['x'] + sum_val += int(point['y']) return sum_val s = allo.customize(kernel) From da7981fba3accfc6adfbfffb59184439a9330bb1 Mon Sep 17 00:00:00 2001 From: Niansong Zhang Date: Mon, 25 Nov 2024 04:31:57 -0500 Subject: [PATCH 4/9] Add struct construct, get builder --- allo/ir/builder.py | 32 ++++++++++++++++++++++++++++++++ allo/ir/infer.py | 2 +- allo/ir/types.py | 4 +--- 3 files changed, 34 insertions(+), 4 deletions(-) diff --git a/allo/ir/builder.py b/allo/ir/builder.py index be79553a..9a32b51d 100644 --- a/allo/ir/builder.py +++ b/allo/ir/builder.py @@ -1235,9 +1235,41 @@ def build_Subscript(ctx, node, val=None, idx=0): return ASTTransformer.build_memory_access(ctx, node, val=val, idx=idx) elif len(node.value.shape) > 0 and ctx.enable_tensor: return ASTTransformer.build_tensor_access(ctx, node, val=val, idx=idx) + elif isinstance(node.value.dtype, Struct): + # Get the struct value + value = build_stmt(ctx, node.value) + # Get the field name from the string slice + field_name = node.slice.value + # Get the field index from the struct type + field_idx = list(node.value.dtype.dtype_dict.keys()).index(field_name) + # Create index attribute + idx_attr = IntegerAttr.get(IntegerType.get_signless(64), field_idx) + # Extract the field using struct get op + return allo_d.StructGetOp( + node.value.dtype[field_name].build(), + value.result, + idx_attr, + ip=ctx.get_ip() + ) else: # bit operation return ASTTransformer.build_bit_operation(ctx, node, val=val, idx=idx) + + @staticmethod + def build_Dict(ctx, node): + # Build each value in the dictionary + values = [build_stmt(ctx, value) for value in node.values] + + # Get the field names from the keys + field_names = [key.value for key in node.keys] + + # Create a struct construct op with the values + return allo_d.StructConstructOp( + node.dtype.build(), # The struct type should already be inferred + [value.result for value in values], + ip=ctx.get_ip() + ) + @staticmethod def build_AnnAssign(ctx, node): shape, dtype = node.shape, node.dtype diff --git a/allo/ir/infer.py b/allo/ir/infer.py index 406448cf..c5e004d9 100644 --- a/allo/ir/infer.py +++ b/allo/ir/infer.py @@ -145,7 +145,7 @@ def visit_Dict(ctx, node): visit_stmts(ctx, node.values) # Dictionary type is a mapping of keys to value types - node.dtype = {k.value: v.dtype for k, v in zip(node.keys, node.values)} + node.dtype = Struct({k.value: v.dtype for k, v in zip(node.keys, node.values)}) node.shape = () # one dict is considered as one Struct-type scalar return node diff --git a/allo/ir/types.py b/allo/ir/types.py index d0871ad2..4c478f79 100644 --- a/allo/ir/types.py +++ b/allo/ir/types.py @@ -233,12 +233,10 @@ def __getitem__(self, key): return self.__getattr__(key) def build(self): - fields = [] types = [] for name, dtype in self.dtype_dict.items(): - fields.append(StringAttr.get(name)) types.append(dtype.build()) - return StructType.get(fields, types) + return allo_d.StructType.get(types) class Stream(AlloType): From cc64d4ce11181438da463cf77f8bdf3ca2722291 Mon Sep 17 00:00:00 2001 From: Niansong Zhang Date: Mon, 25 Nov 2024 04:37:58 -0500 Subject: [PATCH 5/9] Add lower composite type pass in LLVM module --- allo/backend/llvm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/allo/backend/llvm.py b/allo/backend/llvm.py index 71e6ded8..4cd3023f 100644 --- a/allo/backend/llvm.py +++ b/allo/backend/llvm.py @@ -69,6 +69,8 @@ def __init__(self, mod, top_func_name, ext_libs=None): call_ext_libs_in_ptr(self.module, ext_libs) # Remove .partition() annotation allo_d.remove_stride_map(self.module) + # Lower composite (struct) types + allo_d.lower_composite_type(self.module) # Resolve FixedType allo_d.lower_fixed_to_int(self.module) allo_d.lower_bit_ops(self.module) From ef017a81cb38c9943d07b3c76192157a4e9162b7 Mon Sep 17 00:00:00 2001 From: Niansong Zhang Date: Mon, 25 Nov 2024 09:55:44 -0500 Subject: [PATCH 6/9] Lint, format --- allo/ir/builder.py | 10 +++------- allo/ir/infer.py | 15 +++++++++------ allo/ir/symbol_resolver.py | 1 + allo/ir/types.py | 2 +- allo/ir/use_def.py | 1 - tests/test_types.py | 23 +++++++++++------------ 6 files changed, 25 insertions(+), 27 deletions(-) diff --git a/allo/ir/builder.py b/allo/ir/builder.py index 9a32b51d..c34cda82 100644 --- a/allo/ir/builder.py +++ b/allo/ir/builder.py @@ -1249,25 +1249,21 @@ def build_Subscript(ctx, node, val=None, idx=0): node.value.dtype[field_name].build(), value.result, idx_attr, - ip=ctx.get_ip() + ip=ctx.get_ip(), ) else: # bit operation return ASTTransformer.build_bit_operation(ctx, node, val=val, idx=idx) - @staticmethod def build_Dict(ctx, node): # Build each value in the dictionary values = [build_stmt(ctx, value) for value in node.values] - - # Get the field names from the keys - field_names = [key.value for key in node.keys] - + # Create a struct construct op with the values return allo_d.StructConstructOp( node.dtype.build(), # The struct type should already be inferred [value.result for value in values], - ip=ctx.get_ip() + ip=ctx.get_ip(), ) @staticmethod diff --git a/allo/ir/infer.py b/allo/ir/infer.py index c5e004d9..6944ab08 100644 --- a/allo/ir/infer.py +++ b/allo/ir/infer.py @@ -72,9 +72,9 @@ def visit_type_hint(ctx, node): if dtype is Stream: # create an actual class instance base_type, base_shape = TypeInferer.visit_type_hint(ctx, node.slice) - dtype = Stream(base_type, base_shape) + stream_dtype = Stream(base_type, base_shape) shape = tuple() - return dtype, shape + return stream_dtype, shape assert dtype is not None, f"Unsupported type {node.value.id}" size = node.slice.value if isinstance(node.slice, ast.Index) else node.slice elts = size.elts if isinstance(size, ast.Tuple) else [size] @@ -143,10 +143,10 @@ def visit_Dict(ctx, node): # Visit all keys and values visit_stmts(ctx, node.keys) visit_stmts(ctx, node.values) - + # Dictionary type is a mapping of keys to value types node.dtype = Struct({k.value: v.dtype for k, v in zip(node.keys, node.values)}) - node.shape = () # one dict is considered as one Struct-type scalar + node.shape = () # one dict is considered as one Struct-type scalar return node @staticmethod @@ -416,7 +416,9 @@ def visit_Subscript(ctx, node): value = visit_stmt(ctx, node.value) # Handle struct field access if len(value.shape) == 0 and isinstance(value.dtype, Struct): - if not isinstance(node.slice, ast.Constant) or not isinstance(node.slice.value, str): + if not isinstance(node.slice, ast.Constant) or not isinstance( + node.slice.value, str + ): raise RuntimeError("Struct field access must use string literal") field = node.slice.value if field not in value.dtype.dtype_dict: @@ -424,7 +426,7 @@ def visit_Subscript(ctx, node): node.dtype = value.dtype.dtype_dict[field] node.shape = tuple() return node - + # Handle tensor subscript if len(value.shape) > 0: visit_stmt(ctx, node.slice) @@ -492,6 +494,7 @@ def visit_Subscript(ctx, node): else: raise RuntimeError("Can only access bit (slice) for integers") return node + @staticmethod def visit_ExtSlice(ctx, node): stmts = visit_stmts(ctx, node.dims) diff --git a/allo/ir/symbol_resolver.py b/allo/ir/symbol_resolver.py index bd1eb26a..810ff8b7 100644 --- a/allo/ir/symbol_resolver.py +++ b/allo/ir/symbol_resolver.py @@ -34,6 +34,7 @@ def resolve(node, scope): if isinstance(node, ast.Dict): # Resolve dictionary literals to struct types from .types import Struct + keys = [k.value if isinstance(k, ast.Constant) else None for k in node.keys] # If any key is not a string constant, this isn't a valid struct type if any(not isinstance(k, str) for k in keys): diff --git a/allo/ir/types.py b/allo/ir/types.py index 4c478f79..b3f78aef 100644 --- a/allo/ir/types.py +++ b/allo/ir/types.py @@ -234,7 +234,7 @@ def __getitem__(self, key): def build(self): types = [] - for name, dtype in self.dtype_dict.items(): + for _, dtype in self.dtype_dict.items(): types.append(dtype.build()) return allo_d.StructType.get(types) diff --git a/allo/ir/use_def.py b/allo/ir/use_def.py index 26feb8d5..028bc7ea 100644 --- a/allo/ir/use_def.py +++ b/allo/ir/use_def.py @@ -151,7 +151,6 @@ def visit_For(self, node): res.append(self.visit(stmt)) return res raise RuntimeError("Unsupported for loop") - def visit_Dict(self, node): res = set() diff --git a/tests/test_types.py b/tests/test_types.py index 1619b75d..ac1287f0 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -506,37 +506,37 @@ def kernel(A: bool[16]) -> bool[16]: np.testing.assert_array_equal(np_A, np_B) -def test_struct(): - struct_type = T.Struct({"x": int32, "y": float32}) +def test_struct_simple(): + struct_type = T.Struct({"x": int32, "y": float32}) + def kernel(x: int32[16], y: float32[16]) -> int32: sum_val: int32 = 0 for i in range(16): # Create struct inside function point: struct_type = {"x": x[i], "y": y[i]} - sum_val += point['x'] - sum_val += int(point['y']) + sum_val += point["x"] + sum_val += int(point["y"]) return sum_val s = allo.customize(kernel) print(s.module) mod = s.build() - + # Create separate arrays for x and y np_x = np.zeros(16, dtype=np.int32) np_y = np.zeros(16, dtype=np.float32) - + # Fill with test data for i in range(16): np_x[i] = i np_y[i] = float(i) - + allo_result = mod(np_x, np_y) - + # Calculate expected result expected = sum(x + int(y) for x, y in zip(np_x, np_y)) - - assert allo_result == expected + assert allo_result == expected ###################################################################### @@ -567,5 +567,4 @@ def test_type_comparison(): if __name__ == "__main__": - # pytest.main([__file__]) - test_struct() + pytest.main([__file__]) From 7b99c494e15dbb551f457120f95b205c5794faa9 Mon Sep 17 00:00:00 2001 From: Niansong Zhang Date: Mon, 25 Nov 2024 11:19:27 -0500 Subject: [PATCH 7/9] Fix an issue with DCE in LowerCompositeType --- allo/ir/infer.py | 1 - diff.txt | 233 +++++++++++++++++++++ mlir/lib/Conversion/LowerCompositeType.cpp | 23 ++ 3 files changed, 256 insertions(+), 1 deletion(-) create mode 100644 diff.txt diff --git a/allo/ir/infer.py b/allo/ir/infer.py index 6944ab08..bf7886dd 100644 --- a/allo/ir/infer.py +++ b/allo/ir/infer.py @@ -410,7 +410,6 @@ def visit_symbol(ctx, node): # pylint: disable=raising-bad-type raise None - @staticmethod @staticmethod def visit_Subscript(ctx, node): value = visit_stmt(ctx, node.value) diff --git a/diff.txt b/diff.txt new file mode 100644 index 00000000..fec56a73 --- /dev/null +++ b/diff.txt @@ -0,0 +1,233 @@ +diff --git a/allo/backend/llvm.py b/allo/backend/llvm.py +index 71e6ded..4cd3023 100644 +--- a/allo/backend/llvm.py ++++ b/allo/backend/llvm.py +@@ -69,6 +69,8 @@ class LLVMModule: + call_ext_libs_in_ptr(self.module, ext_libs) + # Remove .partition() annotation + allo_d.remove_stride_map(self.module) ++ # Lower composite (struct) types ++ allo_d.lower_composite_type(self.module) + # Resolve FixedType + allo_d.lower_fixed_to_int(self.module) + allo_d.lower_bit_ops(self.module) +diff --git a/allo/ir/builder.py b/allo/ir/builder.py +index be79553..c34cda8 100644 +--- a/allo/ir/builder.py ++++ b/allo/ir/builder.py +@@ -1235,9 +1235,37 @@ class ASTTransformer(ASTBuilder): + return ASTTransformer.build_memory_access(ctx, node, val=val, idx=idx) + elif len(node.value.shape) > 0 and ctx.enable_tensor: + return ASTTransformer.build_tensor_access(ctx, node, val=val, idx=idx) ++ elif isinstance(node.value.dtype, Struct): ++ # Get the struct value ++ value = build_stmt(ctx, node.value) ++ # Get the field name from the string slice ++ field_name = node.slice.value ++ # Get the field index from the struct type ++ field_idx = list(node.value.dtype.dtype_dict.keys()).index(field_name) ++ # Create index attribute ++ idx_attr = IntegerAttr.get(IntegerType.get_signless(64), field_idx) ++ # Extract the field using struct get op ++ return allo_d.StructGetOp( ++ node.value.dtype[field_name].build(), ++ value.result, ++ idx_attr, ++ ip=ctx.get_ip(), ++ ) + else: # bit operation + return ASTTransformer.build_bit_operation(ctx, node, val=val, idx=idx) + ++ @staticmethod ++ def build_Dict(ctx, node): ++ # Build each value in the dictionary ++ values = [build_stmt(ctx, value) for value in node.values] ++ ++ # Create a struct construct op with the values ++ return allo_d.StructConstructOp( ++ node.dtype.build(), # The struct type should already be inferred ++ [value.result for value in values], ++ ip=ctx.get_ip(), ++ ) ++ + @staticmethod + def build_AnnAssign(ctx, node): + shape, dtype = node.shape, node.dtype +diff --git a/allo/ir/infer.py b/allo/ir/infer.py +index 62a0b80..6944ab0 100644 +--- a/allo/ir/infer.py ++++ b/allo/ir/infer.py +@@ -21,6 +21,7 @@ from .types import ( + uint1, + int32, + float32, ++ Struct, + Stream, + ) + from .typing_rule import get_typing_rule +@@ -71,9 +72,9 @@ class TypeInferer(ASTVisitor): + if dtype is Stream: + # create an actual class instance + base_type, base_shape = TypeInferer.visit_type_hint(ctx, node.slice) +- dtype = Stream(base_type, base_shape) ++ stream_dtype = Stream(base_type, base_shape) + shape = tuple() +- return dtype, shape ++ return stream_dtype, shape + assert dtype is not None, f"Unsupported type {node.value.id}" + size = node.slice.value if isinstance(node.slice, ast.Index) else node.slice + elts = size.elts if isinstance(size, ast.Tuple) else [size] +@@ -122,6 +123,8 @@ class TypeInferer(ASTVisitor): + node.dtype = int32 + elif isinstance(node.value, float): + node.dtype = float32 ++ elif isinstance(node.value, str): ++ node.dtype = str + elif node.value is None: + return ASTResolver.resolve_constant(node.value, ctx) + else: +@@ -135,6 +138,17 @@ class TypeInferer(ASTVisitor): + node.dtype = [elt.dtype for elt in node.elts] + return node + ++ @staticmethod ++ def visit_Dict(ctx, node): ++ # Visit all keys and values ++ visit_stmts(ctx, node.keys) ++ visit_stmts(ctx, node.values) ++ ++ # Dictionary type is a mapping of keys to value types ++ node.dtype = Struct({k.value: v.dtype for k, v in zip(node.keys, node.values)}) ++ node.shape = () # one dict is considered as one Struct-type scalar ++ return node ++ + @staticmethod + def visit_Index(ctx, node): + value = visit_stmt(ctx, node.value) +@@ -396,9 +410,24 @@ class TypeInferer(ASTVisitor): + # pylint: disable=raising-bad-type + raise None + ++ @staticmethod + @staticmethod + def visit_Subscript(ctx, node): + value = visit_stmt(ctx, node.value) ++ # Handle struct field access ++ if len(value.shape) == 0 and isinstance(value.dtype, Struct): ++ if not isinstance(node.slice, ast.Constant) or not isinstance( ++ node.slice.value, str ++ ): ++ raise RuntimeError("Struct field access must use string literal") ++ field = node.slice.value ++ if field not in value.dtype.dtype_dict: ++ raise RuntimeError(f"Field {field} not found in struct type") ++ node.dtype = value.dtype.dtype_dict[field] ++ node.shape = tuple() ++ return node ++ ++ # Handle tensor subscript + if len(value.shape) > 0: + visit_stmt(ctx, node.slice) + # calculate tensor slicing +diff --git a/allo/ir/symbol_resolver.py b/allo/ir/symbol_resolver.py +index 79a393b..810ff8b 100644 +--- a/allo/ir/symbol_resolver.py ++++ b/allo/ir/symbol_resolver.py +@@ -31,6 +31,20 @@ class ASTResolver: + # pylint: disable=eval-used + return eval(compile(ast.Expression(node), "", "eval"), scope) + ++ if isinstance(node, ast.Dict): ++ # Resolve dictionary literals to struct types ++ from .types import Struct ++ ++ keys = [k.value if isinstance(k, ast.Constant) else None for k in node.keys] ++ # If any key is not a string constant, this isn't a valid struct type ++ if any(not isinstance(k, str) for k in keys): ++ return None ++ values = [ASTResolver.resolve(v, scope) for v in node.values] ++ # If any value type couldn't be resolved, return None ++ if any(v is None for v in values): ++ return None ++ return Struct(dict(zip(keys, values))) ++ + if isinstance(node, ast.Name): + return scope.get(node.id) + +diff --git a/allo/ir/types.py b/allo/ir/types.py +index 444d7e7..b3f78ae 100644 +--- a/allo/ir/types.py ++++ b/allo/ir/types.py +@@ -233,7 +233,10 @@ class Struct(AlloType): + return self.__getattr__(key) + + def build(self): +- raise NotImplementedError("TODO") ++ types = [] ++ for _, dtype in self.dtype_dict.items(): ++ types.append(dtype.build()) ++ return allo_d.StructType.get(types) + + + class Stream(AlloType): +diff --git a/allo/ir/use_def.py b/allo/ir/use_def.py +index 8d23da6..028bc7e 100644 +--- a/allo/ir/use_def.py ++++ b/allo/ir/use_def.py +@@ -152,6 +152,12 @@ class UseDefChain(ast.NodeVisitor): + return res + raise RuntimeError("Unsupported for loop") + ++ def visit_Dict(self, node): ++ res = set() ++ for value in node.values: ++ res = res.union(self.visit(value)) ++ return res ++ + def visit_Call(self, node): + original_func_id = self.func_id + if isinstance(node.func, ast.Name): +diff --git a/tests/test_types.py b/tests/test_types.py +index 6341fc3..ac1287f 100644 +--- a/tests/test_types.py ++++ b/tests/test_types.py +@@ -506,6 +506,39 @@ def test_boolean(): + np.testing.assert_array_equal(np_A, np_B) + + ++def test_struct_simple(): ++ struct_type = T.Struct({"x": int32, "y": float32}) ++ ++ def kernel(x: int32[16], y: float32[16]) -> int32: ++ sum_val: int32 = 0 ++ for i in range(16): ++ # Create struct inside function ++ point: struct_type = {"x": x[i], "y": y[i]} ++ sum_val += point["x"] ++ sum_val += int(point["y"]) ++ return sum_val ++ ++ s = allo.customize(kernel) ++ print(s.module) ++ mod = s.build() ++ ++ # Create separate arrays for x and y ++ np_x = np.zeros(16, dtype=np.int32) ++ np_y = np.zeros(16, dtype=np.float32) ++ ++ # Fill with test data ++ for i in range(16): ++ np_x[i] = i ++ np_y[i] = float(i) ++ ++ allo_result = mod(np_x, np_y) ++ ++ # Calculate expected result ++ expected = sum(x + int(y) for x, y in zip(np_x, np_y)) ++ ++ assert allo_result == expected ++ ++ + ###################################################################### + # Legacy tests + ###################################################################### diff --git a/mlir/lib/Conversion/LowerCompositeType.cpp b/mlir/lib/Conversion/LowerCompositeType.cpp index a6496d36..19378cd6 100644 --- a/mlir/lib/Conversion/LowerCompositeType.cpp +++ b/mlir/lib/Conversion/LowerCompositeType.cpp @@ -260,14 +260,37 @@ bool isLegal(func::FuncOp &func) { /// Pass entry point bool applyLowerCompositeType(ModuleOp &mod) { + // First check if there are any struct operations to lower + bool hasStructOps = false; + for (func::FuncOp func : mod.getOps()) { + func.walk([&](Operation *op) { + if (isa(op)) { + hasStructOps = true; + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (hasStructOps) break; + } + + // If no struct operations, return success without doing anything + if (!hasStructOps) { + return true; + } + + // Only apply transformations if we found struct operations for (func::FuncOp func : mod.getOps()) { lowerIntToStructOp(func); } + // Only run DCE if we actually did some transformations applyMemRefDCE(mod); + for (func::FuncOp func : mod.getOps()) { lowerStructType(func); } + // Run final DCE pass applyMemRefDCE(mod); + for (func::FuncOp func : mod.getOps()) { if (!isLegal(func)) { return false; From 0de70288351cde802ea4053d909818ee50f2177b Mon Sep 17 00:00:00 2001 From: Niansong Zhang Date: Mon, 25 Nov 2024 11:20:19 -0500 Subject: [PATCH 8/9] Fix format --- mlir/lib/Conversion/LowerCompositeType.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Conversion/LowerCompositeType.cpp b/mlir/lib/Conversion/LowerCompositeType.cpp index 19378cd6..6b04e39b 100644 --- a/mlir/lib/Conversion/LowerCompositeType.cpp +++ b/mlir/lib/Conversion/LowerCompositeType.cpp @@ -270,7 +270,8 @@ bool applyLowerCompositeType(ModuleOp &mod) { } return WalkResult::advance(); }); - if (hasStructOps) break; + if (hasStructOps) + break; } // If no struct operations, return success without doing anything @@ -284,13 +285,13 @@ bool applyLowerCompositeType(ModuleOp &mod) { } // Only run DCE if we actually did some transformations applyMemRefDCE(mod); - + for (func::FuncOp func : mod.getOps()) { lowerStructType(func); } // Run final DCE pass applyMemRefDCE(mod); - + for (func::FuncOp func : mod.getOps()) { if (!isLegal(func)) { return false; From e7acf14e526d71d9f08b16fae50aa760c232e044 Mon Sep 17 00:00:00 2001 From: Niansong Zhang Date: Mon, 25 Nov 2024 11:22:50 -0500 Subject: [PATCH 9/9] Delete diff file --- diff.txt | 233 ------------------------------------------------------- 1 file changed, 233 deletions(-) delete mode 100644 diff.txt diff --git a/diff.txt b/diff.txt deleted file mode 100644 index fec56a73..00000000 --- a/diff.txt +++ /dev/null @@ -1,233 +0,0 @@ -diff --git a/allo/backend/llvm.py b/allo/backend/llvm.py -index 71e6ded..4cd3023 100644 ---- a/allo/backend/llvm.py -+++ b/allo/backend/llvm.py -@@ -69,6 +69,8 @@ class LLVMModule: - call_ext_libs_in_ptr(self.module, ext_libs) - # Remove .partition() annotation - allo_d.remove_stride_map(self.module) -+ # Lower composite (struct) types -+ allo_d.lower_composite_type(self.module) - # Resolve FixedType - allo_d.lower_fixed_to_int(self.module) - allo_d.lower_bit_ops(self.module) -diff --git a/allo/ir/builder.py b/allo/ir/builder.py -index be79553..c34cda8 100644 ---- a/allo/ir/builder.py -+++ b/allo/ir/builder.py -@@ -1235,9 +1235,37 @@ class ASTTransformer(ASTBuilder): - return ASTTransformer.build_memory_access(ctx, node, val=val, idx=idx) - elif len(node.value.shape) > 0 and ctx.enable_tensor: - return ASTTransformer.build_tensor_access(ctx, node, val=val, idx=idx) -+ elif isinstance(node.value.dtype, Struct): -+ # Get the struct value -+ value = build_stmt(ctx, node.value) -+ # Get the field name from the string slice -+ field_name = node.slice.value -+ # Get the field index from the struct type -+ field_idx = list(node.value.dtype.dtype_dict.keys()).index(field_name) -+ # Create index attribute -+ idx_attr = IntegerAttr.get(IntegerType.get_signless(64), field_idx) -+ # Extract the field using struct get op -+ return allo_d.StructGetOp( -+ node.value.dtype[field_name].build(), -+ value.result, -+ idx_attr, -+ ip=ctx.get_ip(), -+ ) - else: # bit operation - return ASTTransformer.build_bit_operation(ctx, node, val=val, idx=idx) - -+ @staticmethod -+ def build_Dict(ctx, node): -+ # Build each value in the dictionary -+ values = [build_stmt(ctx, value) for value in node.values] -+ -+ # Create a struct construct op with the values -+ return allo_d.StructConstructOp( -+ node.dtype.build(), # The struct type should already be inferred -+ [value.result for value in values], -+ ip=ctx.get_ip(), -+ ) -+ - @staticmethod - def build_AnnAssign(ctx, node): - shape, dtype = node.shape, node.dtype -diff --git a/allo/ir/infer.py b/allo/ir/infer.py -index 62a0b80..6944ab0 100644 ---- a/allo/ir/infer.py -+++ b/allo/ir/infer.py -@@ -21,6 +21,7 @@ from .types import ( - uint1, - int32, - float32, -+ Struct, - Stream, - ) - from .typing_rule import get_typing_rule -@@ -71,9 +72,9 @@ class TypeInferer(ASTVisitor): - if dtype is Stream: - # create an actual class instance - base_type, base_shape = TypeInferer.visit_type_hint(ctx, node.slice) -- dtype = Stream(base_type, base_shape) -+ stream_dtype = Stream(base_type, base_shape) - shape = tuple() -- return dtype, shape -+ return stream_dtype, shape - assert dtype is not None, f"Unsupported type {node.value.id}" - size = node.slice.value if isinstance(node.slice, ast.Index) else node.slice - elts = size.elts if isinstance(size, ast.Tuple) else [size] -@@ -122,6 +123,8 @@ class TypeInferer(ASTVisitor): - node.dtype = int32 - elif isinstance(node.value, float): - node.dtype = float32 -+ elif isinstance(node.value, str): -+ node.dtype = str - elif node.value is None: - return ASTResolver.resolve_constant(node.value, ctx) - else: -@@ -135,6 +138,17 @@ class TypeInferer(ASTVisitor): - node.dtype = [elt.dtype for elt in node.elts] - return node - -+ @staticmethod -+ def visit_Dict(ctx, node): -+ # Visit all keys and values -+ visit_stmts(ctx, node.keys) -+ visit_stmts(ctx, node.values) -+ -+ # Dictionary type is a mapping of keys to value types -+ node.dtype = Struct({k.value: v.dtype for k, v in zip(node.keys, node.values)}) -+ node.shape = () # one dict is considered as one Struct-type scalar -+ return node -+ - @staticmethod - def visit_Index(ctx, node): - value = visit_stmt(ctx, node.value) -@@ -396,9 +410,24 @@ class TypeInferer(ASTVisitor): - # pylint: disable=raising-bad-type - raise None - -+ @staticmethod - @staticmethod - def visit_Subscript(ctx, node): - value = visit_stmt(ctx, node.value) -+ # Handle struct field access -+ if len(value.shape) == 0 and isinstance(value.dtype, Struct): -+ if not isinstance(node.slice, ast.Constant) or not isinstance( -+ node.slice.value, str -+ ): -+ raise RuntimeError("Struct field access must use string literal") -+ field = node.slice.value -+ if field not in value.dtype.dtype_dict: -+ raise RuntimeError(f"Field {field} not found in struct type") -+ node.dtype = value.dtype.dtype_dict[field] -+ node.shape = tuple() -+ return node -+ -+ # Handle tensor subscript - if len(value.shape) > 0: - visit_stmt(ctx, node.slice) - # calculate tensor slicing -diff --git a/allo/ir/symbol_resolver.py b/allo/ir/symbol_resolver.py -index 79a393b..810ff8b 100644 ---- a/allo/ir/symbol_resolver.py -+++ b/allo/ir/symbol_resolver.py -@@ -31,6 +31,20 @@ class ASTResolver: - # pylint: disable=eval-used - return eval(compile(ast.Expression(node), "", "eval"), scope) - -+ if isinstance(node, ast.Dict): -+ # Resolve dictionary literals to struct types -+ from .types import Struct -+ -+ keys = [k.value if isinstance(k, ast.Constant) else None for k in node.keys] -+ # If any key is not a string constant, this isn't a valid struct type -+ if any(not isinstance(k, str) for k in keys): -+ return None -+ values = [ASTResolver.resolve(v, scope) for v in node.values] -+ # If any value type couldn't be resolved, return None -+ if any(v is None for v in values): -+ return None -+ return Struct(dict(zip(keys, values))) -+ - if isinstance(node, ast.Name): - return scope.get(node.id) - -diff --git a/allo/ir/types.py b/allo/ir/types.py -index 444d7e7..b3f78ae 100644 ---- a/allo/ir/types.py -+++ b/allo/ir/types.py -@@ -233,7 +233,10 @@ class Struct(AlloType): - return self.__getattr__(key) - - def build(self): -- raise NotImplementedError("TODO") -+ types = [] -+ for _, dtype in self.dtype_dict.items(): -+ types.append(dtype.build()) -+ return allo_d.StructType.get(types) - - - class Stream(AlloType): -diff --git a/allo/ir/use_def.py b/allo/ir/use_def.py -index 8d23da6..028bc7e 100644 ---- a/allo/ir/use_def.py -+++ b/allo/ir/use_def.py -@@ -152,6 +152,12 @@ class UseDefChain(ast.NodeVisitor): - return res - raise RuntimeError("Unsupported for loop") - -+ def visit_Dict(self, node): -+ res = set() -+ for value in node.values: -+ res = res.union(self.visit(value)) -+ return res -+ - def visit_Call(self, node): - original_func_id = self.func_id - if isinstance(node.func, ast.Name): -diff --git a/tests/test_types.py b/tests/test_types.py -index 6341fc3..ac1287f 100644 ---- a/tests/test_types.py -+++ b/tests/test_types.py -@@ -506,6 +506,39 @@ def test_boolean(): - np.testing.assert_array_equal(np_A, np_B) - - -+def test_struct_simple(): -+ struct_type = T.Struct({"x": int32, "y": float32}) -+ -+ def kernel(x: int32[16], y: float32[16]) -> int32: -+ sum_val: int32 = 0 -+ for i in range(16): -+ # Create struct inside function -+ point: struct_type = {"x": x[i], "y": y[i]} -+ sum_val += point["x"] -+ sum_val += int(point["y"]) -+ return sum_val -+ -+ s = allo.customize(kernel) -+ print(s.module) -+ mod = s.build() -+ -+ # Create separate arrays for x and y -+ np_x = np.zeros(16, dtype=np.int32) -+ np_y = np.zeros(16, dtype=np.float32) -+ -+ # Fill with test data -+ for i in range(16): -+ np_x[i] = i -+ np_y[i] = float(i) -+ -+ allo_result = mod(np_x, np_y) -+ -+ # Calculate expected result -+ expected = sum(x + int(y) for x, y in zip(np_x, np_y)) -+ -+ assert allo_result == expected -+ -+ - ###################################################################### - # Legacy tests - ######################################################################