Skip to content

Commit

Permalink
[REFACTOR] TVM_REGISTER_API -> TVM_REGISTER_GLOBAL
Browse files Browse the repository at this point in the history
TVM_REGSISTER_API is an alias of TVM_REGISTER_GLOBAL.
In the spirit of simplify redirections, this PR removes
the original TVM_REGISTER_API macro and directly use TVM_REGISTER_GLOBAL.

This type of refactor will also simplify the IDE navigation tools
such as FFI navigator to provide better code reading experiences.

Move EnvFunc's definition to node.
  • Loading branch information
tqchen committed Jan 4, 2020
1 parent 1ecd3ee commit 2a25230
Show file tree
Hide file tree
Showing 127 changed files with 579 additions and 527 deletions.
1 change: 0 additions & 1 deletion include/tvm/codegen.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "base.h"
#include "expr.h"
#include "lowered_func.h"
#include "api_registry.h"
#include "runtime/packed_func.h"

namespace tvm {
Expand Down
32 changes: 9 additions & 23 deletions include/tvm/api_registry.h → include/tvm/node/env_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,33 +18,19 @@
*/

/*!
* \file tvm/api_registry.h
* \brief This file contains utilities related to
* the TVM's global function registry.
* \file tvm/env_func.h
* \brief Env function
*/
#ifndef TVM_API_REGISTRY_H_
#define TVM_API_REGISTRY_H_
#ifndef TVM_NODE_ENV_FUNC_H_
#define TVM_NODE_ENV_FUNC_H_

#include <tvm/node/reflection.h>

#include <string>
#include <utility>
#include "base.h"
#include "packed_func_ext.h"
#include "runtime/registry.h"

namespace tvm {
/*!
* \brief Register an API function globally.
* It simply redirects to TVM_REGISTER_GLOBAL
*
* \code
* TVM_REGISTER_API(MyPrint)
* .set_body([](TVMArgs args, TVMRetValue* rv) {
* // my code.
* });
* \endcode
*/
#define TVM_REGISTER_API(OpName) TVM_REGISTER_GLOBAL(OpName)

namespace tvm {
/*!
* \brief Node container of EnvFunc
* \sa EnvFunc
Expand All @@ -54,7 +40,7 @@ class EnvFuncNode : public Object {
/*! \brief Unique name of the global function */
std::string name;
/*! \brief The internal packed function */
PackedFunc func;
runtime::PackedFunc func;
/*! \brief constructor */
EnvFuncNode() {}

Expand Down Expand Up @@ -154,4 +140,4 @@ class TypedEnvFunc<R(Args...)> : public ObjectRef {
};

} // namespace tvm
#endif // TVM_API_REGISTRY_H_
#endif // TVM_NODE_ENV_FUNC_H_
2 changes: 1 addition & 1 deletion include/tvm/relay/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
#ifndef TVM_RELAY_BASE_H_
#define TVM_RELAY_BASE_H_

#include <tvm/api_registry.h>

#include <tvm/ir/span.h>
#include <tvm/ir.h>
#include <tvm/node/node.h>
Expand Down
6 changes: 5 additions & 1 deletion include/tvm/relay/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,12 @@
#ifndef TVM_RELAY_TYPE_H_
#define TVM_RELAY_TYPE_H_

#include <tvm/api_registry.h>

#include <tvm/ir/type.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>
#include <tvm/node/env_func.h>

#include <tvm/ir.h>
#include <string>

Expand Down
12 changes: 6 additions & 6 deletions include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ class Registry {
*
* \code
*
* TVM_REGISTER_API("addone")
* TVM_REGISTER_GLOBAL("addone")
* .set_body_typed<int(int)>([](int x) { return x + 1; });
*
* \endcode
Expand All @@ -96,7 +96,7 @@ class Registry {
* return x * y;
* }
*
* TVM_REGISTER_API("multiply")
* TVM_REGISTER_GLOBAL("multiply")
* .set_body_typed(multiply); // will have type int(int, int)
*
* \endcode
Expand All @@ -120,7 +120,7 @@ class Registry {
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* TVM_REGISTER_GLOBAL("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
Expand Down Expand Up @@ -148,7 +148,7 @@ class Registry {
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* TVM_REGISTER_GLOBAL("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
Expand Down Expand Up @@ -181,7 +181,7 @@ class Registry {
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* TVM_REGISTER_GLOBAL("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
Expand Down Expand Up @@ -221,7 +221,7 @@ class Registry {
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* TVM_REGISTER_GLOBAL("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
Expand Down
32 changes: 17 additions & 15 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,29 +23,31 @@
*/
#include <tvm/expr.h>
#include <tvm/ir.h>
#include <tvm/api_registry.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>

#include <tvm/tensor.h>

namespace tvm {
namespace arith {

TVM_REGISTER_API("arith.intset_single_point")
TVM_REGISTER_GLOBAL("arith.intset_single_point")
.set_body_typed(IntSet::single_point);

TVM_REGISTER_API("arith.intset_vector")
TVM_REGISTER_GLOBAL("arith.intset_vector")
.set_body_typed(IntSet::vector);

TVM_REGISTER_API("arith.intset_interval")
TVM_REGISTER_GLOBAL("arith.intset_interval")
.set_body_typed(IntSet::interval);


TVM_REGISTER_API("arith.DetectLinearEquation")
TVM_REGISTER_GLOBAL("arith.DetectLinearEquation")
.set_body_typed(DetectLinearEquation);

TVM_REGISTER_API("arith.DetectClipBound")
TVM_REGISTER_GLOBAL("arith.DetectClipBound")
.set_body_typed(DetectClipBound);

TVM_REGISTER_API("arith.DeduceBound")
TVM_REGISTER_GLOBAL("arith.DeduceBound")
.set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([](
Expr v, Expr cond,
const Map<Var, IntSet> hint_map,
Expand All @@ -55,36 +57,36 @@ TVM_REGISTER_API("arith.DeduceBound")
});


TVM_REGISTER_API("arith.DomainTouched")
TVM_REGISTER_GLOBAL("arith.DomainTouched")
.set_body_typed(DomainTouched);

TVM_REGISTER_API("_IntervalSetGetMin")
TVM_REGISTER_GLOBAL("_IntervalSetGetMin")
.set_body_method(&IntSet::min);

TVM_REGISTER_API("_IntervalSetGetMax")
TVM_REGISTER_GLOBAL("_IntervalSetGetMax")
.set_body_method(&IntSet::max);

TVM_REGISTER_API("_IntSetIsNothing")
TVM_REGISTER_GLOBAL("_IntSetIsNothing")
.set_body_method(&IntSet::is_nothing);

TVM_REGISTER_API("_IntSetIsEverything")
TVM_REGISTER_GLOBAL("_IntSetIsEverything")
.set_body_method(&IntSet::is_everything);

ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
return ConstIntBound(min_value, max_value);
}

TVM_REGISTER_API("arith._make_ConstIntBound")
TVM_REGISTER_GLOBAL("arith._make_ConstIntBound")
.set_body_typed(MakeConstIntBound);

ModularSet MakeModularSet(int64_t coeff, int64_t base) {
return ModularSet(coeff, base);
}

TVM_REGISTER_API("arith._make_ModularSet")
TVM_REGISTER_GLOBAL("arith._make_ModularSet")
.set_body_typed(MakeModularSet);

TVM_REGISTER_API("arith._CreateAnalyzer")
TVM_REGISTER_GLOBAL("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
using runtime::PackedFunc;
using runtime::TypedPackedFunc;
Expand Down
16 changes: 9 additions & 7 deletions src/api/api_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,34 +24,36 @@
#include <dmlc/memory_io.h>
#include <tvm/expr.h>
#include <tvm/tensor.h>
#include <tvm/api_registry.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>

#include <tvm/node/serialization.h>

namespace tvm {
TVM_REGISTER_API("_format_str")
TVM_REGISTER_GLOBAL("_format_str")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kObjectHandle);
std::ostringstream os;
os << args[0].operator ObjectRef();
*ret = os.str();
});

TVM_REGISTER_API("_raw_ptr")
TVM_REGISTER_GLOBAL("_raw_ptr")
.set_body([](TVMArgs args, TVMRetValue *ret) {
CHECK(args[0].type_code() == kObjectHandle);
*ret = reinterpret_cast<int64_t>(args[0].value().v_handle);
});

TVM_REGISTER_API("_save_json")
TVM_REGISTER_GLOBAL("_save_json")
.set_body_typed<std::string(ObjectRef)>(SaveJSON);

TVM_REGISTER_API("_load_json")
TVM_REGISTER_GLOBAL("_load_json")
.set_body_typed<ObjectRef(std::string)>(LoadJSON);

TVM_REGISTER_API("_TVMSetStream")
TVM_REGISTER_GLOBAL("_TVMSetStream")
.set_body_typed(TVMSetStream);

TVM_REGISTER_API("_save_param_dict")
TVM_REGISTER_GLOBAL("_save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_EQ(args.size() % 2, 0u);
constexpr uint64_t TVMNDArrayListMagic = 0xF7E58D4F05049CB7;
Expand Down
8 changes: 5 additions & 3 deletions src/api/api_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@
#include <tvm/ir.h>
#include <tvm/codegen.h>
#include <tvm/lowered_func.h>
#include <tvm/api_registry.h>
#include <tvm/runtime/registry.h>
#include <tvm/packed_func_ext.h>


namespace tvm {
namespace codegen {

TVM_REGISTER_API("codegen._Build")
TVM_REGISTER_GLOBAL("codegen._Build")
.set_body([](TVMArgs args, TVMRetValue *ret) {
if (args[0].IsObjectRef<LoweredFunc>()) {
*ret = Build({args[0]}, args[1]);
Expand All @@ -39,7 +41,7 @@ TVM_REGISTER_API("codegen._Build")
}
});

TVM_REGISTER_API("module._PackImportsToC")
TVM_REGISTER_GLOBAL("module._PackImportsToC")
.set_body_typed(PackImportsToC);
} // namespace codegen
} // namespace tvm
Loading

0 comments on commit 2a25230

Please sign in to comment.