Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DictAttrs to IRModule and refactor DictAttrs utility functions #8750

Merged
merged 9 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final;

// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
Expand All @@ -232,6 +233,72 @@ class DictAttrs : public Attrs {
*/
TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);

// Utils for accessing attributes
// This needs to be on DictAttrs, not DictAttrsNode because we return the default
// value if DictAttrsNode is not defined.
/*!
* \brief Get a function attribute.
*
* \param attr_key The attribute key.
* \param default_value The default value if the key does not exist, defaults to nullptr.
*
* \return The result
*
* \tparam TOBjectRef the expected object type.
* \throw Error if the key exists but the value does not match TObjectRef
*
* \code
*
* void GetAttrExample(const BaseFunc& f) {
* auto value = f->attrs.GetAttr<Integer>("AttrKey", 0);
Copy link
Member

@tqchen tqchen Aug 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to keep the function GetAttr and HasNonZeroAttr function in BaseFunc and IRModule, based on the considerations:

  • C0: It will remove one level of indirection(f->GetAttr vs f->attrs.GetAttr) and gives more clear documentation(since developer usually looks up doc on the Function or IRModule themselves)
  • C1: API consitency: WithAttr directly operates on the function and module, and the functions with related functionalities should ideally be made consistent with this usage.
  • C2: If there is a future refactor that changes DictAttr => Map, the API can be made consistent in a backward compaitble way

We can of course keep a common impl as well and redirect in the function and module case.

* }
*
* \endcode
*/
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
if (!defined()) return default_value;
const DictAttrsNode* node = this->as<DictAttrsNode>();

auto it = node->dict.find(attr_key);
if (it != node->dict.end()) {
return Downcast<Optional<TObjectRef>>((*it).second);
} else {
return default_value;
}
}
// variant that uses TObjectRef to enable implicit conversion to default value.
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}
/*!
* \brief Check whether the function has an non-zero integer attr.
*
* This function can be used to check whether an optional
* attribute mark(e.g. inline) exists.
*
* \param attr_key The key to the attribute.
* \return The check result.
*
* \code
*
* void HasNonzeroAttrExample(const BaseFunc& f) {
* if (f->HasNonzeroAttr(attr::kInline)) {
* // inline the function.
* }
* }
*
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<Integer>(attr_key, 0) != 0;
}

TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};
Expand All @@ -249,6 +316,47 @@ inline TAttrs AttrsWithDefaultValues() {
return TAttrs(n);
}

/*!
* \brief Copy the function or module, but overrides
* the attribute value key with the value.
*
* \param input The thing to annotate (BaseFunc or IRModule)
* \param attr_key The attribute key.
* \param attr_value The value attribute value.
*
* \tparam TFunc The corresponding function or module type.
*
* \returns The new function or module with updated attributes.
*
* \note This function performs copy on write optimization for func and module.
* If we move a uniquely referenced func or module into WithAttr,
* then no additional copy will be performed.
*
* This is also why we make it as a function instead of a member function
* and why we pass by value in the first argument.
*
* \code
*
* // Recommended way to trigger copy on write
* func = WithAttr(std::move(func), "key1", value1);
* func = WithAttr(std::move(func), "key2", value2);
*
* \endcode
*/
template <typename TFunc>
inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<String, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
return input;
}

// Namespace containing detail implementations
namespace detail {
using runtime::TVMArgValue;
Expand Down
57 changes: 3 additions & 54 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,21 +102,14 @@ class BaseFuncNode : public RelayExprNode {
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
if (!attrs.defined()) return default_value;
auto it = attrs->dict.find(attr_key);
if (it != attrs->dict.end()) {
return Downcast<Optional<TObjectRef>>((*it).second);
} else {
return default_value;
}
return attrs.GetAttr(attr_key, default_value);
}
// variant that uses TObjectRef to enable implicit conversion to default value.
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}

/*!
* \brief Check whether the function has an non-zero integer attr.
*
Expand All @@ -136,9 +129,7 @@ class BaseFuncNode : public RelayExprNode {
*
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<Integer>(attr_key, 0) != 0;
}
bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); }

static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
Expand All @@ -154,48 +145,6 @@ class BaseFunc : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};

/*!
* \brief Create a new function that copies func, but overrides
* the attribute value key with the value.
*
* \param func The input function.
* \param attr_key The attribute key.
* \param attr_value The value attribute value.
*
* \tparam TFunc The corresponding function type.
*
* \returns The new function with updated attributes.
*
* \note This function performs copy on write optimization for func.
* If we move a uniquely referenced func into WithAttr,
* then no additional copy will be performed.
*
* This is also why we make it as a function instead of a member function
* and why we pass by value in the first argument.
*
* \code
*
* // Recommended way to trigger copy on write
* func = WithAttr(std::move(func), "key1", value1);
* func = WithAttr(std::move(func), "key2", value2);
*
* \endcode
*/
template <typename TFunc,
typename = typename std::enable_if<std::is_base_of<BaseFunc, TFunc>::value>::type>
inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = func.CopyOnWrite();
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<String, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
return func;
}

/*!
* \brief Generic attribute names that can be attached to any function.
*
Expand Down
54 changes: 54 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,60 @@ class IRModuleNode : public Object {
Map<GlobalTypeVar, TypeData> type_definitions;
/*! \brief The source map for the module. */
parser::SourceMap source_map;
/* \brief Additional attributes storing meta-data about the module. */
DictAttrs attrs;

/*!
* \brief Get a module attribute.
*
* \param attr_key The attribute key.
* \param default_value The default value if the key does not exist, defaults to nullptr.
*
* \return The result
*
* \tparam TOBjectRef the expected object type.
* \throw Error if the key exists but the value does not match TObjectRef
*
* \code
*
* void GetAttrExample(const IRModule& mod) {
* auto value = f->GetAttr<Integer>("AttrKey", 0);
* }
*
* \endcode
*/
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
return attrs.GetAttr(attr_key, default_value);
}
// variant that uses TObjectRef to enable implicit conversion to default value.
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}

/*!
* \brief Check whether the module has an non-zero integer attr.
*
* This function can be used to check whether an optional
* attribute mark(e.g. inline) exists.
*
* \param attr_key The key to the attribute.
* \return The check result.
*
* \code
*
* void HasNonzeroAttrExample(const IRModule& mod) {
* if (mod->HasNonzeroAttr(attr::kInline)) {
* // inline the function.
* }
* }
*
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); }

IRModuleNode() : source_map() {}

Expand Down