From bb824227eac09a433d0406a52f4d50987d0b2805 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 5 Aug 2022 00:51:01 -0400 Subject: [PATCH 1/6] Add traced_object_functor Co-authored-by: Junru Shao Co-authored-by: Greg Bonik --- include/tvm/script/printer/traced_object.h | 2 + .../script/printer/traced_object_functor.h | 168 +++++++++++++++++ src/script/printer/traced_object_functor.cc | 77 ++++++++ ...ript_printer_traced_object_functor_test.cc | 174 ++++++++++++++++++ 4 files changed, 421 insertions(+) create mode 100644 include/tvm/script/printer/traced_object_functor.h create mode 100644 src/script/printer/traced_object_functor.cc create mode 100644 tests/cpp/tvmscript_printer_traced_object_functor_test.cc diff --git a/include/tvm/script/printer/traced_object.h b/include/tvm/script/printer/traced_object.h index 6f04b66cec97..4c09b0a41b79 100644 --- a/include/tvm/script/printer/traced_object.h +++ b/include/tvm/script/printer/traced_object.h @@ -86,6 +86,8 @@ class TracedObject { using ObjectType = typename RefT::ContainerType; public: + using ObjectRefType = RefT; + // Don't use this direcly. For convenience, call MakeTraced() instead. explicit TracedObject(const RefT& object_ref, ObjectPath path) : ref_(object_ref), path_(std::move(path)) {} diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h new file mode 100644 index 000000000000..d10a2d6e649d --- /dev/null +++ b/include/tvm/script/printer/traced_object_functor.h @@ -0,0 +1,168 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#ifndef TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_ +#define TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_ + +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +namespace tvm { +namespace script { +namespace printer { + +namespace { + +namespace detail { +/*! + * \brief Helper template class to extract the type of first argument of a function + * \tparam FType The function type. + */ +template +struct first_arg_type_helper; + +template +struct first_arg_type_helper { + using T = ArgOne; +}; + +/*! + * \brief Template alias for the type of first argument of a function + * \tparam FType The function type. + * + * The name of public functions are in snake case to be consistent with + * tvm/node/functor.h + */ +template +using first_arg_type = typename detail::first_arg_type_helper< + typename tvm::runtime::detail::function_signature::FType>::T; +} // namespace detail + +} // namespace + +namespace dispatch_table { +/* + * Functions in dispatch_table namespace is created to reduce the binary bloat + * from template and also hide implementation details from this header + */ + +using DispatchTable = std::unordered_map>; + +constexpr const char* kDefaultDispatchToken = ""; + +const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table, + const String& token, uint32_t type_index); +void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index, + runtime::PackedFunc f); +} // namespace dispatch_table + +/*! + * \brief Dynamic dispatch functor based on TracedObject. + * + * This functor dispatches based on the type of object ref inside the input TracedObject, + * and the input dispatch token. + */ +template +class TracedObjectFunctor { + private: + using TSelf = TracedObjectFunctor; + + template + using IsDispatchFunction = + typename std::is_convertible, Args...)>>; + + public: + /*! + * \brief Call the dispatch function. + * \param token The dispatch token. + * \param traced_object The traced object. + * \param args Other args. + * + * \return The return value of the dispatch function + * + * If the TObjectRef isn't registered with the token, it will try to find + * dispatch function for TObjectRef with kDefaultDispatchToken. + */ + template + R operator()(const String& token, TracedObject traced_object, Args... args) const { + const runtime::PackedFunc& dispatch_function = dispatch_table::GetDispatchFunction( + dispatch_table_, token, traced_object.Get()->type_index()); + return dispatch_function(traced_object.Get(), traced_object.GetPath(), + std::forward(args)...); + } + + /*! + * \brief Set the dispatch function + * \param f The dispatch function. + * + * This takes a type-erased packed function as input. It should be used + * through FFI boundary, for example, registering dispatch function from Python. + */ + TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) { + dispatch_table::SetDispatchFunction(&dispatch_table_, token, type_index, std::move(f)); + return *this; + } + + /*! + * \brief Set the dispatch function + * \param token The dispatch token. + * \param f The dispatch function. + * + * The diaptch function should have signature `R(TracedObject, Args...)`. + */ + template ::ObjectRefType, + typename = std::enable_if_t::value>> + TSelf& set_dispatch(String token, TCallable f) { + return set_dispatch(token, // + TObjectRef::ContainerType::RuntimeTypeIndex(), // + runtime::TypedPackedFunc( + [f](TObjectRef object, ObjectPath path, Args... args) -> R { + return f(MakeTraced(object, path), std::forward(args)...); + })); + } + /*! + * \brief Set the default dispatch function + * \param f The dispatch function. + * + * Default dispatch function will be used if there is no function registered + * with the requested dispatch token. + * + * Default dispatch function has an empty string as dispatch token. + */ + template + TSelf& set_dispatch(TCallable f) { + return set_dispatch(dispatch_table::kDefaultDispatchToken, std::forward(f)); + } + + private: + dispatch_table::DispatchTable dispatch_table_; +}; + +} // namespace printer +} // namespace script +} // namespace tvm +#endif // TVM_SCRIPT_PRINTER_TRACED_OBJECT_FUNCTOR_H_ diff --git a/src/script/printer/traced_object_functor.cc b/src/script/printer/traced_object_functor.cc new file mode 100644 index 000000000000..64483602a83c --- /dev/null +++ b/src/script/printer/traced_object_functor.cc @@ -0,0 +1,77 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include + +namespace tvm { +namespace script { +namespace printer { +namespace dispatch_table { + +const runtime::PackedFunc* GetDispatchFunctionForToken(const DispatchTable& table, + const String& token, uint32_t type_index) { + auto it = table.find(token); + if (it == table.end()) { + return nullptr; + } + const std::vector& tab = it->second; + if (type_index >= tab.size()) { + return nullptr; + } + const PackedFunc* f = &tab[type_index]; + if (f->defined()) { + return f; + } else { + return nullptr; + } +} + +const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table, + const String& token, uint32_t type_index) { + if (const runtime::PackedFunc* pf = + GetDispatchFunctionForToken(dispatch_table, token, type_index)) { + return *pf; + } else if (const runtime::PackedFunc* pf = + GetDispatchFunctionForToken(dispatch_table, kDefaultDispatchToken, type_index)) { + // Fallback to function with the default dispatch token + return *pf; + } else { + ICHECK(false) << "ObjectFunctor calls un-registered function on type: " + << runtime::Object::TypeIndex2Key(type_index) << " (token: " << token << ")"; + throw; + } +} + +void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index, + runtime::PackedFunc f) { + std::vector* table = &(*dispatch_table)[token]; + if (table->size() <= type_index) { + table->resize(type_index + 1, nullptr); + } + runtime::PackedFunc& slot = (*table)[type_index]; + if (slot != nullptr) { + ICHECK(false) << "Dispatch for type is already registered: " + << runtime::Object::TypeIndex2Key(type_index); + } + slot = f; +} +} // namespace dispatch_table +} // namespace printer +} // namespace script +} // namespace tvm diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc new file mode 100644 index 000000000000..73f9e40dfe10 --- /dev/null +++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc @@ -0,0 +1,174 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include "tvm/node/object_path.h" +#include "tvm/runtime/packed_func.h" +#include "tvm/script/printer/traced_object.h" + +using namespace tvm; +using namespace tvm::script::printer; + +namespace { + +class FooObjectNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const char* _type_key = "test.FooObject"; + TVM_DECLARE_FINAL_OBJECT_INFO(FooObjectNode, Object); +}; + +class FooObject : public ObjectRef { + public: + FooObject() { this->data_ = make_object(); } + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(FooObject, ObjectRef, FooObjectNode); +}; + +TVM_REGISTER_NODE_TYPE(FooObjectNode); + +class BarObjectNode : public Object { + public: + void VisitAttrs(AttrVisitor* v) {} + + static constexpr const char* _type_key = "test.BarObject"; + TVM_DECLARE_FINAL_OBJECT_INFO(BarObjectNode, Object); +}; + +class BarObject : public ObjectRef { + public: + BarObject() { this->data_ = make_object(); } + TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(BarObject, ObjectRef, BarObjectNode); +}; + +TVM_REGISTER_NODE_TYPE(BarObjectNode); + +String ComputeFoo(TracedObject foo) { return "Foo"; } + +} // anonymous namespace + +TEST(TracedObjectFunctorTest, NormalRegistration) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch([](TracedObject o) -> String { return "Bar"; }); + + ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); + ICHECK_EQ(functor("", MakeTraced(BarObject(), path)), "Bar"); +} + +TEST(TracedObjectFunctorTest, RegistrationWithFunction) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch([](TracedObject o) -> String { return "FooLambda"; }); + functor.set_dispatch("tir", ComputeFoo); + + ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "FooLambda"); + ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo"); +} + +TEST(TracedObjectFunctorTest, RegistrationWithDispatchToken) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch([](TracedObject o) -> String { return "Foo"; }); + functor.set_dispatch("tir", [](TracedObject o) -> String { return "Foo tir"; }); + functor.set_dispatch("relax", [](TracedObject o) -> String { return "Foo relax"; }); + + ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "Foo"); + ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "Foo tir"); + ICHECK_EQ(functor("relax", MakeTraced(FooObject(), path)), "Foo relax"); + ICHECK_EQ(functor("xyz", MakeTraced(FooObject(), path)), "Foo"); +} + +TEST(TracedObjectFunctorTest, RegistrationWithPackedFunc) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + auto f_default = [](runtime::TVMArgs, runtime::TVMRetValue* ret) { *ret = String("default"); }; + auto f_tir = [](runtime::TVMArgs, runtime::TVMRetValue* ret) { *ret = String("tir"); }; + + functor.set_dispatch("", FooObjectNode::RuntimeTypeIndex(), runtime::PackedFunc(f_default)); + functor.set_dispatch("tir", FooObjectNode::RuntimeTypeIndex(), runtime::PackedFunc(f_tir)); + + ICHECK_EQ(functor("", MakeTraced(FooObject(), path)), "default"); + ICHECK_EQ(functor("tir", MakeTraced(FooObject(), path)), "tir"); +} + +TEST(TracedObjectFunctorTest, ExtraArg) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch([](TracedObject o, int x) { return x; }); + functor.set_dispatch([](TracedObject o, int x) { return x + 1; }); + + ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2); + ICHECK_EQ(functor("", MakeTraced(BarObject(), path), 2), 3); + ICHECK_EQ(functor("tir", MakeTraced(BarObject(), path), 2), 3); +} + +TEST(TracedObjectFunctorTest, CallWithUnregisteredType) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + bool failed = false; + try { + ICHECK_EQ(functor("", MakeTraced(FooObject(), path), 2), 2); + } catch (...) { + failed = true; + } + ASSERT_EQ(failed, true); +} + +TEST(TracedObjectFunctorTest, DuplicateRegistration_WithoutToken) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch([](TracedObject o, int x) { return x; }); + + bool failed = false; + try { + functor.set_dispatch([](TracedObject o, int x) { return x; }); + } catch (...) { + failed = true; + } + ASSERT_EQ(failed, true); +} + +TEST(TracedObjectFunctorTest, DuplicateRegistration_WithToken) { + TracedObjectFunctor functor; + ObjectPath path = ObjectPath::Root(); + + functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); + + bool failed = false; + try { + functor.set_dispatch("tir", [](TracedObject o, int x) { return x; }); + } catch (...) { + failed = true; + } + ASSERT_EQ(failed, true); +} From 33b45493a70844d74f3a20f9c807898ac2f71147 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Fri, 5 Aug 2022 09:38:28 -0400 Subject: [PATCH 2/6] Add missing doc --- include/tvm/script/printer/traced_object_functor.h | 2 ++ 1 file changed, 2 insertions(+) diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h index d10a2d6e649d..3da075b76f7e 100644 --- a/include/tvm/script/printer/traced_object_functor.h +++ b/include/tvm/script/printer/traced_object_functor.h @@ -116,6 +116,8 @@ class TracedObjectFunctor { /*! * \brief Set the dispatch function + * \param token The dispatch token. + * \param type_index The TVM object type index for this dispatch function. * \param f The dispatch function. * * This takes a type-erased packed function as input. It should be used From 6285e2236dd41c51992c119e7e6f01cb88c83965 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Mon, 8 Aug 2022 10:01:24 -0400 Subject: [PATCH 3/6] Fix include style --- tests/cpp/tvmscript_printer_traced_object_functor_test.cc | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc index 73f9e40dfe10..b78abcdd64c3 100644 --- a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc +++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc @@ -19,14 +19,13 @@ #include #include +#include #include #include +#include +#include #include -#include "tvm/node/object_path.h" -#include "tvm/runtime/packed_func.h" -#include "tvm/script/printer/traced_object.h" - using namespace tvm; using namespace tvm::script::printer; From bb7e1328c2cb4a7ff22e3e18f3a8057c3bfe37ff Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Mon, 8 Aug 2022 10:08:13 -0400 Subject: [PATCH 4/6] Remove unused headers from test --- tests/cpp/tvmscript_printer_traced_object_functor_test.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc index b78abcdd64c3..3fd52d44aa8c 100644 --- a/tests/cpp/tvmscript_printer_traced_object_functor_test.cc +++ b/tests/cpp/tvmscript_printer_traced_object_functor_test.cc @@ -20,8 +20,6 @@ #include #include #include -#include -#include #include #include #include From 34ab326acba94dbf61eecc5d3898cfea67c37eb9 Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Mon, 8 Aug 2022 14:14:52 -0400 Subject: [PATCH 5/6] Improve code structure --- .../script/printer/traced_object_functor.h | 43 ++++++++++++------- src/script/printer/traced_object_functor.cc | 2 - 2 files changed, 28 insertions(+), 17 deletions(-) diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h index 3da075b76f7e..e1ec287072e5 100644 --- a/include/tvm/script/printer/traced_object_functor.h +++ b/include/tvm/script/printer/traced_object_functor.h @@ -42,10 +42,10 @@ namespace detail { * \tparam FType The function type. */ template -struct first_arg_type_helper; +struct FirstArgTypeGetter; template -struct first_arg_type_helper { +struct FirstArgTypeGetter { using T = ArgOne; }; @@ -57,27 +57,40 @@ struct first_arg_type_helper { * tvm/node/functor.h */ template -using first_arg_type = typename detail::first_arg_type_helper< +using FirstArgType = typename detail::FirstArgTypeGetter< typename tvm::runtime::detail::function_signature::FType>::T; } // namespace detail } // namespace -namespace dispatch_table { /* - * Functions in dispatch_table namespace is created to reduce the binary bloat + * This type alias and the following free functions are created to reduce the binary bloat * from template and also hide implementation details from this header */ - using DispatchTable = std::unordered_map>; -constexpr const char* kDefaultDispatchToken = ""; - +/*! + * \brief Get function from dispatch table. + * \param dispatch_table The dispatch table. + * \param token The dispatch token. + * \param type_index The type index of the Object type to be dispatched. + * + * \return The dispatch function. + */ const runtime::PackedFunc& GetDispatchFunction(const DispatchTable& dispatch_table, const String& token, uint32_t type_index); + +/*! + * \brief Set function in dispatch table. + * \param dispatch_table The dispatch table. + * \param token The dispatch token. + * \param type_index The type index of the Object type to be dispatched. + * \param f The dispatch function. + */ void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uint32_t type_index, runtime::PackedFunc f); -} // namespace dispatch_table + +constexpr const char* kDefaultDispatchToken = ""; /*! * \brief Dynamic dispatch functor based on TracedObject. @@ -108,8 +121,8 @@ class TracedObjectFunctor { */ template R operator()(const String& token, TracedObject traced_object, Args... args) const { - const runtime::PackedFunc& dispatch_function = dispatch_table::GetDispatchFunction( - dispatch_table_, token, traced_object.Get()->type_index()); + const runtime::PackedFunc& dispatch_function = + GetDispatchFunction(dispatch_table_, token, traced_object.Get()->type_index()); return dispatch_function(traced_object.Get(), traced_object.GetPath(), std::forward(args)...); } @@ -124,7 +137,7 @@ class TracedObjectFunctor { * through FFI boundary, for example, registering dispatch function from Python. */ TSelf& set_dispatch(String token, uint32_t type_index, runtime::PackedFunc f) { - dispatch_table::SetDispatchFunction(&dispatch_table_, token, type_index, std::move(f)); + SetDispatchFunction(&dispatch_table_, token, type_index, std::move(f)); return *this; } @@ -136,7 +149,7 @@ class TracedObjectFunctor { * The diaptch function should have signature `R(TracedObject, Args...)`. */ template ::ObjectRefType, + typename TObjectRef = typename detail::FirstArgType::ObjectRefType, typename = std::enable_if_t::value>> TSelf& set_dispatch(String token, TCallable f) { return set_dispatch(token, // @@ -157,11 +170,11 @@ class TracedObjectFunctor { */ template TSelf& set_dispatch(TCallable f) { - return set_dispatch(dispatch_table::kDefaultDispatchToken, std::forward(f)); + return set_dispatch(kDefaultDispatchToken, std::forward(f)); } private: - dispatch_table::DispatchTable dispatch_table_; + DispatchTable dispatch_table_; }; } // namespace printer diff --git a/src/script/printer/traced_object_functor.cc b/src/script/printer/traced_object_functor.cc index 64483602a83c..a018099a1de0 100644 --- a/src/script/printer/traced_object_functor.cc +++ b/src/script/printer/traced_object_functor.cc @@ -22,7 +22,6 @@ namespace tvm { namespace script { namespace printer { -namespace dispatch_table { const runtime::PackedFunc* GetDispatchFunctionForToken(const DispatchTable& table, const String& token, uint32_t type_index) { @@ -71,7 +70,6 @@ void SetDispatchFunction(DispatchTable* dispatch_table, const String& token, uin } slot = f; } -} // namespace dispatch_table } // namespace printer } // namespace script } // namespace tvm From 0bdd60ab1d6e4d9a6da9e1f3e3c044fe80e378ab Mon Sep 17 00:00:00 2001 From: Lite Ye Date: Mon, 8 Aug 2022 14:50:31 -0400 Subject: [PATCH 6/6] Correct the use of forward --- .../tvm/script/printer/traced_object_functor.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/include/tvm/script/printer/traced_object_functor.h b/include/tvm/script/printer/traced_object_functor.h index e1ec287072e5..05fbbf79f2ee 100644 --- a/include/tvm/script/printer/traced_object_functor.h +++ b/include/tvm/script/printer/traced_object_functor.h @@ -123,8 +123,7 @@ class TracedObjectFunctor { R operator()(const String& token, TracedObject traced_object, Args... args) const { const runtime::PackedFunc& dispatch_function = GetDispatchFunction(dispatch_table_, token, traced_object.Get()->type_index()); - return dispatch_function(traced_object.Get(), traced_object.GetPath(), - std::forward(args)...); + return dispatch_function(traced_object.Get(), traced_object.GetPath(), args...); } /*! @@ -152,12 +151,13 @@ class TracedObjectFunctor { typename TObjectRef = typename detail::FirstArgType::ObjectRefType, typename = std::enable_if_t::value>> TSelf& set_dispatch(String token, TCallable f) { - return set_dispatch(token, // - TObjectRef::ContainerType::RuntimeTypeIndex(), // - runtime::TypedPackedFunc( - [f](TObjectRef object, ObjectPath path, Args... args) -> R { - return f(MakeTraced(object, path), std::forward(args)...); - })); + return set_dispatch( + token, // + TObjectRef::ContainerType::RuntimeTypeIndex(), // + runtime::TypedPackedFunc( + [f = std::move(f)](TObjectRef object, ObjectPath path, Args... args) -> R { + return f(MakeTraced(object, path), args...); + })); } /*! * \brief Set the default dispatch function @@ -169,7 +169,7 @@ class TracedObjectFunctor { * Default dispatch function has an empty string as dispatch token. */ template - TSelf& set_dispatch(TCallable f) { + TSelf& set_dispatch(TCallable&& f) { return set_dispatch(kDefaultDispatchToken, std::forward(f)); }