Skip to content

[mlir][python] value casting #69644

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Nov 7, 2023
23 changes: 19 additions & 4 deletions mlir/include/mlir-c/Bindings/Python/Interop.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,28 @@

/** Attribute on main C extension module (_mlir) that corresponds to the
* type caster registration binding. The signature of the function is:
* def register_type_caster(MlirTypeID mlirTypeID, py::function typeCaster,
* bool replace)
* where replace indicates the typeCaster should replace any existing registered
* type casters (such as those for upstream ConcreteTypes).
* def register_type_caster(MlirTypeID mlirTypeID, *, bool replace)
* which then takes a typeCaster (register_type_caster is meant to be used as a
* decorator from python), and where replace indicates the typeCaster should
* replace any existing registered type casters (such as those for upstream
* ConcreteTypes). The interface of the typeCaster is: def type_caster(ir.Type)
* -> SubClassTypeT where SubClassTypeT indicates the result should be a
* subclass (inherit from) ir.Type.
*/
#define MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR "register_type_caster"

/** Attribute on main C extension module (_mlir) that corresponds to the
* value caster registration binding. The signature of the function is:
* def register_value_caster(MlirTypeID mlirTypeID, *, bool replace)
* which then takes a valueCaster (register_value_caster is meant to be used as
* a decorator, from python), and where replace indicates the valueCaster should
* replace any existing registered value casters. The interface of the
* valueCaster is: def value_caster(ir.Value) -> SubClassValueT where
* SubClassValueT indicates the result should be a subclass (inherit from)
* ir.Value.
*/
#define MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR "register_value_caster"

/// Gets a void* from a wrapped struct. Needed because const cast is different
/// between C/C++.
#ifdef __cplusplus
Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Bindings/Python/PybindAdaptors.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,6 +234,7 @@ struct type_caster<MlirValue> {
return py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr("Value")
.attr(MLIR_PYTHON_CAPI_FACTORY_ATTR)(capsule)
.attr(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR)()
.release();
};
};
Expand Down Expand Up @@ -496,11 +497,10 @@ class mlir_type_subclass : public pure_subclass {
if (getTypeIDFunction) {
py::module::import(MAKE_MLIR_PYTHON_QUALNAME("ir"))
.attr(MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR)(
getTypeIDFunction(),
pybind11::cpp_function(
[thisClass = thisClass](const py::object &mlirType) {
return thisClass(mlirType);
}));
getTypeIDFunction())(pybind11::cpp_function(
[thisClass = thisClass](const py::object &mlirType) {
return thisClass(mlirType);
}));
}
}
};
Expand Down
14 changes: 13 additions & 1 deletion mlir/lib/Bindings/Python/Globals.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,13 @@ class PyGlobals {
void registerTypeCaster(MlirTypeID mlirTypeID, pybind11::function typeCaster,
bool replace = false);

/// Adds a user-friendly value caster. Raises an exception if the mapping
/// already exists and replace == false. This is intended to be called by
/// implementation code.
void registerValueCaster(MlirTypeID mlirTypeID,
pybind11::function valueCaster,
bool replace = false);

/// Adds a concrete implementation dialect class.
/// Raises an exception if the mapping already exists.
/// This is intended to be called by implementation code.
Expand All @@ -86,6 +93,10 @@ class PyGlobals {
std::optional<pybind11::function> lookupTypeCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);

/// Returns the custom value caster for MlirTypeID mlirTypeID.
std::optional<pybind11::function> lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect);

/// Looks up a registered dialect class by namespace. Note that this may
/// trigger loading of the defining module and can arbitrarily re-enter.
std::optional<pybind11::object>
Expand All @@ -109,7 +120,8 @@ class PyGlobals {
llvm::StringMap<pybind11::object> attributeBuilderMap;
/// Map of MlirTypeID to custom type caster.
llvm::DenseMap<MlirTypeID, pybind11::object> typeCasterMap;

/// Map of MlirTypeID to custom value caster.
llvm::DenseMap<MlirTypeID, pybind11::object> valueCasterMap;
/// Set of dialect namespaces that we have attempted to import implementation
/// modules for.
llvm::StringSet<> loadedDialectModules;
Expand Down
31 changes: 27 additions & 4 deletions mlir/lib/Bindings/Python/IRCore.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1899,13 +1899,28 @@ bool PyTypeID::operator==(const PyTypeID &other) const {
}

//------------------------------------------------------------------------------
// PyValue and subclases.
// PyValue and subclasses.
//------------------------------------------------------------------------------

pybind11::object PyValue::getCapsule() {
return py::reinterpret_steal<py::object>(mlirPythonValueToCapsule(get()));
}

pybind11::object PyValue::maybeDownCast() {
MlirType type = mlirValueGetType(get());
MlirTypeID mlirTypeID = mlirTypeGetTypeID(type);
assert(!mlirTypeIDIsNull(mlirTypeID) &&
"mlirTypeID was expected to be non-null.");
std::optional<pybind11::function> valueCaster =
PyGlobals::get().lookupValueCaster(mlirTypeID, mlirTypeGetDialect(type));
// py::return_value_policy::move means use std::move to move the return value
// contents into a new instance that will be owned by Python.
py::object thisObj = py::cast(this, py::return_value_policy::move);
if (!valueCaster)
return thisObj;
return valueCaster.value()(thisObj);
}

PyValue PyValue::createFromCapsule(pybind11::object capsule) {
MlirValue value = mlirPythonCapsuleToValue(capsule.ptr());
if (mlirValueIsNull(value))
Expand Down Expand Up @@ -2121,6 +2136,8 @@ class PyConcreteValue : public PyValue {
return DerivedTy::isaFunction(otherValue);
},
py::arg("other_value"));
cls.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](DerivedTy &self) { return self.maybeDownCast(); });
DerivedTy::bindDerived(cls);
}

Expand Down Expand Up @@ -2193,6 +2210,7 @@ class PyBlockArgumentList
: public Sliceable<PyBlockArgumentList, PyBlockArgument> {
public:
static constexpr const char *pyClassName = "BlockArgumentList";
using SliceableT = Sliceable<PyBlockArgumentList, PyBlockArgument>;

PyBlockArgumentList(PyOperationRef operation, MlirBlock block,
intptr_t startIndex = 0, intptr_t length = -1,
Expand Down Expand Up @@ -2241,6 +2259,7 @@ class PyBlockArgumentList
class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
public:
static constexpr const char *pyClassName = "OpOperandList";
using SliceableT = Sliceable<PyOpOperandList, PyValue>;

PyOpOperandList(PyOperationRef operation, intptr_t startIndex = 0,
intptr_t length = -1, intptr_t step = 1)
Expand Down Expand Up @@ -2296,14 +2315,15 @@ class PyOpOperandList : public Sliceable<PyOpOperandList, PyValue> {
class PyOpResultList : public Sliceable<PyOpResultList, PyOpResult> {
public:
static constexpr const char *pyClassName = "OpResultList";
using SliceableT = Sliceable<PyOpResultList, PyOpResult>;

PyOpResultList(PyOperationRef operation, intptr_t startIndex = 0,
intptr_t length = -1, intptr_t step = 1)
: Sliceable(startIndex,
length == -1 ? mlirOperationGetNumResults(operation->get())
: length,
step),
operation(operation) {}
operation(std::move(operation)) {}

static void bindDerived(ClassTy &c) {
c.def_property_readonly("types", [](PyOpResultList &self) {
Expand Down Expand Up @@ -2892,7 +2912,8 @@ void mlir::python::populateIRCore(py::module &m) {
.str());
}
return PyOpResult(operation.getRef(),
mlirOperationGetResult(operation, 0));
mlirOperationGetResult(operation, 0))
.maybeDownCast();
},
"Shortcut to get an op result if it has only one (throws an error "
"otherwise).")
Expand Down Expand Up @@ -3566,7 +3587,9 @@ void mlir::python::populateIRCore(py::module &m) {
[](PyValue &self, PyValue &with) {
mlirValueReplaceAllUsesOfWith(self.get(), with.get());
},
kValueReplaceAllUsesWithDocstring);
kValueReplaceAllUsesWithDocstring)
.def(MLIR_PYTHON_MAYBE_DOWNCAST_ATTR,
[](PyValue &self) { return self.maybeDownCast(); });
PyBlockArgument::bind(m);
PyOpResult::bind(m);
PyOpOperand::bind(m);
Expand Down
21 changes: 21 additions & 0 deletions mlir/lib/Bindings/Python/IRModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ void PyGlobals::registerTypeCaster(MlirTypeID mlirTypeID,
found = std::move(typeCaster);
}

void PyGlobals::registerValueCaster(MlirTypeID mlirTypeID,
pybind11::function valueCaster,
bool replace) {
pybind11::object &found = valueCasterMap[mlirTypeID];
if (found && !replace)
throw std::runtime_error("Value caster is already registered: " +
py::repr(found).cast<std::string>());
found = std::move(valueCaster);
}

void PyGlobals::registerDialectImpl(const std::string &dialectNamespace,
py::object pyClass) {
py::object &found = dialectClassMap[dialectNamespace];
Expand Down Expand Up @@ -134,6 +144,17 @@ std::optional<py::function> PyGlobals::lookupTypeCaster(MlirTypeID mlirTypeID,
return std::nullopt;
}

std::optional<py::function> PyGlobals::lookupValueCaster(MlirTypeID mlirTypeID,
MlirDialect dialect) {
loadDialectModule(unwrap(mlirDialectGetNamespace(dialect)));
const auto foundIt = valueCasterMap.find(mlirTypeID);
if (foundIt != valueCasterMap.end()) {
assert(foundIt->second && "value caster is defined");
return foundIt->second;
}
return std::nullopt;
}

std::optional<py::object>
PyGlobals::lookupDialectClass(const std::string &dialectNamespace) {
// Make sure dialect module is loaded.
Expand Down
14 changes: 9 additions & 5 deletions mlir/lib/Bindings/Python/IRModule.h
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ class PyRegion {

/// Wrapper around an MlirAsmState.
class PyAsmState {
public:
public:
PyAsmState(MlirValue value, bool useLocalScope) {
flags = mlirOpPrintingFlagsCreate();
// The OpPrintingFlags are not exposed Python side, create locally and
Expand All @@ -780,16 +780,14 @@ class PyAsmState {
state =
mlirAsmStateCreateForOperation(operation.getOperation().get(), flags);
}
~PyAsmState() {
mlirOpPrintingFlagsDestroy(flags);
}
~PyAsmState() { mlirOpPrintingFlagsDestroy(flags); }
// Delete copy constructors.
PyAsmState(PyAsmState &other) = delete;
PyAsmState(const PyAsmState &other) = delete;

MlirAsmState get() { return state; }

private:
private:
MlirAsmState state;
MlirOpPrintingFlags flags;
};
Expand Down Expand Up @@ -1112,6 +1110,10 @@ class PyConcreteAttribute : public BaseTy {
/// bindings so such operation always exists).
class PyValue {
public:
// The virtual here is "load bearing" in that it enables RTTI
// for PyConcreteValue CRTP classes that support maybeDownCast.
// See PyValue::maybeDownCast.
virtual ~PyValue() = default;
PyValue(PyOperationRef parentOperation, MlirValue value)
: parentOperation(std::move(parentOperation)), value(value) {}
operator MlirValue() const { return value; }
Expand All @@ -1124,6 +1126,8 @@ class PyValue {
/// Gets a capsule wrapping the void* within the MlirValue.
pybind11::object getCapsule();

pybind11::object maybeDownCast();

/// Creates a PyValue from the MlirValue wrapped by a capsule. Ownership of
/// the underlying MlirValue is still tied to the owning operation.
static PyValue createFromCapsule(pybind11::object capsule);
Expand Down
30 changes: 22 additions & 8 deletions mlir/lib/Bindings/Python/MainModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
#include "IRModule.h"
#include "Pass.h"

#include <tuple>

namespace py = pybind11;
using namespace mlir;
using namespace py::literals;
Expand Down Expand Up @@ -46,7 +44,8 @@ PYBIND11_MODULE(_mlir, m) {
"dialect_namespace"_a, "dialect_class"_a,
"Testing hook for directly registering a dialect")
.def("_register_operation_impl", &PyGlobals::registerOperationImpl,
"operation_name"_a, "operation_class"_a, "replace"_a = false,
"operation_name"_a, "operation_class"_a, py::kw_only(),
"replace"_a = false,
"Testing hook for directly registering an operation");

// Aside from making the globals accessible to python, having python manage
Expand Down Expand Up @@ -82,17 +81,32 @@ PYBIND11_MODULE(_mlir, m) {
return opClass;
});
},
"dialect_class"_a, "replace"_a = false,
"dialect_class"_a, py::kw_only(), "replace"_a = false,
"Produce a class decorator for registering an Operation class as part of "
"a dialect");
m.def(
MLIR_PYTHON_CAPI_TYPE_CASTER_REGISTER_ATTR,
[](MlirTypeID mlirTypeID, py::function typeCaster, bool replace) {
PyGlobals::get().registerTypeCaster(mlirTypeID, std::move(typeCaster),
replace);
[](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
return py::cpp_function([mlirTypeID,
replace](py::object typeCaster) -> py::object {
PyGlobals::get().registerTypeCaster(mlirTypeID, typeCaster, replace);
return typeCaster;
});
},
"typeid"_a, "type_caster"_a, "replace"_a = false,
"typeid"_a, py::kw_only(), "replace"_a = false,
"Register a type caster for casting MLIR types to custom user types.");
m.def(
MLIR_PYTHON_CAPI_VALUE_CASTER_REGISTER_ATTR,
[](MlirTypeID mlirTypeID, bool replace) -> py::cpp_function {
return py::cpp_function(
[mlirTypeID, replace](py::object valueCaster) -> py::object {
PyGlobals::get().registerValueCaster(mlirTypeID, valueCaster,
replace);
return valueCaster;
});
},
"typeid"_a, py::kw_only(), "replace"_a = false,
"Register a value caster for casting MLIR values to custom user values.");

// Define and populate IR submodule.
auto irModule = m.def_submodule("ir", "MLIR IR Bindings");
Expand Down
15 changes: 13 additions & 2 deletions mlir/lib/Bindings/Python/PybindUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#define MLIR_BINDINGS_PYTHON_PYBINDUTILS_H

#include "mlir-c/Support.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/Twine.h"
#include "llvm/Support/DataTypes.h"

Expand Down Expand Up @@ -228,6 +229,11 @@ class Sliceable {
return linearIndex;
}

/// Trait to check if T provides a `maybeDownCast` method.
/// Note, you need the & to detect inherited members.
template <typename T, typename... Args>
using has_maybe_downcast = decltype(&T::maybeDownCast);

/// Returns the element at the given slice index. Supports negative indices
/// by taking elements in inverse order. Returns a nullptr object if out
/// of bounds.
Expand All @@ -239,8 +245,13 @@ class Sliceable {
return {};
}

return pybind11::cast(
static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
if constexpr (llvm::is_detected<has_maybe_downcast, ElementTy>::value)
return static_cast<Derived *>(this)
->getRawElement(linearizeIndex(index))
.maybeDownCast();
else
return pybind11::cast(
static_cast<Derived *>(this)->getRawElement(linearizeIndex(index)));
}

/// Returns a new instance of the pseudo-container restricted to the given
Expand Down
13 changes: 12 additions & 1 deletion mlir/python/mlir/dialects/_ods_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
# Provide a convenient name for sub-packages to resolve the main C-extension
# with a relative import.
from .._mlir_libs import _mlir as _cext
from typing import Sequence as _Sequence, Union as _Union
from typing import (
Sequence as _Sequence,
Type as _Type,
TypeVar as _TypeVar,
Union as _Union,
)

__all__ = [
"equally_sized_accessor",
Expand Down Expand Up @@ -123,3 +128,9 @@ def get_op_result_or_op_results(
if len(op.results) > 0
else op
)


# This is the standard way to indicate subclass/inheritance relationship
# see the typing.Type doc string.
_U = _TypeVar("_U", bound=_cext.ir.Value)
SubClassValueT = _Type[_U]
2 changes: 1 addition & 1 deletion mlir/python/mlir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from ._mlir_libs._mlir.ir import *
from ._mlir_libs._mlir.ir import _GlobalDebug
from ._mlir_libs._mlir import register_type_caster
from ._mlir_libs._mlir import register_type_caster, register_value_caster


# Convenience decorator for registering user-friendly Attribute builders.
Expand Down
Loading