Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[Torchscript] Enable inference optimizations on the scripted model #4499

Merged
merged 1 commit into from
May 19, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions parlai/scripts/torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,23 @@ def export_model(opt: Opt):
instantiated = script_class(agent)
if not opt["no_cuda"]:
instantiated = instantiated.cuda()
scripted_module = torch.jit.script(instantiated)
if opt.get("enable_inference_optimizations"):
scripted_model = torch.jit.optimize_for_inference(
torch.jit.script(instantiated.eval())
)
else:
scripted_model = torch.jit.script(instantiated)

with PathManager.open(opt["scripted_model_file"], "wb") as f:
torch.jit.save(scripted_module, f)
torch.jit.save(scripted_model, f)

# Compare the original module to the scripted module against the test inputs
if len(opt["input"]) > 0:
inputs = opt["input"].split("|")
print("\nGenerating given the original unscripted module:")
_run_conversation(module=original_module, inputs=inputs)
print("\nGenerating given the scripted module:")
_run_conversation(module=scripted_module, inputs=inputs)
_run_conversation(module=scripted_model, inputs=inputs)


def setup_args() -> ParlaiParser:
Expand All @@ -90,6 +96,13 @@ def setup_args() -> ParlaiParser:
default="parlai.torchscript.modules:TorchScriptGreedySearch",
help="module to TorchScript. Example: parlai.torchscript.modules:TorchScriptGreedySearch",
)
parser.add_argument(
"-eio",
"--enable-inference-optimization",
type=bool,
default=False,
help="Enable inference optimizations on the scripted model.",
)
return parser


Expand Down