From f428ffa68c7b352ca84400624076e007dafceea9 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 11 Oct 2023 00:28:04 -0500 Subject: [PATCH 01/10] [mlir][python] value casting --- mlir/python/mlir/dialects/_ods_common.py | 58 +++++++++++++++- mlir/python/mlir/ir.py | 14 ++++ mlir/test/mlir-tblgen/op-python-bindings.td | 48 ++++++------- mlir/test/python/dialects/arith_dialect.py | 68 +++++++++++++++++-- mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 17 +++-- 5 files changed, 171 insertions(+), 34 deletions(-) diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 9cca7d659ec8c..dd41ee63c8bf7 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -1,11 +1,18 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +from collections import defaultdict # Provide a convenient name for sub-packages to resolve the main C-extension # with a relative import. from .._mlir_libs import _mlir as _cext -from typing import Sequence as _Sequence, Union as _Union +from typing import ( + Callable as _Callable, + Sequence as _Sequence, + Type as _Type, + TypeVar as _TypeVar, + Union as _Union, +) __all__ = [ "equally_sized_accessor", @@ -123,3 +130,52 @@ def get_op_result_or_op_results( if len(op.results) > 0 else op ) + + +U = _TypeVar("U", bound=_cext.ir.Value) +SubClassValueT = _Type[U] + +ValueCasterT = _Callable[ + [_Union[_cext.ir.Value, _cext.ir.OpResult]], _Union[SubClassValueT, None] +] + +_VALUE_CASTERS: defaultdict[ + _cext.ir.TypeID, + _Sequence[ValueCasterT], +] = defaultdict(list) + + +def has_value_caster(typeid: _cext.ir.TypeID): + if not isinstance(typeid, _cext.ir.TypeID): + raise ValueError(f"{typeid=} is not a TypeID") + if typeid in _VALUE_CASTERS: + return True + return False + + +def get_value_caster(typeid: _cext.ir.TypeID): + if not has_value_caster(typeid): + raise ValueError(f"no registered caster for {typeid=}") + return _VALUE_CASTERS[typeid] + + +def maybe_cast( + val: _Union[ + _cext.ir.Value, + _cext.ir.OpResult, + _Sequence[_cext.ir.Value], + _Sequence[_cext.ir.OpResult], + _cext.ir.Operation, + ] +) -> _Union[SubClassValueT, _Sequence[SubClassValueT], _cext.ir.Operation]: + if isinstance(val, (tuple, list)): + return tuple(map(maybe_cast, val)) + + if not isinstance(val, _cext.ir.Value) and not isinstance(val, _cext.ir.OpResult): + return val + + if has_value_caster(val.type.typeid): + for caster in get_value_caster(val.type.typeid): + if casted := caster(val): + return casted + return val diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index cf4228c2a63a9..74b32bfc5de76 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -5,6 +5,20 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug from ._mlir_libs._mlir import register_type_caster +from .dialects._ods_common import ValueCasterT, _VALUE_CASTERS + + +def register_value_caster(typeid: TypeID, priority: int = None): + def wrapper(caster: ValueCasterT): + if not isinstance(typeid, TypeID): + raise ValueError(f"{typeid=} is not a TypeID") + if priority is None: + _VALUE_CASTERS[typeid].append(caster) + else: + _VALUE_CASTERS[typeid].insert(priority, caster) + return caster + + return wrapper # Convenience decorator for registering user-friendly Attribute builders. diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 63dad1cc901fe..96b0c170dc5bb 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -61,7 +61,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands", } // CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttrSizedResultsOp(_ods_ir.OpView): @@ -108,7 +108,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results", } // CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -158,7 +158,7 @@ def AttributedOp : TestOp<"attributed_op"> { } // CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttributedOpWithOperands(_ods_ir.OpView): @@ -194,7 +194,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { } // CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView): @@ -218,7 +218,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> { } // CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))) // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> { @@ -236,7 +236,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu } // CHECK: def derive_result_types_op(type_, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))) // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op" def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> { @@ -246,7 +246,7 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir } // CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class EmptyOp(_ods_ir.OpView): @@ -263,7 +263,7 @@ def EmptyOp : TestOp<"empty">; // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: def empty(*, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))) // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op" def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { @@ -276,7 +276,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { } // CHECK: def infer_result_types_implied_op(*, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))) // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op" def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> { @@ -289,7 +289,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> } // CHECK: def infer_result_types_op(*, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class MissingNamesOp(_ods_ir.OpView): @@ -327,7 +327,7 @@ def MissingNamesOp : TestOp<"missing_names"> { } // CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneOptionalOperandOp(_ods_ir.OpView): @@ -358,7 +358,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> { } // CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicOperandOp(_ods_ir.OpView): @@ -390,7 +390,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { } // CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicResultOp(_ods_ir.OpView): @@ -423,7 +423,7 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> { } // CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class PythonKeywordOp(_ods_ir.OpView): @@ -447,7 +447,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> { } // CHECK: def python_keyword(in_, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))) // CHECK-LABEL: OPERATION_NAME = "test.same_results" def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { @@ -461,7 +461,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { } // CHECK: def same_results(in1, in2, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))) // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic" def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> { @@ -471,7 +471,7 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu } // CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -498,7 +498,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand", } // CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView): @@ -524,7 +524,7 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result", } // CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class SimpleOp(_ods_ir.OpView): @@ -564,7 +564,7 @@ def SimpleOp : TestOp<"simple"> { } // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))) // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region" @@ -591,7 +591,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> { } // CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))) // CHECK: class VariadicRegionOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.variadic_region" @@ -614,7 +614,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> { } // CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class WithSpecialCharactersOp(_ods_ir.OpView): @@ -623,7 +623,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> { } // CHECK: def _123with__special_characters(*, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip)) +// CHECK: return _maybe_cast(_get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class WithSuccessorsOp(_ods_ir.OpView): @@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> { } // CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)) \ No newline at end of file +// CHECK: return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))) \ No newline at end of file diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index 6d1c5eab75898..180d30ff4cfb3 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -1,8 +1,9 @@ # RUN: %PYTHON %s | FileCheck %s +from functools import partialmethod from mlir.ir import * -import mlir.dialects.func as func import mlir.dialects.arith as arith +from mlir.dialects._ods_common import maybe_cast def run(f): @@ -35,14 +36,71 @@ def testFastMathFlags(): print(r) -# CHECK-LABEL: TEST: testArithValueBuilder +# CHECK-LABEL: TEST: testArithValue @run -def testArithValueBuilder(): +def testArithValue(): + def _binary_op(lhs, rhs, op: str): + op = op.capitalize() + if arith._is_float_type(lhs.type): + op += "F" + elif arith._is_integer_like_type(lhs.type): + op += "I" + else: + raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}") + + op = getattr(arith, f"{op}Op") + return maybe_cast(op(lhs, rhs).result) + + @register_value_caster(F16Type.static_typeid) + @register_value_caster(F32Type.static_typeid) + @register_value_caster(F64Type.static_typeid) + @register_value_caster(IntegerType.static_typeid) + class ArithValue(Value): + __add__ = partialmethod(_binary_op, op="add") + __sub__ = partialmethod(_binary_op, op="sub") + __mul__ = partialmethod(_binary_op, op="mul") + + def __str__(self): + return super().__str__().replace("Value", "ArithValue") + + @register_value_caster(IntegerType.static_typeid, priority=0) + class ArithValue1(Value): + __mul__ = partialmethod(_binary_op, op="mul") + + def __str__(self): + return super().__str__().replace("Value", "ArithValue1") + + @register_value_caster(IntegerType.static_typeid, priority=0) + def no_op_caster(val): + print("no_op_caster", val) + return None + with Context() as ctx, Location.unknown(): module = Module.create() + f16_t = F16Type.get() f32_t = F32Type.get() + f64_t = F64Type.get() + i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): + a = arith.constant(value=FloatAttr.get(f16_t, 42.42)) + b = a + a + # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16) + print(b) + a = arith.constant(value=FloatAttr.get(f32_t, 42.42)) - # CHECK: %cst = arith.constant 4.242000e+01 : f32 - print(a) + b = a - a + # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32) + print(b) + + a = arith.constant(value=FloatAttr.get(f64_t, 42.42)) + b = a * a + # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64) + print(b) + + # CHECK: no_op_caster Value(%c1_i32 = arith.constant 1 : i32) + a = arith.constant(value=IntegerAttr.get(i32, 1)) + b = a * a + # CHECK: no_op_caster Value(%3 = arith.muli %c1_i32, %c1_i32 : i32) + # CHECK: ArithValue1(%3 = arith.muli %c1_i32, %c1_i32 : i32) + print(b) diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index c8ef84721090a..170ac6b87c693 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -30,7 +30,16 @@ constexpr const char *fileHeader = R"Py( # Autogenerated by mlir-tblgen; don't manually edit. from ._ods_common import _cext as _ods_cext -from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, get_op_result_or_op_results as _get_op_result_or_op_results +from ._ods_common import ( + SubClassValueT as _SubClassValueT, + equally_sized_accessor as _ods_equally_sized_accessor, + get_default_loc_context as _ods_get_default_loc_context, + get_op_result_or_op_results as _get_op_result_or_op_results, + get_op_result_or_value as _get_op_result_or_value, + get_op_results_or_values as _get_op_results_or_values, + maybe_cast as _maybe_cast, + segmented_accessor as _ods_segmented_accessor, +) _ods_ir = _ods_cext.ir import builtins @@ -263,7 +272,7 @@ constexpr const char *regionAccessorTemplate = R"Py( constexpr const char *valueBuilderTemplate = R"Py( def {0}({2}) -> {4}: - return _get_op_result_or_op_results({1}({3})) + return _maybe_cast(_get_op_result_or_op_results({1}({3}))) )Py"; static llvm::cl::OptionCategory @@ -1004,8 +1013,8 @@ static void emitValueBuilder(const Operator &op, llvm::join(valueBuilderParams, ", "), llvm::join(opBuilderArgs, ", "), (op.getNumResults() > 1 - ? "_Sequence[_ods_ir.OpResult]" - : (op.getNumResults() > 0 ? "_ods_ir.OpResult" + ? "_Sequence[_SubClassValueT]" + : (op.getNumResults() > 0 ? "_SubClassValueT" : "_ods_ir.Operation"))); } From f80ce47068a3e97fc4a24d3ae04bc41272230174 Mon Sep 17 00:00:00 2001 From: max Date: Mon, 23 Oct 2023 10:07:12 -0500 Subject: [PATCH 02/10] add new line to op-python-bindings.td --- mlir/test/mlir-tblgen/op-python-bindings.td | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 96b0c170dc5bb..9844040f8a33c 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> { } // CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))) \ No newline at end of file +// CHECK: return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))) From 42b5ebccfe167a7006e3debe996bace9e0de1ff5 Mon Sep 17 00:00:00 2001 From: max Date: Wed, 25 Oct 2023 14:23:52 -0500 Subject: [PATCH 03/10] WIP opresult and opoperand and blockarg casting --- .../mlir/Bindings/Python/PybindAdaptors.h | 1 + mlir/lib/Bindings/Python/Globals.h | 16 ++++ mlir/lib/Bindings/Python/IRCore.cpp | 52 +++++++++- mlir/lib/Bindings/Python/IRModule.cpp | 46 +++++++++ mlir/lib/Bindings/Python/IRModule.h | 10 +- mlir/lib/Bindings/Python/MainModule.cpp | 12 +++ mlir/lib/Bindings/Python/PybindUtils.h | 2 +- mlir/python/mlir/dialects/_ods_common.py | 52 +--------- mlir/python/mlir/ir.py | 16 +--- mlir/test/mlir-tblgen/op-python-bindings.td | 48 +++++----- mlir/test/python/dialects/arith_dialect.py | 28 ++---- mlir/test/python/ir/value.py | 96 +++++++++++++++++++ mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 3 +- 13 files changed, 260 insertions(+), 122 deletions(-) diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index 49680c8b79b13..acc90e4ab9a22 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -234,6 +234,7 @@ struct type_caster { return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr("Value") .attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule) + .attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)() .release(); }; }; diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 976297257ced0..7baf029436408 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -66,6 +66,13 @@ class PyGlobals { void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster, bool replace = false); + /// Adds a user-friendly value caster. Raises an exception if the mapping + /// already exists and replace == false. This is intended to be called by + /// implementation code. + void registerValueCaster(MlirTypeID mlirTypeID, + pybind11::function valueCaster, + bool replace = false); + /// Adds a concrete implementation dialect class. /// Raises an exception if the mapping already exists. /// This is intended to be called by implementation code. @@ -86,6 +93,10 @@ class PyGlobals { std::optional lookupTypeCaster(MlirTypeID mlirTypeID, MlirDialect dialect); + /// Returns the custom value caster for MlirTypeID mlirTypeID. + std::optional lookupValueCaster(MlirTypeID mlirTypeID, + MlirDialect dialect); + /// Looks up a registered dialect class by namespace. Note that this may /// trigger loading of the defining module and can arbitrarily re-enter. std::optional @@ -110,6 +121,11 @@ class PyGlobals { /// Map of MlirTypeID to custom type caster. llvm::DenseMap typeCasterMap; + /// Map of MlirTypeID to custom value caster. + llvm::DenseMap valueCasterMap; + /// Cache for map of MlirTypeID to custom value caster. + llvm::DenseMap valueCasterMapCache; + /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModules; diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 7cfea31dbb2e8..2c7ffda4e0880 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1899,13 +1899,26 @@ bool PyTypeID::operator==(const PyTypeID &other) const { } //------------------------------------------------------------------------------ -// PyValue and subclases. +// PyValue and subclasses. //------------------------------------------------------------------------------ pybind11::object PyValue::getCapsule() { return py::reinterpret_steal(mlirPythonValueToCapsule(get())); } +pybind11::object PyValue::maybeDownCast() { + MlirType type = mlirValueGetType(get()); + MlirTypeID mlirTypeID = mlirTypeGetTypeID(type); + assert(!mlirTypeIDIsNull(mlirTypeID) && + "mlirTypeID was expected to be non-null."); + std::optional valueCaster = + PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); + py::object this_ = py::cast(this, py::return_value_policy::move); + if (!valueCaster) + return this_; + return valueCaster.value()(this_); +} + PyValue PyValue::createFromCapsule(pybind11::object capsule) { MlirValue value = mlirPythonCapsuleToValue(capsule.ptr()); if (mlirValueIsNull(value)) @@ -2121,6 +2134,8 @@ class PyConcreteValue : public PyValue { return DerivedTy::isaFunction(otherValue); }, py::arg("other_value")); + cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](DerivedTy &self) { return self.maybeDownCast(); }); DerivedTy::bindDerived(cls); } @@ -2193,6 +2208,7 @@ class PyBlockArgumentList : public Sliceable { public: static constexpr const char *pyClassName = "BlockArgumentList"; + using SliceableT = Sliceable; PyBlockArgumentList(PyOperationRef operation, MlirBlock block, intptr_t startIndex = 0, intptr_t length = -1, @@ -2202,6 +2218,13 @@ class PyBlockArgumentList step), operation(std::move(operation)), block(block) {} + pybind11::object getItem(intptr_t index) override { + auto item = this->SliceableT::getItem(index); + if (item.ptr() != nullptr) + return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)(); + return item; + } + static void bindDerived(ClassTy &c) { c.def_property_readonly("types", [](PyBlockArgumentList &self) { return getValueTypes(self, self.operation->getContext()); @@ -2241,6 +2264,7 @@ class PyBlockArgumentList class PyOpOperandList : public Sliceable { public: static constexpr const char *pyClassName = "OpOperandList"; + using SliceableT = Sliceable; PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0, intptr_t length = -1, intptr_t step = 1) @@ -2250,6 +2274,13 @@ class PyOpOperandList : public Sliceable { step), operation(operation) {} + pybind11::object getItem(intptr_t index) override { + auto item = this->SliceableT::getItem(index); + if (item.ptr() != nullptr) + return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)(); + return item; + } + void dunderSetItem(intptr_t index, PyValue value) { index = wrapIndex(index); mlirOperationSetOperand(operation->get(), index, value.get()); @@ -2296,6 +2327,7 @@ class PyOpOperandList : public Sliceable { class PyOpResultList : public Sliceable { public: static constexpr const char *pyClassName = "OpResultList"; + using SliceableT = Sliceable; PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0, intptr_t length = -1, intptr_t step = 1) @@ -2303,7 +2335,14 @@ class PyOpResultList : public Sliceable { length == -1 ? mlirOperationGetNumResults(operation->get()) : length, step), - operation(operation) {} + operation(std::move(operation)) {} + + pybind11::object getItem(intptr_t index) override { + auto item = this->SliceableT::getItem(index); + if (item.ptr() != nullptr) + return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)(); + return item; + } static void bindDerived(ClassTy &c) { c.def_property_readonly("types", [](PyOpResultList &self) { @@ -2891,8 +2930,9 @@ void mlir::python::populateIRCore(py::module &m) { "single result)") .str()); } - return PyOpResult(operation.getRef(), - mlirOperationGetResult(operation, 0)); + PyOpResult result = PyOpResult( + operation.getRef(), mlirOperationGetResult(operation, 0)); + return result.maybeDownCast(); }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") @@ -3566,7 +3606,9 @@ void mlir::python::populateIRCore(py::module &m) { [](PyValue &self, PyValue &with) { mlirValueReplaceAllUsesOfWith(self.get(), with.get()); }, - kValueReplaceAllUsesWithDocstring); + kValueReplaceAllUsesWithDocstring) + .def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR, + [](PyValue &self) { return self.maybeDownCast(); }); PyBlockArgument::bind(m); PyOpResult::bind(m); PyOpOperand::bind(m); diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index 6c5cde86236ce..b3fc67aa86551 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -88,6 +88,19 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID, found = std::move(typeCaster); } +void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, + pybind11::function valueCaster, + bool replace) { + pybind11::object &found = valueCasterMap[mlirTypeID]; + if (found && !found.is_none() && !replace) + throw std::runtime_error("Value caster is already registered"); + found = std::move(valueCaster); + const auto foundIt = valueCasterMapCache.find(mlirTypeID); + if (foundIt != valueCasterMapCache.end() && !foundIt->second.is_none()) { + valueCasterMapCache[mlirTypeID] = found; + } +} + void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, py::object pyClass) { py::object &found = dialectClassMap[dialectNamespace]; @@ -134,6 +147,39 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, return std::nullopt; } +std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, + MlirDialect dialect) { + { + // Fast match against the value caster map first (common case). + const auto foundIt = valueCasterMapCache.find(mlirTypeID); + if (foundIt != valueCasterMapCache.end()) { + if (foundIt->second.is_none()) + return std::nullopt; + assert(foundIt->second && "py::function is defined"); + return foundIt->second; + } + } + + // Not found. Load the dialect namespace. + loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); + + // Attempt to find from the canonical map and cache. + { + const auto foundIt = valueCasterMap.find(mlirTypeID); + if (foundIt != valueCasterMap.end()) { + if (foundIt->second.is_none()) + return std::nullopt; + assert(foundIt->second && "py::object is defined"); + // Positive cache. + valueCasterMapCache[mlirTypeID] = foundIt->second; + return foundIt->second; + } + // Negative cache. + valueCasterMap[mlirTypeID] = py::none(); + return std::nullopt; + } +} + std::optional PyGlobals::lookupDialectClass(const std::string &dialectNamespace) { // Make sure dialect module is loaded. diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index 01ee4975d0e9a..b95c4578fbc22 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -761,7 +761,7 @@ class PyRegion { /// Wrapper around an MlirAsmState. class PyAsmState { - public: +public: PyAsmState(MlirValue value, bool useLocalScope) { flags = mlirOpPrintingFlagsCreate(); // The OpPrintingFlags are not exposed Python side, create locally and @@ -780,16 +780,14 @@ class PyAsmState { state = mlirAsmStateCreateForOperation(operation.getOperation().get(), flags); } - ~PyAsmState() { - mlirOpPrintingFlagsDestroy(flags); - } + ~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); } // Delete copy constructors. PyAsmState(PyAsmState &other) = delete; PyAsmState(const PyAsmState &other) = delete; MlirAsmState get() { return state; } - private: +private: MlirAsmState state; MlirOpPrintingFlags flags; }; @@ -1124,6 +1122,8 @@ class PyValue { /// Gets a capsule wrapping the void* within the MlirValue. pybind11::object getCapsule(); + virtual pybind11::object maybeDownCast(); + /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of /// the underlying MlirValue is still tied to the owning operation. static PyValue createFromCapsule(pybind11::object capsule); diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 2ba3a3677198c..68de57d389fbe 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -93,6 +93,18 @@ PYBIND11_MODULE(_mlir, m) { }, "typeid"_a, "type_caster"_a, "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); + m.def( + "register_value_caster", + [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { + return py::cpp_function( + [mlirTypeID, replace](py::object valueCaster) -> py::object { + PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster, + replace); + return valueCaster; + }); + }, + "typeid"_a, "replace"_a = false, + "Register a value caster for casting MLIR values to custom user values."); // Define and populate IR submodule. auto irModule = m.def_submodule("ir", "MLIR IR Bindings"); diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index 2a8da20bee049..efb7b713f80a4 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -231,7 +231,7 @@ class Sliceable { /// Returns the element at the given slice index. Supports negative indices /// by taking elements in inverse order. Returns a nullptr object if out /// of bounds. - pybind11::object getItem(intptr_t index) { + virtual pybind11::object getItem(intptr_t index) { // Negative indices mean we count from the end. index = wrapIndex(index); if (index < 0) { diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index dd41ee63c8bf7..fa73c197c17fa 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -7,7 +7,6 @@ # with a relative import. from .._mlir_libs import _mlir as _cext from typing import ( - Callable as _Callable, Sequence as _Sequence, Type as _Type, TypeVar as _TypeVar, @@ -132,50 +131,7 @@ def get_op_result_or_op_results( ) -U = _TypeVar("U", bound=_cext.ir.Value) -SubClassValueT = _Type[U] - -ValueCasterT = _Callable[ - [_Union[_cext.ir.Value, _cext.ir.OpResult]], _Union[SubClassValueT, None] -] - -_VALUE_CASTERS: defaultdict[ - _cext.ir.TypeID, - _Sequence[ValueCasterT], -] = defaultdict(list) - - -def has_value_caster(typeid: _cext.ir.TypeID): - if not isinstance(typeid, _cext.ir.TypeID): - raise ValueError(f"{typeid=} is not a TypeID") - if typeid in _VALUE_CASTERS: - return True - return False - - -def get_value_caster(typeid: _cext.ir.TypeID): - if not has_value_caster(typeid): - raise ValueError(f"no registered caster for {typeid=}") - return _VALUE_CASTERS[typeid] - - -def maybe_cast( - val: _Union[ - _cext.ir.Value, - _cext.ir.OpResult, - _Sequence[_cext.ir.Value], - _Sequence[_cext.ir.OpResult], - _cext.ir.Operation, - ] -) -> _Union[SubClassValueT, _Sequence[SubClassValueT], _cext.ir.Operation]: - if isinstance(val, (tuple, list)): - return tuple(map(maybe_cast, val)) - - if not isinstance(val, _cext.ir.Value) and not isinstance(val, _cext.ir.OpResult): - return val - - if has_value_caster(val.type.typeid): - for caster in get_value_caster(val.type.typeid): - if casted := caster(val): - return casted - return val +# This is the standard way to indicate subclass/inheritance relationship +# see the typing.Type doc string. +_U = _TypeVar("_U", bound=_cext.ir.Value) +SubClassValueT = _Type[_U] diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py index 74b32bfc5de76..18526ab8c3c02 100644 --- a/mlir/python/mlir/ir.py +++ b/mlir/python/mlir/ir.py @@ -4,21 +4,7 @@ from ._mlir_libs._mlir.ir import * from ._mlir_libs._mlir.ir import _GlobalDebug -from ._mlir_libs._mlir import register_type_caster -from .dialects._ods_common import ValueCasterT, _VALUE_CASTERS - - -def register_value_caster(typeid: TypeID, priority: int = None): - def wrapper(caster: ValueCasterT): - if not isinstance(typeid, TypeID): - raise ValueError(f"{typeid=} is not a TypeID") - if priority is None: - _VALUE_CASTERS[typeid].append(caster) - else: - _VALUE_CASTERS[typeid].insert(priority, caster) - return caster - - return wrapper +from ._mlir_libs._mlir import register_type_caster, register_value_caster # Convenience decorator for registering user-friendly Attribute builders. diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index 9844040f8a33c..f7df8ba2df0ae 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -61,7 +61,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands", } // CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttrSizedResultsOp(_ods_ir.OpView): @@ -108,7 +108,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results", } // CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -158,7 +158,7 @@ def AttributedOp : TestOp<"attributed_op"> { } // CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class AttributedOpWithOperands(_ods_ir.OpView): @@ -194,7 +194,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> { } // CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView): @@ -218,7 +218,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> { } // CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)) // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op" def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> { @@ -236,7 +236,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu } // CHECK: def derive_result_types_op(type_, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip)) // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op" def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> { @@ -246,7 +246,7 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir } // CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class EmptyOp(_ods_ir.OpView): @@ -263,7 +263,7 @@ def EmptyOp : TestOp<"empty">; // CHECK: successors=_ods_successors, regions=regions, loc=loc, ip=ip)) // CHECK: def empty(*, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip)) // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op" def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { @@ -276,7 +276,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> { } // CHECK: def infer_result_types_implied_op(*, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip)) // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op" def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> { @@ -289,7 +289,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> } // CHECK: def infer_result_types_op(*, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class MissingNamesOp(_ods_ir.OpView): @@ -327,7 +327,7 @@ def MissingNamesOp : TestOp<"missing_names"> { } // CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneOptionalOperandOp(_ods_ir.OpView): @@ -358,7 +358,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> { } // CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicOperandOp(_ods_ir.OpView): @@ -390,7 +390,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> { } // CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class OneVariadicResultOp(_ods_ir.OpView): @@ -423,7 +423,7 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> { } // CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class PythonKeywordOp(_ods_ir.OpView): @@ -447,7 +447,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> { } // CHECK: def python_keyword(in_, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip)) // CHECK-LABEL: OPERATION_NAME = "test.same_results" def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { @@ -461,7 +461,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> { } // CHECK: def same_results(in1, in2, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)) // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic" def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> { @@ -471,7 +471,7 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu } // CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) @@ -498,7 +498,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand", } // CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView): @@ -524,7 +524,7 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result", } // CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class SimpleOp(_ods_ir.OpView): @@ -564,7 +564,7 @@ def SimpleOp : TestOp<"simple"> { } // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip)) // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region" @@ -591,7 +591,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> { } // CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)) // CHECK: class VariadicRegionOp(_ods_ir.OpView): // CHECK-LABEL: OPERATION_NAME = "test.variadic_region" @@ -614,7 +614,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> { } // CHECK: def variadic_region(num_variadic, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class WithSpecialCharactersOp(_ods_ir.OpView): @@ -623,7 +623,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> { } // CHECK: def _123with__special_characters(*, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip)) // CHECK: @_ods_cext.register_operation(_Dialect) // CHECK: class WithSuccessorsOp(_ods_ir.OpView): @@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> { } // CHECK: def with_successors(successor, successors, *, loc=None, ip=None) -// CHECK: return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))) +// CHECK: return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)) diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index 180d30ff4cfb3..39c3d5799a656 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -3,7 +3,7 @@ from mlir.ir import * import mlir.dialects.arith as arith -from mlir.dialects._ods_common import maybe_cast +import mlir.dialects.func as func def run(f): @@ -49,31 +49,22 @@ def _binary_op(lhs, rhs, op: str): raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}") op = getattr(arith, f"{op}Op") - return maybe_cast(op(lhs, rhs).result) + return op(lhs, rhs).result @register_value_caster(F16Type.static_typeid) @register_value_caster(F32Type.static_typeid) @register_value_caster(F64Type.static_typeid) @register_value_caster(IntegerType.static_typeid) class ArithValue(Value): + def __init__(self, v): + super().__init__(v) + __add__ = partialmethod(_binary_op, op="add") __sub__ = partialmethod(_binary_op, op="sub") __mul__ = partialmethod(_binary_op, op="mul") def __str__(self): - return super().__str__().replace("Value", "ArithValue") - - @register_value_caster(IntegerType.static_typeid, priority=0) - class ArithValue1(Value): - __mul__ = partialmethod(_binary_op, op="mul") - - def __str__(self): - return super().__str__().replace("Value", "ArithValue1") - - @register_value_caster(IntegerType.static_typeid, priority=0) - def no_op_caster(val): - print("no_op_caster", val) - return None + return super().__str__().replace(Value.__name__, ArithValue.__name__) with Context() as ctx, Location.unknown(): module = Module.create() @@ -97,10 +88,3 @@ def no_op_caster(val): b = a * a # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64) print(b) - - # CHECK: no_op_caster Value(%c1_i32 = arith.constant 1 : i32) - a = arith.constant(value=IntegerAttr.get(i32, 1)) - b = a * a - # CHECK: no_op_caster Value(%3 = arith.muli %c1_i32, %c1_i32 : i32) - # CHECK: ArithValue1(%3 = arith.muli %c1_i32, %c1_i32 : i32) - print(b) diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index ddf653dcce278..1c3e1a6ae9654 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -270,3 +270,99 @@ def testValueSetType(): # CHECK: %[[VAL1]] = "custom.op1"() : () -> i64 print(value.owner) + + +# CHECK-LABEL: TEST: testValueCasters +@run +def testValueCasters(): + class NOPResult(OpResult): + def __init__(self, v): + super().__init__(v) + + def __str__(self): + return super().__str__().replace(Value.__name__, NOPResult.__name__) + + class NOPValue(Value): + def __init__(self, v): + super().__init__(v) + + def __str__(self): + return super().__str__().replace(Value.__name__, NOPValue.__name__) + + class NOPBlockArg(BlockArgument): + def __init__(self, v): + super().__init__(v) + + def __str__(self): + return super().__str__().replace(Value.__name__, NOPBlockArg.__name__) + + @register_value_caster(IntegerType.static_typeid) + def cast_int(v): + print("in caster", v.__class__.__name__) + if isinstance(v, OpResult): + return NOPResult(v) + if isinstance(v, BlockArgument): + return NOPBlockArg(v) + elif isinstance(v, Value): + return NOPValue(v) + + ctx = Context() + ctx.allow_unregistered_dialects = True + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + module = Module.create() + with InsertionPoint(module.body): + values = Operation.create("custom.op1", results=[i32, i32]).results + # CHECK: in caster OpResult + # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("result", values[0].result_number, values[0]) + # CHECK: in caster OpResult + # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("result", values[1].result_number, values[0]) + + value0, value1 = values + # CHECK: in caster OpResult + # CHECK: result 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("result", value0.result_number, values[0]) + # CHECK: in caster OpResult + # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("result", value1.result_number, values[0]) + + op1 = Operation.create("custom.op2", operands=[value0, value1]) + # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> () + print(op1) + + # CHECK: in caster Value + # CHECK: operand 0 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("operand 0", op1.operands[0]) + # CHECK: in caster Value + # CHECK: operand 1 NOPValue(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("operand 1", op1.operands[1]) + + # CHECK: in caster BlockArgument + # CHECK: in caster BlockArgument + @func.FuncOp.from_py_func(i32, i32) + def reduction(arg0, arg1): + # CHECK: as func arg 0 NOPBlockArg + print("as func arg", arg0.arg_number, arg0.__class__.__name__) + # CHECK: as func arg 1 NOPBlockArg + print("as func arg", arg1.arg_number, arg1.__class__.__name__) + + @register_value_caster(IntegerType.static_typeid, replace=True) + def dont_cast_int(v): + print("don't cast", v.result_number, v) + return v + + with Location.unknown(ctx): + i32 = IntegerType.get_signless(32) + module = Module.create() + with InsertionPoint(module.body): + # CHECK: don't cast 0 Value(%0 = "custom.op1"() : () -> i32) + new_value = Operation.create("custom.op1", results=[i32]).result + # CHECK: result 0 Value(%0 = "custom.op1"() : () -> i32) + print("result", new_value.result_number, new_value) + + # CHECK: don't cast 0 Value(%1 = "custom.op2"() : () -> i32) + new_value = Operation.create("custom.op2", results=[i32]).results[0] + # CHECK: result 0 Value(%1 = "custom.op2"() : () -> i32) + print("result", new_value.result_number, new_value) diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 170ac6b87c693..0c0ad2cfeffdc 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -37,7 +37,6 @@ from ._ods_common import ( get_op_result_or_op_results as _get_op_result_or_op_results, get_op_result_or_value as _get_op_result_or_value, get_op_results_or_values as _get_op_results_or_values, - maybe_cast as _maybe_cast, segmented_accessor as _ods_segmented_accessor, ) _ods_ir = _ods_cext.ir @@ -272,7 +271,7 @@ constexpr const char *regionAccessorTemplate = R"Py( constexpr const char *valueBuilderTemplate = R"Py( def {0}({2}) -> {4}: - return _maybe_cast(_get_op_result_or_op_results({1}({3}))) + return _get_op_result_or_op_results({1}({3})) )Py"; static llvm::cl::OptionCategory From 31792be9d0b9da95ea35e27a2a3d4511273634c9 Mon Sep 17 00:00:00 2001 From: max Date: Mon, 30 Oct 2023 13:34:21 -0500 Subject: [PATCH 04/10] done with opresult, blockarg casting --- mlir/include/mlir-c/Bindings/Python/Interop.h | 18 ++++++++- mlir/lib/Bindings/Python/IRCore.cpp | 12 +++--- mlir/lib/Bindings/Python/IRModule.cpp | 3 +- mlir/lib/Bindings/Python/MainModule.cpp | 2 +- mlir/python/mlir/dialects/_ods_common.py | 1 - mlir/test/python/dialects/arith_dialect.py | 3 +- mlir/test/python/dialects/python_test.py | 6 +++ mlir/test/python/ir/value.py | 29 ++++++++++++-- mlir/test/python/lib/PythonTestModule.cpp | 40 +++++++++++++++---- 9 files changed, 90 insertions(+), 24 deletions(-) diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index f79c10cb93838..9b026a6b922de 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -121,10 +121,26 @@ * def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster, * bool replace) * where replace indicates the typeCaster should replace any existing registered - * type casters (such as those for upstream ConcreteTypes). + * type casters (such as those for upstream ConcreteTypes). The interface of the + * typeCaster is: + * def type_caster(ir.Type) -> SubClassTypeT + * where SubClassTypeT indicates the result should be a subclass (inherit from) + * ir.Type. */ #define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster" +/** Attribute on main C extension module (_mlir) that corresponds to the + * value caster registration binding. The signature of the function is: + * def register_value_caster(MlirTypeID mlirTypeID, bool replace, + * py::function valueCaster) + * where replace indicates the valueCaster should replace any existing + * registered value casters. The interface of the valueCaster is: + * def value_caster(ir.Value) -> SubClassValueT + * where SubClassValueT indicates the result should be a subclass (inherit from) + * ir.Value. + */ +#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster" + /// Gets a void* from a wrapped struct. Needed because const cast is different /// between C/C++. #ifdef __cplusplus diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 2c7ffda4e0880..53eb75f810c18 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1913,10 +1913,10 @@ pybind11::object PyValue::maybeDownCast() { "mlirTypeID was expected to be non-null."); std::optional valueCaster = PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); - py::object this_ = py::cast(this, py::return_value_policy::move); + py::object thisObj = py::cast(this, py::return_value_policy::move); if (!valueCaster) - return this_; - return valueCaster.value()(this_); + return thisObj; + return valueCaster.value()(thisObj); } PyValue PyValue::createFromCapsule(pybind11::object capsule) { @@ -2930,9 +2930,9 @@ void mlir::python::populateIRCore(py::module &m) { "single result)") .str()); } - PyOpResult result = PyOpResult( - operation.getRef(), mlirOperationGetResult(operation, 0)); - return result.maybeDownCast(); + return PyOpResult(operation.getRef(), + mlirOperationGetResult(operation, 0)) + .maybeDownCast(); }, "Shortcut to get an op result if it has only one (throws an error " "otherwise).") diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index b3fc67aa86551..e8973bd1884b7 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -93,7 +93,8 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, bool replace) { pybind11::object &found = valueCasterMap[mlirTypeID]; if (found && !found.is_none() && !replace) - throw std::runtime_error("Value caster is already registered"); + throw std::runtime_error("Value caster is already registered: " + + py::repr(found).cast()); found = std::move(valueCaster); const auto foundIt = valueCasterMapCache.find(mlirTypeID); if (foundIt != valueCasterMapCache.end() && !foundIt->second.is_none()) { diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 68de57d389fbe..8693583f26c05 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -94,7 +94,7 @@ PYBIND11_MODULE(_mlir, m) { "typeid"_a, "type_caster"_a, "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); m.def( - "register_value_caster", + MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { return py::cpp_function( [mlirTypeID, replace](py::object valueCaster) -> py::object { diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index fa73c197c17fa..60ce83c09f171 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -1,7 +1,6 @@ # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -from collections import defaultdict # Provide a convenient name for sub-packages to resolve the main C-extension # with a relative import. diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index 39c3d5799a656..c8d21dfb62ed5 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -39,7 +39,7 @@ def testFastMathFlags(): # CHECK-LABEL: TEST: testArithValue @run def testArithValue(): - def _binary_op(lhs, rhs, op: str): + def _binary_op(lhs, rhs, op: str) -> "ArithValue": op = op.capitalize() if arith._is_float_type(lhs.type): op += "F" @@ -71,7 +71,6 @@ def __str__(self): f16_t = F16Type.get() f32_t = F32Type.get() f64_t = F64Type.get() - i32 = IntegerType.get_signless(32) with InsertionPoint(module.body): a = arith.constant(value=FloatAttr.get(f16_t, 42.42)) diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 472db7e5124db..157e5c28f19e7 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -425,6 +425,12 @@ def __str__(self): # And it should be equal to the in-tree concrete type assert test.TestIntegerRankedTensorType.static_typeid == t.type.typeid + d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result + # CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>) + print(d) + # CHECK: SubClassValueT: print("in caster", v.__class__.__name__) if isinstance(v, OpResult): return NOPResult(v) @@ -318,7 +319,10 @@ def cast_int(v): print("result", values[0].result_number, values[0]) # CHECK: in caster OpResult # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) - print("result", values[1].result_number, values[0]) + print("result", values[1].result_number, values[1]) + + # CHECK: results slice 0 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) + print("results slice", values[:1][0].result_number, values[:1][0]) value0, value1 = values # CHECK: in caster OpResult @@ -326,7 +330,7 @@ def cast_int(v): print("result", value0.result_number, values[0]) # CHECK: in caster OpResult # CHECK: result 1 NOPResult(%0:2 = "custom.op1"() : () -> (i32, i32)) - print("result", value1.result_number, values[0]) + print("result", value1.result_number, values[1]) op1 = Operation.create("custom.op2", operands=[value0, value1]) # CHECK: "custom.op2"(%0#0, %0#1) : (i32, i32) -> () @@ -348,8 +352,25 @@ def reduction(arg0, arg1): # CHECK: as func arg 1 NOPBlockArg print("as func arg", arg1.arg_number, arg1.__class__.__name__) + # CHECK: args slice 0 NOPBlockArg( of type 'i32' at index: 0) + print( + "args slice", + reduction.func_op.arguments[:1][0].arg_number, + reduction.func_op.arguments[:1][0], + ) + + try: + + @register_value_caster(IntegerType.static_typeid) + def dont_cast_int_shouldnt_register(v): + ... + + except RuntimeError as e: + # CHECK: Value caster is already registered: .cast_int at + print(e) + @register_value_caster(IntegerType.static_typeid, replace=True) - def dont_cast_int(v): + def dont_cast_int(v) -> Value: print("don't cast", v.result_number, v) return v diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp index f533082a0a147..1e584343d0f0a 100644 --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -42,6 +42,7 @@ PYBIND11_MODULE(_mlirPythonTest, m) { return cls(mlirPythonTestTestAttributeGet(ctx)); }, py::arg("cls"), py::arg("context") = py::none()); + mlir_type_subclass(m, "TestType", mlirTypeIsAPythonTestTestType, mlirPythonTestTestTypeGetTypeID) .def_classmethod( @@ -50,7 +51,8 @@ PYBIND11_MODULE(_mlirPythonTest, m) { return cls(mlirPythonTestTestTypeGet(ctx)); }, py::arg("cls"), py::arg("context") = py::none()); - auto cls = + + auto typeCls = mlir_type_subclass(m, "TestIntegerRankedTensorType", mlirTypeIsARankedIntegerTensor, py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) @@ -65,16 +67,38 @@ PYBIND11_MODULE(_mlirPythonTest, m) { encoding)); }, "cls"_a, "shape"_a, "width"_a, "context"_a = py::none()); - assert(py::hasattr(cls.get_class(), "static_typeid") && + + assert(py::hasattr(typeCls.get_class(), "static_typeid") && "TestIntegerRankedTensorType has no static_typeid"); - MlirTypeID mlirTypeID = mlirRankedTensorTypeGetTypeID(); + + MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID(); + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( - mlirTypeID, pybind11::cpp_function([cls](const py::object &mlirType) { - return cls.get_class()(mlirType); + mlirRankedTensorTypeID, + pybind11::cpp_function([typeCls](const py::object &mlirType) { + return typeCls.get_class()(mlirType); }), /*replace=*/true); - mlir_value_subclass(m, "TestTensorValue", - mlirTypeIsAPythonTestTestTensorValue) - .def("is_null", [](MlirValue &self) { return mlirValueIsNull(self); }); + + auto valueCls = mlir_value_subclass(m, "TestTensorValue", + mlirTypeIsAPythonTestTestTensorValue) + .def("is_null", [](MlirValue &self) { + return mlirValueIsNull(self); + }); + + py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) + .attr(MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR)( + mlirRankedTensorTypeID)( + pybind11::cpp_function([valueCls](const py::object &valueObj) { + py::object capsule = mlirApiObjectToCapsule(valueObj); + MlirValue v = mlirPythonCapsuleToValue(capsule.ptr()); + MlirType t = mlirValueGetType(v); + if (mlirShapedTypeHasStaticShape(t) && + mlirShapedTypeGetDimSize(t, 0) == 1 && + mlirShapedTypeGetDimSize(t, 1) == 2 && + mlirShapedTypeGetDimSize(t, 2) == 3) + return valueCls.get_class()(valueObj); + return valueObj; + })); } From 69ae05c2e2fa4dd15616eed9b86ec2f8d5fb2042 Mon Sep 17 00:00:00 2001 From: max Date: Tue, 31 Oct 2023 14:00:16 -0500 Subject: [PATCH 05/10] remove valuecastercache --- mlir/lib/Bindings/Python/Globals.h | 4 --- mlir/lib/Bindings/Python/IRModule.cpp | 38 +++++---------------------- 2 files changed, 6 insertions(+), 36 deletions(-) diff --git a/mlir/lib/Bindings/Python/Globals.h b/mlir/lib/Bindings/Python/Globals.h index 7baf029436408..a022067f5c7e5 100644 --- a/mlir/lib/Bindings/Python/Globals.h +++ b/mlir/lib/Bindings/Python/Globals.h @@ -120,12 +120,8 @@ class PyGlobals { llvm::StringMap attributeBuilderMap; /// Map of MlirTypeID to custom type caster. llvm::DenseMap typeCasterMap; - /// Map of MlirTypeID to custom value caster. llvm::DenseMap valueCasterMap; - /// Cache for map of MlirTypeID to custom value caster. - llvm::DenseMap valueCasterMapCache; - /// Set of dialect namespaces that we have attempted to import implementation /// modules for. llvm::StringSet<> loadedDialectModules; diff --git a/mlir/lib/Bindings/Python/IRModule.cpp b/mlir/lib/Bindings/Python/IRModule.cpp index e8973bd1884b7..5538924d24818 100644 --- a/mlir/lib/Bindings/Python/IRModule.cpp +++ b/mlir/lib/Bindings/Python/IRModule.cpp @@ -92,14 +92,10 @@ void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID, pybind11::function valueCaster, bool replace) { pybind11::object &found = valueCasterMap[mlirTypeID]; - if (found && !found.is_none() && !replace) + if (found && !replace) throw std::runtime_error("Value caster is already registered: " + py::repr(found).cast()); found = std::move(valueCaster); - const auto foundIt = valueCasterMapCache.find(mlirTypeID); - if (foundIt != valueCasterMapCache.end() && !foundIt->second.is_none()) { - valueCasterMapCache[mlirTypeID] = found; - } } void PyGlobals::registerDialectImpl(const std::string &dialectNamespace, @@ -150,35 +146,13 @@ std::optional PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID, std::optional PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID, MlirDialect dialect) { - { - // Fast match against the value caster map first (common case). - const auto foundIt = valueCasterMapCache.find(mlirTypeID); - if (foundIt != valueCasterMapCache.end()) { - if (foundIt->second.is_none()) - return std::nullopt; - assert(foundIt->second && "py::function is defined"); - return foundIt->second; - } - } - - // Not found. Load the dialect namespace. loadDialectModule(unwrap(mlirDialectGetNamespace(dialect))); - - // Attempt to find from the canonical map and cache. - { - const auto foundIt = valueCasterMap.find(mlirTypeID); - if (foundIt != valueCasterMap.end()) { - if (foundIt->second.is_none()) - return std::nullopt; - assert(foundIt->second && "py::object is defined"); - // Positive cache. - valueCasterMapCache[mlirTypeID] = foundIt->second; - return foundIt->second; - } - // Negative cache. - valueCasterMap[mlirTypeID] = py::none(); - return std::nullopt; + const auto foundIt = valueCasterMap.find(mlirTypeID); + if (foundIt != valueCasterMap.end()) { + assert(foundIt->second && "value caster is defined"); + return foundIt->second; } + return std::nullopt; } std::optional From a7ac3cf2eed7ef225f3dd4b2c371fe8aa2137fc5 Mon Sep 17 00:00:00 2001 From: max Date: Tue, 31 Oct 2023 15:30:03 -0500 Subject: [PATCH 06/10] use detection idiom inside of `getItem` instead of virtual member --- mlir/lib/Bindings/Python/IRCore.cpp | 21 --------------------- mlir/lib/Bindings/Python/IRModule.h | 3 ++- mlir/lib/Bindings/Python/PybindUtils.h | 17 ++++++++++++++--- mlir/test/python/dialects/arith_dialect.py | 3 +++ mlir/test/python/ir/value.py | 3 ++- 5 files changed, 21 insertions(+), 26 deletions(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 53eb75f810c18..5f6b7d380bc02 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -2218,13 +2218,6 @@ class PyBlockArgumentList step), operation(std::move(operation)), block(block) {} - pybind11::object getItem(intptr_t index) override { - auto item = this->SliceableT::getItem(index); - if (item.ptr() != nullptr) - return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)(); - return item; - } - static void bindDerived(ClassTy &c) { c.def_property_readonly("types", [](PyBlockArgumentList &self) { return getValueTypes(self, self.operation->getContext()); @@ -2274,13 +2267,6 @@ class PyOpOperandList : public Sliceable { step), operation(operation) {} - pybind11::object getItem(intptr_t index) override { - auto item = this->SliceableT::getItem(index); - if (item.ptr() != nullptr) - return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)(); - return item; - } - void dunderSetItem(intptr_t index, PyValue value) { index = wrapIndex(index); mlirOperationSetOperand(operation->get(), index, value.get()); @@ -2337,13 +2323,6 @@ class PyOpResultList : public Sliceable { step), operation(std::move(operation)) {} - pybind11::object getItem(intptr_t index) override { - auto item = this->SliceableT::getItem(index); - if (item.ptr() != nullptr) - return item.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)(); - return item; - } - static void bindDerived(ClassTy &c) { c.def_property_readonly("types", [](PyOpResultList &self) { return getValueTypes(self, self.operation->getContext()); diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index b95c4578fbc22..c6442ce8fefde 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -1110,6 +1110,7 @@ class PyConcreteAttribute : public BaseTy { /// bindings so such operation always exists). class PyValue { public: + virtual ~PyValue() = default; PyValue(PyOperationRef parentOperation, MlirValue value) : parentOperation(std::move(parentOperation)), value(value) {} operator MlirValue() const { return value; } @@ -1122,7 +1123,7 @@ class PyValue { /// Gets a capsule wrapping the void* within the MlirValue. pybind11::object getCapsule(); - virtual pybind11::object maybeDownCast(); + pybind11::object maybeDownCast(); /// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of /// the underlying MlirValue is still tied to the owning operation. diff --git a/mlir/lib/Bindings/Python/PybindUtils.h b/mlir/lib/Bindings/Python/PybindUtils.h index efb7b713f80a4..38462ac8ba6db 100644 --- a/mlir/lib/Bindings/Python/PybindUtils.h +++ b/mlir/lib/Bindings/Python/PybindUtils.h @@ -10,6 +10,7 @@ #define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H #include "mlir-c/Support.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Twine.h" #include "llvm/Support/DataTypes.h" @@ -228,10 +229,15 @@ class Sliceable { return linearIndex; } + /// Trait to check if T provides a `maybeDownCast` method. + /// Note, you need the & to detect inherited members. + template + using has_maybe_downcast = decltype(&T::maybeDownCast); + /// Returns the element at the given slice index. Supports negative indices /// by taking elements in inverse order. Returns a nullptr object if out /// of bounds. - virtual pybind11::object getItem(intptr_t index) { + pybind11::object getItem(intptr_t index) { // Negative indices mean we count from the end. index = wrapIndex(index); if (index < 0) { @@ -239,8 +245,13 @@ class Sliceable { return {}; } - return pybind11::cast( - static_cast(this)->getRawElement(linearizeIndex(index))); + if constexpr (llvm::is_detected::value) + return static_cast(this) + ->getRawElement(linearizeIndex(index)) + .maybeDownCast(); + else + return pybind11::cast( + static_cast(this)->getRawElement(linearizeIndex(index))); } /// Returns a new instance of the pseudo-container restricted to the given diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index c8d21dfb62ed5..25a258f3e3688 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -74,6 +74,9 @@ def __str__(self): with InsertionPoint(module.body): a = arith.constant(value=FloatAttr.get(f16_t, 42.42)) + # CHECK: ArithValue(%cst = arith.constant 4.240 + print(a) + b = a + a # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16) print(b) diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index f665f05c9f1a8..3723c00785f03 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -370,7 +370,8 @@ def dont_cast_int_shouldnt_register(v): print(e) @register_value_caster(IntegerType.static_typeid, replace=True) - def dont_cast_int(v) -> Value: + def dont_cast_int(v) -> OpResult: + assert isinstance(v, OpResult) print("don't cast", v.result_number, v) return v From 700a237cec4967727ab1ad08ff2eb05e064c8b68 Mon Sep 17 00:00:00 2001 From: max Date: Tue, 31 Oct 2023 18:21:44 -0500 Subject: [PATCH 07/10] incorporate remaining comments --- mlir/lib/Bindings/Python/IRCore.cpp | 2 ++ mlir/test/python/ir/value.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Bindings/Python/IRCore.cpp b/mlir/lib/Bindings/Python/IRCore.cpp index 5f6b7d380bc02..0f2ca666ccc05 100644 --- a/mlir/lib/Bindings/Python/IRCore.cpp +++ b/mlir/lib/Bindings/Python/IRCore.cpp @@ -1913,6 +1913,8 @@ pybind11::object PyValue::maybeDownCast() { "mlirTypeID was expected to be non-null."); std::optional valueCaster = PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type)); + // py::return_value_policy::move means use std::move to move the return value + // contents into a new instance that will be owned by Python. py::object thisObj = py::cast(this, py::return_value_policy::move); if (!valueCaster) return thisObj; diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index 3723c00785f03..acbf463113a6d 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -366,7 +366,7 @@ def dont_cast_int_shouldnt_register(v): ... except RuntimeError as e: - # CHECK: Value caster is already registered: .cast_int at + # CHECK: Value caster is already registered: {{.*}}cast_int print(e) @register_value_caster(IntegerType.static_typeid, replace=True) From ea64522bfd9967342ffd70e5774415a80e7193da Mon Sep 17 00:00:00 2001 From: max Date: Fri, 3 Nov 2023 13:03:19 -0500 Subject: [PATCH 08/10] incorporate more comments --- mlir/lib/Bindings/Python/MainModule.cpp | 2 -- mlir/test/python/dialects/python_test.py | 2 +- mlir/test/python/lib/PythonTestModule.cpp | 3 +++ 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 8693583f26c05..098584483b185 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -12,8 +12,6 @@ #include "IRModule.h" #include "Pass.h" -#include - namespace py = pybind11; using namespace mlir; using namespace py::literals; diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 157e5c28f19e7..02c76caded6bf 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -428,7 +428,7 @@ def __str__(self): d = tensor.EmptyOp([1, 2, 3], IntegerType.get_signless(5)).result # CHECK: Value(%{{.*}} = tensor.empty() : tensor<1x2x3xi5>) print(d) - # CHECK: Date: Mon, 6 Nov 2023 16:32:30 -0600 Subject: [PATCH 09/10] kwonly replace arg --- mlir/include/mlir-c/Bindings/Python/Interop.h | 27 +++++++++---------- .../mlir/Bindings/Python/PybindAdaptors.h | 9 +++---- mlir/lib/Bindings/Python/MainModule.cpp | 18 ++++++++----- mlir/test/python/dialects/arith_dialect.py | 6 +++-- mlir/test/python/dialects/python_test.py | 12 ++++----- mlir/test/python/lib/PythonTestModule.cpp | 7 +++-- 6 files changed, 40 insertions(+), 39 deletions(-) diff --git a/mlir/include/mlir-c/Bindings/Python/Interop.h b/mlir/include/mlir-c/Bindings/Python/Interop.h index 9b026a6b922de..0a36e97c2ae68 100644 --- a/mlir/include/mlir-c/Bindings/Python/Interop.h +++ b/mlir/include/mlir-c/Bindings/Python/Interop.h @@ -118,25 +118,24 @@ /** Attribute on main C extension module (_mlir) that corresponds to the * type caster registration binding. The signature of the function is: - * def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster, - * bool replace) - * where replace indicates the typeCaster should replace any existing registered - * type casters (such as those for upstream ConcreteTypes). The interface of the - * typeCaster is: - * def type_caster(ir.Type) -> SubClassTypeT - * where SubClassTypeT indicates the result should be a subclass (inherit from) - * ir.Type. + * def register_type_caster(MlirTypeID mlirTypeID, *, bool replace) + * which then takes a typeCaster (register_type_caster is meant to be used as a + * decorator from python), and where replace indicates the typeCaster should + * replace any existing registered type casters (such as those for upstream + * ConcreteTypes). The interface of the typeCaster is: def type_caster(ir.Type) + * -> SubClassTypeT where SubClassTypeT indicates the result should be a + * subclass (inherit from) ir.Type. */ #define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster" /** Attribute on main C extension module (_mlir) that corresponds to the * value caster registration binding. The signature of the function is: - * def register_value_caster(MlirTypeID mlirTypeID, bool replace, - * py::function valueCaster) - * where replace indicates the valueCaster should replace any existing - * registered value casters. The interface of the valueCaster is: - * def value_caster(ir.Value) -> SubClassValueT - * where SubClassValueT indicates the result should be a subclass (inherit from) + * def register_value_caster(MlirTypeID mlirTypeID, *, bool replace) + * which then takes a valueCaster (register_value_caster is meant to be used as + * a decorator, from python), and where replace indicates the valueCaster should + * replace any existing registered value casters. The interface of the + * valueCaster is: def value_caster(ir.Value) -> SubClassValueT where + * SubClassValueT indicates the result should be a subclass (inherit from) * ir.Value. */ #define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster" diff --git a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h index acc90e4ab9a22..5e0e56fc00a67 100644 --- a/mlir/include/mlir/Bindings/Python/PybindAdaptors.h +++ b/mlir/include/mlir/Bindings/Python/PybindAdaptors.h @@ -497,11 +497,10 @@ class mlir_type_subclass : public pure_subclass { if (getTypeIDFunction) { py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( - getTypeIDFunction(), - pybind11::cpp_function( - [thisClass = thisClass](const py::object &mlirType) { - return thisClass(mlirType); - })); + getTypeIDFunction())(pybind11::cpp_function( + [thisClass = thisClass](const py::object &mlirType) { + return thisClass(mlirType); + })); } } }; diff --git a/mlir/lib/Bindings/Python/MainModule.cpp b/mlir/lib/Bindings/Python/MainModule.cpp index 098584483b185..17272472ccca4 100644 --- a/mlir/lib/Bindings/Python/MainModule.cpp +++ b/mlir/lib/Bindings/Python/MainModule.cpp @@ -44,7 +44,8 @@ PYBIND11_MODULE(_mlir, m) { "dialect_namespace"_a, "dialect_class"_a, "Testing hook for directly registering a dialect") .def("_register_operation_impl", &PyGlobals::registerOperationImpl, - "operation_name"_a, "operation_class"_a, "replace"_a = false, + "operation_name"_a, "operation_class"_a, py::kw_only(), + "replace"_a = false, "Testing hook for directly registering an operation"); // Aside from making the globals accessible to python, having python manage @@ -80,16 +81,19 @@ PYBIND11_MODULE(_mlir, m) { return opClass; }); }, - "dialect_class"_a, "replace"_a = false, + "dialect_class"_a, py::kw_only(), "replace"_a = false, "Produce a class decorator for registering an Operation class as part of " "a dialect"); m.def( MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR, - [](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) { - PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster), - replace); + [](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function { + return py::cpp_function([mlirTypeID, + replace](py::object typeCaster) -> py::object { + PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace); + return typeCaster; + }); }, - "typeid"_a, "type_caster"_a, "replace"_a = false, + "typeid"_a, py::kw_only(), "replace"_a = false, "Register a type caster for casting MLIR types to custom user types."); m.def( MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR, @@ -101,7 +105,7 @@ PYBIND11_MODULE(_mlir, m) { return valueCaster; }); }, - "typeid"_a, "replace"_a = false, + "typeid"_a, py::kw_only(), "replace"_a = false, "Register a value caster for casting MLIR values to custom user values."); // Define and populate IR submodule. diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py index 25a258f3e3688..f80f2c084a0f3 100644 --- a/mlir/test/python/dialects/arith_dialect.py +++ b/mlir/test/python/dialects/arith_dialect.py @@ -41,9 +41,11 @@ def testFastMathFlags(): def testArithValue(): def _binary_op(lhs, rhs, op: str) -> "ArithValue": op = op.capitalize() - if arith._is_float_type(lhs.type): + if arith._is_float_type(lhs.type) and arith._is_float_type(rhs.type): op += "F" - elif arith._is_integer_like_type(lhs.type): + elif arith._is_integer_like_type(lhs.type) and arith._is_integer_like_type( + lhs.type + ): op += "I" else: raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}") diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 02c76caded6bf..f313a400b73c0 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -508,19 +508,18 @@ def testCustomTypeTypeCaster(): # CHECK: Type caster is already registered try: + @register_type_caster(c.typeid) def type_caster(pytype): return test.TestIntegerRankedTensorType(pytype) - register_type_caster(c.typeid, type_caster) except RuntimeError as e: print(e) - def type_caster(pytype): - return RankedTensorType(pytype) - # python_test dialect registers a caster for RankedTensorType in its extension (pybind) module. # So this one replaces that one (successfully). And then just to be sure we restore the original caster below. - register_type_caster(c.typeid, type_caster, replace=True) + @register_type_caster(c.typeid, replace=True) + def type_caster(pytype): + return RankedTensorType(pytype) d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result # CHECK: tensor<10x10xi5> @@ -528,11 +527,10 @@ def type_caster(pytype): # CHECK: ranked tensor type RankedTensorType(tensor<10x10xi5>) print("ranked tensor type", repr(d.type)) + @register_type_caster(c.typeid, replace=True) def type_caster(pytype): return test.TestIntegerRankedTensorType(pytype) - register_type_caster(c.typeid, type_caster, replace=True) - d = tensor.EmptyOp([10, 10], IntegerType.get_signless(5)).result # CHECK: tensor<10x10xi5> print(d.type) diff --git a/mlir/test/python/lib/PythonTestModule.cpp b/mlir/test/python/lib/PythonTestModule.cpp index b37843a8ef466..aff414894cb82 100644 --- a/mlir/test/python/lib/PythonTestModule.cpp +++ b/mlir/test/python/lib/PythonTestModule.cpp @@ -74,12 +74,11 @@ PYBIND11_MODULE(_mlirPythonTest, m) { MlirTypeID mlirRankedTensorTypeID = mlirRankedTensorTypeGetTypeID(); py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir")) - .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)( - mlirRankedTensorTypeID, + .attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(mlirRankedTensorTypeID, + "replace"_a = true)( pybind11::cpp_function([typeCls](const py::object &mlirType) { return typeCls.get_class()(mlirType); - }), - /*replace=*/true); + })); auto valueCls = mlir_value_subclass(m, "TestTensorValue", mlirTypeIsAPythonTestTestTensorValue) From 27712ea3f589844b78110503fed027bc53339ad3 Mon Sep 17 00:00:00 2001 From: Maksim Levental Date: Tue, 7 Nov 2023 03:03:35 -0600 Subject: [PATCH 10/10] Add comment about virtual --- mlir/lib/Bindings/Python/IRModule.h | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mlir/lib/Bindings/Python/IRModule.h b/mlir/lib/Bindings/Python/IRModule.h index c6442ce8fefde..af55693f18fbb 100644 --- a/mlir/lib/Bindings/Python/IRModule.h +++ b/mlir/lib/Bindings/Python/IRModule.h @@ -1110,6 +1110,9 @@ class PyConcreteAttribute : public BaseTy { /// bindings so such operation always exists). class PyValue { public: + // The virtual here is "load bearing" in that it enables RTTI + // for PyConcreteValue CRTP classes that support maybeDownCast. + // See PyValue::maybeDownCast. virtual ~PyValue() = default; PyValue(PyOperationRef parentOperation, MlirValue value) : parentOperation(std::move(parentOperation)), value(value) {}