Skip to content

Commit

Permalink
dialects (func): Add SymbolUserOpInterface implementation for `func.c…
Browse files Browse the repository at this point in the history
…all` operation (#3652)

This PR:

- Adds support for the `SymbolUserOpInterface` interface and implements
it for `func.call`
- Adds tests (pytest and filecheck) of the above

Resolves #3497
  • Loading branch information
compor authored Dec 19, 2024
1 parent b394263 commit 4b15917
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 4 deletions.
41 changes: 39 additions & 2 deletions tests/dialects/test_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,23 @@

from xdsl.builder import Builder, ImplicitBuilder
from xdsl.dialects.arith import AddiOp, ConstantOp
from xdsl.dialects.builtin import IntegerAttr, IntegerType, ModuleOp, i32, i64
from xdsl.dialects.builtin import (
IntegerAttr,
IntegerType,
ModuleOp,
StringAttr,
i32,
i64,
)
from xdsl.dialects.func import CallOp, FuncOp, ReturnOp
from xdsl.ir import Block, Region
from xdsl.traits import CallableOpInterface
from xdsl.irdl import (
IRDLOperation,
attr_def,
irdl_op_definition,
traits_def,
)
from xdsl.traits import CallableOpInterface, SymbolOpInterface
from xdsl.utils.exceptions import VerifyException


Expand Down Expand Up @@ -261,6 +274,30 @@ def test_call_II():
assert_print_op(mod, expected, None)


def test_call_III():
"""Call a symbol that is not func.func"""

@irdl_op_definition
class SymbolOp(IRDLOperation):
name = "test.symbol"

sym_name = attr_def(StringAttr)

traits = traits_def(SymbolOpInterface())

def __init__(self, name: str):
return super().__init__(attributes={"sym_name": StringAttr(name)})

symop = SymbolOp("foo")
call0 = CallOp("foo", [], [])
mod = ModuleOp([symop, call0])

with pytest.raises(
VerifyException, match="'@foo' does not reference a valid function"
):
mod.verify()


def test_return():
# Create two constants and add them, then return
a = ConstantOp.from_int_and_width(1, i32)
Expand Down
57 changes: 57 additions & 0 deletions tests/filecheck/dialects/func/func_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,60 @@ builtin.module {

// CHECK: Operation does not verify: Unexpected nested symbols in FlatSymbolRefAttr
// CHECK-NEXT: Underlying verification failure: expected empty array, but got ["invalid"]

// -----

func.func @bar() {
%1 = "test.op"() : () -> !test.type<"int">
%2 = func.call @foo(%1) : (!test.type<"int">) -> !test.type<"int">
func.return
}

// CHECK: '@foo' could not be found in symbol table

// -----

func.func @foo(%0 : !test.type<"int">) -> !test.type<"int">

func.func @bar() {
%1 = func.call @foo() : () -> !test.type<"int">
func.return
}

// CHECK: incorrect number of operands for callee

// -----

func.func @foo(%0 : !test.type<"int">)

func.func @bar() {
%1 = "test.op"() : () -> !test.type<"int">
%2 = func.call @foo(%1) : (!test.type<"int">) -> !test.type<"int">
func.return
}

// CHECK: incorrect number of results for callee

// -----

func.func @foo(%0 : !test.type<"int">) -> !test.type<"int">

func.func @bar() {
%1 = "test.op"() : () -> !test.type<"foo">
%2 = func.call @foo(%1) : (!test.type<"foo">) -> !test.type<"int">
func.return
}

// CHECK: operand type mismatch: expected operand type !test.type<"int">, but provided !test.type<"foo"> for operand number 0

// -----

func.func @foo(%0 : !test.type<"int">) -> !test.type<"int">

func.func @bar() {
%1 = "test.op"() : () -> !test.type<"int">
%2 = func.call @foo(%1) : (!test.type<"int">) -> !test.type<"foo">
func.return
}

// CHECK: result type mismatch: expected result type !test.type<"int">, but provided !test.type<"foo"> for result number 0
44 changes: 42 additions & 2 deletions xdsl/dialects/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
IsolatedFromAbove,
IsTerminator,
SymbolOpInterface,
SymbolTable,
SymbolUserOpInterface,
)
from xdsl.utils.exceptions import VerifyException

Expand All @@ -62,6 +64,42 @@ def get_result_types(cls, op: Operation) -> tuple[Attribute, ...]:
return op.function_type.outputs.data


class CallOpSymbolUserOpInterface(SymbolUserOpInterface):
def verify(self, op: Operation) -> None:
assert isinstance(op, CallOp)

found_callee = SymbolTable.lookup_symbol(op, op.callee)
if not found_callee:
raise VerifyException(f"'{op.callee}' could not be found in symbol table")

if not isinstance(found_callee, FuncOp):
raise VerifyException(f"'{op.callee}' does not reference a valid function")

if len(found_callee.function_type.inputs) != len(op.arguments):
raise VerifyException("incorrect number of operands for callee")

if len(found_callee.function_type.outputs) != len(op.result_types):
raise VerifyException("incorrect number of results for callee")

for idx, (found_operand, operand) in enumerate(
zip(found_callee.function_type.inputs, (arg.type for arg in op.arguments))
):
if found_operand != operand:
raise VerifyException(
f"operand type mismatch: expected operand type {found_operand}, but provided {operand} for operand number {idx}"
)

for idx, (found_res, res) in enumerate(
zip(found_callee.function_type.outputs, op.result_types)
):
if found_res != res:
raise VerifyException(
f"result type mismatch: expected result type {found_res}, but provided {res} for result number {idx}"
)

return


@irdl_op_definition
class FuncOp(IRDLOperation):
name = "func.func"
Expand Down Expand Up @@ -108,7 +146,6 @@ def verify_(self) -> None:
if len(self.body.blocks) == 0:
return

# TODO: how to verify that there is a terminator?
entry_block = self.body.blocks.first
assert entry_block is not None
block_arg_types = entry_block.arg_types
Expand Down Expand Up @@ -272,11 +309,14 @@ class CallOp(IRDLOperation):
callee = prop_def(FlatSymbolRefAttrConstr)
res = var_result_def()

traits = traits_def(
CallOpSymbolUserOpInterface(),
)

assembly_format = (
"$callee `(` $arguments `)` attr-dict `:` functional-type($arguments, $res)"
)

# TODO how do we verify that the types are correct?
def __init__(
self,
callee: str | SymbolRefAttr,
Expand Down
20 changes: 20 additions & 0 deletions xdsl/traits.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,26 @@ def verify(self, op: Operation) -> None:
regions += child_op.regions


class SymbolUserOpInterface(OpTrait, abc.ABC):
"""
Used to represent operations that reference Symbol operations. This provides the
ability to perform safe and efficient verification of symbol uses, as well as
additional functionality.
https://mlir.llvm.org/docs/Interfaces/#symbolinterfaces
"""

@abc.abstractmethod
def verify(self, op: Operation) -> None:
"""
This method should be adapted to the requirements of specific symbol users per
operation.
It corresponds to the verifySymbolUses in upstream MLIR.
"""
raise NotImplementedError()


class SymbolTable(OpTrait):
"""
SymbolTable operations are containers for Symbol operations. They offer lookup
Expand Down

0 comments on commit 4b15917

Please sign in to comment.