Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Rebase (#4499)
Browse files Browse the repository at this point in the history
[Torchscript] Enable inference optimizations on the scripted model

Enable inference optimizations with a flag

remove extra line

Use _ in argument name

Use _ in argument name

Use _ in argument name

address black lint and comment

address comment
  • Loading branch information
pooyanamini authored and kushalarora committed Jun 15, 2022
1 parent b8883e8 commit 147a954
Showing 1 changed file with 16 additions and 3 deletions.
19 changes: 16 additions & 3 deletions parlai/scripts/torchscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,23 @@ def export_model(opt: Opt):
instantiated = script_class(agent)
if not opt["no_cuda"]:
instantiated = instantiated.cuda()
scripted_module = torch.jit.script(instantiated)
if opt.get("enable_inference_optimizations"):
scripted_model = torch.jit.optimize_for_inference(
torch.jit.script(instantiated.eval())
)
else:
scripted_model = torch.jit.script(instantiated)

with PathManager.open(opt["scripted_model_file"], "wb") as f:
torch.jit.save(scripted_module, f)
torch.jit.save(scripted_model, f)

# Compare the original module to the scripted module against the test inputs
if len(opt["input"]) > 0:
inputs = opt["input"].split("|")
print("\nGenerating given the original unscripted module:")
_run_conversation(module=original_module, inputs=inputs)
print("\nGenerating given the scripted module:")
_run_conversation(module=scripted_module, inputs=inputs)
_run_conversation(module=scripted_model, inputs=inputs)


def setup_args() -> ParlaiParser:
Expand All @@ -90,6 +96,13 @@ def setup_args() -> ParlaiParser:
default="parlai.torchscript.modules:TorchScriptGreedySearch",
help="module to TorchScript. Example: parlai.torchscript.modules:TorchScriptGreedySearch",
)
parser.add_argument(
"-eio",
"--enable-inference-optimization",
type=bool,
default=False,
help="Enable inference optimizations on the scripted model.",
)
return parser


Expand Down

0 comments on commit 147a954

Please sign in to comment.