-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[API/Refactor] Unified PackedFunc for API and Generated Functions (#26)
- Loading branch information
Showing
46 changed files
with
2,375 additions
and
1,723 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,7 @@ language: cpp | |
|
||
os: | ||
- linux | ||
- osx | ||
# - osx | ||
|
||
env: | ||
# code analysis | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
/*! | ||
* Copyright (c) 2016 by Contributors | ||
* \file api_registry.h | ||
* \brief This file defines the TVM API registry. | ||
* | ||
* The API registry stores type-erased functions. | ||
* Each registered function is automatically exposed | ||
* to front-end language(e.g. python). | ||
* Front-end can also pass callbacks as PackedFunc, or register | ||
* then into the same global registry in C++. | ||
* The goal is to mix the front-end language and the TVM back-end. | ||
* | ||
* \code | ||
* // register the function as MyAPIFuncName | ||
* TVM_REGISTER_API(MyAPIFuncName) | ||
* .set_body([](TVMArgs args, TVMRetValue* rv) { | ||
* // my code. | ||
* }); | ||
* \endcode | ||
*/ | ||
#ifndef TVM_API_REGISTRY_H_ | ||
#define TVM_API_REGISTRY_H_ | ||
|
||
#include <dmlc/base.h> | ||
#include <string> | ||
#include "./base.h" | ||
#include "./runtime/packed_func.h" | ||
#include "./packed_func_ext.h" | ||
|
||
namespace tvm { | ||
|
||
/*! \brief Utility to register API. */ | ||
class APIRegistry { | ||
public: | ||
/*! | ||
* \brief set the body of the function to be f | ||
* \param f The body of the function. | ||
*/ | ||
APIRegistry& set_body(PackedFunc f); // NOLINT(*) | ||
/*! | ||
* \brief set the body of the function to be f | ||
* \param f The body of the function. | ||
*/ | ||
APIRegistry& set_body(PackedFunc::FType f) { // NOLINT(*) | ||
return set_body(PackedFunc(f)); | ||
} | ||
/*! | ||
* \brief Register a function with given name | ||
* \param name The name of the function. | ||
*/ | ||
static APIRegistry& __REGISTER__(const std::string& name); // NOLINT(*) | ||
|
||
private: | ||
/*! \brief name of the function */ | ||
std::string name_; | ||
}; | ||
|
||
/*! | ||
* \brief Get API function by name. | ||
* | ||
* \param name The name of the function. | ||
* \return the corresponding API function. | ||
* \note It is really PackedFunc::GetGlobal under the hood. | ||
*/ | ||
inline PackedFunc GetAPIFunc(const std::string& name) { | ||
return PackedFunc::GetGlobal(name); | ||
} | ||
|
||
#define _TVM_REGISTER_VAR_DEF_ \ | ||
static DMLC_ATTRIBUTE_UNUSED ::tvm::APIRegistry& __make_TVMRegistry_ | ||
|
||
/*! | ||
* \brief Register API function globally. | ||
* \code | ||
* TVM_REGISTER_API(MyPrint) | ||
* .set_body([](TVMArgs args, TVMRetValue* rv) { | ||
* // my code. | ||
* }); | ||
* \endcode | ||
*/ | ||
#define TVM_REGISTER_API(OpName) \ | ||
DMLC_STR_CONCAT(_TVM_REGISTER_VAR_DEF_, __COUNTER__) = \ | ||
::tvm::APIRegistry::__REGISTER__(#OpName) | ||
} // namespace tvm | ||
#endif // TVM_API_REGISTRY_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,196 @@ | ||
/*! | ||
* Copyright (c) 2016 by Contributors | ||
* \file packed_func_ext.h | ||
* \brief Extension package to PackedFunc | ||
* This enales pass NodeRef types into/from PackedFunc. | ||
*/ | ||
#ifndef TVM_PACKED_FUNC_EXT_H_ | ||
#define TVM_PACKED_FUNC_EXT_H_ | ||
|
||
#include <sstream> | ||
#include <string> | ||
#include <memory> | ||
#include <type_traits> | ||
|
||
#include "./base.h" | ||
#include "./expr.h" | ||
|
||
namespace tvm { | ||
using runtime::TVMArgs; | ||
using runtime::TVMRetValue; | ||
using runtime::PackedFunc; | ||
|
||
namespace runtime { | ||
/*! | ||
* \brief Runtime type checker for node type. | ||
* \tparam T the type to be checked. | ||
*/ | ||
template<typename T> | ||
struct NodeTypeChecker { | ||
static inline bool Check(Node* sptr) { | ||
// This is the only place in the project where RTTI is used | ||
// It can be turned off, but will make non strict checking. | ||
// TODO(tqchen) possibly find alternative to turn of RTTI | ||
using ContainerType = typename T::ContainerType; | ||
return (dynamic_cast<ContainerType*>(sptr) != nullptr); | ||
} | ||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) | ||
using ContainerType = typename T::ContainerType; | ||
os << ContainerType::_type_key; | ||
} | ||
}; | ||
|
||
template<typename T> | ||
struct NodeTypeChecker<Array<T> > { | ||
static inline bool Check(Node* sptr) { | ||
if (sptr == nullptr) return false; | ||
if (!sptr->is_type<ArrayNode>()) return false; | ||
ArrayNode* n = static_cast<ArrayNode*>(sptr); | ||
for (const auto& p : n->data) { | ||
if (!NodeTypeChecker<T>::Check(p.get())) return false; | ||
} | ||
return true; | ||
} | ||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) | ||
os << "array<"; | ||
NodeTypeChecker<T>::PrintName(os); | ||
os << ">"; | ||
} | ||
}; | ||
|
||
template<typename K, typename V> | ||
struct NodeTypeChecker<Map<K, V> > { | ||
static inline bool Check(Node* sptr) { | ||
if (sptr == nullptr) return false; | ||
if (!sptr->is_type<MapNode>()) return false; | ||
MapNode* n = static_cast<MapNode*>(sptr); | ||
for (const auto& kv : n->data) { | ||
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false; | ||
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false; | ||
} | ||
return true; | ||
} | ||
static inline void PrintName(std::ostringstream& os) { // NOLINT(*) | ||
os << "map<"; | ||
NodeTypeChecker<K>::PrintName(os); | ||
os << ','; | ||
NodeTypeChecker<V>::PrintName(os); | ||
os << '>'; | ||
} | ||
}; | ||
|
||
template<typename T> | ||
inline std::string NodeTypeName() { | ||
std::ostringstream os; | ||
NodeTypeChecker<T>::PrintName(os); | ||
return os.str(); | ||
} | ||
|
||
// extensions for tvm arg value | ||
|
||
template<typename TNodeRef, typename> | ||
inline TVMArgValue::operator TNodeRef() const { | ||
static_assert( | ||
std::is_base_of<NodeRef, TNodeRef>::value, | ||
"Conversion only works for NodeRef"); | ||
if (type_code_ == kNull) return TNodeRef(); | ||
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); | ||
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >(); | ||
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get())) | ||
<< "Expected type " << NodeTypeName<TNodeRef>() | ||
<< " but get " << sptr->type_key(); | ||
return TNodeRef(sptr); | ||
} | ||
|
||
inline TVMArgValue::operator Halide::Expr() const { | ||
if (type_code_ == kNull) return Expr(); | ||
if (type_code_ == kInt) { | ||
return Expr(static_cast<int>(value_.v_int64)); | ||
} | ||
if (type_code_ == kFloat) { | ||
return Expr(static_cast<float>(value_.v_float64)); | ||
} | ||
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); | ||
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >(); | ||
if (sptr->is_type<IterVarNode>()) { | ||
return IterVar(sptr)->var; | ||
} | ||
CHECK(NodeTypeChecker<Expr>::Check(sptr.get())) | ||
<< "Expected type " << NodeTypeName<Expr>() | ||
<< " but get " << sptr->type_key(); | ||
return Expr(sptr); | ||
} | ||
|
||
inline std::shared_ptr<Node>& TVMArgValue::node_sptr() { | ||
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); | ||
return *ptr<std::shared_ptr<Node> >(); | ||
} | ||
|
||
|
||
template<typename TNodeRef, typename> | ||
inline bool TVMArgValue::IsNodeType() const { | ||
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); | ||
std::shared_ptr<Node>& sptr = | ||
*ptr<std::shared_ptr<Node> >(); | ||
return NodeTypeChecker<TNodeRef>::Check(sptr.get()); | ||
} | ||
|
||
// extensions for TVMRetValue | ||
inline TVMRetValue& TVMRetValue::operator=( | ||
const std::shared_ptr<Node>& other) { | ||
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other); | ||
return *this; | ||
} | ||
|
||
inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) { | ||
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_); | ||
return *this; | ||
} | ||
|
||
template<typename TNodeRef, typename> | ||
inline TVMRetValue::operator TNodeRef() const { | ||
static_assert( | ||
std::is_base_of<NodeRef, TNodeRef>::value, | ||
"Conversion only works for NodeRef"); | ||
if (type_code_ == kNull) return TNodeRef(); | ||
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle); | ||
return TNodeRef(*ptr<std::shared_ptr<Node> >()); | ||
} | ||
|
||
inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLINT(*) | ||
values_[i].v_handle = &(other.node_); | ||
type_codes_[i] = kNodeHandle; | ||
} | ||
|
||
// Type related stuffs | ||
inline Type TVMType2Type(TVMType t) { | ||
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes); | ||
} | ||
|
||
inline TVMType Type2TVMType(Type t) { | ||
TVMType ret; | ||
ret.code = static_cast<uint8_t>(t.code()); | ||
ret.bits = static_cast<uint8_t>(t.bits()); | ||
ret.lanes = static_cast<uint16_t>(t.lanes()); | ||
return ret; | ||
} | ||
|
||
inline TVMRetValue& TVMRetValue::operator=(const Halide::Type& t) { | ||
return this->operator=(Type2TVMType(t)); | ||
} | ||
|
||
inline TVMRetValue::operator Halide::Type() const { | ||
return TVMType2Type(operator TVMType()); | ||
} | ||
|
||
inline TVMArgValue::operator Halide::Type() const { | ||
return TVMType2Type(operator TVMType()); | ||
} | ||
|
||
inline void TVMArgsSetter::operator()( | ||
size_t i, const Halide::Type& t) const { | ||
this->operator()(i, Type2TVMType(t)); | ||
} | ||
} // namespace runtime | ||
} // namespace tvm | ||
#endif // TVM_PACKED_FUNC_EXT_H_ |
Oops, something went wrong.