Skip to content

Commit

Permalink
misc: raise error if variadic results types aren't referenced in asse…
Browse files Browse the repository at this point in the history
…mbly format (#3416)

The number of results is not passed in when parsing operations. In the
generic format, the type of the operation always specifies the types of
the results, and `resultSegmentSizes` specifies the ranges of of the
results if multiple are variadic. In order to support variadic results,
the types an length of all variadic results must be present in the
custom syntax.
  • Loading branch information
superlopuh authored Nov 11, 2024
1 parent daab137 commit 97cfc1f
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 18 deletions.
40 changes: 39 additions & 1 deletion tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,14 @@
import pytest

from xdsl.context import MLContext
from xdsl.dialects.builtin import I32, BoolAttr, IntegerAttr, ModuleOp, UnitAttr
from xdsl.dialects import test
from xdsl.dialects.builtin import (
I32,
BoolAttr,
IntegerAttr,
ModuleOp,
UnitAttr,
)
from xdsl.dialects.test import Test, TestType
from xdsl.ir import (
Attribute,
Expand All @@ -29,6 +36,8 @@
ParamAttrConstraint,
ParameterDef,
ParsePropInAttrDict,
RangeOf,
RangeVarConstraint,
VarConstraint,
VarOperand,
VarOpResult,
Expand Down Expand Up @@ -1745,6 +1754,35 @@ class OneOperandOneResultNestedOp(IRDLOperation):
check_roundtrip(program, ctx)


def test_variadic_length_inference():
@irdl_op_definition
class RangeVarOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.range_var"
T: ClassVar = RangeVarConstraint("T", RangeOf(AnyAttr()))
ins = var_operand_def(T)
outs = var_result_def(T)

assembly_format = "$ins attr-dict `:` type($ins)"

with pytest.raises(
NotImplementedError,
match="Inference of length of variadic result 'outs' not implemented",
):
ctx = MLContext()
ctx.load_op(RangeVarOp)
ctx.load_dialect(Test)
program = textwrap.dedent("""\
%in0, %in1 = "test.op"() : () -> (index, index)
%out0, %out1 = test.range_var %in0, %in1 : index, index
""")

parser = Parser(ctx, program)
test_op = parser.parse_optional_operation()
assert isinstance(test_op, test.Operation)
my_op = parser.parse_optional_operation()
assert isinstance(my_op, RangeVarOp)


################################################################################
# Declarative Format Verification #
################################################################################
Expand Down
4 changes: 2 additions & 2 deletions xdsl/dialects/experimental/air.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class ChannelPutOp(IRDLOperation):
src_sizes = var_operand_def(IndexType())
src_strides = var_operand_def(IndexType())

async_token = opt_result_def(AsyncTokenAttr())
async_token = result_def(AsyncTokenAttr())

irdl_options = [AttrSizedOperandSegments()]

Expand Down Expand Up @@ -242,7 +242,7 @@ class DmaMemcpyNdOp(IRDLOperation):
src_sizes = var_operand_def(IndexType())
src_strides = var_operand_def(IndexType())

async_token = opt_result_def(AsyncTokenAttr())
async_token = result_def(AsyncTokenAttr())

irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()]

Expand Down
31 changes: 16 additions & 15 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,26 +234,27 @@ def resolve_result_types(self, state: ParsingState, op_def: OpDef) -> None:
Use the inferred type resolutions to fill missing result types from other parsed
types.
"""
for i, (result_type, (_, result_def)) in enumerate(
for i, (result_type, (result_name, result_def)) in enumerate(
zip(state.result_types, op_def.results, strict=True)
):
if result_type is None:
result_type = state.result_types[i]
range_length = len(result_type) if isinstance(result_type, list) else 1
result_type = result_def.constr.infer(
# The number of results is not passed in when parsing operations.
# In the generic format, the type of the operation always specifies the
# types of the results, and `resultSegmentSizes` specifies the ranges of
# of the results if multiple are variadic.
# In order to support variadic results, the types an length of all
# variadic results must be present in the custom syntax.
if isinstance(result_def, OptionalDef | VariadicDef):
raise NotImplementedError(
f"Inference of length of variadic result '{result_name}' not "
"implemented"
)
range_length = 1
inferred_result_types = result_def.constr.infer(
range_length, state.constraint_context
)
if isinstance(result_def, OptionalDef):
result_type = (
list[Attribute | None]()
if len(result_type) == 0
else result_type[0]
)
elif isinstance(result_def, VariadicDef):
result_type = cast(list[Attribute | None], result_type)
else:
result_type = result_type[0]
state.result_types[i] = result_type
resolved_result_type = inferred_result_types[0]
state.result_types[i] = resolved_result_type

def print(self, printer: Printer, op: IRDLOperation) -> None:
"""
Expand Down

0 comments on commit 97cfc1f

Please sign in to comment.