diff --git a/tests/interpreters/test_pdl_interpreter.py b/tests/interpreters/test_pdl_interpreter.py index 25756b3603..8e1e3b5458 100644 --- a/tests/interpreters/test_pdl_interpreter.py +++ b/tests/interpreters/test_pdl_interpreter.py @@ -12,7 +12,7 @@ ) from xdsl.interpreter import Interpreter from xdsl.interpreters.pdl import PDLMatcher, PDLRewriteFunctions, PDLRewritePattern -from xdsl.ir import Attribute +from xdsl.ir import Attribute, Block from xdsl.pattern_rewriter import PatternRewriteWalker from xdsl.utils.test_value import TestSSAValue @@ -148,6 +148,89 @@ def test_match_operand(): assert matcher.matching_context == {ssa_value: xdsl_value} +def test_match_result(): + matcher = PDLMatcher() + + type_op = pdl.TypeOp(IntegerType(32)) + operation_op = pdl.OperationOp(op_name=None, type_values=(type_op.result,)) + result_op = pdl.ResultOp(0, operation_op.op) + xdsl_op = test.TestOp(result_types=(i32,)) + xdsl_value = xdsl_op.res[0] + + # New result + # If the result of an operation has the expected type we should match + assert matcher.match_result(result_op.val, result_op, xdsl_value) + assert matcher.matching_context == { + result_op.val: xdsl_value, + operation_op.op: xdsl_op, + type_op.result: i32, + } + + # Same result + # We should accept the same value given the same constraint + assert matcher.match_result(result_op.val, result_op, xdsl_value) + assert matcher.matching_context == { + result_op.val: xdsl_value, + operation_op.op: xdsl_op, + type_op.result: i32, + } + + # Other result + # We should not match again with a different value, even if it has the correct type + other_xdsl_op = test.TestOp(result_types=(i32,)) + other_xdsl_value = other_xdsl_op.res[0] + + assert not matcher.match_result(result_op.val, result_op, other_xdsl_value) + assert matcher.matching_context == { + result_op.val: xdsl_value, + operation_op.op: xdsl_op, + type_op.result: i32, + } + + # Wrong type + # Matching should fail if the result's type differs from the expected type + wrong_type_op = pdl.TypeOp(i64) + wrong_type_operation_op = pdl.OperationOp( + op_name=None, type_values=(wrong_type_op.result,) + ) + wrong_type_result_op = pdl.ResultOp(0, wrong_type_operation_op.op) + + assert not matcher.match_result( + wrong_type_result_op.val, wrong_type_result_op, xdsl_value + ) + assert matcher.matching_context == { + result_op.val: xdsl_value, + operation_op.op: xdsl_op, + type_op.result: i32, + } + + # Index out of range + # If the operation has only one result, we should not match results at different + # indices + out_of_range_result_op = pdl.ResultOp(1, operation_op.op) + assert not matcher.match_result( + out_of_range_result_op.val, out_of_range_result_op, xdsl_value + ) + assert matcher.matching_context == { + result_op.val: xdsl_value, + operation_op.op: xdsl_op, + type_op.result: i32, + } + + # Block argument + # Result patterns should not match on block arguments + block = Block(arg_types=(i32,)) + block_arg_result_op = pdl.ResultOp(1, operation_op.op) + assert not matcher.match_result( + block_arg_result_op.val, block_arg_result_op, block.args[0] + ) + assert matcher.matching_context == { + result_op.val: xdsl_value, + operation_op.op: xdsl_op, + type_op.result: i32, + } + + def test_native_constraint_constant_parameter(): """ Check that `pdl.apply_native_constraint` can take constant attribute parameters