Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: Implement simple helpers for construct type access. #3024

Merged
merged 7 commits into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tests/dialects/test_hw.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,8 @@ def test_instance_builder():
assert inst_op.arg_names.data == (StringAttr("foo"), StringAttr("bar"))
assert inst_op.result_names.data == (StringAttr("baz"), StringAttr("qux"))

assert [op.type for op in inst_op.operands] == [i32, i64]
assert [res.type for res in inst_op.results] == [i32, i64]
assert inst_op.operand_types == (i32, i64)
assert inst_op.result_types == (i32, i64)


def test_hwmoduleop_hwmodulelike():
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def test_matmul_on_memrefs():

matmul_op = linalg.MatmulOp(inputs=(a.memref, b.memref), outputs=(c.memref,))

assert tuple(result.type for result in matmul_op.results) == ()
assert matmul_op.result_types == ()


def test_loop_range_methods():
Expand Down
2 changes: 1 addition & 1 deletion tests/dialects/test_pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def test_build_anr():
assert anr.constraint_name == StringAttr("anr")
assert anr.args == (type_val,)
assert len(anr.results) == 1
assert [r.type for r in anr.results] == [attribute_type]
assert anr.result_types == (attribute_type,)


def test_build_rewrite():
Expand Down
20 changes: 10 additions & 10 deletions tests/irdl/test_operation_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class ResultOp(IRDLOperation):
def test_result_builder():
op = ResultOp.build(result_types=[StringAttr("")])
op.verify()
assert [res.type for res in op.results] == [StringAttr("")]
assert op.result_types == (StringAttr(""),)


def test_result_builder_exception():
Expand All @@ -79,7 +79,7 @@ def test_opt_result_builder():
op1.verify()
op2.verify()
op3.verify()
assert [res.type for res in op1.results] == [StringAttr("")]
assert op1.result_types == (StringAttr(""),)
assert len(op2.results) == 0
assert len(op3.results) == 0

Expand All @@ -99,10 +99,10 @@ class VarResultOp(IRDLOperation):
def test_var_result_builder():
op = VarResultOp.build(result_types=[[StringAttr("0"), StringAttr("1")]])
op.verify()
assert [res.type for res in op.results] == [
assert op.result_types == (
StringAttr("0"),
StringAttr("1"),
]
)


@irdl_op_definition
Expand All @@ -122,12 +122,12 @@ def test_two_var_result_builder():
]
)
op.verify()
assert [res.type for res in op.results] == [
assert op.result_types == (
StringAttr("0"),
StringAttr("1"),
StringAttr("2"),
StringAttr("3"),
]
)

assert op.attributes[
AttrSizedResultSegments.attribute_name
Expand All @@ -142,12 +142,12 @@ def test_two_var_result_builder2():
]
)
op.verify()
assert [res.type for res in op.results] == [
assert op.result_types == (
StringAttr("0"),
StringAttr("1"),
StringAttr("2"),
StringAttr("3"),
]
)
assert op.attributes[
AttrSizedResultSegments.attribute_name
] == DenseArrayBase.from_list(i32, [1, 3])
Expand All @@ -172,13 +172,13 @@ def test_var_mixed_builder():
]
)
op.verify()
assert [res.type for res in op.results] == [
assert op.result_types == (
StringAttr("0"),
StringAttr("1"),
StringAttr("2"),
StringAttr("3"),
StringAttr("4"),
]
)

assert op.attributes[
AttrSizedResultSegments.attribute_name
Expand Down
2 changes: 1 addition & 1 deletion tests/pattern_rewriter/test_pattern_rewriter.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,7 +1075,7 @@ def match_and_rewrite(self, op: ModuleOp, rewriter: PatternRewriter):
old_op = next(ops_iter)
assert isinstance(old_op, test.TestOp)
new_region = rewriter.move_region_contents_to_new_regions(old_op.regions[0])
res_types = [r.type for r in old_op.results]
res_types = old_op.result_types
new_op = test.TestOp.create(result_types=res_types, regions=[new_region])
rewriter.insert_op(new_op, InsertPoint.after(old_op))

Expand Down
4 changes: 2 additions & 2 deletions tests/test_ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,8 +488,8 @@ def test_split_block_args():

new_block = old_block.split_before(op, arg_types=(i32, i64))

arg_types = [a.type for a in new_block.args]
assert arg_types == [i32, i64]
arg_types = new_block.arg_types
assert arg_types == (i32, i64)


def test_region_clone_into_circular_blocks():
Expand Down
2 changes: 1 addition & 1 deletion xdsl/backend/riscv/lowering/convert_func_to_riscv_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def match_and_rewrite(self, op: func.FuncOp, rewriter: PatternRewriter):
if (first_block := op.body.blocks.first) is not None:
cast_block_args_from_a_regs(first_block, rewriter)

input_types = [arg.type for arg in first_block.args]
input_types = first_block.arg_types
else:
input_types = tuple(a_regs_for_types(op.function_type.inputs.data))
result_types = list(a_regs_for_types(op.function_type.outputs.data))
Expand Down
8 changes: 3 additions & 5 deletions xdsl/backend/riscv/lowering/convert_riscv_scf_to_riscv_cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,13 @@ def match_and_rewrite(self, op: riscv_scf.ForOp, rewriter: PatternRewriter, /):
init_block = op.parent_block()
assert init_block is not None

body_args = op.body.blocks[0].args
body = op.body.blocks[0]

# TODO: add method to rewriter
end_block = init_block.split_before(
op, arg_types=(arg.type for arg in body_args)
)
end_block = init_block.split_before(op, arg_types=body.arg_types)

# The first argument of the loop body block is the loop counter by SCF invariant.
loop_var_reg = body_args[0].type
loop_var_reg = body.args[0].type
assert isinstance(loop_var_reg, riscv.IntRegisterType)

# Use the first block of the loop body as the condition block since it is the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def match_and_rewrite(
block = op.body.block

rewriter.insert_op_before_matched_op(
enable_op := snitch.SsrEnable(tuple(arg.type for arg in block.args))
enable_op := snitch.SsrEnable(block.arg_types)
)

for val, arg in zip(enable_op.streams, block.args):
Expand Down
8 changes: 4 additions & 4 deletions xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,14 +142,14 @@ def verify_(self) -> None:
raise VerifyException(
"Expected as many upper bound operands as upper bound dimensions and symbols."
)
iter_types = [op.type for op in self.inits]
if iter_types != [res.type for res in self.results]:
iter_types = tuple(op.type for op in self.inits)
if iter_types != self.result_types:
raise VerifyException(
"Expected all operands and result pairs to have matching types"
)
entry_block: Block = self.body.blocks[0]
block_arg_types = [IndexType()] + iter_types
arg_types = [arg.type for arg in entry_block.args]
block_arg_types = (IndexType(), *iter_types)
arg_types = entry_block.arg_types
if block_arg_types != arg_types:
raise VerifyException(
"Expected BlockArguments to have the same types as the operands"
Expand Down
6 changes: 2 additions & 4 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _verify(self):
return

entry_block: Block = self.body.blocks[0]
block_arg_types = [arg.type for arg in entry_block.args]
block_arg_types = entry_block.arg_types
if self.function_type.inputs.data != tuple(block_arg_types):
raise VerifyException(
"Expected entry block arguments to have the same types as the function "
Expand Down Expand Up @@ -718,9 +718,7 @@ def verify_(self) -> None:
func_op = self.parent_op()
assert isinstance(func_op, FuncOp) or isinstance(func_op, TaskOp)

if tuple(func_op.function_type.outputs) != tuple(
val.type for val in self.operands
):
if tuple(func_op.function_type.outputs.data) != self.operand_types:
raise VerifyException(
"Expected arguments to have the same types as the function output types"
)
Expand Down
6 changes: 1 addition & 5 deletions xdsl/dialects/fsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,11 +205,7 @@ def verify_(self):
raise VerifyException("Transition regions should not output any value")
while (parent := parent.parent_op()) is not None:
if isinstance(parent, MachineOp):
if not (
[operand.type for operand in self.operands]
== [result for result in parent.function_type.outputs]
and len(self.operands) == len(parent.function_type.outputs)
):
if self.operand_types != parent.function_type.outputs.data:
raise VerifyException(
"OutputOp output type must be consistent with the machine "
+ str(parent.sym_name)
Expand Down
6 changes: 3 additions & 3 deletions xdsl/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def verify_(self) -> None:
# TODO: how to verify that there is a terminator?
entry_block = self.body.blocks.first
assert entry_block is not None
block_arg_types = [arg.type for arg in entry_block.args]
block_arg_types = entry_block.arg_types
if self.function_type.inputs.data != tuple(block_arg_types):
raise VerifyException(
"Expected entry block arguments to have the same types as the function "
Expand Down Expand Up @@ -222,10 +222,10 @@ def update_function_type(self):
not self.is_declaration
), "update_function_type does not work with function declarations!"
return_op = self.get_return_op()
return_type: tuple[Attribute, ...] = self.function_type.outputs.data
return_type = self.function_type.outputs.data

if return_op is not None:
return_type = tuple(arg.type for arg in return_op.operands)
return_type = return_op.operand_types

self.properties["function_type"] = FunctionType.from_lists(
[arg.type for arg in self.args],
Expand Down
18 changes: 8 additions & 10 deletions xdsl/dialects/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def verify_(self) -> None:
f"{self.operand.type}. They must be the same type for gpu.all_reduce"
)

non_empty_body = any(b.ops for b in self.body.blocks)
non_empty_body = bool(self.body.blocks)
op_attr = self.op is not None
if non_empty_body == op_attr:
if op_attr:
Expand All @@ -258,9 +258,8 @@ def verify_(self) -> None:
"gpu.all_reduce need either a non empty body or an op attribute."
)
if non_empty_body:
region_args = self.body.blocks[0].args
args_types = [r.type for r in region_args]
if args_types != [self.result.type, self.operand.type]:
args_types = self.body.blocks[0].arg_types
if args_types != (self.result.type, self.operand.type):
raise VerifyException(
f"Expected {[str(t) for t in [self.result.type, self.operand.type]]}, "
f"got {[str(t) for t in args_types]}. A gpu.all_reduce's body must "
Expand Down Expand Up @@ -431,7 +430,7 @@ def __init__(
def verify_(self):
entry_block: Block = self.body.blocks[0]
function_inputs = self.function_type.inputs.data
block_arg_types = tuple(a.type for a in entry_block.args)
block_arg_types = entry_block.arg_types
if function_inputs != block_arg_types:
raise VerifyException(
"Expected first entry block arguments to have the same types as the "
Expand Down Expand Up @@ -558,9 +557,8 @@ def __init__(
def verify_(self) -> None:
if not any(b.ops for b in self.body.blocks):
raise VerifyException("gpu.launch requires a non-empty body.")
body_args = self.body.blocks[0].args
args_type = [a.type for a in body_args]
if args_type != [IndexType()] * 12:
args_type = self.body.blocks[0].arg_types
if args_type != (IndexType(),) * 12:
raise VerifyException(
f"Expected [12 x {str(IndexType())}], got {[str(t) for t in args_type]}. "
"gpu.launch's body arguments are 12 index arguments, with 3 block "
Expand Down Expand Up @@ -759,8 +757,8 @@ def __init__(self, operands: Sequence[SSAValue | Operation]):
def verify_(self) -> None:
op = self.parent_op()
if op is not None:
yield_type = [o.type for o in self.values]
result_type = [r.type for r in op.results]
yield_type = tuple(o.type for o in self.values)
result_type = op.result_types
if yield_type != result_type:
raise VerifyException(
f"Expected {[str(t) for t in result_type]}, got {[str(t) for t in yield_type]}. The gpu.yield values "
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/hw.py
Original file line number Diff line number Diff line change
Expand Up @@ -1230,7 +1230,7 @@ def print_output_port(name: str, port_type: Attribute):
printer.print_list(
zip(
(name.data for name in self.result_names),
(result.type for result in self.results),
self.result_types,
),
lambda x: print_output_port(*x),
)
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/pdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def print(self, printer: Printer) -> None:
printer.print(")")
if len(self.results) != 0:
printer.print(" : ")
printer.print_list([res.type for res in self.results], printer.print)
printer.print_list(self.result_types, printer.print)


@irdl_op_definition
Expand Down
14 changes: 13 additions & 1 deletion xdsl/ir/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,14 @@ class Operation(IRNode):
def parent_node(self) -> IRNode | None:
return self.parent

@property
def result_types(self) -> Sequence[Attribute]:
return tuple(r.type for r in self.results)

@property
def operand_types(self) -> Sequence[Attribute]:
return tuple(operand.type for operand in self.operands)

def parent_op(self) -> Operation | None:
if p := self.parent_region():
return p.parent
Expand Down Expand Up @@ -894,7 +902,7 @@ def clone_without_regions(
(value_mapper[operand] if operand in value_mapper else operand)
for operand in self.operands
]
result_types = [res.type for res in self.results]
result_types = self.result_types
attributes = self.attributes.copy()
properties = self.properties.copy()
successors = [
Expand Down Expand Up @@ -1212,6 +1220,10 @@ def __init__(

self.add_ops(ops)

@property
def arg_types(self) -> Sequence[Attribute]:
return tuple(arg.type for arg in self._args)

@property
def parent_node(self) -> IRNode | None:
return self.parent
Expand Down
2 changes: 1 addition & 1 deletion xdsl/irdl/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1390,7 +1390,7 @@ def irdl_op_verify_regions(
f"{len(region.blocks)} blocks"
)
if (first_block := region.blocks.first) is not None:
entry_args_types = tuple(a.type for a in first_block.args)
entry_args_types = first_block.arg_types
try:
region_def.entry_args.verify(entry_args_types, constraint_context)
except Exception as e:
Expand Down
4 changes: 1 addition & 3 deletions xdsl/printer.py
Original file line number Diff line number Diff line change
Expand Up @@ -938,9 +938,7 @@ def print_function_type(
self.print_string(")")

def print_operation_type(self, op: Operation) -> None:
self.print_function_type(
(o.type for o in op.operands), (r.type for r in op.results)
)
self.print_function_type(op.operand_types, op.result_types)
if self.print_debuginfo:
self.print_string(" loc(unknown)")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -570,7 +570,7 @@ def from_mutable(
op.name,
op_type,
operands,
[result.type for result in op.results],
op.result_types,
properties,
attributes,
successors,
Expand Down
8 changes: 3 additions & 5 deletions xdsl/transforms/common_subexpression_elimination.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __hash__(self):
self.name,
sum(hash(i) for i in self.op.attributes.items()),
sum(hash(i) for i in self.op.properties.items()),
hash(tuple(i.type for i in self.op.results)),
hash(self.op.result_types),
hash(self.op.operands),
)
)
Expand All @@ -55,12 +55,10 @@ def __eq__(self, other: object):
and self.op.attributes == other.op.attributes
and self.op.properties == other.op.properties
and self.op.operands == other.op.operands
and len(self.op.results) == len(other.op.results)
and all(r.type == o.type for r, o in zip(self.op.results, other.op.results))
and len(self.op.regions) == len(other.op.regions)
and self.op.result_types == other.op.result_types
and all(
s.is_structurally_equivalent(o)
for s, o in zip(self.op.regions, other.op.regions)
for s, o in zip(self.op.regions, other.op.regions, strict=True)
)
)

Expand Down
Loading
Loading