Skip to content

Commit

Permalink
[REFACTOR] Use more TypedPackedFuncs (#2981)
Browse files Browse the repository at this point in the history
* Add `set_body_simple` to Registry, refactor a lot of code to use it

* Add more types to Relay PackedFuncs

* Add Registry::set_body_method to easily make Node methods into
PackedFuncs

* Add set_body_method, set_body_node_method; start typing api_lang

* Add some docs, remove unused script

* Fix mysterious linter problem

* Touch up api_ir.cc

* Fix some issues with TOPI argument counts

* Revert changes to topi.cc to avoid problems with optional arguments

* A little more cleanup

* Type more of the api _ functions

* Whitespace

* Finalize names and docs for new registry helpers

* Update docs
  • Loading branch information
kazimuth authored and tqchen committed Apr 10, 2019
1 parent 57f47a1 commit 5178506
Show file tree
Hide file tree
Showing 68 changed files with 635 additions and 1,090 deletions.
163 changes: 163 additions & 0 deletions include/tvm/runtime/registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,169 @@ class Registry {
Registry& set_body_typed(FLambda f) {
return set_body(TypedPackedFunc<FType>(f).packed());
}

/*!
* \brief set the body of the function to the given function pointer.
* Note that this doesn't work with lambdas, you need to
* explicitly give a type for those.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* int multiply(int x, int y) {
* return x * y;
* }
*
* TVM_REGISTER_API("multiply")
* .set_body_typed(multiply); // will have type int(int, int)
*
* \endcode
*
* \param f The function to forward to.
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename R, typename ...Args>
Registry& set_body_typed(R (*f)(Args...)) {
return set_body(TypedPackedFunc<R(Args...)>(f));
}

/*!
* \brief set the body of the function to be the passed method pointer.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam T the type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...)) {
return set_body_typed<R(T, Args...)>([f](T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
});
}

/*!
* \brief set the body of the function to be the passed method pointer.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct Example {
* int doThing(int x);
* }
* TVM_REGISTER_API("Example_doThing")
* .set_body_method(&Example::doThing); // will have type int(Example, int)
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam T the type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename T, typename R, typename ...Args>
Registry& set_body_method(R (T::*f)(Args...) const) {
return set_body_typed<R(T, Args...)>([f](const T target, Args... params) -> R {
// call method pointer
return (target.*f)(params...);
});
}

/*!
* \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a NodeRef subclass.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct ExampleNode: BaseNode {
* int doThing(int x);
* }
*
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
* // .set_body_method(&ExampleNode::doThing);
* // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam TNodeRef the node reference type to call the method on
* \tparam TNode the node type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename TNodeRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...)) {
return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
});
}

/*!
* \brief set the body of the function to be the passed method pointer.
* Used when calling a method on a Node subclass through a NodeRef subclass.
* Note that this will ignore default arg values and always require all arguments to be provided.
*
* \code
*
* // node subclass:
* struct ExampleNode: BaseNode {
* int doThing(int x);
* }
*
* // noderef subclass
* struct Example;
*
* TVM_REGISTER_API("Example_doThing")
* .set_body_method<Example>(&ExampleNode::doThing); // will have type int(Example, int)
*
* // note that just doing:
* // .set_body_method(&ExampleNode::doThing);
* // wouldn't work, because ExampleNode can't be taken from a TVMArgValue.
*
* \endcode
*
* \param f the method pointer to forward to.
* \tparam TNodeRef the node reference type to call the method on
* \tparam TNode the node type containing the method (inferred).
* \tparam R the return type of the function (inferred).
* \tparam Args the argument types of the function (inferred).
*/
template<typename TNodeRef, typename TNode, typename R, typename ...Args,
typename = typename std::enable_if<std::is_base_of<NodeRef, TNodeRef>::value>::type>
Registry& set_body_method(R (TNode::*f)(Args...) const) {
return set_body_typed<R(TNodeRef, Args...)>([f](TNodeRef ref, Args... params) {
const TNode* target = ref.operator->();
// call method pointer
return (target->*f)(params...);
});
}

/*!
* \brief Register a function with given name
* \param name The name of the function.
Expand Down
4 changes: 1 addition & 3 deletions nnvm/src/compiler/compile_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,7 @@ TVM_REGISTER_GLOBAL("nnvm.compiler.GraphKeyGetGraph")
});

TVM_REGISTER_GLOBAL("nnvm.compiler.MakeGraphKey")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = GraphKeyNode::make(args[0], args[1], args[2]);
});
.set_body_typed(GraphKeyNode::make);

// This can be used to extract workloads from nnvm compiler
TVM_REGISTER_GLOBAL("nnvm.compiler.CacheItem2ScheduleArgs")
Expand Down
4 changes: 1 addition & 3 deletions nnvm/src/compiler/graph_hash.cc
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,6 @@ std::string GraphDeepCompare(const Graph& a,
}

TVM_REGISTER_GLOBAL("nnvm.graph.DeepCompare")
.set_body([](tvm::runtime::TVMArgs args, tvm::runtime::TVMRetValue *rv) {
*rv = GraphDeepCompare(args[0], args[1], args[2]);
});
.set_body_typed(GraphDeepCompare);
} // namespace compiler
} // namespace nnvm
60 changes: 19 additions & 41 deletions src/api/api_arith.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,73 +31,51 @@ namespace tvm {
namespace arith {

TVM_REGISTER_API("arith.intset_single_point")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::single_point(args[0]);
});
.set_body_typed(IntSet::single_point);

TVM_REGISTER_API("arith.intset_vector")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::vector(args[0]);
});
.set_body_typed(IntSet::vector);

TVM_REGISTER_API("arith.intset_interval")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = IntSet::interval(args[0], args[1]);
});
.set_body_typed(IntSet::interval);

TVM_REGISTER_API("arith.DetectLinearEquation")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DetectLinearEquation(args[0], args[1]);
});
.set_body_typed(DetectLinearEquation);

TVM_REGISTER_API("arith.DetectClipBound")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DetectClipBound(args[0], args[1]);
});
.set_body_typed(DetectClipBound);

TVM_REGISTER_API("arith.DeduceBound")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DeduceBound(args[0], args[1],
args[2].operator Map<Var, IntSet>(),
args[3].operator Map<Var, IntSet>());
});
.set_body_typed<IntSet(Expr, Expr, Map<Var, IntSet>, Map<Var, IntSet>)>([](
Expr v, Expr cond,
const Map<Var, IntSet> hint_map,
const Map<Var, IntSet> relax_map
) {
return DeduceBound(v, cond, hint_map, relax_map);
});


TVM_REGISTER_API("arith.DomainTouched")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = DomainTouched(args[0], args[1], args[2], args[3]);
});
.set_body_typed(DomainTouched);


TVM_REGISTER_API("_IntervalSetGetMin")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().min();
});
.set_body_method(&IntSet::min);

TVM_REGISTER_API("_IntervalSetGetMax")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().max();
});
.set_body_method(&IntSet::max);

TVM_REGISTER_API("_IntSetIsNothing")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_nothing();
});
.set_body_method(&IntSet::is_nothing);

TVM_REGISTER_API("_IntSetIsEverything")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = args[0].operator IntSet().is_everything();
});
.set_body_method(&IntSet::is_everything);

TVM_REGISTER_API("arith._make_ConstIntBound")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ConstIntBoundNode::make(args[0], args[1]);
});
.set_body_typed(ConstIntBoundNode::make);

TVM_REGISTER_API("arith._make_ModularSet")
.set_body([](TVMArgs args, TVMRetValue* ret) {
*ret = ModularSetNode::make(args[0], args[1]);
});
.set_body_typed(ModularSetNode::make);

TVM_REGISTER_API("arith._CreateAnalyzer")
.set_body([](TVMArgs args, TVMRetValue* ret) {
Expand Down
5 changes: 2 additions & 3 deletions src/api/api_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@ TVM_REGISTER_API("_load_json")
.set_body_typed<NodeRef(std::string)>(LoadJSON<NodeRef>);

TVM_REGISTER_API("_TVMSetStream")
.set_body([](TVMArgs args, TVMRetValue *ret) {
TVMSetStream(args[0], args[1], args[2]);
});
.set_body_typed(TVMSetStream);

TVM_REGISTER_API("_save_param_dict")
.set_body([](TVMArgs args, TVMRetValue *rv) {
CHECK_EQ(args.size() % 2, 0u);
Expand Down
4 changes: 1 addition & 3 deletions src/api/api_codegen.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@ TVM_REGISTER_API("codegen._Build")
});

TVM_REGISTER_API("module._PackImportsToC")
.set_body([](TVMArgs args, TVMRetValue *ret) {
*ret = PackImportsToC(args[0], args[1]);
});
.set_body_typed(PackImportsToC);
} // namespace codegen
} // namespace tvm
Loading

0 comments on commit 5178506

Please sign in to comment.