Skip to content

Commit

Permalink
[REFACTOR] Streamline Function Attr interface. (#5045)
Browse files Browse the repository at this point in the history
* [REFACTOR] Streamline Function Attr interface.

There has been quite a few recent changes that depends heavily on
the function attr interface. This PR streamlines that interface by introducing
two APIs that covers most of the usages.

- GetAttr which gets a typed object for a given key
  - HasNonzeroAttr is a quick helper that calls GetAttr to quickly check an attribute
- WithAttr that creates a new function object with the given attr
  - The API comes with copy on write optimization to avoid multiple copies
  - We deliberately pick the prefix With(instead of Set) to indicate this
    function does not mutate the original input.

On the python side:
- We allow read access via func.attrs (which is a DictAttr)
- func.with_attrs to create a new instance with updated attrs.

We also get rid of the small wrapper functions and make sure the API centered around
the GetAttr and HasNonzeroAttr interface.

This PR also changes the function construction to follow the new convention.

* Address review comments

* Address review comments

* Fix doxygen path
  • Loading branch information
tqchen authored Mar 12, 2020
1 parent a950536 commit ec86d7f
Show file tree
Hide file tree
Showing 70 changed files with 810 additions and 587 deletions.
47 changes: 23 additions & 24 deletions include/tvm/ir/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -277,26 +277,13 @@ class BaseAttrsNode : public Object {
TVM_DECLARE_BASE_OBJECT_INFO(BaseAttrsNode, Object);
};

/*! \brief Base attribute container for all attributes */
/*!
* \brief Managed reference to BaseAttrsNode.
* \sa AttrsNode, BaseAttrsNode
*/
class Attrs : public ObjectRef {
public:
// normal constructor
Attrs() {}
// construct from shared ptr.
explicit Attrs(ObjectPtr<Object> n) : ObjectRef(n) {}

/*! \return The attribute node */
const BaseAttrsNode* operator->() const {
return ptr();
}
/*! \brief specify container node */
using ContainerType = BaseAttrsNode;

private:
/*! \return the internal attribute node */
const BaseAttrsNode* ptr() const {
return static_cast<const BaseAttrsNode*>(get());
}
TVM_DEFINE_OBJECT_REF_METHODS(Attrs, ObjectRef, BaseAttrsNode);
};

/*!
Expand All @@ -309,12 +296,7 @@ class DictAttrsNode : public BaseAttrsNode {
public:
/*! \brief internal attrs map */
Map<std::string, ObjectRef> dict;
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL static Attrs make(Map<std::string, ObjectRef> dict);

// implementations
void VisitAttrs(AttrVisitor* v) final;
void VisitNonDefaultAttrs(AttrVisitor* v) final;
Expand All @@ -327,6 +309,23 @@ class DictAttrsNode : public BaseAttrsNode {
TVM_DECLARE_FINAL_OBJECT_INFO(DictAttrsNode, BaseAttrsNode);
};

/*!
* \brief Managed reference to DictAttrsNode
* \sa DictAttrsNode.
*/
class DictAttrs : public Attrs {
public:
/*!
* \brief Consruct a Attrs backed by DictAttrsNode.
* \param dict The attributes.
* \return The dict attributes.
*/
TVM_DLL explicit DictAttrs(Map<std::string, ObjectRef> dict);


TVM_DEFINE_OBJECT_REF_METHODS(DictAttrs, Attrs, DictAttrsNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(DictAttrsNode);
};

// Namespace containing detail implementations
namespace detail {
Expand Down
24 changes: 0 additions & 24 deletions include/tvm/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -211,30 +211,6 @@ class GlobalVar : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(GlobalVar, RelayExpr, GlobalVarNode);
};

/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
* All of the functions shares the same type system(via checked_type)
* to support cross variant calls.
*
* \sa BaseFunc
*/
class BaseFuncNode : public RelayExprNode {
public:
static constexpr const char* _type_key = "BaseFunc";
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};

/*!
* \brief Managed reference to BaseFuncNode.
* \sa BaseFuncNode
*/
class BaseFunc : public RelayExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};

// PrimExprs that are useful as runtime containers.
//
/*!
Expand Down
119 changes: 119 additions & 0 deletions include/tvm/ir/function.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ir/function.h
* \brief Function nodes.
*/
#ifndef TVM_IR_FUNCTION_H_
#define TVM_IR_FUNCTION_H_

#include <tvm/ir/expr.h>
#include <tvm/ir/attrs.h>
#include <type_traits>
#include <string>


namespace tvm {

/*!
* \brief Base node of all functions.
*
* We support several variants of functions throughout the stack.
* All of the functions share the same type system(via checked_type)
* to support cross variant calls.
*
* \sa BaseFunc
*/
class BaseFuncNode : public RelayExprNode {
public:
/*! \brief Additional attributes storing the meta-data */
DictAttrs attrs;

/*!
* \brief Get a function attribute.
*
* \param attr_key The attribute key.
* \param default_value The default value if the key does not exist, defaults to nullptr.
*
* \return The result
*
* \tparam TOBjectRef the expected object type.
* \throw Error if the key exists but the value does not match TObjectRef
*
* \code
*
* void GetAttrExample(const BaseFunc& f) {
* Integer value = f->GetAttr<Integer>("AttrKey", 0);
* }
*
* \endcode
*/
template<typename TObjectRef>
TObjectRef GetAttr(const std::string& attr_key,
TObjectRef default_value = NullValue<TObjectRef>()) const {
static_assert(std::is_base_of<ObjectRef, TObjectRef>::value,
"Can only call GetAttr with ObjectRef types.");
if (!attrs.defined()) return default_value;
auto it = attrs->dict.find(attr_key);
if (it != attrs->dict.end()) {
return Downcast<TObjectRef>((*it).second);
} else {
return default_value;
}
}

/*!
* \brief Check whether the function has an non-zero integer attr.
*
* This function can be used to check whether an optional
* attribute mark(e.g. inline) exists.
*
* \param attr_key The key to the attribute.
* \return The check result.
*
* \code
*
* void HasNonzeroAttrExample(const BaseFunc& f) {
* if (f->HasNonzeroAttr(attr::kInline)) {
* // inline the function.
* }
* }
*
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const {
return GetAttr<Integer>(attr_key, 0)->value != 0;
}

static constexpr const char* _type_key = "BaseFunc";
TVM_DECLARE_BASE_OBJECT_INFO(BaseFuncNode, RelayExprNode);
};

/*!
* \brief Managed reference to BaseFuncNode.
* \sa BaseFuncNode
*/
class BaseFunc : public RelayExpr {
public:
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};

} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
1 change: 1 addition & 0 deletions include/tvm/ir/module.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/ir/type.h>
#include <tvm/ir/expr.h>
#include <tvm/ir/function.h>
#include <tvm/ir/adt.h>

#include <string>
Expand Down
1 change: 1 addition & 0 deletions include/tvm/relay/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <tvm/relay/adt.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/function.h>
#include <tvm/ir/module.h>
#include <tvm/relay/type.h>
#include <string>
Expand Down
131 changes: 0 additions & 131 deletions include/tvm/relay/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,113 +164,6 @@ class Var : public Expr {
TVM_DEFINE_OBJECT_REF_METHODS(Var, RelayExpr, VarNode);
};

/*!
* \brief Function (subgraph in computational graph)
*/
class Function;
/*! \brief Function container */
class FunctionNode : public BaseFuncNode {
public:
/*! \brief Function parameters */
tvm::Array<Var> params;
/*!
* \brief
* The expression which represents the computation of the function,
* the expression may reference the parameters, and the type of it
* or sub-expressions may reference the type variables.
*/
Expr body;
/*! \brief User annotated return type of the function. */
Type ret_type;
/*!
* \brief Type parameters of the function.
* Enables the function to vary its type based on these.
* This corresponds to template paramaters in c++'s terminology.
*
* \note This can be usually empty for non-polymorphic functions.
*/
tvm::Array<TypeVar> type_params;

/*!
* \brief The attributes which store metadata about functions.
*/
tvm::Attrs attrs;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("params", &params);
v->Visit("body", &body);
v->Visit("ret_type", &ret_type);
v->Visit("type_params", &type_params);
v->Visit("attrs", &attrs);
v->Visit("span", &span);
v->Visit("_checked_type_", &checked_type_);
}

/*!
* \brief Return the derived function annotation of this expression.
*
* \return The function type annotation.
* \note The function type annotation can contain IncompleteType.
*/
TVM_DLL FuncType func_type_annotation() const;

/*!
* \brief Check whether the function is a primitive function.
*
* \return Whether the function is primitive or not.
*/
bool IsPrimitive() const;

/*!
* \brief Check whether the function is marked as inline.
*
* \return Whether the function should be inlined or not.
*/
bool IsMarkedInline() const;

/*!
* \brief Check whether the function should use the TVM default compiler to build, or
* use other compilers.
*
* \return Whether the function will be compiled using the default compiler
* (e.g. those are used in the TVM stack).
*/
bool UseDefaultCompiler() const;

TVM_DLL static Function make(tvm::Array<Var> params,
Expr body,
Type ret_type,
tvm::Array<TypeVar> ty_params,
tvm::Attrs attrs = Attrs());

/*!
* \brief Attach the function's parameters to its attributes for use in analysis.
* \return The function with its parameters attached.
*/
Function SetParams(const tvm::Map<Var, Constant>& parameters) const;

/*!
* \brief Retrieve the function's parameters.
*
* \return The function's parameter.
*/
tvm::Map<Var, Constant> GetParams() const;

static constexpr const char* _type_key = "relay.Function";
TVM_DECLARE_FINAL_OBJECT_INFO(FunctionNode, BaseFuncNode);
};

class Function : public BaseFunc {
public:
TVM_DEFINE_OBJECT_REF_METHODS(Function, BaseFunc, FunctionNode);
};


TVM_DLL ObjectRef FunctionGetAttr(const Function& func, const std::string& key);
TVM_DLL Function FunctionSetAttr(const Function& func,
const std::string& key,
const ObjectRef& data);

/*!
* \brief Call corresponds to operator invocation.
* Corresponds to the operator in computational graph terminology.
Expand Down Expand Up @@ -550,30 +443,6 @@ class TempExpr : public Expr {
TVM_DEFINE_OBJECT_REF_METHODS(TempExpr, RelayExpr, TempExprNode);
};


/*! \brief namespace of the attributes that are attached to a function. */
namespace attr {
/*! \brief Mark the function as a primitive function. */
constexpr const char* kPrimitive = "Primitive";
/*!
* \brief Indicate the compiler that should be used for builing this function.
* When this is unset or set to "default", the default compilation pipeline will be used.
*/
constexpr const char* kCompiler = "Compiler";
/*! \brief Indicate if the function is a closure. */
constexpr const char* kClosure = "Closure";
/*! \brief Store a Var to parameter/Constant mapping on a Function. */
constexpr const char* kParams = "__params__";
/*! \brief Store the unique external symbol for external compilers. */
constexpr const char* kExternalSymbol = "ExternalSymbol";
/*! \brief Mark if the function should be avoided being optimized. */
constexpr const char* kSkipOptimization = "SkipOptimization";
/*! \brief Treat the function as a composite operator. */
constexpr const char* kComposite = "Composite";
/*! \brief Mark the function to be inlined. */
constexpr const char* kInline = "Inline";
} // namespace attr

} // namespace relay
} // namespace tvm
#endif // TVM_RELAY_EXPR_H_
Loading

0 comments on commit ec86d7f

Please sign in to comment.