@@ -92,28 +92,30 @@ vanilla Pytorch model::
92
92
93
93
Because the ``forward `` method of this module uses control flow that is
94
94
dependent on the input, it is not suitable for tracing. Instead, we can convert
95
- it to a ``ScriptModule `` by subclassing it from ``torch.jit.ScriptModule `` and
96
- adding a ``@torch.jit.script_method `` annotation to the model's ``forward ``
97
- method::
98
-
99
- import torch
100
-
101
- class MyModule(torch.jit.ScriptModule):
102
- def __init__(self, N, M):
103
- super(MyModule, self).__init__()
104
- self.weight = torch.nn.Parameter(torch.rand(N, M))
105
-
106
- @torch.jit.script_method
107
- def forward(self, input):
108
- if bool(input.sum() > 0):
109
- output = self.weight.mv(input)
110
- else:
111
- output = self.weight + input
112
- return output
113
-
114
- my_script_module = MyModule(2, 3)
115
-
116
- Creating a new ``MyModule `` object now directly produces an instance of
95
+ it to a ``ScriptModule ``.
96
+ In order to convert the module to the ``ScriptModule ``, one needs to
97
+ compile the module with ``torch.jit.script `` as follows::
98
+
99
+ class MyModule(torch.nn.Module):
100
+ def __init__(self, N, M):
101
+ super(MyModule, self).__init__()
102
+ self.weight = torch.nn.Parameter(torch.rand(N, M))
103
+
104
+ def forward(self, input):
105
+ if input.sum() > 0:
106
+ output = self.weight.mv(input)
107
+ else:
108
+ output = self.weight + input
109
+ return output
110
+
111
+ my_module = MyModule(10,20)
112
+ sm = torch.jit.script(my_module)
113
+
114
+ If you need to exclude some methods in your ``nn.Module ``
115
+ because they use Python features that TorchScript doesn't support yet,
116
+ you could annotate those with ``@torch.jit.ignore ``
117
+
118
+ ``my_module `` is an instance of
117
119
``ScriptModule `` that is ready for serialization.
118
120
119
121
Step 2: Serializing Your Script Module to a File
@@ -152,32 +154,38 @@ do:
152
154
153
155
.. code-block :: cpp
154
156
155
- #include <torch/script.h> // One-stop header.
156
-
157
- #include <iostream>
158
- #include <memory>
159
-
160
- int main(int argc, const char* argv[]) {
161
- if (argc != 2) {
162
- std::cerr << "usage: example-app <path-to-exported-script-module>\n";
163
- return -1;
157
+ #include <torch/script.h> // One-stop header.
158
+
159
+ #include <iostream>
160
+ #include <memory>
161
+
162
+ int main(int argc, const char* argv[]) {
163
+ if (argc != 2) {
164
+ std::cerr << "usage: example-app <path-to-exported-script-module>\n";
165
+ return -1;
166
+ }
167
+
168
+
169
+ torch::jit::script::Module module;
170
+ try {
171
+ // Deserialize the ScriptModule from a file using torch::jit::load().
172
+ module = torch::jit::load(argv[1]);
173
+ }
174
+ catch (const c10::Error& e) {
175
+ std::cerr << "error loading the model\n";
176
+ return -1;
177
+ }
178
+
179
+ std::cout << "ok\n";
164
180
}
165
181
166
- // Deserialize the ScriptModule from a file using torch::jit::load().
167
- std::shared_ptr<torch::jit::script::Module> module = torch::jit::load(argv[1]);
168
-
169
- assert(module != nullptr);
170
- std::cout << "ok\n";
171
- }
172
182
173
183
The ``<torch/script.h> `` header encompasses all relevant includes from the
174
184
LibTorch library necessary to run the example. Our application accepts the file
175
185
path to a serialized PyTorch ``ScriptModule `` as its only command line argument
176
186
and then proceeds to deserialize the module using the ``torch::jit::load() ``
177
- function, which takes this file path as input. In return we receive a shared
178
- pointer to a ``torch::jit::script::Module ``, the equivalent to a
179
- ``torch.jit.ScriptModule `` in C++. For now, we only verify that this pointer is
180
- not null. We will examine how to execute it in a moment.
187
+ function, which takes this file path as input. In return we receive a ``torch::jit::script::Module ``
188
+ object. We will examine how to execute it in a moment.
181
189
182
190
Depending on LibTorch and Building the Application
183
191
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
@@ -300,8 +308,7 @@ application's ``main()`` function:
300
308
inputs.push_back(torch::ones({1, 3, 224, 224}));
301
309
302
310
// Execute the model and turn its output into a tensor.
303
- at::Tensor output = module->forward(inputs).toTensor ();
304
-
311
+ at::Tensor output = module.forward(inputs).toTensor ();
305
312
std::cout << output.slice(/*dim=*/1, /*start=*/0, /*end=*/5) << '\n';
306
313
307
314
The first two lines set up the inputs to our model. We create a vector of
@@ -344,7 +351,7 @@ Looks like a good match!
344
351
345
352
.. tip::
346
353
347
- To move your model to GPU memory, you can write ` ` model-> to(at::kCUDA);` ` .
354
+ To move your model to GPU memory, you can write ` ` model. to(at::kCUDA);` ` .
348
355
Make sure the inputs to a model living in CUDA memory are also in CUDA memory
349
356
by calling ` ` tensor.to(at::kCUDA)` ` , which will return a new tensor in CUDA
350
357
memory.
0 commit comments