Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FFI] Re-introduce the boxed primitive values #17257

Merged
merged 3 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 58 additions & 18 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,16 @@ class DictAttrs : public Attrs {

auto it = node->dict.find(attr_key);
if (it != node->dict.end()) {
return Downcast<Optional<TObjectRef>>((*it).second);
// For backwards compatibility, return through TVMRetValue.
// This triggers any automatic conversions registered with
// PackedFuncValueConverter. Importantly, this allows use of
// `GetAttr<Integer>` and `GetAttr<Bool>` for properties that
// are stored internally as `runtime::Box<int64_t>` and
// `runtime::Box<bool>`.
TVMRetValue ret;
ret = (*it).second;
Optional<TObjectRef> obj = ret;
return obj;
} else {
return default_value;
}
Expand Down Expand Up @@ -315,6 +324,46 @@ inline TAttrs AttrsWithDefaultValues() {
return TAttrs(n);
}

/*!
* \brief Copy the DictAttrs, but overrides attributes with the
* entries from \p attrs.
*
* \param attrs The DictAttrs to update
*
* \param new_attrs Key/values attributes to add to \p attrs.
*
* \returns The new DictAttrs with updated attributes.
*/
DictAttrs WithAttrs(DictAttrs attrs, Map<String, ObjectRef> new_attrs);

/*!
* \brief Copy the DictAttrs, but overrides a single attribute.
*
* \param attrs The DictAttrs to update
*
* \param key The update to insert or update.
*
* \param value The new value of the attribute
*
* \returns The new DictAttrs with updated attributes.
*/
DictAttrs WithAttr(DictAttrs attrs, String key, ObjectRef value);

inline DictAttrs WithAttr(DictAttrs attrs, const std::string& key, ObjectRef value) {
return WithAttr(std::move(attrs), String(key), std::move(value));
}

/*!
* \brief Copy the DictAttrs, but without a specific attribute.
*
* \param attrs The DictAttrs to update
*
* \param key The key to remove
*
* \returns The new DictAttrs with updated attributes.
*/
DictAttrs WithoutAttr(DictAttrs attrs, const std::string& key);

/*!
* \brief Copy the function or module, but overrides
* the attribute value key with the value.
Expand Down Expand Up @@ -347,12 +396,8 @@ inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_v
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<String, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
node->attrs = WithAttr(std::move(node->attrs), attr_key, attr_value);

return input;
}

Expand All @@ -371,13 +416,9 @@ inline TFunc WithAttrs(TFunc input, Map<String, ObjectRef> attrs) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();
if (node->attrs.defined()) {
for (const auto& pair : attrs) {
node->attrs.CopyOnWrite()->dict.Set(pair.first, pair.second);
}
} else {
node->attrs = DictAttrs(std::move(attrs));
}

node->attrs = WithAttrs(std::move(node->attrs), attrs);

return input;
}

Expand Down Expand Up @@ -412,10 +453,9 @@ inline TFunc WithoutAttr(TFunc input, const std::string& attr_key) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");

if (input->attrs.defined()) {
TNode* node = input.CopyOnWrite();
node->attrs.CopyOnWrite()->dict.erase(attr_key);
}
TNode* node = input.CopyOnWrite();
node->attrs = WithoutAttr(std::move(node->attrs), attr_key);

return input;
}

Expand Down
130 changes: 99 additions & 31 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -770,53 +770,121 @@ inline const TTypeNode* RelayExprNode::type_as() const {

namespace tvm {
namespace runtime {
// common rule for RetValue and ArgValue

// Automatic conversion into IntImm, Integer, and Bool, when called
// through the FFI. Automatic conversions into PrimExpr are
// registered in "tvm/tir/expr.h", as it includes conversions to the
// TIR-only StringImm.
//
// While the FFI only requires the From() method, these
// implementations also define a TryFrom() method to avoid duplicate
// logic in the PrimExpr conversion.

template <>
struct PackedFuncValueConverter<PrimExpr> {
static PrimExpr From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return PrimExpr(ObjectPtr<Object>(nullptr));
}
if (val.type_code() == kDLInt) {
int64_t value = val.operator int64_t();
if (value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min()) {
return IntImm(runtime::DataType::Int(64), value);
}
return IntImm(runtime::DataType::Int(32), val.operator int());
}
if (val.type_code() == kDLFloat) {
return FloatImm(runtime::DataType::Float(32), val.operator double());
struct PackedFuncValueConverter<tvm::IntImm> {
template <typename PODSubclass>
static Optional<tvm::IntImm> TryFrom(const PODSubclass& val) {
if (auto opt = val.TryAsInt()) {
int64_t value = opt.value();
auto dtype =
(value > std::numeric_limits<int>::max() || value < std::numeric_limits<int>::min())
? DataType::Int(64)
: DataType::Int(32);
return IntImm(dtype, value);
} else if (auto opt = val.TryAsBool()) {
return IntImm(DataType::Int(32), opt.value());
} else {
return NullOpt;
}
}

return PrimExpr::FromObject_(val.AsObjectRef<ObjectRef>());
template <typename PODSubclass>
static tvm::IntImm From(const PODSubclass& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.template AsObjectRef<tvm::IntImm>();
}
}
};

template <>
struct PackedFuncValueConverter<tvm::Integer> {
static tvm::Integer From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Integer(ObjectPtr<Object>(nullptr));
template <typename PODSubclass>
static tvm::Integer From(const PODSubclass& val) {
if (auto opt = PackedFuncValueConverter<tvm::IntImm>::TryFrom(val)) {
return Integer(opt.value());
} else {
return val.template AsObjectRef<tvm::Integer>();
}
if (val.type_code() == kTVMArgInt) {
return Integer(val.operator int());
}
return val.AsObjectRef<tvm::Integer>();
}
};

template <>
struct PackedFuncValueConverter<tvm::Bool> {
static tvm::Bool From(const TVMPODValue_& val) {
if (val.type_code() == kTVMNullptr) {
return Bool(ObjectPtr<Object>(nullptr));
template <typename PODSubclass>
static Optional<tvm::Bool> TryFrom(const PODSubclass& val) {
if (auto opt = val.TryAsBool()) {
return tvm::Bool(opt.value());
} else if (auto opt = val.TryAsInt()) {
int value = opt.value();
ICHECK(value == 0 || value == 1)
<< "ValueError: boolean value can only be 0 or 1, but get " << value;
return tvm::Bool(static_cast<bool>(value));
} else {
return NullOpt;
}
}

template <typename PODSubclass>
static tvm::Bool From(const PODSubclass& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.template AsObjectRef<tvm::Bool>();
}
if (val.type_code() == kTVMArgInt) {
int v = val.operator int();
ICHECK(v == 0 || v == 1) << "ValueError: boolean value can only be 0 or 1, but get " << v;
return Bool(static_cast<bool>(v));
}
};

template <>
struct PackedFuncValueConverter<tvm::FloatImm> {
static Optional<tvm::FloatImm> TryFrom(const TVMPODValue_& val) {
if (auto opt = val.TryAsFloat()) {
return FloatImm(runtime::DataType::Float(32), opt.value());
} else {
return NullOpt;
}
}

template <typename PODSubclass>
static tvm::FloatImm From(const PODSubclass& val) {
if (auto opt = TryFrom(val)) {
return opt.value();
} else {
return val.template AsObjectRef<tvm::FloatImm>();
}
}
};

/* \brief Backwards compatibility wrapper for IntImm arguments
*
* In previous versions of TVM, IntImm was the default FFI type for
* integer arguments, instead of runtime::Int. For backwards
* compatibility where the callee has been updated to expected a
* runtime::Int, the caller has not been updated to provide a
* runtime::Int (e.g. relay script parsing), and the auto-unboxing of
* runtime::Int does not apply (e.g. making an `Array<runtime::Int>`),
* allow the IntImm to be generated.
*/
template <>
struct PackedFuncValueConverter<runtime::Int> {
template <typename PODSubclass>
static runtime::Int From(const PODSubclass& val) {
if (val.template IsObjectRef<tvm::IntImm>()) {
return runtime::Int(val.template AsObjectRef<tvm::IntImm>()->value);
} else {
return val.template AsObjectRef<runtime::Int>();
}
return val.AsObjectRef<tvm::Bool>();
}
};

Expand Down
34 changes: 32 additions & 2 deletions include/tvm/ir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,36 @@ class PassContext : public ObjectRef {
using ValueNodeType = typename ValueType::ContainerType;
// NOTE: we could further update the function later.
uint32_t tindex = ValueNodeType::_GetOrAllocRuntimeTypeIndex();
RegisterConfigOption(key, tindex);
auto type_key = runtime::Object::TypeIndex2Key(tindex);

auto* reflection = ReflectionVTable::Global();

auto legalization = [=](ObjectRef obj) -> ObjectRef {
if (obj->IsInstance<Map<String, ObjectRef>::ContainerType>()) {
return reflection->CreateObject(type_key, Downcast<Map<String, ObjectRef>>(obj));
} else {
// Backwards compatibility for config options defined prior to
// https://github.com/apache/tvm/pull/16183. This commit
// changed the default FFI conversion of python integers from
// `tvm::IntImm` to `runtime::Int`.
//
// This backwards compatibility fix can be removed when all
// options registered with TVM_REGISTER_PASS_CONFIG_OPTION are
// updated to use `runtime::Int` and `runtime::Bool`.
TVMRetValue ret;
ret = obj;
try {
ValueType legalized = ret;
return legalized;
} catch (Error& err) {
LOG(FATAL) << "AttributeError: expect config " << key << " to have type " << type_key
<< ", but received error when converting to this type.\n"
<< err.what();
}
}
};

RegisterConfigOption(key, tindex, legalization);
return tindex;
}

Expand All @@ -285,7 +314,8 @@ class PassContext : public ObjectRef {
// The exit of a pass context scope.
TVM_DLL void ExitWithScope();
// Register configuration key value type.
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index);
TVM_DLL static void RegisterConfigOption(const char* key, uint32_t value_type_index,
std::function<ObjectRef(ObjectRef)> legalization);

// Classes to get the Python `with` like syntax.
friend class Internal;
Expand Down
8 changes: 4 additions & 4 deletions include/tvm/meta_schedule/schedule_rule.h
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ class ScheduleRule : public runtime::ObjectRef {
* \param thread_extents Candidates of thread axis extent (values are required to be positive).
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule CrossThreadReduction(Array<Integer> thread_extents);
TVM_DLL static ScheduleRule CrossThreadReduction(Array<runtime::Int> thread_extents);
/*!
* \brief A rule that randomly select a compute-at location for a free block
* \return The schedule rule created
Expand All @@ -260,9 +260,9 @@ class ScheduleRule : public runtime::ObjectRef {
* \param unroll_explicit Whether to explicitly unroll the loop, or just add an "unroll" pragma.
* \return The schedule rule created
*/
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
int max_vectorize_extent, //
Array<Integer> unroll_max_steps, //
TVM_DLL static ScheduleRule ParallelizeVectorizeUnroll(int max_jobs_per_core, //
int max_vectorize_extent, //
Array<runtime::Int> unroll_max_steps, //
bool unroll_explicit);
/*!
* \brief Auto bind loops around the block to BlockIdx and ThreadIdx
Expand Down
2 changes: 1 addition & 1 deletion include/tvm/relay/attrs/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ struct SqueezeAttrs : public tvm::AttrsNode<SqueezeAttrs> {
}; // struct SqueezeAttrs

struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
ObjectRef indices_or_sections;
Variant<runtime::Int, Array<runtime::Int>> indices_or_sections;
int axis;

TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
Expand Down
5 changes: 4 additions & 1 deletion include/tvm/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
#ifdef __cplusplus
extern "C" {
#endif
#include <stdbool.h>
#include <stddef.h>
#include <stdint.h>

Expand Down Expand Up @@ -186,11 +187,12 @@ typedef enum {
kTVMBytes = 12U,
kTVMNDArrayHandle = 13U,
kTVMObjectRValueRefArg = 14U,
kTVMArgBool = 15U,
// Extension codes for other frameworks to integrate TVM PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
// Open an issue at the repo if you need a section of code.
kTVMExtBegin = 15U,
kTVMExtBegin = 16U,
kTVMNNVMFirst = 16U,
kTVMNNVMLast = 20U,
// The following section of code is used for non-reserved types.
Expand All @@ -207,6 +209,7 @@ typedef DLTensor* TVMArrayHandle;
*/
typedef union {
int64_t v_int64;
bool v_bool;
double v_float64;
void* v_handle;
const char* v_str;
Expand Down
Loading
Loading