@@ -92,21 +92,21 @@ class definitions.
9292~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9393
9494We can implement part one as a pure python function as below. Notice, to
95- trace this function we add the ``@ torch.jit.trace`` decorator. Since the
96- trace requires a dummy input of the expected runtime type and shape, we
97- also include the ``torch.rand`` to generate a single valued torch
98- tensor.
95+ trace this function we call `` torch.jit.trace`` and pass in the function
96+ to be traced. Since the trace requires a dummy input of the expected
97+ runtime type and shape, we also include the ``torch.rand`` to generate a
98+ single valued torch tensor.
9999
100100"""
101101
102102import torch
103103
104- # This is how you define a traced function
105- # Pass in an example input to this decorator and then apply it to the function
106- @torch .jit .trace (torch .rand (()))
107- def traced_fn (x ):
104+ def fn (x ):
108105 return torch .abs (2 * x )
109106
107+ # This is how you define a traced function
108+ # Pass in both the function to be traced and an example input to ``torch.jit.trace``
109+ traced_fn = torch .jit .trace (fn , torch .rand (()))
110110
111111######################################################################
112112# Part 2 - Scripting a pure python function
@@ -124,7 +124,7 @@ def traced_fn(x):
124124@torch .jit .script
125125def script_fn (x ):
126126 z = torch .ones ([1 ], dtype = torch .int64 )
127- for i in range (x ):
127+ for i in range (int ( x ) ):
128128 z = z * (i + 1 )
129129 return z
130130
@@ -163,7 +163,7 @@ class ScriptModule(torch.jit.ScriptModule):
163163 @torch .jit .script_method
164164 def forward (self , x ):
165165 r = - x
166- if torch .fmod (x , 2.0 ) == 0.0 :
166+ if int ( torch .fmod (x , 2.0 ) ) == 0.0 :
167167 r = x / 2.0
168168 return r
169169
@@ -201,7 +201,7 @@ def __init__(self):
201201 # Modules must be attributes on the Module because if you want to trace
202202 # or script this Module, we must be able to inherit the submodules'
203203 # params.
204- self .traced_module = torch .jit .trace (torch .rand (()))( TracedModule ( ))
204+ self .traced_module = torch .jit .trace (TracedModule (), torch .rand (()))
205205 self .script_module = ScriptModule ()
206206
207207 print ('traced_fn graph' , traced_fn .graph )
@@ -244,8 +244,6 @@ def forward(self, x):
244244# Tracing the Top-Level Model
245245# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
246246#
247- # **NOTE:** Open issue https://github.com/pytorch/pytorch/issues/8755
248- #
249247# The last part of the example is to trace the top-level module, ``Net``.
250248# As mentioned previously, since the traced/scripted modules are
251249# attributes of Net, we are able to trace ``Net`` as it inherits the
@@ -254,11 +252,9 @@ def forward(self, x):
254252# Also, check out the graph that is created.
255253#
256254
257- # TODO: this fails with some weird bug https://github.com/pytorch/pytorch/issues/8755
258- #n_traced = torch.jit.trace(torch.tensor([5]))(n)
259- #print(n_traced(torch.tensor([5])))
260-
261- # TODO: print the graph of the traced module
255+ n_traced = torch .jit .trace (n , torch .tensor ([5 ]))
256+ print (n_traced (torch .tensor ([5 ])))
257+ print ('n_traced graph' , n_traced .graph )
262258
263259
264260######################################################################
0 commit comments