1
1
# -*- coding: utf-8 -*-
2
2
"""
3
- Deploying a Seq2Seq Model with the Hybrid Frontend
3
+ Deploying a Seq2Seq Model with TorchScript
4
4
==================================================
5
5
**Author:** `Matthew Inkawhich <https://github.com/MatthewInkawhich>`_
6
6
"""
7
7
8
8
9
9
######################################################################
10
10
# This tutorial will walk through the process of transitioning a
11
- # sequence-to-sequence model to Torch Script using PyTorch’s Hybrid
12
- # Frontend . The model that we will convert is the chatbot model from the
13
- # `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
11
+ # sequence-to-sequence model to TorchScript using the TorchScript
12
+ # API . The model that we will convert is the chatbot model from the
13
+ # `Chatbot tutorial <https://pytorch.org/tutorials/beginner/chatbot_tutorial.html>`__.
14
14
# You can either treat this tutorial as a “Part 2” to the Chatbot tutorial
15
15
# and deploy your own pretrained model, or you can start with this
16
16
# document and use a pretrained model that we host. In the latter case,
17
17
# you can reference the original Chatbot tutorial for details
18
18
# regarding data preprocessing, model theory and definition, and model
19
19
# training.
20
20
#
21
- # What is the Hybrid Frontend ?
21
+ # What is TorchScript ?
22
22
# ----------------------------
23
23
#
24
24
# During the research and development phase of a deep learning-based
34
34
# to target highly optimized hardware architectures. Also, a graph-based
35
35
# representation enables framework-agnostic model exportation. PyTorch
36
36
# provides mechanisms for incrementally converting eager-mode code into
37
- # Torch Script , a statically analyzable and optimizable subset of Python
37
+ # TorchScript , a statically analyzable and optimizable subset of Python
38
38
# that Torch uses to represent deep learning programs independently from
39
39
# the Python runtime.
40
40
#
41
- # The API for converting eager-mode PyTorch programs into Torch Script is
41
+ # The API for converting eager-mode PyTorch programs into TorchScript is
42
42
# found in the torch.jit module. This module has two core modalities for
43
- # converting an eager-mode model to a Torch Script graph representation:
43
+ # converting an eager-mode model to a TorchScript graph representation:
44
44
# **tracing** and **scripting**. The ``torch.jit.trace`` function takes a
45
45
# module or function and a set of example inputs. It then runs the example
46
46
# input through the function or module while tracing the computational
52
52
# operations called along the execution route taken by the example input
53
53
# will be recorded. In other words, the control flow itself is not
54
54
# captured. To convert modules and functions containing data-dependent
55
- # control flow, a **scripting** mechanism is provided. Scripting
56
- # explicitly converts the module or function code to Torch Script,
57
- # including all possible control flow routes. To use script mode, be sure
58
- # to inherit from the the ``torch.jit.ScriptModule`` base class (instead
59
- # of ``torch.nn.Module``) and add a ``torch.jit.script`` decorator to your
60
- # Python function or a ``torch.jit.script_method`` decorator to your
61
- # module’s methods. The one caveat with using scripting is that it only
62
- # supports a restricted subset of Python. For all details relating to the
63
- # supported features, see the Torch Script `language
64
- # reference <https://pytorch.org/docs/master/jit.html>`__. To provide the
65
- # maximum flexibility, the modes of Torch Script can be composed to
66
- # represent your whole program, and these techniques can be applied
67
- # incrementally.
55
+ # control flow, a **scripting** mechanism is provided. The
56
+ # ``torch.jit.script`` function/decorator takes a module or function and
57
+ # does not requires example inputs. Scripting then explicitly converts
58
+ # the module or function code to TorchScript, including all control flows.
59
+ # One caveat with using scripting is that it only supports a subset of
60
+ # Python, so you might need to rewrite the code to make it compatible
61
+ # with the TorchScript syntax.
62
+ #
63
+ # For all details relating to the supported features, see the `TorchScript
64
+ # language reference <https://pytorch.org/docs/master/jit.html>`__.
65
+ # To provide the maximum flexibility, you can also mix tracing and scripting
66
+ # modes together to represent your whole program, and these techniques can
67
+ # be applied incrementally.
68
68
#
69
69
# .. figure:: /_static/img/chatbot/pytorch_workflow.png
70
70
# :align: center
@@ -273,7 +273,7 @@ def indexesFromSentence(voc, sentence):
273
273
# used by the ``torch.nn.utils.rnn.pack_padded_sequence`` function when
274
274
# padding.
275
275
#
276
- # Hybrid Frontend Notes:
276
+ # TorchScript Notes:
277
277
# ~~~~~~~~~~~~~~~~~~~~~~
278
278
#
279
279
# Since the encoder’s ``forward`` function does not contain any
@@ -296,6 +296,7 @@ def __init__(self, hidden_size, embedding, n_layers=1, dropout=0):
296
296
dropout = (0 if n_layers == 1 else dropout ), bidirectional = True )
297
297
298
298
def forward (self , input_seq , input_lengths , hidden = None ):
299
+ # type: (Tensor, Tensor, Optional[Tensor]) -> Tuple[Tensor, Tensor]
299
300
# Convert word indexes to embeddings
300
301
embedded = self .embedding (input_seq )
301
302
# Pack padded batch of sequences for RNN module
@@ -325,18 +326,18 @@ def forward(self, input_seq, input_lengths, hidden=None):
325
326
#
326
327
327
328
# Luong attention layer
328
- class Attn (torch . nn .Module ):
329
+ class Attn (nn .Module ):
329
330
def __init__ (self , method , hidden_size ):
330
331
super (Attn , self ).__init__ ()
331
332
self .method = method
332
333
if self .method not in ['dot' , 'general' , 'concat' ]:
333
334
raise ValueError (self .method , "is not an appropriate attention method." )
334
335
self .hidden_size = hidden_size
335
336
if self .method == 'general' :
336
- self .attn = torch . nn .Linear (self .hidden_size , hidden_size )
337
+ self .attn = nn .Linear (self .hidden_size , hidden_size )
337
338
elif self .method == 'concat' :
338
- self .attn = torch . nn .Linear (self .hidden_size * 2 , hidden_size )
339
- self .v = torch . nn .Parameter (torch .FloatTensor (hidden_size ))
339
+ self .attn = nn .Linear (self .hidden_size * 2 , hidden_size )
340
+ self .v = nn .Parameter (torch .FloatTensor (hidden_size ))
340
341
341
342
def dot_score (self , hidden , encoder_output ):
342
343
return torch .sum (hidden * encoder_output , dim = 2 )
@@ -383,14 +384,14 @@ def forward(self, hidden, encoder_outputs):
383
384
# weighted sum indicating what parts of the encoder’s output to pay
384
385
# attention to. From here, we use a linear layer and softmax normalization
385
386
# to select the next word in the output sequence.
386
- #
387
- # Hybrid Frontend Notes:
387
+
388
+ # TorchScript Notes:
388
389
# ~~~~~~~~~~~~~~~~~~~~~~
389
390
#
390
391
# Similarly to the ``EncoderRNN``, this module does not contain any
391
392
# data-dependent control flow. Therefore, we can once again use
392
- # **tracing** to convert this model to Torch Script after it is
393
- # initialized and its parameters are loaded.
393
+ # **tracing** to convert this model to TorchScript after it
394
+ # is initialized and its parameters are loaded.
394
395
#
395
396
396
397
class LuongAttnDecoderRNN (nn .Module ):
@@ -465,18 +466,18 @@ def forward(self, input_step, last_hidden, encoder_outputs):
465
466
# terminates either if the ``decoded_words`` list has reached a length of
466
467
# *MAX_LENGTH* or if the predicted word is the *EOS_token*.
467
468
#
468
- # Hybrid Frontend Notes:
469
+ # TorchScript Notes:
469
470
# ~~~~~~~~~~~~~~~~~~~~~~
470
471
#
471
472
# The ``forward`` method of this module involves iterating over the range
472
473
# of :math:`[0, max\_length)` when decoding an output sequence one word at
473
474
# a time. Because of this, we should use **scripting** to convert this
474
- # module to Torch Script . Unlike with our encoder and decoder models,
475
+ # module to TorchScript . Unlike with our encoder and decoder models,
475
476
# which we can trace, we must make some necessary changes to the
476
477
# ``GreedySearchDecoder`` module in order to initialize an object without
477
478
# error. In other words, we must ensure that our module adheres to the
478
- # rules of the scripting mechanism, and does not utilize any language
479
- # features outside of the subset of Python that Torch Script includes.
479
+ # rules of the TorchScript mechanism, and does not utilize any language
480
+ # features outside of the subset of Python that TorchScript includes.
480
481
#
481
482
# To get an idea of some manipulations that may be required, we will go
482
483
# over the diffs between the ``GreedySearchDecoder`` implementation from
@@ -491,12 +492,6 @@ def forward(self, input_step, last_hidden, encoder_outputs):
491
492
# Changes:
492
493
# ^^^^^^^^
493
494
#
494
- # - ``nn.Module`` -> ``torch.jit.ScriptModule``
495
- #
496
- # - In order to use PyTorch’s scripting mechanism on a module, that
497
- # module must inherit from the ``torch.jit.ScriptModule``.
498
- #
499
- #
500
495
# - Added ``decoder_n_layers`` to the constructor arguments
501
496
#
502
497
# - This change stems from the fact that the encoder and decoder
@@ -523,16 +518,9 @@ def forward(self, input_step, last_hidden, encoder_outputs):
523
518
# ``self._SOS_token``.
524
519
#
525
520
#
526
- # - Add the ``torch.jit.script_method`` decorator to the ``forward``
527
- # method
528
- #
529
- # - Adding this decorator lets the JIT compiler know that the function
530
- # that it is decorating should be scripted.
531
- #
532
- #
533
521
# - Enforce types of ``forward`` method arguments
534
522
#
535
- # - By default, all parameters to a Torch Script function are assumed
523
+ # - By default, all parameters to a TorchScript function are assumed
536
524
# to be Tensor. If we need to pass an argument of a different type,
537
525
# we can use function type annotations as introduced in `PEP
538
526
# 3107 <https://www.python.org/dev/peps/pep-3107/>`__. In addition,
@@ -553,7 +541,7 @@ def forward(self, input_step, last_hidden, encoder_outputs):
553
541
# ``self._SOS_token``.
554
542
#
555
543
556
- class GreedySearchDecoder (torch . jit . ScriptModule ):
544
+ class GreedySearchDecoder (nn . Module ):
557
545
def __init__ (self , encoder , decoder , decoder_n_layers ):
558
546
super (GreedySearchDecoder , self ).__init__ ()
559
547
self .encoder = encoder
@@ -564,7 +552,6 @@ def __init__(self, encoder, decoder, decoder_n_layers):
564
552
565
553
__constants__ = ['_device' , '_SOS_token' , '_decoder_n_layers' ]
566
554
567
- @torch .jit .script_method
568
555
def forward (self , input_seq : torch .Tensor , input_length : torch .Tensor , max_length : int ):
569
556
# Forward input through encoder model
570
557
encoder_outputs , encoder_hidden = self .encoder (input_seq , input_length )
@@ -613,7 +600,7 @@ def forward(self, input_seq : torch.Tensor, input_length : torch.Tensor, max_len
613
600
# an argument, normalizes it, evaluates it, and prints the response.
614
601
#
615
602
616
- def evaluate (encoder , decoder , searcher , voc , sentence , max_length = MAX_LENGTH ):
603
+ def evaluate (searcher , voc , sentence , max_length = MAX_LENGTH ):
617
604
### Format input sentence as a batch
618
605
# words -> indexes
619
606
indexes_batch = [indexesFromSentence (voc , sentence )]
@@ -632,7 +619,7 @@ def evaluate(encoder, decoder, searcher, voc, sentence, max_length=MAX_LENGTH):
632
619
633
620
634
621
# Evaluate inputs from user input (stdin)
635
- def evaluateInput (encoder , decoder , searcher , voc ):
622
+ def evaluateInput (searcher , voc ):
636
623
input_sentence = ''
637
624
while (1 ):
638
625
try :
@@ -643,7 +630,7 @@ def evaluateInput(encoder, decoder, searcher, voc):
643
630
# Normalize sentence
644
631
input_sentence = normalizeString (input_sentence )
645
632
# Evaluate sentence
646
- output_words = evaluate (encoder , decoder , searcher , voc , input_sentence )
633
+ output_words = evaluate (searcher , voc , input_sentence )
647
634
# Format and print response sentence
648
635
output_words [:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD' )]
649
636
print ('Bot:' , ' ' .join (output_words ))
@@ -652,12 +639,12 @@ def evaluateInput(encoder, decoder, searcher, voc):
652
639
print ("Error: Encountered unknown word." )
653
640
654
641
# Normalize input sentence and call evaluate()
655
- def evaluateExample (sentence , encoder , decoder , searcher , voc ):
642
+ def evaluateExample (sentence , searcher , voc ):
656
643
print ("> " + sentence )
657
644
# Normalize sentence
658
645
input_sentence = normalizeString (sentence )
659
646
# Evaluate sentence
660
- output_words = evaluate (encoder , decoder , searcher , voc , input_sentence )
647
+ output_words = evaluate (searcher , voc , input_sentence )
661
648
output_words [:] = [x for x in output_words if not (x == 'EOS' or x == 'PAD' )]
662
649
print ('Bot:' , ' ' .join (output_words ))
663
650
@@ -700,14 +687,17 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
700
687
# ``checkpoint = torch.load(loadFilename, map_location=torch.device('cpu'))``
701
688
# line.
702
689
#
703
- # Hybrid Frontend Notes:
690
+ # TorchScript Notes:
704
691
# ~~~~~~~~~~~~~~~~~~~~~~
705
692
#
706
693
# Notice that we initialize and load parameters into our encoder and
707
- # decoder models as usual. Also, we must call ``.to(device)`` to set the
708
- # device options of the models and ``.eval()`` to set the dropout layers
709
- # to test mode **before** we trace the models. ``TracedModule`` objects do
710
- # not inherit the ``to`` or ``eval`` methods.
694
+ # decoder models as usual. If you are using tracing mode(`torch.jit.trace`)
695
+ # for some part of your models, you must call .to(device) to set the device
696
+ # options of the models and .eval() to set the dropout layers to test mode
697
+ # **before** tracing the models. `TracedModule` objects do not inherit the
698
+ # ``to`` or ``eval`` methods. Since in this tutorial we are only using
699
+ # scripting instead of tracing, we only need to do this before we do
700
+ # evaluation (which is the same as we normally do in eager mode).
711
701
#
712
702
713
703
save_dir = os .path .join ("data" , "save" )
@@ -766,16 +756,14 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
766
756
767
757
768
758
######################################################################
769
- # Convert Model to Torch Script
759
+ # Convert Model to TorchScript
770
760
# -----------------------------
771
761
#
772
762
# Encoder
773
763
# ~~~~~~~
774
764
#
775
- # As previously mentioned, to convert the encoder model to Torch Script,
776
- # we use **tracing**. Tracing any module requires running an example input
777
- # through the model’s ``forward`` method and trace the computational graph
778
- # that the data encounters. The encoder model takes an input sequence and
765
+ # As previously mentioned, to convert the encoder model to TorchScript,
766
+ # we use **scripting**. The encoder model takes an input sequence and
779
767
# a corresponding lengths tensor. Therefore, we create an example input
780
768
# sequence tensor ``test_seq``, which is of appropriate size (MAX_LENGTH,
781
769
# 1), contains numbers in the appropriate range
@@ -803,13 +791,13 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
803
791
# ~~~~~~~~~~~~~~~~~~~
804
792
#
805
793
# Recall that we scripted our searcher module due to the presence of
806
- # data-dependent control flow. In the case of scripting, we do the
807
- # conversion work up front by adding the decorator and making sure the
808
- # implementation complies with scripting rules . We initialize the scripted
809
- # searcher the same way that we would initialize an un-scripted variant.
794
+ # data-dependent control flow. In the case of scripting, we do necessary
795
+ # language changes to make sure the implementation complies with
796
+ # TorchScript . We initialize the scripted searcher the same way that we
797
+ # would initialize an un-scripted variant.
810
798
#
811
799
812
- ### Convert encoder model
800
+ ### Compile the whole greedy search model to TorchScript model
813
801
# Create artificial inputs
814
802
test_seq = torch .LongTensor (MAX_LENGTH , 1 ).random_ (0 , voc .num_words ).to (device )
815
803
test_seq_length = torch .LongTensor ([test_seq .size ()[0 ]]).to (device )
@@ -824,19 +812,21 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
824
812
# Trace the model
825
813
traced_decoder = torch .jit .trace (decoder , (test_decoder_input , test_decoder_hidden , test_encoder_outputs ))
826
814
827
- ### Initialize searcher module
828
- scripted_searcher = GreedySearchDecoder (traced_encoder , traced_decoder , decoder .n_layers )
815
+ ### Initialize searcher module by wrapping ``torch.jit.script`` call
816
+ scripted_searcher = torch .jit .script (GreedySearchDecoder (traced_encoder , traced_decoder , decoder .n_layers ))
817
+
818
+
829
819
830
820
831
821
######################################################################
832
822
# Print Graphs
833
823
# ------------
834
824
#
835
- # Now that our models are in Torch Script form, we can print the graphs of
825
+ # Now that our models are in TorchScript form, we can print the graphs of
836
826
# each to ensure that we captured the computational graph appropriately.
837
- # Since our ``scripted_searcher`` contains our ``traced_encoder`` and
838
- # ``traced_decoder``, these graphs will print inline.
839
- #
827
+ # Since TorchScript allow us to recursively compile the whole model
828
+ # hierarchy and inline the ``encoder`` and ``decoder`` graph into a single
829
+ # graph, we just need to print the `scripted_searcher` graph
840
830
841
831
print ('scripted_searcher graph:\n ' , scripted_searcher .graph )
842
832
@@ -845,19 +835,25 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
845
835
# Run Evaluation
846
836
# --------------
847
837
#
848
- # Finally, we will run evaluation of the chatbot model using the Torch
849
- # Script models. If converted correctly, the models will behave exactly as
850
- # they would in their eager-mode representation.
838
+ # Finally, we will run evaluation of the chatbot model using the TorchScript
839
+ # models. If converted correctly, the models will behave exactly as they
840
+ # would in their eager-mode representation.
851
841
#
852
842
# By default, we evaluate a few common query sentences. If you want to
853
843
# chat with the bot yourself, uncomment the ``evaluateInput`` line and
854
844
# give it a spin.
855
845
#
856
846
847
+
848
+ # Use appropriate device
849
+ scripted_searcher .to (device )
850
+ # Set dropout layers to eval mode
851
+ scripted_searcher .eval ()
852
+
857
853
# Evaluate examples
858
854
sentences = ["hello" , "what's up?" , "who are you?" , "where am I?" , "where are you from?" ]
859
855
for s in sentences :
860
- evaluateExample (s , traced_encoder , traced_decoder , scripted_searcher , voc )
856
+ evaluateExample (s , scripted_searcher , voc )
861
857
862
858
# Evaluate your input
863
859
#evaluateInput(traced_encoder, traced_decoder, scripted_searcher, voc)
@@ -867,7 +863,7 @@ def evaluateExample(sentence, encoder, decoder, searcher, voc):
867
863
# Save Model
868
864
# ----------
869
865
#
870
- # Now that we have successfully converted our model to Torch Script , we
866
+ # Now that we have successfully converted our model to TorchScript , we
871
867
# will serialize it for use in a non-Python deployment environment. To do
872
868
# this, we can simply save our ``scripted_searcher`` module, as this is
873
869
# the user-facing interface for running inference against the chatbot
0 commit comments