Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REFACTOR] Use more TypedPackedFuncs #2981

Merged
merged 15 commits into from
Apr 10, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 163 additions & 0 deletions include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,169 @@ class Registry {
Registry& set_body_typed(FLambda f) {
return set_body(TypedPackedFunc<FType>(f).packed());
}

/*!
* \brief set the body of the function to the given function pointer.
* Note that this doesn't work with lambdas, you need to
* explicitly give a type for those.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* int multiply(int x, int y) {
* return x * y;
* }
*
* TVM_REGISTER_API("multiply")
* .set_body_typed(multiply); // will have type int(int, int)
*
* \endcode
*
* \param f The function to forward to.
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename R, typename ...Args>
Registry& set_body_typed(R (*f)(Args...)) {
return set_body(TypedPackedFunc<R(Args...)>(f));
}

/*!
* \brief set the body of the function to be the passed method pointer.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam T the type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...)) {
return set_body_typed<R(T, Args...)>([f](T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
});
}

/*!
* \brief set the body of the function to be the passed method pointer.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam T the type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...) const) {
return set_body_typed<R(T, Args...)>([f](const T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
});
}

/*!
* \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a NodeRef subclass.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct ExampleNode: BaseNode {
* int doThing(int x);
* }
*
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
* // .set_body_method(&ExampleNode::doThing);
* // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam TNodeRef the node reference type to call the method on
* \tparam TNode the node type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename TNodeRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...)) {
return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
});
}

/*!
* \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a NodeRef subclass.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct ExampleNode: BaseNode {
* int doThing(int x);
* }
*
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
* // .set_body_method(&ExampleNode::doThing);
* // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam TNodeRef the node reference type to call the method on
* \tparam TNode the node type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename TNodeRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...) const) {
return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
const TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
});
}

/*!
* \brief Register a function with given name
* \param name The name of the function.
Expand Down
4 changes: 1 addition & 3 deletions nnvm/src/compiler/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.GraphKeyGetGraph")
});

TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = GraphKeyNode::make(args[0], args[1], args[2]);
});
.set_body_typed(GraphKeyNode::make);

// This can be used to extract workloads from nnvm compiler
TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
Expand Down
4 changes: 1 addition & 3 deletions nnvm/src/compiler/graph_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,6 @@ std::string GraphDeepCompare(const Graph& a,
}

TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = GraphDeepCompare(args[0], args[1], args[2]);
});
.set_body_typed(GraphDeepCompare);
} // namespace compiler
} // namespace nnvm
60 changes: 19 additions & 41 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,73 +31,51 @@ namespace tvm {
namespace arith {

TVM_REGISTER_API("arith.intset_single_point")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::single_point(args[0]);
});
.set_body_typed(IntSet::single_point);

TVM_REGISTER_API("arith.intset_vector")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::vector(args[0]);
});
.set_body_typed(IntSet::vector);

TVM_REGISTER_API("arith.intset_interval")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::interval(args[0], args[1]);
});
.set_body_typed(IntSet::interval);

TVM_REGISTER_API("arith.DetectLinearEquation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DetectLinearEquation(args[0], args[1]);
});
.set_body_typed(DetectLinearEquation);

TVM_REGISTER_API("arith.DetectClipBound")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DetectClipBound(args[0], args[1]);
});
.set_body_typed(DetectClipBound);

TVM_REGISTER_API("arith.DeduceBound")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1],
args[2].operator Map<Var, IntSet>(),
args[3].operator Map<Var, IntSet>());
});
.set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([](
Expr v, Expr cond,
const Map<Var, IntSet> hint_map,
const Map<Var, IntSet> relax_map
) {
return DeduceBound(v, cond, hint_map, relax_map);
});


TVM_REGISTER_API("arith.DomainTouched")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DomainTouched(args[0], args[1], args[2], args[3]);
});
.set_body_typed(DomainTouched);


TVM_REGISTER_API("_IntervalSetGetMin")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().min();
});
.set_body_method(&IntSet::min);

TVM_REGISTER_API("_IntervalSetGetMax")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().max();
});
.set_body_method(&IntSet::max);

TVM_REGISTER_API("_IntSetIsNothing")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_nothing();
});
.set_body_method(&IntSet::is_nothing);

TVM_REGISTER_API("_IntSetIsEverything")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_everything();
});
.set_body_method(&IntSet::is_everything);

TVM_REGISTER_API("arith._make_ConstIntBound")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstIntBoundNode::make(args[0], args[1]);
});
.set_body_typed(ConstIntBoundNode::make);

TVM_REGISTER_API("arith._make_ModularSet")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ModularSetNode::make(args[0], args[1]);
});
.set_body_typed(ModularSetNode::make);

TVM_REGISTER_API("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Expand Down
5 changes: 2 additions & 3 deletions src/api/api_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ TVM_REGISTER_API("_load_json")
.set_body_typed<NodeRef(std::string)>(LoadJSON<NodeRef>);

TVM_REGISTER_API("_TVMSetStream")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMSetStream(args[0], args[1], args[2]);
});
.set_body_typed(TVMSetStream);

TVM_REGISTER_API("_save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_EQ(args.size() % 2, 0u);
Expand Down
4 changes: 1 addition & 3 deletions src/api/api_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ TVM_REGISTER_API("codegen._Build")
});

TVM_REGISTER_API("module._PackImportsToC")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = PackImportsToC(args[0], args[1]);
});
.set_body_typed(PackImportsToC);
} // namespace codegen
} // namespace tvm
Loading