From 22c5c011f1a8f32d0c96d5764b3e1ce71c2b9db5 Mon Sep 17 00:00:00 2001 From: Yuxin Wu Date: Thu, 25 Mar 2021 11:28:27 -0700 Subject: [PATCH] support loading & running scripted models after scripting Summary: workaround upstream pytorch issue Reviewed By: alexander-kirillov Differential Revision: D27328284 fbshipit-source-id: 42d6ccf3bd19b1b3fe591278ec56adb2938d9742 --- detectron2/export/torchscript_patch.py | 28 +++++++++++++++++++++++++- docs/tutorials/deployment.md | 4 +--- tools/deploy/export_model.py | 28 ++++++++++++++++++++++++-- 3 files changed, 54 insertions(+), 6 deletions(-) diff --git a/detectron2/export/torchscript_patch.py b/detectron2/export/torchscript_patch.py index 8c0d3eb935..618e7e0c4b 100644 --- a/detectron2/export/torchscript_patch.py +++ b/detectron2/export/torchscript_patch.py @@ -11,7 +11,7 @@ # need some explicit imports due to https://github.com/pytorch/pytorch/issues/38964 import detectron2 # noqa F401 -from detectron2.structures import Instances +from detectron2.structures import Boxes, Instances from detectron2.utils.env import _import_file _counter = 0 @@ -229,6 +229,32 @@ def __getitem__(self, item) -> "{cls_name}": return ret """ ) + + # support method `get_fields()` + lines.append( + """ + def get_fields(self) -> Dict[str, Tensor]: + ret = {} + """ + ) + for f in fields: + if f.type_ == Boxes: + stmt = "t.tensor" + elif f.type_ == torch.Tensor: + stmt = "t" + else: + stmt = f'assert False, "unsupported type {str(f.type_)}"' + lines.append( + f""" + t = self._{f.name} + if t is not None: + ret["{f.name}"] = {stmt} + """ + ) + lines.append( + """ + return ret""" + ) return cls_name, os.linesep.join(lines) diff --git a/docs/tutorials/deployment.md b/docs/tutorials/deployment.md index b9ccab0908..7f4ff7fb6e 100644 --- a/docs/tutorials/deployment.md +++ b/docs/tutorials/deployment.md @@ -26,7 +26,7 @@ We currently support the following combination and each has some limitations: +----------------------------+-------------+-------------+-----------------------------+ | **Runtime** | PyTorch | PyTorch | Caffe2, PyTorch | +----------------------------+-------------+-------------+-----------------------------+ -| C++/Python inference | ✅ | ❌ (WIP_) | ✅ | +| C++/Python inference | ✅ | ✅ | ✅ | +----------------------------+-------------+-------------+-----------------------------+ | Dynamic resolution | ✅ | ✅ | ✅ | +----------------------------+-------------+-------------+-----------------------------+ @@ -43,8 +43,6 @@ We currently support the following combination and each has some limitations: | PointRend R-CNN | ✅ | ❌ | ❌ | +----------------------------+-------------+-------------+-----------------------------+ -.. _WIP: https://github.com/pytorch/pytorch/issues/46944 - ``` We don't plan to work on additional support for other formats/runtime, but contributions are welcome. diff --git a/tools/deploy/export_model.py b/tools/deploy/export_model.py index fe2fe304d7..520e4b8dc1 100755 --- a/tools/deploy/export_model.py +++ b/tools/deploy/export_model.py @@ -2,9 +2,10 @@ # Copyright (c) Facebook, Inc. and its affiliates. import argparse import os +from typing import Dict, List, Tuple import onnx import torch -from torch import Tensor +from torch import Tensor, nn from detectron2.checkpoint import DetectionCheckpointer from detectron2.config import get_cfg @@ -70,7 +71,30 @@ def export_scripting(torch_model): "pred_keypoint_heatmaps": torch.Tensor, } assert args.format == "torchscript", "Scripting only supports torchscript format." - ts_model = scripting_with_instances(torch_model, fields) + + class ScriptableAdapterBase(nn.Module): + # Use this adapter to workaround https://github.com/pytorch/pytorch/issues/46944 + # by not retuning instances but dicts. Otherwise the exported model is not deployable + def __init__(self): + super().__init__() + self.model = torch_model + self.eval() + + if isinstance(torch_model, GeneralizedRCNN): + + class ScriptableAdapter(ScriptableAdapterBase): + def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: + instances = self.model.inference(inputs, do_postprocess=False) + return [i.get_fields() for i in instances] + + else: + + class ScriptableAdapter(ScriptableAdapterBase): + def forward(self, inputs: Tuple[Dict[str, torch.Tensor]]) -> List[Dict[str, Tensor]]: + instances = self.model(inputs) + return [i.get_fields() for i in instances] + + ts_model = scripting_with_instances(ScriptableAdapter(), fields) with PathManager.open(os.path.join(args.output, "model.ts"), "wb") as f: torch.jit.save(ts_model, f) dump_torchscript_IR(ts_model, args.output)