Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[IR] Add Support for Struct Types #265

Merged
merged 9 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions allo/backend/llvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def __init__(self, mod, top_func_name, ext_libs=None):
call_ext_libs_in_ptr(self.module, ext_libs)
# Remove .partition() annotation
allo_d.remove_stride_map(self.module)
# Lower composite (struct) types
allo_d.lower_composite_type(self.module)
# Resolve FixedType
allo_d.lower_fixed_to_int(self.module)
allo_d.lower_bit_ops(self.module)
Expand Down
28 changes: 28 additions & 0 deletions allo/ir/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,9 +1235,37 @@ def build_Subscript(ctx, node, val=None, idx=0):
return ASTTransformer.build_memory_access(ctx, node, val=val, idx=idx)
elif len(node.value.shape) > 0 and ctx.enable_tensor:
return ASTTransformer.build_tensor_access(ctx, node, val=val, idx=idx)
elif isinstance(node.value.dtype, Struct):
# Get the struct value
value = build_stmt(ctx, node.value)
# Get the field name from the string slice
field_name = node.slice.value
# Get the field index from the struct type
field_idx = list(node.value.dtype.dtype_dict.keys()).index(field_name)
# Create index attribute
idx_attr = IntegerAttr.get(IntegerType.get_signless(64), field_idx)
# Extract the field using struct get op
return allo_d.StructGetOp(
node.value.dtype[field_name].build(),
value.result,
idx_attr,
ip=ctx.get_ip(),
)
else: # bit operation
return ASTTransformer.build_bit_operation(ctx, node, val=val, idx=idx)

@staticmethod
def build_Dict(ctx, node):
# Build each value in the dictionary
values = [build_stmt(ctx, value) for value in node.values]

# Create a struct construct op with the values
return allo_d.StructConstructOp(
node.dtype.build(), # The struct type should already be inferred
[value.result for value in values],
ip=ctx.get_ip(),
)

@staticmethod
def build_AnnAssign(ctx, node):
shape, dtype = node.shape, node.dtype
Expand Down
33 changes: 31 additions & 2 deletions allo/ir/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
uint1,
int32,
float32,
Struct,
Stream,
)
from .typing_rule import get_typing_rule
Expand Down Expand Up @@ -71,9 +72,9 @@ def visit_type_hint(ctx, node):
if dtype is Stream:
# create an actual class instance
base_type, base_shape = TypeInferer.visit_type_hint(ctx, node.slice)
dtype = Stream(base_type, base_shape)
stream_dtype = Stream(base_type, base_shape)
shape = tuple()
return dtype, shape
return stream_dtype, shape
assert dtype is not None, f"Unsupported type {node.value.id}"
size = node.slice.value if isinstance(node.slice, ast.Index) else node.slice
elts = size.elts if isinstance(size, ast.Tuple) else [size]
Expand Down Expand Up @@ -122,6 +123,8 @@ def visit_Constant(ctx, node):
node.dtype = int32
elif isinstance(node.value, float):
node.dtype = float32
elif isinstance(node.value, str):
node.dtype = str
elif node.value is None:
return ASTResolver.resolve_constant(node.value, ctx)
else:
Expand All @@ -135,6 +138,17 @@ def visit_Tuple(ctx, node):
node.dtype = [elt.dtype for elt in node.elts]
return node

@staticmethod
def visit_Dict(ctx, node):
# Visit all keys and values
visit_stmts(ctx, node.keys)
visit_stmts(ctx, node.values)

# Dictionary type is a mapping of keys to value types
node.dtype = Struct({k.value: v.dtype for k, v in zip(node.keys, node.values)})
node.shape = () # one dict is considered as one Struct-type scalar
return node

@staticmethod
def visit_Index(ctx, node):
value = visit_stmt(ctx, node.value)
Expand Down Expand Up @@ -396,9 +410,24 @@ def visit_symbol(ctx, node):
# pylint: disable=raising-bad-type
raise None

@staticmethod
@staticmethod
zzzDavid marked this conversation as resolved.
Show resolved Hide resolved
def visit_Subscript(ctx, node):
value = visit_stmt(ctx, node.value)
# Handle struct field access
if len(value.shape) == 0 and isinstance(value.dtype, Struct):
if not isinstance(node.slice, ast.Constant) or not isinstance(
node.slice.value, str
):
raise RuntimeError("Struct field access must use string literal")
field = node.slice.value
if field not in value.dtype.dtype_dict:
raise RuntimeError(f"Field {field} not found in struct type")
node.dtype = value.dtype.dtype_dict[field]
node.shape = tuple()
return node

# Handle tensor subscript
if len(value.shape) > 0:
visit_stmt(ctx, node.slice)
# calculate tensor slicing
Expand Down
14 changes: 14 additions & 0 deletions allo/ir/symbol_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,20 @@ def resolve(node, scope):
# pylint: disable=eval-used
return eval(compile(ast.Expression(node), "", "eval"), scope)

if isinstance(node, ast.Dict):
# Resolve dictionary literals to struct types
from .types import Struct

keys = [k.value if isinstance(k, ast.Constant) else None for k in node.keys]
# If any key is not a string constant, this isn't a valid struct type
if any(not isinstance(k, str) for k in keys):
return None
values = [ASTResolver.resolve(v, scope) for v in node.values]
# If any value type couldn't be resolved, return None
if any(v is None for v in values):
return None
return Struct(dict(zip(keys, values)))

if isinstance(node, ast.Name):
return scope.get(node.id)

Expand Down
5 changes: 4 additions & 1 deletion allo/ir/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,10 @@ def __getitem__(self, key):
return self.__getattr__(key)

def build(self):
raise NotImplementedError("TODO")
types = []
for _, dtype in self.dtype_dict.items():
types.append(dtype.build())
return allo_d.StructType.get(types)


class Stream(AlloType):
Expand Down
6 changes: 6 additions & 0 deletions allo/ir/use_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,12 @@ def visit_For(self, node):
return res
raise RuntimeError("Unsupported for loop")

def visit_Dict(self, node):
res = set()
for value in node.values:
res = res.union(self.visit(value))
return res

def visit_Call(self, node):
original_func_id = self.func_id
if isinstance(node.func, ast.Name):
Expand Down
33 changes: 33 additions & 0 deletions tests/test_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,39 @@ def kernel(A: bool[16]) -> bool[16]:
np.testing.assert_array_equal(np_A, np_B)


def test_struct_simple():
struct_type = T.Struct({"x": int32, "y": float32})

def kernel(x: int32[16], y: float32[16]) -> int32:
sum_val: int32 = 0
for i in range(16):
# Create struct inside function
point: struct_type = {"x": x[i], "y": y[i]}
sum_val += point["x"]
sum_val += int(point["y"])
return sum_val

s = allo.customize(kernel)
print(s.module)
mod = s.build()

# Create separate arrays for x and y
np_x = np.zeros(16, dtype=np.int32)
np_y = np.zeros(16, dtype=np.float32)

# Fill with test data
for i in range(16):
np_x[i] = i
np_y[i] = float(i)

allo_result = mod(np_x, np_y)

# Calculate expected result
expected = sum(x + int(y) for x, y in zip(np_x, np_y))

assert allo_result == expected


######################################################################
# Legacy tests
######################################################################
Expand Down
Loading