From 4aa216cbca300fd26de289f4d14636713138279b Mon Sep 17 00:00:00 2001
From: Lily Orth-Smith <lilyorthsmith@gmail.com>
Date: Mon, 16 Aug 2021 22:44:59 -0700
Subject: [PATCH] Add DictAttrs to IRModule and refactor DictAttrs utility
 functions (#8750)

* Add DictAttrs to IRModuleNode

Move GetAttrs to be a member of DictAttrs

Generalize WithAttrs to work with IRModule and move to attrs.h

Change func->GetAttr to func->attrs.GetAttr

* lint

* Fix documentation

* fix typo

* Another typo!

* Revert GetAttrs to ->attrs.GetAttrs change

* Didn't mean to revert these

* Revert a few more things

* Add GetAttrs to IRModuleNode
---
 include/tvm/ir/attrs.h    | 108 ++++++++++++++++++++++++++++++++++++++
 include/tvm/ir/function.h |  57 ++------------------
 include/tvm/ir/module.h   |  54 +++++++++++++++++++
 3 files changed, 165 insertions(+), 54 deletions(-)

diff --git a/include/tvm/ir/attrs.h b/include/tvm/ir/attrs.h
index da7bc12619bd..fa1861051e2f 100644
--- a/include/tvm/ir/attrs.h
+++ b/include/tvm/ir/attrs.h
@@ -214,6 +214,7 @@ class DictAttrsNode : public BaseAttrsNode {
   void VisitNonDefaultAttrs(AttrVisitor* v) final;
   void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
   Array<AttrFieldInfo> ListFieldInfo() const final;
+
   // type info
   static constexpr const char* _type_key = "DictAttrs";
   TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
@@ -232,6 +233,72 @@ class DictAttrs : public Attrs {
    */
   TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);
 
+  // Utils for accessing attributes
+  // This needs to be on DictAttrs, not DictAttrsNode because we return the default
+  // value if DictAttrsNode is not defined.
+  /*!
+   * \brief Get a function attribute.
+   *
+   * \param attr_key The attribute key.
+   * \param default_value The default value if the key does not exist, defaults to nullptr.
+   *
+   * \return The result
+   *
+   * \tparam TOBjectRef the expected object type.
+   * \throw Error if the key exists but the value does not match TObjectRef
+   *
+   * \code
+   *
+   *  void GetAttrExample(const BaseFunc& f) {
+   *    auto value = f->attrs.GetAttr<Integer>("AttrKey", 0);
+   *  }
+   *
+   * \endcode
+   */
+  template <typename TObjectRef>
+  Optional<TObjectRef> GetAttr(
+      const std::string& attr_key,
+      Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
+    static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
+                  "Can only call GetAttr with ObjectRef types.");
+    if (!defined()) return default_value;
+    const DictAttrsNode* node = this->as<DictAttrsNode>();
+
+    auto it = node->dict.find(attr_key);
+    if (it != node->dict.end()) {
+      return Downcast<Optional<TObjectRef>>((*it).second);
+    } else {
+      return default_value;
+    }
+  }
+  // variant that uses TObjectRef to enable implicit conversion to default value.
+  template <typename TObjectRef>
+  Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
+    return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
+  }
+  /*!
+   * \brief Check whether the function has an non-zero integer attr.
+   *
+   * This function can be used to check whether an optional
+   * attribute mark(e.g. inline) exists.
+   *
+   * \param attr_key The key to the attribute.
+   * \return The check result.
+   *
+   * \code
+   *
+   *  void HasNonzeroAttrExample(const BaseFunc& f) {
+   *    if (f->HasNonzeroAttr(attr::kInline)) {
+   *      // inline the function.
+   *    }
+   *  }
+   *
+   * \endcode
+   */
+  bool HasNonzeroAttr(const std::string& attr_key) const {
+    return GetAttr<Integer>(attr_key, 0) != 0;
+  }
+
   TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
   TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
 };
@@ -249,6 +316,47 @@ inline TAttrs AttrsWithDefaultValues() {
   return TAttrs(n);
 }
 
+/*!
+ * \brief Copy the function or module, but overrides
+ *        the attribute value key with the value.
+ *
+ * \param input The thing to annotate (BaseFunc or IRModule)
+ * \param attr_key The attribute key.
+ * \param attr_value The value attribute value.
+ *
+ * \tparam TFunc The corresponding function or module type.
+ *
+ * \returns The new function or module with updated attributes.
+ *
+ * \note This function performs copy on write optimization for func and module.
+ *       If we move a uniquely referenced func or module into WithAttr,
+ *       then no additional copy will be performed.
+ *
+ *       This is also why we make it as a function instead of a member function
+ *       and why we pass by value in the first argument.
+ *
+ * \code
+ *
+ *  // Recommended way to trigger copy on write
+ *  func = WithAttr(std::move(func), "key1", value1);
+ *  func = WithAttr(std::move(func), "key2", value2);
+ *
+ * \endcode
+ */
+template <typename TFunc>
+inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) {
+  using TNode = typename TFunc::ContainerType;
+  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
+  TNode* node = input.CopyOnWrite();
+  if (node->attrs.defined()) {
+    node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
+  } else {
+    Map<String, ObjectRef> dict = {{attr_key, attr_value}};
+    node->attrs = DictAttrs(dict);
+  }
+  return input;
+}
+
 // Namespace containing detail implementations
 namespace detail {
 using runtime::TVMArgValue;
diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h
index 09c074cb71bd..13b984d9cb35 100644
--- a/include/tvm/ir/function.h
+++ b/include/tvm/ir/function.h
@@ -102,21 +102,14 @@ class BaseFuncNode : public RelayExprNode {
   Optional<TObjectRef> GetAttr(
       const std::string& attr_key,
       Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
-    static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
-                  "Can only call GetAttr with ObjectRef types.");
-    if (!attrs.defined()) return default_value;
-    auto it = attrs->dict.find(attr_key);
-    if (it != attrs->dict.end()) {
-      return Downcast<Optional<TObjectRef>>((*it).second);
-    } else {
-      return default_value;
-    }
+    return attrs.GetAttr(attr_key, default_value);
   }
   // variant that uses TObjectRef to enable implicit conversion to default value.
   template <typename TObjectRef>
   Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
     return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
   }
+
   /*!
    * \brief Check whether the function has an non-zero integer attr.
    *
@@ -136,9 +129,7 @@ class BaseFuncNode : public RelayExprNode {
    *
    * \endcode
    */
-  bool HasNonzeroAttr(const std::string& attr_key) const {
-    return GetAttr<Integer>(attr_key, 0) != 0;
-  }
+  bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); }
 
   static constexpr const char* _type_key = "BaseFunc";
   static constexpr const uint32_t _type_child_slots = 2;
@@ -154,48 +145,6 @@ class BaseFunc : public RelayExpr {
   TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
 };
 
-/*!
- * \brief Create a new function that copies func, but overrides
- *        the attribute value key with the value.
- *
- * \param func The input function.
- * \param attr_key The attribute key.
- * \param attr_value The value attribute value.
- *
- * \tparam TFunc The corresponding function type.
- *
- * \returns The new function with updated attributes.
- *
- * \note This function performs copy on write optimization for func.
- *       If we move a uniquely referenced func into WithAttr,
- *       then no additional copy will be performed.
- *
- *       This is also why we make it as a function instead of a member function
- *       and why we pass by value in the first argument.
- *
- * \code
- *
- *  // Recommended way to trigger copy on write
- *  func = WithAttr(std::move(func), "key1", value1);
- *  func = WithAttr(std::move(func), "key2", value2);
- *
- * \endcode
- */
-template <typename TFunc,
-          typename = typename std::enable_if<std::is_base_of<BaseFunc, TFunc>::value>::type>
-inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) {
-  using TNode = typename TFunc::ContainerType;
-  static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
-  TNode* node = func.CopyOnWrite();
-  if (node->attrs.defined()) {
-    node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
-  } else {
-    Map<String, ObjectRef> dict = {{attr_key, attr_value}};
-    node->attrs = DictAttrs(dict);
-  }
-  return func;
-}
-
 /*!
  * \brief Generic attribute names that can be attached to any function.
  *
diff --git a/include/tvm/ir/module.h b/include/tvm/ir/module.h
index 638f132e3179..9ca27ec3b661 100644
--- a/include/tvm/ir/module.h
+++ b/include/tvm/ir/module.h
@@ -58,6 +58,60 @@ class IRModuleNode : public Object {
   Map<GlobalTypeVar, TypeData> type_definitions;
   /*! \brief The source map for the module. */
   parser::SourceMap source_map;
+  /* \brief Additional attributes storing meta-data about the module. */
+  DictAttrs attrs;
+
+  /*!
+   * \brief Get a module attribute.
+   *
+   * \param attr_key The attribute key.
+   * \param default_value The default value if the key does not exist, defaults to nullptr.
+   *
+   * \return The result
+   *
+   * \tparam TOBjectRef the expected object type.
+   * \throw Error if the key exists but the value does not match TObjectRef
+   *
+   * \code
+   *
+   *  void GetAttrExample(const IRModule& mod) {
+   *    auto value = f->GetAttr<Integer>("AttrKey", 0);
+   *  }
+   *
+   * \endcode
+   */
+  template <typename TObjectRef>
+  Optional<TObjectRef> GetAttr(
+      const std::string& attr_key,
+      Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
+    return attrs.GetAttr(attr_key, default_value);
+  }
+  // variant that uses TObjectRef to enable implicit conversion to default value.
+  template <typename TObjectRef>
+  Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
+    return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
+  }
+
+  /*!
+   * \brief Check whether the module has an non-zero integer attr.
+   *
+   * This function can be used to check whether an optional
+   * attribute mark(e.g. inline) exists.
+   *
+   * \param attr_key The key to the attribute.
+   * \return The check result.
+   *
+   * \code
+   *
+   *  void HasNonzeroAttrExample(const IRModule& mod) {
+   *    if (mod->HasNonzeroAttr(attr::kInline)) {
+   *      // inline the function.
+   *    }
+   *  }
+   *
+   * \endcode
+   */
+  bool HasNonzeroAttr(const std::string& attr_key) const { return attrs.HasNonzeroAttr(attr_key); }
 
   IRModuleNode() : source_map() {}