Skip to content

Commit

Permalink
[MLIR][Python] Use ir.Value directly instead of _SubClassValueT
Browse files Browse the repository at this point in the history
_SubClassValueT is only useful when it is has >1 usage in a signature.
This was not true for the signatures produced by tblgen.

For example

    def call(result, callee, operands_, *, loc=None, ip=None) -> _SubClassValueT:
      ...

here a type checker does not have enough information to infer a type argument
for _SubClassValueT, and thus effectively treats it as Any.
  • Loading branch information
superbobry committed Feb 20, 2024
1 parent 3e77871 commit f0ce602
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 24 deletions.
2 changes: 1 addition & 1 deletion mlir/python/mlir/_mlir_libs/_mlir/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ class _Globals:
def _check_dialect_module_loaded(self, dialect_namespace: str) -> bool: ...

def register_dialect(dialect_class: type) -> object: ...
def register_operation(dialect_class: type) -> object: ...
def register_operation(dialect_class: type, *, replace: bool = ...) -> object: ...
7 changes: 0 additions & 7 deletions mlir/python/mlir/dialects/_ods_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Sequence as _Sequence,
Tuple as _Tuple,
Type as _Type,
TypeVar as _TypeVar,
Union as _Union,
)

Expand Down Expand Up @@ -143,12 +142,6 @@ def get_op_result_or_op_results(
else op
)


# This is the standard way to indicate subclass/inheritance relationship
# see the typing.Type doc string.
_U = _TypeVar("_U", bound=_cext.ir.Value)
SubClassValueT = _Type[_U]

ResultValueTypeTuple = _cext.ir.Operation, _cext.ir.OpView, _cext.ir.Value
ResultValueT = _Union[ResultValueTypeTuple]
VariadicResultValueT = _Union[ResultValueT, _Sequence[ResultValueT]]
Expand Down
3 changes: 1 addition & 2 deletions mlir/python/mlir/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
get_default_loc_context as _get_default_loc_context,
_cext as _ods_cext,
get_op_result_or_op_results as _get_op_result_or_op_results,
SubClassValueT as _SubClassValueT,
)

from typing import Any, List, Union
Expand Down Expand Up @@ -81,5 +80,5 @@ def literal_value(self) -> Union[int, float]:

def constant(
result: Type, value: Union[int, float, Attribute], *, loc=None, ip=None
) -> _SubClassValueT:
) -> Value:
return _get_op_result_or_op_results(ConstantOp(result, value, loc=loc, ip=ip))
1 change: 0 additions & 1 deletion mlir/test/mlir-tblgen/op-python-bindings.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
// CHECK: @_ods_cext.register_dialect
// CHECK: class _Dialect(_ods_ir.Dialect):
// CHECK: DIALECT_NAMESPACE = "test"
// CHECK: pass
def Test_Dialect : Dialect {
let name = "test";
let cppNamespace = "Test";
Expand Down
3 changes: 1 addition & 2 deletions mlir/test/python/ir/value.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import gc
from mlir.ir import *
from mlir.dialects import func
from mlir.dialects._ods_common import SubClassValueT


def run(f):
Expand Down Expand Up @@ -270,7 +269,7 @@ def __str__(self):
return super().__str__().replace(Value.__name__, NOPBlockArg.__name__)

@register_value_caster(IntegerType.static_typeid)
def cast_int(v) -> SubClassValueT:
def cast_int(v) -> Value:
print("in caster", v.__class__.__name__)
if isinstance(v, OpResult):
return NOPResult(v)
Expand Down
18 changes: 7 additions & 11 deletions mlir/tools/mlir-tblgen/OpPythonBindingGen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ constexpr const char *fileHeader = R"Py(
from ._ods_common import _cext as _ods_cext
from ._ods_common import (
SubClassValueT as _SubClassValueT,
equally_sized_accessor as _ods_equally_sized_accessor,
get_default_loc_context as _ods_get_default_loc_context,
get_op_result_or_op_results as _get_op_result_or_op_results,
Expand All @@ -52,8 +51,6 @@ constexpr const char *dialectClassTemplate = R"Py(
@_ods_cext.register_dialect
class _Dialect(_ods_ir.Dialect):
DIALECT_NAMESPACE = "{0}"
pass
)Py";

constexpr const char *dialectExtensionTemplate = R"Py(
Expand Down Expand Up @@ -1007,14 +1004,13 @@ static void emitValueBuilder(const Operator &op,
});
std::string nameWithoutDialect =
op.getOperationName().substr(op.getOperationName().find('.') + 1);
os << llvm::formatv(valueBuilderTemplate, sanitizeName(nameWithoutDialect),
op.getCppClassName(),
llvm::join(valueBuilderParams, ", "),
llvm::join(opBuilderArgs, ", "),
(op.getNumResults() > 1
? "_Sequence[_SubClassValueT]"
: (op.getNumResults() > 0 ? "_SubClassValueT"
: "_ods_ir.Operation")));
os << llvm::formatv(
valueBuilderTemplate, sanitizeName(nameWithoutDialect),
op.getCppClassName(), llvm::join(valueBuilderParams, ", "),
llvm::join(opBuilderArgs, ", "),
(op.getNumResults() > 1
? "_Sequence[_ods_ir.Value]"
: (op.getNumResults() > 0 ? "_ods_ir.Value" : "_ods_ir.Operation")));
}

/// Emits bindings for a specific Op to the given output stream.
Expand Down

0 comments on commit f0ce602

Please sign in to comment.