Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Add gpu option to torchscript BART models #3979

Merged
merged 11 commits into from
Aug 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,7 @@ commands:
name: check for bad links
working_directory: ~/ParlAI/
command: |
sudo apt-get update
sudo apt-get install linkchecker
pip install linkchecker
python -m http.server --directory website/build >/dev/null &
linkchecker http://localhost:8000/
kill %1
Expand Down
10 changes: 5 additions & 5 deletions parlai/scripts/torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
def export_model(opt: Opt):
"""
Export a model to TorchScript so that inference can be run outside of ParlAI.

Currently, only CPU greedy-search inference on BART models is supported.
"""

if version.parse(torch.__version__) < version.parse("1.7.0"):
Expand All @@ -34,8 +32,7 @@ def export_model(opt: Opt):
from parlai.torchscript.modules import TorchScriptGreedySearch

overrides = {
"no_cuda": True, # TorchScripting is CPU only
"model_parallel": False, # model_parallel is not currently supported when TorchScripting
"model_parallel": False # model_parallel is not currently supported when TorchScripting,
}
if opt.get("script_module"):
script_module_name, script_class_name = opt["script_module"].split(":", 1)
Expand All @@ -54,7 +51,10 @@ def export_model(opt: Opt):
original_module = script_class(agent)

# Script the module and save
scripted_module = torch.jit.script(script_class(agent))
instantiated = script_class(agent)
if not opt["no_cuda"]:
instantiated = instantiated.cuda()
scripted_module = torch.jit.script(instantiated)
with PathManager.open(opt["scripted_model_file"], "wb") as f:
torch.jit.save(scripted_module, f)

Expand Down
17 changes: 14 additions & 3 deletions parlai/torchscript/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self, agent: TorchAgent):
super().__init__()

self.is_bart = agent.opt["model"] == "bart"

self.device = agent.model.encoder.embeddings.weight.device
# Dictionary/tokenization setup
for key, val in self.CAIRAOKE_DICT_PARAMS.items():
assert (
Expand Down Expand Up @@ -98,7 +98,10 @@ def __init__(self, agent: TorchAgent):
wrapped_model = ModelIncrStateFlattener(agent.model)

# Create sample inputs for tracing
sample_tokens = torch.tensor([[1, 2, 3, 4, 5]], dtype=torch.long)
sample_tokens = torch.tensor(
[[1, 2, 3, 4, 5]], dtype=torch.long, device=self.device
)
sample_tokens = sample_tokens.to(self.device)
encoder_states = agent.model.encoder(sample_tokens)
initial_generations = self._get_initial_decoder_input(sample_tokens)
latent, initial_incr_state = wrapped_decoder(
Expand Down Expand Up @@ -137,6 +140,9 @@ def __init__(self, agent: TorchAgent):
wrapped_decoder, (generations, encoder_states, incr_state), strict=False
)

def get_device(self):
return self.encoder.embeddings.weight.device

def _get_initial_decoder_input(self, x: torch.Tensor) -> torch.Tensor:
"""
Workaround because we can't use TGM._get_initial_decoder_input() directly.
Expand All @@ -147,7 +153,9 @@ def _get_initial_decoder_input(self, x: torch.Tensor) -> torch.Tensor:
"""
bsz = x.size(0)
return (
torch.tensor(self.initial_decoder_input, dtype=torch.long)
torch.tensor(
self.initial_decoder_input, dtype=torch.long, device=self.device
)
.expand(bsz, len(self.initial_decoder_input))
.to(x.device)
)
Expand Down Expand Up @@ -213,6 +221,8 @@ def forward(self, context: str, max_len: int = 128) -> str:
)

# Pass through the encoder and decoder to generate tokens

flattened_text_vec = flattened_text_vec.to(self.get_device())
batch_text_vec = torch.unsqueeze(flattened_text_vec, dim=0) # Add batch dim
encoder_states = self.encoder(batch_text_vec)
generations = self._get_initial_decoder_input(batch_text_vec)
Expand Down Expand Up @@ -255,6 +265,7 @@ def forward(self, context: str, max_len: int = 128) -> str:
def postprocess_output_generations(self, label: str) -> str:
"""
Post-process the model output.

Returns the model output by default, override to add custom logic
"""
return label
Expand Down
78 changes: 52 additions & 26 deletions tests/nightly/gpu/test_torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,29 +38,27 @@ def test_token_splitter(self):
tasks = ['taskmaster2', 'convai2']
compiled_pattern = regex.compile(Gpt2BpeHelper.PATTERN)

with testing_utils.tempdir() as tmpdir:
for task in tasks:
opt = TorchScript.setup_args().parse_kwargs(
task=task, datatype='train:ordered'
)
agent = RepeatLabelAgent(opt)
# TODO(roller): make a proper create_teacher helper
teacher = create_task(opt, agent).get_task_agent()
num_examples = teacher.num_examples()

print(
f'\nStarting to test {num_examples:d} examples for the '
f'{task} task.'
)
for idx, message in enumerate(teacher):
if idx % 10000 == 0:
print(f'Testing example #{idx:d}.')
text = message['text']
canonical_tokens = regex.findall(compiled_pattern, text)
scriptable_tokens = ScriptableGpt2BpeHelper.findall(text)
self.assertEqual(canonical_tokens, scriptable_tokens)
if idx + 1 == num_examples:
break
for task in tasks:
opt = TorchScript.setup_args().parse_kwargs(
task=task, datatype='train:ordered'
)
agent = RepeatLabelAgent(opt)
# TODO(roller): make a proper create_teacher helper
teacher = create_task(opt, agent).get_task_agent()
num_examples = teacher.num_examples()

print(
f'\nStarting to test {num_examples:d} examples for the ' f'{task} task.'
)
for idx, message in enumerate(teacher):
if idx % 10000 == 0:
print(f'Testing example #{idx:d}.')
text = message['text']
canonical_tokens = regex.findall(compiled_pattern, text)
scriptable_tokens = ScriptableGpt2BpeHelper.findall(text)
self.assertEqual(canonical_tokens, scriptable_tokens)
if idx + 1 == num_examples:
break

def test_special_tokenization(self):
from parlai.core.dict import DictionaryAgent
Expand Down Expand Up @@ -102,7 +100,6 @@ def test_special_tokenization(self):
tokenized = sda.txt2vec(text)
assert len(tokenized) == 15
assert sda.vec2txt(tokenized) == text
nice_tok = [sda.ind2tok[i] for i in tokenized]

orig_dict = DictionaryAgent(opt)
orig_dict.add_additional_special_tokens(SPECIAL)
Expand All @@ -126,7 +123,6 @@ def test_special_tokenization(self):
assert len(special_tokenized) == 15
assert sda.vec2txt(special_tokenized) == text
assert special_tokenized != tokenized
nice_specialtok = [sda.ind2tok[i] for i in special_tokenized]

def test_torchscript_agent(self):
"""
Expand All @@ -143,7 +139,7 @@ def test_torchscript_agent(self):

# Export the BART model
export_opt = TorchScript.setup_args().parse_kwargs(
model='bart', scripted_model_file=scripted_model_file
model='bart', scripted_model_file=scripted_model_file, no_cuda=True
)
TorchScript(export_opt).run()

Expand All @@ -157,6 +153,36 @@ def test_torchscript_agent(self):
act = bart.act()
self.assertEqual(act['text'], test_phrase)

def test_gpu_torchscript_agent(self):
"""
Test exporting a model to TorchScript for GPU and then testing it on sample
data.
"""

from parlai.scripts.torchscript import TorchScript

test_phrase = "Don't have a cow, man!" # From test_bart.py

with testing_utils.tempdir() as tmpdir:

scripted_model_file = os.path.join(tmpdir, 'scripted_model.pt')

# Export the BART model for GPU
export_opt = TorchScript.setup_args().parse_kwargs(
model='bart', scripted_model_file=scripted_model_file, no_cuda=False
)
TorchScript(export_opt).run()

# Test the scripted GPU BART model
scripted_opt = ParlaiParser(True, True).parse_kwargs(
model='parlai.torchscript.agents:TorchScriptAgent',
model_file=scripted_model_file,
)
bart = create_agent(scripted_opt)
bart.observe({'text': test_phrase, 'episode_done': True})
act = bart.act()
self.assertEqual(act['text'], test_phrase)


if __name__ == '__main__':
unittest.main()