Skip to content

Commit

Permalink
[API/Refactor] Unified PackedFunc for API and Generated Functions (#26)
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen authored Jan 29, 2017
1 parent 4242b9c commit ff06917
Show file tree
Hide file tree
Showing 46 changed files with 2,375 additions and 1,723 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ language: cpp

os:
- linux
- osx
# - osx

env:
# code analysis
Expand Down
85 changes: 85 additions & 0 deletions include/tvm/api_registry.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*!
* Copyright (c) 2016 by Contributors
* \file api_registry.h
* \brief This file defines the TVM API registry.
*
* The API registry stores type-erased functions.
* Each registered function is automatically exposed
* to front-end language(e.g. python).
* Front-end can also pass callbacks as PackedFunc, or register
* then into the same global registry in C++.
* The goal is to mix the front-end language and the TVM back-end.
*
* \code
* // register the function as MyAPIFuncName
* TVM_REGISTER_API(MyAPIFuncName)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#ifndef TVM_API_REGISTRY_H_
#define TVM_API_REGISTRY_H_

#include <dmlc/base.h>
#include <string>
#include "./base.h"
#include "./runtime/packed_func.h"
#include "./packed_func_ext.h"

namespace tvm {

/*! \brief Utility to register API. */
class APIRegistry {
public:
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
APIRegistry& set_body(PackedFunc f); // NOLINT(*)
/*!
* \brief set the body of the function to be f
* \param f The body of the function.
*/
APIRegistry& set_body(PackedFunc::FType f) { // NOLINT(*)
return set_body(PackedFunc(f));
}
/*!
* \brief Register a function with given name
* \param name The name of the function.
*/
static APIRegistry& __REGISTER__(const std::string& name); // NOLINT(*)

private:
/*! \brief name of the function */
std::string name_;
};

/*!
* \brief Get API function by name.
*
* \param name The name of the function.
* \return the corresponding API function.
* \note It is really PackedFunc::GetGlobal under the hood.
*/
inline PackedFunc GetAPIFunc(const std::string& name) {
return PackedFunc::GetGlobal(name);
}

#define _TVM_REGISTER_VAR_DEF_ \
static DMLC_ATTRIBUTE_UNUSED ::tvm::APIRegistry& __make_TVMRegistry_

/*!
* \brief Register API function globally.
* \code
* TVM_REGISTER_API(MyPrint)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#define TVM_REGISTER_API(OpName) \
DMLC_STR_CONCAT(_TVM_REGISTER_VAR_DEF_, __COUNTER__) = \
::tvm::APIRegistry::__REGISTER__(#OpName)
} // namespace tvm
#endif // TVM_API_REGISTRY_H_
74 changes: 7 additions & 67 deletions include/tvm/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,83 +2,23 @@
* Copyright (c) 2016 by Contributors
* \file c_api.h
* \brief C API of TVM DSL
*
* \note The API is designed in a minimum way.
* Most of the API functions are registered and can be pulled out.
*
* The common flow is:
* - Use TVMFuncListGlobalNames to get global function name
* - Use TVMFuncCall to call these functions.
*/
#ifndef TVM_C_API_H_
#define TVM_C_API_H_

#include "./runtime/c_runtime_api.h"

TVM_EXTERN_C {
/*! \brief handle to functions */
typedef void* APIFuncHandle;
/*! \brief handle to node */
typedef void* NodeHandle;

/*!
* \brief List all the node function name
* \param out_size The number of functions
* \param out_array The array of function names.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMListAPIFuncNames(int *out_size,
const char*** out_array);
/*!
* \brief get function handle by name
* \param name The name of function
* \param handle The returning function handle
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetAPIFuncHandle(const char* name,
APIFuncHandle *handle);

/*!
* \brief Get the detailed information about function.
* \param handle The operator handle.
* \param real_name The returned name of the function.
* This name is not the alias name of the atomic symbol.
* \param description The returned description of the symbol.
* \param num_doc_args Number of arguments that contain documents.
* \param arg_names Name of the arguments of doc args
* \param arg_type_infos Type informations about the arguments.
* \param arg_descriptions Description information about the arguments.
* \param return_type Return type of the function, if any.
* \return 0 when success, -1 when failure happens
*/
TVM_DLL int TVMGetAPIFuncInfo(APIFuncHandle handle,
const char **real_name,
const char **description,
int *num_doc_args,
const char ***arg_names,
const char ***arg_type_infos,
const char ***arg_descriptions,
const char **return_type);

/*!
* \brief Push an argument to the function calling stack.
* If push fails, the stack will be reset to empty
*
* \param arg The argument
* \param type_code The type_code of argument as in TVMTypeCode
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/
TVM_DLL int TVMAPIPushStack(TVMValue arg,
int type_code);

/*!
* \brief call a function by using arguments in the stack.
* The stack will be cleanup to empty after this call, whether the call is successful.
*
* \param handle The function handle
* \param ret_val The return value.
* \param ret_type_code the type code of return value.
* \return 0 when success, -1 when failure happens
* \note API calls always exchanges with type bits=64, lanes=1
*/
TVM_DLL int TVMAPIFuncCall(APIFuncHandle handle,
TVMValue* ret_val,
int* ret_type_code);

/*!
* \brief free the node handle
* \param handle The node handle to be freed.
Expand Down
1 change: 1 addition & 0 deletions include/tvm/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string>
#include <algorithm>
#include "./base.h"
#include "./runtime/packed_func.h"

namespace tvm {

Expand Down
196 changes: 196 additions & 0 deletions include/tvm/packed_func_ext.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
/*!
* Copyright (c) 2016 by Contributors
* \file packed_func_ext.h
* \brief Extension package to PackedFunc
* This enales pass NodeRef types into/from PackedFunc.
*/
#ifndef TVM_PACKED_FUNC_EXT_H_
#define TVM_PACKED_FUNC_EXT_H_

#include <sstream>
#include <string>
#include <memory>
#include <type_traits>

#include "./base.h"
#include "./expr.h"

namespace tvm {
using runtime::TVMArgs;
using runtime::TVMRetValue;
using runtime::PackedFunc;

namespace runtime {
/*!
* \brief Runtime type checker for node type.
* \tparam T the type to be checked.
*/
template<typename T>
struct NodeTypeChecker {
static inline bool Check(Node* sptr) {
// This is the only place in the project where RTTI is used
// It can be turned off, but will make non strict checking.
// TODO(tqchen) possibly find alternative to turn of RTTI
using ContainerType = typename T::ContainerType;
return (dynamic_cast<ContainerType*>(sptr) != nullptr);
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
using ContainerType = typename T::ContainerType;
os << ContainerType::_type_key;
}
};

template<typename T>
struct NodeTypeChecker<Array<T> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<ArrayNode>()) return false;
ArrayNode* n = static_cast<ArrayNode*>(sptr);
for (const auto& p : n->data) {
if (!NodeTypeChecker<T>::Check(p.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "array<";
NodeTypeChecker<T>::PrintName(os);
os << ">";
}
};

template<typename K, typename V>
struct NodeTypeChecker<Map<K, V> > {
static inline bool Check(Node* sptr) {
if (sptr == nullptr) return false;
if (!sptr->is_type<MapNode>()) return false;
MapNode* n = static_cast<MapNode*>(sptr);
for (const auto& kv : n->data) {
if (!NodeTypeChecker<K>::Check(kv.first.get())) return false;
if (!NodeTypeChecker<V>::Check(kv.second.get())) return false;
}
return true;
}
static inline void PrintName(std::ostringstream& os) { // NOLINT(*)
os << "map<";
NodeTypeChecker<K>::PrintName(os);
os << ',';
NodeTypeChecker<V>::PrintName(os);
os << '>';
}
};

template<typename T>
inline std::string NodeTypeName() {
std::ostringstream os;
NodeTypeChecker<T>::PrintName(os);
return os.str();
}

// extensions for tvm arg value

template<typename TNodeRef, typename>
inline TVMArgValue::operator TNodeRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
CHECK(NodeTypeChecker<TNodeRef>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<TNodeRef>()
<< " but get " << sptr->type_key();
return TNodeRef(sptr);
}

inline TVMArgValue::operator Halide::Expr() const {
if (type_code_ == kNull) return Expr();
if (type_code_ == kInt) {
return Expr(static_cast<int>(value_.v_int64));
}
if (type_code_ == kFloat) {
return Expr(static_cast<float>(value_.v_float64));
}
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr = *ptr<std::shared_ptr<Node> >();
if (sptr->is_type<IterVarNode>()) {
return IterVar(sptr)->var;
}
CHECK(NodeTypeChecker<Expr>::Check(sptr.get()))
<< "Expected type " << NodeTypeName<Expr>()
<< " but get " << sptr->type_key();
return Expr(sptr);
}

inline std::shared_ptr<Node>& TVMArgValue::node_sptr() {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return *ptr<std::shared_ptr<Node> >();
}


template<typename TNodeRef, typename>
inline bool TVMArgValue::IsNodeType() const {
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
std::shared_ptr<Node>& sptr =
*ptr<std::shared_ptr<Node> >();
return NodeTypeChecker<TNodeRef>::Check(sptr.get());
}

// extensions for TVMRetValue
inline TVMRetValue& TVMRetValue::operator=(
const std::shared_ptr<Node>& other) {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other);
return *this;
}

inline TVMRetValue& TVMRetValue::operator=(const NodeRef& other) {
SwitchToClass<std::shared_ptr<Node> >(kNodeHandle, other.node_);
return *this;
}

template<typename TNodeRef, typename>
inline TVMRetValue::operator TNodeRef() const {
static_assert(
std::is_base_of<NodeRef, TNodeRef>::value,
"Conversion only works for NodeRef");
if (type_code_ == kNull) return TNodeRef();
TVM_CHECK_TYPE_CODE(type_code_, kNodeHandle);
return TNodeRef(*ptr<std::shared_ptr<Node> >());
}

inline void TVMArgsSetter::operator()(size_t i, NodeRef& other) const { // NOLINT(*)
values_[i].v_handle = &(other.node_);
type_codes_[i] = kNodeHandle;
}

// Type related stuffs
inline Type TVMType2Type(TVMType t) {
return Type(static_cast<halide_type_code_t>(t.code), t.bits, t.lanes);
}

inline TVMType Type2TVMType(Type t) {
TVMType ret;
ret.code = static_cast<uint8_t>(t.code());
ret.bits = static_cast<uint8_t>(t.bits());
ret.lanes = static_cast<uint16_t>(t.lanes());
return ret;
}

inline TVMRetValue& TVMRetValue::operator=(const Halide::Type& t) {
return this->operator=(Type2TVMType(t));
}

inline TVMRetValue::operator Halide::Type() const {
return TVMType2Type(operator TVMType());
}

inline TVMArgValue::operator Halide::Type() const {
return TVMType2Type(operator TVMType());
}

inline void TVMArgsSetter::operator()(
size_t i, const Halide::Type& t) const {
this->operator()(i, Type2TVMType(t));
}
} // namespace runtime
} // namespace tvm
#endif // TVM_PACKED_FUNC_EXT_H_
Loading

0 comments on commit ff06917

Please sign in to comment.