Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add DictAttrs to IRModule and refactor DictAttrs utility functions #8750

Merged
merged 9 commits into from
Aug 17, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 108 additions & 0 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ class DictAttrsNode : public BaseAttrsNode {
void VisitNonDefaultAttrs(AttrVisitor* v) final;
void InitByPackedArgs(const runtime::TVMArgs& args, bool allow_unknown) final;
Array<AttrFieldInfo> ListFieldInfo() const final;

// type info
static constexpr const char* _type_key = "DictAttrs";
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
Expand All @@ -232,6 +233,72 @@ class DictAttrs : public Attrs {
*/
TVM_DLL explicit DictAttrs(Map<String, ObjectRef> dict);

// Utils for accessing attributes
// This needs to be on DictAttrs, not DictAttrsNode because we return the default
// value if DictAttrsNode is not defined.
/*!
* \brief Get a function attribute.
*
* \param attr_key The attribute key.
* \param default_value The default value if the key does not exist, defaults to nullptr.
*
* \return The result
*
* \tparam TOBjectRef the expected object type.
* \throw Error if the key exists but the value does not match TObjectRef
*
* \code
*
* void GetAttrExample(const BaseFunc& f) {
* auto value = f->attrs.GetAttr<Integer>("AttrKey", 0);
Copy link
Member

@tqchen tqchen Aug 16, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to keep the function GetAttr and HasNonZeroAttr function in BaseFunc and IRModule, based on the considerations:

  • C0: It will remove one level of indirection(f->GetAttr vs f->attrs.GetAttr) and gives more clear documentation(since developer usually looks up doc on the Function or IRModule themselves)
  • C1: API consitency: WithAttr directly operates on the function and module, and the functions with related functionalities should ideally be made consistent with this usage.
  • C2: If there is a future refactor that changes DictAttr => Map, the API can be made consistent in a backward compaitble way

We can of course keep a common impl as well and redirect in the function and module case.

* }
*
* \endcode
*/
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
if (!defined()) return default_value;
const DictAttrsNode* node = this->as<DictAttrsNode>();

auto it = node->dict.find(attr_key);
if (it != node->dict.end()) {
return Downcast<Optional<TObjectRef>>((*it).second);
} else {
return default_value;
}
}
// variant that uses TObjectRef to enable implicit conversion to default value.
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}
/*!
* \brief Check whether the function has an non-zero integer attr.
*
* This function can be used to check whether an optional
* attribute mark(e.g. inline) exists.
*
* \param attr_key The key to the attribute.
* \return The check result.
*
* \code
*
* void HasNonzeroAttrExample(const BaseFunc& f) {
* if (f->attrs.HasNonzeroAttr(attr::kInline)) {
* // inline the function.
* }
* }
*
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<Integer>(attr_key, 0) != 0;
}

TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};
Expand All @@ -249,6 +316,47 @@ inline TAttrs AttrsWithDefaultValues() {
return TAttrs(n);
}

/*!
* \brief Copy the function or module, but overrides
* the attribute value key with the value.
*
* \param input The thing to annotate (BaseFunc or IRModule)
* \param attr_key The attribute key.
* \param attr_value The value attribute value.
*
* \tparam TFunc The corresponding function or module type.
*
* \returns The new function or module with updated attributes.
*
* \note This function performs copy on write optimization for func and module.
* If we move a uniquely referenced func or module into WithAttr,
* then no additional copy will be performed.
*
* This is also why we make it as a function instead of a member function
* and why we pass by value in the first argument.
*
* \code
*
* // Recommended way to trigger copy on write
* func = WithAttr(std::move(func), "key1", value1);
* func = WithAttr(std::move(func), "key2", value2);
*
* \endcode
*/
template <typename TFunc>
inline TFunc WithAttr(TFunc input, const std::string& attr_key, ObjectRef attr_value) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = input.CopyOnWrite();
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<String, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
return input;
}

// Namespace containing detail implementations
namespace detail {
using runtime::TVMArgValue;
Expand Down
103 changes: 0 additions & 103 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,67 +79,6 @@ class BaseFuncNode : public RelayExprNode {
/*! \brief Additional attributes storing the meta-data */
DictAttrs attrs;

/*!
* \brief Get a function attribute.
*
* \param attr_key The attribute key.
* \param default_value The default value if the key does not exist, defaults to nullptr.
*
* \return The result
*
* \tparam TOBjectRef the expected object type.
* \throw Error if the key exists but the value does not match TObjectRef
*
* \code
*
* void GetAttrExample(const BaseFunc& f) {
* auto value = f->GetAttr<Integer>("AttrKey", 0);
* }
*
* \endcode
*/
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(
const std::string& attr_key,
Optional<TObjectRef> default_value = Optional<TObjectRef>(nullptr)) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
if (!attrs.defined()) return default_value;
auto it = attrs->dict.find(attr_key);
if (it != attrs->dict.end()) {
return Downcast<Optional<TObjectRef>>((*it).second);
} else {
return default_value;
}
}
// variant that uses TObjectRef to enable implicit conversion to default value.
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
}
/*!
* \brief Check whether the function has an non-zero integer attr.
*
* This function can be used to check whether an optional
* attribute mark(e.g. inline) exists.
*
* \param attr_key The key to the attribute.
* \return The check result.
*
* \code
*
* void HasNonzeroAttrExample(const BaseFunc& f) {
* if (f->HasNonzeroAttr(attr::kInline)) {
* // inline the function.
* }
* }
*
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<Integer>(attr_key, 0) != 0;
}

static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
Expand All @@ -154,48 +93,6 @@ class BaseFunc : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};

/*!
* \brief Create a new function that copies func, but overrides
* the attribute value key with the value.
*
* \param func The input function.
* \param attr_key The attribute key.
* \param attr_value The value attribute value.
*
* \tparam TFunc The corresponding function type.
*
* \returns The new function with updated attributes.
*
* \note This function performs copy on write optimization for func.
* If we move a uniquely referenced func into WithAttr,
* then no additional copy will be performed.
*
* This is also why we make it as a function instead of a member function
* and why we pass by value in the first argument.
*
* \code
*
* // Recommended way to trigger copy on write
* func = WithAttr(std::move(func), "key1", value1);
* func = WithAttr(std::move(func), "key2", value2);
*
* \endcode
*/
template <typename TFunc,
typename = typename std::enable_if<std::is_base_of<BaseFunc, TFunc>::value>::type>
inline TFunc WithAttr(TFunc func, const std::string& attr_key, ObjectRef attr_value) {
using TNode = typename TFunc::ContainerType;
static_assert(TNode::_type_final, "Can only operate on the leaf nodes");
TNode* node = func.CopyOnWrite();
if (node->attrs.defined()) {
node->attrs.CopyOnWrite()->dict.Set(attr_key, attr_value);
} else {
Map<String, ObjectRef> dict = {{attr_key, attr_value}};
node->attrs = DictAttrs(dict);
}
return func;
}

/*!
* \brief Generic attribute names that can be attached to any function.
*
Expand Down
2 changes: 2 additions & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ class IRModuleNode : public Object {
Map<GlobalTypeVar, TypeData> type_definitions;
/*! \brief The source map for the module. */
parser::SourceMap source_map;
/* \brief Additional attributes storing meta-data about the module. */
DictAttrs attrs;

IRModuleNode() : source_map() {}

Expand Down
3 changes: 2 additions & 1 deletion include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class TargetNode : public Object {
/*! \brief Keys for this target */
Array<String> keys;
/*! \brief Collection of attributes */
Map<String, ObjectRef> attrs;
Map<String, ObjectRef> attrs; // TODO(@electriclilies): Unify with DictAttrs on IRModule
electriclilies marked this conversation as resolved.
Show resolved Hide resolved
/*!
* \brief The raw string representation of the target
* \return the full device string to pass to codegen::Build
Expand Down Expand Up @@ -101,6 +101,7 @@ class TargetNode : public Object {
* \param default_value The value returned if the key is not present
* \return An optional, NullOpt if not found, otherwise the value found
*/
// TODO(@electriclilies): Remove once we have removed the target attrs
template <typename TObjectRef>
Optional<TObjectRef> GetAttr(const std::string& attr_key, TObjectRef default_value) const {
return GetAttr<TObjectRef>(attr_key, Optional<TObjectRef>(default_value));
Expand Down
4 changes: 2 additions & 2 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target

auto host_pass_list = {
Filter([](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) !=
return f->attrs.GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) !=
CallingConv::kDeviceKernelLaunch;
}),
BindTarget(target_host),
Expand All @@ -418,7 +418,7 @@ std::pair<IRModule, IRModule> SplitDevHostFuncs(IRModule mod_mixed, const Target
// device pipeline
auto device_pass_list = {
Filter([](const tir::PrimFunc& f) {
return f->GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
return f->attrs.GetAttr<Integer>(tvm::attr::kCallingConv, Integer(CallingConv::kDefault)) ==
CallingConv::kDeviceKernelLaunch;
}),
BindTarget(target),
Expand Down
2 changes: 1 addition & 1 deletion src/ir/module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -356,7 +356,7 @@ IRModule IRModule::FromExpr(const RelayExpr& expr,

if (auto* func_node = expr.as<BaseFuncNode>()) {
func = GetRef<BaseFunc>(func_node);
if (auto opt = func->GetAttr<String>(tvm::attr::kGlobalSymbol)) {
if (auto opt = func->attrs.GetAttr<String>(tvm::attr::kGlobalSymbol)) {
gv_name = opt.value();
}

Expand Down
6 changes: 4 additions & 2 deletions src/relay/analysis/context_analysis.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ class ContextAnalyzer : public MixedModeVisitor {
auto func = GetRef<Function>(fn);
// No need to step into fused primitive functions as they are handled as
// a whole.
if (fn->HasNonzeroAttr(attr::kPrimitive)) {
if (fn->attrs.HasNonzeroAttr(attr::kPrimitive)) {
return;
}

Expand Down Expand Up @@ -432,7 +432,9 @@ class ContextAnalyzer : public MixedModeVisitor {
}

// Check if a function is a closure.
bool IsClosure(const Function& func) { return func->GetAttr<Integer>(attr::kClosure, 0) != 0; }
bool IsClosure(const Function& func) {
return func->attrs.GetAttr<Integer>(attr::kClosure, 0) != 0;
}

// Check if a function is a currying function.
bool IsCurrying(const Function& func) {
Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/extract_fused_functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class FusedFunctionExtractorWrapper : private ExprVisitor {
Map<String, Function> functions;

void VisitExpr_(const FunctionNode* n) final {
if (n->HasNonzeroAttr(attr::kPrimitive)) {
if (n->attrs.HasNonzeroAttr(attr::kPrimitive)) {
// Add function to functions, keyed by function hash string
Function func = Function(n->params, n->body, n->ret_type, n->type_params, n->attrs);
size_t hash_ = tvm::StructuralHash()(func);
Expand Down
2 changes: 1 addition & 1 deletion src/relay/analysis/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ FeatureSet DetectFeature(const Expr& expr) {
DETECT_DEFAULT_CONSTRUCT(Tuple)
DETECT_DEFAULT_CONSTRUCT(TupleGetItem)
DETECT_CONSTRUCT(Function, {
if (!op->HasNonzeroAttr(attr::kPrimitive)) {
if (!op->attrs.HasNonzeroAttr(attr::kPrimitive)) {
ExprVisitor::VisitExpr_(op);
}
})
Expand Down
6 changes: 3 additions & 3 deletions src/relay/analysis/get_calibration_data.cc
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Collector : public ExprRewriter {
ICHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined";
// we only handle functions with Compiler attribute set
auto func = Downcast<Function>(module_->Lookup(var));
if (func->GetAttr<String>(attr::kCompiler)) {
if (func->attrs.GetAttr<String>(attr::kCompiler)) {
// collect all the inputs and outputs
for (const auto& it : call->args) new_outputs_.push_back(it);
new_outputs_.push_back(post);
Expand Down Expand Up @@ -110,7 +110,7 @@ IRModule GetCalibrateModule(IRModule module) {
for (const auto& pair : glob_funcs) {
if (auto* fn = pair.second.as<FunctionNode>()) {
auto func = GetRef<Function>(fn);
if (func->GetAttr<String>(attr::kCompiler)) {
if (func->attrs.GetAttr<String>(attr::kCompiler)) {
// we need to inline the functions in order to run grpah runtime
func = WithAttr(std::move(func), attr::kInline, tvm::Integer(1));
// reset the compiler attribute to null for llvm execution
Expand Down Expand Up @@ -145,7 +145,7 @@ class OutputMapper : public ExprRewriter {
<< "Repeated function call " << var << " is not supported.";
auto func = Downcast<Function>(module_->Lookup(var));
// we only handle functions with Compiler attribute set
if (func->GetAttr<String>(attr::kCompiler)) {
if (func->attrs.GetAttr<String>(attr::kCompiler)) {
Array<Integer> info;
// the first value is the offset
info.push_back(Integer(*offset_));
Expand Down
2 changes: 1 addition & 1 deletion src/relay/backend/aot_executor_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -440,7 +440,7 @@ class AOTExecutorCodegen : public ExprVisitor {
void VisitExpr_(const GlobalVarNode* op) override { throw std::runtime_error(""); }
void VisitExpr_(const IfNode* op) override { throw std::invalid_argument("if not supported"); }
void VisitExpr_(const FunctionNode* op) override {
ICHECK(op->GetAttr<String>(attr::kCompiler).defined())
ICHECK(op->attrs.GetAttr<String>(attr::kCompiler).defined())
<< "FunctionNode only supported by custom codegen";
}
void VisitExpr_(const RefCreateNode* op) override {
Expand Down
Loading