Skip to content

Commit

Permalink
support loading & running scripted models after scripting
Browse files Browse the repository at this point in the history
Summary: workaround upstream pytorch issue

Reviewed By: alexander-kirillov

Differential Revision: D27328284

fbshipit-source-id: 42d6ccf3bd19b1b3fe591278ec56adb2938d9742
  • Loading branch information
ppwwyyxx authored and facebook-github-bot committed Mar 25, 2021
1 parent 34bd206 commit 22c5c01
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 6 deletions.
28 changes: 27 additions & 1 deletion detectron2/export/torchscript_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# need some explicit imports due to https://github.com/pytorch/pytorch/issues/38964
import detectron2 # noqa F401
from detectron2.structures import Instances
from detectron2.structures import Boxes, Instances
from detectron2.utils.env import _import_file

_counter = 0
Expand Down Expand Up @@ -229,6 +229,32 @@ def __getitem__(self, item) -> "{cls_name}":
return ret
"""
)

# support method `get_fields()`
lines.append(
"""
def get_fields(self) -> Dict[str, Tensor]:
ret = {}
"""
)
for f in fields:
if f.type_ == Boxes:
stmt = "t.tensor"
elif f.type_ == torch.Tensor:
stmt = "t"
else:
stmt = f'assert False, "unsupported type {str(f.type_)}"'
lines.append(
f"""
t = self._{f.name}
if t is not None:
ret["{f.name}"] = {stmt}
"""
)
lines.append(
"""
return ret"""
)
return cls_name, os.linesep.join(lines)


Expand Down
4 changes: 1 addition & 3 deletions docs/tutorials/deployment.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ We currently support the following combination and each has some limitations:
+----------------------------+-------------+-------------+-----------------------------+
| **Runtime** | PyTorch | PyTorch | Caffe2, PyTorch |
+----------------------------+-------------+-------------+-----------------------------+
| C++/Python inference | ✅ | ❌ (WIP_) | ✅ |
| C++/Python inference | ✅ | | ✅ |
+----------------------------+-------------+-------------+-----------------------------+
| Dynamic resolution | ✅ | ✅ | ✅ |
+----------------------------+-------------+-------------+-----------------------------+
Expand All @@ -43,8 +43,6 @@ We currently support the following combination and each has some limitations:
| PointRend R-CNN | ✅ | ❌ | ❌ |
+----------------------------+-------------+-------------+-----------------------------+
.. _WIP: https://github.com/pytorch/pytorch/issues/46944
```

We don't plan to work on additional support for other formats/runtime, but contributions are welcome.
Expand Down
28 changes: 26 additions & 2 deletions tools/deploy/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@
# Copyright (c) Facebook, Inc. and its affiliates.
import argparse
import os
from typing import Dict, List, Tuple
import onnx
import torch
from torch import Tensor
from torch import Tensor, nn

from detectron2.checkpoint import DetectionCheckpointer
from detectron2.config import get_cfg
Expand Down Expand Up @@ -70,7 +71,30 @@ def export_scripting(torch_model):
"pred_keypoint_heatmaps": torch.Tensor,
}
assert args.format == "torchscript", "Scripting only supports torchscript format."
ts_model = scripting_with_instances(torch_model, fields)

class ScriptableAdapterBase(nn.Module):
# Use this adapter to workaround https://github.com/pytorch/pytorch/issues/46944
# by not retuning instances but dicts. Otherwise the exported model is not deployable
def __init__(self):
super().__init__()
self.model = torch_model
self.eval()

if isinstance(torch_model, GeneralizedRCNN):

class ScriptableAdapter(ScriptableAdapterBase):
def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]:
instances = self.model.inference(inputs, do_postprocess=False)
return [i.get_fields() for i in instances]

else:

class ScriptableAdapter(ScriptableAdapterBase):
def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]:
instances = self.model(inputs)
return [i.get_fields() for i in instances]

ts_model = scripting_with_instances(ScriptableAdapter(), fields)
with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f:
torch.jit.save(ts_model, f)
dump_torchscript_IR(ts_model, args.output)
Expand Down

0 comments on commit 22c5c01

Please sign in to comment.