Skip to content

Update torch_script_custom_classes to use TORCH_LIBRARY #1062

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions advanced_source/torch_script_custom_classes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,14 @@ There are several things to note:
with your custom class.
- Notice that whenever we are working with instances of the custom
class, we do it via instances of ``c10::intrusive_ptr<>``. Think of ``intrusive_ptr``
as a smart pointer like ``std::shared_ptr``. The reason for using this smart pointer
is to ensure consistent lifetime management of the object instances between languages
(C++, Python and TorchScript).
as a smart pointer like ``std::shared_ptr``, but the reference count is stored
directly in the object, as opposed to a separate metadata block (as is done in
``std::shared_ptr``. ``torch::Tensor`` internally uses the same pointer type;
and custom classes have to also use this pointer type so that we can
consistently manage different object types.
- The second thing to notice is that the user-defined class must inherit from
``torch::CustomClassHolder``. This ensures that everything is set up to handle
the lifetime management system previously mentioned.
``torch::CustomClassHolder``. This ensures that the custom class has space to
store the reference count.

Now let's take a look at how we will make this class visible to TorchScript, a process called
*binding* the class:
Expand All @@ -39,6 +41,9 @@ Now let's take a look at how we will make this class visible to TorchScript, a p
:language: cpp
:start-after: BEGIN binding
:end-before: END binding
:append:
;
}



Expand Down Expand Up @@ -269,13 +274,13 @@ the special ``def_pickle`` method on ``class_``.
`read more <https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/docs/serialization.md#getstate-and-setstate>`_
about how we use these methods.

Here is an example of how we can update the registration code for our
``MyStackClass`` class to include serialization methods:
Here is an example of the ``def_pickle`` call we can add to the registration of
``MyStackClass`` to include serialization methods:

.. literalinclude:: ../advanced_source/torch_script_custom_classes/custom_class_project/class.cpp
:language: cpp
:start-after: BEGIN pickle_binding
:end-before: END pickle_binding
:start-after: BEGIN def_pickle
:end-before: END def_pickle

.. note::
We take a different approach from pybind11 in the pickle API. Whereas pybind11
Expand All @@ -295,14 +300,22 @@ Defining Custom Operators that Take or Return Bound C++ Classes
---------------------------------------------------------------

Once you've defined a custom C++ class, you can also use that class
as an argument or return from a custom operator (i.e. free functions). Here's an
example of how to do that:
as an argument or return from a custom operator (i.e. free functions). Suppose
you have the following free function:

.. literalinclude:: ../advanced_source/torch_script_custom_classes/custom_class_project/class.cpp
:language: cpp
:start-after: BEGIN free_function
:end-before: END free_function

You can register it running the following code inside your ``TORCH_LIBRARY``
block:

.. literalinclude:: ../advanced_source/torch_script_custom_classes/custom_class_project/class.cpp
:language: cpp
:start-after: BEGIN def_free
:end-before: END def_free

Refer to the `custom op tutorial <https://pytorch.org/tutorials/advanced/torch_script_custom_ops.html>`_
for more details on the registration API.

Expand All @@ -321,12 +334,12 @@ Once this is done, you can use the op like the following example:
.. note::

Registration of an operator that takes a C++ class as an argument requires that
the custom class has already been registered. This is fine if your op is
registered after your class in a single compilation unit, however, if your
class is registered in a separate compilation unit from the op you will need
to enforce that dependency. One way to do this is to wrap the class registration
in a `Meyer's singleton <https://stackoverflow.com/q/1661529>`_, which can be
called from the compilation unit that does the operator registration.
the custom class has already been registered. You can enforce this by
making sure the custom class registration and your free function definitions
are in the same ``TORCH_LIBRARY`` block, and that the custom class
registration comes first. In the future, we may relax this requirement,
so that these can be registered in any order.


Conclusion
----------
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@ struct MyStackClass : torch::CustomClassHolder {
};
// END class

#ifdef NO_PICKLE
// BEGIN free_function
c10::intrusive_ptr<MyStackClass<std::string>> manipulate_instance(const c10::intrusive_ptr<MyStackClass<std::string>>& instance) {
instance->pop();
return instance;
}
// END free_function

// BEGIN binding
// Notice a few things:
Expand All @@ -52,94 +57,76 @@ struct MyStackClass : torch::CustomClassHolder {
// Python and C++ as `torch.classes.my_classes.MyStackClass`. We call
// the first argument the "namespace" and the second argument the
// actual class name.
static auto testStack =
torch::class_<MyStackClass<std::string>>("my_classes", "MyStackClass")
// The following line registers the contructor of our MyStackClass
// class that takes a single `std::vector<std::string>` argument,
// i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
// Currently, we do not support registering overloaded
// constructors, so for now you can only `def()` one instance of
// `torch::init`.
.def(torch::init<std::vector<std::string>>())
// The next line registers a stateless (i.e. no captures) C++ lambda
// function as a method. Note that a lambda function must take a
// `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
// as the first argument. Other arguments can be whatever you want.
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
return self->stack_.back();
})
// The following four lines expose methods of the MyStackClass<std::string>
// class as-is. `torch::class_` will automatically examine the
// argument and return types of the passed-in method pointers and
// expose these to Python and TorchScript accordingly. Finally, notice
// that we must take the *address* of the fully-qualified method name,
// i.e. use the unary `&` operator, due to C++ typing rules.
.def("push", &MyStackClass<std::string>::push)
.def("pop", &MyStackClass<std::string>::pop)
.def("clone", &MyStackClass<std::string>::clone)
.def("merge", &MyStackClass<std::string>::merge);
TORCH_LIBRARY(my_classes, m) {
m.class_<MyStackClass<std::string>>("MyStackClass")
// The following line registers the contructor of our MyStackClass
// class that takes a single `std::vector<std::string>` argument,
// i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
// Currently, we do not support registering overloaded
// constructors, so for now you can only `def()` one instance of
// `torch::init`.
.def(torch::init<std::vector<std::string>>())
// The next line registers a stateless (i.e. no captures) C++ lambda
// function as a method. Note that a lambda function must take a
// `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
// as the first argument. Other arguments can be whatever you want.
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
return self->stack_.back();
})
// The following four lines expose methods of the MyStackClass<std::string>
// class as-is. `torch::class_` will automatically examine the
// argument and return types of the passed-in method pointers and
// expose these to Python and TorchScript accordingly. Finally, notice
// that we must take the *address* of the fully-qualified method name,
// i.e. use the unary `&` operator, due to C++ typing rules.
.def("push", &MyStackClass<std::string>::push)
.def("pop", &MyStackClass<std::string>::pop)
.def("clone", &MyStackClass<std::string>::clone)
.def("merge", &MyStackClass<std::string>::merge)
// END binding
#ifndef NO_PICKLE
// BEGIN def_pickle
// class_<>::def_pickle allows you to define the serialization
// and deserialization methods for your C++ class.
// Currently, we only support passing stateless lambda functions
// as arguments to def_pickle
.def_pickle(
// __getstate__
// This function defines what data structure should be produced
// when we serialize an instance of this class. The function
// must take a single `self` argument, which is an intrusive_ptr
// to the instance of the object. The function can return
// any type that is supported as a return value of the TorchScript
// custom operator API. In this instance, we've chosen to return
// a std::vector<std::string> as the salient data to preserve
// from the class.
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
-> std::vector<std::string> {
return self->stack_;
},
// __setstate__
// This function defines how to create a new instance of the C++
// class when we are deserializing. The function must take a
// single argument of the same type as the return value of
// `__getstate__`. The function must return an intrusive_ptr
// to a new instance of the C++ class, initialized however
// you would like given the serialized state.
[](std::vector<std::string> state)
-> c10::intrusive_ptr<MyStackClass<std::string>> {
// A convenient way to instantiate an object and get an
// intrusive_ptr to it is via `make_intrusive`. We use
// that here to allocate an instance of MyStackClass<std::string>
// and call the single-argument std::vector<std::string>
// constructor with the serialized state.
return c10::make_intrusive<MyStackClass<std::string>>(std::move(state));
});
// END def_pickle
#endif // NO_PICKLE

#else

// BEGIN pickle_binding
static auto testStack =
torch::class_<MyStackClass<std::string>>("my_classes", "MyStackClass")
.def(torch::init<std::vector<std::string>>())
.def("top", [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
return self->stack_.back();
})
.def("push", &MyStackClass<std::string>::push)
.def("pop", &MyStackClass<std::string>::pop)
.def("clone", &MyStackClass<std::string>::clone)
.def("merge", &MyStackClass<std::string>::merge)
// class_<>::def_pickle allows you to define the serialization
// and deserialization methods for your C++ class.
// Currently, we only support passing stateless lambda functions
// as arguments to def_pickle
.def_pickle(
// __getstate__
// This function defines what data structure should be produced
// when we serialize an instance of this class. The function
// must take a single `self` argument, which is an intrusive_ptr
// to the instance of the object. The function can return
// any type that is supported as a return value of the TorchScript
// custom operator API. In this instance, we've chosen to return
// a std::vector<std::string> as the salient data to preserve
// from the class.
[](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
-> std::vector<std::string> {
return self->stack_;
},
// __setstate__
// This function defines how to create a new instance of the C++
// class when we are deserializing. The function must take a
// single argument of the same type as the return value of
// `__getstate__`. The function must return an intrusive_ptr
// to a new instance of the C++ class, initialized however
// you would like given the serialized state.
[](std::vector<std::string> state)
-> c10::intrusive_ptr<MyStackClass<std::string>> {
// A convenient way to instantiate an object and get an
// intrusive_ptr to it is via `make_intrusive`. We use
// that here to allocate an instance of MyStackClass<std::string>
// and call the single-argument std::vector<std::string>
// constructor with the serialized state.
return c10::make_intrusive<MyStackClass<std::string>>(std::move(state));
});
// END pickle_binding

// BEGIN free_function
c10::intrusive_ptr<MyStackClass<std::string>> manipulate_instance(const c10::intrusive_ptr<MyStackClass<std::string>>& instance) {
instance->pop();
return instance;
// BEGIN def_free
m.def(
"foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y",
manipulate_instance
);
// END def_free
}

static auto instance_registry = torch::RegisterOperators().op(
torch::RegisterOperators::options()
.schema(
"foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y")
.catchAllKernel<decltype(manipulate_instance), &manipulate_instance>());
// END free_function

#endif