From 6ce5159945997126b8a0f40f55e876c9fd882fc5 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev <185856+superbobry@users.noreply.github.com> Date: Wed, 21 Feb 2024 11:59:23 +0000 Subject: [PATCH] [MLIR][Python] Use ir.Value directly instead of _SubClassValueT (#82341) _SubClassValueT is only useful when it is has >1 usage in a signature. This was not true for the signatures produced by tblgen. For example def call(result, callee, operands_, *, loc=None, ip=None) -> _SubClassValueT: ... here a type checker does not have enough information to infer a type argument for _SubClassValueT, and thus effectively treats it as Any. --- mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi | 2 +- mlir/python/mlir/dialects/_ods_common.py | 7 ------- mlir/python/mlir/dialects/arith.py | 3 +-- mlir/test/mlir-tblgen/op-python-bindings.td | 1 - mlir/test/python/ir/value.py | 3 +-- mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp | 18 +++++++----------- 6 files changed, 10 insertions(+), 24 deletions(-) diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi index 3ed1872f1cd5a2..93b978c75540f4 100644 --- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi +++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi @@ -10,4 +10,4 @@ class _Globals: def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ... def register_dialect(dialect_class: type) -> object: ... -def register_operation(dialect_class: type) -> object: ... +def register_operation(dialect_class: type, *, replace: bool = ...) -> object: ... diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py index 3af3b5ce73bc60..1e7e8244ed4420 100644 --- a/mlir/python/mlir/dialects/_ods_common.py +++ b/mlir/python/mlir/dialects/_ods_common.py @@ -8,7 +8,6 @@ Sequence as _Sequence, Tuple as _Tuple, Type as _Type, - TypeVar as _TypeVar, Union as _Union, ) @@ -143,12 +142,6 @@ def get_op_result_or_op_results( else op ) - -# This is the standard way to indicate subclass/inheritance relationship -# see the typing.Type doc string. -_U = _TypeVar("_U", bound=_cext.ir.Value) -SubClassValueT = _Type[_U] - ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value ResultValueT = _Union[ResultValueTypeTuple] VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]] diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py index 663a53660a6474..61c6917393f1f9 100644 --- a/mlir/python/mlir/dialects/arith.py +++ b/mlir/python/mlir/dialects/arith.py @@ -12,7 +12,6 @@ get_default_loc_context as _get_default_loc_context, _cext as _ods_cext, get_op_result_or_op_results as _get_op_result_or_op_results, - SubClassValueT as _SubClassValueT, ) from typing import Any, List, Union @@ -81,5 +80,5 @@ def literal_value(self) -> Union[int, float]: def constant( result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None -) -> _SubClassValueT: +) -> Value: return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip)) diff --git a/mlir/test/mlir-tblgen/op-python-bindings.td b/mlir/test/mlir-tblgen/op-python-bindings.td index dbed1164f1eb0b..9f202ba08608c6 100644 --- a/mlir/test/mlir-tblgen/op-python-bindings.td +++ b/mlir/test/mlir-tblgen/op-python-bindings.td @@ -7,7 +7,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td" // CHECK: @_ods_cext.register_dialect // CHECK: class _Dialect(_ods_ir.Dialect): // CHECK: DIALECT_NAMESPACE = "test" - // CHECK: pass def Test_Dialect : Dialect { let name = "test"; let cppNamespace = "Test"; diff --git a/mlir/test/python/ir/value.py b/mlir/test/python/ir/value.py index 28ef0f2ef3e25c..50b0e8403a7f21 100644 --- a/mlir/test/python/ir/value.py +++ b/mlir/test/python/ir/value.py @@ -3,7 +3,6 @@ import gc from mlir.ir import * from mlir.dialects import func -from mlir.dialects._ods_common import SubClassValueT def run(f): @@ -270,7 +269,7 @@ def __str__(self): return super().__str__().replace(Value.__name__, NOPBlockArg.__name__) @register_value_caster(IntegerType.static_typeid) - def cast_int(v) -> SubClassValueT: + def cast_int(v) -> Value: print("in caster", v.__class__.__name__) if isinstance(v, OpResult): return NOPResult(v) diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp index 640360eff734a6..814008c2545114 100644 --- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp +++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp @@ -31,7 +31,6 @@ constexpr const char *fileHeader = R"Py( from ._ods_common import _cext as _ods_cext from ._ods_common import ( - SubClassValueT as _SubClassValueT, equally_sized_accessor as _ods_equally_sized_accessor, get_default_loc_context as _ods_get_default_loc_context, get_op_result_or_op_results as _get_op_result_or_op_results, @@ -52,8 +51,6 @@ constexpr const char *dialectClassTemplate = R"Py( @_ods_cext.register_dialect class _Dialect(_ods_ir.Dialect): DIALECT_NAMESPACE = "{0}" - pass - )Py"; constexpr const char *dialectExtensionTemplate = R"Py( @@ -1007,14 +1004,13 @@ static void emitValueBuilder(const Operator &op, }); std::string nameWithoutDialect = op.getOperationName().substr(op.getOperationName().find('.') + 1); - os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect), - op.getCppClassName(), - llvm::join(valueBuilderParams, ", "), - llvm::join(opBuilderArgs, ", "), - (op.getNumResults() > 1 - ? "_Sequence[_SubClassValueT]" - : (op.getNumResults() > 0 ? "_SubClassValueT" - : "_ods_ir.Operation"))); + os << llvm::formatv( + valueBuilderTemplate, sanitizeName(nameWithoutDialect), + op.getCppClassName(), llvm::join(valueBuilderParams, ", "), + llvm::join(opBuilderArgs, ", "), + (op.getNumResults() > 1 + ? "_Sequence[_ods_ir.Value]" + : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation"))); } /// Emits bindings for a specific Op to the given output stream.