Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR][Python] Use ir.Value directly instead of _SubClassValueT #82341

Merged
merged 1 commit into from
Feb 21, 2024

Conversation

superbobry
Copy link
Contributor

_SubClassValueT is only useful when it is has >1 usage in a signature. This was not true for the signatures produced by tblgen.

For example

def call(result, callee, operands_, *, loc=None, ip=None) -> _SubClassValueT:
    ...

here a type checker does not have enough information to infer a type argument for _SubClassValueT, and thus effectively treats it as Any.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:python MLIR Python bindings mlir labels Feb 20, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Feb 20, 2024

@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Sergei Lebedev (superbobry)

Changes

_SubClassValueT is only useful when it is has >1 usage in a signature. This was not true for the signatures produced by tblgen.

For example

def call(result, callee, operands_, *, loc=None, ip=None) -> _SubClassValueT:
    ...

here a type checker does not have enough information to infer a type argument for _SubClassValueT, and thus effectively treats it as Any.


Full diff: https://github.com/llvm/llvm-project/pull/82341.diff

4 Files Affected:

  • (modified) mlir/python/mlir/_mlir_libs/_mlir/init.pyi (+1-1)
  • (modified) mlir/python/mlir/dialects/_ods_common.py (-7)
  • (modified) mlir/python/mlir/dialects/arith.py (+1-2)
  • (modified) mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp (+7-11)
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
index 3ed1872f1cd5a2..93b978c75540f4 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
@@ -10,4 +10,4 @@ class _Globals:
     def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
 
 def register_dialect(dialect_class: type) -> object: ...
-def register_operation(dialect_class: type) -> object: ...
+def register_operation(dialect_class: type, *, replace: bool = ...) -> object: ...
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 3af3b5ce73bc60..1e7e8244ed4420 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -8,7 +8,6 @@
     Sequence as _Sequence,
     Tuple as _Tuple,
     Type as _Type,
-    TypeVar as _TypeVar,
     Union as _Union,
 )
 
@@ -143,12 +142,6 @@ def get_op_result_or_op_results(
         else op
     )
 
-
-# This is the standard way to indicate subclass/inheritance relationship
-# see the typing.Type doc string.
-_U = _TypeVar("_U", bound=_cext.ir.Value)
-SubClassValueT = _Type[_U]
-
 ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
 ResultValueT = _Union[ResultValueTypeTuple]
 VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 663a53660a6474..61c6917393f1f9 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -12,7 +12,6 @@
         get_default_loc_context as _get_default_loc_context,
         _cext as _ods_cext,
         get_op_result_or_op_results as _get_op_result_or_op_results,
-        SubClassValueT as _SubClassValueT,
     )
 
     from typing import Any, List, Union
@@ -81,5 +80,5 @@ def literal_value(self) -> Union[int, float]:
 
 def constant(
     result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
-) -> _SubClassValueT:
+) -> Value:
     return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0770ed562309e7..6c06b86fdf751f 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -31,7 +31,6 @@ constexpr const char *fileHeader = R"Py(
 
 from ._ods_common import _cext as _ods_cext
 from ._ods_common import (
-    SubClassValueT as _SubClassValueT,
     equally_sized_accessor as _ods_equally_sized_accessor,
     get_default_loc_context as _ods_get_default_loc_context,
     get_op_result_or_op_results as _get_op_result_or_op_results,
@@ -52,8 +51,6 @@ constexpr const char *dialectClassTemplate = R"Py(
 @_ods_cext.register_dialect
 class _Dialect(_ods_ir.Dialect):
   DIALECT_NAMESPACE = "{0}"
-  pass
-
 )Py";
 
 constexpr const char *dialectExtensionTemplate = R"Py(
@@ -1007,14 +1004,13 @@ static void emitValueBuilder(const Operator &op,
       });
   std::string nameWithoutDialect =
       op.getOperationName().substr(op.getOperationName().find('.') + 1);
-  os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
-                      op.getCppClassName(),
-                      llvm::join(valueBuilderParams, ", "),
-                      llvm::join(opBuilderArgs, ", "),
-                      (op.getNumResults() > 1
-                           ? "_Sequence[_SubClassValueT]"
-                           : (op.getNumResults() > 0 ? "_SubClassValueT"
-                                                     : "_ods_ir.Operation")));
+  os << llvm::formatv(
+      valueBuilderTemplate, sanitizeName(nameWithoutDialect),
+      op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
+      llvm::join(opBuilderArgs, ", "),
+      (op.getNumResults() > 1
+           ? "_Sequence[_ods_ir.Value]"
+           : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
 }
 
 /// Emits bindings for a specific Op to the given output stream.

@superbobry superbobry force-pushed the piper_export_cl_608539531 branch 2 times, most recently from 489d60c to 4479294 Compare February 20, 2024 11:19
@@ -10,4 +10,4 @@ class _Globals:
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...

def register_dialect(dialect_class: type) -> object: ...
def register_operation(dialect_class: type) -> object: ...
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just a drive by change fixing a type error I discovered in arith.py.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"drive by fix" - I'm gonna use that from now on

_SubClassValueT is only useful when it is has >1 usage in a signature.
This was not true for the signatures produced by tblgen.

For example

    def call(result, callee, operands_, *, loc=None, ip=None) -> _SubClassValueT:
      ...

here a type checker does not have enough information to infer a type argument
for _SubClassValueT, and thus effectively treats it as Any.
Copy link
Contributor

@makslevental makslevental left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's funny - I think I noticed recently that something broke with type hints on results from these "value builders" but I didn't get around to figuring it out. Your point that that type only propagates type info if used twice in a signature makes sense as the root cause for the vanishing typing hints. Thanks a lot for figuring it out.

@ftynse ftynse merged commit 6ce5159 into llvm:main Feb 21, 2024
4 checks passed
@superbobry superbobry deleted the piper_export_cl_608539531 branch February 21, 2024 12:48
makslevental added a commit to Xilinx/mlir-aie that referenced this pull request Feb 21, 2024
makslevental added a commit to Xilinx/mlir-aie that referenced this pull request Feb 21, 2024
makslevental added a commit to Xilinx/mlir-aie that referenced this pull request Feb 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir:python MLIR Python bindings mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants