-
Notifications
You must be signed in to change notification settings - Fork 11.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR][Python] Use ir.Value directly instead of _SubClassValueT #82341
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Sergei Lebedev (superbobry) Changes_SubClassValueT is only useful when it is has >1 usage in a signature. This was not true for the signatures produced by tblgen. For example
here a type checker does not have enough information to infer a type argument for _SubClassValueT, and thus effectively treats it as Any. Full diff: https://github.com/llvm/llvm-project/pull/82341.diff 4 Files Affected:
diff --git a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
index 3ed1872f1cd5a2..93b978c75540f4 100644
--- a/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
+++ b/mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
@@ -10,4 +10,4 @@ class _Globals:
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...
def register_dialect(dialect_class: type) -> object: ...
-def register_operation(dialect_class: type) -> object: ...
+def register_operation(dialect_class: type, *, replace: bool = ...) -> object: ...
diff --git a/mlir/python/mlir/dialects/_ods_common.py b/mlir/python/mlir/dialects/_ods_common.py
index 3af3b5ce73bc60..1e7e8244ed4420 100644
--- a/mlir/python/mlir/dialects/_ods_common.py
+++ b/mlir/python/mlir/dialects/_ods_common.py
@@ -8,7 +8,6 @@
Sequence as _Sequence,
Tuple as _Tuple,
Type as _Type,
- TypeVar as _TypeVar,
Union as _Union,
)
@@ -143,12 +142,6 @@ def get_op_result_or_op_results(
else op
)
-
-# This is the standard way to indicate subclass/inheritance relationship
-# see the typing.Type doc string.
-_U = _TypeVar("_U", bound=_cext.ir.Value)
-SubClassValueT = _Type[_U]
-
ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
ResultValueT = _Union[ResultValueTypeTuple]
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
diff --git a/mlir/python/mlir/dialects/arith.py b/mlir/python/mlir/dialects/arith.py
index 663a53660a6474..61c6917393f1f9 100644
--- a/mlir/python/mlir/dialects/arith.py
+++ b/mlir/python/mlir/dialects/arith.py
@@ -12,7 +12,6 @@
get_default_loc_context as _get_default_loc_context,
_cext as _ods_cext,
get_op_result_or_op_results as _get_op_result_or_op_results,
- SubClassValueT as _SubClassValueT,
)
from typing import Any, List, Union
@@ -81,5 +80,5 @@ def literal_value(self) -> Union[int, float]:
def constant(
result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
-) -> _SubClassValueT:
+) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
diff --git a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
index 0770ed562309e7..6c06b86fdf751f 100644
--- a/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
+++ b/mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
@@ -31,7 +31,6 @@ constexpr const char *fileHeader = R"Py(
from ._ods_common import _cext as _ods_cext
from ._ods_common import (
- SubClassValueT as _SubClassValueT,
equally_sized_accessor as _ods_equally_sized_accessor,
get_default_loc_context as _ods_get_default_loc_context,
get_op_result_or_op_results as _get_op_result_or_op_results,
@@ -52,8 +51,6 @@ constexpr const char *dialectClassTemplate = R"Py(
@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
DIALECT_NAMESPACE = "{0}"
- pass
-
)Py";
constexpr const char *dialectExtensionTemplate = R"Py(
@@ -1007,14 +1004,13 @@ static void emitValueBuilder(const Operator &op,
});
std::string nameWithoutDialect =
op.getOperationName().substr(op.getOperationName().find('.') + 1);
- os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
- op.getCppClassName(),
- llvm::join(valueBuilderParams, ", "),
- llvm::join(opBuilderArgs, ", "),
- (op.getNumResults() > 1
- ? "_Sequence[_SubClassValueT]"
- : (op.getNumResults() > 0 ? "_SubClassValueT"
- : "_ods_ir.Operation")));
+ os << llvm::formatv(
+ valueBuilderTemplate, sanitizeName(nameWithoutDialect),
+ op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
+ llvm::join(opBuilderArgs, ", "),
+ (op.getNumResults() > 1
+ ? "_Sequence[_ods_ir.Value]"
+ : (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
}
/// Emits bindings for a specific Op to the given output stream.
|
489d60c
to
4479294
Compare
@@ -10,4 +10,4 @@ class _Globals: | |||
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ... | |||
|
|||
def register_dialect(dialect_class: type) -> object: ... | |||
def register_operation(dialect_class: type) -> object: ... |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just a drive by change fixing a type error I discovered in arith.py
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"drive by fix" - I'm gonna use that from now on
_SubClassValueT is only useful when it is has >1 usage in a signature. This was not true for the signatures produced by tblgen. For example def call(result, callee, operands_, *, loc=None, ip=None) -> _SubClassValueT: ... here a type checker does not have enough information to infer a type argument for _SubClassValueT, and thus effectively treats it as Any.
4479294
to
f0ce602
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's funny - I think I noticed recently that something broke with type hints on results from these "value builders" but I didn't get around to figuring it out. Your point that that type only propagates type info if used twice in a signature makes sense as the root cause for the vanishing typing hints. Thanks a lot for figuring it out.
_SubClassValueT is only useful when it is has >1 usage in a signature. This was not true for the signatures produced by tblgen.
For example
here a type checker does not have enough information to infer a type argument for _SubClassValueT, and thus effectively treats it as Any.