Skip to content

[mlir][python] value casting #69644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 7, 2023
Merged

Conversation

makslevental
Copy link
Contributor

@makslevental makslevental commented Oct 19, 2023

This PR adds "value casting", i.e., a mechanism to wrap ir.Value in a proxy class that overloads dunders such as __add__, __sub__, and __mul__ for fun and great profit.

This is thematically similar to bfb1ba7 and 9566ee2. The example in the test demonstrates the value of the feature (no pun intended):

    @register_value_caster(F16Type.static_typeid)
    @register_value_caster(F32Type.static_typeid)
    @register_value_caster(F64Type.static_typeid)
    @register_value_caster(IntegerType.static_typeid)
    class ArithValue(Value):
        __add__ = partialmethod(_binary_op, op="add")
        __sub__ = partialmethod(_binary_op, op="sub")
        __mul__ = partialmethod(_binary_op, op="mul")

    a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
    b = a + a
    # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
    print(b)

    a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
    b = a - a
    # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
    print(b)

    a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
    b = a * a
    # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
    print(b)

EDIT: this now goes through the bindings and thus supports automatic casting of OpResult (including as an element of OpResultList), BlockArgument (including as an element of BlockArgumentList), as well as Value.

@makslevental makslevental marked this pull request as ready for review October 19, 2023 22:27
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:python MLIR Python bindings mlir labels Oct 19, 2023
@llvmbot
Copy link
Member

llvmbot commented Oct 19, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Maksim Levental (makslevental)

Changes

This PR adds "value casting", i.e., a mechanism to wrap ir.Value in a proxy class that overloads dunders such as __add__, __sub__, and __mul__ for fun and great profit.

This is thematically similar to bfb1ba7 and 9566ee2 (and indeed relies on the former) but is implemented at the Python API level because

  1. it's quite tough to intervene in all the places ir.Value is returned from the pybind code, in particular PyOpResultList::slice;
  2. registering Python classes to be conjured up by the pybind code is fairly brittle IMHO;
  3. it's always better to have things be purely in Python so they can more easily be understood/patched/monkey-patched.

This has been my secret (lol) plan for quite a while now and it wasn't possible until #68308 but now we've arrived.

The design looks like this:

  • a ValueCasterT is a callable that takes an ir.Value (or ir.OpResult) and returns a subclass of ir.Value or None;
  • the user registers a value caster associated with a TypeID (using register_value_caster) that is responsible for casting ir.Values with a matching type.typeid;
  • an ir.Value produced by a value builder will be maybe_casted, depending on whether there's a value caster registered for ir.Value.type.typeid

In the absence of a registered value caster, the ir.Value or ir.OpResult is just returned unscathed.

The user can optionally provide a priority when registering the value caster in order to allow for more/less specific proxies; for example, one might want to make attempt at wrapping a tensor<10x10x!tt.ptr<f32>> (a tensor of triton pointers), which has the same TypeID as RankedTensorType, which might already have been wrapped (possibly upstream). Should the higher priority caster fail, it can return None and the next priority caster will be tried.

Note, this doesn't affect any paths other than through the "value builders".

cc @stephenneuendorffer


Patch is 21.46 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/69644.diff

5 Files Affected:

  • (modified) mlir/python/mlir/dialects/_ods_common.py (+57-1)
  • (modified) mlir/python/mlir/ir.py (+14)
  • (modified) mlir/test/mlir-tblgen/op-python-bindings.td (+24-24)
  • (modified) mlir/test/python/dialects/arith_dialect.py (+63-5)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+13-4)
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 9cca7d659ec8cb3..dd41ee63c8bf7af 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -1,11 +1,18 @@
 #  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
 #  See https://llvm.org/LICENSE.txt for license information.
 #  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+from collections import defaultdict
 
 # Provide a convenient name for sub-packages to resolve the main C-extension
 # with a relative import.
 from .._mlir_libs import _mlir as _cext
-from typing import Sequence as _Sequence, Union as _Union
+from typing import (
+    Callable as _Callable,
+    Sequence as _Sequence,
+    Type as _Type,
+    TypeVar as _TypeVar,
+    Union as _Union,
+)
 
 __all__ = [
     "equally_sized_accessor",
@@ -123,3 +130,52 @@ def get_op_result_or_op_results(
         if len(op.results) > 0
         else op
     )
+
+
+U = _TypeVar("U", bound=_cext.ir.Value)
+SubClassValueT = _Type[U]
+
+ValueCasterT = _Callable[
+    [_Union[_cext.ir.Value, _cext.ir.OpResult]], _Union[SubClassValueT, None]
+]
+
+_VALUE_CASTERS: defaultdict[
+    _cext.ir.TypeID,
+    _Sequence[ValueCasterT],
+] = defaultdict(list)
+
+
+def has_value_caster(typeid: _cext.ir.TypeID):
+    if not isinstance(typeid, _cext.ir.TypeID):
+        raise ValueError(f"{typeid=} is not a TypeID")
+    if typeid in _VALUE_CASTERS:
+        return True
+    return False
+
+
+def get_value_caster(typeid: _cext.ir.TypeID):
+    if not has_value_caster(typeid):
+        raise ValueError(f"no registered caster for {typeid=}")
+    return _VALUE_CASTERS[typeid]
+
+
+def maybe_cast(
+    val: _Union[
+        _cext.ir.Value,
+        _cext.ir.OpResult,
+        _Sequence[_cext.ir.Value],
+        _Sequence[_cext.ir.OpResult],
+        _cext.ir.Operation,
+    ]
+) -> _Union[SubClassValueT, _Sequence[SubClassValueT], _cext.ir.Operation]:
+    if isinstance(val, (tuple, list)):
+        return tuple(map(maybe_cast, val))
+
+    if not isinstance(val, _cext.ir.Value) and not isinstance(val, _cext.ir.OpResult):
+        return val
+
+    if has_value_caster(val.type.typeid):
+        for caster in get_value_caster(val.type.typeid):
+            if casted := caster(val):
+                return casted
+    return val
diff --git a/mlir/python/mlir/ir.py b/mlir/python/mlir/ir.py
index 43553f3118a51fc..6e1f2b357f31711 100644
--- a/mlir/python/mlir/ir.py
+++ b/mlir/python/mlir/ir.py
@@ -5,6 +5,20 @@
 from ._mlir_libs._mlir.ir import *
 from ._mlir_libs._mlir.ir import _GlobalDebug
 from ._mlir_libs._mlir import register_type_caster
+from .dialects._ods_common import ValueCasterT, _VALUE_CASTERS
+
+
+def register_value_caster(typeid: TypeID, priority: int = None):
+    def wrapper(caster: ValueCasterT):
+        if not isinstance(typeid, TypeID):
+            raise ValueError(f"{typeid=} is not a TypeID")
+        if priority is None:
+            _VALUE_CASTERS[typeid].append(caster)
+        else:
+            _VALUE_CASTERS[typeid].insert(priority, caster)
+        return caster
+
+    return wrapper
 
 
 # Convenience decorator for registering user-friendly Attribute builders.
diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td
index 63dad1cc901fe2b..96b0c170dc5bb40 100644
--- a/mlir/test/mlir-tblgen/op-python-bindings.td
+++ b/mlir/test/mlir-tblgen/op-python-bindings.td
@@ -61,7 +61,7 @@ def AttrSizedOperandsOp : TestOp<"attr_sized_operands",
 }
 
 // CHECK: def attr_sized_operands(variadic1, non_variadic, *, variadic2=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(AttrSizedOperandsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttrSizedResultsOp(_ods_ir.OpView):
@@ -108,7 +108,7 @@ def AttrSizedResultsOp : TestOp<"attr_sized_results",
 }
 
 // CHECK: def attr_sized_results(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(AttrSizedResultsOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
@@ -158,7 +158,7 @@ def AttributedOp : TestOp<"attributed_op"> {
 }
 
 // CHECK: def attributed_op(i32attr, in_, *, optional_f32_attr=None, unit_attr=None, loc=None, ip=None)
-// CHECK:     return _get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip))
+// CHECK:     return _maybe_cast(_get_op_result_or_op_results(AttributedOp(i32attr=i32attr, in_=in_, optionalF32Attr=optional_f32_attr, unitAttr=unit_attr, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class AttributedOpWithOperands(_ods_ir.OpView):
@@ -194,7 +194,7 @@ def AttributedOpWithOperands : TestOp<"attributed_op_with_operands"> {
 }
 
 // CHECK: def attributed_op_with_operands(_gen_arg_0, _gen_arg_2, *, in_=None, is_=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(AttributedOpWithOperands(_gen_arg_0=_gen_arg_0, _gen_arg_2=_gen_arg_2, in_=in_, is_=is_, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class DefaultValuedAttrsOp(_ods_ir.OpView):
@@ -218,7 +218,7 @@ def DefaultValuedAttrsOp : TestOp<"default_valued_attrs"> {
 }
 
 // CHECK: def default_valued_attrs(*, arr=None, unsupported=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DefaultValuedAttrsOp(arr=arr, unsupported=unsupported, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_op"
 def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResultType]> {
@@ -236,7 +236,7 @@ def DeriveResultTypesOp : TestOp<"derive_result_types_op", [FirstAttrDerivedResu
 }
 
 // CHECK: def derive_result_types_op(type_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesOp(type_=type_, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.derive_result_types_variadic_op"
 def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [FirstAttrDerivedResultType]> {
@@ -246,7 +246,7 @@ def DeriveResultTypesVariadicOp : TestOp<"derive_result_types_variadic_op", [Fir
 }
 
 // CHECK: def derive_result_types_variadic_op(res, _gen_res_1, type_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(DeriveResultTypesVariadicOp(res=res, _gen_res_1=_gen_res_1, type_=type_, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class EmptyOp(_ods_ir.OpView):
@@ -263,7 +263,7 @@ def EmptyOp : TestOp<"empty">;
   // CHECK:     successors=_ods_successors, regions=regions, loc=loc, ip=ip))
 
 // CHECK: def empty(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(EmptyOp(loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_implied_op"
 def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
@@ -276,7 +276,7 @@ def InferResultTypesImpliedOp : TestOp<"infer_result_types_implied_op"> {
 }
 
 // CHECK: def infer_result_types_implied_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(InferResultTypesImpliedOp(loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.infer_result_types_op"
 def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]> {
@@ -289,7 +289,7 @@ def InferResultTypesOp : TestOp<"infer_result_types_op", [InferTypeOpInterface]>
 }
 
 // CHECK: def infer_result_types_op(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(InferResultTypesOp(loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class MissingNamesOp(_ods_ir.OpView):
@@ -327,7 +327,7 @@ def MissingNamesOp : TestOp<"missing_names"> {
 }
 
 // CHECK: def missing_names(i32, _gen_res_1, i64, _gen_arg_0, f32, _gen_arg_2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(MissingNamesOp(i32=i32, _gen_res_1=_gen_res_1, i64=i64, _gen_arg_0=_gen_arg_0, f32=f32, _gen_arg_2=_gen_arg_2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneOptionalOperandOp(_ods_ir.OpView):
@@ -358,7 +358,7 @@ def OneOptionalOperandOp : TestOp<"one_optional_operand"> {
 }
 
 // CHECK: def one_optional_operand(non_optional, *, optional=None, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(OneOptionalOperandOp(non_optional=non_optional, optional=optional, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicOperandOp(_ods_ir.OpView):
@@ -390,7 +390,7 @@ def OneVariadicOperandOp : TestOp<"one_variadic_operand"> {
 }
 
 // CHECK: def one_variadic_operand(non_variadic, variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(OneVariadicOperandOp(non_variadic=non_variadic, variadic=variadic, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class OneVariadicResultOp(_ods_ir.OpView):
@@ -423,7 +423,7 @@ def OneVariadicResultOp : TestOp<"one_variadic_result"> {
 }
 
 // CHECK: def one_variadic_result(variadic, non_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(OneVariadicResultOp(variadic=variadic, non_variadic=non_variadic, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class PythonKeywordOp(_ods_ir.OpView):
@@ -447,7 +447,7 @@ def PythonKeywordOp : TestOp<"python_keyword"> {
 }
 
 // CHECK: def python_keyword(in_, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(PythonKeywordOp(in_=in_, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results"
 def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
@@ -461,7 +461,7 @@ def SameResultsOp : TestOp<"same_results", [SameOperandsAndResultType]> {
 }
 
 // CHECK: def same_results(in1, in2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameResultsOp(in1=in1, in2=in2, loc=loc, ip=ip)))
 
 // CHECK-LABEL: OPERATION_NAME = "test.same_results_variadic"
 def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResultType]> {
@@ -471,7 +471,7 @@ def SameResultsVariadicOp : TestOp<"same_results_variadic", [SameOperandsAndResu
 }
 
 // CHECK: def same_results_variadic(res, in1, in2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameResultsVariadicOp(res=res, in1=in1, in2=in2, loc=loc, ip=ip)))
 
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
@@ -498,7 +498,7 @@ def SameVariadicOperandSizeOp : TestOp<"same_variadic_operand",
 }
 
 // CHECK: def same_variadic_operand(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameVariadicOperandSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SameVariadicResultSizeOp(_ods_ir.OpView):
@@ -524,7 +524,7 @@ def SameVariadicResultSizeOp : TestOp<"same_variadic_result",
 }
 
 // CHECK: def same_variadic_result(variadic1, non_variadic, variadic2, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SameVariadicResultSizeOp(variadic1=variadic1, non_variadic=non_variadic, variadic2=variadic2, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class SimpleOp(_ods_ir.OpView):
@@ -564,7 +564,7 @@ def SimpleOp : TestOp<"simple"> {
 }
 
 // CHECK: def simple(i64, f64, i32, f32, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(SimpleOp(i64=i64, f64=f64, i32=i32, f32=f32, loc=loc, ip=ip)))
 
 // CHECK: class VariadicAndNormalRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_and_normal_region"
@@ -591,7 +591,7 @@ def VariadicAndNormalRegionOp : TestOp<"variadic_and_normal_region"> {
 }
 
 // CHECK: def variadic_and_normal_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(VariadicAndNormalRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)))
 
 // CHECK: class VariadicRegionOp(_ods_ir.OpView):
 // CHECK-LABEL: OPERATION_NAME = "test.variadic_region"
@@ -614,7 +614,7 @@ def VariadicRegionOp : TestOp<"variadic_region"> {
 }
 
 // CHECK: def variadic_region(num_variadic, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(VariadicRegionOp(num_variadic=num_variadic, loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSpecialCharactersOp(_ods_ir.OpView):
@@ -623,7 +623,7 @@ def WithSpecialCharactersOp : TestOp<"123with--special.characters"> {
 }
 
 // CHECK: def _123with__special_characters(*, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip))
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(WithSpecialCharactersOp(loc=loc, ip=ip)))
 
 // CHECK: @_ods_cext.register_operation(_Dialect)
 // CHECK: class WithSuccessorsOp(_ods_ir.OpView):
@@ -638,4 +638,4 @@ def WithSuccessorsOp : TestOp<"with_successors"> {
 }
 
 // CHECK: def with_successors(successor, successors, *, loc=None, ip=None)
-// CHECK:   return _get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip))
\ No newline at end of file
+// CHECK:   return _maybe_cast(_get_op_result_or_op_results(WithSuccessorsOp(successor=successor, successors=successors, loc=loc, ip=ip)))
\ No newline at end of file
diff --git a/mlir/test/python/dialects/arith_dialect.py b/mlir/test/python/dialects/arith_dialect.py
index 6d1c5eab7589847..180d30ff4cfb3e5 100644
--- a/mlir/test/python/dialects/arith_dialect.py
+++ b/mlir/test/python/dialects/arith_dialect.py
@@ -1,8 +1,9 @@
 # RUN: %PYTHON %s | FileCheck %s
+from functools import partialmethod
 
 from mlir.ir import *
-import mlir.dialects.func as func
 import mlir.dialects.arith as arith
+from mlir.dialects._ods_common import maybe_cast
 
 
 def run(f):
@@ -35,14 +36,71 @@ def testFastMathFlags():
             print(r)
 
 
-# CHECK-LABEL: TEST: testArithValueBuilder
+# CHECK-LABEL: TEST: testArithValue
 @run
-def testArithValueBuilder():
+def testArithValue():
+    def _binary_op(lhs, rhs, op: str):
+        op = op.capitalize()
+        if arith._is_float_type(lhs.type):
+            op += "F"
+        elif arith._is_integer_like_type(lhs.type):
+            op += "I"
+        else:
+            raise NotImplementedError(f"Unsupported '{op}' operands: {lhs}, {rhs}")
+
+        op = getattr(arith, f"{op}Op")
+        return maybe_cast(op(lhs, rhs).result)
+
+    @register_value_caster(F16Type.static_typeid)
+    @register_value_caster(F32Type.static_typeid)
+    @register_value_caster(F64Type.static_typeid)
+    @register_value_caster(IntegerType.static_typeid)
+    class ArithValue(Value):
+        __add__ = partialmethod(_binary_op, op="add")
+        __sub__ = partialmethod(_binary_op, op="sub")
+        __mul__ = partialmethod(_binary_op, op="mul")
+
+        def __str__(self):
+            return super().__str__().replace("Value", "ArithValue")
+
+    @register_value_caster(IntegerType.static_typeid, priority=0)
+    class ArithValue1(Value):
+        __mul__ = partialmethod(_binary_op, op="mul")
+
+        def __str__(self):
+            return super().__str__().replace("Value", "ArithValue1")
+
+    @register_value_caster(IntegerType.static_typeid, priority=0)
+    def no_op_caster(val):
+        print("no_op_caster", val)
+        return None
+
     with Context() as ctx, Location.unknown():
         module = Module.create()
+        f16_t = F16Type.get()
         f32_t = F32Type.get()
+        f64_t = F64Type.get()
+        i32 = IntegerType.get_signless(32)
 
         with InsertionPoint(module.body):
+            a = arith.constant(value=FloatAttr.get(f16_t, 42.42))
+            b = a + a
+            # CHECK: ArithValue(%0 = arith.addf %cst, %cst : f16)
+            print(b)
+
             a = arith.constant(value=FloatAttr.get(f32_t, 42.42))
-            # CHECK: %cst = arith.constant 4.242000e+01 : f32
-            print(a)
+            b = a - a
+            # CHECK: ArithValue(%1 = arith.subf %cst_0, %cst_0 : f32)
+            print(b)
+
+            a = arith.constant(value=FloatAttr.get(f64_t, 42.42))
+            b = a * a
+            # CHECK: ArithValue(%2 = arith.mulf %cst_1, %cst_1 : f64)
+            print(b)
+
+            # CHECK: no_op_caster Value(%c1_i32 = arith.constant 1 : i32)
+            a = arith.constant(value=IntegerAttr.get(i32, 1))
+            b = a * a
+            # CHECK: no_op_caster Value(%3 = arith.muli %c1_i32, %c1_i32 : i32)
+            # CHECK: ArithValue1(%3 = arith.muli %c1_i32, %c1_i32 : i32)
+            print(b)
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index c8ef84721090ab9..170ac6b87c693d7 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -30,7 +30,16 @@ constexpr const char *fileHeader = R"Py(
 # Autogenerated by mlir-tblgen; don't manually edit.
 
 from ._ods_common import _cext as _ods_cext
-from ._ods_common import segmented_accessor as _ods_segmented_accessor, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_value as _get_op_result_or_value, get_op_resul...
[truncated]

@mikeurbach
Copy link
Contributor

Nice! As someone who has worked on a Python eDSL in CIRCT, I think this would be useful to have upstream. I will take a look at the implementation.

Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. So Am I correct in understanding that the "builder" being used is not dictated by what produced the value? So it doesn't matter if its an arith.constant vs a tensor.constant or hlo.constant, the + behaves based on the type? And for upstream they get to override by using a priority. (well and if there is someone upstream from them, then another priority )

@makslevental
Copy link
Contributor Author

makslevental commented Oct 23, 2023

So Am I correct in understanding that the "builder" being used is not dictated by what produced the value? So it doesn't matter if its an arith.constant vs a tensor.constant or hlo.constant, the + behaves based on the type?

Yes that's correct. My reasoning is well x = arith.constant + arith.constant is no longer readily identified as an arith.constant but it's still a value that you would want to perform arithmetic on. Of course, this is just a long way of saying behaviors/interfaces are associated with types rather values (like in general).

The reason it strikes the eye as weird is because mlir::Value only models the SSA-ness of the result of the operation that created it (if it was created by an operation) and not any of behaviors. Now if we started "refining" mlir::TypedValue like we do mlir::TypedAttr we'd have different story here but I figured that's a large change for just this feature.

And for upstream they get to override by using a priority. (well and if there is someone upstream from them, then another priority )

Yup exactly. The Triton case I linked is the archetypal example - there might be an upstream "value caster" for all tensor types whose elements are upstream types (eg to basically support tensor arithmetic) but then you want to do something different for a tensor with tt.ptr element type while still getting the benefits of the upstream caster. There are other ways to make this work of course (user registers a caster with their own branching on typeid) but this seems simplest.

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks interesting.

The change request: all this should also work on BlockArgument, which is a another class of values that we can see in addition to OpResult. Also please add tests covering _maybe_cast for non-trivial cases where there is more than one op result or block argument, and an for an op with no results if that makes sense.

And a question: have you considered a finer-grain type identification with, e.g., user-provided isa functions? It could be a second stage check after typeid so we don't go over a list of callbacks every time we return a value, or may be embedded into the child class of the Value somehow... The use case I have in mind is defining __add__ for values of vector type where I'd want arith.addi for vectors of integers and arith.addf for vectors of floats, but both have the same typeid on the outside.

@github-actions
Copy link

github-actions bot commented Oct 25, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

@github-actions
Copy link

github-actions bot commented Oct 25, 2023

✅ With the latest revision this PR passed the Python code formatter.

@makslevental makslevental force-pushed the value_builders_1 branch 8 times, most recently from 0e8da41 to 85820c7 Compare October 30, 2023 19:16
@makslevental makslevental force-pushed the value_builders_1 branch 3 times, most recently from f72beb9 to 70c6ade Compare October 31, 2023 23:26
@makslevental makslevental requested a review from ftynse October 31, 2023 23:33
Copy link
Member

@jpienaar jpienaar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like the end result here!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants