From 37288a40c376a5c384e81d0309fa2432d9639724 Mon Sep 17 00:00:00 2001 From: "Edward Z. Yang" Date: Thu, 9 Jul 2020 08:51:37 -0700 Subject: [PATCH] Update torch_script_custom_classes to use TORCH_LIBRARY Signed-off-by: Edward Z. Yang --- .../torch_script_custom_classes.rst | 47 +++-- .../custom_class_project/class.cpp | 165 ++++++++---------- 2 files changed, 106 insertions(+), 106 deletions(-) diff --git a/advanced_source/torch_script_custom_classes.rst b/advanced_source/torch_script_custom_classes.rst index c5d34b13100..be13759cdc6 100644 --- a/advanced_source/torch_script_custom_classes.rst +++ b/advanced_source/torch_script_custom_classes.rst @@ -25,12 +25,14 @@ There are several things to note: with your custom class. - Notice that whenever we are working with instances of the custom class, we do it via instances of ``c10::intrusive_ptr<>``. Think of ``intrusive_ptr`` - as a smart pointer like ``std::shared_ptr``. The reason for using this smart pointer - is to ensure consistent lifetime management of the object instances between languages - (C++, Python and TorchScript). + as a smart pointer like ``std::shared_ptr``, but the reference count is stored + directly in the object, as opposed to a separate metadata block (as is done in + ``std::shared_ptr``. ``torch::Tensor`` internally uses the same pointer type; + and custom classes have to also use this pointer type so that we can + consistently manage different object types. - The second thing to notice is that the user-defined class must inherit from - ``torch::CustomClassHolder``. This ensures that everything is set up to handle - the lifetime management system previously mentioned. + ``torch::CustomClassHolder``. This ensures that the custom class has space to + store the reference count. Now let's take a look at how we will make this class visible to TorchScript, a process called *binding* the class: @@ -39,6 +41,9 @@ Now let's take a look at how we will make this class visible to TorchScript, a p :language: cpp :start-after: BEGIN binding :end-before: END binding + :append: + ; + } @@ -269,13 +274,13 @@ the special ``def_pickle`` method on ``class_``. `read more `_ about how we use these methods. -Here is an example of how we can update the registration code for our -``MyStackClass`` class to include serialization methods: +Here is an example of the ``def_pickle`` call we can add to the registration of +``MyStackClass`` to include serialization methods: .. literalinclude:: ../advanced_source/torch_script_custom_classes/custom_class_project/class.cpp :language: cpp - :start-after: BEGIN pickle_binding - :end-before: END pickle_binding + :start-after: BEGIN def_pickle + :end-before: END def_pickle .. note:: We take a different approach from pybind11 in the pickle API. Whereas pybind11 @@ -295,14 +300,22 @@ Defining Custom Operators that Take or Return Bound C++ Classes --------------------------------------------------------------- Once you've defined a custom C++ class, you can also use that class -as an argument or return from a custom operator (i.e. free functions). Here's an -example of how to do that: +as an argument or return from a custom operator (i.e. free functions). Suppose +you have the following free function: .. literalinclude:: ../advanced_source/torch_script_custom_classes/custom_class_project/class.cpp :language: cpp :start-after: BEGIN free_function :end-before: END free_function +You can register it running the following code inside your ``TORCH_LIBRARY`` +block: + +.. literalinclude:: ../advanced_source/torch_script_custom_classes/custom_class_project/class.cpp + :language: cpp + :start-after: BEGIN def_free + :end-before: END def_free + Refer to the `custom op tutorial `_ for more details on the registration API. @@ -321,12 +334,12 @@ Once this is done, you can use the op like the following example: .. note:: Registration of an operator that takes a C++ class as an argument requires that - the custom class has already been registered. This is fine if your op is - registered after your class in a single compilation unit, however, if your - class is registered in a separate compilation unit from the op you will need - to enforce that dependency. One way to do this is to wrap the class registration - in a `Meyer's singleton `_, which can be - called from the compilation unit that does the operator registration. + the custom class has already been registered. You can enforce this by + making sure the custom class registration and your free function definitions + are in the same ``TORCH_LIBRARY`` block, and that the custom class + registration comes first. In the future, we may relax this requirement, + so that these can be registered in any order. + Conclusion ---------- diff --git a/advanced_source/torch_script_custom_classes/custom_class_project/class.cpp b/advanced_source/torch_script_custom_classes/custom_class_project/class.cpp index dd3480bee75..c5849cef102 100644 --- a/advanced_source/torch_script_custom_classes/custom_class_project/class.cpp +++ b/advanced_source/torch_script_custom_classes/custom_class_project/class.cpp @@ -37,7 +37,12 @@ struct MyStackClass : torch::CustomClassHolder { }; // END class -#ifdef NO_PICKLE +// BEGIN free_function +c10::intrusive_ptr> manipulate_instance(const c10::intrusive_ptr>& instance) { + instance->pop(); + return instance; +} +// END free_function // BEGIN binding // Notice a few things: @@ -52,94 +57,76 @@ struct MyStackClass : torch::CustomClassHolder { // Python and C++ as `torch.classes.my_classes.MyStackClass`. We call // the first argument the "namespace" and the second argument the // actual class name. -static auto testStack = - torch::class_>("my_classes", "MyStackClass") - // The following line registers the contructor of our MyStackClass - // class that takes a single `std::vector` argument, - // i.e. it exposes the C++ method `MyStackClass(std::vector init)`. - // Currently, we do not support registering overloaded - // constructors, so for now you can only `def()` one instance of - // `torch::init`. - .def(torch::init>()) - // The next line registers a stateless (i.e. no captures) C++ lambda - // function as a method. Note that a lambda function must take a - // `c10::intrusive_ptr` (or some const/ref version of that) - // as the first argument. Other arguments can be whatever you want. - .def("top", [](const c10::intrusive_ptr>& self) { - return self->stack_.back(); - }) - // The following four lines expose methods of the MyStackClass - // class as-is. `torch::class_` will automatically examine the - // argument and return types of the passed-in method pointers and - // expose these to Python and TorchScript accordingly. Finally, notice - // that we must take the *address* of the fully-qualified method name, - // i.e. use the unary `&` operator, due to C++ typing rules. - .def("push", &MyStackClass::push) - .def("pop", &MyStackClass::pop) - .def("clone", &MyStackClass::clone) - .def("merge", &MyStackClass::merge); +TORCH_LIBRARY(my_classes, m) { + m.class_>("MyStackClass") + // The following line registers the contructor of our MyStackClass + // class that takes a single `std::vector` argument, + // i.e. it exposes the C++ method `MyStackClass(std::vector init)`. + // Currently, we do not support registering overloaded + // constructors, so for now you can only `def()` one instance of + // `torch::init`. + .def(torch::init>()) + // The next line registers a stateless (i.e. no captures) C++ lambda + // function as a method. Note that a lambda function must take a + // `c10::intrusive_ptr` (or some const/ref version of that) + // as the first argument. Other arguments can be whatever you want. + .def("top", [](const c10::intrusive_ptr>& self) { + return self->stack_.back(); + }) + // The following four lines expose methods of the MyStackClass + // class as-is. `torch::class_` will automatically examine the + // argument and return types of the passed-in method pointers and + // expose these to Python and TorchScript accordingly. Finally, notice + // that we must take the *address* of the fully-qualified method name, + // i.e. use the unary `&` operator, due to C++ typing rules. + .def("push", &MyStackClass::push) + .def("pop", &MyStackClass::pop) + .def("clone", &MyStackClass::clone) + .def("merge", &MyStackClass::merge) // END binding +#ifndef NO_PICKLE +// BEGIN def_pickle + // class_<>::def_pickle allows you to define the serialization + // and deserialization methods for your C++ class. + // Currently, we only support passing stateless lambda functions + // as arguments to def_pickle + .def_pickle( + // __getstate__ + // This function defines what data structure should be produced + // when we serialize an instance of this class. The function + // must take a single `self` argument, which is an intrusive_ptr + // to the instance of the object. The function can return + // any type that is supported as a return value of the TorchScript + // custom operator API. In this instance, we've chosen to return + // a std::vector as the salient data to preserve + // from the class. + [](const c10::intrusive_ptr>& self) + -> std::vector { + return self->stack_; + }, + // __setstate__ + // This function defines how to create a new instance of the C++ + // class when we are deserializing. The function must take a + // single argument of the same type as the return value of + // `__getstate__`. The function must return an intrusive_ptr + // to a new instance of the C++ class, initialized however + // you would like given the serialized state. + [](std::vector state) + -> c10::intrusive_ptr> { + // A convenient way to instantiate an object and get an + // intrusive_ptr to it is via `make_intrusive`. We use + // that here to allocate an instance of MyStackClass + // and call the single-argument std::vector + // constructor with the serialized state. + return c10::make_intrusive>(std::move(state)); + }); +// END def_pickle +#endif // NO_PICKLE -#else - -// BEGIN pickle_binding -static auto testStack = - torch::class_>("my_classes", "MyStackClass") - .def(torch::init>()) - .def("top", [](const c10::intrusive_ptr>& self) { - return self->stack_.back(); - }) - .def("push", &MyStackClass::push) - .def("pop", &MyStackClass::pop) - .def("clone", &MyStackClass::clone) - .def("merge", &MyStackClass::merge) - // class_<>::def_pickle allows you to define the serialization - // and deserialization methods for your C++ class. - // Currently, we only support passing stateless lambda functions - // as arguments to def_pickle - .def_pickle( - // __getstate__ - // This function defines what data structure should be produced - // when we serialize an instance of this class. The function - // must take a single `self` argument, which is an intrusive_ptr - // to the instance of the object. The function can return - // any type that is supported as a return value of the TorchScript - // custom operator API. In this instance, we've chosen to return - // a std::vector as the salient data to preserve - // from the class. - [](const c10::intrusive_ptr>& self) - -> std::vector { - return self->stack_; - }, - // __setstate__ - // This function defines how to create a new instance of the C++ - // class when we are deserializing. The function must take a - // single argument of the same type as the return value of - // `__getstate__`. The function must return an intrusive_ptr - // to a new instance of the C++ class, initialized however - // you would like given the serialized state. - [](std::vector state) - -> c10::intrusive_ptr> { - // A convenient way to instantiate an object and get an - // intrusive_ptr to it is via `make_intrusive`. We use - // that here to allocate an instance of MyStackClass - // and call the single-argument std::vector - // constructor with the serialized state. - return c10::make_intrusive>(std::move(state)); - }); -// END pickle_binding - -// BEGIN free_function -c10::intrusive_ptr> manipulate_instance(const c10::intrusive_ptr>& instance) { - instance->pop(); - return instance; +// BEGIN def_free + m.def( + "foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y", + manipulate_instance + ); +// END def_free } - -static auto instance_registry = torch::RegisterOperators().op( -torch::RegisterOperators::options() - .schema( - "foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y") - .catchAllKernel()); -// END free_function - -#endif