Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add gpu option to torchscript BART models #3979

Merged
merged 11 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions parlai/scripts/torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
def export_model(opt: Opt):
"""
Export a model to TorchScript so that inference can be run outside of ParlAI.

Currently, only CPU greedy-search inference on BART models is supported.
"""

if version.parse(torch.__version__) < version.parse("1.7.0"):
Expand All @@ -34,9 +32,12 @@ def export_model(opt: Opt):
from parlai.torchscript.modules import TorchScriptGreedySearch

overrides = {
"no_cuda": True, # TorchScripting is CPU only
"model_parallel": False, # model_parallel is not currently supported when TorchScripting
"model_parallel": False, # model_parallel is not currently supported when TorchScripting,
}
if opt.get("script_for_gpu", False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm instead of adding the --script-for-gpu flag, could we just reuse the value of the no_cuda key?

opt["no_cuda"] = False
else:
opt["no_cuda"] = True
if opt.get("script_module"):
script_module_name, script_class_name = opt["script_module"].split(":", 1)
script_module = importlib.import_module(script_module_name)
Expand All @@ -54,7 +55,10 @@ def export_model(opt: Opt):
original_module = script_class(agent)

# Script the module and save
scripted_module = torch.jit.script(script_class(agent))
instantiated = script_class(agent)
if not opt["no_cuda"]:
instantiated = instantiated.cuda()
scripted_module = torch.jit.script(instantiated)
with PathManager.open(opt["scripted_model_file"], "wb") as f:
torch.jit.save(scripted_module, f)

Expand Down Expand Up @@ -90,6 +94,13 @@ def setup_args() -> ParlaiParser:
default="parlai.torchscript.modules:TorchScriptGreedySearch",
help="module to TorchScript. Example: parlai.torchscript.modules:TorchScriptGreedySearch",
)
parser.add_argument(
"-sfg",
"--script-for-gpu",
type=bool,
default="parlai.torchscript.modules:TorchScriptGreedySearch",
help="should torchscript for gpu",
)
return parser


Expand Down
16 changes: 13 additions & 3 deletions parlai/torchscript/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, agent: TorchAgent):
super().__init__()

self.is_bart = agent.opt["model"] == "bart"

self.device = agent.model.encoder.embeddings.weight.device
# Dictionary/tokenization setup
for key, val in self.CAIRAOKE_DICT_PARAMS.items():
assert (
Expand Down Expand Up @@ -98,7 +98,10 @@ def __init__(self, agent: TorchAgent):
wrapped_model = ModelIncrStateFlattener(agent.model)

# Create sample inputs for tracing
sample_tokens = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long)
sample_tokens = torch.tensor(
[[1, 2, 3, 4, 5]], dtype=torch.long, device=self.device
)
sample_tokens = sample_tokens.to(self.device)
encoder_states = agent.model.encoder(sample_tokens)
initial_generations = self._get_initial_decoder_input(sample_tokens)
latent, initial_incr_state = wrapped_decoder(
Expand Down Expand Up @@ -137,6 +140,9 @@ def __init__(self, agent: TorchAgent):
wrapped_decoder, (generations, encoder_states, incr_state), strict=False
)

def get_device(self):
return self.encoder.embeddings.weight.device

def _get_initial_decoder_input(self, x: torch.Tensor) -> torch.Tensor:
"""
Workaround because we can't use TGM._get_initial_decoder_input() directly.
Expand All @@ -147,7 +153,9 @@ def _get_initial_decoder_input(self, x: torch.Tensor) -> torch.Tensor:
"""
bsz = x.size(0)
return (
torch.tensor(self.initial_decoder_input, dtype=torch.long)
torch.tensor(
self.initial_decoder_input, dtype=torch.long, device=self.device
)
.expand(bsz, len(self.initial_decoder_input))
.to(x.device)
)
Expand Down Expand Up @@ -213,6 +221,8 @@ def forward(self, context: str, max_len: int = 128) -> str:
)

# Pass through the encoder and decoder to generate tokens

flattened_text_vec = flattened_text_vec.to(self.get_device())
batch_text_vec = torch.unsqueeze(flattened_text_vec, dim=0) # Add batch dim
encoder_states = self.encoder(batch_text_vec)
generations = self._get_initial_decoder_input(batch_text_vec)
Expand Down